refactor: use atomic renames for uploaded files (#372)

Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
Alessandro (Ale) Segala
2025-03-23 13:21:44 -07:00
committed by GitHub
parent b8dcda8049
commit 26b2de4f00
4 changed files with 108 additions and 37 deletions

View File

@@ -1,5 +1,9 @@
package model package model
import (
"strconv"
)
type AppConfigVariable struct { type AppConfigVariable struct {
Key string `gorm:"primaryKey;not null"` Key string `gorm:"primaryKey;not null"`
Type string Type string
@@ -9,6 +13,11 @@ type AppConfigVariable struct {
DefaultValue string DefaultValue string
} }
func (a *AppConfigVariable) IsTrue() bool {
ok, _ := strconv.ParseBool(a.Value)
return ok
}
type AppConfig struct { type AppConfig struct {
// General // General
AppName AppConfigVariable AppName AppConfigVariable

View File

@@ -59,7 +59,7 @@ func (s *UserService) GetProfilePicture(userID string) (io.Reader, int64, error)
return nil, 0, &common.InvalidUUIDError{} return nil, 0, &common.InvalidUUIDError{}
} }
profilePicturePath := fmt.Sprintf("%s/profile-pictures/%s.png", common.EnvConfig.UploadPath, userID) profilePicturePath := common.EnvConfig.UploadPath + "/profile-pictures/" + userID + ".png"
file, err := os.Open(profilePicturePath) file, err := os.Open(profilePicturePath)
if err == nil { if err == nil {
// Get the file size // Get the file size
@@ -94,7 +94,8 @@ func (s *UserService) GetUserGroups(userID string) ([]model.UserGroup, error) {
func (s *UserService) UpdateProfilePicture(userID string, file io.Reader) error { func (s *UserService) UpdateProfilePicture(userID string, file io.Reader) error {
// Validate the user ID to prevent directory traversal // Validate the user ID to prevent directory traversal
if err := uuid.Validate(userID); err != nil { err := uuid.Validate(userID)
if err != nil {
return &common.InvalidUUIDError{} return &common.InvalidUUIDError{}
} }
@@ -105,20 +106,14 @@ func (s *UserService) UpdateProfilePicture(userID string, file io.Reader) error
} }
// Ensure the directory exists // Ensure the directory exists
profilePictureDir := fmt.Sprintf("%s/profile-pictures", common.EnvConfig.UploadPath) profilePictureDir := common.EnvConfig.UploadPath + "/profile-pictures"
if err := os.MkdirAll(profilePictureDir, os.ModePerm); err != nil { err = os.MkdirAll(profilePictureDir, os.ModePerm)
if err != nil {
return err return err
} }
// Create the profile picture file // Create the profile picture file
createdProfilePicture, err := os.Create(fmt.Sprintf("%s/%s.png", profilePictureDir, userID)) err = utils.SaveFileStream(profilePicture, profilePictureDir+"/"+userID+".png")
if err != nil {
return err
}
defer createdProfilePicture.Close()
// Copy the image to the file
_, err = io.Copy(createdProfilePicture, profilePicture)
if err != nil { if err != nil {
return err return err
} }
@@ -133,12 +128,12 @@ func (s *UserService) DeleteUser(userID string) error {
} }
// Disallow deleting the user if it is an LDAP user and LDAP is enabled // Disallow deleting the user if it is an LDAP user and LDAP is enabled
if user.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.Value == "true" { if user.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.IsTrue() {
return &common.LdapUserUpdateError{} return &common.LdapUserUpdateError{}
} }
// Delete the profile picture // Delete the profile picture
profilePicturePath := fmt.Sprintf("%s/profile-pictures/%s.png", common.EnvConfig.UploadPath, userID) profilePicturePath := common.EnvConfig.UploadPath + "/profile-pictures/" + userID + ".png"
if err := os.Remove(profilePicturePath); err != nil && !os.IsNotExist(err) { if err := os.Remove(profilePicturePath); err != nil && !os.IsNotExist(err) {
return err return err
} }
@@ -175,7 +170,7 @@ func (s *UserService) UpdateUser(userID string, updatedUser dto.UserCreateDto, u
} }
// Disallow updating the user if it is an LDAP group and LDAP is enabled // Disallow updating the user if it is an LDAP group and LDAP is enabled
if !allowLdapUpdate && user.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.Value == "true" { if !allowLdapUpdate && user.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.IsTrue() {
return model.User{}, &common.LdapUserUpdateError{} return model.User{}, &common.LdapUserUpdateError{}
} }
@@ -199,7 +194,7 @@ func (s *UserService) UpdateUser(userID string, updatedUser dto.UserCreateDto, u
} }
func (s *UserService) RequestOneTimeAccessEmail(emailAddress, redirectPath string) error { func (s *UserService) RequestOneTimeAccessEmail(emailAddress, redirectPath string) error {
isDisabled := s.appConfigService.DbConfig.EmailOneTimeAccessEnabled.Value != "true" isDisabled := !s.appConfigService.DbConfig.EmailOneTimeAccessEnabled.IsTrue()
if isDisabled { if isDisabled {
return &common.OneTimeAccessDisabledError{} return &common.OneTimeAccessDisabledError{}
} }
@@ -376,7 +371,7 @@ func (s *UserService) ResetProfilePicture(userID string) error {
} }
// Build path to profile picture // Build path to profile picture
profilePicturePath := fmt.Sprintf("%s/profile-pictures/%s.png", common.EnvConfig.UploadPath, userID) profilePicturePath := common.EnvConfig.UploadPath + "/profile-pictures/" + userID + ".png"
// Check if file exists and delete it // Check if file exists and delete it
if _, err := os.Stat(profilePicturePath); err == nil { if _, err := os.Stat(profilePicturePath); err == nil {

View File

@@ -1,10 +1,15 @@
package utils package utils
import ( import (
"errors"
"fmt"
"hash/crc64"
"io" "io"
"mime/multipart" "mime/multipart"
"os" "os"
"path/filepath" "path/filepath"
"strconv"
"time"
"github.com/pocket-id/pocket-id/backend/resources" "github.com/pocket-id/pocket-id/backend/resources"
) )
@@ -69,14 +74,70 @@ func SaveFile(file *multipart.FileHeader, dst string) error {
return err return err
} }
out, err := os.Create(dst) return SaveFileStream(src, dst)
if err != nil { }
return err
}
defer out.Close()
_, err = io.Copy(out, src) // SaveFileStream saves a stream to a file.
return err func SaveFileStream(r io.Reader, dstFileName string) error {
// Our strategy is to save to a separate file and then rename it to override the original file
// First, get a temp file name that doesn't exist already
var tmpFileName string
var i int64
for {
seed := strconv.FormatInt(time.Now().UnixNano()+i, 10)
suffix := crc64.Checksum([]byte(dstFileName+seed), crc64.MakeTable(crc64.ISO))
tmpFileName = dstFileName + "." + strconv.FormatUint(suffix, 10)
exists, err := FileExists(tmpFileName)
if err != nil {
return fmt.Errorf("failed to check if file '%s' exists: %w", tmpFileName, err)
}
if !exists {
break
}
i++
}
// Write to the temporary file
tmpFile, err := os.Create(tmpFileName)
if err != nil {
return fmt.Errorf("failed to open file '%s' for writing: %w", tmpFileName, err)
}
n, err := io.Copy(tmpFile, r)
if err != nil {
// Delete the temporary file; we ignore errors here
_ = tmpFile.Close()
_ = os.Remove(tmpFileName)
return fmt.Errorf("failed to write to file '%s': %w", tmpFileName, err)
}
err = tmpFile.Close()
if err != nil {
// Delete the temporary file; we ignore errors here
_ = os.Remove(tmpFileName)
return fmt.Errorf("failed to close stream to file '%s': %w", tmpFileName, err)
}
if n == 0 {
// Delete the temporary file; we ignore errors here
_ = os.Remove(tmpFileName)
return errors.New("no data written")
}
// Rename to the final file, which overrides existing files
// This is an atomic operation
err = os.Rename(tmpFileName, dstFileName)
if err != nil {
// Delete the temporary file; we ignore errors here
_ = os.Remove(tmpFileName)
return fmt.Errorf("failed to rename file '%s': %w", dstFileName, err)
}
return nil
} }
// FileExists returns true if a file exists on disk and is a regular file // FileExists returns true if a file exists on disk and is a regular file

View File

@@ -3,22 +3,24 @@ package profilepicture
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"github.com/disintegration/imageorient"
"github.com/disintegration/imaging"
"github.com/pocket-id/pocket-id/backend/resources"
"golang.org/x/image/font"
"golang.org/x/image/font/opentype"
"golang.org/x/image/math/fixed"
"image" "image"
"image/color" "image/color"
"io" "io"
"strings" "strings"
"github.com/disintegration/imageorient"
"github.com/disintegration/imaging"
"golang.org/x/image/font"
"golang.org/x/image/font/opentype"
"golang.org/x/image/math/fixed"
"github.com/pocket-id/pocket-id/backend/resources"
) )
const profilePictureSize = 300 const profilePictureSize = 300
// CreateProfilePicture resizes the profile picture to a square // CreateProfilePicture resizes the profile picture to a square
func CreateProfilePicture(file io.Reader) (*bytes.Buffer, error) { func CreateProfilePicture(file io.Reader) (io.Reader, error) {
img, _, err := imageorient.Decode(file) img, _, err := imageorient.Decode(file)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to decode image: %w", err) return nil, fmt.Errorf("failed to decode image: %w", err)
@@ -26,13 +28,17 @@ func CreateProfilePicture(file io.Reader) (*bytes.Buffer, error) {
img = imaging.Fill(img, profilePictureSize, profilePictureSize, imaging.Center, imaging.Lanczos) img = imaging.Fill(img, profilePictureSize, profilePictureSize, imaging.Center, imaging.Lanczos)
var buf bytes.Buffer pr, pw := io.Pipe()
err = imaging.Encode(&buf, img, imaging.PNG) go func() {
if err != nil { err = imaging.Encode(pw, img, imaging.PNG)
return nil, fmt.Errorf("failed to encode image: %v", err) if err != nil {
} _ = pw.CloseWithError(fmt.Errorf("failed to encode image: %v", err))
return
}
pw.Close()
}()
return &buf, nil return pr, nil
} }
// CreateDefaultProfilePicture creates a profile picture with the initials // CreateDefaultProfilePicture creates a profile picture with the initials