diff --git a/backend/internal/model/app_config.go b/backend/internal/model/app_config.go index ac479705..678b7df0 100644 --- a/backend/internal/model/app_config.go +++ b/backend/internal/model/app_config.go @@ -1,5 +1,9 @@ package model +import ( + "strconv" +) + type AppConfigVariable struct { Key string `gorm:"primaryKey;not null"` Type string @@ -9,6 +13,11 @@ type AppConfigVariable struct { DefaultValue string } +func (a *AppConfigVariable) IsTrue() bool { + ok, _ := strconv.ParseBool(a.Value) + return ok +} + type AppConfig struct { // General AppName AppConfigVariable diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 8c63c260..bdea1675 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -59,7 +59,7 @@ func (s *UserService) GetProfilePicture(userID string) (io.Reader, int64, error) return nil, 0, &common.InvalidUUIDError{} } - profilePicturePath := fmt.Sprintf("%s/profile-pictures/%s.png", common.EnvConfig.UploadPath, userID) + profilePicturePath := common.EnvConfig.UploadPath + "/profile-pictures/" + userID + ".png" file, err := os.Open(profilePicturePath) if err == nil { // Get the file size @@ -94,7 +94,8 @@ func (s *UserService) GetUserGroups(userID string) ([]model.UserGroup, error) { func (s *UserService) UpdateProfilePicture(userID string, file io.Reader) error { // Validate the user ID to prevent directory traversal - if err := uuid.Validate(userID); err != nil { + err := uuid.Validate(userID) + if err != nil { return &common.InvalidUUIDError{} } @@ -105,20 +106,14 @@ func (s *UserService) UpdateProfilePicture(userID string, file io.Reader) error } // Ensure the directory exists - profilePictureDir := fmt.Sprintf("%s/profile-pictures", common.EnvConfig.UploadPath) - if err := os.MkdirAll(profilePictureDir, os.ModePerm); err != nil { + profilePictureDir := common.EnvConfig.UploadPath + "/profile-pictures" + err = os.MkdirAll(profilePictureDir, os.ModePerm) + if err != nil { return err } // Create the profile picture file - createdProfilePicture, err := os.Create(fmt.Sprintf("%s/%s.png", profilePictureDir, userID)) - if err != nil { - return err - } - defer createdProfilePicture.Close() - - // Copy the image to the file - _, err = io.Copy(createdProfilePicture, profilePicture) + err = utils.SaveFileStream(profilePicture, profilePictureDir+"/"+userID+".png") if err != nil { return err } @@ -133,12 +128,12 @@ func (s *UserService) DeleteUser(userID string) error { } // Disallow deleting the user if it is an LDAP user and LDAP is enabled - if user.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.Value == "true" { + if user.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.IsTrue() { return &common.LdapUserUpdateError{} } // Delete the profile picture - profilePicturePath := fmt.Sprintf("%s/profile-pictures/%s.png", common.EnvConfig.UploadPath, userID) + profilePicturePath := common.EnvConfig.UploadPath + "/profile-pictures/" + userID + ".png" if err := os.Remove(profilePicturePath); err != nil && !os.IsNotExist(err) { return err } @@ -175,7 +170,7 @@ func (s *UserService) UpdateUser(userID string, updatedUser dto.UserCreateDto, u } // Disallow updating the user if it is an LDAP group and LDAP is enabled - if !allowLdapUpdate && user.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.Value == "true" { + if !allowLdapUpdate && user.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.IsTrue() { return model.User{}, &common.LdapUserUpdateError{} } @@ -199,7 +194,7 @@ func (s *UserService) UpdateUser(userID string, updatedUser dto.UserCreateDto, u } func (s *UserService) RequestOneTimeAccessEmail(emailAddress, redirectPath string) error { - isDisabled := s.appConfigService.DbConfig.EmailOneTimeAccessEnabled.Value != "true" + isDisabled := !s.appConfigService.DbConfig.EmailOneTimeAccessEnabled.IsTrue() if isDisabled { return &common.OneTimeAccessDisabledError{} } @@ -376,7 +371,7 @@ func (s *UserService) ResetProfilePicture(userID string) error { } // Build path to profile picture - profilePicturePath := fmt.Sprintf("%s/profile-pictures/%s.png", common.EnvConfig.UploadPath, userID) + profilePicturePath := common.EnvConfig.UploadPath + "/profile-pictures/" + userID + ".png" // Check if file exists and delete it if _, err := os.Stat(profilePicturePath); err == nil { diff --git a/backend/internal/utils/file_util.go b/backend/internal/utils/file_util.go index bc98c6c0..04ab8547 100644 --- a/backend/internal/utils/file_util.go +++ b/backend/internal/utils/file_util.go @@ -1,10 +1,15 @@ package utils import ( + "errors" + "fmt" + "hash/crc64" "io" "mime/multipart" "os" "path/filepath" + "strconv" + "time" "github.com/pocket-id/pocket-id/backend/resources" ) @@ -69,14 +74,70 @@ func SaveFile(file *multipart.FileHeader, dst string) error { return err } - out, err := os.Create(dst) - if err != nil { - return err - } - defer out.Close() + return SaveFileStream(src, dst) +} - _, err = io.Copy(out, src) - return err +// SaveFileStream saves a stream to a file. +func SaveFileStream(r io.Reader, dstFileName string) error { + // Our strategy is to save to a separate file and then rename it to override the original file + // First, get a temp file name that doesn't exist already + var tmpFileName string + var i int64 + for { + seed := strconv.FormatInt(time.Now().UnixNano()+i, 10) + suffix := crc64.Checksum([]byte(dstFileName+seed), crc64.MakeTable(crc64.ISO)) + tmpFileName = dstFileName + "." + strconv.FormatUint(suffix, 10) + exists, err := FileExists(tmpFileName) + if err != nil { + return fmt.Errorf("failed to check if file '%s' exists: %w", tmpFileName, err) + } + if !exists { + break + } + i++ + } + + // Write to the temporary file + tmpFile, err := os.Create(tmpFileName) + if err != nil { + return fmt.Errorf("failed to open file '%s' for writing: %w", tmpFileName, err) + } + + n, err := io.Copy(tmpFile, r) + if err != nil { + // Delete the temporary file; we ignore errors here + _ = tmpFile.Close() + _ = os.Remove(tmpFileName) + + return fmt.Errorf("failed to write to file '%s': %w", tmpFileName, err) + } + + err = tmpFile.Close() + if err != nil { + // Delete the temporary file; we ignore errors here + _ = os.Remove(tmpFileName) + + return fmt.Errorf("failed to close stream to file '%s': %w", tmpFileName, err) + } + + if n == 0 { + // Delete the temporary file; we ignore errors here + _ = os.Remove(tmpFileName) + + return errors.New("no data written") + } + + // Rename to the final file, which overrides existing files + // This is an atomic operation + err = os.Rename(tmpFileName, dstFileName) + if err != nil { + // Delete the temporary file; we ignore errors here + _ = os.Remove(tmpFileName) + + return fmt.Errorf("failed to rename file '%s': %w", dstFileName, err) + } + + return nil } // FileExists returns true if a file exists on disk and is a regular file diff --git a/backend/internal/utils/image/profile_picture.go b/backend/internal/utils/image/profile_picture.go index 9e4842c3..e83c4857 100644 --- a/backend/internal/utils/image/profile_picture.go +++ b/backend/internal/utils/image/profile_picture.go @@ -3,22 +3,24 @@ package profilepicture import ( "bytes" "fmt" - "github.com/disintegration/imageorient" - "github.com/disintegration/imaging" - "github.com/pocket-id/pocket-id/backend/resources" - "golang.org/x/image/font" - "golang.org/x/image/font/opentype" - "golang.org/x/image/math/fixed" "image" "image/color" "io" "strings" + + "github.com/disintegration/imageorient" + "github.com/disintegration/imaging" + "golang.org/x/image/font" + "golang.org/x/image/font/opentype" + "golang.org/x/image/math/fixed" + + "github.com/pocket-id/pocket-id/backend/resources" ) const profilePictureSize = 300 // CreateProfilePicture resizes the profile picture to a square -func CreateProfilePicture(file io.Reader) (*bytes.Buffer, error) { +func CreateProfilePicture(file io.Reader) (io.Reader, error) { img, _, err := imageorient.Decode(file) if err != nil { return nil, fmt.Errorf("failed to decode image: %w", err) @@ -26,13 +28,17 @@ func CreateProfilePicture(file io.Reader) (*bytes.Buffer, error) { img = imaging.Fill(img, profilePictureSize, profilePictureSize, imaging.Center, imaging.Lanczos) - var buf bytes.Buffer - err = imaging.Encode(&buf, img, imaging.PNG) - if err != nil { - return nil, fmt.Errorf("failed to encode image: %v", err) - } + pr, pw := io.Pipe() + go func() { + err = imaging.Encode(pw, img, imaging.PNG) + if err != nil { + _ = pw.CloseWithError(fmt.Errorf("failed to encode image: %v", err)) + return + } + pw.Close() + }() - return &buf, nil + return pr, nil } // CreateDefaultProfilePicture creates a profile picture with the initials