mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-17 01:11:38 +03:00
fix: ensure users imported from LDAP have fields validated (#923)
This commit is contained in:
committed by
GitHub
parent
92edc26a30
commit
42155238b7
@@ -1,6 +1,9 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin/binding"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -29,6 +32,17 @@ type UserCreateDto struct {
|
|||||||
LdapID string `json:"-"`
|
LdapID string `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u UserCreateDto) Validate() error {
|
||||||
|
e, ok := binding.Validator.Engine().(interface {
|
||||||
|
Struct(s any) error
|
||||||
|
})
|
||||||
|
if !ok {
|
||||||
|
return errors.New("validator does not implement the expected interface")
|
||||||
|
}
|
||||||
|
|
||||||
|
return e.Struct(u)
|
||||||
|
}
|
||||||
|
|
||||||
type OneTimeAccessTokenCreateDto struct {
|
type OneTimeAccessTokenCreateDto struct {
|
||||||
UserID string `json:"userId"`
|
UserID string `json:"userId"`
|
||||||
TTL utils.JSONDuration `json:"ttl" binding:"ttl"`
|
TTL utils.JSONDuration `json:"ttl" binding:"ttl"`
|
||||||
|
|||||||
89
backend/internal/dto/user_dto_test.go
Normal file
89
backend/internal/dto/user_dto_test.go
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUserCreateDto_Validate(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
input UserCreateDto
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid input",
|
||||||
|
input: UserCreateDto{
|
||||||
|
Username: "testuser",
|
||||||
|
Email: "test@example.com",
|
||||||
|
FirstName: "John",
|
||||||
|
LastName: "Doe",
|
||||||
|
},
|
||||||
|
wantErr: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing username",
|
||||||
|
input: UserCreateDto{
|
||||||
|
Email: "test@example.com",
|
||||||
|
FirstName: "John",
|
||||||
|
LastName: "Doe",
|
||||||
|
},
|
||||||
|
wantErr: "Field validation for 'Username' failed on the 'required' tag",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "username contains invalid characters",
|
||||||
|
input: UserCreateDto{
|
||||||
|
Username: "test/ser",
|
||||||
|
Email: "test@example.com",
|
||||||
|
FirstName: "John",
|
||||||
|
LastName: "Doe",
|
||||||
|
},
|
||||||
|
wantErr: "Field validation for 'Username' failed on the 'username' tag",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid email",
|
||||||
|
input: UserCreateDto{
|
||||||
|
Username: "testuser",
|
||||||
|
Email: "not-an-email",
|
||||||
|
FirstName: "John",
|
||||||
|
LastName: "Doe",
|
||||||
|
},
|
||||||
|
wantErr: "Field validation for 'Email' failed on the 'email' tag",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "first name too short",
|
||||||
|
input: UserCreateDto{
|
||||||
|
Username: "testuser",
|
||||||
|
Email: "test@example.com",
|
||||||
|
FirstName: "",
|
||||||
|
LastName: "Doe",
|
||||||
|
},
|
||||||
|
wantErr: "Field validation for 'FirstName' failed on the 'required' tag",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "last name too long",
|
||||||
|
input: UserCreateDto{
|
||||||
|
Username: "testuser",
|
||||||
|
Email: "test@example.com",
|
||||||
|
FirstName: "John",
|
||||||
|
LastName: "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz",
|
||||||
|
},
|
||||||
|
wantErr: "Field validation for 'LastName' failed on the 'max' tag",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
err := tc.input.Validate()
|
||||||
|
|
||||||
|
if tc.wantErr == "" {
|
||||||
|
require.NoError(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorContains(t, err, tc.wantErr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,9 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin/binding"
|
||||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -39,6 +42,17 @@ type UserGroupCreateDto struct {
|
|||||||
LdapID string `json:"-"`
|
LdapID string `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (g UserGroupCreateDto) Validate() error {
|
||||||
|
e, ok := binding.Validator.Engine().(interface {
|
||||||
|
Struct(s any) error
|
||||||
|
})
|
||||||
|
if !ok {
|
||||||
|
return errors.New("validator does not implement the expected interface")
|
||||||
|
}
|
||||||
|
|
||||||
|
return e.Struct(g)
|
||||||
|
}
|
||||||
|
|
||||||
type UserGroupUpdateUsersDto struct {
|
type UserGroupUpdateUsersDto struct {
|
||||||
UserIDs []string `json:"userIds" binding:"required"`
|
UserIDs []string `json:"userIds" binding:"required"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,29 +10,29 @@ import (
|
|||||||
"github.com/go-playground/validator/v10"
|
"github.com/go-playground/validator/v10"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// [a-zA-Z0-9] : The username must start with an alphanumeric character
|
||||||
|
// [a-zA-Z0-9_.@-]* : The rest of the username can contain alphanumeric characters, dots, underscores, hyphens, and "@" symbols
|
||||||
|
// [a-zA-Z0-9]$ : The username must end with an alphanumeric character
|
||||||
|
var validateUsernameRegex = regexp.MustCompile("^[a-zA-Z0-9][a-zA-Z0-9_.@-]*[a-zA-Z0-9]$")
|
||||||
|
|
||||||
|
var validateClientIDRegex = regexp.MustCompile("^[a-zA-Z0-9._-]+$")
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
v := binding.Validator.Engine().(*validator.Validate)
|
v := binding.Validator.Engine().(*validator.Validate)
|
||||||
|
|
||||||
// [a-zA-Z0-9] : The username must start with an alphanumeric character
|
|
||||||
// [a-zA-Z0-9_.@-]* : The rest of the username can contain alphanumeric characters, dots, underscores, hyphens, and "@" symbols
|
|
||||||
// [a-zA-Z0-9]$ : The username must end with an alphanumeric character
|
|
||||||
var validateUsernameRegex = regexp.MustCompile("^[a-zA-Z0-9][a-zA-Z0-9_.@-]*[a-zA-Z0-9]$")
|
|
||||||
|
|
||||||
var validateClientIDRegex = regexp.MustCompile("^[a-zA-Z0-9._-]+$")
|
|
||||||
|
|
||||||
// Maximum allowed value for TTLs
|
// Maximum allowed value for TTLs
|
||||||
const maxTTL = 31 * 24 * time.Hour
|
const maxTTL = 31 * 24 * time.Hour
|
||||||
|
|
||||||
// Errors here are development-time ones
|
// Errors here are development-time ones
|
||||||
err := v.RegisterValidation("username", func(fl validator.FieldLevel) bool {
|
err := v.RegisterValidation("username", func(fl validator.FieldLevel) bool {
|
||||||
return validateUsernameRegex.MatchString(fl.Field().String())
|
return ValidateUsername(fl.Field().String())
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic("Failed to register custom validation for username: " + err.Error())
|
panic("Failed to register custom validation for username: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
err = v.RegisterValidation("client_id", func(fl validator.FieldLevel) bool {
|
err = v.RegisterValidation("client_id", func(fl validator.FieldLevel) bool {
|
||||||
return validateClientIDRegex.MatchString(fl.Field().String())
|
return ValidateClientID(fl.Field().String())
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic("Failed to register custom validation for client_id: " + err.Error())
|
panic("Failed to register custom validation for client_id: " + err.Error())
|
||||||
@@ -50,3 +50,13 @@ func init() {
|
|||||||
panic("Failed to register custom validation for ttl: " + err.Error())
|
panic("Failed to register custom validation for ttl: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ValidateUsername validates username inputs
|
||||||
|
func ValidateUsername(username string) bool {
|
||||||
|
return validateUsernameRegex.MatchString(username)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateClientID validates client ID inputs
|
||||||
|
func ValidateClientID(clientID string) bool {
|
||||||
|
return validateClientIDRegex.MatchString(clientID)
|
||||||
|
}
|
||||||
|
|||||||
58
backend/internal/dto/validations_test.go
Normal file
58
backend/internal/dto/validations_test.go
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestValidateUsername(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"valid simple", "user123", true},
|
||||||
|
{"valid with dot", "user.name", true},
|
||||||
|
{"valid with underscore", "user_name", true},
|
||||||
|
{"valid with hyphen", "user-name", true},
|
||||||
|
{"valid with at", "user@name", true},
|
||||||
|
{"starts with symbol", ".username", false},
|
||||||
|
{"ends with non-alphanumeric", "username-", false},
|
||||||
|
{"contains space", "user name", false},
|
||||||
|
{"empty", "", false},
|
||||||
|
{"only special chars", "-._@", false},
|
||||||
|
{"valid long", "a1234567890_b.c-d@e", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.expected, ValidateUsername(tt.input))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateClientID(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"valid simple", "client123", true},
|
||||||
|
{"valid with dot", "client.id", true},
|
||||||
|
{"valid with underscore", "client_id", true},
|
||||||
|
{"valid with hyphen", "client-id", true},
|
||||||
|
{"valid with all", "client.id-123_abc", true},
|
||||||
|
{"contains space", "client id", false},
|
||||||
|
{"contains at", "client@id", false},
|
||||||
|
{"empty", "", false},
|
||||||
|
{"only special chars", "-._", true},
|
||||||
|
{"invalid char", "client!id", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.expected, ValidateClientID(tt.input))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -179,10 +179,12 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
username = norm.NFC.String(username)
|
||||||
|
|
||||||
var databaseUser model.User
|
var databaseUser model.User
|
||||||
err = tx.
|
err = tx.
|
||||||
WithContext(ctx).
|
WithContext(ctx).
|
||||||
Where("username = ? AND ldap_id IS NOT NULL", norm.NFC.String(username)).
|
Where("username = ? AND ldap_id IS NOT NULL", username).
|
||||||
First(&databaseUser).
|
First(&databaseUser).
|
||||||
Error
|
Error
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
@@ -202,6 +204,12 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
|
|||||||
}
|
}
|
||||||
dto.Normalize(syncGroup)
|
dto.Normalize(syncGroup)
|
||||||
|
|
||||||
|
err = syncGroup.Validate()
|
||||||
|
if err != nil {
|
||||||
|
slog.WarnContext(ctx, "LDAP user group object is not valid", slog.Any("error", err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if databaseGroup.ID == "" {
|
if databaseGroup.ID == "" {
|
||||||
newGroup, err := s.groupService.createInternal(ctx, syncGroup, tx)
|
newGroup, err := s.groupService.createInternal(ctx, syncGroup, tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -347,6 +355,12 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
|||||||
}
|
}
|
||||||
dto.Normalize(newUser)
|
dto.Normalize(newUser)
|
||||||
|
|
||||||
|
err = newUser.Validate()
|
||||||
|
if err != nil {
|
||||||
|
slog.WarnContext(ctx, "LDAP user object is not valid", slog.Any("error", err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if databaseUser.ID == "" {
|
if databaseUser.ID == "" {
|
||||||
_, err = s.userService.createUserInternal(ctx, newUser, true, tx)
|
_, err = s.userService.createUserInternal(ctx, newUser, true, tx)
|
||||||
if errors.Is(err, &common.AlreadyInUseError{}) {
|
if errors.Is(err, &common.AlreadyInUseError{}) {
|
||||||
|
|||||||
Reference in New Issue
Block a user