fix: use transactions when operations involve multiple database queries (#392)

Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
This commit is contained in:
Alessandro (Ale) Segala
2025-04-06 06:04:08 -07:00
committed by GitHub
parent c810fec8c4
commit ec626ee797
33 changed files with 1401 additions and 501 deletions

View File

@@ -1,13 +1,15 @@
package service
import (
"context"
"errors"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"gorm.io/gorm"
)
type UserGroupService struct {
@@ -19,8 +21,11 @@ func NewUserGroupService(db *gorm.DB, appConfigService *AppConfigService) *UserG
return &UserGroupService{db: db, appConfigService: appConfigService}
}
func (s *UserGroupService) List(name string, sortedPaginationRequest utils.SortedPaginationRequest) (groups []model.UserGroup, response utils.PaginationResponse, err error) {
query := s.db.Preload("CustomClaims").Model(&model.UserGroup{})
func (s *UserGroupService) List(ctx context.Context, name string, sortedPaginationRequest utils.SortedPaginationRequest) (groups []model.UserGroup, response utils.PaginationResponse, err error) {
query := s.db.
WithContext(ctx).
Preload("CustomClaims").
Model(&model.UserGroup{})
if name != "" {
query = query.Where("name LIKE ?", "%"+name+"%")
@@ -42,26 +47,59 @@ func (s *UserGroupService) List(name string, sortedPaginationRequest utils.Sorte
return groups, response, err
}
func (s *UserGroupService) Get(id string) (group model.UserGroup, err error) {
err = s.db.Where("id = ?", id).Preload("CustomClaims").Preload("Users").First(&group).Error
func (s *UserGroupService) Get(ctx context.Context, id string) (group model.UserGroup, err error) {
return s.getInternal(ctx, id, s.db)
}
func (s *UserGroupService) getInternal(ctx context.Context, id string, tx *gorm.DB) (group model.UserGroup, err error) {
err = tx.
WithContext(ctx).
Where("id = ?", id).
Preload("CustomClaims").
Preload("Users").
First(&group).
Error
return group, err
}
func (s *UserGroupService) Delete(id string) error {
func (s *UserGroupService) Delete(ctx context.Context, id string) error {
tx := s.db.Begin()
var group model.UserGroup
if err := s.db.Where("id = ?", id).First(&group).Error; err != nil {
err := tx.
WithContext(ctx).
Where("id = ?", id).
First(&group).
Error
if err != nil {
return err
}
// Disallow deleting the group if it is an LDAP group and LDAP is enabled
if group.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.IsTrue() {
err = tx.Rollback().Error
if err != nil {
return err
}
return &common.LdapUserGroupUpdateError{}
}
return s.db.Delete(&group).Error
err = tx.
WithContext(ctx).
Delete(&group).
Error
if err != nil {
return err
}
return tx.Commit().Error
}
func (s *UserGroupService) Create(input dto.UserGroupCreateDto) (group model.UserGroup, err error) {
func (s *UserGroupService) Create(ctx context.Context, input dto.UserGroupCreateDto) (group model.UserGroup, err error) {
return s.createInternal(ctx, input, s.db)
}
func (s *UserGroupService) createInternal(ctx context.Context, input dto.UserGroupCreateDto, tx *gorm.DB) (group model.UserGroup, err error) {
group = model.UserGroup{
FriendlyName: input.FriendlyName,
Name: input.Name,
@@ -71,7 +109,12 @@ func (s *UserGroupService) Create(input dto.UserGroupCreateDto) (group model.Use
group.LdapID = &input.LdapID
}
if err := s.db.Preload("Users").Create(&group).Error; err != nil {
err = tx.
WithContext(ctx).
Preload("Users").
Create(&group).
Error
if err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
return model.UserGroup{}, &common.AlreadyInUseError{Property: "name"}
}
@@ -80,8 +123,26 @@ func (s *UserGroupService) Create(input dto.UserGroupCreateDto) (group model.Use
return group, nil
}
func (s *UserGroupService) Update(id string, input dto.UserGroupCreateDto, allowLdapUpdate bool) (group model.UserGroup, err error) {
group, err = s.Get(id)
func (s *UserGroupService) Update(ctx context.Context, id string, input dto.UserGroupCreateDto, allowLdapUpdate bool) (group model.UserGroup, err error) {
tx := s.db.Begin()
group, err = s.updateInternal(ctx, id, input, allowLdapUpdate, tx)
if err != nil {
tx.Rollback()
return model.UserGroup{}, err
}
err = tx.Commit().Error
if err != nil {
tx.Rollback()
return model.UserGroup{}, err
}
return group, nil
}
func (s *UserGroupService) updateInternal(ctx context.Context, id string, input dto.UserGroupCreateDto, allowLdapUpdate bool, tx *gorm.DB) (group model.UserGroup, err error) {
group, err = s.getInternal(ctx, id, tx)
if err != nil {
return model.UserGroup{}, err
}
@@ -94,7 +155,12 @@ func (s *UserGroupService) Update(id string, input dto.UserGroupCreateDto, allow
group.Name = input.Name
group.FriendlyName = input.FriendlyName
if err := s.db.Preload("Users").Save(&group).Error; err != nil {
err = tx.
WithContext(ctx).
Preload("Users").
Save(&group).
Error
if err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
return model.UserGroup{}, &common.AlreadyInUseError{Property: "name"}
}
@@ -103,8 +169,26 @@ func (s *UserGroupService) Update(id string, input dto.UserGroupCreateDto, allow
return group, nil
}
func (s *UserGroupService) UpdateUsers(id string, userIds []string) (group model.UserGroup, err error) {
group, err = s.Get(id)
func (s *UserGroupService) UpdateUsers(ctx context.Context, id string, userIds []string) (group model.UserGroup, err error) {
tx := s.db.Begin()
group, err = s.updateUsersInternal(ctx, id, userIds, tx)
if err != nil {
tx.Rollback()
return model.UserGroup{}, err
}
err = tx.Commit().Error
if err != nil {
tx.Rollback()
return model.UserGroup{}, err
}
return group, nil
}
func (s *UserGroupService) updateUsersInternal(ctx context.Context, id string, userIds []string, tx *gorm.DB) (group model.UserGroup, err error) {
group, err = s.getInternal(ctx, id, tx)
if err != nil {
return model.UserGroup{}, err
}
@@ -112,28 +196,59 @@ func (s *UserGroupService) UpdateUsers(id string, userIds []string) (group model
// Fetch the users based on the userIds
var users []model.User
if len(userIds) > 0 {
if err := s.db.Where("id IN (?)", userIds).Find(&users).Error; err != nil {
err := tx.
WithContext(ctx).
Where("id IN (?)", userIds).
Find(&users).
Error
if err != nil {
return model.UserGroup{}, err
}
}
// Replace the current users with the new set of users
if err := s.db.Model(&group).Association("Users").Replace(users); err != nil {
err = tx.
WithContext(ctx).
Model(&group).
Association("Users").
Replace(users)
if err != nil {
return model.UserGroup{}, err
}
// Save the updated group
if err := s.db.Save(&group).Error; err != nil {
err = tx.
WithContext(ctx).
Save(&group).
Error
if err != nil {
return model.UserGroup{}, err
}
return group, nil
}
func (s *UserGroupService) GetUserCountOfGroup(id string) (int64, error) {
func (s *UserGroupService) GetUserCountOfGroup(ctx context.Context, id string) (int64, error) {
// We only perform select queries here, so we can rollback in all cases
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
var group model.UserGroup
if err := s.db.Preload("Users").Where("id = ?", id).First(&group).Error; err != nil {
err := tx.
WithContext(ctx).
Preload("Users").
Where("id = ?", id).
First(&group).
Error
if err != nil {
return 0, err
}
return s.db.Model(&group).Association("Users").Count(), nil
count := tx.
WithContext(ctx).
Model(&group).
Association("Users").
Count()
return count, nil
}