feat: store keys as JWK on disk (#339)

Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
This commit is contained in:
Alessandro (Ale) Segala
2025-03-18 13:08:33 -07:00
committed by GitHub
parent e9b2d981b7
commit a7c9741802
10 changed files with 1207 additions and 174 deletions

View File

@@ -3,14 +3,12 @@ package service
import (
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"math/big"
"os"
"path/filepath"
"slices"
@@ -18,80 +16,148 @@ import (
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils"
)
const (
privateKeyFile = "jwt_private_key.pem"
// Path in the data/keys folder where the key is stored
// This is a JSON file containing a key encoded as JWK
PrivateKeyFile = "jwt_private_key.json"
// Size, in bits, of the RSA key to generate if none is found
RsaKeySize = 2048
// Usage for the private keys, for the "use" property
KeyUsageSigning = "sig"
)
type JwtService struct {
privateKey *rsa.PrivateKey
privateKey jwk.Key
keyId string
appConfigService *AppConfigService
jwksEncoded []byte
}
func NewJwtService(appConfigService *AppConfigService) *JwtService {
service := &JwtService{
appConfigService: appConfigService,
}
service := &JwtService{}
// Ensure keys are generated or loaded
if err := service.loadOrGenerateKey(common.EnvConfig.KeysPath); err != nil {
if err := service.init(appConfigService, common.EnvConfig.KeysPath); err != nil {
log.Fatalf("Failed to initialize jwt service: %v", err)
}
return service
}
func (s *JwtService) init(appConfigService *AppConfigService, keysPath string) error {
s.appConfigService = appConfigService
// Ensure keys are generated or loaded
return s.loadOrGenerateKey(keysPath)
}
type AccessTokenJWTClaims struct {
jwt.RegisteredClaims
IsAdmin bool `json:"isAdmin,omitempty"`
}
type JWK struct {
Kid string `json:"kid"`
Kty string `json:"kty"`
Use string `json:"use"`
Alg string `json:"alg"`
N string `json:"n"`
E string `json:"e"`
}
// loadOrGenerateKey loads RSA keys from the given paths or generates them if they do not exist.
// loadOrGenerateKey loads the private key from the given path or generates it if not existing.
func (s *JwtService) loadOrGenerateKey(keysPath string) error {
privateKeyPath := filepath.Join(keysPath, privateKeyFile)
var key jwk.Key
if _, err := os.Stat(privateKeyPath); os.IsNotExist(err) {
if err := s.generateKey(keysPath); err != nil {
return fmt.Errorf("can't generate key: %w", err)
// First, check if we have a JWK file
// If we do, then we just load that
jwkPath := filepath.Join(keysPath, PrivateKeyFile)
ok, err := utils.FileExists(jwkPath)
if err != nil {
return fmt.Errorf("failed to check if private key file (JWK) exists at path '%s': %w", jwkPath, err)
}
if ok {
key, err = s.loadKeyJWK(jwkPath)
if err != nil {
return fmt.Errorf("failed to load private key file (JWK) at path '%s': %w", jwkPath, err)
}
// Set the key, and we are done
err = s.SetKey(key)
if err != nil {
return fmt.Errorf("failed to set private key: %w", err)
}
return nil
}
privateKeyBytes, err := os.ReadFile(privateKeyPath)
// If we are here, we need to generate a new key
key, err = s.generateNewRSAKey()
if err != nil {
return fmt.Errorf("can't read jwt private key: %w", err)
}
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes)
if err != nil {
return fmt.Errorf("can't parse jwt private key: %w", err)
return fmt.Errorf("failed to generate new private key: %w", err)
}
err = s.SetKey(privateKey)
// Set the key in the object, which also validates it
err = s.SetKey(key)
if err != nil {
return fmt.Errorf("failed to set private key: %w", err)
}
// Save the key as JWK
err = SaveKeyJWK(s.privateKey, jwkPath)
if err != nil {
return fmt.Errorf("failed to save private key file at path '%s': %w", jwkPath, err)
}
return nil
}
func (s *JwtService) SetKey(privateKey *rsa.PrivateKey) (err error) {
func ValidateKey(privateKey jwk.Key) error {
// Validate the loaded key
err := privateKey.Validate()
if err != nil {
return fmt.Errorf("key object is invalid: %w", err)
}
keyID, ok := privateKey.KeyID()
if !ok || keyID == "" {
return errors.New("key object does not contain a key ID")
}
usage, ok := privateKey.KeyUsage()
if !ok || usage != KeyUsageSigning {
return errors.New("key object is not valid for signing")
}
ok, err = jwk.IsPrivateKey(privateKey)
if err != nil || !ok {
return errors.New("key object is not a private key")
}
return nil
}
func (s *JwtService) SetKey(privateKey jwk.Key) error {
// Validate the loaded key
err := ValidateKey(privateKey)
if err != nil {
return fmt.Errorf("private key is not valid: %w", err)
}
// Set the private key in the object
s.privateKey = privateKey
s.keyId, err = s.generateKeyID()
// Create and encode a JWKS containing the public key
publicKey, err := s.GetPublicJWK()
if err != nil {
return fmt.Errorf("can't generate key ID: %w", err)
return fmt.Errorf("failed to get public JWK: %w", err)
}
jwks := jwk.NewSet()
err = jwks.AddKey(publicKey)
if err != nil {
return fmt.Errorf("failed to add public key to JWKS: %w", err)
}
s.jwksEncoded, err = json.Marshal(jwks)
if err != nil {
return fmt.Errorf("failed to encode JWKS to JSON: %w", err)
}
return nil
@@ -112,12 +178,23 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
token.Header["kid"] = s.keyId
return token.SignedString(s.privateKey)
var privateKeyRaw any
err := jwk.Export(s.privateKey, &privateKeyRaw)
if err != nil {
return "", fmt.Errorf("failed to export private key object: %w", err)
}
signed, err := token.SignedString(privateKeyRaw)
if err != nil {
return "", fmt.Errorf("failed to sign token: %w", err)
}
return signed, nil
}
func (s *JwtService) VerifyAccessToken(tokenString string) (*AccessTokenJWTClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &AccessTokenJWTClaims{}, func(token *jwt.Token) (interface{}, error) {
return &s.privateKey.PublicKey, nil
token, err := jwt.ParseWithClaims(tokenString, &AccessTokenJWTClaims{}, func(token *jwt.Token) (any, error) {
return s.getPublicKeyRaw()
})
if err != nil || !token.Valid {
return nil, errors.New("couldn't handle this token")
@@ -135,12 +212,12 @@ func (s *JwtService) VerifyAccessToken(tokenString string) (*AccessTokenJWTClaim
}
func (s *JwtService) GenerateIDToken(userClaims map[string]interface{}, clientID string, nonce string) (string, error) {
claims := jwt.MapClaims{
"aud": clientID,
"exp": jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
"iat": jwt.NewNumericDate(time.Now()),
"iss": common.EnvConfig.AppURL,
}
// Initialize with capacity for userClaims, + 4 fixed claims, + 2 claims which may be set in some cases, to avoid re-allocations
claims := make(jwt.MapClaims, len(userClaims)+6)
claims["aud"] = clientID
claims["exp"] = jwt.NewNumericDate(time.Now().Add(1 * time.Hour))
claims["iat"] = jwt.NewNumericDate(time.Now())
claims["iss"] = common.EnvConfig.AppURL
for k, v := range userClaims {
claims[k] = v
@@ -153,43 +230,18 @@ func (s *JwtService) GenerateIDToken(userClaims map[string]interface{}, clientID
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = s.keyId
return token.SignedString(s.privateKey)
}
func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) {
claim := jwt.RegisteredClaims{
Subject: user.ID,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
Audience: jwt.ClaimStrings{clientID},
Issuer: common.EnvConfig.AppURL,
var privateKeyRaw any
err := jwk.Export(s.privateKey, &privateKeyRaw)
if err != nil {
return "", fmt.Errorf("failed to export private key object: %w", err)
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
token.Header["kid"] = s.keyId
return token.SignedString(s.privateKey)
}
func (s *JwtService) VerifyOauthAccessToken(tokenString string) (*jwt.RegisteredClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
return &s.privateKey.PublicKey, nil
})
if err != nil || !token.Valid {
return nil, errors.New("couldn't handle this token")
}
claims, isValid := token.Claims.(*jwt.RegisteredClaims)
if !isValid {
return nil, errors.New("can't parse claims")
}
return claims, nil
return token.SignedString(privateKeyRaw)
}
func (s *JwtService) VerifyIdToken(tokenString string) (*jwt.RegisteredClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
return &s.privateKey.PublicKey, nil
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (any, error) {
return s.getPublicKeyRaw()
}, jwt.WithIssuer(common.EnvConfig.AppURL))
if err != nil && !errors.Is(err, jwt.ErrTokenExpired) {
@@ -204,78 +256,180 @@ func (s *JwtService) VerifyIdToken(tokenString string) (*jwt.RegisteredClaims, e
return claims, nil
}
// GetJWK returns the JSON Web Key (JWK) for the public key.
func (s *JwtService) GetJWK() (JWK, error) {
func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) {
claim := jwt.RegisteredClaims{
Subject: user.ID,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
Audience: jwt.ClaimStrings{clientID},
Issuer: common.EnvConfig.AppURL,
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
token.Header["kid"] = s.keyId
var privateKeyRaw any
err := jwk.Export(s.privateKey, &privateKeyRaw)
if err != nil {
return "", fmt.Errorf("failed to export private key object: %w", err)
}
return token.SignedString(privateKeyRaw)
}
func (s *JwtService) VerifyOauthAccessToken(tokenString string) (*jwt.RegisteredClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (any, error) {
return s.getPublicKeyRaw()
})
if err != nil || !token.Valid {
return nil, errors.New("couldn't handle this token")
}
claims, isValid := token.Claims.(*jwt.RegisteredClaims)
if !isValid {
return nil, errors.New("can't parse claims")
}
return claims, nil
}
// GetPublicJWK returns the JSON Web Key (JWK) for the public key.
func (s *JwtService) GetPublicJWK() (jwk.Key, error) {
if s.privateKey == nil {
return JWK{}, errors.New("public key is not initialized")
return nil, errors.New("key is not initialized")
}
jwk := JWK{
Kid: s.keyId,
Kty: "RSA",
Use: "sig",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(s.privateKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(s.privateKey.E)).Bytes()),
}
return jwk, nil
}
// GenerateKeyID generates a Key ID for the public key using the first 8 bytes of the SHA-256 hash of the public key.
func (s *JwtService) generateKeyID() (string, error) {
pubASN1, err := x509.MarshalPKIXPublicKey(&s.privateKey.PublicKey)
pubKey, err := s.privateKey.PublicKey()
if err != nil {
return "", fmt.Errorf("failed to marshal public key: %w", err)
return nil, fmt.Errorf("failed to get public key: %w", err)
}
// Compute SHA-256 hash of the public key
hash := sha256.New()
hash.Write(pubASN1)
hashed := hash.Sum(nil)
EnsureAlgInKey(pubKey)
// Truncate the hash to the first 8 bytes for a shorter Key ID
shortHash := hashed[:8]
// Return Base64 encoded truncated hash as Key ID
return base64.RawURLEncoding.EncodeToString(shortHash), nil
return pubKey, nil
}
// generateKey generates a new RSA key and saves it to the specified path.
func (s *JwtService) generateKey(keysPath string) error {
if err := os.MkdirAll(keysPath, 0700); err != nil {
return fmt.Errorf("failed to create directories for keys: %w", err)
// GetPublicJWKSAsJSON returns the JSON Web Key Set (JWKS) for the public key, encoded as JSON.
// The value is cached since the key is static.
func (s *JwtService) GetPublicJWKSAsJSON() ([]byte, error) {
if len(s.jwksEncoded) == 0 {
return nil, errors.New("key is not initialized")
}
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
return s.jwksEncoded, nil
}
func (s *JwtService) getPublicKeyRaw() (any, error) {
pubKey, err := s.privateKey.PublicKey()
if err != nil {
return fmt.Errorf("failed to generate private key: %w", err)
return nil, fmt.Errorf("failed to get public key: %w", err)
}
privateKeyPath := filepath.Join(keysPath, privateKeyFile)
if err := s.savePEMKey(privateKeyPath, x509.MarshalPKCS1PrivateKey(privateKey), "RSA PRIVATE KEY"); err != nil {
return err
var pubKeyRaw any
err = jwk.Export(pubKey, &pubKeyRaw)
if err != nil {
return nil, fmt.Errorf("failed to export raw public key: %w", err)
}
return nil
return pubKeyRaw, nil
}
// savePEMKey saves a PEM encoded key to a file.
func (s *JwtService) savePEMKey(path string, keyBytes []byte, keyType string) error {
keyFile, err := os.Create(path)
func (s *JwtService) loadKeyJWK(path string) (jwk.Key, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("failed to read key data: %w", err)
}
key, err := jwk.ParseKey(data)
if err != nil {
return nil, fmt.Errorf("failed to parse key: %w", err)
}
return key, nil
}
// EnsureAlgInKey ensures that the key contains an "alg" parameter, set depending on the key type
func EnsureAlgInKey(key jwk.Key) {
_, ok := key.Algorithm()
if ok {
// Algorithm is already set
return
}
switch key.KeyType() {
case jwa.RSA():
// Default to RS256 for RSA keys
_ = key.Set(jwk.AlgorithmKey, jwa.RS256())
case jwa.EC():
// Default to ES256 for ECDSA keys
_ = key.Set(jwk.AlgorithmKey, jwa.ES256())
case jwa.OKP():
// Default to EdDSA for OKP keys
_ = key.Set(jwk.AlgorithmKey, jwa.EdDSA())
}
}
func (s *JwtService) generateNewRSAKey() (jwk.Key, error) {
// We generate RSA keys only
rawKey, err := rsa.GenerateKey(rand.Reader, RsaKeySize)
if err != nil {
return nil, fmt.Errorf("failed to generate RSA private key: %w", err)
}
// Import the raw key
return importRawKey(rawKey)
}
func importRawKey(rawKey any) (jwk.Key, error) {
key, err := jwk.Import(rawKey)
if err != nil {
return nil, fmt.Errorf("failed to import generated private key: %w", err)
}
// Generate the key ID
kid, err := generateRandomKeyID()
if err != nil {
return nil, fmt.Errorf("failed to generate key ID: %w", err)
}
_ = key.Set(jwk.KeyIDKey, kid)
// Set other required fields
_ = key.Set(jwk.KeyUsageKey, KeyUsageSigning)
EnsureAlgInKey(key)
return key, err
}
// SaveKeyJWK saves a JWK to a file
func SaveKeyJWK(key jwk.Key, path string) error {
dir := filepath.Dir(path)
err := os.MkdirAll(dir, 0700)
if err != nil {
return fmt.Errorf("failed to create directory '%s' for key file: %w", dir, err)
}
keyFile, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("failed to create key file: %w", err)
}
defer keyFile.Close()
keyPEM := pem.EncodeToMemory(&pem.Block{
Type: keyType,
Bytes: keyBytes,
})
if _, err := keyFile.Write(keyPEM); err != nil {
// Write the JSON file to disk
enc := json.NewEncoder(keyFile)
enc.SetEscapeHTML(false)
err = enc.Encode(key)
if err != nil {
return fmt.Errorf("failed to write key file: %w", err)
}
return nil
}
// generateRandomKeyID generates a random key ID.
// It is used for newly-generated keys
func generateRandomKeyID() (string, error) {
buf := make([]byte, 8)
_, err := io.ReadFull(rand.Reader, buf)
if err != nil {
return "", fmt.Errorf("failed to read random bytes: %w", err)
}
return base64.RawURLEncoding.EncodeToString(buf), nil
}