mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-16 09:13:20 +03:00
fix: ensure users imported from LDAP have fields validated (#923)
This commit is contained in:
committed by
GitHub
parent
92edc26a30
commit
42155238b7
@@ -1,6 +1,9 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/gin-gonic/gin/binding"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
)
|
||||
|
||||
@@ -29,6 +32,17 @@ type UserCreateDto struct {
|
||||
LdapID string `json:"-"`
|
||||
}
|
||||
|
||||
func (u UserCreateDto) Validate() error {
|
||||
e, ok := binding.Validator.Engine().(interface {
|
||||
Struct(s any) error
|
||||
})
|
||||
if !ok {
|
||||
return errors.New("validator does not implement the expected interface")
|
||||
}
|
||||
|
||||
return e.Struct(u)
|
||||
}
|
||||
|
||||
type OneTimeAccessTokenCreateDto struct {
|
||||
UserID string `json:"userId"`
|
||||
TTL utils.JSONDuration `json:"ttl" binding:"ttl"`
|
||||
|
||||
89
backend/internal/dto/user_dto_test.go
Normal file
89
backend/internal/dto/user_dto_test.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUserCreateDto_Validate(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
input UserCreateDto
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid input",
|
||||
input: UserCreateDto{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
FirstName: "John",
|
||||
LastName: "Doe",
|
||||
},
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "missing username",
|
||||
input: UserCreateDto{
|
||||
Email: "test@example.com",
|
||||
FirstName: "John",
|
||||
LastName: "Doe",
|
||||
},
|
||||
wantErr: "Field validation for 'Username' failed on the 'required' tag",
|
||||
},
|
||||
{
|
||||
name: "username contains invalid characters",
|
||||
input: UserCreateDto{
|
||||
Username: "test/ser",
|
||||
Email: "test@example.com",
|
||||
FirstName: "John",
|
||||
LastName: "Doe",
|
||||
},
|
||||
wantErr: "Field validation for 'Username' failed on the 'username' tag",
|
||||
},
|
||||
{
|
||||
name: "invalid email",
|
||||
input: UserCreateDto{
|
||||
Username: "testuser",
|
||||
Email: "not-an-email",
|
||||
FirstName: "John",
|
||||
LastName: "Doe",
|
||||
},
|
||||
wantErr: "Field validation for 'Email' failed on the 'email' tag",
|
||||
},
|
||||
{
|
||||
name: "first name too short",
|
||||
input: UserCreateDto{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
FirstName: "",
|
||||
LastName: "Doe",
|
||||
},
|
||||
wantErr: "Field validation for 'FirstName' failed on the 'required' tag",
|
||||
},
|
||||
{
|
||||
name: "last name too long",
|
||||
input: UserCreateDto{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
FirstName: "John",
|
||||
LastName: "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz",
|
||||
},
|
||||
wantErr: "Field validation for 'LastName' failed on the 'max' tag",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.input.Validate()
|
||||
|
||||
if tc.wantErr == "" {
|
||||
require.NoError(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, tc.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,9 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/gin-gonic/gin/binding"
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
)
|
||||
|
||||
@@ -39,6 +42,17 @@ type UserGroupCreateDto struct {
|
||||
LdapID string `json:"-"`
|
||||
}
|
||||
|
||||
func (g UserGroupCreateDto) Validate() error {
|
||||
e, ok := binding.Validator.Engine().(interface {
|
||||
Struct(s any) error
|
||||
})
|
||||
if !ok {
|
||||
return errors.New("validator does not implement the expected interface")
|
||||
}
|
||||
|
||||
return e.Struct(g)
|
||||
}
|
||||
|
||||
type UserGroupUpdateUsersDto struct {
|
||||
UserIDs []string `json:"userIds" binding:"required"`
|
||||
}
|
||||
|
||||
@@ -10,29 +10,29 @@ import (
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
// [a-zA-Z0-9] : The username must start with an alphanumeric character
|
||||
// [a-zA-Z0-9_.@-]* : The rest of the username can contain alphanumeric characters, dots, underscores, hyphens, and "@" symbols
|
||||
// [a-zA-Z0-9]$ : The username must end with an alphanumeric character
|
||||
var validateUsernameRegex = regexp.MustCompile("^[a-zA-Z0-9][a-zA-Z0-9_.@-]*[a-zA-Z0-9]$")
|
||||
|
||||
var validateClientIDRegex = regexp.MustCompile("^[a-zA-Z0-9._-]+$")
|
||||
|
||||
func init() {
|
||||
v := binding.Validator.Engine().(*validator.Validate)
|
||||
|
||||
// [a-zA-Z0-9] : The username must start with an alphanumeric character
|
||||
// [a-zA-Z0-9_.@-]* : The rest of the username can contain alphanumeric characters, dots, underscores, hyphens, and "@" symbols
|
||||
// [a-zA-Z0-9]$ : The username must end with an alphanumeric character
|
||||
var validateUsernameRegex = regexp.MustCompile("^[a-zA-Z0-9][a-zA-Z0-9_.@-]*[a-zA-Z0-9]$")
|
||||
|
||||
var validateClientIDRegex = regexp.MustCompile("^[a-zA-Z0-9._-]+$")
|
||||
|
||||
// Maximum allowed value for TTLs
|
||||
const maxTTL = 31 * 24 * time.Hour
|
||||
|
||||
// Errors here are development-time ones
|
||||
err := v.RegisterValidation("username", func(fl validator.FieldLevel) bool {
|
||||
return validateUsernameRegex.MatchString(fl.Field().String())
|
||||
return ValidateUsername(fl.Field().String())
|
||||
})
|
||||
if err != nil {
|
||||
panic("Failed to register custom validation for username: " + err.Error())
|
||||
}
|
||||
|
||||
err = v.RegisterValidation("client_id", func(fl validator.FieldLevel) bool {
|
||||
return validateClientIDRegex.MatchString(fl.Field().String())
|
||||
return ValidateClientID(fl.Field().String())
|
||||
})
|
||||
if err != nil {
|
||||
panic("Failed to register custom validation for client_id: " + err.Error())
|
||||
@@ -50,3 +50,13 @@ func init() {
|
||||
panic("Failed to register custom validation for ttl: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateUsername validates username inputs
|
||||
func ValidateUsername(username string) bool {
|
||||
return validateUsernameRegex.MatchString(username)
|
||||
}
|
||||
|
||||
// ValidateClientID validates client ID inputs
|
||||
func ValidateClientID(clientID string) bool {
|
||||
return validateClientIDRegex.MatchString(clientID)
|
||||
}
|
||||
|
||||
58
backend/internal/dto/validations_test.go
Normal file
58
backend/internal/dto/validations_test.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestValidateUsername(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{"valid simple", "user123", true},
|
||||
{"valid with dot", "user.name", true},
|
||||
{"valid with underscore", "user_name", true},
|
||||
{"valid with hyphen", "user-name", true},
|
||||
{"valid with at", "user@name", true},
|
||||
{"starts with symbol", ".username", false},
|
||||
{"ends with non-alphanumeric", "username-", false},
|
||||
{"contains space", "user name", false},
|
||||
{"empty", "", false},
|
||||
{"only special chars", "-._@", false},
|
||||
{"valid long", "a1234567890_b.c-d@e", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, ValidateUsername(tt.input))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateClientID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{"valid simple", "client123", true},
|
||||
{"valid with dot", "client.id", true},
|
||||
{"valid with underscore", "client_id", true},
|
||||
{"valid with hyphen", "client-id", true},
|
||||
{"valid with all", "client.id-123_abc", true},
|
||||
{"contains space", "client id", false},
|
||||
{"contains at", "client@id", false},
|
||||
{"empty", "", false},
|
||||
{"only special chars", "-._", true},
|
||||
{"invalid char", "client!id", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, ValidateClientID(tt.input))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -179,10 +179,12 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
|
||||
}
|
||||
}
|
||||
|
||||
username = norm.NFC.String(username)
|
||||
|
||||
var databaseUser model.User
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Where("username = ? AND ldap_id IS NOT NULL", norm.NFC.String(username)).
|
||||
Where("username = ? AND ldap_id IS NOT NULL", username).
|
||||
First(&databaseUser).
|
||||
Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
@@ -202,6 +204,12 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
|
||||
}
|
||||
dto.Normalize(syncGroup)
|
||||
|
||||
err = syncGroup.Validate()
|
||||
if err != nil {
|
||||
slog.WarnContext(ctx, "LDAP user group object is not valid", slog.Any("error", err))
|
||||
continue
|
||||
}
|
||||
|
||||
if databaseGroup.ID == "" {
|
||||
newGroup, err := s.groupService.createInternal(ctx, syncGroup, tx)
|
||||
if err != nil {
|
||||
@@ -347,6 +355,12 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
||||
}
|
||||
dto.Normalize(newUser)
|
||||
|
||||
err = newUser.Validate()
|
||||
if err != nil {
|
||||
slog.WarnContext(ctx, "LDAP user object is not valid", slog.Any("error", err))
|
||||
continue
|
||||
}
|
||||
|
||||
if databaseUser.ID == "" {
|
||||
_, err = s.userService.createUserInternal(ctx, newUser, true, tx)
|
||||
if errors.Is(err, &common.AlreadyInUseError{}) {
|
||||
|
||||
Reference in New Issue
Block a user