mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-17 01:11:38 +03:00
fix: define token type as claim for better client compatibility
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
@@ -11,11 +12,8 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lestrrat-go/jwx/v3/jws"
|
|
||||||
|
|
||||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||||
@@ -40,11 +38,17 @@ const (
|
|||||||
// This may be omitted on non-admin tokens
|
// This may be omitted on non-admin tokens
|
||||||
IsAdminClaim = "isAdmin"
|
IsAdminClaim = "isAdmin"
|
||||||
|
|
||||||
// AccessTokenJWTType is the media type for access tokens
|
// TokenTypeClaim is the claim used to identify the type of token
|
||||||
AccessTokenJWTType = "AT+JWT"
|
TokenTypeClaim = "type"
|
||||||
|
|
||||||
// IDTokenJWTType is the media type for ID tokens
|
// OAuthAccessTokenJWTType identifies a JWT as an OAuth access token
|
||||||
IDTokenJWTType = "ID+JWT"
|
OAuthAccessTokenJWTType = "oauth-access-token" //nolint:gosec
|
||||||
|
|
||||||
|
// AccessTokenJWTType identifies a JWT as an access token used by Pocket ID
|
||||||
|
AccessTokenJWTType = "access-token"
|
||||||
|
|
||||||
|
// IDTokenJWTType identifies a JWT as an ID token used by Pocket ID
|
||||||
|
IDTokenJWTType = "id-token"
|
||||||
|
|
||||||
// Acceptable clock skew for verifying tokens
|
// Acceptable clock skew for verifying tokens
|
||||||
clockSkew = time.Minute
|
clockSkew = time.Minute
|
||||||
@@ -195,6 +199,11 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
|
|||||||
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
|
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = SetTokenType(token, AccessTokenJWTType)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to set 'type' claim in token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
err = SetIsAdmin(token, user.IsAdmin)
|
err = SetIsAdmin(token, user.IsAdmin)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to set 'isAdmin' claim in token: %w", err)
|
return "", fmt.Errorf("failed to set 'isAdmin' claim in token: %w", err)
|
||||||
@@ -218,6 +227,7 @@ func (s *JwtService) VerifyAccessToken(tokenString string) (jwt.Token, error) {
|
|||||||
jwt.WithAcceptableSkew(clockSkew),
|
jwt.WithAcceptableSkew(clockSkew),
|
||||||
jwt.WithAudience(common.EnvConfig.AppURL),
|
jwt.WithAudience(common.EnvConfig.AppURL),
|
||||||
jwt.WithIssuer(common.EnvConfig.AppURL),
|
jwt.WithIssuer(common.EnvConfig.AppURL),
|
||||||
|
jwt.WithValidator(TokenTypeValidator(AccessTokenJWTType)),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||||
@@ -242,6 +252,11 @@ func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string,
|
|||||||
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
|
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = SetTokenType(token, IDTokenJWTType)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to set 'type' claim in token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
for k, v := range userClaims {
|
for k, v := range userClaims {
|
||||||
err = token.Set(k, v)
|
err = token.Set(k, v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -256,13 +271,8 @@ func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
headers, err := CreateTokenTypeHeader(IDTokenJWTType)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to set token type: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
alg, _ := s.privateKey.Algorithm()
|
alg, _ := s.privateKey.Algorithm()
|
||||||
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey, jws.WithProtectedHeaders(headers)))
|
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to sign token: %w", err)
|
return "", fmt.Errorf("failed to sign token: %w", err)
|
||||||
}
|
}
|
||||||
@@ -281,6 +291,7 @@ func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool)
|
|||||||
jwt.WithKey(alg, s.privateKey),
|
jwt.WithKey(alg, s.privateKey),
|
||||||
jwt.WithAcceptableSkew(clockSkew),
|
jwt.WithAcceptableSkew(clockSkew),
|
||||||
jwt.WithIssuer(common.EnvConfig.AppURL),
|
jwt.WithIssuer(common.EnvConfig.AppURL),
|
||||||
|
jwt.WithValidator(TokenTypeValidator(IDTokenJWTType)),
|
||||||
)
|
)
|
||||||
|
|
||||||
// By default, jwt.Parse includes 3 default validators for "nbf", "iat", and "exp"
|
// By default, jwt.Parse includes 3 default validators for "nbf", "iat", and "exp"
|
||||||
@@ -299,11 +310,6 @@ func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool)
|
|||||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = VerifyTokenTypeHeader(tokenString, IDTokenJWTType)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to verify token type: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -324,13 +330,13 @@ func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string)
|
|||||||
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
|
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
headers, err := CreateTokenTypeHeader(AccessTokenJWTType)
|
err = SetTokenType(token, OAuthAccessTokenJWTType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to set token type: %w", err)
|
return "", fmt.Errorf("failed to set 'type' claim in token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
alg, _ := s.privateKey.Algorithm()
|
alg, _ := s.privateKey.Algorithm()
|
||||||
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey, jws.WithProtectedHeaders(headers)))
|
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to sign token: %w", err)
|
return "", fmt.Errorf("failed to sign token: %w", err)
|
||||||
}
|
}
|
||||||
@@ -346,16 +352,12 @@ func (s *JwtService) VerifyOauthAccessToken(tokenString string) (jwt.Token, erro
|
|||||||
jwt.WithKey(alg, s.privateKey),
|
jwt.WithKey(alg, s.privateKey),
|
||||||
jwt.WithAcceptableSkew(clockSkew),
|
jwt.WithAcceptableSkew(clockSkew),
|
||||||
jwt.WithIssuer(common.EnvConfig.AppURL),
|
jwt.WithIssuer(common.EnvConfig.AppURL),
|
||||||
|
jwt.WithValidator(TokenTypeValidator(OAuthAccessTokenJWTType)),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = VerifyTokenTypeHeader(tokenString, AccessTokenJWTType)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to verify token type: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -510,15 +512,12 @@ func GetIsAdmin(token jwt.Token) (bool, error) {
|
|||||||
return isAdmin, err
|
return isAdmin, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateTokenTypeHeader creates a new JWS header with the given token type
|
// SetTokenType sets the "type" claim in the token
|
||||||
func CreateTokenTypeHeader(tokenType string) (jws.Headers, error) {
|
func SetTokenType(token jwt.Token, tokenType string) error {
|
||||||
headers := jws.NewHeaders()
|
if tokenType == "" {
|
||||||
err := headers.Set(jws.TypeKey, tokenType)
|
return nil
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to set token type: %w", err)
|
|
||||||
}
|
}
|
||||||
|
return token.Set(TokenTypeClaim, tokenType)
|
||||||
return headers, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetIsAdmin sets the "isAdmin" claim in the token
|
// SetIsAdmin sets the "isAdmin" claim in the token
|
||||||
@@ -536,36 +535,17 @@ func SetAudienceString(token jwt.Token, audience string) error {
|
|||||||
return token.Set(jwt.AudienceKey, audience)
|
return token.Set(jwt.AudienceKey, audience)
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifyTokenTypeHeader verifies that the "typ" header in the token matches the expected type
|
// TokenTypeValidator is a validator function that checks the "type" claim in the token
|
||||||
func VerifyTokenTypeHeader(tokenBytes string, expectedTokenType string) error {
|
func TokenTypeValidator(expectedTokenType string) jwt.ValidatorFunc {
|
||||||
// Parse the raw token string purely as a JWS message structure
|
return func(_ context.Context, t jwt.Token) error {
|
||||||
// We don't need to verify the signature at this stage, just inspect headers.
|
var tokenType string
|
||||||
msg, err := jws.Parse([]byte(tokenBytes))
|
err := t.Get(TokenTypeClaim, &tokenType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to parse token as JWS message: %w", err)
|
return fmt.Errorf("failed to get token type claim: %w", err)
|
||||||
|
}
|
||||||
|
if tokenType != expectedTokenType {
|
||||||
|
return fmt.Errorf("invalid token type: expected %s, got %s", expectedTokenType, tokenType)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the list of signatures attached to the message. Usually just one for JWT.
|
|
||||||
signatures := msg.Signatures()
|
|
||||||
if len(signatures) == 0 {
|
|
||||||
return errors.New("JWS message contains no signatures")
|
|
||||||
}
|
|
||||||
|
|
||||||
protectedHeaders := signatures[0].ProtectedHeaders()
|
|
||||||
if protectedHeaders == nil {
|
|
||||||
return fmt.Errorf("JWS signature has no protected headers")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retrieve the 'typ' header value from the PROTECTED headers.
|
|
||||||
var typHeaderValue string
|
|
||||||
err = protectedHeaders.Get(jws.TypeKey, &typHeaderValue)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("token is missing required protected header '%s'", jws.TypeKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.EqualFold(typHeaderValue, expectedTokenType) {
|
|
||||||
return fmt.Errorf("'%s' header mismatch: expected '%s', got '%s'", jws.TypeKey, expectedTokenType, typHeaderValue)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,20 +1,18 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
"crypto/elliptic"
|
"crypto/elliptic"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"fmt"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lestrrat-go/jwx/v3/jws"
|
|
||||||
|
|
||||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||||
@@ -636,6 +634,9 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
|||||||
Build()
|
Build()
|
||||||
require.NoError(t, err, "Failed to build token")
|
require.NoError(t, err, "Failed to build token")
|
||||||
|
|
||||||
|
err = SetTokenType(token, IDTokenJWTType)
|
||||||
|
require.NoError(t, err, "Failed to set token type")
|
||||||
|
|
||||||
// Add custom claims
|
// Add custom claims
|
||||||
for k, v := range userClaims {
|
for k, v := range userClaims {
|
||||||
if k != "sub" { // Already set above
|
if k != "sub" { // Already set above
|
||||||
@@ -644,13 +645,8 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create headers with the specified type
|
|
||||||
hdrs := jws.NewHeaders()
|
|
||||||
err = hdrs.Set(jws.TypeKey, "ID+JWT")
|
|
||||||
require.NoError(t, err, "Failed to set header type")
|
|
||||||
|
|
||||||
// Sign the token
|
// Sign the token
|
||||||
signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey, jws.WithProtectedHeaders(hdrs)))
|
signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey))
|
||||||
require.NoError(t, err, "Failed to sign token")
|
require.NoError(t, err, "Failed to sign token")
|
||||||
tokenString := string(signed)
|
tokenString := string(signed)
|
||||||
|
|
||||||
@@ -968,6 +964,9 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
|||||||
Build()
|
Build()
|
||||||
require.NoError(t, err, "Failed to build token")
|
require.NoError(t, err, "Failed to build token")
|
||||||
|
|
||||||
|
err = SetTokenType(token, OAuthAccessTokenJWTType)
|
||||||
|
require.NoError(t, err, "Failed to set token type")
|
||||||
|
|
||||||
signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey))
|
signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey))
|
||||||
require.NoError(t, err, "Failed to sign token")
|
require.NoError(t, err, "Failed to sign token")
|
||||||
|
|
||||||
@@ -1168,59 +1167,50 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestVerifyTokenTypeHeader(t *testing.T) {
|
func TestTokenTypeValidator(t *testing.T) {
|
||||||
mockConfig := &AppConfigService{}
|
// Create a context for the validator function
|
||||||
tempDir := t.TempDir()
|
ctx := context.Background()
|
||||||
// Helper function to create a token with a specific type header
|
|
||||||
createTokenWithType := func(tokenType string) (string, error) {
|
|
||||||
// Create a simple JWT token
|
|
||||||
token := jwt.New()
|
|
||||||
err := token.Set("test_claim", "test_value")
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to set claim: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create headers with the specified type
|
|
||||||
hdrs := jws.NewHeaders()
|
|
||||||
if tokenType != "" {
|
|
||||||
err = hdrs.Set(jws.TypeKey, tokenType)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to set type header: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sign the token with the headers
|
|
||||||
service := &JwtService{}
|
|
||||||
err = service.init(mockConfig, tempDir)
|
|
||||||
require.NoError(t, err, "Failed to initialize JWT service")
|
|
||||||
|
|
||||||
signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey, jws.WithProtectedHeaders(hdrs)))
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to sign token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return string(signed), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("succeeds when token type matches expected type", func(t *testing.T) {
|
t.Run("succeeds when token type matches expected type", func(t *testing.T) {
|
||||||
// Create a token with "JWT" type
|
// Create a token with the expected type
|
||||||
tokenString, err := createTokenWithType("JWT")
|
token := jwt.New()
|
||||||
require.NoError(t, err, "Failed to create test token")
|
err := token.Set(TokenTypeClaim, AccessTokenJWTType)
|
||||||
|
require.NoError(t, err, "Failed to set token type claim")
|
||||||
|
|
||||||
// Verify the token type
|
// Create a validator function for the expected type
|
||||||
err = VerifyTokenTypeHeader(tokenString, "JWT")
|
validator := TokenTypeValidator(AccessTokenJWTType)
|
||||||
assert.NoError(t, err, "Should accept token with matching type")
|
|
||||||
|
// Validate the token
|
||||||
|
err = validator(ctx, token)
|
||||||
|
assert.NoError(t, err, "Validator should accept token with matching type")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("fails when token type doesn't match expected type", func(t *testing.T) {
|
t.Run("fails when token type doesn't match expected type", func(t *testing.T) {
|
||||||
// Create a token with "AT+JWT" type
|
// Create a token with a different type
|
||||||
tokenString, err := createTokenWithType("AT+JWT")
|
token := jwt.New()
|
||||||
require.NoError(t, err, "Failed to create test token")
|
err := token.Set(TokenTypeClaim, OAuthAccessTokenJWTType)
|
||||||
|
require.NoError(t, err, "Failed to set token type claim")
|
||||||
|
|
||||||
// Verify the token with different expected type
|
// Create a validator function for a different expected type
|
||||||
err = VerifyTokenTypeHeader(tokenString, "JWT")
|
validator := TokenTypeValidator(IDTokenJWTType)
|
||||||
require.Error(t, err, "Should reject token with non-matching type")
|
|
||||||
assert.Contains(t, err.Error(), "header mismatch: expected 'JWT', got 'AT+JWT'")
|
// Validate the token
|
||||||
|
err = validator(ctx, token)
|
||||||
|
require.Error(t, err, "Validator should reject token with non-matching type")
|
||||||
|
assert.Contains(t, err.Error(), "invalid token type: expected id-token, got oauth-access-token")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("fails when token type claim is missing", func(t *testing.T) {
|
||||||
|
// Create a token without a type claim
|
||||||
|
token := jwt.New()
|
||||||
|
|
||||||
|
// Create a validator function
|
||||||
|
validator := TokenTypeValidator(AccessTokenJWTType)
|
||||||
|
|
||||||
|
// Validate the token
|
||||||
|
err := validator(ctx, token)
|
||||||
|
require.Error(t, err, "Validator should reject token without type claim")
|
||||||
|
assert.Contains(t, err.Error(), "failed to get token type claim")
|
||||||
})
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user