mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-17 02:33:02 +03:00
feat: add database storage backend (#1091)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
committed by
GitHub
parent
12125713a2
commit
29a1d3b778
83
.github/workflows/e2e-tests.yml
vendored
83
.github/workflows/e2e-tests.yml
vendored
@@ -57,7 +57,17 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
db: [sqlite, postgres, sqlite-s3]
|
||||
include:
|
||||
- db: sqlite
|
||||
storage: fs
|
||||
- db: postgres
|
||||
storage: fs
|
||||
- db: sqlite
|
||||
storage: s3
|
||||
- db: sqlite
|
||||
storage: database
|
||||
- db: postgres
|
||||
storage: database
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
@@ -71,65 +81,74 @@ jobs:
|
||||
node-version: 22
|
||||
|
||||
- name: Cache Playwright Browsers
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
id: playwright-cache
|
||||
with:
|
||||
path: ~/.cache/ms-playwright
|
||||
key: ${{ runner.os }}-playwright-${{ hashFiles('pnpm-lock.yaml') }}
|
||||
|
||||
- name: Cache PostgreSQL Docker image
|
||||
if: matrix.db == 'postgres'
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
id: postgres-cache
|
||||
with:
|
||||
path: /tmp/postgres-image.tar
|
||||
key: postgres-17-${{ runner.os }}
|
||||
|
||||
- name: Pull and save PostgreSQL image
|
||||
if: matrix.db == 'postgres' && steps.postgres-cache.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
docker pull postgres:17
|
||||
docker save postgres:17 > /tmp/postgres-image.tar
|
||||
|
||||
- name: Load PostgreSQL image
|
||||
if: matrix.db == 'postgres' && steps.postgres-cache.outputs.cache-hit == 'true'
|
||||
run: docker load < /tmp/postgres-image.tar
|
||||
|
||||
- name: Cache LLDAP Docker image
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
id: lldap-cache
|
||||
with:
|
||||
path: /tmp/lldap-image.tar
|
||||
key: lldap-stable-${{ runner.os }}
|
||||
|
||||
- name: Pull and save LLDAP image
|
||||
if: steps.lldap-cache.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
docker pull nitnelave/lldap:stable
|
||||
docker save nitnelave/lldap:stable > /tmp/lldap-image.tar
|
||||
|
||||
docker pull lldap/lldap:2025-05-19
|
||||
docker save lldap/lldap:2025-05-19 > /tmp/lldap-image.tar
|
||||
- name: Load LLDAP image
|
||||
if: steps.lldap-cache.outputs.cache-hit == 'true'
|
||||
run: docker load < /tmp/lldap-image.tar
|
||||
|
||||
- name: Cache Localstack S3 Docker image
|
||||
if: matrix.db == 'sqlite-s3'
|
||||
uses: actions/cache@v3
|
||||
if: matrix.storage == 's3'
|
||||
uses: actions/cache@v4
|
||||
id: s3-cache
|
||||
with:
|
||||
path: /tmp/localstack-s3-image.tar
|
||||
key: localstack-s3-latest-${{ runner.os }}
|
||||
|
||||
- name: Pull and save Localstack S3 image
|
||||
if: matrix.db == 'sqlite-s3' && steps.s3-cache.outputs.cache-hit != 'true'
|
||||
if: matrix.storage == 's3' && steps.s3-cache.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
docker pull localstack/localstack:s3-latest
|
||||
docker save localstack/localstack:s3-latest > /tmp/localstack-s3-image.tar
|
||||
|
||||
- name: Load Localstack S3 image
|
||||
if: matrix.db == 'sqlite-s3' && steps.s3-cache.outputs.cache-hit == 'true'
|
||||
if: matrix.storage == 's3' && steps.s3-cache.outputs.cache-hit == 'true'
|
||||
run: docker load < /tmp/localstack-s3-image.tar
|
||||
|
||||
- name: Cache AWS CLI Docker image
|
||||
if: matrix.storage == 's3'
|
||||
uses: actions/cache@v4
|
||||
id: aws-cli-cache
|
||||
with:
|
||||
path: /tmp/aws-cli-image.tar
|
||||
key: aws-cli-latest-${{ runner.os }}
|
||||
- name: Pull and save AWS CLI image
|
||||
if: matrix.storage == 's3' && steps.aws-cli-cache.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
docker pull amazon/aws-cli:latest
|
||||
docker save amazon/aws-cli:latest > /tmp/aws-cli-image.tar
|
||||
- name: Load AWS CLI image
|
||||
if: matrix.storage == 's3' && steps.aws-cli-cache.outputs.cache-hit == 'true'
|
||||
run: docker load < /tmp/aws-cli-image.tar
|
||||
|
||||
- name: Download Docker image artifact
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
@@ -147,26 +166,20 @@ jobs:
|
||||
if: steps.playwright-cache.outputs.cache-hit != 'true'
|
||||
run: pnpm exec playwright install --with-deps chromium
|
||||
|
||||
- name: Run Docker Container (sqlite) with LDAP
|
||||
if: matrix.db == 'sqlite'
|
||||
- name: Run Docker containers
|
||||
working-directory: ./tests/setup
|
||||
run: |
|
||||
docker compose up -d
|
||||
docker compose logs -f pocket-id &> /tmp/backend.log &
|
||||
DOCKER_COMPOSE_FILE=docker-compose.yml
|
||||
|
||||
- name: Run Docker Container (postgres) with LDAP
|
||||
if: matrix.db == 'postgres'
|
||||
working-directory: ./tests/setup
|
||||
run: |
|
||||
docker compose -f docker-compose-postgres.yml up -d
|
||||
docker compose -f docker-compose-postgres.yml logs -f pocket-id &> /tmp/backend.log &
|
||||
export FILE_BACKEND="${{ matrix.storage }}"
|
||||
if [ "${{ matrix.db }}" = "postgres" ]; then
|
||||
DOCKER_COMPOSE_FILE=docker-compose-postgres.yml
|
||||
elif [ "${{ matrix.storage }}" = "s3" ]; then
|
||||
DOCKER_COMPOSE_FILE=docker-compose-s3.yml
|
||||
fi
|
||||
|
||||
- name: Run Docker Container (sqlite-s3) with LDAP + S3
|
||||
if: matrix.db == 'sqlite-s3'
|
||||
working-directory: ./tests/setup
|
||||
run: |
|
||||
docker compose -f docker-compose-s3.yml up -d
|
||||
docker compose -f docker-compose-s3.yml logs -f pocket-id &> /tmp/backend.log &
|
||||
docker compose -f "$DOCKER_COMPOSE_FILE" up -d
|
||||
docker compose -f "$DOCKER_COMPOSE_FILE" logs -f pocket-id &> /tmp/backend.log &
|
||||
|
||||
- name: Run Playwright tests
|
||||
working-directory: ./tests
|
||||
@@ -176,7 +189,7 @@ jobs:
|
||||
uses: actions/upload-artifact@v4
|
||||
if: always() && github.event.pull_request.head.ref != 'i18n_crowdin'
|
||||
with:
|
||||
name: playwright-report-${{ matrix.db }}
|
||||
name: playwright-report-${{ matrix.db }}-${{ matrix.storage }}
|
||||
path: tests/.report
|
||||
include-hidden-files: true
|
||||
retention-days: 15
|
||||
@@ -185,7 +198,7 @@ jobs:
|
||||
uses: actions/upload-artifact@v4
|
||||
if: always() && github.event.pull_request.head.ref != 'i18n_crowdin'
|
||||
with:
|
||||
name: backend-${{ matrix.db }}
|
||||
name: backend-${{ matrix.db }}-${{ matrix.storage }}
|
||||
path: /tmp/backend.log
|
||||
include-hidden-files: true
|
||||
retention-days: 15
|
||||
|
||||
@@ -22,12 +22,20 @@ func Bootstrap(ctx context.Context) error {
|
||||
}
|
||||
slog.InfoContext(ctx, "Pocket ID is starting")
|
||||
|
||||
// Connect to the database
|
||||
db, err := NewDatabase()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize database: %w", err)
|
||||
}
|
||||
|
||||
// Initialize the file storage backend
|
||||
var fileStorage storage.FileStorage
|
||||
|
||||
switch common.EnvConfig.FileBackend {
|
||||
case storage.TypeFileSystem:
|
||||
fileStorage, err = storage.NewFilesystemStorage(common.EnvConfig.UploadPath)
|
||||
case storage.TypeDatabase:
|
||||
fileStorage, err = storage.NewDatabaseStorage(db)
|
||||
case storage.TypeS3:
|
||||
s3Cfg := storage.S3Config{
|
||||
Bucket: common.EnvConfig.S3Bucket,
|
||||
@@ -43,7 +51,7 @@ func Bootstrap(ctx context.Context) error {
|
||||
err = fmt.Errorf("unknown file storage backend: %s", common.EnvConfig.FileBackend)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize file storage: %w", err)
|
||||
return fmt.Errorf("failed to initialize file storage (backend: %s): %w", common.EnvConfig.FileBackend, err)
|
||||
}
|
||||
|
||||
imageExtensions, err := initApplicationImages(ctx, fileStorage)
|
||||
@@ -51,12 +59,6 @@ func Bootstrap(ctx context.Context) error {
|
||||
return fmt.Errorf("failed to initialize application images: %w", err)
|
||||
}
|
||||
|
||||
// Connect to the database
|
||||
db, err := NewDatabase()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize database: %w", err)
|
||||
}
|
||||
|
||||
// Create all services
|
||||
svc, err := initServices(ctx, db, httpClient, imageExtensions, fileStorage)
|
||||
if err != nil {
|
||||
|
||||
@@ -66,7 +66,7 @@ func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client, ima
|
||||
|
||||
svc.userGroupService = service.NewUserGroupService(db, svc.appConfigService)
|
||||
svc.userService = service.NewUserService(db, svc.jwtService, svc.auditLogService, svc.emailService, svc.appConfigService, svc.customClaimService, svc.appImagesService, fileStorage)
|
||||
svc.ldapService = service.NewLdapService(db, httpClient, svc.appConfigService, svc.userService, svc.userGroupService)
|
||||
svc.ldapService = service.NewLdapService(db, httpClient, svc.appConfigService, svc.userService, svc.userGroupService, fileStorage)
|
||||
svc.apiKeyService = service.NewApiKeyService(db, svc.emailService)
|
||||
|
||||
svc.versionService = service.NewVersionService(httpClient)
|
||||
|
||||
@@ -180,12 +180,14 @@ func validateEnvConfig(config *EnvConfigSchema) error {
|
||||
if config.KeysStorage == "file" {
|
||||
return errors.New("KEYS_STORAGE cannot be 'file' when FILE_BACKEND is 's3'")
|
||||
}
|
||||
case "database":
|
||||
// All good, these are valid values
|
||||
case "", "fs":
|
||||
if config.UploadPath == "" {
|
||||
config.UploadPath = defaultFsUploadPath
|
||||
}
|
||||
default:
|
||||
return errors.New("invalid FILE_BACKEND value. Must be 'fs' or 's3'")
|
||||
return errors.New("invalid FILE_BACKEND value. Must be 'fs', 'database', or 's3'")
|
||||
}
|
||||
|
||||
// Validate LOCAL_IPV6_RANGES
|
||||
|
||||
@@ -587,7 +587,6 @@ func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
|
||||
}
|
||||
|
||||
// deleteClientLogoHandler godoc
|
||||
@@ -614,7 +613,6 @@ func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
|
||||
}
|
||||
|
||||
// updateAllowedUserGroupsHandler godoc
|
||||
|
||||
17
backend/internal/model/storage.go
Normal file
17
backend/internal/model/storage.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
)
|
||||
|
||||
type Storage struct {
|
||||
Path string `gorm:"primaryKey"`
|
||||
Data []byte
|
||||
Size int64
|
||||
ModTime datatype.DateTime
|
||||
CreatedAt datatype.DateTime
|
||||
}
|
||||
|
||||
func (Storage) TableName() string {
|
||||
return "storage"
|
||||
}
|
||||
@@ -426,7 +426,8 @@ func (s *TestService) ResetDatabase() error {
|
||||
}
|
||||
|
||||
func (s *TestService) ResetApplicationImages(ctx context.Context) error {
|
||||
if err := s.fileStorage.DeleteAll(ctx, "/"); err != nil {
|
||||
err := s.fileStorage.DeleteAll(ctx, "/")
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx, "Error removing uploads", slog.Any("error", err))
|
||||
return err
|
||||
}
|
||||
@@ -445,7 +446,8 @@ func (s *TestService) ResetApplicationImages(ctx context.Context) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.fileStorage.Save(ctx, path.Join("application-images", file.Name()), srcFile); err != nil {
|
||||
err = s.fileStorage.Save(ctx, path.Join("application-images", file.Name()), srcFile)
|
||||
if err != nil {
|
||||
srcFile.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -11,12 +11,14 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/go-ldap/ldap/v3"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/storage"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
"golang.org/x/text/unicode/norm"
|
||||
"gorm.io/gorm"
|
||||
@@ -32,15 +34,23 @@ type LdapService struct {
|
||||
appConfigService *AppConfigService
|
||||
userService *UserService
|
||||
groupService *UserGroupService
|
||||
fileStorage storage.FileStorage
|
||||
}
|
||||
|
||||
func NewLdapService(db *gorm.DB, httpClient *http.Client, appConfigService *AppConfigService, userService *UserService, groupService *UserGroupService) *LdapService {
|
||||
type savePicture struct {
|
||||
userID string
|
||||
username string
|
||||
picture string
|
||||
}
|
||||
|
||||
func NewLdapService(db *gorm.DB, httpClient *http.Client, appConfigService *AppConfigService, userService *UserService, groupService *UserGroupService, fileStorage storage.FileStorage) *LdapService {
|
||||
return &LdapService{
|
||||
db: db,
|
||||
httpClient: httpClient,
|
||||
appConfigService: appConfigService,
|
||||
userService: userService,
|
||||
groupService: groupService,
|
||||
fileStorage: fileStorage,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,12 +78,6 @@ func (s *LdapService) createClient() (*ldap.Conn, error) {
|
||||
}
|
||||
|
||||
func (s *LdapService) SyncAll(ctx context.Context) error {
|
||||
// Start a transaction
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
// Setup LDAP connection
|
||||
client, err := s.createClient()
|
||||
if err != nil {
|
||||
@@ -81,7 +85,13 @@ func (s *LdapService) SyncAll(ctx context.Context) error {
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
err = s.SyncUsers(ctx, tx, client)
|
||||
// Start a transaction
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
savePictures, deleteFiles, err := s.SyncUsers(ctx, tx, client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sync users: %w", err)
|
||||
}
|
||||
@@ -97,6 +107,25 @@ func (s *LdapService) SyncAll(ctx context.Context) error {
|
||||
return fmt.Errorf("failed to commit changes to database: %w", err)
|
||||
}
|
||||
|
||||
// Now that we've committed the transaction, we can perform operations on the storage layer
|
||||
// First, save all new pictures
|
||||
for _, sp := range savePictures {
|
||||
err = s.saveProfilePicture(ctx, sp.userID, sp.picture)
|
||||
if err != nil {
|
||||
// This is not a fatal error
|
||||
slog.Warn("Error saving profile picture for LDAP user", slog.String("username", sp.username), slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
|
||||
// Delete all old files
|
||||
for _, path := range deleteFiles {
|
||||
err = s.fileStorage.Delete(ctx, path)
|
||||
if err != nil {
|
||||
// This is not a fatal error
|
||||
slog.Error("Failed to delete file after LDAP sync", slog.String("path", path), slog.Any("error", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -266,7 +295,7 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.Conn) error {
|
||||
func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.Conn) (savePictures []savePicture, deleteFiles []string, err error) {
|
||||
dbConfig := s.appConfigService.GetDbConfig()
|
||||
|
||||
searchAttrs := []string{
|
||||
@@ -294,11 +323,12 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
||||
|
||||
result, err := client.Search(searchReq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query LDAP: %w", err)
|
||||
return nil, nil, fmt.Errorf("failed to query LDAP: %w", err)
|
||||
}
|
||||
|
||||
// Create a mapping for users that exist
|
||||
ldapUserIDs := make(map[string]struct{}, len(result.Entries))
|
||||
savePictures = make([]savePicture, 0, len(result.Entries))
|
||||
|
||||
for _, value := range result.Entries {
|
||||
ldapId := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeUserUniqueIdentifier.Value))
|
||||
@@ -329,13 +359,13 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
||||
Error
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to enable user %s: %w", databaseUser.Username, err)
|
||||
return nil, nil, fmt.Errorf("failed to enable user %s: %w", databaseUser.Username, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// This could error with ErrRecordNotFound and we want to ignore that here
|
||||
return fmt.Errorf("failed to query for LDAP user ID '%s': %w", ldapId, err)
|
||||
return nil, nil, fmt.Errorf("failed to query for LDAP user ID '%s': %w", ldapId, err)
|
||||
}
|
||||
|
||||
// Check if user is admin by checking if they are in the admin group
|
||||
@@ -369,32 +399,35 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
||||
continue
|
||||
}
|
||||
|
||||
userID := databaseUser.ID
|
||||
if databaseUser.ID == "" {
|
||||
_, err = s.userService.createUserInternal(ctx, newUser, true, tx)
|
||||
createdUser, err := s.userService.createUserInternal(ctx, newUser, true, tx)
|
||||
if errors.Is(err, &common.AlreadyInUseError{}) {
|
||||
slog.Warn("Skipping creating LDAP user", slog.String("username", newUser.Username), slog.Any("error", err))
|
||||
continue
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("error creating user '%s': %w", newUser.Username, err)
|
||||
return nil, nil, fmt.Errorf("error creating user '%s': %w", newUser.Username, err)
|
||||
}
|
||||
userID = createdUser.ID
|
||||
} else {
|
||||
_, err = s.userService.updateUserInternal(ctx, databaseUser.ID, newUser, false, true, tx)
|
||||
if errors.Is(err, &common.AlreadyInUseError{}) {
|
||||
slog.Warn("Skipping updating LDAP user", slog.String("username", newUser.Username), slog.Any("error", err))
|
||||
continue
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("error updating user '%s': %w", newUser.Username, err)
|
||||
return nil, nil, fmt.Errorf("error updating user '%s': %w", newUser.Username, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Save profile picture
|
||||
pictureString := value.GetAttributeValue(dbConfig.LdapAttributeUserProfilePicture.Value)
|
||||
if pictureString != "" {
|
||||
err = s.saveProfilePicture(ctx, databaseUser.ID, pictureString)
|
||||
if err != nil {
|
||||
// This is not a fatal error
|
||||
slog.Warn("Error saving profile picture for user", slog.String("username", newUser.Username), slog.Any("error", err))
|
||||
}
|
||||
// Storage operations must be executed outside of a transaction
|
||||
savePictures = append(savePictures, savePicture{
|
||||
userID: databaseUser.ID,
|
||||
username: userID,
|
||||
picture: pictureString,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -406,10 +439,11 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
||||
Select("id, username, ldap_id, disabled").
|
||||
Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch users from database: %w", err)
|
||||
return nil, nil, fmt.Errorf("failed to fetch users from database: %w", err)
|
||||
}
|
||||
|
||||
// Mark users as disabled or delete users that no longer exist in LDAP
|
||||
deleteFiles = make([]string, 0, len(ldapUserIDs))
|
||||
for _, user := range ldapUsersInDb {
|
||||
// Skip if the user ID exists in the fetched LDAP results
|
||||
if _, exists := ldapUserIDs[*user.LdapID]; exists {
|
||||
@@ -417,26 +451,30 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
||||
}
|
||||
|
||||
if dbConfig.LdapSoftDeleteUsers.IsTrue() {
|
||||
err = s.userService.disableUserInternal(ctx, user.ID, tx)
|
||||
err = s.userService.disableUserInternal(ctx, tx, user.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to disable user %s: %w", user.Username, err)
|
||||
return nil, nil, fmt.Errorf("failed to disable user %s: %w", user.Username, err)
|
||||
}
|
||||
|
||||
slog.Info("Disabled user", slog.String("username", user.Username))
|
||||
} else {
|
||||
err = s.userService.deleteUserInternal(ctx, user.ID, true, tx)
|
||||
err = s.userService.deleteUserInternal(ctx, tx, user.ID, true)
|
||||
if err != nil {
|
||||
target := &common.LdapUserUpdateError{}
|
||||
if errors.As(err, &target) {
|
||||
return fmt.Errorf("failed to delete user %s: LDAP user must be disabled before deletion", user.Username)
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to delete user %s: %w", user.Username, err)
|
||||
return nil, nil, fmt.Errorf("failed to delete user %s: LDAP user must be disabled before deletion", user.Username)
|
||||
}
|
||||
return nil, nil, fmt.Errorf("failed to delete user %s: %w", user.Username, err)
|
||||
}
|
||||
|
||||
slog.Info("Deleted user", slog.String("username", user.Username))
|
||||
|
||||
// Storage operations must be executed outside of a transaction
|
||||
deleteFiles = append(deleteFiles, path.Join("profile-pictures", user.ID+".png"))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return savePictures, deleteFiles, nil
|
||||
}
|
||||
|
||||
func (s *LdapService) saveProfilePicture(parentCtx context.Context, userId string, pictureString string) error {
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"mime/multipart"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
@@ -679,19 +678,21 @@ func (s *OidcService) introspectRefreshToken(ctx context.Context, clientID strin
|
||||
}
|
||||
|
||||
func (s *OidcService) GetClient(ctx context.Context, clientID string) (model.OidcClient, error) {
|
||||
return s.getClientInternal(ctx, clientID, s.db)
|
||||
return s.getClientInternal(ctx, clientID, s.db, false)
|
||||
}
|
||||
|
||||
func (s *OidcService) getClientInternal(ctx context.Context, clientID string, tx *gorm.DB) (model.OidcClient, error) {
|
||||
func (s *OidcService) getClientInternal(ctx context.Context, clientID string, tx *gorm.DB, forUpdate bool) (model.OidcClient, error) {
|
||||
var client model.OidcClient
|
||||
err := tx.
|
||||
q := tx.
|
||||
WithContext(ctx).
|
||||
Preload("CreatedBy").
|
||||
Preload("AllowedUserGroups").
|
||||
First(&client, "id = ?", clientID).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
Preload("AllowedUserGroups")
|
||||
if forUpdate {
|
||||
q = q.Clauses(clause.Locking{Strength: "UPDATE"})
|
||||
}
|
||||
q = q.First(&client, "id = ?", clientID)
|
||||
if q.Error != nil {
|
||||
return model.OidcClient{}, q.Error
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
@@ -724,11 +725,6 @@ func (s *OidcService) ListClients(ctx context.Context, name string, listRequestO
|
||||
}
|
||||
|
||||
func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
client := model.OidcClient{
|
||||
Base: model.Base{
|
||||
ID: input.ID,
|
||||
@@ -737,7 +733,7 @@ func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCrea
|
||||
}
|
||||
updateOIDCClientModelFromDto(&client, &input.OidcClientUpdateDto)
|
||||
|
||||
err := tx.
|
||||
err := s.db.
|
||||
WithContext(ctx).
|
||||
Create(&client).
|
||||
Error
|
||||
@@ -748,62 +744,65 @@ func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCrea
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
// All storage operations must be executed outside of a transaction
|
||||
if input.LogoURL != nil {
|
||||
err = s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.LogoURL, true)
|
||||
err = s.downloadAndSaveLogoFromURL(ctx, client.ID, *input.LogoURL, true)
|
||||
if err != nil {
|
||||
return model.OidcClient{}, fmt.Errorf("failed to download logo: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if input.DarkLogoURL != nil {
|
||||
err = s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.DarkLogoURL, false)
|
||||
err = s.downloadAndSaveLogoFromURL(ctx, client.ID, *input.DarkLogoURL, false)
|
||||
if err != nil {
|
||||
return model.OidcClient{}, fmt.Errorf("failed to download dark logo: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input dto.OidcClientUpdateDto) (model.OidcClient, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() { tx.Rollback() }()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
var client model.OidcClient
|
||||
if err := tx.WithContext(ctx).
|
||||
err := tx.WithContext(ctx).
|
||||
Preload("CreatedBy").
|
||||
First(&client, "id = ?", clientID).Error; err != nil {
|
||||
First(&client, "id = ?", clientID).Error
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
updateOIDCClientModelFromDto(&client, &input)
|
||||
|
||||
if err := tx.WithContext(ctx).Save(&client).Error; err != nil {
|
||||
err = tx.WithContext(ctx).Save(&client).Error
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
// All storage operations must be executed outside of a transaction
|
||||
if input.LogoURL != nil {
|
||||
err := s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.LogoURL, true)
|
||||
err = s.downloadAndSaveLogoFromURL(ctx, client.ID, *input.LogoURL, true)
|
||||
if err != nil {
|
||||
return model.OidcClient{}, fmt.Errorf("failed to download logo: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if input.DarkLogoURL != nil {
|
||||
err := s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.DarkLogoURL, false)
|
||||
err = s.downloadAndSaveLogoFromURL(ctx, client.ID, *input.DarkLogoURL, false)
|
||||
if err != nil {
|
||||
return model.OidcClient{}, fmt.Errorf("failed to download dark logo: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
@@ -836,12 +835,24 @@ func (s *OidcService) DeleteClient(ctx context.Context, clientID string) error {
|
||||
err := s.db.
|
||||
WithContext(ctx).
|
||||
Where("id = ?", clientID).
|
||||
Clauses(clause.Returning{}).
|
||||
Delete(&client).
|
||||
Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete images if present
|
||||
// Note that storage operations must be done outside of a transaction
|
||||
if client.ImageType != nil && *client.ImageType != "" {
|
||||
old := path.Join("oidc-client-images", client.ID+"."+*client.ImageType)
|
||||
_ = s.fileStorage.Delete(ctx, old)
|
||||
}
|
||||
if client.DarkImageType != nil && *client.DarkImageType != "" {
|
||||
old := path.Join("oidc-client-images", client.ID+"-dark."+*client.DarkImageType)
|
||||
_ = s.fileStorage.Delete(ctx, old)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -941,57 +952,12 @@ func (s *OidcService) UpdateClientLogo(ctx context.Context, clientID string, fil
|
||||
return err
|
||||
}
|
||||
defer reader.Close()
|
||||
if err := s.fileStorage.Save(ctx, imagePath, reader); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx := s.db.Begin()
|
||||
|
||||
err = s.updateClientLogoType(ctx, tx, clientID, fileType, light)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit().Error
|
||||
}
|
||||
|
||||
func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) error {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
var client model.OidcClient
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
First(&client, "id = ?", clientID).
|
||||
Error
|
||||
err = s.fileStorage.Save(ctx, imagePath, reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if client.ImageType == nil {
|
||||
return errors.New("image not found")
|
||||
}
|
||||
|
||||
oldImageType := *client.ImageType
|
||||
client.ImageType = nil
|
||||
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Save(&client).
|
||||
Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
imagePath := path.Join("oidc-client-images", client.ID+"."+oldImageType)
|
||||
if err := s.fileStorage.Delete(ctx, imagePath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
err = s.updateClientLogoType(ctx, clientID, fileType, light)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -999,7 +965,31 @@ func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) error {
|
||||
return s.deleteClientLogoInternal(ctx, clientID, "", func(client *model.OidcClient) (string, error) {
|
||||
if client.ImageType == nil {
|
||||
return "", errors.New("image not found")
|
||||
}
|
||||
|
||||
oldImageType := *client.ImageType
|
||||
client.ImageType = nil
|
||||
return oldImageType, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OidcService) DeleteClientDarkLogo(ctx context.Context, clientID string) error {
|
||||
return s.deleteClientLogoInternal(ctx, clientID, "-dark", func(client *model.OidcClient) (string, error) {
|
||||
if client.DarkImageType == nil {
|
||||
return "", errors.New("image not found")
|
||||
}
|
||||
|
||||
oldImageType := *client.DarkImageType
|
||||
client.DarkImageType = nil
|
||||
return oldImageType, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OidcService) deleteClientLogoInternal(ctx context.Context, clientID string, imagePathSuffix string, setClientImage func(*model.OidcClient) (string, error)) error {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
@@ -1014,13 +1004,11 @@ func (s *OidcService) DeleteClientDarkLogo(ctx context.Context, clientID string)
|
||||
return err
|
||||
}
|
||||
|
||||
if client.DarkImageType == nil {
|
||||
return errors.New("image not found")
|
||||
oldImageType, err := setClientImage(&client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldImageType := *client.DarkImageType
|
||||
client.DarkImageType = nil
|
||||
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Save(&client).
|
||||
@@ -1029,12 +1017,14 @@ func (s *OidcService) DeleteClientDarkLogo(ctx context.Context, clientID string)
|
||||
return err
|
||||
}
|
||||
|
||||
imagePath := path.Join("oidc-client-images", client.ID+"-dark."+oldImageType)
|
||||
if err := s.fileStorage.Delete(ctx, imagePath); err != nil {
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
// All storage operations must be performed outside of a database transaction
|
||||
imagePath := path.Join("oidc-client-images", client.ID+imagePathSuffix+"."+oldImageType)
|
||||
err = s.fileStorage.Delete(ctx, imagePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1048,7 +1038,7 @@ func (s *OidcService) UpdateAllowedUserGroups(ctx context.Context, id string, in
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
client, err = s.getClientInternal(ctx, id, tx)
|
||||
client, err = s.getClientInternal(ctx, id, tx, true)
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
@@ -1831,7 +1821,7 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
client, err := s.getClientInternal(ctx, clientID, tx)
|
||||
client, err := s.getClientInternal(ctx, clientID, tx, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1976,7 +1966,25 @@ func (s *OidcService) IsClientAccessibleToUser(ctx context.Context, clientID str
|
||||
return s.IsUserGroupAllowedToAuthorize(user, client), nil
|
||||
}
|
||||
|
||||
func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *gorm.DB, clientID string, raw string, light bool) error {
|
||||
var errLogoTooLarge = errors.New("logo is too large")
|
||||
|
||||
func httpClientWithCheckRedirect(source *http.Client, checkRedirect func(req *http.Request, via []*http.Request) error) *http.Client {
|
||||
if source == nil {
|
||||
source = http.DefaultClient
|
||||
}
|
||||
|
||||
// Create a new client that clones the transport
|
||||
client := &http.Client{
|
||||
Transport: source.Transport,
|
||||
}
|
||||
|
||||
// Assign the CheckRedirect function
|
||||
client.CheckRedirect = checkRedirect
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, clientID string, raw string, light bool) error {
|
||||
u, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1985,19 +1993,30 @@ func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
r := net.Resolver{}
|
||||
ips, err := r.LookupIPAddr(ctx, u.Hostname())
|
||||
if err != nil || len(ips) == 0 {
|
||||
return fmt.Errorf("cannot resolve hostname")
|
||||
// Prevents SSRF by allowing only public IPs
|
||||
ok, err := utils.IsURLPrivate(ctx, u)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if ok {
|
||||
return errors.New("private IP addresses are not allowed")
|
||||
}
|
||||
|
||||
// Prevents SSRF by allowing only public IPs
|
||||
for _, addr := range ips {
|
||||
if utils.IsPrivateIP(addr.IP) {
|
||||
return fmt.Errorf("private IP addresses are not allowed")
|
||||
// We need to check this on redirects too
|
||||
client := httpClientWithCheckRedirect(s.httpClient, func(r *http.Request, via []*http.Request) error {
|
||||
if len(via) >= 10 {
|
||||
return errors.New("stopped after 10 redirects")
|
||||
}
|
||||
|
||||
ok, err := utils.IsURLPrivate(r.Context(), r.URL)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if ok {
|
||||
return errors.New("private IP addresses are not allowed")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, raw, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -2005,7 +2024,7 @@ func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *
|
||||
req.Header.Set("User-Agent", "pocket-id/oidc-logo-fetcher")
|
||||
req.Header.Set("Accept", "image/*")
|
||||
|
||||
resp, err := s.httpClient.Do(req)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -2017,7 +2036,7 @@ func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *
|
||||
|
||||
const maxLogoSize int64 = 2 * 1024 * 1024 // 2MB
|
||||
if resp.ContentLength > maxLogoSize {
|
||||
return fmt.Errorf("logo is too large")
|
||||
return errLogoTooLarge
|
||||
}
|
||||
|
||||
// Prefer extension in path if supported
|
||||
@@ -2037,48 +2056,70 @@ func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *
|
||||
}
|
||||
|
||||
imagePath := path.Join("oidc-client-images", clientID+darkSuffix+"."+ext)
|
||||
if err := s.fileStorage.Save(ctx, imagePath, io.LimitReader(resp.Body, maxLogoSize+1)); err != nil {
|
||||
err = s.fileStorage.Save(ctx, imagePath, utils.NewLimitReader(resp.Body, maxLogoSize+1))
|
||||
if errors.Is(err, utils.ErrSizeExceeded) {
|
||||
return errLogoTooLarge
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.updateClientLogoType(ctx, tx, clientID, ext, light); err != nil {
|
||||
err = s.updateClientLogoType(ctx, clientID, ext, light)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OidcService) updateClientLogoType(ctx context.Context, tx *gorm.DB, clientID, ext string, light bool) error {
|
||||
func (s *OidcService) updateClientLogoType(ctx context.Context, clientID string, ext string, light bool) error {
|
||||
var darkSuffix string
|
||||
if !light {
|
||||
darkSuffix = "-dark"
|
||||
}
|
||||
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
// We need to acquire an update lock for the row to be locked, since we'll update it later
|
||||
var client model.OidcClient
|
||||
if err := tx.WithContext(ctx).First(&client, "id = ?", clientID).Error; err != nil {
|
||||
return err
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Clauses(clause.Locking{Strength: "UPDATE"}).
|
||||
First(&client, "id = ?", clientID).
|
||||
Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to look up client: %w", err)
|
||||
}
|
||||
|
||||
var currentType *string
|
||||
if light {
|
||||
currentType = client.ImageType
|
||||
client.ImageType = &ext
|
||||
} else {
|
||||
currentType = client.DarkImageType
|
||||
client.DarkImageType = &ext
|
||||
}
|
||||
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Save(&client).
|
||||
Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save updated client: %w", err)
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
|
||||
// Storage operations must be executed outside of a transaction
|
||||
if currentType != nil && *currentType != ext {
|
||||
old := path.Join("oidc-client-images", client.ID+darkSuffix+"."+*currentType)
|
||||
_ = s.fileStorage.Delete(ctx, old)
|
||||
}
|
||||
|
||||
var column string
|
||||
if light {
|
||||
column = "image_type"
|
||||
} else {
|
||||
column = "dark_image_type"
|
||||
}
|
||||
|
||||
return tx.WithContext(ctx).
|
||||
Model(&model.OidcClient{}).
|
||||
Where("id = ?", clientID).
|
||||
Update(column, ext).
|
||||
Error
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,7 +8,10 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -21,6 +24,7 @@ import (
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/storage"
|
||||
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
|
||||
)
|
||||
|
||||
@@ -537,3 +541,435 @@ func TestValidateCodeVerifier_Plain(t *testing.T) {
|
||||
require.False(t, validateCodeVerifier("NOT!VALID", codeChallenge, true))
|
||||
})
|
||||
}
|
||||
|
||||
func TestOidcService_updateClientLogoType(t *testing.T) {
|
||||
// Create a test database
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
|
||||
// Create database storage
|
||||
dbStorage, err := storage.NewDatabaseStorage(db)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Init the OidcService
|
||||
s := &OidcService{
|
||||
db: db,
|
||||
fileStorage: dbStorage,
|
||||
}
|
||||
|
||||
// Create a test client
|
||||
client := model.OidcClient{
|
||||
Name: "Test Client",
|
||||
CallbackURLs: model.UrlList{"https://example.com/callback"},
|
||||
}
|
||||
err = db.Create(&client).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Helper function to check if a file exists in storage
|
||||
fileExists := func(t *testing.T, path string) bool {
|
||||
t.Helper()
|
||||
_, _, err := dbStorage.Open(t.Context(), path)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// Helper function to create a dummy file in storage
|
||||
createDummyFile := func(t *testing.T, path string) {
|
||||
t.Helper()
|
||||
err := dbStorage.Save(t.Context(), path, strings.NewReader("dummy content"))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
t.Run("Updates light logo type for client without previous logo", func(t *testing.T) {
|
||||
// Update the logo type
|
||||
err := s.updateClientLogoType(t.Context(), client.ID, "png", true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the client was updated
|
||||
var updatedClient model.OidcClient
|
||||
err = db.First(&updatedClient, "id = ?", client.ID).Error
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updatedClient.ImageType)
|
||||
assert.Equal(t, "png", *updatedClient.ImageType)
|
||||
})
|
||||
|
||||
t.Run("Updates dark logo type for client without previous dark logo", func(t *testing.T) {
|
||||
// Update the dark logo type
|
||||
err := s.updateClientLogoType(t.Context(), client.ID, "jpg", false)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the client was updated
|
||||
var updatedClient model.OidcClient
|
||||
err = db.First(&updatedClient, "id = ?", client.ID).Error
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updatedClient.DarkImageType)
|
||||
assert.Equal(t, "jpg", *updatedClient.DarkImageType)
|
||||
})
|
||||
|
||||
t.Run("Updates light logo type and deletes old file when type changes", func(t *testing.T) {
|
||||
// Create the old PNG file in storage
|
||||
oldPath := "oidc-client-images/" + client.ID + ".png"
|
||||
createDummyFile(t, oldPath)
|
||||
require.True(t, fileExists(t, oldPath), "Old file should exist before update")
|
||||
|
||||
// Client currently has a PNG logo, update to WEBP
|
||||
err := s.updateClientLogoType(t.Context(), client.ID, "webp", true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the client was updated
|
||||
var updatedClient model.OidcClient
|
||||
err = db.First(&updatedClient, "id = ?", client.ID).Error
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updatedClient.ImageType)
|
||||
assert.Equal(t, "webp", *updatedClient.ImageType)
|
||||
|
||||
// Old PNG file should be deleted
|
||||
assert.False(t, fileExists(t, oldPath), "Old PNG file should have been deleted")
|
||||
})
|
||||
|
||||
t.Run("Updates dark logo type and deletes old file when type changes", func(t *testing.T) {
|
||||
// Create the old JPG dark file in storage
|
||||
oldPath := "oidc-client-images/" + client.ID + "-dark.jpg"
|
||||
createDummyFile(t, oldPath)
|
||||
require.True(t, fileExists(t, oldPath), "Old dark file should exist before update")
|
||||
|
||||
// Client currently has a JPG dark logo, update to WEBP
|
||||
err := s.updateClientLogoType(t.Context(), client.ID, "webp", false)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the client was updated
|
||||
var updatedClient model.OidcClient
|
||||
err = db.First(&updatedClient, "id = ?", client.ID).Error
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updatedClient.DarkImageType)
|
||||
assert.Equal(t, "webp", *updatedClient.DarkImageType)
|
||||
|
||||
// Old JPG dark file should be deleted
|
||||
assert.False(t, fileExists(t, oldPath), "Old JPG dark file should have been deleted")
|
||||
})
|
||||
|
||||
t.Run("Does not delete file when type remains the same", func(t *testing.T) {
|
||||
// Create the WEBP file in storage
|
||||
webpPath := "oidc-client-images/" + client.ID + ".webp"
|
||||
createDummyFile(t, webpPath)
|
||||
require.True(t, fileExists(t, webpPath), "WEBP file should exist before update")
|
||||
|
||||
// Update to the same type (WEBP)
|
||||
err := s.updateClientLogoType(t.Context(), client.ID, "webp", true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the client still has WEBP
|
||||
var updatedClient model.OidcClient
|
||||
err = db.First(&updatedClient, "id = ?", client.ID).Error
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updatedClient.ImageType)
|
||||
assert.Equal(t, "webp", *updatedClient.ImageType)
|
||||
|
||||
// WEBP file should still exist since type didn't change
|
||||
assert.True(t, fileExists(t, webpPath), "WEBP file should still exist")
|
||||
})
|
||||
|
||||
t.Run("Returns error for non-existent client", func(t *testing.T) {
|
||||
err := s.updateClientLogoType(t.Context(), "non-existent-client-id", "png", true)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "failed to look up client")
|
||||
})
|
||||
}
|
||||
|
||||
func TestOidcService_downloadAndSaveLogoFromURL(t *testing.T) {
|
||||
// Create a test database
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
|
||||
// Create database storage
|
||||
dbStorage, err := storage.NewDatabaseStorage(db)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a test client
|
||||
client := model.OidcClient{
|
||||
Name: "Test Client",
|
||||
CallbackURLs: model.UrlList{"https://example.com/callback"},
|
||||
}
|
||||
err = db.Create(&client).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Helper function to check if a file exists in storage
|
||||
fileExists := func(t *testing.T, path string) bool {
|
||||
t.Helper()
|
||||
_, _, err := dbStorage.Open(t.Context(), path)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// Helper function to get file content from storage
|
||||
getFileContent := func(t *testing.T, path string) []byte {
|
||||
t.Helper()
|
||||
reader, _, err := dbStorage.Open(t.Context(), path)
|
||||
require.NoError(t, err)
|
||||
defer reader.Close()
|
||||
content, err := io.ReadAll(reader)
|
||||
require.NoError(t, err)
|
||||
return content
|
||||
}
|
||||
|
||||
t.Run("Successfully downloads and saves PNG logo from URL", func(t *testing.T) {
|
||||
// Create mock PNG content
|
||||
pngContent := []byte("fake-png-content")
|
||||
|
||||
// Create a mock HTTP response with headers
|
||||
//nolint:bodyclose
|
||||
pngResponse := testutils.NewMockResponse(http.StatusOK, string(pngContent))
|
||||
pngResponse.Header.Set("Content-Type", "image/png")
|
||||
|
||||
// Create a mock HTTP client with responses
|
||||
mockResponses := map[string]*http.Response{
|
||||
//nolint:bodyclose
|
||||
"https://example.com/logo.png": pngResponse,
|
||||
}
|
||||
httpClient := &http.Client{
|
||||
Transport: &testutils.MockRoundTripper{
|
||||
Responses: mockResponses,
|
||||
},
|
||||
}
|
||||
|
||||
// Init the OidcService with mock HTTP client
|
||||
s := &OidcService{
|
||||
db: db,
|
||||
fileStorage: dbStorage,
|
||||
httpClient: httpClient,
|
||||
}
|
||||
|
||||
// Download and save the logo
|
||||
err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/logo.png", true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the file was saved
|
||||
logoPath := "oidc-client-images/" + client.ID + ".png"
|
||||
require.True(t, fileExists(t, logoPath), "Logo file should exist in storage")
|
||||
|
||||
// Verify the content
|
||||
savedContent := getFileContent(t, logoPath)
|
||||
assert.Equal(t, pngContent, savedContent)
|
||||
|
||||
// Verify the client was updated
|
||||
var updatedClient model.OidcClient
|
||||
err = db.First(&updatedClient, "id = ?", client.ID).Error
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updatedClient.ImageType)
|
||||
assert.Equal(t, "png", *updatedClient.ImageType)
|
||||
})
|
||||
|
||||
t.Run("Successfully downloads and saves dark logo", func(t *testing.T) {
|
||||
// Create mock WEBP content
|
||||
webpContent := []byte("fake-webp-content")
|
||||
|
||||
//nolint:bodyclose
|
||||
webpResponse := testutils.NewMockResponse(http.StatusOK, string(webpContent))
|
||||
webpResponse.Header.Set("Content-Type", "image/webp")
|
||||
|
||||
mockResponses := map[string]*http.Response{
|
||||
//nolint:bodyclose
|
||||
"https://example.com/dark-logo.webp": webpResponse,
|
||||
}
|
||||
httpClient := &http.Client{
|
||||
Transport: &testutils.MockRoundTripper{
|
||||
Responses: mockResponses,
|
||||
},
|
||||
}
|
||||
|
||||
s := &OidcService{
|
||||
db: db,
|
||||
fileStorage: dbStorage,
|
||||
httpClient: httpClient,
|
||||
}
|
||||
|
||||
// Download and save the dark logo
|
||||
err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/dark-logo.webp", false)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the dark logo file was saved
|
||||
darkLogoPath := "oidc-client-images/" + client.ID + "-dark.webp"
|
||||
require.True(t, fileExists(t, darkLogoPath), "Dark logo file should exist in storage")
|
||||
|
||||
// Verify the content
|
||||
savedContent := getFileContent(t, darkLogoPath)
|
||||
assert.Equal(t, webpContent, savedContent)
|
||||
|
||||
// Verify the client was updated
|
||||
var updatedClient model.OidcClient
|
||||
err = db.First(&updatedClient, "id = ?", client.ID).Error
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updatedClient.DarkImageType)
|
||||
assert.Equal(t, "webp", *updatedClient.DarkImageType)
|
||||
})
|
||||
|
||||
t.Run("Detects extension from URL path", func(t *testing.T) {
|
||||
svgContent := []byte("<svg></svg>")
|
||||
|
||||
mockResponses := map[string]*http.Response{
|
||||
//nolint:bodyclose
|
||||
"https://example.com/icon.svg": testutils.NewMockResponse(http.StatusOK, string(svgContent)),
|
||||
}
|
||||
httpClient := &http.Client{
|
||||
Transport: &testutils.MockRoundTripper{
|
||||
Responses: mockResponses,
|
||||
},
|
||||
}
|
||||
|
||||
s := &OidcService{
|
||||
db: db,
|
||||
fileStorage: dbStorage,
|
||||
httpClient: httpClient,
|
||||
}
|
||||
|
||||
err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/icon.svg", true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify SVG file was saved
|
||||
logoPath := "oidc-client-images/" + client.ID + ".svg"
|
||||
require.True(t, fileExists(t, logoPath), "SVG logo should exist")
|
||||
})
|
||||
|
||||
t.Run("Detects extension from Content-Type when path has no extension", func(t *testing.T) {
|
||||
jpgContent := []byte("fake-jpg-content")
|
||||
|
||||
//nolint:bodyclose
|
||||
jpgResponse := testutils.NewMockResponse(http.StatusOK, string(jpgContent))
|
||||
jpgResponse.Header.Set("Content-Type", "image/jpeg")
|
||||
|
||||
mockResponses := map[string]*http.Response{
|
||||
//nolint:bodyclose
|
||||
"https://example.com/logo": jpgResponse,
|
||||
}
|
||||
httpClient := &http.Client{
|
||||
Transport: &testutils.MockRoundTripper{
|
||||
Responses: mockResponses,
|
||||
},
|
||||
}
|
||||
|
||||
s := &OidcService{
|
||||
db: db,
|
||||
fileStorage: dbStorage,
|
||||
httpClient: httpClient,
|
||||
}
|
||||
|
||||
err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/logo", true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify JPG file was saved (jpeg extension is normalized to jpg)
|
||||
logoPath := "oidc-client-images/" + client.ID + ".jpg"
|
||||
require.True(t, fileExists(t, logoPath), "JPG logo should exist")
|
||||
})
|
||||
|
||||
t.Run("Returns error for invalid URL", func(t *testing.T) {
|
||||
s := &OidcService{
|
||||
db: db,
|
||||
fileStorage: dbStorage,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
|
||||
err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "://invalid-url", true)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("Returns error for non-200 status code", func(t *testing.T) {
|
||||
mockResponses := map[string]*http.Response{
|
||||
//nolint:bodyclose
|
||||
"https://example.com/not-found.png": testutils.NewMockResponse(http.StatusNotFound, "Not Found"),
|
||||
}
|
||||
httpClient := &http.Client{
|
||||
Transport: &testutils.MockRoundTripper{
|
||||
Responses: mockResponses,
|
||||
},
|
||||
}
|
||||
|
||||
s := &OidcService{
|
||||
db: db,
|
||||
fileStorage: dbStorage,
|
||||
httpClient: httpClient,
|
||||
}
|
||||
|
||||
err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/not-found.png", true)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "failed to fetch logo")
|
||||
})
|
||||
|
||||
t.Run("Returns error for too large content", func(t *testing.T) {
|
||||
// Create content larger than 2MB (maxLogoSize)
|
||||
largeContent := strings.Repeat("x", 2<<20+100) // 2.1MB
|
||||
|
||||
//nolint:bodyclose
|
||||
largeResponse := testutils.NewMockResponse(http.StatusOK, largeContent)
|
||||
largeResponse.Header.Set("Content-Type", "image/png")
|
||||
largeResponse.Header.Set("Content-Length", strconv.Itoa(len(largeContent)))
|
||||
|
||||
mockResponses := map[string]*http.Response{
|
||||
//nolint:bodyclose
|
||||
"https://example.com/large.png": largeResponse,
|
||||
}
|
||||
httpClient := &http.Client{
|
||||
Transport: &testutils.MockRoundTripper{
|
||||
Responses: mockResponses,
|
||||
},
|
||||
}
|
||||
|
||||
s := &OidcService{
|
||||
db: db,
|
||||
fileStorage: dbStorage,
|
||||
httpClient: httpClient,
|
||||
}
|
||||
|
||||
err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/large.png", true)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, errLogoTooLarge)
|
||||
})
|
||||
|
||||
t.Run("Returns error for unsupported file type", func(t *testing.T) {
|
||||
//nolint:bodyclose
|
||||
textResponse := testutils.NewMockResponse(http.StatusOK, "text content")
|
||||
textResponse.Header.Set("Content-Type", "text/plain")
|
||||
|
||||
mockResponses := map[string]*http.Response{
|
||||
//nolint:bodyclose
|
||||
"https://example.com/file.txt": textResponse,
|
||||
}
|
||||
httpClient := &http.Client{
|
||||
Transport: &testutils.MockRoundTripper{
|
||||
Responses: mockResponses,
|
||||
},
|
||||
}
|
||||
|
||||
s := &OidcService{
|
||||
db: db,
|
||||
fileStorage: dbStorage,
|
||||
httpClient: httpClient,
|
||||
}
|
||||
|
||||
err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/file.txt", true)
|
||||
require.Error(t, err)
|
||||
var fileTypeErr *common.FileTypeNotSupportedError
|
||||
require.ErrorAs(t, err, &fileTypeErr)
|
||||
})
|
||||
|
||||
t.Run("Returns error for non-existent client", func(t *testing.T) {
|
||||
//nolint:bodyclose
|
||||
pngResponse := testutils.NewMockResponse(http.StatusOK, "content")
|
||||
pngResponse.Header.Set("Content-Type", "image/png")
|
||||
|
||||
mockResponses := map[string]*http.Response{
|
||||
//nolint:bodyclose
|
||||
"https://example.com/logo.png": pngResponse,
|
||||
}
|
||||
httpClient := &http.Client{
|
||||
Transport: &testutils.MockRoundTripper{
|
||||
Responses: mockResponses,
|
||||
},
|
||||
}
|
||||
|
||||
s := &OidcService{
|
||||
db: db,
|
||||
fileStorage: dbStorage,
|
||||
httpClient: httpClient,
|
||||
}
|
||||
|
||||
err := s.downloadAndSaveLogoFromURL(t.Context(), "non-existent-client-id", "https://example.com/logo.png", true)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "failed to look up client")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
@@ -101,9 +102,10 @@ func (s *UserService) GetProfilePicture(ctx context.Context, userID string) (io.
|
||||
profilePicturePath := path.Join("profile-pictures", userID+".png")
|
||||
|
||||
// Try custom profile picture
|
||||
if file, size, err := s.fileStorage.Open(ctx, profilePicturePath); err == nil {
|
||||
file, size, err := s.fileStorage.Open(ctx, profilePicturePath)
|
||||
if err == nil {
|
||||
return file, size, nil
|
||||
} else if err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
} else if !errors.Is(err, fs.ErrNotExist) {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
@@ -120,9 +122,10 @@ func (s *UserService) GetProfilePicture(ctx context.Context, userID string) (io.
|
||||
|
||||
// Try cached default for initials
|
||||
defaultPicturePath := path.Join("profile-pictures", "defaults", user.Initials()+".png")
|
||||
if file, size, err := s.fileStorage.Open(ctx, defaultPicturePath); err == nil {
|
||||
file, size, err = s.fileStorage.Open(ctx, defaultPicturePath)
|
||||
if err == nil {
|
||||
return file, size, nil
|
||||
} else if err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
} else if !errors.Is(err, fs.ErrNotExist) {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
@@ -133,12 +136,13 @@ func (s *UserService) GetProfilePicture(ctx context.Context, userID string) (io.
|
||||
}
|
||||
|
||||
// Save the default picture for future use (in a goroutine to avoid blocking)
|
||||
//nolint:contextcheck
|
||||
defaultPictureBytes := defaultPicture.Bytes()
|
||||
//nolint:contextcheck
|
||||
go func() {
|
||||
if err := s.fileStorage.Save(context.Background(), defaultPicturePath, bytes.NewReader(defaultPictureBytes)); err != nil {
|
||||
slog.Error("Failed to cache default profile picture", slog.String("initials", user.Initials()), slog.Any("error", err))
|
||||
// Use bytes.NewReader because we need an io.ReadSeeker
|
||||
rErr := s.fileStorage.Save(context.Background(), defaultPicturePath, bytes.NewReader(defaultPictureBytes))
|
||||
if rErr != nil {
|
||||
slog.Error("Failed to cache default profile picture", slog.String("initials", user.Initials()), slog.Any("error", rErr))
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -182,17 +186,30 @@ func (s *UserService) UpdateProfilePicture(ctx context.Context, userID string, f
|
||||
}
|
||||
|
||||
func (s *UserService) DeleteUser(ctx context.Context, userID string, allowLdapDelete bool) error {
|
||||
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||
return s.deleteUserInternal(ctx, userID, allowLdapDelete, tx)
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
return s.deleteUserInternal(ctx, tx, userID, allowLdapDelete)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete user '%s': %w", userID, err)
|
||||
}
|
||||
|
||||
func (s *UserService) deleteUserInternal(ctx context.Context, userID string, allowLdapDelete bool, tx *gorm.DB) error {
|
||||
// Storage operations must be executed outside of a transaction
|
||||
profilePicturePath := path.Join("profile-pictures", userID+".png")
|
||||
err = s.fileStorage.Delete(ctx, profilePicturePath)
|
||||
if err != nil && !storage.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to delete profile picture for user '%s': %w", userID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UserService) deleteUserInternal(ctx context.Context, tx *gorm.DB, userID string, allowLdapDelete bool) error {
|
||||
var user model.User
|
||||
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Where("id = ?", userID).
|
||||
Clauses(clause.Locking{Strength: "UPDATE"}).
|
||||
First(&user).
|
||||
Error
|
||||
if err != nil {
|
||||
@@ -204,11 +221,6 @@ func (s *UserService) deleteUserInternal(ctx context.Context, userID string, all
|
||||
return &common.LdapUserUpdateError{}
|
||||
}
|
||||
|
||||
profilePicturePath := path.Join("profile-pictures", userID+".png")
|
||||
if err := s.fileStorage.Delete(ctx, profilePicturePath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tx.WithContext(ctx).Delete(&user).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete user: %w", err)
|
||||
@@ -286,16 +298,27 @@ func (s *UserService) applySignupDefaults(ctx context.Context, user *model.User,
|
||||
|
||||
// Apply default user groups
|
||||
var groupIDs []string
|
||||
if v := config.SignupDefaultUserGroupIDs.Value; v != "" && v != "[]" {
|
||||
if err := json.Unmarshal([]byte(v), &groupIDs); err != nil {
|
||||
v := config.SignupDefaultUserGroupIDs.Value
|
||||
if v != "" && v != "[]" {
|
||||
err := json.Unmarshal([]byte(v), &groupIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid SignupDefaultUserGroupIDs JSON: %w", err)
|
||||
}
|
||||
if len(groupIDs) > 0 {
|
||||
var groups []model.UserGroup
|
||||
if err := tx.WithContext(ctx).Where("id IN ?", groupIDs).Find(&groups).Error; err != nil {
|
||||
err = tx.WithContext(ctx).
|
||||
Where("id IN ?", groupIDs).
|
||||
Find(&groups).
|
||||
Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find default user groups: %w", err)
|
||||
}
|
||||
if err := tx.WithContext(ctx).Model(user).Association("UserGroups").Replace(groups); err != nil {
|
||||
|
||||
err = tx.WithContext(ctx).
|
||||
Model(user).
|
||||
Association("UserGroups").
|
||||
Replace(groups)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to associate default user groups: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -303,12 +326,15 @@ func (s *UserService) applySignupDefaults(ctx context.Context, user *model.User,
|
||||
|
||||
// Apply default custom claims
|
||||
var claims []dto.CustomClaimCreateDto
|
||||
if v := config.SignupDefaultCustomClaims.Value; v != "" && v != "[]" {
|
||||
if err := json.Unmarshal([]byte(v), &claims); err != nil {
|
||||
v = config.SignupDefaultCustomClaims.Value
|
||||
if v != "" && v != "[]" {
|
||||
err := json.Unmarshal([]byte(v), &claims)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid SignupDefaultCustomClaims JSON: %w", err)
|
||||
}
|
||||
if len(claims) > 0 {
|
||||
if _, err := s.customClaimService.updateCustomClaimsInternal(ctx, UserID, user.ID, claims, tx); err != nil {
|
||||
_, err = s.customClaimService.updateCustomClaimsInternal(ctx, UserID, user.ID, claims, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to apply default custom claims: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -345,6 +371,7 @@ func (s *UserService) updateUserInternal(ctx context.Context, userID string, upd
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Where("id = ?", userID).
|
||||
Clauses(clause.Locking{Strength: "UPDATE"}).
|
||||
First(&user).
|
||||
Error
|
||||
if err != nil {
|
||||
@@ -416,14 +443,12 @@ func (s *UserService) RequestOneTimeAccessEmailAsUnauthenticatedUser(ctx context
|
||||
|
||||
var userId string
|
||||
err := s.db.Model(&model.User{}).Select("id").Where("email = ?", userID).First(&userId).Error
|
||||
if err != nil {
|
||||
// Do not return error if user not found to prevent email enumeration
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// Do not return error if user not found to prevent email enumeration
|
||||
return nil
|
||||
} else {
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return s.requestOneTimeAccessEmailInternal(ctx, userId, redirectPath, 15*time.Minute)
|
||||
}
|
||||
@@ -513,7 +538,9 @@ func (s *UserService) ExchangeOneTimeAccessToken(ctx context.Context, token stri
|
||||
var oneTimeAccessToken model.OneTimeAccessToken
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Where("token = ? AND expires_at > ?", token, datatype.DateTime(time.Now())).Preload("User").
|
||||
Where("token = ? AND expires_at > ?", token, datatype.DateTime(time.Now())).
|
||||
Preload("User").
|
||||
Clauses(clause.Locking{Strength: "UPDATE"}).
|
||||
First(&oneTimeAccessToken).
|
||||
Error
|
||||
if err != nil {
|
||||
@@ -679,7 +706,7 @@ func (s *UserService) ResetProfilePicture(ctx context.Context, userID string) er
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UserService) disableUserInternal(ctx context.Context, userID string, tx *gorm.DB) error {
|
||||
func (s *UserService) disableUserInternal(ctx context.Context, tx *gorm.DB, userID string) error {
|
||||
return tx.
|
||||
WithContext(ctx).
|
||||
Model(&model.User{}).
|
||||
@@ -720,6 +747,7 @@ func (s *UserService) SignUp(ctx context.Context, signupData dto.SignUpDto, ipAd
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Where("token = ?", signupData.Token).
|
||||
Clauses(clause.Locking{Strength: "UPDATE"}).
|
||||
First(&signupToken).
|
||||
Error
|
||||
if err != nil {
|
||||
|
||||
@@ -58,7 +58,7 @@ func (s *VersionService) GetLatestVersion(ctx context.Context) (string, error) {
|
||||
}
|
||||
|
||||
if payload.TagName == "" {
|
||||
return "", fmt.Errorf("GitHub API returned empty tag name")
|
||||
return "", errors.New("GitHub API returned empty tag name")
|
||||
}
|
||||
|
||||
return strings.TrimPrefix(payload.TagName, "v"), nil
|
||||
|
||||
226
backend/internal/storage/database.go
Normal file
226
backend/internal/storage/database.go
Normal file
@@ -0,0 +1,226 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
var TypeDatabase = "database"
|
||||
|
||||
type databaseStorage struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewDatabaseStorage creates a new database storage provider
|
||||
func NewDatabaseStorage(db *gorm.DB) (FileStorage, error) {
|
||||
if db == nil {
|
||||
return nil, errors.New("database connection is required")
|
||||
}
|
||||
return &databaseStorage{db: db}, nil
|
||||
}
|
||||
|
||||
func (s *databaseStorage) Type() string {
|
||||
return TypeDatabase
|
||||
}
|
||||
|
||||
func (s *databaseStorage) Save(ctx context.Context, relativePath string, data io.Reader) error {
|
||||
// Normalize the path
|
||||
relativePath = filepath.ToSlash(filepath.Clean(relativePath))
|
||||
|
||||
// Read all data into memory
|
||||
b, err := io.ReadAll(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read data: %w", err)
|
||||
}
|
||||
|
||||
now := datatype.DateTime(time.Now())
|
||||
storage := model.Storage{
|
||||
Path: relativePath,
|
||||
Data: b,
|
||||
Size: int64(len(b)),
|
||||
ModTime: now,
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
// Use upsert: insert or update on conflict
|
||||
result := s.db.
|
||||
WithContext(ctx).
|
||||
Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "path"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"data", "size", "mod_time"}),
|
||||
}).
|
||||
Create(&storage)
|
||||
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to save file to database: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *databaseStorage) Open(ctx context.Context, relativePath string) (io.ReadCloser, int64, error) {
|
||||
relativePath = filepath.ToSlash(filepath.Clean(relativePath))
|
||||
|
||||
var storage model.Storage
|
||||
result := s.db.
|
||||
WithContext(ctx).
|
||||
Where("path = ?", relativePath).
|
||||
First(&storage)
|
||||
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, 0, os.ErrNotExist
|
||||
}
|
||||
return nil, 0, fmt.Errorf("failed to read file from database: %w", result.Error)
|
||||
}
|
||||
|
||||
reader := io.NopCloser(bytes.NewReader(storage.Data))
|
||||
return reader, storage.Size, nil
|
||||
}
|
||||
|
||||
func (s *databaseStorage) Delete(ctx context.Context, relativePath string) error {
|
||||
relativePath = filepath.ToSlash(filepath.Clean(relativePath))
|
||||
|
||||
result := s.db.
|
||||
WithContext(ctx).
|
||||
Where("path = ?", relativePath).
|
||||
Delete(&model.Storage{})
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to delete file from database: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *databaseStorage) DeleteAll(ctx context.Context, prefix string) error {
|
||||
prefix = filepath.ToSlash(filepath.Clean(prefix))
|
||||
|
||||
// If empty prefix, delete all
|
||||
if isRootPath(prefix) {
|
||||
result := s.db.
|
||||
WithContext(ctx).
|
||||
Where("1 = 1"). // Delete everything
|
||||
Delete(&model.Storage{})
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to delete all files from database: %w", result.Error)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure prefix ends with / for proper prefix matching
|
||||
if !strings.HasSuffix(prefix, "/") {
|
||||
prefix += "/"
|
||||
}
|
||||
|
||||
query := s.db.WithContext(ctx)
|
||||
query = addPathPrefixClause(s.db.Name(), query, prefix)
|
||||
result := query.Delete(&model.Storage{})
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to delete files with prefix '%s' from database: %w", prefix, result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *databaseStorage) List(ctx context.Context, prefix string) ([]ObjectInfo, error) {
|
||||
prefix = filepath.ToSlash(filepath.Clean(prefix))
|
||||
|
||||
var storageItems []model.Storage
|
||||
query := s.db.WithContext(ctx)
|
||||
|
||||
if !isRootPath(prefix) {
|
||||
// Ensure prefix matching
|
||||
if !strings.HasSuffix(prefix, "/") {
|
||||
prefix += "/"
|
||||
}
|
||||
query = addPathPrefixClause(s.db.Name(), query, prefix)
|
||||
}
|
||||
|
||||
result := query.
|
||||
Select("path", "size", "mod_time").
|
||||
Find(&storageItems)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("failed to list files from database: %w", result.Error)
|
||||
}
|
||||
|
||||
objects := make([]ObjectInfo, 0, len(storageItems))
|
||||
for _, item := range storageItems {
|
||||
// Filter out directory-like paths (those that contain additional slashes after the prefix)
|
||||
relativePath := strings.TrimPrefix(item.Path, prefix)
|
||||
if strings.ContainsRune(relativePath, '/') {
|
||||
continue
|
||||
}
|
||||
|
||||
objects = append(objects, ObjectInfo{
|
||||
Path: item.Path,
|
||||
Size: item.Size,
|
||||
ModTime: time.Time(item.ModTime),
|
||||
})
|
||||
}
|
||||
|
||||
return objects, nil
|
||||
}
|
||||
|
||||
func (s *databaseStorage) Walk(ctx context.Context, root string, fn func(ObjectInfo) error) error {
|
||||
root = filepath.ToSlash(filepath.Clean(root))
|
||||
|
||||
var storageItems []model.Storage
|
||||
query := s.db.WithContext(ctx)
|
||||
|
||||
if !isRootPath(root) {
|
||||
// Ensure root matching
|
||||
if !strings.HasSuffix(root, "/") {
|
||||
root += "/"
|
||||
}
|
||||
query = addPathPrefixClause(s.db.Name(), query, root)
|
||||
}
|
||||
|
||||
result := query.
|
||||
Select("path", "size", "mod_time").
|
||||
Find(&storageItems)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to walk files from database: %w", result.Error)
|
||||
}
|
||||
|
||||
for _, item := range storageItems {
|
||||
err := fn(ObjectInfo{
|
||||
Path: item.Path,
|
||||
Size: item.Size,
|
||||
ModTime: time.Time(item.ModTime),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isRootPath(path string) bool {
|
||||
return path == "" || path == "/" || path == "."
|
||||
}
|
||||
|
||||
func addPathPrefixClause(dialect string, query *gorm.DB, prefix string) *gorm.DB {
|
||||
// In SQLite, we use "GLOB" which can use the index
|
||||
switch dialect {
|
||||
case "sqlite":
|
||||
return query.Where("path GLOB ?", prefix+"*")
|
||||
case "postgres":
|
||||
return query.Where("path LIKE ?", prefix+"%")
|
||||
default:
|
||||
// Indicates a development-time error
|
||||
panic(fmt.Errorf("unsupported database dialect: %s", dialect))
|
||||
}
|
||||
}
|
||||
148
backend/internal/storage/database_test.go
Normal file
148
backend/internal/storage/database_test.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
testingutil "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestDatabaseStorageOperations(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := testingutil.NewDatabaseForTest(t)
|
||||
store, err := NewDatabaseStorage(db)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("type should be database", func(t *testing.T) {
|
||||
assert.Equal(t, TypeDatabase, store.Type())
|
||||
})
|
||||
|
||||
t.Run("save, open and list files", func(t *testing.T) {
|
||||
err := store.Save(ctx, "images/logo.png", bytes.NewBufferString("logo-data"))
|
||||
require.NoError(t, err)
|
||||
|
||||
reader, size, err := store.Open(ctx, "images/logo.png")
|
||||
require.NoError(t, err)
|
||||
defer reader.Close()
|
||||
|
||||
contents, err := io.ReadAll(reader)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("logo-data"), contents)
|
||||
assert.Equal(t, int64(len(contents)), size)
|
||||
|
||||
err = store.Save(ctx, "images/nested/child.txt", bytes.NewBufferString("child"))
|
||||
require.NoError(t, err)
|
||||
|
||||
files, err := store.List(ctx, "images")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, files, 1)
|
||||
assert.Equal(t, "images/logo.png", files[0].Path)
|
||||
assert.Equal(t, int64(len("logo-data")), files[0].Size)
|
||||
})
|
||||
|
||||
t.Run("save should update existing file", func(t *testing.T) {
|
||||
err := store.Save(ctx, "test/update.txt", bytes.NewBufferString("original"))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.Save(ctx, "test/update.txt", bytes.NewBufferString("updated"))
|
||||
require.NoError(t, err)
|
||||
|
||||
reader, size, err := store.Open(ctx, "test/update.txt")
|
||||
require.NoError(t, err)
|
||||
defer reader.Close()
|
||||
|
||||
contents, err := io.ReadAll(reader)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("updated"), contents)
|
||||
assert.Equal(t, int64(len("updated")), size)
|
||||
})
|
||||
|
||||
t.Run("delete files individually", func(t *testing.T) {
|
||||
err := store.Save(ctx, "images/delete-me.txt", bytes.NewBufferString("temp"))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, store.Delete(ctx, "images/delete-me.txt"))
|
||||
_, _, err = store.Open(ctx, "images/delete-me.txt")
|
||||
require.Error(t, err)
|
||||
assert.True(t, IsNotExist(err))
|
||||
})
|
||||
|
||||
t.Run("delete missing file should not error", func(t *testing.T) {
|
||||
require.NoError(t, store.Delete(ctx, "images/missing.txt"))
|
||||
})
|
||||
|
||||
t.Run("delete all files", func(t *testing.T) {
|
||||
require.NoError(t, store.Save(ctx, "cleanup/a.txt", bytes.NewBufferString("a")))
|
||||
require.NoError(t, store.Save(ctx, "cleanup/b.txt", bytes.NewBufferString("b")))
|
||||
require.NoError(t, store.Save(ctx, "cleanup/nested/c.txt", bytes.NewBufferString("c")))
|
||||
require.NoError(t, store.DeleteAll(ctx, "/"))
|
||||
|
||||
_, _, err := store.Open(ctx, "cleanup/a.txt")
|
||||
require.Error(t, err)
|
||||
assert.True(t, IsNotExist(err))
|
||||
|
||||
_, _, err = store.Open(ctx, "cleanup/b.txt")
|
||||
require.Error(t, err)
|
||||
assert.True(t, IsNotExist(err))
|
||||
|
||||
_, _, err = store.Open(ctx, "cleanup/nested/c.txt")
|
||||
require.Error(t, err)
|
||||
assert.True(t, IsNotExist(err))
|
||||
})
|
||||
|
||||
t.Run("delete all files under a prefix", func(t *testing.T) {
|
||||
require.NoError(t, store.Save(ctx, "cleanup/a.txt", bytes.NewBufferString("a")))
|
||||
require.NoError(t, store.Save(ctx, "cleanup/b.txt", bytes.NewBufferString("b")))
|
||||
require.NoError(t, store.Save(ctx, "cleanup/nested/c.txt", bytes.NewBufferString("c")))
|
||||
require.NoError(t, store.DeleteAll(ctx, "cleanup"))
|
||||
|
||||
_, _, err := store.Open(ctx, "cleanup/a.txt")
|
||||
require.Error(t, err)
|
||||
assert.True(t, IsNotExist(err))
|
||||
|
||||
_, _, err = store.Open(ctx, "cleanup/b.txt")
|
||||
require.Error(t, err)
|
||||
assert.True(t, IsNotExist(err))
|
||||
|
||||
_, _, err = store.Open(ctx, "cleanup/nested/c.txt")
|
||||
require.Error(t, err)
|
||||
assert.True(t, IsNotExist(err))
|
||||
})
|
||||
|
||||
t.Run("walk files", func(t *testing.T) {
|
||||
require.NoError(t, store.Save(ctx, "walk/file1.txt", bytes.NewBufferString("1")))
|
||||
require.NoError(t, store.Save(ctx, "walk/file2.txt", bytes.NewBufferString("2")))
|
||||
require.NoError(t, store.Save(ctx, "walk/nested/file3.txt", bytes.NewBufferString("3")))
|
||||
|
||||
var paths []string
|
||||
err := store.Walk(ctx, "walk", func(info ObjectInfo) error {
|
||||
paths = append(paths, info.Path)
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, paths, 3)
|
||||
assert.Contains(t, paths, "walk/file1.txt")
|
||||
assert.Contains(t, paths, "walk/file2.txt")
|
||||
assert.Contains(t, paths, "walk/nested/file3.txt")
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewDatabaseStorage(t *testing.T) {
|
||||
t.Run("should return error with nil database", func(t *testing.T) {
|
||||
_, err := NewDatabaseStorage(nil)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "database connection is required")
|
||||
})
|
||||
|
||||
t.Run("should create storage with valid database", func(t *testing.T) {
|
||||
db := testingutil.NewDatabaseForTest(t)
|
||||
store, err := NewDatabaseStorage(db)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, store)
|
||||
})
|
||||
}
|
||||
@@ -1,7 +1,10 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
@@ -56,6 +59,23 @@ func IsPrivateIP(ip net.IP) bool {
|
||||
return IsLocalhostIP(ip) || IsPrivateLanIP(ip) || IsTailscaleIP(ip) || IsLocalIPv6(ip)
|
||||
}
|
||||
|
||||
func IsURLPrivate(ctx context.Context, u *url.URL) (bool, error) {
|
||||
var r net.Resolver
|
||||
ips, err := r.LookupIPAddr(ctx, u.Hostname())
|
||||
if err != nil || len(ips) == 0 {
|
||||
return false, errors.New("cannot resolve hostname")
|
||||
}
|
||||
|
||||
// Prevents SSRF by allowing only public IPs
|
||||
for _, addr := range ips {
|
||||
if IsPrivateIP(addr.IP) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func listContainsIP(ipNets []*net.IPNet, ip net.IP) bool {
|
||||
for _, ipNet := range ipNets {
|
||||
if ipNet.Contains(ip) {
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
)
|
||||
@@ -20,9 +26,8 @@ func TestIsLocalhostIP(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
if got := IsLocalhostIP(ip); got != tt.expected {
|
||||
t.Errorf("IsLocalhostIP(%s) = %v, want %v", tt.ip, got, tt.expected)
|
||||
}
|
||||
got := IsLocalhostIP(ip)
|
||||
assert.Equal(t, tt.expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,9 +45,8 @@ func TestIsPrivateLanIP(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
if got := IsPrivateLanIP(ip); got != tt.expected {
|
||||
t.Errorf("IsPrivateLanIP(%s) = %v, want %v", tt.ip, got, tt.expected)
|
||||
}
|
||||
got := IsPrivateLanIP(ip)
|
||||
assert.Equal(t, tt.expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,9 +63,9 @@ func TestIsTailscaleIP(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
if got := IsTailscaleIP(ip); got != tt.expected {
|
||||
t.Errorf("IsTailscaleIP(%s) = %v, want %v", tt.ip, got, tt.expected)
|
||||
}
|
||||
|
||||
got := IsTailscaleIP(ip)
|
||||
assert.Equal(t, tt.expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,16 +90,17 @@ func TestIsLocalIPv6(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
if got := IsLocalIPv6(ip); got != tt.expected {
|
||||
t.Errorf("IsLocalIPv6(%s) = %v, want %v", tt.ip, got, tt.expected)
|
||||
}
|
||||
got := IsLocalIPv6(ip)
|
||||
assert.Equal(t, tt.expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsPrivateIP(t *testing.T) {
|
||||
// Save and restore env config
|
||||
origRanges := common.EnvConfig.LocalIPv6Ranges
|
||||
defer func() { common.EnvConfig.LocalIPv6Ranges = origRanges }()
|
||||
t.Cleanup(func() {
|
||||
common.EnvConfig.LocalIPv6Ranges = origRanges
|
||||
})
|
||||
|
||||
common.EnvConfig.LocalIPv6Ranges = "fd00::/8"
|
||||
localIPv6Ranges = nil // reset
|
||||
@@ -115,9 +120,8 @@ func TestIsPrivateIP(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
if got := IsPrivateIP(ip); got != tt.expected {
|
||||
t.Errorf("IsPrivateIP(%s) = %v, want %v", tt.ip, got, tt.expected)
|
||||
}
|
||||
got := IsPrivateIP(ip)
|
||||
assert.Equal(t, tt.expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -138,22 +142,202 @@ func TestListContainsIP(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
if got := listContainsIP(list, ip); got != tt.expected {
|
||||
t.Errorf("listContainsIP(%s) = %v, want %v", tt.ip, got, tt.expected)
|
||||
}
|
||||
got := listContainsIP(list, ip)
|
||||
assert.Equal(t, tt.expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInit_LocalIPv6Ranges(t *testing.T) {
|
||||
// Save and restore env config
|
||||
origRanges := common.EnvConfig.LocalIPv6Ranges
|
||||
defer func() { common.EnvConfig.LocalIPv6Ranges = origRanges }()
|
||||
t.Cleanup(func() {
|
||||
common.EnvConfig.LocalIPv6Ranges = origRanges
|
||||
})
|
||||
|
||||
common.EnvConfig.LocalIPv6Ranges = "fd00::/8, invalidCIDR ,fc00::/7"
|
||||
localIPv6Ranges = nil
|
||||
loadLocalIPv6Ranges()
|
||||
|
||||
if len(localIPv6Ranges) != 2 {
|
||||
t.Errorf("expected 2 valid IPv6 ranges, got %d", len(localIPv6Ranges))
|
||||
assert.Len(t, localIPv6Ranges, 2)
|
||||
}
|
||||
|
||||
func TestIsURLPrivate(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
urlStr string
|
||||
expectPriv bool
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "localhost by name",
|
||||
urlStr: "http://localhost",
|
||||
expectPriv: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "localhost with port",
|
||||
urlStr: "http://localhost:8080",
|
||||
expectPriv: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1 IP",
|
||||
urlStr: "http://127.0.0.1",
|
||||
expectPriv: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1 with port",
|
||||
urlStr: "http://127.0.0.1:3000",
|
||||
expectPriv: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 loopback",
|
||||
urlStr: "http://[::1]",
|
||||
expectPriv: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 loopback with port",
|
||||
urlStr: "http://[::1]:8080",
|
||||
expectPriv: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "private IP 10.x.x.x",
|
||||
urlStr: "http://10.0.0.1",
|
||||
expectPriv: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "private IP 192.168.x.x",
|
||||
urlStr: "http://192.168.1.1",
|
||||
expectPriv: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "private IP 172.16.x.x",
|
||||
urlStr: "http://172.16.0.1",
|
||||
expectPriv: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Tailscale IP",
|
||||
urlStr: "http://100.64.0.1",
|
||||
expectPriv: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "public IP - Google DNS",
|
||||
urlStr: "http://8.8.8.8",
|
||||
expectPriv: false,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "public IP - Cloudflare DNS",
|
||||
urlStr: "http://1.1.1.1",
|
||||
expectPriv: false,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid hostname",
|
||||
urlStr: "http://this-should-not-resolve-ever-123456789.invalid",
|
||||
expectPriv: false,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
u, err := url.Parse(tt.urlStr)
|
||||
require.NoError(t, err, "Failed to parse URL %s", tt.urlStr)
|
||||
|
||||
isPriv, err := IsURLPrivate(ctx, u)
|
||||
|
||||
if tt.expectError {
|
||||
require.Error(t, err, "IsURLPrivate(%s) expected error but got none", tt.urlStr)
|
||||
} else {
|
||||
require.NoError(t, err, "IsURLPrivate(%s) unexpected error", tt.urlStr)
|
||||
assert.Equal(t, tt.expectPriv, isPriv, "IsURLPrivate(%s)", tt.urlStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsURLPrivate_WithDomainName(t *testing.T) {
|
||||
// Note: These tests rely on actual DNS resolution
|
||||
// They test real public domains to ensure they are not flagged as private
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
urlStr string
|
||||
expectPriv bool
|
||||
}{
|
||||
{
|
||||
name: "Google public domain",
|
||||
urlStr: "https://www.google.com",
|
||||
expectPriv: false,
|
||||
},
|
||||
{
|
||||
name: "GitHub public domain",
|
||||
urlStr: "https://github.com",
|
||||
expectPriv: false,
|
||||
},
|
||||
{
|
||||
// localhost.localtest.me is a well-known domain that resolves to 127.0.0.1
|
||||
name: "localhost.localtest.me resolves to 127.0.0.1",
|
||||
urlStr: "http://localhost.localtest.me",
|
||||
expectPriv: true,
|
||||
},
|
||||
{
|
||||
// 10.0.0.1.nip.io resolves to 10.0.0.1 (private IP)
|
||||
name: "nip.io domain resolving to private 10.x IP",
|
||||
urlStr: "http://10.0.0.1.nip.io",
|
||||
expectPriv: true,
|
||||
},
|
||||
{
|
||||
// 192.168.1.1.nip.io resolves to 192.168.1.1 (private IP)
|
||||
name: "nip.io domain resolving to private 192.168.x IP",
|
||||
urlStr: "http://192.168.1.1.nip.io",
|
||||
expectPriv: true,
|
||||
},
|
||||
{
|
||||
// 127.0.0.1.nip.io resolves to 127.0.0.1 (localhost)
|
||||
name: "nip.io domain resolving to localhost",
|
||||
urlStr: "http://127.0.0.1.nip.io",
|
||||
expectPriv: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
u, err := url.Parse(tt.urlStr)
|
||||
require.NoError(t, err, "Failed to parse URL %s", tt.urlStr)
|
||||
|
||||
isPriv, err := IsURLPrivate(ctx, u)
|
||||
if err != nil {
|
||||
t.Skipf("DNS resolution failed for %s (network issue?): %v", tt.urlStr, err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.expectPriv, isPriv, "IsURLPrivate(%s)", tt.urlStr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsURLPrivate_ContextCancellation(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
u, err := url.Parse("http://example.com")
|
||||
require.NoError(t, err, "Failed to parse URL")
|
||||
|
||||
_, err = IsURLPrivate(ctx, u)
|
||||
assert.Error(t, err, "IsURLPrivate with cancelled context expected error but got none")
|
||||
}
|
||||
|
||||
34
backend/internal/utils/stream_util.go
Normal file
34
backend/internal/utils/stream_util.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
var ErrSizeExceeded = errors.New("stream size exceeded")
|
||||
|
||||
// LimitReader is like io.LimitReader but throws an error if the stream exceeds the max size
|
||||
// io.LimitReader instead just returns io.EOF
|
||||
// Adapted from https://github.com/golang/go/issues/51115#issuecomment-1079761212
|
||||
type LimitReader struct {
|
||||
io.ReadCloser
|
||||
N int64
|
||||
}
|
||||
|
||||
func NewLimitReader(r io.ReadCloser, limit int64) *LimitReader {
|
||||
return &LimitReader{r, limit}
|
||||
}
|
||||
|
||||
func (r *LimitReader) Read(p []byte) (n int, err error) {
|
||||
if r.N <= 0 {
|
||||
return 0, ErrSizeExceeded
|
||||
}
|
||||
|
||||
if int64(len(p)) > r.N {
|
||||
p = p[0:r.N]
|
||||
}
|
||||
|
||||
n, err = r.ReadCloser.Read(p)
|
||||
r.N -= int64(n)
|
||||
return
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
DROP TABLE storage;
|
||||
@@ -0,0 +1,9 @@
|
||||
-- The "storage" table contains file data stored in the database
|
||||
CREATE TABLE storage
|
||||
(
|
||||
path TEXT NOT NULL PRIMARY KEY,
|
||||
data BYTEA NOT NULL,
|
||||
size BIGINT NOT NULL,
|
||||
mod_time TIMESTAMPTZ NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL
|
||||
);
|
||||
@@ -0,0 +1 @@
|
||||
DROP INDEX idx_api_keys_expires_at;
|
||||
@@ -0,0 +1 @@
|
||||
CREATE INDEX idx_api_keys_expires_at ON api_keys(expires_at);
|
||||
@@ -0,0 +1,6 @@
|
||||
PRAGMA foreign_keys=OFF;
|
||||
BEGIN;
|
||||
DROP TABLE storage;
|
||||
|
||||
COMMIT;
|
||||
PRAGMA foreign_keys=ON;
|
||||
@@ -0,0 +1,14 @@
|
||||
PRAGMA foreign_keys=OFF;
|
||||
BEGIN;
|
||||
-- The "storage" table contains file data stored in the database
|
||||
CREATE TABLE storage
|
||||
(
|
||||
path TEXT NOT NULL PRIMARY KEY,
|
||||
data BLOB NOT NULL,
|
||||
size INTEGER NOT NULL,
|
||||
mod_time DATETIME NOT NULL,
|
||||
created_at DATETIME NOT NULL
|
||||
);
|
||||
|
||||
COMMIT;
|
||||
PRAGMA foreign_keys=ON;
|
||||
@@ -0,0 +1,5 @@
|
||||
PRAGMA foreign_keys=OFF;
|
||||
BEGIN;
|
||||
DROP INDEX idx_api_keys_expires_at;
|
||||
COMMIT;
|
||||
PRAGMA foreign_keys=ON;
|
||||
@@ -0,0 +1,5 @@
|
||||
PRAGMA foreign_keys=OFF;
|
||||
BEGIN;
|
||||
CREATE INDEX idx_api_keys_expires_at ON api_keys(expires_at);
|
||||
COMMIT;
|
||||
PRAGMA foreign_keys=ON;
|
||||
@@ -20,8 +20,10 @@ services:
|
||||
file: docker-compose.yml
|
||||
service: pocket-id
|
||||
environment:
|
||||
- APP_ENV=test
|
||||
- DB_PROVIDER=postgres
|
||||
- DB_CONNECTION_STRING=postgres://postgres:postgres@postgres:5432/pocket-id
|
||||
- FILE_BACKEND=${FILE_BACKEND}
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
|
||||
@@ -13,7 +13,7 @@ services:
|
||||
retries: 10
|
||||
|
||||
create-bucket:
|
||||
image: amazon/aws-cli
|
||||
image: amazon/aws-cli:latest
|
||||
environment:
|
||||
AWS_ACCESS_KEY_ID: test
|
||||
AWS_SECRET_ACCESS_KEY: test
|
||||
@@ -28,15 +28,14 @@ services:
|
||||
file: docker-compose.yml
|
||||
service: pocket-id
|
||||
environment:
|
||||
FILE_BACKEND: s3
|
||||
S3_BUCKET: pocket-id-test
|
||||
S3_REGION: us-east-1
|
||||
S3_ENDPOINT: http://localstack-s3:4566
|
||||
S3_ACCESS_KEY_ID: test
|
||||
S3_SECRET_ACCESS_KEY: test
|
||||
S3_FORCE_PATH_STYLE: true
|
||||
KEYS_STORAGE: database
|
||||
ENCRYPTION_KEY: test
|
||||
- S3_BUCKET=pocket-id-test
|
||||
- S3_REGION=us-east-1
|
||||
- S3_ENDPOINT=http://localstack-s3:4566
|
||||
- S3_ACCESS_KEY_ID=test
|
||||
- S3_SECRET_ACCESS_KEY=test
|
||||
- S3_FORCE_PATH_STYLE=true
|
||||
- KEYS_STORAGE=database
|
||||
- ENCRYPTION_KEY=test1234test1234test1234test1234
|
||||
depends_on:
|
||||
create-bucket:
|
||||
condition: service_completed_successfully
|
||||
|
||||
@@ -14,6 +14,7 @@ services:
|
||||
- '1411:1411'
|
||||
environment:
|
||||
- APP_ENV=test
|
||||
- FILE_BACKEND=${FILE_BACKEND}
|
||||
build:
|
||||
args:
|
||||
- BUILD_TAGS=e2etest
|
||||
|
||||
Reference in New Issue
Block a user