feat: encrypt private keys saved on disk and in database (#682)

Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
This commit is contained in:
Alessandro (Ale) Segala
2025-07-03 11:34:34 -07:00
committed by GitHub
parent 9872608d61
commit 5550729120
25 changed files with 2311 additions and 328 deletions

View File

@@ -38,7 +38,7 @@ func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client) (sv
svc.geoLiteService = service.NewGeoLiteService(httpClient)
svc.auditLogService = service.NewAuditLogService(db, svc.appConfigService, svc.emailService, svc.geoLiteService)
svc.jwtService = service.NewJwtService(svc.appConfigService)
svc.jwtService = service.NewJwtService(db, svc.appConfigService)
svc.userService = service.NewUserService(db, svc.jwtService, svc.auditLogService, svc.emailService, svc.appConfigService)
svc.customClaimService = service.NewCustomClaimService(db)

View File

@@ -1,6 +1,8 @@
package common
import (
"errors"
"fmt"
"log"
"net/url"
@@ -18,9 +20,10 @@ const (
)
const (
DbProviderSqlite DbProvider = "sqlite"
DbProviderPostgres DbProvider = "postgres"
MaxMindGeoLiteCityUrl string = "https://download.maxmind.com/app/geoip_download?edition_id=GeoLite2-City&license_key=%s&suffix=tar.gz"
DbProviderSqlite DbProvider = "sqlite"
DbProviderPostgres DbProvider = "postgres"
MaxMindGeoLiteCityUrl string = "https://download.maxmind.com/app/geoip_download?edition_id=GeoLite2-City&license_key=%s&suffix=tar.gz"
defaultSqliteConnString string = "file:data/pocket-id.db?_pragma=journal_mode(WAL)&_pragma=busy_timeout(2500)&_txlock=immediate"
)
type EnvConfigSchema struct {
@@ -30,6 +33,9 @@ type EnvConfigSchema struct {
DbConnectionString string `env:"DB_CONNECTION_STRING"`
UploadPath string `env:"UPLOAD_PATH"`
KeysPath string `env:"KEYS_PATH"`
KeysStorage string `env:"KEYS_STORAGE"`
EncryptionKey string `env:"ENCRYPTION_KEY"`
EncryptionKeyFile string `env:"ENCRYPTION_KEY_FILE"`
Port string `env:"PORT"`
Host string `env:"HOST"`
UnixSocket string `env:"UNIX_SOCKET"`
@@ -45,52 +51,83 @@ type EnvConfigSchema struct {
AnalyticsDisabled bool `env:"ANALYTICS_DISABLED"`
}
var EnvConfig = &EnvConfigSchema{
AppEnv: "production",
DbProvider: "sqlite",
DbConnectionString: "file:data/pocket-id.db?_pragma=journal_mode(WAL)&_pragma=busy_timeout(2500)&_txlock=immediate",
UploadPath: "data/uploads",
KeysPath: "data/keys",
AppURL: "http://localhost:1411",
Port: "1411",
Host: "0.0.0.0",
UnixSocket: "",
UnixSocketMode: "",
MaxMindLicenseKey: "",
GeoLiteDBPath: "data/GeoLite2-City.mmdb",
GeoLiteDBUrl: MaxMindGeoLiteCityUrl,
LocalIPv6Ranges: "",
UiConfigDisabled: false,
MetricsEnabled: false,
TracingEnabled: false,
TrustProxy: false,
AnalyticsDisabled: false,
}
var EnvConfig = defaultConfig()
func init() {
if err := env.ParseWithOptions(EnvConfig, env.Options{}); err != nil {
log.Fatal(err)
err := parseEnvConfig()
if err != nil {
log.Fatalf("Configuration error: %v", err)
}
}
func defaultConfig() EnvConfigSchema {
return EnvConfigSchema{
AppEnv: "production",
DbProvider: "sqlite",
DbConnectionString: "",
UploadPath: "data/uploads",
KeysPath: "data/keys",
KeysStorage: "", // "database" or "file"
EncryptionKey: "",
AppURL: "http://localhost:1411",
Port: "1411",
Host: "0.0.0.0",
UnixSocket: "",
UnixSocketMode: "",
MaxMindLicenseKey: "",
GeoLiteDBPath: "data/GeoLite2-City.mmdb",
GeoLiteDBUrl: MaxMindGeoLiteCityUrl,
LocalIPv6Ranges: "",
UiConfigDisabled: false,
MetricsEnabled: false,
TracingEnabled: false,
TrustProxy: false,
AnalyticsDisabled: false,
}
}
func parseEnvConfig() error {
err := env.ParseWithOptions(&EnvConfig, env.Options{})
if err != nil {
return fmt.Errorf("error parsing env config: %w", err)
}
// Validate the environment variables
switch EnvConfig.DbProvider {
case DbProviderSqlite:
if EnvConfig.DbConnectionString == "" {
log.Fatal("Missing required env var 'DB_CONNECTION_STRING' for SQLite database")
EnvConfig.DbConnectionString = defaultSqliteConnString
}
case DbProviderPostgres:
if EnvConfig.DbConnectionString == "" {
log.Fatal("Missing required env var 'DB_CONNECTION_STRING' for Postgres database")
return errors.New("missing required env var 'DB_CONNECTION_STRING' for Postgres database")
}
default:
log.Fatal("Invalid DB_PROVIDER value. Must be 'sqlite' or 'postgres'")
return errors.New("invalid DB_PROVIDER value. Must be 'sqlite' or 'postgres'")
}
parsedAppUrl, err := url.Parse(EnvConfig.AppURL)
if err != nil {
log.Fatal("APP_URL is not a valid URL")
return errors.New("APP_URL is not a valid URL")
}
if parsedAppUrl.Path != "" {
log.Fatal("APP_URL must not contain a path")
return errors.New("APP_URL must not contain a path")
}
switch EnvConfig.KeysStorage {
// KeysStorage defaults to "file" if empty
case "":
EnvConfig.KeysStorage = "file"
case "database":
// If KeysStorage is "database", a key must be specified
if EnvConfig.EncryptionKey == "" && EnvConfig.EncryptionKeyFile == "" {
return errors.New("ENCRYPTION_KEY or ENCRYPTION_KEY_FILE must be non-empty when KEYS_STORAGE is database")
}
case "file":
// All good, these are valid values
default:
return fmt.Errorf("invalid value for KEYS_STORAGE: %s", EnvConfig.KeysStorage)
}
return nil
}

View File

@@ -0,0 +1,188 @@
package common
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseEnvConfig(t *testing.T) {
// Store original config to restore later
originalConfig := EnvConfig
t.Cleanup(func() {
EnvConfig = originalConfig
})
t.Run("should parse valid SQLite config correctly", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite")
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("APP_URL", "http://localhost:3000")
err := parseEnvConfig()
require.NoError(t, err)
assert.Equal(t, DbProviderSqlite, EnvConfig.DbProvider)
})
t.Run("should parse valid Postgres config correctly", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "postgres")
t.Setenv("DB_CONNECTION_STRING", "postgres://user:pass@localhost/db")
t.Setenv("APP_URL", "https://example.com")
err := parseEnvConfig()
require.NoError(t, err)
assert.Equal(t, DbProviderPostgres, EnvConfig.DbProvider)
})
t.Run("should fail with invalid DB_PROVIDER", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "invalid")
t.Setenv("DB_CONNECTION_STRING", "test")
t.Setenv("APP_URL", "http://localhost:3000")
err := parseEnvConfig()
require.Error(t, err)
assert.ErrorContains(t, err, "invalid DB_PROVIDER value")
})
t.Run("should set default SQLite connection string when DB_CONNECTION_STRING is empty", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite")
t.Setenv("DB_CONNECTION_STRING", "") // Explicitly empty
t.Setenv("APP_URL", "http://localhost:3000")
err := parseEnvConfig()
require.NoError(t, err)
assert.Equal(t, defaultSqliteConnString, EnvConfig.DbConnectionString)
})
t.Run("should fail when Postgres DB_CONNECTION_STRING is missing", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "postgres")
t.Setenv("APP_URL", "http://localhost:3000")
err := parseEnvConfig()
require.Error(t, err)
assert.ErrorContains(t, err, "missing required env var 'DB_CONNECTION_STRING' for Postgres")
})
t.Run("should fail with invalid APP_URL", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite")
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("APP_URL", "€://not-a-valid-url")
err := parseEnvConfig()
require.Error(t, err)
assert.ErrorContains(t, err, "APP_URL is not a valid URL")
})
t.Run("should fail when APP_URL contains path", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite")
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("APP_URL", "http://localhost:3000/path")
err := parseEnvConfig()
require.Error(t, err)
assert.ErrorContains(t, err, "APP_URL must not contain a path")
})
t.Run("should default KEYS_STORAGE to 'file' when empty", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite")
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("APP_URL", "http://localhost:3000")
err := parseEnvConfig()
require.NoError(t, err)
assert.Equal(t, "file", EnvConfig.KeysStorage)
})
t.Run("should fail when KEYS_STORAGE is 'database' but no encryption key", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite")
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("APP_URL", "http://localhost:3000")
t.Setenv("KEYS_STORAGE", "database")
err := parseEnvConfig()
require.Error(t, err)
assert.ErrorContains(t, err, "ENCRYPTION_KEY or ENCRYPTION_KEY_FILE must be non-empty")
})
t.Run("should accept valid KEYS_STORAGE values", func(t *testing.T) {
validStorageTypes := []string{"file", "database"}
for _, storage := range validStorageTypes {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite")
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("APP_URL", "http://localhost:3000")
t.Setenv("KEYS_STORAGE", storage)
if storage == "database" {
t.Setenv("ENCRYPTION_KEY", "test-key")
}
err := parseEnvConfig()
require.NoError(t, err)
assert.Equal(t, storage, EnvConfig.KeysStorage)
}
})
t.Run("should fail with invalid KEYS_STORAGE value", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite")
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("APP_URL", "http://localhost:3000")
t.Setenv("KEYS_STORAGE", "invalid")
err := parseEnvConfig()
require.Error(t, err)
assert.ErrorContains(t, err, "invalid value for KEYS_STORAGE")
})
t.Run("should parse boolean environment variables correctly", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite")
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("APP_URL", "http://localhost:3000")
t.Setenv("UI_CONFIG_DISABLED", "true")
t.Setenv("METRICS_ENABLED", "true")
t.Setenv("TRACING_ENABLED", "false")
t.Setenv("TRUST_PROXY", "true")
t.Setenv("ANALYTICS_DISABLED", "false")
err := parseEnvConfig()
require.NoError(t, err)
assert.True(t, EnvConfig.UiConfigDisabled)
assert.True(t, EnvConfig.MetricsEnabled)
assert.False(t, EnvConfig.TracingEnabled)
assert.True(t, EnvConfig.TrustProxy)
assert.False(t, EnvConfig.AnalyticsDisabled)
})
t.Run("should parse string environment variables correctly", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "postgres")
t.Setenv("DB_CONNECTION_STRING", "postgres://test")
t.Setenv("APP_URL", "https://prod.example.com")
t.Setenv("APP_ENV", "staging")
t.Setenv("UPLOAD_PATH", "/custom/uploads")
t.Setenv("KEYS_PATH", "/custom/keys")
t.Setenv("PORT", "8080")
t.Setenv("HOST", "127.0.0.1")
t.Setenv("UNIX_SOCKET", "/tmp/app.sock")
t.Setenv("MAXMIND_LICENSE_KEY", "test-license")
t.Setenv("GEOLITE_DB_PATH", "/custom/geolite.mmdb")
err := parseEnvConfig()
require.NoError(t, err)
assert.Equal(t, "staging", EnvConfig.AppEnv)
assert.Equal(t, "/custom/uploads", EnvConfig.UploadPath)
assert.Equal(t, "8080", EnvConfig.Port)
assert.Equal(t, "127.0.0.1", EnvConfig.Host)
})
}

View File

@@ -0,0 +1,11 @@
package model
type KV struct {
Key string `gorm:"primaryKey;not null"`
Value *string
}
// TableName overrides the table name used by KV to `kv`
func (KV) TableName() string {
return "kv"
}

