feat: allow setting path where keys are stored (#327)

This commit is contained in:
Alessandro (Ale) Segala
2025-03-13 09:01:15 -07:00
committed by GitHub
parent 8c1c04db1d
commit 7b654c6bd1
3 changed files with 63 additions and 78 deletions

View File

@@ -23,6 +23,7 @@ type EnvConfigSchema struct {
SqliteDBPath string `env:"SQLITE_DB_PATH"` SqliteDBPath string `env:"SQLITE_DB_PATH"`
PostgresConnectionString string `env:"POSTGRES_CONNECTION_STRING"` PostgresConnectionString string `env:"POSTGRES_CONNECTION_STRING"`
UploadPath string `env:"UPLOAD_PATH"` UploadPath string `env:"UPLOAD_PATH"`
KeysPath string `env:"KEYS_PATH"`
Port string `env:"BACKEND_PORT"` Port string `env:"BACKEND_PORT"`
Host string `env:"HOST"` Host string `env:"HOST"`
MaxMindLicenseKey string `env:"MAXMIND_LICENSE_KEY"` MaxMindLicenseKey string `env:"MAXMIND_LICENSE_KEY"`
@@ -37,6 +38,7 @@ var EnvConfig = &EnvConfigSchema{
SqliteDBPath: "data/pocket-id.db", SqliteDBPath: "data/pocket-id.db",
PostgresConnectionString: "", PostgresConnectionString: "",
UploadPath: "data/uploads", UploadPath: "data/uploads",
KeysPath: "data/keys",
AppURL: "http://localhost", AppURL: "http://localhost",
Port: "8080", Port: "8080",
Host: "0.0.0.0", Host: "0.0.0.0",
@@ -50,17 +52,19 @@ func init() {
if err := env.ParseWithOptions(EnvConfig, env.Options{}); err != nil { if err := env.ParseWithOptions(EnvConfig, env.Options{}); err != nil {
log.Fatal(err) log.Fatal(err)
} }
// Validate the environment variables
if EnvConfig.DbProvider != DbProviderSqlite && EnvConfig.DbProvider != DbProviderPostgres {
log.Fatal("Invalid DB_PROVIDER value. Must be 'sqlite' or 'postgres'")
}
if EnvConfig.DbProvider == DbProviderPostgres && EnvConfig.PostgresConnectionString == "" { // Validate the environment variables
switch EnvConfig.DbProvider {
case DbProviderSqlite:
if EnvConfig.SqliteDBPath == "" {
log.Fatal("Missing SQLITE_DB_PATH environment variable")
}
case DbProviderPostgres:
if EnvConfig.PostgresConnectionString == "" {
log.Fatal("Missing POSTGRES_CONNECTION_STRING environment variable") log.Fatal("Missing POSTGRES_CONNECTION_STRING environment variable")
} }
default:
if EnvConfig.DbProvider == DbProviderSqlite && EnvConfig.SqliteDBPath == "" { log.Fatal("Invalid DB_PROVIDER value. Must be 'sqlite' or 'postgres'")
log.Fatal("Missing SQLITE_DB_PATH environment variable")
} }
parsedAppUrl, err := url.Parse(EnvConfig.AppURL) parsedAppUrl, err := url.Parse(EnvConfig.AppURL)

View File

@@ -8,6 +8,7 @@ import (
"encoding/base64" "encoding/base64"
"encoding/pem" "encoding/pem"
"errors" "errors"
"fmt"
"log" "log"
"math/big" "math/big"
"os" "os"
@@ -22,13 +23,12 @@ import (
) )
const ( const (
privateKeyPath = "data/keys/jwt_private_key.pem" privateKeyFile = "jwt_private_key.pem"
publicKeyPath = "data/keys/jwt_public_key.pem"
) )
type JwtService struct { type JwtService struct {
PublicKey *rsa.PublicKey privateKey *rsa.PrivateKey
PrivateKey *rsa.PrivateKey keyId string
appConfigService *AppConfigService appConfigService *AppConfigService
} }
@@ -38,7 +38,7 @@ func NewJwtService(appConfigService *AppConfigService) *JwtService {
} }
// Ensure keys are generated or loaded // Ensure keys are generated or loaded
if err := service.loadOrGenerateKeys(); err != nil { if err := service.loadOrGenerateKey(common.EnvConfig.KeysPath); err != nil {
log.Fatalf("Failed to initialize jwt service: %v", err) log.Fatalf("Failed to initialize jwt service: %v", err)
} }
@@ -59,30 +59,39 @@ type JWK struct {
E string `json:"e"` E string `json:"e"`
} }
// loadOrGenerateKeys loads RSA keys from the given paths or generates them if they do not exist. // loadOrGenerateKey loads RSA keys from the given paths or generates them if they do not exist.
func (s *JwtService) loadOrGenerateKeys() error { func (s *JwtService) loadOrGenerateKey(keysPath string) error {
privateKeyPath := filepath.Join(keysPath, privateKeyFile)
if _, err := os.Stat(privateKeyPath); os.IsNotExist(err) { if _, err := os.Stat(privateKeyPath); os.IsNotExist(err) {
if err := s.generateKeys(); err != nil { if err := s.generateKey(keysPath); err != nil {
return err return fmt.Errorf("can't generate key: %w", err)
} }
} }
privateKeyBytes, err := os.ReadFile(privateKeyPath) privateKeyBytes, err := os.ReadFile(privateKeyPath)
if err != nil { if err != nil {
return errors.New("can't read jwt private key: " + err.Error()) return fmt.Errorf("can't read jwt private key: %w", err)
} }
s.PrivateKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes) privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes)
if err != nil { if err != nil {
return errors.New("can't parse jwt private key: " + err.Error()) return fmt.Errorf("can't parse jwt private key: %w", err)
} }
publicKeyBytes, err := os.ReadFile(publicKeyPath) err = s.SetKey(privateKey)
if err != nil { if err != nil {
return errors.New("can't read jwt public key: " + err.Error()) return fmt.Errorf("failed to set private key: %w", err)
} }
s.PublicKey, err = jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes)
return nil
}
func (s *JwtService) SetKey(privateKey *rsa.PrivateKey) (err error) {
s.privateKey = privateKey
s.keyId, err = s.generateKeyID()
if err != nil { if err != nil {
return errors.New("can't parse jwt public key: " + err.Error()) return fmt.Errorf("can't generate key ID: %w", err)
} }
return nil return nil
@@ -100,20 +109,15 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
IsAdmin: user.IsAdmin, IsAdmin: user.IsAdmin,
} }
kid, err := s.generateKeyID(s.PublicKey)
if err != nil {
return "", errors.New("failed to generate key ID: " + err.Error())
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim) token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
token.Header["kid"] = kid token.Header["kid"] = s.keyId
return token.SignedString(s.PrivateKey) return token.SignedString(s.privateKey)
} }
func (s *JwtService) VerifyAccessToken(tokenString string) (*AccessTokenJWTClaims, error) { func (s *JwtService) VerifyAccessToken(tokenString string) (*AccessTokenJWTClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &AccessTokenJWTClaims{}, func(token *jwt.Token) (interface{}, error) { token, err := jwt.ParseWithClaims(tokenString, &AccessTokenJWTClaims{}, func(token *jwt.Token) (interface{}, error) {
return s.PublicKey, nil return &s.privateKey.PublicKey, nil
}) })
if err != nil || !token.Valid { if err != nil || !token.Valid {
return nil, errors.New("couldn't handle this token") return nil, errors.New("couldn't handle this token")
@@ -146,15 +150,10 @@ func (s *JwtService) GenerateIDToken(userClaims map[string]interface{}, clientID
claims["nonce"] = nonce claims["nonce"] = nonce
} }
kid, err := s.generateKeyID(s.PublicKey)
if err != nil {
return "", errors.New("failed to generate key ID: " + err.Error())
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = kid token.Header["kid"] = s.keyId
return token.SignedString(s.PrivateKey) return token.SignedString(s.privateKey)
} }
func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) { func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) {
@@ -166,20 +165,15 @@ func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string)
Issuer: common.EnvConfig.AppURL, Issuer: common.EnvConfig.AppURL,
} }
kid, err := s.generateKeyID(s.PublicKey)
if err != nil {
return "", errors.New("failed to generate key ID: " + err.Error())
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim) token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
token.Header["kid"] = kid token.Header["kid"] = s.keyId
return token.SignedString(s.PrivateKey) return token.SignedString(s.privateKey)
} }
func (s *JwtService) VerifyOauthAccessToken(tokenString string) (*jwt.RegisteredClaims, error) { func (s *JwtService) VerifyOauthAccessToken(tokenString string) (*jwt.RegisteredClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) { token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
return s.PublicKey, nil return &s.privateKey.PublicKey, nil
}) })
if err != nil || !token.Valid { if err != nil || !token.Valid {
return nil, errors.New("couldn't handle this token") return nil, errors.New("couldn't handle this token")
@@ -195,7 +189,7 @@ func (s *JwtService) VerifyOauthAccessToken(tokenString string) (*jwt.Registered
func (s *JwtService) VerifyIdToken(tokenString string) (*jwt.RegisteredClaims, error) { func (s *JwtService) VerifyIdToken(tokenString string) (*jwt.RegisteredClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) { token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
return s.PublicKey, nil return &s.privateKey.PublicKey, nil
}, jwt.WithIssuer(common.EnvConfig.AppURL)) }, jwt.WithIssuer(common.EnvConfig.AppURL))
if err != nil && !errors.Is(err, jwt.ErrTokenExpired) { if err != nil && !errors.Is(err, jwt.ErrTokenExpired) {
@@ -212,32 +206,27 @@ func (s *JwtService) VerifyIdToken(tokenString string) (*jwt.RegisteredClaims, e
// GetJWK returns the JSON Web Key (JWK) for the public key. // GetJWK returns the JSON Web Key (JWK) for the public key.
func (s *JwtService) GetJWK() (JWK, error) { func (s *JwtService) GetJWK() (JWK, error) {
if s.PublicKey == nil { if s.privateKey == nil {
return JWK{}, errors.New("public key is not initialized") return JWK{}, errors.New("public key is not initialized")
} }
kid, err := s.generateKeyID(s.PublicKey)
if err != nil {
return JWK{}, err
}
jwk := JWK{ jwk := JWK{
Kid: kid, Kid: s.keyId,
Kty: "RSA", Kty: "RSA",
Use: "sig", Use: "sig",
Alg: "RS256", Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(s.PublicKey.N.Bytes()), N: base64.RawURLEncoding.EncodeToString(s.privateKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(s.PublicKey.E)).Bytes()), E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(s.privateKey.E)).Bytes()),
} }
return jwk, nil return jwk, nil
} }
// GenerateKeyID generates a Key ID for the public key using the first 8 bytes of the SHA-256 hash of the public key. // GenerateKeyID generates a Key ID for the public key using the first 8 bytes of the SHA-256 hash of the public key.
func (s *JwtService) generateKeyID(publicKey *rsa.PublicKey) (string, error) { func (s *JwtService) generateKeyID() (string, error) {
pubASN1, err := x509.MarshalPKIXPublicKey(publicKey) pubASN1, err := x509.MarshalPKIXPublicKey(&s.privateKey.PublicKey)
if err != nil { if err != nil {
return "", errors.New("failed to marshal public key: " + err.Error()) return "", fmt.Errorf("failed to marshal public key: %w", err)
} }
// Compute SHA-256 hash of the public key // Compute SHA-256 hash of the public key
@@ -252,29 +241,22 @@ func (s *JwtService) generateKeyID(publicKey *rsa.PublicKey) (string, error) {
return base64.RawURLEncoding.EncodeToString(shortHash), nil return base64.RawURLEncoding.EncodeToString(shortHash), nil
} }
// generateKeys generates a new RSA key pair and saves them to the specified paths. // generateKey generates a new RSA key and saves it to the specified path.
func (s *JwtService) generateKeys() error { func (s *JwtService) generateKey(keysPath string) error {
if err := os.MkdirAll(filepath.Dir(privateKeyPath), 0700); err != nil { if err := os.MkdirAll(keysPath, 0700); err != nil {
return errors.New("failed to create directories for keys: " + err.Error()) return fmt.Errorf("failed to create directories for keys: %w", err)
} }
privateKey, err := rsa.GenerateKey(rand.Reader, 2048) privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil { if err != nil {
return errors.New("failed to generate private key: " + err.Error()) return fmt.Errorf("failed to generate private key: %w", err)
} }
s.PrivateKey = privateKey
privateKeyPath := filepath.Join(keysPath, privateKeyFile)
if err := s.savePEMKey(privateKeyPath, x509.MarshalPKCS1PrivateKey(privateKey), "RSA PRIVATE KEY"); err != nil { if err := s.savePEMKey(privateKeyPath, x509.MarshalPKCS1PrivateKey(privateKey), "RSA PRIVATE KEY"); err != nil {
return err return err
} }
publicKey := &privateKey.PublicKey
s.PublicKey = publicKey
if err := s.savePEMKey(publicKeyPath, x509.MarshalPKCS1PublicKey(publicKey), "RSA PUBLIC KEY"); err != nil {
return err
}
return nil return nil
} }
@@ -282,7 +264,7 @@ func (s *JwtService) generateKeys() error {
func (s *JwtService) savePEMKey(path string, keyBytes []byte, keyType string) error { func (s *JwtService) savePEMKey(path string, keyBytes []byte, keyType string) error {
keyFile, err := os.Create(path) keyFile, err := os.Create(path)
if err != nil { if err != nil {
return errors.New("failed to create key file: " + err.Error()) return fmt.Errorf("failed to create key file: %w", err)
} }
defer keyFile.Close() defer keyFile.Close()
@@ -292,7 +274,7 @@ func (s *JwtService) savePEMKey(path string, keyBytes []byte, keyType string) er
}) })
if _, err := keyFile.Write(keyPEM); err != nil { if _, err := keyFile.Write(keyPEM); err != nil {
return errors.New("failed to write key file: " + err.Error()) return fmt.Errorf("failed to write key file: %w", err)
} }
return nil return nil

View File

@@ -336,8 +336,7 @@ wbeF6l05LexCkI7ShsOuSt+dsyaTJTszuKDIA6YOfWvfo3aVZmlWRaI=
block, _ := pem.Decode([]byte(privateKeyString)) block, _ := pem.Decode([]byte(privateKeyString))
privateKey, _ := x509.ParsePKCS1PrivateKey(block.Bytes) privateKey, _ := x509.ParsePKCS1PrivateKey(block.Bytes)
s.jwtService.PrivateKey = privateKey s.jwtService.SetKey(privateKey)
s.jwtService.PublicKey = &privateKey.PublicKey
} }
// getCborPublicKey decodes a Base64 encoded public key and returns the CBOR encoded COSE key // getCborPublicKey decodes a Base64 encoded public key and returns the CBOR encoded COSE key