mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-20 17:25:43 +03:00
feat: add database storage backend (#1091)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
committed by
GitHub
parent
12125713a2
commit
29a1d3b778
83
.github/workflows/e2e-tests.yml
vendored
83
.github/workflows/e2e-tests.yml
vendored
@@ -57,7 +57,17 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
db: [sqlite, postgres, sqlite-s3]
|
include:
|
||||||
|
- db: sqlite
|
||||||
|
storage: fs
|
||||||
|
- db: postgres
|
||||||
|
storage: fs
|
||||||
|
- db: sqlite
|
||||||
|
storage: s3
|
||||||
|
- db: sqlite
|
||||||
|
storage: database
|
||||||
|
- db: postgres
|
||||||
|
storage: database
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v5
|
||||||
@@ -71,65 +81,74 @@ jobs:
|
|||||||
node-version: 22
|
node-version: 22
|
||||||
|
|
||||||
- name: Cache Playwright Browsers
|
- name: Cache Playwright Browsers
|
||||||
uses: actions/cache@v3
|
uses: actions/cache@v4
|
||||||
id: playwright-cache
|
id: playwright-cache
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/ms-playwright
|
path: ~/.cache/ms-playwright
|
||||||
key: ${{ runner.os }}-playwright-${{ hashFiles('pnpm-lock.yaml') }}
|
key: ${{ runner.os }}-playwright-${{ hashFiles('pnpm-lock.yaml') }}
|
||||||
|
|
||||||
- name: Cache PostgreSQL Docker image
|
- name: Cache PostgreSQL Docker image
|
||||||
if: matrix.db == 'postgres'
|
uses: actions/cache@v4
|
||||||
uses: actions/cache@v3
|
|
||||||
id: postgres-cache
|
id: postgres-cache
|
||||||
with:
|
with:
|
||||||
path: /tmp/postgres-image.tar
|
path: /tmp/postgres-image.tar
|
||||||
key: postgres-17-${{ runner.os }}
|
key: postgres-17-${{ runner.os }}
|
||||||
|
|
||||||
- name: Pull and save PostgreSQL image
|
- name: Pull and save PostgreSQL image
|
||||||
if: matrix.db == 'postgres' && steps.postgres-cache.outputs.cache-hit != 'true'
|
if: matrix.db == 'postgres' && steps.postgres-cache.outputs.cache-hit != 'true'
|
||||||
run: |
|
run: |
|
||||||
docker pull postgres:17
|
docker pull postgres:17
|
||||||
docker save postgres:17 > /tmp/postgres-image.tar
|
docker save postgres:17 > /tmp/postgres-image.tar
|
||||||
|
|
||||||
- name: Load PostgreSQL image
|
- name: Load PostgreSQL image
|
||||||
if: matrix.db == 'postgres' && steps.postgres-cache.outputs.cache-hit == 'true'
|
if: matrix.db == 'postgres' && steps.postgres-cache.outputs.cache-hit == 'true'
|
||||||
run: docker load < /tmp/postgres-image.tar
|
run: docker load < /tmp/postgres-image.tar
|
||||||
|
|
||||||
- name: Cache LLDAP Docker image
|
- name: Cache LLDAP Docker image
|
||||||
uses: actions/cache@v3
|
uses: actions/cache@v4
|
||||||
id: lldap-cache
|
id: lldap-cache
|
||||||
with:
|
with:
|
||||||
path: /tmp/lldap-image.tar
|
path: /tmp/lldap-image.tar
|
||||||
key: lldap-stable-${{ runner.os }}
|
key: lldap-stable-${{ runner.os }}
|
||||||
|
|
||||||
- name: Pull and save LLDAP image
|
- name: Pull and save LLDAP image
|
||||||
if: steps.lldap-cache.outputs.cache-hit != 'true'
|
if: steps.lldap-cache.outputs.cache-hit != 'true'
|
||||||
run: |
|
run: |
|
||||||
docker pull nitnelave/lldap:stable
|
docker pull lldap/lldap:2025-05-19
|
||||||
docker save nitnelave/lldap:stable > /tmp/lldap-image.tar
|
docker save lldap/lldap:2025-05-19 > /tmp/lldap-image.tar
|
||||||
|
|
||||||
- name: Load LLDAP image
|
- name: Load LLDAP image
|
||||||
if: steps.lldap-cache.outputs.cache-hit == 'true'
|
if: steps.lldap-cache.outputs.cache-hit == 'true'
|
||||||
run: docker load < /tmp/lldap-image.tar
|
run: docker load < /tmp/lldap-image.tar
|
||||||
|
|
||||||
- name: Cache Localstack S3 Docker image
|
- name: Cache Localstack S3 Docker image
|
||||||
if: matrix.db == 'sqlite-s3'
|
if: matrix.storage == 's3'
|
||||||
uses: actions/cache@v3
|
uses: actions/cache@v4
|
||||||
id: s3-cache
|
id: s3-cache
|
||||||
with:
|
with:
|
||||||
path: /tmp/localstack-s3-image.tar
|
path: /tmp/localstack-s3-image.tar
|
||||||
key: localstack-s3-latest-${{ runner.os }}
|
key: localstack-s3-latest-${{ runner.os }}
|
||||||
|
|
||||||
- name: Pull and save Localstack S3 image
|
- name: Pull and save Localstack S3 image
|
||||||
if: matrix.db == 'sqlite-s3' && steps.s3-cache.outputs.cache-hit != 'true'
|
if: matrix.storage == 's3' && steps.s3-cache.outputs.cache-hit != 'true'
|
||||||
run: |
|
run: |
|
||||||
docker pull localstack/localstack:s3-latest
|
docker pull localstack/localstack:s3-latest
|
||||||
docker save localstack/localstack:s3-latest > /tmp/localstack-s3-image.tar
|
docker save localstack/localstack:s3-latest > /tmp/localstack-s3-image.tar
|
||||||
|
|
||||||
- name: Load Localstack S3 image
|
- name: Load Localstack S3 image
|
||||||
if: matrix.db == 'sqlite-s3' && steps.s3-cache.outputs.cache-hit == 'true'
|
if: matrix.storage == 's3' && steps.s3-cache.outputs.cache-hit == 'true'
|
||||||
run: docker load < /tmp/localstack-s3-image.tar
|
run: docker load < /tmp/localstack-s3-image.tar
|
||||||
|
|
||||||
|
- name: Cache AWS CLI Docker image
|
||||||
|
if: matrix.storage == 's3'
|
||||||
|
uses: actions/cache@v4
|
||||||
|
id: aws-cli-cache
|
||||||
|
with:
|
||||||
|
path: /tmp/aws-cli-image.tar
|
||||||
|
key: aws-cli-latest-${{ runner.os }}
|
||||||
|
- name: Pull and save AWS CLI image
|
||||||
|
if: matrix.storage == 's3' && steps.aws-cli-cache.outputs.cache-hit != 'true'
|
||||||
|
run: |
|
||||||
|
docker pull amazon/aws-cli:latest
|
||||||
|
docker save amazon/aws-cli:latest > /tmp/aws-cli-image.tar
|
||||||
|
- name: Load AWS CLI image
|
||||||
|
if: matrix.storage == 's3' && steps.aws-cli-cache.outputs.cache-hit == 'true'
|
||||||
|
run: docker load < /tmp/aws-cli-image.tar
|
||||||
|
|
||||||
- name: Download Docker image artifact
|
- name: Download Docker image artifact
|
||||||
uses: actions/download-artifact@v4
|
uses: actions/download-artifact@v4
|
||||||
with:
|
with:
|
||||||
@@ -147,26 +166,20 @@ jobs:
|
|||||||
if: steps.playwright-cache.outputs.cache-hit != 'true'
|
if: steps.playwright-cache.outputs.cache-hit != 'true'
|
||||||
run: pnpm exec playwright install --with-deps chromium
|
run: pnpm exec playwright install --with-deps chromium
|
||||||
|
|
||||||
- name: Run Docker Container (sqlite) with LDAP
|
- name: Run Docker containers
|
||||||
if: matrix.db == 'sqlite'
|
|
||||||
working-directory: ./tests/setup
|
working-directory: ./tests/setup
|
||||||
run: |
|
run: |
|
||||||
docker compose up -d
|
DOCKER_COMPOSE_FILE=docker-compose.yml
|
||||||
docker compose logs -f pocket-id &> /tmp/backend.log &
|
|
||||||
|
|
||||||
- name: Run Docker Container (postgres) with LDAP
|
export FILE_BACKEND="${{ matrix.storage }}"
|
||||||
if: matrix.db == 'postgres'
|
if [ "${{ matrix.db }}" = "postgres" ]; then
|
||||||
working-directory: ./tests/setup
|
DOCKER_COMPOSE_FILE=docker-compose-postgres.yml
|
||||||
run: |
|
elif [ "${{ matrix.storage }}" = "s3" ]; then
|
||||||
docker compose -f docker-compose-postgres.yml up -d
|
DOCKER_COMPOSE_FILE=docker-compose-s3.yml
|
||||||
docker compose -f docker-compose-postgres.yml logs -f pocket-id &> /tmp/backend.log &
|
fi
|
||||||
|
|
||||||
- name: Run Docker Container (sqlite-s3) with LDAP + S3
|
docker compose -f "$DOCKER_COMPOSE_FILE" up -d
|
||||||
if: matrix.db == 'sqlite-s3'
|
docker compose -f "$DOCKER_COMPOSE_FILE" logs -f pocket-id &> /tmp/backend.log &
|
||||||
working-directory: ./tests/setup
|
|
||||||
run: |
|
|
||||||
docker compose -f docker-compose-s3.yml up -d
|
|
||||||
docker compose -f docker-compose-s3.yml logs -f pocket-id &> /tmp/backend.log &
|
|
||||||
|
|
||||||
- name: Run Playwright tests
|
- name: Run Playwright tests
|
||||||
working-directory: ./tests
|
working-directory: ./tests
|
||||||
@@ -176,7 +189,7 @@ jobs:
|
|||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
if: always() && github.event.pull_request.head.ref != 'i18n_crowdin'
|
if: always() && github.event.pull_request.head.ref != 'i18n_crowdin'
|
||||||
with:
|
with:
|
||||||
name: playwright-report-${{ matrix.db }}
|
name: playwright-report-${{ matrix.db }}-${{ matrix.storage }}
|
||||||
path: tests/.report
|
path: tests/.report
|
||||||
include-hidden-files: true
|
include-hidden-files: true
|
||||||
retention-days: 15
|
retention-days: 15
|
||||||
@@ -185,7 +198,7 @@ jobs:
|
|||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
if: always() && github.event.pull_request.head.ref != 'i18n_crowdin'
|
if: always() && github.event.pull_request.head.ref != 'i18n_crowdin'
|
||||||
with:
|
with:
|
||||||
name: backend-${{ matrix.db }}
|
name: backend-${{ matrix.db }}-${{ matrix.storage }}
|
||||||
path: /tmp/backend.log
|
path: /tmp/backend.log
|
||||||
include-hidden-files: true
|
include-hidden-files: true
|
||||||
retention-days: 15
|
retention-days: 15
|
||||||
|
|||||||
@@ -22,12 +22,20 @@ func Bootstrap(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
slog.InfoContext(ctx, "Pocket ID is starting")
|
slog.InfoContext(ctx, "Pocket ID is starting")
|
||||||
|
|
||||||
|
// Connect to the database
|
||||||
|
db, err := NewDatabase()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to initialize database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Initialize the file storage backend
|
// Initialize the file storage backend
|
||||||
var fileStorage storage.FileStorage
|
var fileStorage storage.FileStorage
|
||||||
|
|
||||||
switch common.EnvConfig.FileBackend {
|
switch common.EnvConfig.FileBackend {
|
||||||
case storage.TypeFileSystem:
|
case storage.TypeFileSystem:
|
||||||
fileStorage, err = storage.NewFilesystemStorage(common.EnvConfig.UploadPath)
|
fileStorage, err = storage.NewFilesystemStorage(common.EnvConfig.UploadPath)
|
||||||
|
case storage.TypeDatabase:
|
||||||
|
fileStorage, err = storage.NewDatabaseStorage(db)
|
||||||
case storage.TypeS3:
|
case storage.TypeS3:
|
||||||
s3Cfg := storage.S3Config{
|
s3Cfg := storage.S3Config{
|
||||||
Bucket: common.EnvConfig.S3Bucket,
|
Bucket: common.EnvConfig.S3Bucket,
|
||||||
@@ -43,7 +51,7 @@ func Bootstrap(ctx context.Context) error {
|
|||||||
err = fmt.Errorf("unknown file storage backend: %s", common.EnvConfig.FileBackend)
|
err = fmt.Errorf("unknown file storage backend: %s", common.EnvConfig.FileBackend)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize file storage: %w", err)
|
return fmt.Errorf("failed to initialize file storage (backend: %s): %w", common.EnvConfig.FileBackend, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
imageExtensions, err := initApplicationImages(ctx, fileStorage)
|
imageExtensions, err := initApplicationImages(ctx, fileStorage)
|
||||||
@@ -51,12 +59,6 @@ func Bootstrap(ctx context.Context) error {
|
|||||||
return fmt.Errorf("failed to initialize application images: %w", err)
|
return fmt.Errorf("failed to initialize application images: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect to the database
|
|
||||||
db, err := NewDatabase()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to initialize database: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create all services
|
// Create all services
|
||||||
svc, err := initServices(ctx, db, httpClient, imageExtensions, fileStorage)
|
svc, err := initServices(ctx, db, httpClient, imageExtensions, fileStorage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client, ima
|
|||||||
|
|
||||||
svc.userGroupService = service.NewUserGroupService(db, svc.appConfigService)
|
svc.userGroupService = service.NewUserGroupService(db, svc.appConfigService)
|
||||||
svc.userService = service.NewUserService(db, svc.jwtService, svc.auditLogService, svc.emailService, svc.appConfigService, svc.customClaimService, svc.appImagesService, fileStorage)
|
svc.userService = service.NewUserService(db, svc.jwtService, svc.auditLogService, svc.emailService, svc.appConfigService, svc.customClaimService, svc.appImagesService, fileStorage)
|
||||||
svc.ldapService = service.NewLdapService(db, httpClient, svc.appConfigService, svc.userService, svc.userGroupService)
|
svc.ldapService = service.NewLdapService(db, httpClient, svc.appConfigService, svc.userService, svc.userGroupService, fileStorage)
|
||||||
svc.apiKeyService = service.NewApiKeyService(db, svc.emailService)
|
svc.apiKeyService = service.NewApiKeyService(db, svc.emailService)
|
||||||
|
|
||||||
svc.versionService = service.NewVersionService(httpClient)
|
svc.versionService = service.NewVersionService(httpClient)
|
||||||
|
|||||||
@@ -180,12 +180,14 @@ func validateEnvConfig(config *EnvConfigSchema) error {
|
|||||||
if config.KeysStorage == "file" {
|
if config.KeysStorage == "file" {
|
||||||
return errors.New("KEYS_STORAGE cannot be 'file' when FILE_BACKEND is 's3'")
|
return errors.New("KEYS_STORAGE cannot be 'file' when FILE_BACKEND is 's3'")
|
||||||
}
|
}
|
||||||
|
case "database":
|
||||||
|
// All good, these are valid values
|
||||||
case "", "fs":
|
case "", "fs":
|
||||||
if config.UploadPath == "" {
|
if config.UploadPath == "" {
|
||||||
config.UploadPath = defaultFsUploadPath
|
config.UploadPath = defaultFsUploadPath
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return errors.New("invalid FILE_BACKEND value. Must be 'fs' or 's3'")
|
return errors.New("invalid FILE_BACKEND value. Must be 'fs', 'database', or 's3'")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate LOCAL_IPV6_RANGES
|
// Validate LOCAL_IPV6_RANGES
|
||||||
|
|||||||
@@ -587,7 +587,6 @@ func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
c.Status(http.StatusNoContent)
|
c.Status(http.StatusNoContent)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// deleteClientLogoHandler godoc
|
// deleteClientLogoHandler godoc
|
||||||
@@ -614,7 +613,6 @@ func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
c.Status(http.StatusNoContent)
|
c.Status(http.StatusNoContent)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateAllowedUserGroupsHandler godoc
|
// updateAllowedUserGroupsHandler godoc
|
||||||
|
|||||||
17
backend/internal/model/storage.go
Normal file
17
backend/internal/model/storage.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Storage struct {
|
||||||
|
Path string `gorm:"primaryKey"`
|
||||||
|
Data []byte
|
||||||
|
Size int64
|
||||||
|
ModTime datatype.DateTime
|
||||||
|
CreatedAt datatype.DateTime
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Storage) TableName() string {
|
||||||
|
return "storage"
|
||||||
|
}
|
||||||
@@ -426,7 +426,8 @@ func (s *TestService) ResetDatabase() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *TestService) ResetApplicationImages(ctx context.Context) error {
|
func (s *TestService) ResetApplicationImages(ctx context.Context) error {
|
||||||
if err := s.fileStorage.DeleteAll(ctx, "/"); err != nil {
|
err := s.fileStorage.DeleteAll(ctx, "/")
|
||||||
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "Error removing uploads", slog.Any("error", err))
|
slog.ErrorContext(ctx, "Error removing uploads", slog.Any("error", err))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -445,7 +446,8 @@ func (s *TestService) ResetApplicationImages(ctx context.Context) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := s.fileStorage.Save(ctx, path.Join("application-images", file.Name()), srcFile); err != nil {
|
err = s.fileStorage.Save(ctx, path.Join("application-images", file.Name()), srcFile)
|
||||||
|
if err != nil {
|
||||||
srcFile.Close()
|
srcFile.Close()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,12 +11,14 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
"github.com/go-ldap/ldap/v3"
|
"github.com/go-ldap/ldap/v3"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/storage"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||||
"golang.org/x/text/unicode/norm"
|
"golang.org/x/text/unicode/norm"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -32,15 +34,23 @@ type LdapService struct {
|
|||||||
appConfigService *AppConfigService
|
appConfigService *AppConfigService
|
||||||
userService *UserService
|
userService *UserService
|
||||||
groupService *UserGroupService
|
groupService *UserGroupService
|
||||||
|
fileStorage storage.FileStorage
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLdapService(db *gorm.DB, httpClient *http.Client, appConfigService *AppConfigService, userService *UserService, groupService *UserGroupService) *LdapService {
|
type savePicture struct {
|
||||||
|
userID string
|
||||||
|
username string
|
||||||
|
picture string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLdapService(db *gorm.DB, httpClient *http.Client, appConfigService *AppConfigService, userService *UserService, groupService *UserGroupService, fileStorage storage.FileStorage) *LdapService {
|
||||||
return &LdapService{
|
return &LdapService{
|
||||||
db: db,
|
db: db,
|
||||||
httpClient: httpClient,
|
httpClient: httpClient,
|
||||||
appConfigService: appConfigService,
|
appConfigService: appConfigService,
|
||||||
userService: userService,
|
userService: userService,
|
||||||
groupService: groupService,
|
groupService: groupService,
|
||||||
|
fileStorage: fileStorage,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -68,12 +78,6 @@ func (s *LdapService) createClient() (*ldap.Conn, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *LdapService) SyncAll(ctx context.Context) error {
|
func (s *LdapService) SyncAll(ctx context.Context) error {
|
||||||
// Start a transaction
|
|
||||||
tx := s.db.Begin()
|
|
||||||
defer func() {
|
|
||||||
tx.Rollback()
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Setup LDAP connection
|
// Setup LDAP connection
|
||||||
client, err := s.createClient()
|
client, err := s.createClient()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -81,7 +85,13 @@ func (s *LdapService) SyncAll(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
defer client.Close()
|
defer client.Close()
|
||||||
|
|
||||||
err = s.SyncUsers(ctx, tx, client)
|
// Start a transaction
|
||||||
|
tx := s.db.Begin()
|
||||||
|
defer func() {
|
||||||
|
tx.Rollback()
|
||||||
|
}()
|
||||||
|
|
||||||
|
savePictures, deleteFiles, err := s.SyncUsers(ctx, tx, client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to sync users: %w", err)
|
return fmt.Errorf("failed to sync users: %w", err)
|
||||||
}
|
}
|
||||||
@@ -97,6 +107,25 @@ func (s *LdapService) SyncAll(ctx context.Context) error {
|
|||||||
return fmt.Errorf("failed to commit changes to database: %w", err)
|
return fmt.Errorf("failed to commit changes to database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Now that we've committed the transaction, we can perform operations on the storage layer
|
||||||
|
// First, save all new pictures
|
||||||
|
for _, sp := range savePictures {
|
||||||
|
err = s.saveProfilePicture(ctx, sp.userID, sp.picture)
|
||||||
|
if err != nil {
|
||||||
|
// This is not a fatal error
|
||||||
|
slog.Warn("Error saving profile picture for LDAP user", slog.String("username", sp.username), slog.Any("error", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete all old files
|
||||||
|
for _, path := range deleteFiles {
|
||||||
|
err = s.fileStorage.Delete(ctx, path)
|
||||||
|
if err != nil {
|
||||||
|
// This is not a fatal error
|
||||||
|
slog.Error("Failed to delete file after LDAP sync", slog.String("path", path), slog.Any("error", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -266,7 +295,7 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
|
|||||||
}
|
}
|
||||||
|
|
||||||
//nolint:gocognit
|
//nolint:gocognit
|
||||||
func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.Conn) error {
|
func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.Conn) (savePictures []savePicture, deleteFiles []string, err error) {
|
||||||
dbConfig := s.appConfigService.GetDbConfig()
|
dbConfig := s.appConfigService.GetDbConfig()
|
||||||
|
|
||||||
searchAttrs := []string{
|
searchAttrs := []string{
|
||||||
@@ -294,11 +323,12 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
|||||||
|
|
||||||
result, err := client.Search(searchReq)
|
result, err := client.Search(searchReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to query LDAP: %w", err)
|
return nil, nil, fmt.Errorf("failed to query LDAP: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a mapping for users that exist
|
// Create a mapping for users that exist
|
||||||
ldapUserIDs := make(map[string]struct{}, len(result.Entries))
|
ldapUserIDs := make(map[string]struct{}, len(result.Entries))
|
||||||
|
savePictures = make([]savePicture, 0, len(result.Entries))
|
||||||
|
|
||||||
for _, value := range result.Entries {
|
for _, value := range result.Entries {
|
||||||
ldapId := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeUserUniqueIdentifier.Value))
|
ldapId := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeUserUniqueIdentifier.Value))
|
||||||
@@ -329,13 +359,13 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
|||||||
Error
|
Error
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to enable user %s: %w", databaseUser.Username, err)
|
return nil, nil, fmt.Errorf("failed to enable user %s: %w", databaseUser.Username, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
// This could error with ErrRecordNotFound and we want to ignore that here
|
// This could error with ErrRecordNotFound and we want to ignore that here
|
||||||
return fmt.Errorf("failed to query for LDAP user ID '%s': %w", ldapId, err)
|
return nil, nil, fmt.Errorf("failed to query for LDAP user ID '%s': %w", ldapId, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if user is admin by checking if they are in the admin group
|
// Check if user is admin by checking if they are in the admin group
|
||||||
@@ -369,32 +399,35 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
userID := databaseUser.ID
|
||||||
if databaseUser.ID == "" {
|
if databaseUser.ID == "" {
|
||||||
_, err = s.userService.createUserInternal(ctx, newUser, true, tx)
|
createdUser, err := s.userService.createUserInternal(ctx, newUser, true, tx)
|
||||||
if errors.Is(err, &common.AlreadyInUseError{}) {
|
if errors.Is(err, &common.AlreadyInUseError{}) {
|
||||||
slog.Warn("Skipping creating LDAP user", slog.String("username", newUser.Username), slog.Any("error", err))
|
slog.Warn("Skipping creating LDAP user", slog.String("username", newUser.Username), slog.Any("error", err))
|
||||||
continue
|
continue
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return fmt.Errorf("error creating user '%s': %w", newUser.Username, err)
|
return nil, nil, fmt.Errorf("error creating user '%s': %w", newUser.Username, err)
|
||||||
}
|
}
|
||||||
|
userID = createdUser.ID
|
||||||
} else {
|
} else {
|
||||||
_, err = s.userService.updateUserInternal(ctx, databaseUser.ID, newUser, false, true, tx)
|
_, err = s.userService.updateUserInternal(ctx, databaseUser.ID, newUser, false, true, tx)
|
||||||
if errors.Is(err, &common.AlreadyInUseError{}) {
|
if errors.Is(err, &common.AlreadyInUseError{}) {
|
||||||
slog.Warn("Skipping updating LDAP user", slog.String("username", newUser.Username), slog.Any("error", err))
|
slog.Warn("Skipping updating LDAP user", slog.String("username", newUser.Username), slog.Any("error", err))
|
||||||
continue
|
continue
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return fmt.Errorf("error updating user '%s': %w", newUser.Username, err)
|
return nil, nil, fmt.Errorf("error updating user '%s': %w", newUser.Username, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save profile picture
|
// Save profile picture
|
||||||
pictureString := value.GetAttributeValue(dbConfig.LdapAttributeUserProfilePicture.Value)
|
pictureString := value.GetAttributeValue(dbConfig.LdapAttributeUserProfilePicture.Value)
|
||||||
if pictureString != "" {
|
if pictureString != "" {
|
||||||
err = s.saveProfilePicture(ctx, databaseUser.ID, pictureString)
|
// Storage operations must be executed outside of a transaction
|
||||||
if err != nil {
|
savePictures = append(savePictures, savePicture{
|
||||||
// This is not a fatal error
|
userID: databaseUser.ID,
|
||||||
slog.Warn("Error saving profile picture for user", slog.String("username", newUser.Username), slog.Any("error", err))
|
username: userID,
|
||||||
}
|
picture: pictureString,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -406,10 +439,11 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
|||||||
Select("id, username, ldap_id, disabled").
|
Select("id, username, ldap_id, disabled").
|
||||||
Error
|
Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to fetch users from database: %w", err)
|
return nil, nil, fmt.Errorf("failed to fetch users from database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mark users as disabled or delete users that no longer exist in LDAP
|
// Mark users as disabled or delete users that no longer exist in LDAP
|
||||||
|
deleteFiles = make([]string, 0, len(ldapUserIDs))
|
||||||
for _, user := range ldapUsersInDb {
|
for _, user := range ldapUsersInDb {
|
||||||
// Skip if the user ID exists in the fetched LDAP results
|
// Skip if the user ID exists in the fetched LDAP results
|
||||||
if _, exists := ldapUserIDs[*user.LdapID]; exists {
|
if _, exists := ldapUserIDs[*user.LdapID]; exists {
|
||||||
@@ -417,26 +451,30 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
if dbConfig.LdapSoftDeleteUsers.IsTrue() {
|
if dbConfig.LdapSoftDeleteUsers.IsTrue() {
|
||||||
err = s.userService.disableUserInternal(ctx, user.ID, tx)
|
err = s.userService.disableUserInternal(ctx, tx, user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to disable user %s: %w", user.Username, err)
|
return nil, nil, fmt.Errorf("failed to disable user %s: %w", user.Username, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("Disabled user", slog.String("username", user.Username))
|
slog.Info("Disabled user", slog.String("username", user.Username))
|
||||||
} else {
|
} else {
|
||||||
err = s.userService.deleteUserInternal(ctx, user.ID, true, tx)
|
err = s.userService.deleteUserInternal(ctx, tx, user.ID, true)
|
||||||
|
if err != nil {
|
||||||
target := &common.LdapUserUpdateError{}
|
target := &common.LdapUserUpdateError{}
|
||||||
if errors.As(err, &target) {
|
if errors.As(err, &target) {
|
||||||
return fmt.Errorf("failed to delete user %s: LDAP user must be disabled before deletion", user.Username)
|
return nil, nil, fmt.Errorf("failed to delete user %s: LDAP user must be disabled before deletion", user.Username)
|
||||||
} else if err != nil {
|
}
|
||||||
return fmt.Errorf("failed to delete user %s: %w", user.Username, err)
|
return nil, nil, fmt.Errorf("failed to delete user %s: %w", user.Username, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("Deleted user", slog.String("username", user.Username))
|
slog.Info("Deleted user", slog.String("username", user.Username))
|
||||||
|
|
||||||
|
// Storage operations must be executed outside of a transaction
|
||||||
|
deleteFiles = append(deleteFiles, path.Join("profile-pictures", user.ID+".png"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return savePictures, deleteFiles, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *LdapService) saveProfilePicture(parentCtx context.Context, userId string, pictureString string) error {
|
func (s *LdapService) saveProfilePicture(parentCtx context.Context, userId string, pictureString string) error {
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"path"
|
"path"
|
||||||
@@ -679,19 +678,21 @@ func (s *OidcService) introspectRefreshToken(ctx context.Context, clientID strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) GetClient(ctx context.Context, clientID string) (model.OidcClient, error) {
|
func (s *OidcService) GetClient(ctx context.Context, clientID string) (model.OidcClient, error) {
|
||||||
return s.getClientInternal(ctx, clientID, s.db)
|
return s.getClientInternal(ctx, clientID, s.db, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) getClientInternal(ctx context.Context, clientID string, tx *gorm.DB) (model.OidcClient, error) {
|
func (s *OidcService) getClientInternal(ctx context.Context, clientID string, tx *gorm.DB, forUpdate bool) (model.OidcClient, error) {
|
||||||
var client model.OidcClient
|
var client model.OidcClient
|
||||||
err := tx.
|
q := tx.
|
||||||
WithContext(ctx).
|
WithContext(ctx).
|
||||||
Preload("CreatedBy").
|
Preload("CreatedBy").
|
||||||
Preload("AllowedUserGroups").
|
Preload("AllowedUserGroups")
|
||||||
First(&client, "id = ?", clientID).
|
if forUpdate {
|
||||||
Error
|
q = q.Clauses(clause.Locking{Strength: "UPDATE"})
|
||||||
if err != nil {
|
}
|
||||||
return model.OidcClient{}, err
|
q = q.First(&client, "id = ?", clientID)
|
||||||
|
if q.Error != nil {
|
||||||
|
return model.OidcClient{}, q.Error
|
||||||
}
|
}
|
||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
@@ -724,11 +725,6 @@ func (s *OidcService) ListClients(ctx context.Context, name string, listRequestO
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
|
func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
|
||||||
tx := s.db.Begin()
|
|
||||||
defer func() {
|
|
||||||
tx.Rollback()
|
|
||||||
}()
|
|
||||||
|
|
||||||
client := model.OidcClient{
|
client := model.OidcClient{
|
||||||
Base: model.Base{
|
Base: model.Base{
|
||||||
ID: input.ID,
|
ID: input.ID,
|
||||||
@@ -737,7 +733,7 @@ func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCrea
|
|||||||
}
|
}
|
||||||
updateOIDCClientModelFromDto(&client, &input.OidcClientUpdateDto)
|
updateOIDCClientModelFromDto(&client, &input.OidcClientUpdateDto)
|
||||||
|
|
||||||
err := tx.
|
err := s.db.
|
||||||
WithContext(ctx).
|
WithContext(ctx).
|
||||||
Create(&client).
|
Create(&client).
|
||||||
Error
|
Error
|
||||||
@@ -748,62 +744,65 @@ func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCrea
|
|||||||
return model.OidcClient{}, err
|
return model.OidcClient{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// All storage operations must be executed outside of a transaction
|
||||||
if input.LogoURL != nil {
|
if input.LogoURL != nil {
|
||||||
err = s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.LogoURL, true)
|
err = s.downloadAndSaveLogoFromURL(ctx, client.ID, *input.LogoURL, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return model.OidcClient{}, fmt.Errorf("failed to download logo: %w", err)
|
return model.OidcClient{}, fmt.Errorf("failed to download logo: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.DarkLogoURL != nil {
|
if input.DarkLogoURL != nil {
|
||||||
err = s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.DarkLogoURL, false)
|
err = s.downloadAndSaveLogoFromURL(ctx, client.ID, *input.DarkLogoURL, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return model.OidcClient{}, fmt.Errorf("failed to download dark logo: %w", err)
|
return model.OidcClient{}, fmt.Errorf("failed to download dark logo: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tx.Commit().Error
|
|
||||||
if err != nil {
|
|
||||||
return model.OidcClient{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input dto.OidcClientUpdateDto) (model.OidcClient, error) {
|
func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input dto.OidcClientUpdateDto) (model.OidcClient, error) {
|
||||||
tx := s.db.Begin()
|
tx := s.db.Begin()
|
||||||
defer func() { tx.Rollback() }()
|
defer func() {
|
||||||
|
tx.Rollback()
|
||||||
|
}()
|
||||||
|
|
||||||
var client model.OidcClient
|
var client model.OidcClient
|
||||||
if err := tx.WithContext(ctx).
|
err := tx.WithContext(ctx).
|
||||||
Preload("CreatedBy").
|
Preload("CreatedBy").
|
||||||
First(&client, "id = ?", clientID).Error; err != nil {
|
First(&client, "id = ?", clientID).Error
|
||||||
|
if err != nil {
|
||||||
return model.OidcClient{}, err
|
return model.OidcClient{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
updateOIDCClientModelFromDto(&client, &input)
|
updateOIDCClientModelFromDto(&client, &input)
|
||||||
|
|
||||||
if err := tx.WithContext(ctx).Save(&client).Error; err != nil {
|
err = tx.WithContext(ctx).Save(&client).Error
|
||||||
|
if err != nil {
|
||||||
return model.OidcClient{}, err
|
return model.OidcClient{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = tx.Commit().Error
|
||||||
|
if err != nil {
|
||||||
|
return model.OidcClient{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// All storage operations must be executed outside of a transaction
|
||||||
if input.LogoURL != nil {
|
if input.LogoURL != nil {
|
||||||
err := s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.LogoURL, true)
|
err = s.downloadAndSaveLogoFromURL(ctx, client.ID, *input.LogoURL, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return model.OidcClient{}, fmt.Errorf("failed to download logo: %w", err)
|
return model.OidcClient{}, fmt.Errorf("failed to download logo: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.DarkLogoURL != nil {
|
if input.DarkLogoURL != nil {
|
||||||
err := s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.DarkLogoURL, false)
|
err = s.downloadAndSaveLogoFromURL(ctx, client.ID, *input.DarkLogoURL, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return model.OidcClient{}, fmt.Errorf("failed to download dark logo: %w", err)
|
return model.OidcClient{}, fmt.Errorf("failed to download dark logo: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Commit().Error; err != nil {
|
|
||||||
return model.OidcClient{}, err
|
|
||||||
}
|
|
||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -836,12 +835,24 @@ func (s *OidcService) DeleteClient(ctx context.Context, clientID string) error {
|
|||||||
err := s.db.
|
err := s.db.
|
||||||
WithContext(ctx).
|
WithContext(ctx).
|
||||||
Where("id = ?", clientID).
|
Where("id = ?", clientID).
|
||||||
|
Clauses(clause.Returning{}).
|
||||||
Delete(&client).
|
Delete(&client).
|
||||||
Error
|
Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Delete images if present
|
||||||
|
// Note that storage operations must be done outside of a transaction
|
||||||
|
if client.ImageType != nil && *client.ImageType != "" {
|
||||||
|
old := path.Join("oidc-client-images", client.ID+"."+*client.ImageType)
|
||||||
|
_ = s.fileStorage.Delete(ctx, old)
|
||||||
|
}
|
||||||
|
if client.DarkImageType != nil && *client.DarkImageType != "" {
|
||||||
|
old := path.Join("oidc-client-images", client.ID+"-dark."+*client.DarkImageType)
|
||||||
|
_ = s.fileStorage.Delete(ctx, old)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -941,57 +952,12 @@ func (s *OidcService) UpdateClientLogo(ctx context.Context, clientID string, fil
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer reader.Close()
|
defer reader.Close()
|
||||||
if err := s.fileStorage.Save(ctx, imagePath, reader); err != nil {
|
err = s.fileStorage.Save(ctx, imagePath, reader)
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
tx := s.db.Begin()
|
|
||||||
|
|
||||||
err = s.updateClientLogoType(ctx, tx, clientID, fileType, light)
|
|
||||||
if err != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return tx.Commit().Error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) error {
|
|
||||||
tx := s.db.Begin()
|
|
||||||
defer func() {
|
|
||||||
tx.Rollback()
|
|
||||||
}()
|
|
||||||
|
|
||||||
var client model.OidcClient
|
|
||||||
err := tx.
|
|
||||||
WithContext(ctx).
|
|
||||||
First(&client, "id = ?", clientID).
|
|
||||||
Error
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if client.ImageType == nil {
|
err = s.updateClientLogoType(ctx, clientID, fileType, light)
|
||||||
return errors.New("image not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
oldImageType := *client.ImageType
|
|
||||||
client.ImageType = nil
|
|
||||||
|
|
||||||
err = tx.
|
|
||||||
WithContext(ctx).
|
|
||||||
Save(&client).
|
|
||||||
Error
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
imagePath := path.Join("oidc-client-images", client.ID+"."+oldImageType)
|
|
||||||
if err := s.fileStorage.Delete(ctx, imagePath); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = tx.Commit().Error
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -999,7 +965,31 @@ func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) err
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) error {
|
||||||
|
return s.deleteClientLogoInternal(ctx, clientID, "", func(client *model.OidcClient) (string, error) {
|
||||||
|
if client.ImageType == nil {
|
||||||
|
return "", errors.New("image not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
oldImageType := *client.ImageType
|
||||||
|
client.ImageType = nil
|
||||||
|
return oldImageType, nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OidcService) DeleteClientDarkLogo(ctx context.Context, clientID string) error {
|
func (s *OidcService) DeleteClientDarkLogo(ctx context.Context, clientID string) error {
|
||||||
|
return s.deleteClientLogoInternal(ctx, clientID, "-dark", func(client *model.OidcClient) (string, error) {
|
||||||
|
if client.DarkImageType == nil {
|
||||||
|
return "", errors.New("image not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
oldImageType := *client.DarkImageType
|
||||||
|
client.DarkImageType = nil
|
||||||
|
return oldImageType, nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OidcService) deleteClientLogoInternal(ctx context.Context, clientID string, imagePathSuffix string, setClientImage func(*model.OidcClient) (string, error)) error {
|
||||||
tx := s.db.Begin()
|
tx := s.db.Begin()
|
||||||
defer func() {
|
defer func() {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
@@ -1014,13 +1004,11 @@ func (s *OidcService) DeleteClientDarkLogo(ctx context.Context, clientID string)
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if client.DarkImageType == nil {
|
oldImageType, err := setClientImage(&client)
|
||||||
return errors.New("image not found")
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
oldImageType := *client.DarkImageType
|
|
||||||
client.DarkImageType = nil
|
|
||||||
|
|
||||||
err = tx.
|
err = tx.
|
||||||
WithContext(ctx).
|
WithContext(ctx).
|
||||||
Save(&client).
|
Save(&client).
|
||||||
@@ -1029,12 +1017,14 @@ func (s *OidcService) DeleteClientDarkLogo(ctx context.Context, clientID string)
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
imagePath := path.Join("oidc-client-images", client.ID+"-dark."+oldImageType)
|
err = tx.Commit().Error
|
||||||
if err := s.fileStorage.Delete(ctx, imagePath); err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tx.Commit().Error
|
// All storage operations must be performed outside of a database transaction
|
||||||
|
imagePath := path.Join("oidc-client-images", client.ID+imagePathSuffix+"."+oldImageType)
|
||||||
|
err = s.fileStorage.Delete(ctx, imagePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -1048,7 +1038,7 @@ func (s *OidcService) UpdateAllowedUserGroups(ctx context.Context, id string, in
|
|||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
client, err = s.getClientInternal(ctx, id, tx)
|
client, err = s.getClientInternal(ctx, id, tx, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return model.OidcClient{}, err
|
return model.OidcClient{}, err
|
||||||
}
|
}
|
||||||
@@ -1831,7 +1821,7 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use
|
|||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
client, err := s.getClientInternal(ctx, clientID, tx)
|
client, err := s.getClientInternal(ctx, clientID, tx, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1976,7 +1966,25 @@ func (s *OidcService) IsClientAccessibleToUser(ctx context.Context, clientID str
|
|||||||
return s.IsUserGroupAllowedToAuthorize(user, client), nil
|
return s.IsUserGroupAllowedToAuthorize(user, client), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *gorm.DB, clientID string, raw string, light bool) error {
|
var errLogoTooLarge = errors.New("logo is too large")
|
||||||
|
|
||||||
|
func httpClientWithCheckRedirect(source *http.Client, checkRedirect func(req *http.Request, via []*http.Request) error) *http.Client {
|
||||||
|
if source == nil {
|
||||||
|
source = http.DefaultClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new client that clones the transport
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: source.Transport,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assign the CheckRedirect function
|
||||||
|
client.CheckRedirect = checkRedirect
|
||||||
|
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, clientID string, raw string, light bool) error {
|
||||||
u, err := url.Parse(raw)
|
u, err := url.Parse(raw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -1985,19 +1993,30 @@ func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *
|
|||||||
ctx, cancel := context.WithTimeout(parentCtx, 15*time.Second)
|
ctx, cancel := context.WithTimeout(parentCtx, 15*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
r := net.Resolver{}
|
// Prevents SSRF by allowing only public IPs
|
||||||
ips, err := r.LookupIPAddr(ctx, u.Hostname())
|
ok, err := utils.IsURLPrivate(ctx, u)
|
||||||
if err != nil || len(ips) == 0 {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot resolve hostname")
|
return err
|
||||||
|
} else if ok {
|
||||||
|
return errors.New("private IP addresses are not allowed")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prevents SSRF by allowing only public IPs
|
// We need to check this on redirects too
|
||||||
for _, addr := range ips {
|
client := httpClientWithCheckRedirect(s.httpClient, func(r *http.Request, via []*http.Request) error {
|
||||||
if utils.IsPrivateIP(addr.IP) {
|
if len(via) >= 10 {
|
||||||
return fmt.Errorf("private IP addresses are not allowed")
|
return errors.New("stopped after 10 redirects")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ok, err := utils.IsURLPrivate(r.Context(), r.URL)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
} else if ok {
|
||||||
|
return errors.New("private IP addresses are not allowed")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, raw, nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, raw, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -2005,7 +2024,7 @@ func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *
|
|||||||
req.Header.Set("User-Agent", "pocket-id/oidc-logo-fetcher")
|
req.Header.Set("User-Agent", "pocket-id/oidc-logo-fetcher")
|
||||||
req.Header.Set("Accept", "image/*")
|
req.Header.Set("Accept", "image/*")
|
||||||
|
|
||||||
resp, err := s.httpClient.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -2017,7 +2036,7 @@ func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *
|
|||||||
|
|
||||||
const maxLogoSize int64 = 2 * 1024 * 1024 // 2MB
|
const maxLogoSize int64 = 2 * 1024 * 1024 // 2MB
|
||||||
if resp.ContentLength > maxLogoSize {
|
if resp.ContentLength > maxLogoSize {
|
||||||
return fmt.Errorf("logo is too large")
|
return errLogoTooLarge
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prefer extension in path if supported
|
// Prefer extension in path if supported
|
||||||
@@ -2037,48 +2056,70 @@ func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *
|
|||||||
}
|
}
|
||||||
|
|
||||||
imagePath := path.Join("oidc-client-images", clientID+darkSuffix+"."+ext)
|
imagePath := path.Join("oidc-client-images", clientID+darkSuffix+"."+ext)
|
||||||
if err := s.fileStorage.Save(ctx, imagePath, io.LimitReader(resp.Body, maxLogoSize+1)); err != nil {
|
err = s.fileStorage.Save(ctx, imagePath, utils.NewLimitReader(resp.Body, maxLogoSize+1))
|
||||||
|
if errors.Is(err, utils.ErrSizeExceeded) {
|
||||||
|
return errLogoTooLarge
|
||||||
|
} else if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.updateClientLogoType(ctx, tx, clientID, ext, light); err != nil {
|
err = s.updateClientLogoType(ctx, clientID, ext, light)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) updateClientLogoType(ctx context.Context, tx *gorm.DB, clientID, ext string, light bool) error {
|
func (s *OidcService) updateClientLogoType(ctx context.Context, clientID string, ext string, light bool) error {
|
||||||
var darkSuffix string
|
var darkSuffix string
|
||||||
if !light {
|
if !light {
|
||||||
darkSuffix = "-dark"
|
darkSuffix = "-dark"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tx := s.db.Begin()
|
||||||
|
defer func() {
|
||||||
|
tx.Rollback()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// We need to acquire an update lock for the row to be locked, since we'll update it later
|
||||||
var client model.OidcClient
|
var client model.OidcClient
|
||||||
if err := tx.WithContext(ctx).First(&client, "id = ?", clientID).Error; err != nil {
|
err := tx.
|
||||||
return err
|
WithContext(ctx).
|
||||||
|
Clauses(clause.Locking{Strength: "UPDATE"}).
|
||||||
|
First(&client, "id = ?", clientID).
|
||||||
|
Error
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to look up client: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var currentType *string
|
var currentType *string
|
||||||
if light {
|
if light {
|
||||||
currentType = client.ImageType
|
currentType = client.ImageType
|
||||||
|
client.ImageType = &ext
|
||||||
} else {
|
} else {
|
||||||
currentType = client.DarkImageType
|
currentType = client.DarkImageType
|
||||||
|
client.DarkImageType = &ext
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = tx.
|
||||||
|
WithContext(ctx).
|
||||||
|
Save(&client).
|
||||||
|
Error
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to save updated client: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Commit().Error
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Storage operations must be executed outside of a transaction
|
||||||
if currentType != nil && *currentType != ext {
|
if currentType != nil && *currentType != ext {
|
||||||
old := path.Join("oidc-client-images", client.ID+darkSuffix+"."+*currentType)
|
old := path.Join("oidc-client-images", client.ID+darkSuffix+"."+*currentType)
|
||||||
_ = s.fileStorage.Delete(ctx, old)
|
_ = s.fileStorage.Delete(ctx, old)
|
||||||
}
|
}
|
||||||
|
|
||||||
var column string
|
return nil
|
||||||
if light {
|
|
||||||
column = "image_type"
|
|
||||||
} else {
|
|
||||||
column = "dark_image_type"
|
|
||||||
}
|
|
||||||
|
|
||||||
return tx.WithContext(ctx).
|
|
||||||
Model(&model.OidcClient{}).
|
|
||||||
Where("id = ?", clientID).
|
|
||||||
Update(column, ext).
|
|
||||||
Error
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,10 @@ import (
|
|||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -21,6 +24,7 @@ import (
|
|||||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/storage"
|
||||||
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
|
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -537,3 +541,435 @@ func TestValidateCodeVerifier_Plain(t *testing.T) {
|
|||||||
require.False(t, validateCodeVerifier("NOT!VALID", codeChallenge, true))
|
require.False(t, validateCodeVerifier("NOT!VALID", codeChallenge, true))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOidcService_updateClientLogoType(t *testing.T) {
|
||||||
|
// Create a test database
|
||||||
|
db := testutils.NewDatabaseForTest(t)
|
||||||
|
|
||||||
|
// Create database storage
|
||||||
|
dbStorage, err := storage.NewDatabaseStorage(db)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Init the OidcService
|
||||||
|
s := &OidcService{
|
||||||
|
db: db,
|
||||||
|
fileStorage: dbStorage,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a test client
|
||||||
|
client := model.OidcClient{
|
||||||
|
Name: "Test Client",
|
||||||
|
CallbackURLs: model.UrlList{"https://example.com/callback"},
|
||||||
|
}
|
||||||
|
err = db.Create(&client).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Helper function to check if a file exists in storage
|
||||||
|
fileExists := func(t *testing.T, path string) bool {
|
||||||
|
t.Helper()
|
||||||
|
_, _, err := dbStorage.Open(t.Context(), path)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to create a dummy file in storage
|
||||||
|
createDummyFile := func(t *testing.T, path string) {
|
||||||
|
t.Helper()
|
||||||
|
err := dbStorage.Save(t.Context(), path, strings.NewReader("dummy content"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Updates light logo type for client without previous logo", func(t *testing.T) {
|
||||||
|
// Update the logo type
|
||||||
|
err := s.updateClientLogoType(t.Context(), client.ID, "png", true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the client was updated
|
||||||
|
var updatedClient model.OidcClient
|
||||||
|
err = db.First(&updatedClient, "id = ?", client.ID).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, updatedClient.ImageType)
|
||||||
|
assert.Equal(t, "png", *updatedClient.ImageType)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Updates dark logo type for client without previous dark logo", func(t *testing.T) {
|
||||||
|
// Update the dark logo type
|
||||||
|
err := s.updateClientLogoType(t.Context(), client.ID, "jpg", false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the client was updated
|
||||||
|
var updatedClient model.OidcClient
|
||||||
|
err = db.First(&updatedClient, "id = ?", client.ID).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, updatedClient.DarkImageType)
|
||||||
|
assert.Equal(t, "jpg", *updatedClient.DarkImageType)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Updates light logo type and deletes old file when type changes", func(t *testing.T) {
|
||||||
|
// Create the old PNG file in storage
|
||||||
|
oldPath := "oidc-client-images/" + client.ID + ".png"
|
||||||
|
createDummyFile(t, oldPath)
|
||||||
|
require.True(t, fileExists(t, oldPath), "Old file should exist before update")
|
||||||
|
|
||||||
|
// Client currently has a PNG logo, update to WEBP
|
||||||
|
err := s.updateClientLogoType(t.Context(), client.ID, "webp", true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the client was updated
|
||||||
|
var updatedClient model.OidcClient
|
||||||
|
err = db.First(&updatedClient, "id = ?", client.ID).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, updatedClient.ImageType)
|
||||||
|
assert.Equal(t, "webp", *updatedClient.ImageType)
|
||||||
|
|
||||||
|
// Old PNG file should be deleted
|
||||||
|
assert.False(t, fileExists(t, oldPath), "Old PNG file should have been deleted")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Updates dark logo type and deletes old file when type changes", func(t *testing.T) {
|
||||||
|
// Create the old JPG dark file in storage
|
||||||
|
oldPath := "oidc-client-images/" + client.ID + "-dark.jpg"
|
||||||
|
createDummyFile(t, oldPath)
|
||||||
|
require.True(t, fileExists(t, oldPath), "Old dark file should exist before update")
|
||||||
|
|
||||||
|
// Client currently has a JPG dark logo, update to WEBP
|
||||||
|
err := s.updateClientLogoType(t.Context(), client.ID, "webp", false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the client was updated
|
||||||
|
var updatedClient model.OidcClient
|
||||||
|
err = db.First(&updatedClient, "id = ?", client.ID).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, updatedClient.DarkImageType)
|
||||||
|
assert.Equal(t, "webp", *updatedClient.DarkImageType)
|
||||||
|
|
||||||
|
// Old JPG dark file should be deleted
|
||||||
|
assert.False(t, fileExists(t, oldPath), "Old JPG dark file should have been deleted")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Does not delete file when type remains the same", func(t *testing.T) {
|
||||||
|
// Create the WEBP file in storage
|
||||||
|
webpPath := "oidc-client-images/" + client.ID + ".webp"
|
||||||
|
createDummyFile(t, webpPath)
|
||||||
|
require.True(t, fileExists(t, webpPath), "WEBP file should exist before update")
|
||||||
|
|
||||||
|
// Update to the same type (WEBP)
|
||||||
|
err := s.updateClientLogoType(t.Context(), client.ID, "webp", true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the client still has WEBP
|
||||||
|
var updatedClient model.OidcClient
|
||||||
|
err = db.First(&updatedClient, "id = ?", client.ID).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, updatedClient.ImageType)
|
||||||
|
assert.Equal(t, "webp", *updatedClient.ImageType)
|
||||||
|
|
||||||
|
// WEBP file should still exist since type didn't change
|
||||||
|
assert.True(t, fileExists(t, webpPath), "WEBP file should still exist")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Returns error for non-existent client", func(t *testing.T) {
|
||||||
|
err := s.updateClientLogoType(t.Context(), "non-existent-client-id", "png", true)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorContains(t, err, "failed to look up client")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOidcService_downloadAndSaveLogoFromURL(t *testing.T) {
|
||||||
|
// Create a test database
|
||||||
|
db := testutils.NewDatabaseForTest(t)
|
||||||
|
|
||||||
|
// Create database storage
|
||||||
|
dbStorage, err := storage.NewDatabaseStorage(db)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create a test client
|
||||||
|
client := model.OidcClient{
|
||||||
|
Name: "Test Client",
|
||||||
|
CallbackURLs: model.UrlList{"https://example.com/callback"},
|
||||||
|
}
|
||||||
|
err = db.Create(&client).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Helper function to check if a file exists in storage
|
||||||
|
fileExists := func(t *testing.T, path string) bool {
|
||||||
|
t.Helper()
|
||||||
|
_, _, err := dbStorage.Open(t.Context(), path)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to get file content from storage
|
||||||
|
getFileContent := func(t *testing.T, path string) []byte {
|
||||||
|
t.Helper()
|
||||||
|
reader, _, err := dbStorage.Open(t.Context(), path)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer reader.Close()
|
||||||
|
content, err := io.ReadAll(reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Successfully downloads and saves PNG logo from URL", func(t *testing.T) {
|
||||||
|
// Create mock PNG content
|
||||||
|
pngContent := []byte("fake-png-content")
|
||||||
|
|
||||||
|
// Create a mock HTTP response with headers
|
||||||
|
//nolint:bodyclose
|
||||||
|
pngResponse := testutils.NewMockResponse(http.StatusOK, string(pngContent))
|
||||||
|
pngResponse.Header.Set("Content-Type", "image/png")
|
||||||
|
|
||||||
|
// Create a mock HTTP client with responses
|
||||||
|
mockResponses := map[string]*http.Response{
|
||||||
|
//nolint:bodyclose
|
||||||
|
"https://example.com/logo.png": pngResponse,
|
||||||
|
}
|
||||||
|
httpClient := &http.Client{
|
||||||
|
Transport: &testutils.MockRoundTripper{
|
||||||
|
Responses: mockResponses,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Init the OidcService with mock HTTP client
|
||||||
|
s := &OidcService{
|
||||||
|
db: db,
|
||||||
|
fileStorage: dbStorage,
|
||||||
|
httpClient: httpClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Download and save the logo
|
||||||
|
err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/logo.png", true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the file was saved
|
||||||
|
logoPath := "oidc-client-images/" + client.ID + ".png"
|
||||||
|
require.True(t, fileExists(t, logoPath), "Logo file should exist in storage")
|
||||||
|
|
||||||
|
// Verify the content
|
||||||
|
savedContent := getFileContent(t, logoPath)
|
||||||
|
assert.Equal(t, pngContent, savedContent)
|
||||||
|
|
||||||
|
// Verify the client was updated
|
||||||
|
var updatedClient model.OidcClient
|
||||||
|
err = db.First(&updatedClient, "id = ?", client.ID).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, updatedClient.ImageType)
|
||||||
|
assert.Equal(t, "png", *updatedClient.ImageType)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Successfully downloads and saves dark logo", func(t *testing.T) {
|
||||||
|
// Create mock WEBP content
|
||||||
|
webpContent := []byte("fake-webp-content")
|
||||||
|
|
||||||
|
//nolint:bodyclose
|
||||||
|
webpResponse := testutils.NewMockResponse(http.StatusOK, string(webpContent))
|
||||||
|
webpResponse.Header.Set("Content-Type", "image/webp")
|
||||||
|
|
||||||
|
mockResponses := map[string]*http.Response{
|
||||||
|
//nolint:bodyclose
|
||||||
|
"https://example.com/dark-logo.webp": webpResponse,
|
||||||
|
}
|
||||||
|
httpClient := &http.Client{
|
||||||
|
Transport: &testutils.MockRoundTripper{
|
||||||
|
Responses: mockResponses,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &OidcService{
|
||||||
|
db: db,
|
||||||
|
fileStorage: dbStorage,
|
||||||
|
httpClient: httpClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Download and save the dark logo
|
||||||
|
err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/dark-logo.webp", false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the dark logo file was saved
|
||||||
|
darkLogoPath := "oidc-client-images/" + client.ID + "-dark.webp"
|
||||||
|
require.True(t, fileExists(t, darkLogoPath), "Dark logo file should exist in storage")
|
||||||
|
|
||||||
|
// Verify the content
|
||||||
|
savedContent := getFileContent(t, darkLogoPath)
|
||||||
|
assert.Equal(t, webpContent, savedContent)
|
||||||
|
|
||||||
|
// Verify the client was updated
|
||||||
|
var updatedClient model.OidcClient
|
||||||
|
err = db.First(&updatedClient, "id = ?", client.ID).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, updatedClient.DarkImageType)
|
||||||
|
assert.Equal(t, "webp", *updatedClient.DarkImageType)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Detects extension from URL path", func(t *testing.T) {
|
||||||
|
svgContent := []byte("<svg></svg>")
|
||||||
|
|
||||||
|
mockResponses := map[string]*http.Response{
|
||||||
|
//nolint:bodyclose
|
||||||
|
"https://example.com/icon.svg": testutils.NewMockResponse(http.StatusOK, string(svgContent)),
|
||||||
|
}
|
||||||
|
httpClient := &http.Client{
|
||||||
|
Transport: &testutils.MockRoundTripper{
|
||||||
|
Responses: mockResponses,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &OidcService{
|
||||||
|
db: db,
|
||||||
|
fileStorage: dbStorage,
|
||||||
|
httpClient: httpClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/icon.svg", true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify SVG file was saved
|
||||||
|
logoPath := "oidc-client-images/" + client.ID + ".svg"
|
||||||
|
require.True(t, fileExists(t, logoPath), "SVG logo should exist")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Detects extension from Content-Type when path has no extension", func(t *testing.T) {
|
||||||
|
jpgContent := []byte("fake-jpg-content")
|
||||||
|
|
||||||
|
//nolint:bodyclose
|
||||||
|
jpgResponse := testutils.NewMockResponse(http.StatusOK, string(jpgContent))
|
||||||
|
jpgResponse.Header.Set("Content-Type", "image/jpeg")
|
||||||
|
|
||||||
|
mockResponses := map[string]*http.Response{
|
||||||
|
//nolint:bodyclose
|
||||||
|
"https://example.com/logo": jpgResponse,
|
||||||
|
}
|
||||||
|
httpClient := &http.Client{
|
||||||
|
Transport: &testutils.MockRoundTripper{
|
||||||
|
Responses: mockResponses,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &OidcService{
|
||||||
|
db: db,
|
||||||
|
fileStorage: dbStorage,
|
||||||
|
httpClient: httpClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/logo", true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify JPG file was saved (jpeg extension is normalized to jpg)
|
||||||
|
logoPath := "oidc-client-images/" + client.ID + ".jpg"
|
||||||
|
require.True(t, fileExists(t, logoPath), "JPG logo should exist")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Returns error for invalid URL", func(t *testing.T) {
|
||||||
|
s := &OidcService{
|
||||||
|
db: db,
|
||||||
|
fileStorage: dbStorage,
|
||||||
|
httpClient: &http.Client{},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "://invalid-url", true)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Returns error for non-200 status code", func(t *testing.T) {
|
||||||
|
mockResponses := map[string]*http.Response{
|
||||||
|
//nolint:bodyclose
|
||||||
|
"https://example.com/not-found.png": testutils.NewMockResponse(http.StatusNotFound, "Not Found"),
|
||||||
|
}
|
||||||
|
httpClient := &http.Client{
|
||||||
|
Transport: &testutils.MockRoundTripper{
|
||||||
|
Responses: mockResponses,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &OidcService{
|
||||||
|
db: db,
|
||||||
|
fileStorage: dbStorage,
|
||||||
|
httpClient: httpClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/not-found.png", true)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorContains(t, err, "failed to fetch logo")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Returns error for too large content", func(t *testing.T) {
|
||||||
|
// Create content larger than 2MB (maxLogoSize)
|
||||||
|
largeContent := strings.Repeat("x", 2<<20+100) // 2.1MB
|
||||||
|
|
||||||
|
//nolint:bodyclose
|
||||||
|
largeResponse := testutils.NewMockResponse(http.StatusOK, largeContent)
|
||||||
|
largeResponse.Header.Set("Content-Type", "image/png")
|
||||||
|
largeResponse.Header.Set("Content-Length", strconv.Itoa(len(largeContent)))
|
||||||
|
|
||||||
|
mockResponses := map[string]*http.Response{
|
||||||
|
//nolint:bodyclose
|
||||||
|
"https://example.com/large.png": largeResponse,
|
||||||
|
}
|
||||||
|
httpClient := &http.Client{
|
||||||
|
Transport: &testutils.MockRoundTripper{
|
||||||
|
Responses: mockResponses,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &OidcService{
|
||||||
|
db: db,
|
||||||
|
fileStorage: dbStorage,
|
||||||
|
httpClient: httpClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/large.png", true)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorIs(t, err, errLogoTooLarge)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Returns error for unsupported file type", func(t *testing.T) {
|
||||||
|
//nolint:bodyclose
|
||||||
|
textResponse := testutils.NewMockResponse(http.StatusOK, "text content")
|
||||||
|
textResponse.Header.Set("Content-Type", "text/plain")
|
||||||
|
|
||||||
|
mockResponses := map[string]*http.Response{
|
||||||
|
//nolint:bodyclose
|
||||||
|
"https://example.com/file.txt": textResponse,
|
||||||
|
}
|
||||||
|
httpClient := &http.Client{
|
||||||
|
Transport: &testutils.MockRoundTripper{
|
||||||
|
Responses: mockResponses,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &OidcService{
|
||||||
|
db: db,
|
||||||
|
fileStorage: dbStorage,
|
||||||
|
httpClient: httpClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/file.txt", true)
|
||||||
|
require.Error(t, err)
|
||||||
|
var fileTypeErr *common.FileTypeNotSupportedError
|
||||||
|
require.ErrorAs(t, err, &fileTypeErr)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Returns error for non-existent client", func(t *testing.T) {
|
||||||
|
//nolint:bodyclose
|
||||||
|
pngResponse := testutils.NewMockResponse(http.StatusOK, "content")
|
||||||
|
pngResponse.Header.Set("Content-Type", "image/png")
|
||||||
|
|
||||||
|
mockResponses := map[string]*http.Response{
|
||||||
|
//nolint:bodyclose
|
||||||
|
"https://example.com/logo.png": pngResponse,
|
||||||
|
}
|
||||||
|
httpClient := &http.Client{
|
||||||
|
Transport: &testutils.MockRoundTripper{
|
||||||
|
Responses: mockResponses,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &OidcService{
|
||||||
|
db: db,
|
||||||
|
fileStorage: dbStorage,
|
||||||
|
httpClient: httpClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.downloadAndSaveLogoFromURL(t.Context(), "non-existent-client-id", "https://example.com/logo.png", true)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorContains(t, err, "failed to look up client")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"go.opentelemetry.io/otel/trace"
|
"go.opentelemetry.io/otel/trace"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||||
@@ -101,9 +102,10 @@ func (s *UserService) GetProfilePicture(ctx context.Context, userID string) (io.
|
|||||||
profilePicturePath := path.Join("profile-pictures", userID+".png")
|
profilePicturePath := path.Join("profile-pictures", userID+".png")
|
||||||
|
|
||||||
// Try custom profile picture
|
// Try custom profile picture
|
||||||
if file, size, err := s.fileStorage.Open(ctx, profilePicturePath); err == nil {
|
file, size, err := s.fileStorage.Open(ctx, profilePicturePath)
|
||||||
|
if err == nil {
|
||||||
return file, size, nil
|
return file, size, nil
|
||||||
} else if err != nil && !errors.Is(err, fs.ErrNotExist) {
|
} else if !errors.Is(err, fs.ErrNotExist) {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -120,9 +122,10 @@ func (s *UserService) GetProfilePicture(ctx context.Context, userID string) (io.
|
|||||||
|
|
||||||
// Try cached default for initials
|
// Try cached default for initials
|
||||||
defaultPicturePath := path.Join("profile-pictures", "defaults", user.Initials()+".png")
|
defaultPicturePath := path.Join("profile-pictures", "defaults", user.Initials()+".png")
|
||||||
if file, size, err := s.fileStorage.Open(ctx, defaultPicturePath); err == nil {
|
file, size, err = s.fileStorage.Open(ctx, defaultPicturePath)
|
||||||
|
if err == nil {
|
||||||
return file, size, nil
|
return file, size, nil
|
||||||
} else if err != nil && !errors.Is(err, fs.ErrNotExist) {
|
} else if !errors.Is(err, fs.ErrNotExist) {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -133,12 +136,13 @@ func (s *UserService) GetProfilePicture(ctx context.Context, userID string) (io.
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Save the default picture for future use (in a goroutine to avoid blocking)
|
// Save the default picture for future use (in a goroutine to avoid blocking)
|
||||||
//nolint:contextcheck
|
|
||||||
defaultPictureBytes := defaultPicture.Bytes()
|
defaultPictureBytes := defaultPicture.Bytes()
|
||||||
//nolint:contextcheck
|
//nolint:contextcheck
|
||||||
go func() {
|
go func() {
|
||||||
if err := s.fileStorage.Save(context.Background(), defaultPicturePath, bytes.NewReader(defaultPictureBytes)); err != nil {
|
// Use bytes.NewReader because we need an io.ReadSeeker
|
||||||
slog.Error("Failed to cache default profile picture", slog.String("initials", user.Initials()), slog.Any("error", err))
|
rErr := s.fileStorage.Save(context.Background(), defaultPicturePath, bytes.NewReader(defaultPictureBytes))
|
||||||
|
if rErr != nil {
|
||||||
|
slog.Error("Failed to cache default profile picture", slog.String("initials", user.Initials()), slog.Any("error", rErr))
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -182,17 +186,30 @@ func (s *UserService) UpdateProfilePicture(ctx context.Context, userID string, f
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *UserService) DeleteUser(ctx context.Context, userID string, allowLdapDelete bool) error {
|
func (s *UserService) DeleteUser(ctx context.Context, userID string, allowLdapDelete bool) error {
|
||||||
return s.db.Transaction(func(tx *gorm.DB) error {
|
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||||
return s.deleteUserInternal(ctx, userID, allowLdapDelete, tx)
|
return s.deleteUserInternal(ctx, tx, userID, allowLdapDelete)
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete user '%s': %w", userID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Storage operations must be executed outside of a transaction
|
||||||
|
profilePicturePath := path.Join("profile-pictures", userID+".png")
|
||||||
|
err = s.fileStorage.Delete(ctx, profilePicturePath)
|
||||||
|
if err != nil && !storage.IsNotExist(err) {
|
||||||
|
return fmt.Errorf("failed to delete profile picture for user '%s': %w", userID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UserService) deleteUserInternal(ctx context.Context, userID string, allowLdapDelete bool, tx *gorm.DB) error {
|
func (s *UserService) deleteUserInternal(ctx context.Context, tx *gorm.DB, userID string, allowLdapDelete bool) error {
|
||||||
var user model.User
|
var user model.User
|
||||||
|
|
||||||
err := tx.
|
err := tx.
|
||||||
WithContext(ctx).
|
WithContext(ctx).
|
||||||
Where("id = ?", userID).
|
Where("id = ?", userID).
|
||||||
|
Clauses(clause.Locking{Strength: "UPDATE"}).
|
||||||
First(&user).
|
First(&user).
|
||||||
Error
|
Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -204,11 +221,6 @@ func (s *UserService) deleteUserInternal(ctx context.Context, userID string, all
|
|||||||
return &common.LdapUserUpdateError{}
|
return &common.LdapUserUpdateError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
profilePicturePath := path.Join("profile-pictures", userID+".png")
|
|
||||||
if err := s.fileStorage.Delete(ctx, profilePicturePath); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = tx.WithContext(ctx).Delete(&user).Error
|
err = tx.WithContext(ctx).Delete(&user).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to delete user: %w", err)
|
return fmt.Errorf("failed to delete user: %w", err)
|
||||||
@@ -286,16 +298,27 @@ func (s *UserService) applySignupDefaults(ctx context.Context, user *model.User,
|
|||||||
|
|
||||||
// Apply default user groups
|
// Apply default user groups
|
||||||
var groupIDs []string
|
var groupIDs []string
|
||||||
if v := config.SignupDefaultUserGroupIDs.Value; v != "" && v != "[]" {
|
v := config.SignupDefaultUserGroupIDs.Value
|
||||||
if err := json.Unmarshal([]byte(v), &groupIDs); err != nil {
|
if v != "" && v != "[]" {
|
||||||
|
err := json.Unmarshal([]byte(v), &groupIDs)
|
||||||
|
if err != nil {
|
||||||
return fmt.Errorf("invalid SignupDefaultUserGroupIDs JSON: %w", err)
|
return fmt.Errorf("invalid SignupDefaultUserGroupIDs JSON: %w", err)
|
||||||
}
|
}
|
||||||
if len(groupIDs) > 0 {
|
if len(groupIDs) > 0 {
|
||||||
var groups []model.UserGroup
|
var groups []model.UserGroup
|
||||||
if err := tx.WithContext(ctx).Where("id IN ?", groupIDs).Find(&groups).Error; err != nil {
|
err = tx.WithContext(ctx).
|
||||||
|
Where("id IN ?", groupIDs).
|
||||||
|
Find(&groups).
|
||||||
|
Error
|
||||||
|
if err != nil {
|
||||||
return fmt.Errorf("failed to find default user groups: %w", err)
|
return fmt.Errorf("failed to find default user groups: %w", err)
|
||||||
}
|
}
|
||||||
if err := tx.WithContext(ctx).Model(user).Association("UserGroups").Replace(groups); err != nil {
|
|
||||||
|
err = tx.WithContext(ctx).
|
||||||
|
Model(user).
|
||||||
|
Association("UserGroups").
|
||||||
|
Replace(groups)
|
||||||
|
if err != nil {
|
||||||
return fmt.Errorf("failed to associate default user groups: %w", err)
|
return fmt.Errorf("failed to associate default user groups: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -303,12 +326,15 @@ func (s *UserService) applySignupDefaults(ctx context.Context, user *model.User,
|
|||||||
|
|
||||||
// Apply default custom claims
|
// Apply default custom claims
|
||||||
var claims []dto.CustomClaimCreateDto
|
var claims []dto.CustomClaimCreateDto
|
||||||
if v := config.SignupDefaultCustomClaims.Value; v != "" && v != "[]" {
|
v = config.SignupDefaultCustomClaims.Value
|
||||||
if err := json.Unmarshal([]byte(v), &claims); err != nil {
|
if v != "" && v != "[]" {
|
||||||
|
err := json.Unmarshal([]byte(v), &claims)
|
||||||
|
if err != nil {
|
||||||
return fmt.Errorf("invalid SignupDefaultCustomClaims JSON: %w", err)
|
return fmt.Errorf("invalid SignupDefaultCustomClaims JSON: %w", err)
|
||||||
}
|
}
|
||||||
if len(claims) > 0 {
|
if len(claims) > 0 {
|
||||||
if _, err := s.customClaimService.updateCustomClaimsInternal(ctx, UserID, user.ID, claims, tx); err != nil {
|
_, err = s.customClaimService.updateCustomClaimsInternal(ctx, UserID, user.ID, claims, tx)
|
||||||
|
if err != nil {
|
||||||
return fmt.Errorf("failed to apply default custom claims: %w", err)
|
return fmt.Errorf("failed to apply default custom claims: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -345,6 +371,7 @@ func (s *UserService) updateUserInternal(ctx context.Context, userID string, upd
|
|||||||
err := tx.
|
err := tx.
|
||||||
WithContext(ctx).
|
WithContext(ctx).
|
||||||
Where("id = ?", userID).
|
Where("id = ?", userID).
|
||||||
|
Clauses(clause.Locking{Strength: "UPDATE"}).
|
||||||
First(&user).
|
First(&user).
|
||||||
Error
|
Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -416,14 +443,12 @@ func (s *UserService) RequestOneTimeAccessEmailAsUnauthenticatedUser(ctx context
|
|||||||
|
|
||||||
var userId string
|
var userId string
|
||||||
err := s.db.Model(&model.User{}).Select("id").Where("email = ?", userID).First(&userId).Error
|
err := s.db.Model(&model.User{}).Select("id").Where("email = ?", userID).First(&userId).Error
|
||||||
if err != nil {
|
|
||||||
// Do not return error if user not found to prevent email enumeration
|
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
// Do not return error if user not found to prevent email enumeration
|
||||||
return nil
|
return nil
|
||||||
} else {
|
} else if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return s.requestOneTimeAccessEmailInternal(ctx, userId, redirectPath, 15*time.Minute)
|
return s.requestOneTimeAccessEmailInternal(ctx, userId, redirectPath, 15*time.Minute)
|
||||||
}
|
}
|
||||||
@@ -513,7 +538,9 @@ func (s *UserService) ExchangeOneTimeAccessToken(ctx context.Context, token stri
|
|||||||
var oneTimeAccessToken model.OneTimeAccessToken
|
var oneTimeAccessToken model.OneTimeAccessToken
|
||||||
err := tx.
|
err := tx.
|
||||||
WithContext(ctx).
|
WithContext(ctx).
|
||||||
Where("token = ? AND expires_at > ?", token, datatype.DateTime(time.Now())).Preload("User").
|
Where("token = ? AND expires_at > ?", token, datatype.DateTime(time.Now())).
|
||||||
|
Preload("User").
|
||||||
|
Clauses(clause.Locking{Strength: "UPDATE"}).
|
||||||
First(&oneTimeAccessToken).
|
First(&oneTimeAccessToken).
|
||||||
Error
|
Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -679,7 +706,7 @@ func (s *UserService) ResetProfilePicture(ctx context.Context, userID string) er
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UserService) disableUserInternal(ctx context.Context, userID string, tx *gorm.DB) error {
|
func (s *UserService) disableUserInternal(ctx context.Context, tx *gorm.DB, userID string) error {
|
||||||
return tx.
|
return tx.
|
||||||
WithContext(ctx).
|
WithContext(ctx).
|
||||||
Model(&model.User{}).
|
Model(&model.User{}).
|
||||||
@@ -720,6 +747,7 @@ func (s *UserService) SignUp(ctx context.Context, signupData dto.SignUpDto, ipAd
|
|||||||
err := tx.
|
err := tx.
|
||||||
WithContext(ctx).
|
WithContext(ctx).
|
||||||
Where("token = ?", signupData.Token).
|
Where("token = ?", signupData.Token).
|
||||||
|
Clauses(clause.Locking{Strength: "UPDATE"}).
|
||||||
First(&signupToken).
|
First(&signupToken).
|
||||||
Error
|
Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ func (s *VersionService) GetLatestVersion(ctx context.Context) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if payload.TagName == "" {
|
if payload.TagName == "" {
|
||||||
return "", fmt.Errorf("GitHub API returned empty tag name")
|
return "", errors.New("GitHub API returned empty tag name")
|
||||||
}
|
}
|
||||||
|
|
||||||
return strings.TrimPrefix(payload.TagName, "v"), nil
|
return strings.TrimPrefix(payload.TagName, "v"), nil
|
||||||
|
|||||||
226
backend/internal/storage/database.go
Normal file
226
backend/internal/storage/database.go
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||||
|
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
)
|
||||||
|
|
||||||
|
var TypeDatabase = "database"
|
||||||
|
|
||||||
|
type databaseStorage struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDatabaseStorage creates a new database storage provider
|
||||||
|
func NewDatabaseStorage(db *gorm.DB) (FileStorage, error) {
|
||||||
|
if db == nil {
|
||||||
|
return nil, errors.New("database connection is required")
|
||||||
|
}
|
||||||
|
return &databaseStorage{db: db}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *databaseStorage) Type() string {
|
||||||
|
return TypeDatabase
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *databaseStorage) Save(ctx context.Context, relativePath string, data io.Reader) error {
|
||||||
|
// Normalize the path
|
||||||
|
relativePath = filepath.ToSlash(filepath.Clean(relativePath))
|
||||||
|
|
||||||
|
// Read all data into memory
|
||||||
|
b, err := io.ReadAll(data)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := datatype.DateTime(time.Now())
|
||||||
|
storage := model.Storage{
|
||||||
|
Path: relativePath,
|
||||||
|
Data: b,
|
||||||
|
Size: int64(len(b)),
|
||||||
|
ModTime: now,
|
||||||
|
CreatedAt: now,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use upsert: insert or update on conflict
|
||||||
|
result := s.db.
|
||||||
|
WithContext(ctx).
|
||||||
|
Clauses(clause.OnConflict{
|
||||||
|
Columns: []clause.Column{{Name: "path"}},
|
||||||
|
DoUpdates: clause.AssignmentColumns([]string{"data", "size", "mod_time"}),
|
||||||
|
}).
|
||||||
|
Create(&storage)
|
||||||
|
|
||||||
|
if result.Error != nil {
|
||||||
|
return fmt.Errorf("failed to save file to database: %w", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *databaseStorage) Open(ctx context.Context, relativePath string) (io.ReadCloser, int64, error) {
|
||||||
|
relativePath = filepath.ToSlash(filepath.Clean(relativePath))
|
||||||
|
|
||||||
|
var storage model.Storage
|
||||||
|
result := s.db.
|
||||||
|
WithContext(ctx).
|
||||||
|
Where("path = ?", relativePath).
|
||||||
|
First(&storage)
|
||||||
|
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, 0, os.ErrNotExist
|
||||||
|
}
|
||||||
|
return nil, 0, fmt.Errorf("failed to read file from database: %w", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := io.NopCloser(bytes.NewReader(storage.Data))
|
||||||
|
return reader, storage.Size, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *databaseStorage) Delete(ctx context.Context, relativePath string) error {
|
||||||
|
relativePath = filepath.ToSlash(filepath.Clean(relativePath))
|
||||||
|
|
||||||
|
result := s.db.
|
||||||
|
WithContext(ctx).
|
||||||
|
Where("path = ?", relativePath).
|
||||||
|
Delete(&model.Storage{})
|
||||||
|
if result.Error != nil {
|
||||||
|
return fmt.Errorf("failed to delete file from database: %w", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *databaseStorage) DeleteAll(ctx context.Context, prefix string) error {
|
||||||
|
prefix = filepath.ToSlash(filepath.Clean(prefix))
|
||||||
|
|
||||||
|
// If empty prefix, delete all
|
||||||
|
if isRootPath(prefix) {
|
||||||
|
result := s.db.
|
||||||
|
WithContext(ctx).
|
||||||
|
Where("1 = 1"). // Delete everything
|
||||||
|
Delete(&model.Storage{})
|
||||||
|
if result.Error != nil {
|
||||||
|
return fmt.Errorf("failed to delete all files from database: %w", result.Error)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure prefix ends with / for proper prefix matching
|
||||||
|
if !strings.HasSuffix(prefix, "/") {
|
||||||
|
prefix += "/"
|
||||||
|
}
|
||||||
|
|
||||||
|
query := s.db.WithContext(ctx)
|
||||||
|
query = addPathPrefixClause(s.db.Name(), query, prefix)
|
||||||
|
result := query.Delete(&model.Storage{})
|
||||||
|
if result.Error != nil {
|
||||||
|
return fmt.Errorf("failed to delete files with prefix '%s' from database: %w", prefix, result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *databaseStorage) List(ctx context.Context, prefix string) ([]ObjectInfo, error) {
|
||||||
|
prefix = filepath.ToSlash(filepath.Clean(prefix))
|
||||||
|
|
||||||
|
var storageItems []model.Storage
|
||||||
|
query := s.db.WithContext(ctx)
|
||||||
|
|
||||||
|
if !isRootPath(prefix) {
|
||||||
|
// Ensure prefix matching
|
||||||
|
if !strings.HasSuffix(prefix, "/") {
|
||||||
|
prefix += "/"
|
||||||
|
}
|
||||||
|
query = addPathPrefixClause(s.db.Name(), query, prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := query.
|
||||||
|
Select("path", "size", "mod_time").
|
||||||
|
Find(&storageItems)
|
||||||
|
if result.Error != nil {
|
||||||
|
return nil, fmt.Errorf("failed to list files from database: %w", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
objects := make([]ObjectInfo, 0, len(storageItems))
|
||||||
|
for _, item := range storageItems {
|
||||||
|
// Filter out directory-like paths (those that contain additional slashes after the prefix)
|
||||||
|
relativePath := strings.TrimPrefix(item.Path, prefix)
|
||||||
|
if strings.ContainsRune(relativePath, '/') {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
objects = append(objects, ObjectInfo{
|
||||||
|
Path: item.Path,
|
||||||
|
Size: item.Size,
|
||||||
|
ModTime: time.Time(item.ModTime),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return objects, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *databaseStorage) Walk(ctx context.Context, root string, fn func(ObjectInfo) error) error {
|
||||||
|
root = filepath.ToSlash(filepath.Clean(root))
|
||||||
|
|
||||||
|
var storageItems []model.Storage
|
||||||
|
query := s.db.WithContext(ctx)
|
||||||
|
|
||||||
|
if !isRootPath(root) {
|
||||||
|
// Ensure root matching
|
||||||
|
if !strings.HasSuffix(root, "/") {
|
||||||
|
root += "/"
|
||||||
|
}
|
||||||
|
query = addPathPrefixClause(s.db.Name(), query, root)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := query.
|
||||||
|
Select("path", "size", "mod_time").
|
||||||
|
Find(&storageItems)
|
||||||
|
if result.Error != nil {
|
||||||
|
return fmt.Errorf("failed to walk files from database: %w", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, item := range storageItems {
|
||||||
|
err := fn(ObjectInfo{
|
||||||
|
Path: item.Path,
|
||||||
|
Size: item.Size,
|
||||||
|
ModTime: time.Time(item.ModTime),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRootPath(path string) bool {
|
||||||
|
return path == "" || path == "/" || path == "."
|
||||||
|
}
|
||||||
|
|
||||||
|
func addPathPrefixClause(dialect string, query *gorm.DB, prefix string) *gorm.DB {
|
||||||
|
// In SQLite, we use "GLOB" which can use the index
|
||||||
|
switch dialect {
|
||||||
|
case "sqlite":
|
||||||
|
return query.Where("path GLOB ?", prefix+"*")
|
||||||
|
case "postgres":
|
||||||
|
return query.Where("path LIKE ?", prefix+"%")
|
||||||
|
default:
|
||||||
|
// Indicates a development-time error
|
||||||
|
panic(fmt.Errorf("unsupported database dialect: %s", dialect))
|
||||||
|
}
|
||||||
|
}
|
||||||
148
backend/internal/storage/database_test.go
Normal file
148
backend/internal/storage/database_test.go
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
testingutil "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDatabaseStorageOperations(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
db := testingutil.NewDatabaseForTest(t)
|
||||||
|
store, err := NewDatabaseStorage(db)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Run("type should be database", func(t *testing.T) {
|
||||||
|
assert.Equal(t, TypeDatabase, store.Type())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("save, open and list files", func(t *testing.T) {
|
||||||
|
err := store.Save(ctx, "images/logo.png", bytes.NewBufferString("logo-data"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
reader, size, err := store.Open(ctx, "images/logo.png")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
contents, err := io.ReadAll(reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []byte("logo-data"), contents)
|
||||||
|
assert.Equal(t, int64(len(contents)), size)
|
||||||
|
|
||||||
|
err = store.Save(ctx, "images/nested/child.txt", bytes.NewBufferString("child"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
files, err := store.List(ctx, "images")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, files, 1)
|
||||||
|
assert.Equal(t, "images/logo.png", files[0].Path)
|
||||||
|
assert.Equal(t, int64(len("logo-data")), files[0].Size)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("save should update existing file", func(t *testing.T) {
|
||||||
|
err := store.Save(ctx, "test/update.txt", bytes.NewBufferString("original"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = store.Save(ctx, "test/update.txt", bytes.NewBufferString("updated"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
reader, size, err := store.Open(ctx, "test/update.txt")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
contents, err := io.ReadAll(reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []byte("updated"), contents)
|
||||||
|
assert.Equal(t, int64(len("updated")), size)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("delete files individually", func(t *testing.T) {
|
||||||
|
err := store.Save(ctx, "images/delete-me.txt", bytes.NewBufferString("temp"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, store.Delete(ctx, "images/delete-me.txt"))
|
||||||
|
_, _, err = store.Open(ctx, "images/delete-me.txt")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.True(t, IsNotExist(err))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("delete missing file should not error", func(t *testing.T) {
|
||||||
|
require.NoError(t, store.Delete(ctx, "images/missing.txt"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("delete all files", func(t *testing.T) {
|
||||||
|
require.NoError(t, store.Save(ctx, "cleanup/a.txt", bytes.NewBufferString("a")))
|
||||||
|
require.NoError(t, store.Save(ctx, "cleanup/b.txt", bytes.NewBufferString("b")))
|
||||||
|
require.NoError(t, store.Save(ctx, "cleanup/nested/c.txt", bytes.NewBufferString("c")))
|
||||||
|
require.NoError(t, store.DeleteAll(ctx, "/"))
|
||||||
|
|
||||||
|
_, _, err := store.Open(ctx, "cleanup/a.txt")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.True(t, IsNotExist(err))
|
||||||
|
|
||||||
|
_, _, err = store.Open(ctx, "cleanup/b.txt")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.True(t, IsNotExist(err))
|
||||||
|
|
||||||
|
_, _, err = store.Open(ctx, "cleanup/nested/c.txt")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.True(t, IsNotExist(err))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("delete all files under a prefix", func(t *testing.T) {
|
||||||
|
require.NoError(t, store.Save(ctx, "cleanup/a.txt", bytes.NewBufferString("a")))
|
||||||
|
require.NoError(t, store.Save(ctx, "cleanup/b.txt", bytes.NewBufferString("b")))
|
||||||
|
require.NoError(t, store.Save(ctx, "cleanup/nested/c.txt", bytes.NewBufferString("c")))
|
||||||
|
require.NoError(t, store.DeleteAll(ctx, "cleanup"))
|
||||||
|
|
||||||
|
_, _, err := store.Open(ctx, "cleanup/a.txt")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.True(t, IsNotExist(err))
|
||||||
|
|
||||||
|
_, _, err = store.Open(ctx, "cleanup/b.txt")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.True(t, IsNotExist(err))
|
||||||
|
|
||||||
|
_, _, err = store.Open(ctx, "cleanup/nested/c.txt")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.True(t, IsNotExist(err))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("walk files", func(t *testing.T) {
|
||||||
|
require.NoError(t, store.Save(ctx, "walk/file1.txt", bytes.NewBufferString("1")))
|
||||||
|
require.NoError(t, store.Save(ctx, "walk/file2.txt", bytes.NewBufferString("2")))
|
||||||
|
require.NoError(t, store.Save(ctx, "walk/nested/file3.txt", bytes.NewBufferString("3")))
|
||||||
|
|
||||||
|
var paths []string
|
||||||
|
err := store.Walk(ctx, "walk", func(info ObjectInfo) error {
|
||||||
|
paths = append(paths, info.Path)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, paths, 3)
|
||||||
|
assert.Contains(t, paths, "walk/file1.txt")
|
||||||
|
assert.Contains(t, paths, "walk/file2.txt")
|
||||||
|
assert.Contains(t, paths, "walk/nested/file3.txt")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewDatabaseStorage(t *testing.T) {
|
||||||
|
t.Run("should return error with nil database", func(t *testing.T) {
|
||||||
|
_, err := NewDatabaseStorage(nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "database connection is required")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should create storage with valid database", func(t *testing.T) {
|
||||||
|
db := testingutil.NewDatabaseForTest(t)
|
||||||
|
store, err := NewDatabaseStorage(db)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, store)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -1,7 +1,10 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
@@ -56,6 +59,23 @@ func IsPrivateIP(ip net.IP) bool {
|
|||||||
return IsLocalhostIP(ip) || IsPrivateLanIP(ip) || IsTailscaleIP(ip) || IsLocalIPv6(ip)
|
return IsLocalhostIP(ip) || IsPrivateLanIP(ip) || IsTailscaleIP(ip) || IsLocalIPv6(ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func IsURLPrivate(ctx context.Context, u *url.URL) (bool, error) {
|
||||||
|
var r net.Resolver
|
||||||
|
ips, err := r.LookupIPAddr(ctx, u.Hostname())
|
||||||
|
if err != nil || len(ips) == 0 {
|
||||||
|
return false, errors.New("cannot resolve hostname")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prevents SSRF by allowing only public IPs
|
||||||
|
for _, addr := range ips {
|
||||||
|
if IsPrivateIP(addr.IP) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
func listContainsIP(ipNets []*net.IPNet, ip net.IP) bool {
|
func listContainsIP(ipNets []*net.IPNet, ip net.IP) bool {
|
||||||
for _, ipNet := range ipNets {
|
for _, ipNet := range ipNets {
|
||||||
if ipNet.Contains(ip) {
|
if ipNet.Contains(ip) {
|
||||||
|
|||||||
@@ -1,8 +1,14 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
)
|
)
|
||||||
@@ -20,9 +26,8 @@ func TestIsLocalhostIP(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
ip := net.ParseIP(tt.ip)
|
ip := net.ParseIP(tt.ip)
|
||||||
if got := IsLocalhostIP(ip); got != tt.expected {
|
got := IsLocalhostIP(ip)
|
||||||
t.Errorf("IsLocalhostIP(%s) = %v, want %v", tt.ip, got, tt.expected)
|
assert.Equal(t, tt.expected, got)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -40,9 +45,8 @@ func TestIsPrivateLanIP(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
ip := net.ParseIP(tt.ip)
|
ip := net.ParseIP(tt.ip)
|
||||||
if got := IsPrivateLanIP(ip); got != tt.expected {
|
got := IsPrivateLanIP(ip)
|
||||||
t.Errorf("IsPrivateLanIP(%s) = %v, want %v", tt.ip, got, tt.expected)
|
assert.Equal(t, tt.expected, got)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -59,9 +63,9 @@ func TestIsTailscaleIP(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
ip := net.ParseIP(tt.ip)
|
ip := net.ParseIP(tt.ip)
|
||||||
if got := IsTailscaleIP(ip); got != tt.expected {
|
|
||||||
t.Errorf("IsTailscaleIP(%s) = %v, want %v", tt.ip, got, tt.expected)
|
got := IsTailscaleIP(ip)
|
||||||
}
|
assert.Equal(t, tt.expected, got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -86,16 +90,17 @@ func TestIsLocalIPv6(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
ip := net.ParseIP(tt.ip)
|
ip := net.ParseIP(tt.ip)
|
||||||
if got := IsLocalIPv6(ip); got != tt.expected {
|
got := IsLocalIPv6(ip)
|
||||||
t.Errorf("IsLocalIPv6(%s) = %v, want %v", tt.ip, got, tt.expected)
|
assert.Equal(t, tt.expected, got)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIsPrivateIP(t *testing.T) {
|
func TestIsPrivateIP(t *testing.T) {
|
||||||
// Save and restore env config
|
// Save and restore env config
|
||||||
origRanges := common.EnvConfig.LocalIPv6Ranges
|
origRanges := common.EnvConfig.LocalIPv6Ranges
|
||||||
defer func() { common.EnvConfig.LocalIPv6Ranges = origRanges }()
|
t.Cleanup(func() {
|
||||||
|
common.EnvConfig.LocalIPv6Ranges = origRanges
|
||||||
|
})
|
||||||
|
|
||||||
common.EnvConfig.LocalIPv6Ranges = "fd00::/8"
|
common.EnvConfig.LocalIPv6Ranges = "fd00::/8"
|
||||||
localIPv6Ranges = nil // reset
|
localIPv6Ranges = nil // reset
|
||||||
@@ -115,9 +120,8 @@ func TestIsPrivateIP(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
ip := net.ParseIP(tt.ip)
|
ip := net.ParseIP(tt.ip)
|
||||||
if got := IsPrivateIP(ip); got != tt.expected {
|
got := IsPrivateIP(ip)
|
||||||
t.Errorf("IsPrivateIP(%s) = %v, want %v", tt.ip, got, tt.expected)
|
assert.Equal(t, tt.expected, got)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -138,22 +142,202 @@ func TestListContainsIP(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
ip := net.ParseIP(tt.ip)
|
ip := net.ParseIP(tt.ip)
|
||||||
if got := listContainsIP(list, ip); got != tt.expected {
|
got := listContainsIP(list, ip)
|
||||||
t.Errorf("listContainsIP(%s) = %v, want %v", tt.ip, got, tt.expected)
|
assert.Equal(t, tt.expected, got)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInit_LocalIPv6Ranges(t *testing.T) {
|
func TestInit_LocalIPv6Ranges(t *testing.T) {
|
||||||
// Save and restore env config
|
// Save and restore env config
|
||||||
origRanges := common.EnvConfig.LocalIPv6Ranges
|
origRanges := common.EnvConfig.LocalIPv6Ranges
|
||||||
defer func() { common.EnvConfig.LocalIPv6Ranges = origRanges }()
|
t.Cleanup(func() {
|
||||||
|
common.EnvConfig.LocalIPv6Ranges = origRanges
|
||||||
|
})
|
||||||
|
|
||||||
common.EnvConfig.LocalIPv6Ranges = "fd00::/8, invalidCIDR ,fc00::/7"
|
common.EnvConfig.LocalIPv6Ranges = "fd00::/8, invalidCIDR ,fc00::/7"
|
||||||
localIPv6Ranges = nil
|
localIPv6Ranges = nil
|
||||||
loadLocalIPv6Ranges()
|
loadLocalIPv6Ranges()
|
||||||
|
|
||||||
if len(localIPv6Ranges) != 2 {
|
assert.Len(t, localIPv6Ranges, 2)
|
||||||
t.Errorf("expected 2 valid IPv6 ranges, got %d", len(localIPv6Ranges))
|
}
|
||||||
|
|
||||||
|
func TestIsURLPrivate(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
urlStr string
|
||||||
|
expectPriv bool
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "localhost by name",
|
||||||
|
urlStr: "http://localhost",
|
||||||
|
expectPriv: true,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "localhost with port",
|
||||||
|
urlStr: "http://localhost:8080",
|
||||||
|
expectPriv: true,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "127.0.0.1 IP",
|
||||||
|
urlStr: "http://127.0.0.1",
|
||||||
|
expectPriv: true,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "127.0.0.1 with port",
|
||||||
|
urlStr: "http://127.0.0.1:3000",
|
||||||
|
expectPriv: true,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 loopback",
|
||||||
|
urlStr: "http://[::1]",
|
||||||
|
expectPriv: true,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 loopback with port",
|
||||||
|
urlStr: "http://[::1]:8080",
|
||||||
|
expectPriv: true,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "private IP 10.x.x.x",
|
||||||
|
urlStr: "http://10.0.0.1",
|
||||||
|
expectPriv: true,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "private IP 192.168.x.x",
|
||||||
|
urlStr: "http://192.168.1.1",
|
||||||
|
expectPriv: true,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "private IP 172.16.x.x",
|
||||||
|
urlStr: "http://172.16.0.1",
|
||||||
|
expectPriv: true,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Tailscale IP",
|
||||||
|
urlStr: "http://100.64.0.1",
|
||||||
|
expectPriv: true,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "public IP - Google DNS",
|
||||||
|
urlStr: "http://8.8.8.8",
|
||||||
|
expectPriv: false,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "public IP - Cloudflare DNS",
|
||||||
|
urlStr: "http://1.1.1.1",
|
||||||
|
expectPriv: false,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid hostname",
|
||||||
|
urlStr: "http://this-should-not-resolve-ever-123456789.invalid",
|
||||||
|
expectPriv: false,
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
u, err := url.Parse(tt.urlStr)
|
||||||
|
require.NoError(t, err, "Failed to parse URL %s", tt.urlStr)
|
||||||
|
|
||||||
|
isPriv, err := IsURLPrivate(ctx, u)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
require.Error(t, err, "IsURLPrivate(%s) expected error but got none", tt.urlStr)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err, "IsURLPrivate(%s) unexpected error", tt.urlStr)
|
||||||
|
assert.Equal(t, tt.expectPriv, isPriv, "IsURLPrivate(%s)", tt.urlStr)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsURLPrivate_WithDomainName(t *testing.T) {
|
||||||
|
// Note: These tests rely on actual DNS resolution
|
||||||
|
// They test real public domains to ensure they are not flagged as private
|
||||||
|
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
urlStr string
|
||||||
|
expectPriv bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Google public domain",
|
||||||
|
urlStr: "https://www.google.com",
|
||||||
|
expectPriv: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GitHub public domain",
|
||||||
|
urlStr: "https://github.com",
|
||||||
|
expectPriv: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// localhost.localtest.me is a well-known domain that resolves to 127.0.0.1
|
||||||
|
name: "localhost.localtest.me resolves to 127.0.0.1",
|
||||||
|
urlStr: "http://localhost.localtest.me",
|
||||||
|
expectPriv: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// 10.0.0.1.nip.io resolves to 10.0.0.1 (private IP)
|
||||||
|
name: "nip.io domain resolving to private 10.x IP",
|
||||||
|
urlStr: "http://10.0.0.1.nip.io",
|
||||||
|
expectPriv: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// 192.168.1.1.nip.io resolves to 192.168.1.1 (private IP)
|
||||||
|
name: "nip.io domain resolving to private 192.168.x IP",
|
||||||
|
urlStr: "http://192.168.1.1.nip.io",
|
||||||
|
expectPriv: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// 127.0.0.1.nip.io resolves to 127.0.0.1 (localhost)
|
||||||
|
name: "nip.io domain resolving to localhost",
|
||||||
|
urlStr: "http://127.0.0.1.nip.io",
|
||||||
|
expectPriv: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
u, err := url.Parse(tt.urlStr)
|
||||||
|
require.NoError(t, err, "Failed to parse URL %s", tt.urlStr)
|
||||||
|
|
||||||
|
isPriv, err := IsURLPrivate(ctx, u)
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("DNS resolution failed for %s (network issue?): %v", tt.urlStr, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectPriv, isPriv, "IsURLPrivate(%s)", tt.urlStr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsURLPrivate_ContextCancellation(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(t.Context())
|
||||||
|
cancel() // Cancel immediately
|
||||||
|
|
||||||
|
u, err := url.Parse("http://example.com")
|
||||||
|
require.NoError(t, err, "Failed to parse URL")
|
||||||
|
|
||||||
|
_, err = IsURLPrivate(ctx, u)
|
||||||
|
assert.Error(t, err, "IsURLPrivate with cancelled context expected error but got none")
|
||||||
|
}
|
||||||
|
|||||||
34
backend/internal/utils/stream_util.go
Normal file
34
backend/internal/utils/stream_util.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrSizeExceeded = errors.New("stream size exceeded")
|
||||||
|
|
||||||
|
// LimitReader is like io.LimitReader but throws an error if the stream exceeds the max size
|
||||||
|
// io.LimitReader instead just returns io.EOF
|
||||||
|
// Adapted from https://github.com/golang/go/issues/51115#issuecomment-1079761212
|
||||||
|
type LimitReader struct {
|
||||||
|
io.ReadCloser
|
||||||
|
N int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLimitReader(r io.ReadCloser, limit int64) *LimitReader {
|
||||||
|
return &LimitReader{r, limit}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *LimitReader) Read(p []byte) (n int, err error) {
|
||||||
|
if r.N <= 0 {
|
||||||
|
return 0, ErrSizeExceeded
|
||||||
|
}
|
||||||
|
|
||||||
|
if int64(len(p)) > r.N {
|
||||||
|
p = p[0:r.N]
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err = r.ReadCloser.Read(p)
|
||||||
|
r.N -= int64(n)
|
||||||
|
return
|
||||||
|
}
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
DROP TABLE storage;
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
-- The "storage" table contains file data stored in the database
|
||||||
|
CREATE TABLE storage
|
||||||
|
(
|
||||||
|
path TEXT NOT NULL PRIMARY KEY,
|
||||||
|
data BYTEA NOT NULL,
|
||||||
|
size BIGINT NOT NULL,
|
||||||
|
mod_time TIMESTAMPTZ NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL
|
||||||
|
);
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
DROP INDEX idx_api_keys_expires_at;
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
CREATE INDEX idx_api_keys_expires_at ON api_keys(expires_at);
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
PRAGMA foreign_keys=OFF;
|
||||||
|
BEGIN;
|
||||||
|
DROP TABLE storage;
|
||||||
|
|
||||||
|
COMMIT;
|
||||||
|
PRAGMA foreign_keys=ON;
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
PRAGMA foreign_keys=OFF;
|
||||||
|
BEGIN;
|
||||||
|
-- The "storage" table contains file data stored in the database
|
||||||
|
CREATE TABLE storage
|
||||||
|
(
|
||||||
|
path TEXT NOT NULL PRIMARY KEY,
|
||||||
|
data BLOB NOT NULL,
|
||||||
|
size INTEGER NOT NULL,
|
||||||
|
mod_time DATETIME NOT NULL,
|
||||||
|
created_at DATETIME NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
COMMIT;
|
||||||
|
PRAGMA foreign_keys=ON;
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
PRAGMA foreign_keys=OFF;
|
||||||
|
BEGIN;
|
||||||
|
DROP INDEX idx_api_keys_expires_at;
|
||||||
|
COMMIT;
|
||||||
|
PRAGMA foreign_keys=ON;
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
PRAGMA foreign_keys=OFF;
|
||||||
|
BEGIN;
|
||||||
|
CREATE INDEX idx_api_keys_expires_at ON api_keys(expires_at);
|
||||||
|
COMMIT;
|
||||||
|
PRAGMA foreign_keys=ON;
|
||||||
@@ -20,8 +20,10 @@ services:
|
|||||||
file: docker-compose.yml
|
file: docker-compose.yml
|
||||||
service: pocket-id
|
service: pocket-id
|
||||||
environment:
|
environment:
|
||||||
|
- APP_ENV=test
|
||||||
- DB_PROVIDER=postgres
|
- DB_PROVIDER=postgres
|
||||||
- DB_CONNECTION_STRING=postgres://postgres:postgres@postgres:5432/pocket-id
|
- DB_CONNECTION_STRING=postgres://postgres:postgres@postgres:5432/pocket-id
|
||||||
|
- FILE_BACKEND=${FILE_BACKEND}
|
||||||
depends_on:
|
depends_on:
|
||||||
postgres:
|
postgres:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ services:
|
|||||||
retries: 10
|
retries: 10
|
||||||
|
|
||||||
create-bucket:
|
create-bucket:
|
||||||
image: amazon/aws-cli
|
image: amazon/aws-cli:latest
|
||||||
environment:
|
environment:
|
||||||
AWS_ACCESS_KEY_ID: test
|
AWS_ACCESS_KEY_ID: test
|
||||||
AWS_SECRET_ACCESS_KEY: test
|
AWS_SECRET_ACCESS_KEY: test
|
||||||
@@ -28,15 +28,14 @@ services:
|
|||||||
file: docker-compose.yml
|
file: docker-compose.yml
|
||||||
service: pocket-id
|
service: pocket-id
|
||||||
environment:
|
environment:
|
||||||
FILE_BACKEND: s3
|
- S3_BUCKET=pocket-id-test
|
||||||
S3_BUCKET: pocket-id-test
|
- S3_REGION=us-east-1
|
||||||
S3_REGION: us-east-1
|
- S3_ENDPOINT=http://localstack-s3:4566
|
||||||
S3_ENDPOINT: http://localstack-s3:4566
|
- S3_ACCESS_KEY_ID=test
|
||||||
S3_ACCESS_KEY_ID: test
|
- S3_SECRET_ACCESS_KEY=test
|
||||||
S3_SECRET_ACCESS_KEY: test
|
- S3_FORCE_PATH_STYLE=true
|
||||||
S3_FORCE_PATH_STYLE: true
|
- KEYS_STORAGE=database
|
||||||
KEYS_STORAGE: database
|
- ENCRYPTION_KEY=test1234test1234test1234test1234
|
||||||
ENCRYPTION_KEY: test
|
|
||||||
depends_on:
|
depends_on:
|
||||||
create-bucket:
|
create-bucket:
|
||||||
condition: service_completed_successfully
|
condition: service_completed_successfully
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ services:
|
|||||||
- '1411:1411'
|
- '1411:1411'
|
||||||
environment:
|
environment:
|
||||||
- APP_ENV=test
|
- APP_ENV=test
|
||||||
|
- FILE_BACKEND=${FILE_BACKEND}
|
||||||
build:
|
build:
|
||||||
args:
|
args:
|
||||||
- BUILD_TAGS=e2etest
|
- BUILD_TAGS=e2etest
|
||||||
|
|||||||
Reference in New Issue
Block a user