fix: ensure file descriptors are closed + other bugs (#413)

This commit is contained in:
Alessandro (Ale) Segala
2025-04-04 01:04:36 -07:00
committed by GitHub
parent 980780e48b
commit 2f7646105e
6 changed files with 52 additions and 15 deletions

View File

@@ -250,6 +250,9 @@ func (uc *UserController) getUserProfilePictureHandler(c *gin.Context) {
_ = c.Error(err)
return
}
if picture != nil {
defer picture.Close()
}
c.Header("Cache-Control", "public, max-age=300")

View File

@@ -38,9 +38,9 @@ func (j *FileCleanupJobs) clearUnusedDefaultProfilePictures() error {
}
// Create a map to track which initials are in use
initialsInUse := make(map[string]bool)
initialsInUse := make(map[string]struct{})
for _, user := range users {
initialsInUse[user.Initials()] = true
initialsInUse[user.Initials()] = struct{}{}
}
defaultPicturesDir := common.EnvConfig.UploadPath + "/profile-pictures/defaults"
@@ -63,7 +63,7 @@ func (j *FileCleanupJobs) clearUnusedDefaultProfilePictures() error {
initials := strings.TrimSuffix(filename, ".png")
// If these initials aren't used by any user, delete the file
if !initialsInUse[initials] {
if _, ok := initialsInUse[initials]; !ok {
filePath := filepath.Join(defaultPicturesDir, filename)
if err := os.Remove(filePath); err != nil {
log.Printf("Failed to delete unused default profile picture %s: %v", filePath, err)

View File

@@ -6,6 +6,7 @@ import (
"github.com/go-webauthn/webauthn/protocol"
"github.com/go-webauthn/webauthn/webauthn"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/utils"
)
type User struct {
@@ -66,14 +67,9 @@ func (u User) WebAuthnCredentialDescriptors() (descriptors []protocol.Credential
func (u User) FullName() string { return u.FirstName + " " + u.LastName }
func (u User) Initials() string {
initials := ""
if len(u.FirstName) > 0 {
initials += string(u.FirstName[0])
}
if len(u.LastName) > 0 {
initials += string(u.LastName[0])
}
return strings.ToUpper(initials)
return strings.ToUpper(
utils.GetFirstCharacter(u.FirstName) + utils.GetFirstCharacter(u.LastName),
)
}
type OneTimeAccessToken struct {

View File

@@ -54,7 +54,7 @@ func (s *UserService) GetUser(userID string) (model.User, error) {
return user, err
}
func (s *UserService) GetProfilePicture(userID string) (io.Reader, int64, error) {
func (s *UserService) GetProfilePicture(userID string) (io.ReadCloser, int64, error) {
// Validate the user ID to prevent directory traversal
if err := uuid.Validate(userID); err != nil {
return nil, 0, &common.InvalidUUIDError{}
@@ -99,7 +99,7 @@ func (s *UserService) GetProfilePicture(userID string) (io.Reader, int64, error)
}
// Save the default picture for future use (in a goroutine to avoid blocking)
defaultPictureCopy := bytes.NewBuffer(defaultPicture.Bytes())
defaultPictureBytes := defaultPicture.Bytes()
go func() {
// Ensure the directory exists
err = os.MkdirAll(defaultProfilePicturesDir, os.ModePerm)
@@ -107,12 +107,12 @@ func (s *UserService) GetProfilePicture(userID string) (io.Reader, int64, error)
log.Printf("Failed to create directory for default profile picture: %v", err)
return
}
if err := utils.SaveFileStream(defaultPictureCopy, defaultPicturePath); err != nil {
if err := utils.SaveFileStream(bytes.NewReader(defaultPictureBytes), defaultPicturePath); err != nil {
log.Printf("Failed to cache default profile picture for initials %s: %v", user.Initials(), err)
}
}()
return defaultPicture, int64(defaultPicture.Len()), nil
return io.NopCloser(bytes.NewReader(defaultPictureBytes)), int64(defaultPicture.Len()), nil
}
func (s *UserService) GetUserGroups(userID string) ([]model.UserGroup, error) {

View File

@@ -102,3 +102,16 @@ func CamelCaseToScreamingSnakeCase(s string) string {
// Convert to uppercase
return strings.ToUpper(snake)
}
// GetFirstCharacter returns the first non-whitespace character of the string, correctly handling Unicode
func GetFirstCharacter(str string) string {
for _, c := range str {
if unicode.IsSpace(c) {
continue
}
return string(c)
}
// Empty string case
return ""
}

View File

@@ -103,3 +103,28 @@ func TestCamelCaseToSnakeCase(t *testing.T) {
})
}
}
func TestGetFirstCharacter(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"empty string", "", ""},
{"single character", "a", "a"},
{"multiple characters", "hello", "h"},
{"unicode character", "étoile", "é"},
{"special character", "!test", "!"},
{"number as first character", "123abc", "1"},
{"whitespace as first character", " hello", "h"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetFirstCharacter(tt.input)
if result != tt.expected {
t.Errorf("GetFirstCharacter(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}