diff --git a/backend/internal/service/jwt_service.go b/backend/internal/service/jwt_service.go index 5f20fbb7..35dfd848 100644 --- a/backend/internal/service/jwt_service.go +++ b/backend/internal/service/jwt_service.go @@ -1,6 +1,7 @@ package service import ( + "context" "crypto/rand" "crypto/rsa" "encoding/base64" @@ -11,11 +12,8 @@ import ( "log" "os" "path/filepath" - "strings" "time" - "github.com/lestrrat-go/jwx/v3/jws" - "github.com/lestrrat-go/jwx/v3/jwa" "github.com/lestrrat-go/jwx/v3/jwk" "github.com/lestrrat-go/jwx/v3/jwt" @@ -40,11 +38,17 @@ const ( // This may be omitted on non-admin tokens IsAdminClaim = "isAdmin" - // AccessTokenJWTType is the media type for access tokens - AccessTokenJWTType = "AT+JWT" + // TokenTypeClaim is the claim used to identify the type of token + TokenTypeClaim = "type" - // IDTokenJWTType is the media type for ID tokens - IDTokenJWTType = "ID+JWT" + // OAuthAccessTokenJWTType identifies a JWT as an OAuth access token + OAuthAccessTokenJWTType = "oauth-access-token" //nolint:gosec + + // AccessTokenJWTType identifies a JWT as an access token used by Pocket ID + AccessTokenJWTType = "access-token" + + // IDTokenJWTType identifies a JWT as an ID token used by Pocket ID + IDTokenJWTType = "id-token" // Acceptable clock skew for verifying tokens clockSkew = time.Minute @@ -195,6 +199,11 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) { return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err) } + err = SetTokenType(token, AccessTokenJWTType) + if err != nil { + return "", fmt.Errorf("failed to set 'type' claim in token: %w", err) + } + err = SetIsAdmin(token, user.IsAdmin) if err != nil { return "", fmt.Errorf("failed to set 'isAdmin' claim in token: %w", err) @@ -218,6 +227,7 @@ func (s *JwtService) VerifyAccessToken(tokenString string) (jwt.Token, error) { jwt.WithAcceptableSkew(clockSkew), jwt.WithAudience(common.EnvConfig.AppURL), jwt.WithIssuer(common.EnvConfig.AppURL), + jwt.WithValidator(TokenTypeValidator(AccessTokenJWTType)), ) if err != nil { return nil, fmt.Errorf("failed to parse token: %w", err) @@ -242,6 +252,11 @@ func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string, return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err) } + err = SetTokenType(token, IDTokenJWTType) + if err != nil { + return "", fmt.Errorf("failed to set 'type' claim in token: %w", err) + } + for k, v := range userClaims { err = token.Set(k, v) if err != nil { @@ -256,13 +271,8 @@ func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string, } } - headers, err := CreateTokenTypeHeader(IDTokenJWTType) - if err != nil { - return "", fmt.Errorf("failed to set token type: %w", err) - } - alg, _ := s.privateKey.Algorithm() - signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey, jws.WithProtectedHeaders(headers))) + signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey)) if err != nil { return "", fmt.Errorf("failed to sign token: %w", err) } @@ -281,6 +291,7 @@ func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool) jwt.WithKey(alg, s.privateKey), jwt.WithAcceptableSkew(clockSkew), jwt.WithIssuer(common.EnvConfig.AppURL), + jwt.WithValidator(TokenTypeValidator(IDTokenJWTType)), ) // By default, jwt.Parse includes 3 default validators for "nbf", "iat", and "exp" @@ -299,11 +310,6 @@ func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool) return nil, fmt.Errorf("failed to parse token: %w", err) } - err = VerifyTokenTypeHeader(tokenString, IDTokenJWTType) - if err != nil { - return nil, fmt.Errorf("failed to verify token type: %w", err) - } - return token, nil } @@ -324,13 +330,13 @@ func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err) } - headers, err := CreateTokenTypeHeader(AccessTokenJWTType) + err = SetTokenType(token, OAuthAccessTokenJWTType) if err != nil { - return "", fmt.Errorf("failed to set token type: %w", err) + return "", fmt.Errorf("failed to set 'type' claim in token: %w", err) } alg, _ := s.privateKey.Algorithm() - signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey, jws.WithProtectedHeaders(headers))) + signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey)) if err != nil { return "", fmt.Errorf("failed to sign token: %w", err) } @@ -346,16 +352,12 @@ func (s *JwtService) VerifyOauthAccessToken(tokenString string) (jwt.Token, erro jwt.WithKey(alg, s.privateKey), jwt.WithAcceptableSkew(clockSkew), jwt.WithIssuer(common.EnvConfig.AppURL), + jwt.WithValidator(TokenTypeValidator(OAuthAccessTokenJWTType)), ) if err != nil { return nil, fmt.Errorf("failed to parse token: %w", err) } - err = VerifyTokenTypeHeader(tokenString, AccessTokenJWTType) - if err != nil { - return nil, fmt.Errorf("failed to verify token type: %w", err) - } - return token, nil } @@ -510,15 +512,12 @@ func GetIsAdmin(token jwt.Token) (bool, error) { return isAdmin, err } -// CreateTokenTypeHeader creates a new JWS header with the given token type -func CreateTokenTypeHeader(tokenType string) (jws.Headers, error) { - headers := jws.NewHeaders() - err := headers.Set(jws.TypeKey, tokenType) - if err != nil { - return nil, fmt.Errorf("failed to set token type: %w", err) +// SetTokenType sets the "type" claim in the token +func SetTokenType(token jwt.Token, tokenType string) error { + if tokenType == "" { + return nil } - - return headers, nil + return token.Set(TokenTypeClaim, tokenType) } // SetIsAdmin sets the "isAdmin" claim in the token @@ -536,36 +535,17 @@ func SetAudienceString(token jwt.Token, audience string) error { return token.Set(jwt.AudienceKey, audience) } -// VerifyTokenTypeHeader verifies that the "typ" header in the token matches the expected type -func VerifyTokenTypeHeader(tokenBytes string, expectedTokenType string) error { - // Parse the raw token string purely as a JWS message structure - // We don't need to verify the signature at this stage, just inspect headers. - msg, err := jws.Parse([]byte(tokenBytes)) - if err != nil { - return fmt.Errorf("failed to parse token as JWS message: %w", err) +// TokenTypeValidator is a validator function that checks the "type" claim in the token +func TokenTypeValidator(expectedTokenType string) jwt.ValidatorFunc { + return func(_ context.Context, t jwt.Token) error { + var tokenType string + err := t.Get(TokenTypeClaim, &tokenType) + if err != nil { + return fmt.Errorf("failed to get token type claim: %w", err) + } + if tokenType != expectedTokenType { + return fmt.Errorf("invalid token type: expected %s, got %s", expectedTokenType, tokenType) + } + return nil } - - // Get the list of signatures attached to the message. Usually just one for JWT. - signatures := msg.Signatures() - if len(signatures) == 0 { - return errors.New("JWS message contains no signatures") - } - - protectedHeaders := signatures[0].ProtectedHeaders() - if protectedHeaders == nil { - return fmt.Errorf("JWS signature has no protected headers") - } - - // Retrieve the 'typ' header value from the PROTECTED headers. - var typHeaderValue string - err = protectedHeaders.Get(jws.TypeKey, &typHeaderValue) - if err != nil { - return fmt.Errorf("token is missing required protected header '%s'", jws.TypeKey) - } - - if !strings.EqualFold(typHeaderValue, expectedTokenType) { - return fmt.Errorf("'%s' header mismatch: expected '%s', got '%s'", jws.TypeKey, expectedTokenType, typHeaderValue) - } - - return nil } diff --git a/backend/internal/service/jwt_service_test.go b/backend/internal/service/jwt_service_test.go index 3ff7a0bc..e4f7babf 100644 --- a/backend/internal/service/jwt_service_test.go +++ b/backend/internal/service/jwt_service_test.go @@ -1,20 +1,18 @@ package service import ( + "context" "crypto/ecdsa" "crypto/ed25519" "crypto/elliptic" "crypto/rand" "crypto/rsa" - "fmt" "os" "path/filepath" "sync" "testing" "time" - "github.com/lestrrat-go/jwx/v3/jws" - "github.com/lestrrat-go/jwx/v3/jwa" "github.com/lestrrat-go/jwx/v3/jwk" "github.com/lestrrat-go/jwx/v3/jwt" @@ -636,6 +634,9 @@ func TestGenerateVerifyIdToken(t *testing.T) { Build() require.NoError(t, err, "Failed to build token") + err = SetTokenType(token, IDTokenJWTType) + require.NoError(t, err, "Failed to set token type") + // Add custom claims for k, v := range userClaims { if k != "sub" { // Already set above @@ -644,13 +645,8 @@ func TestGenerateVerifyIdToken(t *testing.T) { } } - // Create headers with the specified type - hdrs := jws.NewHeaders() - err = hdrs.Set(jws.TypeKey, "ID+JWT") - require.NoError(t, err, "Failed to set header type") - // Sign the token - signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey, jws.WithProtectedHeaders(hdrs))) + signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey)) require.NoError(t, err, "Failed to sign token") tokenString := string(signed) @@ -968,6 +964,9 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) { Build() require.NoError(t, err, "Failed to build token") + err = SetTokenType(token, OAuthAccessTokenJWTType) + require.NoError(t, err, "Failed to set token type") + signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey)) require.NoError(t, err, "Failed to sign token") @@ -1168,59 +1167,50 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) { }) } -func TestVerifyTokenTypeHeader(t *testing.T) { - mockConfig := &AppConfigService{} - tempDir := t.TempDir() - // Helper function to create a token with a specific type header - createTokenWithType := func(tokenType string) (string, error) { - // Create a simple JWT token - token := jwt.New() - err := token.Set("test_claim", "test_value") - if err != nil { - return "", fmt.Errorf("failed to set claim: %w", err) - } - - // Create headers with the specified type - hdrs := jws.NewHeaders() - if tokenType != "" { - err = hdrs.Set(jws.TypeKey, tokenType) - if err != nil { - return "", fmt.Errorf("failed to set type header: %w", err) - } - } - - // Sign the token with the headers - service := &JwtService{} - err = service.init(mockConfig, tempDir) - require.NoError(t, err, "Failed to initialize JWT service") - - signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey, jws.WithProtectedHeaders(hdrs))) - if err != nil { - return "", fmt.Errorf("failed to sign token: %w", err) - } - - return string(signed), nil - } +func TestTokenTypeValidator(t *testing.T) { + // Create a context for the validator function + ctx := context.Background() t.Run("succeeds when token type matches expected type", func(t *testing.T) { - // Create a token with "JWT" type - tokenString, err := createTokenWithType("JWT") - require.NoError(t, err, "Failed to create test token") + // Create a token with the expected type + token := jwt.New() + err := token.Set(TokenTypeClaim, AccessTokenJWTType) + require.NoError(t, err, "Failed to set token type claim") - // Verify the token type - err = VerifyTokenTypeHeader(tokenString, "JWT") - assert.NoError(t, err, "Should accept token with matching type") + // Create a validator function for the expected type + validator := TokenTypeValidator(AccessTokenJWTType) + + // Validate the token + err = validator(ctx, token) + assert.NoError(t, err, "Validator should accept token with matching type") }) t.Run("fails when token type doesn't match expected type", func(t *testing.T) { - // Create a token with "AT+JWT" type - tokenString, err := createTokenWithType("AT+JWT") - require.NoError(t, err, "Failed to create test token") + // Create a token with a different type + token := jwt.New() + err := token.Set(TokenTypeClaim, OAuthAccessTokenJWTType) + require.NoError(t, err, "Failed to set token type claim") - // Verify the token with different expected type - err = VerifyTokenTypeHeader(tokenString, "JWT") - require.Error(t, err, "Should reject token with non-matching type") - assert.Contains(t, err.Error(), "header mismatch: expected 'JWT', got 'AT+JWT'") + // Create a validator function for a different expected type + validator := TokenTypeValidator(IDTokenJWTType) + + // Validate the token + err = validator(ctx, token) + require.Error(t, err, "Validator should reject token with non-matching type") + assert.Contains(t, err.Error(), "invalid token type: expected id-token, got oauth-access-token") + }) + + t.Run("fails when token type claim is missing", func(t *testing.T) { + // Create a token without a type claim + token := jwt.New() + + // Create a validator function + validator := TokenTypeValidator(AccessTokenJWTType) + + // Validate the token + err := validator(ctx, token) + require.Error(t, err, "Validator should reject token without type claim") + assert.Contains(t, err.Error(), "failed to get token type claim") }) }