diff --git a/.github/workflows/e2e-tests.yml b/.github/workflows/e2e-tests.yml index acd002fb..ab558411 100644 --- a/.github/workflows/e2e-tests.yml +++ b/.github/workflows/e2e-tests.yml @@ -57,7 +57,17 @@ jobs: strategy: fail-fast: false matrix: - db: [sqlite, postgres, sqlite-s3] + include: + - db: sqlite + storage: fs + - db: postgres + storage: fs + - db: sqlite + storage: s3 + - db: sqlite + storage: database + - db: postgres + storage: database steps: - uses: actions/checkout@v5 @@ -71,65 +81,74 @@ jobs: node-version: 22 - name: Cache Playwright Browsers - uses: actions/cache@v3 + uses: actions/cache@v4 id: playwright-cache with: path: ~/.cache/ms-playwright key: ${{ runner.os }}-playwright-${{ hashFiles('pnpm-lock.yaml') }} - name: Cache PostgreSQL Docker image - if: matrix.db == 'postgres' - uses: actions/cache@v3 + uses: actions/cache@v4 id: postgres-cache with: path: /tmp/postgres-image.tar key: postgres-17-${{ runner.os }} - - name: Pull and save PostgreSQL image if: matrix.db == 'postgres' && steps.postgres-cache.outputs.cache-hit != 'true' run: | docker pull postgres:17 docker save postgres:17 > /tmp/postgres-image.tar - - name: Load PostgreSQL image if: matrix.db == 'postgres' && steps.postgres-cache.outputs.cache-hit == 'true' run: docker load < /tmp/postgres-image.tar - name: Cache LLDAP Docker image - uses: actions/cache@v3 + uses: actions/cache@v4 id: lldap-cache with: path: /tmp/lldap-image.tar key: lldap-stable-${{ runner.os }} - - name: Pull and save LLDAP image if: steps.lldap-cache.outputs.cache-hit != 'true' run: | - docker pull nitnelave/lldap:stable - docker save nitnelave/lldap:stable > /tmp/lldap-image.tar - + docker pull lldap/lldap:2025-05-19 + docker save lldap/lldap:2025-05-19 > /tmp/lldap-image.tar - name: Load LLDAP image if: steps.lldap-cache.outputs.cache-hit == 'true' run: docker load < /tmp/lldap-image.tar - name: Cache Localstack S3 Docker image - if: matrix.db == 'sqlite-s3' - uses: actions/cache@v3 + if: matrix.storage == 's3' + uses: actions/cache@v4 id: s3-cache with: path: /tmp/localstack-s3-image.tar key: localstack-s3-latest-${{ runner.os }} - - name: Pull and save Localstack S3 image - if: matrix.db == 'sqlite-s3' && steps.s3-cache.outputs.cache-hit != 'true' + if: matrix.storage == 's3' && steps.s3-cache.outputs.cache-hit != 'true' run: | docker pull localstack/localstack:s3-latest docker save localstack/localstack:s3-latest > /tmp/localstack-s3-image.tar - - name: Load Localstack S3 image - if: matrix.db == 'sqlite-s3' && steps.s3-cache.outputs.cache-hit == 'true' + if: matrix.storage == 's3' && steps.s3-cache.outputs.cache-hit == 'true' run: docker load < /tmp/localstack-s3-image.tar + - name: Cache AWS CLI Docker image + if: matrix.storage == 's3' + uses: actions/cache@v4 + id: aws-cli-cache + with: + path: /tmp/aws-cli-image.tar + key: aws-cli-latest-${{ runner.os }} + - name: Pull and save AWS CLI image + if: matrix.storage == 's3' && steps.aws-cli-cache.outputs.cache-hit != 'true' + run: | + docker pull amazon/aws-cli:latest + docker save amazon/aws-cli:latest > /tmp/aws-cli-image.tar + - name: Load AWS CLI image + if: matrix.storage == 's3' && steps.aws-cli-cache.outputs.cache-hit == 'true' + run: docker load < /tmp/aws-cli-image.tar + - name: Download Docker image artifact uses: actions/download-artifact@v4 with: @@ -147,26 +166,20 @@ jobs: if: steps.playwright-cache.outputs.cache-hit != 'true' run: pnpm exec playwright install --with-deps chromium - - name: Run Docker Container (sqlite) with LDAP - if: matrix.db == 'sqlite' + - name: Run Docker containers working-directory: ./tests/setup run: | - docker compose up -d - docker compose logs -f pocket-id &> /tmp/backend.log & + DOCKER_COMPOSE_FILE=docker-compose.yml - - name: Run Docker Container (postgres) with LDAP - if: matrix.db == 'postgres' - working-directory: ./tests/setup - run: | - docker compose -f docker-compose-postgres.yml up -d - docker compose -f docker-compose-postgres.yml logs -f pocket-id &> /tmp/backend.log & + export FILE_BACKEND="${{ matrix.storage }}" + if [ "${{ matrix.db }}" = "postgres" ]; then + DOCKER_COMPOSE_FILE=docker-compose-postgres.yml + elif [ "${{ matrix.storage }}" = "s3" ]; then + DOCKER_COMPOSE_FILE=docker-compose-s3.yml + fi - - name: Run Docker Container (sqlite-s3) with LDAP + S3 - if: matrix.db == 'sqlite-s3' - working-directory: ./tests/setup - run: | - docker compose -f docker-compose-s3.yml up -d - docker compose -f docker-compose-s3.yml logs -f pocket-id &> /tmp/backend.log & + docker compose -f "$DOCKER_COMPOSE_FILE" up -d + docker compose -f "$DOCKER_COMPOSE_FILE" logs -f pocket-id &> /tmp/backend.log & - name: Run Playwright tests working-directory: ./tests @@ -176,7 +189,7 @@ jobs: uses: actions/upload-artifact@v4 if: always() && github.event.pull_request.head.ref != 'i18n_crowdin' with: - name: playwright-report-${{ matrix.db }} + name: playwright-report-${{ matrix.db }}-${{ matrix.storage }} path: tests/.report include-hidden-files: true retention-days: 15 @@ -185,7 +198,7 @@ jobs: uses: actions/upload-artifact@v4 if: always() && github.event.pull_request.head.ref != 'i18n_crowdin' with: - name: backend-${{ matrix.db }} + name: backend-${{ matrix.db }}-${{ matrix.storage }} path: /tmp/backend.log include-hidden-files: true retention-days: 15 diff --git a/backend/internal/bootstrap/bootstrap.go b/backend/internal/bootstrap/bootstrap.go index a4e22700..ba1ea993 100644 --- a/backend/internal/bootstrap/bootstrap.go +++ b/backend/internal/bootstrap/bootstrap.go @@ -22,12 +22,20 @@ func Bootstrap(ctx context.Context) error { } slog.InfoContext(ctx, "Pocket ID is starting") + // Connect to the database + db, err := NewDatabase() + if err != nil { + return fmt.Errorf("failed to initialize database: %w", err) + } + // Initialize the file storage backend var fileStorage storage.FileStorage switch common.EnvConfig.FileBackend { case storage.TypeFileSystem: fileStorage, err = storage.NewFilesystemStorage(common.EnvConfig.UploadPath) + case storage.TypeDatabase: + fileStorage, err = storage.NewDatabaseStorage(db) case storage.TypeS3: s3Cfg := storage.S3Config{ Bucket: common.EnvConfig.S3Bucket, @@ -43,7 +51,7 @@ func Bootstrap(ctx context.Context) error { err = fmt.Errorf("unknown file storage backend: %s", common.EnvConfig.FileBackend) } if err != nil { - return fmt.Errorf("failed to initialize file storage: %w", err) + return fmt.Errorf("failed to initialize file storage (backend: %s): %w", common.EnvConfig.FileBackend, err) } imageExtensions, err := initApplicationImages(ctx, fileStorage) @@ -51,12 +59,6 @@ func Bootstrap(ctx context.Context) error { return fmt.Errorf("failed to initialize application images: %w", err) } - // Connect to the database - db, err := NewDatabase() - if err != nil { - return fmt.Errorf("failed to initialize database: %w", err) - } - // Create all services svc, err := initServices(ctx, db, httpClient, imageExtensions, fileStorage) if err != nil { diff --git a/backend/internal/bootstrap/services_bootstrap.go b/backend/internal/bootstrap/services_bootstrap.go index 31a9967e..21254627 100644 --- a/backend/internal/bootstrap/services_bootstrap.go +++ b/backend/internal/bootstrap/services_bootstrap.go @@ -66,7 +66,7 @@ func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client, ima svc.userGroupService = service.NewUserGroupService(db, svc.appConfigService) svc.userService = service.NewUserService(db, svc.jwtService, svc.auditLogService, svc.emailService, svc.appConfigService, svc.customClaimService, svc.appImagesService, fileStorage) - svc.ldapService = service.NewLdapService(db, httpClient, svc.appConfigService, svc.userService, svc.userGroupService) + svc.ldapService = service.NewLdapService(db, httpClient, svc.appConfigService, svc.userService, svc.userGroupService, fileStorage) svc.apiKeyService = service.NewApiKeyService(db, svc.emailService) svc.versionService = service.NewVersionService(httpClient) diff --git a/backend/internal/common/env_config.go b/backend/internal/common/env_config.go index 9ee14bae..ea2f2a2b 100644 --- a/backend/internal/common/env_config.go +++ b/backend/internal/common/env_config.go @@ -180,12 +180,14 @@ func validateEnvConfig(config *EnvConfigSchema) error { if config.KeysStorage == "file" { return errors.New("KEYS_STORAGE cannot be 'file' when FILE_BACKEND is 's3'") } + case "database": + // All good, these are valid values case "", "fs": if config.UploadPath == "" { config.UploadPath = defaultFsUploadPath } default: - return errors.New("invalid FILE_BACKEND value. Must be 'fs' or 's3'") + return errors.New("invalid FILE_BACKEND value. Must be 'fs', 'database', or 's3'") } // Validate LOCAL_IPV6_RANGES diff --git a/backend/internal/controller/oidc_controller.go b/backend/internal/controller/oidc_controller.go index 607be22e..86d26e14 100644 --- a/backend/internal/controller/oidc_controller.go +++ b/backend/internal/controller/oidc_controller.go @@ -587,7 +587,6 @@ func (oc *OidcController) updateClientLogoHandler(c *gin.Context) { } c.Status(http.StatusNoContent) - } // deleteClientLogoHandler godoc @@ -614,7 +613,6 @@ func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) { } c.Status(http.StatusNoContent) - } // updateAllowedUserGroupsHandler godoc diff --git a/backend/internal/model/storage.go b/backend/internal/model/storage.go new file mode 100644 index 00000000..f5668b87 --- /dev/null +++ b/backend/internal/model/storage.go @@ -0,0 +1,17 @@ +package model + +import ( + datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" +) + +type Storage struct { + Path string `gorm:"primaryKey"` + Data []byte + Size int64 + ModTime datatype.DateTime + CreatedAt datatype.DateTime +} + +func (Storage) TableName() string { + return "storage" +} diff --git a/backend/internal/service/e2etest_service.go b/backend/internal/service/e2etest_service.go index ff978138..5c5aee45 100644 --- a/backend/internal/service/e2etest_service.go +++ b/backend/internal/service/e2etest_service.go @@ -426,7 +426,8 @@ func (s *TestService) ResetDatabase() error { } func (s *TestService) ResetApplicationImages(ctx context.Context) error { - if err := s.fileStorage.DeleteAll(ctx, "/"); err != nil { + err := s.fileStorage.DeleteAll(ctx, "/") + if err != nil { slog.ErrorContext(ctx, "Error removing uploads", slog.Any("error", err)) return err } @@ -445,7 +446,8 @@ func (s *TestService) ResetApplicationImages(ctx context.Context) error { if err != nil { return err } - if err := s.fileStorage.Save(ctx, path.Join("application-images", file.Name()), srcFile); err != nil { + err = s.fileStorage.Save(ctx, path.Join("application-images", file.Name()), srcFile) + if err != nil { srcFile.Close() return err } diff --git a/backend/internal/service/ldap_service.go b/backend/internal/service/ldap_service.go index 731c7e66..47fe91e8 100644 --- a/backend/internal/service/ldap_service.go +++ b/backend/internal/service/ldap_service.go @@ -11,12 +11,14 @@ import ( "log/slog" "net/http" "net/url" + "path" "strings" "time" "unicode/utf8" "github.com/go-ldap/ldap/v3" "github.com/google/uuid" + "github.com/pocket-id/pocket-id/backend/internal/storage" "github.com/pocket-id/pocket-id/backend/internal/utils" "golang.org/x/text/unicode/norm" "gorm.io/gorm" @@ -32,15 +34,23 @@ type LdapService struct { appConfigService *AppConfigService userService *UserService groupService *UserGroupService + fileStorage storage.FileStorage } -func NewLdapService(db *gorm.DB, httpClient *http.Client, appConfigService *AppConfigService, userService *UserService, groupService *UserGroupService) *LdapService { +type savePicture struct { + userID string + username string + picture string +} + +func NewLdapService(db *gorm.DB, httpClient *http.Client, appConfigService *AppConfigService, userService *UserService, groupService *UserGroupService, fileStorage storage.FileStorage) *LdapService { return &LdapService{ db: db, httpClient: httpClient, appConfigService: appConfigService, userService: userService, groupService: groupService, + fileStorage: fileStorage, } } @@ -68,12 +78,6 @@ func (s *LdapService) createClient() (*ldap.Conn, error) { } func (s *LdapService) SyncAll(ctx context.Context) error { - // Start a transaction - tx := s.db.Begin() - defer func() { - tx.Rollback() - }() - // Setup LDAP connection client, err := s.createClient() if err != nil { @@ -81,7 +85,13 @@ func (s *LdapService) SyncAll(ctx context.Context) error { } defer client.Close() - err = s.SyncUsers(ctx, tx, client) + // Start a transaction + tx := s.db.Begin() + defer func() { + tx.Rollback() + }() + + savePictures, deleteFiles, err := s.SyncUsers(ctx, tx, client) if err != nil { return fmt.Errorf("failed to sync users: %w", err) } @@ -97,6 +107,25 @@ func (s *LdapService) SyncAll(ctx context.Context) error { return fmt.Errorf("failed to commit changes to database: %w", err) } + // Now that we've committed the transaction, we can perform operations on the storage layer + // First, save all new pictures + for _, sp := range savePictures { + err = s.saveProfilePicture(ctx, sp.userID, sp.picture) + if err != nil { + // This is not a fatal error + slog.Warn("Error saving profile picture for LDAP user", slog.String("username", sp.username), slog.Any("error", err)) + } + } + + // Delete all old files + for _, path := range deleteFiles { + err = s.fileStorage.Delete(ctx, path) + if err != nil { + // This is not a fatal error + slog.Error("Failed to delete file after LDAP sync", slog.String("path", path), slog.Any("error", err)) + } + } + return nil } @@ -266,7 +295,7 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap. } //nolint:gocognit -func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.Conn) error { +func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.Conn) (savePictures []savePicture, deleteFiles []string, err error) { dbConfig := s.appConfigService.GetDbConfig() searchAttrs := []string{ @@ -294,11 +323,12 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C result, err := client.Search(searchReq) if err != nil { - return fmt.Errorf("failed to query LDAP: %w", err) + return nil, nil, fmt.Errorf("failed to query LDAP: %w", err) } // Create a mapping for users that exist ldapUserIDs := make(map[string]struct{}, len(result.Entries)) + savePictures = make([]savePicture, 0, len(result.Entries)) for _, value := range result.Entries { ldapId := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeUserUniqueIdentifier.Value)) @@ -329,13 +359,13 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C Error if err != nil { - return fmt.Errorf("failed to enable user %s: %w", databaseUser.Username, err) + return nil, nil, fmt.Errorf("failed to enable user %s: %w", databaseUser.Username, err) } } if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { // This could error with ErrRecordNotFound and we want to ignore that here - return fmt.Errorf("failed to query for LDAP user ID '%s': %w", ldapId, err) + return nil, nil, fmt.Errorf("failed to query for LDAP user ID '%s': %w", ldapId, err) } // Check if user is admin by checking if they are in the admin group @@ -369,32 +399,35 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C continue } + userID := databaseUser.ID if databaseUser.ID == "" { - _, err = s.userService.createUserInternal(ctx, newUser, true, tx) + createdUser, err := s.userService.createUserInternal(ctx, newUser, true, tx) if errors.Is(err, &common.AlreadyInUseError{}) { slog.Warn("Skipping creating LDAP user", slog.String("username", newUser.Username), slog.Any("error", err)) continue } else if err != nil { - return fmt.Errorf("error creating user '%s': %w", newUser.Username, err) + return nil, nil, fmt.Errorf("error creating user '%s': %w", newUser.Username, err) } + userID = createdUser.ID } else { _, err = s.userService.updateUserInternal(ctx, databaseUser.ID, newUser, false, true, tx) if errors.Is(err, &common.AlreadyInUseError{}) { slog.Warn("Skipping updating LDAP user", slog.String("username", newUser.Username), slog.Any("error", err)) continue } else if err != nil { - return fmt.Errorf("error updating user '%s': %w", newUser.Username, err) + return nil, nil, fmt.Errorf("error updating user '%s': %w", newUser.Username, err) } } // Save profile picture pictureString := value.GetAttributeValue(dbConfig.LdapAttributeUserProfilePicture.Value) if pictureString != "" { - err = s.saveProfilePicture(ctx, databaseUser.ID, pictureString) - if err != nil { - // This is not a fatal error - slog.Warn("Error saving profile picture for user", slog.String("username", newUser.Username), slog.Any("error", err)) - } + // Storage operations must be executed outside of a transaction + savePictures = append(savePictures, savePicture{ + userID: databaseUser.ID, + username: userID, + picture: pictureString, + }) } } @@ -406,10 +439,11 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C Select("id, username, ldap_id, disabled"). Error if err != nil { - return fmt.Errorf("failed to fetch users from database: %w", err) + return nil, nil, fmt.Errorf("failed to fetch users from database: %w", err) } // Mark users as disabled or delete users that no longer exist in LDAP + deleteFiles = make([]string, 0, len(ldapUserIDs)) for _, user := range ldapUsersInDb { // Skip if the user ID exists in the fetched LDAP results if _, exists := ldapUserIDs[*user.LdapID]; exists { @@ -417,26 +451,30 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C } if dbConfig.LdapSoftDeleteUsers.IsTrue() { - err = s.userService.disableUserInternal(ctx, user.ID, tx) + err = s.userService.disableUserInternal(ctx, tx, user.ID) if err != nil { - return fmt.Errorf("failed to disable user %s: %w", user.Username, err) + return nil, nil, fmt.Errorf("failed to disable user %s: %w", user.Username, err) } slog.Info("Disabled user", slog.String("username", user.Username)) } else { - err = s.userService.deleteUserInternal(ctx, user.ID, true, tx) - target := &common.LdapUserUpdateError{} - if errors.As(err, &target) { - return fmt.Errorf("failed to delete user %s: LDAP user must be disabled before deletion", user.Username) - } else if err != nil { - return fmt.Errorf("failed to delete user %s: %w", user.Username, err) + err = s.userService.deleteUserInternal(ctx, tx, user.ID, true) + if err != nil { + target := &common.LdapUserUpdateError{} + if errors.As(err, &target) { + return nil, nil, fmt.Errorf("failed to delete user %s: LDAP user must be disabled before deletion", user.Username) + } + return nil, nil, fmt.Errorf("failed to delete user %s: %w", user.Username, err) } slog.Info("Deleted user", slog.String("username", user.Username)) + + // Storage operations must be executed outside of a transaction + deleteFiles = append(deleteFiles, path.Join("profile-pictures", user.ID+".png")) } } - return nil + return savePictures, deleteFiles, nil } func (s *LdapService) saveProfilePicture(parentCtx context.Context, userId string, pictureString string) error { diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go index 08b9ecae..2d42f452 100644 --- a/backend/internal/service/oidc_service.go +++ b/backend/internal/service/oidc_service.go @@ -12,7 +12,6 @@ import ( "io" "log/slog" "mime/multipart" - "net" "net/http" "net/url" "path" @@ -679,19 +678,21 @@ func (s *OidcService) introspectRefreshToken(ctx context.Context, clientID strin } func (s *OidcService) GetClient(ctx context.Context, clientID string) (model.OidcClient, error) { - return s.getClientInternal(ctx, clientID, s.db) + return s.getClientInternal(ctx, clientID, s.db, false) } -func (s *OidcService) getClientInternal(ctx context.Context, clientID string, tx *gorm.DB) (model.OidcClient, error) { +func (s *OidcService) getClientInternal(ctx context.Context, clientID string, tx *gorm.DB, forUpdate bool) (model.OidcClient, error) { var client model.OidcClient - err := tx. + q := tx. WithContext(ctx). Preload("CreatedBy"). - Preload("AllowedUserGroups"). - First(&client, "id = ?", clientID). - Error - if err != nil { - return model.OidcClient{}, err + Preload("AllowedUserGroups") + if forUpdate { + q = q.Clauses(clause.Locking{Strength: "UPDATE"}) + } + q = q.First(&client, "id = ?", clientID) + if q.Error != nil { + return model.OidcClient{}, q.Error } return client, nil } @@ -724,11 +725,6 @@ func (s *OidcService) ListClients(ctx context.Context, name string, listRequestO } func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) { - tx := s.db.Begin() - defer func() { - tx.Rollback() - }() - client := model.OidcClient{ Base: model.Base{ ID: input.ID, @@ -737,7 +733,7 @@ func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCrea } updateOIDCClientModelFromDto(&client, &input.OidcClientUpdateDto) - err := tx. + err := s.db. WithContext(ctx). Create(&client). Error @@ -748,62 +744,65 @@ func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCrea return model.OidcClient{}, err } + // All storage operations must be executed outside of a transaction if input.LogoURL != nil { - err = s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.LogoURL, true) + err = s.downloadAndSaveLogoFromURL(ctx, client.ID, *input.LogoURL, true) if err != nil { return model.OidcClient{}, fmt.Errorf("failed to download logo: %w", err) } } if input.DarkLogoURL != nil { - err = s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.DarkLogoURL, false) + err = s.downloadAndSaveLogoFromURL(ctx, client.ID, *input.DarkLogoURL, false) if err != nil { return model.OidcClient{}, fmt.Errorf("failed to download dark logo: %w", err) } } - err = tx.Commit().Error - if err != nil { - return model.OidcClient{}, err - } - return client, nil } func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input dto.OidcClientUpdateDto) (model.OidcClient, error) { tx := s.db.Begin() - defer func() { tx.Rollback() }() + defer func() { + tx.Rollback() + }() var client model.OidcClient - if err := tx.WithContext(ctx). + err := tx.WithContext(ctx). Preload("CreatedBy"). - First(&client, "id = ?", clientID).Error; err != nil { + First(&client, "id = ?", clientID).Error + if err != nil { return model.OidcClient{}, err } updateOIDCClientModelFromDto(&client, &input) - if err := tx.WithContext(ctx).Save(&client).Error; err != nil { + err = tx.WithContext(ctx).Save(&client).Error + if err != nil { return model.OidcClient{}, err } + err = tx.Commit().Error + if err != nil { + return model.OidcClient{}, err + } + + // All storage operations must be executed outside of a transaction if input.LogoURL != nil { - err := s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.LogoURL, true) + err = s.downloadAndSaveLogoFromURL(ctx, client.ID, *input.LogoURL, true) if err != nil { return model.OidcClient{}, fmt.Errorf("failed to download logo: %w", err) } } if input.DarkLogoURL != nil { - err := s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.DarkLogoURL, false) + err = s.downloadAndSaveLogoFromURL(ctx, client.ID, *input.DarkLogoURL, false) if err != nil { return model.OidcClient{}, fmt.Errorf("failed to download dark logo: %w", err) } } - if err := tx.Commit().Error; err != nil { - return model.OidcClient{}, err - } return client, nil } @@ -836,12 +835,24 @@ func (s *OidcService) DeleteClient(ctx context.Context, clientID string) error { err := s.db. WithContext(ctx). Where("id = ?", clientID). + Clauses(clause.Returning{}). Delete(&client). Error if err != nil { return err } + // Delete images if present + // Note that storage operations must be done outside of a transaction + if client.ImageType != nil && *client.ImageType != "" { + old := path.Join("oidc-client-images", client.ID+"."+*client.ImageType) + _ = s.fileStorage.Delete(ctx, old) + } + if client.DarkImageType != nil && *client.DarkImageType != "" { + old := path.Join("oidc-client-images", client.ID+"-dark."+*client.DarkImageType) + _ = s.fileStorage.Delete(ctx, old) + } + return nil } @@ -941,57 +952,12 @@ func (s *OidcService) UpdateClientLogo(ctx context.Context, clientID string, fil return err } defer reader.Close() - if err := s.fileStorage.Save(ctx, imagePath, reader); err != nil { - return err - } - - tx := s.db.Begin() - - err = s.updateClientLogoType(ctx, tx, clientID, fileType, light) - if err != nil { - tx.Rollback() - return err - } - - return tx.Commit().Error -} - -func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) error { - tx := s.db.Begin() - defer func() { - tx.Rollback() - }() - - var client model.OidcClient - err := tx. - WithContext(ctx). - First(&client, "id = ?", clientID). - Error + err = s.fileStorage.Save(ctx, imagePath, reader) if err != nil { return err } - if client.ImageType == nil { - return errors.New("image not found") - } - - oldImageType := *client.ImageType - client.ImageType = nil - - err = tx. - WithContext(ctx). - Save(&client). - Error - if err != nil { - return err - } - - imagePath := path.Join("oidc-client-images", client.ID+"."+oldImageType) - if err := s.fileStorage.Delete(ctx, imagePath); err != nil { - return err - } - - err = tx.Commit().Error + err = s.updateClientLogoType(ctx, clientID, fileType, light) if err != nil { return err } @@ -999,7 +965,31 @@ func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) err return nil } +func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) error { + return s.deleteClientLogoInternal(ctx, clientID, "", func(client *model.OidcClient) (string, error) { + if client.ImageType == nil { + return "", errors.New("image not found") + } + + oldImageType := *client.ImageType + client.ImageType = nil + return oldImageType, nil + }) +} + func (s *OidcService) DeleteClientDarkLogo(ctx context.Context, clientID string) error { + return s.deleteClientLogoInternal(ctx, clientID, "-dark", func(client *model.OidcClient) (string, error) { + if client.DarkImageType == nil { + return "", errors.New("image not found") + } + + oldImageType := *client.DarkImageType + client.DarkImageType = nil + return oldImageType, nil + }) +} + +func (s *OidcService) deleteClientLogoInternal(ctx context.Context, clientID string, imagePathSuffix string, setClientImage func(*model.OidcClient) (string, error)) error { tx := s.db.Begin() defer func() { tx.Rollback() @@ -1014,13 +1004,11 @@ func (s *OidcService) DeleteClientDarkLogo(ctx context.Context, clientID string) return err } - if client.DarkImageType == nil { - return errors.New("image not found") + oldImageType, err := setClientImage(&client) + if err != nil { + return err } - oldImageType := *client.DarkImageType - client.DarkImageType = nil - err = tx. WithContext(ctx). Save(&client). @@ -1029,12 +1017,14 @@ func (s *OidcService) DeleteClientDarkLogo(ctx context.Context, clientID string) return err } - imagePath := path.Join("oidc-client-images", client.ID+"-dark."+oldImageType) - if err := s.fileStorage.Delete(ctx, imagePath); err != nil { + err = tx.Commit().Error + if err != nil { return err } - err = tx.Commit().Error + // All storage operations must be performed outside of a database transaction + imagePath := path.Join("oidc-client-images", client.ID+imagePathSuffix+"."+oldImageType) + err = s.fileStorage.Delete(ctx, imagePath) if err != nil { return err } @@ -1048,7 +1038,7 @@ func (s *OidcService) UpdateAllowedUserGroups(ctx context.Context, id string, in tx.Rollback() }() - client, err = s.getClientInternal(ctx, id, tx) + client, err = s.getClientInternal(ctx, id, tx, true) if err != nil { return model.OidcClient{}, err } @@ -1831,7 +1821,7 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use tx.Rollback() }() - client, err := s.getClientInternal(ctx, clientID, tx) + client, err := s.getClientInternal(ctx, clientID, tx, false) if err != nil { return nil, err } @@ -1976,7 +1966,25 @@ func (s *OidcService) IsClientAccessibleToUser(ctx context.Context, clientID str return s.IsUserGroupAllowedToAuthorize(user, client), nil } -func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *gorm.DB, clientID string, raw string, light bool) error { +var errLogoTooLarge = errors.New("logo is too large") + +func httpClientWithCheckRedirect(source *http.Client, checkRedirect func(req *http.Request, via []*http.Request) error) *http.Client { + if source == nil { + source = http.DefaultClient + } + + // Create a new client that clones the transport + client := &http.Client{ + Transport: source.Transport, + } + + // Assign the CheckRedirect function + client.CheckRedirect = checkRedirect + + return client +} + +func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, clientID string, raw string, light bool) error { u, err := url.Parse(raw) if err != nil { return err @@ -1985,18 +1993,29 @@ func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx * ctx, cancel := context.WithTimeout(parentCtx, 15*time.Second) defer cancel() - r := net.Resolver{} - ips, err := r.LookupIPAddr(ctx, u.Hostname()) - if err != nil || len(ips) == 0 { - return fmt.Errorf("cannot resolve hostname") + // Prevents SSRF by allowing only public IPs + ok, err := utils.IsURLPrivate(ctx, u) + if err != nil { + return err + } else if ok { + return errors.New("private IP addresses are not allowed") } - // Prevents SSRF by allowing only public IPs - for _, addr := range ips { - if utils.IsPrivateIP(addr.IP) { - return fmt.Errorf("private IP addresses are not allowed") + // We need to check this on redirects too + client := httpClientWithCheckRedirect(s.httpClient, func(r *http.Request, via []*http.Request) error { + if len(via) >= 10 { + return errors.New("stopped after 10 redirects") } - } + + ok, err := utils.IsURLPrivate(r.Context(), r.URL) + if err != nil { + return err + } else if ok { + return errors.New("private IP addresses are not allowed") + } + + return nil + }) req, err := http.NewRequestWithContext(ctx, http.MethodGet, raw, nil) if err != nil { @@ -2005,7 +2024,7 @@ func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx * req.Header.Set("User-Agent", "pocket-id/oidc-logo-fetcher") req.Header.Set("Accept", "image/*") - resp, err := s.httpClient.Do(req) + resp, err := client.Do(req) if err != nil { return err } @@ -2017,7 +2036,7 @@ func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx * const maxLogoSize int64 = 2 * 1024 * 1024 // 2MB if resp.ContentLength > maxLogoSize { - return fmt.Errorf("logo is too large") + return errLogoTooLarge } // Prefer extension in path if supported @@ -2037,48 +2056,70 @@ func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx * } imagePath := path.Join("oidc-client-images", clientID+darkSuffix+"."+ext) - if err := s.fileStorage.Save(ctx, imagePath, io.LimitReader(resp.Body, maxLogoSize+1)); err != nil { + err = s.fileStorage.Save(ctx, imagePath, utils.NewLimitReader(resp.Body, maxLogoSize+1)) + if errors.Is(err, utils.ErrSizeExceeded) { + return errLogoTooLarge + } else if err != nil { return err } - if err := s.updateClientLogoType(ctx, tx, clientID, ext, light); err != nil { + err = s.updateClientLogoType(ctx, clientID, ext, light) + if err != nil { return err } return nil } -func (s *OidcService) updateClientLogoType(ctx context.Context, tx *gorm.DB, clientID, ext string, light bool) error { +func (s *OidcService) updateClientLogoType(ctx context.Context, clientID string, ext string, light bool) error { var darkSuffix string if !light { darkSuffix = "-dark" } + tx := s.db.Begin() + defer func() { + tx.Rollback() + }() + + // We need to acquire an update lock for the row to be locked, since we'll update it later var client model.OidcClient - if err := tx.WithContext(ctx).First(&client, "id = ?", clientID).Error; err != nil { - return err + err := tx. + WithContext(ctx). + Clauses(clause.Locking{Strength: "UPDATE"}). + First(&client, "id = ?", clientID). + Error + if err != nil { + return fmt.Errorf("failed to look up client: %w", err) } + var currentType *string if light { currentType = client.ImageType + client.ImageType = &ext } else { currentType = client.DarkImageType + client.DarkImageType = &ext } + + err = tx. + WithContext(ctx). + Save(&client). + Error + if err != nil { + return fmt.Errorf("failed to save updated client: %w", err) + } + + err = tx.Commit().Error + if err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + + // Storage operations must be executed outside of a transaction if currentType != nil && *currentType != ext { old := path.Join("oidc-client-images", client.ID+darkSuffix+"."+*currentType) _ = s.fileStorage.Delete(ctx, old) } - var column string - if light { - column = "image_type" - } else { - column = "dark_image_type" - } - - return tx.WithContext(ctx). - Model(&model.OidcClient{}). - Where("id = ?", clientID). - Update(column, ext). - Error + return nil } diff --git a/backend/internal/service/oidc_service_test.go b/backend/internal/service/oidc_service_test.go index 5d4d04df..0f4b5b6c 100644 --- a/backend/internal/service/oidc_service_test.go +++ b/backend/internal/service/oidc_service_test.go @@ -8,7 +8,10 @@ import ( "crypto/sha256" "encoding/base64" "encoding/json" + "io" "net/http" + "strconv" + "strings" "testing" "time" @@ -21,6 +24,7 @@ import ( "github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/dto" "github.com/pocket-id/pocket-id/backend/internal/model" + "github.com/pocket-id/pocket-id/backend/internal/storage" testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing" ) @@ -537,3 +541,435 @@ func TestValidateCodeVerifier_Plain(t *testing.T) { require.False(t, validateCodeVerifier("NOT!VALID", codeChallenge, true)) }) } + +func TestOidcService_updateClientLogoType(t *testing.T) { + // Create a test database + db := testutils.NewDatabaseForTest(t) + + // Create database storage + dbStorage, err := storage.NewDatabaseStorage(db) + require.NoError(t, err) + + // Init the OidcService + s := &OidcService{ + db: db, + fileStorage: dbStorage, + } + + // Create a test client + client := model.OidcClient{ + Name: "Test Client", + CallbackURLs: model.UrlList{"https://example.com/callback"}, + } + err = db.Create(&client).Error + require.NoError(t, err) + + // Helper function to check if a file exists in storage + fileExists := func(t *testing.T, path string) bool { + t.Helper() + _, _, err := dbStorage.Open(t.Context(), path) + return err == nil + } + + // Helper function to create a dummy file in storage + createDummyFile := func(t *testing.T, path string) { + t.Helper() + err := dbStorage.Save(t.Context(), path, strings.NewReader("dummy content")) + require.NoError(t, err) + } + + t.Run("Updates light logo type for client without previous logo", func(t *testing.T) { + // Update the logo type + err := s.updateClientLogoType(t.Context(), client.ID, "png", true) + require.NoError(t, err) + + // Verify the client was updated + var updatedClient model.OidcClient + err = db.First(&updatedClient, "id = ?", client.ID).Error + require.NoError(t, err) + require.NotNil(t, updatedClient.ImageType) + assert.Equal(t, "png", *updatedClient.ImageType) + }) + + t.Run("Updates dark logo type for client without previous dark logo", func(t *testing.T) { + // Update the dark logo type + err := s.updateClientLogoType(t.Context(), client.ID, "jpg", false) + require.NoError(t, err) + + // Verify the client was updated + var updatedClient model.OidcClient + err = db.First(&updatedClient, "id = ?", client.ID).Error + require.NoError(t, err) + require.NotNil(t, updatedClient.DarkImageType) + assert.Equal(t, "jpg", *updatedClient.DarkImageType) + }) + + t.Run("Updates light logo type and deletes old file when type changes", func(t *testing.T) { + // Create the old PNG file in storage + oldPath := "oidc-client-images/" + client.ID + ".png" + createDummyFile(t, oldPath) + require.True(t, fileExists(t, oldPath), "Old file should exist before update") + + // Client currently has a PNG logo, update to WEBP + err := s.updateClientLogoType(t.Context(), client.ID, "webp", true) + require.NoError(t, err) + + // Verify the client was updated + var updatedClient model.OidcClient + err = db.First(&updatedClient, "id = ?", client.ID).Error + require.NoError(t, err) + require.NotNil(t, updatedClient.ImageType) + assert.Equal(t, "webp", *updatedClient.ImageType) + + // Old PNG file should be deleted + assert.False(t, fileExists(t, oldPath), "Old PNG file should have been deleted") + }) + + t.Run("Updates dark logo type and deletes old file when type changes", func(t *testing.T) { + // Create the old JPG dark file in storage + oldPath := "oidc-client-images/" + client.ID + "-dark.jpg" + createDummyFile(t, oldPath) + require.True(t, fileExists(t, oldPath), "Old dark file should exist before update") + + // Client currently has a JPG dark logo, update to WEBP + err := s.updateClientLogoType(t.Context(), client.ID, "webp", false) + require.NoError(t, err) + + // Verify the client was updated + var updatedClient model.OidcClient + err = db.First(&updatedClient, "id = ?", client.ID).Error + require.NoError(t, err) + require.NotNil(t, updatedClient.DarkImageType) + assert.Equal(t, "webp", *updatedClient.DarkImageType) + + // Old JPG dark file should be deleted + assert.False(t, fileExists(t, oldPath), "Old JPG dark file should have been deleted") + }) + + t.Run("Does not delete file when type remains the same", func(t *testing.T) { + // Create the WEBP file in storage + webpPath := "oidc-client-images/" + client.ID + ".webp" + createDummyFile(t, webpPath) + require.True(t, fileExists(t, webpPath), "WEBP file should exist before update") + + // Update to the same type (WEBP) + err := s.updateClientLogoType(t.Context(), client.ID, "webp", true) + require.NoError(t, err) + + // Verify the client still has WEBP + var updatedClient model.OidcClient + err = db.First(&updatedClient, "id = ?", client.ID).Error + require.NoError(t, err) + require.NotNil(t, updatedClient.ImageType) + assert.Equal(t, "webp", *updatedClient.ImageType) + + // WEBP file should still exist since type didn't change + assert.True(t, fileExists(t, webpPath), "WEBP file should still exist") + }) + + t.Run("Returns error for non-existent client", func(t *testing.T) { + err := s.updateClientLogoType(t.Context(), "non-existent-client-id", "png", true) + require.Error(t, err) + require.ErrorContains(t, err, "failed to look up client") + }) +} + +func TestOidcService_downloadAndSaveLogoFromURL(t *testing.T) { + // Create a test database + db := testutils.NewDatabaseForTest(t) + + // Create database storage + dbStorage, err := storage.NewDatabaseStorage(db) + require.NoError(t, err) + + // Create a test client + client := model.OidcClient{ + Name: "Test Client", + CallbackURLs: model.UrlList{"https://example.com/callback"}, + } + err = db.Create(&client).Error + require.NoError(t, err) + + // Helper function to check if a file exists in storage + fileExists := func(t *testing.T, path string) bool { + t.Helper() + _, _, err := dbStorage.Open(t.Context(), path) + return err == nil + } + + // Helper function to get file content from storage + getFileContent := func(t *testing.T, path string) []byte { + t.Helper() + reader, _, err := dbStorage.Open(t.Context(), path) + require.NoError(t, err) + defer reader.Close() + content, err := io.ReadAll(reader) + require.NoError(t, err) + return content + } + + t.Run("Successfully downloads and saves PNG logo from URL", func(t *testing.T) { + // Create mock PNG content + pngContent := []byte("fake-png-content") + + // Create a mock HTTP response with headers + //nolint:bodyclose + pngResponse := testutils.NewMockResponse(http.StatusOK, string(pngContent)) + pngResponse.Header.Set("Content-Type", "image/png") + + // Create a mock HTTP client with responses + mockResponses := map[string]*http.Response{ + //nolint:bodyclose + "https://example.com/logo.png": pngResponse, + } + httpClient := &http.Client{ + Transport: &testutils.MockRoundTripper{ + Responses: mockResponses, + }, + } + + // Init the OidcService with mock HTTP client + s := &OidcService{ + db: db, + fileStorage: dbStorage, + httpClient: httpClient, + } + + // Download and save the logo + err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/logo.png", true) + require.NoError(t, err) + + // Verify the file was saved + logoPath := "oidc-client-images/" + client.ID + ".png" + require.True(t, fileExists(t, logoPath), "Logo file should exist in storage") + + // Verify the content + savedContent := getFileContent(t, logoPath) + assert.Equal(t, pngContent, savedContent) + + // Verify the client was updated + var updatedClient model.OidcClient + err = db.First(&updatedClient, "id = ?", client.ID).Error + require.NoError(t, err) + require.NotNil(t, updatedClient.ImageType) + assert.Equal(t, "png", *updatedClient.ImageType) + }) + + t.Run("Successfully downloads and saves dark logo", func(t *testing.T) { + // Create mock WEBP content + webpContent := []byte("fake-webp-content") + + //nolint:bodyclose + webpResponse := testutils.NewMockResponse(http.StatusOK, string(webpContent)) + webpResponse.Header.Set("Content-Type", "image/webp") + + mockResponses := map[string]*http.Response{ + //nolint:bodyclose + "https://example.com/dark-logo.webp": webpResponse, + } + httpClient := &http.Client{ + Transport: &testutils.MockRoundTripper{ + Responses: mockResponses, + }, + } + + s := &OidcService{ + db: db, + fileStorage: dbStorage, + httpClient: httpClient, + } + + // Download and save the dark logo + err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/dark-logo.webp", false) + require.NoError(t, err) + + // Verify the dark logo file was saved + darkLogoPath := "oidc-client-images/" + client.ID + "-dark.webp" + require.True(t, fileExists(t, darkLogoPath), "Dark logo file should exist in storage") + + // Verify the content + savedContent := getFileContent(t, darkLogoPath) + assert.Equal(t, webpContent, savedContent) + + // Verify the client was updated + var updatedClient model.OidcClient + err = db.First(&updatedClient, "id = ?", client.ID).Error + require.NoError(t, err) + require.NotNil(t, updatedClient.DarkImageType) + assert.Equal(t, "webp", *updatedClient.DarkImageType) + }) + + t.Run("Detects extension from URL path", func(t *testing.T) { + svgContent := []byte("") + + mockResponses := map[string]*http.Response{ + //nolint:bodyclose + "https://example.com/icon.svg": testutils.NewMockResponse(http.StatusOK, string(svgContent)), + } + httpClient := &http.Client{ + Transport: &testutils.MockRoundTripper{ + Responses: mockResponses, + }, + } + + s := &OidcService{ + db: db, + fileStorage: dbStorage, + httpClient: httpClient, + } + + err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/icon.svg", true) + require.NoError(t, err) + + // Verify SVG file was saved + logoPath := "oidc-client-images/" + client.ID + ".svg" + require.True(t, fileExists(t, logoPath), "SVG logo should exist") + }) + + t.Run("Detects extension from Content-Type when path has no extension", func(t *testing.T) { + jpgContent := []byte("fake-jpg-content") + + //nolint:bodyclose + jpgResponse := testutils.NewMockResponse(http.StatusOK, string(jpgContent)) + jpgResponse.Header.Set("Content-Type", "image/jpeg") + + mockResponses := map[string]*http.Response{ + //nolint:bodyclose + "https://example.com/logo": jpgResponse, + } + httpClient := &http.Client{ + Transport: &testutils.MockRoundTripper{ + Responses: mockResponses, + }, + } + + s := &OidcService{ + db: db, + fileStorage: dbStorage, + httpClient: httpClient, + } + + err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/logo", true) + require.NoError(t, err) + + // Verify JPG file was saved (jpeg extension is normalized to jpg) + logoPath := "oidc-client-images/" + client.ID + ".jpg" + require.True(t, fileExists(t, logoPath), "JPG logo should exist") + }) + + t.Run("Returns error for invalid URL", func(t *testing.T) { + s := &OidcService{ + db: db, + fileStorage: dbStorage, + httpClient: &http.Client{}, + } + + err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "://invalid-url", true) + require.Error(t, err) + }) + + t.Run("Returns error for non-200 status code", func(t *testing.T) { + mockResponses := map[string]*http.Response{ + //nolint:bodyclose + "https://example.com/not-found.png": testutils.NewMockResponse(http.StatusNotFound, "Not Found"), + } + httpClient := &http.Client{ + Transport: &testutils.MockRoundTripper{ + Responses: mockResponses, + }, + } + + s := &OidcService{ + db: db, + fileStorage: dbStorage, + httpClient: httpClient, + } + + err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/not-found.png", true) + require.Error(t, err) + require.ErrorContains(t, err, "failed to fetch logo") + }) + + t.Run("Returns error for too large content", func(t *testing.T) { + // Create content larger than 2MB (maxLogoSize) + largeContent := strings.Repeat("x", 2<<20+100) // 2.1MB + + //nolint:bodyclose + largeResponse := testutils.NewMockResponse(http.StatusOK, largeContent) + largeResponse.Header.Set("Content-Type", "image/png") + largeResponse.Header.Set("Content-Length", strconv.Itoa(len(largeContent))) + + mockResponses := map[string]*http.Response{ + //nolint:bodyclose + "https://example.com/large.png": largeResponse, + } + httpClient := &http.Client{ + Transport: &testutils.MockRoundTripper{ + Responses: mockResponses, + }, + } + + s := &OidcService{ + db: db, + fileStorage: dbStorage, + httpClient: httpClient, + } + + err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/large.png", true) + require.Error(t, err) + require.ErrorIs(t, err, errLogoTooLarge) + }) + + t.Run("Returns error for unsupported file type", func(t *testing.T) { + //nolint:bodyclose + textResponse := testutils.NewMockResponse(http.StatusOK, "text content") + textResponse.Header.Set("Content-Type", "text/plain") + + mockResponses := map[string]*http.Response{ + //nolint:bodyclose + "https://example.com/file.txt": textResponse, + } + httpClient := &http.Client{ + Transport: &testutils.MockRoundTripper{ + Responses: mockResponses, + }, + } + + s := &OidcService{ + db: db, + fileStorage: dbStorage, + httpClient: httpClient, + } + + err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/file.txt", true) + require.Error(t, err) + var fileTypeErr *common.FileTypeNotSupportedError + require.ErrorAs(t, err, &fileTypeErr) + }) + + t.Run("Returns error for non-existent client", func(t *testing.T) { + //nolint:bodyclose + pngResponse := testutils.NewMockResponse(http.StatusOK, "content") + pngResponse.Header.Set("Content-Type", "image/png") + + mockResponses := map[string]*http.Response{ + //nolint:bodyclose + "https://example.com/logo.png": pngResponse, + } + httpClient := &http.Client{ + Transport: &testutils.MockRoundTripper{ + Responses: mockResponses, + }, + } + + s := &OidcService{ + db: db, + fileStorage: dbStorage, + httpClient: httpClient, + } + + err := s.downloadAndSaveLogoFromURL(t.Context(), "non-existent-client-id", "https://example.com/logo.png", true) + require.Error(t, err) + require.ErrorContains(t, err, "failed to look up client") + }) +} diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 6cef3176..fa92a27d 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -17,6 +17,7 @@ import ( "github.com/google/uuid" "go.opentelemetry.io/otel/trace" "gorm.io/gorm" + "gorm.io/gorm/clause" "github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/dto" @@ -101,9 +102,10 @@ func (s *UserService) GetProfilePicture(ctx context.Context, userID string) (io. profilePicturePath := path.Join("profile-pictures", userID+".png") // Try custom profile picture - if file, size, err := s.fileStorage.Open(ctx, profilePicturePath); err == nil { + file, size, err := s.fileStorage.Open(ctx, profilePicturePath) + if err == nil { return file, size, nil - } else if err != nil && !errors.Is(err, fs.ErrNotExist) { + } else if !errors.Is(err, fs.ErrNotExist) { return nil, 0, err } @@ -120,9 +122,10 @@ func (s *UserService) GetProfilePicture(ctx context.Context, userID string) (io. // Try cached default for initials defaultPicturePath := path.Join("profile-pictures", "defaults", user.Initials()+".png") - if file, size, err := s.fileStorage.Open(ctx, defaultPicturePath); err == nil { + file, size, err = s.fileStorage.Open(ctx, defaultPicturePath) + if err == nil { return file, size, nil - } else if err != nil && !errors.Is(err, fs.ErrNotExist) { + } else if !errors.Is(err, fs.ErrNotExist) { return nil, 0, err } @@ -133,12 +136,13 @@ func (s *UserService) GetProfilePicture(ctx context.Context, userID string) (io. } // Save the default picture for future use (in a goroutine to avoid blocking) - //nolint:contextcheck defaultPictureBytes := defaultPicture.Bytes() //nolint:contextcheck go func() { - if err := s.fileStorage.Save(context.Background(), defaultPicturePath, bytes.NewReader(defaultPictureBytes)); err != nil { - slog.Error("Failed to cache default profile picture", slog.String("initials", user.Initials()), slog.Any("error", err)) + // Use bytes.NewReader because we need an io.ReadSeeker + rErr := s.fileStorage.Save(context.Background(), defaultPicturePath, bytes.NewReader(defaultPictureBytes)) + if rErr != nil { + slog.Error("Failed to cache default profile picture", slog.String("initials", user.Initials()), slog.Any("error", rErr)) } }() @@ -182,17 +186,30 @@ func (s *UserService) UpdateProfilePicture(ctx context.Context, userID string, f } func (s *UserService) DeleteUser(ctx context.Context, userID string, allowLdapDelete bool) error { - return s.db.Transaction(func(tx *gorm.DB) error { - return s.deleteUserInternal(ctx, userID, allowLdapDelete, tx) + err := s.db.Transaction(func(tx *gorm.DB) error { + return s.deleteUserInternal(ctx, tx, userID, allowLdapDelete) }) + if err != nil { + return fmt.Errorf("failed to delete user '%s': %w", userID, err) + } + + // Storage operations must be executed outside of a transaction + profilePicturePath := path.Join("profile-pictures", userID+".png") + err = s.fileStorage.Delete(ctx, profilePicturePath) + if err != nil && !storage.IsNotExist(err) { + return fmt.Errorf("failed to delete profile picture for user '%s': %w", userID, err) + } + + return nil } -func (s *UserService) deleteUserInternal(ctx context.Context, userID string, allowLdapDelete bool, tx *gorm.DB) error { +func (s *UserService) deleteUserInternal(ctx context.Context, tx *gorm.DB, userID string, allowLdapDelete bool) error { var user model.User err := tx. WithContext(ctx). Where("id = ?", userID). + Clauses(clause.Locking{Strength: "UPDATE"}). First(&user). Error if err != nil { @@ -204,11 +221,6 @@ func (s *UserService) deleteUserInternal(ctx context.Context, userID string, all return &common.LdapUserUpdateError{} } - profilePicturePath := path.Join("profile-pictures", userID+".png") - if err := s.fileStorage.Delete(ctx, profilePicturePath); err != nil { - return err - } - err = tx.WithContext(ctx).Delete(&user).Error if err != nil { return fmt.Errorf("failed to delete user: %w", err) @@ -286,16 +298,27 @@ func (s *UserService) applySignupDefaults(ctx context.Context, user *model.User, // Apply default user groups var groupIDs []string - if v := config.SignupDefaultUserGroupIDs.Value; v != "" && v != "[]" { - if err := json.Unmarshal([]byte(v), &groupIDs); err != nil { + v := config.SignupDefaultUserGroupIDs.Value + if v != "" && v != "[]" { + err := json.Unmarshal([]byte(v), &groupIDs) + if err != nil { return fmt.Errorf("invalid SignupDefaultUserGroupIDs JSON: %w", err) } if len(groupIDs) > 0 { var groups []model.UserGroup - if err := tx.WithContext(ctx).Where("id IN ?", groupIDs).Find(&groups).Error; err != nil { + err = tx.WithContext(ctx). + Where("id IN ?", groupIDs). + Find(&groups). + Error + if err != nil { return fmt.Errorf("failed to find default user groups: %w", err) } - if err := tx.WithContext(ctx).Model(user).Association("UserGroups").Replace(groups); err != nil { + + err = tx.WithContext(ctx). + Model(user). + Association("UserGroups"). + Replace(groups) + if err != nil { return fmt.Errorf("failed to associate default user groups: %w", err) } } @@ -303,12 +326,15 @@ func (s *UserService) applySignupDefaults(ctx context.Context, user *model.User, // Apply default custom claims var claims []dto.CustomClaimCreateDto - if v := config.SignupDefaultCustomClaims.Value; v != "" && v != "[]" { - if err := json.Unmarshal([]byte(v), &claims); err != nil { + v = config.SignupDefaultCustomClaims.Value + if v != "" && v != "[]" { + err := json.Unmarshal([]byte(v), &claims) + if err != nil { return fmt.Errorf("invalid SignupDefaultCustomClaims JSON: %w", err) } if len(claims) > 0 { - if _, err := s.customClaimService.updateCustomClaimsInternal(ctx, UserID, user.ID, claims, tx); err != nil { + _, err = s.customClaimService.updateCustomClaimsInternal(ctx, UserID, user.ID, claims, tx) + if err != nil { return fmt.Errorf("failed to apply default custom claims: %w", err) } } @@ -345,6 +371,7 @@ func (s *UserService) updateUserInternal(ctx context.Context, userID string, upd err := tx. WithContext(ctx). Where("id = ?", userID). + Clauses(clause.Locking{Strength: "UPDATE"}). First(&user). Error if err != nil { @@ -416,13 +443,11 @@ func (s *UserService) RequestOneTimeAccessEmailAsUnauthenticatedUser(ctx context var userId string err := s.db.Model(&model.User{}).Select("id").Where("email = ?", userID).First(&userId).Error - if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { // Do not return error if user not found to prevent email enumeration - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil - } else { - return err - } + return nil + } else if err != nil { + return err } return s.requestOneTimeAccessEmailInternal(ctx, userId, redirectPath, 15*time.Minute) @@ -513,7 +538,9 @@ func (s *UserService) ExchangeOneTimeAccessToken(ctx context.Context, token stri var oneTimeAccessToken model.OneTimeAccessToken err := tx. WithContext(ctx). - Where("token = ? AND expires_at > ?", token, datatype.DateTime(time.Now())).Preload("User"). + Where("token = ? AND expires_at > ?", token, datatype.DateTime(time.Now())). + Preload("User"). + Clauses(clause.Locking{Strength: "UPDATE"}). First(&oneTimeAccessToken). Error if err != nil { @@ -679,7 +706,7 @@ func (s *UserService) ResetProfilePicture(ctx context.Context, userID string) er return nil } -func (s *UserService) disableUserInternal(ctx context.Context, userID string, tx *gorm.DB) error { +func (s *UserService) disableUserInternal(ctx context.Context, tx *gorm.DB, userID string) error { return tx. WithContext(ctx). Model(&model.User{}). @@ -720,6 +747,7 @@ func (s *UserService) SignUp(ctx context.Context, signupData dto.SignUpDto, ipAd err := tx. WithContext(ctx). Where("token = ?", signupData.Token). + Clauses(clause.Locking{Strength: "UPDATE"}). First(&signupToken). Error if err != nil { diff --git a/backend/internal/service/version_service.go b/backend/internal/service/version_service.go index 2a12930a..d52fd6f9 100644 --- a/backend/internal/service/version_service.go +++ b/backend/internal/service/version_service.go @@ -58,7 +58,7 @@ func (s *VersionService) GetLatestVersion(ctx context.Context) (string, error) { } if payload.TagName == "" { - return "", fmt.Errorf("GitHub API returned empty tag name") + return "", errors.New("GitHub API returned empty tag name") } return strings.TrimPrefix(payload.TagName, "v"), nil diff --git a/backend/internal/storage/database.go b/backend/internal/storage/database.go new file mode 100644 index 00000000..2c8779dc --- /dev/null +++ b/backend/internal/storage/database.go @@ -0,0 +1,226 @@ +package storage + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strings" + "time" + + "github.com/pocket-id/pocket-id/backend/internal/model" + datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +var TypeDatabase = "database" + +type databaseStorage struct { + db *gorm.DB +} + +// NewDatabaseStorage creates a new database storage provider +func NewDatabaseStorage(db *gorm.DB) (FileStorage, error) { + if db == nil { + return nil, errors.New("database connection is required") + } + return &databaseStorage{db: db}, nil +} + +func (s *databaseStorage) Type() string { + return TypeDatabase +} + +func (s *databaseStorage) Save(ctx context.Context, relativePath string, data io.Reader) error { + // Normalize the path + relativePath = filepath.ToSlash(filepath.Clean(relativePath)) + + // Read all data into memory + b, err := io.ReadAll(data) + if err != nil { + return fmt.Errorf("failed to read data: %w", err) + } + + now := datatype.DateTime(time.Now()) + storage := model.Storage{ + Path: relativePath, + Data: b, + Size: int64(len(b)), + ModTime: now, + CreatedAt: now, + } + + // Use upsert: insert or update on conflict + result := s.db. + WithContext(ctx). + Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "path"}}, + DoUpdates: clause.AssignmentColumns([]string{"data", "size", "mod_time"}), + }). + Create(&storage) + + if result.Error != nil { + return fmt.Errorf("failed to save file to database: %w", result.Error) + } + + return nil +} + +func (s *databaseStorage) Open(ctx context.Context, relativePath string) (io.ReadCloser, int64, error) { + relativePath = filepath.ToSlash(filepath.Clean(relativePath)) + + var storage model.Storage + result := s.db. + WithContext(ctx). + Where("path = ?", relativePath). + First(&storage) + + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, 0, os.ErrNotExist + } + return nil, 0, fmt.Errorf("failed to read file from database: %w", result.Error) + } + + reader := io.NopCloser(bytes.NewReader(storage.Data)) + return reader, storage.Size, nil +} + +func (s *databaseStorage) Delete(ctx context.Context, relativePath string) error { + relativePath = filepath.ToSlash(filepath.Clean(relativePath)) + + result := s.db. + WithContext(ctx). + Where("path = ?", relativePath). + Delete(&model.Storage{}) + if result.Error != nil { + return fmt.Errorf("failed to delete file from database: %w", result.Error) + } + + return nil +} + +func (s *databaseStorage) DeleteAll(ctx context.Context, prefix string) error { + prefix = filepath.ToSlash(filepath.Clean(prefix)) + + // If empty prefix, delete all + if isRootPath(prefix) { + result := s.db. + WithContext(ctx). + Where("1 = 1"). // Delete everything + Delete(&model.Storage{}) + if result.Error != nil { + return fmt.Errorf("failed to delete all files from database: %w", result.Error) + } + return nil + } + + // Ensure prefix ends with / for proper prefix matching + if !strings.HasSuffix(prefix, "/") { + prefix += "/" + } + + query := s.db.WithContext(ctx) + query = addPathPrefixClause(s.db.Name(), query, prefix) + result := query.Delete(&model.Storage{}) + if result.Error != nil { + return fmt.Errorf("failed to delete files with prefix '%s' from database: %w", prefix, result.Error) + } + + return nil +} + +func (s *databaseStorage) List(ctx context.Context, prefix string) ([]ObjectInfo, error) { + prefix = filepath.ToSlash(filepath.Clean(prefix)) + + var storageItems []model.Storage + query := s.db.WithContext(ctx) + + if !isRootPath(prefix) { + // Ensure prefix matching + if !strings.HasSuffix(prefix, "/") { + prefix += "/" + } + query = addPathPrefixClause(s.db.Name(), query, prefix) + } + + result := query. + Select("path", "size", "mod_time"). + Find(&storageItems) + if result.Error != nil { + return nil, fmt.Errorf("failed to list files from database: %w", result.Error) + } + + objects := make([]ObjectInfo, 0, len(storageItems)) + for _, item := range storageItems { + // Filter out directory-like paths (those that contain additional slashes after the prefix) + relativePath := strings.TrimPrefix(item.Path, prefix) + if strings.ContainsRune(relativePath, '/') { + continue + } + + objects = append(objects, ObjectInfo{ + Path: item.Path, + Size: item.Size, + ModTime: time.Time(item.ModTime), + }) + } + + return objects, nil +} + +func (s *databaseStorage) Walk(ctx context.Context, root string, fn func(ObjectInfo) error) error { + root = filepath.ToSlash(filepath.Clean(root)) + + var storageItems []model.Storage + query := s.db.WithContext(ctx) + + if !isRootPath(root) { + // Ensure root matching + if !strings.HasSuffix(root, "/") { + root += "/" + } + query = addPathPrefixClause(s.db.Name(), query, root) + } + + result := query. + Select("path", "size", "mod_time"). + Find(&storageItems) + if result.Error != nil { + return fmt.Errorf("failed to walk files from database: %w", result.Error) + } + + for _, item := range storageItems { + err := fn(ObjectInfo{ + Path: item.Path, + Size: item.Size, + ModTime: time.Time(item.ModTime), + }) + if err != nil { + return err + } + } + + return nil +} + +func isRootPath(path string) bool { + return path == "" || path == "/" || path == "." +} + +func addPathPrefixClause(dialect string, query *gorm.DB, prefix string) *gorm.DB { + // In SQLite, we use "GLOB" which can use the index + switch dialect { + case "sqlite": + return query.Where("path GLOB ?", prefix+"*") + case "postgres": + return query.Where("path LIKE ?", prefix+"%") + default: + // Indicates a development-time error + panic(fmt.Errorf("unsupported database dialect: %s", dialect)) + } +} diff --git a/backend/internal/storage/database_test.go b/backend/internal/storage/database_test.go new file mode 100644 index 00000000..208fb7b7 --- /dev/null +++ b/backend/internal/storage/database_test.go @@ -0,0 +1,148 @@ +package storage + +import ( + "bytes" + "context" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + testingutil "github.com/pocket-id/pocket-id/backend/internal/utils/testing" +) + +func TestDatabaseStorageOperations(t *testing.T) { + ctx := context.Background() + db := testingutil.NewDatabaseForTest(t) + store, err := NewDatabaseStorage(db) + require.NoError(t, err) + + t.Run("type should be database", func(t *testing.T) { + assert.Equal(t, TypeDatabase, store.Type()) + }) + + t.Run("save, open and list files", func(t *testing.T) { + err := store.Save(ctx, "images/logo.png", bytes.NewBufferString("logo-data")) + require.NoError(t, err) + + reader, size, err := store.Open(ctx, "images/logo.png") + require.NoError(t, err) + defer reader.Close() + + contents, err := io.ReadAll(reader) + require.NoError(t, err) + assert.Equal(t, []byte("logo-data"), contents) + assert.Equal(t, int64(len(contents)), size) + + err = store.Save(ctx, "images/nested/child.txt", bytes.NewBufferString("child")) + require.NoError(t, err) + + files, err := store.List(ctx, "images") + require.NoError(t, err) + require.Len(t, files, 1) + assert.Equal(t, "images/logo.png", files[0].Path) + assert.Equal(t, int64(len("logo-data")), files[0].Size) + }) + + t.Run("save should update existing file", func(t *testing.T) { + err := store.Save(ctx, "test/update.txt", bytes.NewBufferString("original")) + require.NoError(t, err) + + err = store.Save(ctx, "test/update.txt", bytes.NewBufferString("updated")) + require.NoError(t, err) + + reader, size, err := store.Open(ctx, "test/update.txt") + require.NoError(t, err) + defer reader.Close() + + contents, err := io.ReadAll(reader) + require.NoError(t, err) + assert.Equal(t, []byte("updated"), contents) + assert.Equal(t, int64(len("updated")), size) + }) + + t.Run("delete files individually", func(t *testing.T) { + err := store.Save(ctx, "images/delete-me.txt", bytes.NewBufferString("temp")) + require.NoError(t, err) + + require.NoError(t, store.Delete(ctx, "images/delete-me.txt")) + _, _, err = store.Open(ctx, "images/delete-me.txt") + require.Error(t, err) + assert.True(t, IsNotExist(err)) + }) + + t.Run("delete missing file should not error", func(t *testing.T) { + require.NoError(t, store.Delete(ctx, "images/missing.txt")) + }) + + t.Run("delete all files", func(t *testing.T) { + require.NoError(t, store.Save(ctx, "cleanup/a.txt", bytes.NewBufferString("a"))) + require.NoError(t, store.Save(ctx, "cleanup/b.txt", bytes.NewBufferString("b"))) + require.NoError(t, store.Save(ctx, "cleanup/nested/c.txt", bytes.NewBufferString("c"))) + require.NoError(t, store.DeleteAll(ctx, "/")) + + _, _, err := store.Open(ctx, "cleanup/a.txt") + require.Error(t, err) + assert.True(t, IsNotExist(err)) + + _, _, err = store.Open(ctx, "cleanup/b.txt") + require.Error(t, err) + assert.True(t, IsNotExist(err)) + + _, _, err = store.Open(ctx, "cleanup/nested/c.txt") + require.Error(t, err) + assert.True(t, IsNotExist(err)) + }) + + t.Run("delete all files under a prefix", func(t *testing.T) { + require.NoError(t, store.Save(ctx, "cleanup/a.txt", bytes.NewBufferString("a"))) + require.NoError(t, store.Save(ctx, "cleanup/b.txt", bytes.NewBufferString("b"))) + require.NoError(t, store.Save(ctx, "cleanup/nested/c.txt", bytes.NewBufferString("c"))) + require.NoError(t, store.DeleteAll(ctx, "cleanup")) + + _, _, err := store.Open(ctx, "cleanup/a.txt") + require.Error(t, err) + assert.True(t, IsNotExist(err)) + + _, _, err = store.Open(ctx, "cleanup/b.txt") + require.Error(t, err) + assert.True(t, IsNotExist(err)) + + _, _, err = store.Open(ctx, "cleanup/nested/c.txt") + require.Error(t, err) + assert.True(t, IsNotExist(err)) + }) + + t.Run("walk files", func(t *testing.T) { + require.NoError(t, store.Save(ctx, "walk/file1.txt", bytes.NewBufferString("1"))) + require.NoError(t, store.Save(ctx, "walk/file2.txt", bytes.NewBufferString("2"))) + require.NoError(t, store.Save(ctx, "walk/nested/file3.txt", bytes.NewBufferString("3"))) + + var paths []string + err := store.Walk(ctx, "walk", func(info ObjectInfo) error { + paths = append(paths, info.Path) + return nil + }) + require.NoError(t, err) + assert.Len(t, paths, 3) + assert.Contains(t, paths, "walk/file1.txt") + assert.Contains(t, paths, "walk/file2.txt") + assert.Contains(t, paths, "walk/nested/file3.txt") + }) +} + +func TestNewDatabaseStorage(t *testing.T) { + t.Run("should return error with nil database", func(t *testing.T) { + _, err := NewDatabaseStorage(nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "database connection is required") + }) + + t.Run("should create storage with valid database", func(t *testing.T) { + db := testingutil.NewDatabaseForTest(t) + store, err := NewDatabaseStorage(db) + require.NoError(t, err) + assert.NotNil(t, store) + }) +} diff --git a/backend/internal/utils/ip_util.go b/backend/internal/utils/ip_util.go index 9832046b..be2ab83c 100644 --- a/backend/internal/utils/ip_util.go +++ b/backend/internal/utils/ip_util.go @@ -1,7 +1,10 @@ package utils import ( + "context" + "errors" "net" + "net/url" "strings" "github.com/pocket-id/pocket-id/backend/internal/common" @@ -56,6 +59,23 @@ func IsPrivateIP(ip net.IP) bool { return IsLocalhostIP(ip) || IsPrivateLanIP(ip) || IsTailscaleIP(ip) || IsLocalIPv6(ip) } +func IsURLPrivate(ctx context.Context, u *url.URL) (bool, error) { + var r net.Resolver + ips, err := r.LookupIPAddr(ctx, u.Hostname()) + if err != nil || len(ips) == 0 { + return false, errors.New("cannot resolve hostname") + } + + // Prevents SSRF by allowing only public IPs + for _, addr := range ips { + if IsPrivateIP(addr.IP) { + return true, nil + } + } + + return false, nil +} + func listContainsIP(ipNets []*net.IPNet, ip net.IP) bool { for _, ipNet := range ipNets { if ipNet.Contains(ip) { diff --git a/backend/internal/utils/ip_util_test.go b/backend/internal/utils/ip_util_test.go index 01c7bf68..5da1eb68 100644 --- a/backend/internal/utils/ip_util_test.go +++ b/backend/internal/utils/ip_util_test.go @@ -1,8 +1,14 @@ package utils import ( + "context" "net" + "net/url" "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/pocket-id/pocket-id/backend/internal/common" ) @@ -20,9 +26,8 @@ func TestIsLocalhostIP(t *testing.T) { for _, tt := range tests { ip := net.ParseIP(tt.ip) - if got := IsLocalhostIP(ip); got != tt.expected { - t.Errorf("IsLocalhostIP(%s) = %v, want %v", tt.ip, got, tt.expected) - } + got := IsLocalhostIP(ip) + assert.Equal(t, tt.expected, got) } } @@ -40,9 +45,8 @@ func TestIsPrivateLanIP(t *testing.T) { for _, tt := range tests { ip := net.ParseIP(tt.ip) - if got := IsPrivateLanIP(ip); got != tt.expected { - t.Errorf("IsPrivateLanIP(%s) = %v, want %v", tt.ip, got, tt.expected) - } + got := IsPrivateLanIP(ip) + assert.Equal(t, tt.expected, got) } } @@ -59,9 +63,9 @@ func TestIsTailscaleIP(t *testing.T) { for _, tt := range tests { ip := net.ParseIP(tt.ip) - if got := IsTailscaleIP(ip); got != tt.expected { - t.Errorf("IsTailscaleIP(%s) = %v, want %v", tt.ip, got, tt.expected) - } + + got := IsTailscaleIP(ip) + assert.Equal(t, tt.expected, got) } } @@ -86,16 +90,17 @@ func TestIsLocalIPv6(t *testing.T) { for _, tt := range tests { ip := net.ParseIP(tt.ip) - if got := IsLocalIPv6(ip); got != tt.expected { - t.Errorf("IsLocalIPv6(%s) = %v, want %v", tt.ip, got, tt.expected) - } + got := IsLocalIPv6(ip) + assert.Equal(t, tt.expected, got) } } func TestIsPrivateIP(t *testing.T) { // Save and restore env config origRanges := common.EnvConfig.LocalIPv6Ranges - defer func() { common.EnvConfig.LocalIPv6Ranges = origRanges }() + t.Cleanup(func() { + common.EnvConfig.LocalIPv6Ranges = origRanges + }) common.EnvConfig.LocalIPv6Ranges = "fd00::/8" localIPv6Ranges = nil // reset @@ -115,9 +120,8 @@ func TestIsPrivateIP(t *testing.T) { for _, tt := range tests { ip := net.ParseIP(tt.ip) - if got := IsPrivateIP(ip); got != tt.expected { - t.Errorf("IsPrivateIP(%s) = %v, want %v", tt.ip, got, tt.expected) - } + got := IsPrivateIP(ip) + assert.Equal(t, tt.expected, got) } } @@ -138,22 +142,202 @@ func TestListContainsIP(t *testing.T) { for _, tt := range tests { ip := net.ParseIP(tt.ip) - if got := listContainsIP(list, ip); got != tt.expected { - t.Errorf("listContainsIP(%s) = %v, want %v", tt.ip, got, tt.expected) - } + got := listContainsIP(list, ip) + assert.Equal(t, tt.expected, got) } } func TestInit_LocalIPv6Ranges(t *testing.T) { // Save and restore env config origRanges := common.EnvConfig.LocalIPv6Ranges - defer func() { common.EnvConfig.LocalIPv6Ranges = origRanges }() + t.Cleanup(func() { + common.EnvConfig.LocalIPv6Ranges = origRanges + }) common.EnvConfig.LocalIPv6Ranges = "fd00::/8, invalidCIDR ,fc00::/7" localIPv6Ranges = nil loadLocalIPv6Ranges() - if len(localIPv6Ranges) != 2 { - t.Errorf("expected 2 valid IPv6 ranges, got %d", len(localIPv6Ranges)) + assert.Len(t, localIPv6Ranges, 2) +} + +func TestIsURLPrivate(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + defer cancel() + + tests := []struct { + name string + urlStr string + expectPriv bool + expectError bool + }{ + { + name: "localhost by name", + urlStr: "http://localhost", + expectPriv: true, + expectError: false, + }, + { + name: "localhost with port", + urlStr: "http://localhost:8080", + expectPriv: true, + expectError: false, + }, + { + name: "127.0.0.1 IP", + urlStr: "http://127.0.0.1", + expectPriv: true, + expectError: false, + }, + { + name: "127.0.0.1 with port", + urlStr: "http://127.0.0.1:3000", + expectPriv: true, + expectError: false, + }, + { + name: "IPv6 loopback", + urlStr: "http://[::1]", + expectPriv: true, + expectError: false, + }, + { + name: "IPv6 loopback with port", + urlStr: "http://[::1]:8080", + expectPriv: true, + expectError: false, + }, + { + name: "private IP 10.x.x.x", + urlStr: "http://10.0.0.1", + expectPriv: true, + expectError: false, + }, + { + name: "private IP 192.168.x.x", + urlStr: "http://192.168.1.1", + expectPriv: true, + expectError: false, + }, + { + name: "private IP 172.16.x.x", + urlStr: "http://172.16.0.1", + expectPriv: true, + expectError: false, + }, + { + name: "Tailscale IP", + urlStr: "http://100.64.0.1", + expectPriv: true, + expectError: false, + }, + { + name: "public IP - Google DNS", + urlStr: "http://8.8.8.8", + expectPriv: false, + expectError: false, + }, + { + name: "public IP - Cloudflare DNS", + urlStr: "http://1.1.1.1", + expectPriv: false, + expectError: false, + }, + { + name: "invalid hostname", + urlStr: "http://this-should-not-resolve-ever-123456789.invalid", + expectPriv: false, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u, err := url.Parse(tt.urlStr) + require.NoError(t, err, "Failed to parse URL %s", tt.urlStr) + + isPriv, err := IsURLPrivate(ctx, u) + + if tt.expectError { + require.Error(t, err, "IsURLPrivate(%s) expected error but got none", tt.urlStr) + } else { + require.NoError(t, err, "IsURLPrivate(%s) unexpected error", tt.urlStr) + assert.Equal(t, tt.expectPriv, isPriv, "IsURLPrivate(%s)", tt.urlStr) + } + }) } } + +func TestIsURLPrivate_WithDomainName(t *testing.T) { + // Note: These tests rely on actual DNS resolution + // They test real public domains to ensure they are not flagged as private + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) + defer cancel() + + tests := []struct { + name string + urlStr string + expectPriv bool + }{ + { + name: "Google public domain", + urlStr: "https://www.google.com", + expectPriv: false, + }, + { + name: "GitHub public domain", + urlStr: "https://github.com", + expectPriv: false, + }, + { + // localhost.localtest.me is a well-known domain that resolves to 127.0.0.1 + name: "localhost.localtest.me resolves to 127.0.0.1", + urlStr: "http://localhost.localtest.me", + expectPriv: true, + }, + { + // 10.0.0.1.nip.io resolves to 10.0.0.1 (private IP) + name: "nip.io domain resolving to private 10.x IP", + urlStr: "http://10.0.0.1.nip.io", + expectPriv: true, + }, + { + // 192.168.1.1.nip.io resolves to 192.168.1.1 (private IP) + name: "nip.io domain resolving to private 192.168.x IP", + urlStr: "http://192.168.1.1.nip.io", + expectPriv: true, + }, + { + // 127.0.0.1.nip.io resolves to 127.0.0.1 (localhost) + name: "nip.io domain resolving to localhost", + urlStr: "http://127.0.0.1.nip.io", + expectPriv: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u, err := url.Parse(tt.urlStr) + require.NoError(t, err, "Failed to parse URL %s", tt.urlStr) + + isPriv, err := IsURLPrivate(ctx, u) + if err != nil { + t.Skipf("DNS resolution failed for %s (network issue?): %v", tt.urlStr, err) + return + } + + assert.Equal(t, tt.expectPriv, isPriv, "IsURLPrivate(%s)", tt.urlStr) + }) + } +} + +func TestIsURLPrivate_ContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + cancel() // Cancel immediately + + u, err := url.Parse("http://example.com") + require.NoError(t, err, "Failed to parse URL") + + _, err = IsURLPrivate(ctx, u) + assert.Error(t, err, "IsURLPrivate with cancelled context expected error but got none") +} diff --git a/backend/internal/utils/stream_util.go b/backend/internal/utils/stream_util.go new file mode 100644 index 00000000..ce77ddf6 --- /dev/null +++ b/backend/internal/utils/stream_util.go @@ -0,0 +1,34 @@ +package utils + +import ( + "errors" + "io" +) + +var ErrSizeExceeded = errors.New("stream size exceeded") + +// LimitReader is like io.LimitReader but throws an error if the stream exceeds the max size +// io.LimitReader instead just returns io.EOF +// Adapted from https://github.com/golang/go/issues/51115#issuecomment-1079761212 +type LimitReader struct { + io.ReadCloser + N int64 +} + +func NewLimitReader(r io.ReadCloser, limit int64) *LimitReader { + return &LimitReader{r, limit} +} + +func (r *LimitReader) Read(p []byte) (n int, err error) { + if r.N <= 0 { + return 0, ErrSizeExceeded + } + + if int64(len(p)) > r.N { + p = p[0:r.N] + } + + n, err = r.ReadCloser.Read(p) + r.N -= int64(n) + return +} diff --git a/backend/resources/migrations/postgres/20251110000000_storage_table.down.sql b/backend/resources/migrations/postgres/20251110000000_storage_table.down.sql new file mode 100644 index 00000000..c914273a --- /dev/null +++ b/backend/resources/migrations/postgres/20251110000000_storage_table.down.sql @@ -0,0 +1 @@ +DROP TABLE storage; diff --git a/backend/resources/migrations/postgres/20251110000000_storage_table.up.sql b/backend/resources/migrations/postgres/20251110000000_storage_table.up.sql new file mode 100644 index 00000000..337d4daa --- /dev/null +++ b/backend/resources/migrations/postgres/20251110000000_storage_table.up.sql @@ -0,0 +1,9 @@ +-- The "storage" table contains file data stored in the database +CREATE TABLE storage +( + path TEXT NOT NULL PRIMARY KEY, + data BYTEA NOT NULL, + size BIGINT NOT NULL, + mod_time TIMESTAMPTZ NOT NULL, + created_at TIMESTAMPTZ NOT NULL +); diff --git a/backend/resources/migrations/postgres/20251115000000_api_keys_expires_at_index.down.sql b/backend/resources/migrations/postgres/20251115000000_api_keys_expires_at_index.down.sql new file mode 100644 index 00000000..2098d8dc --- /dev/null +++ b/backend/resources/migrations/postgres/20251115000000_api_keys_expires_at_index.down.sql @@ -0,0 +1 @@ +DROP INDEX idx_api_keys_expires_at; diff --git a/backend/resources/migrations/postgres/20251115000000_api_keys_expires_at_index.up.sql b/backend/resources/migrations/postgres/20251115000000_api_keys_expires_at_index.up.sql new file mode 100644 index 00000000..af1e6cad --- /dev/null +++ b/backend/resources/migrations/postgres/20251115000000_api_keys_expires_at_index.up.sql @@ -0,0 +1 @@ +CREATE INDEX idx_api_keys_expires_at ON api_keys(expires_at); diff --git a/backend/resources/migrations/sqlite/20251110000000_storage_table.down.sql b/backend/resources/migrations/sqlite/20251110000000_storage_table.down.sql new file mode 100644 index 00000000..2135db93 --- /dev/null +++ b/backend/resources/migrations/sqlite/20251110000000_storage_table.down.sql @@ -0,0 +1,6 @@ +PRAGMA foreign_keys=OFF; +BEGIN; +DROP TABLE storage; + +COMMIT; +PRAGMA foreign_keys=ON; diff --git a/backend/resources/migrations/sqlite/20251110000000_storage_table.up.sql b/backend/resources/migrations/sqlite/20251110000000_storage_table.up.sql new file mode 100644 index 00000000..12dd4dcb --- /dev/null +++ b/backend/resources/migrations/sqlite/20251110000000_storage_table.up.sql @@ -0,0 +1,14 @@ +PRAGMA foreign_keys=OFF; +BEGIN; +-- The "storage" table contains file data stored in the database +CREATE TABLE storage +( + path TEXT NOT NULL PRIMARY KEY, + data BLOB NOT NULL, + size INTEGER NOT NULL, + mod_time DATETIME NOT NULL, + created_at DATETIME NOT NULL +); + +COMMIT; +PRAGMA foreign_keys=ON; diff --git a/backend/resources/migrations/sqlite/20251115000000_api_keys_expires_at_index.down.sql b/backend/resources/migrations/sqlite/20251115000000_api_keys_expires_at_index.down.sql new file mode 100644 index 00000000..ec8c10cb --- /dev/null +++ b/backend/resources/migrations/sqlite/20251115000000_api_keys_expires_at_index.down.sql @@ -0,0 +1,5 @@ +PRAGMA foreign_keys=OFF; +BEGIN; +DROP INDEX idx_api_keys_expires_at; +COMMIT; +PRAGMA foreign_keys=ON; diff --git a/backend/resources/migrations/sqlite/20251115000000_api_keys_expires_at_index.up.sql b/backend/resources/migrations/sqlite/20251115000000_api_keys_expires_at_index.up.sql new file mode 100644 index 00000000..899b4d02 --- /dev/null +++ b/backend/resources/migrations/sqlite/20251115000000_api_keys_expires_at_index.up.sql @@ -0,0 +1,5 @@ +PRAGMA foreign_keys=OFF; +BEGIN; +CREATE INDEX idx_api_keys_expires_at ON api_keys(expires_at); +COMMIT; +PRAGMA foreign_keys=ON; diff --git a/tests/setup/docker-compose-postgres.yml b/tests/setup/docker-compose-postgres.yml index 0171a91b..09539b75 100644 --- a/tests/setup/docker-compose-postgres.yml +++ b/tests/setup/docker-compose-postgres.yml @@ -20,8 +20,10 @@ services: file: docker-compose.yml service: pocket-id environment: + - APP_ENV=test - DB_PROVIDER=postgres - DB_CONNECTION_STRING=postgres://postgres:postgres@postgres:5432/pocket-id + - FILE_BACKEND=${FILE_BACKEND} depends_on: postgres: condition: service_healthy diff --git a/tests/setup/docker-compose-s3.yml b/tests/setup/docker-compose-s3.yml index 689a44c9..8159ee0b 100644 --- a/tests/setup/docker-compose-s3.yml +++ b/tests/setup/docker-compose-s3.yml @@ -13,7 +13,7 @@ services: retries: 10 create-bucket: - image: amazon/aws-cli + image: amazon/aws-cli:latest environment: AWS_ACCESS_KEY_ID: test AWS_SECRET_ACCESS_KEY: test @@ -28,15 +28,14 @@ services: file: docker-compose.yml service: pocket-id environment: - FILE_BACKEND: s3 - S3_BUCKET: pocket-id-test - S3_REGION: us-east-1 - S3_ENDPOINT: http://localstack-s3:4566 - S3_ACCESS_KEY_ID: test - S3_SECRET_ACCESS_KEY: test - S3_FORCE_PATH_STYLE: true - KEYS_STORAGE: database - ENCRYPTION_KEY: test + - S3_BUCKET=pocket-id-test + - S3_REGION=us-east-1 + - S3_ENDPOINT=http://localstack-s3:4566 + - S3_ACCESS_KEY_ID=test + - S3_SECRET_ACCESS_KEY=test + - S3_FORCE_PATH_STYLE=true + - KEYS_STORAGE=database + - ENCRYPTION_KEY=test1234test1234test1234test1234 depends_on: create-bucket: condition: service_completed_successfully diff --git a/tests/setup/docker-compose.yml b/tests/setup/docker-compose.yml index 03db178f..74e1e778 100644 --- a/tests/setup/docker-compose.yml +++ b/tests/setup/docker-compose.yml @@ -14,6 +14,7 @@ services: - '1411:1411' environment: - APP_ENV=test + - FILE_BACKEND=${FILE_BACKEND} build: args: - BUILD_TAGS=e2etest