feat: allow setting path where keys are stored (#327)

This commit is contained in:
Alessandro (Ale) Segala
2025-03-13 09:01:15 -07:00
committed by GitHub
parent 8c1c04db1d
commit 7b654c6bd1
3 changed files with 63 additions and 78 deletions

View File

@@ -23,6 +23,7 @@ type EnvConfigSchema struct {
SqliteDBPath string `env:"SQLITE_DB_PATH"`
PostgresConnectionString string `env:"POSTGRES_CONNECTION_STRING"`
UploadPath string `env:"UPLOAD_PATH"`
KeysPath string `env:"KEYS_PATH"`
Port string `env:"BACKEND_PORT"`
Host string `env:"HOST"`
MaxMindLicenseKey string `env:"MAXMIND_LICENSE_KEY"`
@@ -37,6 +38,7 @@ var EnvConfig = &EnvConfigSchema{
SqliteDBPath: "data/pocket-id.db",
PostgresConnectionString: "",
UploadPath: "data/uploads",
KeysPath: "data/keys",
AppURL: "http://localhost",
Port: "8080",
Host: "0.0.0.0",
@@ -50,19 +52,21 @@ func init() {
if err := env.ParseWithOptions(EnvConfig, env.Options{}); err != nil {
log.Fatal(err)
}
// Validate the environment variables
if EnvConfig.DbProvider != DbProviderSqlite && EnvConfig.DbProvider != DbProviderPostgres {
switch EnvConfig.DbProvider {
case DbProviderSqlite:
if EnvConfig.SqliteDBPath == "" {
log.Fatal("Missing SQLITE_DB_PATH environment variable")
}
case DbProviderPostgres:
if EnvConfig.PostgresConnectionString == "" {
log.Fatal("Missing POSTGRES_CONNECTION_STRING environment variable")
}
default:
log.Fatal("Invalid DB_PROVIDER value. Must be 'sqlite' or 'postgres'")
}
if EnvConfig.DbProvider == DbProviderPostgres && EnvConfig.PostgresConnectionString == "" {
log.Fatal("Missing POSTGRES_CONNECTION_STRING environment variable")
}
if EnvConfig.DbProvider == DbProviderSqlite && EnvConfig.SqliteDBPath == "" {
log.Fatal("Missing SQLITE_DB_PATH environment variable")
}
parsedAppUrl, err := url.Parse(EnvConfig.AppURL)
if err != nil {
log.Fatal("PUBLIC_APP_URL is not a valid URL")

View File

@@ -8,6 +8,7 @@ import (
"encoding/base64"
"encoding/pem"
"errors"
"fmt"
"log"
"math/big"
"os"
@@ -22,13 +23,12 @@ import (
)
const (
privateKeyPath = "data/keys/jwt_private_key.pem"
publicKeyPath = "data/keys/jwt_public_key.pem"
privateKeyFile = "jwt_private_key.pem"
)
type JwtService struct {
PublicKey *rsa.PublicKey
PrivateKey *rsa.PrivateKey
privateKey *rsa.PrivateKey
keyId string
appConfigService *AppConfigService
}
@@ -38,7 +38,7 @@ func NewJwtService(appConfigService *AppConfigService) *JwtService {
}
// Ensure keys are generated or loaded
if err := service.loadOrGenerateKeys(); err != nil {
if err := service.loadOrGenerateKey(common.EnvConfig.KeysPath); err != nil {
log.Fatalf("Failed to initialize jwt service: %v", err)
}
@@ -59,30 +59,39 @@ type JWK struct {
E string `json:"e"`
}
// loadOrGenerateKeys loads RSA keys from the given paths or generates them if they do not exist.
func (s *JwtService) loadOrGenerateKeys() error {
// loadOrGenerateKey loads RSA keys from the given paths or generates them if they do not exist.
func (s *JwtService) loadOrGenerateKey(keysPath string) error {
privateKeyPath := filepath.Join(keysPath, privateKeyFile)
if _, err := os.Stat(privateKeyPath); os.IsNotExist(err) {
if err := s.generateKeys(); err != nil {
return err
if err := s.generateKey(keysPath); err != nil {
return fmt.Errorf("can't generate key: %w", err)
}
}
privateKeyBytes, err := os.ReadFile(privateKeyPath)
if err != nil {
return errors.New("can't read jwt private key: " + err.Error())
return fmt.Errorf("can't read jwt private key: %w", err)
}
s.PrivateKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes)
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes)
if err != nil {
return errors.New("can't parse jwt private key: " + err.Error())
return fmt.Errorf("can't parse jwt private key: %w", err)
}
publicKeyBytes, err := os.ReadFile(publicKeyPath)
err = s.SetKey(privateKey)
if err != nil {
return errors.New("can't read jwt public key: " + err.Error())
return fmt.Errorf("failed to set private key: %w", err)
}
s.PublicKey, err = jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes)
return nil
}
func (s *JwtService) SetKey(privateKey *rsa.PrivateKey) (err error) {
s.privateKey = privateKey
s.keyId, err = s.generateKeyID()
if err != nil {
return errors.New("can't parse jwt public key: " + err.Error())
return fmt.Errorf("can't generate key ID: %w", err)
}
return nil
@@ -100,20 +109,15 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
IsAdmin: user.IsAdmin,
}
kid, err := s.generateKeyID(s.PublicKey)
if err != nil {
return "", errors.New("failed to generate key ID: " + err.Error())
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
token.Header["kid"] = kid
token.Header["kid"] = s.keyId
return token.SignedString(s.PrivateKey)
return token.SignedString(s.privateKey)
}
func (s *JwtService) VerifyAccessToken(tokenString string) (*AccessTokenJWTClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &AccessTokenJWTClaims{}, func(token *jwt.Token) (interface{}, error) {
return s.PublicKey, nil
return &s.privateKey.PublicKey, nil
})
if err != nil || !token.Valid {
return nil, errors.New("couldn't handle this token")
@@ -146,15 +150,10 @@ func (s *JwtService) GenerateIDToken(userClaims map[string]interface{}, clientID
claims["nonce"] = nonce
}
kid, err := s.generateKeyID(s.PublicKey)
if err != nil {
return "", errors.New("failed to generate key ID: " + err.Error())
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = kid
token.Header["kid"] = s.keyId
return token.SignedString(s.PrivateKey)
return token.SignedString(s.privateKey)
}
func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) {
@@ -166,20 +165,15 @@ func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string)
Issuer: common.EnvConfig.AppURL,
}
kid, err := s.generateKeyID(s.PublicKey)
if err != nil {
return "", errors.New("failed to generate key ID: " + err.Error())
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
token.Header["kid"] = kid
token.Header["kid"] = s.keyId
return token.SignedString(s.PrivateKey)
return token.SignedString(s.privateKey)
}
func (s *JwtService) VerifyOauthAccessToken(tokenString string) (*jwt.RegisteredClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
return s.PublicKey, nil
return &s.privateKey.PublicKey, nil
})
if err != nil || !token.Valid {
return nil, errors.New("couldn't handle this token")
@@ -195,7 +189,7 @@ func (s *JwtService) VerifyOauthAccessToken(tokenString string) (*jwt.Registered
func (s *JwtService) VerifyIdToken(tokenString string) (*jwt.RegisteredClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
return s.PublicKey, nil
return &s.privateKey.PublicKey, nil
}, jwt.WithIssuer(common.EnvConfig.AppURL))
if err != nil && !errors.Is(err, jwt.ErrTokenExpired) {
@@ -212,32 +206,27 @@ func (s *JwtService) VerifyIdToken(tokenString string) (*jwt.RegisteredClaims, e
// GetJWK returns the JSON Web Key (JWK) for the public key.
func (s *JwtService) GetJWK() (JWK, error) {
if s.PublicKey == nil {
if s.privateKey == nil {
return JWK{}, errors.New("public key is not initialized")
}
kid, err := s.generateKeyID(s.PublicKey)
if err != nil {
return JWK{}, err
}
jwk := JWK{
Kid: kid,
Kid: s.keyId,
Kty: "RSA",
Use: "sig",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(s.PublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(s.PublicKey.E)).Bytes()),
N: base64.RawURLEncoding.EncodeToString(s.privateKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(s.privateKey.E)).Bytes()),
}
return jwk, nil
}
// GenerateKeyID generates a Key ID for the public key using the first 8 bytes of the SHA-256 hash of the public key.
func (s *JwtService) generateKeyID(publicKey *rsa.PublicKey) (string, error) {
pubASN1, err := x509.MarshalPKIXPublicKey(publicKey)
func (s *JwtService) generateKeyID() (string, error) {
pubASN1, err := x509.MarshalPKIXPublicKey(&s.privateKey.PublicKey)
if err != nil {
return "", errors.New("failed to marshal public key: " + err.Error())
return "", fmt.Errorf("failed to marshal public key: %w", err)
}
// Compute SHA-256 hash of the public key
@@ -252,29 +241,22 @@ func (s *JwtService) generateKeyID(publicKey *rsa.PublicKey) (string, error) {
return base64.RawURLEncoding.EncodeToString(shortHash), nil
}
// generateKeys generates a new RSA key pair and saves them to the specified paths.
func (s *JwtService) generateKeys() error {
if err := os.MkdirAll(filepath.Dir(privateKeyPath), 0700); err != nil {
return errors.New("failed to create directories for keys: " + err.Error())
// generateKey generates a new RSA key and saves it to the specified path.
func (s *JwtService) generateKey(keysPath string) error {
if err := os.MkdirAll(keysPath, 0700); err != nil {
return fmt.Errorf("failed to create directories for keys: %w", err)
}
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return errors.New("failed to generate private key: " + err.Error())
return fmt.Errorf("failed to generate private key: %w", err)
}
s.PrivateKey = privateKey
privateKeyPath := filepath.Join(keysPath, privateKeyFile)
if err := s.savePEMKey(privateKeyPath, x509.MarshalPKCS1PrivateKey(privateKey), "RSA PRIVATE KEY"); err != nil {
return err
}
publicKey := &privateKey.PublicKey
s.PublicKey = publicKey
if err := s.savePEMKey(publicKeyPath, x509.MarshalPKCS1PublicKey(publicKey), "RSA PUBLIC KEY"); err != nil {
return err
}
return nil
}
@@ -282,7 +264,7 @@ func (s *JwtService) generateKeys() error {
func (s *JwtService) savePEMKey(path string, keyBytes []byte, keyType string) error {
keyFile, err := os.Create(path)
if err != nil {
return errors.New("failed to create key file: " + err.Error())
return fmt.Errorf("failed to create key file: %w", err)
}
defer keyFile.Close()
@@ -292,7 +274,7 @@ func (s *JwtService) savePEMKey(path string, keyBytes []byte, keyType string) er
})
if _, err := keyFile.Write(keyPEM); err != nil {
return errors.New("failed to write key file: " + err.Error())
return fmt.Errorf("failed to write key file: %w", err)
}
return nil

View File

@@ -336,8 +336,7 @@ wbeF6l05LexCkI7ShsOuSt+dsyaTJTszuKDIA6YOfWvfo3aVZmlWRaI=
block, _ := pem.Decode([]byte(privateKeyString))
privateKey, _ := x509.ParsePKCS1PrivateKey(block.Bytes)
s.jwtService.PrivateKey = privateKey
s.jwtService.PublicKey = &privateKey.PublicKey
s.jwtService.SetKey(privateKey)
}
// getCborPublicKey decodes a Base64 encoded public key and returns the CBOR encoded COSE key