mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-17 01:11:38 +03:00
feat: encrypt private keys saved on disk and in database (#682)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
This commit is contained in:
committed by
GitHub
parent
9872608d61
commit
5550729120
@@ -38,7 +38,7 @@ func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client) (sv
|
|||||||
|
|
||||||
svc.geoLiteService = service.NewGeoLiteService(httpClient)
|
svc.geoLiteService = service.NewGeoLiteService(httpClient)
|
||||||
svc.auditLogService = service.NewAuditLogService(db, svc.appConfigService, svc.emailService, svc.geoLiteService)
|
svc.auditLogService = service.NewAuditLogService(db, svc.appConfigService, svc.emailService, svc.geoLiteService)
|
||||||
svc.jwtService = service.NewJwtService(svc.appConfigService)
|
svc.jwtService = service.NewJwtService(db, svc.appConfigService)
|
||||||
svc.userService = service.NewUserService(db, svc.jwtService, svc.auditLogService, svc.emailService, svc.appConfigService)
|
svc.userService = service.NewUserService(db, svc.jwtService, svc.auditLogService, svc.emailService, svc.appConfigService)
|
||||||
svc.customClaimService = service.NewCustomClaimService(db)
|
svc.customClaimService = service.NewCustomClaimService(db)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
@@ -18,9 +20,10 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
DbProviderSqlite DbProvider = "sqlite"
|
DbProviderSqlite DbProvider = "sqlite"
|
||||||
DbProviderPostgres DbProvider = "postgres"
|
DbProviderPostgres DbProvider = "postgres"
|
||||||
MaxMindGeoLiteCityUrl string = "https://download.maxmind.com/app/geoip_download?edition_id=GeoLite2-City&license_key=%s&suffix=tar.gz"
|
MaxMindGeoLiteCityUrl string = "https://download.maxmind.com/app/geoip_download?edition_id=GeoLite2-City&license_key=%s&suffix=tar.gz"
|
||||||
|
defaultSqliteConnString string = "file:data/pocket-id.db?_pragma=journal_mode(WAL)&_pragma=busy_timeout(2500)&_txlock=immediate"
|
||||||
)
|
)
|
||||||
|
|
||||||
type EnvConfigSchema struct {
|
type EnvConfigSchema struct {
|
||||||
@@ -30,6 +33,9 @@ type EnvConfigSchema struct {
|
|||||||
DbConnectionString string `env:"DB_CONNECTION_STRING"`
|
DbConnectionString string `env:"DB_CONNECTION_STRING"`
|
||||||
UploadPath string `env:"UPLOAD_PATH"`
|
UploadPath string `env:"UPLOAD_PATH"`
|
||||||
KeysPath string `env:"KEYS_PATH"`
|
KeysPath string `env:"KEYS_PATH"`
|
||||||
|
KeysStorage string `env:"KEYS_STORAGE"`
|
||||||
|
EncryptionKey string `env:"ENCRYPTION_KEY"`
|
||||||
|
EncryptionKeyFile string `env:"ENCRYPTION_KEY_FILE"`
|
||||||
Port string `env:"PORT"`
|
Port string `env:"PORT"`
|
||||||
Host string `env:"HOST"`
|
Host string `env:"HOST"`
|
||||||
UnixSocket string `env:"UNIX_SOCKET"`
|
UnixSocket string `env:"UNIX_SOCKET"`
|
||||||
@@ -45,52 +51,83 @@ type EnvConfigSchema struct {
|
|||||||
AnalyticsDisabled bool `env:"ANALYTICS_DISABLED"`
|
AnalyticsDisabled bool `env:"ANALYTICS_DISABLED"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var EnvConfig = &EnvConfigSchema{
|
var EnvConfig = defaultConfig()
|
||||||
AppEnv: "production",
|
|
||||||
DbProvider: "sqlite",
|
|
||||||
DbConnectionString: "file:data/pocket-id.db?_pragma=journal_mode(WAL)&_pragma=busy_timeout(2500)&_txlock=immediate",
|
|
||||||
UploadPath: "data/uploads",
|
|
||||||
KeysPath: "data/keys",
|
|
||||||
AppURL: "http://localhost:1411",
|
|
||||||
Port: "1411",
|
|
||||||
Host: "0.0.0.0",
|
|
||||||
UnixSocket: "",
|
|
||||||
UnixSocketMode: "",
|
|
||||||
MaxMindLicenseKey: "",
|
|
||||||
GeoLiteDBPath: "data/GeoLite2-City.mmdb",
|
|
||||||
GeoLiteDBUrl: MaxMindGeoLiteCityUrl,
|
|
||||||
LocalIPv6Ranges: "",
|
|
||||||
UiConfigDisabled: false,
|
|
||||||
MetricsEnabled: false,
|
|
||||||
TracingEnabled: false,
|
|
||||||
TrustProxy: false,
|
|
||||||
AnalyticsDisabled: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
if err := env.ParseWithOptions(EnvConfig, env.Options{}); err != nil {
|
err := parseEnvConfig()
|
||||||
log.Fatal(err)
|
if err != nil {
|
||||||
|
log.Fatalf("Configuration error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultConfig() EnvConfigSchema {
|
||||||
|
return EnvConfigSchema{
|
||||||
|
AppEnv: "production",
|
||||||
|
DbProvider: "sqlite",
|
||||||
|
DbConnectionString: "",
|
||||||
|
UploadPath: "data/uploads",
|
||||||
|
KeysPath: "data/keys",
|
||||||
|
KeysStorage: "", // "database" or "file"
|
||||||
|
EncryptionKey: "",
|
||||||
|
AppURL: "http://localhost:1411",
|
||||||
|
Port: "1411",
|
||||||
|
Host: "0.0.0.0",
|
||||||
|
UnixSocket: "",
|
||||||
|
UnixSocketMode: "",
|
||||||
|
MaxMindLicenseKey: "",
|
||||||
|
GeoLiteDBPath: "data/GeoLite2-City.mmdb",
|
||||||
|
GeoLiteDBUrl: MaxMindGeoLiteCityUrl,
|
||||||
|
LocalIPv6Ranges: "",
|
||||||
|
UiConfigDisabled: false,
|
||||||
|
MetricsEnabled: false,
|
||||||
|
TracingEnabled: false,
|
||||||
|
TrustProxy: false,
|
||||||
|
AnalyticsDisabled: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseEnvConfig() error {
|
||||||
|
err := env.ParseWithOptions(&EnvConfig, env.Options{})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error parsing env config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate the environment variables
|
// Validate the environment variables
|
||||||
switch EnvConfig.DbProvider {
|
switch EnvConfig.DbProvider {
|
||||||
case DbProviderSqlite:
|
case DbProviderSqlite:
|
||||||
if EnvConfig.DbConnectionString == "" {
|
if EnvConfig.DbConnectionString == "" {
|
||||||
log.Fatal("Missing required env var 'DB_CONNECTION_STRING' for SQLite database")
|
EnvConfig.DbConnectionString = defaultSqliteConnString
|
||||||
}
|
}
|
||||||
case DbProviderPostgres:
|
case DbProviderPostgres:
|
||||||
if EnvConfig.DbConnectionString == "" {
|
if EnvConfig.DbConnectionString == "" {
|
||||||
log.Fatal("Missing required env var 'DB_CONNECTION_STRING' for Postgres database")
|
return errors.New("missing required env var 'DB_CONNECTION_STRING' for Postgres database")
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
log.Fatal("Invalid DB_PROVIDER value. Must be 'sqlite' or 'postgres'")
|
return errors.New("invalid DB_PROVIDER value. Must be 'sqlite' or 'postgres'")
|
||||||
}
|
}
|
||||||
|
|
||||||
parsedAppUrl, err := url.Parse(EnvConfig.AppURL)
|
parsedAppUrl, err := url.Parse(EnvConfig.AppURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("APP_URL is not a valid URL")
|
return errors.New("APP_URL is not a valid URL")
|
||||||
}
|
}
|
||||||
if parsedAppUrl.Path != "" {
|
if parsedAppUrl.Path != "" {
|
||||||
log.Fatal("APP_URL must not contain a path")
|
return errors.New("APP_URL must not contain a path")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
switch EnvConfig.KeysStorage {
|
||||||
|
// KeysStorage defaults to "file" if empty
|
||||||
|
case "":
|
||||||
|
EnvConfig.KeysStorage = "file"
|
||||||
|
case "database":
|
||||||
|
// If KeysStorage is "database", a key must be specified
|
||||||
|
if EnvConfig.EncryptionKey == "" && EnvConfig.EncryptionKeyFile == "" {
|
||||||
|
return errors.New("ENCRYPTION_KEY or ENCRYPTION_KEY_FILE must be non-empty when KEYS_STORAGE is database")
|
||||||
|
}
|
||||||
|
case "file":
|
||||||
|
// All good, these are valid values
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("invalid value for KEYS_STORAGE: %s", EnvConfig.KeysStorage)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
188
backend/internal/common/env_config_test.go
Normal file
188
backend/internal/common/env_config_test.go
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseEnvConfig(t *testing.T) {
|
||||||
|
// Store original config to restore later
|
||||||
|
originalConfig := EnvConfig
|
||||||
|
t.Cleanup(func() {
|
||||||
|
EnvConfig = originalConfig
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should parse valid SQLite config correctly", func(t *testing.T) {
|
||||||
|
EnvConfig = defaultConfig()
|
||||||
|
t.Setenv("DB_PROVIDER", "sqlite")
|
||||||
|
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||||
|
t.Setenv("APP_URL", "http://localhost:3000")
|
||||||
|
|
||||||
|
err := parseEnvConfig()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, DbProviderSqlite, EnvConfig.DbProvider)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should parse valid Postgres config correctly", func(t *testing.T) {
|
||||||
|
EnvConfig = defaultConfig()
|
||||||
|
t.Setenv("DB_PROVIDER", "postgres")
|
||||||
|
t.Setenv("DB_CONNECTION_STRING", "postgres://user:pass@localhost/db")
|
||||||
|
t.Setenv("APP_URL", "https://example.com")
|
||||||
|
|
||||||
|
err := parseEnvConfig()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, DbProviderPostgres, EnvConfig.DbProvider)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should fail with invalid DB_PROVIDER", func(t *testing.T) {
|
||||||
|
EnvConfig = defaultConfig()
|
||||||
|
t.Setenv("DB_PROVIDER", "invalid")
|
||||||
|
t.Setenv("DB_CONNECTION_STRING", "test")
|
||||||
|
t.Setenv("APP_URL", "http://localhost:3000")
|
||||||
|
|
||||||
|
err := parseEnvConfig()
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.ErrorContains(t, err, "invalid DB_PROVIDER value")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should set default SQLite connection string when DB_CONNECTION_STRING is empty", func(t *testing.T) {
|
||||||
|
EnvConfig = defaultConfig()
|
||||||
|
t.Setenv("DB_PROVIDER", "sqlite")
|
||||||
|
t.Setenv("DB_CONNECTION_STRING", "") // Explicitly empty
|
||||||
|
t.Setenv("APP_URL", "http://localhost:3000")
|
||||||
|
|
||||||
|
err := parseEnvConfig()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, defaultSqliteConnString, EnvConfig.DbConnectionString)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should fail when Postgres DB_CONNECTION_STRING is missing", func(t *testing.T) {
|
||||||
|
EnvConfig = defaultConfig()
|
||||||
|
t.Setenv("DB_PROVIDER", "postgres")
|
||||||
|
t.Setenv("APP_URL", "http://localhost:3000")
|
||||||
|
|
||||||
|
err := parseEnvConfig()
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.ErrorContains(t, err, "missing required env var 'DB_CONNECTION_STRING' for Postgres")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should fail with invalid APP_URL", func(t *testing.T) {
|
||||||
|
EnvConfig = defaultConfig()
|
||||||
|
t.Setenv("DB_PROVIDER", "sqlite")
|
||||||
|
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||||
|
t.Setenv("APP_URL", "€://not-a-valid-url")
|
||||||
|
|
||||||
|
err := parseEnvConfig()
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.ErrorContains(t, err, "APP_URL is not a valid URL")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should fail when APP_URL contains path", func(t *testing.T) {
|
||||||
|
EnvConfig = defaultConfig()
|
||||||
|
t.Setenv("DB_PROVIDER", "sqlite")
|
||||||
|
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||||
|
t.Setenv("APP_URL", "http://localhost:3000/path")
|
||||||
|
|
||||||
|
err := parseEnvConfig()
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.ErrorContains(t, err, "APP_URL must not contain a path")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should default KEYS_STORAGE to 'file' when empty", func(t *testing.T) {
|
||||||
|
EnvConfig = defaultConfig()
|
||||||
|
t.Setenv("DB_PROVIDER", "sqlite")
|
||||||
|
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||||
|
t.Setenv("APP_URL", "http://localhost:3000")
|
||||||
|
|
||||||
|
err := parseEnvConfig()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "file", EnvConfig.KeysStorage)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should fail when KEYS_STORAGE is 'database' but no encryption key", func(t *testing.T) {
|
||||||
|
EnvConfig = defaultConfig()
|
||||||
|
t.Setenv("DB_PROVIDER", "sqlite")
|
||||||
|
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||||
|
t.Setenv("APP_URL", "http://localhost:3000")
|
||||||
|
t.Setenv("KEYS_STORAGE", "database")
|
||||||
|
|
||||||
|
err := parseEnvConfig()
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.ErrorContains(t, err, "ENCRYPTION_KEY or ENCRYPTION_KEY_FILE must be non-empty")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should accept valid KEYS_STORAGE values", func(t *testing.T) {
|
||||||
|
validStorageTypes := []string{"file", "database"}
|
||||||
|
|
||||||
|
for _, storage := range validStorageTypes {
|
||||||
|
EnvConfig = defaultConfig()
|
||||||
|
t.Setenv("DB_PROVIDER", "sqlite")
|
||||||
|
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||||
|
t.Setenv("APP_URL", "http://localhost:3000")
|
||||||
|
t.Setenv("KEYS_STORAGE", storage)
|
||||||
|
if storage == "database" {
|
||||||
|
t.Setenv("ENCRYPTION_KEY", "test-key")
|
||||||
|
}
|
||||||
|
|
||||||
|
err := parseEnvConfig()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, storage, EnvConfig.KeysStorage)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should fail with invalid KEYS_STORAGE value", func(t *testing.T) {
|
||||||
|
EnvConfig = defaultConfig()
|
||||||
|
t.Setenv("DB_PROVIDER", "sqlite")
|
||||||
|
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||||
|
t.Setenv("APP_URL", "http://localhost:3000")
|
||||||
|
t.Setenv("KEYS_STORAGE", "invalid")
|
||||||
|
|
||||||
|
err := parseEnvConfig()
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.ErrorContains(t, err, "invalid value for KEYS_STORAGE")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should parse boolean environment variables correctly", func(t *testing.T) {
|
||||||
|
EnvConfig = defaultConfig()
|
||||||
|
t.Setenv("DB_PROVIDER", "sqlite")
|
||||||
|
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||||
|
t.Setenv("APP_URL", "http://localhost:3000")
|
||||||
|
t.Setenv("UI_CONFIG_DISABLED", "true")
|
||||||
|
t.Setenv("METRICS_ENABLED", "true")
|
||||||
|
t.Setenv("TRACING_ENABLED", "false")
|
||||||
|
t.Setenv("TRUST_PROXY", "true")
|
||||||
|
t.Setenv("ANALYTICS_DISABLED", "false")
|
||||||
|
|
||||||
|
err := parseEnvConfig()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, EnvConfig.UiConfigDisabled)
|
||||||
|
assert.True(t, EnvConfig.MetricsEnabled)
|
||||||
|
assert.False(t, EnvConfig.TracingEnabled)
|
||||||
|
assert.True(t, EnvConfig.TrustProxy)
|
||||||
|
assert.False(t, EnvConfig.AnalyticsDisabled)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should parse string environment variables correctly", func(t *testing.T) {
|
||||||
|
EnvConfig = defaultConfig()
|
||||||
|
t.Setenv("DB_PROVIDER", "postgres")
|
||||||
|
t.Setenv("DB_CONNECTION_STRING", "postgres://test")
|
||||||
|
t.Setenv("APP_URL", "https://prod.example.com")
|
||||||
|
t.Setenv("APP_ENV", "staging")
|
||||||
|
t.Setenv("UPLOAD_PATH", "/custom/uploads")
|
||||||
|
t.Setenv("KEYS_PATH", "/custom/keys")
|
||||||
|
t.Setenv("PORT", "8080")
|
||||||
|
t.Setenv("HOST", "127.0.0.1")
|
||||||
|
t.Setenv("UNIX_SOCKET", "/tmp/app.sock")
|
||||||
|
t.Setenv("MAXMIND_LICENSE_KEY", "test-license")
|
||||||
|
t.Setenv("GEOLITE_DB_PATH", "/custom/geolite.mmdb")
|
||||||
|
|
||||||
|
err := parseEnvConfig()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "staging", EnvConfig.AppEnv)
|
||||||
|
assert.Equal(t, "/custom/uploads", EnvConfig.UploadPath)
|
||||||
|
assert.Equal(t, "8080", EnvConfig.Port)
|
||||||
|
assert.Equal(t, "127.0.0.1", EnvConfig.Host)
|
||||||
|
})
|
||||||
|
}
|
||||||
11
backend/internal/model/kv.go
Normal file
11
backend/internal/model/kv.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
type KV struct {
|
||||||
|
Key string `gorm:"primaryKey;not null"`
|
||||||
|
Value *string
|
||||||
|
}
|
||||||
|
|
||||||
|
// TableName overrides the table name used by KV to `kv`
|
||||||
|
func (KV) TableName() string {
|
||||||
|
return "kv"
|
||||||
|
}
|
||||||
@@ -4,10 +4,12 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||||
"github.com/stretchr/testify/require"
|
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewTestAppConfigService is a function used by tests to create AppConfigService objects with pre-defined configuration values
|
// NewTestAppConfigService is a function used by tests to create AppConfigService objects with pre-defined configuration values
|
||||||
@@ -22,7 +24,7 @@ func NewTestAppConfigService(config *model.AppConfig) *AppConfigService {
|
|||||||
|
|
||||||
func TestLoadDbConfig(t *testing.T) {
|
func TestLoadDbConfig(t *testing.T) {
|
||||||
t.Run("empty config table", func(t *testing.T) {
|
t.Run("empty config table", func(t *testing.T) {
|
||||||
db := newDatabaseForTest(t)
|
db := testutils.NewDatabaseForTest(t)
|
||||||
service := &AppConfigService{
|
service := &AppConfigService{
|
||||||
db: db,
|
db: db,
|
||||||
}
|
}
|
||||||
@@ -36,7 +38,7 @@ func TestLoadDbConfig(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("loads value from config table", func(t *testing.T) {
|
t.Run("loads value from config table", func(t *testing.T) {
|
||||||
db := newDatabaseForTest(t)
|
db := testutils.NewDatabaseForTest(t)
|
||||||
|
|
||||||
// Populate the config table with some initial values
|
// Populate the config table with some initial values
|
||||||
err := db.
|
err := db.
|
||||||
@@ -66,7 +68,7 @@ func TestLoadDbConfig(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("ignores unknown config keys", func(t *testing.T) {
|
t.Run("ignores unknown config keys", func(t *testing.T) {
|
||||||
db := newDatabaseForTest(t)
|
db := testutils.NewDatabaseForTest(t)
|
||||||
|
|
||||||
// Add an entry with a key that doesn't exist in the config struct
|
// Add an entry with a key that doesn't exist in the config struct
|
||||||
err := db.Create([]model.AppConfigVariable{
|
err := db.Create([]model.AppConfigVariable{
|
||||||
@@ -87,7 +89,7 @@ func TestLoadDbConfig(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("loading config multiple times", func(t *testing.T) {
|
t.Run("loading config multiple times", func(t *testing.T) {
|
||||||
db := newDatabaseForTest(t)
|
db := testutils.NewDatabaseForTest(t)
|
||||||
|
|
||||||
// Initial state
|
// Initial state
|
||||||
err := db.Create([]model.AppConfigVariable{
|
err := db.Create([]model.AppConfigVariable{
|
||||||
@@ -129,7 +131,7 @@ func TestLoadDbConfig(t *testing.T) {
|
|||||||
common.EnvConfig.UiConfigDisabled = true
|
common.EnvConfig.UiConfigDisabled = true
|
||||||
|
|
||||||
// Create database with config that should be ignored
|
// Create database with config that should be ignored
|
||||||
db := newDatabaseForTest(t)
|
db := testutils.NewDatabaseForTest(t)
|
||||||
err := db.Create([]model.AppConfigVariable{
|
err := db.Create([]model.AppConfigVariable{
|
||||||
{Key: "appName", Value: "DB App"},
|
{Key: "appName", Value: "DB App"},
|
||||||
{Key: "sessionDuration", Value: "120"},
|
{Key: "sessionDuration", Value: "120"},
|
||||||
@@ -165,7 +167,7 @@ func TestLoadDbConfig(t *testing.T) {
|
|||||||
common.EnvConfig.UiConfigDisabled = false
|
common.EnvConfig.UiConfigDisabled = false
|
||||||
|
|
||||||
// Create database with config values that should take precedence
|
// Create database with config values that should take precedence
|
||||||
db := newDatabaseForTest(t)
|
db := testutils.NewDatabaseForTest(t)
|
||||||
err := db.Create([]model.AppConfigVariable{
|
err := db.Create([]model.AppConfigVariable{
|
||||||
{Key: "appName", Value: "DB App"},
|
{Key: "appName", Value: "DB App"},
|
||||||
{Key: "sessionDuration", Value: "120"},
|
{Key: "sessionDuration", Value: "120"},
|
||||||
@@ -189,7 +191,7 @@ func TestLoadDbConfig(t *testing.T) {
|
|||||||
|
|
||||||
func TestUpdateAppConfigValues(t *testing.T) {
|
func TestUpdateAppConfigValues(t *testing.T) {
|
||||||
t.Run("update single value", func(t *testing.T) {
|
t.Run("update single value", func(t *testing.T) {
|
||||||
db := newDatabaseForTest(t)
|
db := testutils.NewDatabaseForTest(t)
|
||||||
|
|
||||||
// Create a service with default config
|
// Create a service with default config
|
||||||
service := &AppConfigService{
|
service := &AppConfigService{
|
||||||
@@ -214,7 +216,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("update multiple values", func(t *testing.T) {
|
t.Run("update multiple values", func(t *testing.T) {
|
||||||
db := newDatabaseForTest(t)
|
db := testutils.NewDatabaseForTest(t)
|
||||||
|
|
||||||
// Create a service with default config
|
// Create a service with default config
|
||||||
service := &AppConfigService{
|
service := &AppConfigService{
|
||||||
@@ -258,7 +260,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("empty value resets to default", func(t *testing.T) {
|
t.Run("empty value resets to default", func(t *testing.T) {
|
||||||
db := newDatabaseForTest(t)
|
db := testutils.NewDatabaseForTest(t)
|
||||||
|
|
||||||
// Create a service with default config
|
// Create a service with default config
|
||||||
service := &AppConfigService{
|
service := &AppConfigService{
|
||||||
@@ -279,7 +281,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("error with odd number of arguments", func(t *testing.T) {
|
t.Run("error with odd number of arguments", func(t *testing.T) {
|
||||||
db := newDatabaseForTest(t)
|
db := testutils.NewDatabaseForTest(t)
|
||||||
|
|
||||||
// Create a service with default config
|
// Create a service with default config
|
||||||
service := &AppConfigService{
|
service := &AppConfigService{
|
||||||
@@ -295,7 +297,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("error with invalid key", func(t *testing.T) {
|
t.Run("error with invalid key", func(t *testing.T) {
|
||||||
db := newDatabaseForTest(t)
|
db := testutils.NewDatabaseForTest(t)
|
||||||
|
|
||||||
// Create a service with default config
|
// Create a service with default config
|
||||||
service := &AppConfigService{
|
service := &AppConfigService{
|
||||||
@@ -313,7 +315,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
|
|||||||
|
|
||||||
func TestUpdateAppConfig(t *testing.T) {
|
func TestUpdateAppConfig(t *testing.T) {
|
||||||
t.Run("updates configuration values from DTO", func(t *testing.T) {
|
t.Run("updates configuration values from DTO", func(t *testing.T) {
|
||||||
db := newDatabaseForTest(t)
|
db := testutils.NewDatabaseForTest(t)
|
||||||
|
|
||||||
// Create a service with default config
|
// Create a service with default config
|
||||||
service := &AppConfigService{
|
service := &AppConfigService{
|
||||||
@@ -386,7 +388,7 @@ func TestUpdateAppConfig(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("empty values reset to defaults", func(t *testing.T) {
|
t.Run("empty values reset to defaults", func(t *testing.T) {
|
||||||
db := newDatabaseForTest(t)
|
db := testutils.NewDatabaseForTest(t)
|
||||||
|
|
||||||
// Create a service with default config and modify some values
|
// Create a service with default config and modify some values
|
||||||
service := &AppConfigService{
|
service := &AppConfigService{
|
||||||
@@ -451,7 +453,7 @@ func TestUpdateAppConfig(t *testing.T) {
|
|||||||
// Disable UI config
|
// Disable UI config
|
||||||
common.EnvConfig.UiConfigDisabled = true
|
common.EnvConfig.UiConfigDisabled = true
|
||||||
|
|
||||||
db := newDatabaseForTest(t)
|
db := testutils.NewDatabaseForTest(t)
|
||||||
service := &AppConfigService{
|
service := &AppConfigService{
|
||||||
db: db,
|
db: db,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
|
|
||||||
"github.com/fxamacker/cbor/v2"
|
"github.com/fxamacker/cbor/v2"
|
||||||
"github.com/go-webauthn/webauthn/protocol"
|
"github.com/go-webauthn/webauthn/protocol"
|
||||||
|
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -25,6 +26,7 @@ import (
|
|||||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||||
|
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
|
||||||
"github.com/pocket-id/pocket-id/backend/resources"
|
"github.com/pocket-id/pocket-id/backend/resources"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -60,7 +62,7 @@ func (s *TestService) initExternalIdP() error {
|
|||||||
return fmt.Errorf("failed to generate private key: %w", err)
|
return fmt.Errorf("failed to generate private key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.externalIdPKey, err = utils.ImportRawKey(rawKey)
|
s.externalIdPKey, err = jwkutils.ImportRawKey(rawKey, jwa.ES256().String(), "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to import private key: %w", err)
|
return fmt.Errorf("failed to import private key: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,23 +2,20 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -26,8 +23,9 @@ const (
|
|||||||
// This is a JSON file containing a key encoded as JWK
|
// This is a JSON file containing a key encoded as JWK
|
||||||
PrivateKeyFile = "jwt_private_key.json"
|
PrivateKeyFile = "jwt_private_key.json"
|
||||||
|
|
||||||
// RsaKeySize is the size, in bits, of the RSA key to generate if none is found
|
// PrivateKeyFileEncrypted is the path in the data/keys folder where the encrypted key is stored
|
||||||
RsaKeySize = 2048
|
// This is a encrypted JSON file containing a key encoded as JWK
|
||||||
|
PrivateKeyFileEncrypted = "jwt_private_key.json.enc"
|
||||||
|
|
||||||
// KeyUsageSigning is the usage for the private keys, for the "use" property
|
// KeyUsageSigning is the usage for the private keys, for the "use" property
|
||||||
KeyUsageSigning = "sig"
|
KeyUsageSigning = "sig"
|
||||||
@@ -59,58 +57,74 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type JwtService struct {
|
type JwtService struct {
|
||||||
|
envConfig *common.EnvConfigSchema
|
||||||
privateKey jwk.Key
|
privateKey jwk.Key
|
||||||
keyId string
|
keyId string
|
||||||
appConfigService *AppConfigService
|
appConfigService *AppConfigService
|
||||||
jwksEncoded []byte
|
jwksEncoded []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewJwtService(appConfigService *AppConfigService) *JwtService {
|
func NewJwtService(db *gorm.DB, appConfigService *AppConfigService) *JwtService {
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
|
|
||||||
// Ensure keys are generated or loaded
|
// Ensure keys are generated or loaded
|
||||||
if err := service.init(appConfigService, common.EnvConfig.KeysPath); err != nil {
|
err := service.init(db, appConfigService, &common.EnvConfig)
|
||||||
|
if err != nil {
|
||||||
log.Fatalf("Failed to initialize jwt service: %v", err)
|
log.Fatalf("Failed to initialize jwt service: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return service
|
return service
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *JwtService) init(appConfigService *AppConfigService, keysPath string) error {
|
func (s *JwtService) init(db *gorm.DB, appConfigService *AppConfigService, envConfig *common.EnvConfigSchema) (err error) {
|
||||||
s.appConfigService = appConfigService
|
s.appConfigService = appConfigService
|
||||||
|
s.envConfig = envConfig
|
||||||
|
|
||||||
// Ensure keys are generated or loaded
|
// Ensure keys are generated or loaded
|
||||||
return s.loadOrGenerateKey(keysPath)
|
return s.loadOrGenerateKey(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadOrGenerateKey loads the private key from the given path or generates it if not existing.
|
func (s *JwtService) loadOrGenerateKey(db *gorm.DB) error {
|
||||||
func (s *JwtService) loadOrGenerateKey(keysPath string) error {
|
// Get the key provider
|
||||||
var key jwk.Key
|
keyProvider, err := jwkutils.GetKeyProvider(db, s.envConfig, s.appConfigService.GetDbConfig().InstanceID.Value)
|
||||||
|
|
||||||
// First, check if we have a JWK file
|
|
||||||
// If we do, then we just load that
|
|
||||||
jwkPath := filepath.Join(keysPath, PrivateKeyFile)
|
|
||||||
ok, err := utils.FileExists(jwkPath)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to check if private key file (JWK) exists at path '%s': %w", jwkPath, err)
|
return fmt.Errorf("failed to get key provider: %w", err)
|
||||||
}
|
}
|
||||||
if ok {
|
|
||||||
key, err = s.loadKeyJWK(jwkPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to load private key file (JWK) at path '%s': %w", jwkPath, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the key, and we are done
|
// Try loading a key
|
||||||
|
key, err := keyProvider.LoadKey()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to load key (provider type '%s'): %w", s.envConfig.KeysStorage, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we have a key, store it in the object and we're done
|
||||||
|
if key != nil {
|
||||||
err = s.SetKey(key)
|
err = s.SetKey(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to set private key: %w", err)
|
return fmt.Errorf("failed to set private key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we are here, we need to generate a new key
|
// If we are here, we need to generate a new key
|
||||||
key, err = s.generateNewRSAKey()
|
err = s.generateKey()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save the newly-generated key
|
||||||
|
err = keyProvider.SaveKey(s.privateKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to save private key (provider type '%s'): %w", s.envConfig.KeysStorage, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateKey generates a new key and stores it in the object
|
||||||
|
func (s *JwtService) generateKey() error {
|
||||||
|
// Default is to generate RS256 (RSA-2048) keys
|
||||||
|
key, err := jwkutils.GenerateKey(jwa.RS256().String(), "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to generate new private key: %w", err)
|
return fmt.Errorf("failed to generate new private key: %w", err)
|
||||||
}
|
}
|
||||||
@@ -121,12 +135,6 @@ func (s *JwtService) loadOrGenerateKey(keysPath string) error {
|
|||||||
return fmt.Errorf("failed to set private key: %w", err)
|
return fmt.Errorf("failed to set private key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save the key as JWK
|
|
||||||
err = SaveKeyJWK(s.privateKey, jwkPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to save private key file at path '%s': %w", jwkPath, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -192,13 +200,13 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
|
|||||||
Subject(user.ID).
|
Subject(user.ID).
|
||||||
Expiration(now.Add(s.appConfigService.GetDbConfig().SessionDuration.AsDurationMinutes())).
|
Expiration(now.Add(s.appConfigService.GetDbConfig().SessionDuration.AsDurationMinutes())).
|
||||||
IssuedAt(now).
|
IssuedAt(now).
|
||||||
Issuer(common.EnvConfig.AppURL).
|
Issuer(s.envConfig.AppURL).
|
||||||
Build()
|
Build()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to build token: %w", err)
|
return "", fmt.Errorf("failed to build token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = SetAudienceString(token, common.EnvConfig.AppURL)
|
err = SetAudienceString(token, s.envConfig.AppURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
|
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
|
||||||
}
|
}
|
||||||
@@ -229,8 +237,8 @@ func (s *JwtService) VerifyAccessToken(tokenString string) (jwt.Token, error) {
|
|||||||
jwt.WithValidate(true),
|
jwt.WithValidate(true),
|
||||||
jwt.WithKey(alg, s.privateKey),
|
jwt.WithKey(alg, s.privateKey),
|
||||||
jwt.WithAcceptableSkew(clockSkew),
|
jwt.WithAcceptableSkew(clockSkew),
|
||||||
jwt.WithAudience(common.EnvConfig.AppURL),
|
jwt.WithAudience(s.envConfig.AppURL),
|
||||||
jwt.WithIssuer(common.EnvConfig.AppURL),
|
jwt.WithIssuer(s.envConfig.AppURL),
|
||||||
jwt.WithValidator(TokenTypeValidator(AccessTokenJWTType)),
|
jwt.WithValidator(TokenTypeValidator(AccessTokenJWTType)),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -246,7 +254,7 @@ func (s *JwtService) BuildIDToken(userClaims map[string]any, clientID string, no
|
|||||||
token, err := jwt.NewBuilder().
|
token, err := jwt.NewBuilder().
|
||||||
Expiration(now.Add(1 * time.Hour)).
|
Expiration(now.Add(1 * time.Hour)).
|
||||||
IssuedAt(now).
|
IssuedAt(now).
|
||||||
Issuer(common.EnvConfig.AppURL).
|
Issuer(s.envConfig.AppURL).
|
||||||
Build()
|
Build()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to build token: %w", err)
|
return nil, fmt.Errorf("failed to build token: %w", err)
|
||||||
@@ -305,7 +313,7 @@ func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool)
|
|||||||
jwt.WithValidate(true),
|
jwt.WithValidate(true),
|
||||||
jwt.WithKey(alg, s.privateKey),
|
jwt.WithKey(alg, s.privateKey),
|
||||||
jwt.WithAcceptableSkew(clockSkew),
|
jwt.WithAcceptableSkew(clockSkew),
|
||||||
jwt.WithIssuer(common.EnvConfig.AppURL),
|
jwt.WithIssuer(s.envConfig.AppURL),
|
||||||
jwt.WithValidator(TokenTypeValidator(IDTokenJWTType)),
|
jwt.WithValidator(TokenTypeValidator(IDTokenJWTType)),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -335,7 +343,7 @@ func (s *JwtService) BuildOAuthAccessToken(user model.User, clientID string) (jw
|
|||||||
Subject(user.ID).
|
Subject(user.ID).
|
||||||
Expiration(now.Add(1 * time.Hour)).
|
Expiration(now.Add(1 * time.Hour)).
|
||||||
IssuedAt(now).
|
IssuedAt(now).
|
||||||
Issuer(common.EnvConfig.AppURL).
|
Issuer(s.envConfig.AppURL).
|
||||||
Build()
|
Build()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to build token: %w", err)
|
return nil, fmt.Errorf("failed to build token: %w", err)
|
||||||
@@ -377,7 +385,7 @@ func (s *JwtService) VerifyOAuthAccessToken(tokenString string) (jwt.Token, erro
|
|||||||
jwt.WithValidate(true),
|
jwt.WithValidate(true),
|
||||||
jwt.WithKey(alg, s.privateKey),
|
jwt.WithKey(alg, s.privateKey),
|
||||||
jwt.WithAcceptableSkew(clockSkew),
|
jwt.WithAcceptableSkew(clockSkew),
|
||||||
jwt.WithIssuer(common.EnvConfig.AppURL),
|
jwt.WithIssuer(s.envConfig.AppURL),
|
||||||
jwt.WithValidator(TokenTypeValidator(OAuthAccessTokenJWTType)),
|
jwt.WithValidator(TokenTypeValidator(OAuthAccessTokenJWTType)),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -393,7 +401,7 @@ func (s *JwtService) GenerateOAuthRefreshToken(userID string, clientID string, r
|
|||||||
Subject(userID).
|
Subject(userID).
|
||||||
Expiration(now.Add(RefreshTokenDuration)).
|
Expiration(now.Add(RefreshTokenDuration)).
|
||||||
IssuedAt(now).
|
IssuedAt(now).
|
||||||
Issuer(common.EnvConfig.AppURL).
|
Issuer(s.envConfig.AppURL).
|
||||||
Build()
|
Build()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to build token: %w", err)
|
return "", fmt.Errorf("failed to build token: %w", err)
|
||||||
@@ -430,7 +438,7 @@ func (s *JwtService) VerifyOAuthRefreshToken(tokenString string) (userID, client
|
|||||||
jwt.WithValidate(true),
|
jwt.WithValidate(true),
|
||||||
jwt.WithKey(alg, s.privateKey),
|
jwt.WithKey(alg, s.privateKey),
|
||||||
jwt.WithAcceptableSkew(clockSkew),
|
jwt.WithAcceptableSkew(clockSkew),
|
||||||
jwt.WithIssuer(common.EnvConfig.AppURL),
|
jwt.WithIssuer(s.envConfig.AppURL),
|
||||||
jwt.WithValidator(TokenTypeValidator(OAuthRefreshTokenJWTType)),
|
jwt.WithValidator(TokenTypeValidator(OAuthRefreshTokenJWTType)),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -488,7 +496,7 @@ func (s *JwtService) GetPublicJWK() (jwk.Key, error) {
|
|||||||
return nil, fmt.Errorf("failed to get public key: %w", err)
|
return nil, fmt.Errorf("failed to get public key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
utils.EnsureAlgInKey(pubKey)
|
jwkutils.EnsureAlgInKey(pubKey, "", "")
|
||||||
|
|
||||||
return pubKey, nil
|
return pubKey, nil
|
||||||
}
|
}
|
||||||
@@ -517,56 +525,6 @@ func (s *JwtService) GetKeyAlg() (jwa.KeyAlgorithm, error) {
|
|||||||
return alg, nil
|
return alg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *JwtService) loadKeyJWK(path string) (jwk.Key, error) {
|
|
||||||
data, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to read key data: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
key, err := jwk.ParseKey(data)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse key: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return key, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *JwtService) generateNewRSAKey() (jwk.Key, error) {
|
|
||||||
// We generate RSA keys only
|
|
||||||
rawKey, err := rsa.GenerateKey(rand.Reader, RsaKeySize)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to generate RSA private key: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Import the raw key
|
|
||||||
return utils.ImportRawKey(rawKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SaveKeyJWK saves a JWK to a file
|
|
||||||
func SaveKeyJWK(key jwk.Key, path string) error {
|
|
||||||
dir := filepath.Dir(path)
|
|
||||||
err := os.MkdirAll(dir, 0700)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create directory '%s' for key file: %w", dir, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
keyFile, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create key file: %w", err)
|
|
||||||
}
|
|
||||||
defer keyFile.Close()
|
|
||||||
|
|
||||||
// Write the JSON file to disk
|
|
||||||
enc := json.NewEncoder(keyFile)
|
|
||||||
enc.SetEscapeHTML(false)
|
|
||||||
err = enc.Encode(key)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to write key file: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetIsAdmin returns the value of the "isAdmin" claim in the token
|
// GetIsAdmin returns the value of the "isAdmin" claim in the token
|
||||||
func GetIsAdmin(token jwt.Token) (bool, error) {
|
func GetIsAdmin(token jwt.Token) (bool, error) {
|
||||||
if !token.Has(IsAdminClaim) {
|
if !token.Has(IsAdminClaim) {
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ import (
|
|||||||
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestJwtService_Init(t *testing.T) {
|
func TestJwtService_Init(t *testing.T) {
|
||||||
@@ -33,9 +33,16 @@ func TestJwtService_Init(t *testing.T) {
|
|||||||
// Create a temporary directory for the test
|
// Create a temporary directory for the test
|
||||||
tempDir := t.TempDir()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Setup the environment variable required by the token verification
|
||||||
|
mockEnvConfig := &common.EnvConfigSchema{
|
||||||
|
AppURL: "https://test.example.com",
|
||||||
|
KeysStorage: "file",
|
||||||
|
KeysPath: tempDir,
|
||||||
|
}
|
||||||
|
|
||||||
// Initialize the JWT service
|
// Initialize the JWT service
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
err := service.init(mockConfig, tempDir)
|
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||||
require.NoError(t, err, "Failed to initialize JWT service")
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
// Verify the private key was set
|
// Verify the private key was set
|
||||||
@@ -66,9 +73,16 @@ func TestJwtService_Init(t *testing.T) {
|
|||||||
// Create a temporary directory for the test
|
// Create a temporary directory for the test
|
||||||
tempDir := t.TempDir()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Setup the environment variable required by the token verification
|
||||||
|
mockEnvConfig := &common.EnvConfigSchema{
|
||||||
|
AppURL: "https://test.example.com",
|
||||||
|
KeysStorage: "file",
|
||||||
|
KeysPath: tempDir,
|
||||||
|
}
|
||||||
|
|
||||||
// First create a service to generate a key
|
// First create a service to generate a key
|
||||||
firstService := &JwtService{}
|
firstService := &JwtService{}
|
||||||
err := firstService.init(mockConfig, tempDir)
|
err := firstService.init(nil, mockConfig, mockEnvConfig)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Get the key ID of the first service
|
// Get the key ID of the first service
|
||||||
@@ -77,7 +91,7 @@ func TestJwtService_Init(t *testing.T) {
|
|||||||
|
|
||||||
// Now create a new service that should load the existing key
|
// Now create a new service that should load the existing key
|
||||||
secondService := &JwtService{}
|
secondService := &JwtService{}
|
||||||
err = secondService.init(mockConfig, tempDir)
|
err = secondService.init(nil, mockConfig, mockEnvConfig)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify the loaded key has the same ID as the original
|
// Verify the loaded key has the same ID as the original
|
||||||
@@ -90,12 +104,19 @@ func TestJwtService_Init(t *testing.T) {
|
|||||||
// Create a temporary directory for the test
|
// Create a temporary directory for the test
|
||||||
tempDir := t.TempDir()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Setup the environment variable required by the token verification
|
||||||
|
mockEnvConfig := &common.EnvConfigSchema{
|
||||||
|
AppURL: "https://test.example.com",
|
||||||
|
KeysStorage: "file",
|
||||||
|
KeysPath: tempDir,
|
||||||
|
}
|
||||||
|
|
||||||
// Create a new JWK and save it to disk
|
// Create a new JWK and save it to disk
|
||||||
origKeyID := createECDSAKeyJWK(t, tempDir)
|
origKeyID := createECDSAKeyJWK(t, tempDir)
|
||||||
|
|
||||||
// Now create a new service that should load the existing key
|
// Now create a new service that should load the existing key
|
||||||
svc := &JwtService{}
|
svc := &JwtService{}
|
||||||
err := svc.init(mockConfig, tempDir)
|
err := svc.init(nil, mockConfig, mockEnvConfig)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Ensure loaded key has the right algorithm
|
// Ensure loaded key has the right algorithm
|
||||||
@@ -113,12 +134,19 @@ func TestJwtService_Init(t *testing.T) {
|
|||||||
// Create a temporary directory for the test
|
// Create a temporary directory for the test
|
||||||
tempDir := t.TempDir()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Setup the environment variable required by the token verification
|
||||||
|
mockEnvConfig := &common.EnvConfigSchema{
|
||||||
|
AppURL: "https://test.example.com",
|
||||||
|
KeysStorage: "file",
|
||||||
|
KeysPath: tempDir,
|
||||||
|
}
|
||||||
|
|
||||||
// Create a new JWK and save it to disk
|
// Create a new JWK and save it to disk
|
||||||
origKeyID := createEdDSAKeyJWK(t, tempDir)
|
origKeyID := createEdDSAKeyJWK(t, tempDir)
|
||||||
|
|
||||||
// Now create a new service that should load the existing key
|
// Now create a new service that should load the existing key
|
||||||
svc := &JwtService{}
|
svc := &JwtService{}
|
||||||
err := svc.init(mockConfig, tempDir)
|
err := svc.init(nil, mockConfig, mockEnvConfig)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Ensure loaded key has the right algorithm and curve
|
// Ensure loaded key has the right algorithm and curve
|
||||||
@@ -147,9 +175,16 @@ func TestJwtService_GetPublicJWK(t *testing.T) {
|
|||||||
// Create a temporary directory for the test
|
// Create a temporary directory for the test
|
||||||
tempDir := t.TempDir()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Setup the environment variable required by the token verification
|
||||||
|
mockEnvConfig := &common.EnvConfigSchema{
|
||||||
|
AppURL: "https://test.example.com",
|
||||||
|
KeysStorage: "file",
|
||||||
|
KeysPath: tempDir,
|
||||||
|
}
|
||||||
|
|
||||||
// Create a JWT service with initialized key
|
// Create a JWT service with initialized key
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
err := service.init(mockConfig, tempDir)
|
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||||
require.NoError(t, err, "Failed to initialize JWT service")
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
// Get the JWK (public key)
|
// Get the JWK (public key)
|
||||||
@@ -178,12 +213,19 @@ func TestJwtService_GetPublicJWK(t *testing.T) {
|
|||||||
// Create a temporary directory for the test
|
// Create a temporary directory for the test
|
||||||
tempDir := t.TempDir()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Setup the environment variable required by the token verification
|
||||||
|
mockEnvConfig := &common.EnvConfigSchema{
|
||||||
|
AppURL: "https://test.example.com",
|
||||||
|
KeysStorage: "file",
|
||||||
|
KeysPath: tempDir,
|
||||||
|
}
|
||||||
|
|
||||||
// Create an ECDSA key and save it as JWK
|
// Create an ECDSA key and save it as JWK
|
||||||
originalKeyID := createECDSAKeyJWK(t, tempDir)
|
originalKeyID := createECDSAKeyJWK(t, tempDir)
|
||||||
|
|
||||||
// Create a JWT service that loads the ECDSA key
|
// Create a JWT service that loads the ECDSA key
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
err := service.init(mockConfig, tempDir)
|
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||||
require.NoError(t, err, "Failed to initialize JWT service")
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
// Get the JWK (public key)
|
// Get the JWK (public key)
|
||||||
@@ -216,12 +258,19 @@ func TestJwtService_GetPublicJWK(t *testing.T) {
|
|||||||
// Create a temporary directory for the test
|
// Create a temporary directory for the test
|
||||||
tempDir := t.TempDir()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Setup the environment variable required by the token verification
|
||||||
|
mockEnvConfig := &common.EnvConfigSchema{
|
||||||
|
AppURL: "https://test.example.com",
|
||||||
|
KeysStorage: "file",
|
||||||
|
KeysPath: tempDir,
|
||||||
|
}
|
||||||
|
|
||||||
// Create an EdDSA key and save it as JWK
|
// Create an EdDSA key and save it as JWK
|
||||||
originalKeyID := createEdDSAKeyJWK(t, tempDir)
|
originalKeyID := createEdDSAKeyJWK(t, tempDir)
|
||||||
|
|
||||||
// Create a JWT service that loads the EdDSA key
|
// Create a JWT service that loads the EdDSA key
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
err := service.init(mockConfig, tempDir)
|
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||||
require.NoError(t, err, "Failed to initialize JWT service")
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
// Get the JWK (public key)
|
// Get the JWK (public key)
|
||||||
@@ -276,16 +325,16 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Setup the environment variable required by the token verification
|
// Setup the environment variable required by the token verification
|
||||||
originalAppURL := common.EnvConfig.AppURL
|
mockEnvConfig := &common.EnvConfigSchema{
|
||||||
common.EnvConfig.AppURL = "https://test.example.com"
|
AppURL: "https://test.example.com",
|
||||||
defer func() {
|
KeysStorage: "file",
|
||||||
common.EnvConfig.AppURL = originalAppURL
|
KeysPath: tempDir,
|
||||||
}()
|
}
|
||||||
|
|
||||||
t.Run("generates token for regular user", func(t *testing.T) {
|
t.Run("generates token for regular user", func(t *testing.T) {
|
||||||
// Create a JWT service
|
// Create a JWT service
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
err := service.init(mockConfig, tempDir)
|
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||||
require.NoError(t, err, "Failed to initialize JWT service")
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
// Create a test user
|
// Create a test user
|
||||||
@@ -328,7 +377,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
|||||||
t.Run("generates token for admin user", func(t *testing.T) {
|
t.Run("generates token for admin user", func(t *testing.T) {
|
||||||
// Create a JWT service
|
// Create a JWT service
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
err := service.init(mockConfig, tempDir)
|
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||||
require.NoError(t, err, "Failed to initialize JWT service")
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
// Create a test admin user
|
// Create a test admin user
|
||||||
@@ -364,7 +413,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
err := service.init(customMockConfig, tempDir)
|
err := service.init(nil, customMockConfig, mockEnvConfig)
|
||||||
require.NoError(t, err, "Failed to initialize JWT service")
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
// Create a test user
|
// Create a test user
|
||||||
@@ -399,7 +448,10 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
|||||||
|
|
||||||
// Create a JWT service that loads the key
|
// Create a JWT service that loads the key
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
err := service.init(mockConfig, tempDir)
|
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||||
|
KeysStorage: "file",
|
||||||
|
KeysPath: tempDir,
|
||||||
|
})
|
||||||
require.NoError(t, err, "Failed to initialize JWT service")
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
// Verify it loaded the right key
|
// Verify it loaded the right key
|
||||||
@@ -453,7 +505,10 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
|||||||
|
|
||||||
// Create a JWT service that loads the key
|
// Create a JWT service that loads the key
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
err := service.init(mockConfig, tempDir)
|
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||||
|
KeysStorage: "file",
|
||||||
|
KeysPath: tempDir,
|
||||||
|
})
|
||||||
require.NoError(t, err, "Failed to initialize JWT service")
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
// Verify it loaded the right key
|
// Verify it loaded the right key
|
||||||
@@ -507,7 +562,10 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
|||||||
|
|
||||||
// Create a JWT service that loads the key
|
// Create a JWT service that loads the key
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
err := service.init(mockConfig, tempDir)
|
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||||
|
KeysStorage: "file",
|
||||||
|
KeysPath: tempDir,
|
||||||
|
})
|
||||||
require.NoError(t, err, "Failed to initialize JWT service")
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
// Verify it loaded the right key
|
// Verify it loaded the right key
|
||||||
@@ -563,16 +621,16 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Setup the environment variable required by the token verification
|
// Setup the environment variable required by the token verification
|
||||||
originalAppURL := common.EnvConfig.AppURL
|
mockEnvConfig := &common.EnvConfigSchema{
|
||||||
common.EnvConfig.AppURL = "https://test.example.com"
|
AppURL: "https://test.example.com",
|
||||||
defer func() {
|
KeysStorage: "file",
|
||||||
common.EnvConfig.AppURL = originalAppURL
|
KeysPath: tempDir,
|
||||||
}()
|
}
|
||||||
|
|
||||||
t.Run("generates and verifies ID token with standard claims", func(t *testing.T) {
|
t.Run("generates and verifies ID token with standard claims", func(t *testing.T) {
|
||||||
// Create a JWT service
|
// Create a JWT service
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
err := service.init(mockConfig, tempDir)
|
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||||
require.NoError(t, err, "Failed to initialize JWT service")
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
// Create test claims
|
// Create test claims
|
||||||
@@ -601,7 +659,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
|||||||
assert.Equal(t, []string{clientID}, audience, "Audience should contain the client ID")
|
assert.Equal(t, []string{clientID}, audience, "Audience should contain the client ID")
|
||||||
issuer, ok := claims.Issuer()
|
issuer, ok := claims.Issuer()
|
||||||
_ = assert.True(t, ok, "Issuer not found in token") &&
|
_ = assert.True(t, ok, "Issuer not found in token") &&
|
||||||
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
|
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
|
||||||
|
|
||||||
// Check token expiration time is approximately 1 hour from now
|
// Check token expiration time is approximately 1 hour from now
|
||||||
expectedExp := time.Now().Add(1 * time.Hour)
|
expectedExp := time.Now().Add(1 * time.Hour)
|
||||||
@@ -614,7 +672,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
|||||||
t.Run("can accept expired tokens if told so", func(t *testing.T) {
|
t.Run("can accept expired tokens if told so", func(t *testing.T) {
|
||||||
// Create a JWT service
|
// Create a JWT service
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
err := service.init(mockConfig, tempDir)
|
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||||
require.NoError(t, err, "Failed to initialize JWT service")
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
// Create test claims
|
// Create test claims
|
||||||
@@ -628,7 +686,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
|||||||
// Create a token that's already expired
|
// Create a token that's already expired
|
||||||
token, err := jwt.NewBuilder().
|
token, err := jwt.NewBuilder().
|
||||||
Subject(userClaims["sub"].(string)).
|
Subject(userClaims["sub"].(string)).
|
||||||
Issuer(common.EnvConfig.AppURL).
|
Issuer(service.envConfig.AppURL).
|
||||||
Audience([]string{clientID}).
|
Audience([]string{clientID}).
|
||||||
IssuedAt(time.Now().Add(-2 * time.Hour)).
|
IssuedAt(time.Now().Add(-2 * time.Hour)).
|
||||||
Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago
|
Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago
|
||||||
@@ -666,13 +724,13 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
|||||||
assert.Equal(t, userClaims["sub"], subject, "Token subject should match user ID")
|
assert.Equal(t, userClaims["sub"], subject, "Token subject should match user ID")
|
||||||
issuer, ok := claims.Issuer()
|
issuer, ok := claims.Issuer()
|
||||||
_ = assert.True(t, ok, "Issuer not found in token") &&
|
_ = assert.True(t, ok, "Issuer not found in token") &&
|
||||||
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
|
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("generates and verifies ID token with nonce", func(t *testing.T) {
|
t.Run("generates and verifies ID token with nonce", func(t *testing.T) {
|
||||||
// Create a JWT service
|
// Create a JWT service
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
err := service.init(mockConfig, tempDir)
|
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||||
require.NoError(t, err, "Failed to initialize JWT service")
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
// Create test claims with nonce
|
// Create test claims with nonce
|
||||||
@@ -703,7 +761,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
|||||||
t.Run("fails verification with incorrect issuer", func(t *testing.T) {
|
t.Run("fails verification with incorrect issuer", func(t *testing.T) {
|
||||||
// Create a JWT service
|
// Create a JWT service
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
err := service.init(mockConfig, tempDir)
|
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||||
require.NoError(t, err, "Failed to initialize JWT service")
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
// Generate a token with standard claims
|
// Generate a token with standard claims
|
||||||
@@ -714,7 +772,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
|||||||
require.NoError(t, err, "Failed to generate ID token")
|
require.NoError(t, err, "Failed to generate ID token")
|
||||||
|
|
||||||
// Temporarily change the app URL to simulate wrong issuer
|
// Temporarily change the app URL to simulate wrong issuer
|
||||||
common.EnvConfig.AppURL = "https://wrong-issuer.com"
|
service.envConfig.AppURL = "https://wrong-issuer.com"
|
||||||
|
|
||||||
// Verify should fail due to issuer mismatch
|
// Verify should fail due to issuer mismatch
|
||||||
_, err = service.VerifyIdToken(tokenString, false)
|
_, err = service.VerifyIdToken(tokenString, false)
|
||||||
@@ -731,7 +789,10 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
|||||||
|
|
||||||
// Create a JWT service that loads the key
|
// Create a JWT service that loads the key
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
err := service.init(mockConfig, tempDir)
|
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||||
|
KeysStorage: "file",
|
||||||
|
KeysPath: tempDir,
|
||||||
|
})
|
||||||
require.NoError(t, err, "Failed to initialize JWT service")
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
// Verify it loaded the right key
|
// Verify it loaded the right key
|
||||||
@@ -762,7 +823,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
|||||||
assert.Equal(t, "eddsauser456", subject, "Token subject should match user ID")
|
assert.Equal(t, "eddsauser456", subject, "Token subject should match user ID")
|
||||||
issuer, ok := claims.Issuer()
|
issuer, ok := claims.Issuer()
|
||||||
_ = assert.True(t, ok, "Issuer not found in token") &&
|
_ = assert.True(t, ok, "Issuer not found in token") &&
|
||||||
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
|
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
|
||||||
|
|
||||||
// Verify the key type is OKP
|
// Verify the key type is OKP
|
||||||
publicKey, err := service.GetPublicJWK()
|
publicKey, err := service.GetPublicJWK()
|
||||||
@@ -784,7 +845,10 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
|||||||
|
|
||||||
// Create a JWT service that loads the key
|
// Create a JWT service that loads the key
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
err := service.init(mockConfig, tempDir)
|
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||||
|
KeysStorage: "file",
|
||||||
|
KeysPath: tempDir,
|
||||||
|
})
|
||||||
require.NoError(t, err, "Failed to initialize JWT service")
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
// Verify it loaded the right key
|
// Verify it loaded the right key
|
||||||
@@ -795,7 +859,6 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
|||||||
// Create test claims
|
// Create test claims
|
||||||
userClaims := map[string]interface{}{
|
userClaims := map[string]interface{}{
|
||||||
"sub": "ecdsauser456",
|
"sub": "ecdsauser456",
|
||||||
"name": "ECDSA User",
|
|
||||||
"email": "ecdsauser@example.com",
|
"email": "ecdsauser@example.com",
|
||||||
}
|
}
|
||||||
const clientID = "ecdsa-client-123"
|
const clientID = "ecdsa-client-123"
|
||||||
@@ -815,7 +878,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
|||||||
assert.Equal(t, "ecdsauser456", subject, "Token subject should match user ID")
|
assert.Equal(t, "ecdsauser456", subject, "Token subject should match user ID")
|
||||||
issuer, ok := claims.Issuer()
|
issuer, ok := claims.Issuer()
|
||||||
_ = assert.True(t, ok, "Issuer not found in token") &&
|
_ = assert.True(t, ok, "Issuer not found in token") &&
|
||||||
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
|
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
|
||||||
|
|
||||||
// Verify the key type is EC
|
// Verify the key type is EC
|
||||||
publicKey, err := service.GetPublicJWK()
|
publicKey, err := service.GetPublicJWK()
|
||||||
@@ -837,7 +900,10 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
|||||||
|
|
||||||
// Create a JWT service that loads the key
|
// Create a JWT service that loads the key
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
err := service.init(mockConfig, tempDir)
|
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||||
|
KeysStorage: "file",
|
||||||
|
KeysPath: tempDir,
|
||||||
|
})
|
||||||
require.NoError(t, err, "Failed to initialize JWT service")
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
// Verify it loaded the right key
|
// Verify it loaded the right key
|
||||||
@@ -868,17 +934,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
|||||||
assert.Equal(t, "rsauser456", subject, "Token subject should match user ID")
|
assert.Equal(t, "rsauser456", subject, "Token subject should match user ID")
|
||||||
issuer, ok := claims.Issuer()
|
issuer, ok := claims.Issuer()
|
||||||
_ = assert.True(t, ok, "Issuer not found in token") &&
|
_ = assert.True(t, ok, "Issuer not found in token") &&
|
||||||
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
|
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
|
||||||
|
|
||||||
// Verify the key type is RSA
|
|
||||||
publicKey, err := service.GetPublicJWK()
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, jwa.RSA().String(), publicKey.KeyType().String(), "Key type should be RSA")
|
|
||||||
|
|
||||||
// Verify the algorithm is RS256
|
|
||||||
alg, ok := publicKey.Algorithm()
|
|
||||||
require.True(t, ok)
|
|
||||||
assert.Equal(t, jwa.RS256().String(), alg.String(), "Algorithm should be RS256")
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -892,16 +948,16 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Setup the environment variable required by the token verification
|
// Setup the environment variable required by the token verification
|
||||||
originalAppURL := common.EnvConfig.AppURL
|
mockEnvConfig := &common.EnvConfigSchema{
|
||||||
common.EnvConfig.AppURL = "https://test.example.com"
|
AppURL: "https://test.example.com",
|
||||||
defer func() {
|
KeysStorage: "file",
|
||||||
common.EnvConfig.AppURL = originalAppURL
|
KeysPath: tempDir,
|
||||||
}()
|
}
|
||||||
|
|
||||||
t.Run("generates and verifies OAuth access token with standard claims", func(t *testing.T) {
|
t.Run("generates and verifies OAuth access token with standard claims", func(t *testing.T) {
|
||||||
// Create a JWT service
|
// Create a JWT service
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
err := service.init(mockConfig, tempDir)
|
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||||
require.NoError(t, err, "Failed to initialize JWT service")
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
// Create a test user
|
// Create a test user
|
||||||
@@ -931,7 +987,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
|
|||||||
assert.Equal(t, []string{clientID}, audience, "Audience should contain the client ID")
|
assert.Equal(t, []string{clientID}, audience, "Audience should contain the client ID")
|
||||||
issuer, ok := claims.Issuer()
|
issuer, ok := claims.Issuer()
|
||||||
_ = assert.True(t, ok, "Issuer not found in token") &&
|
_ = assert.True(t, ok, "Issuer not found in token") &&
|
||||||
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
|
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
|
||||||
|
|
||||||
// Check token expiration time is approximately 1 hour from now
|
// Check token expiration time is approximately 1 hour from now
|
||||||
expectedExp := time.Now().Add(1 * time.Hour)
|
expectedExp := time.Now().Add(1 * time.Hour)
|
||||||
@@ -944,7 +1000,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
|
|||||||
t.Run("fails verification for expired token", func(t *testing.T) {
|
t.Run("fails verification for expired token", func(t *testing.T) {
|
||||||
// Create a JWT service with a mock function to generate an expired token
|
// Create a JWT service with a mock function to generate an expired token
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
err := service.init(mockConfig, tempDir)
|
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||||
require.NoError(t, err, "Failed to initialize JWT service")
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
// Create a test user
|
// Create a test user
|
||||||
@@ -961,7 +1017,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
|
|||||||
Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago
|
Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago
|
||||||
IssuedAt(time.Now().Add(-2 * time.Hour)).
|
IssuedAt(time.Now().Add(-2 * time.Hour)).
|
||||||
Audience([]string{clientID}).
|
Audience([]string{clientID}).
|
||||||
Issuer(common.EnvConfig.AppURL).
|
Issuer(service.envConfig.AppURL).
|
||||||
Build()
|
Build()
|
||||||
require.NoError(t, err, "Failed to build token")
|
require.NoError(t, err, "Failed to build token")
|
||||||
|
|
||||||
@@ -980,11 +1036,17 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
|
|||||||
t.Run("fails verification with invalid signature", func(t *testing.T) {
|
t.Run("fails verification with invalid signature", func(t *testing.T) {
|
||||||
// Create two JWT services with different keys
|
// Create two JWT services with different keys
|
||||||
service1 := &JwtService{}
|
service1 := &JwtService{}
|
||||||
err := service1.init(mockConfig, t.TempDir()) // Use a different temp dir
|
err := service1.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||||
|
KeysStorage: "file",
|
||||||
|
KeysPath: t.TempDir(), // Use a different temp dir
|
||||||
|
})
|
||||||
require.NoError(t, err, "Failed to initialize first JWT service")
|
require.NoError(t, err, "Failed to initialize first JWT service")
|
||||||
|
|
||||||
service2 := &JwtService{}
|
service2 := &JwtService{}
|
||||||
err = service2.init(mockConfig, t.TempDir()) // Use a different temp dir
|
err = service2.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||||
|
KeysStorage: "file",
|
||||||
|
KeysPath: t.TempDir(), // Use a different temp dir
|
||||||
|
})
|
||||||
require.NoError(t, err, "Failed to initialize second JWT service")
|
require.NoError(t, err, "Failed to initialize second JWT service")
|
||||||
|
|
||||||
// Create a test user
|
// Create a test user
|
||||||
@@ -1014,7 +1076,10 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
|
|||||||
|
|
||||||
// Create a JWT service that loads the key
|
// Create a JWT service that loads the key
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
err := service.init(mockConfig, tempDir)
|
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||||
|
KeysStorage: "file",
|
||||||
|
KeysPath: tempDir,
|
||||||
|
})
|
||||||
require.NoError(t, err, "Failed to initialize JWT service")
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
// Verify it loaded the right key
|
// Verify it loaded the right key
|
||||||
@@ -1068,7 +1133,10 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
|
|||||||
|
|
||||||
// Create a JWT service that loads the key
|
// Create a JWT service that loads the key
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
err := service.init(mockConfig, tempDir)
|
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||||
|
KeysStorage: "file",
|
||||||
|
KeysPath: tempDir,
|
||||||
|
})
|
||||||
require.NoError(t, err, "Failed to initialize JWT service")
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
// Verify it loaded the right key
|
// Verify it loaded the right key
|
||||||
@@ -1122,7 +1190,10 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
|
|||||||
|
|
||||||
// Create a JWT service that loads the key
|
// Create a JWT service that loads the key
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
err := service.init(mockConfig, tempDir)
|
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||||
|
KeysStorage: "file",
|
||||||
|
KeysPath: tempDir,
|
||||||
|
})
|
||||||
require.NoError(t, err, "Failed to initialize JWT service")
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
// Verify it loaded the right key
|
// Verify it loaded the right key
|
||||||
@@ -1176,16 +1247,16 @@ func TestGenerateVerifyOAuthRefreshToken(t *testing.T) {
|
|||||||
mockConfig := NewTestAppConfigService(&model.AppConfig{})
|
mockConfig := NewTestAppConfigService(&model.AppConfig{})
|
||||||
|
|
||||||
// Setup the environment variable required by the token verification
|
// Setup the environment variable required by the token verification
|
||||||
originalAppURL := common.EnvConfig.AppURL
|
mockEnvConfig := &common.EnvConfigSchema{
|
||||||
common.EnvConfig.AppURL = "https://test.example.com"
|
AppURL: "https://test.example.com",
|
||||||
defer func() {
|
KeysStorage: "file",
|
||||||
common.EnvConfig.AppURL = originalAppURL
|
KeysPath: tempDir,
|
||||||
}()
|
}
|
||||||
|
|
||||||
t.Run("generates and verifies refresh token", func(t *testing.T) {
|
t.Run("generates and verifies refresh token", func(t *testing.T) {
|
||||||
// Create a JWT service
|
// Create a JWT service
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
err := service.init(mockConfig, tempDir)
|
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||||
require.NoError(t, err, "Failed to initialize JWT service")
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
// Create a test user
|
// Create a test user
|
||||||
@@ -1211,7 +1282,7 @@ func TestGenerateVerifyOAuthRefreshToken(t *testing.T) {
|
|||||||
t.Run("fails verification for expired token", func(t *testing.T) {
|
t.Run("fails verification for expired token", func(t *testing.T) {
|
||||||
// Create a JWT service
|
// Create a JWT service
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
err := service.init(mockConfig, tempDir)
|
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||||
require.NoError(t, err, "Failed to initialize JWT service")
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
// Generate a token using JWT directly to create an expired token
|
// Generate a token using JWT directly to create an expired token
|
||||||
@@ -1220,7 +1291,7 @@ func TestGenerateVerifyOAuthRefreshToken(t *testing.T) {
|
|||||||
Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago
|
Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago
|
||||||
IssuedAt(time.Now().Add(-2 * time.Hour)).
|
IssuedAt(time.Now().Add(-2 * time.Hour)).
|
||||||
Audience([]string{"client123"}).
|
Audience([]string{"client123"}).
|
||||||
Issuer(common.EnvConfig.AppURL).
|
Issuer(service.envConfig.AppURL).
|
||||||
Build()
|
Build()
|
||||||
require.NoError(t, err, "Failed to build token")
|
require.NoError(t, err, "Failed to build token")
|
||||||
|
|
||||||
@@ -1236,11 +1307,17 @@ func TestGenerateVerifyOAuthRefreshToken(t *testing.T) {
|
|||||||
t.Run("fails verification with invalid signature", func(t *testing.T) {
|
t.Run("fails verification with invalid signature", func(t *testing.T) {
|
||||||
// Create two JWT services with different keys
|
// Create two JWT services with different keys
|
||||||
service1 := &JwtService{}
|
service1 := &JwtService{}
|
||||||
err := service1.init(mockConfig, t.TempDir())
|
err := service1.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||||
|
KeysStorage: "file",
|
||||||
|
KeysPath: t.TempDir(), // Use a different temp dir
|
||||||
|
})
|
||||||
require.NoError(t, err, "Failed to initialize first JWT service")
|
require.NoError(t, err, "Failed to initialize first JWT service")
|
||||||
|
|
||||||
service2 := &JwtService{}
|
service2 := &JwtService{}
|
||||||
err = service2.init(mockConfig, t.TempDir())
|
err = service2.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||||
|
KeysStorage: "file",
|
||||||
|
KeysPath: t.TempDir(), // Use a different temp dir
|
||||||
|
})
|
||||||
require.NoError(t, err, "Failed to initialize second JWT service")
|
require.NoError(t, err, "Failed to initialize second JWT service")
|
||||||
|
|
||||||
// Generate a token with the first service
|
// Generate a token with the first service
|
||||||
@@ -1308,7 +1385,10 @@ func TestGetTokenType(t *testing.T) {
|
|||||||
// Initialize the JWT service
|
// Initialize the JWT service
|
||||||
mockConfig := NewTestAppConfigService(&model.AppConfig{})
|
mockConfig := NewTestAppConfigService(&model.AppConfig{})
|
||||||
service := &JwtService{}
|
service := &JwtService{}
|
||||||
err := service.init(mockConfig, tempDir)
|
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||||
|
KeysStorage: "file",
|
||||||
|
KeysPath: tempDir,
|
||||||
|
})
|
||||||
require.NoError(t, err, "Failed to initialize JWT service")
|
require.NoError(t, err, "Failed to initialize JWT service")
|
||||||
|
|
||||||
buildTokenForType := func(t *testing.T, typ string, setClaimsFn func(b *jwt.Builder)) string {
|
buildTokenForType := func(t *testing.T, typ string, setClaimsFn func(b *jwt.Builder)) string {
|
||||||
@@ -1402,10 +1482,19 @@ func TestGetTokenType(t *testing.T) {
|
|||||||
func importKey(t *testing.T, privateKeyRaw any, path string) string {
|
func importKey(t *testing.T, privateKeyRaw any, path string) string {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
privateKey, err := utils.ImportRawKey(privateKeyRaw)
|
privateKey, err := jwkutils.ImportRawKey(privateKeyRaw, "", "")
|
||||||
require.NoError(t, err, "Failed to import private key")
|
require.NoError(t, err, "Failed to import private key")
|
||||||
|
|
||||||
err = SaveKeyJWK(privateKey, filepath.Join(path, PrivateKeyFile))
|
keyProvider := &jwkutils.KeyProviderFile{}
|
||||||
|
err = keyProvider.Init(jwkutils.KeyProviderOpts{
|
||||||
|
EnvConfig: &common.EnvConfigSchema{
|
||||||
|
KeysStorage: "file",
|
||||||
|
KeysPath: path,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err, "Failed to init file key provider")
|
||||||
|
|
||||||
|
err = keyProvider.SaveKey(privateKey)
|
||||||
require.NoError(t, err, "Failed to save key")
|
require.NoError(t, err, "Failed to save key")
|
||||||
|
|
||||||
kid, _ := privateKey.KeyID()
|
kid, _ := privateKey.KeyID()
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||||
|
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
// generateTestECDSAKey creates an ECDSA key for testing
|
// generateTestECDSAKey creates an ECDSA key for testing
|
||||||
@@ -62,12 +63,12 @@ func TestOidcService_jwkSetForURL(t *testing.T) {
|
|||||||
)
|
)
|
||||||
mockResponses := map[string]*http.Response{
|
mockResponses := map[string]*http.Response{
|
||||||
//nolint:bodyclose
|
//nolint:bodyclose
|
||||||
url1: NewMockResponse(http.StatusOK, string(jwkSetJSON1)),
|
url1: testutils.NewMockResponse(http.StatusOK, string(jwkSetJSON1)),
|
||||||
//nolint:bodyclose
|
//nolint:bodyclose
|
||||||
url2: NewMockResponse(http.StatusOK, string(jwkSetJSON2)),
|
url2: testutils.NewMockResponse(http.StatusOK, string(jwkSetJSON2)),
|
||||||
}
|
}
|
||||||
httpClient := &http.Client{
|
httpClient := &http.Client{
|
||||||
Transport: &MockRoundTripper{
|
Transport: &testutils.MockRoundTripper{
|
||||||
Responses: mockResponses,
|
Responses: mockResponses,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -139,7 +140,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
|||||||
|
|
||||||
var err error
|
var err error
|
||||||
// Create a test database
|
// Create a test database
|
||||||
db := newDatabaseForTest(t)
|
db := testutils.NewDatabaseForTest(t)
|
||||||
|
|
||||||
// Create two JWKs for testing
|
// Create two JWKs for testing
|
||||||
privateJWK, jwkSetJSON := generateTestECDSAKey(t)
|
privateJWK, jwkSetJSON := generateTestECDSAKey(t)
|
||||||
@@ -149,12 +150,12 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
|||||||
|
|
||||||
// Create a mock HTTP client with custom transport to return the JWKS
|
// Create a mock HTTP client with custom transport to return the JWKS
|
||||||
httpClient := &http.Client{
|
httpClient := &http.Client{
|
||||||
Transport: &MockRoundTripper{
|
Transport: &testutils.MockRoundTripper{
|
||||||
Responses: map[string]*http.Response{
|
Responses: map[string]*http.Response{
|
||||||
//nolint:bodyclose
|
//nolint:bodyclose
|
||||||
federatedClientIssuer + "/jwks.json": NewMockResponse(http.StatusOK, string(jwkSetJSON)),
|
federatedClientIssuer + "/jwks.json": testutils.NewMockResponse(http.StatusOK, string(jwkSetJSON)),
|
||||||
//nolint:bodyclose
|
//nolint:bodyclose
|
||||||
federatedClientIssuerDefaults + ".well-known/jwks.json": NewMockResponse(http.StatusOK, string(jwkSetJSONDefaults)),
|
federatedClientIssuerDefaults + ".well-known/jwks.json": testutils.NewMockResponse(http.StatusOK, string(jwkSetJSONDefaults)),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
69
backend/internal/utils/crypto/crypto.go
Normal file
69
backend/internal/utils/crypto/crypto.go
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
package crypto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/rand"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrDecrypt is returned by Decrypt when the operation failed for any reason
|
||||||
|
var ErrDecrypt = errors.New("failed to decrypt data")
|
||||||
|
|
||||||
|
// Encrypt a byte slice using AES-GCM and a random nonce
|
||||||
|
// Important: do not encrypt more than ~4 billion messages with the same key!
|
||||||
|
func Encrypt(key []byte, plaintext []byte, associatedData []byte) (ciphertext []byte, err error) {
|
||||||
|
block, err := aes.NewCipher(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create block cipher: %w", err)
|
||||||
|
}
|
||||||
|
aead, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create AEAD cipher: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a random nonce
|
||||||
|
nonce := make([]byte, aead.NonceSize())
|
||||||
|
_, err = io.ReadFull(rand.Reader, nonce)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate random nonce: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate the slice for the result, with additional space for the nonce and overhead
|
||||||
|
ciphertext = make([]byte, 0, len(plaintext)+aead.NonceSize()+aead.Overhead())
|
||||||
|
ciphertext = append(ciphertext, nonce...)
|
||||||
|
|
||||||
|
// Encrypt the plaintext
|
||||||
|
// Tag is automatically added at the end
|
||||||
|
ciphertext = aead.Seal(ciphertext, nonce, plaintext, associatedData)
|
||||||
|
|
||||||
|
return ciphertext, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrypt a byte slice using AES-GCM
|
||||||
|
func Decrypt(key []byte, ciphertext []byte, associatedData []byte) (plaintext []byte, err error) {
|
||||||
|
block, err := aes.NewCipher(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create block cipher: %w", err)
|
||||||
|
}
|
||||||
|
aead, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create AEAD cipher: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the nonce
|
||||||
|
if len(ciphertext) < (aead.NonceSize() + aead.Overhead()) {
|
||||||
|
return nil, ErrDecrypt
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrypt the data
|
||||||
|
plaintext, err = aead.Open(nil, ciphertext[:aead.NonceSize()], ciphertext[aead.NonceSize():], associatedData)
|
||||||
|
if err != nil {
|
||||||
|
// Note: we do not return the exact error here, to avoid disclosing information
|
||||||
|
return nil, ErrDecrypt
|
||||||
|
}
|
||||||
|
|
||||||
|
return plaintext, nil
|
||||||
|
}
|
||||||
208
backend/internal/utils/crypto/crypto_test.go
Normal file
208
backend/internal/utils/crypto/crypto_test.go
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
package crypto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEncryptDecrypt(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
keySize int
|
||||||
|
plaintext string
|
||||||
|
associatedData []byte
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "AES-128 with short plaintext",
|
||||||
|
keySize: 16,
|
||||||
|
plaintext: "Hello, World!",
|
||||||
|
associatedData: []byte("test-aad"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "AES-192 with medium plaintext",
|
||||||
|
keySize: 24,
|
||||||
|
plaintext: "This is a longer message to test encryption and decryption",
|
||||||
|
associatedData: []byte("associated-data-192"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "AES-256 with unicode",
|
||||||
|
keySize: 32,
|
||||||
|
plaintext: "Hello 世界! 🌍 Testing unicode characters", //nolint:gosmopolitan
|
||||||
|
associatedData: []byte("unicode-test"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No associated data",
|
||||||
|
keySize: 32,
|
||||||
|
plaintext: "Testing without associated data",
|
||||||
|
associatedData: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Generate random key
|
||||||
|
key := make([]byte, tt.keySize)
|
||||||
|
_, err := rand.Read(key)
|
||||||
|
require.NoError(t, err, "Failed to generate random key")
|
||||||
|
|
||||||
|
plaintext := []byte(tt.plaintext)
|
||||||
|
|
||||||
|
// Test encryption
|
||||||
|
ciphertext, err := Encrypt(key, plaintext, tt.associatedData)
|
||||||
|
require.NoError(t, err, "Encrypt should succeed")
|
||||||
|
|
||||||
|
// Verify ciphertext is different from plaintext (unless empty)
|
||||||
|
if len(plaintext) > 0 {
|
||||||
|
assert.NotEqual(t, plaintext, ciphertext)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test decryption
|
||||||
|
decrypted, err := Decrypt(key, ciphertext, tt.associatedData)
|
||||||
|
require.NoError(t, err, "Decrypt should succeed")
|
||||||
|
|
||||||
|
// Verify decrypted text matches original
|
||||||
|
assert.Equal(t, plaintext, decrypted, "Decrypted text should match original")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncryptWithInvalidKeySize(t *testing.T) {
|
||||||
|
invalidKeySizes := []int{8, 12, 33, 47, 55, 128}
|
||||||
|
|
||||||
|
for _, keySize := range invalidKeySizes {
|
||||||
|
t.Run(fmt.Sprintf("Key size %d", keySize), func(t *testing.T) {
|
||||||
|
key := make([]byte, keySize)
|
||||||
|
plaintext := []byte("test message")
|
||||||
|
|
||||||
|
_, err := Encrypt(key, plaintext, nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.ErrorContains(t, err, "invalid key size")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptWithInvalidKeySize(t *testing.T) {
|
||||||
|
invalidKeySizes := []int{8, 12, 33, 47, 55, 128}
|
||||||
|
|
||||||
|
for _, keySize := range invalidKeySizes {
|
||||||
|
t.Run(fmt.Sprintf("Key size %d", keySize), func(t *testing.T) {
|
||||||
|
key := make([]byte, keySize)
|
||||||
|
ciphertext := []byte("fake ciphertext")
|
||||||
|
|
||||||
|
_, err := Decrypt(key, ciphertext, nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.ErrorContains(t, err, "invalid key size")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptWithInvalidCiphertext(t *testing.T) {
|
||||||
|
key := make([]byte, 32)
|
||||||
|
_, err := rand.Read(key)
|
||||||
|
require.NoError(t, err, "Failed to generate random key")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ciphertext []byte
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty ciphertext",
|
||||||
|
ciphertext: []byte{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "too short ciphertext",
|
||||||
|
ciphertext: []byte("short"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "random invalid data",
|
||||||
|
ciphertext: []byte("this is not valid encrypted data"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
_, err := Decrypt(key, tt.ciphertext, nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorIs(t, err, ErrDecrypt)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptWithWrongKey(t *testing.T) {
|
||||||
|
// Generate two different keys
|
||||||
|
key1 := make([]byte, 32)
|
||||||
|
key2 := make([]byte, 32)
|
||||||
|
_, err := rand.Read(key1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = rand.Read(key2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
plaintext := []byte("secret message")
|
||||||
|
|
||||||
|
// Encrypt with key1
|
||||||
|
ciphertext, err := Encrypt(key1, plaintext, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Try to decrypt with key2
|
||||||
|
_, err = Decrypt(key2, ciphertext, nil)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorIs(t, err, ErrDecrypt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptWithWrongAssociatedData(t *testing.T) {
|
||||||
|
key := make([]byte, 32)
|
||||||
|
_, err := rand.Read(key)
|
||||||
|
require.NoError(t, err, "Failed to generate random key")
|
||||||
|
|
||||||
|
plaintext := []byte("secret message")
|
||||||
|
correctAAD := []byte("correct-aad")
|
||||||
|
wrongAAD := []byte("wrong-aad")
|
||||||
|
|
||||||
|
// Encrypt with correct AAD
|
||||||
|
ciphertext, err := Encrypt(key, plaintext, correctAAD)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Try to decrypt with wrong AAD
|
||||||
|
_, err = Decrypt(key, ciphertext, wrongAAD)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorIs(t, err, ErrDecrypt)
|
||||||
|
|
||||||
|
// Verify correct AAD works
|
||||||
|
decrypted, err := Decrypt(key, ciphertext, correctAAD)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, plaintext, decrypted, "Decrypted text should match original when using correct AAD")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncryptDecryptConsistency(t *testing.T) {
|
||||||
|
key := make([]byte, 32)
|
||||||
|
_, err := rand.Read(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
plaintext := []byte("consistency test message")
|
||||||
|
associatedData := []byte("test-aad")
|
||||||
|
|
||||||
|
// Encrypt multiple times and verify we get different ciphertexts (due to random IV)
|
||||||
|
ciphertext1, err := Encrypt(key, plaintext, associatedData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ciphertext2, err := Encrypt(key, plaintext, associatedData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Ciphertexts should be different (due to random IV)
|
||||||
|
assert.NotEqual(t, ciphertext1, ciphertext2, "Multiple encryptions of same plaintext should produce different ciphertexts")
|
||||||
|
|
||||||
|
// Both should decrypt to the same plaintext
|
||||||
|
decrypted1, err := Decrypt(key, ciphertext1, associatedData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
decrypted2, err := Decrypt(key, ciphertext2, associatedData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, plaintext, decrypted1, "First decrypted text should match original")
|
||||||
|
assert.Equal(t, plaintext, decrypted2, "Second decrypted text should match original")
|
||||||
|
assert.Equal(t, decrypted1, decrypted2, "Both decrypted texts should be identical")
|
||||||
|
}
|
||||||
50
backend/internal/utils/jwk/key_provider.go
Normal file
50
backend/internal/utils/jwk/key_provider.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
package jwk
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
type KeyProviderOpts struct {
|
||||||
|
EnvConfig *common.EnvConfigSchema
|
||||||
|
DB *gorm.DB
|
||||||
|
Kek []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type KeyProvider interface {
|
||||||
|
Init(opts KeyProviderOpts) error
|
||||||
|
LoadKey() (jwk.Key, error)
|
||||||
|
SaveKey(key jwk.Key) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetKeyProvider(db *gorm.DB, envConfig *common.EnvConfigSchema, instanceID string) (keyProvider KeyProvider, err error) {
|
||||||
|
// Load the encryption key (KEK) if present
|
||||||
|
kek, err := LoadKeyEncryptionKey(envConfig, instanceID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load encryption key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the key provider
|
||||||
|
switch envConfig.KeysStorage {
|
||||||
|
case "file", "":
|
||||||
|
keyProvider = &KeyProviderFile{}
|
||||||
|
case "database":
|
||||||
|
keyProvider = &KeyProviderDatabase{}
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid key storage '%s'", envConfig.KeysStorage)
|
||||||
|
}
|
||||||
|
err = keyProvider.Init(KeyProviderOpts{
|
||||||
|
DB: db,
|
||||||
|
EnvConfig: envConfig,
|
||||||
|
Kek: kek,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to init key provider of type '%s': %w", envConfig.KeysStorage, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return keyProvider, nil
|
||||||
|
}
|
||||||
109
backend/internal/utils/jwk/key_provider_database.go
Normal file
109
backend/internal/utils/jwk/key_provider_database.go
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
package jwk
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||||
|
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
|
||||||
|
)
|
||||||
|
|
||||||
|
const PrivateKeyDBKey = "jwt_private_key.json"
|
||||||
|
|
||||||
|
type KeyProviderDatabase struct {
|
||||||
|
db *gorm.DB
|
||||||
|
kek []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *KeyProviderDatabase) Init(opts KeyProviderOpts) error {
|
||||||
|
if len(opts.Kek) == 0 {
|
||||||
|
return errors.New("an encryption key is required when using the 'database' key provider")
|
||||||
|
}
|
||||||
|
|
||||||
|
f.db = opts.DB
|
||||||
|
f.kek = opts.Kek
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *KeyProviderDatabase) LoadKey() (key jwk.Key, err error) {
|
||||||
|
row := model.KV{
|
||||||
|
Key: PrivateKeyDBKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
err = f.db.WithContext(ctx).First(&row).Error
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
// Key not present in the database - return nil so a new one can be generated
|
||||||
|
return nil, nil
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to retrieve private key from the database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if row.Value == nil || *row.Value == "" {
|
||||||
|
// Key not present in the database - return nil so a new one can be generated
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode from base64
|
||||||
|
enc, err := base64.StdEncoding.DecodeString(*row.Value)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read encrypted private key: not a valid base64-encoded value: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrypt the data
|
||||||
|
data, err := cryptoutils.Decrypt(f.kek, enc, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decrypt private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the key
|
||||||
|
key, err = jwk.ParseKey(data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse encrypted private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *KeyProviderDatabase) SaveKey(key jwk.Key) error {
|
||||||
|
// Encode the key to JSON
|
||||||
|
data, err := EncodeJWKBytes(key)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to encode key to JSON: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt the key then encode to Base64
|
||||||
|
enc, err := cryptoutils.Encrypt(f.kek, data, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to encrypt key: %w", err)
|
||||||
|
}
|
||||||
|
encB64 := base64.StdEncoding.EncodeToString(enc)
|
||||||
|
|
||||||
|
// Save to database
|
||||||
|
row := model.KV{
|
||||||
|
Key: PrivateKeyDBKey,
|
||||||
|
Value: &encB64,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
err = f.db.WithContext(ctx).Create(&row).Error
|
||||||
|
if err != nil {
|
||||||
|
// There's one scenario where if Pocket ID is started fresh with more than 1 replica, they both could be trying to create the private key in the database at the same time
|
||||||
|
// In this case, only one of the replicas will succeed; the other one(s) will return an error here, which will cascade down and cause the replica(s) to crash and be restarted (at that point they'll load the then-existing key from the database)
|
||||||
|
return fmt.Errorf("failed to store private key in database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile-time interface check
|
||||||
|
var _ KeyProvider = (*KeyProviderDatabase)(nil)
|
||||||
275
backend/internal/utils/jwk/key_provider_database_test.go
Normal file
275
backend/internal/utils/jwk/key_provider_database_test.go
Normal file
@@ -0,0 +1,275 @@
|
|||||||
|
package jwk
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||||
|
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
|
||||||
|
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestKeyProviderDatabase_Init(t *testing.T) {
|
||||||
|
t.Run("Init fails when KEK is not provided", func(t *testing.T) {
|
||||||
|
db := testutils.NewDatabaseForTest(t)
|
||||||
|
provider := &KeyProviderDatabase{}
|
||||||
|
err := provider.Init(KeyProviderOpts{
|
||||||
|
DB: db,
|
||||||
|
Kek: nil, // No KEK
|
||||||
|
})
|
||||||
|
require.Error(t, err, "Expected error when KEK is not provided")
|
||||||
|
require.ErrorContains(t, err, "encryption key is required")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Init succeeds with KEK", func(t *testing.T) {
|
||||||
|
db := testutils.NewDatabaseForTest(t)
|
||||||
|
provider := &KeyProviderDatabase{}
|
||||||
|
err := provider.Init(KeyProviderOpts{
|
||||||
|
DB: db,
|
||||||
|
Kek: generateTestKEK(t),
|
||||||
|
})
|
||||||
|
require.NoError(t, err, "Expected no error when KEK is provided")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyProviderDatabase_LoadKey(t *testing.T) {
|
||||||
|
// Generate a test key to use in our tests
|
||||||
|
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
key, err := jwk.Import(pk)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Run("LoadKey with no existing key", func(t *testing.T) {
|
||||||
|
db := testutils.NewDatabaseForTest(t)
|
||||||
|
kek := generateTestKEK(t)
|
||||||
|
|
||||||
|
provider := &KeyProviderDatabase{}
|
||||||
|
err := provider.Init(KeyProviderOpts{
|
||||||
|
DB: db,
|
||||||
|
Kek: kek,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Load key when none exists
|
||||||
|
loadedKey, err := provider.LoadKey()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Nil(t, loadedKey, "Expected nil key when no key exists in database")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("LoadKey with existing key", func(t *testing.T) {
|
||||||
|
db := testutils.NewDatabaseForTest(t)
|
||||||
|
kek := generateTestKEK(t)
|
||||||
|
|
||||||
|
provider := &KeyProviderDatabase{}
|
||||||
|
err := provider.Init(KeyProviderOpts{
|
||||||
|
DB: db,
|
||||||
|
Kek: kek,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Save a key
|
||||||
|
err = provider.SaveKey(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Load the key
|
||||||
|
loadedKey, err := provider.LoadKey()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, loadedKey, "Expected non-nil key when key exists in database")
|
||||||
|
|
||||||
|
// Verify the loaded key is the same as the original
|
||||||
|
keyBytes, err := EncodeJWKBytes(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("LoadKey with invalid base64", func(t *testing.T) {
|
||||||
|
db := testutils.NewDatabaseForTest(t)
|
||||||
|
kek := generateTestKEK(t)
|
||||||
|
|
||||||
|
provider := &KeyProviderDatabase{}
|
||||||
|
err := provider.Init(KeyProviderOpts{
|
||||||
|
DB: db,
|
||||||
|
Kek: kek,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Insert invalid base64 data
|
||||||
|
invalidBase64 := "not-valid-base64"
|
||||||
|
err = db.Create(&model.KV{
|
||||||
|
Key: PrivateKeyDBKey,
|
||||||
|
Value: &invalidBase64,
|
||||||
|
}).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Attempt to load the key
|
||||||
|
loadedKey, err := provider.LoadKey()
|
||||||
|
require.Error(t, err, "Expected error when loading key with invalid base64")
|
||||||
|
require.ErrorContains(t, err, "not a valid base64-encoded value")
|
||||||
|
assert.Nil(t, loadedKey, "Expected nil key when loading fails")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("LoadKey with invalid encrypted data", func(t *testing.T) {
|
||||||
|
db := testutils.NewDatabaseForTest(t)
|
||||||
|
kek := generateTestKEK(t)
|
||||||
|
|
||||||
|
provider := &KeyProviderDatabase{}
|
||||||
|
err := provider.Init(KeyProviderOpts{
|
||||||
|
DB: db,
|
||||||
|
Kek: kek,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Insert valid base64 but invalid encrypted data
|
||||||
|
invalidData := base64.StdEncoding.EncodeToString([]byte("not-valid-encrypted-data"))
|
||||||
|
err = db.Create(&model.KV{
|
||||||
|
Key: PrivateKeyDBKey,
|
||||||
|
Value: &invalidData,
|
||||||
|
}).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Attempt to load the key
|
||||||
|
loadedKey, err := provider.LoadKey()
|
||||||
|
require.Error(t, err, "Expected error when loading key with invalid encrypted data")
|
||||||
|
require.ErrorContains(t, err, "failed to decrypt")
|
||||||
|
assert.Nil(t, loadedKey, "Expected nil key when loading fails")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("LoadKey with valid encrypted data but wrong KEK", func(t *testing.T) {
|
||||||
|
db := testutils.NewDatabaseForTest(t)
|
||||||
|
originalKek := generateTestKEK(t)
|
||||||
|
|
||||||
|
// Save a key with the original KEK
|
||||||
|
originalProvider := &KeyProviderDatabase{}
|
||||||
|
err := originalProvider.Init(KeyProviderOpts{
|
||||||
|
DB: db,
|
||||||
|
Kek: originalKek,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = originalProvider.SaveKey(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Now try to load with a different KEK
|
||||||
|
differentKek := generateTestKEK(t)
|
||||||
|
differentProvider := &KeyProviderDatabase{}
|
||||||
|
err = differentProvider.Init(KeyProviderOpts{
|
||||||
|
DB: db,
|
||||||
|
Kek: differentKek,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Attempt to load the key with the wrong KEK
|
||||||
|
loadedKey, err := differentProvider.LoadKey()
|
||||||
|
require.Error(t, err, "Expected error when loading key with wrong KEK")
|
||||||
|
require.ErrorContains(t, err, "failed to decrypt")
|
||||||
|
assert.Nil(t, loadedKey, "Expected nil key when loading fails")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("LoadKey with invalid key data", func(t *testing.T) {
|
||||||
|
db := testutils.NewDatabaseForTest(t)
|
||||||
|
kek := generateTestKEK(t)
|
||||||
|
|
||||||
|
provider := &KeyProviderDatabase{}
|
||||||
|
err := provider.Init(KeyProviderOpts{
|
||||||
|
DB: db,
|
||||||
|
Kek: kek,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create invalid key data (valid JSON but not a valid JWK)
|
||||||
|
invalidKeyData := []byte(`{"not": "a valid jwk"}`)
|
||||||
|
|
||||||
|
// Encrypt the invalid key data
|
||||||
|
encryptedData, err := cryptoutils.Encrypt(kek, invalidKeyData, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Base64 encode the encrypted data
|
||||||
|
encodedData := base64.StdEncoding.EncodeToString(encryptedData)
|
||||||
|
|
||||||
|
// Save to database
|
||||||
|
err = db.Create(&model.KV{
|
||||||
|
Key: PrivateKeyDBKey,
|
||||||
|
Value: &encodedData,
|
||||||
|
}).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Attempt to load the key
|
||||||
|
loadedKey, err := provider.LoadKey()
|
||||||
|
require.Error(t, err, "Expected error when loading invalid key data")
|
||||||
|
require.ErrorContains(t, err, "failed to parse")
|
||||||
|
assert.Nil(t, loadedKey, "Expected nil key when loading fails")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyProviderDatabase_SaveKey(t *testing.T) {
|
||||||
|
// Generate a test key to use in our tests
|
||||||
|
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
key, err := jwk.Import(pk)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Run("SaveKey and verify database record", func(t *testing.T) {
|
||||||
|
db := testutils.NewDatabaseForTest(t)
|
||||||
|
kek := generateTestKEK(t)
|
||||||
|
|
||||||
|
provider := &KeyProviderDatabase{}
|
||||||
|
err := provider.Init(KeyProviderOpts{
|
||||||
|
DB: db,
|
||||||
|
Kek: kek,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Save the key
|
||||||
|
err = provider.SaveKey(key)
|
||||||
|
require.NoError(t, err, "Expected no error when saving key")
|
||||||
|
|
||||||
|
// Verify record exists in database
|
||||||
|
var kv model.KV
|
||||||
|
err = db.Where("key = ?", PrivateKeyDBKey).First(&kv).Error
|
||||||
|
require.NoError(t, err, "Expected to find key in database")
|
||||||
|
require.NotNil(t, kv.Value, "Expected non-nil value in database")
|
||||||
|
assert.NotEmpty(t, *kv.Value, "Expected non-empty value in database")
|
||||||
|
|
||||||
|
// Decode and decrypt to verify content
|
||||||
|
encBytes, err := base64.StdEncoding.DecodeString(*kv.Value)
|
||||||
|
require.NoError(t, err, "Expected valid base64 encoding")
|
||||||
|
|
||||||
|
decBytes, err := cryptoutils.Decrypt(kek, encBytes, nil)
|
||||||
|
require.NoError(t, err, "Expected valid encrypted data")
|
||||||
|
|
||||||
|
parsedKey, err := jwk.ParseKey(decBytes)
|
||||||
|
require.NoError(t, err, "Expected valid JWK data")
|
||||||
|
|
||||||
|
// Compare keys
|
||||||
|
keyBytes, err := EncodeJWKBytes(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
parsedKeyBytes, err := EncodeJWKBytes(parsedKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, keyBytes, parsedKeyBytes, "Expected saved key to match original key")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateTestKEK(t *testing.T) []byte {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// Generate a 32-byte kek
|
||||||
|
kek := make([]byte, 32)
|
||||||
|
_, err := rand.Read(kek)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return kek
|
||||||
|
}
|
||||||
202
backend/internal/utils/jwk/key_provider_file.go
Normal file
202
backend/internal/utils/jwk/key_provider_file.go
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
package jwk
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||||
|
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||||
|
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// PrivateKeyFile is the path in the data/keys folder where the key is stored
|
||||||
|
// This is a JSON file containing a key encoded as JWK
|
||||||
|
PrivateKeyFile = "jwt_private_key.json"
|
||||||
|
|
||||||
|
// PrivateKeyFileEncrypted is the path in the data/keys folder where the encrypted key is stored
|
||||||
|
// This is a encrypted JSON file containing a key encoded as JWK
|
||||||
|
PrivateKeyFileEncrypted = "jwt_private_key.json.enc"
|
||||||
|
)
|
||||||
|
|
||||||
|
type KeyProviderFile struct {
|
||||||
|
envConfig *common.EnvConfigSchema
|
||||||
|
kek []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *KeyProviderFile) Init(opts KeyProviderOpts) error {
|
||||||
|
f.envConfig = opts.EnvConfig
|
||||||
|
f.kek = opts.Kek
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *KeyProviderFile) LoadKey() (jwk.Key, error) {
|
||||||
|
if len(f.kek) > 0 {
|
||||||
|
return f.loadEncryptedKey()
|
||||||
|
}
|
||||||
|
return f.loadKey()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *KeyProviderFile) SaveKey(key jwk.Key) error {
|
||||||
|
if len(f.kek) > 0 {
|
||||||
|
return f.saveKeyEncrypted(key)
|
||||||
|
}
|
||||||
|
return f.saveKey(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *KeyProviderFile) loadKey() (jwk.Key, error) {
|
||||||
|
var key jwk.Key
|
||||||
|
|
||||||
|
// First, check if we have a JWK file
|
||||||
|
// If we do, then we just load that
|
||||||
|
jwkPath := f.jwkPath()
|
||||||
|
ok, err := utils.FileExists(jwkPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to check if private key file exists at path '%s': %w", jwkPath, err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
// File doesn't exist, no key was loaded
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(jwkPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read private key file at path '%s': %w", jwkPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
key, err = jwk.ParseKey(data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse private key file at path '%s': %w", jwkPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *KeyProviderFile) loadEncryptedKey() (key jwk.Key, err error) {
|
||||||
|
// First, check if we have an encrypted JWK file
|
||||||
|
// If we do, then we just load that
|
||||||
|
encJwkPath := f.encJwkPath()
|
||||||
|
ok, err := utils.FileExists(encJwkPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to check if encrypted private key file exists at path '%s': %w", encJwkPath, err)
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
encB64, err := os.ReadFile(encJwkPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read encrypted private key file at path '%s': %w", encJwkPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode from base64
|
||||||
|
enc := make([]byte, base64.StdEncoding.DecodedLen(len(encB64)))
|
||||||
|
n, err := base64.StdEncoding.Decode(enc, encB64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read encrypted private key file at path '%s': not a valid base64-encoded file: %w", encJwkPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrypt the data
|
||||||
|
data, err := cryptoutils.Decrypt(f.kek, enc[:n], nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decrypt private key file at path '%s': %w", encJwkPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the key
|
||||||
|
key, err = jwk.ParseKey(data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse encrypted private key file at path '%s': %w", encJwkPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we have an un-encrypted JWK file
|
||||||
|
key, err = f.loadKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load un-encrypted key file: %w", err)
|
||||||
|
}
|
||||||
|
if key == nil {
|
||||||
|
// No key exists, encrypted or un-encrypted
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we are here, we have loaded a key that was un-encrypted
|
||||||
|
// We need to replace the plaintext key with the encrypted one before we return
|
||||||
|
err = f.saveKeyEncrypted(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to save encrypted key file: %w", err)
|
||||||
|
}
|
||||||
|
jwkPath := f.jwkPath()
|
||||||
|
err = os.Remove(jwkPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to remove un-encrypted key file at path '%s': %w", jwkPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *KeyProviderFile) saveKey(key jwk.Key) error {
|
||||||
|
err := os.MkdirAll(f.envConfig.KeysPath, 0700)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create directory '%s' for key file: %w", f.envConfig.KeysPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
jwkPath := f.jwkPath()
|
||||||
|
keyFile, err := os.OpenFile(jwkPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create key file at path '%s': %w", jwkPath, err)
|
||||||
|
}
|
||||||
|
defer keyFile.Close()
|
||||||
|
|
||||||
|
// Write the JSON file to disk
|
||||||
|
err = EncodeJWK(keyFile, key)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write key file at path '%s': %w", jwkPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *KeyProviderFile) saveKeyEncrypted(key jwk.Key) error {
|
||||||
|
err := os.MkdirAll(f.envConfig.KeysPath, 0700)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create directory '%s' for encrypted key file: %w", f.envConfig.KeysPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode the key to JSON
|
||||||
|
data, err := EncodeJWKBytes(key)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to encode key to JSON: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt the key then encode to Base64
|
||||||
|
enc, err := cryptoutils.Encrypt(f.kek, data, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to encrypt key: %w", err)
|
||||||
|
}
|
||||||
|
encB64 := make([]byte, base64.StdEncoding.EncodedLen(len(enc)))
|
||||||
|
base64.StdEncoding.Encode(encB64, enc)
|
||||||
|
|
||||||
|
// Write to disk
|
||||||
|
encJwkPath := f.encJwkPath()
|
||||||
|
err = os.WriteFile(encJwkPath, encB64, 0600)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write encrypted key file at path '%s': %w", encJwkPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *KeyProviderFile) jwkPath() string {
|
||||||
|
return filepath.Join(f.envConfig.KeysPath, PrivateKeyFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *KeyProviderFile) encJwkPath() string {
|
||||||
|
return filepath.Join(f.envConfig.KeysPath, PrivateKeyFileEncrypted)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile-time interface check
|
||||||
|
var _ KeyProvider = (*KeyProviderFile)(nil)
|
||||||
320
backend/internal/utils/jwk/key_provider_file_test.go
Normal file
320
backend/internal/utils/jwk/key_provider_file_test.go
Normal file
@@ -0,0 +1,320 @@
|
|||||||
|
package jwk
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||||
|
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestKeyProviderFile_LoadKey(t *testing.T) {
|
||||||
|
// Generate a test key to use in our tests
|
||||||
|
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
key, err := jwk.Import(pk)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Run("LoadKey with no existing key", func(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
provider := &KeyProviderFile{}
|
||||||
|
err := provider.Init(KeyProviderOpts{
|
||||||
|
EnvConfig: &common.EnvConfigSchema{
|
||||||
|
KeysPath: tempDir,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Load key when none exists
|
||||||
|
loadedKey, err := provider.LoadKey()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Nil(t, loadedKey, "Expected nil key when no key exists")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("LoadKey with no existing key (with kek)", func(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
provider := &KeyProviderFile{}
|
||||||
|
err = provider.Init(KeyProviderOpts{
|
||||||
|
EnvConfig: &common.EnvConfigSchema{
|
||||||
|
KeysPath: tempDir,
|
||||||
|
},
|
||||||
|
Kek: makeKEK(t),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Load key when none exists
|
||||||
|
loadedKey, err := provider.LoadKey()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Nil(t, loadedKey, "Expected nil key when no key exists")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("LoadKey with unencrypted key", func(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
provider := &KeyProviderFile{}
|
||||||
|
err := provider.Init(KeyProviderOpts{
|
||||||
|
EnvConfig: &common.EnvConfigSchema{
|
||||||
|
KeysPath: tempDir,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Save a key
|
||||||
|
err = provider.SaveKey(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Make sure the key file exists
|
||||||
|
keyPath := filepath.Join(tempDir, PrivateKeyFile)
|
||||||
|
exists, err := utils.FileExists(keyPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, exists, "Expected key file to exist")
|
||||||
|
|
||||||
|
// Load the key
|
||||||
|
loadedKey, err := provider.LoadKey()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, loadedKey, "Expected non-nil key when key exists")
|
||||||
|
|
||||||
|
// Verify the loaded key is the same as the original
|
||||||
|
keyBytes, err := EncodeJWKBytes(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("LoadKey with encrypted key", func(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
provider := &KeyProviderFile{}
|
||||||
|
err = provider.Init(KeyProviderOpts{
|
||||||
|
EnvConfig: &common.EnvConfigSchema{
|
||||||
|
KeysPath: tempDir,
|
||||||
|
},
|
||||||
|
Kek: makeKEK(t),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Save a key (will be encrypted)
|
||||||
|
err = provider.SaveKey(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Make sure the encrypted key file exists
|
||||||
|
encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted)
|
||||||
|
exists, err := utils.FileExists(encKeyPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, exists, "Expected encrypted key file to exist")
|
||||||
|
|
||||||
|
// Make sure the unencrypted key file does not exist
|
||||||
|
keyPath := filepath.Join(tempDir, PrivateKeyFile)
|
||||||
|
exists, err = utils.FileExists(keyPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, exists, "Expected unencrypted key file to not exist")
|
||||||
|
|
||||||
|
// Load the key
|
||||||
|
loadedKey, err := provider.LoadKey()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, loadedKey, "Expected non-nil key when encrypted key exists")
|
||||||
|
|
||||||
|
// Verify the loaded key is the same as the original
|
||||||
|
keyBytes, err := EncodeJWKBytes(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("LoadKey replaces unencrypted key with encrypted key when kek is provided", func(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// First, create an unencrypted key
|
||||||
|
providerNoKek := &KeyProviderFile{}
|
||||||
|
err := providerNoKek.Init(KeyProviderOpts{
|
||||||
|
EnvConfig: &common.EnvConfigSchema{
|
||||||
|
KeysPath: tempDir,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Save an unencrypted key
|
||||||
|
err = providerNoKek.SaveKey(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify unencrypted key exists
|
||||||
|
keyPath := filepath.Join(tempDir, PrivateKeyFile)
|
||||||
|
exists, err := utils.FileExists(keyPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, exists, "Expected unencrypted key file to exist")
|
||||||
|
|
||||||
|
// Now create a provider with a kek
|
||||||
|
kek := make([]byte, 32)
|
||||||
|
_, err = rand.Read(kek)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
providerWithKek := &KeyProviderFile{}
|
||||||
|
err = providerWithKek.Init(KeyProviderOpts{
|
||||||
|
EnvConfig: &common.EnvConfigSchema{
|
||||||
|
KeysPath: tempDir,
|
||||||
|
},
|
||||||
|
Kek: kek,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Load the key - this should convert the unencrypted key to encrypted
|
||||||
|
loadedKey, err := providerWithKek.LoadKey()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, loadedKey, "Expected non-nil key when loading and converting key")
|
||||||
|
|
||||||
|
// Verify the unencrypted key no longer exists
|
||||||
|
exists, err = utils.FileExists(keyPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, exists, "Expected unencrypted key file to be removed")
|
||||||
|
|
||||||
|
// Verify the encrypted key file exists
|
||||||
|
encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted)
|
||||||
|
exists, err = utils.FileExists(encKeyPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, exists, "Expected encrypted key file to exist after conversion")
|
||||||
|
|
||||||
|
// Verify the key data
|
||||||
|
keyBytes, err := EncodeJWKBytes(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key after conversion")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyProviderFile_SaveKey(t *testing.T) {
|
||||||
|
// Generate a test key to use in our tests
|
||||||
|
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
key, err := jwk.Import(pk)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Run("SaveKey unencrypted", func(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
provider := &KeyProviderFile{}
|
||||||
|
err := provider.Init(KeyProviderOpts{
|
||||||
|
EnvConfig: &common.EnvConfigSchema{
|
||||||
|
KeysPath: tempDir,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Save the key
|
||||||
|
err = provider.SaveKey(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the key file exists
|
||||||
|
keyPath := filepath.Join(tempDir, PrivateKeyFile)
|
||||||
|
exists, err := utils.FileExists(keyPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, exists, "Expected key file to exist")
|
||||||
|
|
||||||
|
// Verify the content of the key file
|
||||||
|
data, err := os.ReadFile(keyPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
parsedKey, err := jwk.ParseKey(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Compare the saved key with the original
|
||||||
|
keyBytes, err := EncodeJWKBytes(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
parsedKeyBytes, err := EncodeJWKBytes(parsedKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, keyBytes, parsedKeyBytes, "Expected saved key to match original key")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("SaveKey encrypted", func(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Generate a 64-byte kek
|
||||||
|
kek := makeKEK(t)
|
||||||
|
|
||||||
|
provider := &KeyProviderFile{}
|
||||||
|
err = provider.Init(KeyProviderOpts{
|
||||||
|
EnvConfig: &common.EnvConfigSchema{
|
||||||
|
KeysPath: tempDir,
|
||||||
|
},
|
||||||
|
Kek: kek,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Save the key (will be encrypted)
|
||||||
|
err = provider.SaveKey(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the encrypted key file exists
|
||||||
|
encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted)
|
||||||
|
exists, err := utils.FileExists(encKeyPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, exists, "Expected encrypted key file to exist")
|
||||||
|
|
||||||
|
// Verify the unencrypted key file doesn't exist
|
||||||
|
keyPath := filepath.Join(tempDir, PrivateKeyFile)
|
||||||
|
exists, err = utils.FileExists(keyPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, exists, "Expected unencrypted key file to not exist")
|
||||||
|
|
||||||
|
// Manually decrypt the encrypted key file to verify it contains the correct key
|
||||||
|
encB64, err := os.ReadFile(encKeyPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Decode from base64
|
||||||
|
enc := make([]byte, base64.StdEncoding.DecodedLen(len(encB64)))
|
||||||
|
n, err := base64.StdEncoding.Decode(enc, encB64)
|
||||||
|
require.NoError(t, err)
|
||||||
|
enc = enc[:n] // Trim any padding
|
||||||
|
|
||||||
|
// Decrypt the data
|
||||||
|
data, err := cryptoutils.Decrypt(kek, enc, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Parse the key
|
||||||
|
parsedKey, err := jwk.ParseKey(data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Compare the decrypted key with the original
|
||||||
|
keyBytes, err := EncodeJWKBytes(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
parsedKeyBytes, err := EncodeJWKBytes(parsedKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, keyBytes, parsedKeyBytes, "Expected decrypted key to match original key")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeKEK(t *testing.T) []byte {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// Generate a 32-byte kek
|
||||||
|
kek := make([]byte, 32)
|
||||||
|
_, err := rand.Read(kek)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return kek
|
||||||
|
}
|
||||||
180
backend/internal/utils/jwk/utils.go
Normal file
180
backend/internal/utils/jwk/utils.go
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
package jwk
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/sha3"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"hash"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||||
|
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||||
|
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// KeyUsageSigning is the usage for the private keys, for the "use" property
|
||||||
|
KeyUsageSigning = "sig"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EncodeJWK encodes a jwk.Key to a writable stream.
|
||||||
|
func EncodeJWK(w io.Writer, key jwk.Key) error {
|
||||||
|
enc := json.NewEncoder(w)
|
||||||
|
enc.SetEscapeHTML(false)
|
||||||
|
return enc.Encode(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeJWKBytes encodes a jwk.Key to a byte slice.
|
||||||
|
func EncodeJWKBytes(key jwk.Key) ([]byte, error) {
|
||||||
|
b := &bytes.Buffer{}
|
||||||
|
err := EncodeJWK(b, key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return b.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadKeyEncryptionKey loads the key encryption key for JWKs
|
||||||
|
func LoadKeyEncryptionKey(envConfig *common.EnvConfigSchema, instanceID string) (kek []byte, err error) {
|
||||||
|
// Try getting the key from the env var as string
|
||||||
|
kekInput := []byte(envConfig.EncryptionKey)
|
||||||
|
|
||||||
|
// If there's nothing in the env, try loading from file
|
||||||
|
if len(kekInput) == 0 && envConfig.EncryptionKeyFile != "" {
|
||||||
|
kekInput, err = os.ReadFile(envConfig.EncryptionKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read key file '%s': %w", envConfig.EncryptionKeyFile, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there's still no key, return
|
||||||
|
if len(kekInput) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// We need a 256-bit key for encryption with AES-GCM-256
|
||||||
|
// We use HMAC with SHA3-256 here to derive the key from the one passed as input
|
||||||
|
// The key is tied to a specific instance of Pocket ID
|
||||||
|
h := hmac.New(func() hash.Hash { return sha3.New256() }, kekInput)
|
||||||
|
fmt.Fprint(h, "pocketid/"+instanceID+"/jwk-kek")
|
||||||
|
kek = h.Sum(nil)
|
||||||
|
|
||||||
|
return kek, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImportRawKey imports a crypto key in "raw" format (e.g. crypto.PrivateKey) into a jwk.Key.
|
||||||
|
// It also populates additional fields such as the key ID, usage, and alg.
|
||||||
|
func ImportRawKey(rawKey any, alg string, crv string) (jwk.Key, error) {
|
||||||
|
key, err := jwk.Import(rawKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to import generated private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate the key ID
|
||||||
|
kid, err := generateRandomKeyID()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate key ID: %w", err)
|
||||||
|
}
|
||||||
|
_ = key.Set(jwk.KeyIDKey, kid)
|
||||||
|
|
||||||
|
// Set other required fields
|
||||||
|
_ = key.Set(jwk.KeyUsageKey, KeyUsageSigning)
|
||||||
|
EnsureAlgInKey(key, alg, crv)
|
||||||
|
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateRandomKeyID generates a random key ID.
|
||||||
|
func generateRandomKeyID() (string, error) {
|
||||||
|
buf := make([]byte, 8)
|
||||||
|
_, err := io.ReadFull(rand.Reader, buf)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to read random bytes: %w", err)
|
||||||
|
}
|
||||||
|
return base64.RawURLEncoding.EncodeToString(buf), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnsureAlgInKey ensures that the key contains an "alg" parameter (and "crv", if needed), set depending on the key type
|
||||||
|
func EnsureAlgInKey(key jwk.Key, alg string, crv string) {
|
||||||
|
_, ok := key.Algorithm()
|
||||||
|
if ok {
|
||||||
|
// Algorithm is already set
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if alg != "" {
|
||||||
|
_ = key.Set(jwk.AlgorithmKey, alg)
|
||||||
|
if crv != "" {
|
||||||
|
eca, ok := jwa.LookupEllipticCurveAlgorithm(crv)
|
||||||
|
if ok {
|
||||||
|
switch key.KeyType() {
|
||||||
|
case jwa.EC():
|
||||||
|
_ = key.Set(jwk.ECDSACrvKey, eca)
|
||||||
|
case jwa.OKP():
|
||||||
|
_ = key.Set(jwk.OKPCrvKey, eca)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we don't have an algorithm, set the default for the key type
|
||||||
|
switch key.KeyType() {
|
||||||
|
case jwa.RSA():
|
||||||
|
// Default to RS256 for RSA keys
|
||||||
|
_ = key.Set(jwk.AlgorithmKey, jwa.RS256())
|
||||||
|
case jwa.EC():
|
||||||
|
// Default to ES256 for ECDSA keys
|
||||||
|
_ = key.Set(jwk.AlgorithmKey, jwa.ES256())
|
||||||
|
_ = key.Set(jwk.ECDSACrvKey, jwa.P256())
|
||||||
|
case jwa.OKP():
|
||||||
|
// Default to EdDSA and Ed25519 for OKP keys
|
||||||
|
_ = key.Set(jwk.AlgorithmKey, jwa.EdDSA())
|
||||||
|
_ = key.Set(jwk.OKPCrvKey, jwa.Ed25519())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateKey generates a new jwk.Key
|
||||||
|
func GenerateKey(alg string, crv string) (key jwk.Key, err error) {
|
||||||
|
var rawKey any
|
||||||
|
switch alg {
|
||||||
|
case jwa.RS256().String():
|
||||||
|
rawKey, err = rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
case jwa.RS384().String():
|
||||||
|
rawKey, err = rsa.GenerateKey(rand.Reader, 3072)
|
||||||
|
case jwa.RS512().String():
|
||||||
|
rawKey, err = rsa.GenerateKey(rand.Reader, 4096)
|
||||||
|
case jwa.ES256().String():
|
||||||
|
rawKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
case jwa.ES384().String():
|
||||||
|
rawKey, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
||||||
|
case jwa.ES512().String():
|
||||||
|
rawKey, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
|
||||||
|
case jwa.EdDSA().String():
|
||||||
|
switch crv {
|
||||||
|
case jwa.Ed25519().String():
|
||||||
|
_, rawKey, err = ed25519.GenerateKey(rand.Reader)
|
||||||
|
default:
|
||||||
|
return nil, errors.New("unsupported curve for EdDSA algorithm")
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, errors.New("unsupported key algorithm")
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Import the raw key
|
||||||
|
return ImportRawKey(rawKey, alg, crv)
|
||||||
|
}
|
||||||
324
backend/internal/utils/jwk/utils_test.go
Normal file
324
backend/internal/utils/jwk/utils_test.go
Normal file
@@ -0,0 +1,324 @@
|
|||||||
|
package jwk
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||||
|
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGenerateKey(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
alg string
|
||||||
|
crv string
|
||||||
|
expectError bool
|
||||||
|
expectedAlg jwa.SignatureAlgorithm
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "RS256",
|
||||||
|
alg: jwa.RS256().String(),
|
||||||
|
crv: "",
|
||||||
|
expectError: false,
|
||||||
|
expectedAlg: jwa.RS256(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RS384",
|
||||||
|
alg: jwa.RS384().String(),
|
||||||
|
crv: "",
|
||||||
|
expectError: false,
|
||||||
|
expectedAlg: jwa.RS384(),
|
||||||
|
},
|
||||||
|
// Skip the RS512 test as generating a RSA-4096 key can take some time
|
||||||
|
/* {
|
||||||
|
name: "RS512",
|
||||||
|
alg: jwa.RS512().String(),
|
||||||
|
crv: "",
|
||||||
|
expectError: false,
|
||||||
|
expectedAlg: jwa.RS512(),
|
||||||
|
}, */
|
||||||
|
{
|
||||||
|
name: "ES256",
|
||||||
|
alg: jwa.ES256().String(),
|
||||||
|
crv: jwa.P256().String(),
|
||||||
|
expectError: false,
|
||||||
|
expectedAlg: jwa.ES256(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ES384",
|
||||||
|
alg: jwa.ES384().String(),
|
||||||
|
crv: jwa.P384().String(),
|
||||||
|
expectError: false,
|
||||||
|
expectedAlg: jwa.ES384(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ES512",
|
||||||
|
alg: jwa.ES512().String(),
|
||||||
|
crv: jwa.P521().String(),
|
||||||
|
expectError: false,
|
||||||
|
expectedAlg: jwa.ES512(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "EdDSA with Ed25519",
|
||||||
|
alg: jwa.EdDSA().String(),
|
||||||
|
crv: jwa.Ed25519().String(),
|
||||||
|
expectError: false,
|
||||||
|
expectedAlg: jwa.EdDSA(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "EdDSA with unsupported curve",
|
||||||
|
alg: jwa.EdDSA().String(),
|
||||||
|
crv: "unsupported",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Unsupported algorithm",
|
||||||
|
alg: "UNSUPPORTED",
|
||||||
|
crv: "",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
key, err := GenerateKey(tt.alg, tt.crv)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, key)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, key)
|
||||||
|
|
||||||
|
// Verify the algorithm is set correctly
|
||||||
|
alg, ok := key.Algorithm()
|
||||||
|
require.True(t, ok, "algorithm should be set in the key")
|
||||||
|
assert.Equal(t, tt.expectedAlg.String(), alg.String())
|
||||||
|
|
||||||
|
// Verify other required fields are set
|
||||||
|
kid, ok := key.KeyID()
|
||||||
|
assert.True(t, ok, "key ID should be set")
|
||||||
|
assert.NotEmpty(t, kid, "key ID should not be empty")
|
||||||
|
|
||||||
|
usage, ok := key.KeyUsage()
|
||||||
|
assert.True(t, ok, "key usage should be set")
|
||||||
|
assert.Equal(t, KeyUsageSigning, usage)
|
||||||
|
|
||||||
|
var crv any
|
||||||
|
_ = key.Get("crv", &crv)
|
||||||
|
|
||||||
|
// Verify key type matches expected algorithm
|
||||||
|
switch tt.expectedAlg {
|
||||||
|
case jwa.RS256(), jwa.RS384(), jwa.RS512():
|
||||||
|
assert.Equal(t, jwa.RSA(), key.KeyType())
|
||||||
|
assert.Nil(t, crv)
|
||||||
|
case jwa.ES256(), jwa.ES384(), jwa.ES512():
|
||||||
|
assert.Equal(t, jwa.EC(), key.KeyType())
|
||||||
|
eca, ok := crv.(jwa.EllipticCurveAlgorithm)
|
||||||
|
_ = assert.NotNil(t, crv) &&
|
||||||
|
assert.True(t, ok) &&
|
||||||
|
assert.Equal(t, tt.crv, eca.String())
|
||||||
|
case jwa.EdDSA():
|
||||||
|
assert.Equal(t, jwa.OKP(), key.KeyType())
|
||||||
|
eca, ok := crv.(jwa.EllipticCurveAlgorithm)
|
||||||
|
_ = assert.NotNil(t, crv) &&
|
||||||
|
assert.True(t, ok) &&
|
||||||
|
assert.Equal(t, tt.crv, eca.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureAlgInKey(t *testing.T) {
|
||||||
|
// Generate an RSA-2048 key
|
||||||
|
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Run("does not change alg already set", func(t *testing.T) {
|
||||||
|
// Import the RSA key
|
||||||
|
key, err := jwk.Import(rsaKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Pre-set the algorithm
|
||||||
|
_ = key.Set(jwk.AlgorithmKey, jwa.RS256())
|
||||||
|
|
||||||
|
// Call EnsureAlgInKey with a different algorithm
|
||||||
|
EnsureAlgInKey(key, jwa.RS384().String(), "")
|
||||||
|
|
||||||
|
// Verify the algorithm wasn't changed
|
||||||
|
alg, ok := key.Algorithm()
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, jwa.RS256().String(), alg.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("set algorithm to explicitly-provided value", func(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
keyGen func() (any, error)
|
||||||
|
alg string
|
||||||
|
crv string
|
||||||
|
expectedAlg jwa.SignatureAlgorithm
|
||||||
|
expectedCrv string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "RSA key with RS384",
|
||||||
|
keyGen: func() (any, error) {
|
||||||
|
return rsaKey, nil
|
||||||
|
},
|
||||||
|
alg: jwa.RS384().String(),
|
||||||
|
crv: "",
|
||||||
|
expectedAlg: jwa.RS384(),
|
||||||
|
expectedCrv: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ECDSA key with ES384",
|
||||||
|
keyGen: func() (any, error) {
|
||||||
|
return ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
},
|
||||||
|
alg: jwa.ES384().String(),
|
||||||
|
crv: jwa.P384().String(),
|
||||||
|
expectedAlg: jwa.ES384(),
|
||||||
|
expectedCrv: jwa.P384().String(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Ed25519 key with EdDSA",
|
||||||
|
keyGen: func() (any, error) {
|
||||||
|
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
return priv, err
|
||||||
|
},
|
||||||
|
alg: jwa.EdDSA().String(),
|
||||||
|
crv: jwa.Ed25519().String(),
|
||||||
|
expectedAlg: jwa.EdDSA(),
|
||||||
|
expectedCrv: jwa.Ed25519().String(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rawKey, err := tt.keyGen()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
key, err := jwk.Import(rawKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Ensure no algorithm is set initially
|
||||||
|
_, ok := key.Algorithm()
|
||||||
|
assert.False(t, ok)
|
||||||
|
|
||||||
|
// Call EnsureAlgInKey
|
||||||
|
EnsureAlgInKey(key, tt.alg, tt.crv)
|
||||||
|
|
||||||
|
// Verify the algorithm was set correctly
|
||||||
|
alg, ok := key.Algorithm()
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, tt.expectedAlg.String(), alg.String())
|
||||||
|
|
||||||
|
// Verify curve if expected
|
||||||
|
if tt.expectedCrv != "" {
|
||||||
|
var crv any
|
||||||
|
_ = key.Get("crv", &crv)
|
||||||
|
require.NotNil(t, crv)
|
||||||
|
eca, ok := crv.(jwa.EllipticCurveAlgorithm)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, tt.expectedCrv, eca.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("set default algorithms if not present", func(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
keyGen func() (any, error)
|
||||||
|
expectedAlg jwa.SignatureAlgorithm
|
||||||
|
expectedCrv string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "RSA key defaults to RS256",
|
||||||
|
keyGen: func() (any, error) {
|
||||||
|
return rsaKey, nil
|
||||||
|
},
|
||||||
|
expectedAlg: jwa.RS256(),
|
||||||
|
expectedCrv: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ECDSA key defaults to ES256 with P256",
|
||||||
|
keyGen: func() (any, error) {
|
||||||
|
return ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
},
|
||||||
|
expectedAlg: jwa.ES256(),
|
||||||
|
expectedCrv: jwa.P256().String(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Ed25519 key defaults to EdDSA with Ed25519",
|
||||||
|
keyGen: func() (any, error) {
|
||||||
|
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
return priv, err
|
||||||
|
},
|
||||||
|
expectedAlg: jwa.EdDSA(),
|
||||||
|
expectedCrv: jwa.Ed25519().String(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rawKey, err := tt.keyGen()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
key, err := jwk.Import(rawKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Ensure no algorithm is set initially
|
||||||
|
_, ok := key.Algorithm()
|
||||||
|
assert.False(t, ok)
|
||||||
|
|
||||||
|
// Call EnsureAlgInKey with empty parameters
|
||||||
|
EnsureAlgInKey(key, "", "")
|
||||||
|
|
||||||
|
// Verify the default algorithm was set
|
||||||
|
alg, ok := key.Algorithm()
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, tt.expectedAlg.String(), alg.String())
|
||||||
|
|
||||||
|
// Verify curve if expected
|
||||||
|
if tt.expectedCrv != "" {
|
||||||
|
var crv any
|
||||||
|
_ = key.Get("crv", &crv)
|
||||||
|
require.NotNil(t, crv)
|
||||||
|
eca, ok := crv.(jwa.EllipticCurveAlgorithm)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, tt.expectedCrv, eca.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid curve should not set curve parameter", func(t *testing.T) {
|
||||||
|
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
key, err := jwk.Import(rsaKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Call EnsureAlgInKey with invalid curve
|
||||||
|
EnsureAlgInKey(key, jwa.RS256().String(), "invalid-curve")
|
||||||
|
|
||||||
|
// Verify algorithm was set but curve was not
|
||||||
|
alg, ok := key.Algorithm()
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, jwa.RS256().String(), alg.String())
|
||||||
|
|
||||||
|
var crv any
|
||||||
|
_ = key.Get("crv", &crv)
|
||||||
|
assert.Nil(t, crv)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -1,69 +0,0 @@
|
|||||||
package utils
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
|
||||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// KeyUsageSigning is the usage for the private keys, for the "use" property
|
|
||||||
KeyUsageSigning = "sig"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ImportRawKey imports a crypto key in "raw" format (e.g. crypto.PrivateKey) into a jwk.Key.
|
|
||||||
// It also populates additional fields such as the key ID, usage, and alg.
|
|
||||||
func ImportRawKey(rawKey any) (jwk.Key, error) {
|
|
||||||
key, err := jwk.Import(rawKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to import generated private key: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate the key ID
|
|
||||||
kid, err := generateRandomKeyID()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to generate key ID: %w", err)
|
|
||||||
}
|
|
||||||
_ = key.Set(jwk.KeyIDKey, kid)
|
|
||||||
|
|
||||||
// Set other required fields
|
|
||||||
_ = key.Set(jwk.KeyUsageKey, KeyUsageSigning)
|
|
||||||
EnsureAlgInKey(key)
|
|
||||||
|
|
||||||
return key, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateRandomKeyID generates a random key ID.
|
|
||||||
func generateRandomKeyID() (string, error) {
|
|
||||||
buf := make([]byte, 8)
|
|
||||||
_, err := io.ReadFull(rand.Reader, buf)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to read random bytes: %w", err)
|
|
||||||
}
|
|
||||||
return base64.RawURLEncoding.EncodeToString(buf), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// EnsureAlgInKey ensures that the key contains an "alg" parameter, set depending on the key type
|
|
||||||
func EnsureAlgInKey(key jwk.Key) {
|
|
||||||
_, ok := key.Algorithm()
|
|
||||||
if ok {
|
|
||||||
// Algorithm is already set
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
switch key.KeyType() {
|
|
||||||
case jwa.RSA():
|
|
||||||
// Default to RS256 for RSA keys
|
|
||||||
_ = key.Set(jwk.AlgorithmKey, jwa.RS256())
|
|
||||||
case jwa.EC():
|
|
||||||
// Default to ES256 for ECDSA keys
|
|
||||||
_ = key.Set(jwk.AlgorithmKey, jwa.ES256())
|
|
||||||
case jwa.OKP():
|
|
||||||
// Default to EdDSA for OKP keys
|
|
||||||
_ = key.Set(jwk.AlgorithmKey, jwa.EdDSA())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,9 +1,8 @@
|
|||||||
package service
|
// This file is only imported by unit tests
|
||||||
|
|
||||||
|
package testing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -21,7 +20,10 @@ import (
|
|||||||
"github.com/pocket-id/pocket-id/backend/resources"
|
"github.com/pocket-id/pocket-id/backend/resources"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newDatabaseForTest(t *testing.T) *gorm.DB {
|
// NewDatabaseForTest returns a new instance of GORM connected to an in-memory SQLite database.
|
||||||
|
// Each database connection is unique for the test.
|
||||||
|
// All migrations are automatically performed.
|
||||||
|
func NewDatabaseForTest(t *testing.T) *gorm.DB {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
// Get a name for this in-memory database that is specific to the test
|
// Get a name for this in-memory database that is specific to the test
|
||||||
@@ -68,30 +70,3 @@ type testLoggerAdapter struct {
|
|||||||
func (l testLoggerAdapter) Printf(format string, args ...any) {
|
func (l testLoggerAdapter) Printf(format string, args ...any) {
|
||||||
l.t.Logf(format, args...)
|
l.t.Logf(format, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MockRoundTripper is a custom http.RoundTripper that returns responses based on the URL
|
|
||||||
type MockRoundTripper struct {
|
|
||||||
Err error
|
|
||||||
Responses map[string]*http.Response
|
|
||||||
}
|
|
||||||
|
|
||||||
// RoundTrip implements the http.RoundTripper interface
|
|
||||||
func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
||||||
// Check if we have a specific response for this URL
|
|
||||||
for url, resp := range m.Responses {
|
|
||||||
if req.URL.String() == url {
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return NewMockResponse(http.StatusNotFound, ""), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewMockResponse creates an http.Response with the given status code and body
|
|
||||||
func NewMockResponse(statusCode int, body string) *http.Response {
|
|
||||||
return &http.Response{
|
|
||||||
StatusCode: statusCode,
|
|
||||||
Body: io.NopCloser(strings.NewReader(body)),
|
|
||||||
Header: make(http.Header),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
38
backend/internal/utils/testing/round_tripper.go
Normal file
38
backend/internal/utils/testing/round_tripper.go
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
// This file is only imported by unit tests
|
||||||
|
|
||||||
|
package testing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockRoundTripper is a custom http.RoundTripper that returns responses based on the URL
|
||||||
|
type MockRoundTripper struct {
|
||||||
|
Err error
|
||||||
|
Responses map[string]*http.Response
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoundTrip implements the http.RoundTripper interface
|
||||||
|
func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
// Check if we have a specific response for this URL
|
||||||
|
for url, resp := range m.Responses {
|
||||||
|
if req.URL.String() == url {
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewMockResponse(http.StatusNotFound, ""), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockResponse creates an http.Response with the given status code and body
|
||||||
|
func NewMockResponse(statusCode int, body string) *http.Response {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: statusCode,
|
||||||
|
Body: io.NopCloser(strings.NewReader(body)),
|
||||||
|
Header: make(http.Header),
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
DROP TABLE kv;
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
-- The "kv" tables contains miscellaneous key-value pairs
|
||||||
|
CREATE TABLE kv
|
||||||
|
(
|
||||||
|
"key" TEXT NOT NULL PRIMARY KEY,
|
||||||
|
"value" TEXT
|
||||||
|
);
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
DROP TABLE kv;
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
-- The "kv" tables contains miscellaneous key-value pairs
|
||||||
|
CREATE TABLE kv
|
||||||
|
(
|
||||||
|
"key" TEXT NOT NULL PRIMARY KEY,
|
||||||
|
"value" TEXT NOT NULL
|
||||||
|
);
|
||||||
Reference in New Issue
Block a user