refactor: use atomic renames for uploaded files (#372)

Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
Alessandro (Ale) Segala
2025-03-23 13:21:44 -07:00
committed by GitHub
parent b8dcda8049
commit 26b2de4f00
4 changed files with 108 additions and 37 deletions

View File

@@ -1,5 +1,9 @@
package model
import (
"strconv"
)
type AppConfigVariable struct {
Key string `gorm:"primaryKey;not null"`
Type string
@@ -9,6 +13,11 @@ type AppConfigVariable struct {
DefaultValue string
}
func (a *AppConfigVariable) IsTrue() bool {
ok, _ := strconv.ParseBool(a.Value)
return ok
}
type AppConfig struct {
// General
AppName AppConfigVariable

View File

@@ -59,7 +59,7 @@ func (s *UserService) GetProfilePicture(userID string) (io.Reader, int64, error)
return nil, 0, &common.InvalidUUIDError{}
}
profilePicturePath := fmt.Sprintf("%s/profile-pictures/%s.png", common.EnvConfig.UploadPath, userID)
profilePicturePath := common.EnvConfig.UploadPath + "/profile-pictures/" + userID + ".png"
file, err := os.Open(profilePicturePath)
if err == nil {
// Get the file size
@@ -94,7 +94,8 @@ func (s *UserService) GetUserGroups(userID string) ([]model.UserGroup, error) {
func (s *UserService) UpdateProfilePicture(userID string, file io.Reader) error {
// Validate the user ID to prevent directory traversal
if err := uuid.Validate(userID); err != nil {
err := uuid.Validate(userID)
if err != nil {
return &common.InvalidUUIDError{}
}
@@ -105,20 +106,14 @@ func (s *UserService) UpdateProfilePicture(userID string, file io.Reader) error
}
// Ensure the directory exists
profilePictureDir := fmt.Sprintf("%s/profile-pictures", common.EnvConfig.UploadPath)
if err := os.MkdirAll(profilePictureDir, os.ModePerm); err != nil {
profilePictureDir := common.EnvConfig.UploadPath + "/profile-pictures"
err = os.MkdirAll(profilePictureDir, os.ModePerm)
if err != nil {
return err
}
// Create the profile picture file
createdProfilePicture, err := os.Create(fmt.Sprintf("%s/%s.png", profilePictureDir, userID))
if err != nil {
return err
}
defer createdProfilePicture.Close()
// Copy the image to the file
_, err = io.Copy(createdProfilePicture, profilePicture)
err = utils.SaveFileStream(profilePicture, profilePictureDir+"/"+userID+".png")
if err != nil {
return err
}
@@ -133,12 +128,12 @@ func (s *UserService) DeleteUser(userID string) error {
}
// Disallow deleting the user if it is an LDAP user and LDAP is enabled
if user.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.Value == "true" {
if user.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.IsTrue() {
return &common.LdapUserUpdateError{}
}
// Delete the profile picture
profilePicturePath := fmt.Sprintf("%s/profile-pictures/%s.png", common.EnvConfig.UploadPath, userID)
profilePicturePath := common.EnvConfig.UploadPath + "/profile-pictures/" + userID + ".png"
if err := os.Remove(profilePicturePath); err != nil && !os.IsNotExist(err) {
return err
}
@@ -175,7 +170,7 @@ func (s *UserService) UpdateUser(userID string, updatedUser dto.UserCreateDto, u
}
// Disallow updating the user if it is an LDAP group and LDAP is enabled
if !allowLdapUpdate && user.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.Value == "true" {
if !allowLdapUpdate && user.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.IsTrue() {
return model.User{}, &common.LdapUserUpdateError{}
}
@@ -199,7 +194,7 @@ func (s *UserService) UpdateUser(userID string, updatedUser dto.UserCreateDto, u
}
func (s *UserService) RequestOneTimeAccessEmail(emailAddress, redirectPath string) error {
isDisabled := s.appConfigService.DbConfig.EmailOneTimeAccessEnabled.Value != "true"
isDisabled := !s.appConfigService.DbConfig.EmailOneTimeAccessEnabled.IsTrue()
if isDisabled {
return &common.OneTimeAccessDisabledError{}
}
@@ -376,7 +371,7 @@ func (s *UserService) ResetProfilePicture(userID string) error {
}
// Build path to profile picture
profilePicturePath := fmt.Sprintf("%s/profile-pictures/%s.png", common.EnvConfig.UploadPath, userID)
profilePicturePath := common.EnvConfig.UploadPath + "/profile-pictures/" + userID + ".png"
// Check if file exists and delete it
if _, err := os.Stat(profilePicturePath); err == nil {

View File

@@ -1,10 +1,15 @@
package utils
import (
"errors"
"fmt"
"hash/crc64"
"io"
"mime/multipart"
"os"
"path/filepath"
"strconv"
"time"
"github.com/pocket-id/pocket-id/backend/resources"
)
@@ -69,14 +74,70 @@ func SaveFile(file *multipart.FileHeader, dst string) error {
return err
}
out, err := os.Create(dst)
if err != nil {
return err
return SaveFileStream(src, dst)
}
defer out.Close()
_, err = io.Copy(out, src)
return err
// SaveFileStream saves a stream to a file.
func SaveFileStream(r io.Reader, dstFileName string) error {
// Our strategy is to save to a separate file and then rename it to override the original file
// First, get a temp file name that doesn't exist already
var tmpFileName string
var i int64
for {
seed := strconv.FormatInt(time.Now().UnixNano()+i, 10)
suffix := crc64.Checksum([]byte(dstFileName+seed), crc64.MakeTable(crc64.ISO))
tmpFileName = dstFileName + "." + strconv.FormatUint(suffix, 10)
exists, err := FileExists(tmpFileName)
if err != nil {
return fmt.Errorf("failed to check if file '%s' exists: %w", tmpFileName, err)
}
if !exists {
break
}
i++
}
// Write to the temporary file
tmpFile, err := os.Create(tmpFileName)
if err != nil {
return fmt.Errorf("failed to open file '%s' for writing: %w", tmpFileName, err)
}
n, err := io.Copy(tmpFile, r)
if err != nil {
// Delete the temporary file; we ignore errors here
_ = tmpFile.Close()
_ = os.Remove(tmpFileName)
return fmt.Errorf("failed to write to file '%s': %w", tmpFileName, err)
}
err = tmpFile.Close()
if err != nil {
// Delete the temporary file; we ignore errors here
_ = os.Remove(tmpFileName)
return fmt.Errorf("failed to close stream to file '%s': %w", tmpFileName, err)
}
if n == 0 {
// Delete the temporary file; we ignore errors here
_ = os.Remove(tmpFileName)
return errors.New("no data written")
}
// Rename to the final file, which overrides existing files
// This is an atomic operation
err = os.Rename(tmpFileName, dstFileName)
if err != nil {
// Delete the temporary file; we ignore errors here
_ = os.Remove(tmpFileName)
return fmt.Errorf("failed to rename file '%s': %w", dstFileName, err)
}
return nil
}
// FileExists returns true if a file exists on disk and is a regular file

View File

@@ -3,22 +3,24 @@ package profilepicture
import (
"bytes"
"fmt"
"github.com/disintegration/imageorient"
"github.com/disintegration/imaging"
"github.com/pocket-id/pocket-id/backend/resources"
"golang.org/x/image/font"
"golang.org/x/image/font/opentype"
"golang.org/x/image/math/fixed"
"image"
"image/color"
"io"
"strings"
"github.com/disintegration/imageorient"
"github.com/disintegration/imaging"
"golang.org/x/image/font"
"golang.org/x/image/font/opentype"
"golang.org/x/image/math/fixed"
"github.com/pocket-id/pocket-id/backend/resources"
)
const profilePictureSize = 300
// CreateProfilePicture resizes the profile picture to a square
func CreateProfilePicture(file io.Reader) (*bytes.Buffer, error) {
func CreateProfilePicture(file io.Reader) (io.Reader, error) {
img, _, err := imageorient.Decode(file)
if err != nil {
return nil, fmt.Errorf("failed to decode image: %w", err)
@@ -26,13 +28,17 @@ func CreateProfilePicture(file io.Reader) (*bytes.Buffer, error) {
img = imaging.Fill(img, profilePictureSize, profilePictureSize, imaging.Center, imaging.Lanczos)
var buf bytes.Buffer
err = imaging.Encode(&buf, img, imaging.PNG)
pr, pw := io.Pipe()
go func() {
err = imaging.Encode(pw, img, imaging.PNG)
if err != nil {
return nil, fmt.Errorf("failed to encode image: %v", err)
_ = pw.CloseWithError(fmt.Errorf("failed to encode image: %v", err))
return
}
pw.Close()
}()
return &buf, nil
return pr, nil
}
// CreateDefaultProfilePicture creates a profile picture with the initials