fix: use transactions when operations involve multiple database queries (#392)

Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
This commit is contained in:
Alessandro (Ale) Segala
2025-04-06 06:04:08 -07:00
committed by GitHub
parent c810fec8c4
commit ec626ee797
33 changed files with 1401 additions and 501 deletions

View File

@@ -38,7 +38,9 @@ func (s *LdapService) createClient() (*ldap.Conn, error) {
// Setup LDAP connection
ldapURL := s.appConfigService.DbConfig.LdapUrl.Value
skipTLSVerify := s.appConfigService.DbConfig.LdapSkipCertVerify.IsTrue()
client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{InsecureSkipVerify: skipTLSVerify})) //nolint:gosec
client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{
InsecureSkipVerify: skipTLSVerify, //nolint:gosec
}))
if err != nil {
return nil, fmt.Errorf("failed to connect to LDAP: %w", err)
}
@@ -53,22 +55,31 @@ func (s *LdapService) createClient() (*ldap.Conn, error) {
return client, nil
}
func (s *LdapService) SyncAll() error {
err := s.SyncUsers()
func (s *LdapService) SyncAll(ctx context.Context) error {
// Start a transaction
tx := s.db.Begin()
err := s.SyncUsers(ctx, tx)
if err != nil {
return fmt.Errorf("failed to sync users: %w", err)
}
err = s.SyncGroups()
err = s.SyncGroups(ctx, tx)
if err != nil {
return fmt.Errorf("failed to sync groups: %w", err)
}
// Commit the changes
err = tx.Commit().Error
if err != nil {
return fmt.Errorf("failed to commit changes to database: %w", err)
}
return nil
}
//nolint:gocognit
func (s *LdapService) SyncGroups() error {
func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error {
// Setup LDAP connection
client, err := s.createClient()
if err != nil {
@@ -112,7 +123,7 @@ func (s *LdapService) SyncGroups() error {
// Try to find the group in the database
var databaseGroup model.UserGroup
s.db.Where("ldap_id = ?", ldapId).First(&databaseGroup)
tx.WithContext(ctx).Where("ldap_id = ?", ldapId).First(&databaseGroup)
// Get group members and add to the correct Group
groupMembers := value.GetAttributeValues(groupMemberOfAttribute)
@@ -122,7 +133,7 @@ func (s *LdapService) SyncGroups() error {
singleMember := strings.Split(strings.Split(member, "=")[1], ",")[0]
var databaseUser model.User
err := s.db.Where("username = ? AND ldap_id IS NOT NULL", singleMember).First(&databaseUser).Error
err := tx.WithContext(ctx).Where("username = ? AND ldap_id IS NOT NULL", singleMember).First(&databaseUser).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
// The user collides with a non-LDAP user, so we skip it
@@ -143,39 +154,51 @@ func (s *LdapService) SyncGroups() error {
}
if databaseGroup.ID == "" {
newGroup, err := s.groupService.Create(syncGroup)
newGroup, err := s.groupService.createInternal(ctx, syncGroup, tx)
if err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
} else {
if _, err = s.groupService.UpdateUsers(newGroup.ID, membersUserId); err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
}
log.Printf("Error syncing group %s: %v", syncGroup.Name, err)
continue
}
_, err = s.groupService.updateUsersInternal(ctx, newGroup.ID, membersUserId, tx)
if err != nil {
log.Printf("Error syncing group %s: %v", syncGroup.Name, err)
continue
}
} else {
_, err = s.groupService.Update(databaseGroup.ID, syncGroup, true)
_, err = s.groupService.updateInternal(ctx, databaseGroup.ID, syncGroup, true, tx)
if err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
}
_, err = s.groupService.UpdateUsers(databaseGroup.ID, membersUserId)
if err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
return err
log.Printf("Error syncing group %s: %v", syncGroup.Name, err)
continue
}
_, err = s.groupService.updateUsersInternal(ctx, databaseGroup.ID, membersUserId, tx)
if err != nil {
log.Printf("Error syncing group %s: %v", syncGroup.Name, err)
continue
}
}
}
// Get all LDAP groups from the database
var ldapGroupsInDb []model.UserGroup
if err := s.db.Find(&ldapGroupsInDb, "ldap_id IS NOT NULL").Select("ldap_id").Error; err != nil {
fmt.Println(fmt.Errorf("failed to fetch groups from database: %w", err))
err = tx.
WithContext(ctx).
Find(&ldapGroupsInDb, "ldap_id IS NOT NULL").
Select("ldap_id").
Error
if err != nil {
log.Printf("Failed to fetch groups from database: %v", err)
}
// Delete groups that no longer exist in LDAP
for _, group := range ldapGroupsInDb {
if _, exists := ldapGroupIDs[*group.LdapID]; !exists {
if err := s.db.Delete(&model.UserGroup{}, "ldap_id = ?", group.LdapID).Error; err != nil {
err = tx.
WithContext(ctx).
Delete(&model.UserGroup{}, "ldap_id = ?", group.LdapID).
Error
if err != nil {
log.Printf("Failed to delete group %s with: %v", group.Name, err)
} else {
log.Printf("Deleted group %s", group.Name)
@@ -187,7 +210,7 @@ func (s *LdapService) SyncGroups() error {
}
//nolint:gocognit
func (s *LdapService) SyncUsers() error {
func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error {
// Setup LDAP connection
client, err := s.createClient()
if err != nil {
@@ -241,7 +264,7 @@ func (s *LdapService) SyncUsers() error {
// Get the user from the database
var databaseUser model.User
s.db.Where("ldap_id = ?", ldapId).First(&databaseUser)
tx.WithContext(ctx).Where("ldap_id = ?", ldapId).First(&databaseUser)
// Check if user is admin by checking if they are in the admin group
isAdmin := false
@@ -261,68 +284,75 @@ func (s *LdapService) SyncUsers() error {
}
if databaseUser.ID == "" {
_, err = s.userService.CreateUser(newUser)
_, err = s.userService.createUserInternal(ctx, newUser, tx)
if err != nil {
log.Printf("Error syncing user %s: %s", newUser.Username, err)
log.Printf("Error syncing user %s: %v", newUser.Username, err)
}
} else {
_, err = s.userService.UpdateUser(databaseUser.ID, newUser, false, true)
_, err = s.userService.updateUserInternal(ctx, databaseUser.ID, newUser, false, true, tx)
if err != nil {
log.Printf("Error syncing user %s: %s", newUser.Username, err)
log.Printf("Error syncing user %s: %v", newUser.Username, err)
}
}
// Save profile picture
if pictureString := value.GetAttributeValue(profilePictureAttribute); pictureString != "" {
if err := s.SaveProfilePicture(databaseUser.ID, pictureString); err != nil {
log.Printf("Error saving profile picture for user %s: %s", newUser.Username, err)
if err := s.saveProfilePicture(ctx, databaseUser.ID, pictureString); err != nil {
log.Printf("Error saving profile picture for user %s: %v", newUser.Username, err)
}
}
}
// Get all LDAP users from the database
var ldapUsersInDb []model.User
if err := s.db.Find(&ldapUsersInDb, "ldap_id IS NOT NULL").Select("ldap_id").Error; err != nil {
fmt.Println(fmt.Errorf("failed to fetch users from database: %w", err))
err = tx.
WithContext(ctx).
Find(&ldapUsersInDb, "ldap_id IS NOT NULL").
Select("ldap_id").
Error
if err != nil {
log.Printf("Failed to fetch users from database: %v", err)
}
// Delete users that no longer exist in LDAP
for _, user := range ldapUsersInDb {
if _, exists := ldapUserIDs[*user.LdapID]; !exists {
if err := s.userService.DeleteUser(user.ID, true); err != nil {
if err := s.userService.deleteUserInternal(ctx, user.ID, true, tx); err != nil {
log.Printf("Failed to delete user %s with: %v", user.Username, err)
} else {
log.Printf("Deleted user %s", user.Username)
}
}
}
return nil
}
func (s *LdapService) SaveProfilePicture(userId string, pictureString string) error {
func (s *LdapService) saveProfilePicture(parentCtx context.Context, userId string, pictureString string) error {
var reader io.Reader
if _, err := url.ParseRequestURI(pictureString); err == nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
_, err := url.ParseRequestURI(pictureString)
if err == nil {
ctx, cancel := context.WithTimeout(parentCtx, 5*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, pictureString, nil)
var req *http.Request
req, err = http.NewRequestWithContext(ctx, http.MethodGet, pictureString, nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
response, err := http.DefaultClient.Do(req)
var res *http.Response
res, err = http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("failed to download profile picture: %w", err)
}
defer response.Body.Close()
reader = response.Body
defer res.Body.Close()
reader = res.Body
} else if decodedPhoto, err := base64.StdEncoding.DecodeString(pictureString); err == nil {
// If the photo is a base64 encoded string, decode it
reader = bytes.NewReader(decodedPhoto)
} else {
// If the photo is a string, we assume that it's a binary string
reader = bytes.NewReader([]byte(pictureString))