View File

@@ -4,10 +4,12 @@ import (
"sync/atomic"
"testing"
"github.com/stretchr/testify/require"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/stretchr/testify/require"
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
)
// NewTestAppConfigService is a function used by tests to create AppConfigService objects with pre-defined configuration values
@@ -22,7 +24,7 @@ func NewTestAppConfigService(config *model.AppConfig) *AppConfigService {
func TestLoadDbConfig(t *testing.T) {
t.Run("empty config table", func(t *testing.T) {
db := newDatabaseForTest(t)
db := testutils.NewDatabaseForTest(t)
service := &AppConfigService{
db: db,
}
@@ -36,7 +38,7 @@ func TestLoadDbConfig(t *testing.T) {
})
t.Run("loads value from config table", func(t *testing.T) {
db := newDatabaseForTest(t)
db := testutils.NewDatabaseForTest(t)
// Populate the config table with some initial values
err := db.
@@ -66,7 +68,7 @@ func TestLoadDbConfig(t *testing.T) {
})
t.Run("ignores unknown config keys", func(t *testing.T) {
db := newDatabaseForTest(t)
db := testutils.NewDatabaseForTest(t)
// Add an entry with a key that doesn't exist in the config struct
err := db.Create([]model.AppConfigVariable{
@@ -87,7 +89,7 @@ func TestLoadDbConfig(t *testing.T) {
})
t.Run("loading config multiple times", func(t *testing.T) {
db := newDatabaseForTest(t)
db := testutils.NewDatabaseForTest(t)
// Initial state
err := db.Create([]model.AppConfigVariable{
@@ -129,7 +131,7 @@ func TestLoadDbConfig(t *testing.T) {
common.EnvConfig.UiConfigDisabled = true
// Create database with config that should be ignored
db := newDatabaseForTest(t)
db := testutils.NewDatabaseForTest(t)
err := db.Create([]model.AppConfigVariable{
{Key: "appName", Value: "DB App"},
{Key: "sessionDuration", Value: "120"},
@@ -165,7 +167,7 @@ func TestLoadDbConfig(t *testing.T) {
common.EnvConfig.UiConfigDisabled = false
// Create database with config values that should take precedence
db := newDatabaseForTest(t)
db := testutils.NewDatabaseForTest(t)
err := db.Create([]model.AppConfigVariable{
{Key: "appName", Value: "DB App"},
{Key: "sessionDuration", Value: "120"},
@@ -189,7 +191,7 @@ func TestLoadDbConfig(t *testing.T) {
func TestUpdateAppConfigValues(t *testing.T) {
t.Run("update single value", func(t *testing.T) {
db := newDatabaseForTest(t)
db := testutils.NewDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
@@ -214,7 +216,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
})
t.Run("update multiple values", func(t *testing.T) {
db := newDatabaseForTest(t)
db := testutils.NewDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
@@ -258,7 +260,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
})
t.Run("empty value resets to default", func(t *testing.T) {
db := newDatabaseForTest(t)
db := testutils.NewDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
@@ -279,7 +281,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
})
t.Run("error with odd number of arguments", func(t *testing.T) {
db := newDatabaseForTest(t)
db := testutils.NewDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
@@ -295,7 +297,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
})
t.Run("error with invalid key", func(t *testing.T) {
db := newDatabaseForTest(t)
db := testutils.NewDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
@@ -313,7 +315,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
func TestUpdateAppConfig(t *testing.T) {
t.Run("updates configuration values from DTO", func(t *testing.T) {
db := newDatabaseForTest(t)
db := testutils.NewDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
@@ -386,7 +388,7 @@ func TestUpdateAppConfig(t *testing.T) {
})
t.Run("empty values reset to defaults", func(t *testing.T) {
db := newDatabaseForTest(t)
db := testutils.NewDatabaseForTest(t)
// Create a service with default config and modify some values
service := &AppConfigService{
@@ -451,7 +453,7 @@ func TestUpdateAppConfig(t *testing.T) {
// Disable UI config
common.EnvConfig.UiConfigDisabled = true
db := newDatabaseForTest(t)
db := testutils.NewDatabaseForTest(t)
service := &AppConfigService{
db: db,
}

View File

@@ -17,6 +17,7 @@ import (
"github.com/fxamacker/cbor/v2"
"github.com/go-webauthn/webauthn/protocol"
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/lestrrat-go/jwx/v3/jwt"
"gorm.io/gorm"
@@ -25,6 +26,7 @@ import (
"github.com/pocket-id/pocket-id/backend/internal/model"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/utils"
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
"github.com/pocket-id/pocket-id/backend/resources"
)
@@ -60,7 +62,7 @@ func (s *TestService) initExternalIdP() error {
return fmt.Errorf("failed to generate private key: %w", err)
}
s.externalIdPKey, err = utils.ImportRawKey(rawKey)
s.externalIdPKey, err = jwkutils.ImportRawKey(rawKey, jwa.ES256().String(), "")
if err != nil {
return fmt.Errorf("failed to import private key: %w", err)
}

View File

@@ -2,23 +2,20 @@ package service
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/json"
"errors"
"fmt"
"log"
"os"
"path/filepath"
"time"
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/lestrrat-go/jwx/v3/jwt"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils"
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
)
const (
@@ -26,8 +23,9 @@ const (
// This is a JSON file containing a key encoded as JWK
PrivateKeyFile = "jwt_private_key.json"
// RsaKeySize is the size, in bits, of the RSA key to generate if none is found
RsaKeySize = 2048
// PrivateKeyFileEncrypted is the path in the data/keys folder where the encrypted key is stored
// This is a encrypted JSON file containing a key encoded as JWK
PrivateKeyFileEncrypted = "jwt_private_key.json.enc"
// KeyUsageSigning is the usage for the private keys, for the "use" property
KeyUsageSigning = "sig"
@@ -59,58 +57,74 @@ const (
)
type JwtService struct {
envConfig *common.EnvConfigSchema
privateKey jwk.Key
keyId string
appConfigService *AppConfigService
jwksEncoded []byte
}
func NewJwtService(appConfigService *AppConfigService) *JwtService {
func NewJwtService(db *gorm.DB, appConfigService *AppConfigService) *JwtService {
service := &JwtService{}
// Ensure keys are generated or loaded
if err := service.init(appConfigService, common.EnvConfig.KeysPath); err != nil {
err := service.init(db, appConfigService, &common.EnvConfig)
if err != nil {
log.Fatalf("Failed to initialize jwt service: %v", err)
}
return service
}
func (s *JwtService) init(appConfigService *AppConfigService, keysPath string) error {
func (s *JwtService) init(db *gorm.DB, appConfigService *AppConfigService, envConfig *common.EnvConfigSchema) (err error) {
s.appConfigService = appConfigService
s.envConfig = envConfig
// Ensure keys are generated or loaded
return s.loadOrGenerateKey(keysPath)
return s.loadOrGenerateKey(db)
}
// loadOrGenerateKey loads the private key from the given path or generates it if not existing.
func (s *JwtService) loadOrGenerateKey(keysPath string) error {
var key jwk.Key
// First, check if we have a JWK file
// If we do, then we just load that
jwkPath := filepath.Join(keysPath, PrivateKeyFile)
ok, err := utils.FileExists(jwkPath)
func (s *JwtService) loadOrGenerateKey(db *gorm.DB) error {
// Get the key provider
keyProvider, err := jwkutils.GetKeyProvider(db, s.envConfig, s.appConfigService.GetDbConfig().InstanceID.Value)
if err != nil {
return fmt.Errorf("failed to check if private key file (JWK) exists at path '%s': %w", jwkPath, err)
return fmt.Errorf("failed to get key provider: %w", err)
}
if ok {
key, err = s.loadKeyJWK(jwkPath)
if err != nil {
return fmt.Errorf("failed to load private key file (JWK) at path '%s': %w", jwkPath, err)
}
// Set the key, and we are done
// Try loading a key
key, err := keyProvider.LoadKey()
if err != nil {
return fmt.Errorf("failed to load key (provider type '%s'): %w", s.envConfig.KeysStorage, err)
}
// If we have a key, store it in the object and we're done
if key != nil {
err = s.SetKey(key)
if err != nil {
return fmt.Errorf("failed to set private key: %w", err)
}
return nil
}
// If we are here, we need to generate a new key
key, err = s.generateNewRSAKey()
err = s.generateKey()
if err != nil {
return fmt.Errorf("failed to generate key: %w", err)
}
// Save the newly-generated key
err = keyProvider.SaveKey(s.privateKey)
if err != nil {
return fmt.Errorf("failed to save private key (provider type '%s'): %w", s.envConfig.KeysStorage, err)
}
return nil
}
// generateKey generates a new key and stores it in the object
func (s *JwtService) generateKey() error {
// Default is to generate RS256 (RSA-2048) keys
key, err := jwkutils.GenerateKey(jwa.RS256().String(), "")
if err != nil {
return fmt.Errorf("failed to generate new private key: %w", err)
}
@@ -121,12 +135,6 @@ func (s *JwtService) loadOrGenerateKey(keysPath string) error {
return fmt.Errorf("failed to set private key: %w", err)
}
// Save the key as JWK
err = SaveKeyJWK(s.privateKey, jwkPath)
if err != nil {
return fmt.Errorf("failed to save private key file at path '%s': %w", jwkPath, err)
}
return nil
}
@@ -192,13 +200,13 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
Subject(user.ID).
Expiration(now.Add(s.appConfigService.GetDbConfig().SessionDuration.AsDurationMinutes())).
IssuedAt(now).
Issuer(common.EnvConfig.AppURL).
Issuer(s.envConfig.AppURL).
Build()
if err != nil {
return "", fmt.Errorf("failed to build token: %w", err)
}
err = SetAudienceString(token, common.EnvConfig.AppURL)
err = SetAudienceString(token, s.envConfig.AppURL)
if err != nil {
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
}
@@ -229,8 +237,8 @@ func (s *JwtService) VerifyAccessToken(tokenString string) (jwt.Token, error) {
jwt.WithValidate(true),
jwt.WithKey(alg, s.privateKey),
jwt.WithAcceptableSkew(clockSkew),
jwt.WithAudience(common.EnvConfig.AppURL),
jwt.WithIssuer(common.EnvConfig.AppURL),
jwt.WithAudience(s.envConfig.AppURL),
jwt.WithIssuer(s.envConfig.AppURL),
jwt.WithValidator(TokenTypeValidator(AccessTokenJWTType)),
)
if err != nil {
@@ -246,7 +254,7 @@ func (s *JwtService) BuildIDToken(userClaims map[string]any, clientID string, no
token, err := jwt.NewBuilder().
Expiration(now.Add(1 * time.Hour)).
IssuedAt(now).
Issuer(common.EnvConfig.AppURL).
Issuer(s.envConfig.AppURL).
Build()
if err != nil {
return nil, fmt.Errorf("failed to build token: %w", err)
@@ -305,7 +313,7 @@ func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool)
jwt.WithValidate(true),
jwt.WithKey(alg, s.privateKey),
jwt.WithAcceptableSkew(clockSkew),
jwt.WithIssuer(common.EnvConfig.AppURL),
jwt.WithIssuer(s.envConfig.AppURL),
jwt.WithValidator(TokenTypeValidator(IDTokenJWTType)),
)
@@ -335,7 +343,7 @@ func (s *JwtService) BuildOAuthAccessToken(user model.User, clientID string) (jw
Subject(user.ID).
Expiration(now.Add(1 * time.Hour)).
IssuedAt(now).
Issuer(common.EnvConfig.AppURL).
Issuer(s.envConfig.AppURL).
Build()
if err != nil {
return nil, fmt.Errorf("failed to build token: %w", err)
@@ -377,7 +385,7 @@ func (s *JwtService) VerifyOAuthAccessToken(tokenString string) (jwt.Token, erro
jwt.WithValidate(true),
jwt.WithKey(alg, s.privateKey),
jwt.WithAcceptableSkew(clockSkew),
jwt.WithIssuer(common.EnvConfig.AppURL),
jwt.WithIssuer(s.envConfig.AppURL),
jwt.WithValidator(TokenTypeValidator(OAuthAccessTokenJWTType)),
)
if err != nil {
@@ -393,7 +401,7 @@ func (s *JwtService) GenerateOAuthRefreshToken(userID string, clientID string, r
Subject(userID).
Expiration(now.Add(RefreshTokenDuration)).
IssuedAt(now).
Issuer(common.EnvConfig.AppURL).
Issuer(s.envConfig.AppURL).
Build()
if err != nil {
return "", fmt.Errorf("failed to build token: %w", err)
@@ -430,7 +438,7 @@ func (s *JwtService) VerifyOAuthRefreshToken(tokenString string) (userID, client
jwt.WithValidate(true),
jwt.WithKey(alg, s.privateKey),
jwt.WithAcceptableSkew(clockSkew),
jwt.WithIssuer(common.EnvConfig.AppURL),
jwt.WithIssuer(s.envConfig.AppURL),
jwt.WithValidator(TokenTypeValidator(OAuthRefreshTokenJWTType)),
)
if err != nil {
@@ -488,7 +496,7 @@ func (s *JwtService) GetPublicJWK() (jwk.Key, error) {
return nil, fmt.Errorf("failed to get public key: %w", err)
}
utils.EnsureAlgInKey(pubKey)
jwkutils.EnsureAlgInKey(pubKey, "", "")
return pubKey, nil
}
@@ -517,56 +525,6 @@ func (s *JwtService) GetKeyAlg() (jwa.KeyAlgorithm, error) {
return alg, nil
}
func (s *JwtService) loadKeyJWK(path string) (jwk.Key, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("failed to read key data: %w", err)
}
key, err := jwk.ParseKey(data)
if err != nil {
return nil, fmt.Errorf("failed to parse key: %w", err)
}
return key, nil
}
func (s *JwtService) generateNewRSAKey() (jwk.Key, error) {
// We generate RSA keys only
rawKey, err := rsa.GenerateKey(rand.Reader, RsaKeySize)
if err != nil {
return nil, fmt.Errorf("failed to generate RSA private key: %w", err)
}
// Import the raw key
return utils.ImportRawKey(rawKey)
}
// SaveKeyJWK saves a JWK to a file
func SaveKeyJWK(key jwk.Key, path string) error {
dir := filepath.Dir(path)
err := os.MkdirAll(dir, 0700)
if err != nil {
return fmt.Errorf("failed to create directory '%s' for key file: %w", dir, err)
}
keyFile, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("failed to create key file: %w", err)
}
defer keyFile.Close()
// Write the JSON file to disk
enc := json.NewEncoder(keyFile)
enc.SetEscapeHTML(false)
err = enc.Encode(key)
if err != nil {
return fmt.Errorf("failed to write key file: %w", err)
}
return nil
}
// GetIsAdmin returns the value of the "isAdmin" claim in the token
func GetIsAdmin(token jwt.Token) (bool, error) {
if !token.Has(IsAdminClaim) {

View File

@@ -21,7 +21,7 @@ import (
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils"
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
)
func TestJwtService_Init(t *testing.T) {
@@ -33,9 +33,16 @@ func TestJwtService_Init(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
// Setup the environment variable required by the token verification
mockEnvConfig := &common.EnvConfigSchema{
AppURL: "https://test.example.com",
KeysStorage: "file",
KeysPath: tempDir,
}
// Initialize the JWT service
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Verify the private key was set
@@ -66,9 +73,16 @@ func TestJwtService_Init(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
// Setup the environment variable required by the token verification
mockEnvConfig := &common.EnvConfigSchema{
AppURL: "https://test.example.com",
KeysStorage: "file",
KeysPath: tempDir,
}
// First create a service to generate a key
firstService := &JwtService{}
err := firstService.init(mockConfig, tempDir)
err := firstService.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err)
// Get the key ID of the first service
@@ -77,7 +91,7 @@ func TestJwtService_Init(t *testing.T) {
// Now create a new service that should load the existing key
secondService := &JwtService{}
err = secondService.init(mockConfig, tempDir)
err = secondService.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err)
// Verify the loaded key has the same ID as the original
@@ -90,12 +104,19 @@ func TestJwtService_Init(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
// Setup the environment variable required by the token verification
mockEnvConfig := &common.EnvConfigSchema{
AppURL: "https://test.example.com",
KeysStorage: "file",
KeysPath: tempDir,
}
// Create a new JWK and save it to disk
origKeyID := createECDSAKeyJWK(t, tempDir)
// Now create a new service that should load the existing key
svc := &JwtService{}
err := svc.init(mockConfig, tempDir)
err := svc.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err)
// Ensure loaded key has the right algorithm
@@ -113,12 +134,19 @@ func TestJwtService_Init(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
// Setup the environment variable required by the token verification
mockEnvConfig := &common.EnvConfigSchema{
AppURL: "https://test.example.com",
KeysStorage: "file",
KeysPath: tempDir,
}
// Create a new JWK and save it to disk
origKeyID := createEdDSAKeyJWK(t, tempDir)
// Now create a new service that should load the existing key
svc := &JwtService{}
err := svc.init(mockConfig, tempDir)
err := svc.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err)
// Ensure loaded key has the right algorithm and curve
@@ -147,9 +175,16 @@ func TestJwtService_GetPublicJWK(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
// Setup the environment variable required by the token verification
mockEnvConfig := &common.EnvConfigSchema{
AppURL: "https://test.example.com",
KeysStorage: "file",
KeysPath: tempDir,
}
// Create a JWT service with initialized key
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Get the JWK (public key)
@@ -178,12 +213,19 @@ func TestJwtService_GetPublicJWK(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
// Setup the environment variable required by the token verification
mockEnvConfig := &common.EnvConfigSchema{
AppURL: "https://test.example.com",
KeysStorage: "file",
KeysPath: tempDir,
}
// Create an ECDSA key and save it as JWK
originalKeyID := createECDSAKeyJWK(t, tempDir)
// Create a JWT service that loads the ECDSA key
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Get the JWK (public key)
@@ -216,12 +258,19 @@ func TestJwtService_GetPublicJWK(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
// Setup the environment variable required by the token verification
mockEnvConfig := &common.EnvConfigSchema{
AppURL: "https://test.example.com",
KeysStorage: "file",
KeysPath: tempDir,
}
// Create an EdDSA key and save it as JWK
originalKeyID := createEdDSAKeyJWK(t, tempDir)
// Create a JWT service that loads the EdDSA key
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Get the JWK (public key)
@@ -276,16 +325,16 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
})
// Setup the environment variable required by the token verification
originalAppURL := common.EnvConfig.AppURL
common.EnvConfig.AppURL = "https://test.example.com"
defer func() {
common.EnvConfig.AppURL = originalAppURL
}()
mockEnvConfig := &common.EnvConfigSchema{
AppURL: "https://test.example.com",
KeysStorage: "file",
KeysPath: tempDir,
}
t.Run("generates token for regular user", func(t *testing.T) {
// Create a JWT service
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Create a test user
@@ -328,7 +377,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
t.Run("generates token for admin user", func(t *testing.T) {
// Create a JWT service
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Create a test admin user
@@ -364,7 +413,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
})
service := &JwtService{}
err := service.init(customMockConfig, tempDir)
err := service.init(nil, customMockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Create a test user
@@ -399,7 +448,10 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
// Create a JWT service that loads the key
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: tempDir,
})
require.NoError(t, err, "Failed to initialize JWT service")
// Verify it loaded the right key
@@ -453,7 +505,10 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
// Create a JWT service that loads the key
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: tempDir,
})
require.NoError(t, err, "Failed to initialize JWT service")
// Verify it loaded the right key
@@ -507,7 +562,10 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
// Create a JWT service that loads the key
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: tempDir,
})
require.NoError(t, err, "Failed to initialize JWT service")
// Verify it loaded the right key
@@ -563,16 +621,16 @@ func TestGenerateVerifyIdToken(t *testing.T) {
})
// Setup the environment variable required by the token verification
originalAppURL := common.EnvConfig.AppURL
common.EnvConfig.AppURL = "https://test.example.com"
defer func() {
common.EnvConfig.AppURL = originalAppURL
}()
mockEnvConfig := &common.EnvConfigSchema{
AppURL: "https://test.example.com",
KeysStorage: "file",
KeysPath: tempDir,
}
t.Run("generates and verifies ID token with standard claims", func(t *testing.T) {
// Create a JWT service
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Create test claims
@@ -601,7 +659,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
assert.Equal(t, []string{clientID}, audience, "Audience should contain the client ID")
issuer, ok := claims.Issuer()
_ = assert.True(t, ok, "Issuer not found in token") &&
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
// Check token expiration time is approximately 1 hour from now
expectedExp := time.Now().Add(1 * time.Hour)
@@ -614,7 +672,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
t.Run("can accept expired tokens if told so", func(t *testing.T) {
// Create a JWT service
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Create test claims
@@ -628,7 +686,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
// Create a token that's already expired
token, err := jwt.NewBuilder().
Subject(userClaims["sub"].(string)).
Issuer(common.EnvConfig.AppURL).
Issuer(service.envConfig.AppURL).
Audience([]string{clientID}).
IssuedAt(time.Now().Add(-2 * time.Hour)).
Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago
@@ -666,13 +724,13 @@ func TestGenerateVerifyIdToken(t *testing.T) {
assert.Equal(t, userClaims["sub"], subject, "Token subject should match user ID")
issuer, ok := claims.Issuer()
_ = assert.True(t, ok, "Issuer not found in token") &&
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
})
t.Run("generates and verifies ID token with nonce", func(t *testing.T) {
// Create a JWT service
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Create test claims with nonce
@@ -703,7 +761,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
t.Run("fails verification with incorrect issuer", func(t *testing.T) {
// Create a JWT service
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Generate a token with standard claims
@@ -714,7 +772,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
require.NoError(t, err, "Failed to generate ID token")
// Temporarily change the app URL to simulate wrong issuer
common.EnvConfig.AppURL = "https://wrong-issuer.com"
service.envConfig.AppURL = "https://wrong-issuer.com"
// Verify should fail due to issuer mismatch
_, err = service.VerifyIdToken(tokenString, false)
@@ -731,7 +789,10 @@ func TestGenerateVerifyIdToken(t *testing.T) {
// Create a JWT service that loads the key
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: tempDir,
})
require.NoError(t, err, "Failed to initialize JWT service")
// Verify it loaded the right key
@@ -762,7 +823,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
assert.Equal(t, "eddsauser456", subject, "Token subject should match user ID")
issuer, ok := claims.Issuer()
_ = assert.True(t, ok, "Issuer not found in token") &&
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
// Verify the key type is OKP
publicKey, err := service.GetPublicJWK()
@@ -784,7 +845,10 @@ func TestGenerateVerifyIdToken(t *testing.T) {
// Create a JWT service that loads the key
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: tempDir,
})
require.NoError(t, err, "Failed to initialize JWT service")
// Verify it loaded the right key
@@ -795,7 +859,6 @@ func TestGenerateVerifyIdToken(t *testing.T) {
// Create test claims
userClaims := map[string]interface{}{
"sub": "ecdsauser456",
"name": "ECDSA User",
"email": "ecdsauser@example.com",
}
const clientID = "ecdsa-client-123"
@@ -815,7 +878,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
assert.Equal(t, "ecdsauser456", subject, "Token subject should match user ID")
issuer, ok := claims.Issuer()
_ = assert.True(t, ok, "Issuer not found in token") &&
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
// Verify the key type is EC
publicKey, err := service.GetPublicJWK()
@@ -837,7 +900,10 @@ func TestGenerateVerifyIdToken(t *testing.T) {
// Create a JWT service that loads the key
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: tempDir,
})
require.NoError(t, err, "Failed to initialize JWT service")
// Verify it loaded the right key
@@ -868,17 +934,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
assert.Equal(t, "rsauser456", subject, "Token subject should match user ID")
issuer, ok := claims.Issuer()
_ = assert.True(t, ok, "Issuer not found in token") &&
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
// Verify the key type is RSA
publicKey, err := service.GetPublicJWK()
require.NoError(t, err)
assert.Equal(t, jwa.RSA().String(), publicKey.KeyType().String(), "Key type should be RSA")
// Verify the algorithm is RS256
alg, ok := publicKey.Algorithm()
require.True(t, ok)
assert.Equal(t, jwa.RS256().String(), alg.String(), "Algorithm should be RS256")
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
})
}
@@ -892,16 +948,16 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
})
// Setup the environment variable required by the token verification
originalAppURL := common.EnvConfig.AppURL
common.EnvConfig.AppURL = "https://test.example.com"
defer func() {
common.EnvConfig.AppURL = originalAppURL
}()
mockEnvConfig := &common.EnvConfigSchema{
AppURL: "https://test.example.com",
KeysStorage: "file",
KeysPath: tempDir,
}
t.Run("generates and verifies OAuth access token with standard claims", func(t *testing.T) {
// Create a JWT service
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Create a test user
@@ -931,7 +987,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
assert.Equal(t, []string{clientID}, audience, "Audience should contain the client ID")
issuer, ok := claims.Issuer()
_ = assert.True(t, ok, "Issuer not found in token") &&
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
// Check token expiration time is approximately 1 hour from now
expectedExp := time.Now().Add(1 * time.Hour)
@@ -944,7 +1000,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
t.Run("fails verification for expired token", func(t *testing.T) {
// Create a JWT service with a mock function to generate an expired token
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Create a test user
@@ -961,7 +1017,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago
IssuedAt(time.Now().Add(-2 * time.Hour)).
Audience([]string{clientID}).
Issuer(common.EnvConfig.AppURL).
Issuer(service.envConfig.AppURL).
Build()
require.NoError(t, err, "Failed to build token")
@@ -980,11 +1036,17 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
t.Run("fails verification with invalid signature", func(t *testing.T) {
// Create two JWT services with different keys
service1 := &JwtService{}
err := service1.init(mockConfig, t.TempDir()) // Use a different temp dir
err := service1.init(nil, mockConfig, &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: t.TempDir(), // Use a different temp dir
})
require.NoError(t, err, "Failed to initialize first JWT service")
service2 := &JwtService{}
err = service2.init(mockConfig, t.TempDir()) // Use a different temp dir
err = service2.init(nil, mockConfig, &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: t.TempDir(), // Use a different temp dir
})
require.NoError(t, err, "Failed to initialize second JWT service")
// Create a test user
@@ -1014,7 +1076,10 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
// Create a JWT service that loads the key
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: tempDir,
})
require.NoError(t, err, "Failed to initialize JWT service")
// Verify it loaded the right key
@@ -1068,7 +1133,10 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
// Create a JWT service that loads the key
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: tempDir,
})
require.NoError(t, err, "Failed to initialize JWT service")
// Verify it loaded the right key
@@ -1122,7 +1190,10 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
// Create a JWT service that loads the key
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: tempDir,
})
require.NoError(t, err, "Failed to initialize JWT service")
// Verify it loaded the right key
@@ -1176,16 +1247,16 @@ func TestGenerateVerifyOAuthRefreshToken(t *testing.T) {
mockConfig := NewTestAppConfigService(&model.AppConfig{})
// Setup the environment variable required by the token verification
originalAppURL := common.EnvConfig.AppURL
common.EnvConfig.AppURL = "https://test.example.com"
defer func() {
common.EnvConfig.AppURL = originalAppURL
}()
mockEnvConfig := &common.EnvConfigSchema{
AppURL: "https://test.example.com",
KeysStorage: "file",
KeysPath: tempDir,
}
t.Run("generates and verifies refresh token", func(t *testing.T) {
// Create a JWT service
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Create a test user
@@ -1211,7 +1282,7 @@ func TestGenerateVerifyOAuthRefreshToken(t *testing.T) {
t.Run("fails verification for expired token", func(t *testing.T) {
// Create a JWT service
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Generate a token using JWT directly to create an expired token
@@ -1220,7 +1291,7 @@ func TestGenerateVerifyOAuthRefreshToken(t *testing.T) {
Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago
IssuedAt(time.Now().Add(-2 * time.Hour)).
Audience([]string{"client123"}).
Issuer(common.EnvConfig.AppURL).
Issuer(service.envConfig.AppURL).
Build()
require.NoError(t, err, "Failed to build token")
@@ -1236,11 +1307,17 @@ func TestGenerateVerifyOAuthRefreshToken(t *testing.T) {
t.Run("fails verification with invalid signature", func(t *testing.T) {
// Create two JWT services with different keys
service1 := &JwtService{}
err := service1.init(mockConfig, t.TempDir())
err := service1.init(nil, mockConfig, &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: t.TempDir(), // Use a different temp dir
})
require.NoError(t, err, "Failed to initialize first JWT service")
service2 := &JwtService{}
err = service2.init(mockConfig, t.TempDir())
err = service2.init(nil, mockConfig, &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: t.TempDir(), // Use a different temp dir
})
require.NoError(t, err, "Failed to initialize second JWT service")
// Generate a token with the first service
@@ -1308,7 +1385,10 @@ func TestGetTokenType(t *testing.T) {
// Initialize the JWT service
mockConfig := NewTestAppConfigService(&model.AppConfig{})
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: tempDir,
})
require.NoError(t, err, "Failed to initialize JWT service")
buildTokenForType := func(t *testing.T, typ string, setClaimsFn func(b *jwt.Builder)) string {
@@ -1402,10 +1482,19 @@ func TestGetTokenType(t *testing.T) {
func importKey(t *testing.T, privateKeyRaw any, path string) string {
t.Helper()
privateKey, err := utils.ImportRawKey(privateKeyRaw)
privateKey, err := jwkutils.ImportRawKey(privateKeyRaw, "", "")
require.NoError(t, err, "Failed to import private key")
err = SaveKeyJWK(privateKey, filepath.Join(path, PrivateKeyFile))
keyProvider := &jwkutils.KeyProviderFile{}
err = keyProvider.Init(jwkutils.KeyProviderOpts{
EnvConfig: &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: path,
},
})
require.NoError(t, err, "Failed to init file key provider")
err = keyProvider.SaveKey(privateKey)
require.NoError(t, err, "Failed to save key")
kid, _ := privateKey.KeyID()

View File

@@ -18,6 +18,7 @@ import (
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
)
// generateTestECDSAKey creates an ECDSA key for testing
@@ -62,12 +63,12 @@ func TestOidcService_jwkSetForURL(t *testing.T) {
)
mockResponses := map[string]*http.Response{
//nolint:bodyclose
url1: NewMockResponse(http.StatusOK, string(jwkSetJSON1)),
url1: testutils.NewMockResponse(http.StatusOK, string(jwkSetJSON1)),
//nolint:bodyclose
url2: NewMockResponse(http.StatusOK, string(jwkSetJSON2)),
url2: testutils.NewMockResponse(http.StatusOK, string(jwkSetJSON2)),
}
httpClient := &http.Client{
Transport: &MockRoundTripper{
Transport: &testutils.MockRoundTripper{
Responses: mockResponses,
},
}
@@ -139,7 +140,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
var err error
// Create a test database
db := newDatabaseForTest(t)
db := testutils.NewDatabaseForTest(t)
// Create two JWKs for testing
privateJWK, jwkSetJSON := generateTestECDSAKey(t)
@@ -149,12 +150,12 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
// Create a mock HTTP client with custom transport to return the JWKS
httpClient := &http.Client{
Transport: &MockRoundTripper{
Transport: &testutils.MockRoundTripper{
Responses: map[string]*http.Response{
//nolint:bodyclose
federatedClientIssuer + "/jwks.json": NewMockResponse(http.StatusOK, string(jwkSetJSON)),
federatedClientIssuer + "/jwks.json": testutils.NewMockResponse(http.StatusOK, string(jwkSetJSON)),
//nolint:bodyclose
federatedClientIssuerDefaults + ".well-known/jwks.json": NewMockResponse(http.StatusOK, string(jwkSetJSONDefaults)),
federatedClientIssuerDefaults + ".well-known/jwks.json": testutils.NewMockResponse(http.StatusOK, string(jwkSetJSONDefaults)),
},
},
}

View File

@@ -0,0 +1,69 @@
package crypto
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"errors"
"fmt"
"io"
)
// ErrDecrypt is returned by Decrypt when the operation failed for any reason
var ErrDecrypt = errors.New("failed to decrypt data")
// Encrypt a byte slice using AES-GCM and a random nonce
// Important: do not encrypt more than ~4 billion messages with the same key!
func Encrypt(key []byte, plaintext []byte, associatedData []byte) (ciphertext []byte, err error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("failed to create block cipher: %w", err)
}
aead, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create AEAD cipher: %w", err)
}
// Generate a random nonce
nonce := make([]byte, aead.NonceSize())
_, err = io.ReadFull(rand.Reader, nonce)
if err != nil {
return nil, fmt.Errorf("failed to generate random nonce: %w", err)
}
// Allocate the slice for the result, with additional space for the nonce and overhead
ciphertext = make([]byte, 0, len(plaintext)+aead.NonceSize()+aead.Overhead())
ciphertext = append(ciphertext, nonce...)
// Encrypt the plaintext
// Tag is automatically added at the end
ciphertext = aead.Seal(ciphertext, nonce, plaintext, associatedData)
return ciphertext, nil
}
// Decrypt a byte slice using AES-GCM
func Decrypt(key []byte, ciphertext []byte, associatedData []byte) (plaintext []byte, err error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("failed to create block cipher: %w", err)
}
aead, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create AEAD cipher: %w", err)
}
// Extract the nonce
if len(ciphertext) < (aead.NonceSize() + aead.Overhead()) {
return nil, ErrDecrypt
}
// Decrypt the data
plaintext, err = aead.Open(nil, ciphertext[:aead.NonceSize()], ciphertext[aead.NonceSize():], associatedData)
if err != nil {
// Note: we do not return the exact error here, to avoid disclosing information
return nil, ErrDecrypt
}
return plaintext, nil
}

View File

@@ -0,0 +1,208 @@
package crypto
import (
"crypto/rand"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestEncryptDecrypt(t *testing.T) {
tests := []struct {
name string
keySize int
plaintext string
associatedData []byte
}{
{
name: "AES-128 with short plaintext",
keySize: 16,
plaintext: "Hello, World!",
associatedData: []byte("test-aad"),
},
{
name: "AES-192 with medium plaintext",
keySize: 24,
plaintext: "This is a longer message to test encryption and decryption",
associatedData: []byte("associated-data-192"),
},
{
name: "AES-256 with unicode",
keySize: 32,
plaintext: "Hello 世界! 🌍 Testing unicode characters", //nolint:gosmopolitan
associatedData: []byte("unicode-test"),
},
{
name: "No associated data",
keySize: 32,
plaintext: "Testing without associated data",
associatedData: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Generate random key
key := make([]byte, tt.keySize)
_, err := rand.Read(key)
require.NoError(t, err, "Failed to generate random key")
plaintext := []byte(tt.plaintext)
// Test encryption
ciphertext, err := Encrypt(key, plaintext, tt.associatedData)
require.NoError(t, err, "Encrypt should succeed")
// Verify ciphertext is different from plaintext (unless empty)
if len(plaintext) > 0 {
assert.NotEqual(t, plaintext, ciphertext)
}
// Test decryption
decrypted, err := Decrypt(key, ciphertext, tt.associatedData)
require.NoError(t, err, "Decrypt should succeed")
// Verify decrypted text matches original
assert.Equal(t, plaintext, decrypted, "Decrypted text should match original")
})
}
}
func TestEncryptWithInvalidKeySize(t *testing.T) {
invalidKeySizes := []int{8, 12, 33, 47, 55, 128}
for _, keySize := range invalidKeySizes {
t.Run(fmt.Sprintf("Key size %d", keySize), func(t *testing.T) {
key := make([]byte, keySize)
plaintext := []byte("test message")
_, err := Encrypt(key, plaintext, nil)
require.Error(t, err)
assert.ErrorContains(t, err, "invalid key size")
})
}
}
func TestDecryptWithInvalidKeySize(t *testing.T) {
invalidKeySizes := []int{8, 12, 33, 47, 55, 128}
for _, keySize := range invalidKeySizes {
t.Run(fmt.Sprintf("Key size %d", keySize), func(t *testing.T) {
key := make([]byte, keySize)
ciphertext := []byte("fake ciphertext")
_, err := Decrypt(key, ciphertext, nil)
require.Error(t, err)
assert.ErrorContains(t, err, "invalid key size")
})
}
}
func TestDecryptWithInvalidCiphertext(t *testing.T) {
key := make([]byte, 32)
_, err := rand.Read(key)
require.NoError(t, err, "Failed to generate random key")
tests := []struct {
name string
ciphertext []byte
}{
{
name: "empty ciphertext",
ciphertext: []byte{},
},
{
name: "too short ciphertext",
ciphertext: []byte("short"),
},
{
name: "random invalid data",
ciphertext: []byte("this is not valid encrypted data"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := Decrypt(key, tt.ciphertext, nil)
require.Error(t, err)
require.ErrorIs(t, err, ErrDecrypt)
})
}
}
func TestDecryptWithWrongKey(t *testing.T) {
// Generate two different keys
key1 := make([]byte, 32)
key2 := make([]byte, 32)
_, err := rand.Read(key1)
require.NoError(t, err)
_, err = rand.Read(key2)
require.NoError(t, err)
plaintext := []byte("secret message")
// Encrypt with key1
ciphertext, err := Encrypt(key1, plaintext, nil)
require.NoError(t, err)
// Try to decrypt with key2
_, err = Decrypt(key2, ciphertext, nil)
require.Error(t, err)
require.ErrorIs(t, err, ErrDecrypt)
}
func TestDecryptWithWrongAssociatedData(t *testing.T) {
key := make([]byte, 32)
_, err := rand.Read(key)
require.NoError(t, err, "Failed to generate random key")
plaintext := []byte("secret message")
correctAAD := []byte("correct-aad")
wrongAAD := []byte("wrong-aad")
// Encrypt with correct AAD
ciphertext, err := Encrypt(key, plaintext, correctAAD)
require.NoError(t, err)
// Try to decrypt with wrong AAD
_, err = Decrypt(key, ciphertext, wrongAAD)
require.Error(t, err)
require.ErrorIs(t, err, ErrDecrypt)
// Verify correct AAD works
decrypted, err := Decrypt(key, ciphertext, correctAAD)
require.NoError(t, err)
assert.Equal(t, plaintext, decrypted, "Decrypted text should match original when using correct AAD")
}
func TestEncryptDecryptConsistency(t *testing.T) {
key := make([]byte, 32)
_, err := rand.Read(key)
require.NoError(t, err)
plaintext := []byte("consistency test message")
associatedData := []byte("test-aad")
// Encrypt multiple times and verify we get different ciphertexts (due to random IV)
ciphertext1, err := Encrypt(key, plaintext, associatedData)
require.NoError(t, err)
ciphertext2, err := Encrypt(key, plaintext, associatedData)
require.NoError(t, err)
// Ciphertexts should be different (due to random IV)
assert.NotEqual(t, ciphertext1, ciphertext2, "Multiple encryptions of same plaintext should produce different ciphertexts")
// Both should decrypt to the same plaintext
decrypted1, err := Decrypt(key, ciphertext1, associatedData)
require.NoError(t, err)
decrypted2, err := Decrypt(key, ciphertext2, associatedData)
require.NoError(t, err)
assert.Equal(t, plaintext, decrypted1, "First decrypted text should match original")
assert.Equal(t, plaintext, decrypted2, "Second decrypted text should match original")
assert.Equal(t, decrypted1, decrypted2, "Both decrypted texts should be identical")
}

View File

@@ -0,0 +1,50 @@
package jwk
import (
"fmt"
"github.com/lestrrat-go/jwx/v3/jwk"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common"
)
type KeyProviderOpts struct {
EnvConfig *common.EnvConfigSchema
DB *gorm.DB
Kek []byte
}
type KeyProvider interface {
Init(opts KeyProviderOpts) error
LoadKey() (jwk.Key, error)
SaveKey(key jwk.Key) error
}
func GetKeyProvider(db *gorm.DB, envConfig *common.EnvConfigSchema, instanceID string) (keyProvider KeyProvider, err error) {
// Load the encryption key (KEK) if present
kek, err := LoadKeyEncryptionKey(envConfig, instanceID)
if err != nil {
return nil, fmt.Errorf("failed to load encryption key: %w", err)
}
// Get the key provider
switch envConfig.KeysStorage {
case "file", "":
keyProvider = &KeyProviderFile{}
case "database":
keyProvider = &KeyProviderDatabase{}
default:
return nil, fmt.Errorf("invalid key storage '%s'", envConfig.KeysStorage)
}
err = keyProvider.Init(KeyProviderOpts{
DB: db,
EnvConfig: envConfig,
Kek: kek,
})
if err != nil {
return nil, fmt.Errorf("failed to init key provider of type '%s': %w", envConfig.KeysStorage, err)
}
return keyProvider, nil
}

View File

@@ -0,0 +1,109 @@
package jwk
import (
"context"
"encoding/base64"
"errors"
"fmt"
"time"
"github.com/lestrrat-go/jwx/v3/jwk"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/model"
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
)
const PrivateKeyDBKey = "jwt_private_key.json"
type KeyProviderDatabase struct {
db *gorm.DB
kek []byte
}
func (f *KeyProviderDatabase) Init(opts KeyProviderOpts) error {
if len(opts.Kek) == 0 {
return errors.New("an encryption key is required when using the 'database' key provider")
}
f.db = opts.DB
f.kek = opts.Kek
return nil
}
func (f *KeyProviderDatabase) LoadKey() (key jwk.Key, err error) {
row := model.KV{
Key: PrivateKeyDBKey,
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
err = f.db.WithContext(ctx).First(&row).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
// Key not present in the database - return nil so a new one can be generated
return nil, nil
} else if err != nil {
return nil, fmt.Errorf("failed to retrieve private key from the database: %w", err)
}
if row.Value == nil || *row.Value == "" {
// Key not present in the database - return nil so a new one can be generated
return nil, nil
}
// Decode from base64
enc, err := base64.StdEncoding.DecodeString(*row.Value)
if err != nil {
return nil, fmt.Errorf("failed to read encrypted private key: not a valid base64-encoded value: %w", err)
}
// Decrypt the data
data, err := cryptoutils.Decrypt(f.kek, enc, nil)
if err != nil {
return nil, fmt.Errorf("failed to decrypt private key: %w", err)
}
// Parse the key
key, err = jwk.ParseKey(data)
if err != nil {
return nil, fmt.Errorf("failed to parse encrypted private key: %w", err)
}
return key, nil
}
func (f *KeyProviderDatabase) SaveKey(key jwk.Key) error {
// Encode the key to JSON
data, err := EncodeJWKBytes(key)
if err != nil {
return fmt.Errorf("failed to encode key to JSON: %w", err)
}
// Encrypt the key then encode to Base64
enc, err := cryptoutils.Encrypt(f.kek, data, nil)
if err != nil {
return fmt.Errorf("failed to encrypt key: %w", err)
}
encB64 := base64.StdEncoding.EncodeToString(enc)
// Save to database
row := model.KV{
Key: PrivateKeyDBKey,
Value: &encB64,
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
err = f.db.WithContext(ctx).Create(&row).Error
if err != nil {
// There's one scenario where if Pocket ID is started fresh with more than 1 replica, they both could be trying to create the private key in the database at the same time
// In this case, only one of the replicas will succeed; the other one(s) will return an error here, which will cascade down and cause the replica(s) to crash and be restarted (at that point they'll load the then-existing key from the database)
return fmt.Errorf("failed to store private key in database: %w", err)
}
return nil
}
// Compile-time interface check
var _ KeyProvider = (*KeyProviderDatabase)(nil)

View File

@@ -0,0 +1,275 @@
package jwk
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"encoding/base64"
"testing"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pocket-id/pocket-id/backend/internal/model"
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
)
func TestKeyProviderDatabase_Init(t *testing.T) {
t.Run("Init fails when KEK is not provided", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
provider := &KeyProviderDatabase{}
err := provider.Init(KeyProviderOpts{
DB: db,
Kek: nil, // No KEK
})
require.Error(t, err, "Expected error when KEK is not provided")
require.ErrorContains(t, err, "encryption key is required")
})
t.Run("Init succeeds with KEK", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
provider := &KeyProviderDatabase{}
err := provider.Init(KeyProviderOpts{
DB: db,
Kek: generateTestKEK(t),
})
require.NoError(t, err, "Expected no error when KEK is provided")
})
}
func TestKeyProviderDatabase_LoadKey(t *testing.T) {
// Generate a test key to use in our tests
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
key, err := jwk.Import(pk)
require.NoError(t, err)
t.Run("LoadKey with no existing key", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
kek := generateTestKEK(t)
provider := &KeyProviderDatabase{}
err := provider.Init(KeyProviderOpts{
DB: db,
Kek: kek,
})
require.NoError(t, err)
// Load key when none exists
loadedKey, err := provider.LoadKey()
require.NoError(t, err)
assert.Nil(t, loadedKey, "Expected nil key when no key exists in database")
})
t.Run("LoadKey with existing key", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
kek := generateTestKEK(t)
provider := &KeyProviderDatabase{}
err := provider.Init(KeyProviderOpts{
DB: db,
Kek: kek,
})
require.NoError(t, err)
// Save a key
err = provider.SaveKey(key)
require.NoError(t, err)
// Load the key
loadedKey, err := provider.LoadKey()
require.NoError(t, err)
assert.NotNil(t, loadedKey, "Expected non-nil key when key exists in database")
// Verify the loaded key is the same as the original
keyBytes, err := EncodeJWKBytes(key)
require.NoError(t, err)
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
require.NoError(t, err)
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key")
})
t.Run("LoadKey with invalid base64", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
kek := generateTestKEK(t)
provider := &KeyProviderDatabase{}
err := provider.Init(KeyProviderOpts{
DB: db,
Kek: kek,
})
require.NoError(t, err)
// Insert invalid base64 data
invalidBase64 := "not-valid-base64"
err = db.Create(&model.KV{
Key: PrivateKeyDBKey,
Value: &invalidBase64,
}).Error
require.NoError(t, err)
// Attempt to load the key
loadedKey, err := provider.LoadKey()
require.Error(t, err, "Expected error when loading key with invalid base64")
require.ErrorContains(t, err, "not a valid base64-encoded value")
assert.Nil(t, loadedKey, "Expected nil key when loading fails")
})
t.Run("LoadKey with invalid encrypted data", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
kek := generateTestKEK(t)
provider := &KeyProviderDatabase{}
err := provider.Init(KeyProviderOpts{
DB: db,
Kek: kek,
})
require.NoError(t, err)
// Insert valid base64 but invalid encrypted data
invalidData := base64.StdEncoding.EncodeToString([]byte("not-valid-encrypted-data"))
err = db.Create(&model.KV{
Key: PrivateKeyDBKey,
Value: &invalidData,
}).Error
require.NoError(t, err)
// Attempt to load the key
loadedKey, err := provider.LoadKey()
require.Error(t, err, "Expected error when loading key with invalid encrypted data")
require.ErrorContains(t, err, "failed to decrypt")
assert.Nil(t, loadedKey, "Expected nil key when loading fails")
})
t.Run("LoadKey with valid encrypted data but wrong KEK", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
originalKek := generateTestKEK(t)
// Save a key with the original KEK
originalProvider := &KeyProviderDatabase{}
err := originalProvider.Init(KeyProviderOpts{
DB: db,
Kek: originalKek,
})
require.NoError(t, err)
err = originalProvider.SaveKey(key)
require.NoError(t, err)
// Now try to load with a different KEK
differentKek := generateTestKEK(t)
differentProvider := &KeyProviderDatabase{}
err = differentProvider.Init(KeyProviderOpts{
DB: db,
Kek: differentKek,
})
require.NoError(t, err)
// Attempt to load the key with the wrong KEK
loadedKey, err := differentProvider.LoadKey()
require.Error(t, err, "Expected error when loading key with wrong KEK")
require.ErrorContains(t, err, "failed to decrypt")
assert.Nil(t, loadedKey, "Expected nil key when loading fails")
})
t.Run("LoadKey with invalid key data", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
kek := generateTestKEK(t)
provider := &KeyProviderDatabase{}
err := provider.Init(KeyProviderOpts{
DB: db,
Kek: kek,
})
require.NoError(t, err)
// Create invalid key data (valid JSON but not a valid JWK)
invalidKeyData := []byte(`{"not": "a valid jwk"}`)
// Encrypt the invalid key data
encryptedData, err := cryptoutils.Encrypt(kek, invalidKeyData, nil)
require.NoError(t, err)
// Base64 encode the encrypted data
encodedData := base64.StdEncoding.EncodeToString(encryptedData)
// Save to database
err = db.Create(&model.KV{
Key: PrivateKeyDBKey,
Value: &encodedData,
}).Error
require.NoError(t, err)
// Attempt to load the key
loadedKey, err := provider.LoadKey()
require.Error(t, err, "Expected error when loading invalid key data")
require.ErrorContains(t, err, "failed to parse")
assert.Nil(t, loadedKey, "Expected nil key when loading fails")
})
}
func TestKeyProviderDatabase_SaveKey(t *testing.T) {
// Generate a test key to use in our tests
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
key, err := jwk.Import(pk)
require.NoError(t, err)
t.Run("SaveKey and verify database record", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
kek := generateTestKEK(t)
provider := &KeyProviderDatabase{}
err := provider.Init(KeyProviderOpts{
DB: db,
Kek: kek,
})
require.NoError(t, err)
// Save the key
err = provider.SaveKey(key)
require.NoError(t, err, "Expected no error when saving key")
// Verify record exists in database
var kv model.KV
err = db.Where("key = ?", PrivateKeyDBKey).First(&kv).Error
require.NoError(t, err, "Expected to find key in database")
require.NotNil(t, kv.Value, "Expected non-nil value in database")
assert.NotEmpty(t, *kv.Value, "Expected non-empty value in database")
// Decode and decrypt to verify content
encBytes, err := base64.StdEncoding.DecodeString(*kv.Value)
require.NoError(t, err, "Expected valid base64 encoding")
decBytes, err := cryptoutils.Decrypt(kek, encBytes, nil)
require.NoError(t, err, "Expected valid encrypted data")
parsedKey, err := jwk.ParseKey(decBytes)
require.NoError(t, err, "Expected valid JWK data")
// Compare keys
keyBytes, err := EncodeJWKBytes(key)
require.NoError(t, err)
parsedKeyBytes, err := EncodeJWKBytes(parsedKey)
require.NoError(t, err)
assert.Equal(t, keyBytes, parsedKeyBytes, "Expected saved key to match original key")
})
}
func generateTestKEK(t *testing.T) []byte {
t.Helper()
// Generate a 32-byte kek
kek := make([]byte, 32)
_, err := rand.Read(kek)
require.NoError(t, err)
return kek
}

View File

@@ -0,0 +1,202 @@
package jwk
import (
"encoding/base64"
"fmt"
"os"
"path/filepath"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/utils"
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
)
const (
// PrivateKeyFile is the path in the data/keys folder where the key is stored
// This is a JSON file containing a key encoded as JWK
PrivateKeyFile = "jwt_private_key.json"
// PrivateKeyFileEncrypted is the path in the data/keys folder where the encrypted key is stored
// This is a encrypted JSON file containing a key encoded as JWK
PrivateKeyFileEncrypted = "jwt_private_key.json.enc"
)
type KeyProviderFile struct {
envConfig *common.EnvConfigSchema
kek []byte
}
func (f *KeyProviderFile) Init(opts KeyProviderOpts) error {
f.envConfig = opts.EnvConfig
f.kek = opts.Kek
return nil
}
func (f *KeyProviderFile) LoadKey() (jwk.Key, error) {
if len(f.kek) > 0 {
return f.loadEncryptedKey()
}
return f.loadKey()
}
func (f *KeyProviderFile) SaveKey(key jwk.Key) error {
if len(f.kek) > 0 {
return f.saveKeyEncrypted(key)
}
return f.saveKey(key)
}
func (f *KeyProviderFile) loadKey() (jwk.Key, error) {
var key jwk.Key
// First, check if we have a JWK file
// If we do, then we just load that
jwkPath := f.jwkPath()
ok, err := utils.FileExists(jwkPath)
if err != nil {
return nil, fmt.Errorf("failed to check if private key file exists at path '%s': %w", jwkPath, err)
}
if !ok {
// File doesn't exist, no key was loaded
return nil, nil
}
data, err := os.ReadFile(jwkPath)
if err != nil {
return nil, fmt.Errorf("failed to read private key file at path '%s': %w", jwkPath, err)
}
key, err = jwk.ParseKey(data)
if err != nil {
return nil, fmt.Errorf("failed to parse private key file at path '%s': %w", jwkPath, err)
}
return key, nil
}
func (f *KeyProviderFile) loadEncryptedKey() (key jwk.Key, err error) {
// First, check if we have an encrypted JWK file
// If we do, then we just load that
encJwkPath := f.encJwkPath()
ok, err := utils.FileExists(encJwkPath)
if err != nil {
return nil, fmt.Errorf("failed to check if encrypted private key file exists at path '%s': %w", encJwkPath, err)
}
if ok {
encB64, err := os.ReadFile(encJwkPath)
if err != nil {
return nil, fmt.Errorf("failed to read encrypted private key file at path '%s': %w", encJwkPath, err)
}
// Decode from base64
enc := make([]byte, base64.StdEncoding.DecodedLen(len(encB64)))
n, err := base64.StdEncoding.Decode(enc, encB64)
if err != nil {
return nil, fmt.Errorf("failed to read encrypted private key file at path '%s': not a valid base64-encoded file: %w", encJwkPath, err)
}
// Decrypt the data
data, err := cryptoutils.Decrypt(f.kek, enc[:n], nil)
if err != nil {
return nil, fmt.Errorf("failed to decrypt private key file at path '%s': %w", encJwkPath, err)
}
// Parse the key
key, err = jwk.ParseKey(data)
if err != nil {
return nil, fmt.Errorf("failed to parse encrypted private key file at path '%s': %w", encJwkPath, err)
}
return key, nil
}
// Check if we have an un-encrypted JWK file
key, err = f.loadKey()
if err != nil {
return nil, fmt.Errorf("failed to load un-encrypted key file: %w", err)
}
if key == nil {
// No key exists, encrypted or un-encrypted
return nil, nil
}
// If we are here, we have loaded a key that was un-encrypted
// We need to replace the plaintext key with the encrypted one before we return
err = f.saveKeyEncrypted(key)
if err != nil {
return nil, fmt.Errorf("failed to save encrypted key file: %w", err)
}
jwkPath := f.jwkPath()
err = os.Remove(jwkPath)
if err != nil {
return nil, fmt.Errorf("failed to remove un-encrypted key file at path '%s': %w", jwkPath, err)
}
return key, nil
}
func (f *KeyProviderFile) saveKey(key jwk.Key) error {
err := os.MkdirAll(f.envConfig.KeysPath, 0700)
if err != nil {
return fmt.Errorf("failed to create directory '%s' for key file: %w", f.envConfig.KeysPath, err)
}
jwkPath := f.jwkPath()
keyFile, err := os.OpenFile(jwkPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("failed to create key file at path '%s': %w", jwkPath, err)
}
defer keyFile.Close()
// Write the JSON file to disk
err = EncodeJWK(keyFile, key)
if err != nil {
return fmt.Errorf("failed to write key file at path '%s': %w", jwkPath, err)
}
return nil
}
func (f *KeyProviderFile) saveKeyEncrypted(key jwk.Key) error {
err := os.MkdirAll(f.envConfig.KeysPath, 0700)
if err != nil {
return fmt.Errorf("failed to create directory '%s' for encrypted key file: %w", f.envConfig.KeysPath, err)
}
// Encode the key to JSON
data, err := EncodeJWKBytes(key)
if err != nil {
return fmt.Errorf("failed to encode key to JSON: %w", err)
}
// Encrypt the key then encode to Base64
enc, err := cryptoutils.Encrypt(f.kek, data, nil)
if err != nil {
return fmt.Errorf("failed to encrypt key: %w", err)
}
encB64 := make([]byte, base64.StdEncoding.EncodedLen(len(enc)))
base64.StdEncoding.Encode(encB64, enc)
// Write to disk
encJwkPath := f.encJwkPath()
err = os.WriteFile(encJwkPath, encB64, 0600)
if err != nil {
return fmt.Errorf("failed to write encrypted key file at path '%s': %w", encJwkPath, err)
}
return nil
}
func (f *KeyProviderFile) jwkPath() string {
return filepath.Join(f.envConfig.KeysPath, PrivateKeyFile)
}
func (f *KeyProviderFile) encJwkPath() string {
return filepath.Join(f.envConfig.KeysPath, PrivateKeyFileEncrypted)
}
// Compile-time interface check
var _ KeyProvider = (*KeyProviderFile)(nil)

View File

@@ -0,0 +1,320 @@
package jwk
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"encoding/base64"
"os"
"path/filepath"
"testing"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/utils"
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
)
func TestKeyProviderFile_LoadKey(t *testing.T) {
// Generate a test key to use in our tests
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
key, err := jwk.Import(pk)
require.NoError(t, err)
t.Run("LoadKey with no existing key", func(t *testing.T) {
tempDir := t.TempDir()
provider := &KeyProviderFile{}
err := provider.Init(KeyProviderOpts{
EnvConfig: &common.EnvConfigSchema{
KeysPath: tempDir,
},
})
require.NoError(t, err)
// Load key when none exists
loadedKey, err := provider.LoadKey()
require.NoError(t, err)
assert.Nil(t, loadedKey, "Expected nil key when no key exists")
})
t.Run("LoadKey with no existing key (with kek)", func(t *testing.T) {
tempDir := t.TempDir()
provider := &KeyProviderFile{}
err = provider.Init(KeyProviderOpts{
EnvConfig: &common.EnvConfigSchema{
KeysPath: tempDir,
},
Kek: makeKEK(t),
})
require.NoError(t, err)
// Load key when none exists
loadedKey, err := provider.LoadKey()
require.NoError(t, err)
assert.Nil(t, loadedKey, "Expected nil key when no key exists")
})
t.Run("LoadKey with unencrypted key", func(t *testing.T) {
tempDir := t.TempDir()
provider := &KeyProviderFile{}
err := provider.Init(KeyProviderOpts{
EnvConfig: &common.EnvConfigSchema{
KeysPath: tempDir,
},
})
require.NoError(t, err)
// Save a key
err = provider.SaveKey(key)
require.NoError(t, err)
// Make sure the key file exists
keyPath := filepath.Join(tempDir, PrivateKeyFile)
exists, err := utils.FileExists(keyPath)
require.NoError(t, err)
assert.True(t, exists, "Expected key file to exist")
// Load the key
loadedKey, err := provider.LoadKey()
require.NoError(t, err)
assert.NotNil(t, loadedKey, "Expected non-nil key when key exists")
// Verify the loaded key is the same as the original
keyBytes, err := EncodeJWKBytes(key)
require.NoError(t, err)
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
require.NoError(t, err)
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key")
})
t.Run("LoadKey with encrypted key", func(t *testing.T) {
tempDir := t.TempDir()
provider := &KeyProviderFile{}
err = provider.Init(KeyProviderOpts{
EnvConfig: &common.EnvConfigSchema{
KeysPath: tempDir,
},
Kek: makeKEK(t),
})
require.NoError(t, err)
// Save a key (will be encrypted)
err = provider.SaveKey(key)
require.NoError(t, err)
// Make sure the encrypted key file exists
encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted)
exists, err := utils.FileExists(encKeyPath)
require.NoError(t, err)
assert.True(t, exists, "Expected encrypted key file to exist")
// Make sure the unencrypted key file does not exist
keyPath := filepath.Join(tempDir, PrivateKeyFile)
exists, err = utils.FileExists(keyPath)
require.NoError(t, err)
assert.False(t, exists, "Expected unencrypted key file to not exist")
// Load the key
loadedKey, err := provider.LoadKey()
require.NoError(t, err)
assert.NotNil(t, loadedKey, "Expected non-nil key when encrypted key exists")
// Verify the loaded key is the same as the original
keyBytes, err := EncodeJWKBytes(key)
require.NoError(t, err)
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
require.NoError(t, err)
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key")
})
t.Run("LoadKey replaces unencrypted key with encrypted key when kek is provided", func(t *testing.T) {
tempDir := t.TempDir()
// First, create an unencrypted key
providerNoKek := &KeyProviderFile{}
err := providerNoKek.Init(KeyProviderOpts{
EnvConfig: &common.EnvConfigSchema{
KeysPath: tempDir,
},
})
require.NoError(t, err)
// Save an unencrypted key
err = providerNoKek.SaveKey(key)
require.NoError(t, err)
// Verify unencrypted key exists
keyPath := filepath.Join(tempDir, PrivateKeyFile)
exists, err := utils.FileExists(keyPath)
require.NoError(t, err)
assert.True(t, exists, "Expected unencrypted key file to exist")
// Now create a provider with a kek
kek := make([]byte, 32)
_, err = rand.Read(kek)
require.NoError(t, err)
providerWithKek := &KeyProviderFile{}
err = providerWithKek.Init(KeyProviderOpts{
EnvConfig: &common.EnvConfigSchema{
KeysPath: tempDir,
},
Kek: kek,
})
require.NoError(t, err)
// Load the key - this should convert the unencrypted key to encrypted
loadedKey, err := providerWithKek.LoadKey()
require.NoError(t, err)
assert.NotNil(t, loadedKey, "Expected non-nil key when loading and converting key")
// Verify the unencrypted key no longer exists
exists, err = utils.FileExists(keyPath)
require.NoError(t, err)
assert.False(t, exists, "Expected unencrypted key file to be removed")
// Verify the encrypted key file exists
encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted)
exists, err = utils.FileExists(encKeyPath)
require.NoError(t, err)
assert.True(t, exists, "Expected encrypted key file to exist after conversion")
// Verify the key data
keyBytes, err := EncodeJWKBytes(key)
require.NoError(t, err)
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
require.NoError(t, err)
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key after conversion")
})
}
func TestKeyProviderFile_SaveKey(t *testing.T) {
// Generate a test key to use in our tests
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
key, err := jwk.Import(pk)
require.NoError(t, err)
t.Run("SaveKey unencrypted", func(t *testing.T) {
tempDir := t.TempDir()
provider := &KeyProviderFile{}
err := provider.Init(KeyProviderOpts{
EnvConfig: &common.EnvConfigSchema{
KeysPath: tempDir,
},
})
require.NoError(t, err)
// Save the key
err = provider.SaveKey(key)
require.NoError(t, err)
// Verify the key file exists
keyPath := filepath.Join(tempDir, PrivateKeyFile)
exists, err := utils.FileExists(keyPath)
require.NoError(t, err)
assert.True(t, exists, "Expected key file to exist")
// Verify the content of the key file
data, err := os.ReadFile(keyPath)
require.NoError(t, err)
parsedKey, err := jwk.ParseKey(data)
require.NoError(t, err)
// Compare the saved key with the original
keyBytes, err := EncodeJWKBytes(key)
require.NoError(t, err)
parsedKeyBytes, err := EncodeJWKBytes(parsedKey)
require.NoError(t, err)
assert.Equal(t, keyBytes, parsedKeyBytes, "Expected saved key to match original key")
})
t.Run("SaveKey encrypted", func(t *testing.T) {
tempDir := t.TempDir()
// Generate a 64-byte kek
kek := makeKEK(t)
provider := &KeyProviderFile{}
err = provider.Init(KeyProviderOpts{
EnvConfig: &common.EnvConfigSchema{
KeysPath: tempDir,
},
Kek: kek,
})
require.NoError(t, err)
// Save the key (will be encrypted)
err = provider.SaveKey(key)
require.NoError(t, err)
// Verify the encrypted key file exists
encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted)
exists, err := utils.FileExists(encKeyPath)
require.NoError(t, err)
assert.True(t, exists, "Expected encrypted key file to exist")
// Verify the unencrypted key file doesn't exist
keyPath := filepath.Join(tempDir, PrivateKeyFile)
exists, err = utils.FileExists(keyPath)
require.NoError(t, err)
assert.False(t, exists, "Expected unencrypted key file to not exist")
// Manually decrypt the encrypted key file to verify it contains the correct key
encB64, err := os.ReadFile(encKeyPath)
require.NoError(t, err)
// Decode from base64
enc := make([]byte, base64.StdEncoding.DecodedLen(len(encB64)))
n, err := base64.StdEncoding.Decode(enc, encB64)
require.NoError(t, err)
enc = enc[:n] // Trim any padding
// Decrypt the data
data, err := cryptoutils.Decrypt(kek, enc, nil)
require.NoError(t, err)
// Parse the key
parsedKey, err := jwk.ParseKey(data)
require.NoError(t, err)
// Compare the decrypted key with the original
keyBytes, err := EncodeJWKBytes(key)
require.NoError(t, err)
parsedKeyBytes, err := EncodeJWKBytes(parsedKey)
require.NoError(t, err)
assert.Equal(t, keyBytes, parsedKeyBytes, "Expected decrypted key to match original key")
})
}
func makeKEK(t *testing.T) []byte {
t.Helper()
// Generate a 32-byte kek
kek := make([]byte, 32)
_, err := rand.Read(kek)
require.NoError(t, err)
return kek
}

View File

@@ -0,0 +1,180 @@
package jwk
import (
"bytes"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/hmac"
"crypto/rand"
"crypto/rsa"
"crypto/sha3"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"hash"
"io"
"os"
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/pocket-id/pocket-id/backend/internal/common"
)
const (
// KeyUsageSigning is the usage for the private keys, for the "use" property
KeyUsageSigning = "sig"
)
// EncodeJWK encodes a jwk.Key to a writable stream.
func EncodeJWK(w io.Writer, key jwk.Key) error {
enc := json.NewEncoder(w)
enc.SetEscapeHTML(false)
return enc.Encode(key)
}
// EncodeJWKBytes encodes a jwk.Key to a byte slice.
func EncodeJWKBytes(key jwk.Key) ([]byte, error) {
b := &bytes.Buffer{}
err := EncodeJWK(b, key)
if err != nil {
return nil, err
}
return b.Bytes(), nil
}
// LoadKeyEncryptionKey loads the key encryption key for JWKs
func LoadKeyEncryptionKey(envConfig *common.EnvConfigSchema, instanceID string) (kek []byte, err error) {
// Try getting the key from the env var as string
kekInput := []byte(envConfig.EncryptionKey)
// If there's nothing in the env, try loading from file
if len(kekInput) == 0 && envConfig.EncryptionKeyFile != "" {
kekInput, err = os.ReadFile(envConfig.EncryptionKeyFile)
if err != nil {
return nil, fmt.Errorf("failed to read key file '%s': %w", envConfig.EncryptionKeyFile, err)
}
}
// If there's still no key, return
if len(kekInput) == 0 {
return nil, nil
}
// We need a 256-bit key for encryption with AES-GCM-256
// We use HMAC with SHA3-256 here to derive the key from the one passed as input
// The key is tied to a specific instance of Pocket ID
h := hmac.New(func() hash.Hash { return sha3.New256() }, kekInput)
fmt.Fprint(h, "pocketid/"+instanceID+"/jwk-kek")
kek = h.Sum(nil)
return kek, nil
}
// ImportRawKey imports a crypto key in "raw" format (e.g. crypto.PrivateKey) into a jwk.Key.
// It also populates additional fields such as the key ID, usage, and alg.
func ImportRawKey(rawKey any, alg string, crv string) (jwk.Key, error) {
key, err := jwk.Import(rawKey)
if err != nil {
return nil, fmt.Errorf("failed to import generated private key: %w", err)
}
// Generate the key ID
kid, err := generateRandomKeyID()
if err != nil {
return nil, fmt.Errorf("failed to generate key ID: %w", err)
}
_ = key.Set(jwk.KeyIDKey, kid)
// Set other required fields
_ = key.Set(jwk.KeyUsageKey, KeyUsageSigning)
EnsureAlgInKey(key, alg, crv)
return key, nil
}
// generateRandomKeyID generates a random key ID.
func generateRandomKeyID() (string, error) {
buf := make([]byte, 8)
_, err := io.ReadFull(rand.Reader, buf)
if err != nil {
return "", fmt.Errorf("failed to read random bytes: %w", err)
}
return base64.RawURLEncoding.EncodeToString(buf), nil
}
// EnsureAlgInKey ensures that the key contains an "alg" parameter (and "crv", if needed), set depending on the key type
func EnsureAlgInKey(key jwk.Key, alg string, crv string) {
_, ok := key.Algorithm()
if ok {
// Algorithm is already set
return
}
if alg != "" {
_ = key.Set(jwk.AlgorithmKey, alg)
if crv != "" {
eca, ok := jwa.LookupEllipticCurveAlgorithm(crv)
if ok {
switch key.KeyType() {
case jwa.EC():
_ = key.Set(jwk.ECDSACrvKey, eca)
case jwa.OKP():
_ = key.Set(jwk.OKPCrvKey, eca)
}
}
}
return
}
// If we don't have an algorithm, set the default for the key type
switch key.KeyType() {
case jwa.RSA():
// Default to RS256 for RSA keys
_ = key.Set(jwk.AlgorithmKey, jwa.RS256())
case jwa.EC():
// Default to ES256 for ECDSA keys
_ = key.Set(jwk.AlgorithmKey, jwa.ES256())
_ = key.Set(jwk.ECDSACrvKey, jwa.P256())
case jwa.OKP():
// Default to EdDSA and Ed25519 for OKP keys
_ = key.Set(jwk.AlgorithmKey, jwa.EdDSA())
_ = key.Set(jwk.OKPCrvKey, jwa.Ed25519())
}
}
// GenerateKey generates a new jwk.Key
func GenerateKey(alg string, crv string) (key jwk.Key, err error) {
var rawKey any
switch alg {
case jwa.RS256().String():
rawKey, err = rsa.GenerateKey(rand.Reader, 2048)
case jwa.RS384().String():
rawKey, err = rsa.GenerateKey(rand.Reader, 3072)
case jwa.RS512().String():
rawKey, err = rsa.GenerateKey(rand.Reader, 4096)
case jwa.ES256().String():
rawKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
case jwa.ES384().String():
rawKey, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
case jwa.ES512().String():
rawKey, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
case jwa.EdDSA().String():
switch crv {
case jwa.Ed25519().String():
_, rawKey, err = ed25519.GenerateKey(rand.Reader)
default:
return nil, errors.New("unsupported curve for EdDSA algorithm")
}
default:
return nil, errors.New("unsupported key algorithm")
}
if err != nil {
return nil, fmt.Errorf("failed to generate private key: %w", err)
}
// Import the raw key
return ImportRawKey(rawKey, alg, crv)
}

View File

@@ -0,0 +1,324 @@
package jwk
import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"testing"
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGenerateKey(t *testing.T) {
tests := []struct {
name string
alg string
crv string
expectError bool
expectedAlg jwa.SignatureAlgorithm
}{
{
name: "RS256",
alg: jwa.RS256().String(),
crv: "",
expectError: false,
expectedAlg: jwa.RS256(),
},
{
name: "RS384",
alg: jwa.RS384().String(),
crv: "",
expectError: false,
expectedAlg: jwa.RS384(),
},
// Skip the RS512 test as generating a RSA-4096 key can take some time
/* {
name: "RS512",
alg: jwa.RS512().String(),
crv: "",
expectError: false,
expectedAlg: jwa.RS512(),
}, */
{
name: "ES256",
alg: jwa.ES256().String(),
crv: jwa.P256().String(),
expectError: false,
expectedAlg: jwa.ES256(),
},
{
name: "ES384",
alg: jwa.ES384().String(),
crv: jwa.P384().String(),
expectError: false,
expectedAlg: jwa.ES384(),
},
{
name: "ES512",
alg: jwa.ES512().String(),
crv: jwa.P521().String(),
expectError: false,
expectedAlg: jwa.ES512(),
},
{
name: "EdDSA with Ed25519",
alg: jwa.EdDSA().String(),
crv: jwa.Ed25519().String(),
expectError: false,
expectedAlg: jwa.EdDSA(),
},
{
name: "EdDSA with unsupported curve",
alg: jwa.EdDSA().String(),
crv: "unsupported",
expectError: true,
},
{
name: "Unsupported algorithm",
alg: "UNSUPPORTED",
crv: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
key, err := GenerateKey(tt.alg, tt.crv)
if tt.expectError {
require.Error(t, err)
assert.Nil(t, key)
return
}
require.NoError(t, err)
require.NotNil(t, key)
// Verify the algorithm is set correctly
alg, ok := key.Algorithm()
require.True(t, ok, "algorithm should be set in the key")
assert.Equal(t, tt.expectedAlg.String(), alg.String())
// Verify other required fields are set
kid, ok := key.KeyID()
assert.True(t, ok, "key ID should be set")
assert.NotEmpty(t, kid, "key ID should not be empty")
usage, ok := key.KeyUsage()
assert.True(t, ok, "key usage should be set")
assert.Equal(t, KeyUsageSigning, usage)
var crv any
_ = key.Get("crv", &crv)
// Verify key type matches expected algorithm
switch tt.expectedAlg {
case jwa.RS256(), jwa.RS384(), jwa.RS512():
assert.Equal(t, jwa.RSA(), key.KeyType())
assert.Nil(t, crv)
case jwa.ES256(), jwa.ES384(), jwa.ES512():
assert.Equal(t, jwa.EC(), key.KeyType())
eca, ok := crv.(jwa.EllipticCurveAlgorithm)
_ = assert.NotNil(t, crv) &&
assert.True(t, ok) &&
assert.Equal(t, tt.crv, eca.String())
case jwa.EdDSA():
assert.Equal(t, jwa.OKP(), key.KeyType())
eca, ok := crv.(jwa.EllipticCurveAlgorithm)
_ = assert.NotNil(t, crv) &&
assert.True(t, ok) &&
assert.Equal(t, tt.crv, eca.String())
}
})
}
}
func TestEnsureAlgInKey(t *testing.T) {
// Generate an RSA-2048 key
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
t.Run("does not change alg already set", func(t *testing.T) {
// Import the RSA key
key, err := jwk.Import(rsaKey)
require.NoError(t, err)
// Pre-set the algorithm
_ = key.Set(jwk.AlgorithmKey, jwa.RS256())
// Call EnsureAlgInKey with a different algorithm
EnsureAlgInKey(key, jwa.RS384().String(), "")
// Verify the algorithm wasn't changed
alg, ok := key.Algorithm()
require.True(t, ok)
assert.Equal(t, jwa.RS256().String(), alg.String())
})
t.Run("set algorithm to explicitly-provided value", func(t *testing.T) {
tests := []struct {
name string
keyGen func() (any, error)
alg string
crv string
expectedAlg jwa.SignatureAlgorithm
expectedCrv string
}{
{
name: "RSA key with RS384",
keyGen: func() (any, error) {
return rsaKey, nil
},
alg: jwa.RS384().String(),
crv: "",
expectedAlg: jwa.RS384(),
expectedCrv: "",
},
{
name: "ECDSA key with ES384",
keyGen: func() (any, error) {
return ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
},
alg: jwa.ES384().String(),
crv: jwa.P384().String(),
expectedAlg: jwa.ES384(),
expectedCrv: jwa.P384().String(),
},
{
name: "Ed25519 key with EdDSA",
keyGen: func() (any, error) {
_, priv, err := ed25519.GenerateKey(rand.Reader)
return priv, err
},
alg: jwa.EdDSA().String(),
crv: jwa.Ed25519().String(),
expectedAlg: jwa.EdDSA(),
expectedCrv: jwa.Ed25519().String(),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rawKey, err := tt.keyGen()
require.NoError(t, err)
key, err := jwk.Import(rawKey)
require.NoError(t, err)
// Ensure no algorithm is set initially
_, ok := key.Algorithm()
assert.False(t, ok)
// Call EnsureAlgInKey
EnsureAlgInKey(key, tt.alg, tt.crv)
// Verify the algorithm was set correctly
alg, ok := key.Algorithm()
require.True(t, ok)
assert.Equal(t, tt.expectedAlg.String(), alg.String())
// Verify curve if expected
if tt.expectedCrv != "" {
var crv any
_ = key.Get("crv", &crv)
require.NotNil(t, crv)
eca, ok := crv.(jwa.EllipticCurveAlgorithm)
require.True(t, ok)
assert.Equal(t, tt.expectedCrv, eca.String())
}
})
}
})
t.Run("set default algorithms if not present", func(t *testing.T) {
tests := []struct {
name string
keyGen func() (any, error)
expectedAlg jwa.SignatureAlgorithm
expectedCrv string
}{
{
name: "RSA key defaults to RS256",
keyGen: func() (any, error) {
return rsaKey, nil
},
expectedAlg: jwa.RS256(),
expectedCrv: "",
},
{
name: "ECDSA key defaults to ES256 with P256",
keyGen: func() (any, error) {
return ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
},
expectedAlg: jwa.ES256(),
expectedCrv: jwa.P256().String(),
},
{
name: "Ed25519 key defaults to EdDSA with Ed25519",
keyGen: func() (any, error) {
_, priv, err := ed25519.GenerateKey(rand.Reader)
return priv, err
},
expectedAlg: jwa.EdDSA(),
expectedCrv: jwa.Ed25519().String(),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rawKey, err := tt.keyGen()
require.NoError(t, err)
key, err := jwk.Import(rawKey)
require.NoError(t, err)
// Ensure no algorithm is set initially
_, ok := key.Algorithm()
assert.False(t, ok)
// Call EnsureAlgInKey with empty parameters
EnsureAlgInKey(key, "", "")
// Verify the default algorithm was set
alg, ok := key.Algorithm()
require.True(t, ok)
assert.Equal(t, tt.expectedAlg.String(), alg.String())
// Verify curve if expected
if tt.expectedCrv != "" {
var crv any
_ = key.Get("crv", &crv)
require.NotNil(t, crv)
eca, ok := crv.(jwa.EllipticCurveAlgorithm)
require.True(t, ok)
assert.Equal(t, tt.expectedCrv, eca.String())
}
})
}
})
t.Run("invalid curve should not set curve parameter", func(t *testing.T) {
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
key, err := jwk.Import(rsaKey)
require.NoError(t, err)
// Call EnsureAlgInKey with invalid curve
EnsureAlgInKey(key, jwa.RS256().String(), "invalid-curve")
// Verify algorithm was set but curve was not
alg, ok := key.Algorithm()
require.True(t, ok)
assert.Equal(t, jwa.RS256().String(), alg.String())
var crv any
_ = key.Get("crv", &crv)
assert.Nil(t, crv)
})
}

View File

@@ -1,69 +0,0 @@
package utils
import (
"crypto/rand"
"encoding/base64"
"fmt"
"io"
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk"
)
const (
// KeyUsageSigning is the usage for the private keys, for the "use" property
KeyUsageSigning = "sig"
)
// ImportRawKey imports a crypto key in "raw" format (e.g. crypto.PrivateKey) into a jwk.Key.
// It also populates additional fields such as the key ID, usage, and alg.
func ImportRawKey(rawKey any) (jwk.Key, error) {
key, err := jwk.Import(rawKey)
if err != nil {
return nil, fmt.Errorf("failed to import generated private key: %w", err)
}
// Generate the key ID
kid, err := generateRandomKeyID()
if err != nil {
return nil, fmt.Errorf("failed to generate key ID: %w", err)
}
_ = key.Set(jwk.KeyIDKey, kid)
// Set other required fields
_ = key.Set(jwk.KeyUsageKey, KeyUsageSigning)
EnsureAlgInKey(key)
return key, nil
}
// generateRandomKeyID generates a random key ID.
func generateRandomKeyID() (string, error) {
buf := make([]byte, 8)
_, err := io.ReadFull(rand.Reader, buf)
if err != nil {
return "", fmt.Errorf("failed to read random bytes: %w", err)
}
return base64.RawURLEncoding.EncodeToString(buf), nil
}
// EnsureAlgInKey ensures that the key contains an "alg" parameter, set depending on the key type
func EnsureAlgInKey(key jwk.Key) {
_, ok := key.Algorithm()
if ok {
// Algorithm is already set
return
}
switch key.KeyType() {
case jwa.RSA():
// Default to RS256 for RSA keys
_ = key.Set(jwk.AlgorithmKey, jwa.RS256())
case jwa.EC():
// Default to ES256 for ECDSA keys
_ = key.Set(jwk.AlgorithmKey, jwa.ES256())
case jwa.OKP():
// Default to EdDSA for OKP keys
_ = key.Set(jwk.AlgorithmKey, jwa.EdDSA())
}
}

View File

@@ -1,9 +1,8 @@
package service
// This file is only imported by unit tests
package testing
import (
"io"
"net/http"
"strings"
"testing"
"time"
@@ -21,7 +20,10 @@ import (
"github.com/pocket-id/pocket-id/backend/resources"
)
func newDatabaseForTest(t *testing.T) *gorm.DB {
// NewDatabaseForTest returns a new instance of GORM connected to an in-memory SQLite database.
// Each database connection is unique for the test.
// All migrations are automatically performed.
func NewDatabaseForTest(t *testing.T) *gorm.DB {
t.Helper()
// Get a name for this in-memory database that is specific to the test
@@ -68,30 +70,3 @@ type testLoggerAdapter struct {
func (l testLoggerAdapter) Printf(format string, args ...any) {
l.t.Logf(format, args...)
}
// MockRoundTripper is a custom http.RoundTripper that returns responses based on the URL
type MockRoundTripper struct {
Err error
Responses map[string]*http.Response
}
// RoundTrip implements the http.RoundTripper interface
func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// Check if we have a specific response for this URL
for url, resp := range m.Responses {
if req.URL.String() == url {
return resp, nil
}
}
return NewMockResponse(http.StatusNotFound, ""), nil
}
// NewMockResponse creates an http.Response with the given status code and body
func NewMockResponse(statusCode int, body string) *http.Response {
return &http.Response{
StatusCode: statusCode,
Body: io.NopCloser(strings.NewReader(body)),
Header: make(http.Header),
}
}

View File

@@ -0,0 +1,38 @@
// This file is only imported by unit tests
package testing
import (
"io"
"net/http"
"strings"
_ "github.com/golang-migrate/migrate/v4/source/file"
)
// MockRoundTripper is a custom http.RoundTripper that returns responses based on the URL
type MockRoundTripper struct {
Err error
Responses map[string]*http.Response
}
// RoundTrip implements the http.RoundTripper interface
func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// Check if we have a specific response for this URL
for url, resp := range m.Responses {
if req.URL.String() == url {
return resp, nil
}
}
return NewMockResponse(http.StatusNotFound, ""), nil
}
// NewMockResponse creates an http.Response with the given status code and body
func NewMockResponse(statusCode int, body string) *http.Response {
return &http.Response{
StatusCode: statusCode,
Body: io.NopCloser(strings.NewReader(body)),
Header: make(http.Header),
}
}

View File

@@ -0,0 +1 @@
DROP TABLE kv;

View File

@@ -0,0 +1,6 @@
-- The "kv" tables contains miscellaneous key-value pairs
CREATE TABLE kv
(
"key" TEXT NOT NULL PRIMARY KEY,
"value" TEXT
);

View File

@@ -0,0 +1 @@
DROP TABLE kv;

View File

@@ -0,0 +1,6 @@
-- The "kv" tables contains miscellaneous key-value pairs
CREATE TABLE kv
(
"key" TEXT NOT NULL PRIMARY KEY,
"value" TEXT NOT NULL
);