mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-12 00:03:00 +03:00
refactor!: remove old DB env variables, and jwk migrations logic (#529)
This commit is contained in:
@@ -20,10 +20,6 @@ func Bootstrap() error {
|
||||
|
||||
initApplicationImages()
|
||||
|
||||
// Perform migrations for changes
|
||||
migrateConfigDBConnstring()
|
||||
migrateKey()
|
||||
|
||||
// Initialize the tracer and metrics exporter
|
||||
shutdownFns, httpClient, err := initOtel(ctx, common.EnvConfig.MetricsEnabled, common.EnvConfig.TracingEnabled)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
)
|
||||
|
||||
// Performs the migration of the database connection string
|
||||
// See: https://github.com/pocket-id/pocket-id/pull/388
|
||||
func migrateConfigDBConnstring() {
|
||||
switch common.EnvConfig.DbProvider {
|
||||
case common.DbProviderSqlite:
|
||||
// Check if we're using the deprecated SqliteDBPath env var
|
||||
if common.EnvConfig.SqliteDBPath != "" {
|
||||
connString := "file:" + common.EnvConfig.SqliteDBPath + "?_journal_mode=WAL&_busy_timeout=2500&_txlock=immediate"
|
||||
common.EnvConfig.DbConnectionString = connString
|
||||
common.EnvConfig.SqliteDBPath = ""
|
||||
|
||||
log.Printf("[WARN] Env var 'SQLITE_DB_PATH' is deprecated - use 'DB_CONNECTION_STRING' instead with the value: '%s'", connString)
|
||||
}
|
||||
case common.DbProviderPostgres:
|
||||
// Check if we're using the deprecated PostgresConnectionString alias
|
||||
if common.EnvConfig.PostgresConnectionString != "" {
|
||||
common.EnvConfig.DbConnectionString = common.EnvConfig.PostgresConnectionString
|
||||
common.EnvConfig.PostgresConnectionString = ""
|
||||
|
||||
log.Print("[WARN] Env var 'POSTGRES_CONNECTION_STRING' is deprecated - use 'DB_CONNECTION_STRING' instead with the same value")
|
||||
}
|
||||
default:
|
||||
// We don't do anything here in the default case
|
||||
// This is an error, but will be handled later on
|
||||
}
|
||||
}
|
||||
@@ -1,136 +0,0 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
privateKeyFilePem = "jwt_private_key.pem"
|
||||
)
|
||||
|
||||
func migrateKey() {
|
||||
err := migrateKeyInternal(common.EnvConfig.KeysPath)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to perform migration of keys: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func migrateKeyInternal(basePath string) error {
|
||||
// First, check if there's already a JWK stored
|
||||
jwkPath := filepath.Join(basePath, service.PrivateKeyFile)
|
||||
ok, err := utils.FileExists(jwkPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if private key file (JWK) exists at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
if ok {
|
||||
// There's already a key as JWK, so we don't do anything else here
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if there's a PEM file
|
||||
pemPath := filepath.Join(basePath, privateKeyFilePem)
|
||||
ok, err = utils.FileExists(pemPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if private key file (PEM) exists at path '%s': %w", pemPath, err)
|
||||
}
|
||||
if !ok {
|
||||
// No file to migrate, return
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load and validate the key
|
||||
key, err := loadKeyPEM(pemPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load private key file (PEM) at path '%s': %w", pemPath, err)
|
||||
}
|
||||
err = service.ValidateKey(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("key object is invalid: %w", err)
|
||||
}
|
||||
|
||||
// Save the key as JWK
|
||||
err = service.SaveKeyJWK(key, jwkPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save private key file at path '%s': %w", jwkPath, err)
|
||||
}
|
||||
|
||||
// Finally, delete the PEM file
|
||||
err = os.Remove(pemPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove migrated key at path '%s': %w", pemPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadKeyPEM(path string) (jwk.Key, error) {
|
||||
// Load the key from disk and parse it
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read key data: %w", err)
|
||||
}
|
||||
|
||||
key, err := jwk.ParseKey(data, jwk.WithPEM(true))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse key: %w", err)
|
||||
}
|
||||
|
||||
// Populate the key ID using the "legacy" algorithm
|
||||
keyId, err := generateKeyID(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate key ID: %w", err)
|
||||
}
|
||||
err = key.Set(jwk.KeyIDKey, keyId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to set key ID: %w", err)
|
||||
}
|
||||
|
||||
// Populate other required fields
|
||||
_ = key.Set(jwk.KeyUsageKey, service.KeyUsageSigning)
|
||||
service.EnsureAlgInKey(key)
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// generateKeyID generates a Key ID for the public key using the first 8 bytes of the SHA-256 hash of the public key's PKIX-serialized structure.
|
||||
// This is used for legacy keys, imported from PEM.
|
||||
func generateKeyID(key jwk.Key) (string, error) {
|
||||
// Export the public key and serialize it to PKIX (not in a PEM block)
|
||||
// This is for backwards-compatibility with the algorithm used before the switch to JWK
|
||||
pubKey, err := key.PublicKey()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get public key: %w", err)
|
||||
}
|
||||
var pubKeyRaw any
|
||||
err = jwk.Export(pubKey, &pubKeyRaw)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to export public key: %w", err)
|
||||
}
|
||||
pubASN1, err := x509.MarshalPKIXPublicKey(pubKeyRaw)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal public key: %w", err)
|
||||
}
|
||||
|
||||
// Compute SHA-256 hash of the public key
|
||||
hash := sha256.New()
|
||||
hash.Write(pubASN1)
|
||||
hashed := hash.Sum(nil)
|
||||
|
||||
// Truncate the hash to the first 8 bytes for a shorter Key ID
|
||||
shortHash := hashed[:8]
|
||||
|
||||
// Return Base64 encoded truncated hash as Key ID
|
||||
return base64.RawURLEncoding.EncodeToString(shortHash), nil
|
||||
}
|
||||
@@ -1,190 +0,0 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
)
|
||||
|
||||
func TestMigrateKey(t *testing.T) {
|
||||
// Create a temporary directory for testing
|
||||
tempDir := t.TempDir()
|
||||
|
||||
t.Run("no keys exist", func(t *testing.T) {
|
||||
// Test when no keys exist
|
||||
err := migrateKeyInternal(tempDir)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("jwk already exists", func(t *testing.T) {
|
||||
// Create a JWK file
|
||||
jwkPath := filepath.Join(tempDir, service.PrivateKeyFile)
|
||||
key, err := createTestRSAKey()
|
||||
require.NoError(t, err)
|
||||
err = service.SaveKeyJWK(key, jwkPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Run migration - should do nothing
|
||||
err = migrateKeyInternal(tempDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check the file still exists
|
||||
exists, err := utils.FileExists(jwkPath)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Delete for next test
|
||||
err = os.Remove(jwkPath)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("migrate pem to jwk", func(t *testing.T) {
|
||||
// Create a PEM file
|
||||
pemPath := filepath.Join(tempDir, privateKeyFilePem)
|
||||
jwkPath := filepath.Join(tempDir, service.PrivateKeyFile)
|
||||
|
||||
// Generate RSA key and save as PEM
|
||||
createRSAPrivateKeyPEM(t, pemPath)
|
||||
|
||||
// Run migration
|
||||
err := migrateKeyInternal(tempDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check PEM file is gone
|
||||
exists, err := utils.FileExists(pemPath)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
// Check JWK file exists
|
||||
exists, err = utils.FileExists(jwkPath)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Verify the JWK can be loaded
|
||||
data, err := os.ReadFile(jwkPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = jwk.ParseKey(data)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadKeyPEM(t *testing.T) {
|
||||
// Create a temporary directory for testing
|
||||
tempDir := t.TempDir()
|
||||
|
||||
t.Run("successfully load PEM key", func(t *testing.T) {
|
||||
pemPath := filepath.Join(tempDir, "test_key.pem")
|
||||
|
||||
// Generate RSA key and save as PEM
|
||||
createRSAPrivateKeyPEM(t, pemPath)
|
||||
|
||||
// Load the key
|
||||
key, err := loadKeyPEM(pemPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify key properties
|
||||
assert.NotEmpty(t, key)
|
||||
|
||||
// Check key ID is set
|
||||
var keyID string
|
||||
err = key.Get(jwk.KeyIDKey, &keyID)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, keyID)
|
||||
|
||||
// Check algorithm is set
|
||||
var alg jwa.SignatureAlgorithm
|
||||
err = key.Get(jwk.AlgorithmKey, &alg)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, alg)
|
||||
|
||||
// Check key usage is set
|
||||
var keyUsage string
|
||||
err = key.Get(jwk.KeyUsageKey, &keyUsage)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, service.KeyUsageSigning, keyUsage)
|
||||
})
|
||||
|
||||
t.Run("file not found", func(t *testing.T) {
|
||||
key, err := loadKeyPEM(filepath.Join(tempDir, "nonexistent.pem"))
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, key)
|
||||
})
|
||||
|
||||
t.Run("invalid file content", func(t *testing.T) {
|
||||
invalidPath := filepath.Join(tempDir, "invalid.pem")
|
||||
err := os.WriteFile(invalidPath, []byte("not a valid PEM"), 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := loadKeyPEM(invalidPath)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, key)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateKeyID(t *testing.T) {
|
||||
key, err := createTestRSAKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
keyID, err := generateKeyID(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Key ID should be non-empty
|
||||
assert.NotEmpty(t, keyID)
|
||||
|
||||
// Generate another key ID to prove it depends on the key
|
||||
key2, err := createTestRSAKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
keyID2, err := generateKeyID(key2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The two key IDs should be different
|
||||
assert.NotEqual(t, keyID, keyID2)
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func createTestRSAKey() (jwk.Key, error) {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key, err := jwk.Import(privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// createRSAPrivateKeyPEM generates an RSA private key and returns its PEM-encoded form
|
||||
func createRSAPrivateKeyPEM(t *testing.T, pemPath string) ([]byte, *rsa.PrivateKey) {
|
||||
// Generate RSA key
|
||||
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Encode to PEM format
|
||||
pemData := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(privKey),
|
||||
})
|
||||
|
||||
err = os.WriteFile(pemPath, pemData, 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
return pemData, privKey
|
||||
}
|
||||
@@ -24,41 +24,37 @@ const (
|
||||
)
|
||||
|
||||
type EnvConfigSchema struct {
|
||||
AppEnv string `env:"APP_ENV"`
|
||||
AppURL string `env:"PUBLIC_APP_URL"`
|
||||
DbProvider DbProvider `env:"DB_PROVIDER"`
|
||||
DbConnectionString string `env:"DB_CONNECTION_STRING"`
|
||||
SqliteDBPath string `env:"SQLITE_DB_PATH"` // Deprecated: use "DB_CONNECTION_STRING" instead
|
||||
PostgresConnectionString string `env:"POSTGRES_CONNECTION_STRING"` // Deprecated: use "DB_CONNECTION_STRING" instead
|
||||
UploadPath string `env:"UPLOAD_PATH"`
|
||||
KeysPath string `env:"KEYS_PATH"`
|
||||
Port string `env:"BACKEND_PORT"`
|
||||
Host string `env:"HOST"`
|
||||
MaxMindLicenseKey string `env:"MAXMIND_LICENSE_KEY"`
|
||||
GeoLiteDBPath string `env:"GEOLITE_DB_PATH"`
|
||||
GeoLiteDBUrl string `env:"GEOLITE_DB_URL"`
|
||||
UiConfigDisabled bool `env:"PUBLIC_UI_CONFIG_DISABLED"`
|
||||
MetricsEnabled bool `env:"METRICS_ENABLED"`
|
||||
TracingEnabled bool `env:"TRACING_ENABLED"`
|
||||
AppEnv string `env:"APP_ENV"`
|
||||
AppURL string `env:"PUBLIC_APP_URL"`
|
||||
DbProvider DbProvider `env:"DB_PROVIDER"`
|
||||
DbConnectionString string `env:"DB_CONNECTION_STRING"`
|
||||
UploadPath string `env:"UPLOAD_PATH"`
|
||||
KeysPath string `env:"KEYS_PATH"`
|
||||
Port string `env:"BACKEND_PORT"`
|
||||
Host string `env:"HOST"`
|
||||
MaxMindLicenseKey string `env:"MAXMIND_LICENSE_KEY"`
|
||||
GeoLiteDBPath string `env:"GEOLITE_DB_PATH"`
|
||||
GeoLiteDBUrl string `env:"GEOLITE_DB_URL"`
|
||||
UiConfigDisabled bool `env:"PUBLIC_UI_CONFIG_DISABLED"`
|
||||
MetricsEnabled bool `env:"METRICS_ENABLED"`
|
||||
TracingEnabled bool `env:"TRACING_ENABLED"`
|
||||
}
|
||||
|
||||
var EnvConfig = &EnvConfigSchema{
|
||||
AppEnv: "production",
|
||||
DbProvider: "sqlite",
|
||||
DbConnectionString: "file:data/pocket-id.db?_journal_mode=WAL&_busy_timeout=2500&_txlock=immediate",
|
||||
SqliteDBPath: "",
|
||||
PostgresConnectionString: "",
|
||||
UploadPath: "data/uploads",
|
||||
KeysPath: "data/keys",
|
||||
AppURL: "http://localhost",
|
||||
Port: "8080",
|
||||
Host: "0.0.0.0",
|
||||
MaxMindLicenseKey: "",
|
||||
GeoLiteDBPath: "data/GeoLite2-City.mmdb",
|
||||
GeoLiteDBUrl: MaxMindGeoLiteCityUrl,
|
||||
UiConfigDisabled: false,
|
||||
MetricsEnabled: false,
|
||||
TracingEnabled: false,
|
||||
AppEnv: "production",
|
||||
DbProvider: "sqlite",
|
||||
DbConnectionString: "file:data/pocket-id.db?_journal_mode=WAL&_busy_timeout=2500&_txlock=immediate",
|
||||
UploadPath: "data/uploads",
|
||||
KeysPath: "data/keys",
|
||||
AppURL: "http://localhost",
|
||||
Port: "8080",
|
||||
Host: "0.0.0.0",
|
||||
MaxMindLicenseKey: "",
|
||||
GeoLiteDBPath: "data/GeoLite2-City.mmdb",
|
||||
GeoLiteDBUrl: MaxMindGeoLiteCityUrl,
|
||||
UiConfigDisabled: false,
|
||||
MetricsEnabled: false,
|
||||
TracingEnabled: false,
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
Reference in New Issue
Block a user