mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-12 16:22:58 +03:00
feat: support wildcard callback URLs
This commit is contained in:
@@ -16,7 +16,7 @@ type OidcClientDto struct {
|
|||||||
|
|
||||||
type OidcClientCreateDto struct {
|
type OidcClientCreateDto struct {
|
||||||
Name string `json:"name" binding:"required,max=50"`
|
Name string `json:"name" binding:"required,max=50"`
|
||||||
CallbackURLs []string `json:"callbackURLs" binding:"required,urlList"`
|
CallbackURLs []string `json:"callbackURLs" binding:"required"`
|
||||||
IsPublic bool `json:"isPublic"`
|
IsPublic bool `json:"isPublic"`
|
||||||
PkceEnabled bool `json:"pkceEnabled"`
|
PkceEnabled bool `json:"pkceEnabled"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,21 +4,9 @@ import (
|
|||||||
"github.com/gin-gonic/gin/binding"
|
"github.com/gin-gonic/gin/binding"
|
||||||
"github.com/go-playground/validator/v10"
|
"github.com/go-playground/validator/v10"
|
||||||
"log"
|
"log"
|
||||||
"net/url"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
)
|
)
|
||||||
|
|
||||||
var validateUrlList validator.Func = func(fl validator.FieldLevel) bool {
|
|
||||||
urls := fl.Field().Interface().([]string)
|
|
||||||
for _, u := range urls {
|
|
||||||
_, err := url.ParseRequestURI(u)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
var validateUsername validator.Func = func(fl validator.FieldLevel) bool {
|
var validateUsername validator.Func = func(fl validator.FieldLevel) bool {
|
||||||
// [a-zA-Z0-9] : The username must start with an alphanumeric character
|
// [a-zA-Z0-9] : The username must start with an alphanumeric character
|
||||||
// [a-zA-Z0-9_.@-]* : The rest of the username can contain alphanumeric characters, dots, underscores, hyphens, and "@" symbols
|
// [a-zA-Z0-9_.@-]* : The rest of the username can contain alphanumeric characters, dots, underscores, hyphens, and "@" symbols
|
||||||
@@ -36,11 +24,6 @@ var validateClaimKey validator.Func = func(fl validator.FieldLevel) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
|
|
||||||
if err := v.RegisterValidation("urlList", validateUrlList); err != nil {
|
|
||||||
log.Fatalf("Failed to register custom validation: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
|
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
|
||||||
if err := v.RegisterValidation("username", validateUsername); err != nil {
|
if err := v.RegisterValidation("username", validateUsername); err != nil {
|
||||||
log.Fatalf("Failed to register custom validation: %v", err)
|
log.Fatalf("Failed to register custom validation: %v", err)
|
||||||
|
|||||||
@@ -83,8 +83,6 @@ func handleValidationError(validationErrors validator.ValidationErrors) string {
|
|||||||
errorMessage = fmt.Sprintf("%s must be at least %s characters long", fieldName, ve.Param())
|
errorMessage = fmt.Sprintf("%s must be at least %s characters long", fieldName, ve.Param())
|
||||||
case "max":
|
case "max":
|
||||||
errorMessage = fmt.Sprintf("%s must be at most %s characters long", fieldName, ve.Param())
|
errorMessage = fmt.Sprintf("%s must be at most %s characters long", fieldName, ve.Param())
|
||||||
case "urlList":
|
|
||||||
errorMessage = fmt.Sprintf("%s must be a list of valid URLs", fieldName)
|
|
||||||
default:
|
default:
|
||||||
errorMessage = fmt.Sprintf("%s is invalid", fieldName)
|
errorMessage = fmt.Sprintf("%s is invalid", fieldName)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"os"
|
"os"
|
||||||
"slices"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -432,8 +432,16 @@ func (s *OidcService) getCallbackURL(client model.OidcClient, inputCallbackURL s
|
|||||||
if inputCallbackURL == "" {
|
if inputCallbackURL == "" {
|
||||||
return client.CallbackURLs[0], nil
|
return client.CallbackURLs[0], nil
|
||||||
}
|
}
|
||||||
if slices.Contains(client.CallbackURLs, inputCallbackURL) {
|
|
||||||
return inputCallbackURL, nil
|
for _, callbackPattern := range client.CallbackURLs {
|
||||||
|
regexPattern := strings.ReplaceAll(regexp.QuoteMeta(callbackPattern), `\*`, ".*") + "$"
|
||||||
|
matched, err := regexp.MatchString(regexPattern, inputCallbackURL)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if matched {
|
||||||
|
return inputCallbackURL, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", &common.OidcInvalidCallbackURLError{}
|
return "", &common.OidcInvalidCallbackURLError{}
|
||||||
|
|||||||
@@ -36,7 +36,7 @@
|
|||||||
|
|
||||||
const formSchema = z.object({
|
const formSchema = z.object({
|
||||||
name: z.string().min(2).max(50),
|
name: z.string().min(2).max(50),
|
||||||
callbackURLs: z.array(z.string().url()).nonempty(),
|
callbackURLs: z.array(z.string()).nonempty(),
|
||||||
isPublic: z.boolean(),
|
isPublic: z.boolean(),
|
||||||
pkceEnabled: z.boolean()
|
pkceEnabled: z.boolean()
|
||||||
});
|
});
|
||||||
|
|||||||
Reference in New Issue
Block a user