mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-13 16:52:58 +03:00
feat: encrypt private keys saved on disk and in database (#682)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
This commit is contained in:
committed by
GitHub
parent
9872608d61
commit
5550729120
@@ -38,7 +38,7 @@ func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client) (sv
|
||||
|
||||
svc.geoLiteService = service.NewGeoLiteService(httpClient)
|
||||
svc.auditLogService = service.NewAuditLogService(db, svc.appConfigService, svc.emailService, svc.geoLiteService)
|
||||
svc.jwtService = service.NewJwtService(svc.appConfigService)
|
||||
svc.jwtService = service.NewJwtService(db, svc.appConfigService)
|
||||
svc.userService = service.NewUserService(db, svc.jwtService, svc.auditLogService, svc.emailService, svc.appConfigService)
|
||||
svc.customClaimService = service.NewCustomClaimService(db)
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/url"
|
||||
|
||||
@@ -18,9 +20,10 @@ const (
|
||||
)
|
||||
|
||||
const (
|
||||
DbProviderSqlite DbProvider = "sqlite"
|
||||
DbProviderPostgres DbProvider = "postgres"
|
||||
MaxMindGeoLiteCityUrl string = "https://download.maxmind.com/app/geoip_download?edition_id=GeoLite2-City&license_key=%s&suffix=tar.gz"
|
||||
DbProviderSqlite DbProvider = "sqlite"
|
||||
DbProviderPostgres DbProvider = "postgres"
|
||||
MaxMindGeoLiteCityUrl string = "https://download.maxmind.com/app/geoip_download?edition_id=GeoLite2-City&license_key=%s&suffix=tar.gz"
|
||||
defaultSqliteConnString string = "file:data/pocket-id.db?_pragma=journal_mode(WAL)&_pragma=busy_timeout(2500)&_txlock=immediate"
|
||||
)
|
||||
|
||||
type EnvConfigSchema struct {
|
||||
@@ -30,6 +33,9 @@ type EnvConfigSchema struct {
|
||||
DbConnectionString string `env:"DB_CONNECTION_STRING"`
|
||||
UploadPath string `env:"UPLOAD_PATH"`
|
||||
KeysPath string `env:"KEYS_PATH"`
|
||||
KeysStorage string `env:"KEYS_STORAGE"`
|
||||
EncryptionKey string `env:"ENCRYPTION_KEY"`
|
||||
EncryptionKeyFile string `env:"ENCRYPTION_KEY_FILE"`
|
||||
Port string `env:"PORT"`
|
||||
Host string `env:"HOST"`
|
||||
UnixSocket string `env:"UNIX_SOCKET"`
|
||||
@@ -45,52 +51,83 @@ type EnvConfigSchema struct {
|
||||
AnalyticsDisabled bool `env:"ANALYTICS_DISABLED"`
|
||||
}
|
||||
|
||||
var EnvConfig = &EnvConfigSchema{
|
||||
AppEnv: "production",
|
||||
DbProvider: "sqlite",
|
||||
DbConnectionString: "file:data/pocket-id.db?_pragma=journal_mode(WAL)&_pragma=busy_timeout(2500)&_txlock=immediate",
|
||||
UploadPath: "data/uploads",
|
||||
KeysPath: "data/keys",
|
||||
AppURL: "http://localhost:1411",
|
||||
Port: "1411",
|
||||
Host: "0.0.0.0",
|
||||
UnixSocket: "",
|
||||
UnixSocketMode: "",
|
||||
MaxMindLicenseKey: "",
|
||||
GeoLiteDBPath: "data/GeoLite2-City.mmdb",
|
||||
GeoLiteDBUrl: MaxMindGeoLiteCityUrl,
|
||||
LocalIPv6Ranges: "",
|
||||
UiConfigDisabled: false,
|
||||
MetricsEnabled: false,
|
||||
TracingEnabled: false,
|
||||
TrustProxy: false,
|
||||
AnalyticsDisabled: false,
|
||||
}
|
||||
var EnvConfig = defaultConfig()
|
||||
|
||||
func init() {
|
||||
if err := env.ParseWithOptions(EnvConfig, env.Options{}); err != nil {
|
||||
log.Fatal(err)
|
||||
err := parseEnvConfig()
|
||||
if err != nil {
|
||||
log.Fatalf("Configuration error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func defaultConfig() EnvConfigSchema {
|
||||
return EnvConfigSchema{
|
||||
AppEnv: "production",
|
||||
DbProvider: "sqlite",
|
||||
DbConnectionString: "",
|
||||
UploadPath: "data/uploads",
|
||||
KeysPath: "data/keys",
|
||||
KeysStorage: "", // "database" or "file"
|
||||
EncryptionKey: "",
|
||||
AppURL: "http://localhost:1411",
|
||||
Port: "1411",
|
||||
Host: "0.0.0.0",
|
||||
UnixSocket: "",
|
||||
UnixSocketMode: "",
|
||||
MaxMindLicenseKey: "",
|
||||
GeoLiteDBPath: "data/GeoLite2-City.mmdb",
|
||||
GeoLiteDBUrl: MaxMindGeoLiteCityUrl,
|
||||
LocalIPv6Ranges: "",
|
||||
UiConfigDisabled: false,
|
||||
MetricsEnabled: false,
|
||||
TracingEnabled: false,
|
||||
TrustProxy: false,
|
||||
AnalyticsDisabled: false,
|
||||
}
|
||||
}
|
||||
|
||||
func parseEnvConfig() error {
|
||||
err := env.ParseWithOptions(&EnvConfig, env.Options{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing env config: %w", err)
|
||||
}
|
||||
|
||||
// Validate the environment variables
|
||||
switch EnvConfig.DbProvider {
|
||||
case DbProviderSqlite:
|
||||
if EnvConfig.DbConnectionString == "" {
|
||||
log.Fatal("Missing required env var 'DB_CONNECTION_STRING' for SQLite database")
|
||||
EnvConfig.DbConnectionString = defaultSqliteConnString
|
||||
}
|
||||
case DbProviderPostgres:
|
||||
if EnvConfig.DbConnectionString == "" {
|
||||
log.Fatal("Missing required env var 'DB_CONNECTION_STRING' for Postgres database")
|
||||
return errors.New("missing required env var 'DB_CONNECTION_STRING' for Postgres database")
|
||||
}
|
||||
default:
|
||||
log.Fatal("Invalid DB_PROVIDER value. Must be 'sqlite' or 'postgres'")
|
||||
return errors.New("invalid DB_PROVIDER value. Must be 'sqlite' or 'postgres'")
|
||||
}
|
||||
|
||||
parsedAppUrl, err := url.Parse(EnvConfig.AppURL)
|
||||
if err != nil {
|
||||
log.Fatal("APP_URL is not a valid URL")
|
||||
return errors.New("APP_URL is not a valid URL")
|
||||
}
|
||||
if parsedAppUrl.Path != "" {
|
||||
log.Fatal("APP_URL must not contain a path")
|
||||
return errors.New("APP_URL must not contain a path")
|
||||
}
|
||||
|
||||
switch EnvConfig.KeysStorage {
|
||||
// KeysStorage defaults to "file" if empty
|
||||
case "":
|
||||
EnvConfig.KeysStorage = "file"
|
||||
case "database":
|
||||
// If KeysStorage is "database", a key must be specified
|
||||
if EnvConfig.EncryptionKey == "" && EnvConfig.EncryptionKeyFile == "" {
|
||||
return errors.New("ENCRYPTION_KEY or ENCRYPTION_KEY_FILE must be non-empty when KEYS_STORAGE is database")
|
||||
}
|
||||
case "file":
|
||||
// All good, these are valid values
|
||||
default:
|
||||
return fmt.Errorf("invalid value for KEYS_STORAGE: %s", EnvConfig.KeysStorage)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
188
backend/internal/common/env_config_test.go
Normal file
188
backend/internal/common/env_config_test.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseEnvConfig(t *testing.T) {
|
||||
// Store original config to restore later
|
||||
originalConfig := EnvConfig
|
||||
t.Cleanup(func() {
|
||||
EnvConfig = originalConfig
|
||||
})
|
||||
|
||||
t.Run("should parse valid SQLite config correctly", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, DbProviderSqlite, EnvConfig.DbProvider)
|
||||
})
|
||||
|
||||
t.Run("should parse valid Postgres config correctly", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "postgres")
|
||||
t.Setenv("DB_CONNECTION_STRING", "postgres://user:pass@localhost/db")
|
||||
t.Setenv("APP_URL", "https://example.com")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, DbProviderPostgres, EnvConfig.DbProvider)
|
||||
})
|
||||
|
||||
t.Run("should fail with invalid DB_PROVIDER", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "invalid")
|
||||
t.Setenv("DB_CONNECTION_STRING", "test")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "invalid DB_PROVIDER value")
|
||||
})
|
||||
|
||||
t.Run("should set default SQLite connection string when DB_CONNECTION_STRING is empty", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "") // Explicitly empty
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, defaultSqliteConnString, EnvConfig.DbConnectionString)
|
||||
})
|
||||
|
||||
t.Run("should fail when Postgres DB_CONNECTION_STRING is missing", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "postgres")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "missing required env var 'DB_CONNECTION_STRING' for Postgres")
|
||||
})
|
||||
|
||||
t.Run("should fail with invalid APP_URL", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "€://not-a-valid-url")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "APP_URL is not a valid URL")
|
||||
})
|
||||
|
||||
t.Run("should fail when APP_URL contains path", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000/path")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "APP_URL must not contain a path")
|
||||
})
|
||||
|
||||
t.Run("should default KEYS_STORAGE to 'file' when empty", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "file", EnvConfig.KeysStorage)
|
||||
})
|
||||
|
||||
t.Run("should fail when KEYS_STORAGE is 'database' but no encryption key", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
t.Setenv("KEYS_STORAGE", "database")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "ENCRYPTION_KEY or ENCRYPTION_KEY_FILE must be non-empty")
|
||||
})
|
||||
|
||||
t.Run("should accept valid KEYS_STORAGE values", func(t *testing.T) {
|
||||
validStorageTypes := []string{"file", "database"}
|
||||
|
||||
for _, storage := range validStorageTypes {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
t.Setenv("KEYS_STORAGE", storage)
|
||||
if storage == "database" {
|
||||
t.Setenv("ENCRYPTION_KEY", "test-key")
|
||||
}
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, storage, EnvConfig.KeysStorage)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("should fail with invalid KEYS_STORAGE value", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
t.Setenv("KEYS_STORAGE", "invalid")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "invalid value for KEYS_STORAGE")
|
||||
})
|
||||
|
||||
t.Run("should parse boolean environment variables correctly", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "sqlite")
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
t.Setenv("UI_CONFIG_DISABLED", "true")
|
||||
t.Setenv("METRICS_ENABLED", "true")
|
||||
t.Setenv("TRACING_ENABLED", "false")
|
||||
t.Setenv("TRUST_PROXY", "true")
|
||||
t.Setenv("ANALYTICS_DISABLED", "false")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.NoError(t, err)
|
||||
assert.True(t, EnvConfig.UiConfigDisabled)
|
||||
assert.True(t, EnvConfig.MetricsEnabled)
|
||||
assert.False(t, EnvConfig.TracingEnabled)
|
||||
assert.True(t, EnvConfig.TrustProxy)
|
||||
assert.False(t, EnvConfig.AnalyticsDisabled)
|
||||
})
|
||||
|
||||
t.Run("should parse string environment variables correctly", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_PROVIDER", "postgres")
|
||||
t.Setenv("DB_CONNECTION_STRING", "postgres://test")
|
||||
t.Setenv("APP_URL", "https://prod.example.com")
|
||||
t.Setenv("APP_ENV", "staging")
|
||||
t.Setenv("UPLOAD_PATH", "/custom/uploads")
|
||||
t.Setenv("KEYS_PATH", "/custom/keys")
|
||||
t.Setenv("PORT", "8080")
|
||||
t.Setenv("HOST", "127.0.0.1")
|
||||
t.Setenv("UNIX_SOCKET", "/tmp/app.sock")
|
||||
t.Setenv("MAXMIND_LICENSE_KEY", "test-license")
|
||||
t.Setenv("GEOLITE_DB_PATH", "/custom/geolite.mmdb")
|
||||
|
||||
err := parseEnvConfig()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "staging", EnvConfig.AppEnv)
|
||||
assert.Equal(t, "/custom/uploads", EnvConfig.UploadPath)
|
||||
assert.Equal(t, "8080", EnvConfig.Port)
|
||||
assert.Equal(t, "127.0.0.1", EnvConfig.Host)
|
||||
})
|
||||
}
|
||||
11
backend/internal/model/kv.go
Normal file
11
backend/internal/model/kv.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package model
|
||||
|
||||
type KV struct {
|
||||
Key string `gorm:"primaryKey;not null"`
|
||||
Value *string
|
||||
}
|
||||
|
||||
// TableName overrides the table name used by KV to `kv`
|
||||
func (KV) TableName() string {
|
||||
return "kv"
|
||||
}
|
||||
@@ -4,10 +4,12 @@ import (
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
"github.com/stretchr/testify/require"
|
||||
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
|
||||
)
|
||||
|
||||
// NewTestAppConfigService is a function used by tests to create AppConfigService objects with pre-defined configuration values
|
||||
@@ -22,7 +24,7 @@ func NewTestAppConfigService(config *model.AppConfig) *AppConfigService {
|
||||
|
||||
func TestLoadDbConfig(t *testing.T) {
|
||||
t.Run("empty config table", func(t *testing.T) {
|
||||
db := newDatabaseForTest(t)
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
service := &AppConfigService{
|
||||
db: db,
|
||||
}
|
||||
@@ -36,7 +38,7 @@ func TestLoadDbConfig(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("loads value from config table", func(t *testing.T) {
|
||||
db := newDatabaseForTest(t)
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
|
||||
// Populate the config table with some initial values
|
||||
err := db.
|
||||
@@ -66,7 +68,7 @@ func TestLoadDbConfig(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("ignores unknown config keys", func(t *testing.T) {
|
||||
db := newDatabaseForTest(t)
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
|
||||
// Add an entry with a key that doesn't exist in the config struct
|
||||
err := db.Create([]model.AppConfigVariable{
|
||||
@@ -87,7 +89,7 @@ func TestLoadDbConfig(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("loading config multiple times", func(t *testing.T) {
|
||||
db := newDatabaseForTest(t)
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
|
||||
// Initial state
|
||||
err := db.Create([]model.AppConfigVariable{
|
||||
@@ -129,7 +131,7 @@ func TestLoadDbConfig(t *testing.T) {
|
||||
common.EnvConfig.UiConfigDisabled = true
|
||||
|
||||
// Create database with config that should be ignored
|
||||
db := newDatabaseForTest(t)
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
err := db.Create([]model.AppConfigVariable{
|
||||
{Key: "appName", Value: "DB App"},
|
||||
{Key: "sessionDuration", Value: "120"},
|
||||
@@ -165,7 +167,7 @@ func TestLoadDbConfig(t *testing.T) {
|
||||
common.EnvConfig.UiConfigDisabled = false
|
||||
|
||||
// Create database with config values that should take precedence
|
||||
db := newDatabaseForTest(t)
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
err := db.Create([]model.AppConfigVariable{
|
||||
{Key: "appName", Value: "DB App"},
|
||||
{Key: "sessionDuration", Value: "120"},
|
||||
@@ -189,7 +191,7 @@ func TestLoadDbConfig(t *testing.T) {
|
||||
|
||||
func TestUpdateAppConfigValues(t *testing.T) {
|
||||
t.Run("update single value", func(t *testing.T) {
|
||||
db := newDatabaseForTest(t)
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
|
||||
// Create a service with default config
|
||||
service := &AppConfigService{
|
||||
@@ -214,7 +216,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("update multiple values", func(t *testing.T) {
|
||||
db := newDatabaseForTest(t)
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
|
||||
// Create a service with default config
|
||||
service := &AppConfigService{
|
||||
@@ -258,7 +260,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("empty value resets to default", func(t *testing.T) {
|
||||
db := newDatabaseForTest(t)
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
|
||||
// Create a service with default config
|
||||
service := &AppConfigService{
|
||||
@@ -279,7 +281,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("error with odd number of arguments", func(t *testing.T) {
|
||||
db := newDatabaseForTest(t)
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
|
||||
// Create a service with default config
|
||||
service := &AppConfigService{
|
||||
@@ -295,7 +297,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("error with invalid key", func(t *testing.T) {
|
||||
db := newDatabaseForTest(t)
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
|
||||
// Create a service with default config
|
||||
service := &AppConfigService{
|
||||
@@ -313,7 +315,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
|
||||
|
||||
func TestUpdateAppConfig(t *testing.T) {
|
||||
t.Run("updates configuration values from DTO", func(t *testing.T) {
|
||||
db := newDatabaseForTest(t)
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
|
||||
// Create a service with default config
|
||||
service := &AppConfigService{
|
||||
@@ -386,7 +388,7 @@ func TestUpdateAppConfig(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("empty values reset to defaults", func(t *testing.T) {
|
||||
db := newDatabaseForTest(t)
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
|
||||
// Create a service with default config and modify some values
|
||||
service := &AppConfigService{
|
||||
@@ -451,7 +453,7 @@ func TestUpdateAppConfig(t *testing.T) {
|
||||
// Disable UI config
|
||||
common.EnvConfig.UiConfigDisabled = true
|
||||
|
||||
db := newDatabaseForTest(t)
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
service := &AppConfigService{
|
||||
db: db,
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/go-webauthn/webauthn/protocol"
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
"gorm.io/gorm"
|
||||
@@ -25,6 +26,7 @@ import (
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
|
||||
"github.com/pocket-id/pocket-id/backend/resources"
|
||||
)
|
||||
|
||||
@@ -60,7 +62,7 @@ func (s *TestService) initExternalIdP() error {
|
||||
return fmt.Errorf("failed to generate private key: %w", err)
|
||||
}
|
||||
|
||||
s.externalIdPKey, err = utils.ImportRawKey(rawKey)
|
||||
s.externalIdPKey, err = jwkutils.ImportRawKey(rawKey, jwa.ES256().String(), "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to import private key: %w", err)
|
||||
}
|
||||
|
||||
@@ -2,23 +2,20 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -26,8 +23,9 @@ const (
|
||||
// This is a JSON file containing a key encoded as JWK
|
||||
PrivateKeyFile = "jwt_private_key.json"
|
||||
|
||||
// RsaKeySize is the size, in bits, of the RSA key to generate if none is found
|
||||
RsaKeySize = 2048
|
||||
// PrivateKeyFileEncrypted is the path in the data/keys folder where the encrypted key is stored
|
||||
// This is a encrypted JSON file containing a key encoded as JWK
|
||||
PrivateKeyFileEncrypted = "jwt_private_key.json.enc"
|
||||
|
||||
// KeyUsageSigning is the usage for the private keys, for the "use" property
|
||||
KeyUsageSigning = "sig"
|
||||
@@ -59,58 +57,74 @@ const (
|
||||
)
|
||||
|
||||
type JwtService struct {
|
||||
envConfig *common.EnvConfigSchema
|
||||
privateKey jwk.Key
|
||||
keyId string
|
||||
appConfigService *AppConfigService
|
||||
jwksEncoded []byte
|
||||
}
|
||||
|
||||
func NewJwtService(appConfigService *AppConfigService) *JwtService {
|
||||
func NewJwtService(db *gorm.DB, appConfigService *AppConfigService) *JwtService {
|
||||
service := &JwtService{}
|
||||
|
||||
// Ensure keys are generated or loaded
|
||||
if err := service.init(appConfigService, common.EnvConfig.KeysPath); err != nil {
|
||||
err := service.init(db, appConfigService, &common.EnvConfig)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to initialize jwt service: %v", err)
|
||||
}
|
||||
|
||||
return service
|
||||
}
|
||||
|
||||
func (s *JwtService) init(appConfigService *AppConfigService, keysPath string) error {
|
||||
func (s *JwtService) init(db *gorm.DB, appConfigService *AppConfigService, envConfig *common.EnvConfigSchema) (err error) {
|
||||
s.appConfigService = appConfigService
|
||||
s.envConfig = envConfig
|
||||
|
||||
// Ensure keys are generated or loaded
|
||||
return s.loadOrGenerateKey(keysPath)
|
||||
return s.loadOrGenerateKey(db)
|
||||
}
|
||||
|
||||
// loadOrGenerateKey loads the private key from the given path or generates it if not existing.
|
||||
func (s *JwtService) loadOrGenerateKey(keysPath string) error {
|
||||
var key jwk.Key
|
||||
|
||||
// First, check if we have a JWK file
|
||||
// If we do, then we just load that
|
||||
jwkPath := filepath.Join(keysPath, PrivateKeyFile)
|
||||
ok, err := utils.FileExists(jwkPath)
|
||||
func (s *JwtService) loadOrGenerateKey(db *gorm.DB) error {
|
||||
// Get the key provider
|
||||
keyProvider, err := jwkutils.GetKeyProvider(db, s.envConfig, s.appConfigService.GetDbConfig().InstanceID.Value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if private key file (JWK) exists at path '%s': %w", jwkPath, err)
|
||||
return fmt.Errorf("failed to get key provider: %w", err)
|
||||
}
|
||||
if ok {
|
||||
key, err = s.loadKeyJWK(jwkPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load private key file (JWK) at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
|
||||
// Set the key, and we are done
|
||||
// Try loading a key
|
||||
key, err := keyProvider.LoadKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load key (provider type '%s'): %w", s.envConfig.KeysStorage, err)
|
||||
}
|
||||
|
||||
// If we have a key, store it in the object and we're done
|
||||
if key != nil {
|
||||
err = s.SetKey(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set private key: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// If we are here, we need to generate a new key
|
||||
key, err = s.generateNewRSAKey()
|
||||
err = s.generateKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate key: %w", err)
|
||||
}
|
||||
|
||||
// Save the newly-generated key
|
||||
err = keyProvider.SaveKey(s.privateKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save private key (provider type '%s'): %w", s.envConfig.KeysStorage, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateKey generates a new key and stores it in the object
|
||||
func (s *JwtService) generateKey() error {
|
||||
// Default is to generate RS256 (RSA-2048) keys
|
||||
key, err := jwkutils.GenerateKey(jwa.RS256().String(), "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate new private key: %w", err)
|
||||
}
|
||||
@@ -121,12 +135,6 @@ func (s *JwtService) loadOrGenerateKey(keysPath string) error {
|
||||
return fmt.Errorf("failed to set private key: %w", err)
|
||||
}
|
||||
|
||||
// Save the key as JWK
|
||||
err = SaveKeyJWK(s.privateKey, jwkPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save private key file at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -192,13 +200,13 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
|
||||
Subject(user.ID).
|
||||
Expiration(now.Add(s.appConfigService.GetDbConfig().SessionDuration.AsDurationMinutes())).
|
||||
IssuedAt(now).
|
||||
Issuer(common.EnvConfig.AppURL).
|
||||
Issuer(s.envConfig.AppURL).
|
||||
Build()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to build token: %w", err)
|
||||
}
|
||||
|
||||
err = SetAudienceString(token, common.EnvConfig.AppURL)
|
||||
err = SetAudienceString(token, s.envConfig.AppURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
|
||||
}
|
||||
@@ -229,8 +237,8 @@ func (s *JwtService) VerifyAccessToken(tokenString string) (jwt.Token, error) {
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithKey(alg, s.privateKey),
|
||||
jwt.WithAcceptableSkew(clockSkew),
|
||||
jwt.WithAudience(common.EnvConfig.AppURL),
|
||||
jwt.WithIssuer(common.EnvConfig.AppURL),
|
||||
jwt.WithAudience(s.envConfig.AppURL),
|
||||
jwt.WithIssuer(s.envConfig.AppURL),
|
||||
jwt.WithValidator(TokenTypeValidator(AccessTokenJWTType)),
|
||||
)
|
||||
if err != nil {
|
||||
@@ -246,7 +254,7 @@ func (s *JwtService) BuildIDToken(userClaims map[string]any, clientID string, no
|
||||
token, err := jwt.NewBuilder().
|
||||
Expiration(now.Add(1 * time.Hour)).
|
||||
IssuedAt(now).
|
||||
Issuer(common.EnvConfig.AppURL).
|
||||
Issuer(s.envConfig.AppURL).
|
||||
Build()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build token: %w", err)
|
||||
@@ -305,7 +313,7 @@ func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool)
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithKey(alg, s.privateKey),
|
||||
jwt.WithAcceptableSkew(clockSkew),
|
||||
jwt.WithIssuer(common.EnvConfig.AppURL),
|
||||
jwt.WithIssuer(s.envConfig.AppURL),
|
||||
jwt.WithValidator(TokenTypeValidator(IDTokenJWTType)),
|
||||
)
|
||||
|
||||
@@ -335,7 +343,7 @@ func (s *JwtService) BuildOAuthAccessToken(user model.User, clientID string) (jw
|
||||
Subject(user.ID).
|
||||
Expiration(now.Add(1 * time.Hour)).
|
||||
IssuedAt(now).
|
||||
Issuer(common.EnvConfig.AppURL).
|
||||
Issuer(s.envConfig.AppURL).
|
||||
Build()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build token: %w", err)
|
||||
@@ -377,7 +385,7 @@ func (s *JwtService) VerifyOAuthAccessToken(tokenString string) (jwt.Token, erro
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithKey(alg, s.privateKey),
|
||||
jwt.WithAcceptableSkew(clockSkew),
|
||||
jwt.WithIssuer(common.EnvConfig.AppURL),
|
||||
jwt.WithIssuer(s.envConfig.AppURL),
|
||||
jwt.WithValidator(TokenTypeValidator(OAuthAccessTokenJWTType)),
|
||||
)
|
||||
if err != nil {
|
||||
@@ -393,7 +401,7 @@ func (s *JwtService) GenerateOAuthRefreshToken(userID string, clientID string, r
|
||||
Subject(userID).
|
||||
Expiration(now.Add(RefreshTokenDuration)).
|
||||
IssuedAt(now).
|
||||
Issuer(common.EnvConfig.AppURL).
|
||||
Issuer(s.envConfig.AppURL).
|
||||
Build()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to build token: %w", err)
|
||||
@@ -430,7 +438,7 @@ func (s *JwtService) VerifyOAuthRefreshToken(tokenString string) (userID, client
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithKey(alg, s.privateKey),
|
||||
jwt.WithAcceptableSkew(clockSkew),
|
||||
jwt.WithIssuer(common.EnvConfig.AppURL),
|
||||
jwt.WithIssuer(s.envConfig.AppURL),
|
||||
jwt.WithValidator(TokenTypeValidator(OAuthRefreshTokenJWTType)),
|
||||
)
|
||||
if err != nil {
|
||||
@@ -488,7 +496,7 @@ func (s *JwtService) GetPublicJWK() (jwk.Key, error) {
|
||||
return nil, fmt.Errorf("failed to get public key: %w", err)
|
||||
}
|
||||
|
||||
utils.EnsureAlgInKey(pubKey)
|
||||
jwkutils.EnsureAlgInKey(pubKey, "", "")
|
||||
|
||||
return pubKey, nil
|
||||
}
|
||||
@@ -517,56 +525,6 @@ func (s *JwtService) GetKeyAlg() (jwa.KeyAlgorithm, error) {
|
||||
return alg, nil
|
||||
}
|
||||
|
||||
func (s *JwtService) loadKeyJWK(path string) (jwk.Key, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read key data: %w", err)
|
||||
}
|
||||
|
||||
key, err := jwk.ParseKey(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse key: %w", err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (s *JwtService) generateNewRSAKey() (jwk.Key, error) {
|
||||
// We generate RSA keys only
|
||||
rawKey, err := rsa.GenerateKey(rand.Reader, RsaKeySize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate RSA private key: %w", err)
|
||||
}
|
||||
|
||||
// Import the raw key
|
||||
return utils.ImportRawKey(rawKey)
|
||||
}
|
||||
|
||||
// SaveKeyJWK saves a JWK to a file
|
||||
func SaveKeyJWK(key jwk.Key, path string) error {
|
||||
dir := filepath.Dir(path)
|
||||
err := os.MkdirAll(dir, 0700)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create directory '%s' for key file: %w", dir, err)
|
||||
}
|
||||
|
||||
keyFile, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create key file: %w", err)
|
||||
}
|
||||
defer keyFile.Close()
|
||||
|
||||
// Write the JSON file to disk
|
||||
enc := json.NewEncoder(keyFile)
|
||||
enc.SetEscapeHTML(false)
|
||||
err = enc.Encode(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write key file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetIsAdmin returns the value of the "isAdmin" claim in the token
|
||||
func GetIsAdmin(token jwt.Token) (bool, error) {
|
||||
if !token.Has(IsAdminClaim) {
|
||||
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
|
||||
)
|
||||
|
||||
func TestJwtService_Init(t *testing.T) {
|
||||
@@ -33,9 +33,16 @@ func TestJwtService_Init(t *testing.T) {
|
||||
// Create a temporary directory for the test
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Setup the environment variable required by the token verification
|
||||
mockEnvConfig := &common.EnvConfigSchema{
|
||||
AppURL: "https://test.example.com",
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
}
|
||||
|
||||
// Initialize the JWT service
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Verify the private key was set
|
||||
@@ -66,9 +73,16 @@ func TestJwtService_Init(t *testing.T) {
|
||||
// Create a temporary directory for the test
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Setup the environment variable required by the token verification
|
||||
mockEnvConfig := &common.EnvConfigSchema{
|
||||
AppURL: "https://test.example.com",
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
}
|
||||
|
||||
// First create a service to generate a key
|
||||
firstService := &JwtService{}
|
||||
err := firstService.init(mockConfig, tempDir)
|
||||
err := firstService.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get the key ID of the first service
|
||||
@@ -77,7 +91,7 @@ func TestJwtService_Init(t *testing.T) {
|
||||
|
||||
// Now create a new service that should load the existing key
|
||||
secondService := &JwtService{}
|
||||
err = secondService.init(mockConfig, tempDir)
|
||||
err = secondService.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the loaded key has the same ID as the original
|
||||
@@ -90,12 +104,19 @@ func TestJwtService_Init(t *testing.T) {
|
||||
// Create a temporary directory for the test
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Setup the environment variable required by the token verification
|
||||
mockEnvConfig := &common.EnvConfigSchema{
|
||||
AppURL: "https://test.example.com",
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
}
|
||||
|
||||
// Create a new JWK and save it to disk
|
||||
origKeyID := createECDSAKeyJWK(t, tempDir)
|
||||
|
||||
// Now create a new service that should load the existing key
|
||||
svc := &JwtService{}
|
||||
err := svc.init(mockConfig, tempDir)
|
||||
err := svc.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Ensure loaded key has the right algorithm
|
||||
@@ -113,12 +134,19 @@ func TestJwtService_Init(t *testing.T) {
|
||||
// Create a temporary directory for the test
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Setup the environment variable required by the token verification
|
||||
mockEnvConfig := &common.EnvConfigSchema{
|
||||
AppURL: "https://test.example.com",
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
}
|
||||
|
||||
// Create a new JWK and save it to disk
|
||||
origKeyID := createEdDSAKeyJWK(t, tempDir)
|
||||
|
||||
// Now create a new service that should load the existing key
|
||||
svc := &JwtService{}
|
||||
err := svc.init(mockConfig, tempDir)
|
||||
err := svc.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Ensure loaded key has the right algorithm and curve
|
||||
@@ -147,9 +175,16 @@ func TestJwtService_GetPublicJWK(t *testing.T) {
|
||||
// Create a temporary directory for the test
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Setup the environment variable required by the token verification
|
||||
mockEnvConfig := &common.EnvConfigSchema{
|
||||
AppURL: "https://test.example.com",
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
}
|
||||
|
||||
// Create a JWT service with initialized key
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Get the JWK (public key)
|
||||
@@ -178,12 +213,19 @@ func TestJwtService_GetPublicJWK(t *testing.T) {
|
||||
// Create a temporary directory for the test
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Setup the environment variable required by the token verification
|
||||
mockEnvConfig := &common.EnvConfigSchema{
|
||||
AppURL: "https://test.example.com",
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
}
|
||||
|
||||
// Create an ECDSA key and save it as JWK
|
||||
originalKeyID := createECDSAKeyJWK(t, tempDir)
|
||||
|
||||
// Create a JWT service that loads the ECDSA key
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Get the JWK (public key)
|
||||
@@ -216,12 +258,19 @@ func TestJwtService_GetPublicJWK(t *testing.T) {
|
||||
// Create a temporary directory for the test
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Setup the environment variable required by the token verification
|
||||
mockEnvConfig := &common.EnvConfigSchema{
|
||||
AppURL: "https://test.example.com",
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
}
|
||||
|
||||
// Create an EdDSA key and save it as JWK
|
||||
originalKeyID := createEdDSAKeyJWK(t, tempDir)
|
||||
|
||||
// Create a JWT service that loads the EdDSA key
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Get the JWK (public key)
|
||||
@@ -276,16 +325,16 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
||||
})
|
||||
|
||||
// Setup the environment variable required by the token verification
|
||||
originalAppURL := common.EnvConfig.AppURL
|
||||
common.EnvConfig.AppURL = "https://test.example.com"
|
||||
defer func() {
|
||||
common.EnvConfig.AppURL = originalAppURL
|
||||
}()
|
||||
mockEnvConfig := &common.EnvConfigSchema{
|
||||
AppURL: "https://test.example.com",
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
}
|
||||
|
||||
t.Run("generates token for regular user", func(t *testing.T) {
|
||||
// Create a JWT service
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Create a test user
|
||||
@@ -328,7 +377,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
||||
t.Run("generates token for admin user", func(t *testing.T) {
|
||||
// Create a JWT service
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Create a test admin user
|
||||
@@ -364,7 +413,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
||||
})
|
||||
|
||||
service := &JwtService{}
|
||||
err := service.init(customMockConfig, tempDir)
|
||||
err := service.init(nil, customMockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Create a test user
|
||||
@@ -399,7 +448,10 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
||||
|
||||
// Create a JWT service that loads the key
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
})
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Verify it loaded the right key
|
||||
@@ -453,7 +505,10 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
||||
|
||||
// Create a JWT service that loads the key
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
})
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Verify it loaded the right key
|
||||
@@ -507,7 +562,10 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
||||
|
||||
// Create a JWT service that loads the key
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
})
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Verify it loaded the right key
|
||||
@@ -563,16 +621,16 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
})
|
||||
|
||||
// Setup the environment variable required by the token verification
|
||||
originalAppURL := common.EnvConfig.AppURL
|
||||
common.EnvConfig.AppURL = "https://test.example.com"
|
||||
defer func() {
|
||||
common.EnvConfig.AppURL = originalAppURL
|
||||
}()
|
||||
mockEnvConfig := &common.EnvConfigSchema{
|
||||
AppURL: "https://test.example.com",
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
}
|
||||
|
||||
t.Run("generates and verifies ID token with standard claims", func(t *testing.T) {
|
||||
// Create a JWT service
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Create test claims
|
||||
@@ -601,7 +659,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
assert.Equal(t, []string{clientID}, audience, "Audience should contain the client ID")
|
||||
issuer, ok := claims.Issuer()
|
||||
_ = assert.True(t, ok, "Issuer not found in token") &&
|
||||
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
|
||||
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
|
||||
|
||||
// Check token expiration time is approximately 1 hour from now
|
||||
expectedExp := time.Now().Add(1 * time.Hour)
|
||||
@@ -614,7 +672,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
t.Run("can accept expired tokens if told so", func(t *testing.T) {
|
||||
// Create a JWT service
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Create test claims
|
||||
@@ -628,7 +686,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
// Create a token that's already expired
|
||||
token, err := jwt.NewBuilder().
|
||||
Subject(userClaims["sub"].(string)).
|
||||
Issuer(common.EnvConfig.AppURL).
|
||||
Issuer(service.envConfig.AppURL).
|
||||
Audience([]string{clientID}).
|
||||
IssuedAt(time.Now().Add(-2 * time.Hour)).
|
||||
Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago
|
||||
@@ -666,13 +724,13 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
assert.Equal(t, userClaims["sub"], subject, "Token subject should match user ID")
|
||||
issuer, ok := claims.Issuer()
|
||||
_ = assert.True(t, ok, "Issuer not found in token") &&
|
||||
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
|
||||
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
|
||||
})
|
||||
|
||||
t.Run("generates and verifies ID token with nonce", func(t *testing.T) {
|
||||
// Create a JWT service
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Create test claims with nonce
|
||||
@@ -703,7 +761,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
t.Run("fails verification with incorrect issuer", func(t *testing.T) {
|
||||
// Create a JWT service
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Generate a token with standard claims
|
||||
@@ -714,7 +772,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
require.NoError(t, err, "Failed to generate ID token")
|
||||
|
||||
// Temporarily change the app URL to simulate wrong issuer
|
||||
common.EnvConfig.AppURL = "https://wrong-issuer.com"
|
||||
service.envConfig.AppURL = "https://wrong-issuer.com"
|
||||
|
||||
// Verify should fail due to issuer mismatch
|
||||
_, err = service.VerifyIdToken(tokenString, false)
|
||||
@@ -731,7 +789,10 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
|
||||
// Create a JWT service that loads the key
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
})
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Verify it loaded the right key
|
||||
@@ -762,7 +823,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
assert.Equal(t, "eddsauser456", subject, "Token subject should match user ID")
|
||||
issuer, ok := claims.Issuer()
|
||||
_ = assert.True(t, ok, "Issuer not found in token") &&
|
||||
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
|
||||
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
|
||||
|
||||
// Verify the key type is OKP
|
||||
publicKey, err := service.GetPublicJWK()
|
||||
@@ -784,7 +845,10 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
|
||||
// Create a JWT service that loads the key
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
})
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Verify it loaded the right key
|
||||
@@ -795,7 +859,6 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
// Create test claims
|
||||
userClaims := map[string]interface{}{
|
||||
"sub": "ecdsauser456",
|
||||
"name": "ECDSA User",
|
||||
"email": "ecdsauser@example.com",
|
||||
}
|
||||
const clientID = "ecdsa-client-123"
|
||||
@@ -815,7 +878,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
assert.Equal(t, "ecdsauser456", subject, "Token subject should match user ID")
|
||||
issuer, ok := claims.Issuer()
|
||||
_ = assert.True(t, ok, "Issuer not found in token") &&
|
||||
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
|
||||
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
|
||||
|
||||
// Verify the key type is EC
|
||||
publicKey, err := service.GetPublicJWK()
|
||||
@@ -837,7 +900,10 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
|
||||
// Create a JWT service that loads the key
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
})
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Verify it loaded the right key
|
||||
@@ -868,17 +934,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
assert.Equal(t, "rsauser456", subject, "Token subject should match user ID")
|
||||
issuer, ok := claims.Issuer()
|
||||
_ = assert.True(t, ok, "Issuer not found in token") &&
|
||||
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
|
||||
|
||||
// Verify the key type is RSA
|
||||
publicKey, err := service.GetPublicJWK()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, jwa.RSA().String(), publicKey.KeyType().String(), "Key type should be RSA")
|
||||
|
||||
// Verify the algorithm is RS256
|
||||
alg, ok := publicKey.Algorithm()
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, jwa.RS256().String(), alg.String(), "Algorithm should be RS256")
|
||||
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -892,16 +948,16 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
|
||||
})
|
||||
|
||||
// Setup the environment variable required by the token verification
|
||||
originalAppURL := common.EnvConfig.AppURL
|
||||
common.EnvConfig.AppURL = "https://test.example.com"
|
||||
defer func() {
|
||||
common.EnvConfig.AppURL = originalAppURL
|
||||
}()
|
||||
mockEnvConfig := &common.EnvConfigSchema{
|
||||
AppURL: "https://test.example.com",
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
}
|
||||
|
||||
t.Run("generates and verifies OAuth access token with standard claims", func(t *testing.T) {
|
||||
// Create a JWT service
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Create a test user
|
||||
@@ -931,7 +987,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
|
||||
assert.Equal(t, []string{clientID}, audience, "Audience should contain the client ID")
|
||||
issuer, ok := claims.Issuer()
|
||||
_ = assert.True(t, ok, "Issuer not found in token") &&
|
||||
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
|
||||
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
|
||||
|
||||
// Check token expiration time is approximately 1 hour from now
|
||||
expectedExp := time.Now().Add(1 * time.Hour)
|
||||
@@ -944,7 +1000,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
|
||||
t.Run("fails verification for expired token", func(t *testing.T) {
|
||||
// Create a JWT service with a mock function to generate an expired token
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Create a test user
|
||||
@@ -961,7 +1017,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
|
||||
Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago
|
||||
IssuedAt(time.Now().Add(-2 * time.Hour)).
|
||||
Audience([]string{clientID}).
|
||||
Issuer(common.EnvConfig.AppURL).
|
||||
Issuer(service.envConfig.AppURL).
|
||||
Build()
|
||||
require.NoError(t, err, "Failed to build token")
|
||||
|
||||
@@ -980,11 +1036,17 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
|
||||
t.Run("fails verification with invalid signature", func(t *testing.T) {
|
||||
// Create two JWT services with different keys
|
||||
service1 := &JwtService{}
|
||||
err := service1.init(mockConfig, t.TempDir()) // Use a different temp dir
|
||||
err := service1.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: t.TempDir(), // Use a different temp dir
|
||||
})
|
||||
require.NoError(t, err, "Failed to initialize first JWT service")
|
||||
|
||||
service2 := &JwtService{}
|
||||
err = service2.init(mockConfig, t.TempDir()) // Use a different temp dir
|
||||
err = service2.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: t.TempDir(), // Use a different temp dir
|
||||
})
|
||||
require.NoError(t, err, "Failed to initialize second JWT service")
|
||||
|
||||
// Create a test user
|
||||
@@ -1014,7 +1076,10 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
|
||||
|
||||
// Create a JWT service that loads the key
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
})
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Verify it loaded the right key
|
||||
@@ -1068,7 +1133,10 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
|
||||
|
||||
// Create a JWT service that loads the key
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
})
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Verify it loaded the right key
|
||||
@@ -1122,7 +1190,10 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
|
||||
|
||||
// Create a JWT service that loads the key
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
})
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Verify it loaded the right key
|
||||
@@ -1176,16 +1247,16 @@ func TestGenerateVerifyOAuthRefreshToken(t *testing.T) {
|
||||
mockConfig := NewTestAppConfigService(&model.AppConfig{})
|
||||
|
||||
// Setup the environment variable required by the token verification
|
||||
originalAppURL := common.EnvConfig.AppURL
|
||||
common.EnvConfig.AppURL = "https://test.example.com"
|
||||
defer func() {
|
||||
common.EnvConfig.AppURL = originalAppURL
|
||||
}()
|
||||
mockEnvConfig := &common.EnvConfigSchema{
|
||||
AppURL: "https://test.example.com",
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
}
|
||||
|
||||
t.Run("generates and verifies refresh token", func(t *testing.T) {
|
||||
// Create a JWT service
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Create a test user
|
||||
@@ -1211,7 +1282,7 @@ func TestGenerateVerifyOAuthRefreshToken(t *testing.T) {
|
||||
t.Run("fails verification for expired token", func(t *testing.T) {
|
||||
// Create a JWT service
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, mockEnvConfig)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Generate a token using JWT directly to create an expired token
|
||||
@@ -1220,7 +1291,7 @@ func TestGenerateVerifyOAuthRefreshToken(t *testing.T) {
|
||||
Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago
|
||||
IssuedAt(time.Now().Add(-2 * time.Hour)).
|
||||
Audience([]string{"client123"}).
|
||||
Issuer(common.EnvConfig.AppURL).
|
||||
Issuer(service.envConfig.AppURL).
|
||||
Build()
|
||||
require.NoError(t, err, "Failed to build token")
|
||||
|
||||
@@ -1236,11 +1307,17 @@ func TestGenerateVerifyOAuthRefreshToken(t *testing.T) {
|
||||
t.Run("fails verification with invalid signature", func(t *testing.T) {
|
||||
// Create two JWT services with different keys
|
||||
service1 := &JwtService{}
|
||||
err := service1.init(mockConfig, t.TempDir())
|
||||
err := service1.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: t.TempDir(), // Use a different temp dir
|
||||
})
|
||||
require.NoError(t, err, "Failed to initialize first JWT service")
|
||||
|
||||
service2 := &JwtService{}
|
||||
err = service2.init(mockConfig, t.TempDir())
|
||||
err = service2.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: t.TempDir(), // Use a different temp dir
|
||||
})
|
||||
require.NoError(t, err, "Failed to initialize second JWT service")
|
||||
|
||||
// Generate a token with the first service
|
||||
@@ -1308,7 +1385,10 @@ func TestGetTokenType(t *testing.T) {
|
||||
// Initialize the JWT service
|
||||
mockConfig := NewTestAppConfigService(&model.AppConfig{})
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: tempDir,
|
||||
})
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
buildTokenForType := func(t *testing.T, typ string, setClaimsFn func(b *jwt.Builder)) string {
|
||||
@@ -1402,10 +1482,19 @@ func TestGetTokenType(t *testing.T) {
|
||||
func importKey(t *testing.T, privateKeyRaw any, path string) string {
|
||||
t.Helper()
|
||||
|
||||
privateKey, err := utils.ImportRawKey(privateKeyRaw)
|
||||
privateKey, err := jwkutils.ImportRawKey(privateKeyRaw, "", "")
|
||||
require.NoError(t, err, "Failed to import private key")
|
||||
|
||||
err = SaveKeyJWK(privateKey, filepath.Join(path, PrivateKeyFile))
|
||||
keyProvider := &jwkutils.KeyProviderFile{}
|
||||
err = keyProvider.Init(jwkutils.KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysStorage: "file",
|
||||
KeysPath: path,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err, "Failed to init file key provider")
|
||||
|
||||
err = keyProvider.SaveKey(privateKey)
|
||||
require.NoError(t, err, "Failed to save key")
|
||||
|
||||
kid, _ := privateKey.KeyID()
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
|
||||
)
|
||||
|
||||
// generateTestECDSAKey creates an ECDSA key for testing
|
||||
@@ -62,12 +63,12 @@ func TestOidcService_jwkSetForURL(t *testing.T) {
|
||||
)
|
||||
mockResponses := map[string]*http.Response{
|
||||
//nolint:bodyclose
|
||||
url1: NewMockResponse(http.StatusOK, string(jwkSetJSON1)),
|
||||
url1: testutils.NewMockResponse(http.StatusOK, string(jwkSetJSON1)),
|
||||
//nolint:bodyclose
|
||||
url2: NewMockResponse(http.StatusOK, string(jwkSetJSON2)),
|
||||
url2: testutils.NewMockResponse(http.StatusOK, string(jwkSetJSON2)),
|
||||
}
|
||||
httpClient := &http.Client{
|
||||
Transport: &MockRoundTripper{
|
||||
Transport: &testutils.MockRoundTripper{
|
||||
Responses: mockResponses,
|
||||
},
|
||||
}
|
||||
@@ -139,7 +140,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
|
||||
var err error
|
||||
// Create a test database
|
||||
db := newDatabaseForTest(t)
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
|
||||
// Create two JWKs for testing
|
||||
privateJWK, jwkSetJSON := generateTestECDSAKey(t)
|
||||
@@ -149,12 +150,12 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
|
||||
// Create a mock HTTP client with custom transport to return the JWKS
|
||||
httpClient := &http.Client{
|
||||
Transport: &MockRoundTripper{
|
||||
Transport: &testutils.MockRoundTripper{
|
||||
Responses: map[string]*http.Response{
|
||||
//nolint:bodyclose
|
||||
federatedClientIssuer + "/jwks.json": NewMockResponse(http.StatusOK, string(jwkSetJSON)),
|
||||
federatedClientIssuer + "/jwks.json": testutils.NewMockResponse(http.StatusOK, string(jwkSetJSON)),
|
||||
//nolint:bodyclose
|
||||
federatedClientIssuerDefaults + ".well-known/jwks.json": NewMockResponse(http.StatusOK, string(jwkSetJSONDefaults)),
|
||||
federatedClientIssuerDefaults + ".well-known/jwks.json": testutils.NewMockResponse(http.StatusOK, string(jwkSetJSONDefaults)),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
69
backend/internal/utils/crypto/crypto.go
Normal file
69
backend/internal/utils/crypto/crypto.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// ErrDecrypt is returned by Decrypt when the operation failed for any reason
|
||||
var ErrDecrypt = errors.New("failed to decrypt data")
|
||||
|
||||
// Encrypt a byte slice using AES-GCM and a random nonce
|
||||
// Important: do not encrypt more than ~4 billion messages with the same key!
|
||||
func Encrypt(key []byte, plaintext []byte, associatedData []byte) (ciphertext []byte, err error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create block cipher: %w", err)
|
||||
}
|
||||
aead, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create AEAD cipher: %w", err)
|
||||
}
|
||||
|
||||
// Generate a random nonce
|
||||
nonce := make([]byte, aead.NonceSize())
|
||||
_, err = io.ReadFull(rand.Reader, nonce)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate random nonce: %w", err)
|
||||
}
|
||||
|
||||
// Allocate the slice for the result, with additional space for the nonce and overhead
|
||||
ciphertext = make([]byte, 0, len(plaintext)+aead.NonceSize()+aead.Overhead())
|
||||
ciphertext = append(ciphertext, nonce...)
|
||||
|
||||
// Encrypt the plaintext
|
||||
// Tag is automatically added at the end
|
||||
ciphertext = aead.Seal(ciphertext, nonce, plaintext, associatedData)
|
||||
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
// Decrypt a byte slice using AES-GCM
|
||||
func Decrypt(key []byte, ciphertext []byte, associatedData []byte) (plaintext []byte, err error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create block cipher: %w", err)
|
||||
}
|
||||
aead, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create AEAD cipher: %w", err)
|
||||
}
|
||||
|
||||
// Extract the nonce
|
||||
if len(ciphertext) < (aead.NonceSize() + aead.Overhead()) {
|
||||
return nil, ErrDecrypt
|
||||
}
|
||||
|
||||
// Decrypt the data
|
||||
plaintext, err = aead.Open(nil, ciphertext[:aead.NonceSize()], ciphertext[aead.NonceSize():], associatedData)
|
||||
if err != nil {
|
||||
// Note: we do not return the exact error here, to avoid disclosing information
|
||||
return nil, ErrDecrypt
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
208
backend/internal/utils/crypto/crypto_test.go
Normal file
208
backend/internal/utils/crypto/crypto_test.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEncryptDecrypt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keySize int
|
||||
plaintext string
|
||||
associatedData []byte
|
||||
}{
|
||||
{
|
||||
name: "AES-128 with short plaintext",
|
||||
keySize: 16,
|
||||
plaintext: "Hello, World!",
|
||||
associatedData: []byte("test-aad"),
|
||||
},
|
||||
{
|
||||
name: "AES-192 with medium plaintext",
|
||||
keySize: 24,
|
||||
plaintext: "This is a longer message to test encryption and decryption",
|
||||
associatedData: []byte("associated-data-192"),
|
||||
},
|
||||
{
|
||||
name: "AES-256 with unicode",
|
||||
keySize: 32,
|
||||
plaintext: "Hello 世界! 🌍 Testing unicode characters", //nolint:gosmopolitan
|
||||
associatedData: []byte("unicode-test"),
|
||||
},
|
||||
{
|
||||
name: "No associated data",
|
||||
keySize: 32,
|
||||
plaintext: "Testing without associated data",
|
||||
associatedData: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Generate random key
|
||||
key := make([]byte, tt.keySize)
|
||||
_, err := rand.Read(key)
|
||||
require.NoError(t, err, "Failed to generate random key")
|
||||
|
||||
plaintext := []byte(tt.plaintext)
|
||||
|
||||
// Test encryption
|
||||
ciphertext, err := Encrypt(key, plaintext, tt.associatedData)
|
||||
require.NoError(t, err, "Encrypt should succeed")
|
||||
|
||||
// Verify ciphertext is different from plaintext (unless empty)
|
||||
if len(plaintext) > 0 {
|
||||
assert.NotEqual(t, plaintext, ciphertext)
|
||||
}
|
||||
|
||||
// Test decryption
|
||||
decrypted, err := Decrypt(key, ciphertext, tt.associatedData)
|
||||
require.NoError(t, err, "Decrypt should succeed")
|
||||
|
||||
// Verify decrypted text matches original
|
||||
assert.Equal(t, plaintext, decrypted, "Decrypted text should match original")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptWithInvalidKeySize(t *testing.T) {
|
||||
invalidKeySizes := []int{8, 12, 33, 47, 55, 128}
|
||||
|
||||
for _, keySize := range invalidKeySizes {
|
||||
t.Run(fmt.Sprintf("Key size %d", keySize), func(t *testing.T) {
|
||||
key := make([]byte, keySize)
|
||||
plaintext := []byte("test message")
|
||||
|
||||
_, err := Encrypt(key, plaintext, nil)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "invalid key size")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptWithInvalidKeySize(t *testing.T) {
|
||||
invalidKeySizes := []int{8, 12, 33, 47, 55, 128}
|
||||
|
||||
for _, keySize := range invalidKeySizes {
|
||||
t.Run(fmt.Sprintf("Key size %d", keySize), func(t *testing.T) {
|
||||
key := make([]byte, keySize)
|
||||
ciphertext := []byte("fake ciphertext")
|
||||
|
||||
_, err := Decrypt(key, ciphertext, nil)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "invalid key size")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptWithInvalidCiphertext(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
_, err := rand.Read(key)
|
||||
require.NoError(t, err, "Failed to generate random key")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ciphertext []byte
|
||||
}{
|
||||
{
|
||||
name: "empty ciphertext",
|
||||
ciphertext: []byte{},
|
||||
},
|
||||
{
|
||||
name: "too short ciphertext",
|
||||
ciphertext: []byte("short"),
|
||||
},
|
||||
{
|
||||
name: "random invalid data",
|
||||
ciphertext: []byte("this is not valid encrypted data"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := Decrypt(key, tt.ciphertext, nil)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, ErrDecrypt)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptWithWrongKey(t *testing.T) {
|
||||
// Generate two different keys
|
||||
key1 := make([]byte, 32)
|
||||
key2 := make([]byte, 32)
|
||||
_, err := rand.Read(key1)
|
||||
require.NoError(t, err)
|
||||
_, err = rand.Read(key2)
|
||||
require.NoError(t, err)
|
||||
|
||||
plaintext := []byte("secret message")
|
||||
|
||||
// Encrypt with key1
|
||||
ciphertext, err := Encrypt(key1, plaintext, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to decrypt with key2
|
||||
_, err = Decrypt(key2, ciphertext, nil)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, ErrDecrypt)
|
||||
}
|
||||
|
||||
func TestDecryptWithWrongAssociatedData(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
_, err := rand.Read(key)
|
||||
require.NoError(t, err, "Failed to generate random key")
|
||||
|
||||
plaintext := []byte("secret message")
|
||||
correctAAD := []byte("correct-aad")
|
||||
wrongAAD := []byte("wrong-aad")
|
||||
|
||||
// Encrypt with correct AAD
|
||||
ciphertext, err := Encrypt(key, plaintext, correctAAD)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to decrypt with wrong AAD
|
||||
_, err = Decrypt(key, ciphertext, wrongAAD)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, ErrDecrypt)
|
||||
|
||||
// Verify correct AAD works
|
||||
decrypted, err := Decrypt(key, ciphertext, correctAAD)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, plaintext, decrypted, "Decrypted text should match original when using correct AAD")
|
||||
}
|
||||
|
||||
func TestEncryptDecryptConsistency(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
_, err := rand.Read(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
plaintext := []byte("consistency test message")
|
||||
associatedData := []byte("test-aad")
|
||||
|
||||
// Encrypt multiple times and verify we get different ciphertexts (due to random IV)
|
||||
ciphertext1, err := Encrypt(key, plaintext, associatedData)
|
||||
require.NoError(t, err)
|
||||
|
||||
ciphertext2, err := Encrypt(key, plaintext, associatedData)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Ciphertexts should be different (due to random IV)
|
||||
assert.NotEqual(t, ciphertext1, ciphertext2, "Multiple encryptions of same plaintext should produce different ciphertexts")
|
||||
|
||||
// Both should decrypt to the same plaintext
|
||||
decrypted1, err := Decrypt(key, ciphertext1, associatedData)
|
||||
require.NoError(t, err)
|
||||
|
||||
decrypted2, err := Decrypt(key, ciphertext2, associatedData)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, plaintext, decrypted1, "First decrypted text should match original")
|
||||
assert.Equal(t, plaintext, decrypted2, "Second decrypted text should match original")
|
||||
assert.Equal(t, decrypted1, decrypted2, "Both decrypted texts should be identical")
|
||||
}
|
||||
50
backend/internal/utils/jwk/key_provider.go
Normal file
50
backend/internal/utils/jwk/key_provider.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
)
|
||||
|
||||
type KeyProviderOpts struct {
|
||||
EnvConfig *common.EnvConfigSchema
|
||||
DB *gorm.DB
|
||||
Kek []byte
|
||||
}
|
||||
|
||||
type KeyProvider interface {
|
||||
Init(opts KeyProviderOpts) error
|
||||
LoadKey() (jwk.Key, error)
|
||||
SaveKey(key jwk.Key) error
|
||||
}
|
||||
|
||||
func GetKeyProvider(db *gorm.DB, envConfig *common.EnvConfigSchema, instanceID string) (keyProvider KeyProvider, err error) {
|
||||
// Load the encryption key (KEK) if present
|
||||
kek, err := LoadKeyEncryptionKey(envConfig, instanceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load encryption key: %w", err)
|
||||
}
|
||||
|
||||
// Get the key provider
|
||||
switch envConfig.KeysStorage {
|
||||
case "file", "":
|
||||
keyProvider = &KeyProviderFile{}
|
||||
case "database":
|
||||
keyProvider = &KeyProviderDatabase{}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid key storage '%s'", envConfig.KeysStorage)
|
||||
}
|
||||
err = keyProvider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
EnvConfig: envConfig,
|
||||
Kek: kek,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to init key provider of type '%s': %w", envConfig.KeysStorage, err)
|
||||
}
|
||||
|
||||
return keyProvider, nil
|
||||
}
|
||||
109
backend/internal/utils/jwk/key_provider_database.go
Normal file
109
backend/internal/utils/jwk/key_provider_database.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
|
||||
)
|
||||
|
||||
const PrivateKeyDBKey = "jwt_private_key.json"
|
||||
|
||||
type KeyProviderDatabase struct {
|
||||
db *gorm.DB
|
||||
kek []byte
|
||||
}
|
||||
|
||||
func (f *KeyProviderDatabase) Init(opts KeyProviderOpts) error {
|
||||
if len(opts.Kek) == 0 {
|
||||
return errors.New("an encryption key is required when using the 'database' key provider")
|
||||
}
|
||||
|
||||
f.db = opts.DB
|
||||
f.kek = opts.Kek
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *KeyProviderDatabase) LoadKey() (key jwk.Key, err error) {
|
||||
row := model.KV{
|
||||
Key: PrivateKeyDBKey,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
err = f.db.WithContext(ctx).First(&row).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// Key not present in the database - return nil so a new one can be generated
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve private key from the database: %w", err)
|
||||
}
|
||||
|
||||
if row.Value == nil || *row.Value == "" {
|
||||
// Key not present in the database - return nil so a new one can be generated
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Decode from base64
|
||||
enc, err := base64.StdEncoding.DecodeString(*row.Value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read encrypted private key: not a valid base64-encoded value: %w", err)
|
||||
}
|
||||
|
||||
// Decrypt the data
|
||||
data, err := cryptoutils.Decrypt(f.kek, enc, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt private key: %w", err)
|
||||
}
|
||||
|
||||
// Parse the key
|
||||
key, err = jwk.ParseKey(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse encrypted private key: %w", err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (f *KeyProviderDatabase) SaveKey(key jwk.Key) error {
|
||||
// Encode the key to JSON
|
||||
data, err := EncodeJWKBytes(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode key to JSON: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt the key then encode to Base64
|
||||
enc, err := cryptoutils.Encrypt(f.kek, data, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt key: %w", err)
|
||||
}
|
||||
encB64 := base64.StdEncoding.EncodeToString(enc)
|
||||
|
||||
// Save to database
|
||||
row := model.KV{
|
||||
Key: PrivateKeyDBKey,
|
||||
Value: &encB64,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
err = f.db.WithContext(ctx).Create(&row).Error
|
||||
if err != nil {
|
||||
// There's one scenario where if Pocket ID is started fresh with more than 1 replica, they both could be trying to create the private key in the database at the same time
|
||||
// In this case, only one of the replicas will succeed; the other one(s) will return an error here, which will cascade down and cause the replica(s) to crash and be restarted (at that point they'll load the then-existing key from the database)
|
||||
return fmt.Errorf("failed to store private key in database: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Compile-time interface check
|
||||
var _ KeyProvider = (*KeyProviderDatabase)(nil)
|
||||
275
backend/internal/utils/jwk/key_provider_database_test.go
Normal file
275
backend/internal/utils/jwk/key_provider_database_test.go
Normal file
@@ -0,0 +1,275 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
|
||||
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestKeyProviderDatabase_Init(t *testing.T) {
|
||||
t.Run("Init fails when KEK is not provided", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
provider := &KeyProviderDatabase{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: nil, // No KEK
|
||||
})
|
||||
require.Error(t, err, "Expected error when KEK is not provided")
|
||||
require.ErrorContains(t, err, "encryption key is required")
|
||||
})
|
||||
|
||||
t.Run("Init succeeds with KEK", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
provider := &KeyProviderDatabase{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: generateTestKEK(t),
|
||||
})
|
||||
require.NoError(t, err, "Expected no error when KEK is provided")
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeyProviderDatabase_LoadKey(t *testing.T) {
|
||||
// Generate a test key to use in our tests
|
||||
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := jwk.Import(pk)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("LoadKey with no existing key", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
kek := generateTestKEK(t)
|
||||
|
||||
provider := &KeyProviderDatabase{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: kek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Load key when none exists
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, loadedKey, "Expected nil key when no key exists in database")
|
||||
})
|
||||
|
||||
t.Run("LoadKey with existing key", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
kek := generateTestKEK(t)
|
||||
|
||||
provider := &KeyProviderDatabase{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: kek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save a key
|
||||
err = provider.SaveKey(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Load the key
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, loadedKey, "Expected non-nil key when key exists in database")
|
||||
|
||||
// Verify the loaded key is the same as the original
|
||||
keyBytes, err := EncodeJWKBytes(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key")
|
||||
})
|
||||
|
||||
t.Run("LoadKey with invalid base64", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
kek := generateTestKEK(t)
|
||||
|
||||
provider := &KeyProviderDatabase{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: kek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert invalid base64 data
|
||||
invalidBase64 := "not-valid-base64"
|
||||
err = db.Create(&model.KV{
|
||||
Key: PrivateKeyDBKey,
|
||||
Value: &invalidBase64,
|
||||
}).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attempt to load the key
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.Error(t, err, "Expected error when loading key with invalid base64")
|
||||
require.ErrorContains(t, err, "not a valid base64-encoded value")
|
||||
assert.Nil(t, loadedKey, "Expected nil key when loading fails")
|
||||
})
|
||||
|
||||
t.Run("LoadKey with invalid encrypted data", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
kek := generateTestKEK(t)
|
||||
|
||||
provider := &KeyProviderDatabase{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: kek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert valid base64 but invalid encrypted data
|
||||
invalidData := base64.StdEncoding.EncodeToString([]byte("not-valid-encrypted-data"))
|
||||
err = db.Create(&model.KV{
|
||||
Key: PrivateKeyDBKey,
|
||||
Value: &invalidData,
|
||||
}).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attempt to load the key
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.Error(t, err, "Expected error when loading key with invalid encrypted data")
|
||||
require.ErrorContains(t, err, "failed to decrypt")
|
||||
assert.Nil(t, loadedKey, "Expected nil key when loading fails")
|
||||
})
|
||||
|
||||
t.Run("LoadKey with valid encrypted data but wrong KEK", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
originalKek := generateTestKEK(t)
|
||||
|
||||
// Save a key with the original KEK
|
||||
originalProvider := &KeyProviderDatabase{}
|
||||
err := originalProvider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: originalKek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = originalProvider.SaveKey(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now try to load with a different KEK
|
||||
differentKek := generateTestKEK(t)
|
||||
differentProvider := &KeyProviderDatabase{}
|
||||
err = differentProvider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: differentKek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attempt to load the key with the wrong KEK
|
||||
loadedKey, err := differentProvider.LoadKey()
|
||||
require.Error(t, err, "Expected error when loading key with wrong KEK")
|
||||
require.ErrorContains(t, err, "failed to decrypt")
|
||||
assert.Nil(t, loadedKey, "Expected nil key when loading fails")
|
||||
})
|
||||
|
||||
t.Run("LoadKey with invalid key data", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
kek := generateTestKEK(t)
|
||||
|
||||
provider := &KeyProviderDatabase{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: kek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create invalid key data (valid JSON but not a valid JWK)
|
||||
invalidKeyData := []byte(`{"not": "a valid jwk"}`)
|
||||
|
||||
// Encrypt the invalid key data
|
||||
encryptedData, err := cryptoutils.Encrypt(kek, invalidKeyData, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Base64 encode the encrypted data
|
||||
encodedData := base64.StdEncoding.EncodeToString(encryptedData)
|
||||
|
||||
// Save to database
|
||||
err = db.Create(&model.KV{
|
||||
Key: PrivateKeyDBKey,
|
||||
Value: &encodedData,
|
||||
}).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attempt to load the key
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.Error(t, err, "Expected error when loading invalid key data")
|
||||
require.ErrorContains(t, err, "failed to parse")
|
||||
assert.Nil(t, loadedKey, "Expected nil key when loading fails")
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeyProviderDatabase_SaveKey(t *testing.T) {
|
||||
// Generate a test key to use in our tests
|
||||
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := jwk.Import(pk)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("SaveKey and verify database record", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
kek := generateTestKEK(t)
|
||||
|
||||
provider := &KeyProviderDatabase{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
DB: db,
|
||||
Kek: kek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save the key
|
||||
err = provider.SaveKey(key)
|
||||
require.NoError(t, err, "Expected no error when saving key")
|
||||
|
||||
// Verify record exists in database
|
||||
var kv model.KV
|
||||
err = db.Where("key = ?", PrivateKeyDBKey).First(&kv).Error
|
||||
require.NoError(t, err, "Expected to find key in database")
|
||||
require.NotNil(t, kv.Value, "Expected non-nil value in database")
|
||||
assert.NotEmpty(t, *kv.Value, "Expected non-empty value in database")
|
||||
|
||||
// Decode and decrypt to verify content
|
||||
encBytes, err := base64.StdEncoding.DecodeString(*kv.Value)
|
||||
require.NoError(t, err, "Expected valid base64 encoding")
|
||||
|
||||
decBytes, err := cryptoutils.Decrypt(kek, encBytes, nil)
|
||||
require.NoError(t, err, "Expected valid encrypted data")
|
||||
|
||||
parsedKey, err := jwk.ParseKey(decBytes)
|
||||
require.NoError(t, err, "Expected valid JWK data")
|
||||
|
||||
// Compare keys
|
||||
keyBytes, err := EncodeJWKBytes(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
parsedKeyBytes, err := EncodeJWKBytes(parsedKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, keyBytes, parsedKeyBytes, "Expected saved key to match original key")
|
||||
})
|
||||
}
|
||||
|
||||
func generateTestKEK(t *testing.T) []byte {
|
||||
t.Helper()
|
||||
|
||||
// Generate a 32-byte kek
|
||||
kek := make([]byte, 32)
|
||||
_, err := rand.Read(kek)
|
||||
require.NoError(t, err)
|
||||
return kek
|
||||
}
|
||||
202
backend/internal/utils/jwk/key_provider_file.go
Normal file
202
backend/internal/utils/jwk/key_provider_file.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
|
||||
)
|
||||
|
||||
const (
|
||||
// PrivateKeyFile is the path in the data/keys folder where the key is stored
|
||||
// This is a JSON file containing a key encoded as JWK
|
||||
PrivateKeyFile = "jwt_private_key.json"
|
||||
|
||||
// PrivateKeyFileEncrypted is the path in the data/keys folder where the encrypted key is stored
|
||||
// This is a encrypted JSON file containing a key encoded as JWK
|
||||
PrivateKeyFileEncrypted = "jwt_private_key.json.enc"
|
||||
)
|
||||
|
||||
type KeyProviderFile struct {
|
||||
envConfig *common.EnvConfigSchema
|
||||
kek []byte
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) Init(opts KeyProviderOpts) error {
|
||||
f.envConfig = opts.EnvConfig
|
||||
f.kek = opts.Kek
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) LoadKey() (jwk.Key, error) {
|
||||
if len(f.kek) > 0 {
|
||||
return f.loadEncryptedKey()
|
||||
}
|
||||
return f.loadKey()
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) SaveKey(key jwk.Key) error {
|
||||
if len(f.kek) > 0 {
|
||||
return f.saveKeyEncrypted(key)
|
||||
}
|
||||
return f.saveKey(key)
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) loadKey() (jwk.Key, error) {
|
||||
var key jwk.Key
|
||||
|
||||
// First, check if we have a JWK file
|
||||
// If we do, then we just load that
|
||||
jwkPath := f.jwkPath()
|
||||
ok, err := utils.FileExists(jwkPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if private key file exists at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
if !ok {
|
||||
// File doesn't exist, no key was loaded
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(jwkPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read private key file at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
|
||||
key, err = jwk.ParseKey(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse private key file at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) loadEncryptedKey() (key jwk.Key, err error) {
|
||||
// First, check if we have an encrypted JWK file
|
||||
// If we do, then we just load that
|
||||
encJwkPath := f.encJwkPath()
|
||||
ok, err := utils.FileExists(encJwkPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if encrypted private key file exists at path '%s': %w", encJwkPath, err)
|
||||
}
|
||||
if ok {
|
||||
encB64, err := os.ReadFile(encJwkPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read encrypted private key file at path '%s': %w", encJwkPath, err)
|
||||
}
|
||||
|
||||
// Decode from base64
|
||||
enc := make([]byte, base64.StdEncoding.DecodedLen(len(encB64)))
|
||||
n, err := base64.StdEncoding.Decode(enc, encB64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read encrypted private key file at path '%s': not a valid base64-encoded file: %w", encJwkPath, err)
|
||||
}
|
||||
|
||||
// Decrypt the data
|
||||
data, err := cryptoutils.Decrypt(f.kek, enc[:n], nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt private key file at path '%s': %w", encJwkPath, err)
|
||||
}
|
||||
|
||||
// Parse the key
|
||||
key, err = jwk.ParseKey(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse encrypted private key file at path '%s': %w", encJwkPath, err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// Check if we have an un-encrypted JWK file
|
||||
key, err = f.loadKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load un-encrypted key file: %w", err)
|
||||
}
|
||||
if key == nil {
|
||||
// No key exists, encrypted or un-encrypted
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// If we are here, we have loaded a key that was un-encrypted
|
||||
// We need to replace the plaintext key with the encrypted one before we return
|
||||
err = f.saveKeyEncrypted(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to save encrypted key file: %w", err)
|
||||
}
|
||||
jwkPath := f.jwkPath()
|
||||
err = os.Remove(jwkPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to remove un-encrypted key file at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) saveKey(key jwk.Key) error {
|
||||
err := os.MkdirAll(f.envConfig.KeysPath, 0700)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create directory '%s' for key file: %w", f.envConfig.KeysPath, err)
|
||||
}
|
||||
|
||||
jwkPath := f.jwkPath()
|
||||
keyFile, err := os.OpenFile(jwkPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create key file at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
defer keyFile.Close()
|
||||
|
||||
// Write the JSON file to disk
|
||||
err = EncodeJWK(keyFile, key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write key file at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) saveKeyEncrypted(key jwk.Key) error {
|
||||
err := os.MkdirAll(f.envConfig.KeysPath, 0700)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create directory '%s' for encrypted key file: %w", f.envConfig.KeysPath, err)
|
||||
}
|
||||
|
||||
// Encode the key to JSON
|
||||
data, err := EncodeJWKBytes(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode key to JSON: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt the key then encode to Base64
|
||||
enc, err := cryptoutils.Encrypt(f.kek, data, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt key: %w", err)
|
||||
}
|
||||
encB64 := make([]byte, base64.StdEncoding.EncodedLen(len(enc)))
|
||||
base64.StdEncoding.Encode(encB64, enc)
|
||||
|
||||
// Write to disk
|
||||
encJwkPath := f.encJwkPath()
|
||||
err = os.WriteFile(encJwkPath, encB64, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write encrypted key file at path '%s': %w", encJwkPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) jwkPath() string {
|
||||
return filepath.Join(f.envConfig.KeysPath, PrivateKeyFile)
|
||||
}
|
||||
|
||||
func (f *KeyProviderFile) encJwkPath() string {
|
||||
return filepath.Join(f.envConfig.KeysPath, PrivateKeyFileEncrypted)
|
||||
}
|
||||
|
||||
// Compile-time interface check
|
||||
var _ KeyProvider = (*KeyProviderFile)(nil)
|
||||
320
backend/internal/utils/jwk/key_provider_file_test.go
Normal file
320
backend/internal/utils/jwk/key_provider_file_test.go
Normal file
@@ -0,0 +1,320 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
|
||||
)
|
||||
|
||||
func TestKeyProviderFile_LoadKey(t *testing.T) {
|
||||
// Generate a test key to use in our tests
|
||||
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := jwk.Import(pk)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("LoadKey with no existing key", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
provider := &KeyProviderFile{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Load key when none exists
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, loadedKey, "Expected nil key when no key exists")
|
||||
})
|
||||
|
||||
t.Run("LoadKey with no existing key (with kek)", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
provider := &KeyProviderFile{}
|
||||
err = provider.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
Kek: makeKEK(t),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Load key when none exists
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, loadedKey, "Expected nil key when no key exists")
|
||||
})
|
||||
|
||||
t.Run("LoadKey with unencrypted key", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
provider := &KeyProviderFile{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save a key
|
||||
err = provider.SaveKey(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Make sure the key file exists
|
||||
keyPath := filepath.Join(tempDir, PrivateKeyFile)
|
||||
exists, err := utils.FileExists(keyPath)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Expected key file to exist")
|
||||
|
||||
// Load the key
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, loadedKey, "Expected non-nil key when key exists")
|
||||
|
||||
// Verify the loaded key is the same as the original
|
||||
keyBytes, err := EncodeJWKBytes(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key")
|
||||
})
|
||||
|
||||
t.Run("LoadKey with encrypted key", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
provider := &KeyProviderFile{}
|
||||
err = provider.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
Kek: makeKEK(t),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save a key (will be encrypted)
|
||||
err = provider.SaveKey(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Make sure the encrypted key file exists
|
||||
encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted)
|
||||
exists, err := utils.FileExists(encKeyPath)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Expected encrypted key file to exist")
|
||||
|
||||
// Make sure the unencrypted key file does not exist
|
||||
keyPath := filepath.Join(tempDir, PrivateKeyFile)
|
||||
exists, err = utils.FileExists(keyPath)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Expected unencrypted key file to not exist")
|
||||
|
||||
// Load the key
|
||||
loadedKey, err := provider.LoadKey()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, loadedKey, "Expected non-nil key when encrypted key exists")
|
||||
|
||||
// Verify the loaded key is the same as the original
|
||||
keyBytes, err := EncodeJWKBytes(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key")
|
||||
})
|
||||
|
||||
t.Run("LoadKey replaces unencrypted key with encrypted key when kek is provided", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// First, create an unencrypted key
|
||||
providerNoKek := &KeyProviderFile{}
|
||||
err := providerNoKek.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save an unencrypted key
|
||||
err = providerNoKek.SaveKey(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify unencrypted key exists
|
||||
keyPath := filepath.Join(tempDir, PrivateKeyFile)
|
||||
exists, err := utils.FileExists(keyPath)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Expected unencrypted key file to exist")
|
||||
|
||||
// Now create a provider with a kek
|
||||
kek := make([]byte, 32)
|
||||
_, err = rand.Read(kek)
|
||||
require.NoError(t, err)
|
||||
|
||||
providerWithKek := &KeyProviderFile{}
|
||||
err = providerWithKek.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
Kek: kek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Load the key - this should convert the unencrypted key to encrypted
|
||||
loadedKey, err := providerWithKek.LoadKey()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, loadedKey, "Expected non-nil key when loading and converting key")
|
||||
|
||||
// Verify the unencrypted key no longer exists
|
||||
exists, err = utils.FileExists(keyPath)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Expected unencrypted key file to be removed")
|
||||
|
||||
// Verify the encrypted key file exists
|
||||
encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted)
|
||||
exists, err = utils.FileExists(encKeyPath)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Expected encrypted key file to exist after conversion")
|
||||
|
||||
// Verify the key data
|
||||
keyBytes, err := EncodeJWKBytes(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key after conversion")
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeyProviderFile_SaveKey(t *testing.T) {
|
||||
// Generate a test key to use in our tests
|
||||
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := jwk.Import(pk)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("SaveKey unencrypted", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
provider := &KeyProviderFile{}
|
||||
err := provider.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save the key
|
||||
err = provider.SaveKey(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the key file exists
|
||||
keyPath := filepath.Join(tempDir, PrivateKeyFile)
|
||||
exists, err := utils.FileExists(keyPath)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Expected key file to exist")
|
||||
|
||||
// Verify the content of the key file
|
||||
data, err := os.ReadFile(keyPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
parsedKey, err := jwk.ParseKey(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Compare the saved key with the original
|
||||
keyBytes, err := EncodeJWKBytes(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
parsedKeyBytes, err := EncodeJWKBytes(parsedKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, keyBytes, parsedKeyBytes, "Expected saved key to match original key")
|
||||
})
|
||||
|
||||
t.Run("SaveKey encrypted", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Generate a 64-byte kek
|
||||
kek := makeKEK(t)
|
||||
|
||||
provider := &KeyProviderFile{}
|
||||
err = provider.Init(KeyProviderOpts{
|
||||
EnvConfig: &common.EnvConfigSchema{
|
||||
KeysPath: tempDir,
|
||||
},
|
||||
Kek: kek,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save the key (will be encrypted)
|
||||
err = provider.SaveKey(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the encrypted key file exists
|
||||
encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted)
|
||||
exists, err := utils.FileExists(encKeyPath)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Expected encrypted key file to exist")
|
||||
|
||||
// Verify the unencrypted key file doesn't exist
|
||||
keyPath := filepath.Join(tempDir, PrivateKeyFile)
|
||||
exists, err = utils.FileExists(keyPath)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Expected unencrypted key file to not exist")
|
||||
|
||||
// Manually decrypt the encrypted key file to verify it contains the correct key
|
||||
encB64, err := os.ReadFile(encKeyPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Decode from base64
|
||||
enc := make([]byte, base64.StdEncoding.DecodedLen(len(encB64)))
|
||||
n, err := base64.StdEncoding.Decode(enc, encB64)
|
||||
require.NoError(t, err)
|
||||
enc = enc[:n] // Trim any padding
|
||||
|
||||
// Decrypt the data
|
||||
data, err := cryptoutils.Decrypt(kek, enc, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse the key
|
||||
parsedKey, err := jwk.ParseKey(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Compare the decrypted key with the original
|
||||
keyBytes, err := EncodeJWKBytes(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
parsedKeyBytes, err := EncodeJWKBytes(parsedKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, keyBytes, parsedKeyBytes, "Expected decrypted key to match original key")
|
||||
})
|
||||
}
|
||||
|
||||
func makeKEK(t *testing.T) []byte {
|
||||
t.Helper()
|
||||
|
||||
// Generate a 32-byte kek
|
||||
kek := make([]byte, 32)
|
||||
_, err := rand.Read(kek)
|
||||
require.NoError(t, err)
|
||||
return kek
|
||||
}
|
||||
180
backend/internal/utils/jwk/utils.go
Normal file
180
backend/internal/utils/jwk/utils.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/elliptic"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha3"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
)
|
||||
|
||||
const (
|
||||
// KeyUsageSigning is the usage for the private keys, for the "use" property
|
||||
KeyUsageSigning = "sig"
|
||||
)
|
||||
|
||||
// EncodeJWK encodes a jwk.Key to a writable stream.
|
||||
func EncodeJWK(w io.Writer, key jwk.Key) error {
|
||||
enc := json.NewEncoder(w)
|
||||
enc.SetEscapeHTML(false)
|
||||
return enc.Encode(key)
|
||||
}
|
||||
|
||||
// EncodeJWKBytes encodes a jwk.Key to a byte slice.
|
||||
func EncodeJWKBytes(key jwk.Key) ([]byte, error) {
|
||||
b := &bytes.Buffer{}
|
||||
err := EncodeJWK(b, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b.Bytes(), nil
|
||||
}
|
||||
|
||||
// LoadKeyEncryptionKey loads the key encryption key for JWKs
|
||||
func LoadKeyEncryptionKey(envConfig *common.EnvConfigSchema, instanceID string) (kek []byte, err error) {
|
||||
// Try getting the key from the env var as string
|
||||
kekInput := []byte(envConfig.EncryptionKey)
|
||||
|
||||
// If there's nothing in the env, try loading from file
|
||||
if len(kekInput) == 0 && envConfig.EncryptionKeyFile != "" {
|
||||
kekInput, err = os.ReadFile(envConfig.EncryptionKeyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read key file '%s': %w", envConfig.EncryptionKeyFile, err)
|
||||
}
|
||||
}
|
||||
|
||||
// If there's still no key, return
|
||||
if len(kekInput) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// We need a 256-bit key for encryption with AES-GCM-256
|
||||
// We use HMAC with SHA3-256 here to derive the key from the one passed as input
|
||||
// The key is tied to a specific instance of Pocket ID
|
||||
h := hmac.New(func() hash.Hash { return sha3.New256() }, kekInput)
|
||||
fmt.Fprint(h, "pocketid/"+instanceID+"/jwk-kek")
|
||||
kek = h.Sum(nil)
|
||||
|
||||
return kek, nil
|
||||
}
|
||||
|
||||
// ImportRawKey imports a crypto key in "raw" format (e.g. crypto.PrivateKey) into a jwk.Key.
|
||||
// It also populates additional fields such as the key ID, usage, and alg.
|
||||
func ImportRawKey(rawKey any, alg string, crv string) (jwk.Key, error) {
|
||||
key, err := jwk.Import(rawKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to import generated private key: %w", err)
|
||||
}
|
||||
|
||||
// Generate the key ID
|
||||
kid, err := generateRandomKeyID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate key ID: %w", err)
|
||||
}
|
||||
_ = key.Set(jwk.KeyIDKey, kid)
|
||||
|
||||
// Set other required fields
|
||||
_ = key.Set(jwk.KeyUsageKey, KeyUsageSigning)
|
||||
EnsureAlgInKey(key, alg, crv)
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// generateRandomKeyID generates a random key ID.
|
||||
func generateRandomKeyID() (string, error) {
|
||||
buf := make([]byte, 8)
|
||||
_, err := io.ReadFull(rand.Reader, buf)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read random bytes: %w", err)
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(buf), nil
|
||||
}
|
||||
|
||||
// EnsureAlgInKey ensures that the key contains an "alg" parameter (and "crv", if needed), set depending on the key type
|
||||
func EnsureAlgInKey(key jwk.Key, alg string, crv string) {
|
||||
_, ok := key.Algorithm()
|
||||
if ok {
|
||||
// Algorithm is already set
|
||||
return
|
||||
}
|
||||
|
||||
if alg != "" {
|
||||
_ = key.Set(jwk.AlgorithmKey, alg)
|
||||
if crv != "" {
|
||||
eca, ok := jwa.LookupEllipticCurveAlgorithm(crv)
|
||||
if ok {
|
||||
switch key.KeyType() {
|
||||
case jwa.EC():
|
||||
_ = key.Set(jwk.ECDSACrvKey, eca)
|
||||
case jwa.OKP():
|
||||
_ = key.Set(jwk.OKPCrvKey, eca)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// If we don't have an algorithm, set the default for the key type
|
||||
switch key.KeyType() {
|
||||
case jwa.RSA():
|
||||
// Default to RS256 for RSA keys
|
||||
_ = key.Set(jwk.AlgorithmKey, jwa.RS256())
|
||||
case jwa.EC():
|
||||
// Default to ES256 for ECDSA keys
|
||||
_ = key.Set(jwk.AlgorithmKey, jwa.ES256())
|
||||
_ = key.Set(jwk.ECDSACrvKey, jwa.P256())
|
||||
case jwa.OKP():
|
||||
// Default to EdDSA and Ed25519 for OKP keys
|
||||
_ = key.Set(jwk.AlgorithmKey, jwa.EdDSA())
|
||||
_ = key.Set(jwk.OKPCrvKey, jwa.Ed25519())
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateKey generates a new jwk.Key
|
||||
func GenerateKey(alg string, crv string) (key jwk.Key, err error) {
|
||||
var rawKey any
|
||||
switch alg {
|
||||
case jwa.RS256().String():
|
||||
rawKey, err = rsa.GenerateKey(rand.Reader, 2048)
|
||||
case jwa.RS384().String():
|
||||
rawKey, err = rsa.GenerateKey(rand.Reader, 3072)
|
||||
case jwa.RS512().String():
|
||||
rawKey, err = rsa.GenerateKey(rand.Reader, 4096)
|
||||
case jwa.ES256().String():
|
||||
rawKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
case jwa.ES384().String():
|
||||
rawKey, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
||||
case jwa.ES512().String():
|
||||
rawKey, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
|
||||
case jwa.EdDSA().String():
|
||||
switch crv {
|
||||
case jwa.Ed25519().String():
|
||||
_, rawKey, err = ed25519.GenerateKey(rand.Reader)
|
||||
default:
|
||||
return nil, errors.New("unsupported curve for EdDSA algorithm")
|
||||
}
|
||||
default:
|
||||
return nil, errors.New("unsupported key algorithm")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate private key: %w", err)
|
||||
}
|
||||
|
||||
// Import the raw key
|
||||
return ImportRawKey(rawKey, alg, crv)
|
||||
}
|
||||
324
backend/internal/utils/jwk/utils_test.go
Normal file
324
backend/internal/utils/jwk/utils_test.go
Normal file
@@ -0,0 +1,324 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"testing"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGenerateKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
alg string
|
||||
crv string
|
||||
expectError bool
|
||||
expectedAlg jwa.SignatureAlgorithm
|
||||
}{
|
||||
{
|
||||
name: "RS256",
|
||||
alg: jwa.RS256().String(),
|
||||
crv: "",
|
||||
expectError: false,
|
||||
expectedAlg: jwa.RS256(),
|
||||
},
|
||||
{
|
||||
name: "RS384",
|
||||
alg: jwa.RS384().String(),
|
||||
crv: "",
|
||||
expectError: false,
|
||||
expectedAlg: jwa.RS384(),
|
||||
},
|
||||
// Skip the RS512 test as generating a RSA-4096 key can take some time
|
||||
/* {
|
||||
name: "RS512",
|
||||
alg: jwa.RS512().String(),
|
||||
crv: "",
|
||||
expectError: false,
|
||||
expectedAlg: jwa.RS512(),
|
||||
}, */
|
||||
{
|
||||
name: "ES256",
|
||||
alg: jwa.ES256().String(),
|
||||
crv: jwa.P256().String(),
|
||||
expectError: false,
|
||||
expectedAlg: jwa.ES256(),
|
||||
},
|
||||
{
|
||||
name: "ES384",
|
||||
alg: jwa.ES384().String(),
|
||||
crv: jwa.P384().String(),
|
||||
expectError: false,
|
||||
expectedAlg: jwa.ES384(),
|
||||
},
|
||||
{
|
||||
name: "ES512",
|
||||
alg: jwa.ES512().String(),
|
||||
crv: jwa.P521().String(),
|
||||
expectError: false,
|
||||
expectedAlg: jwa.ES512(),
|
||||
},
|
||||
{
|
||||
name: "EdDSA with Ed25519",
|
||||
alg: jwa.EdDSA().String(),
|
||||
crv: jwa.Ed25519().String(),
|
||||
expectError: false,
|
||||
expectedAlg: jwa.EdDSA(),
|
||||
},
|
||||
{
|
||||
name: "EdDSA with unsupported curve",
|
||||
alg: jwa.EdDSA().String(),
|
||||
crv: "unsupported",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Unsupported algorithm",
|
||||
alg: "UNSUPPORTED",
|
||||
crv: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
key, err := GenerateKey(tt.alg, tt.crv)
|
||||
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, key)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, key)
|
||||
|
||||
// Verify the algorithm is set correctly
|
||||
alg, ok := key.Algorithm()
|
||||
require.True(t, ok, "algorithm should be set in the key")
|
||||
assert.Equal(t, tt.expectedAlg.String(), alg.String())
|
||||
|
||||
// Verify other required fields are set
|
||||
kid, ok := key.KeyID()
|
||||
assert.True(t, ok, "key ID should be set")
|
||||
assert.NotEmpty(t, kid, "key ID should not be empty")
|
||||
|
||||
usage, ok := key.KeyUsage()
|
||||
assert.True(t, ok, "key usage should be set")
|
||||
assert.Equal(t, KeyUsageSigning, usage)
|
||||
|
||||
var crv any
|
||||
_ = key.Get("crv", &crv)
|
||||
|
||||
// Verify key type matches expected algorithm
|
||||
switch tt.expectedAlg {
|
||||
case jwa.RS256(), jwa.RS384(), jwa.RS512():
|
||||
assert.Equal(t, jwa.RSA(), key.KeyType())
|
||||
assert.Nil(t, crv)
|
||||
case jwa.ES256(), jwa.ES384(), jwa.ES512():
|
||||
assert.Equal(t, jwa.EC(), key.KeyType())
|
||||
eca, ok := crv.(jwa.EllipticCurveAlgorithm)
|
||||
_ = assert.NotNil(t, crv) &&
|
||||
assert.True(t, ok) &&
|
||||
assert.Equal(t, tt.crv, eca.String())
|
||||
case jwa.EdDSA():
|
||||
assert.Equal(t, jwa.OKP(), key.KeyType())
|
||||
eca, ok := crv.(jwa.EllipticCurveAlgorithm)
|
||||
_ = assert.NotNil(t, crv) &&
|
||||
assert.True(t, ok) &&
|
||||
assert.Equal(t, tt.crv, eca.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureAlgInKey(t *testing.T) {
|
||||
// Generate an RSA-2048 key
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("does not change alg already set", func(t *testing.T) {
|
||||
// Import the RSA key
|
||||
key, err := jwk.Import(rsaKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Pre-set the algorithm
|
||||
_ = key.Set(jwk.AlgorithmKey, jwa.RS256())
|
||||
|
||||
// Call EnsureAlgInKey with a different algorithm
|
||||
EnsureAlgInKey(key, jwa.RS384().String(), "")
|
||||
|
||||
// Verify the algorithm wasn't changed
|
||||
alg, ok := key.Algorithm()
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, jwa.RS256().String(), alg.String())
|
||||
})
|
||||
|
||||
t.Run("set algorithm to explicitly-provided value", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyGen func() (any, error)
|
||||
alg string
|
||||
crv string
|
||||
expectedAlg jwa.SignatureAlgorithm
|
||||
expectedCrv string
|
||||
}{
|
||||
{
|
||||
name: "RSA key with RS384",
|
||||
keyGen: func() (any, error) {
|
||||
return rsaKey, nil
|
||||
},
|
||||
alg: jwa.RS384().String(),
|
||||
crv: "",
|
||||
expectedAlg: jwa.RS384(),
|
||||
expectedCrv: "",
|
||||
},
|
||||
{
|
||||
name: "ECDSA key with ES384",
|
||||
keyGen: func() (any, error) {
|
||||
return ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
},
|
||||
alg: jwa.ES384().String(),
|
||||
crv: jwa.P384().String(),
|
||||
expectedAlg: jwa.ES384(),
|
||||
expectedCrv: jwa.P384().String(),
|
||||
},
|
||||
{
|
||||
name: "Ed25519 key with EdDSA",
|
||||
keyGen: func() (any, error) {
|
||||
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
return priv, err
|
||||
},
|
||||
alg: jwa.EdDSA().String(),
|
||||
crv: jwa.Ed25519().String(),
|
||||
expectedAlg: jwa.EdDSA(),
|
||||
expectedCrv: jwa.Ed25519().String(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rawKey, err := tt.keyGen()
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := jwk.Import(rawKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Ensure no algorithm is set initially
|
||||
_, ok := key.Algorithm()
|
||||
assert.False(t, ok)
|
||||
|
||||
// Call EnsureAlgInKey
|
||||
EnsureAlgInKey(key, tt.alg, tt.crv)
|
||||
|
||||
// Verify the algorithm was set correctly
|
||||
alg, ok := key.Algorithm()
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, tt.expectedAlg.String(), alg.String())
|
||||
|
||||
// Verify curve if expected
|
||||
if tt.expectedCrv != "" {
|
||||
var crv any
|
||||
_ = key.Get("crv", &crv)
|
||||
require.NotNil(t, crv)
|
||||
eca, ok := crv.(jwa.EllipticCurveAlgorithm)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, tt.expectedCrv, eca.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("set default algorithms if not present", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyGen func() (any, error)
|
||||
expectedAlg jwa.SignatureAlgorithm
|
||||
expectedCrv string
|
||||
}{
|
||||
{
|
||||
name: "RSA key defaults to RS256",
|
||||
keyGen: func() (any, error) {
|
||||
return rsaKey, nil
|
||||
},
|
||||
expectedAlg: jwa.RS256(),
|
||||
expectedCrv: "",
|
||||
},
|
||||
{
|
||||
name: "ECDSA key defaults to ES256 with P256",
|
||||
keyGen: func() (any, error) {
|
||||
return ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
},
|
||||
expectedAlg: jwa.ES256(),
|
||||
expectedCrv: jwa.P256().String(),
|
||||
},
|
||||
{
|
||||
name: "Ed25519 key defaults to EdDSA with Ed25519",
|
||||
keyGen: func() (any, error) {
|
||||
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
return priv, err
|
||||
},
|
||||
expectedAlg: jwa.EdDSA(),
|
||||
expectedCrv: jwa.Ed25519().String(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rawKey, err := tt.keyGen()
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := jwk.Import(rawKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Ensure no algorithm is set initially
|
||||
_, ok := key.Algorithm()
|
||||
assert.False(t, ok)
|
||||
|
||||
// Call EnsureAlgInKey with empty parameters
|
||||
EnsureAlgInKey(key, "", "")
|
||||
|
||||
// Verify the default algorithm was set
|
||||
alg, ok := key.Algorithm()
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, tt.expectedAlg.String(), alg.String())
|
||||
|
||||
// Verify curve if expected
|
||||
if tt.expectedCrv != "" {
|
||||
var crv any
|
||||
_ = key.Get("crv", &crv)
|
||||
require.NotNil(t, crv)
|
||||
eca, ok := crv.(jwa.EllipticCurveAlgorithm)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, tt.expectedCrv, eca.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid curve should not set curve parameter", func(t *testing.T) {
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := jwk.Import(rsaKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Call EnsureAlgInKey with invalid curve
|
||||
EnsureAlgInKey(key, jwa.RS256().String(), "invalid-curve")
|
||||
|
||||
// Verify algorithm was set but curve was not
|
||||
alg, ok := key.Algorithm()
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, jwa.RS256().String(), alg.String())
|
||||
|
||||
var crv any
|
||||
_ = key.Get("crv", &crv)
|
||||
assert.Nil(t, crv)
|
||||
})
|
||||
}
|
||||
@@ -1,69 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
)
|
||||
|
||||
const (
|
||||
// KeyUsageSigning is the usage for the private keys, for the "use" property
|
||||
KeyUsageSigning = "sig"
|
||||
)
|
||||
|
||||
// ImportRawKey imports a crypto key in "raw" format (e.g. crypto.PrivateKey) into a jwk.Key.
|
||||
// It also populates additional fields such as the key ID, usage, and alg.
|
||||
func ImportRawKey(rawKey any) (jwk.Key, error) {
|
||||
key, err := jwk.Import(rawKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to import generated private key: %w", err)
|
||||
}
|
||||
|
||||
// Generate the key ID
|
||||
kid, err := generateRandomKeyID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate key ID: %w", err)
|
||||
}
|
||||
_ = key.Set(jwk.KeyIDKey, kid)
|
||||
|
||||
// Set other required fields
|
||||
_ = key.Set(jwk.KeyUsageKey, KeyUsageSigning)
|
||||
EnsureAlgInKey(key)
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// generateRandomKeyID generates a random key ID.
|
||||
func generateRandomKeyID() (string, error) {
|
||||
buf := make([]byte, 8)
|
||||
_, err := io.ReadFull(rand.Reader, buf)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read random bytes: %w", err)
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(buf), nil
|
||||
}
|
||||
|
||||
// EnsureAlgInKey ensures that the key contains an "alg" parameter, set depending on the key type
|
||||
func EnsureAlgInKey(key jwk.Key) {
|
||||
_, ok := key.Algorithm()
|
||||
if ok {
|
||||
// Algorithm is already set
|
||||
return
|
||||
}
|
||||
|
||||
switch key.KeyType() {
|
||||
case jwa.RSA():
|
||||
// Default to RS256 for RSA keys
|
||||
_ = key.Set(jwk.AlgorithmKey, jwa.RS256())
|
||||
case jwa.EC():
|
||||
// Default to ES256 for ECDSA keys
|
||||
_ = key.Set(jwk.AlgorithmKey, jwa.ES256())
|
||||
case jwa.OKP():
|
||||
// Default to EdDSA for OKP keys
|
||||
_ = key.Set(jwk.AlgorithmKey, jwa.EdDSA())
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,8 @@
|
||||
package service
|
||||
// This file is only imported by unit tests
|
||||
|
||||
package testing
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -21,7 +20,10 @@ import (
|
||||
"github.com/pocket-id/pocket-id/backend/resources"
|
||||
)
|
||||
|
||||
func newDatabaseForTest(t *testing.T) *gorm.DB {
|
||||
// NewDatabaseForTest returns a new instance of GORM connected to an in-memory SQLite database.
|
||||
// Each database connection is unique for the test.
|
||||
// All migrations are automatically performed.
|
||||
func NewDatabaseForTest(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
// Get a name for this in-memory database that is specific to the test
|
||||
@@ -68,30 +70,3 @@ type testLoggerAdapter struct {
|
||||
func (l testLoggerAdapter) Printf(format string, args ...any) {
|
||||
l.t.Logf(format, args...)
|
||||
}
|
||||
|
||||
// MockRoundTripper is a custom http.RoundTripper that returns responses based on the URL
|
||||
type MockRoundTripper struct {
|
||||
Err error
|
||||
Responses map[string]*http.Response
|
||||
}
|
||||
|
||||
// RoundTrip implements the http.RoundTripper interface
|
||||
func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Check if we have a specific response for this URL
|
||||
for url, resp := range m.Responses {
|
||||
if req.URL.String() == url {
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
return NewMockResponse(http.StatusNotFound, ""), nil
|
||||
}
|
||||
|
||||
// NewMockResponse creates an http.Response with the given status code and body
|
||||
func NewMockResponse(statusCode int, body string) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: statusCode,
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
Header: make(http.Header),
|
||||
}
|
||||
}
|
||||
38
backend/internal/utils/testing/round_tripper.go
Normal file
38
backend/internal/utils/testing/round_tripper.go
Normal file
@@ -0,0 +1,38 @@
|
||||
// This file is only imported by unit tests
|
||||
|
||||
package testing
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
)
|
||||
|
||||
// MockRoundTripper is a custom http.RoundTripper that returns responses based on the URL
|
||||
type MockRoundTripper struct {
|
||||
Err error
|
||||
Responses map[string]*http.Response
|
||||
}
|
||||
|
||||
// RoundTrip implements the http.RoundTripper interface
|
||||
func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Check if we have a specific response for this URL
|
||||
for url, resp := range m.Responses {
|
||||
if req.URL.String() == url {
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
return NewMockResponse(http.StatusNotFound, ""), nil
|
||||
}
|
||||
|
||||
// NewMockResponse creates an http.Response with the given status code and body
|
||||
func NewMockResponse(statusCode int, body string) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: statusCode,
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
Header: make(http.Header),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
DROP TABLE kv;
|
||||
@@ -0,0 +1,6 @@
|
||||
-- The "kv" tables contains miscellaneous key-value pairs
|
||||
CREATE TABLE kv
|
||||
(
|
||||
"key" TEXT NOT NULL PRIMARY KEY,
|
||||
"value" TEXT
|
||||
);
|
||||
@@ -0,0 +1 @@
|
||||
DROP TABLE kv;
|
||||
@@ -0,0 +1,6 @@
|
||||
-- The "kv" tables contains miscellaneous key-value pairs
|
||||
CREATE TABLE kv
|
||||
(
|
||||
"key" TEXT NOT NULL PRIMARY KEY,
|
||||
"value" TEXT NOT NULL
|
||||
);
|
||||
Reference in New Issue
Block a user