feat: support wildcard callback URLs

This commit is contained in:
Elias Schneider
2025-01-20 11:19:23 +01:00
parent 3f02d08109
commit 8a1db0cb4a
5 changed files with 13 additions and 24 deletions

View File

@@ -16,7 +16,7 @@ type OidcClientDto struct {
type OidcClientCreateDto struct { type OidcClientCreateDto struct {
Name string `json:"name" binding:"required,max=50"` Name string `json:"name" binding:"required,max=50"`
CallbackURLs []string `json:"callbackURLs" binding:"required,urlList"` CallbackURLs []string `json:"callbackURLs" binding:"required"`
IsPublic bool `json:"isPublic"` IsPublic bool `json:"isPublic"`
PkceEnabled bool `json:"pkceEnabled"` PkceEnabled bool `json:"pkceEnabled"`
} }

View File

@@ -4,21 +4,9 @@ import (
"github.com/gin-gonic/gin/binding" "github.com/gin-gonic/gin/binding"
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
"log" "log"
"net/url"
"regexp" "regexp"
) )
var validateUrlList validator.Func = func(fl validator.FieldLevel) bool {
urls := fl.Field().Interface().([]string)
for _, u := range urls {
_, err := url.ParseRequestURI(u)
if err != nil {
return false
}
}
return true
}
var validateUsername validator.Func = func(fl validator.FieldLevel) bool { var validateUsername validator.Func = func(fl validator.FieldLevel) bool {
// [a-zA-Z0-9] : The username must start with an alphanumeric character // [a-zA-Z0-9] : The username must start with an alphanumeric character
// [a-zA-Z0-9_.@-]* : The rest of the username can contain alphanumeric characters, dots, underscores, hyphens, and "@" symbols // [a-zA-Z0-9_.@-]* : The rest of the username can contain alphanumeric characters, dots, underscores, hyphens, and "@" symbols
@@ -36,11 +24,6 @@ var validateClaimKey validator.Func = func(fl validator.FieldLevel) bool {
} }
func init() { func init() {
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
if err := v.RegisterValidation("urlList", validateUrlList); err != nil {
log.Fatalf("Failed to register custom validation: %v", err)
}
}
if v, ok := binding.Validator.Engine().(*validator.Validate); ok { if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
if err := v.RegisterValidation("username", validateUsername); err != nil { if err := v.RegisterValidation("username", validateUsername); err != nil {
log.Fatalf("Failed to register custom validation: %v", err) log.Fatalf("Failed to register custom validation: %v", err)

View File

@@ -83,8 +83,6 @@ func handleValidationError(validationErrors validator.ValidationErrors) string {
errorMessage = fmt.Sprintf("%s must be at least %s characters long", fieldName, ve.Param()) errorMessage = fmt.Sprintf("%s must be at least %s characters long", fieldName, ve.Param())
case "max": case "max":
errorMessage = fmt.Sprintf("%s must be at most %s characters long", fieldName, ve.Param()) errorMessage = fmt.Sprintf("%s must be at most %s characters long", fieldName, ve.Param())
case "urlList":
errorMessage = fmt.Sprintf("%s must be a list of valid URLs", fieldName)
default: default:
errorMessage = fmt.Sprintf("%s is invalid", fieldName) errorMessage = fmt.Sprintf("%s is invalid", fieldName)
} }

View File

@@ -14,7 +14,7 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
"mime/multipart" "mime/multipart"
"os" "os"
"slices" "regexp"
"strings" "strings"
"time" "time"
) )
@@ -432,8 +432,16 @@ func (s *OidcService) getCallbackURL(client model.OidcClient, inputCallbackURL s
if inputCallbackURL == "" { if inputCallbackURL == "" {
return client.CallbackURLs[0], nil return client.CallbackURLs[0], nil
} }
if slices.Contains(client.CallbackURLs, inputCallbackURL) {
return inputCallbackURL, nil for _, callbackPattern := range client.CallbackURLs {
regexPattern := strings.ReplaceAll(regexp.QuoteMeta(callbackPattern), `\*`, ".*") + "$"
matched, err := regexp.MatchString(regexPattern, inputCallbackURL)
if err != nil {
return "", err
}
if matched {
return inputCallbackURL, nil
}
} }
return "", &common.OidcInvalidCallbackURLError{} return "", &common.OidcInvalidCallbackURLError{}

View File

@@ -36,7 +36,7 @@
const formSchema = z.object({ const formSchema = z.object({
name: z.string().min(2).max(50), name: z.string().min(2).max(50),
callbackURLs: z.array(z.string().url()).nonempty(), callbackURLs: z.array(z.string()).nonempty(),
isPublic: z.boolean(), isPublic: z.boolean(),
pkceEnabled: z.boolean() pkceEnabled: z.boolean()
}); });