fix: ensure users imported from LDAP have fields validated (#923)

This commit is contained in:
Alessandro (Ale) Segala
2025-09-09 00:31:49 -07:00
committed by GitHub
parent 92edc26a30
commit 42155238b7
6 changed files with 209 additions and 10 deletions

View File

@@ -1,6 +1,9 @@
package dto
import (
"errors"
"github.com/gin-gonic/gin/binding"
"github.com/pocket-id/pocket-id/backend/internal/utils"
)
@@ -29,6 +32,17 @@ type UserCreateDto struct {
LdapID string `json:"-"`
}
func (u UserCreateDto) Validate() error {
e, ok := binding.Validator.Engine().(interface {
Struct(s any) error
})
if !ok {
return errors.New("validator does not implement the expected interface")
}
return e.Struct(u)
}
type OneTimeAccessTokenCreateDto struct {
UserID string `json:"userId"`
TTL utils.JSONDuration `json:"ttl" binding:"ttl"`

View File

@@ -0,0 +1,89 @@
package dto
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestUserCreateDto_Validate(t *testing.T) {
testCases := []struct {
name string
input UserCreateDto
wantErr string
}{
{
name: "valid input",
input: UserCreateDto{
Username: "testuser",
Email: "test@example.com",
FirstName: "John",
LastName: "Doe",
},
wantErr: "",
},
{
name: "missing username",
input: UserCreateDto{
Email: "test@example.com",
FirstName: "John",
LastName: "Doe",
},
wantErr: "Field validation for 'Username' failed on the 'required' tag",
},
{
name: "username contains invalid characters",
input: UserCreateDto{
Username: "test/ser",
Email: "test@example.com",
FirstName: "John",
LastName: "Doe",
},
wantErr: "Field validation for 'Username' failed on the 'username' tag",
},
{
name: "invalid email",
input: UserCreateDto{
Username: "testuser",
Email: "not-an-email",
FirstName: "John",
LastName: "Doe",
},
wantErr: "Field validation for 'Email' failed on the 'email' tag",
},
{
name: "first name too short",
input: UserCreateDto{
Username: "testuser",
Email: "test@example.com",
FirstName: "",
LastName: "Doe",
},
wantErr: "Field validation for 'FirstName' failed on the 'required' tag",
},
{
name: "last name too long",
input: UserCreateDto{
Username: "testuser",
Email: "test@example.com",
FirstName: "John",
LastName: "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz",
},
wantErr: "Field validation for 'LastName' failed on the 'max' tag",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.input.Validate()
if tc.wantErr == "" {
require.NoError(t, err)
return
}
require.Error(t, err)
require.ErrorContains(t, err, tc.wantErr)
})
}
}

View File

@@ -1,6 +1,9 @@
package dto
import (
"errors"
"github.com/gin-gonic/gin/binding"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
)
@@ -39,6 +42,17 @@ type UserGroupCreateDto struct {
LdapID string `json:"-"`
}
func (g UserGroupCreateDto) Validate() error {
e, ok := binding.Validator.Engine().(interface {
Struct(s any) error
})
if !ok {
return errors.New("validator does not implement the expected interface")
}
return e.Struct(g)
}
type UserGroupUpdateUsersDto struct {
UserIDs []string `json:"userIds" binding:"required"`
}

View File

@@ -10,29 +10,29 @@ import (
"github.com/go-playground/validator/v10"
)
// [a-zA-Z0-9] : The username must start with an alphanumeric character
// [a-zA-Z0-9_.@-]* : The rest of the username can contain alphanumeric characters, dots, underscores, hyphens, and "@" symbols
// [a-zA-Z0-9]$ : The username must end with an alphanumeric character
var validateUsernameRegex = regexp.MustCompile("^[a-zA-Z0-9][a-zA-Z0-9_.@-]*[a-zA-Z0-9]$")
var validateClientIDRegex = regexp.MustCompile("^[a-zA-Z0-9._-]+$")
func init() {
v := binding.Validator.Engine().(*validator.Validate)
// [a-zA-Z0-9] : The username must start with an alphanumeric character
// [a-zA-Z0-9_.@-]* : The rest of the username can contain alphanumeric characters, dots, underscores, hyphens, and "@" symbols
// [a-zA-Z0-9]$ : The username must end with an alphanumeric character
var validateUsernameRegex = regexp.MustCompile("^[a-zA-Z0-9][a-zA-Z0-9_.@-]*[a-zA-Z0-9]$")
var validateClientIDRegex = regexp.MustCompile("^[a-zA-Z0-9._-]+$")
// Maximum allowed value for TTLs
const maxTTL = 31 * 24 * time.Hour
// Errors here are development-time ones
err := v.RegisterValidation("username", func(fl validator.FieldLevel) bool {
return validateUsernameRegex.MatchString(fl.Field().String())
return ValidateUsername(fl.Field().String())
})
if err != nil {
panic("Failed to register custom validation for username: " + err.Error())
}
err = v.RegisterValidation("client_id", func(fl validator.FieldLevel) bool {
return validateClientIDRegex.MatchString(fl.Field().String())
return ValidateClientID(fl.Field().String())
})
if err != nil {
panic("Failed to register custom validation for client_id: " + err.Error())
@@ -50,3 +50,13 @@ func init() {
panic("Failed to register custom validation for ttl: " + err.Error())
}
}
// ValidateUsername validates username inputs
func ValidateUsername(username string) bool {
return validateUsernameRegex.MatchString(username)
}
// ValidateClientID validates client ID inputs
func ValidateClientID(clientID string) bool {
return validateClientIDRegex.MatchString(clientID)
}

View File

@@ -0,0 +1,58 @@
package dto
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestValidateUsername(t *testing.T) {
tests := []struct {
name string
input string
expected bool
}{
{"valid simple", "user123", true},
{"valid with dot", "user.name", true},
{"valid with underscore", "user_name", true},
{"valid with hyphen", "user-name", true},
{"valid with at", "user@name", true},
{"starts with symbol", ".username", false},
{"ends with non-alphanumeric", "username-", false},
{"contains space", "user name", false},
{"empty", "", false},
{"only special chars", "-._@", false},
{"valid long", "a1234567890_b.c-d@e", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, ValidateUsername(tt.input))
})
}
}
func TestValidateClientID(t *testing.T) {
tests := []struct {
name string
input string
expected bool
}{
{"valid simple", "client123", true},
{"valid with dot", "client.id", true},
{"valid with underscore", "client_id", true},
{"valid with hyphen", "client-id", true},
{"valid with all", "client.id-123_abc", true},
{"contains space", "client id", false},
{"contains at", "client@id", false},
{"empty", "", false},
{"only special chars", "-._", true},
{"invalid char", "client!id", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, ValidateClientID(tt.input))
})
}
}

View File

@@ -179,10 +179,12 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
}
}
username = norm.NFC.String(username)
var databaseUser model.User
err = tx.
WithContext(ctx).
Where("username = ? AND ldap_id IS NOT NULL", norm.NFC.String(username)).
Where("username = ? AND ldap_id IS NOT NULL", username).
First(&databaseUser).
Error
if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -202,6 +204,12 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
}
dto.Normalize(syncGroup)
err = syncGroup.Validate()
if err != nil {
slog.WarnContext(ctx, "LDAP user group object is not valid", slog.Any("error", err))
continue
}
if databaseGroup.ID == "" {
newGroup, err := s.groupService.createInternal(ctx, syncGroup, tx)
if err != nil {
@@ -347,6 +355,12 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
}
dto.Normalize(newUser)
err = newUser.Validate()
if err != nil {
slog.WarnContext(ctx, "LDAP user object is not valid", slog.Any("error", err))
continue
}
if databaseUser.ID == "" {
_, err = s.userService.createUserInternal(ctx, newUser, true, tx)
if errors.Is(err, &common.AlreadyInUseError{}) {