package service import ( "context" "encoding/json" "errors" "fmt" "time" "github.com/lestrrat-go/jwx/v3/jwa" "github.com/lestrrat-go/jwx/v3/jwk" "github.com/lestrrat-go/jwx/v3/jwt" "gorm.io/gorm" "github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/model" jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk" ) const ( // PrivateKeyFile is the path in the data/keys folder where the key is stored // This is a JSON file containing a key encoded as JWK PrivateKeyFile = "jwt_private_key.json" // PrivateKeyFileEncrypted is the path in the data/keys folder where the encrypted key is stored // This is a encrypted JSON file containing a key encoded as JWK PrivateKeyFileEncrypted = "jwt_private_key.json.enc" // KeyUsageSigning is the usage for the private keys, for the "use" property KeyUsageSigning = "sig" // IsAdminClaim is a boolean claim used in access tokens for admin users // This may be omitted on non-admin tokens IsAdminClaim = "isAdmin" // TokenTypeClaim is the claim used to identify the type of token TokenTypeClaim = "type" // RefreshTokenClaim is the claim used for the refresh token's value RefreshTokenClaim = "rt" // OAuthAccessTokenJWTType identifies a JWT as an OAuth access token OAuthAccessTokenJWTType = "oauth-access-token" //nolint:gosec // OAuthRefreshTokenJWTType identifies a JWT as an OAuth refresh token OAuthRefreshTokenJWTType = "refresh-token" // AccessTokenJWTType identifies a JWT as an access token used by Pocket ID AccessTokenJWTType = "access-token" // IDTokenJWTType identifies a JWT as an ID token used by Pocket ID IDTokenJWTType = "id-token" // Acceptable clock skew for verifying tokens clockSkew = time.Minute ) type JwtService struct { envConfig *common.EnvConfigSchema privateKey jwk.Key keyId string appConfigService *AppConfigService jwksEncoded []byte } func NewJwtService(db *gorm.DB, appConfigService *AppConfigService) (*JwtService, error) { service := &JwtService{} // Ensure keys are generated or loaded err := service.init(db, appConfigService, &common.EnvConfig) if err != nil { return nil, err } return service, nil } func (s *JwtService) init(db *gorm.DB, appConfigService *AppConfigService, envConfig *common.EnvConfigSchema) (err error) { s.appConfigService = appConfigService s.envConfig = envConfig // Ensure keys are generated or loaded return s.loadOrGenerateKey(db) } func (s *JwtService) loadOrGenerateKey(db *gorm.DB) error { // Get the key provider keyProvider, err := jwkutils.GetKeyProvider(db, s.envConfig, s.appConfigService.GetDbConfig().InstanceID.Value) if err != nil { return fmt.Errorf("failed to get key provider: %w", err) } // Try loading a key key, err := keyProvider.LoadKey() if err != nil { return fmt.Errorf("failed to load key (provider type '%s'): %w", s.envConfig.KeysStorage, err) } // If we have a key, store it in the object and we're done if key != nil { err = s.SetKey(key) if err != nil { return fmt.Errorf("failed to set private key: %w", err) } return nil } // If we are here, we need to generate a new key err = s.generateKey() if err != nil { return fmt.Errorf("failed to generate key: %w", err) } // Save the newly-generated key err = keyProvider.SaveKey(s.privateKey) if err != nil { return fmt.Errorf("failed to save private key (provider type '%s'): %w", s.envConfig.KeysStorage, err) } return nil } // generateKey generates a new key and stores it in the object func (s *JwtService) generateKey() error { // Default is to generate RS256 (RSA-2048) keys key, err := jwkutils.GenerateKey(jwa.RS256().String(), "") if err != nil { return fmt.Errorf("failed to generate new private key: %w", err) } // Set the key in the object, which also validates it err = s.SetKey(key) if err != nil { return fmt.Errorf("failed to set private key: %w", err) } return nil } func ValidateKey(privateKey jwk.Key) error { // Validate the loaded key err := privateKey.Validate() if err != nil { return fmt.Errorf("key object is invalid: %w", err) } keyID, ok := privateKey.KeyID() if !ok || keyID == "" { return errors.New("key object does not contain a key ID") } usage, ok := privateKey.KeyUsage() if !ok || usage != KeyUsageSigning { return errors.New("key object is not valid for signing") } ok, err = jwk.IsPrivateKey(privateKey) if err != nil || !ok { return errors.New("key object is not a private key") } return nil } func (s *JwtService) SetKey(privateKey jwk.Key) error { // Validate the loaded key err := ValidateKey(privateKey) if err != nil { return fmt.Errorf("private key is not valid: %w", err) } // Set the private key and key id in the object s.privateKey = privateKey keyId, ok := privateKey.KeyID() if !ok { return errors.New("key object does not contain a key ID") } s.keyId = keyId // Create and encode a JWKS containing the public key publicKey, err := s.GetPublicJWK() if err != nil { return fmt.Errorf("failed to get public JWK: %w", err) } jwks := jwk.NewSet() err = jwks.AddKey(publicKey) if err != nil { return fmt.Errorf("failed to add public key to JWKS: %w", err) } s.jwksEncoded, err = json.Marshal(jwks) if err != nil { return fmt.Errorf("failed to encode JWKS to JSON: %w", err) } return nil } func (s *JwtService) GenerateAccessToken(user model.User) (string, error) { now := time.Now() token, err := jwt.NewBuilder(). Subject(user.ID). Expiration(now.Add(s.appConfigService.GetDbConfig().SessionDuration.AsDurationMinutes())). IssuedAt(now). Issuer(s.envConfig.AppURL). Build() if err != nil { return "", fmt.Errorf("failed to build token: %w", err) } err = SetAudienceString(token, s.envConfig.AppURL) if err != nil { return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err) } err = SetTokenType(token, AccessTokenJWTType) if err != nil { return "", fmt.Errorf("failed to set 'type' claim in token: %w", err) } err = SetIsAdmin(token, user.IsAdmin) if err != nil { return "", fmt.Errorf("failed to set 'isAdmin' claim in token: %w", err) } alg, _ := s.privateKey.Algorithm() signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey)) if err != nil { return "", fmt.Errorf("failed to sign token: %w", err) } return string(signed), nil } func (s *JwtService) VerifyAccessToken(tokenString string) (jwt.Token, error) { alg, _ := s.privateKey.Algorithm() token, err := jwt.ParseString( tokenString, jwt.WithValidate(true), jwt.WithKey(alg, s.privateKey), jwt.WithAcceptableSkew(clockSkew), jwt.WithAudience(s.envConfig.AppURL), jwt.WithIssuer(s.envConfig.AppURL), jwt.WithValidator(TokenTypeValidator(AccessTokenJWTType)), ) if err != nil { return nil, fmt.Errorf("failed to parse token: %w", err) } return token, nil } // BuildIDToken creates an ID token with all claims func (s *JwtService) BuildIDToken(userClaims map[string]any, clientID string, nonce string) (jwt.Token, error) { now := time.Now() token, err := jwt.NewBuilder(). Expiration(now.Add(1 * time.Hour)). IssuedAt(now). Issuer(s.envConfig.AppURL). Build() if err != nil { return nil, fmt.Errorf("failed to build token: %w", err) } err = SetAudienceString(token, clientID) if err != nil { return nil, fmt.Errorf("failed to set 'aud' claim in token: %w", err) } err = SetTokenType(token, IDTokenJWTType) if err != nil { return nil, fmt.Errorf("failed to set 'type' claim in token: %w", err) } for k, v := range userClaims { err = token.Set(k, v) if err != nil { return nil, fmt.Errorf("failed to set claim '%s': %w", k, err) } } if nonce != "" { err = token.Set("nonce", nonce) if err != nil { return nil, fmt.Errorf("failed to set claim 'nonce': %w", err) } } return token, nil } // GenerateIDToken creates and signs an ID token func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string, nonce string) (string, error) { token, err := s.BuildIDToken(userClaims, clientID, nonce) if err != nil { return "", err } alg, _ := s.privateKey.Algorithm() signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey)) if err != nil { return "", fmt.Errorf("failed to sign token: %w", err) } return string(signed), nil } func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool) (jwt.Token, error) { alg, _ := s.privateKey.Algorithm() opts := make([]jwt.ParseOption, 0) // These options are always present opts = append(opts, jwt.WithValidate(true), jwt.WithKey(alg, s.privateKey), jwt.WithAcceptableSkew(clockSkew), jwt.WithIssuer(s.envConfig.AppURL), jwt.WithValidator(TokenTypeValidator(IDTokenJWTType)), ) // By default, jwt.Parse includes 3 default validators for "nbf", "iat", and "exp" // In case we want to accept expired tokens (during logout), we need to set the validators explicitly without validating "exp" if acceptExpiredTokens { // This is equivalent to the default validators except it doesn't validate "exp" opts = append(opts, jwt.WithResetValidators(true), jwt.WithValidator(jwt.IsIssuedAtValid()), jwt.WithValidator(jwt.IsNbfValid()), ) } token, err := jwt.ParseString(tokenString, opts...) if err != nil { return nil, fmt.Errorf("failed to parse token: %w", err) } return token, nil } // BuildOAuthAccessToken creates an OAuth access token with all claims func (s *JwtService) BuildOAuthAccessToken(user model.User, clientID string) (jwt.Token, error) { now := time.Now() token, err := jwt.NewBuilder(). Subject(user.ID). Expiration(now.Add(1 * time.Hour)). IssuedAt(now). Issuer(s.envConfig.AppURL). Build() if err != nil { return nil, fmt.Errorf("failed to build token: %w", err) } err = SetAudienceString(token, clientID) if err != nil { return nil, fmt.Errorf("failed to set 'aud' claim in token: %w", err) } err = SetTokenType(token, OAuthAccessTokenJWTType) if err != nil { return nil, fmt.Errorf("failed to set 'type' claim in token: %w", err) } return token, nil } // GenerateOAuthAccessToken creates and signs an OAuth access token func (s *JwtService) GenerateOAuthAccessToken(user model.User, clientID string) (string, error) { token, err := s.BuildOAuthAccessToken(user, clientID) if err != nil { return "", err } alg, _ := s.privateKey.Algorithm() signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey)) if err != nil { return "", fmt.Errorf("failed to sign token: %w", err) } return string(signed), nil } func (s *JwtService) VerifyOAuthAccessToken(tokenString string) (jwt.Token, error) { alg, _ := s.privateKey.Algorithm() token, err := jwt.ParseString( tokenString, jwt.WithValidate(true), jwt.WithKey(alg, s.privateKey), jwt.WithAcceptableSkew(clockSkew), jwt.WithIssuer(s.envConfig.AppURL), jwt.WithValidator(TokenTypeValidator(OAuthAccessTokenJWTType)), ) if err != nil { return nil, fmt.Errorf("failed to parse token: %w", err) } return token, nil } func (s *JwtService) GenerateOAuthRefreshToken(userID string, clientID string, refreshToken string) (string, error) { now := time.Now() token, err := jwt.NewBuilder(). Subject(userID). Expiration(now.Add(RefreshTokenDuration)). IssuedAt(now). Issuer(s.envConfig.AppURL). Build() if err != nil { return "", fmt.Errorf("failed to build token: %w", err) } err = token.Set(RefreshTokenClaim, refreshToken) if err != nil { return "", fmt.Errorf("failed to set 'rt' claim in token: %w", err) } err = SetAudienceString(token, clientID) if err != nil { return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err) } err = SetTokenType(token, OAuthRefreshTokenJWTType) if err != nil { return "", fmt.Errorf("failed to set 'type' claim in token: %w", err) } alg, _ := s.privateKey.Algorithm() signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey)) if err != nil { return "", fmt.Errorf("failed to sign token: %w", err) } return string(signed), nil } func (s *JwtService) VerifyOAuthRefreshToken(tokenString string) (userID, clientID, rt string, err error) { alg, _ := s.privateKey.Algorithm() token, err := jwt.ParseString( tokenString, jwt.WithValidate(true), jwt.WithKey(alg, s.privateKey), jwt.WithAcceptableSkew(clockSkew), jwt.WithIssuer(s.envConfig.AppURL), jwt.WithValidator(TokenTypeValidator(OAuthRefreshTokenJWTType)), ) if err != nil { return "", "", "", fmt.Errorf("failed to parse token: %w", err) } err = token.Get(RefreshTokenClaim, &rt) if err != nil { return "", "", "", fmt.Errorf("failed to get '%s' claim from token: %w", RefreshTokenClaim, err) } audiences, ok := token.Audience() if !ok || len(audiences) != 1 || audiences[0] == "" { return "", "", "", errors.New("failed to get 'aud' claim from token") } clientID = audiences[0] userID, ok = token.Subject() if !ok { return "", "", "", errors.New("failed to get 'sub' claim from token") } return userID, clientID, rt, nil } // GetTokenType returns the type of the JWT token issued by Pocket ID, but **does not validate it**. func (s *JwtService) GetTokenType(tokenString string) (string, jwt.Token, error) { // Disable validation and verification to parse the token without checking it token, err := jwt.ParseString( tokenString, jwt.WithValidate(false), jwt.WithVerify(false), ) if err != nil { return "", nil, fmt.Errorf("failed to parse token: %w", err) } var tokenType string err = token.Get(TokenTypeClaim, &tokenType) if err != nil { return "", nil, fmt.Errorf("failed to get token type claim: %w", err) } return tokenType, token, nil } // GetPublicJWK returns the JSON Web Key (JWK) for the public key. func (s *JwtService) GetPublicJWK() (jwk.Key, error) { if s.privateKey == nil { return nil, errors.New("key is not initialized") } pubKey, err := s.privateKey.PublicKey() if err != nil { return nil, fmt.Errorf("failed to get public key: %w", err) } jwkutils.EnsureAlgInKey(pubKey, "", "") return pubKey, nil } // GetPublicJWKSAsJSON returns the JSON Web Key Set (JWKS) for the public key, encoded as JSON. // The value is cached since the key is static. func (s *JwtService) GetPublicJWKSAsJSON() ([]byte, error) { if len(s.jwksEncoded) == 0 { return nil, errors.New("key is not initialized") } return s.jwksEncoded, nil } // GetKeyAlg returns the algorithm of the key func (s *JwtService) GetKeyAlg() (jwa.KeyAlgorithm, error) { if len(s.jwksEncoded) == 0 { return nil, errors.New("key is not initialized") } alg, ok := s.privateKey.Algorithm() if !ok || alg == nil { return nil, errors.New("failed to retrieve algorithm for key") } return alg, nil } // GetIsAdmin returns the value of the "isAdmin" claim in the token func GetIsAdmin(token jwt.Token) (bool, error) { if !token.Has(IsAdminClaim) { return false, nil } var isAdmin bool err := token.Get(IsAdminClaim, &isAdmin) if err != nil { return false, fmt.Errorf("failed to get 'isAdmin' claim from token: %w", err) } return isAdmin, nil } // SetTokenType sets the "type" claim in the token func SetTokenType(token jwt.Token, tokenType string) error { if tokenType == "" { return nil } return token.Set(TokenTypeClaim, tokenType) } // SetIsAdmin sets the "isAdmin" claim in the token func SetIsAdmin(token jwt.Token, isAdmin bool) error { // Only set if true if !isAdmin { return nil } return token.Set(IsAdminClaim, isAdmin) } // SetAudienceString sets the "aud" claim with a value that is a string, and not an array // This is permitted by RFC 7519, and it's done here for backwards-compatibility func SetAudienceString(token jwt.Token, audience string) error { return token.Set(jwt.AudienceKey, audience) } // TokenTypeValidator is a validator function that checks the "type" claim in the token func TokenTypeValidator(expectedTokenType string) jwt.ValidatorFunc { return func(_ context.Context, t jwt.Token) error { var tokenType string err := t.Get(TokenTypeClaim, &tokenType) if err != nil { return fmt.Errorf("failed to get token type claim: %w", err) } if tokenType != expectedTokenType { return fmt.Errorf("invalid token type: expected %s, got %s", expectedTokenType, tokenType) } return nil } }