feat: add end session endpoint (#232)

This commit is contained in:
Elias Schneider
2025-02-14 17:09:27 +01:00
committed by GitHub
parent 4d0fff821e
commit 7550333fe2
25 changed files with 352 additions and 111 deletions

View File

@@ -8,7 +8,6 @@ import (
"encoding/base64"
"encoding/pem"
"errors"
"fmt"
"log"
"math/big"
"os"
@@ -28,8 +27,8 @@ const (
)
type JwtService struct {
publicKey *rsa.PublicKey
privateKey *rsa.PrivateKey
PublicKey *rsa.PublicKey
PrivateKey *rsa.PrivateKey
appConfigService *AppConfigService
}
@@ -72,7 +71,7 @@ func (s *JwtService) loadOrGenerateKeys() error {
if err != nil {
return errors.New("can't read jwt private key: " + err.Error())
}
s.privateKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes)
s.PrivateKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes)
if err != nil {
return errors.New("can't parse jwt private key: " + err.Error())
}
@@ -81,7 +80,7 @@ func (s *JwtService) loadOrGenerateKeys() error {
if err != nil {
return errors.New("can't read jwt public key: " + err.Error())
}
s.publicKey, err = jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes)
s.PublicKey, err = jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes)
if err != nil {
return errors.New("can't parse jwt public key: " + err.Error())
}
@@ -101,7 +100,7 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
IsAdmin: user.IsAdmin,
}
kid, err := s.generateKeyID(s.publicKey)
kid, err := s.generateKeyID(s.PublicKey)
if err != nil {
return "", errors.New("failed to generate key ID: " + err.Error())
}
@@ -109,12 +108,12 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
token.Header["kid"] = kid
return token.SignedString(s.privateKey)
return token.SignedString(s.PrivateKey)
}
func (s *JwtService) VerifyAccessToken(tokenString string) (*AccessTokenJWTClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &AccessTokenJWTClaims{}, func(token *jwt.Token) (interface{}, error) {
return s.publicKey, nil
return s.PublicKey, nil
})
if err != nil || !token.Valid {
return nil, errors.New("couldn't handle this token")
@@ -147,7 +146,7 @@ func (s *JwtService) GenerateIDToken(userClaims map[string]interface{}, clientID
claims["nonce"] = nonce
}
kid, err := s.generateKeyID(s.publicKey)
kid, err := s.generateKeyID(s.PublicKey)
if err != nil {
return "", errors.New("failed to generate key ID: " + err.Error())
}
@@ -155,7 +154,7 @@ func (s *JwtService) GenerateIDToken(userClaims map[string]interface{}, clientID
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = kid
return token.SignedString(s.privateKey)
return token.SignedString(s.PrivateKey)
}
func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) {
@@ -167,7 +166,7 @@ func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string)
Issuer: common.EnvConfig.AppURL,
}
kid, err := s.generateKeyID(s.publicKey)
kid, err := s.generateKeyID(s.PublicKey)
if err != nil {
return "", errors.New("failed to generate key ID: " + err.Error())
}
@@ -175,12 +174,12 @@ func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string)
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
token.Header["kid"] = kid
return token.SignedString(s.privateKey)
return token.SignedString(s.PrivateKey)
}
func (s *JwtService) VerifyOauthAccessToken(tokenString string) (*jwt.RegisteredClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
return s.publicKey, nil
return s.PublicKey, nil
})
if err != nil || !token.Valid {
return nil, errors.New("couldn't handle this token")
@@ -194,13 +193,30 @@ func (s *JwtService) VerifyOauthAccessToken(tokenString string) (*jwt.Registered
return claims, nil
}
func (s *JwtService) VerifyIdToken(tokenString string) (*jwt.RegisteredClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
return s.PublicKey, nil
}, jwt.WithIssuer(common.EnvConfig.AppURL))
if err != nil && !errors.Is(err, jwt.ErrTokenExpired) {
return nil, errors.New("couldn't handle this token")
}
claims, isValid := token.Claims.(*jwt.RegisteredClaims)
if !isValid {
return nil, errors.New("can't parse claims")
}
return claims, nil
}
// GetJWK returns the JSON Web Key (JWK) for the public key.
func (s *JwtService) GetJWK() (JWK, error) {
if s.publicKey == nil {
if s.PublicKey == nil {
return JWK{}, errors.New("public key is not initialized")
}
kid, err := s.generateKeyID(s.publicKey)
kid, err := s.generateKeyID(s.PublicKey)
if err != nil {
return JWK{}, err
}
@@ -210,8 +226,8 @@ func (s *JwtService) GetJWK() (JWK, error) {
Kty: "RSA",
Use: "sig",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(s.publicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(s.publicKey.E)).Bytes()),
N: base64.RawURLEncoding.EncodeToString(s.PublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(s.PublicKey.E)).Bytes()),
}
return jwk, nil
@@ -246,14 +262,14 @@ func (s *JwtService) generateKeys() error {
if err != nil {
return errors.New("failed to generate private key: " + err.Error())
}
s.privateKey = privateKey
s.PrivateKey = privateKey
if err := s.savePEMKey(privateKeyPath, x509.MarshalPKCS1PrivateKey(privateKey), "RSA PRIVATE KEY"); err != nil {
return err
}
publicKey := &privateKey.PublicKey
s.publicKey = publicKey
s.PublicKey = publicKey
if err := s.savePEMKey(publicKeyPath, x509.MarshalPKCS1PublicKey(publicKey), "RSA PUBLIC KEY"); err != nil {
return err
@@ -281,32 +297,3 @@ func (s *JwtService) savePEMKey(path string, keyBytes []byte, keyType string) er
return nil
}
// loadKeys loads RSA keys from the given paths.
func (s *JwtService) loadKeys() error {
if _, err := os.Stat(privateKeyPath); os.IsNotExist(err) {
if err := s.generateKeys(); err != nil {
return err
}
}
privateKeyBytes, err := os.ReadFile(privateKeyPath)
if err != nil {
return fmt.Errorf("can't read jwt private key: %w", err)
}
s.privateKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes)
if err != nil {
return fmt.Errorf("can't parse jwt private key: %w", err)
}
publicKeyBytes, err := os.ReadFile(publicKeyPath)
if err != nil {
return fmt.Errorf("can't read jwt public key: %w", err)
}
s.publicKey, err = jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes)
if err != nil {
return fmt.Errorf("can't parse jwt public key: %w", err)
}
return nil
}