mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-11 15:53:00 +03:00
860 lines
24 KiB
Go
860 lines
24 KiB
Go
package service
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/fs"
|
|
"log/slog"
|
|
"net/url"
|
|
"path"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"go.opentelemetry.io/otel/trace"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/clause"
|
|
|
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
|
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
|
"github.com/pocket-id/pocket-id/backend/internal/model"
|
|
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
|
"github.com/pocket-id/pocket-id/backend/internal/storage"
|
|
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
|
"github.com/pocket-id/pocket-id/backend/internal/utils/email"
|
|
profilepicture "github.com/pocket-id/pocket-id/backend/internal/utils/image"
|
|
)
|
|
|
|
type UserService struct {
|
|
db *gorm.DB
|
|
jwtService *JwtService
|
|
auditLogService *AuditLogService
|
|
emailService *EmailService
|
|
appConfigService *AppConfigService
|
|
customClaimService *CustomClaimService
|
|
appImagesService *AppImagesService
|
|
fileStorage storage.FileStorage
|
|
}
|
|
|
|
func NewUserService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditLogService, emailService *EmailService, appConfigService *AppConfigService, customClaimService *CustomClaimService, appImagesService *AppImagesService, fileStorage storage.FileStorage) *UserService {
|
|
return &UserService{
|
|
db: db,
|
|
jwtService: jwtService,
|
|
auditLogService: auditLogService,
|
|
emailService: emailService,
|
|
appConfigService: appConfigService,
|
|
customClaimService: customClaimService,
|
|
appImagesService: appImagesService,
|
|
fileStorage: fileStorage,
|
|
}
|
|
}
|
|
|
|
func (s *UserService) ListUsers(ctx context.Context, searchTerm string, listRequestOptions utils.ListRequestOptions) ([]model.User, utils.PaginationResponse, error) {
|
|
var users []model.User
|
|
query := s.db.WithContext(ctx).
|
|
Model(&model.User{}).
|
|
Preload("UserGroups").
|
|
Preload("CustomClaims")
|
|
|
|
if searchTerm != "" {
|
|
searchPattern := "%" + searchTerm + "%"
|
|
query = query.Where(
|
|
"email LIKE ? OR first_name LIKE ? OR last_name LIKE ? OR username LIKE ?",
|
|
searchPattern, searchPattern, searchPattern, searchPattern)
|
|
}
|
|
|
|
pagination, err := utils.PaginateFilterAndSort(listRequestOptions, query, &users)
|
|
|
|
return users, pagination, err
|
|
}
|
|
|
|
func (s *UserService) GetUser(ctx context.Context, userID string) (model.User, error) {
|
|
return s.getUserInternal(ctx, userID, s.db)
|
|
}
|
|
|
|
func (s *UserService) getUserInternal(ctx context.Context, userID string, tx *gorm.DB) (model.User, error) {
|
|
var user model.User
|
|
err := tx.
|
|
WithContext(ctx).
|
|
Preload("UserGroups").
|
|
Preload("CustomClaims").
|
|
Where("id = ?", userID).
|
|
First(&user).
|
|
Error
|
|
return user, err
|
|
}
|
|
|
|
func (s *UserService) GetProfilePicture(ctx context.Context, userID string) (io.ReadCloser, int64, error) {
|
|
// Validate the user ID to prevent directory traversal
|
|
if err := uuid.Validate(userID); err != nil {
|
|
return nil, 0, &common.InvalidUUIDError{}
|
|
}
|
|
|
|
user, err := s.GetUser(ctx, userID)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
profilePicturePath := path.Join("profile-pictures", userID+".png")
|
|
|
|
// Try custom profile picture
|
|
file, size, err := s.fileStorage.Open(ctx, profilePicturePath)
|
|
if err == nil {
|
|
return file, size, nil
|
|
} else if !errors.Is(err, fs.ErrNotExist) {
|
|
return nil, 0, err
|
|
}
|
|
|
|
// Try default global profile picture
|
|
if s.appImagesService.IsDefaultProfilePictureSet() {
|
|
reader, size, _, err := s.appImagesService.GetImage(ctx, "default-profile-picture")
|
|
if err == nil {
|
|
return reader, size, nil
|
|
}
|
|
if !errors.Is(err, &common.ImageNotFoundError{}) {
|
|
return nil, 0, err
|
|
}
|
|
}
|
|
|
|
// Try cached default for initials
|
|
defaultPicturePath := path.Join("profile-pictures", "defaults", user.Initials()+".png")
|
|
file, size, err = s.fileStorage.Open(ctx, defaultPicturePath)
|
|
if err == nil {
|
|
return file, size, nil
|
|
} else if !errors.Is(err, fs.ErrNotExist) {
|
|
return nil, 0, err
|
|
}
|
|
|
|
// Create and return generated default with initials
|
|
defaultPicture, err := profilepicture.CreateDefaultProfilePicture(user.Initials())
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
// Save the default picture for future use (in a goroutine to avoid blocking)
|
|
defaultPictureBytes := defaultPicture.Bytes()
|
|
//nolint:contextcheck
|
|
go func() {
|
|
// Use bytes.NewReader because we need an io.ReadSeeker
|
|
rErr := s.fileStorage.Save(context.Background(), defaultPicturePath, bytes.NewReader(defaultPictureBytes))
|
|
if rErr != nil {
|
|
slog.Error("Failed to cache default profile picture", slog.String("initials", user.Initials()), slog.Any("error", rErr))
|
|
}
|
|
}()
|
|
|
|
return io.NopCloser(bytes.NewReader(defaultPictureBytes)), int64(len(defaultPictureBytes)), nil
|
|
}
|
|
|
|
func (s *UserService) GetUserGroups(ctx context.Context, userID string) ([]model.UserGroup, error) {
|
|
var user model.User
|
|
err := s.db.
|
|
WithContext(ctx).
|
|
Preload("UserGroups").
|
|
Where("id = ?", userID).
|
|
First(&user).
|
|
Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return user.UserGroups, nil
|
|
}
|
|
|
|
func (s *UserService) UpdateProfilePicture(ctx context.Context, userID string, file io.ReadSeeker) error {
|
|
// Validate the user ID to prevent directory traversal
|
|
err := uuid.Validate(userID)
|
|
if err != nil {
|
|
return &common.InvalidUUIDError{}
|
|
}
|
|
|
|
// Convert the image to a smaller square image
|
|
profilePicture, err := profilepicture.CreateProfilePicture(file)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
profilePicturePath := path.Join("profile-pictures", userID+".png")
|
|
err = s.fileStorage.Save(ctx, profilePicturePath, profilePicture)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *UserService) DeleteUser(ctx context.Context, userID string, allowLdapDelete bool) error {
|
|
err := s.db.Transaction(func(tx *gorm.DB) error {
|
|
return s.deleteUserInternal(ctx, tx, userID, allowLdapDelete)
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete user '%s': %w", userID, err)
|
|
}
|
|
|
|
// Storage operations must be executed outside of a transaction
|
|
profilePicturePath := path.Join("profile-pictures", userID+".png")
|
|
err = s.fileStorage.Delete(ctx, profilePicturePath)
|
|
if err != nil && !storage.IsNotExist(err) {
|
|
return fmt.Errorf("failed to delete profile picture for user '%s': %w", userID, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *UserService) deleteUserInternal(ctx context.Context, tx *gorm.DB, userID string, allowLdapDelete bool) error {
|
|
var user model.User
|
|
|
|
err := tx.
|
|
WithContext(ctx).
|
|
Where("id = ?", userID).
|
|
Clauses(clause.Locking{Strength: "UPDATE"}).
|
|
First(&user).
|
|
Error
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load user to delete: %w", err)
|
|
}
|
|
|
|
// Disallow deleting the user if it is an LDAP user, LDAP is enabled, and the user is not disabled
|
|
if !allowLdapDelete && !user.Disabled && user.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
|
|
return &common.LdapUserUpdateError{}
|
|
}
|
|
|
|
err = tx.WithContext(ctx).Delete(&user).Error
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete user: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *UserService) CreateUser(ctx context.Context, input dto.UserCreateDto) (model.User, error) {
|
|
tx := s.db.Begin()
|
|
defer func() {
|
|
tx.Rollback()
|
|
}()
|
|
|
|
user, err := s.createUserInternal(ctx, input, false, tx)
|
|
if err != nil {
|
|
return model.User{}, err
|
|
}
|
|
|
|
err = tx.Commit().Error
|
|
if err != nil {
|
|
return model.User{}, err
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCreateDto, isLdapSync bool, tx *gorm.DB) (model.User, error) {
|
|
if s.appConfigService.GetDbConfig().RequireUserEmail.IsTrue() && input.Email == nil {
|
|
return model.User{}, &common.UserEmailNotSetError{}
|
|
}
|
|
|
|
user := model.User{
|
|
FirstName: input.FirstName,
|
|
LastName: input.LastName,
|
|
DisplayName: input.DisplayName,
|
|
Email: input.Email,
|
|
Username: input.Username,
|
|
IsAdmin: input.IsAdmin,
|
|
Locale: input.Locale,
|
|
Disabled: input.Disabled,
|
|
}
|
|
if input.LdapID != "" {
|
|
user.LdapID = &input.LdapID
|
|
}
|
|
|
|
err := tx.WithContext(ctx).Create(&user).Error
|
|
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
|
// Do not follow this path if we're using LDAP, as we don't want to roll-back the transaction here
|
|
if !isLdapSync {
|
|
tx.Rollback()
|
|
// If we are here, the transaction is already aborted due to an error, so we pass s.db
|
|
err = s.checkDuplicatedFields(ctx, user, s.db)
|
|
} else {
|
|
err = s.checkDuplicatedFields(ctx, user, tx)
|
|
}
|
|
|
|
return model.User{}, err
|
|
} else if err != nil {
|
|
return model.User{}, err
|
|
}
|
|
|
|
// Apply default groups and claims for new non-LDAP users
|
|
if !isLdapSync {
|
|
if err := s.applySignupDefaults(ctx, &user, tx); err != nil {
|
|
return model.User{}, err
|
|
}
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
func (s *UserService) applySignupDefaults(ctx context.Context, user *model.User, tx *gorm.DB) error {
|
|
config := s.appConfigService.GetDbConfig()
|
|
|
|
// Apply default user groups
|
|
var groupIDs []string
|
|
v := config.SignupDefaultUserGroupIDs.Value
|
|
if v != "" && v != "[]" {
|
|
err := json.Unmarshal([]byte(v), &groupIDs)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid SignupDefaultUserGroupIDs JSON: %w", err)
|
|
}
|
|
if len(groupIDs) > 0 {
|
|
var groups []model.UserGroup
|
|
err = tx.WithContext(ctx).
|
|
Where("id IN ?", groupIDs).
|
|
Find(&groups).
|
|
Error
|
|
if err != nil {
|
|
return fmt.Errorf("failed to find default user groups: %w", err)
|
|
}
|
|
|
|
err = tx.WithContext(ctx).
|
|
Model(user).
|
|
Association("UserGroups").
|
|
Replace(groups)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to associate default user groups: %w", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Apply default custom claims
|
|
var claims []dto.CustomClaimCreateDto
|
|
v = config.SignupDefaultCustomClaims.Value
|
|
if v != "" && v != "[]" {
|
|
err := json.Unmarshal([]byte(v), &claims)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid SignupDefaultCustomClaims JSON: %w", err)
|
|
}
|
|
if len(claims) > 0 {
|
|
_, err = s.customClaimService.updateCustomClaimsInternal(ctx, UserID, user.ID, claims, tx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to apply default custom claims: %w", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *UserService) UpdateUser(ctx context.Context, userID string, updatedUser dto.UserCreateDto, updateOwnUser bool, isLdapSync bool) (model.User, error) {
|
|
tx := s.db.Begin()
|
|
defer func() {
|
|
tx.Rollback()
|
|
}()
|
|
|
|
user, err := s.updateUserInternal(ctx, userID, updatedUser, updateOwnUser, isLdapSync, tx)
|
|
if err != nil {
|
|
return model.User{}, err
|
|
}
|
|
|
|
err = tx.Commit().Error
|
|
if err != nil {
|
|
return model.User{}, err
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
func (s *UserService) updateUserInternal(ctx context.Context, userID string, updatedUser dto.UserCreateDto, updateOwnUser bool, isLdapSync bool, tx *gorm.DB) (model.User, error) {
|
|
if s.appConfigService.GetDbConfig().RequireUserEmail.IsTrue() && updatedUser.Email == nil {
|
|
return model.User{}, &common.UserEmailNotSetError{}
|
|
}
|
|
|
|
var user model.User
|
|
err := tx.
|
|
WithContext(ctx).
|
|
Where("id = ?", userID).
|
|
Clauses(clause.Locking{Strength: "UPDATE"}).
|
|
First(&user).
|
|
Error
|
|
if err != nil {
|
|
return model.User{}, err
|
|
}
|
|
|
|
// Check if this is an LDAP user and LDAP is enabled
|
|
isLdapUser := user.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue()
|
|
allowOwnAccountEdit := s.appConfigService.GetDbConfig().AllowOwnAccountEdit.IsTrue()
|
|
|
|
if !isLdapSync && (isLdapUser || (!allowOwnAccountEdit && updateOwnUser)) {
|
|
// Restricted update: Only locale can be changed when:
|
|
// - User is from LDAP, OR
|
|
// - User is editing their own account but global setting disallows self-editing
|
|
// (Exception: LDAP sync operations can update everything)
|
|
user.Locale = updatedUser.Locale
|
|
} else {
|
|
// Full update: Allow updating all personal fields
|
|
user.FirstName = updatedUser.FirstName
|
|
user.LastName = updatedUser.LastName
|
|
user.DisplayName = updatedUser.DisplayName
|
|
user.Email = updatedUser.Email
|
|
user.Username = updatedUser.Username
|
|
user.Locale = updatedUser.Locale
|
|
|
|
// Admin-only fields: Only allow updates when not updating own account
|
|
if !updateOwnUser {
|
|
user.IsAdmin = updatedUser.IsAdmin
|
|
user.Disabled = updatedUser.Disabled
|
|
}
|
|
}
|
|
|
|
err = tx.
|
|
WithContext(ctx).
|
|
Save(&user).
|
|
Error
|
|
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
|
// Do not follow this path if we're using LDAP, as we don't want to roll-back the transaction here
|
|
if !isLdapSync {
|
|
tx.Rollback()
|
|
// If we are here, the transaction is already aborted due to an error, so we pass s.db
|
|
err = s.checkDuplicatedFields(ctx, user, s.db)
|
|
} else {
|
|
err = s.checkDuplicatedFields(ctx, user, tx)
|
|
}
|
|
|
|
return user, err
|
|
} else if err != nil {
|
|
return user, err
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
func (s *UserService) RequestOneTimeAccessEmailAsAdmin(ctx context.Context, userID string, ttl time.Duration) error {
|
|
isDisabled := !s.appConfigService.GetDbConfig().EmailOneTimeAccessAsAdminEnabled.IsTrue()
|
|
if isDisabled {
|
|
return &common.OneTimeAccessDisabledError{}
|
|
}
|
|
|
|
return s.requestOneTimeAccessEmailInternal(ctx, userID, "", ttl)
|
|
}
|
|
|
|
func (s *UserService) RequestOneTimeAccessEmailAsUnauthenticatedUser(ctx context.Context, userID, redirectPath string) error {
|
|
isDisabled := !s.appConfigService.GetDbConfig().EmailOneTimeAccessAsUnauthenticatedEnabled.IsTrue()
|
|
if isDisabled {
|
|
return &common.OneTimeAccessDisabledError{}
|
|
}
|
|
|
|
var userId string
|
|
err := s.db.Model(&model.User{}).Select("id").Where("email = ?", userID).First(&userId).Error
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
// Do not return error if user not found to prevent email enumeration
|
|
return nil
|
|
} else if err != nil {
|
|
return err
|
|
}
|
|
|
|
return s.requestOneTimeAccessEmailInternal(ctx, userId, redirectPath, 15*time.Minute)
|
|
}
|
|
|
|
func (s *UserService) requestOneTimeAccessEmailInternal(ctx context.Context, userID, redirectPath string, ttl time.Duration) error {
|
|
tx := s.db.Begin()
|
|
defer func() {
|
|
tx.Rollback()
|
|
}()
|
|
|
|
user, err := s.GetUser(ctx, userID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if user.Email == nil {
|
|
return &common.UserEmailNotSetError{}
|
|
}
|
|
|
|
oneTimeAccessToken, err := s.createOneTimeAccessTokenInternal(ctx, user.ID, ttl, tx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = tx.Commit().Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// We use a background context here as this is running in a goroutine
|
|
//nolint:contextcheck
|
|
go func() {
|
|
span := trace.SpanFromContext(ctx)
|
|
innerCtx := trace.ContextWithSpan(context.Background(), span)
|
|
|
|
link := common.EnvConfig.AppURL + "/lc"
|
|
linkWithCode := link + "/" + oneTimeAccessToken
|
|
|
|
// Add redirect path to the link
|
|
if strings.HasPrefix(redirectPath, "/") {
|
|
encodedRedirectPath := url.QueryEscape(redirectPath)
|
|
linkWithCode = linkWithCode + "?redirect=" + encodedRedirectPath
|
|
}
|
|
|
|
errInternal := SendEmail(innerCtx, s.emailService, email.Address{
|
|
Name: user.FullName(),
|
|
Email: *user.Email,
|
|
}, OneTimeAccessTemplate, &OneTimeAccessTemplateData{
|
|
Code: oneTimeAccessToken,
|
|
LoginLink: link,
|
|
LoginLinkWithCode: linkWithCode,
|
|
ExpirationString: utils.DurationToString(ttl),
|
|
})
|
|
if errInternal != nil {
|
|
slog.ErrorContext(innerCtx, "Failed to send one-time access token email", slog.Any("error", errInternal), slog.String("address", *user.Email))
|
|
return
|
|
}
|
|
}()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *UserService) CreateOneTimeAccessToken(ctx context.Context, userID string, ttl time.Duration) (string, error) {
|
|
return s.createOneTimeAccessTokenInternal(ctx, userID, ttl, s.db)
|
|
}
|
|
|
|
func (s *UserService) createOneTimeAccessTokenInternal(ctx context.Context, userID string, ttl time.Duration, tx *gorm.DB) (string, error) {
|
|
oneTimeAccessToken, err := NewOneTimeAccessToken(userID, ttl)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
err = tx.WithContext(ctx).Create(oneTimeAccessToken).Error
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return oneTimeAccessToken.Token, nil
|
|
}
|
|
|
|
func (s *UserService) ExchangeOneTimeAccessToken(ctx context.Context, token string, ipAddress, userAgent string) (model.User, string, error) {
|
|
tx := s.db.Begin()
|
|
defer func() {
|
|
tx.Rollback()
|
|
}()
|
|
|
|
var oneTimeAccessToken model.OneTimeAccessToken
|
|
err := tx.
|
|
WithContext(ctx).
|
|
Where("token = ? AND expires_at > ?", token, datatype.DateTime(time.Now())).
|
|
Preload("User").
|
|
Clauses(clause.Locking{Strength: "UPDATE"}).
|
|
First(&oneTimeAccessToken).
|
|
Error
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return model.User{}, "", &common.TokenInvalidOrExpiredError{}
|
|
}
|
|
return model.User{}, "", err
|
|
}
|
|
accessToken, err := s.jwtService.GenerateAccessToken(oneTimeAccessToken.User)
|
|
if err != nil {
|
|
return model.User{}, "", err
|
|
}
|
|
|
|
err = tx.
|
|
WithContext(ctx).
|
|
Delete(&oneTimeAccessToken).
|
|
Error
|
|
if err != nil {
|
|
return model.User{}, "", err
|
|
}
|
|
|
|
s.auditLogService.Create(ctx, model.AuditLogEventOneTimeAccessTokenSignIn, ipAddress, userAgent, oneTimeAccessToken.User.ID, model.AuditLogData{}, tx)
|
|
|
|
err = tx.Commit().Error
|
|
if err != nil {
|
|
return model.User{}, "", err
|
|
}
|
|
|
|
return oneTimeAccessToken.User, accessToken, nil
|
|
}
|
|
|
|
func (s *UserService) UpdateUserGroups(ctx context.Context, id string, userGroupIds []string) (user model.User, err error) {
|
|
tx := s.db.Begin()
|
|
defer func() {
|
|
tx.Rollback()
|
|
}()
|
|
|
|
user, err = s.getUserInternal(ctx, id, tx)
|
|
if err != nil {
|
|
return model.User{}, err
|
|
}
|
|
|
|
// Fetch the groups based on userGroupIds
|
|
var groups []model.UserGroup
|
|
if len(userGroupIds) > 0 {
|
|
err := tx.
|
|
WithContext(ctx).
|
|
Where("id IN (?)", userGroupIds).
|
|
Find(&groups).
|
|
Error
|
|
if err != nil {
|
|
return model.User{}, err
|
|
}
|
|
}
|
|
|
|
// Replace the current groups with the new set of groups
|
|
err = tx.
|
|
WithContext(ctx).
|
|
Model(&user).
|
|
Association("UserGroups").
|
|
Replace(groups)
|
|
if err != nil {
|
|
return model.User{}, err
|
|
}
|
|
|
|
// Save the updated user
|
|
err = tx.WithContext(ctx).Save(&user).Error
|
|
if err != nil {
|
|
return model.User{}, err
|
|
}
|
|
|
|
err = tx.Commit().Error
|
|
if err != nil {
|
|
return model.User{}, err
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
func (s *UserService) SignUpInitialAdmin(ctx context.Context, signUpData dto.SignUpDto) (model.User, string, error) {
|
|
tx := s.db.Begin()
|
|
defer func() {
|
|
tx.Rollback()
|
|
}()
|
|
|
|
var userCount int64
|
|
if err := tx.WithContext(ctx).Model(&model.User{}).Count(&userCount).Error; err != nil {
|
|
return model.User{}, "", err
|
|
}
|
|
if userCount != 0 {
|
|
return model.User{}, "", &common.SetupAlreadyCompletedError{}
|
|
}
|
|
|
|
userToCreate := dto.UserCreateDto{
|
|
FirstName: signUpData.FirstName,
|
|
LastName: signUpData.LastName,
|
|
DisplayName: strings.TrimSpace(signUpData.FirstName + " " + signUpData.LastName),
|
|
Username: signUpData.Username,
|
|
Email: signUpData.Email,
|
|
IsAdmin: true,
|
|
}
|
|
|
|
user, err := s.createUserInternal(ctx, userToCreate, false, tx)
|
|
if err != nil {
|
|
return model.User{}, "", err
|
|
}
|
|
|
|
token, err := s.jwtService.GenerateAccessToken(user)
|
|
if err != nil {
|
|
return model.User{}, "", err
|
|
}
|
|
|
|
err = tx.Commit().Error
|
|
if err != nil {
|
|
return model.User{}, "", err
|
|
}
|
|
|
|
return user, token, nil
|
|
}
|
|
|
|
func (s *UserService) checkDuplicatedFields(ctx context.Context, user model.User, tx *gorm.DB) error {
|
|
var result struct {
|
|
Found bool
|
|
}
|
|
err := tx.
|
|
WithContext(ctx).
|
|
Raw(`SELECT EXISTS(SELECT 1 FROM users WHERE id != ? AND email = ?) AS found`, user.ID, user.Email).
|
|
First(&result).
|
|
Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if result.Found {
|
|
return &common.AlreadyInUseError{Property: "email"}
|
|
}
|
|
|
|
err = tx.
|
|
WithContext(ctx).
|
|
Raw(`SELECT EXISTS(SELECT 1 FROM users WHERE id != ? AND username = ?) AS found`, user.ID, user.Username).
|
|
First(&result).
|
|
Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if result.Found {
|
|
return &common.AlreadyInUseError{Property: "username"}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ResetProfilePicture deletes a user's custom profile picture
|
|
func (s *UserService) ResetProfilePicture(ctx context.Context, userID string) error {
|
|
// Validate the user ID to prevent directory traversal
|
|
if err := uuid.Validate(userID); err != nil {
|
|
return &common.InvalidUUIDError{}
|
|
}
|
|
|
|
profilePicturePath := path.Join("profile-pictures", userID+".png")
|
|
if err := s.fileStorage.Delete(ctx, profilePicturePath); err != nil {
|
|
return fmt.Errorf("failed to delete profile picture: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *UserService) disableUserInternal(ctx context.Context, tx *gorm.DB, userID string) error {
|
|
return tx.
|
|
WithContext(ctx).
|
|
Model(&model.User{}).
|
|
Where("id = ?", userID).
|
|
Update("disabled", true).
|
|
Error
|
|
}
|
|
|
|
func (s *UserService) CreateSignupToken(ctx context.Context, ttl time.Duration, usageLimit int) (model.SignupToken, error) {
|
|
signupToken, err := NewSignupToken(ttl, usageLimit)
|
|
if err != nil {
|
|
return model.SignupToken{}, err
|
|
}
|
|
|
|
err = s.db.WithContext(ctx).Create(signupToken).Error
|
|
if err != nil {
|
|
return model.SignupToken{}, err
|
|
}
|
|
|
|
return *signupToken, nil
|
|
}
|
|
|
|
func (s *UserService) SignUp(ctx context.Context, signupData dto.SignUpDto, ipAddress, userAgent string) (model.User, string, error) {
|
|
tx := s.db.Begin()
|
|
defer func() {
|
|
tx.Rollback()
|
|
}()
|
|
|
|
tokenProvided := signupData.Token != ""
|
|
|
|
config := s.appConfigService.GetDbConfig()
|
|
if config.AllowUserSignups.Value != "open" && !tokenProvided {
|
|
return model.User{}, "", &common.OpenSignupDisabledError{}
|
|
}
|
|
|
|
var signupToken model.SignupToken
|
|
if tokenProvided {
|
|
err := tx.
|
|
WithContext(ctx).
|
|
Where("token = ?", signupData.Token).
|
|
Clauses(clause.Locking{Strength: "UPDATE"}).
|
|
First(&signupToken).
|
|
Error
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return model.User{}, "", &common.TokenInvalidOrExpiredError{}
|
|
}
|
|
return model.User{}, "", err
|
|
}
|
|
|
|
if !signupToken.IsValid() {
|
|
return model.User{}, "", &common.TokenInvalidOrExpiredError{}
|
|
}
|
|
}
|
|
|
|
userToCreate := dto.UserCreateDto{
|
|
Username: signupData.Username,
|
|
Email: signupData.Email,
|
|
FirstName: signupData.FirstName,
|
|
LastName: signupData.LastName,
|
|
DisplayName: strings.TrimSpace(signupData.FirstName + " " + signupData.LastName),
|
|
}
|
|
|
|
user, err := s.createUserInternal(ctx, userToCreate, false, tx)
|
|
if err != nil {
|
|
return model.User{}, "", err
|
|
}
|
|
|
|
accessToken, err := s.jwtService.GenerateAccessToken(user)
|
|
if err != nil {
|
|
return model.User{}, "", err
|
|
}
|
|
|
|
if tokenProvided {
|
|
s.auditLogService.Create(ctx, model.AuditLogEventAccountCreated, ipAddress, userAgent, user.ID, model.AuditLogData{
|
|
"signupToken": signupToken.Token,
|
|
}, tx)
|
|
|
|
signupToken.UsageCount++
|
|
|
|
err = tx.WithContext(ctx).Save(&signupToken).Error
|
|
if err != nil {
|
|
return model.User{}, "", err
|
|
|
|
}
|
|
} else {
|
|
s.auditLogService.Create(ctx, model.AuditLogEventAccountCreated, ipAddress, userAgent, user.ID, model.AuditLogData{
|
|
"method": "open_signup",
|
|
}, tx)
|
|
}
|
|
|
|
err = tx.Commit().Error
|
|
if err != nil {
|
|
return model.User{}, "", err
|
|
}
|
|
|
|
return user, accessToken, nil
|
|
}
|
|
|
|
func (s *UserService) ListSignupTokens(ctx context.Context, listRequestOptions utils.ListRequestOptions) ([]model.SignupToken, utils.PaginationResponse, error) {
|
|
var tokens []model.SignupToken
|
|
query := s.db.WithContext(ctx).Model(&model.SignupToken{})
|
|
|
|
pagination, err := utils.PaginateFilterAndSort(listRequestOptions, query, &tokens)
|
|
return tokens, pagination, err
|
|
}
|
|
|
|
func (s *UserService) DeleteSignupToken(ctx context.Context, tokenID string) error {
|
|
return s.db.WithContext(ctx).Delete(&model.SignupToken{}, "id = ?", tokenID).Error
|
|
}
|
|
|
|
func NewOneTimeAccessToken(userID string, ttl time.Duration) (*model.OneTimeAccessToken, error) {
|
|
// If expires at is less than 15 minutes, use a 6-character token instead of 16
|
|
tokenLength := 16
|
|
if ttl <= 15*time.Minute {
|
|
tokenLength = 6
|
|
}
|
|
|
|
randomString, err := utils.GenerateRandomAlphanumericString(tokenLength)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
now := time.Now().Round(time.Second)
|
|
o := &model.OneTimeAccessToken{
|
|
UserID: userID,
|
|
ExpiresAt: datatype.DateTime(now.Add(ttl)),
|
|
Token: randomString,
|
|
}
|
|
|
|
return o, nil
|
|
}
|
|
|
|
func NewSignupToken(ttl time.Duration, usageLimit int) (*model.SignupToken, error) {
|
|
// Generate a random token
|
|
randomString, err := utils.GenerateRandomAlphanumericString(16)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
now := time.Now().Round(time.Second)
|
|
token := &model.SignupToken{
|
|
Token: randomString,
|
|
ExpiresAt: datatype.DateTime(now.Add(ttl)),
|
|
UsageLimit: usageLimit,
|
|
UsageCount: 0,
|
|
}
|
|
|
|
return token, nil
|
|
}
|