diff --git a/backend/internal/dto/user_dto.go b/backend/internal/dto/user_dto.go index 401b15b8..f3b2b980 100644 --- a/backend/internal/dto/user_dto.go +++ b/backend/internal/dto/user_dto.go @@ -1,6 +1,9 @@ package dto import ( + "errors" + + "github.com/gin-gonic/gin/binding" "github.com/pocket-id/pocket-id/backend/internal/utils" ) @@ -29,6 +32,17 @@ type UserCreateDto struct { LdapID string `json:"-"` } +func (u UserCreateDto) Validate() error { + e, ok := binding.Validator.Engine().(interface { + Struct(s any) error + }) + if !ok { + return errors.New("validator does not implement the expected interface") + } + + return e.Struct(u) +} + type OneTimeAccessTokenCreateDto struct { UserID string `json:"userId"` TTL utils.JSONDuration `json:"ttl" binding:"ttl"` diff --git a/backend/internal/dto/user_dto_test.go b/backend/internal/dto/user_dto_test.go new file mode 100644 index 00000000..181e6e8e --- /dev/null +++ b/backend/internal/dto/user_dto_test.go @@ -0,0 +1,89 @@ +package dto + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUserCreateDto_Validate(t *testing.T) { + testCases := []struct { + name string + input UserCreateDto + wantErr string + }{ + { + name: "valid input", + input: UserCreateDto{ + Username: "testuser", + Email: "test@example.com", + FirstName: "John", + LastName: "Doe", + }, + wantErr: "", + }, + { + name: "missing username", + input: UserCreateDto{ + Email: "test@example.com", + FirstName: "John", + LastName: "Doe", + }, + wantErr: "Field validation for 'Username' failed on the 'required' tag", + }, + { + name: "username contains invalid characters", + input: UserCreateDto{ + Username: "test/ser", + Email: "test@example.com", + FirstName: "John", + LastName: "Doe", + }, + wantErr: "Field validation for 'Username' failed on the 'username' tag", + }, + { + name: "invalid email", + input: UserCreateDto{ + Username: "testuser", + Email: "not-an-email", + FirstName: "John", + LastName: "Doe", + }, + wantErr: "Field validation for 'Email' failed on the 'email' tag", + }, + { + name: "first name too short", + input: UserCreateDto{ + Username: "testuser", + Email: "test@example.com", + FirstName: "", + LastName: "Doe", + }, + wantErr: "Field validation for 'FirstName' failed on the 'required' tag", + }, + { + name: "last name too long", + input: UserCreateDto{ + Username: "testuser", + Email: "test@example.com", + FirstName: "John", + LastName: "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz", + }, + wantErr: "Field validation for 'LastName' failed on the 'max' tag", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.input.Validate() + + if tc.wantErr == "" { + require.NoError(t, err) + return + } + + require.Error(t, err) + require.ErrorContains(t, err, tc.wantErr) + }) + } +} diff --git a/backend/internal/dto/user_group_dto.go b/backend/internal/dto/user_group_dto.go index 57c2b22a..09b3d1dc 100644 --- a/backend/internal/dto/user_group_dto.go +++ b/backend/internal/dto/user_group_dto.go @@ -1,6 +1,9 @@ package dto import ( + "errors" + + "github.com/gin-gonic/gin/binding" datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" ) @@ -39,6 +42,17 @@ type UserGroupCreateDto struct { LdapID string `json:"-"` } +func (g UserGroupCreateDto) Validate() error { + e, ok := binding.Validator.Engine().(interface { + Struct(s any) error + }) + if !ok { + return errors.New("validator does not implement the expected interface") + } + + return e.Struct(g) +} + type UserGroupUpdateUsersDto struct { UserIDs []string `json:"userIds" binding:"required"` } diff --git a/backend/internal/dto/validations.go b/backend/internal/dto/validations.go index 80c3ea56..7a497282 100644 --- a/backend/internal/dto/validations.go +++ b/backend/internal/dto/validations.go @@ -10,29 +10,29 @@ import ( "github.com/go-playground/validator/v10" ) +// [a-zA-Z0-9] : The username must start with an alphanumeric character +// [a-zA-Z0-9_.@-]* : The rest of the username can contain alphanumeric characters, dots, underscores, hyphens, and "@" symbols +// [a-zA-Z0-9]$ : The username must end with an alphanumeric character +var validateUsernameRegex = regexp.MustCompile("^[a-zA-Z0-9][a-zA-Z0-9_.@-]*[a-zA-Z0-9]$") + +var validateClientIDRegex = regexp.MustCompile("^[a-zA-Z0-9._-]+$") + func init() { v := binding.Validator.Engine().(*validator.Validate) - // [a-zA-Z0-9] : The username must start with an alphanumeric character - // [a-zA-Z0-9_.@-]* : The rest of the username can contain alphanumeric characters, dots, underscores, hyphens, and "@" symbols - // [a-zA-Z0-9]$ : The username must end with an alphanumeric character - var validateUsernameRegex = regexp.MustCompile("^[a-zA-Z0-9][a-zA-Z0-9_.@-]*[a-zA-Z0-9]$") - - var validateClientIDRegex = regexp.MustCompile("^[a-zA-Z0-9._-]+$") - // Maximum allowed value for TTLs const maxTTL = 31 * 24 * time.Hour // Errors here are development-time ones err := v.RegisterValidation("username", func(fl validator.FieldLevel) bool { - return validateUsernameRegex.MatchString(fl.Field().String()) + return ValidateUsername(fl.Field().String()) }) if err != nil { panic("Failed to register custom validation for username: " + err.Error()) } err = v.RegisterValidation("client_id", func(fl validator.FieldLevel) bool { - return validateClientIDRegex.MatchString(fl.Field().String()) + return ValidateClientID(fl.Field().String()) }) if err != nil { panic("Failed to register custom validation for client_id: " + err.Error()) @@ -50,3 +50,13 @@ func init() { panic("Failed to register custom validation for ttl: " + err.Error()) } } + +// ValidateUsername validates username inputs +func ValidateUsername(username string) bool { + return validateUsernameRegex.MatchString(username) +} + +// ValidateClientID validates client ID inputs +func ValidateClientID(clientID string) bool { + return validateClientIDRegex.MatchString(clientID) +} diff --git a/backend/internal/dto/validations_test.go b/backend/internal/dto/validations_test.go new file mode 100644 index 00000000..f6449068 --- /dev/null +++ b/backend/internal/dto/validations_test.go @@ -0,0 +1,58 @@ +package dto + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValidateUsername(t *testing.T) { + tests := []struct { + name string + input string + expected bool + }{ + {"valid simple", "user123", true}, + {"valid with dot", "user.name", true}, + {"valid with underscore", "user_name", true}, + {"valid with hyphen", "user-name", true}, + {"valid with at", "user@name", true}, + {"starts with symbol", ".username", false}, + {"ends with non-alphanumeric", "username-", false}, + {"contains space", "user name", false}, + {"empty", "", false}, + {"only special chars", "-._@", false}, + {"valid long", "a1234567890_b.c-d@e", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, ValidateUsername(tt.input)) + }) + } +} + +func TestValidateClientID(t *testing.T) { + tests := []struct { + name string + input string + expected bool + }{ + {"valid simple", "client123", true}, + {"valid with dot", "client.id", true}, + {"valid with underscore", "client_id", true}, + {"valid with hyphen", "client-id", true}, + {"valid with all", "client.id-123_abc", true}, + {"contains space", "client id", false}, + {"contains at", "client@id", false}, + {"empty", "", false}, + {"only special chars", "-._", true}, + {"invalid char", "client!id", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, ValidateClientID(tt.input)) + }) + } +} diff --git a/backend/internal/service/ldap_service.go b/backend/internal/service/ldap_service.go index 9567503c..fd564d26 100644 --- a/backend/internal/service/ldap_service.go +++ b/backend/internal/service/ldap_service.go @@ -179,10 +179,12 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap. } } + username = norm.NFC.String(username) + var databaseUser model.User err = tx. WithContext(ctx). - Where("username = ? AND ldap_id IS NOT NULL", norm.NFC.String(username)). + Where("username = ? AND ldap_id IS NOT NULL", username). First(&databaseUser). Error if errors.Is(err, gorm.ErrRecordNotFound) { @@ -202,6 +204,12 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap. } dto.Normalize(syncGroup) + err = syncGroup.Validate() + if err != nil { + slog.WarnContext(ctx, "LDAP user group object is not valid", slog.Any("error", err)) + continue + } + if databaseGroup.ID == "" { newGroup, err := s.groupService.createInternal(ctx, syncGroup, tx) if err != nil { @@ -347,6 +355,12 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C } dto.Normalize(newUser) + err = newUser.Validate() + if err != nil { + slog.WarnContext(ctx, "LDAP user object is not valid", slog.Any("error", err)) + continue + } + if databaseUser.ID == "" { _, err = s.userService.createUserInternal(ctx, newUser, true, tx) if errors.Is(err, &common.AlreadyInUseError{}) {