mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-09 14:52:57 +03:00
feat: allow setting path where keys are stored (#327)
This commit is contained in:
committed by
GitHub
parent
8c1c04db1d
commit
7b654c6bd1
@@ -23,6 +23,7 @@ type EnvConfigSchema struct {
|
||||
SqliteDBPath string `env:"SQLITE_DB_PATH"`
|
||||
PostgresConnectionString string `env:"POSTGRES_CONNECTION_STRING"`
|
||||
UploadPath string `env:"UPLOAD_PATH"`
|
||||
KeysPath string `env:"KEYS_PATH"`
|
||||
Port string `env:"BACKEND_PORT"`
|
||||
Host string `env:"HOST"`
|
||||
MaxMindLicenseKey string `env:"MAXMIND_LICENSE_KEY"`
|
||||
@@ -37,6 +38,7 @@ var EnvConfig = &EnvConfigSchema{
|
||||
SqliteDBPath: "data/pocket-id.db",
|
||||
PostgresConnectionString: "",
|
||||
UploadPath: "data/uploads",
|
||||
KeysPath: "data/keys",
|
||||
AppURL: "http://localhost",
|
||||
Port: "8080",
|
||||
Host: "0.0.0.0",
|
||||
@@ -50,19 +52,21 @@ func init() {
|
||||
if err := env.ParseWithOptions(EnvConfig, env.Options{}); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Validate the environment variables
|
||||
if EnvConfig.DbProvider != DbProviderSqlite && EnvConfig.DbProvider != DbProviderPostgres {
|
||||
switch EnvConfig.DbProvider {
|
||||
case DbProviderSqlite:
|
||||
if EnvConfig.SqliteDBPath == "" {
|
||||
log.Fatal("Missing SQLITE_DB_PATH environment variable")
|
||||
}
|
||||
case DbProviderPostgres:
|
||||
if EnvConfig.PostgresConnectionString == "" {
|
||||
log.Fatal("Missing POSTGRES_CONNECTION_STRING environment variable")
|
||||
}
|
||||
default:
|
||||
log.Fatal("Invalid DB_PROVIDER value. Must be 'sqlite' or 'postgres'")
|
||||
}
|
||||
|
||||
if EnvConfig.DbProvider == DbProviderPostgres && EnvConfig.PostgresConnectionString == "" {
|
||||
log.Fatal("Missing POSTGRES_CONNECTION_STRING environment variable")
|
||||
}
|
||||
|
||||
if EnvConfig.DbProvider == DbProviderSqlite && EnvConfig.SqliteDBPath == "" {
|
||||
log.Fatal("Missing SQLITE_DB_PATH environment variable")
|
||||
}
|
||||
|
||||
parsedAppUrl, err := url.Parse(EnvConfig.AppURL)
|
||||
if err != nil {
|
||||
log.Fatal("PUBLIC_APP_URL is not a valid URL")
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/big"
|
||||
"os"
|
||||
@@ -22,13 +23,12 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
privateKeyPath = "data/keys/jwt_private_key.pem"
|
||||
publicKeyPath = "data/keys/jwt_public_key.pem"
|
||||
privateKeyFile = "jwt_private_key.pem"
|
||||
)
|
||||
|
||||
type JwtService struct {
|
||||
PublicKey *rsa.PublicKey
|
||||
PrivateKey *rsa.PrivateKey
|
||||
privateKey *rsa.PrivateKey
|
||||
keyId string
|
||||
appConfigService *AppConfigService
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ func NewJwtService(appConfigService *AppConfigService) *JwtService {
|
||||
}
|
||||
|
||||
// Ensure keys are generated or loaded
|
||||
if err := service.loadOrGenerateKeys(); err != nil {
|
||||
if err := service.loadOrGenerateKey(common.EnvConfig.KeysPath); err != nil {
|
||||
log.Fatalf("Failed to initialize jwt service: %v", err)
|
||||
}
|
||||
|
||||
@@ -59,30 +59,39 @@ type JWK struct {
|
||||
E string `json:"e"`
|
||||
}
|
||||
|
||||
// loadOrGenerateKeys loads RSA keys from the given paths or generates them if they do not exist.
|
||||
func (s *JwtService) loadOrGenerateKeys() error {
|
||||
// loadOrGenerateKey loads RSA keys from the given paths or generates them if they do not exist.
|
||||
func (s *JwtService) loadOrGenerateKey(keysPath string) error {
|
||||
privateKeyPath := filepath.Join(keysPath, privateKeyFile)
|
||||
|
||||
if _, err := os.Stat(privateKeyPath); os.IsNotExist(err) {
|
||||
if err := s.generateKeys(); err != nil {
|
||||
return err
|
||||
if err := s.generateKey(keysPath); err != nil {
|
||||
return fmt.Errorf("can't generate key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
privateKeyBytes, err := os.ReadFile(privateKeyPath)
|
||||
if err != nil {
|
||||
return errors.New("can't read jwt private key: " + err.Error())
|
||||
return fmt.Errorf("can't read jwt private key: %w", err)
|
||||
}
|
||||
s.PrivateKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes)
|
||||
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes)
|
||||
if err != nil {
|
||||
return errors.New("can't parse jwt private key: " + err.Error())
|
||||
return fmt.Errorf("can't parse jwt private key: %w", err)
|
||||
}
|
||||
|
||||
publicKeyBytes, err := os.ReadFile(publicKeyPath)
|
||||
err = s.SetKey(privateKey)
|
||||
if err != nil {
|
||||
return errors.New("can't read jwt public key: " + err.Error())
|
||||
return fmt.Errorf("failed to set private key: %w", err)
|
||||
}
|
||||
s.PublicKey, err = jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *JwtService) SetKey(privateKey *rsa.PrivateKey) (err error) {
|
||||
s.privateKey = privateKey
|
||||
|
||||
s.keyId, err = s.generateKeyID()
|
||||
if err != nil {
|
||||
return errors.New("can't parse jwt public key: " + err.Error())
|
||||
return fmt.Errorf("can't generate key ID: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -100,20 +109,15 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
|
||||
IsAdmin: user.IsAdmin,
|
||||
}
|
||||
|
||||
kid, err := s.generateKeyID(s.PublicKey)
|
||||
if err != nil {
|
||||
return "", errors.New("failed to generate key ID: " + err.Error())
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
|
||||
token.Header["kid"] = kid
|
||||
token.Header["kid"] = s.keyId
|
||||
|
||||
return token.SignedString(s.PrivateKey)
|
||||
return token.SignedString(s.privateKey)
|
||||
}
|
||||
|
||||
func (s *JwtService) VerifyAccessToken(tokenString string) (*AccessTokenJWTClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &AccessTokenJWTClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return s.PublicKey, nil
|
||||
return &s.privateKey.PublicKey, nil
|
||||
})
|
||||
if err != nil || !token.Valid {
|
||||
return nil, errors.New("couldn't handle this token")
|
||||
@@ -146,15 +150,10 @@ func (s *JwtService) GenerateIDToken(userClaims map[string]interface{}, clientID
|
||||
claims["nonce"] = nonce
|
||||
}
|
||||
|
||||
kid, err := s.generateKeyID(s.PublicKey)
|
||||
if err != nil {
|
||||
return "", errors.New("failed to generate key ID: " + err.Error())
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
token.Header["kid"] = kid
|
||||
token.Header["kid"] = s.keyId
|
||||
|
||||
return token.SignedString(s.PrivateKey)
|
||||
return token.SignedString(s.privateKey)
|
||||
}
|
||||
|
||||
func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) {
|
||||
@@ -166,20 +165,15 @@ func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string)
|
||||
Issuer: common.EnvConfig.AppURL,
|
||||
}
|
||||
|
||||
kid, err := s.generateKeyID(s.PublicKey)
|
||||
if err != nil {
|
||||
return "", errors.New("failed to generate key ID: " + err.Error())
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
|
||||
token.Header["kid"] = kid
|
||||
token.Header["kid"] = s.keyId
|
||||
|
||||
return token.SignedString(s.PrivateKey)
|
||||
return token.SignedString(s.privateKey)
|
||||
}
|
||||
|
||||
func (s *JwtService) VerifyOauthAccessToken(tokenString string) (*jwt.RegisteredClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return s.PublicKey, nil
|
||||
return &s.privateKey.PublicKey, nil
|
||||
})
|
||||
if err != nil || !token.Valid {
|
||||
return nil, errors.New("couldn't handle this token")
|
||||
@@ -195,7 +189,7 @@ func (s *JwtService) VerifyOauthAccessToken(tokenString string) (*jwt.Registered
|
||||
|
||||
func (s *JwtService) VerifyIdToken(tokenString string) (*jwt.RegisteredClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return s.PublicKey, nil
|
||||
return &s.privateKey.PublicKey, nil
|
||||
}, jwt.WithIssuer(common.EnvConfig.AppURL))
|
||||
|
||||
if err != nil && !errors.Is(err, jwt.ErrTokenExpired) {
|
||||
@@ -212,32 +206,27 @@ func (s *JwtService) VerifyIdToken(tokenString string) (*jwt.RegisteredClaims, e
|
||||
|
||||
// GetJWK returns the JSON Web Key (JWK) for the public key.
|
||||
func (s *JwtService) GetJWK() (JWK, error) {
|
||||
if s.PublicKey == nil {
|
||||
if s.privateKey == nil {
|
||||
return JWK{}, errors.New("public key is not initialized")
|
||||
}
|
||||
|
||||
kid, err := s.generateKeyID(s.PublicKey)
|
||||
if err != nil {
|
||||
return JWK{}, err
|
||||
}
|
||||
|
||||
jwk := JWK{
|
||||
Kid: kid,
|
||||
Kid: s.keyId,
|
||||
Kty: "RSA",
|
||||
Use: "sig",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(s.PublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(s.PublicKey.E)).Bytes()),
|
||||
N: base64.RawURLEncoding.EncodeToString(s.privateKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(s.privateKey.E)).Bytes()),
|
||||
}
|
||||
|
||||
return jwk, nil
|
||||
}
|
||||
|
||||
// GenerateKeyID generates a Key ID for the public key using the first 8 bytes of the SHA-256 hash of the public key.
|
||||
func (s *JwtService) generateKeyID(publicKey *rsa.PublicKey) (string, error) {
|
||||
pubASN1, err := x509.MarshalPKIXPublicKey(publicKey)
|
||||
func (s *JwtService) generateKeyID() (string, error) {
|
||||
pubASN1, err := x509.MarshalPKIXPublicKey(&s.privateKey.PublicKey)
|
||||
if err != nil {
|
||||
return "", errors.New("failed to marshal public key: " + err.Error())
|
||||
return "", fmt.Errorf("failed to marshal public key: %w", err)
|
||||
}
|
||||
|
||||
// Compute SHA-256 hash of the public key
|
||||
@@ -252,29 +241,22 @@ func (s *JwtService) generateKeyID(publicKey *rsa.PublicKey) (string, error) {
|
||||
return base64.RawURLEncoding.EncodeToString(shortHash), nil
|
||||
}
|
||||
|
||||
// generateKeys generates a new RSA key pair and saves them to the specified paths.
|
||||
func (s *JwtService) generateKeys() error {
|
||||
if err := os.MkdirAll(filepath.Dir(privateKeyPath), 0700); err != nil {
|
||||
return errors.New("failed to create directories for keys: " + err.Error())
|
||||
// generateKey generates a new RSA key and saves it to the specified path.
|
||||
func (s *JwtService) generateKey(keysPath string) error {
|
||||
if err := os.MkdirAll(keysPath, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directories for keys: %w", err)
|
||||
}
|
||||
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return errors.New("failed to generate private key: " + err.Error())
|
||||
return fmt.Errorf("failed to generate private key: %w", err)
|
||||
}
|
||||
s.PrivateKey = privateKey
|
||||
|
||||
privateKeyPath := filepath.Join(keysPath, privateKeyFile)
|
||||
if err := s.savePEMKey(privateKeyPath, x509.MarshalPKCS1PrivateKey(privateKey), "RSA PRIVATE KEY"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
publicKey := &privateKey.PublicKey
|
||||
s.PublicKey = publicKey
|
||||
|
||||
if err := s.savePEMKey(publicKeyPath, x509.MarshalPKCS1PublicKey(publicKey), "RSA PUBLIC KEY"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -282,7 +264,7 @@ func (s *JwtService) generateKeys() error {
|
||||
func (s *JwtService) savePEMKey(path string, keyBytes []byte, keyType string) error {
|
||||
keyFile, err := os.Create(path)
|
||||
if err != nil {
|
||||
return errors.New("failed to create key file: " + err.Error())
|
||||
return fmt.Errorf("failed to create key file: %w", err)
|
||||
}
|
||||
defer keyFile.Close()
|
||||
|
||||
@@ -292,7 +274,7 @@ func (s *JwtService) savePEMKey(path string, keyBytes []byte, keyType string) er
|
||||
})
|
||||
|
||||
if _, err := keyFile.Write(keyPEM); err != nil {
|
||||
return errors.New("failed to write key file: " + err.Error())
|
||||
return fmt.Errorf("failed to write key file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -336,8 +336,7 @@ wbeF6l05LexCkI7ShsOuSt+dsyaTJTszuKDIA6YOfWvfo3aVZmlWRaI=
|
||||
block, _ := pem.Decode([]byte(privateKeyString))
|
||||
privateKey, _ := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
|
||||
s.jwtService.PrivateKey = privateKey
|
||||
s.jwtService.PublicKey = &privateKey.PublicKey
|
||||
s.jwtService.SetKey(privateKey)
|
||||
}
|
||||
|
||||
// getCborPublicKey decodes a Base64 encoded public key and returns the CBOR encoded COSE key
|
||||
|
||||
Reference in New Issue
Block a user