mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-12 08:12:57 +03:00
137 lines
3.7 KiB
Go
137 lines
3.7 KiB
Go
package bootstrap
|
|
|
|
import (
|
|
"crypto/sha256"
|
|
"crypto/x509"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"path/filepath"
|
|
|
|
"github.com/lestrrat-go/jwx/v3/jwk"
|
|
|
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
|
"github.com/pocket-id/pocket-id/backend/internal/service"
|
|
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
|
)
|
|
|
|
const (
|
|
privateKeyFilePem = "jwt_private_key.pem"
|
|
)
|
|
|
|
func migrateKey() {
|
|
err := migrateKeyInternal(common.EnvConfig.KeysPath)
|
|
if err != nil {
|
|
log.Fatalf("failed to perform migration of keys: %v", err)
|
|
}
|
|
}
|
|
|
|
func migrateKeyInternal(basePath string) error {
|
|
// First, check if there's already a JWK stored
|
|
jwkPath := filepath.Join(basePath, service.PrivateKeyFile)
|
|
ok, err := utils.FileExists(jwkPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to check if private key file (JWK) exists at path '%s': %w", jwkPath, err)
|
|
}
|
|
if ok {
|
|
// There's already a key as JWK, so we don't do anything else here
|
|
return nil
|
|
}
|
|
|
|
// Check if there's a PEM file
|
|
pemPath := filepath.Join(basePath, privateKeyFilePem)
|
|
ok, err = utils.FileExists(pemPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to check if private key file (PEM) exists at path '%s': %w", pemPath, err)
|
|
}
|
|
if !ok {
|
|
// No file to migrate, return
|
|
return nil
|
|
}
|
|
|
|
// Load and validate the key
|
|
key, err := loadKeyPEM(pemPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load private key file (PEM) at path '%s': %w", pemPath, err)
|
|
}
|
|
err = service.ValidateKey(key)
|
|
if err != nil {
|
|
return fmt.Errorf("key object is invalid: %w", err)
|
|
}
|
|
|
|
// Save the key as JWK
|
|
err = service.SaveKeyJWK(key, jwkPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to save private key file at path '%s': %w", jwkPath, err)
|
|
}
|
|
|
|
// Finally, delete the PEM file
|
|
err = os.Remove(pemPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to remove migrated key at path '%s': %w", pemPath, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func loadKeyPEM(path string) (jwk.Key, error) {
|
|
// Load the key from disk and parse it
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read key data: %w", err)
|
|
}
|
|
|
|
key, err := jwk.ParseKey(data, jwk.WithPEM(true))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse key: %w", err)
|
|
}
|
|
|
|
// Populate the key ID using the "legacy" algorithm
|
|
keyId, err := generateKeyID(key)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate key ID: %w", err)
|
|
}
|
|
err = key.Set(jwk.KeyIDKey, keyId)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to set key ID: %w", err)
|
|
}
|
|
|
|
// Populate other required fields
|
|
_ = key.Set(jwk.KeyUsageKey, service.KeyUsageSigning)
|
|
service.EnsureAlgInKey(key)
|
|
|
|
return key, nil
|
|
}
|
|
|
|
// generateKeyID generates a Key ID for the public key using the first 8 bytes of the SHA-256 hash of the public key's PKIX-serialized structure.
|
|
// This is used for legacy keys, imported from PEM.
|
|
func generateKeyID(key jwk.Key) (string, error) {
|
|
// Export the public key and serialize it to PKIX (not in a PEM block)
|
|
// This is for backwards-compatibility with the algorithm used before the switch to JWK
|
|
pubKey, err := key.PublicKey()
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to get public key: %w", err)
|
|
}
|
|
var pubKeyRaw any
|
|
err = jwk.Export(pubKey, &pubKeyRaw)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to export public key: %w", err)
|
|
}
|
|
pubASN1, err := x509.MarshalPKIXPublicKey(pubKeyRaw)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to marshal public key: %w", err)
|
|
}
|
|
|
|
// Compute SHA-256 hash of the public key
|
|
hash := sha256.New()
|
|
hash.Write(pubASN1)
|
|
hashed := hash.Sum(nil)
|
|
|
|
// Truncate the hash to the first 8 bytes for a shorter Key ID
|
|
shortHash := hashed[:8]
|
|
|
|
// Return Base64 encoded truncated hash as Key ID
|
|
return base64.RawURLEncoding.EncodeToString(shortHash), nil
|
|
}
|