mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-23 17:25:22 +03:00
Compare commits
19 Commits
feat/user-
...
v2/use-ite
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
19c793b81a | ||
|
|
d8f060e64c | ||
|
|
8336e31aa6 | ||
|
|
002da54c72 | ||
|
|
8c1f1c8340 | ||
|
|
176bac8123 | ||
|
|
c4bd20a90d | ||
|
|
900f8fe240 | ||
|
|
1da694f8ad | ||
|
|
96aa2ce043 | ||
|
|
e06538a101 | ||
|
|
ddff3a2975 | ||
|
|
a738d9fe88 | ||
|
|
39c1f93756 | ||
|
|
e306d6eb58 | ||
|
|
46793fe68a | ||
|
|
e22822890f | ||
|
|
2694d79add | ||
|
|
98cf1f66c3 |
17
.github/workflows/e2e-tests.yml
vendored
17
.github/workflows/e2e-tests.yml
vendored
@@ -171,7 +171,7 @@ jobs:
|
||||
run: |
|
||||
DOCKER_COMPOSE_FILE=docker-compose.yml
|
||||
|
||||
export FILE_BACKEND="${{ matrix.storage }}"
|
||||
echo "FILE_BACKEND=${{ matrix.storage }}" > .env
|
||||
if [ "${{ matrix.db }}" = "postgres" ]; then
|
||||
DOCKER_COMPOSE_FILE=docker-compose-postgres.yml
|
||||
elif [ "${{ matrix.storage }}" = "s3" ]; then
|
||||
@@ -179,7 +179,20 @@ jobs:
|
||||
fi
|
||||
|
||||
docker compose -f "$DOCKER_COMPOSE_FILE" up -d
|
||||
docker compose -f "$DOCKER_COMPOSE_FILE" logs -f pocket-id &> /tmp/backend.log &
|
||||
|
||||
{
|
||||
LOG_FILE="/tmp/backend.log"
|
||||
while true; do
|
||||
CID=$(docker compose -f "$DOCKER_COMPOSE_FILE" ps -q pocket-id)
|
||||
if [ -n "$CID" ]; then
|
||||
echo "[$(date)] Attaching logs for $CID" >> "$LOG_FILE"
|
||||
docker logs -f --since=0 "$CID" >> "$LOG_FILE" 2>&1
|
||||
else
|
||||
echo "[$(date)] Container not yet running…" >> "$LOG_FILE"
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
} &
|
||||
|
||||
- name: Run Playwright tests
|
||||
working-directory: ./tests
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -15,6 +15,7 @@ node_modules
|
||||
/backend/bin
|
||||
pocket-id
|
||||
/tests/test-results/*.json
|
||||
.tmp/
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
|
||||
@@ -4,7 +4,7 @@ Pocket ID is a simple OIDC provider that allows users to authenticate with their
|
||||
|
||||
→ Try out the [Demo](https://demo.pocket-id.org)
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/1e99ba44-76da-4b47-9b8a-dbe9b7f84512" width="1200"/>
|
||||
<img src="https://github.com/user-attachments/assets/96ac549d-b897-404a-8811-f42b16ea58e2" width="1200"/>
|
||||
|
||||
The goal of Pocket ID is to be a simple and easy-to-use. There are other self-hosted OIDC providers like [Keycloak](https://www.keycloak.org/) or [ORY Hydra](https://www.ory.sh/hydra/) but they are often too complex for simple use cases.
|
||||
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
_ "time/tzdata"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/cmds"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
)
|
||||
|
||||
// @title Pocket ID API
|
||||
@@ -11,5 +14,9 @@ import (
|
||||
// @description.markdown
|
||||
|
||||
func main() {
|
||||
if err := common.ValidateEnvConfig(&common.EnvConfig); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "config error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
cmds.Execute()
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ require (
|
||||
github.com/disintegration/imaging v1.6.2
|
||||
github.com/emersion/go-sasl v0.0.0-20241020182733-b788ff22d5a6
|
||||
github.com/emersion/go-smtp v0.24.0
|
||||
github.com/fxamacker/cbor/v2 v2.9.0
|
||||
github.com/gin-contrib/slog v1.2.0
|
||||
github.com/gin-gonic/gin v1.11.0
|
||||
github.com/glebarez/go-sqlite v1.22.0
|
||||
@@ -84,6 +83,7 @@ require (
|
||||
github.com/disintegration/gift v1.2.1 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/fxamacker/cbor/v2 v2.9.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.11 // indirect
|
||||
github.com/gin-contrib/sse v1.1.0 // indirect
|
||||
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect
|
||||
|
||||
@@ -24,8 +24,7 @@ func initApplicationImages(ctx context.Context, fileStorage storage.FileStorage)
|
||||
// Previous versions of images
|
||||
// If these are found, they are deleted
|
||||
legacyImageHashes := imageHashMap{
|
||||
"background.jpg": mustDecodeHex("138d510030ed845d1d74de34658acabff562d306476454369a60ab8ade31933f"),
|
||||
"background.webp": mustDecodeHex("3fc436a66d6b872b01d96a4e75046c46b5c3e2daccd51e98ecdf98fd445599ab"),
|
||||
"background.jpg": mustDecodeHex("138d510030ed845d1d74de34658acabff562d306476454369a60ab8ade31933f"),
|
||||
}
|
||||
|
||||
sourceFiles, err := resources.FS.ReadDir("images")
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/job"
|
||||
@@ -15,6 +16,16 @@ import (
|
||||
)
|
||||
|
||||
func Bootstrap(ctx context.Context) error {
|
||||
var shutdownFns []utils.Service
|
||||
defer func() { //nolint:contextcheck
|
||||
// Invoke all shutdown functions on exit
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := utils.NewServiceRunner(shutdownFns...).Run(shutdownCtx); err != nil {
|
||||
slog.Error("Error during graceful shutdown", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Initialize the observability stack, including the logger, distributed tracing, and metrics
|
||||
shutdownFns, httpClient, err := initObservability(ctx, common.EnvConfig.MetricsEnabled, common.EnvConfig.TracingEnabled)
|
||||
if err != nil {
|
||||
@@ -22,15 +33,80 @@ func Bootstrap(ctx context.Context) error {
|
||||
}
|
||||
slog.InfoContext(ctx, "Pocket ID is starting")
|
||||
|
||||
// Connect to the database
|
||||
db, err := NewDatabase()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize database: %w", err)
|
||||
}
|
||||
|
||||
// Initialize the file storage backend
|
||||
var fileStorage storage.FileStorage
|
||||
fileStorage, err := InitStorage(ctx, db)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize file storage (backend: %s): %w", common.EnvConfig.FileBackend, err)
|
||||
}
|
||||
|
||||
imageExtensions, err := initApplicationImages(ctx, fileStorage)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize application images: %w", err)
|
||||
}
|
||||
|
||||
// Create all services
|
||||
svc, err := initServices(ctx, db, httpClient, imageExtensions, fileStorage)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize services: %w", err)
|
||||
}
|
||||
|
||||
waitUntil, err := svc.appLockService.Acquire(ctx, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to acquire application lock: %w", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(time.Until(waitUntil)):
|
||||
}
|
||||
|
||||
shutdownFn := func(shutdownCtx context.Context) error {
|
||||
sErr := svc.appLockService.Release(shutdownCtx)
|
||||
if sErr != nil {
|
||||
return fmt.Errorf("failed to release application lock: %w", sErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
shutdownFns = append(shutdownFns, shutdownFn)
|
||||
|
||||
// Init the job scheduler
|
||||
scheduler, err := job.NewScheduler()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create job scheduler: %w", err)
|
||||
}
|
||||
err = registerScheduledJobs(ctx, db, svc, httpClient, scheduler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register scheduled jobs: %w", err)
|
||||
}
|
||||
|
||||
// Init the router
|
||||
router, err := initRouter(db, svc)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize router: %w", err)
|
||||
}
|
||||
|
||||
// Run all background services
|
||||
// This call blocks until the context is canceled
|
||||
services := []utils.Service{svc.appLockService.RunRenewal, router}
|
||||
|
||||
if common.EnvConfig.AppEnv != "test" {
|
||||
services = append(services, scheduler.Run)
|
||||
}
|
||||
|
||||
err = utils.NewServiceRunner(services...).Run(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to run services: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func InitStorage(ctx context.Context, db *gorm.DB) (fileStorage storage.FileStorage, err error) {
|
||||
switch common.EnvConfig.FileBackend {
|
||||
case storage.TypeFileSystem:
|
||||
fileStorage, err = storage.NewFilesystemStorage(common.EnvConfig.UploadPath)
|
||||
@@ -52,53 +128,8 @@ func Bootstrap(ctx context.Context) error {
|
||||
err = fmt.Errorf("unknown file storage backend: %s", common.EnvConfig.FileBackend)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize file storage (backend: %s): %w", common.EnvConfig.FileBackend, err)
|
||||
return fileStorage, err
|
||||
}
|
||||
|
||||
imageExtensions, err := initApplicationImages(ctx, fileStorage)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize application images: %w", err)
|
||||
}
|
||||
|
||||
// Create all services
|
||||
svc, err := initServices(ctx, db, httpClient, imageExtensions, fileStorage)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize services: %w", err)
|
||||
}
|
||||
|
||||
// Init the job scheduler
|
||||
scheduler, err := job.NewScheduler()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create job scheduler: %w", err)
|
||||
}
|
||||
err = registerScheduledJobs(ctx, db, svc, httpClient, scheduler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register scheduled jobs: %w", err)
|
||||
}
|
||||
|
||||
// Init the router
|
||||
router := initRouter(db, svc)
|
||||
|
||||
// Run all background services
|
||||
// This call blocks until the context is canceled
|
||||
err = utils.
|
||||
NewServiceRunner(router, scheduler.Run).
|
||||
Run(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to run services: %w", err)
|
||||
}
|
||||
|
||||
// Invoke all shutdown functions
|
||||
// We give these a timeout of 5s
|
||||
// Note: we use a background context because the run context has been canceled already
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer shutdownCancel()
|
||||
err = utils.
|
||||
NewServiceRunner(shutdownFns...).
|
||||
Run(shutdownCtx) //nolint:contextcheck
|
||||
if err != nil {
|
||||
slog.Error("Error shutting down services", slog.Any("error", err))
|
||||
}
|
||||
|
||||
return nil
|
||||
return fileStorage, nil
|
||||
}
|
||||
|
||||
@@ -12,12 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database"
|
||||
postgresMigrate "github.com/golang-migrate/migrate/v4/database/postgres"
|
||||
sqliteMigrate "github.com/golang-migrate/migrate/v4/database/sqlite3"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/github"
|
||||
"github.com/golang-migrate/migrate/v4/source/iofs"
|
||||
slogGorm "github.com/orandin/slog-gorm"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
@@ -26,11 +21,10 @@ import (
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
sqliteutil "github.com/pocket-id/pocket-id/backend/internal/utils/sqlite"
|
||||
"github.com/pocket-id/pocket-id/backend/resources"
|
||||
)
|
||||
|
||||
func NewDatabase() (db *gorm.DB, err error) {
|
||||
db, err = connectDatabase()
|
||||
db, err = ConnectDatabase()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
@@ -39,105 +33,15 @@ func NewDatabase() (db *gorm.DB, err error) {
|
||||
return nil, fmt.Errorf("failed to get sql.DB: %w", err)
|
||||
}
|
||||
|
||||
// Choose the correct driver for the database provider
|
||||
var driver database.Driver
|
||||
switch common.EnvConfig.DbProvider {
|
||||
case common.DbProviderSqlite:
|
||||
driver, err = sqliteMigrate.WithInstance(sqlDb, &sqliteMigrate.Config{
|
||||
NoTxWrap: true,
|
||||
})
|
||||
case common.DbProviderPostgres:
|
||||
driver, err = postgresMigrate.WithInstance(sqlDb, &postgresMigrate.Config{})
|
||||
default:
|
||||
// Should never happen at this point
|
||||
return nil, fmt.Errorf("unsupported database provider: %s", common.EnvConfig.DbProvider)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create migration driver: %w", err)
|
||||
}
|
||||
|
||||
// Run migrations
|
||||
if err := migrateDatabase(driver); err != nil {
|
||||
if err := utils.MigrateDatabase(sqlDb); err != nil {
|
||||
return nil, fmt.Errorf("failed to run migrations: %w", err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func migrateDatabase(driver database.Driver) error {
|
||||
// Embedded migrations via iofs
|
||||
path := "migrations/" + string(common.EnvConfig.DbProvider)
|
||||
source, err := iofs.New(resources.FS, path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create embedded migration source: %w", err)
|
||||
}
|
||||
|
||||
m, err := migrate.NewWithInstance("iofs", source, "pocket-id", driver)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create migration instance: %w", err)
|
||||
}
|
||||
|
||||
requiredVersion, err := getRequiredMigrationVersion(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get last migration version: %w", err)
|
||||
}
|
||||
|
||||
currentVersion, _, _ := m.Version()
|
||||
if currentVersion > requiredVersion {
|
||||
slog.Warn("Database version is newer than the application supports, possible downgrade detected", slog.Uint64("db_version", uint64(currentVersion)), slog.Uint64("app_version", uint64(requiredVersion)))
|
||||
if !common.EnvConfig.AllowDowngrade {
|
||||
return fmt.Errorf("database version (%d) is newer than application version (%d), downgrades are not allowed (set ALLOW_DOWNGRADE=true to enable)", currentVersion, requiredVersion)
|
||||
}
|
||||
slog.Info("Fetching migrations from GitHub to handle possible downgrades")
|
||||
return migrateDatabaseFromGitHub(driver, requiredVersion)
|
||||
}
|
||||
|
||||
if err := m.Migrate(requiredVersion); err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
||||
return fmt.Errorf("failed to apply embedded migrations: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateDatabaseFromGitHub(driver database.Driver, version uint) error {
|
||||
srcURL := "github://pocket-id/pocket-id/backend/resources/migrations/" + string(common.EnvConfig.DbProvider)
|
||||
|
||||
m, err := migrate.NewWithDatabaseInstance(srcURL, "pocket-id", driver)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create GitHub migration instance: %w", err)
|
||||
}
|
||||
|
||||
if err := m.Migrate(version); err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
||||
return fmt.Errorf("failed to apply GitHub migrations: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getRequiredMigrationVersion reads the embedded migration files and returns the highest version number found.
|
||||
func getRequiredMigrationVersion(path string) (uint, error) {
|
||||
entries, err := resources.FS.ReadDir(path)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to read migration directory: %w", err)
|
||||
}
|
||||
|
||||
var maxVersion uint
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
var version uint
|
||||
n, err := fmt.Sscanf(name, "%d_", &version)
|
||||
if err == nil && n == 1 {
|
||||
if version > maxVersion {
|
||||
maxVersion = version
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return maxVersion, nil
|
||||
}
|
||||
|
||||
func connectDatabase() (db *gorm.DB, err error) {
|
||||
func ConnectDatabase() (db *gorm.DB, err error) {
|
||||
var dialector gorm.Dialector
|
||||
|
||||
// Choose the correct database provider
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
func init() {
|
||||
registerTestControllers = []func(apiGroup *gin.RouterGroup, db *gorm.DB, svc *services){
|
||||
func(apiGroup *gin.RouterGroup, db *gorm.DB, svc *services) {
|
||||
testService, err := service.NewTestService(db, svc.appConfigService, svc.jwtService, svc.ldapService, svc.fileStorage)
|
||||
testService, err := service.NewTestService(db, svc.appConfigService, svc.jwtService, svc.ldapService, svc.appLockService, svc.fileStorage)
|
||||
if err != nil {
|
||||
slog.Error("Failed to initialize test service", slog.Any("error", err))
|
||||
os.Exit(1)
|
||||
|
||||
@@ -29,16 +29,7 @@ import (
|
||||
// This is used to register additional controllers for tests
|
||||
var registerTestControllers []func(apiGroup *gin.RouterGroup, db *gorm.DB, svc *services)
|
||||
|
||||
func initRouter(db *gorm.DB, svc *services) utils.Service {
|
||||
runner, err := initRouterInternal(db, svc)
|
||||
if err != nil {
|
||||
slog.Error("Failed to init router", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
return runner
|
||||
}
|
||||
|
||||
func initRouterInternal(db *gorm.DB, svc *services) (utils.Service, error) {
|
||||
func initRouter(db *gorm.DB, svc *services) (utils.Service, error) {
|
||||
// Set the appropriate Gin mode based on the environment
|
||||
switch common.EnvConfig.AppEnv {
|
||||
case common.AppEnvProduction:
|
||||
@@ -198,7 +189,6 @@ func initLogger(r *gin.Engine) {
|
||||
"GET /api/application-images/logo",
|
||||
"GET /api/application-images/background",
|
||||
"GET /api/application-images/favicon",
|
||||
"GET /api/application-images/email",
|
||||
"GET /_app",
|
||||
"GET /fonts",
|
||||
"GET /healthz",
|
||||
|
||||
@@ -27,6 +27,7 @@ type services struct {
|
||||
apiKeyService *service.ApiKeyService
|
||||
versionService *service.VersionService
|
||||
fileStorage storage.FileStorage
|
||||
appLockService *service.AppLockService
|
||||
}
|
||||
|
||||
// Initializes all services
|
||||
@@ -40,6 +41,7 @@ func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client, ima
|
||||
|
||||
svc.fileStorage = fileStorage
|
||||
svc.appImagesService = service.NewAppImagesService(imageExtensions, fileStorage)
|
||||
svc.appLockService = service.NewAppLockService(db)
|
||||
|
||||
svc.emailService, err = service.NewEmailService(db, svc.appConfigService)
|
||||
if err != nil {
|
||||
|
||||
70
backend/internal/cmds/export.go
Normal file
70
backend/internal/cmds/export.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package cmds
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/bootstrap"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
type exportFlags struct {
|
||||
Path string
|
||||
}
|
||||
|
||||
func init() {
|
||||
var flags exportFlags
|
||||
|
||||
exportCmd := &cobra.Command{
|
||||
Use: "export",
|
||||
Short: "Exports all data of Pocket ID into a ZIP file",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return runExport(cmd.Context(), flags)
|
||||
},
|
||||
}
|
||||
|
||||
exportCmd.Flags().StringVarP(&flags.Path, "path", "p", "pocket-id-export.zip", "Path to the ZIP file to export the data to, or '-' to write to stdout")
|
||||
|
||||
rootCmd.AddCommand(exportCmd)
|
||||
}
|
||||
|
||||
// runExport orchestrates the export flow
|
||||
func runExport(ctx context.Context, flags exportFlags) error {
|
||||
db, err := bootstrap.NewDatabase()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
|
||||
storage, err := bootstrap.InitStorage(ctx, db)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize storage: %w", err)
|
||||
}
|
||||
|
||||
exportService := service.NewExportService(db, storage)
|
||||
|
||||
var w io.Writer
|
||||
if flags.Path == "-" {
|
||||
w = os.Stdout
|
||||
} else {
|
||||
file, err := os.Create(flags.Path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create export file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
w = file
|
||||
}
|
||||
|
||||
if err := exportService.ExportToZip(ctx, w); err != nil {
|
||||
return fmt.Errorf("failed to export data: %w", err)
|
||||
}
|
||||
|
||||
if flags.Path != "-" {
|
||||
fmt.Printf("Exported data to %s\n", flags.Path)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
191
backend/internal/cmds/import.go
Normal file
191
backend/internal/cmds/import.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package cmds
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/bootstrap"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
)
|
||||
|
||||
type importFlags struct {
|
||||
Path string
|
||||
Yes bool
|
||||
ForcefullyAcquireLock bool
|
||||
}
|
||||
|
||||
func init() {
|
||||
var flags importFlags
|
||||
|
||||
importCmd := &cobra.Command{
|
||||
Use: "import",
|
||||
Short: "Imports all data of Pocket ID from a ZIP file",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return runImport(cmd.Context(), flags)
|
||||
},
|
||||
}
|
||||
|
||||
importCmd.Flags().StringVarP(&flags.Path, "path", "p", "pocket-id-export.zip", "Path to the ZIP file to import the data from, or '-' to read from stdin")
|
||||
importCmd.Flags().BoolVarP(&flags.Yes, "yes", "y", false, "Skip confirmation prompts")
|
||||
importCmd.Flags().BoolVarP(&flags.ForcefullyAcquireLock, "forcefully-acquire-lock", "", false, "Forcefully acquire the application lock by terminating the Pocket ID instance")
|
||||
|
||||
rootCmd.AddCommand(importCmd)
|
||||
}
|
||||
|
||||
// runImport handles the high-level orchestration of the import process
|
||||
func runImport(ctx context.Context, flags importFlags) error {
|
||||
if !flags.Yes {
|
||||
ok, err := askForConfirmation()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get confirmation: %w", err)
|
||||
}
|
||||
if !ok {
|
||||
fmt.Println("Aborted")
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
zipReader *zip.ReadCloser
|
||||
cleanup func()
|
||||
err error
|
||||
)
|
||||
|
||||
if flags.Path == "-" {
|
||||
zipReader, cleanup, err = readZipFromStdin()
|
||||
defer cleanup()
|
||||
} else {
|
||||
zipReader, err = zip.OpenReader(flags.Path)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open zip: %w", err)
|
||||
}
|
||||
defer zipReader.Close()
|
||||
|
||||
db, err := bootstrap.ConnectDatabase()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = acquireImportLock(ctx, db, flags.ForcefullyAcquireLock)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
storage, err := bootstrap.InitStorage(ctx, db)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize storage: %w", err)
|
||||
}
|
||||
|
||||
importService := service.NewImportService(db, storage)
|
||||
err = importService.ImportFromZip(ctx, &zipReader.Reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to import data from zip: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("Import completed successfully.")
|
||||
return nil
|
||||
}
|
||||
|
||||
func acquireImportLock(ctx context.Context, db *gorm.DB, force bool) error {
|
||||
// Check if the kv table exists, in case we are starting from an empty database
|
||||
exists, err := utils.DBTableExists(db, "kv")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if kv table exists: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
// This either means the database is empty, or the import is into an old version of PocketID that doesn't support locks
|
||||
// In either case, there's no lock to acquire
|
||||
fmt.Println("Could not acquire a lock because the 'kv' table does not exist. This is fine if you're importing into a new database, but make sure that there isn't an instance of Pocket ID currently running and using the same database.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Note that we do not call a deferred Release if the data was imported
|
||||
// This is because we are overriding the contents of the database, so the lock is automatically lost
|
||||
appLockService := service.NewAppLockService(db)
|
||||
|
||||
opCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
waitUntil, err := appLockService.Acquire(opCtx, force)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrLockUnavailable) {
|
||||
//nolint:staticcheck
|
||||
return errors.New("Pocket ID must be stopped before importing data; please stop the running instance or run with --forcefully-acquire-lock to terminate the other instance")
|
||||
}
|
||||
return fmt.Errorf("failed to acquire application lock: %w", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(time.Until(waitUntil)):
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func askForConfirmation() (bool, error) {
|
||||
fmt.Println("WARNING: This feature is experimental and may not work correctly. Please create a backup before proceeding and report any issues you encounter.")
|
||||
fmt.Println()
|
||||
fmt.Println("WARNING: Import will erase all existing data at the following locations:")
|
||||
fmt.Printf("Database: %s\n", absolutePathOrOriginal(common.EnvConfig.DbConnectionString))
|
||||
fmt.Printf("Uploads Path: %s\n", absolutePathOrOriginal(common.EnvConfig.UploadPath))
|
||||
|
||||
ok, err := utils.PromptForConfirmation("Do you want to continue?")
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
// absolutePathOrOriginal returns the absolute path of the given path, or the original if it fails
|
||||
func absolutePathOrOriginal(path string) string {
|
||||
abs, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return path
|
||||
}
|
||||
return abs
|
||||
}
|
||||
|
||||
func readZipFromStdin() (*zip.ReadCloser, func(), error) {
|
||||
tmpFile, err := os.CreateTemp("", "pocket-id-import-*.zip")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create temporary file: %w", err)
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
}
|
||||
|
||||
if _, err := io.Copy(tmpFile, os.Stdin); err != nil {
|
||||
tmpFile.Close()
|
||||
cleanup()
|
||||
return nil, nil, fmt.Errorf("failed to read data from stdin: %w", err)
|
||||
}
|
||||
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
cleanup()
|
||||
return nil, nil, fmt.Errorf("failed to close temporary file: %w", err)
|
||||
}
|
||||
|
||||
r, err := zip.OpenReader(tmpFile.Name())
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return r, cleanup, nil
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
@@ -78,7 +79,7 @@ func keyRotate(ctx context.Context, flags keyRotateFlags, db *gorm.DB, envConfig
|
||||
}
|
||||
if !ok {
|
||||
fmt.Println("Aborted")
|
||||
return nil
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package cmds
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -69,78 +67,14 @@ func TestKeyRotate(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Run("file storage", func(t *testing.T) {
|
||||
testKeyRotateWithFileStorage(t, tt.flags, tt.wantErr, tt.errMsg)
|
||||
})
|
||||
|
||||
t.Run("database storage", func(t *testing.T) {
|
||||
testKeyRotateWithDatabaseStorage(t, tt.flags, tt.wantErr, tt.errMsg)
|
||||
})
|
||||
testKeyRotateWithDatabaseStorage(t, tt.flags, tt.wantErr, tt.errMsg)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testKeyRotateWithFileStorage(t *testing.T, flags keyRotateFlags, wantErr bool, errMsg string) {
|
||||
// Create temporary directory for keys
|
||||
tempDir := t.TempDir()
|
||||
keysPath := filepath.Join(tempDir, "keys")
|
||||
err := os.MkdirAll(keysPath, 0755)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set up file storage config
|
||||
envConfig := &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: keysPath,
|
||||
}
|
||||
|
||||
// Create test database
|
||||
db := testingutils.NewDatabaseForTest(t)
|
||||
|
||||
// Initialize app config service and create instance
|
||||
appConfigService, err := service.NewAppConfigService(t.Context(), db)
|
||||
require.NoError(t, err)
|
||||
instanceID := appConfigService.GetDbConfig().InstanceID.Value
|
||||
|
||||
// Check if key exists before rotation
|
||||
keyProvider, err := jwkutils.GetKeyProvider(db, envConfig, instanceID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Run the key rotation
|
||||
err = keyRotate(t.Context(), flags, db, envConfig)
|
||||
|
||||
if wantErr {
|
||||
require.Error(t, err)
|
||||
if errMsg != "" {
|
||||
require.ErrorContains(t, err, errMsg)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify key was created
|
||||
key, err := keyProvider.LoadKey()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, key)
|
||||
|
||||
// Verify the algorithm matches what we requested
|
||||
alg, _ := key.Algorithm()
|
||||
assert.NotEmpty(t, alg)
|
||||
if flags.Alg != "" {
|
||||
expectedAlg := flags.Alg
|
||||
if expectedAlg == "EdDSA" {
|
||||
// EdDSA keys should have the EdDSA algorithm
|
||||
assert.Equal(t, "EdDSA", alg.String())
|
||||
} else {
|
||||
assert.Equal(t, expectedAlg, alg.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testKeyRotateWithDatabaseStorage(t *testing.T, flags keyRotateFlags, wantErr bool, errMsg string) {
|
||||
// Set up database storage config
|
||||
envConfig := &common.EnvConfigSchema{
|
||||
KeysStorage: "database",
|
||||
EncryptionKey: []byte("test-encryption-key-characters-long"),
|
||||
}
|
||||
|
||||
|
||||
@@ -12,9 +12,10 @@ import (
|
||||
)
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "pocket-id",
|
||||
Short: "A simple and easy-to-use OIDC provider that allows users to authenticate with their passkeys to your services.",
|
||||
Long: "By default, this command starts the pocket-id server.",
|
||||
Use: "pocket-id",
|
||||
Short: "A simple and easy-to-use OIDC provider that allows users to authenticate with their passkeys to your services.",
|
||||
Long: "By default, this command starts the pocket-id server.",
|
||||
SilenceUsage: true,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
// Start the server
|
||||
err := bootstrap.Bootstrap(cmd.Context())
|
||||
|
||||
@@ -38,40 +38,40 @@ const (
|
||||
)
|
||||
|
||||
type EnvConfigSchema struct {
|
||||
AppEnv AppEnv `env:"APP_ENV" options:"toLower"`
|
||||
LogLevel string `env:"LOG_LEVEL" options:"toLower"`
|
||||
AppURL string `env:"APP_URL" options:"toLower,trimTrailingSlash"`
|
||||
DbProvider DbProvider `env:"DB_PROVIDER" options:"toLower"`
|
||||
DbConnectionString string `env:"DB_CONNECTION_STRING" options:"file"`
|
||||
FileBackend string `env:"FILE_BACKEND" options:"toLower"`
|
||||
UploadPath string `env:"UPLOAD_PATH"`
|
||||
S3Bucket string `env:"S3_BUCKET"`
|
||||
S3Region string `env:"S3_REGION"`
|
||||
S3Endpoint string `env:"S3_ENDPOINT"`
|
||||
S3AccessKeyID string `env:"S3_ACCESS_KEY_ID"`
|
||||
S3SecretAccessKey string `env:"S3_SECRET_ACCESS_KEY"`
|
||||
S3ForcePathStyle bool `env:"S3_FORCE_PATH_STYLE"`
|
||||
S3DisableDefaultIntegrityChecks bool `env:"S3_DISABLE_DEFAULT_INTEGRITY_CHECKS"`
|
||||
KeysPath string `env:"KEYS_PATH"`
|
||||
KeysStorage string `env:"KEYS_STORAGE"`
|
||||
EncryptionKey []byte `env:"ENCRYPTION_KEY" options:"file"`
|
||||
Port string `env:"PORT"`
|
||||
Host string `env:"HOST" options:"toLower"`
|
||||
UnixSocket string `env:"UNIX_SOCKET"`
|
||||
UnixSocketMode string `env:"UNIX_SOCKET_MODE"`
|
||||
MaxMindLicenseKey string `env:"MAXMIND_LICENSE_KEY" options:"file"`
|
||||
GeoLiteDBPath string `env:"GEOLITE_DB_PATH"`
|
||||
GeoLiteDBUrl string `env:"GEOLITE_DB_URL"`
|
||||
LocalIPv6Ranges string `env:"LOCAL_IPV6_RANGES"`
|
||||
UiConfigDisabled bool `env:"UI_CONFIG_DISABLED"`
|
||||
MetricsEnabled bool `env:"METRICS_ENABLED"`
|
||||
TracingEnabled bool `env:"TRACING_ENABLED"`
|
||||
LogJSON bool `env:"LOG_JSON"`
|
||||
TrustProxy bool `env:"TRUST_PROXY"`
|
||||
AuditLogRetentionDays int `env:"AUDIT_LOG_RETENTION_DAYS"`
|
||||
AnalyticsDisabled bool `env:"ANALYTICS_DISABLED"`
|
||||
AllowDowngrade bool `env:"ALLOW_DOWNGRADE"`
|
||||
InternalAppURL string `env:"INTERNAL_APP_URL"`
|
||||
AppEnv AppEnv `env:"APP_ENV" options:"toLower"`
|
||||
LogLevel string `env:"LOG_LEVEL" options:"toLower"`
|
||||
LogJSON bool `env:"LOG_JSON"`
|
||||
AppURL string `env:"APP_URL" options:"toLower,trimTrailingSlash"`
|
||||
DbProvider DbProvider
|
||||
DbConnectionString string `env:"DB_CONNECTION_STRING" options:"file"`
|
||||
EncryptionKey []byte `env:"ENCRYPTION_KEY" options:"file"`
|
||||
Port string `env:"PORT"`
|
||||
Host string `env:"HOST" options:"toLower"`
|
||||
UnixSocket string `env:"UNIX_SOCKET"`
|
||||
UnixSocketMode string `env:"UNIX_SOCKET_MODE"`
|
||||
LocalIPv6Ranges string `env:"LOCAL_IPV6_RANGES"`
|
||||
UiConfigDisabled bool `env:"UI_CONFIG_DISABLED"`
|
||||
MetricsEnabled bool `env:"METRICS_ENABLED"`
|
||||
TracingEnabled bool `env:"TRACING_ENABLED"`
|
||||
TrustProxy bool `env:"TRUST_PROXY"`
|
||||
AnalyticsDisabled bool `env:"ANALYTICS_DISABLED"`
|
||||
AllowDowngrade bool `env:"ALLOW_DOWNGRADE"`
|
||||
InternalAppURL string `env:"INTERNAL_APP_URL"`
|
||||
|
||||
MaxMindLicenseKey string `env:"MAXMIND_LICENSE_KEY" options:"file"`
|
||||
GeoLiteDBPath string `env:"GEOLITE_DB_PATH"`
|
||||
GeoLiteDBUrl string `env:"GEOLITE_DB_URL"`
|
||||
|
||||
FileBackend string `env:"FILE_BACKEND" options:"toLower"`
|
||||
UploadPath string `env:"UPLOAD_PATH"`
|
||||
|
||||
S3Bucket string `env:"S3_BUCKET"`
|
||||
S3Region string `env:"S3_REGION"`
|
||||
S3Endpoint string `env:"S3_ENDPOINT"`
|
||||
S3AccessKeyID string `env:"S3_ACCESS_KEY_ID"`
|
||||
S3SecretAccessKey string `env:"S3_SECRET_ACCESS_KEY"`
|
||||
S3ForcePathStyle bool `env:"S3_FORCE_PATH_STYLE"`
|
||||
S3DisableDefaultIntegrityChecks bool `env:"S3_DISABLE_DEFAULT_INTEGRITY_CHECKS"`
|
||||
}
|
||||
|
||||
var EnvConfig = defaultConfig()
|
||||
@@ -86,17 +86,15 @@ func init() {
|
||||
|
||||
func defaultConfig() EnvConfigSchema {
|
||||
return EnvConfigSchema{
|
||||
AppEnv: AppEnvProduction,
|
||||
LogLevel: "info",
|
||||
DbProvider: "sqlite",
|
||||
FileBackend: "filesystem",
|
||||
KeysPath: "data/keys",
|
||||
AuditLogRetentionDays: 90,
|
||||
AppURL: AppUrl,
|
||||
Port: "1411",
|
||||
Host: "0.0.0.0",
|
||||
GeoLiteDBPath: "data/GeoLite2-City.mmdb",
|
||||
GeoLiteDBUrl: MaxMindGeoLiteCityUrl,
|
||||
AppEnv: AppEnvProduction,
|
||||
LogLevel: "info",
|
||||
DbProvider: "sqlite",
|
||||
FileBackend: "filesystem",
|
||||
AppURL: AppUrl,
|
||||
Port: "1411",
|
||||
Host: "0.0.0.0",
|
||||
GeoLiteDBPath: "data/GeoLite2-City.mmdb",
|
||||
GeoLiteDBUrl: MaxMindGeoLiteCityUrl,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,32 +117,28 @@ func parseEnvConfig() error {
|
||||
return fmt.Errorf("error preparing env config: %w", err)
|
||||
}
|
||||
|
||||
err = validateEnvConfig(&EnvConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// validateEnvConfig checks the EnvConfig for required fields and valid values
|
||||
func validateEnvConfig(config *EnvConfigSchema) error {
|
||||
// ValidateEnvConfig checks the EnvConfig for required fields and valid values
|
||||
func ValidateEnvConfig(config *EnvConfigSchema) error {
|
||||
if _, err := sloggin.ParseLevel(config.LogLevel); err != nil {
|
||||
return errors.New("invalid LOG_LEVEL value. Must be 'debug', 'info', 'warn' or 'error'")
|
||||
}
|
||||
|
||||
switch config.DbProvider {
|
||||
case DbProviderSqlite:
|
||||
if config.DbConnectionString == "" {
|
||||
config.DbConnectionString = defaultSqliteConnString
|
||||
}
|
||||
case DbProviderPostgres:
|
||||
if config.DbConnectionString == "" {
|
||||
return errors.New("missing required env var 'DB_CONNECTION_STRING' for Postgres database")
|
||||
}
|
||||
if len(config.EncryptionKey) < 16 {
|
||||
return errors.New("ENCRYPTION_KEY must be at least 16 bytes long")
|
||||
}
|
||||
|
||||
switch {
|
||||
case config.DbConnectionString == "":
|
||||
config.DbProvider = DbProviderSqlite
|
||||
config.DbConnectionString = defaultSqliteConnString
|
||||
case strings.HasPrefix(config.DbConnectionString, "postgres://") || strings.HasPrefix(config.DbConnectionString, "postgresql://"):
|
||||
config.DbProvider = DbProviderPostgres
|
||||
default:
|
||||
return errors.New("invalid DB_PROVIDER value. Must be 'sqlite' or 'postgres'")
|
||||
config.DbProvider = DbProviderSqlite
|
||||
}
|
||||
|
||||
parsedAppUrl, err := url.Parse(config.AppURL)
|
||||
@@ -168,27 +162,8 @@ func validateEnvConfig(config *EnvConfigSchema) error {
|
||||
}
|
||||
}
|
||||
|
||||
switch config.KeysStorage {
|
||||
// KeysStorage defaults to "file" if empty
|
||||
case "":
|
||||
config.KeysStorage = "file"
|
||||
case "database":
|
||||
if config.EncryptionKey == nil {
|
||||
return errors.New("ENCRYPTION_KEY must be non-empty when KEYS_STORAGE is database")
|
||||
}
|
||||
case "file":
|
||||
// All good, these are valid values
|
||||
default:
|
||||
return fmt.Errorf("invalid value for KEYS_STORAGE: %s", config.KeysStorage)
|
||||
}
|
||||
|
||||
switch config.FileBackend {
|
||||
case "s3":
|
||||
if config.KeysStorage == "file" {
|
||||
return errors.New("KEYS_STORAGE cannot be 'file' when FILE_BACKEND is 's3'")
|
||||
}
|
||||
case "database":
|
||||
// All good, these are valid values
|
||||
case "s3", "database":
|
||||
case "", "filesystem":
|
||||
if config.UploadPath == "" {
|
||||
config.UploadPath = defaultFsUploadPath
|
||||
@@ -216,10 +191,6 @@ func validateEnvConfig(config *EnvConfigSchema) error {
|
||||
|
||||
}
|
||||
|
||||
if config.AuditLogRetentionDays <= 0 {
|
||||
return errors.New("AUDIT_LOG_RETENTION_DAYS must be greater than 0")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
@@ -8,6 +8,20 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func parseAndValidateEnvConfig(t *testing.T) error {
|
||||
t.Helper()
|
||||
|
||||
if _, exists := os.LookupEnv("ENCRYPTION_KEY"); !exists {
|
||||
t.Setenv("ENCRYPTION_KEY", "0123456789abcdef")
|
||||
}
|
||||
|
||||
if err := parseEnvConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return ValidateEnvConfig(&EnvConfig)
|
||||
}
|
||||
|
||||
func TestParseEnvConfig(t *testing.T) {
|
||||
// Store original config to restore later
|
||||
originalConfig := EnvConfig
|
||||
@@ -17,11 +31,10 @@ func TestParseEnvConfig(t *testing.T) {
|
||||
|
||||
t.Run("should parse valid SQLite config correctly", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "SQLITE") // should be lowercased automatically
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "HTTP://LOCALHOST:3000")
|
||||
|
||||
err := parseEnvConfig()
|
||||
err := parseAndValidateEnvConfig(t)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, DbProviderSqlite, EnvConfig.DbProvider)
|
||||
assert.Equal(t, "http://localhost:3000", EnvConfig.AppURL)
|
||||
@@ -29,147 +42,76 @@ func TestParseEnvConfig(t *testing.T) {
|
||||
|
||||
t.Run("should parse valid Postgres config correctly", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "POSTGRES")
|
||||
t.Setenv("DB_CONNECTION_STRING", "postgres://user:pass@localhost/db")
|
||||
t.Setenv("APP_URL", "https://example.com")
|
||||
|
||||
err := parseEnvConfig()
|
||||
err := parseAndValidateEnvConfig(t)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, DbProviderPostgres, EnvConfig.DbProvider)
|
||||
})
|
||||
|
||||
t.Run("should fail with invalid DB_PROVIDER", func(t *testing.T) {
|
||||
t.Run("should fail when ENCRYPTION_KEY is too short", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "invalid")
|
||||
t.Setenv("DB_CONNECTION_STRING", "test")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
t.Setenv("ENCRYPTION_KEY", "short")
|
||||
|
||||
err := parseEnvConfig()
|
||||
err := parseAndValidateEnvConfig(t)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "invalid DB_PROVIDER value")
|
||||
assert.ErrorContains(t, err, "ENCRYPTION_KEY must be at least 16 bytes long")
|
||||
})
|
||||
|
||||
t.Run("should set default SQLite connection string when DB_CONNECTION_STRING is empty", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
|
||||
err := parseEnvConfig()
|
||||
err := parseAndValidateEnvConfig(t)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, defaultSqliteConnString, EnvConfig.DbConnectionString)
|
||||
})
|
||||
|
||||
t.Run("should fail when Postgres DB_CONNECTION_STRING is missing", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "postgres")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "missing required env var 'DB_CONNECTION_STRING' for Postgres")
|
||||
})
|
||||
|
||||
t.Run("should fail with invalid APP_URL", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "€://not-a-valid-url")
|
||||
|
||||
err := parseEnvConfig()
|
||||
err := parseAndValidateEnvConfig(t)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "APP_URL is not a valid URL")
|
||||
})
|
||||
|
||||
t.Run("should fail when APP_URL contains path", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000/path")
|
||||
|
||||
err := parseEnvConfig()
|
||||
err := parseAndValidateEnvConfig(t)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "APP_URL must not contain a path")
|
||||
})
|
||||
|
||||
t.Run("should fail with invalid INTERNAL_APP_URL", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("INTERNAL_APP_URL", "€://not-a-valid-url")
|
||||
|
||||
err := parseEnvConfig()
|
||||
err := parseAndValidateEnvConfig(t)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "INTERNAL_APP_URL is not a valid URL")
|
||||
})
|
||||
|
||||
t.Run("should fail when INTERNAL_APP_URL contains path", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("INTERNAL_APP_URL", "http://localhost:3000/path")
|
||||
|
||||
err := parseEnvConfig()
|
||||
err := parseAndValidateEnvConfig(t)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "INTERNAL_APP_URL must not contain a path")
|
||||
})
|
||||
|
||||
t.Run("should default KEYS_STORAGE to 'file' when empty", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "file", EnvConfig.KeysStorage)
|
||||
})
|
||||
|
||||
t.Run("should fail when KEYS_STORAGE is 'database' but no encryption key", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
t.Setenv("KEYS_STORAGE", "database")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "ENCRYPTION_KEY must be non-empty when KEYS_STORAGE is database")
|
||||
})
|
||||
|
||||
t.Run("should accept valid KEYS_STORAGE values", func(t *testing.T) {
|
||||
validStorageTypes := []string{"file", "database"}
|
||||
|
||||
for _, storage := range validStorageTypes {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
t.Setenv("KEYS_STORAGE", storage)
|
||||
if storage == "database" {
|
||||
t.Setenv("ENCRYPTION_KEY", "test-key")
|
||||
}
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, storage, EnvConfig.KeysStorage)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("should fail with invalid KEYS_STORAGE value", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
t.Setenv("KEYS_STORAGE", "invalid")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "invalid value for KEYS_STORAGE")
|
||||
})
|
||||
|
||||
t.Run("should parse boolean environment variables correctly", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
t.Setenv("UI_CONFIG_DISABLED", "true")
|
||||
@@ -178,7 +120,7 @@ func TestParseEnvConfig(t *testing.T) {
|
||||
t.Setenv("TRUST_PROXY", "true")
|
||||
t.Setenv("ANALYTICS_DISABLED", "false")
|
||||
|
||||
err := parseEnvConfig()
|
||||
err := parseAndValidateEnvConfig(t)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, EnvConfig.UiConfigDisabled)
|
||||
assert.True(t, EnvConfig.MetricsEnabled)
|
||||
@@ -187,56 +129,19 @@ func TestParseEnvConfig(t *testing.T) {
|
||||
assert.False(t, EnvConfig.AnalyticsDisabled)
|
||||
})
|
||||
|
||||
t.Run("should default audit log retention days to 90", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 90, EnvConfig.AuditLogRetentionDays)
|
||||
})
|
||||
|
||||
t.Run("should parse audit log retention days override", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
t.Setenv("AUDIT_LOG_RETENTION_DAYS", "365")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 365, EnvConfig.AuditLogRetentionDays)
|
||||
})
|
||||
|
||||
t.Run("should fail when AUDIT_LOG_RETENTION_DAYS is non-positive", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
t.Setenv("AUDIT_LOG_RETENTION_DAYS", "0")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "AUDIT_LOG_RETENTION_DAYS must be greater than 0")
|
||||
})
|
||||
|
||||
t.Run("should parse string environment variables correctly", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "postgres")
|
||||
t.Setenv("DB_CONNECTION_STRING", "postgres://test")
|
||||
t.Setenv("APP_URL", "https://prod.example.com")
|
||||
t.Setenv("APP_ENV", "PRODUCTION")
|
||||
t.Setenv("UPLOAD_PATH", "/custom/uploads")
|
||||
t.Setenv("KEYS_PATH", "/custom/keys")
|
||||
t.Setenv("PORT", "8080")
|
||||
t.Setenv("HOST", "LOCALHOST")
|
||||
t.Setenv("UNIX_SOCKET", "/tmp/app.sock")
|
||||
t.Setenv("MAXMIND_LICENSE_KEY", "test-license")
|
||||
t.Setenv("GEOLITE_DB_PATH", "/custom/geolite.mmdb")
|
||||
|
||||
err := parseEnvConfig()
|
||||
err := parseAndValidateEnvConfig(t)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, AppEnvProduction, EnvConfig.AppEnv) // lowercased
|
||||
assert.Equal(t, "/custom/uploads", EnvConfig.UploadPath)
|
||||
@@ -246,38 +151,24 @@ func TestParseEnvConfig(t *testing.T) {
|
||||
|
||||
t.Run("should normalize file backend and default upload path", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
t.Setenv("FILE_BACKEND", "FILESYSTEM")
|
||||
t.Setenv("UPLOAD_PATH", "")
|
||||
|
||||
err := parseEnvConfig()
|
||||
err := parseAndValidateEnvConfig(t)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "filesystem", EnvConfig.FileBackend)
|
||||
assert.Equal(t, defaultFsUploadPath, EnvConfig.UploadPath)
|
||||
})
|
||||
|
||||
t.Run("should fail when FILE_BACKEND is s3 but keys are stored on filesystem", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
t.Setenv("FILE_BACKEND", "s3")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "KEYS_STORAGE cannot be 'file' when FILE_BACKEND is 's3'")
|
||||
})
|
||||
|
||||
t.Run("should fail with invalid FILE_BACKEND value", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
t.Setenv("FILE_BACKEND", "invalid")
|
||||
|
||||
err := parseEnvConfig()
|
||||
err := parseAndValidateEnvConfig(t)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "invalid FILE_BACKEND value")
|
||||
})
|
||||
|
||||
@@ -23,13 +23,11 @@ func NewAppImagesController(
|
||||
}
|
||||
|
||||
group.GET("/application-images/logo", controller.getLogoHandler)
|
||||
group.GET("/application-images/email", controller.getEmailLogoHandler)
|
||||
group.GET("/application-images/background", controller.getBackgroundImageHandler)
|
||||
group.GET("/application-images/favicon", controller.getFaviconHandler)
|
||||
group.GET("/application-images/default-profile-picture", authMiddleware.Add(), controller.getDefaultProfilePicture)
|
||||
|
||||
group.PUT("/application-images/logo", authMiddleware.Add(), controller.updateLogoHandler)
|
||||
group.PUT("/application-images/email", authMiddleware.Add(), controller.updateEmailLogoHandler)
|
||||
group.PUT("/application-images/background", authMiddleware.Add(), controller.updateBackgroundImageHandler)
|
||||
group.PUT("/application-images/favicon", authMiddleware.Add(), controller.updateFaviconHandler)
|
||||
group.PUT("/application-images/default-profile-picture", authMiddleware.Add(), controller.updateDefaultProfilePicture)
|
||||
@@ -61,18 +59,6 @@ func (c *AppImagesController) getLogoHandler(ctx *gin.Context) {
|
||||
c.getImage(ctx, imageName)
|
||||
}
|
||||
|
||||
// getEmailLogoHandler godoc
|
||||
// @Summary Get email logo image
|
||||
// @Description Get the email logo image for use in emails
|
||||
// @Tags Application Images
|
||||
// @Produce image/png
|
||||
// @Produce image/jpeg
|
||||
// @Success 200 {file} binary "Email logo image"
|
||||
// @Router /api/application-images/email [get]
|
||||
func (c *AppImagesController) getEmailLogoHandler(ctx *gin.Context) {
|
||||
c.getImage(ctx, "logoEmail")
|
||||
}
|
||||
|
||||
// getBackgroundImageHandler godoc
|
||||
// @Summary Get background image
|
||||
// @Description Get the background image for the application
|
||||
@@ -138,37 +124,6 @@ func (c *AppImagesController) updateLogoHandler(ctx *gin.Context) {
|
||||
ctx.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// updateEmailLogoHandler godoc
|
||||
// @Summary Update email logo
|
||||
// @Description Update the email logo for use in emails
|
||||
// @Tags Application Images
|
||||
// @Accept multipart/form-data
|
||||
// @Param file formData file true "Email logo image file"
|
||||
// @Success 204 "No Content"
|
||||
// @Router /api/application-images/email [put]
|
||||
func (c *AppImagesController) updateEmailLogoHandler(ctx *gin.Context) {
|
||||
file, err := ctx.FormFile("file")
|
||||
if err != nil {
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
fileType := utils.GetFileExtension(file.Filename)
|
||||
mimeType := utils.GetImageMimeType(fileType)
|
||||
|
||||
if mimeType != "image/png" && mimeType != "image/jpeg" {
|
||||
_ = ctx.Error(&common.WrongFileTypeError{ExpectedFileType: ".png or .jpg/jpeg"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.appImagesService.UpdateImage(ctx.Request.Context(), file, "logoEmail"); err != nil {
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// updateBackgroundImageHandler godoc
|
||||
// @Summary Update background image
|
||||
// @Description Update the application background image
|
||||
|
||||
@@ -40,6 +40,11 @@ func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := tc.TestService.ResetLock(c.Request.Context()); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := tc.TestService.ResetApplicationImages(c.Request.Context()); err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -69,8 +74,6 @@ func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
tc.TestService.SetJWTKeys()
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
|
||||
@@ -545,7 +545,7 @@ func (uc *UserController) createSignupTokenHandler(c *gin.Context) {
|
||||
ttl = defaultSignupTokenDuration
|
||||
}
|
||||
|
||||
signupToken, err := uc.userService.CreateSignupToken(c.Request.Context(), ttl, input.UsageLimit, input.UserGroupIDs)
|
||||
signupToken, err := uc.userService.CreateSignupToken(c.Request.Context(), ttl, input.UsageLimit)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
|
||||
@@ -47,7 +47,7 @@ type AppConfigUpdateDto struct {
|
||||
LdapAttributeGroupMember string `json:"ldapAttributeGroupMember"`
|
||||
LdapAttributeGroupUniqueIdentifier string `json:"ldapAttributeGroupUniqueIdentifier"`
|
||||
LdapAttributeGroupName string `json:"ldapAttributeGroupName"`
|
||||
LdapAttributeAdminGroup string `json:"ldapAttributeAdminGroup"`
|
||||
LdapAdminGroupName string `json:"ldapAdminGroupName"`
|
||||
LdapSoftDeleteUsers string `json:"ldapSoftDeleteUsers"`
|
||||
EmailOneTimeAccessAsAdminEnabled string `json:"emailOneTimeAccessAsAdminEnabled" binding:"required"`
|
||||
EmailOneTimeAccessAsUnauthenticatedEnabled string `json:"emailOneTimeAccessAsUnauthenticatedEnabled" binding:"required"`
|
||||
|
||||
@@ -18,7 +18,6 @@ type OidcClientDto struct {
|
||||
IsPublic bool `json:"isPublic"`
|
||||
PkceEnabled bool `json:"pkceEnabled"`
|
||||
Credentials OidcClientCredentialsDto `json:"credentials"`
|
||||
IsGroupRestricted bool `json:"isGroupRestricted"`
|
||||
}
|
||||
|
||||
type OidcClientWithAllowedUserGroupsDto struct {
|
||||
@@ -44,7 +43,6 @@ type OidcClientUpdateDto struct {
|
||||
HasDarkLogo bool `json:"hasDarkLogo"`
|
||||
LogoURL *string `json:"logoUrl"`
|
||||
DarkLogoURL *string `json:"darkLogoUrl"`
|
||||
IsGroupRestricted bool `json:"isGroupRestricted"`
|
||||
}
|
||||
|
||||
type OidcClientCreateDto struct {
|
||||
|
||||
@@ -6,9 +6,8 @@ import (
|
||||
)
|
||||
|
||||
type SignupTokenCreateDto struct {
|
||||
TTL utils.JSONDuration `json:"ttl" binding:"required,ttl"`
|
||||
UsageLimit int `json:"usageLimit" binding:"required,min=1,max=100"`
|
||||
UserGroupIDs []string `json:"userGroupIds"`
|
||||
TTL utils.JSONDuration `json:"ttl" binding:"required,ttl"`
|
||||
UsageLimit int `json:"usageLimit" binding:"required,min=1,max=100"`
|
||||
}
|
||||
|
||||
type SignupTokenDto struct {
|
||||
@@ -17,6 +16,5 @@ type SignupTokenDto struct {
|
||||
ExpiresAt datatype.DateTime `json:"expiresAt"`
|
||||
UsageLimit int `json:"usageLimit"`
|
||||
UsageCount int `json:"usageCount"`
|
||||
UserGroups []UserGroupDto `json:"userGroups"`
|
||||
CreatedAt datatype.DateTime `json:"createdAt"`
|
||||
}
|
||||
|
||||
@@ -23,16 +23,15 @@ type UserDto struct {
|
||||
}
|
||||
|
||||
type UserCreateDto struct {
|
||||
Username string `json:"username" binding:"required,username,min=2,max=50" unorm:"nfc"`
|
||||
Email *string `json:"email" binding:"omitempty,email" unorm:"nfc"`
|
||||
FirstName string `json:"firstName" binding:"required,min=1,max=50" unorm:"nfc"`
|
||||
LastName string `json:"lastName" binding:"max=50" unorm:"nfc"`
|
||||
DisplayName string `json:"displayName" binding:"required,min=1,max=100" unorm:"nfc"`
|
||||
IsAdmin bool `json:"isAdmin"`
|
||||
Locale *string `json:"locale"`
|
||||
Disabled bool `json:"disabled"`
|
||||
UserGroupIds []string `json:"userGroupIds"`
|
||||
LdapID string `json:"-"`
|
||||
Username string `json:"username" binding:"required,username,min=2,max=50" unorm:"nfc"`
|
||||
Email *string `json:"email" binding:"omitempty,email" unorm:"nfc"`
|
||||
FirstName string `json:"firstName" binding:"required,min=1,max=50" unorm:"nfc"`
|
||||
LastName string `json:"lastName" binding:"max=50" unorm:"nfc"`
|
||||
DisplayName string `json:"displayName" binding:"required,min=1,max=100" unorm:"nfc"`
|
||||
IsAdmin bool `json:"isAdmin"`
|
||||
Locale *string `json:"locale"`
|
||||
Disabled bool `json:"disabled"`
|
||||
LdapID string `json:"-"`
|
||||
}
|
||||
|
||||
func (u UserCreateDto) Validate() error {
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"github.com/go-co-op/gocron/v2"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
)
|
||||
@@ -120,13 +119,11 @@ func (j *DbCleanupJobs) clearReauthenticationTokens(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearAuditLogs deletes audit logs older than the configured retention window
|
||||
// ClearAuditLogs deletes audit logs older than 90 days
|
||||
func (j *DbCleanupJobs) clearAuditLogs(ctx context.Context) error {
|
||||
cutoff := time.Now().AddDate(0, 0, -common.EnvConfig.AuditLogRetentionDays)
|
||||
|
||||
st := j.db.
|
||||
WithContext(ctx).
|
||||
Delete(&model.AuditLog{}, "created_at < ?", datatype.DateTime(cutoff))
|
||||
Delete(&model.AuditLog{}, "created_at < ?", datatype.DateTime(time.Now().AddDate(0, 0, -90)))
|
||||
if st.Error != nil {
|
||||
return fmt.Errorf("failed to delete old audit logs: %w", st.Error)
|
||||
}
|
||||
|
||||
@@ -77,7 +77,7 @@ type AppConfig struct {
|
||||
LdapAttributeGroupMember AppConfigVariable `key:"ldapAttributeGroupMember"`
|
||||
LdapAttributeGroupUniqueIdentifier AppConfigVariable `key:"ldapAttributeGroupUniqueIdentifier"`
|
||||
LdapAttributeGroupName AppConfigVariable `key:"ldapAttributeGroupName"`
|
||||
LdapAttributeAdminGroup AppConfigVariable `key:"ldapAttributeAdminGroup"`
|
||||
LdapAdminGroupName AppConfigVariable `key:"ldapAdminGroupName"`
|
||||
LdapSoftDeleteUsers AppConfigVariable `key:"ldapSoftDeleteUsers"`
|
||||
}
|
||||
|
||||
|
||||
@@ -58,7 +58,6 @@ type OidcClient struct {
|
||||
RequiresReauthentication bool `sortable:"true" filterable:"true"`
|
||||
Credentials OidcClientCredentials
|
||||
LaunchURL *string
|
||||
IsGroupRestricted bool
|
||||
|
||||
AllowedUserGroups []UserGroup `gorm:"many2many:oidc_clients_allowed_user_groups;"`
|
||||
CreatedByID *string
|
||||
|
||||
@@ -13,7 +13,6 @@ type SignupToken struct {
|
||||
ExpiresAt datatype.DateTime `json:"expiresAt" sortable:"true"`
|
||||
UsageLimit int `json:"usageLimit" sortable:"true"`
|
||||
UsageCount int `json:"usageCount" sortable:"true"`
|
||||
UserGroups []UserGroup `gorm:"many2many:signup_tokens_user_groups;"`
|
||||
}
|
||||
|
||||
func (st *SignupToken) IsExpired() bool {
|
||||
|
||||
@@ -11,6 +11,15 @@ import (
|
||||
// DateTime custom type for time.Time to store date as unix timestamp for sqlite and as date for postgres
|
||||
type DateTime time.Time //nolint:recvcheck
|
||||
|
||||
func DateTimeFromString(str string) (DateTime, error) {
|
||||
t, err := time.Parse(time.RFC3339Nano, str)
|
||||
if err != nil {
|
||||
return DateTime{}, fmt.Errorf("failed to parse date string: %w", err)
|
||||
}
|
||||
|
||||
return DateTime(t), nil
|
||||
}
|
||||
|
||||
func (date *DateTime) Scan(value any) (err error) {
|
||||
switch v := value.(type) {
|
||||
case time.Time:
|
||||
|
||||
@@ -102,7 +102,7 @@ func (s *AppConfigService) getDefaultDbConfig() *model.AppConfig {
|
||||
LdapAttributeGroupMember: model.AppConfigVariable{Value: "member"},
|
||||
LdapAttributeGroupUniqueIdentifier: model.AppConfigVariable{},
|
||||
LdapAttributeGroupName: model.AppConfigVariable{},
|
||||
LdapAttributeAdminGroup: model.AppConfigVariable{},
|
||||
LdapAdminGroupName: model.AppConfigVariable{},
|
||||
LdapSoftDeleteUsers: model.AppConfigVariable{Value: "true"},
|
||||
}
|
||||
}
|
||||
|
||||
296
backend/internal/service/app_lock_service.go
Normal file
296
backend/internal/service/app_lock_service.go
Normal file
@@ -0,0 +1,296 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrLockUnavailable = errors.New("lock is already held by another process")
|
||||
ErrLockLost = errors.New("lock ownership lost")
|
||||
)
|
||||
|
||||
const (
|
||||
ttl = 30 * time.Second
|
||||
renewInterval = 20 * time.Second
|
||||
renewRetries = 3
|
||||
lockKey = "application_lock"
|
||||
)
|
||||
|
||||
type AppLockService struct {
|
||||
db *gorm.DB
|
||||
lockID string
|
||||
processID int64
|
||||
hostID string
|
||||
}
|
||||
|
||||
func NewAppLockService(db *gorm.DB) *AppLockService {
|
||||
host, err := os.Hostname()
|
||||
if err != nil || host == "" {
|
||||
host = "unknown-host"
|
||||
}
|
||||
|
||||
return &AppLockService{
|
||||
db: db,
|
||||
processID: int64(os.Getpid()),
|
||||
hostID: host,
|
||||
lockID: uuid.NewString(),
|
||||
}
|
||||
}
|
||||
|
||||
type lockValue struct {
|
||||
ProcessID int64 `json:"process_id"`
|
||||
HostID string `json:"host_id"`
|
||||
LockID string `json:"lock_id"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
}
|
||||
|
||||
func (lv *lockValue) Marshal() (string, error) {
|
||||
data, err := json.Marshal(lv)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
func (lv *lockValue) Unmarshal(raw string) error {
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal([]byte(raw), lv)
|
||||
}
|
||||
|
||||
// Acquire obtains the lock. When force is true, the lock is stolen from any existing owner.
|
||||
// If the lock is forcefully acquired, it blocks until the previous lock has expired.
|
||||
func (s *AppLockService) Acquire(ctx context.Context, force bool) (waitUntil time.Time, err error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
var prevLockRaw string
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Model(&model.KV{}).
|
||||
Where("key = ?", lockKey).
|
||||
Clauses(clause.Locking{Strength: "UPDATE"}).
|
||||
Select("value").
|
||||
Scan(&prevLockRaw).
|
||||
Error
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("query existing lock: %w", err)
|
||||
}
|
||||
|
||||
var prevLock lockValue
|
||||
if prevLockRaw != "" {
|
||||
if err := prevLock.Unmarshal(prevLockRaw); err != nil {
|
||||
return time.Time{}, fmt.Errorf("decode existing lock value: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
nowUnix := now.Unix()
|
||||
|
||||
value := lockValue{
|
||||
ProcessID: s.processID,
|
||||
HostID: s.hostID,
|
||||
LockID: s.lockID,
|
||||
ExpiresAt: now.Add(ttl).Unix(),
|
||||
}
|
||||
raw, err := value.Marshal()
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("encode lock value: %w", err)
|
||||
}
|
||||
|
||||
var query string
|
||||
switch s.db.Name() {
|
||||
case "sqlite":
|
||||
query = `
|
||||
INSERT INTO kv (key, value)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(key) DO UPDATE SET
|
||||
value = excluded.value
|
||||
WHERE (json_extract(kv.value, '$.expires_at') < ?) OR ?
|
||||
`
|
||||
case "postgres":
|
||||
query = `
|
||||
INSERT INTO kv (key, value)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT(key) DO UPDATE SET
|
||||
value = excluded.value
|
||||
WHERE ((kv.value::json->>'expires_at')::bigint < $3) OR ($4::boolean IS TRUE)
|
||||
`
|
||||
default:
|
||||
return time.Time{}, fmt.Errorf("unsupported database dialect: %s", s.db.Name())
|
||||
}
|
||||
|
||||
res := tx.WithContext(ctx).Exec(query, lockKey, raw, nowUnix, force)
|
||||
if res.Error != nil {
|
||||
return time.Time{}, fmt.Errorf("lock acquisition failed: %w", res.Error)
|
||||
}
|
||||
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return time.Time{}, fmt.Errorf("commit lock acquisition: %w", err)
|
||||
}
|
||||
|
||||
// If there is a lock that is not expired and force is false, no rows will be affected
|
||||
if res.RowsAffected == 0 {
|
||||
return time.Time{}, ErrLockUnavailable
|
||||
}
|
||||
|
||||
if force && prevLock.ExpiresAt > nowUnix && prevLock.LockID != s.lockID {
|
||||
waitUntil = time.Unix(prevLock.ExpiresAt, 0)
|
||||
}
|
||||
|
||||
attrs := []any{
|
||||
slog.Int64("process_id", s.processID),
|
||||
slog.String("host_id", s.hostID),
|
||||
}
|
||||
if wait := time.Until(waitUntil); wait > 0 {
|
||||
attrs = append(attrs, slog.Duration("wait_before_proceeding", wait))
|
||||
}
|
||||
slog.Info("Acquired application lock", attrs...)
|
||||
|
||||
return waitUntil, nil
|
||||
}
|
||||
|
||||
// RunRenewal keeps renewing the lock until the context is canceled.
|
||||
func (s *AppLockService) RunRenewal(ctx context.Context) error {
|
||||
ticker := time.NewTicker(renewInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-ticker.C:
|
||||
if err := s.renew(ctx); err != nil {
|
||||
return fmt.Errorf("renew lock: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Release releases the lock if it is held by this process.
|
||||
func (s *AppLockService) Release(ctx context.Context) error {
|
||||
opCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var query string
|
||||
switch s.db.Name() {
|
||||
case "sqlite":
|
||||
query = `
|
||||
DELETE FROM kv
|
||||
WHERE key = ?
|
||||
AND json_extract(value, '$.lock_id') = ?
|
||||
`
|
||||
case "postgres":
|
||||
query = `
|
||||
DELETE FROM kv
|
||||
WHERE key = $1
|
||||
AND value::json->>'lock_id' = $2
|
||||
`
|
||||
default:
|
||||
return fmt.Errorf("unsupported database dialect: %s", s.db.Name())
|
||||
}
|
||||
|
||||
res := s.db.WithContext(opCtx).Exec(query, lockKey, s.lockID)
|
||||
if res.Error != nil {
|
||||
return fmt.Errorf("release lock failed: %w", res.Error)
|
||||
}
|
||||
|
||||
if res.RowsAffected == 0 {
|
||||
slog.Warn("Application lock not held by this process, cannot release",
|
||||
slog.Int64("process_id", s.processID),
|
||||
slog.String("host_id", s.hostID),
|
||||
)
|
||||
}
|
||||
|
||||
slog.Info("Released application lock",
|
||||
slog.Int64("process_id", s.processID),
|
||||
slog.String("host_id", s.hostID),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// renew tries to renew the lock, retrying up to renewRetries times (sleeping 1s between attempts).
|
||||
func (s *AppLockService) renew(ctx context.Context) error {
|
||||
var lastErr error
|
||||
for attempt := 1; attempt <= renewRetries; attempt++ {
|
||||
now := time.Now()
|
||||
nowUnix := now.Unix()
|
||||
expiresAt := now.Add(ttl).Unix()
|
||||
|
||||
value := lockValue{
|
||||
LockID: s.lockID,
|
||||
ProcessID: s.processID,
|
||||
HostID: s.hostID,
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
raw, err := value.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode lock value: %w", err)
|
||||
}
|
||||
|
||||
var query string
|
||||
switch s.db.Name() {
|
||||
case "sqlite":
|
||||
query = `
|
||||
UPDATE kv
|
||||
SET value = ?
|
||||
WHERE key = ?
|
||||
AND json_extract(value, '$.lock_id') = ?
|
||||
AND json_extract(value, '$.expires_at') > ?
|
||||
`
|
||||
case "postgres":
|
||||
query = `
|
||||
UPDATE kv
|
||||
SET value = $1
|
||||
WHERE key = $2
|
||||
AND value::json->>'lock_id' = $3
|
||||
AND ((value::json->>'expires_at')::bigint > $4)
|
||||
`
|
||||
default:
|
||||
return fmt.Errorf("unsupported database dialect: %s", s.db.Name())
|
||||
}
|
||||
|
||||
opCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
res := s.db.WithContext(opCtx).Exec(query, raw, lockKey, s.lockID, nowUnix)
|
||||
cancel()
|
||||
|
||||
switch {
|
||||
case res.Error != nil:
|
||||
lastErr = fmt.Errorf("lock renewal failed: %w", res.Error)
|
||||
case res.RowsAffected == 0:
|
||||
// Must be after checking res.Error
|
||||
return ErrLockLost
|
||||
default:
|
||||
slog.Debug("Renewed application lock",
|
||||
slog.Int64("process_id", s.processID),
|
||||
slog.String("host_id", s.hostID),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Wait before next attempt or cancel if context is done
|
||||
if attempt < renewRetries {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(1 * time.Second):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
189
backend/internal/service/app_lock_service_test.go
Normal file
189
backend/internal/service/app_lock_service_test.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
|
||||
)
|
||||
|
||||
func newTestAppLockService(t *testing.T, db *gorm.DB) *AppLockService {
|
||||
t.Helper()
|
||||
|
||||
return &AppLockService{
|
||||
db: db,
|
||||
processID: 1,
|
||||
hostID: "test-host",
|
||||
lockID: "a13c7673-c7ae-49f1-9112-2cd2d0d4b0c1",
|
||||
}
|
||||
}
|
||||
|
||||
func insertLock(t *testing.T, db *gorm.DB, value lockValue) {
|
||||
t.Helper()
|
||||
|
||||
raw, err := value.Marshal()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Create(&model.KV{Key: lockKey, Value: &raw}).Error
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func readLockValue(t *testing.T, db *gorm.DB) lockValue {
|
||||
t.Helper()
|
||||
|
||||
var row model.KV
|
||||
err := db.Take(&row, "key = ?", lockKey).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, row.Value)
|
||||
|
||||
var value lockValue
|
||||
err = value.Unmarshal(*row.Value)
|
||||
require.NoError(t, err)
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
func TestAppLockServiceAcquire(t *testing.T) {
|
||||
t.Run("creates new lock when none exists", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
service := newTestAppLockService(t, db)
|
||||
|
||||
_, err := service.Acquire(context.Background(), false)
|
||||
require.NoError(t, err)
|
||||
|
||||
stored := readLockValue(t, db)
|
||||
require.Equal(t, service.processID, stored.ProcessID)
|
||||
require.Equal(t, service.hostID, stored.HostID)
|
||||
require.Greater(t, stored.ExpiresAt, time.Now().Unix())
|
||||
})
|
||||
|
||||
t.Run("returns ErrLockUnavailable when lock held by another process", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
service := newTestAppLockService(t, db)
|
||||
|
||||
existing := lockValue{
|
||||
ProcessID: 99,
|
||||
HostID: "other-host",
|
||||
ExpiresAt: time.Now().Add(ttl).Unix(),
|
||||
}
|
||||
insertLock(t, db, existing)
|
||||
|
||||
_, err := service.Acquire(context.Background(), false)
|
||||
require.ErrorIs(t, err, ErrLockUnavailable)
|
||||
|
||||
current := readLockValue(t, db)
|
||||
require.Equal(t, existing, current)
|
||||
})
|
||||
|
||||
t.Run("force acquisition steals lock", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
service := newTestAppLockService(t, db)
|
||||
|
||||
insertLock(t, db, lockValue{
|
||||
ProcessID: 99,
|
||||
HostID: "other-host",
|
||||
ExpiresAt: time.Now().Unix(),
|
||||
})
|
||||
|
||||
_, err := service.Acquire(context.Background(), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
stored := readLockValue(t, db)
|
||||
require.Equal(t, service.processID, stored.ProcessID)
|
||||
require.Equal(t, service.hostID, stored.HostID)
|
||||
require.Greater(t, stored.ExpiresAt, time.Now().Unix())
|
||||
})
|
||||
}
|
||||
|
||||
func TestAppLockServiceRelease(t *testing.T) {
|
||||
t.Run("removes owned lock", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
service := newTestAppLockService(t, db)
|
||||
|
||||
_, err := service.Acquire(context.Background(), false)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = service.Release(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
var row model.KV
|
||||
err = db.Take(&row, "key = ?", lockKey).Error
|
||||
require.ErrorIs(t, err, gorm.ErrRecordNotFound)
|
||||
})
|
||||
|
||||
t.Run("ignores lock held by another owner", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
service := newTestAppLockService(t, db)
|
||||
|
||||
existing := lockValue{
|
||||
ProcessID: 2,
|
||||
HostID: "other-host",
|
||||
ExpiresAt: time.Now().Add(ttl).Unix(),
|
||||
}
|
||||
insertLock(t, db, existing)
|
||||
|
||||
err := service.Release(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
stored := readLockValue(t, db)
|
||||
require.Equal(t, existing, stored)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAppLockServiceRenew(t *testing.T) {
|
||||
t.Run("extends expiration when lock is still owned", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
service := newTestAppLockService(t, db)
|
||||
|
||||
_, err := service.Acquire(context.Background(), false)
|
||||
require.NoError(t, err)
|
||||
|
||||
before := readLockValue(t, db)
|
||||
|
||||
err = service.renew(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
after := readLockValue(t, db)
|
||||
require.Equal(t, service.processID, after.ProcessID)
|
||||
require.Equal(t, service.hostID, after.HostID)
|
||||
require.GreaterOrEqual(t, after.ExpiresAt, before.ExpiresAt)
|
||||
})
|
||||
|
||||
t.Run("returns ErrLockLost when lock is missing", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
service := newTestAppLockService(t, db)
|
||||
|
||||
err := service.renew(context.Background())
|
||||
require.ErrorIs(t, err, ErrLockLost)
|
||||
})
|
||||
|
||||
t.Run("returns ErrLockLost when ownership changed", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
service := newTestAppLockService(t, db)
|
||||
|
||||
_, err := service.Acquire(context.Background(), false)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate a different process taking the lock.
|
||||
newOwner := lockValue{
|
||||
ProcessID: 9,
|
||||
HostID: "stolen-host",
|
||||
ExpiresAt: time.Now().Add(ttl).Unix(),
|
||||
}
|
||||
raw, marshalErr := newOwner.Marshal()
|
||||
require.NoError(t, marshalErr)
|
||||
updateErr := db.Model(&model.KV{}).
|
||||
Where("key = ?", lockKey).
|
||||
Update("value", raw).Error
|
||||
require.NoError(t, updateErr)
|
||||
|
||||
err = service.renew(context.Background())
|
||||
require.ErrorIs(t, err, ErrLockLost)
|
||||
})
|
||||
}
|
||||
@@ -7,14 +7,12 @@ import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/go-webauthn/webauthn/protocol"
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
@@ -36,15 +34,17 @@ type TestService struct {
|
||||
appConfigService *AppConfigService
|
||||
ldapService *LdapService
|
||||
fileStorage storage.FileStorage
|
||||
appLockService *AppLockService
|
||||
externalIdPKey jwk.Key
|
||||
}
|
||||
|
||||
func NewTestService(db *gorm.DB, appConfigService *AppConfigService, jwtService *JwtService, ldapService *LdapService, fileStorage storage.FileStorage) (*TestService, error) {
|
||||
func NewTestService(db *gorm.DB, appConfigService *AppConfigService, jwtService *JwtService, ldapService *LdapService, appLockService *AppLockService, fileStorage storage.FileStorage) (*TestService, error) {
|
||||
s := &TestService{
|
||||
db: db,
|
||||
appConfigService: appConfigService,
|
||||
jwtService: jwtService,
|
||||
ldapService: ldapService,
|
||||
appLockService: appLockService,
|
||||
fileStorage: fileStorage,
|
||||
}
|
||||
err := s.initExternalIdP()
|
||||
@@ -288,8 +288,8 @@ func (s *TestService) SeedDatabase(baseURL string) error {
|
||||
// openssl genpkey -algorithm EC -pkeyopt ec_paramgen_curve:P-256 | \
|
||||
// openssl pkcs8 -topk8 -nocrypt | tee >(openssl pkey -pubout)
|
||||
|
||||
publicKeyPasskey1, _ := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEwcOo5KV169KR67QEHrcYkeXE3CCxv2BgwnSq4VYTQxyLtdmKxegexa8JdwFKhKXa2BMI9xaN15BoL6wSCRFJhg==")
|
||||
publicKeyPasskey2, _ := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEj4qA0PrZzg8Co1C27nyUbzrp8Ewjr7eOlGI2LfrzmbL5nPhZRAdJ3hEaqrHMSnJBhfMqtQGKwDYpaLIQFAKLhw==")
|
||||
publicKeyPasskey1, _ := base64.StdEncoding.DecodeString("pQMmIAEhWCDBw6jkpXXr0pHrtAQetxiR5cTcILG/YGDCdKrhVhNDHCJYIIu12YrF6B7Frwl3AUqEpdrYEwj3Fo3XkGgvrBIJEUmGAQI=")
|
||||
publicKeyPasskey2, _ := base64.StdEncoding.DecodeString("pSJYIPmc+FlEB0neERqqscxKckGF8yq1AYrANiloshAUAouHAQIDJiABIVggj4qA0PrZzg8Co1C27nyUbzrp8Ewjr7eOlGI2LfrzmbI=")
|
||||
webauthnCredentials := []model.WebauthnCredential{
|
||||
{
|
||||
Name: "Passkey 1",
|
||||
@@ -318,6 +318,10 @@ func (s *TestService) SeedDatabase(baseURL string) error {
|
||||
Challenge: "challenge",
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)),
|
||||
UserVerification: "preferred",
|
||||
CredentialParams: model.CredentialParameters{
|
||||
{Type: "public-key", Algorithm: -7},
|
||||
{Type: "public-key", Algorithm: -257},
|
||||
},
|
||||
}
|
||||
if err := tx.Create(&webauthnSession).Error; err != nil {
|
||||
return err
|
||||
@@ -327,9 +331,10 @@ func (s *TestService) SeedDatabase(baseURL string) error {
|
||||
Base: model.Base{
|
||||
ID: "5f1fa856-c164-4295-961e-175a0d22d725",
|
||||
},
|
||||
Name: "Test API Key",
|
||||
Key: "6c34966f57ef2bb7857649aff0e7ab3ad67af93c846342ced3f5a07be8706c20",
|
||||
UserID: users[0].ID,
|
||||
Name: "Test API Key",
|
||||
Key: "6c34966f57ef2bb7857649aff0e7ab3ad67af93c846342ced3f5a07be8706c20",
|
||||
UserID: users[0].ID,
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(30 * 24 * time.Hour)),
|
||||
}
|
||||
if err := tx.Create(&apiKey).Error; err != nil {
|
||||
return err
|
||||
@@ -344,9 +349,6 @@ func (s *TestService) SeedDatabase(baseURL string) error {
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(24 * time.Hour)),
|
||||
UsageLimit: 1,
|
||||
UsageCount: 0,
|
||||
UserGroups: []model.UserGroup{
|
||||
userGroups[0],
|
||||
},
|
||||
},
|
||||
{
|
||||
Base: model.Base{
|
||||
@@ -382,6 +384,20 @@ func (s *TestService) SeedDatabase(baseURL string) error {
|
||||
}
|
||||
}
|
||||
|
||||
keyValues := []model.KV{
|
||||
{
|
||||
Key: jwkutils.PrivateKeyDBKey,
|
||||
// {"alg":"RS256","d":"mvMDWSdPPvcum0c0iEHE2gbqtV2NKMmLwrl9E6K7g8lTV95SePLnW_bwyMPV7EGp7PQk3l17I5XRhFjze7GqTnFIOgKzMianPs7jv2ELtBMGK0xOPATgu1iGb70xZ6vcvuEfRyY3dJ0zr4jpUdVuXwKmx9rK4IdZn2dFCKfvSuspqIpz11RhF1ALrqDLkxGVv7ZwNh0_VhJZU9hcjG5l6xc7rQEKpPRkZp0IdjkGS8Z0FskoVaiRIWAbZuiVFB9WCW8k1czC4HQTPLpII01bUQx2ludbm0UlXRgVU9ptUUbU7GAImQqTOW8LfPGklEvcgzlIlR_oqw4P9yBxLi-yMQ","dp":"pvNCSnnhbo8Igw9psPR-DicxFnkXlu_ix4gpy6efTrxA-z1VDFDioJ814vKQNioYDzpyAP1gfMPhRkvG_q0hRZsJah3Sb9dfA-WkhSWY7lURQP4yIBTMU0PF_rEATuS7lRciYk1SOx5fqXZd3m_LP0vpBC4Ujlq6NAq6CIjCnms","dq":"TtUVGCCkPNgfOLmkYXu7dxxUCV5kB01-xAEK2OY0n0pG8vfDophH4_D_ZC7nvJ8J9uDhs_3JStexq1lIvaWtG99RNTChIEDzpdn6GH9yaVcb_eB4uJjrNm64FhF8PGCCwxA-xMCZMaARKwhMB2_IOMkxUbWboL3gnhJ2rDO_QO0","e":"AQAB","kid":"8uHDw3M6rf8","kty":"RSA","n":"yaeEL0VKoPBXIAaWXsUgmu05lAvEIIdJn0FX9lHh4JE5UY9B83C5sCNdhs9iSWzpeP11EVjWp8i3Yv2CF7c7u50BXnVBGtxpZpFC-585UXacoJ0chUmarL9GRFJcM1nPHBTFu68aRrn1rIKNHUkNaaxFo0NFGl_4EDDTO8HwawTjwkPoQlRzeByhlvGPVvwgB3Fn93B8QJ_cZhXKxJvjjrC_8Pk76heC_ntEMru71Ix77BoC3j2TuyiN7m9RNBW8BU5q6lKoIdvIeZfTFLzi37iufyfvMrJTixp9zhNB1NxlLCeOZl2MXegtiGqd2H3cbAyqoOiv9ihUWTfXj7SxJw","p":"_Yylc9e07CKdqNRD2EosMC2mrhrEa9j5oY_l00Qyy4-jmCA59Q9viyqvveRo0U7cRvFA5BWgWN6GGLh1DG3X-QBqVr0dnk3uzbobb55RYUXyPLuBZI2q6w2oasbiDwPdY7KpkVv_H-bpITQlyDvO8hhucA6rUV7F6KTQVz8M3Ms","q":"y5p3hch-7jJ21TkAhp_Vk1fLCAuD4tbErwQs2of9ja8sB4iJOs5Wn6HD3P7Mc8Plye7qaLHvzc8I5g0tPKWvC0DPd_FLPXiWwMVAzee3NUX_oGeJNOQp11y1w_KqdO9qZqHSEPZ3NcFL_SZMFgggxhM1uzRiPzsVN0lnD_6prZU","qi":"2Grt6uXHm61ji3xSdkBWNtUnj19vS1-7rFJp5SoYztVQVThf_W52BAiXKBdYZDRVoItC_VS2NvAOjeJjhYO_xQ_q3hK7MdtuXfEPpLnyXKkmWo3lrJ26wbeF6l05LexCkI7ShsOuSt-dsyaTJTszuKDIA6YOfWvfo3aVZmlWRaI","use":"sig"}
|
||||
Value: utils.Ptr("7d/5hl7diJ2rnFL14hEAQf9tzpu29aqXQ8jpJ2iqqKUNFZpdOkEpud0CmRv4H3r8yyk2u/Gqqj9klSy58DJkYXGF5PAYgLyoBIb7L3JXWRbxg4cQ3QJCug13l2OTmpAKoVc+rmX8c3j3h1sNqyJ+7Ql5sS0jSeyiYgIsFNCdnK5alBDyvtcpe/QDpklmP4JCeVpvmf2rLGplk3g5UO5ydJ8UiDXxfDmi+gF6NKJvrGnnah8Ar3G/x88z+tTJtp0DIQFwxXwUM2XZqzEVGm8K2r0w5o9/Keh6bBBaiuH2C78ZOaijGV3DovhR+e9J0cYUYGwT42MZMx9fSWQ/lvWGGnf+Uq3MXJfjWSREfhkp8KTQwR9F7+dnVJWswOEk7jPR8I7hCWTMxJyvaFX3wgAXIVmhrgXZQQbYOqTt56IoqUl0xOJku8dA8opg2UcLlmmuOh6+hfkXKsiiS/H/9c1BVIGj1fCOiT6IePh4wKKSTbwJnPD5EKmdJpgTsUpjcDnXQKY4ReO0UpdRdKxwRDDLeQuG6j+ljGxR9GPudCU9Nmci6rFVI6n5LWYkQxBA1O73RpmXRZPDzntDfpXMEonkmSvOoxaCK2Id7CRKMdqvR0kEouwnhk5WSFtsfi3sA0pkXzPFxwZeWM8vFtbffZOZzXaOhxCOfcj1NClZohlZhyc4jvkxmrpY7PSaAzih0AmHI7y0LYFi6fZu/K4EheVa1+KF55nWZ8ARikHMWKAKkyExkTak7xyN884TDmzURRaPlQg4jzQte5WMNjAG/hlHibdMBNvgwiYd49ZxteJ8ABdbiXVRl+2JGbdjl2ubpQZwOn7bJKlqO56bIwsZ+e4+pXsuOGdBahkHrUjtMEmH3DZbGc6CJLbcmdhdpApLQRRcLAazxJhzAwJ47FRYsHsj57LnYNvmcKdIxw8rxCdLUuzz95uw0T3ankEO5J9sjem+HMEuKdwXK1UcuOn2rjR8Sd/BuvQmeso27dFbPXqXYNS90Ml45YyTvcKSiopD181oZR703TFUSpR7dsiqROMr+p/2jN9h6a8WbQ8xpksyclaQByY/M77AssbXnG6wfhRsntNIINCZLbBnjXOyz6ZHIC5K4tSTdcnWaiYPeRPQmnw9UUvHAcNU2yMWsy0eU377yDS0WstTxOdQutTdkczl8kv5Lo26JiEK7mSIuRK19ffF9Zz8FG8+eKv5zdyIPjyQRDYBysUoDv5huKe2eoxJu/MWS2Pql/ZtUGeD6Ozm3mCvh0vQ9ceagBkY6Ocm3du0ziAKP29Ri0mjg4DizVorbLzsh+EQH/s2Pi9MnjUZDlEmuLl2Xfp7/w4j/8u0N0tVR70VDFuGdKpTjFY3vS8EJrPtyMTM51x1D9rb8gIql8aR/rJw4YF+huxg1mv5n6+tGVqg5msbPmF12eJijP4lkmaRwIpLW5pJTtaDkUj7uOeu1mm4k+Dt5nh0/0jPHzrv6bcTCcbV7UjMHDoTXXqEpFAAJ66rHR7zdAJu+YKsnTIZyLmOpcowq7LL8G9qTvV0OSpyQWUIavRSgbDHFqEqRs+JU94jAzkq8nCY5MTd9m5sIv9InfdT3k+pwpsE/FKge8nghFLtbUrafGkzTky8SE2druvVcIvbfXMfLIKRUYjJgnWc0gQzF5J6pzXM7D2r/RG6JDzASqjlbURq6v9bhNerlOVdMujWKEEVcKWIzlbt4RkihRjM8AUqIZQOyicGQ+4yfIjAHw5viuABONYs3OIWULnFqJxdvS9rNKhfxSjIq9cfqyzevq2xrRoMXEonobh6M3bD2Vang8OAeVeD1OXWPERi4pepCYFS9RJ/Xa/UWxptsqSNuGcb3fAzQSmLpXLGdWRoKXvSe7EYgc0bGcLOjSTu5RURKo+EF9i4KT9EJauf6VXw5dTf/CCIJRXE1bWzXhSCFYntohYhX2ldOCDYpi/jFBC6Vtkw0ud3/xq8Nmhd5gUk+SpngByCZH3Pm3H+jvlbMpiqkDkm1v74hDX13Xhrcw2eWyuqKBVoRCCniUvwpYNbGvBfjC6Hcizv0Aybciwj+4nybt5EPoEUm6S6Gs7fG7QpPdvrzpAxX70MlmdkF/gwyuhbEeJhLK+WL7qAsN5CvHPzVbsIf90x+nGTtMJPgpxVr0tJMj+vprXV4WxutfARBiOnqe58MhA857sd+MzKBgKnoLOBRTiC3qc/0/ULwbG2HCCD7nmwzz7M4nUuMvo8rgS7z0BF68OClT8X3JwSXbL5Wg=="),
|
||||
},
|
||||
}
|
||||
|
||||
for _, kv := range keyValues {
|
||||
if err := tx.Create(&kv).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -467,47 +483,29 @@ func (s *TestService) ResetAppConfig(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Manually set instance ID
|
||||
err = s.appConfigService.UpdateAppConfigValues(ctx, "instanceId", "test-instance-id")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Reload the app config from the database after resetting the values
|
||||
return s.appConfigService.LoadDbConfig(ctx)
|
||||
err = s.appConfigService.LoadDbConfig(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Reload the JWK
|
||||
if err := s.jwtService.LoadOrGenerateKey(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *TestService) SetJWTKeys() {
|
||||
const privateKeyString = `{"alg":"RS256","d":"mvMDWSdPPvcum0c0iEHE2gbqtV2NKMmLwrl9E6K7g8lTV95SePLnW_bwyMPV7EGp7PQk3l17I5XRhFjze7GqTnFIOgKzMianPs7jv2ELtBMGK0xOPATgu1iGb70xZ6vcvuEfRyY3dJ0zr4jpUdVuXwKmx9rK4IdZn2dFCKfvSuspqIpz11RhF1ALrqDLkxGVv7ZwNh0_VhJZU9hcjG5l6xc7rQEKpPRkZp0IdjkGS8Z0FskoVaiRIWAbZuiVFB9WCW8k1czC4HQTPLpII01bUQx2ludbm0UlXRgVU9ptUUbU7GAImQqTOW8LfPGklEvcgzlIlR_oqw4P9yBxLi-yMQ","dp":"pvNCSnnhbo8Igw9psPR-DicxFnkXlu_ix4gpy6efTrxA-z1VDFDioJ814vKQNioYDzpyAP1gfMPhRkvG_q0hRZsJah3Sb9dfA-WkhSWY7lURQP4yIBTMU0PF_rEATuS7lRciYk1SOx5fqXZd3m_LP0vpBC4Ujlq6NAq6CIjCnms","dq":"TtUVGCCkPNgfOLmkYXu7dxxUCV5kB01-xAEK2OY0n0pG8vfDophH4_D_ZC7nvJ8J9uDhs_3JStexq1lIvaWtG99RNTChIEDzpdn6GH9yaVcb_eB4uJjrNm64FhF8PGCCwxA-xMCZMaARKwhMB2_IOMkxUbWboL3gnhJ2rDO_QO0","e":"AQAB","kid":"8uHDw3M6rf8","kty":"RSA","n":"yaeEL0VKoPBXIAaWXsUgmu05lAvEIIdJn0FX9lHh4JE5UY9B83C5sCNdhs9iSWzpeP11EVjWp8i3Yv2CF7c7u50BXnVBGtxpZpFC-585UXacoJ0chUmarL9GRFJcM1nPHBTFu68aRrn1rIKNHUkNaaxFo0NFGl_4EDDTO8HwawTjwkPoQlRzeByhlvGPVvwgB3Fn93B8QJ_cZhXKxJvjjrC_8Pk76heC_ntEMru71Ix77BoC3j2TuyiN7m9RNBW8BU5q6lKoIdvIeZfTFLzi37iufyfvMrJTixp9zhNB1NxlLCeOZl2MXegtiGqd2H3cbAyqoOiv9ihUWTfXj7SxJw","p":"_Yylc9e07CKdqNRD2EosMC2mrhrEa9j5oY_l00Qyy4-jmCA59Q9viyqvveRo0U7cRvFA5BWgWN6GGLh1DG3X-QBqVr0dnk3uzbobb55RYUXyPLuBZI2q6w2oasbiDwPdY7KpkVv_H-bpITQlyDvO8hhucA6rUV7F6KTQVz8M3Ms","q":"y5p3hch-7jJ21TkAhp_Vk1fLCAuD4tbErwQs2of9ja8sB4iJOs5Wn6HD3P7Mc8Plye7qaLHvzc8I5g0tPKWvC0DPd_FLPXiWwMVAzee3NUX_oGeJNOQp11y1w_KqdO9qZqHSEPZ3NcFL_SZMFgggxhM1uzRiPzsVN0lnD_6prZU","qi":"2Grt6uXHm61ji3xSdkBWNtUnj19vS1-7rFJp5SoYztVQVThf_W52BAiXKBdYZDRVoItC_VS2NvAOjeJjhYO_xQ_q3hK7MdtuXfEPpLnyXKkmWo3lrJ26wbeF6l05LexCkI7ShsOuSt-dsyaTJTszuKDIA6YOfWvfo3aVZmlWRaI","use":"sig"}`
|
||||
|
||||
privateKey, _ := jwk.ParseKey([]byte(privateKeyString))
|
||||
_ = s.jwtService.SetKey(privateKey)
|
||||
}
|
||||
|
||||
// getCborPublicKey decodes a Base64 encoded public key and returns the CBOR encoded COSE key
|
||||
func (s *TestService) getCborPublicKey(base64PublicKey string) ([]byte, error) {
|
||||
decodedKey, err := base64.StdEncoding.DecodeString(base64PublicKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode base64 key: %w", err)
|
||||
}
|
||||
pubKey, err := x509.ParsePKIXPublicKey(decodedKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse public key: %w", err)
|
||||
}
|
||||
|
||||
ecdsaPubKey, ok := pubKey.(*ecdsa.PublicKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("not an ECDSA public key")
|
||||
}
|
||||
|
||||
coseKey := map[int]interface{}{
|
||||
1: 2, // Key type: EC2
|
||||
3: -7, // Algorithm: ECDSA with SHA-256
|
||||
-1: 1, // Curve: P-256
|
||||
-2: ecdsaPubKey.X.Bytes(), // X coordinate
|
||||
-3: ecdsaPubKey.Y.Bytes(), // Y coordinate
|
||||
}
|
||||
|
||||
cborPublicKey, err := cbor.Marshal(coseKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal COSE key: %w", err)
|
||||
}
|
||||
|
||||
return cborPublicKey, nil
|
||||
func (s *TestService) ResetLock(ctx context.Context) error {
|
||||
_, err := s.appLockService.Acquire(ctx, true)
|
||||
return err
|
||||
}
|
||||
|
||||
// SyncLdap triggers an LDAP synchronization
|
||||
@@ -534,7 +532,7 @@ func (s *TestService) SetLdapTestConfig(ctx context.Context) error {
|
||||
"ldapAttributeGroupUniqueIdentifier": "uuid",
|
||||
"ldapAttributeGroupName": "uid",
|
||||
"ldapAttributeGroupMember": "member",
|
||||
"ldapAttributeAdminGroup": "admin_group",
|
||||
"ldapAdminGroupName": "admin_group",
|
||||
"ldapSoftDeleteUsers": "true",
|
||||
"ldapEnabled": "true",
|
||||
}
|
||||
|
||||
@@ -78,7 +78,7 @@ func SendEmail[V any](ctx context.Context, srv *EmailService, toEmail email.Addr
|
||||
|
||||
data := &email.TemplateData[V]{
|
||||
AppName: dbConfig.AppName.Value,
|
||||
LogoURL: common.EnvConfig.AppURL + "/api/application-images/email",
|
||||
LogoURL: common.EnvConfig.AppURL + "/api/application-images/logo",
|
||||
Data: tData,
|
||||
}
|
||||
|
||||
|
||||
217
backend/internal/service/export_service.go
Normal file
217
backend/internal/service/export_service.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"path/filepath"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/storage"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
)
|
||||
|
||||
// ExportService handles exporting Pocket ID data into a ZIP archive.
|
||||
type ExportService struct {
|
||||
db *gorm.DB
|
||||
storage storage.FileStorage
|
||||
}
|
||||
|
||||
func NewExportService(db *gorm.DB, storage storage.FileStorage) *ExportService {
|
||||
return &ExportService{
|
||||
db: db,
|
||||
storage: storage,
|
||||
}
|
||||
}
|
||||
|
||||
// ExportToZip performs the full export process and writes the ZIP data to the given writer.
|
||||
func (s *ExportService) ExportToZip(ctx context.Context, w io.Writer) error {
|
||||
dbData, err := s.extractDatabase()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.writeExportZipStream(ctx, w, dbData)
|
||||
}
|
||||
|
||||
// extractDatabase reads all tables into a DatabaseExport struct
|
||||
func (s *ExportService) extractDatabase() (DatabaseExport, error) {
|
||||
schema, err := utils.LoadDBSchemaTypes(s.db)
|
||||
if err != nil {
|
||||
return DatabaseExport{}, fmt.Errorf("failed to load schema types: %w", err)
|
||||
}
|
||||
|
||||
version, err := s.schemaVersion()
|
||||
if err != nil {
|
||||
return DatabaseExport{}, err
|
||||
}
|
||||
|
||||
out := DatabaseExport{
|
||||
Provider: s.db.Name(),
|
||||
Version: version,
|
||||
Tables: map[string][]map[string]any{},
|
||||
// These tables need to be inserted in a specific order because of foreign key constraints
|
||||
// Not all tables are listed here, because not all tables are order-dependent
|
||||
TableOrder: []string{"users", "user_groups", "oidc_clients"},
|
||||
}
|
||||
|
||||
for table := range schema {
|
||||
if table == "storage" || table == "schema_migrations" {
|
||||
continue
|
||||
}
|
||||
err = s.dumpTable(table, schema[table], &out)
|
||||
if err != nil {
|
||||
return DatabaseExport{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *ExportService) schemaVersion() (uint, error) {
|
||||
var version uint
|
||||
if err := s.db.Raw("SELECT version FROM schema_migrations").Row().Scan(&version); err != nil {
|
||||
return 0, fmt.Errorf("failed to query schema version: %w", err)
|
||||
}
|
||||
return version, nil
|
||||
}
|
||||
|
||||
// dumpTable selects all rows from a table and appends them to out.Tables
|
||||
func (s *ExportService) dumpTable(table string, types utils.DBSchemaTableTypes, out *DatabaseExport) error {
|
||||
rows, err := s.db.Raw("SELECT * FROM " + table).Rows()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read table %s: %w", table, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
cols, _ := rows.Columns()
|
||||
if len(cols) != len(types) {
|
||||
// Should never happen...
|
||||
return fmt.Errorf("mismatched columns in table (%d) and schema (%d)", len(cols), len(types))
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
vals := s.getScanValuesForTable(cols, types)
|
||||
err = rows.Scan(vals...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to scan row in table %s: %w", table, err)
|
||||
}
|
||||
|
||||
rowMap := make(map[string]any, len(cols))
|
||||
for i, col := range cols {
|
||||
rowMap[col] = vals[i]
|
||||
}
|
||||
|
||||
// Skip the app lock row in the kv table
|
||||
if table == "kv" {
|
||||
if keyPtr, ok := rowMap["key"].(*string); ok && keyPtr != nil && *keyPtr == lockKey {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
out.Tables[table] = append(out.Tables[table], rowMap)
|
||||
}
|
||||
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
func (s *ExportService) getScanValuesForTable(cols []string, types utils.DBSchemaTableTypes) []any {
|
||||
res := make([]any, len(cols))
|
||||
for i, col := range cols {
|
||||
// Store a pointer
|
||||
// Note: don't create a helper function for this switch, because it would return type "any" and mess everything up
|
||||
// If the column is nullable, we need a pointer to a pointer!
|
||||
switch types[col].Name {
|
||||
case "boolean", "bool":
|
||||
var x bool
|
||||
if types[col].Nullable {
|
||||
res[i] = utils.Ptr(utils.Ptr(x))
|
||||
} else {
|
||||
res[i] = utils.Ptr(x)
|
||||
}
|
||||
case "blob", "bytea", "jsonb":
|
||||
// Treat jsonb columns as binary too
|
||||
var x []byte
|
||||
if types[col].Nullable {
|
||||
res[i] = utils.Ptr(utils.Ptr(x))
|
||||
} else {
|
||||
res[i] = utils.Ptr(x)
|
||||
}
|
||||
case "timestamp", "timestamptz", "timestamp with time zone", "datetime":
|
||||
var x datatype.DateTime
|
||||
if types[col].Nullable {
|
||||
res[i] = utils.Ptr(utils.Ptr(x))
|
||||
} else {
|
||||
res[i] = utils.Ptr(x)
|
||||
}
|
||||
case "integer", "int", "bigint":
|
||||
var x int64
|
||||
if types[col].Nullable {
|
||||
res[i] = utils.Ptr(utils.Ptr(x))
|
||||
} else {
|
||||
res[i] = utils.Ptr(x)
|
||||
}
|
||||
default:
|
||||
// Treat everything else as a string (including the "numeric" type)
|
||||
var x string
|
||||
if types[col].Nullable {
|
||||
res[i] = utils.Ptr(utils.Ptr(x))
|
||||
} else {
|
||||
res[i] = utils.Ptr(x)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (s *ExportService) writeExportZipStream(ctx context.Context, w io.Writer, dbData DatabaseExport) error {
|
||||
zipWriter := zip.NewWriter(w)
|
||||
|
||||
// Add database.json
|
||||
jsonWriter, err := zipWriter.Create("database.json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create database.json in zip: %w", err)
|
||||
}
|
||||
|
||||
jsonEncoder := json.NewEncoder(jsonWriter)
|
||||
jsonEncoder.SetEscapeHTML(false)
|
||||
|
||||
if err := jsonEncoder.Encode(dbData); err != nil {
|
||||
return fmt.Errorf("failed to encode database.json: %w", err)
|
||||
}
|
||||
|
||||
// Add uploaded files
|
||||
if err := s.addUploadsToZip(ctx, zipWriter); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return zipWriter.Close()
|
||||
}
|
||||
|
||||
// addUploadsToZip adds all files from the storage to the ZIP archive under the "uploads/" directory
|
||||
func (s *ExportService) addUploadsToZip(ctx context.Context, zipWriter *zip.Writer) error {
|
||||
return s.storage.Walk(ctx, "/", func(p storage.ObjectInfo) error {
|
||||
zipPath := filepath.Join("uploads", p.Path)
|
||||
|
||||
w, err := zipWriter.Create(zipPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create zip entry for %s: %w", zipPath, err)
|
||||
}
|
||||
|
||||
f, _, err := s.storage.Open(ctx, p.Path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open file %s: %w", zipPath, err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := io.Copy(w, f); err != nil {
|
||||
return fmt.Errorf("failed to copy file %s into zip: %w", zipPath, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
264
backend/internal/service/import_service.go
Normal file
264
backend/internal/service/import_service.go
Normal file
@@ -0,0 +1,264 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/storage"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
)
|
||||
|
||||
// ImportService handles importing Pocket ID data from an exported ZIP archive.
|
||||
type ImportService struct {
|
||||
db *gorm.DB
|
||||
storage storage.FileStorage
|
||||
}
|
||||
|
||||
type DatabaseExport struct {
|
||||
Provider string `json:"provider"`
|
||||
Version uint `json:"version"`
|
||||
Tables map[string][]map[string]any `json:"tables"`
|
||||
TableOrder []string `json:"tableOrder"`
|
||||
}
|
||||
|
||||
func NewImportService(db *gorm.DB, storage storage.FileStorage) *ImportService {
|
||||
return &ImportService{
|
||||
db: db,
|
||||
storage: storage,
|
||||
}
|
||||
}
|
||||
|
||||
// ImportFromZip performs the full import process from the given ZIP reader.
|
||||
func (s *ImportService) ImportFromZip(ctx context.Context, r *zip.Reader) error {
|
||||
dbData, err := processZipDatabaseJson(r.File)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.ImportDatabase(dbData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.importUploads(ctx, r.File)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ImportDatabase only imports the database data from the given DatabaseExport struct.
|
||||
func (s *ImportService) ImportDatabase(dbData DatabaseExport) error {
|
||||
err := s.resetSchema(dbData.Version, dbData.Provider)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.insertData(dbData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// processZipDatabaseJson extracts database.json from the ZIP archive
|
||||
func processZipDatabaseJson(files []*zip.File) (dbData DatabaseExport, err error) {
|
||||
for _, f := range files {
|
||||
if f.Name == "database.json" {
|
||||
return parseDatabaseJsonStream(f)
|
||||
}
|
||||
}
|
||||
return dbData, errors.New("database.json not found in the ZIP file")
|
||||
}
|
||||
|
||||
func parseDatabaseJsonStream(f *zip.File) (dbData DatabaseExport, err error) {
|
||||
rc, err := f.Open()
|
||||
if err != nil {
|
||||
return dbData, fmt.Errorf("failed to open database.json: %w", err)
|
||||
}
|
||||
defer rc.Close()
|
||||
|
||||
err = json.NewDecoder(rc).Decode(&dbData)
|
||||
if err != nil {
|
||||
return dbData, fmt.Errorf("failed to decode database.json: %w", err)
|
||||
}
|
||||
|
||||
return dbData, nil
|
||||
}
|
||||
|
||||
// importUploads imports files from the uploads/ directory in the ZIP archive
|
||||
func (s *ImportService) importUploads(ctx context.Context, files []*zip.File) error {
|
||||
const maxFileSize = 50 << 20 // 50 MiB
|
||||
const uploadsPrefix = "uploads/"
|
||||
|
||||
for _, f := range files {
|
||||
if !strings.HasPrefix(f.Name, uploadsPrefix) {
|
||||
continue
|
||||
}
|
||||
|
||||
if f.UncompressedSize64 > maxFileSize {
|
||||
return fmt.Errorf("file %s too large (%d bytes)", f.Name, f.UncompressedSize64)
|
||||
}
|
||||
|
||||
targetPath := strings.TrimPrefix(f.Name, uploadsPrefix)
|
||||
if strings.HasSuffix(f.Name, "/") || targetPath == "" {
|
||||
continue // Skip directories
|
||||
}
|
||||
|
||||
err := s.storage.DeleteAll(ctx, targetPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete existing file %s: %w", targetPath, err)
|
||||
}
|
||||
|
||||
rc, err := f.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf, err := io.ReadAll(rc)
|
||||
rc.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("read file %s: %w", f.Name, err)
|
||||
}
|
||||
|
||||
err = s.storage.Save(ctx, targetPath, bytes.NewReader(buf))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save file %s: %w", targetPath, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// resetSchema drops the existing schema and migrates to the target version
|
||||
func (s *ImportService) resetSchema(targetVersion uint, exportDbProvider string) error {
|
||||
sqlDb, err := s.db.DB()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get sql.DB: %w", err)
|
||||
}
|
||||
|
||||
m, err := utils.GetEmbeddedMigrateInstance(sqlDb)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get migrate instance: %w", err)
|
||||
}
|
||||
|
||||
err = m.Drop()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to drop existing schema: %w", err)
|
||||
}
|
||||
|
||||
// Needs to be called again to re-create the schema_migrations table
|
||||
m, err = utils.GetEmbeddedMigrateInstance(sqlDb)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get migrate instance: %w", err)
|
||||
}
|
||||
|
||||
err = m.Migrate(targetVersion)
|
||||
if err != nil {
|
||||
return fmt.Errorf("migration failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// insertData populates the DB with the imported data
|
||||
func (s *ImportService) insertData(dbData DatabaseExport) error {
|
||||
schema, err := utils.LoadDBSchemaTypes(s.db)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load schema types: %w", err)
|
||||
}
|
||||
|
||||
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||
// Iterate through all tables
|
||||
// Some tables need to be processed in order
|
||||
tables := make([]string, 0, len(dbData.Tables))
|
||||
tables = append(tables, dbData.TableOrder...)
|
||||
|
||||
for t := range dbData.Tables {
|
||||
// Skip tables already present where the order matters
|
||||
// Also skip the schema_migrations table
|
||||
if slices.Contains(dbData.TableOrder, t) || t == "schema_migrations" {
|
||||
continue
|
||||
}
|
||||
tables = append(tables, t)
|
||||
}
|
||||
|
||||
// Insert rows
|
||||
for _, table := range tables {
|
||||
for _, row := range dbData.Tables[table] {
|
||||
err = normalizeRowWithSchema(row, table, schema)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to normalize row for table '%s': %w", table, err)
|
||||
}
|
||||
err = tx.Table(table).Create(row).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed inserting into table '%s': %w", table, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// normalizeRowWithSchema converts row values based on the DB schema
|
||||
func normalizeRowWithSchema(row map[string]any, table string, schema utils.DBSchemaTypes) error {
|
||||
if schema[table] == nil {
|
||||
return fmt.Errorf("schema not found for table '%s'", table)
|
||||
}
|
||||
|
||||
for col, val := range row {
|
||||
if val == nil {
|
||||
// If the value is nil, skip the column
|
||||
continue
|
||||
}
|
||||
|
||||
colType := schema[table][col]
|
||||
|
||||
switch colType.Name {
|
||||
case "timestamp", "timestamptz", "timestamp with time zone", "datetime":
|
||||
// Dates are stored as strings
|
||||
str, ok := val.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("value for column '%s/%s' was expected to be a string, but was '%T'", table, col, val)
|
||||
}
|
||||
d, err := datatype.DateTimeFromString(str)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode value for column '%s/%s' as timestamp: %w", table, col, err)
|
||||
}
|
||||
row[col] = d
|
||||
|
||||
case "blob", "bytea", "jsonb":
|
||||
// Binary data and jsonb data is stored in the file as base64-encoded string
|
||||
str, ok := val.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("value for column '%s/%s' was expected to be a string, but was '%T'", table, col, val)
|
||||
}
|
||||
b, err := base64.StdEncoding.DecodeString(str)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode value for column '%s/%s' from base64: %w", table, col, err)
|
||||
}
|
||||
|
||||
// For jsonb, we additionally cast to json.RawMessage
|
||||
if colType.Name == "jsonb" {
|
||||
row[col] = json.RawMessage(b)
|
||||
} else {
|
||||
row[col] = b
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -18,14 +18,6 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// PrivateKeyFile is the path in the data/keys folder where the key is stored
|
||||
// This is a JSON file containing a key encoded as JWK
|
||||
PrivateKeyFile = "jwt_private_key.json"
|
||||
|
||||
// PrivateKeyFileEncrypted is the path in the data/keys folder where the encrypted key is stored
|
||||
// This is a encrypted JSON file containing a key encoded as JWK
|
||||
PrivateKeyFileEncrypted = "jwt_private_key.json.enc"
|
||||
|
||||
// KeyUsageSigning is the usage for the private keys, for the "use" property
|
||||
KeyUsageSigning = "sig"
|
||||
|
||||
@@ -56,6 +48,7 @@ const (
|
||||
)
|
||||
|
||||
type JwtService struct {
|
||||
db *gorm.DB
|
||||
envConfig *common.EnvConfigSchema
|
||||
privateKey jwk.Key
|
||||
keyId string
|
||||
@@ -66,7 +59,6 @@ type JwtService struct {
|
||||
func NewJwtService(db *gorm.DB, appConfigService *AppConfigService) (*JwtService, error) {
|
||||
service := &JwtService{}
|
||||
|
||||
// Ensure keys are generated or loaded
|
||||
err := service.init(db, appConfigService, &common.EnvConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -78,14 +70,15 @@ func NewJwtService(db *gorm.DB, appConfigService *AppConfigService) (*JwtService
|
||||
func (s *JwtService) init(db *gorm.DB, appConfigService *AppConfigService, envConfig *common.EnvConfigSchema) (err error) {
|
||||
s.appConfigService = appConfigService
|
||||
s.envConfig = envConfig
|
||||
s.db = db
|
||||
|
||||
// Ensure keys are generated or loaded
|
||||
return s.loadOrGenerateKey(db)
|
||||
return s.LoadOrGenerateKey()
|
||||
}
|
||||
|
||||
func (s *JwtService) loadOrGenerateKey(db *gorm.DB) error {
|
||||
func (s *JwtService) LoadOrGenerateKey() error {
|
||||
// Get the key provider
|
||||
keyProvider, err := jwkutils.GetKeyProvider(db, s.envConfig, s.appConfigService.GetDbConfig().InstanceID.Value)
|
||||
keyProvider, err := jwkutils.GetKeyProvider(s.db, s.envConfig, s.appConfigService.GetDbConfig().InstanceID.Value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get key provider: %w", err)
|
||||
}
|
||||
@@ -93,7 +86,7 @@ func (s *JwtService) loadOrGenerateKey(db *gorm.DB) error {
|
||||
// Try loading a key
|
||||
key, err := keyProvider.LoadKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load key (provider type '%s'): %w", s.envConfig.KeysStorage, err)
|
||||
return fmt.Errorf("failed to load key: %w", err)
|
||||
}
|
||||
|
||||
// If we have a key, store it in the object and we're done
|
||||
@@ -114,7 +107,7 @@ func (s *JwtService) loadOrGenerateKey(db *gorm.DB) error {
|
||||
// Save the newly-generated key
|
||||
err = keyProvider.SaveKey(s.privateKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save private key (provider type '%s'): %w", s.envConfig.KeysStorage, err)
|
||||
return fmt.Errorf("failed to save private key: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -371,7 +371,7 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
||||
// Check if user is admin by checking if they are in the admin group
|
||||
isAdmin := false
|
||||
for _, group := range value.GetAttributeValues("memberOf") {
|
||||
if getDNProperty(dbConfig.LdapAttributeGroupName.Value, group) == dbConfig.LdapAttributeAdminGroup.Value {
|
||||
if getDNProperty(dbConfig.LdapAttributeGroupName.Value, group) == dbConfig.LdapAdminGroupName.Value {
|
||||
isAdmin = true
|
||||
break
|
||||
}
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -226,7 +225,7 @@ func (s *OidcService) hasAuthorizedClientInternal(ctx context.Context, clientID,
|
||||
|
||||
// IsUserGroupAllowedToAuthorize checks if the user group of the user is allowed to authorize the client
|
||||
func (s *OidcService) IsUserGroupAllowedToAuthorize(user model.User, client model.OidcClient) bool {
|
||||
if !client.IsGroupRestricted {
|
||||
if len(client.AllowedUserGroups) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -816,7 +815,6 @@ func updateOIDCClientModelFromDto(client *model.OidcClient, input *dto.OidcClien
|
||||
client.PkceEnabled = input.IsPublic || input.PkceEnabled
|
||||
client.RequiresReauthentication = input.RequiresReauthentication
|
||||
client.LaunchURL = input.LaunchURL
|
||||
client.IsGroupRestricted = input.IsGroupRestricted
|
||||
|
||||
// Credentials
|
||||
client.Credentials.FederatedIdentities = make([]model.OidcClientFederatedIdentity, len(input.Credentials.FederatedIdentities))
|
||||
@@ -1197,7 +1195,7 @@ func (s *OidcService) getCallbackURL(client *model.OidcClient, inputCallbackURL
|
||||
|
||||
// If URLs are already configured, validate against them
|
||||
if len(client.CallbackURLs) > 0 {
|
||||
matched, err := s.getCallbackURLFromList(client.CallbackURLs, inputCallbackURL)
|
||||
matched, err := utils.GetCallbackURLFromList(client.CallbackURLs, inputCallbackURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else if matched == "" {
|
||||
@@ -1220,7 +1218,7 @@ func (s *OidcService) getLogoutCallbackURL(client *model.OidcClient, inputLogout
|
||||
return client.LogoutCallbackURLs[0], nil
|
||||
}
|
||||
|
||||
matched, err := s.getCallbackURLFromList(client.LogoutCallbackURLs, inputLogoutCallbackURL)
|
||||
matched, err := utils.GetCallbackURLFromList(client.LogoutCallbackURLs, inputLogoutCallbackURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else if matched == "" {
|
||||
@@ -1230,21 +1228,6 @@ func (s *OidcService) getLogoutCallbackURL(client *model.OidcClient, inputLogout
|
||||
return matched, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) getCallbackURLFromList(urls []string, inputCallbackURL string) (callbackURL string, err error) {
|
||||
for _, callbackPattern := range urls {
|
||||
regexPattern := "^" + strings.ReplaceAll(regexp.QuoteMeta(callbackPattern), `\*`, ".*") + "$"
|
||||
matched, err := regexp.MatchString(regexPattern, inputCallbackURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if matched {
|
||||
return inputCallbackURL, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (s *OidcService) addCallbackURLToClient(ctx context.Context, client *model.OidcClient, callbackURL string, tx *gorm.DB) error {
|
||||
// Add the new callback URL to the existing list
|
||||
client.CallbackURLs = append(client.CallbackURLs, callbackURL)
|
||||
|
||||
@@ -148,6 +148,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
var err error
|
||||
// Create a test database
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
common.EnvConfig.EncryptionKey = []byte("0123456789abcdef0123456789abcdef")
|
||||
|
||||
// Create two JWKs for testing
|
||||
privateJWK, jwkSetJSON := generateTestECDSAKey(t)
|
||||
|
||||
@@ -253,18 +253,6 @@ func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCrea
|
||||
return model.User{}, &common.UserEmailNotSetError{}
|
||||
}
|
||||
|
||||
var userGroups []model.UserGroup
|
||||
if len(input.UserGroupIds) > 0 {
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Where("id IN ?", input.UserGroupIds).
|
||||
Find(&userGroups).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.User{}, err
|
||||
}
|
||||
}
|
||||
|
||||
user := model.User{
|
||||
FirstName: input.FirstName,
|
||||
LastName: input.LastName,
|
||||
@@ -274,7 +262,6 @@ func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCrea
|
||||
IsAdmin: input.IsAdmin,
|
||||
Locale: input.Locale,
|
||||
Disabled: input.Disabled,
|
||||
UserGroups: userGroups,
|
||||
}
|
||||
if input.LdapID != "" {
|
||||
user.LdapID = &input.LdapID
|
||||
@@ -298,13 +285,7 @@ func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCrea
|
||||
|
||||
// Apply default groups and claims for new non-LDAP users
|
||||
if !isLdapSync {
|
||||
if len(input.UserGroupIds) == 0 {
|
||||
if err := s.applyDefaultGroups(ctx, &user, tx); err != nil {
|
||||
return model.User{}, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.applyDefaultCustomClaims(ctx, &user, tx); err != nil {
|
||||
if err := s.applySignupDefaults(ctx, &user, tx); err != nil {
|
||||
return model.User{}, err
|
||||
}
|
||||
}
|
||||
@@ -312,9 +293,10 @@ func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCrea
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *UserService) applyDefaultGroups(ctx context.Context, user *model.User, tx *gorm.DB) error {
|
||||
func (s *UserService) applySignupDefaults(ctx context.Context, user *model.User, tx *gorm.DB) error {
|
||||
config := s.appConfigService.GetDbConfig()
|
||||
|
||||
// Apply default user groups
|
||||
var groupIDs []string
|
||||
v := config.SignupDefaultUserGroupIDs.Value
|
||||
if v != "" && v != "[]" {
|
||||
@@ -341,14 +323,10 @@ func (s *UserService) applyDefaultGroups(ctx context.Context, user *model.User,
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UserService) applyDefaultCustomClaims(ctx context.Context, user *model.User, tx *gorm.DB) error {
|
||||
config := s.appConfigService.GetDbConfig()
|
||||
|
||||
// Apply default custom claims
|
||||
var claims []dto.CustomClaimCreateDto
|
||||
v := config.SignupDefaultCustomClaims.Value
|
||||
v = config.SignupDefaultCustomClaims.Value
|
||||
if v != "" && v != "[]" {
|
||||
err := json.Unmarshal([]byte(v), &claims)
|
||||
if err != nil {
|
||||
@@ -749,22 +727,12 @@ func (s *UserService) disableUserInternal(ctx context.Context, tx *gorm.DB, user
|
||||
Error
|
||||
}
|
||||
|
||||
func (s *UserService) CreateSignupToken(ctx context.Context, ttl time.Duration, usageLimit int, userGroupIDs []string) (model.SignupToken, error) {
|
||||
func (s *UserService) CreateSignupToken(ctx context.Context, ttl time.Duration, usageLimit int) (model.SignupToken, error) {
|
||||
signupToken, err := NewSignupToken(ttl, usageLimit)
|
||||
if err != nil {
|
||||
return model.SignupToken{}, err
|
||||
}
|
||||
|
||||
var userGroups []model.UserGroup
|
||||
err = s.db.WithContext(ctx).
|
||||
Where("id IN ?", userGroupIDs).
|
||||
Find(&userGroups).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.SignupToken{}, err
|
||||
}
|
||||
signupToken.UserGroups = userGroups
|
||||
|
||||
err = s.db.WithContext(ctx).Create(signupToken).Error
|
||||
if err != nil {
|
||||
return model.SignupToken{}, err
|
||||
@@ -787,11 +755,9 @@ func (s *UserService) SignUp(ctx context.Context, signupData dto.SignUpDto, ipAd
|
||||
}
|
||||
|
||||
var signupToken model.SignupToken
|
||||
var userGroupIDs []string
|
||||
if tokenProvided {
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Preload("UserGroups").
|
||||
Where("token = ?", signupData.Token).
|
||||
Clauses(clause.Locking{Strength: "UPDATE"}).
|
||||
First(&signupToken).
|
||||
@@ -806,19 +772,14 @@ func (s *UserService) SignUp(ctx context.Context, signupData dto.SignUpDto, ipAd
|
||||
if !signupToken.IsValid() {
|
||||
return model.User{}, "", &common.TokenInvalidOrExpiredError{}
|
||||
}
|
||||
|
||||
for _, group := range signupToken.UserGroups {
|
||||
userGroupIDs = append(userGroupIDs, group.ID)
|
||||
}
|
||||
}
|
||||
|
||||
userToCreate := dto.UserCreateDto{
|
||||
Username: signupData.Username,
|
||||
Email: signupData.Email,
|
||||
FirstName: signupData.FirstName,
|
||||
LastName: signupData.LastName,
|
||||
DisplayName: strings.TrimSpace(signupData.FirstName + " " + signupData.LastName),
|
||||
UserGroupIds: userGroupIDs,
|
||||
Username: signupData.Username,
|
||||
Email: signupData.Email,
|
||||
FirstName: signupData.FirstName,
|
||||
LastName: signupData.LastName,
|
||||
DisplayName: strings.TrimSpace(signupData.FirstName + " " + signupData.LastName),
|
||||
}
|
||||
|
||||
user, err := s.createUserInternal(ctx, userToCreate, false, tx)
|
||||
@@ -859,7 +820,7 @@ func (s *UserService) SignUp(ctx context.Context, signupData dto.SignUpDto, ipAd
|
||||
|
||||
func (s *UserService) ListSignupTokens(ctx context.Context, listRequestOptions utils.ListRequestOptions) ([]model.SignupToken, utils.PaginationResponse, error) {
|
||||
var tokens []model.SignupToken
|
||||
query := s.db.WithContext(ctx).Preload("UserGroups").Model(&model.SignupToken{})
|
||||
query := s.db.WithContext(ctx).Model(&model.SignupToken{})
|
||||
|
||||
pagination, err := utils.PaginateFilterAndSort(listRequestOptions, query, &tokens)
|
||||
return tokens, pagination, err
|
||||
|
||||
199
backend/internal/utils/callback_url_util.go
Normal file
199
backend/internal/utils/callback_url_util.go
Normal file
@@ -0,0 +1,199 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/url"
|
||||
"path"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// GetCallbackURLFromList returns the first callback URL that matches the input callback URL
|
||||
func GetCallbackURLFromList(urls []string, inputCallbackURL string) (callbackURL string, err error) {
|
||||
// Special case for Loopback Interface Redirection. Quoting from RFC 8252 section 7.3:
|
||||
// https://datatracker.ietf.org/doc/html/rfc8252#section-7.3
|
||||
//
|
||||
// The authorization server MUST allow any port to be specified at the
|
||||
// time of the request for loopback IP redirect URIs, to accommodate
|
||||
// clients that obtain an available ephemeral port from the operating
|
||||
// system at the time of the request.
|
||||
loopbackRedirect := ""
|
||||
u, _ := url.Parse(inputCallbackURL)
|
||||
|
||||
if u != nil && u.Scheme == "http" {
|
||||
host := u.Hostname()
|
||||
ip := net.ParseIP(host)
|
||||
if host == "localhost" || (ip != nil && ip.IsLoopback()) {
|
||||
loopbackRedirect = u.String()
|
||||
u.Host = host
|
||||
inputCallbackURL = u.String()
|
||||
}
|
||||
}
|
||||
|
||||
for _, pattern := range urls {
|
||||
matches, err := matchCallbackURL(pattern, inputCallbackURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else if !matches {
|
||||
continue
|
||||
}
|
||||
|
||||
if loopbackRedirect != "" {
|
||||
return loopbackRedirect, nil
|
||||
}
|
||||
return inputCallbackURL, nil
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// matchCallbackURL checks if the input callback URL matches the given pattern.
|
||||
// It supports wildcard matching for paths and query parameters.
|
||||
//
|
||||
// The base URL (scheme, userinfo, host, port) and query parameters supports single '*' wildcards only,
|
||||
// while the path supports both single '*' and double '**' wildcards.
|
||||
func matchCallbackURL(pattern string, inputCallbackURL string) (matches bool, err error) {
|
||||
if pattern == inputCallbackURL || pattern == "*" {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Strip fragment part
|
||||
// The endpoint URI MUST NOT include a fragment component.
|
||||
// https://datatracker.ietf.org/doc/html/rfc6749#section-3.1.2
|
||||
pattern, _, _ = strings.Cut(pattern, "#")
|
||||
inputCallbackURL, _, _ = strings.Cut(inputCallbackURL, "#")
|
||||
|
||||
// Store and strip query part
|
||||
var patternQuery url.Values
|
||||
if i := strings.Index(pattern, "?"); i >= 0 {
|
||||
patternQuery, err = url.ParseQuery(pattern[i+1:])
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
pattern = pattern[:i]
|
||||
}
|
||||
var inputQuery url.Values
|
||||
if i := strings.Index(inputCallbackURL, "?"); i >= 0 {
|
||||
inputQuery, err = url.ParseQuery(inputCallbackURL[i+1:])
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
inputCallbackURL = inputCallbackURL[:i]
|
||||
}
|
||||
|
||||
// Split both pattern and input parts
|
||||
patternParts, patternPath := splitParts(pattern)
|
||||
inputParts, inputPath := splitParts(inputCallbackURL)
|
||||
|
||||
// Verify everything except the path and query parameters
|
||||
if len(patternParts) != len(inputParts) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for i, patternPart := range patternParts {
|
||||
matched, err := path.Match(patternPart, inputParts[i])
|
||||
if err != nil || !matched {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
// Verify path with wildcard support
|
||||
matched, err := matchPath(patternPath, inputPath)
|
||||
if err != nil || !matched {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Verify query parameters
|
||||
if len(patternQuery) != len(inputQuery) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for patternKey, patternValues := range patternQuery {
|
||||
inputValues, exists := inputQuery[patternKey]
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if len(patternValues) != len(inputValues) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for i := range patternValues {
|
||||
matched, err := path.Match(patternValues[i], inputValues[i])
|
||||
if err != nil || !matched {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// matchPath matches the input path against the pattern with wildcard support
|
||||
// Supported wildcards:
|
||||
//
|
||||
// '*' matches any sequence of characters except '/'
|
||||
// '**' matches any sequence of characters including '/'
|
||||
func matchPath(pattern string, input string) (matches bool, err error) {
|
||||
var regexPattern strings.Builder
|
||||
regexPattern.WriteString("^")
|
||||
|
||||
runes := []rune(pattern)
|
||||
n := len(runes)
|
||||
|
||||
for i := 0; i < n; {
|
||||
switch runes[i] {
|
||||
case '*':
|
||||
// Check if it's a ** (globstar)
|
||||
if i+1 < n && runes[i+1] == '*' {
|
||||
// globstar = .* (match slashes too)
|
||||
regexPattern.WriteString(".*")
|
||||
i += 2
|
||||
} else {
|
||||
// single * = [^/]* (no slash)
|
||||
regexPattern.WriteString(`[^/]*`)
|
||||
i++
|
||||
}
|
||||
default:
|
||||
regexPattern.WriteString(regexp.QuoteMeta(string(runes[i])))
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
regexPattern.WriteString("$")
|
||||
|
||||
matched, err := regexp.MatchString(regexPattern.String(), input)
|
||||
return matched, err
|
||||
}
|
||||
|
||||
// splitParts splits the URL into parts by special characters and returns the path separately
|
||||
func splitParts(s string) (parts []string, path string) {
|
||||
split := func(r rune) bool {
|
||||
return r == ':' || r == '/' || r == '[' || r == ']' || r == '@' || r == '.'
|
||||
}
|
||||
|
||||
pathStart := -1
|
||||
|
||||
// Look for scheme:// first
|
||||
if i := strings.Index(s, "://"); i >= 0 {
|
||||
// Look for the next slash after scheme://
|
||||
rest := s[i+3:]
|
||||
if j := strings.IndexRune(rest, '/'); j >= 0 {
|
||||
pathStart = i + 3 + j
|
||||
}
|
||||
} else {
|
||||
// Otherwise, first slash is path start
|
||||
pathStart = strings.IndexRune(s, '/')
|
||||
}
|
||||
|
||||
if pathStart >= 0 {
|
||||
path = s[pathStart:]
|
||||
base := s[:pathStart]
|
||||
parts = strings.FieldsFunc(base, split)
|
||||
} else {
|
||||
parts = strings.FieldsFunc(s, split)
|
||||
path = ""
|
||||
}
|
||||
|
||||
return parts, path
|
||||
}
|
||||
784
backend/internal/utils/callback_url_util_test.go
Normal file
784
backend/internal/utils/callback_url_util_test.go
Normal file
@@ -0,0 +1,784 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMatchCallbackURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pattern string
|
||||
input string
|
||||
shouldMatch bool
|
||||
}{
|
||||
// Basic matching
|
||||
{
|
||||
"exact match",
|
||||
"https://example.com/callback",
|
||||
"https://example.com/callback",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"no match",
|
||||
"https://example.org/callback",
|
||||
"https://example.com/callback",
|
||||
false,
|
||||
},
|
||||
|
||||
// Scheme
|
||||
{
|
||||
"scheme mismatch",
|
||||
"https://example.com/callback",
|
||||
"http://example.com/callback",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"wildcard scheme",
|
||||
"*://example.com/callback",
|
||||
"https://example.com/callback",
|
||||
true,
|
||||
},
|
||||
|
||||
// Hostname
|
||||
{
|
||||
"hostname mismatch",
|
||||
"https://example.com/callback",
|
||||
"https://malicious.com/callback",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"wildcard subdomain",
|
||||
"https://*.example.com/callback",
|
||||
"https://subdomain.example.com/callback",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"partial wildcard in hostname prefix",
|
||||
"https://app*.example.com/callback",
|
||||
"https://app1.example.com/callback",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"partial wildcard in hostname suffix",
|
||||
"https://*-prod.example.com/callback",
|
||||
"https://api-prod.example.com/callback",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"partial wildcard in hostname middle",
|
||||
"https://app-*-server.example.com/callback",
|
||||
"https://app-staging-server.example.com/callback",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"subdomain wildcard doesn't match domain hijack attempt",
|
||||
"https://*.example.com/callback",
|
||||
"https://malicious.site?url=abc.example.com/callback",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"hostname mismatch with confusable characters",
|
||||
"https://example.com/callback",
|
||||
"https://examp1e.com/callback",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"hostname mismatch with homograph attack",
|
||||
"https://example.com/callback",
|
||||
"https://еxample.com/callback",
|
||||
false,
|
||||
},
|
||||
|
||||
// Port
|
||||
{
|
||||
"port mismatch",
|
||||
"https://example.com:8080/callback",
|
||||
"https://example.com:9090/callback",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"wildcard port",
|
||||
"https://example.com:*/callback",
|
||||
"https://example.com:8080/callback",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"partial wildcard in port prefix",
|
||||
"https://example.com:80*/callback",
|
||||
"https://example.com:8080/callback",
|
||||
true,
|
||||
},
|
||||
|
||||
// Path
|
||||
{
|
||||
"path mismatch",
|
||||
"https://example.com/callback",
|
||||
"https://example.com/other",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"wildcard path segment",
|
||||
"https://example.com/api/*/callback",
|
||||
"https://example.com/api/v1/callback",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"wildcard entire path",
|
||||
"https://example.com/*",
|
||||
"https://example.com/callback",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"partial wildcard in path prefix",
|
||||
"https://example.com/test*",
|
||||
"https://example.com/testcase",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"partial wildcard in path suffix",
|
||||
"https://example.com/*-callback",
|
||||
"https://example.com/oauth-callback",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"partial wildcard in path middle",
|
||||
"https://example.com/api-*-v1/callback",
|
||||
"https://example.com/api-internal-v1/callback",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"multiple partial wildcards in path",
|
||||
"https://example.com/*/test*/callback",
|
||||
"https://example.com/v1/testing/callback",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"multiple wildcard segments in path",
|
||||
"https://example.com/**/callback",
|
||||
"https://example.com/api/v1/foo/bar/callback",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"multiple wildcard segments in path",
|
||||
"https://example.com/**/v1/**/callback",
|
||||
"https://example.com/api/v1/foo/bar/callback",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"partial wildcard matching full path segment",
|
||||
"https://example.com/foo-*",
|
||||
"https://example.com/foo-bar",
|
||||
true,
|
||||
},
|
||||
|
||||
// Credentials
|
||||
{
|
||||
"username mismatch",
|
||||
"https://user:pass@example.com/callback",
|
||||
"https://admin:pass@example.com/callback",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"missing credentials",
|
||||
"https://user:pass@example.com/callback",
|
||||
"https://example.com/callback",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"unexpected credentials",
|
||||
"https://example.com/callback",
|
||||
"https://user:pass@example.com/callback",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"wildcard password",
|
||||
"https://user:*@example.com/callback",
|
||||
"https://user:secret123@example.com/callback",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"partial wildcard in username",
|
||||
"https://admin*:pass@example.com/callback",
|
||||
"https://admin123:pass@example.com/callback",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"partial wildcard in password",
|
||||
"https://user:pass*@example.com/callback",
|
||||
"https://user:password123@example.com/callback",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"wildcard password doesn't allow domain hijack",
|
||||
"https://user:*@example.com/callback",
|
||||
"https://user:password@malicious.site#example.com/callback",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"credentials with @ in password trying to hijack hostname",
|
||||
"https://user:pass@example.com/callback",
|
||||
"https://user:pass@evil.com@example.com/callback",
|
||||
false,
|
||||
},
|
||||
|
||||
// Query parameters
|
||||
{
|
||||
"extra query parameter",
|
||||
"https://example.com/callback?code=*",
|
||||
"https://example.com/callback?code=abc123&extra=value",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"missing query parameter",
|
||||
"https://example.com/callback?code=*&state=*",
|
||||
"https://example.com/callback?code=abc123",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"query parameter after fragment",
|
||||
"https://example.com/callback?code=123",
|
||||
"https://example.com/callback#section?code=123",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"query parameter name mismatch",
|
||||
"https://example.com/callback?code=*",
|
||||
"https://example.com/callback?token=abc123",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"wildcard query parameter",
|
||||
"https://example.com/callback?code=*",
|
||||
"https://example.com/callback?code=abc123",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"multiple query parameters",
|
||||
"https://example.com/callback?code=*&state=*",
|
||||
"https://example.com/callback?code=abc123&state=xyz789",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"query parameters in different order",
|
||||
"https://example.com/callback?state=*&code=*",
|
||||
"https://example.com/callback?code=abc123&state=xyz789",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"exact query parameter value",
|
||||
"https://example.com/callback?mode=production",
|
||||
"https://example.com/callback?mode=production",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"query parameter value mismatch",
|
||||
"https://example.com/callback?mode=production",
|
||||
"https://example.com/callback?mode=development",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"mixed exact and wildcard query parameters",
|
||||
"https://example.com/callback?mode=production&code=*",
|
||||
"https://example.com/callback?mode=production&code=abc123",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"mixed exact and wildcard with wrong exact value",
|
||||
"https://example.com/callback?mode=production&code=*",
|
||||
"https://example.com/callback?mode=development&code=abc123",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"multiple values for same parameter",
|
||||
"https://example.com/callback?param=*¶m=*",
|
||||
"https://example.com/callback?param=value1¶m=value2",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"unexpected query parameters",
|
||||
"https://example.com/callback",
|
||||
"https://example.com/callback?extra=value",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"query parameter with redirect to external site",
|
||||
"https://example.com/callback?code=*",
|
||||
"https://example.com/callback?code=123&redirect=https://evil.com",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"open redirect via encoded URL in query param",
|
||||
"https://example.com/callback?state=*",
|
||||
"https://example.com/callback?state=abc&next=//evil.com",
|
||||
false,
|
||||
},
|
||||
|
||||
// Fragment
|
||||
{
|
||||
"fragment ignored when both pattern and input have fragment",
|
||||
"https://example.com/callback#fragment",
|
||||
"https://example.com/callback#fragment",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"fragment ignored when pattern has fragment but input doesn't",
|
||||
"https://example.com/callback#fragment",
|
||||
"https://example.com/callback",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"fragment ignored when input has fragment but pattern doesn't",
|
||||
"https://example.com/callback",
|
||||
"https://example.com/callback#section",
|
||||
true,
|
||||
},
|
||||
|
||||
// Path traversal and injection attempts
|
||||
{
|
||||
"path traversal attempt",
|
||||
"https://example.com/callback",
|
||||
"https://example.com/../admin/callback",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"backslash instead of forward slash",
|
||||
"https://example.com/callback",
|
||||
"https://example.com\\callback",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"double slash in hostname (protocol smuggling)",
|
||||
"https://example.com/callback",
|
||||
"https://example.com//evil.com/callback",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"CRLF injection attempt in path",
|
||||
"https://example.com/callback",
|
||||
"https://example.com/callback%0d%0aLocation:%20https://evil.com",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"null byte injection",
|
||||
"https://example.com/callback",
|
||||
"https://example.com/callback%00.evil.com",
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
matches, err := matchCallbackURL(tt.pattern, tt.input)
|
||||
require.NoError(t, err, tt.name)
|
||||
assert.Equal(t, tt.shouldMatch, matches, tt.name)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCallbackURLFromList_LoopbackSpecialHandling(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
urls []string
|
||||
inputCallbackURL string
|
||||
expectedURL string
|
||||
expectMatch bool
|
||||
}{
|
||||
{
|
||||
name: "127.0.0.1 with dynamic port - exact match",
|
||||
urls: []string{"http://127.0.0.1/callback"},
|
||||
inputCallbackURL: "http://127.0.0.1:8080/callback",
|
||||
expectedURL: "http://127.0.0.1:8080/callback",
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1 with different port",
|
||||
urls: []string{"http://127.0.0.1/callback"},
|
||||
inputCallbackURL: "http://127.0.0.1:9999/callback",
|
||||
expectedURL: "http://127.0.0.1:9999/callback",
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 loopback with dynamic port",
|
||||
urls: []string{"http://[::1]/callback"},
|
||||
inputCallbackURL: "http://[::1]:8080/callback",
|
||||
expectedURL: "http://[::1]:8080/callback",
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 loopback without brackets in input",
|
||||
urls: []string{"http://[::1]/callback"},
|
||||
inputCallbackURL: "http://::1:8080/callback",
|
||||
expectedURL: "http://::1:8080/callback",
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "localhost with dynamic port",
|
||||
urls: []string{"http://localhost/callback"},
|
||||
inputCallbackURL: "http://localhost:8080/callback",
|
||||
expectedURL: "http://localhost:8080/callback",
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "https loopback doesn't trigger special handling",
|
||||
urls: []string{"https://127.0.0.1/callback"},
|
||||
inputCallbackURL: "https://127.0.0.1:8080/callback",
|
||||
expectedURL: "",
|
||||
expectMatch: false,
|
||||
},
|
||||
{
|
||||
name: "loopback with path match",
|
||||
urls: []string{"http://127.0.0.1/auth/*"},
|
||||
inputCallbackURL: "http://127.0.0.1:3000/auth/callback",
|
||||
expectedURL: "http://127.0.0.1:3000/auth/callback",
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "loopback with path mismatch",
|
||||
urls: []string{"http://127.0.0.1/callback"},
|
||||
inputCallbackURL: "http://127.0.0.1:8080/different",
|
||||
expectedURL: "",
|
||||
expectMatch: false,
|
||||
},
|
||||
{
|
||||
name: "non-loopback IP",
|
||||
urls: []string{"http://192.168.1.1/callback"},
|
||||
inputCallbackURL: "http://192.168.1.1:8080/callback",
|
||||
expectedURL: "",
|
||||
expectMatch: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard matches loopback",
|
||||
urls: []string{"*"},
|
||||
inputCallbackURL: "http://127.0.0.1:8080/callback",
|
||||
expectedURL: "http://127.0.0.1:8080/callback",
|
||||
expectMatch: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := GetCallbackURLFromList(tt.urls, tt.inputCallbackURL)
|
||||
require.NoError(t, err)
|
||||
if tt.expectMatch {
|
||||
assert.Equal(t, tt.expectedURL, result)
|
||||
} else {
|
||||
assert.Empty(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCallbackURLFromList_MultiplePatterns(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
urls []string
|
||||
inputCallbackURL string
|
||||
expectedURL string
|
||||
expectMatch bool
|
||||
}{
|
||||
{
|
||||
name: "matches first pattern",
|
||||
urls: []string{
|
||||
"https://example.com/callback",
|
||||
"https://example.org/callback",
|
||||
},
|
||||
inputCallbackURL: "https://example.com/callback",
|
||||
expectedURL: "https://example.com/callback",
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "matches second pattern",
|
||||
urls: []string{
|
||||
"https://example.com/callback",
|
||||
"https://example.org/callback",
|
||||
},
|
||||
inputCallbackURL: "https://example.org/callback",
|
||||
expectedURL: "https://example.org/callback",
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "matches none",
|
||||
urls: []string{
|
||||
"https://example.com/callback",
|
||||
"https://example.org/callback",
|
||||
},
|
||||
inputCallbackURL: "https://malicious.com/callback",
|
||||
expectedURL: "",
|
||||
expectMatch: false,
|
||||
},
|
||||
{
|
||||
name: "matches wildcard pattern",
|
||||
urls: []string{
|
||||
"https://example.com/callback",
|
||||
"https://*.example.org/callback",
|
||||
},
|
||||
inputCallbackURL: "https://subdomain.example.org/callback",
|
||||
expectedURL: "https://subdomain.example.org/callback",
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "empty pattern list",
|
||||
urls: []string{},
|
||||
inputCallbackURL: "https://example.com/callback",
|
||||
expectedURL: "",
|
||||
expectMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := GetCallbackURLFromList(tt.urls, tt.inputCallbackURL)
|
||||
require.NoError(t, err)
|
||||
if tt.expectMatch {
|
||||
assert.Equal(t, tt.expectedURL, result)
|
||||
} else {
|
||||
assert.Empty(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pattern string
|
||||
input string
|
||||
shouldMatch bool
|
||||
}{
|
||||
// Exact matches
|
||||
{
|
||||
name: "exact match",
|
||||
pattern: "/callback",
|
||||
input: "/callback",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "exact mismatch",
|
||||
pattern: "/callback",
|
||||
input: "/other",
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "empty paths",
|
||||
pattern: "",
|
||||
input: "",
|
||||
shouldMatch: true,
|
||||
},
|
||||
|
||||
// Single wildcard (*)
|
||||
{
|
||||
name: "single wildcard matches segment",
|
||||
pattern: "/api/*/callback",
|
||||
input: "/api/v1/callback",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "single wildcard doesn't match multiple segments",
|
||||
pattern: "/api/*/callback",
|
||||
input: "/api/v1/v2/callback",
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "single wildcard at end",
|
||||
pattern: "/callback/*",
|
||||
input: "/callback/test",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "single wildcard at start",
|
||||
pattern: "/*/callback",
|
||||
input: "/api/callback",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "multiple single wildcards",
|
||||
pattern: "/*/test/*",
|
||||
input: "/api/test/callback",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "partial wildcard prefix",
|
||||
pattern: "/test*",
|
||||
input: "/testing",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "partial wildcard suffix",
|
||||
pattern: "/*-callback",
|
||||
input: "/oauth-callback",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "partial wildcard middle",
|
||||
pattern: "/api-*-v1",
|
||||
input: "/api-internal-v1",
|
||||
shouldMatch: true,
|
||||
},
|
||||
|
||||
// Double wildcard (**)
|
||||
{
|
||||
name: "double wildcard matches multiple segments",
|
||||
pattern: "/api/**/callback",
|
||||
input: "/api/v1/v2/v3/callback",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "double wildcard matches single segment",
|
||||
pattern: "/api/**/callback",
|
||||
input: "/api/v1/callback",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "double wildcard doesn't match when pattern has extra slashes",
|
||||
pattern: "/api/**/callback",
|
||||
input: "/api/callback",
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "double wildcard at end",
|
||||
pattern: "/api/**",
|
||||
input: "/api/v1/v2/callback",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "double wildcard in middle",
|
||||
pattern: "/api/**/v2/**/callback",
|
||||
input: "/api/v1/v2/v3/v4/callback",
|
||||
shouldMatch: true,
|
||||
},
|
||||
|
||||
// Complex patterns
|
||||
{
|
||||
name: "mix of single and double wildcards",
|
||||
pattern: "/*/api/**/callback",
|
||||
input: "/app/api/v1/v2/callback",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard with special characters",
|
||||
pattern: "/callback-*",
|
||||
input: "/callback-123",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "path with query-like string (no special handling)",
|
||||
pattern: "/callback?code=*",
|
||||
input: "/callback?code=abc",
|
||||
shouldMatch: true,
|
||||
},
|
||||
|
||||
// Edge cases
|
||||
{
|
||||
name: "single wildcard matches empty segment",
|
||||
pattern: "/api/*/callback",
|
||||
input: "/api//callback",
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "pattern longer than input",
|
||||
pattern: "/api/v1/callback",
|
||||
input: "/api",
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "input longer than pattern",
|
||||
pattern: "/api",
|
||||
input: "/api/v1/callback",
|
||||
shouldMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
matches, err := matchPath(tt.pattern, tt.input)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.shouldMatch, matches)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitParts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedParts []string
|
||||
expectedPath string
|
||||
}{
|
||||
{
|
||||
name: "simple https URL",
|
||||
input: "https://example.com/callback",
|
||||
expectedParts: []string{"https", "example", "com"},
|
||||
expectedPath: "/callback",
|
||||
},
|
||||
{
|
||||
name: "URL with port",
|
||||
input: "https://example.com:8080/callback",
|
||||
expectedParts: []string{"https", "example", "com", "8080"},
|
||||
expectedPath: "/callback",
|
||||
},
|
||||
{
|
||||
name: "URL with subdomain",
|
||||
input: "https://api.example.com/callback",
|
||||
expectedParts: []string{"https", "api", "example", "com"},
|
||||
expectedPath: "/callback",
|
||||
},
|
||||
{
|
||||
name: "URL with credentials",
|
||||
input: "https://user:pass@example.com/callback",
|
||||
expectedParts: []string{"https", "user", "pass", "example", "com"},
|
||||
expectedPath: "/callback",
|
||||
},
|
||||
{
|
||||
name: "URL without path",
|
||||
input: "https://example.com",
|
||||
expectedParts: []string{"https", "example", "com"},
|
||||
expectedPath: "",
|
||||
},
|
||||
{
|
||||
name: "URL with deep path",
|
||||
input: "https://example.com/api/v1/callback",
|
||||
expectedParts: []string{"https", "example", "com"},
|
||||
expectedPath: "/api/v1/callback",
|
||||
},
|
||||
{
|
||||
name: "URL with path and query",
|
||||
input: "https://example.com/callback?code=123",
|
||||
expectedParts: []string{"https", "example", "com"},
|
||||
expectedPath: "/callback?code=123",
|
||||
},
|
||||
{
|
||||
name: "URL with trailing slash",
|
||||
input: "https://example.com/",
|
||||
expectedParts: []string{"https", "example", "com"},
|
||||
expectedPath: "/",
|
||||
},
|
||||
{
|
||||
name: "URL with multiple subdomains",
|
||||
input: "https://api.v1.staging.example.com/callback",
|
||||
expectedParts: []string{"https", "api", "v1", "staging", "example", "com"},
|
||||
expectedPath: "/callback",
|
||||
},
|
||||
{
|
||||
name: "URL with port and credentials",
|
||||
input: "https://user:pass@example.com:8080/callback",
|
||||
expectedParts: []string{"https", "user", "pass", "example", "com", "8080"},
|
||||
expectedPath: "/callback",
|
||||
},
|
||||
{
|
||||
name: "scheme with authority separator but no slash",
|
||||
input: "http://example.com",
|
||||
expectedParts: []string{"http", "example", "com"},
|
||||
expectedPath: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parts, path := splitParts(tt.input)
|
||||
assert.Equal(t, tt.expectedParts, parts, "parts mismatch")
|
||||
assert.Equal(t, tt.expectedPath, path, "path mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
130
backend/internal/utils/db_migration_util.go
Normal file
130
backend/internal/utils/db_migration_util.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database"
|
||||
postgresMigrate "github.com/golang-migrate/migrate/v4/database/postgres"
|
||||
sqliteMigrate "github.com/golang-migrate/migrate/v4/database/sqlite3"
|
||||
"github.com/golang-migrate/migrate/v4/source/iofs"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/resources"
|
||||
)
|
||||
|
||||
// MigrateDatabase applies database migrations using embedded migration files or fetches them from GitHub if a downgrade is detected.
|
||||
func MigrateDatabase(sqlDb *sql.DB) error {
|
||||
m, err := GetEmbeddedMigrateInstance(sqlDb)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get migrate instance: %w", err)
|
||||
}
|
||||
|
||||
path := "migrations/" + string(common.EnvConfig.DbProvider)
|
||||
requiredVersion, err := getRequiredMigrationVersion(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get last migration version: %w", err)
|
||||
}
|
||||
|
||||
currentVersion, _, _ := m.Version()
|
||||
if currentVersion > requiredVersion {
|
||||
slog.Warn("Database version is newer than the application supports, possible downgrade detected", slog.Uint64("db_version", uint64(currentVersion)), slog.Uint64("app_version", uint64(requiredVersion)))
|
||||
if !common.EnvConfig.AllowDowngrade {
|
||||
return fmt.Errorf("database version (%d) is newer than application version (%d), downgrades are not allowed (set ALLOW_DOWNGRADE=true to enable)", currentVersion, requiredVersion)
|
||||
}
|
||||
slog.Info("Fetching migrations from GitHub to handle possible downgrades")
|
||||
return migrateDatabaseFromGitHub(sqlDb, requiredVersion)
|
||||
}
|
||||
|
||||
if err := m.Migrate(requiredVersion); err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
||||
return fmt.Errorf("failed to apply embedded migrations: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetEmbeddedMigrateInstance creates a migrate.Migrate instance using embedded migration files.
|
||||
func GetEmbeddedMigrateInstance(sqlDb *sql.DB) (*migrate.Migrate, error) {
|
||||
path := "migrations/" + string(common.EnvConfig.DbProvider)
|
||||
source, err := iofs.New(resources.FS, path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create embedded migration source: %w", err)
|
||||
}
|
||||
|
||||
driver, err := newMigrationDriver(sqlDb, common.EnvConfig.DbProvider)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create migration driver: %w", err)
|
||||
}
|
||||
|
||||
m, err := migrate.NewWithInstance("iofs", source, "pocket-id", driver)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create migration instance: %w", err)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// newMigrationDriver creates a database.Driver instance based on the given database provider.
|
||||
func newMigrationDriver(sqlDb *sql.DB, dbProvider common.DbProvider) (driver database.Driver, err error) {
|
||||
switch dbProvider {
|
||||
case common.DbProviderSqlite:
|
||||
driver, err = sqliteMigrate.WithInstance(sqlDb, &sqliteMigrate.Config{
|
||||
NoTxWrap: true,
|
||||
})
|
||||
case common.DbProviderPostgres:
|
||||
driver, err = postgresMigrate.WithInstance(sqlDb, &postgresMigrate.Config{})
|
||||
default:
|
||||
// Should never happen at this point
|
||||
return nil, fmt.Errorf("unsupported database provider: %s", common.EnvConfig.DbProvider)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create migration driver: %w", err)
|
||||
}
|
||||
|
||||
return driver, nil
|
||||
}
|
||||
|
||||
// migrateDatabaseFromGitHub applies database migrations fetched from GitHub to handle downgrades.
|
||||
func migrateDatabaseFromGitHub(sqlDb *sql.DB, version uint) error {
|
||||
srcURL := "github://pocket-id/pocket-id/backend/resources/migrations/" + string(common.EnvConfig.DbProvider)
|
||||
|
||||
driver, err := newMigrationDriver(sqlDb, common.EnvConfig.DbProvider)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create migration driver: %w", err)
|
||||
}
|
||||
|
||||
m, err := migrate.NewWithDatabaseInstance(srcURL, "pocket-id", driver)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create GitHub migration instance: %w", err)
|
||||
}
|
||||
|
||||
if err := m.Migrate(version); err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
||||
return fmt.Errorf("failed to apply GitHub migrations: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getRequiredMigrationVersion reads the embedded migration files and returns the highest version number found.
|
||||
func getRequiredMigrationVersion(path string) (uint, error) {
|
||||
entries, err := resources.FS.ReadDir(path)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to read migration directory: %w", err)
|
||||
}
|
||||
|
||||
var maxVersion uint
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
var version uint
|
||||
n, err := fmt.Sscanf(name, "%d_", &version)
|
||||
if err == nil && n == 1 {
|
||||
if version > maxVersion {
|
||||
maxVersion = version
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return maxVersion, nil
|
||||
}
|
||||
116
backend/internal/utils/db_util.go
Normal file
116
backend/internal/utils/db_util.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// DBTableExists checks if a table exists in the database
|
||||
func DBTableExists(db *gorm.DB, tableName string) (exists bool, err error) {
|
||||
switch db.Name() {
|
||||
case "postgres":
|
||||
query := `SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name = ?
|
||||
)`
|
||||
err = db.Raw(query, tableName).Scan(&exists).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
case "sqlite":
|
||||
query := `SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='table' AND name=?`
|
||||
err = db.Raw(query, tableName).Scan(&exists).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
default:
|
||||
return false, fmt.Errorf("unsupported database dialect: %s", db.Name())
|
||||
}
|
||||
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
type DBSchemaColumn struct {
|
||||
Name string
|
||||
Nullable bool
|
||||
}
|
||||
type DBSchemaTableTypes = map[string]DBSchemaColumn
|
||||
type DBSchemaTypes = map[string]DBSchemaTableTypes
|
||||
|
||||
// LoadDBSchemaTypes retrieves the column types for all tables in the DB
|
||||
// Result is a map of "table --> column --> {name: column type name, nullable: boolean}"
|
||||
func LoadDBSchemaTypes(db *gorm.DB) (result DBSchemaTypes, err error) {
|
||||
result = make(DBSchemaTypes)
|
||||
|
||||
switch db.Name() {
|
||||
case "postgres":
|
||||
var rows []struct {
|
||||
TableName string
|
||||
ColumnName string
|
||||
DataType string
|
||||
Nullable bool
|
||||
}
|
||||
err := db.
|
||||
Raw(`
|
||||
SELECT table_name, column_name, data_type, is_nullable = 'YES' AS nullable
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = 'public';
|
||||
`).
|
||||
Scan(&rows).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, r := range rows {
|
||||
t := strings.ToLower(r.DataType)
|
||||
if result[r.TableName] == nil {
|
||||
result[r.TableName] = make(map[string]DBSchemaColumn)
|
||||
}
|
||||
result[r.TableName][r.ColumnName] = DBSchemaColumn{
|
||||
Name: strings.ToLower(t),
|
||||
Nullable: r.Nullable,
|
||||
}
|
||||
}
|
||||
|
||||
case "sqlite":
|
||||
var tables []string
|
||||
err = db.
|
||||
Raw(`SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';`).
|
||||
Scan(&tables).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, table := range tables {
|
||||
var cols []struct {
|
||||
Name string
|
||||
Type string
|
||||
Notnull bool
|
||||
}
|
||||
err := db.
|
||||
Raw(`PRAGMA table_info("` + table + `");`).
|
||||
Scan(&cols).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, c := range cols {
|
||||
if result[table] == nil {
|
||||
result[table] = make(map[string]DBSchemaColumn)
|
||||
}
|
||||
result[table][c.Name] = DBSchemaColumn{
|
||||
Name: strings.ToLower(c.Type),
|
||||
Nullable: !c.Notnull,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported database dialect: %s", db.Name())
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -28,22 +28,14 @@ func GetKeyProvider(db *gorm.DB, envConfig *common.EnvConfigSchema, instanceID s
|
||||
return nil, fmt.Errorf("failed to load encryption key: %w", err)
|
||||
}
|
||||
|
||||
// Get the key provider
|
||||
switch envConfig.KeysStorage {
|
||||
case "file", "":
|
||||
keyProvider = &KeyProviderFile{}
|
||||
case "database":
|
||||
keyProvider = &KeyProviderDatabase{}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid key storage '%s'", envConfig.KeysStorage)
|
||||
}
|
||||
keyProvider = &KeyProviderDatabase{}
|
||||
err = keyProvider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
EnvConfig: envConfig,
|
||||
Kek: kek,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to init key provider of type '%s': %w", envConfig.KeysStorage, err)
|
||||
return nil, fmt.Errorf("failed to init key provider: %w", err)
|
||||
}
|
||||
|
||||
return keyProvider, nil
|
||||
|
||||
@@ -1,202 +0,0 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
|
||||
)
|
||||
|
||||
const (
|
||||
// PrivateKeyFile is the path in the data/keys folder where the key is stored
|
||||
// This is a JSON file containing a key encoded as JWK
|
||||
PrivateKeyFile = "jwt_private_key.json"
|
||||
|
||||
// PrivateKeyFileEncrypted is the path in the data/keys folder where the encrypted key is stored
|
||||
// This is a encrypted JSON file containing a key encoded as JWK
|
||||
PrivateKeyFileEncrypted = "jwt_private_key.json.enc"
|
||||
)
|
||||
|
||||
type KeyProviderFile struct {
|
||||
envConfig *common.EnvConfigSchema
|
||||
kek []byte
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) Init(opts KeyProviderOpts) error {
|
||||
f.envConfig = opts.EnvConfig
|
||||
f.kek = opts.Kek
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) LoadKey() (jwk.Key, error) {
|
||||
if len(f.kek) > 0 {
|
||||
return f.loadEncryptedKey()
|
||||
}
|
||||
return f.loadKey()
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) SaveKey(key jwk.Key) error {
|
||||
if len(f.kek) > 0 {
|
||||
return f.saveKeyEncrypted(key)
|
||||
}
|
||||
return f.saveKey(key)
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) loadKey() (jwk.Key, error) {
|
||||
var key jwk.Key
|
||||
|
||||
// First, check if we have a JWK file
|
||||
// If we do, then we just load that
|
||||
jwkPath := f.jwkPath()
|
||||
ok, err := utils.FileExists(jwkPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if private key file exists at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
if !ok {
|
||||
// File doesn't exist, no key was loaded
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(jwkPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read private key file at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
|
||||
key, err = jwk.ParseKey(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse private key file at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) loadEncryptedKey() (key jwk.Key, err error) {
|
||||
// First, check if we have an encrypted JWK file
|
||||
// If we do, then we just load that
|
||||
encJwkPath := f.encJwkPath()
|
||||
ok, err := utils.FileExists(encJwkPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if encrypted private key file exists at path '%s': %w", encJwkPath, err)
|
||||
}
|
||||
if ok {
|
||||
encB64, err := os.ReadFile(encJwkPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read encrypted private key file at path '%s': %w", encJwkPath, err)
|
||||
}
|
||||
|
||||
// Decode from base64
|
||||
enc := make([]byte, base64.StdEncoding.DecodedLen(len(encB64)))
|
||||
n, err := base64.StdEncoding.Decode(enc, encB64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read encrypted private key file at path '%s': not a valid base64-encoded file: %w", encJwkPath, err)
|
||||
}
|
||||
|
||||
// Decrypt the data
|
||||
data, err := cryptoutils.Decrypt(f.kek, enc[:n], nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt private key file at path '%s': %w", encJwkPath, err)
|
||||
}
|
||||
|
||||
// Parse the key
|
||||
key, err = jwk.ParseKey(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse encrypted private key file at path '%s': %w", encJwkPath, err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// Check if we have an un-encrypted JWK file
|
||||
key, err = f.loadKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load un-encrypted key file: %w", err)
|
||||
}
|
||||
if key == nil {
|
||||
// No key exists, encrypted or un-encrypted
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// If we are here, we have loaded a key that was un-encrypted
|
||||
// We need to replace the plaintext key with the encrypted one before we return
|
||||
err = f.saveKeyEncrypted(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to save encrypted key file: %w", err)
|
||||
}
|
||||
jwkPath := f.jwkPath()
|
||||
err = os.Remove(jwkPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to remove un-encrypted key file at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) saveKey(key jwk.Key) error {
|
||||
err := os.MkdirAll(f.envConfig.KeysPath, 0700)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create directory '%s' for key file: %w", f.envConfig.KeysPath, err)
|
||||
}
|
||||
|
||||
jwkPath := f.jwkPath()
|
||||
keyFile, err := os.OpenFile(jwkPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create key file at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
defer keyFile.Close()
|
||||
|
||||
// Write the JSON file to disk
|
||||
err = EncodeJWK(keyFile, key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write key file at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) saveKeyEncrypted(key jwk.Key) error {
|
||||
err := os.MkdirAll(f.envConfig.KeysPath, 0700)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create directory '%s' for encrypted key file: %w", f.envConfig.KeysPath, err)
|
||||
}
|
||||
|
||||
// Encode the key to JSON
|
||||
data, err := EncodeJWKBytes(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode key to JSON: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt the key then encode to Base64
|
||||
enc, err := cryptoutils.Encrypt(f.kek, data, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt key: %w", err)
|
||||
}
|
||||
encB64 := make([]byte, base64.StdEncoding.EncodedLen(len(enc)))
|
||||
base64.StdEncoding.Encode(encB64, enc)
|
||||
|
||||
// Write to disk
|
||||
encJwkPath := f.encJwkPath()
|
||||
err = os.WriteFile(encJwkPath, encB64, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write encrypted key file at path '%s': %w", encJwkPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) jwkPath() string {
|
||||
return filepath.Join(f.envConfig.KeysPath, PrivateKeyFile)
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) encJwkPath() string {
|
||||
return filepath.Join(f.envConfig.KeysPath, PrivateKeyFileEncrypted)
|
||||
}
|
||||
|
||||
// Compile-time interface check
|
||||
var _ KeyProvider = (*KeyProviderFile)(nil)
|
||||
@@ -1,320 +0,0 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
|
||||
)
|
||||
|
||||
func TestKeyProviderFile_LoadKey(t *testing.T) {
|
||||
// Generate a test key to use in our tests
|
||||
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := jwk.Import(pk)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("LoadKey with no existing key", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
provider := &KeyProviderFile{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Load key when none exists
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, loadedKey, "Expected nil key when no key exists")
|
||||
})
|
||||
|
||||
t.Run("LoadKey with no existing key (with kek)", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
provider := &KeyProviderFile{}
|
||||
err = provider.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
Kek: makeKEK(t),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Load key when none exists
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, loadedKey, "Expected nil key when no key exists")
|
||||
})
|
||||
|
||||
t.Run("LoadKey with unencrypted key", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
provider := &KeyProviderFile{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save a key
|
||||
err = provider.SaveKey(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Make sure the key file exists
|
||||
keyPath := filepath.Join(tempDir, PrivateKeyFile)
|
||||
exists, err := utils.FileExists(keyPath)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Expected key file to exist")
|
||||
|
||||
// Load the key
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, loadedKey, "Expected non-nil key when key exists")
|
||||
|
||||
// Verify the loaded key is the same as the original
|
||||
keyBytes, err := EncodeJWKBytes(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key")
|
||||
})
|
||||
|
||||
t.Run("LoadKey with encrypted key", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
provider := &KeyProviderFile{}
|
||||
err = provider.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
Kek: makeKEK(t),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save a key (will be encrypted)
|
||||
err = provider.SaveKey(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Make sure the encrypted key file exists
|
||||
encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted)
|
||||
exists, err := utils.FileExists(encKeyPath)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Expected encrypted key file to exist")
|
||||
|
||||
// Make sure the unencrypted key file does not exist
|
||||
keyPath := filepath.Join(tempDir, PrivateKeyFile)
|
||||
exists, err = utils.FileExists(keyPath)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Expected unencrypted key file to not exist")
|
||||
|
||||
// Load the key
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, loadedKey, "Expected non-nil key when encrypted key exists")
|
||||
|
||||
// Verify the loaded key is the same as the original
|
||||
keyBytes, err := EncodeJWKBytes(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key")
|
||||
})
|
||||
|
||||
t.Run("LoadKey replaces unencrypted key with encrypted key when kek is provided", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// First, create an unencrypted key
|
||||
providerNoKek := &KeyProviderFile{}
|
||||
err := providerNoKek.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save an unencrypted key
|
||||
err = providerNoKek.SaveKey(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify unencrypted key exists
|
||||
keyPath := filepath.Join(tempDir, PrivateKeyFile)
|
||||
exists, err := utils.FileExists(keyPath)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Expected unencrypted key file to exist")
|
||||
|
||||
// Now create a provider with a kek
|
||||
kek := make([]byte, 32)
|
||||
_, err = rand.Read(kek)
|
||||
require.NoError(t, err)
|
||||
|
||||
providerWithKek := &KeyProviderFile{}
|
||||
err = providerWithKek.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
Kek: kek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Load the key - this should convert the unencrypted key to encrypted
|
||||
loadedKey, err := providerWithKek.LoadKey()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, loadedKey, "Expected non-nil key when loading and converting key")
|
||||
|
||||
// Verify the unencrypted key no longer exists
|
||||
exists, err = utils.FileExists(keyPath)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Expected unencrypted key file to be removed")
|
||||
|
||||
// Verify the encrypted key file exists
|
||||
encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted)
|
||||
exists, err = utils.FileExists(encKeyPath)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Expected encrypted key file to exist after conversion")
|
||||
|
||||
// Verify the key data
|
||||
keyBytes, err := EncodeJWKBytes(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key after conversion")
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeyProviderFile_SaveKey(t *testing.T) {
|
||||
// Generate a test key to use in our tests
|
||||
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := jwk.Import(pk)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("SaveKey unencrypted", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
provider := &KeyProviderFile{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save the key
|
||||
err = provider.SaveKey(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the key file exists
|
||||
keyPath := filepath.Join(tempDir, PrivateKeyFile)
|
||||
exists, err := utils.FileExists(keyPath)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Expected key file to exist")
|
||||
|
||||
// Verify the content of the key file
|
||||
data, err := os.ReadFile(keyPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
parsedKey, err := jwk.ParseKey(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Compare the saved key with the original
|
||||
keyBytes, err := EncodeJWKBytes(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
parsedKeyBytes, err := EncodeJWKBytes(parsedKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, keyBytes, parsedKeyBytes, "Expected saved key to match original key")
|
||||
})
|
||||
|
||||
t.Run("SaveKey encrypted", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Generate a 64-byte kek
|
||||
kek := makeKEK(t)
|
||||
|
||||
provider := &KeyProviderFile{}
|
||||
err = provider.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
Kek: kek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save the key (will be encrypted)
|
||||
err = provider.SaveKey(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the encrypted key file exists
|
||||
encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted)
|
||||
exists, err := utils.FileExists(encKeyPath)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Expected encrypted key file to exist")
|
||||
|
||||
// Verify the unencrypted key file doesn't exist
|
||||
keyPath := filepath.Join(tempDir, PrivateKeyFile)
|
||||
exists, err = utils.FileExists(keyPath)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Expected unencrypted key file to not exist")
|
||||
|
||||
// Manually decrypt the encrypted key file to verify it contains the correct key
|
||||
encB64, err := os.ReadFile(encKeyPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Decode from base64
|
||||
enc := make([]byte, base64.StdEncoding.DecodedLen(len(encB64)))
|
||||
n, err := base64.StdEncoding.Decode(enc, encB64)
|
||||
require.NoError(t, err)
|
||||
enc = enc[:n] // Trim any padding
|
||||
|
||||
// Decrypt the data
|
||||
data, err := cryptoutils.Decrypt(kek, enc, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse the key
|
||||
parsedKey, err := jwk.ParseKey(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Compare the decrypted key with the original
|
||||
keyBytes, err := EncodeJWKBytes(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
parsedKeyBytes, err := EncodeJWKBytes(parsedKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, keyBytes, parsedKeyBytes, "Expected decrypted key to match original key")
|
||||
})
|
||||
}
|
||||
|
||||
func makeKEK(t *testing.T) []byte {
|
||||
t.Helper()
|
||||
|
||||
// Generate a 32-byte kek
|
||||
kek := make([]byte, 32)
|
||||
_, err := rand.Read(kek)
|
||||
require.NoError(t, err)
|
||||
return kek
|
||||
}
|
||||
@@ -38,6 +38,7 @@ func (r *ServiceRunner) Run(ctx context.Context) error {
|
||||
|
||||
// Ignore context canceled errors here as they generally indicate that the service is stopping
|
||||
if rErr != nil && !errors.Is(rErr, context.Canceled) {
|
||||
cancel()
|
||||
errCh <- rErr
|
||||
return
|
||||
}
|
||||
|
||||
@@ -61,6 +61,26 @@ func TestServiceRunner_Run(t *testing.T) {
|
||||
require.ErrorIs(t, err, expectedErr)
|
||||
})
|
||||
|
||||
t.Run("service error cancels others", func(t *testing.T) {
|
||||
expectedErr := errors.New("boom")
|
||||
errorService := func(ctx context.Context) error {
|
||||
return expectedErr
|
||||
}
|
||||
waitingService := func(ctx context.Context) error {
|
||||
<-ctx.Done()
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
runner := NewServiceRunner(errorService, waitingService)
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := runner.Run(ctx)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, expectedErr)
|
||||
})
|
||||
|
||||
t.Run("context canceled", func(t *testing.T) {
|
||||
// Create a service that waits until context is canceled
|
||||
waitingService := func(ctx context.Context) error {
|
||||
|
||||
1
backend/resources/e2e-tests/database.json
Symbolic link
1
backend/resources/e2e-tests/database.json
Symbolic link
@@ -0,0 +1 @@
|
||||
../../../tests/database.json
|
||||
72
backend/resources/files_test.go
Normal file
72
backend/resources/files_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package resources
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// This test is meant to enforce that for every new migration added, a file with the same migration number exists for all supported databases
|
||||
// This is necessary to ensure import/export works correctly
|
||||
// Note: if a migration is not needed for a database, ensure there's a file with an empty (no-op) migration (e.g. even just a comment)
|
||||
func TestMigrationsMatchingVersions(t *testing.T) {
|
||||
// We can ignore migrations with version below 20251115000000
|
||||
const ignoreBefore = 20251115000000
|
||||
|
||||
// Scan postgres migrations
|
||||
postgresMigrations := scanMigrations(t, FS, "migrations/postgres", ignoreBefore)
|
||||
|
||||
// Scan sqlite migrations
|
||||
sqliteMigrations := scanMigrations(t, FS, "migrations/sqlite", ignoreBefore)
|
||||
|
||||
// Sort both lists for consistent comparison
|
||||
slices.Sort(postgresMigrations)
|
||||
slices.Sort(sqliteMigrations)
|
||||
|
||||
// Compare the lists
|
||||
assert.Equal(t, postgresMigrations, sqliteMigrations, "Migration versions must match between Postgres and SQLite")
|
||||
}
|
||||
|
||||
// scanMigrations scans a directory for migration files and returns a list of versions
|
||||
func scanMigrations(t *testing.T, fs embed.FS, dir string, ignoreBefore int64) []int64 {
|
||||
t.Helper()
|
||||
|
||||
entries, err := fs.ReadDir(dir)
|
||||
require.NoErrorf(t, err, "Failed to read directory '%s'", dir)
|
||||
|
||||
// Divide by 2 because of up and down files
|
||||
versions := make([]int64, 0, len(entries)/2)
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
filename := entry.Name()
|
||||
|
||||
// Only consider .up.sql files
|
||||
if !strings.HasSuffix(filename, ".up.sql") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract version from filename (format: <version>_<anything>.up.sql)
|
||||
versionString, _, ok := strings.Cut(filename, "_")
|
||||
require.Truef(t, ok, "Migration file has unexpected format: %s", filename)
|
||||
|
||||
version, err := strconv.ParseInt(versionString, 10, 64)
|
||||
require.NoErrorf(t, err, "Failed to parse version from filename '%s'", filename)
|
||||
|
||||
// Exclude migrations with version below ignoreBefore
|
||||
if version < ignoreBefore {
|
||||
continue
|
||||
}
|
||||
|
||||
versions = append(versions, version)
|
||||
}
|
||||
|
||||
return versions
|
||||
}
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 221 KiB After Width: | Height: | Size: 291 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 566 B |
@@ -1 +0,0 @@
|
||||
DROP TABLE signup_tokens_user_groups;
|
||||
@@ -1,8 +0,0 @@
|
||||
CREATE TABLE signup_tokens_user_groups
|
||||
(
|
||||
signup_token_id UUID NOT NULL,
|
||||
user_group_id UUID NOT NULL,
|
||||
PRIMARY KEY (signup_token_id, user_group_id),
|
||||
FOREIGN KEY (signup_token_id) REFERENCES signup_tokens (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (user_group_id) REFERENCES user_groups (id) ON DELETE CASCADE
|
||||
);
|
||||
@@ -0,0 +1,3 @@
|
||||
-- This migration is part of v2
|
||||
|
||||
-- No-op in Postgres
|
||||
@@ -0,0 +1,3 @@
|
||||
-- This migration is part of v2
|
||||
|
||||
-- No-op in Postgres
|
||||
@@ -1,7 +1 @@
|
||||
PRAGMA foreign_keys=OFF;
|
||||
BEGIN;
|
||||
|
||||
ALTER TABLE one_time_access_tokens DROP COLUMN device_token;
|
||||
|
||||
COMMIT;
|
||||
PRAGMA foreign_keys=ON;
|
||||
ALTER TABLE one_time_access_tokens DROP COLUMN device_token;
|
||||
@@ -1,7 +1 @@
|
||||
PRAGMA foreign_keys=OFF;
|
||||
BEGIN;
|
||||
|
||||
ALTER TABLE one_time_access_tokens ADD COLUMN device_token TEXT;
|
||||
|
||||
COMMIT;
|
||||
PRAGMA foreign_keys=ON;
|
||||
ALTER TABLE one_time_access_tokens ADD COLUMN device_token TEXT;
|
||||
@@ -1,7 +0,0 @@
|
||||
PRAGMA foreign_keys=OFF;
|
||||
BEGIN;
|
||||
|
||||
DROP TABLE signup_tokens_user_groups;
|
||||
|
||||
COMMIT;
|
||||
PRAGMA foreign_keys=ON;
|
||||
@@ -1,14 +0,0 @@
|
||||
PRAGMA foreign_keys=OFF;
|
||||
BEGIN;
|
||||
|
||||
CREATE TABLE signup_tokens_user_groups
|
||||
(
|
||||
signup_token_id TEXT NOT NULL,
|
||||
user_group_id TEXT NOT NULL,
|
||||
PRIMARY KEY (signup_token_id, user_group_id),
|
||||
FOREIGN KEY (signup_token_id) REFERENCES signup_tokens (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (user_group_id) REFERENCES user_groups (id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
COMMIT;
|
||||
PRAGMA foreign_keys=ON;
|
||||
@@ -0,0 +1,135 @@
|
||||
-- This migration is part of v2
|
||||
|
||||
PRAGMA foreign_keys = OFF;
|
||||
|
||||
BEGIN;
|
||||
|
||||
CREATE TABLE users_old
|
||||
(
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
created_at DATETIME,
|
||||
username TEXT COLLATE NOCASE NOT NULL UNIQUE,
|
||||
email TEXT NOT NULL UNIQUE,
|
||||
first_name TEXT,
|
||||
last_name TEXT NOT NULL,
|
||||
display_name TEXT NOT NULL,
|
||||
is_admin NUMERIC DEFAULT 0 NOT NULL,
|
||||
ldap_id TEXT,
|
||||
locale TEXT,
|
||||
disabled NUMERIC DEFAULT 0 NOT NULL
|
||||
);
|
||||
|
||||
INSERT INTO users_old (
|
||||
id,
|
||||
created_at,
|
||||
username,
|
||||
email,
|
||||
first_name,
|
||||
last_name,
|
||||
display_name,
|
||||
is_admin,
|
||||
ldap_id,
|
||||
locale,
|
||||
disabled
|
||||
)
|
||||
SELECT
|
||||
id,
|
||||
created_at,
|
||||
username,
|
||||
email,
|
||||
first_name,
|
||||
last_name,
|
||||
display_name,
|
||||
CASE WHEN is_admin THEN 1 ELSE 0 END,
|
||||
ldap_id,
|
||||
locale,
|
||||
CASE WHEN disabled THEN 1 ELSE 0 END
|
||||
FROM users;
|
||||
|
||||
DROP TABLE users;
|
||||
|
||||
ALTER TABLE users_old RENAME TO users;
|
||||
|
||||
CREATE UNIQUE INDEX users_ldap_id ON users (ldap_id);
|
||||
|
||||
|
||||
|
||||
CREATE TABLE webauthn_credentials_old
|
||||
(
|
||||
id TEXT PRIMARY KEY,
|
||||
created_at DATETIME NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
credential_id TEXT NOT NULL UNIQUE,
|
||||
public_key BLOB NOT NULL,
|
||||
attestation_type TEXT NOT NULL,
|
||||
transport BLOB NOT NULL,
|
||||
user_id TEXT REFERENCES users ON DELETE CASCADE,
|
||||
backup_eligible NUMERIC DEFAULT 0 NOT NULL,
|
||||
backup_state NUMERIC DEFAULT 0 NOT NULL
|
||||
);
|
||||
|
||||
INSERT INTO webauthn_credentials_old (
|
||||
id,
|
||||
created_at,
|
||||
name,
|
||||
credential_id,
|
||||
public_key,
|
||||
attestation_type,
|
||||
transport,
|
||||
user_id,
|
||||
backup_eligible,
|
||||
backup_state
|
||||
)
|
||||
SELECT
|
||||
id,
|
||||
created_at,
|
||||
name,
|
||||
credential_id,
|
||||
public_key,
|
||||
attestation_type,
|
||||
transport,
|
||||
user_id,
|
||||
CASE WHEN backup_eligible THEN 1 ELSE 0 END,
|
||||
CASE WHEN backup_state THEN 1 ELSE 0 END
|
||||
FROM webauthn_credentials;
|
||||
|
||||
DROP TABLE webauthn_credentials;
|
||||
|
||||
ALTER TABLE webauthn_credentials_old RENAME TO webauthn_credentials;
|
||||
|
||||
|
||||
|
||||
CREATE TABLE webauthn_sessions_old
|
||||
(
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
created_at DATETIME,
|
||||
challenge TEXT NOT NULL UNIQUE,
|
||||
expires_at DATETIME NOT NULL,
|
||||
user_verification TEXT NOT NULL,
|
||||
credential_params TEXT DEFAULT '[]' NOT NULL
|
||||
);
|
||||
|
||||
INSERT INTO webauthn_sessions_old (
|
||||
id,
|
||||
created_at,
|
||||
challenge,
|
||||
expires_at,
|
||||
user_verification,
|
||||
credential_params
|
||||
)
|
||||
SELECT
|
||||
id,
|
||||
created_at,
|
||||
challenge,
|
||||
expires_at,
|
||||
user_verification,
|
||||
credential_params
|
||||
FROM webauthn_sessions;
|
||||
|
||||
DROP TABLE webauthn_sessions;
|
||||
|
||||
ALTER TABLE webauthn_sessions_old RENAME TO webauthn_sessions;
|
||||
|
||||
COMMIT;
|
||||
|
||||
PRAGMA foreign_keys = ON;
|
||||
@@ -0,0 +1,146 @@
|
||||
-- This migration is part of v2
|
||||
|
||||
PRAGMA foreign_keys = OFF;
|
||||
|
||||
BEGIN;
|
||||
|
||||
-- 1. Create a new table with BOOLEAN columns
|
||||
CREATE TABLE users_new
|
||||
(
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
created_at DATETIME,
|
||||
username TEXT COLLATE NOCASE NOT NULL UNIQUE,
|
||||
email TEXT NOT NULL UNIQUE,
|
||||
first_name TEXT,
|
||||
last_name TEXT NOT NULL,
|
||||
display_name TEXT NOT NULL,
|
||||
is_admin BOOLEAN DEFAULT FALSE NOT NULL,
|
||||
ldap_id TEXT,
|
||||
locale TEXT,
|
||||
disabled BOOLEAN DEFAULT FALSE NOT NULL
|
||||
);
|
||||
|
||||
-- 2. Copy all existing data, converting numeric bools to real booleans
|
||||
INSERT INTO users_new (
|
||||
id,
|
||||
created_at,
|
||||
username,
|
||||
email,
|
||||
first_name,
|
||||
last_name,
|
||||
display_name,
|
||||
is_admin,
|
||||
ldap_id,
|
||||
locale,
|
||||
disabled
|
||||
)
|
||||
SELECT
|
||||
id,
|
||||
created_at,
|
||||
username,
|
||||
email,
|
||||
first_name,
|
||||
last_name,
|
||||
display_name,
|
||||
CASE WHEN is_admin != 0 THEN TRUE ELSE FALSE END,
|
||||
ldap_id,
|
||||
locale,
|
||||
CASE WHEN disabled != 0 THEN TRUE ELSE FALSE END
|
||||
FROM users;
|
||||
|
||||
-- 3. Drop old table
|
||||
DROP TABLE users;
|
||||
|
||||
-- 4. Rename new table to original name
|
||||
ALTER TABLE users_new RENAME TO users;
|
||||
|
||||
-- 5. Recreate index
|
||||
CREATE UNIQUE INDEX users_ldap_id ON users (ldap_id);
|
||||
|
||||
-- 6. Create temporary table with changed credential_id type to BLOB
|
||||
CREATE TABLE webauthn_credentials_dg_tmp
|
||||
(
|
||||
id TEXT PRIMARY KEY,
|
||||
created_at DATETIME NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
credential_id BLOB NOT NULL UNIQUE,
|
||||
public_key BLOB NOT NULL,
|
||||
attestation_type TEXT NOT NULL,
|
||||
transport BLOB NOT NULL,
|
||||
user_id TEXT REFERENCES users ON DELETE CASCADE,
|
||||
backup_eligible BOOLEAN DEFAULT FALSE NOT NULL,
|
||||
backup_state BOOLEAN DEFAULT FALSE NOT NULL
|
||||
);
|
||||
|
||||
-- 7. Copy existing data into the temporary table
|
||||
INSERT INTO webauthn_credentials_dg_tmp (
|
||||
id,
|
||||
created_at,
|
||||
name,
|
||||
credential_id,
|
||||
public_key,
|
||||
attestation_type,
|
||||
transport,
|
||||
user_id,
|
||||
backup_eligible,
|
||||
backup_state
|
||||
)
|
||||
SELECT
|
||||
id,
|
||||
created_at,
|
||||
name,
|
||||
credential_id,
|
||||
public_key,
|
||||
attestation_type,
|
||||
transport,
|
||||
user_id,
|
||||
backup_eligible,
|
||||
backup_state
|
||||
FROM webauthn_credentials;
|
||||
|
||||
-- 8. Drop old table
|
||||
DROP TABLE webauthn_credentials;
|
||||
|
||||
-- 9. Rename temporary table to original name
|
||||
ALTER TABLE webauthn_credentials_dg_tmp
|
||||
RENAME TO webauthn_credentials;
|
||||
|
||||
-- 10. Create temporary table with credential_params type changed to BLOB
|
||||
CREATE TABLE webauthn_sessions_dg_tmp
|
||||
(
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
created_at DATETIME,
|
||||
challenge TEXT NOT NULL UNIQUE,
|
||||
expires_at DATETIME NOT NULL,
|
||||
user_verification TEXT NOT NULL,
|
||||
credential_params BLOB DEFAULT '[]' NOT NULL
|
||||
);
|
||||
|
||||
-- 11. Copy existing data into the temporary sessions table
|
||||
INSERT INTO webauthn_sessions_dg_tmp (
|
||||
id,
|
||||
created_at,
|
||||
challenge,
|
||||
expires_at,
|
||||
user_verification,
|
||||
credential_params
|
||||
)
|
||||
SELECT
|
||||
id,
|
||||
created_at,
|
||||
challenge,
|
||||
expires_at,
|
||||
user_verification,
|
||||
credential_params
|
||||
FROM webauthn_sessions;
|
||||
|
||||
-- 12. Drop old table
|
||||
DROP TABLE webauthn_sessions;
|
||||
|
||||
-- 13. Rename temporary sessions table to original name
|
||||
ALTER TABLE webauthn_sessions_dg_tmp
|
||||
RENAME TO webauthn_sessions;
|
||||
|
||||
COMMIT;
|
||||
|
||||
PRAGMA foreign_keys = ON;
|
||||
@@ -1,7 +0,0 @@
|
||||
PRAGMA foreign_keys=OFF;
|
||||
BEGIN;
|
||||
|
||||
ALTER TABLE oidc_clients DROP COLUMN is_group_restricted;
|
||||
|
||||
COMMIT;
|
||||
PRAGMA foreign_keys=ON;
|
||||
@@ -1,13 +0,0 @@
|
||||
PRAGMA foreign_keys= OFF;
|
||||
BEGIN;
|
||||
|
||||
ALTER TABLE oidc_clients
|
||||
ADD COLUMN is_group_restricted BOOLEAN NOT NULL DEFAULT 0;
|
||||
|
||||
UPDATE oidc_clients
|
||||
SET is_group_restricted = (SELECT CASE WHEN COUNT(*) > 0 THEN 1 ELSE 0 END
|
||||
FROM oidc_client_user_groups
|
||||
WHERE oidc_client_user_groups.oidc_client_id = oidc_clients.id);
|
||||
|
||||
COMMIT;
|
||||
PRAGMA foreign_keys= ON;
|
||||
@@ -95,7 +95,7 @@
|
||||
"settings": "Settings",
|
||||
"update_pocket_id": "Update Pocket ID",
|
||||
"powered_by": "Powered by",
|
||||
"see_your_recent_account_activities": "See your account activities within the configured retention period.",
|
||||
"see_your_account_activities_from_the_last_3_months": "See your account activities from the last 3 months.",
|
||||
"time": "Time",
|
||||
"event": "Event",
|
||||
"approximate_location": "Approximate Location",
|
||||
@@ -301,20 +301,16 @@
|
||||
"are_you_sure_you_want_to_create_a_new_client_secret": "Are you sure you want to create a new client secret? The old one will be invalidated.",
|
||||
"generate": "Generate",
|
||||
"new_client_secret_created_successfully": "New client secret created successfully",
|
||||
"allowed_user_groups_updated_successfully": "Allowed user groups updated successfully",
|
||||
"oidc_client_name": "OIDC Client {name}",
|
||||
"client_id": "Client ID",
|
||||
"client_secret": "Client secret",
|
||||
"show_more_details": "Show more details",
|
||||
"allowed_user_groups": "Allowed User Groups",
|
||||
"allowed_user_groups_description": "Select user groups to restrict signing in to this client to only users in these groups.",
|
||||
"allowed_user_groups_status_unrestricted_description": "No user group restrictions are applied. Any user can sign in to this client.",
|
||||
"unrestrict": "Unrestrict",
|
||||
"restrict": "Restrict",
|
||||
"user_groups_restriction_updated_successfully": "User groups restriction updated successfully",
|
||||
"add_user_groups_to_this_client_to_restrict_access_to_users_in_these_groups": "Add user groups to this client to restrict access to users in these groups. If no user groups are selected, all users will have access to this client.",
|
||||
"favicon": "Favicon",
|
||||
"light_mode_logo": "Light Mode Logo",
|
||||
"dark_mode_logo": "Dark Mode Logo",
|
||||
"email_logo": "Email Logo",
|
||||
"background_image": "Background Image",
|
||||
"language": "Language",
|
||||
"reset_profile_picture_question": "Reset profile picture?",
|
||||
@@ -331,7 +327,7 @@
|
||||
"all_clients": "All Clients",
|
||||
"all_locations": "All Locations",
|
||||
"global_audit_log": "Global Audit Log",
|
||||
"see_all_recent_account_activities": "View the account activities of all users during the set retention period.",
|
||||
"see_all_account_activities_from_the_last_3_months": "See all user activity for the last 3 months.",
|
||||
"token_sign_in": "Token Sign In",
|
||||
"client_authorization": "Client Authorization",
|
||||
"new_client_authorization": "New Client Authorization",
|
||||
@@ -353,8 +349,8 @@
|
||||
"login_code_email_success": "The login code has been sent to the user.",
|
||||
"send_email": "Send Email",
|
||||
"show_code": "Show Code",
|
||||
"callback_url_description": "URL(s) provided by your client. Will be automatically added if left blank. Wildcards (*) are supported, but best avoided for better security.",
|
||||
"logout_callback_url_description": "URL(s) provided by your client for logout. Wildcards (*) are supported, but best avoided for better security.",
|
||||
"callback_url_description": "URL(s) provided by your client. Will be automatically added if left blank. <link href='https://pocket-id.org/docs/advanced/callback-url-wildcards'>Wildcards</link> are supported.",
|
||||
"logout_callback_url_description": "URL(s) provided by your client for logout. <link href='https://pocket-id.org/docs/advanced/callback-url-wildcards'>Wildcards</link> are supported.",
|
||||
"api_key_expiration": "API Key Expiration",
|
||||
"send_an_email_to_the_user_when_their_api_key_is_about_to_expire": "Send an email to the user when their API key is about to expire.",
|
||||
"authorize_device": "Authorize Device",
|
||||
@@ -474,5 +470,5 @@
|
||||
"light": "Light",
|
||||
"dark": "Dark",
|
||||
"system": "System",
|
||||
"signup_token_user_groups_description": "Automatically assign these groups to users who sign up using this token."
|
||||
"scopes": "Scopes"
|
||||
}
|
||||
|
||||
@@ -232,19 +232,22 @@
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes bg-zoom {
|
||||
@keyframes slide-bg-container {
|
||||
0% {
|
||||
transform: scale(1.3);
|
||||
left: 0;
|
||||
}
|
||||
100% {
|
||||
transform: scale(1);
|
||||
left: 650px;
|
||||
}
|
||||
}
|
||||
|
||||
.animate-bg-zoom {
|
||||
transform-origin: center;
|
||||
will-change: transform;
|
||||
animation: bg-zoom 0.7s cubic-bezier(0.25, 0.1, 0.25, 1) forwards;
|
||||
.animate-slide-bg-container {
|
||||
position: absolute;
|
||||
top: 0;
|
||||
bottom: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
animation: slide-bg-container 0.6s cubic-bezier(0.33, 1, 0.68, 1) forwards;
|
||||
}
|
||||
|
||||
@keyframes delayed-fade {
|
||||
|
||||
@@ -12,8 +12,6 @@
|
||||
title,
|
||||
description,
|
||||
defaultExpanded = false,
|
||||
forcedExpanded,
|
||||
button,
|
||||
icon,
|
||||
children
|
||||
}: {
|
||||
@@ -21,9 +19,7 @@
|
||||
title: string;
|
||||
description?: string;
|
||||
defaultExpanded?: boolean;
|
||||
forcedExpanded?: boolean;
|
||||
icon?: typeof IconType;
|
||||
button?: Snippet;
|
||||
children: Snippet;
|
||||
} = $props();
|
||||
|
||||
@@ -51,12 +47,6 @@
|
||||
}
|
||||
loadExpandedState();
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
if (forcedExpanded !== undefined) {
|
||||
expanded = forcedExpanded;
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<Card.Root>
|
||||
@@ -73,18 +63,11 @@
|
||||
<Card.Description>{description}</Card.Description>
|
||||
{/if}
|
||||
</div>
|
||||
{#if button}
|
||||
{@render button()}
|
||||
{:else}
|
||||
<Button class="ml-10 h-8 p-3" variant="ghost" aria-label={m.expand_card()}>
|
||||
<LucideChevronDown
|
||||
class={cn(
|
||||
'size-5 transition-transform duration-200',
|
||||
expanded && 'rotate-180 transform'
|
||||
)}
|
||||
/>
|
||||
</Button>
|
||||
{/if}
|
||||
<Button class="ml-10 h-8 p-3" variant="ghost" aria-label={m.expand_card()}>
|
||||
<LucideChevronDown
|
||||
class={cn('size-5 transition-transform duration-200', expanded && 'rotate-180 transform')}
|
||||
/>
|
||||
</Button>
|
||||
</div>
|
||||
</Card.Header>
|
||||
{#if expanded}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
<script lang="ts">
|
||||
import { Checkbox } from '$lib/components/ui/checkbox';
|
||||
import { Label } from '$lib/components/ui/label';
|
||||
import * as Field from '$lib/components/ui/field';
|
||||
|
||||
let {
|
||||
id,
|
||||
@@ -26,14 +26,10 @@
|
||||
onCheckedChange={(v) => onCheckedChange && onCheckedChange(v == true)}
|
||||
bind:checked
|
||||
/>
|
||||
<div class="grid gap-1.5 leading-none">
|
||||
<Label for={id} class="mb-0 text-sm leading-none font-medium">
|
||||
{label}
|
||||
</Label>
|
||||
<Field.Field class="gap-0">
|
||||
<Field.Label for={id}>{label}</Field.Label>
|
||||
{#if description}
|
||||
<p class="text-muted-foreground text-[0.8rem]">
|
||||
{description}
|
||||
</p>
|
||||
<Field.Description>{description}</Field.Description>
|
||||
{/if}
|
||||
</div>
|
||||
</Field.Field>
|
||||
</div>
|
||||
|
||||
@@ -1,23 +1,13 @@
|
||||
<script lang="ts">
|
||||
import DatePicker from '$lib/components/form/date-picker.svelte';
|
||||
import * as Field from '$lib/components/ui/field';
|
||||
import { Input, type FormInputEvent } from '$lib/components/ui/input';
|
||||
import { Label } from '$lib/components/ui/label';
|
||||
import { m } from '$lib/paraglide/messages';
|
||||
import type { FormInput } from '$lib/utils/form-util';
|
||||
import { LucideExternalLink } from '@lucide/svelte';
|
||||
import type { Snippet } from 'svelte';
|
||||
import type { HTMLAttributes } from 'svelte/elements';
|
||||
|
||||
type WithoutChildren = {
|
||||
children?: undefined;
|
||||
input?: FormInput<string | boolean | number | Date | undefined>;
|
||||
labelFor?: never;
|
||||
};
|
||||
type WithChildren = {
|
||||
children: Snippet;
|
||||
input?: any;
|
||||
labelFor?: string;
|
||||
};
|
||||
import FormattedMessage from '../formatted-message.svelte';
|
||||
|
||||
let {
|
||||
input = $bindable(),
|
||||
@@ -29,29 +19,29 @@
|
||||
type = 'text',
|
||||
children,
|
||||
onInput,
|
||||
labelFor,
|
||||
...restProps
|
||||
}: HTMLAttributes<HTMLDivElement> &
|
||||
(WithChildren | WithoutChildren) & {
|
||||
label?: string;
|
||||
description?: string;
|
||||
docsLink?: string;
|
||||
placeholder?: string;
|
||||
disabled?: boolean;
|
||||
type?: 'text' | 'password' | 'email' | 'number' | 'checkbox' | 'date';
|
||||
onInput?: (e: FormInputEvent) => void;
|
||||
} = $props();
|
||||
}: HTMLAttributes<HTMLDivElement> & {
|
||||
input?: FormInput<string | boolean | number | Date | undefined>;
|
||||
label?: string;
|
||||
description?: string;
|
||||
docsLink?: string;
|
||||
placeholder?: string;
|
||||
disabled?: boolean;
|
||||
type?: 'text' | 'password' | 'email' | 'number' | 'checkbox' | 'date';
|
||||
onInput?: (e: FormInputEvent) => void;
|
||||
children?: Snippet;
|
||||
} = $props();
|
||||
|
||||
const id = label?.toLowerCase().replace(/ /g, '-');
|
||||
</script>
|
||||
|
||||
<div {...restProps}>
|
||||
<Field.Field data-disabled={disabled} {...restProps}>
|
||||
{#if label}
|
||||
<Label required={input?.required} class="mb-0" for={labelFor ?? id}>{label}</Label>
|
||||
<Field.Label required={input?.required} for={id}>{label}</Field.Label>
|
||||
{/if}
|
||||
{#if description}
|
||||
<p class="text-muted-foreground mt-1 text-xs">
|
||||
{description}
|
||||
<Field.Description>
|
||||
<FormattedMessage m={description} />
|
||||
{#if docsLink}
|
||||
<a
|
||||
class="relative text-black after:absolute after:bottom-0 after:left-0 after:h-px after:w-full after:translate-y-[-1px] after:bg-white dark:text-white"
|
||||
@@ -62,28 +52,26 @@
|
||||
<LucideExternalLink class="inline size-3 align-text-top" />
|
||||
</a>
|
||||
{/if}
|
||||
</p>
|
||||
</Field.Description>
|
||||
{/if}
|
||||
<div class={label || description ? 'mt-2' : ''}>
|
||||
{#if children}
|
||||
{@render children()}
|
||||
{:else if input}
|
||||
{#if type === 'date'}
|
||||
<DatePicker {id} bind:value={input.value as Date} />
|
||||
{:else}
|
||||
<Input
|
||||
aria-invalid={!!input.error}
|
||||
{id}
|
||||
{placeholder}
|
||||
{type}
|
||||
bind:value={input.value}
|
||||
{disabled}
|
||||
oninput={(e) => onInput?.(e)}
|
||||
/>
|
||||
{/if}
|
||||
{#if children}
|
||||
{@render children()}
|
||||
{:else if input}
|
||||
{#if type === 'date'}
|
||||
<DatePicker {id} bind:value={input.value as Date} />
|
||||
{:else}
|
||||
<Input
|
||||
aria-invalid={!!input.error}
|
||||
{id}
|
||||
{placeholder}
|
||||
{type}
|
||||
bind:value={input.value}
|
||||
{disabled}
|
||||
oninput={(e) => onInput?.(e)}
|
||||
/>
|
||||
{/if}
|
||||
{#if input?.error}
|
||||
<p class="text-destructive mt-1 text-start text-xs">{input.error}</p>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
{#if input?.error}
|
||||
<Field.Error>{input.error}</Field.Error>
|
||||
{/if}
|
||||
</Field.Field>
|
||||
|
||||
@@ -5,7 +5,8 @@
|
||||
import { m } from '$lib/paraglide/messages';
|
||||
import appConfigStore from '$lib/stores/application-configuration-store';
|
||||
import { cachedProfilePicture } from '$lib/utils/cached-image-util';
|
||||
import { LucideLoader, LucideRefreshCw, LucideUpload } from '@lucide/svelte';
|
||||
import { LucideRefreshCw, LucideUpload } from '@lucide/svelte';
|
||||
import { Spinner } from '$lib/components/ui/spinner';
|
||||
import { onMount } from 'svelte';
|
||||
import { openConfirmDialog } from '../confirm-dialog';
|
||||
|
||||
@@ -88,7 +89,7 @@
|
||||
</Avatar.Root>
|
||||
<div class="absolute inset-0 flex items-center justify-center">
|
||||
{#if isLoading}
|
||||
<LucideLoader class="size-5 animate-spin" />
|
||||
<Spinner class="size-5" />
|
||||
{:else}
|
||||
<LucideUpload class="size-5 opacity-0 transition-opacity group-hover:opacity-100" />
|
||||
{/if}
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import * as Command from '$lib/components/ui/command';
|
||||
import * as Popover from '$lib/components/ui/popover';
|
||||
import { Spinner } from '$lib/components/ui/spinner';
|
||||
import { cn } from '$lib/utils/style';
|
||||
import { m } from '$lib/paraglide/messages';
|
||||
import { LoaderCircle, LucideCheck, LucideChevronDown } from '@lucide/svelte';
|
||||
import type { FormEventHandler } from 'svelte/elements';
|
||||
|
||||
type Item = {
|
||||
@@ -108,7 +108,7 @@
|
||||
<Command.Empty>
|
||||
{#if isLoading}
|
||||
<div class="flex w-full items-center justify-center py-2">
|
||||
<LoaderCircle class="size-4 animate-spin" />
|
||||
<Spinner />
|
||||
</div>
|
||||
{:else}
|
||||
{m.no_items_found()}
|
||||
|
||||
@@ -2,9 +2,10 @@
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import * as Command from '$lib/components/ui/command';
|
||||
import * as Popover from '$lib/components/ui/popover';
|
||||
import { Spinner } from '$lib/components/ui/spinner';
|
||||
import { m } from '$lib/paraglide/messages';
|
||||
import { cn } from '$lib/utils/style';
|
||||
import { LoaderCircle, LucideCheck, LucideChevronDown } from '@lucide/svelte';
|
||||
import { LucideCheck, LucideChevronDown } from '@lucide/svelte';
|
||||
import { tick } from 'svelte';
|
||||
import type { FormEventHandler, HTMLAttributes } from 'svelte/elements';
|
||||
|
||||
@@ -90,7 +91,7 @@
|
||||
<Command.Empty>
|
||||
{#if isLoading}
|
||||
<div class="flex w-full justify-center">
|
||||
<LoaderCircle class="size-4 animate-spin" />
|
||||
<Spinner />
|
||||
</div>
|
||||
{:else}
|
||||
{m.no_items_found()}
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
<script lang="ts">
|
||||
import SearchableMultiSelect from '$lib/components/form/searchable-multi-select.svelte';
|
||||
import UserGroupService from '$lib/services/user-group-service';
|
||||
import { debounced } from '$lib/utils/debounce-util';
|
||||
import { onMount } from 'svelte';
|
||||
|
||||
let {
|
||||
selectedGroupIds = $bindable()
|
||||
}: {
|
||||
selectedGroupIds: string[];
|
||||
} = $props();
|
||||
|
||||
const userGroupService = new UserGroupService();
|
||||
|
||||
let userGroups = $state<{ value: string; label: string }[]>([]);
|
||||
let isLoading = $state(false);
|
||||
|
||||
async function loadUserGroups(search?: string) {
|
||||
userGroups = (await userGroupService.list({ search })).data.map((group) => ({
|
||||
value: group.id,
|
||||
label: group.name
|
||||
}));
|
||||
|
||||
// Ensure selected groups are still in the list
|
||||
for (const selectedGroupId of selectedGroupIds) {
|
||||
if (!userGroups.some((g) => g.value === selectedGroupId)) {
|
||||
const group = await userGroupService.get(selectedGroupId);
|
||||
userGroups.push({ value: group.id, label: group.name });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const onUserGroupSearch = debounced(
|
||||
async (search: string) => await loadUserGroups(search),
|
||||
300,
|
||||
(loading) => (isLoading = loading)
|
||||
);
|
||||
|
||||
onMount(() => loadUserGroups());
|
||||
</script>
|
||||
|
||||
<SearchableMultiSelect
|
||||
id="default-groups"
|
||||
items={userGroups}
|
||||
oninput={(e) => onUserGroupSearch(e.currentTarget.value)}
|
||||
selectedItems={selectedGroupIds}
|
||||
onSelect={(selected) => (selectedGroupIds = selected)}
|
||||
{isLoading}
|
||||
disableInternalSearch
|
||||
/>
|
||||
@@ -19,7 +19,7 @@
|
||||
);
|
||||
</script>
|
||||
|
||||
<div class=" w-full {isAuthPage ? 'absolute top-0 z-10 mt-3 lg:mt-8 pr-2 lg:pr-3' : 'border-b'}">
|
||||
<div class=" w-full {isAuthPage ? 'absolute top-0 z-10 mt-4' : 'border-b'}">
|
||||
<div
|
||||
class="{!isAuthPage
|
||||
? 'max-w-[1640px]'
|
||||
|
||||
@@ -48,16 +48,20 @@
|
||||
{#if isDesktop.current}
|
||||
<div class="h-screen items-center overflow-hidden text-center">
|
||||
<div
|
||||
class="relative z-10 flex h-full w-[650px] 2xl:w-[800px] p-16 {cn(
|
||||
showAlternativeSignInMethodButton && 'pb-0'
|
||||
class="relative z-10 flex h-full w-[650px] p-16 {cn(
|
||||
showAlternativeSignInMethodButton && 'pb-0',
|
||||
animate && 'animate-delayed-fade'
|
||||
)}"
|
||||
>
|
||||
<div class="flex h-full w-full flex-col overflow-hidden">
|
||||
<div class="relative flex grow flex-col items-center justify-center overflow-auto">
|
||||
<div class="relative flex flex-grow flex-col items-center justify-center overflow-auto">
|
||||
{@render children()}
|
||||
</div>
|
||||
{#if showAlternativeSignInMethodButton}
|
||||
<div class="mb-4 flex items-center justify-center">
|
||||
<div
|
||||
class="mb-4 flex items-center justify-center"
|
||||
style={animate ? 'animation-delay: 500ms;' : ''}
|
||||
>
|
||||
<a
|
||||
href={alternativeSignInButton.href}
|
||||
class="text-muted-foreground text-xs transition-colors hover:underline"
|
||||
@@ -69,13 +73,13 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Background image -->
|
||||
<div class="absolute top-0 right-0 left-500px bottom-0 z-0 overflow-hidden rounded-[40px] m-6">
|
||||
<!-- Background image with slide animation -->
|
||||
<div class="{cn(animate && 'animate-slide-bg-container')} absolute top-0 right-0 bottom-0 z-0">
|
||||
<img
|
||||
src={cachedBackgroundImage.getUrl()}
|
||||
class="{cn(
|
||||
animate && 'animate-bg-zoom'
|
||||
)} h-screen object-cover w-[calc(100vw-650px)] 2xl:w-[calc(100vw-800px)]"
|
||||
class="h-screen rounded-l-[60px] object-cover {animate
|
||||
? 'w-full'
|
||||
: 'w-[calc(100vw-650px)]'}"
|
||||
alt={m.login_background()}
|
||||
/>
|
||||
</div>
|
||||
@@ -85,7 +89,7 @@
|
||||
class="flex h-screen items-center justify-center bg-cover bg-center text-center"
|
||||
style="background-image: url({cachedBackgroundImage.getUrl()});"
|
||||
>
|
||||
<Card.Root class="mx-3 w-full max-w-md">
|
||||
<Card.Root class="mx-3 w-full max-w-md" style={animate ? 'animation-delay: 200ms;' : ''}>
|
||||
<Card.CardContent
|
||||
class="px-4 py-10 sm:p-10 {showAlternativeSignInMethodButton ? 'pb-3 sm:pb-3' : ''}"
|
||||
>
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import Qrcode from '$lib/components/qrcode/qrcode.svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import * as Dialog from '$lib/components/ui/dialog';
|
||||
import Label from '$lib/components/ui/label/label.svelte';
|
||||
import * as Field from '$lib/components/ui/field';
|
||||
import * as Select from '$lib/components/ui/select/index.js';
|
||||
import { Separator } from '$lib/components/ui/separator';
|
||||
import { m } from '$lib/paraglide/messages';
|
||||
@@ -78,14 +78,14 @@
|
||||
</Dialog.Header>
|
||||
|
||||
{#if oneTimeLink === null}
|
||||
<div>
|
||||
<Label for="expiration">{m.expiration()}</Label>
|
||||
<Field.Field>
|
||||
<Field.Label for="expiration">{m.expiration()}</Field.Label>
|
||||
<Select.Root
|
||||
type="single"
|
||||
value={Object.keys(availableExpirations)[0]}
|
||||
onValueChange={(v) => (selectedExpiration = v! as keyof typeof availableExpirations)}
|
||||
>
|
||||
<Select.Trigger id="expiration" class="w-full h-9">
|
||||
<Select.Trigger id="expiration" class="h-9 w-full">
|
||||
{selectedExpiration}
|
||||
</Select.Trigger>
|
||||
<Select.Content>
|
||||
@@ -94,7 +94,7 @@
|
||||
{/each}
|
||||
</Select.Content>
|
||||
</Select.Root>
|
||||
</div>
|
||||
</Field.Field>
|
||||
<Dialog.Footer class="mt-2">
|
||||
{#if $appConfigStore.emailOneTimeAccessAsAdminEnabled}
|
||||
<Button
|
||||
@@ -112,10 +112,10 @@
|
||||
{:else}
|
||||
<div class="flex flex-col items-center gap-2">
|
||||
<CopyToClipboard value={code!}>
|
||||
<p class="text-3xl font-code">{code}</p>
|
||||
<p class="font-code text-3xl">{code}</p>
|
||||
</CopyToClipboard>
|
||||
|
||||
<div class="flex items-center justify-center gap-3 my-2 text-muted-foreground">
|
||||
<div class="text-muted-foreground my-2 flex items-center justify-center gap-3">
|
||||
<Separator />
|
||||
<p class="text-xs text-nowrap">{m.or_visit()}</p>
|
||||
<Separator />
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
<script lang="ts">
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import * as Item from '$lib/components/ui/item/index.js';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip/index.js';
|
||||
import { m } from '$lib/paraglide/messages';
|
||||
import { LucideCalendar, LucidePencil, LucideTrash, type Icon as IconType } from '@lucide/svelte';
|
||||
@@ -19,61 +20,54 @@
|
||||
} = $props();
|
||||
</script>
|
||||
|
||||
<div class="bg-card hover:bg-muted/50 group rounded-lg p-3 transition-colors">
|
||||
<div class="flex items-center justify-between">
|
||||
<div class="flex items-start gap-3">
|
||||
<div class="bg-primary/10 text-primary mt-1 rounded-lg p-2">
|
||||
{#if icon}{@const Icon = icon}
|
||||
<Icon class="size-5" />
|
||||
{/if}
|
||||
</div>
|
||||
<div>
|
||||
<div class="flex items-center gap-2">
|
||||
<p class="font-medium">{label}</p>
|
||||
</div>
|
||||
{#if description}
|
||||
<div class="text-muted-foreground mt-1 flex items-center text-xs">
|
||||
<LucideCalendar class="mr-1 size-3" />
|
||||
{description}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
<Item.Root variant="transparent" class="hover:bg-muted transition-colors">
|
||||
<Item.Media class="bg-primary/10 text-primary rounded-lg p-2">
|
||||
{#if icon}{@const Icon = icon}
|
||||
<Icon class="size-5" />
|
||||
{/if}
|
||||
</Item.Media>
|
||||
<Item.Content>
|
||||
<Item.Title>{label}</Item.Title>
|
||||
{#if description}
|
||||
<Item.Description class="flex items-center">
|
||||
<LucideCalendar class="mr-1 size-3" />
|
||||
{description}
|
||||
</Item.Description>
|
||||
{/if}
|
||||
</Item.Content>
|
||||
<Item.Actions>
|
||||
<Tooltip.Provider>
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<Button
|
||||
onclick={onRename}
|
||||
size="icon"
|
||||
variant="ghost"
|
||||
class="size-8"
|
||||
aria-label={m.rename()}
|
||||
>
|
||||
<LucidePencil class="size-4" />
|
||||
</Button>
|
||||
</Tooltip.Trigger>
|
||||
<Tooltip.Content>{m.rename()}</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
</Tooltip.Provider>
|
||||
|
||||
<div class="flex items-center gap-2">
|
||||
<Tooltip.Provider>
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<Button
|
||||
onclick={onRename}
|
||||
size="icon"
|
||||
variant="ghost"
|
||||
class="size-8"
|
||||
aria-label={m.rename()}
|
||||
>
|
||||
<LucidePencil class="size-4" />
|
||||
</Button>
|
||||
</Tooltip.Trigger>
|
||||
<Tooltip.Content>{m.rename()}</Tooltip.Content>
|
||||
</Tooltip.Root></Tooltip.Provider
|
||||
>
|
||||
|
||||
<Tooltip.Provider>
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<Button
|
||||
onclick={onDelete}
|
||||
size="icon"
|
||||
variant="ghost"
|
||||
class="hover:bg-destructive/10 hover:text-destructive size-8"
|
||||
aria-label={m.delete()}
|
||||
>
|
||||
<LucideTrash class="size-4" />
|
||||
</Button>
|
||||
</Tooltip.Trigger>
|
||||
<Tooltip.Content>{m.delete()}</Tooltip.Content>
|
||||
</Tooltip.Root></Tooltip.Provider
|
||||
>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<Tooltip.Provider>
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<Button
|
||||
onclick={onDelete}
|
||||
size="icon"
|
||||
variant="ghost"
|
||||
class="hover:bg-destructive/10 hover:text-destructive size-8"
|
||||
aria-label={m.delete()}
|
||||
>
|
||||
<LucideTrash class="size-4" />
|
||||
</Button>
|
||||
</Tooltip.Trigger>
|
||||
<Tooltip.Content>{m.delete()}</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
</Tooltip.Provider>
|
||||
</Item.Actions>
|
||||
</Item.Root>
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
<script lang="ts">
|
||||
import * as Item from '$lib/components/ui/item/index.js';
|
||||
import type { Icon as IconType } from '@lucide/svelte';
|
||||
|
||||
interface Props {
|
||||
icon: typeof IconType;
|
||||
name: string;
|
||||
@@ -11,10 +13,12 @@
|
||||
const SvelteComponent = $derived(icon);
|
||||
</script>
|
||||
|
||||
<div class="flex items-center">
|
||||
<div class="bg-muted mr-5 rounded-lg p-2"><SvelteComponent /></div>
|
||||
<div class="text-start">
|
||||
<h3 class="font-semibold">{name}</h3>
|
||||
<p class="text-muted-foreground text-sm">{description}</p>
|
||||
</div>
|
||||
</div>
|
||||
<Item.Root size="sm" class="gap-5">
|
||||
<Item.Media class="bg-muted !self-center rounded-lg p-2 !translate-y-0">
|
||||
<SvelteComponent class="size-4" />
|
||||
</Item.Media>
|
||||
<Item.Content class="text-start">
|
||||
<Item.Title class="font-semibold">{name}</Item.Title>
|
||||
<Item.Description>{description}</Item.Description>
|
||||
</Item.Content>
|
||||
</Item.Root>
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
<script lang="ts">
|
||||
import * as Item from '$lib/components/ui/item/index.js';
|
||||
import { m } from '$lib/paraglide/messages';
|
||||
import { LucideMail, LucideUser, LucideUsers } from '@lucide/svelte';
|
||||
import ScopeItem from './scope-item.svelte';
|
||||
@@ -6,7 +7,7 @@
|
||||
let { scope }: { scope: string } = $props();
|
||||
</script>
|
||||
|
||||
<div class="flex flex-col gap-3" data-testid="scopes">
|
||||
<Item.Group data-testid="scopes">
|
||||
{#if scope!.includes('email')}
|
||||
<ScopeItem icon={LucideMail} name={m.email()} description={m.view_your_email_address()} />
|
||||
{/if}
|
||||
@@ -24,4 +25,4 @@
|
||||
description={m.view_the_groups_you_are_a_member_of()}
|
||||
/>
|
||||
{/if}
|
||||
</div>
|
||||
</Item.Group>
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
AdvancedTableColumn,
|
||||
CreateAdvancedTableActions
|
||||
} from '$lib/types/advanced-table.type';
|
||||
import type { SignupToken } from '$lib/types/signup-token.type';
|
||||
import type { SignupTokenDto } from '$lib/types/signup-token.type';
|
||||
import { axiosErrorToast } from '$lib/utils/error-util';
|
||||
import { Copy, Trash2 } from '@lucide/svelte';
|
||||
import { toast } from 'svelte-sonner';
|
||||
@@ -23,14 +23,14 @@
|
||||
} = $props();
|
||||
|
||||
const userService = new UserService();
|
||||
let tableRef: AdvancedTable<SignupToken>;
|
||||
let tableRef: AdvancedTable<SignupTokenDto>;
|
||||
|
||||
function formatDate(dateStr: string | undefined) {
|
||||
if (!dateStr) return m.never();
|
||||
return new Date(dateStr).toLocaleString();
|
||||
}
|
||||
|
||||
async function deleteToken(token: SignupToken) {
|
||||
async function deleteToken(token: SignupTokenDto) {
|
||||
openConfirmDialog({
|
||||
title: m.delete_signup_token(),
|
||||
message: m.are_you_sure_you_want_to_delete_this_signup_token(),
|
||||
@@ -58,11 +58,11 @@
|
||||
return new Date(expiresAt) < new Date();
|
||||
}
|
||||
|
||||
function isTokenUsedUp(token: SignupToken) {
|
||||
function isTokenUsedUp(token: SignupTokenDto) {
|
||||
return token.usageCount >= token.usageLimit;
|
||||
}
|
||||
|
||||
function getTokenStatus(token: SignupToken) {
|
||||
function getTokenStatus(token: SignupTokenDto) {
|
||||
if (isTokenExpired(token.expiresAt)) return 'expired';
|
||||
if (isTokenUsedUp(token)) return 'used-up';
|
||||
return 'active';
|
||||
@@ -79,7 +79,7 @@
|
||||
}
|
||||
}
|
||||
|
||||
function copySignupLink(token: SignupToken) {
|
||||
function copySignupLink(token: SignupTokenDto) {
|
||||
const signupLink = `${page.url.origin}/st/${token.token}`;
|
||||
navigator.clipboard
|
||||
.writeText(signupLink)
|
||||
@@ -91,7 +91,7 @@
|
||||
});
|
||||
}
|
||||
|
||||
const columns: AdvancedTableColumn<SignupToken>[] = [
|
||||
const columns: AdvancedTableColumn<SignupTokenDto>[] = [
|
||||
{ label: m.token(), column: 'token', cell: TokenCell },
|
||||
{ label: m.status(), key: 'status', cell: StatusCell },
|
||||
{
|
||||
@@ -106,12 +106,7 @@
|
||||
sortable: true,
|
||||
value: (item) => formatDate(item.expiresAt)
|
||||
},
|
||||
{
|
||||
key: 'userGroups',
|
||||
label: m.user_groups(),
|
||||
value: (item) => item.userGroups.map((g) => g.name).join(', '),
|
||||
hidden: true
|
||||
},
|
||||
{ label: 'Usage Limit', column: 'usageLimit' },
|
||||
{
|
||||
label: m.created(),
|
||||
column: 'createdAt',
|
||||
@@ -121,7 +116,7 @@
|
||||
}
|
||||
];
|
||||
|
||||
const actions: CreateAdvancedTableActions<SignupToken> = (_) => [
|
||||
const actions: CreateAdvancedTableActions<SignupTokenDto> = (_) => [
|
||||
{
|
||||
label: m.copy(),
|
||||
icon: Copy,
|
||||
@@ -136,13 +131,13 @@
|
||||
];
|
||||
</script>
|
||||
|
||||
{#snippet TokenCell({ item }: { item: SignupToken })}
|
||||
{#snippet TokenCell({ item }: { item: SignupTokenDto })}
|
||||
<span class="font-mono text-xs">
|
||||
{item.token.substring(0, 3)}...{item.token.substring(Math.max(item.token.length - 4, 0))}
|
||||
</span>
|
||||
{/snippet}
|
||||
|
||||
{#snippet StatusCell({ item }: { item: SignupToken })}
|
||||
{#snippet StatusCell({ item }: { item: SignupTokenDto })}
|
||||
{@const status = getTokenStatus(item)}
|
||||
{@const statusBadge = getStatusBadge(status)}
|
||||
<Badge class="rounded-full" variant={statusBadge.variant}>
|
||||
@@ -150,7 +145,7 @@
|
||||
</Badge>
|
||||
{/snippet}
|
||||
|
||||
{#snippet UsageCell({ item }: { item: SignupToken })}
|
||||
{#snippet UsageCell({ item }: { item: SignupTokenDto })}
|
||||
<div class="flex items-center gap-1">
|
||||
{item.usageCount}
|
||||
{m.of()}
|
||||
|
||||
@@ -1,22 +1,16 @@
|
||||
<script lang="ts">
|
||||
import { page } from '$app/state';
|
||||
import CopyToClipboard from '$lib/components/copy-to-clipboard.svelte';
|
||||
import FormInput from '$lib/components/form/form-input.svelte';
|
||||
import UserGroupInput from '$lib/components/form/user-group-input.svelte';
|
||||
import Qrcode from '$lib/components/qrcode/qrcode.svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import * as Dialog from '$lib/components/ui/dialog';
|
||||
import * as Field from '$lib/components/ui/field';
|
||||
import { Input } from '$lib/components/ui/input';
|
||||
import * as Select from '$lib/components/ui/select/index.js';
|
||||
import { m } from '$lib/paraglide/messages';
|
||||
import AppConfigService from '$lib/services/app-config-service';
|
||||
import UserService from '$lib/services/user-service';
|
||||
import { axiosErrorToast } from '$lib/utils/error-util';
|
||||
import { preventDefault } from '$lib/utils/event-util';
|
||||
import { createForm } from '$lib/utils/form-util';
|
||||
import { mode } from 'mode-watcher';
|
||||
import { onMount } from 'svelte';
|
||||
import { z } from 'zod/v4';
|
||||
|
||||
let {
|
||||
open = $bindable()
|
||||
@@ -25,74 +19,29 @@
|
||||
} = $props();
|
||||
|
||||
const userService = new UserService();
|
||||
const appConfigService = new AppConfigService();
|
||||
|
||||
const DEFAULT_TTL_SECONDS = 60 * 60 * 24;
|
||||
const availableExpirations = [
|
||||
{ label: m.one_hour(), value: 60 * 60 },
|
||||
{ label: m.twelve_hours(), value: 60 * 60 * 12 },
|
||||
{ label: m.one_day(), value: DEFAULT_TTL_SECONDS },
|
||||
{ label: m.one_week(), value: DEFAULT_TTL_SECONDS * 7 },
|
||||
{ label: m.one_month(), value: DEFAULT_TTL_SECONDS * 30 }
|
||||
] as const;
|
||||
|
||||
const defaultExpiration =
|
||||
availableExpirations.find((exp) => exp.value === DEFAULT_TTL_SECONDS)?.value ??
|
||||
availableExpirations[0].value;
|
||||
|
||||
type SignupTokenForm = {
|
||||
ttl: number;
|
||||
usageLimit: number;
|
||||
userGroupIds: string[];
|
||||
};
|
||||
|
||||
const initialFormValues: SignupTokenForm = {
|
||||
ttl: defaultExpiration,
|
||||
usageLimit: 1,
|
||||
userGroupIds: []
|
||||
};
|
||||
|
||||
const formSchema = z.object({
|
||||
ttl: z.number(),
|
||||
usageLimit: z.number().min(1).max(100),
|
||||
userGroupIds: z.array(z.string()).default([])
|
||||
});
|
||||
|
||||
const { inputs, ...form } = createForm<typeof formSchema>(formSchema, initialFormValues);
|
||||
|
||||
let signupToken: string | null = $state(null);
|
||||
let signupLink: string | null = $state(null);
|
||||
let createdSignupData: SignupTokenForm | null = $state(null);
|
||||
let isLoading = $state(false);
|
||||
let selectedExpiration: keyof typeof availableExpirations = $state(m.one_day());
|
||||
let usageLimit: number = $state(1);
|
||||
|
||||
let defaultUserGroupIds: string[] = [];
|
||||
|
||||
function getExpirationLabel(ttl: number) {
|
||||
return availableExpirations.find((exp) => exp.value === ttl)?.label ?? '';
|
||||
}
|
||||
|
||||
function resetForm() {
|
||||
form.reset();
|
||||
form.setValue('userGroupIds', defaultUserGroupIds);
|
||||
}
|
||||
let availableExpirations = {
|
||||
[m.one_hour()]: 60 * 60,
|
||||
[m.twelve_hours()]: 60 * 60 * 12,
|
||||
[m.one_day()]: 60 * 60 * 24,
|
||||
[m.one_week()]: 60 * 60 * 24 * 7,
|
||||
[m.one_month()]: 60 * 60 * 24 * 30
|
||||
};
|
||||
|
||||
async function createSignupToken() {
|
||||
const data = form.validate();
|
||||
if (!data) return;
|
||||
|
||||
isLoading = true;
|
||||
try {
|
||||
signupToken = await userService.createSignupToken(
|
||||
data.ttl,
|
||||
data.usageLimit,
|
||||
data.userGroupIds
|
||||
availableExpirations[selectedExpiration],
|
||||
usageLimit
|
||||
);
|
||||
signupLink = `${page.url.origin}/st/${signupToken}`;
|
||||
createdSignupData = data;
|
||||
} catch (e) {
|
||||
axiosErrorToast(e);
|
||||
} finally {
|
||||
isLoading = false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,22 +50,10 @@
|
||||
if (!isOpen) {
|
||||
signupToken = null;
|
||||
signupLink = null;
|
||||
createdSignupData = null;
|
||||
resetForm();
|
||||
selectedExpiration = m.one_day();
|
||||
usageLimit = 1;
|
||||
}
|
||||
}
|
||||
|
||||
onMount(() => {
|
||||
appConfigService
|
||||
.list(true)
|
||||
.then((response) => {
|
||||
const responseGroupIds = response.signupDefaultUserGroupIDs || [];
|
||||
defaultUserGroupIds = responseGroupIds;
|
||||
initialFormValues.userGroupIds = responseGroupIds;
|
||||
form.setValue('userGroupIds', responseGroupIds);
|
||||
})
|
||||
.catch(axiosErrorToast);
|
||||
});
|
||||
</script>
|
||||
|
||||
<Dialog.Root {open} {onOpenChange}>
|
||||
@@ -129,57 +66,49 @@
|
||||
</Dialog.Header>
|
||||
|
||||
{#if signupToken === null}
|
||||
<form class="space-y-4" onsubmit={preventDefault(createSignupToken)}>
|
||||
<FormInput labelFor="expiration" label={m.expiration()} input={$inputs.ttl}>
|
||||
<div class="space-y-4">
|
||||
<Field.Field>
|
||||
<Field.Label for="expiration">{m.expiration()}</Field.Label>
|
||||
<Select.Root
|
||||
type="single"
|
||||
value={$inputs.ttl.value.toString()}
|
||||
onValueChange={(v) => v && form.setValue('ttl', Number(v))}
|
||||
value={Object.keys(availableExpirations)[0]}
|
||||
onValueChange={(v) => (selectedExpiration = v! as keyof typeof availableExpirations)}
|
||||
>
|
||||
<Select.Trigger id="expiration" class="h-9 w-full">
|
||||
{getExpirationLabel($inputs.ttl.value)}
|
||||
{selectedExpiration}
|
||||
</Select.Trigger>
|
||||
<Select.Content>
|
||||
{#each availableExpirations as expiration}
|
||||
<Select.Item value={expiration.value.toString()}>
|
||||
{expiration.label}
|
||||
</Select.Item>
|
||||
{#each Object.keys(availableExpirations) as key}
|
||||
<Select.Item value={key}>{key}</Select.Item>
|
||||
{/each}
|
||||
</Select.Content>
|
||||
</Select.Root>
|
||||
{#if $inputs.ttl.error}
|
||||
<p class="text-destructive mt-1 text-xs">{$inputs.ttl.error}</p>
|
||||
{/if}
|
||||
</FormInput>
|
||||
<FormInput
|
||||
labelFor="usage-limit"
|
||||
label={m.usage_limit()}
|
||||
description={m.number_of_times_token_can_be_used()}
|
||||
input={$inputs.usageLimit}
|
||||
>
|
||||
</Field.Field>
|
||||
|
||||
<Field.Field>
|
||||
<Field.Label for="usage-limit">{m.usage_limit()}</Field.Label>
|
||||
<Field.Description>
|
||||
{m.number_of_times_token_can_be_used()}
|
||||
</Field.Description>
|
||||
<Input
|
||||
id="usage-limit"
|
||||
type="number"
|
||||
bind:value={$inputs.usageLimit.value}
|
||||
aria-invalid={$inputs.usageLimit.error ? 'true' : undefined}
|
||||
min="1"
|
||||
max="100"
|
||||
bind:value={usageLimit}
|
||||
class="h-9"
|
||||
/>
|
||||
</FormInput>
|
||||
<FormInput
|
||||
labelFor="default-groups"
|
||||
label={m.user_groups()}
|
||||
description={m.signup_token_user_groups_description()}
|
||||
input={$inputs.userGroupIds}
|
||||
>
|
||||
<UserGroupInput bind:selectedGroupIds={$inputs.userGroupIds.value} />
|
||||
</FormInput>
|
||||
</Field.Field>
|
||||
</div>
|
||||
|
||||
<Dialog.Footer class="mt-4">
|
||||
<Button type="submit" {isLoading}>
|
||||
{m.create()}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</form>
|
||||
<Dialog.Footer class="mt-4">
|
||||
<Button
|
||||
onclick={() => createSignupToken()}
|
||||
disabled={!selectedExpiration || usageLimit < 1}
|
||||
>
|
||||
{m.create()}
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
{:else}
|
||||
<div class="flex flex-col items-center gap-2">
|
||||
<Qrcode
|
||||
@@ -196,8 +125,8 @@
|
||||
</CopyToClipboard>
|
||||
|
||||
<div class="text-muted-foreground mt-2 text-center text-sm">
|
||||
<p>{m.usage_limit()}: {createdSignupData?.usageLimit}</p>
|
||||
<p>{m.expiration()}: {getExpirationLabel(createdSignupData?.ttl ?? 0)}</p>
|
||||
<p>{m.usage_limit()}: {usageLimit}</p>
|
||||
<p>{m.expiration()}: {selectedExpiration}</p>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
@@ -42,7 +42,7 @@
|
||||
</script>
|
||||
|
||||
<script lang="ts">
|
||||
import LoaderCircle from '@lucide/svelte/icons/loader-circle';
|
||||
import { Spinner } from '$lib/components/ui/spinner';
|
||||
import { onMount } from 'svelte';
|
||||
|
||||
let {
|
||||
@@ -97,7 +97,7 @@
|
||||
{...restProps}
|
||||
>
|
||||
{#if isLoading}
|
||||
<LoaderCircle class="size-4 animate-spin" />
|
||||
<Spinner />
|
||||
{/if}
|
||||
{@render children?.()}
|
||||
</a>
|
||||
@@ -112,7 +112,7 @@
|
||||
{...restProps}
|
||||
>
|
||||
{#if isLoading}
|
||||
<LoaderCircle class="size-4 animate-spin" />
|
||||
<Spinner />
|
||||
{/if}
|
||||
{@render children?.()}
|
||||
</button>
|
||||
|
||||
20
frontend/src/lib/components/ui/field/field-content.svelte
Normal file
20
frontend/src/lib/components/ui/field/field-content.svelte
Normal file
@@ -0,0 +1,20 @@
|
||||
<script lang="ts">
|
||||
import { cn, type WithElementRef } from "$lib/utils/style.js";
|
||||
import type { HTMLAttributes } from "svelte/elements";
|
||||
|
||||
let {
|
||||
ref = $bindable(null),
|
||||
class: className,
|
||||
children,
|
||||
...restProps
|
||||
}: WithElementRef<HTMLAttributes<HTMLDivElement>> = $props();
|
||||
</script>
|
||||
|
||||
<div
|
||||
bind:this={ref}
|
||||
data-slot="field-content"
|
||||
class={cn("group/field-content flex flex-1 flex-col gap-1.5 leading-snug", className)}
|
||||
{...restProps}
|
||||
>
|
||||
{@render children?.()}
|
||||
</div>
|
||||
@@ -0,0 +1,24 @@
|
||||
<script lang="ts">
|
||||
import { cn, type WithElementRef } from '$lib/utils/style.js';
|
||||
import type { HTMLAttributes } from 'svelte/elements';
|
||||
|
||||
let {
|
||||
ref = $bindable(null),
|
||||
class: className,
|
||||
children,
|
||||
...restProps
|
||||
}: WithElementRef<HTMLAttributes<HTMLParagraphElement>> = $props();
|
||||
</script>
|
||||
|
||||
<p
|
||||
bind:this={ref}
|
||||
data-slot="field-description"
|
||||
class={cn(
|
||||
'text-muted-foreground -mt-1 mb-0 text-xs leading-normal font-normal group-has-[[data-orientation=horizontal]]/field:text-balance',
|
||||
'[&>a:hover]:text-primary [&>a]:underline [&>a]:underline-offset-4',
|
||||
className
|
||||
)}
|
||||
{...restProps}
|
||||
>
|
||||
{@render children?.()}
|
||||
</p>
|
||||
58
frontend/src/lib/components/ui/field/field-error.svelte
Normal file
58
frontend/src/lib/components/ui/field/field-error.svelte
Normal file
@@ -0,0 +1,58 @@
|
||||
<script lang="ts">
|
||||
import { cn, type WithElementRef } from "$lib/utils/style.js";
|
||||
import type { HTMLAttributes } from "svelte/elements";
|
||||
import type { Snippet } from "svelte";
|
||||
|
||||
let {
|
||||
ref = $bindable(null),
|
||||
class: className,
|
||||
children,
|
||||
errors,
|
||||
...restProps
|
||||
}: WithElementRef<HTMLAttributes<HTMLDivElement>> & {
|
||||
children?: Snippet;
|
||||
errors?: { message?: string }[];
|
||||
} = $props();
|
||||
|
||||
const hasContent = $derived.by(() => {
|
||||
// has slotted error
|
||||
if (children) return true;
|
||||
|
||||
// no errors
|
||||
if (!errors) return false;
|
||||
|
||||
// has an error but no message
|
||||
if (errors.length === 1 && !errors[0]?.message) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
});
|
||||
|
||||
const isMultipleErrors = $derived(errors && errors.length > 1);
|
||||
const singleErrorMessage = $derived(errors && errors.length === 1 && errors[0]?.message);
|
||||
</script>
|
||||
|
||||
{#if hasContent}
|
||||
<div
|
||||
bind:this={ref}
|
||||
role="alert"
|
||||
data-slot="field-error"
|
||||
class={cn("text-destructive text-sm font-normal", className)}
|
||||
{...restProps}
|
||||
>
|
||||
{#if children}
|
||||
{@render children()}
|
||||
{:else if singleErrorMessage}
|
||||
{singleErrorMessage}
|
||||
{:else if isMultipleErrors}
|
||||
<ul class="ms-4 flex list-disc flex-col gap-1">
|
||||
{#each errors ?? [] as error, index (index)}
|
||||
{#if error?.message}
|
||||
<li>{error.message}</li>
|
||||
{/if}
|
||||
{/each}
|
||||
</ul>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
23
frontend/src/lib/components/ui/field/field-group.svelte
Normal file
23
frontend/src/lib/components/ui/field/field-group.svelte
Normal file
@@ -0,0 +1,23 @@
|
||||
<script lang="ts">
|
||||
import { cn, type WithElementRef } from "$lib/utils/style.js";
|
||||
import type { HTMLAttributes } from "svelte/elements";
|
||||
|
||||
let {
|
||||
ref = $bindable(null),
|
||||
class: className,
|
||||
children,
|
||||
...restProps
|
||||
}: WithElementRef<HTMLAttributes<HTMLDivElement>> = $props();
|
||||
</script>
|
||||
|
||||
<div
|
||||
bind:this={ref}
|
||||
data-slot="field-group"
|
||||
class={cn(
|
||||
"group/field-group @container/field-group flex w-full flex-col gap-7 data-[slot=checkbox-group]:gap-3 [&>[data-slot=field-group]]:gap-4",
|
||||
className
|
||||
)}
|
||||
{...restProps}
|
||||
>
|
||||
{@render children?.()}
|
||||
</div>
|
||||
30
frontend/src/lib/components/ui/field/field-label.svelte
Normal file
30
frontend/src/lib/components/ui/field/field-label.svelte
Normal file
@@ -0,0 +1,30 @@
|
||||
<script lang="ts">
|
||||
import { Label } from '$lib/components/ui/label/index.js';
|
||||
import { cn } from '$lib/utils/style.js';
|
||||
import type { ComponentProps } from 'svelte';
|
||||
|
||||
let {
|
||||
ref = $bindable(null),
|
||||
class: className,
|
||||
required = false,
|
||||
children,
|
||||
...restProps
|
||||
}: ComponentProps<typeof Label> & {
|
||||
required?: boolean;
|
||||
} = $props();
|
||||
</script>
|
||||
|
||||
<Label
|
||||
bind:ref
|
||||
data-slot="field-label"
|
||||
{required}
|
||||
class={cn(
|
||||
'group/field-label peer/field-label mt-1 mb-0 flex w-fit gap-2 leading-snug group-data-[disabled=true]/field:opacity-50',
|
||||
'has-[>[data-slot=field]]:w-full has-[>[data-slot=field]]:flex-col has-[>[data-slot=field]]:rounded-md has-[>[data-slot=field]]:border [&>*]:data-[slot=field]:p-4',
|
||||
'has-data-[state=checked]:bg-primary/5 has-data-[state=checked]:border-primary dark:has-data-[state=checked]:bg-primary/10',
|
||||
className
|
||||
)}
|
||||
{...restProps}
|
||||
>
|
||||
{@render children?.()}
|
||||
</Label>
|
||||
29
frontend/src/lib/components/ui/field/field-legend.svelte
Normal file
29
frontend/src/lib/components/ui/field/field-legend.svelte
Normal file
@@ -0,0 +1,29 @@
|
||||
<script lang="ts">
|
||||
import { cn, type WithElementRef } from "$lib/utils/style.js";
|
||||
import type { HTMLAttributes } from "svelte/elements";
|
||||
|
||||
let {
|
||||
ref = $bindable(null),
|
||||
class: className,
|
||||
variant = "legend",
|
||||
children,
|
||||
...restProps
|
||||
}: WithElementRef<HTMLAttributes<HTMLLegendElement>> & {
|
||||
variant?: "legend" | "label";
|
||||
} = $props();
|
||||
</script>
|
||||
|
||||
<legend
|
||||
bind:this={ref}
|
||||
data-slot="field-legend"
|
||||
data-variant={variant}
|
||||
class={cn(
|
||||
"mb-3 font-medium",
|
||||
"data-[variant=legend]:text-base",
|
||||
"data-[variant=label]:text-sm",
|
||||
className
|
||||
)}
|
||||
{...restProps}
|
||||
>
|
||||
{@render children?.()}
|
||||
</legend>
|
||||
38
frontend/src/lib/components/ui/field/field-separator.svelte
Normal file
38
frontend/src/lib/components/ui/field/field-separator.svelte
Normal file
@@ -0,0 +1,38 @@
|
||||
<script lang="ts">
|
||||
import { Separator } from "$lib/components/ui/separator/index.js";
|
||||
import { cn, type WithElementRef } from "$lib/utils/style.js";
|
||||
import type { HTMLAttributes } from "svelte/elements";
|
||||
import type { Snippet } from "svelte";
|
||||
|
||||
let {
|
||||
ref = $bindable(null),
|
||||
class: className,
|
||||
children,
|
||||
...restProps
|
||||
}: WithElementRef<HTMLAttributes<HTMLDivElement>> & {
|
||||
children?: Snippet;
|
||||
} = $props();
|
||||
|
||||
const hasContent = $derived(!!children);
|
||||
</script>
|
||||
|
||||
<div
|
||||
bind:this={ref}
|
||||
data-slot="field-separator"
|
||||
data-content={hasContent}
|
||||
class={cn(
|
||||
"relative -my-2 h-5 text-sm group-data-[variant=outline]/field-group:-mb-2",
|
||||
className
|
||||
)}
|
||||
{...restProps}
|
||||
>
|
||||
<Separator class="absolute inset-0 top-1/2" />
|
||||
{#if children}
|
||||
<span
|
||||
class="bg-background text-muted-foreground relative mx-auto block w-fit px-2"
|
||||
data-slot="field-separator-content"
|
||||
>
|
||||
{@render children()}
|
||||
</span>
|
||||
{/if}
|
||||
</div>
|
||||
24
frontend/src/lib/components/ui/field/field-set.svelte
Normal file
24
frontend/src/lib/components/ui/field/field-set.svelte
Normal file
@@ -0,0 +1,24 @@
|
||||
<script lang="ts">
|
||||
import { cn, type WithElementRef } from "$lib/utils/style.js";
|
||||
import type { HTMLFieldsetAttributes } from "svelte/elements";
|
||||
|
||||
let {
|
||||
ref = $bindable(null),
|
||||
class: className,
|
||||
children,
|
||||
...restProps
|
||||
}: WithElementRef<HTMLFieldsetAttributes> = $props();
|
||||
</script>
|
||||
|
||||
<fieldset
|
||||
bind:this={ref}
|
||||
data-slot="field-set"
|
||||
class={cn(
|
||||
"flex flex-col gap-6",
|
||||
"has-[>[data-slot=checkbox-group]]:gap-3 has-[>[data-slot=radio-group]]:gap-3",
|
||||
className
|
||||
)}
|
||||
{...restProps}
|
||||
>
|
||||
{@render children?.()}
|
||||
</fieldset>
|
||||
23
frontend/src/lib/components/ui/field/field-title.svelte
Normal file
23
frontend/src/lib/components/ui/field/field-title.svelte
Normal file
@@ -0,0 +1,23 @@
|
||||
<script lang="ts">
|
||||
import { cn, type WithElementRef } from "$lib/utils/style.js";
|
||||
import type { HTMLAttributes } from "svelte/elements";
|
||||
|
||||
let {
|
||||
ref = $bindable(null),
|
||||
class: className,
|
||||
children,
|
||||
...restProps
|
||||
}: WithElementRef<HTMLAttributes<HTMLDivElement>> = $props();
|
||||
</script>
|
||||
|
||||
<div
|
||||
bind:this={ref}
|
||||
data-slot="field-title"
|
||||
class={cn(
|
||||
"flex w-fit items-center gap-2 text-sm font-medium leading-snug group-data-[disabled=true]/field:opacity-50",
|
||||
className
|
||||
)}
|
||||
{...restProps}
|
||||
>
|
||||
{@render children?.()}
|
||||
</div>
|
||||
53
frontend/src/lib/components/ui/field/field.svelte
Normal file
53
frontend/src/lib/components/ui/field/field.svelte
Normal file
@@ -0,0 +1,53 @@
|
||||
<script lang="ts" module>
|
||||
import { tv, type VariantProps } from "tailwind-variants";
|
||||
|
||||
export const fieldVariants = tv({
|
||||
base: "group/field data-[invalid=true]:text-destructive flex w-full gap-3",
|
||||
variants: {
|
||||
orientation: {
|
||||
vertical: "flex-col [&>*]:w-full [&>.sr-only]:w-auto",
|
||||
horizontal: [
|
||||
"flex-row items-center",
|
||||
"[&>[data-slot=field-label]]:flex-auto",
|
||||
"has-[>[data-slot=field-content]]:[&>[role=checkbox],[role=radio]]:mt-px has-[>[data-slot=field-content]]:items-start",
|
||||
],
|
||||
responsive: [
|
||||
"@md/field-group:flex-row @md/field-group:items-center @md/field-group:[&>*]:w-auto flex-col [&>*]:w-full [&>.sr-only]:w-auto",
|
||||
"@md/field-group:[&>[data-slot=field-label]]:flex-auto",
|
||||
"@md/field-group:has-[>[data-slot=field-content]]:items-start @md/field-group:has-[>[data-slot=field-content]]:[&>[role=checkbox],[role=radio]]:mt-px",
|
||||
],
|
||||
},
|
||||
},
|
||||
defaultVariants: {
|
||||
orientation: "vertical",
|
||||
},
|
||||
});
|
||||
|
||||
export type FieldOrientation = VariantProps<typeof fieldVariants>["orientation"];
|
||||
</script>
|
||||
|
||||
<script lang="ts">
|
||||
import { cn, type WithElementRef } from "$lib/utils/style.js";
|
||||
import type { HTMLAttributes } from "svelte/elements";
|
||||
|
||||
let {
|
||||
ref = $bindable(null),
|
||||
class: className,
|
||||
orientation = "vertical",
|
||||
children,
|
||||
...restProps
|
||||
}: WithElementRef<HTMLAttributes<HTMLDivElement>> & {
|
||||
orientation?: FieldOrientation;
|
||||
} = $props();
|
||||
</script>
|
||||
|
||||
<div
|
||||
bind:this={ref}
|
||||
role="group"
|
||||
data-slot="field"
|
||||
data-orientation={orientation}
|
||||
class={cn(fieldVariants({ orientation }), className)}
|
||||
{...restProps}
|
||||
>
|
||||
{@render children?.()}
|
||||
</div>
|
||||
33
frontend/src/lib/components/ui/field/index.ts
Normal file
33
frontend/src/lib/components/ui/field/index.ts
Normal file
@@ -0,0 +1,33 @@
|
||||
import Field from "./field.svelte";
|
||||
import Set from "./field-set.svelte";
|
||||
import Legend from "./field-legend.svelte";
|
||||
import Group from "./field-group.svelte";
|
||||
import Content from "./field-content.svelte";
|
||||
import Label from "./field-label.svelte";
|
||||
import Title from "./field-title.svelte";
|
||||
import Description from "./field-description.svelte";
|
||||
import Separator from "./field-separator.svelte";
|
||||
import Error from "./field-error.svelte";
|
||||
|
||||
export {
|
||||
Field,
|
||||
Set,
|
||||
Legend,
|
||||
Group,
|
||||
Content,
|
||||
Label,
|
||||
Title,
|
||||
Description,
|
||||
Separator,
|
||||
Error,
|
||||
//
|
||||
Set as FieldSet,
|
||||
Legend as FieldLegend,
|
||||
Group as FieldGroup,
|
||||
Content as FieldContent,
|
||||
Label as FieldLabel,
|
||||
Title as FieldTitle,
|
||||
Description as FieldDescription,
|
||||
Separator as FieldSeparator,
|
||||
Error as FieldError,
|
||||
};
|
||||
34
frontend/src/lib/components/ui/item/index.ts
Normal file
34
frontend/src/lib/components/ui/item/index.ts
Normal file
@@ -0,0 +1,34 @@
|
||||
import Root from "./item.svelte";
|
||||
import Group from "./item-group.svelte";
|
||||
import Separator from "./item-separator.svelte";
|
||||
import Header from "./item-header.svelte";
|
||||
import Footer from "./item-footer.svelte";
|
||||
import Content from "./item-content.svelte";
|
||||
import Title from "./item-title.svelte";
|
||||
import Description from "./item-description.svelte";
|
||||
import Actions from "./item-actions.svelte";
|
||||
import Media from "./item-media.svelte";
|
||||
|
||||
export {
|
||||
Root,
|
||||
Group,
|
||||
Separator,
|
||||
Header,
|
||||
Footer,
|
||||
Content,
|
||||
Title,
|
||||
Description,
|
||||
Actions,
|
||||
Media,
|
||||
//
|
||||
Root as Item,
|
||||
Group as ItemGroup,
|
||||
Separator as ItemSeparator,
|
||||
Header as ItemHeader,
|
||||
Footer as ItemFooter,
|
||||
Content as ItemContent,
|
||||
Title as ItemTitle,
|
||||
Description as ItemDescription,
|
||||
Actions as ItemActions,
|
||||
Media as ItemMedia,
|
||||
};
|
||||
20
frontend/src/lib/components/ui/item/item-actions.svelte
Normal file
20
frontend/src/lib/components/ui/item/item-actions.svelte
Normal file
@@ -0,0 +1,20 @@
|
||||
<script lang="ts">
|
||||
import { cn, type WithElementRef } from "$lib/utils/style.js";
|
||||
import type { HTMLAttributes } from "svelte/elements";
|
||||
|
||||
let {
|
||||
ref = $bindable(null),
|
||||
class: className,
|
||||
children,
|
||||
...restProps
|
||||
}: WithElementRef<HTMLAttributes<HTMLDivElement>> = $props();
|
||||
</script>
|
||||
|
||||
<div
|
||||
bind:this={ref}
|
||||
data-slot="item-actions"
|
||||
class={cn("flex items-center gap-2", className)}
|
||||
{...restProps}
|
||||
>
|
||||
{@render children?.()}
|
||||
</div>
|
||||
20
frontend/src/lib/components/ui/item/item-content.svelte
Normal file
20
frontend/src/lib/components/ui/item/item-content.svelte
Normal file
@@ -0,0 +1,20 @@
|
||||
<script lang="ts">
|
||||
import { cn, type WithElementRef } from "$lib/utils/style.js";
|
||||
import type { HTMLAttributes } from "svelte/elements";
|
||||
|
||||
let {
|
||||
ref = $bindable(null),
|
||||
class: className,
|
||||
children,
|
||||
...restProps
|
||||
}: WithElementRef<HTMLAttributes<HTMLDivElement>> = $props();
|
||||
</script>
|
||||
|
||||
<div
|
||||
bind:this={ref}
|
||||
data-slot="item-content"
|
||||
class={cn("flex flex-1 flex-col gap-1 [&+[data-slot=item-content]]:flex-none", className)}
|
||||
{...restProps}
|
||||
>
|
||||
{@render children?.()}
|
||||
</div>
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user