mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-13 16:53:03 +03:00
fix: use transactions when operations involve multiple database queries (#392)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
This commit is contained in:
committed by
GitHub
parent
c810fec8c4
commit
ec626ee797
@@ -38,7 +38,9 @@ func (s *LdapService) createClient() (*ldap.Conn, error) {
|
||||
// Setup LDAP connection
|
||||
ldapURL := s.appConfigService.DbConfig.LdapUrl.Value
|
||||
skipTLSVerify := s.appConfigService.DbConfig.LdapSkipCertVerify.IsTrue()
|
||||
client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{InsecureSkipVerify: skipTLSVerify})) //nolint:gosec
|
||||
client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{
|
||||
InsecureSkipVerify: skipTLSVerify, //nolint:gosec
|
||||
}))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to LDAP: %w", err)
|
||||
}
|
||||
@@ -53,22 +55,31 @@ func (s *LdapService) createClient() (*ldap.Conn, error) {
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (s *LdapService) SyncAll() error {
|
||||
err := s.SyncUsers()
|
||||
func (s *LdapService) SyncAll(ctx context.Context) error {
|
||||
// Start a transaction
|
||||
tx := s.db.Begin()
|
||||
|
||||
err := s.SyncUsers(ctx, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sync users: %w", err)
|
||||
}
|
||||
|
||||
err = s.SyncGroups()
|
||||
err = s.SyncGroups(ctx, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sync groups: %w", err)
|
||||
}
|
||||
|
||||
// Commit the changes
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to commit changes to database: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func (s *LdapService) SyncGroups() error {
|
||||
func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error {
|
||||
// Setup LDAP connection
|
||||
client, err := s.createClient()
|
||||
if err != nil {
|
||||
@@ -112,7 +123,7 @@ func (s *LdapService) SyncGroups() error {
|
||||
|
||||
// Try to find the group in the database
|
||||
var databaseGroup model.UserGroup
|
||||
s.db.Where("ldap_id = ?", ldapId).First(&databaseGroup)
|
||||
tx.WithContext(ctx).Where("ldap_id = ?", ldapId).First(&databaseGroup)
|
||||
|
||||
// Get group members and add to the correct Group
|
||||
groupMembers := value.GetAttributeValues(groupMemberOfAttribute)
|
||||
@@ -122,7 +133,7 @@ func (s *LdapService) SyncGroups() error {
|
||||
singleMember := strings.Split(strings.Split(member, "=")[1], ",")[0]
|
||||
|
||||
var databaseUser model.User
|
||||
err := s.db.Where("username = ? AND ldap_id IS NOT NULL", singleMember).First(&databaseUser).Error
|
||||
err := tx.WithContext(ctx).Where("username = ? AND ldap_id IS NOT NULL", singleMember).First(&databaseUser).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// The user collides with a non-LDAP user, so we skip it
|
||||
@@ -143,39 +154,51 @@ func (s *LdapService) SyncGroups() error {
|
||||
}
|
||||
|
||||
if databaseGroup.ID == "" {
|
||||
newGroup, err := s.groupService.Create(syncGroup)
|
||||
newGroup, err := s.groupService.createInternal(ctx, syncGroup, tx)
|
||||
if err != nil {
|
||||
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
|
||||
} else {
|
||||
if _, err = s.groupService.UpdateUsers(newGroup.ID, membersUserId); err != nil {
|
||||
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
|
||||
}
|
||||
log.Printf("Error syncing group %s: %v", syncGroup.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = s.groupService.updateUsersInternal(ctx, newGroup.ID, membersUserId, tx)
|
||||
if err != nil {
|
||||
log.Printf("Error syncing group %s: %v", syncGroup.Name, err)
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
_, err = s.groupService.Update(databaseGroup.ID, syncGroup, true)
|
||||
_, err = s.groupService.updateInternal(ctx, databaseGroup.ID, syncGroup, true, tx)
|
||||
if err != nil {
|
||||
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
|
||||
}
|
||||
_, err = s.groupService.UpdateUsers(databaseGroup.ID, membersUserId)
|
||||
if err != nil {
|
||||
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
|
||||
return err
|
||||
log.Printf("Error syncing group %s: %v", syncGroup.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = s.groupService.updateUsersInternal(ctx, databaseGroup.ID, membersUserId, tx)
|
||||
if err != nil {
|
||||
log.Printf("Error syncing group %s: %v", syncGroup.Name, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Get all LDAP groups from the database
|
||||
var ldapGroupsInDb []model.UserGroup
|
||||
if err := s.db.Find(&ldapGroupsInDb, "ldap_id IS NOT NULL").Select("ldap_id").Error; err != nil {
|
||||
fmt.Println(fmt.Errorf("failed to fetch groups from database: %w", err))
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Find(&ldapGroupsInDb, "ldap_id IS NOT NULL").
|
||||
Select("ldap_id").
|
||||
Error
|
||||
if err != nil {
|
||||
log.Printf("Failed to fetch groups from database: %v", err)
|
||||
}
|
||||
|
||||
// Delete groups that no longer exist in LDAP
|
||||
for _, group := range ldapGroupsInDb {
|
||||
if _, exists := ldapGroupIDs[*group.LdapID]; !exists {
|
||||
if err := s.db.Delete(&model.UserGroup{}, "ldap_id = ?", group.LdapID).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Delete(&model.UserGroup{}, "ldap_id = ?", group.LdapID).
|
||||
Error
|
||||
if err != nil {
|
||||
log.Printf("Failed to delete group %s with: %v", group.Name, err)
|
||||
} else {
|
||||
log.Printf("Deleted group %s", group.Name)
|
||||
@@ -187,7 +210,7 @@ func (s *LdapService) SyncGroups() error {
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func (s *LdapService) SyncUsers() error {
|
||||
func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error {
|
||||
// Setup LDAP connection
|
||||
client, err := s.createClient()
|
||||
if err != nil {
|
||||
@@ -241,7 +264,7 @@ func (s *LdapService) SyncUsers() error {
|
||||
|
||||
// Get the user from the database
|
||||
var databaseUser model.User
|
||||
s.db.Where("ldap_id = ?", ldapId).First(&databaseUser)
|
||||
tx.WithContext(ctx).Where("ldap_id = ?", ldapId).First(&databaseUser)
|
||||
|
||||
// Check if user is admin by checking if they are in the admin group
|
||||
isAdmin := false
|
||||
@@ -261,68 +284,75 @@ func (s *LdapService) SyncUsers() error {
|
||||
}
|
||||
|
||||
if databaseUser.ID == "" {
|
||||
_, err = s.userService.CreateUser(newUser)
|
||||
_, err = s.userService.createUserInternal(ctx, newUser, tx)
|
||||
if err != nil {
|
||||
log.Printf("Error syncing user %s: %s", newUser.Username, err)
|
||||
log.Printf("Error syncing user %s: %v", newUser.Username, err)
|
||||
}
|
||||
} else {
|
||||
_, err = s.userService.UpdateUser(databaseUser.ID, newUser, false, true)
|
||||
_, err = s.userService.updateUserInternal(ctx, databaseUser.ID, newUser, false, true, tx)
|
||||
if err != nil {
|
||||
log.Printf("Error syncing user %s: %s", newUser.Username, err)
|
||||
log.Printf("Error syncing user %s: %v", newUser.Username, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Save profile picture
|
||||
if pictureString := value.GetAttributeValue(profilePictureAttribute); pictureString != "" {
|
||||
if err := s.SaveProfilePicture(databaseUser.ID, pictureString); err != nil {
|
||||
log.Printf("Error saving profile picture for user %s: %s", newUser.Username, err)
|
||||
if err := s.saveProfilePicture(ctx, databaseUser.ID, pictureString); err != nil {
|
||||
log.Printf("Error saving profile picture for user %s: %v", newUser.Username, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get all LDAP users from the database
|
||||
var ldapUsersInDb []model.User
|
||||
if err := s.db.Find(&ldapUsersInDb, "ldap_id IS NOT NULL").Select("ldap_id").Error; err != nil {
|
||||
fmt.Println(fmt.Errorf("failed to fetch users from database: %w", err))
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Find(&ldapUsersInDb, "ldap_id IS NOT NULL").
|
||||
Select("ldap_id").
|
||||
Error
|
||||
if err != nil {
|
||||
log.Printf("Failed to fetch users from database: %v", err)
|
||||
}
|
||||
|
||||
// Delete users that no longer exist in LDAP
|
||||
for _, user := range ldapUsersInDb {
|
||||
if _, exists := ldapUserIDs[*user.LdapID]; !exists {
|
||||
if err := s.userService.DeleteUser(user.ID, true); err != nil {
|
||||
if err := s.userService.deleteUserInternal(ctx, user.ID, true, tx); err != nil {
|
||||
log.Printf("Failed to delete user %s with: %v", user.Username, err)
|
||||
} else {
|
||||
log.Printf("Deleted user %s", user.Username)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *LdapService) SaveProfilePicture(userId string, pictureString string) error {
|
||||
func (s *LdapService) saveProfilePicture(parentCtx context.Context, userId string, pictureString string) error {
|
||||
var reader io.Reader
|
||||
|
||||
if _, err := url.ParseRequestURI(pictureString); err == nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
_, err := url.ParseRequestURI(pictureString)
|
||||
if err == nil {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, pictureString, nil)
|
||||
var req *http.Request
|
||||
req, err = http.NewRequestWithContext(ctx, http.MethodGet, pictureString, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
response, err := http.DefaultClient.Do(req)
|
||||
var res *http.Response
|
||||
res, err = http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download profile picture: %w", err)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
reader = response.Body
|
||||
defer res.Body.Close()
|
||||
|
||||
reader = res.Body
|
||||
} else if decodedPhoto, err := base64.StdEncoding.DecodeString(pictureString); err == nil {
|
||||
// If the photo is a base64 encoded string, decode it
|
||||
reader = bytes.NewReader(decodedPhoto)
|
||||
|
||||
} else {
|
||||
// If the photo is a string, we assume that it's a binary string
|
||||
reader = bytes.NewReader([]byte(pictureString))
|
||||
|
||||
Reference in New Issue
Block a user