fix: improve LDAP error handling (#425)

Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
This commit is contained in:
Alessandro (Ale) Segala
2025-04-12 15:38:19 -07:00
committed by GitHub
parent 72061ba427
commit 796bc7ed34
6 changed files with 239 additions and 103 deletions

View File

@@ -1,6 +1,7 @@
package common
import (
"errors"
"fmt"
"net/http"
)
@@ -21,6 +22,12 @@ func (e *AlreadyInUseError) Error() string {
}
func (e *AlreadyInUseError) HttpStatusCode() int { return 400 }
func (e *AlreadyInUseError) Is(target error) bool {
// Ignore the field property when checking if an error is of the type AlreadyInUseError
x := &AlreadyInUseError{}
return errors.As(target, &x)
}
type SetupAlreadyCompletedError struct{}
func (e *SetupAlreadyCompletedError) Error() string { return "setup already completed" }

View File

@@ -160,7 +160,7 @@ func (ugc *UserGroupController) update(c *gin.Context) {
return
}
group, err := ugc.UserGroupService.Update(c.Request.Context(), c.Param("id"), input, false)
group, err := ugc.UserGroupService.Update(c.Request.Context(), c.Param("id"), input)
if err != nil {
_ = c.Error(err)
return

View File

@@ -15,6 +15,7 @@ import (
"time"
"github.com/go-ldap/ldap/v3"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
"gorm.io/gorm"
@@ -28,7 +29,12 @@ type LdapService struct {
}
func NewLdapService(db *gorm.DB, appConfigService *AppConfigService, userService *UserService, groupService *UserGroupService) *LdapService {
return &LdapService{db: db, appConfigService: appConfigService, userService: userService, groupService: groupService}
return &LdapService{
db: db,
appConfigService: appConfigService,
userService: userService,
groupService: groupService,
}
}
func (s *LdapService) createClient() (*ldap.Conn, error) {
@@ -39,19 +45,15 @@ func (s *LdapService) createClient() (*ldap.Conn, error) {
}
// Setup LDAP connection
ldapURL := dbConfig.LdapUrl.Value
skipTLSVerify := dbConfig.LdapSkipCertVerify.IsTrue()
client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{
InsecureSkipVerify: skipTLSVerify, //nolint:gosec
client, err := ldap.DialURL(dbConfig.LdapUrl.Value, ldap.DialWithTLSConfig(&tls.Config{
InsecureSkipVerify: dbConfig.LdapSkipCertVerify.IsTrue(), //nolint:gosec
}))
if err != nil {
return nil, fmt.Errorf("failed to connect to LDAP: %w", err)
}
// Bind as service account
bindDn := dbConfig.LdapBindDn.Value
bindPassword := dbConfig.LdapBindPassword.Value
err = client.Bind(bindDn, bindPassword)
err = client.Bind(dbConfig.LdapBindDn.Value, dbConfig.LdapBindPassword.Value)
if err != nil {
return nil, fmt.Errorf("failed to bind to LDAP: %w", err)
}
@@ -65,12 +67,19 @@ func (s *LdapService) SyncAll(ctx context.Context) error {
tx.Rollback()
}()
err := s.SyncUsers(ctx, tx)
// Setup LDAP connection
client, err := s.createClient()
if err != nil {
return fmt.Errorf("failed to create LDAP client: %w", err)
}
defer client.Close()
err = s.SyncUsers(ctx, tx, client)
if err != nil {
return fmt.Errorf("failed to sync users: %w", err)
}
err = s.SyncGroups(ctx, tx)
err = s.SyncGroups(ctx, tx, client)
if err != nil {
return fmt.Errorf("failed to sync groups: %w", err)
}
@@ -85,16 +94,9 @@ func (s *LdapService) SyncAll(ctx context.Context) error {
}
//nolint:gocognit
func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error {
func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.Conn) error {
dbConfig := s.appConfigService.GetDbConfig()
// Setup LDAP connection
client, err := s.createClient()
if err != nil {
return fmt.Errorf("failed to create LDAP client: %w", err)
}
defer client.Close()
searchAttrs := []string{
dbConfig.LdapAttributeGroupName.Value,
dbConfig.LdapAttributeGroupUniqueIdentifier.Value,
@@ -115,11 +117,9 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error {
}
// Create a mapping for groups that exist
ldapGroupIDs := make(map[string]bool)
ldapGroupIDs := make(map[string]struct{}, len(result.Entries))
for _, value := range result.Entries {
var membersUserId []string
ldapId := value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value)
// Skip groups without a valid LDAP ID
@@ -128,29 +128,40 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error {
continue
}
ldapGroupIDs[ldapId] = true
ldapGroupIDs[ldapId] = struct{}{}
// Try to find the group in the database
var databaseGroup model.UserGroup
tx.WithContext(ctx).Where("ldap_id = ?", ldapId).First(&databaseGroup)
err = tx.
WithContext(ctx).
Where("ldap_id = ?", ldapId).
First(&databaseGroup).
Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
// This could error with ErrRecordNotFound and we want to ignore that here
return fmt.Errorf("failed to query for LDAP group ID '%s': %w", ldapId, err)
}
// Get group members and add to the correct Group
groupMembers := value.GetAttributeValues(dbConfig.LdapAttributeGroupMember.Value)
membersUserId := make([]string, 0, len(groupMembers))
for _, member := range groupMembers {
// Normal output of this would be CN=username,ou=people,dc=example,dc=com
// Splitting at the "=" and "," then just grabbing the username for that string
singleMember := strings.Split(strings.Split(member, "=")[1], ",")[0]
ldapId := getDNProperty("uid", member)
if ldapId == "" {
continue
}
var databaseUser model.User
err := tx.WithContext(ctx).Where("username = ? AND ldap_id IS NOT NULL", singleMember).First(&databaseUser).Error
if err != nil {
err = tx.
WithContext(ctx).
Where("username = ? AND ldap_id IS NOT NULL", ldapId).
First(&databaseUser).
Error
if errors.Is(err, gorm.ErrRecordNotFound) {
// The user collides with a non-LDAP user, so we skip it
continue
} else {
return err
}
} else if err != nil {
return fmt.Errorf("failed to query for existing user '%s': %w", ldapId, err)
}
membersUserId = append(membersUserId, databaseUser.ID)
@@ -165,26 +176,22 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error {
if databaseGroup.ID == "" {
newGroup, err := s.groupService.createInternal(ctx, syncGroup, tx)
if err != nil {
log.Printf("Error syncing group %s: %v", syncGroup.Name, err)
continue
return fmt.Errorf("failed to create group '%s': %w", syncGroup.Name, err)
}
_, err = s.groupService.updateUsersInternal(ctx, newGroup.ID, membersUserId, tx)
if err != nil {
log.Printf("Error syncing group %s: %v", syncGroup.Name, err)
continue
return fmt.Errorf("failed to sync users for group '%s': %w", syncGroup.Name, err)
}
} else {
_, err = s.groupService.updateInternal(ctx, databaseGroup.ID, syncGroup, true, tx)
if err != nil {
log.Printf("Error syncing group %s: %v", syncGroup.Name, err)
continue
return fmt.Errorf("failed to update group '%s': %w", syncGroup.Name, err)
}
_, err = s.groupService.updateUsersInternal(ctx, databaseGroup.ID, membersUserId, tx)
if err != nil {
log.Printf("Error syncing group %s: %v", syncGroup.Name, err)
continue
return fmt.Errorf("failed to sync users for group '%s': %w", syncGroup.Name, err)
}
}
}
@@ -197,38 +204,33 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error {
Select("ldap_id").
Error
if err != nil {
log.Printf("Failed to fetch groups from database: %v", err)
return fmt.Errorf("failed to fetch groups from database: %w", err)
}
// Delete groups that no longer exist in LDAP
for _, group := range ldapGroupsInDb {
if _, exists := ldapGroupIDs[*group.LdapID]; !exists {
if _, exists := ldapGroupIDs[*group.LdapID]; exists {
continue
}
err = tx.
WithContext(ctx).
Delete(&model.UserGroup{}, "ldap_id = ?", group.LdapID).
Error
if err != nil {
log.Printf("Failed to delete group %s with: %v", group.Name, err)
} else {
log.Printf("Deleted group %s", group.Name)
}
return fmt.Errorf("failed to delete group '%s': %w", group.Name, err)
}
log.Printf("Deleted group '%s'", group.Name)
}
return nil
}
//nolint:gocognit
func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error {
func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.Conn) error {
dbConfig := s.appConfigService.GetDbConfig()
// Setup LDAP connection
client, err := s.createClient()
if err != nil {
return fmt.Errorf("failed to create LDAP client: %w", err)
}
defer client.Close()
searchAttrs := []string{
"memberOf",
"sn",
@@ -253,11 +255,11 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error {
result, err := client.Search(searchReq)
if err != nil {
fmt.Println(fmt.Errorf("failed to query LDAP: %w", err))
return fmt.Errorf("failed to query LDAP: %w", err)
}
// Create a mapping for users that exist
ldapUserIDs := make(map[string]bool)
ldapUserIDs := make(map[string]struct{}, len(result.Entries))
for _, value := range result.Entries {
ldapId := value.GetAttributeValue(dbConfig.LdapAttributeUserUniqueIdentifier.Value)
@@ -268,17 +270,26 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error {
continue
}
ldapUserIDs[ldapId] = true
ldapUserIDs[ldapId] = struct{}{}
// Get the user from the database
var databaseUser model.User
tx.WithContext(ctx).Where("ldap_id = ?", ldapId).First(&databaseUser)
err = tx.
WithContext(ctx).
Where("ldap_id = ?", ldapId).
First(&databaseUser).
Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
// This could error with ErrRecordNotFound and we want to ignore that here
return fmt.Errorf("failed to query for LDAP user ID '%s': %w", ldapId, err)
}
// Check if user is admin by checking if they are in the admin group
isAdmin := false
for _, group := range value.GetAttributeValues("memberOf") {
if strings.Contains(group, dbConfig.LdapAttributeAdminGroup.Value) {
if getDNProperty("cn", group) == dbConfig.LdapAttributeAdminGroup.Value {
isAdmin = true
break
}
}
@@ -292,20 +303,29 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error {
}
if databaseUser.ID == "" {
_, err = s.userService.createUserInternal(ctx, newUser, tx)
if err != nil {
log.Printf("Error syncing user %s: %v", newUser.Username, err)
_, err = s.userService.createUserInternal(ctx, newUser, true, tx)
if errors.Is(err, &common.AlreadyInUseError{}) {
log.Printf("Skipping creating LDAP user '%s': %v", newUser.Username, err)
continue
} else if err != nil {
return fmt.Errorf("error creating user '%s': %w", newUser.Username, err)
}
} else {
_, err = s.userService.updateUserInternal(ctx, databaseUser.ID, newUser, false, true, tx)
if err != nil {
log.Printf("Error syncing user %s: %v", newUser.Username, err)
if errors.Is(err, &common.AlreadyInUseError{}) {
log.Printf("Skipping updating LDAP user '%s': %v", newUser.Username, err)
continue
} else if err != nil {
return fmt.Errorf("error updating user '%s': %w", newUser.Username, err)
}
}
// Save profile picture
if pictureString := value.GetAttributeValue(dbConfig.LdapAttributeUserProfilePicture.Value); pictureString != "" {
if err := s.saveProfilePicture(ctx, databaseUser.ID, pictureString); err != nil {
pictureString := value.GetAttributeValue(dbConfig.LdapAttributeUserProfilePicture.Value)
if pictureString != "" {
err = s.saveProfilePicture(ctx, databaseUser.ID, pictureString)
if err != nil {
// This is not a fatal error
log.Printf("Error saving profile picture for user %s: %v", newUser.Username, err)
}
}
@@ -319,18 +339,21 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error {
Select("ldap_id").
Error
if err != nil {
log.Printf("Failed to fetch users from database: %v", err)
return fmt.Errorf("failed to fetch users from database: %w", err)
}
// Delete users that no longer exist in LDAP
for _, user := range ldapUsersInDb {
if _, exists := ldapUserIDs[*user.LdapID]; !exists {
if err := s.userService.deleteUserInternal(ctx, user.ID, true, tx); err != nil {
log.Printf("Failed to delete user %s with: %v", user.Username, err)
} else {
log.Printf("Deleted user %s", user.Username)
if _, exists := ldapUserIDs[*user.LdapID]; exists {
continue
}
err = s.userService.deleteUserInternal(ctx, user.ID, true, tx)
if err != nil {
return fmt.Errorf("failed to delete user '%s': %w", user.Username, err)
}
log.Printf("Deleted user '%s'", user.Username)
}
return nil
@@ -367,9 +390,28 @@ func (s *LdapService) saveProfilePicture(parentCtx context.Context, userId strin
}
// Update the profile picture
if err := s.userService.UpdateProfilePicture(userId, reader); err != nil {
err = s.userService.UpdateProfilePicture(userId, reader)
if err != nil {
return fmt.Errorf("failed to update profile picture: %w", err)
}
return nil
}
// getDNProperty returns the value of a property from a LDAP identifier
// See: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/ldap/distinguished-names
func getDNProperty(property string, str string) string {
// Example format is "CN=username,ou=people,dc=example,dc=com"
// First we split at the comma
property = strings.ToLower(property)
l := len(property) + 1
for _, v := range strings.Split(str, ",") {
v = strings.TrimSpace(v)
if len(v) > l && strings.ToLower(v)[0:l] == property+"=" {
return v[l:]
}
}
// CN not found, return an empty string
return ""
}

View File

@@ -0,0 +1,73 @@
package service
import (
"testing"
)
func TestGetDNProperty(t *testing.T) {
tests := []struct {
name string
property string
dn string
expectedResult string
}{
{
name: "simple case",
property: "cn",
dn: "cn=username,ou=people,dc=example,dc=com",
expectedResult: "username",
},
{
name: "property not found",
property: "uid",
dn: "cn=username,ou=people,dc=example,dc=com",
expectedResult: "",
},
{
name: "mixed case property",
property: "CN",
dn: "cn=username,ou=people,dc=example,dc=com",
expectedResult: "username",
},
{
name: "mixed case DN",
property: "cn",
dn: "CN=username,OU=people,DC=example,DC=com",
expectedResult: "username",
},
{
name: "spaces in DN",
property: "cn",
dn: "cn=username, ou=people, dc=example, dc=com",
expectedResult: "username",
},
{
name: "value with special characters",
property: "cn",
dn: "cn=user.name+123,ou=people,dc=example,dc=com",
expectedResult: "user.name+123",
},
{
name: "empty DN",
property: "cn",
dn: "",
expectedResult: "",
},
{
name: "empty property",
property: "",
dn: "cn=username,ou=people,dc=example,dc=com",
expectedResult: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := getDNProperty(tt.property, tt.dn)
if result != tt.expectedResult {
t.Errorf("getDNProperty(%q, %q) = %q, want %q",
tt.property, tt.dn, result, tt.expectedResult)
}
})
}
}

View File

@@ -122,13 +122,13 @@ func (s *UserGroupService) createInternal(ctx context.Context, input dto.UserGro
return group, nil
}
func (s *UserGroupService) Update(ctx context.Context, id string, input dto.UserGroupCreateDto, allowLdapUpdate bool) (group model.UserGroup, err error) {
func (s *UserGroupService) Update(ctx context.Context, id string, input dto.UserGroupCreateDto) (group model.UserGroup, err error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
group, err = s.updateInternal(ctx, id, input, allowLdapUpdate, tx)
group, err = s.updateInternal(ctx, id, input, false, tx)
if err != nil {
return model.UserGroup{}, err
}
@@ -141,14 +141,14 @@ func (s *UserGroupService) Update(ctx context.Context, id string, input dto.User
return group, nil
}
func (s *UserGroupService) updateInternal(ctx context.Context, id string, input dto.UserGroupCreateDto, allowLdapUpdate bool, tx *gorm.DB) (group model.UserGroup, err error) {
func (s *UserGroupService) updateInternal(ctx context.Context, id string, input dto.UserGroupCreateDto, isLdapSync bool, tx *gorm.DB) (group model.UserGroup, err error) {
group, err = s.getInternal(ctx, id, tx)
if err != nil {
return model.UserGroup{}, err
}
// Disallow updating the group if it is an LDAP group and LDAP is enabled
if !allowLdapUpdate && group.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
if !isLdapSync && group.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
return model.UserGroup{}, &common.LdapUserGroupUpdateError{}
}
@@ -160,10 +160,9 @@ func (s *UserGroupService) updateInternal(ctx context.Context, id string, input
Preload("Users").
Save(&group).
Error
if err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
return model.UserGroup{}, &common.AlreadyInUseError{Property: "name"}
}
} else if err != nil {
return model.UserGroup{}, err
}
return group, nil

View File

@@ -184,7 +184,7 @@ func (s *UserService) deleteUserInternal(ctx context.Context, userID string, all
First(&user).
Error
if err != nil {
return err
return fmt.Errorf("failed to load user to delete: %w", err)
}
// Disallow deleting the user if it is an LDAP user and LDAP is enabled
@@ -199,7 +199,12 @@ func (s *UserService) deleteUserInternal(ctx context.Context, userID string, all
return err
}
return tx.WithContext(ctx).Delete(&user).Error
err = tx.WithContext(ctx).Delete(&user).Error
if err != nil {
return fmt.Errorf("failed to delete user: %w", err)
}
return nil
}
func (s *UserService) CreateUser(ctx context.Context, input dto.UserCreateDto) (model.User, error) {
@@ -208,7 +213,7 @@ func (s *UserService) CreateUser(ctx context.Context, input dto.UserCreateDto) (
tx.Rollback()
}()
user, err := s.createUserInternal(ctx, input, tx)
user, err := s.createUserInternal(ctx, input, false, tx)
if err != nil {
return model.User{}, err
}
@@ -221,7 +226,7 @@ func (s *UserService) CreateUser(ctx context.Context, input dto.UserCreateDto) (
return user, nil
}
func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCreateDto, tx *gorm.DB) (model.User, error) {
func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCreateDto, isLdapSync bool, tx *gorm.DB) (model.User, error) {
user := model.User{
FirstName: input.FirstName,
LastName: input.LastName,
@@ -236,10 +241,15 @@ func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCrea
err := tx.WithContext(ctx).Create(&user).Error
if errors.Is(err, gorm.ErrDuplicatedKey) {
// Do not follow this path if we're using LDAP, as we don't want to roll-back the transaction here
if !isLdapSync {
tx.Rollback()
// If we are here, the transaction is already aborted due to an error, so we pass s.db
err = s.checkDuplicatedFields(ctx, user, s.db)
} else {
err = s.checkDuplicatedFields(ctx, user, tx)
}
return model.User{}, err
} else if err != nil {
return model.User{}, err
@@ -266,7 +276,7 @@ func (s *UserService) UpdateUser(ctx context.Context, userID string, updatedUser
return user, nil
}
func (s *UserService) updateUserInternal(ctx context.Context, userID string, updatedUser dto.UserCreateDto, updateOwnUser bool, allowLdapUpdate bool, tx *gorm.DB) (model.User, error) {
func (s *UserService) updateUserInternal(ctx context.Context, userID string, updatedUser dto.UserCreateDto, updateOwnUser bool, isLdapSync bool, tx *gorm.DB) (model.User, error) {
var user model.User
err := tx.
WithContext(ctx).
@@ -278,7 +288,7 @@ func (s *UserService) updateUserInternal(ctx context.Context, userID string, upd
}
// Disallow updating the user if it is an LDAP group and LDAP is enabled
if !allowLdapUpdate && user.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
if !isLdapSync && user.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
return model.User{}, &common.LdapUserUpdateError{}
}
@@ -296,10 +306,15 @@ func (s *UserService) updateUserInternal(ctx context.Context, userID string, upd
Save(&user).
Error
if errors.Is(err, gorm.ErrDuplicatedKey) {
// Do not follow this path if we're using LDAP, as we don't want to roll-back the transaction here
if !isLdapSync {
tx.Rollback()
// If we are here, the transaction is already aborted due to an error, so we pass s.db
err = s.checkDuplicatedFields(ctx, user, s.db)
} else {
err = s.checkDuplicatedFields(ctx, user, tx)
}
return user, err
} else if err != nil {
return user, err