mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-18 01:11:26 +03:00
fix: improve LDAP error handling (#425)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
This commit is contained in:
committed by
GitHub
parent
72061ba427
commit
796bc7ed34
@@ -1,6 +1,7 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
@@ -21,6 +22,12 @@ func (e *AlreadyInUseError) Error() string {
|
||||
}
|
||||
func (e *AlreadyInUseError) HttpStatusCode() int { return 400 }
|
||||
|
||||
func (e *AlreadyInUseError) Is(target error) bool {
|
||||
// Ignore the field property when checking if an error is of the type AlreadyInUseError
|
||||
x := &AlreadyInUseError{}
|
||||
return errors.As(target, &x)
|
||||
}
|
||||
|
||||
type SetupAlreadyCompletedError struct{}
|
||||
|
||||
func (e *SetupAlreadyCompletedError) Error() string { return "setup already completed" }
|
||||
|
||||
@@ -160,7 +160,7 @@ func (ugc *UserGroupController) update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
group, err := ugc.UserGroupService.Update(c.Request.Context(), c.Param("id"), input, false)
|
||||
group, err := ugc.UserGroupService.Update(c.Request.Context(), c.Param("id"), input)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/go-ldap/ldap/v3"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
"gorm.io/gorm"
|
||||
@@ -28,7 +29,12 @@ type LdapService struct {
|
||||
}
|
||||
|
||||
func NewLdapService(db *gorm.DB, appConfigService *AppConfigService, userService *UserService, groupService *UserGroupService) *LdapService {
|
||||
return &LdapService{db: db, appConfigService: appConfigService, userService: userService, groupService: groupService}
|
||||
return &LdapService{
|
||||
db: db,
|
||||
appConfigService: appConfigService,
|
||||
userService: userService,
|
||||
groupService: groupService,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *LdapService) createClient() (*ldap.Conn, error) {
|
||||
@@ -39,19 +45,15 @@ func (s *LdapService) createClient() (*ldap.Conn, error) {
|
||||
}
|
||||
|
||||
// Setup LDAP connection
|
||||
ldapURL := dbConfig.LdapUrl.Value
|
||||
skipTLSVerify := dbConfig.LdapSkipCertVerify.IsTrue()
|
||||
client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{
|
||||
InsecureSkipVerify: skipTLSVerify, //nolint:gosec
|
||||
client, err := ldap.DialURL(dbConfig.LdapUrl.Value, ldap.DialWithTLSConfig(&tls.Config{
|
||||
InsecureSkipVerify: dbConfig.LdapSkipCertVerify.IsTrue(), //nolint:gosec
|
||||
}))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to LDAP: %w", err)
|
||||
}
|
||||
|
||||
// Bind as service account
|
||||
bindDn := dbConfig.LdapBindDn.Value
|
||||
bindPassword := dbConfig.LdapBindPassword.Value
|
||||
err = client.Bind(bindDn, bindPassword)
|
||||
err = client.Bind(dbConfig.LdapBindDn.Value, dbConfig.LdapBindPassword.Value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to bind to LDAP: %w", err)
|
||||
}
|
||||
@@ -65,12 +67,19 @@ func (s *LdapService) SyncAll(ctx context.Context) error {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
err := s.SyncUsers(ctx, tx)
|
||||
// Setup LDAP connection
|
||||
client, err := s.createClient()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create LDAP client: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
err = s.SyncUsers(ctx, tx, client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sync users: %w", err)
|
||||
}
|
||||
|
||||
err = s.SyncGroups(ctx, tx)
|
||||
err = s.SyncGroups(ctx, tx, client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sync groups: %w", err)
|
||||
}
|
||||
@@ -85,16 +94,9 @@ func (s *LdapService) SyncAll(ctx context.Context) error {
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error {
|
||||
func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.Conn) error {
|
||||
dbConfig := s.appConfigService.GetDbConfig()
|
||||
|
||||
// Setup LDAP connection
|
||||
client, err := s.createClient()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create LDAP client: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
searchAttrs := []string{
|
||||
dbConfig.LdapAttributeGroupName.Value,
|
||||
dbConfig.LdapAttributeGroupUniqueIdentifier.Value,
|
||||
@@ -115,11 +117,9 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error {
|
||||
}
|
||||
|
||||
// Create a mapping for groups that exist
|
||||
ldapGroupIDs := make(map[string]bool)
|
||||
ldapGroupIDs := make(map[string]struct{}, len(result.Entries))
|
||||
|
||||
for _, value := range result.Entries {
|
||||
var membersUserId []string
|
||||
|
||||
ldapId := value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value)
|
||||
|
||||
// Skip groups without a valid LDAP ID
|
||||
@@ -128,29 +128,40 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error {
|
||||
continue
|
||||
}
|
||||
|
||||
ldapGroupIDs[ldapId] = true
|
||||
ldapGroupIDs[ldapId] = struct{}{}
|
||||
|
||||
// Try to find the group in the database
|
||||
var databaseGroup model.UserGroup
|
||||
tx.WithContext(ctx).Where("ldap_id = ?", ldapId).First(&databaseGroup)
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Where("ldap_id = ?", ldapId).
|
||||
First(&databaseGroup).
|
||||
Error
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// This could error with ErrRecordNotFound and we want to ignore that here
|
||||
return fmt.Errorf("failed to query for LDAP group ID '%s': %w", ldapId, err)
|
||||
}
|
||||
|
||||
// Get group members and add to the correct Group
|
||||
groupMembers := value.GetAttributeValues(dbConfig.LdapAttributeGroupMember.Value)
|
||||
membersUserId := make([]string, 0, len(groupMembers))
|
||||
for _, member := range groupMembers {
|
||||
// Normal output of this would be CN=username,ou=people,dc=example,dc=com
|
||||
// Splitting at the "=" and "," then just grabbing the username for that string
|
||||
singleMember := strings.Split(strings.Split(member, "=")[1], ",")[0]
|
||||
ldapId := getDNProperty("uid", member)
|
||||
if ldapId == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var databaseUser model.User
|
||||
err := tx.WithContext(ctx).Where("username = ? AND ldap_id IS NOT NULL", singleMember).First(&databaseUser).Error
|
||||
if err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Where("username = ? AND ldap_id IS NOT NULL", ldapId).
|
||||
First(&databaseUser).
|
||||
Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// The user collides with a non-LDAP user, so we skip it
|
||||
continue
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to query for existing user '%s': %w", ldapId, err)
|
||||
}
|
||||
|
||||
membersUserId = append(membersUserId, databaseUser.ID)
|
||||
@@ -165,26 +176,22 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error {
|
||||
if databaseGroup.ID == "" {
|
||||
newGroup, err := s.groupService.createInternal(ctx, syncGroup, tx)
|
||||
if err != nil {
|
||||
log.Printf("Error syncing group %s: %v", syncGroup.Name, err)
|
||||
continue
|
||||
return fmt.Errorf("failed to create group '%s': %w", syncGroup.Name, err)
|
||||
}
|
||||
|
||||
_, err = s.groupService.updateUsersInternal(ctx, newGroup.ID, membersUserId, tx)
|
||||
if err != nil {
|
||||
log.Printf("Error syncing group %s: %v", syncGroup.Name, err)
|
||||
continue
|
||||
return fmt.Errorf("failed to sync users for group '%s': %w", syncGroup.Name, err)
|
||||
}
|
||||
} else {
|
||||
_, err = s.groupService.updateInternal(ctx, databaseGroup.ID, syncGroup, true, tx)
|
||||
if err != nil {
|
||||
log.Printf("Error syncing group %s: %v", syncGroup.Name, err)
|
||||
continue
|
||||
return fmt.Errorf("failed to update group '%s': %w", syncGroup.Name, err)
|
||||
}
|
||||
|
||||
_, err = s.groupService.updateUsersInternal(ctx, databaseGroup.ID, membersUserId, tx)
|
||||
if err != nil {
|
||||
log.Printf("Error syncing group %s: %v", syncGroup.Name, err)
|
||||
continue
|
||||
return fmt.Errorf("failed to sync users for group '%s': %w", syncGroup.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -197,38 +204,33 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error {
|
||||
Select("ldap_id").
|
||||
Error
|
||||
if err != nil {
|
||||
log.Printf("Failed to fetch groups from database: %v", err)
|
||||
return fmt.Errorf("failed to fetch groups from database: %w", err)
|
||||
}
|
||||
|
||||
// Delete groups that no longer exist in LDAP
|
||||
for _, group := range ldapGroupsInDb {
|
||||
if _, exists := ldapGroupIDs[*group.LdapID]; !exists {
|
||||
if _, exists := ldapGroupIDs[*group.LdapID]; exists {
|
||||
continue
|
||||
}
|
||||
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Delete(&model.UserGroup{}, "ldap_id = ?", group.LdapID).
|
||||
Error
|
||||
if err != nil {
|
||||
log.Printf("Failed to delete group %s with: %v", group.Name, err)
|
||||
} else {
|
||||
log.Printf("Deleted group %s", group.Name)
|
||||
}
|
||||
return fmt.Errorf("failed to delete group '%s': %w", group.Name, err)
|
||||
}
|
||||
|
||||
log.Printf("Deleted group '%s'", group.Name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error {
|
||||
func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.Conn) error {
|
||||
dbConfig := s.appConfigService.GetDbConfig()
|
||||
|
||||
// Setup LDAP connection
|
||||
client, err := s.createClient()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create LDAP client: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
searchAttrs := []string{
|
||||
"memberOf",
|
||||
"sn",
|
||||
@@ -253,11 +255,11 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error {
|
||||
|
||||
result, err := client.Search(searchReq)
|
||||
if err != nil {
|
||||
fmt.Println(fmt.Errorf("failed to query LDAP: %w", err))
|
||||
return fmt.Errorf("failed to query LDAP: %w", err)
|
||||
}
|
||||
|
||||
// Create a mapping for users that exist
|
||||
ldapUserIDs := make(map[string]bool)
|
||||
ldapUserIDs := make(map[string]struct{}, len(result.Entries))
|
||||
|
||||
for _, value := range result.Entries {
|
||||
ldapId := value.GetAttributeValue(dbConfig.LdapAttributeUserUniqueIdentifier.Value)
|
||||
@@ -268,17 +270,26 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error {
|
||||
continue
|
||||
}
|
||||
|
||||
ldapUserIDs[ldapId] = true
|
||||
ldapUserIDs[ldapId] = struct{}{}
|
||||
|
||||
// Get the user from the database
|
||||
var databaseUser model.User
|
||||
tx.WithContext(ctx).Where("ldap_id = ?", ldapId).First(&databaseUser)
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Where("ldap_id = ?", ldapId).
|
||||
First(&databaseUser).
|
||||
Error
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// This could error with ErrRecordNotFound and we want to ignore that here
|
||||
return fmt.Errorf("failed to query for LDAP user ID '%s': %w", ldapId, err)
|
||||
}
|
||||
|
||||
// Check if user is admin by checking if they are in the admin group
|
||||
isAdmin := false
|
||||
for _, group := range value.GetAttributeValues("memberOf") {
|
||||
if strings.Contains(group, dbConfig.LdapAttributeAdminGroup.Value) {
|
||||
if getDNProperty("cn", group) == dbConfig.LdapAttributeAdminGroup.Value {
|
||||
isAdmin = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
@@ -292,20 +303,29 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error {
|
||||
}
|
||||
|
||||
if databaseUser.ID == "" {
|
||||
_, err = s.userService.createUserInternal(ctx, newUser, tx)
|
||||
if err != nil {
|
||||
log.Printf("Error syncing user %s: %v", newUser.Username, err)
|
||||
_, err = s.userService.createUserInternal(ctx, newUser, true, tx)
|
||||
if errors.Is(err, &common.AlreadyInUseError{}) {
|
||||
log.Printf("Skipping creating LDAP user '%s': %v", newUser.Username, err)
|
||||
continue
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("error creating user '%s': %w", newUser.Username, err)
|
||||
}
|
||||
} else {
|
||||
_, err = s.userService.updateUserInternal(ctx, databaseUser.ID, newUser, false, true, tx)
|
||||
if err != nil {
|
||||
log.Printf("Error syncing user %s: %v", newUser.Username, err)
|
||||
if errors.Is(err, &common.AlreadyInUseError{}) {
|
||||
log.Printf("Skipping updating LDAP user '%s': %v", newUser.Username, err)
|
||||
continue
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("error updating user '%s': %w", newUser.Username, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Save profile picture
|
||||
if pictureString := value.GetAttributeValue(dbConfig.LdapAttributeUserProfilePicture.Value); pictureString != "" {
|
||||
if err := s.saveProfilePicture(ctx, databaseUser.ID, pictureString); err != nil {
|
||||
pictureString := value.GetAttributeValue(dbConfig.LdapAttributeUserProfilePicture.Value)
|
||||
if pictureString != "" {
|
||||
err = s.saveProfilePicture(ctx, databaseUser.ID, pictureString)
|
||||
if err != nil {
|
||||
// This is not a fatal error
|
||||
log.Printf("Error saving profile picture for user %s: %v", newUser.Username, err)
|
||||
}
|
||||
}
|
||||
@@ -319,18 +339,21 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error {
|
||||
Select("ldap_id").
|
||||
Error
|
||||
if err != nil {
|
||||
log.Printf("Failed to fetch users from database: %v", err)
|
||||
return fmt.Errorf("failed to fetch users from database: %w", err)
|
||||
}
|
||||
|
||||
// Delete users that no longer exist in LDAP
|
||||
for _, user := range ldapUsersInDb {
|
||||
if _, exists := ldapUserIDs[*user.LdapID]; !exists {
|
||||
if err := s.userService.deleteUserInternal(ctx, user.ID, true, tx); err != nil {
|
||||
log.Printf("Failed to delete user %s with: %v", user.Username, err)
|
||||
} else {
|
||||
log.Printf("Deleted user %s", user.Username)
|
||||
if _, exists := ldapUserIDs[*user.LdapID]; exists {
|
||||
continue
|
||||
}
|
||||
|
||||
err = s.userService.deleteUserInternal(ctx, user.ID, true, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete user '%s': %w", user.Username, err)
|
||||
}
|
||||
|
||||
log.Printf("Deleted user '%s'", user.Username)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -367,9 +390,28 @@ func (s *LdapService) saveProfilePicture(parentCtx context.Context, userId strin
|
||||
}
|
||||
|
||||
// Update the profile picture
|
||||
if err := s.userService.UpdateProfilePicture(userId, reader); err != nil {
|
||||
err = s.userService.UpdateProfilePicture(userId, reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update profile picture: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getDNProperty returns the value of a property from a LDAP identifier
|
||||
// See: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/ldap/distinguished-names
|
||||
func getDNProperty(property string, str string) string {
|
||||
// Example format is "CN=username,ou=people,dc=example,dc=com"
|
||||
// First we split at the comma
|
||||
property = strings.ToLower(property)
|
||||
l := len(property) + 1
|
||||
for _, v := range strings.Split(str, ",") {
|
||||
v = strings.TrimSpace(v)
|
||||
if len(v) > l && strings.ToLower(v)[0:l] == property+"=" {
|
||||
return v[l:]
|
||||
}
|
||||
}
|
||||
|
||||
// CN not found, return an empty string
|
||||
return ""
|
||||
}
|
||||
|
||||
73
backend/internal/service/ldap_service_test.go
Normal file
73
backend/internal/service/ldap_service_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetDNProperty(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
property string
|
||||
dn string
|
||||
expectedResult string
|
||||
}{
|
||||
{
|
||||
name: "simple case",
|
||||
property: "cn",
|
||||
dn: "cn=username,ou=people,dc=example,dc=com",
|
||||
expectedResult: "username",
|
||||
},
|
||||
{
|
||||
name: "property not found",
|
||||
property: "uid",
|
||||
dn: "cn=username,ou=people,dc=example,dc=com",
|
||||
expectedResult: "",
|
||||
},
|
||||
{
|
||||
name: "mixed case property",
|
||||
property: "CN",
|
||||
dn: "cn=username,ou=people,dc=example,dc=com",
|
||||
expectedResult: "username",
|
||||
},
|
||||
{
|
||||
name: "mixed case DN",
|
||||
property: "cn",
|
||||
dn: "CN=username,OU=people,DC=example,DC=com",
|
||||
expectedResult: "username",
|
||||
},
|
||||
{
|
||||
name: "spaces in DN",
|
||||
property: "cn",
|
||||
dn: "cn=username, ou=people, dc=example, dc=com",
|
||||
expectedResult: "username",
|
||||
},
|
||||
{
|
||||
name: "value with special characters",
|
||||
property: "cn",
|
||||
dn: "cn=user.name+123,ou=people,dc=example,dc=com",
|
||||
expectedResult: "user.name+123",
|
||||
},
|
||||
{
|
||||
name: "empty DN",
|
||||
property: "cn",
|
||||
dn: "",
|
||||
expectedResult: "",
|
||||
},
|
||||
{
|
||||
name: "empty property",
|
||||
property: "",
|
||||
dn: "cn=username,ou=people,dc=example,dc=com",
|
||||
expectedResult: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := getDNProperty(tt.property, tt.dn)
|
||||
if result != tt.expectedResult {
|
||||
t.Errorf("getDNProperty(%q, %q) = %q, want %q",
|
||||
tt.property, tt.dn, result, tt.expectedResult)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -122,13 +122,13 @@ func (s *UserGroupService) createInternal(ctx context.Context, input dto.UserGro
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func (s *UserGroupService) Update(ctx context.Context, id string, input dto.UserGroupCreateDto, allowLdapUpdate bool) (group model.UserGroup, err error) {
|
||||
func (s *UserGroupService) Update(ctx context.Context, id string, input dto.UserGroupCreateDto) (group model.UserGroup, err error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
group, err = s.updateInternal(ctx, id, input, allowLdapUpdate, tx)
|
||||
group, err = s.updateInternal(ctx, id, input, false, tx)
|
||||
if err != nil {
|
||||
return model.UserGroup{}, err
|
||||
}
|
||||
@@ -141,14 +141,14 @@ func (s *UserGroupService) Update(ctx context.Context, id string, input dto.User
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func (s *UserGroupService) updateInternal(ctx context.Context, id string, input dto.UserGroupCreateDto, allowLdapUpdate bool, tx *gorm.DB) (group model.UserGroup, err error) {
|
||||
func (s *UserGroupService) updateInternal(ctx context.Context, id string, input dto.UserGroupCreateDto, isLdapSync bool, tx *gorm.DB) (group model.UserGroup, err error) {
|
||||
group, err = s.getInternal(ctx, id, tx)
|
||||
if err != nil {
|
||||
return model.UserGroup{}, err
|
||||
}
|
||||
|
||||
// Disallow updating the group if it is an LDAP group and LDAP is enabled
|
||||
if !allowLdapUpdate && group.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
|
||||
if !isLdapSync && group.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
|
||||
return model.UserGroup{}, &common.LdapUserGroupUpdateError{}
|
||||
}
|
||||
|
||||
@@ -160,10 +160,9 @@ func (s *UserGroupService) updateInternal(ctx context.Context, id string, input
|
||||
Preload("Users").
|
||||
Save(&group).
|
||||
Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
return model.UserGroup{}, &common.AlreadyInUseError{Property: "name"}
|
||||
}
|
||||
} else if err != nil {
|
||||
return model.UserGroup{}, err
|
||||
}
|
||||
return group, nil
|
||||
|
||||
@@ -184,7 +184,7 @@ func (s *UserService) deleteUserInternal(ctx context.Context, userID string, all
|
||||
First(&user).
|
||||
Error
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to load user to delete: %w", err)
|
||||
}
|
||||
|
||||
// Disallow deleting the user if it is an LDAP user and LDAP is enabled
|
||||
@@ -199,7 +199,12 @@ func (s *UserService) deleteUserInternal(ctx context.Context, userID string, all
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.WithContext(ctx).Delete(&user).Error
|
||||
err = tx.WithContext(ctx).Delete(&user).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete user: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UserService) CreateUser(ctx context.Context, input dto.UserCreateDto) (model.User, error) {
|
||||
@@ -208,7 +213,7 @@ func (s *UserService) CreateUser(ctx context.Context, input dto.UserCreateDto) (
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
user, err := s.createUserInternal(ctx, input, tx)
|
||||
user, err := s.createUserInternal(ctx, input, false, tx)
|
||||
if err != nil {
|
||||
return model.User{}, err
|
||||
}
|
||||
@@ -221,7 +226,7 @@ func (s *UserService) CreateUser(ctx context.Context, input dto.UserCreateDto) (
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCreateDto, tx *gorm.DB) (model.User, error) {
|
||||
func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCreateDto, isLdapSync bool, tx *gorm.DB) (model.User, error) {
|
||||
user := model.User{
|
||||
FirstName: input.FirstName,
|
||||
LastName: input.LastName,
|
||||
@@ -236,10 +241,15 @@ func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCrea
|
||||
|
||||
err := tx.WithContext(ctx).Create(&user).Error
|
||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
// Do not follow this path if we're using LDAP, as we don't want to roll-back the transaction here
|
||||
if !isLdapSync {
|
||||
tx.Rollback()
|
||||
|
||||
// If we are here, the transaction is already aborted due to an error, so we pass s.db
|
||||
err = s.checkDuplicatedFields(ctx, user, s.db)
|
||||
} else {
|
||||
err = s.checkDuplicatedFields(ctx, user, tx)
|
||||
}
|
||||
|
||||
return model.User{}, err
|
||||
} else if err != nil {
|
||||
return model.User{}, err
|
||||
@@ -266,7 +276,7 @@ func (s *UserService) UpdateUser(ctx context.Context, userID string, updatedUser
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *UserService) updateUserInternal(ctx context.Context, userID string, updatedUser dto.UserCreateDto, updateOwnUser bool, allowLdapUpdate bool, tx *gorm.DB) (model.User, error) {
|
||||
func (s *UserService) updateUserInternal(ctx context.Context, userID string, updatedUser dto.UserCreateDto, updateOwnUser bool, isLdapSync bool, tx *gorm.DB) (model.User, error) {
|
||||
var user model.User
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
@@ -278,7 +288,7 @@ func (s *UserService) updateUserInternal(ctx context.Context, userID string, upd
|
||||
}
|
||||
|
||||
// Disallow updating the user if it is an LDAP group and LDAP is enabled
|
||||
if !allowLdapUpdate && user.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
|
||||
if !isLdapSync && user.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
|
||||
return model.User{}, &common.LdapUserUpdateError{}
|
||||
}
|
||||
|
||||
@@ -296,10 +306,15 @@ func (s *UserService) updateUserInternal(ctx context.Context, userID string, upd
|
||||
Save(&user).
|
||||
Error
|
||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
// Do not follow this path if we're using LDAP, as we don't want to roll-back the transaction here
|
||||
if !isLdapSync {
|
||||
tx.Rollback()
|
||||
|
||||
// If we are here, the transaction is already aborted due to an error, so we pass s.db
|
||||
err = s.checkDuplicatedFields(ctx, user, s.db)
|
||||
} else {
|
||||
err = s.checkDuplicatedFields(ctx, user, tx)
|
||||
}
|
||||
|
||||
return user, err
|
||||
} else if err != nil {
|
||||
return user, err
|
||||
|
||||
Reference in New Issue
Block a user