mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-16 17:23:24 +03:00
fix: define token type as claim for better client compatibility
This commit is contained in:
@@ -1,20 +1,18 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jws"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
@@ -636,6 +634,9 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
Build()
|
||||
require.NoError(t, err, "Failed to build token")
|
||||
|
||||
err = SetTokenType(token, IDTokenJWTType)
|
||||
require.NoError(t, err, "Failed to set token type")
|
||||
|
||||
// Add custom claims
|
||||
for k, v := range userClaims {
|
||||
if k != "sub" { // Already set above
|
||||
@@ -644,13 +645,8 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Create headers with the specified type
|
||||
hdrs := jws.NewHeaders()
|
||||
err = hdrs.Set(jws.TypeKey, "ID+JWT")
|
||||
require.NoError(t, err, "Failed to set header type")
|
||||
|
||||
// Sign the token
|
||||
signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey, jws.WithProtectedHeaders(hdrs)))
|
||||
signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey))
|
||||
require.NoError(t, err, "Failed to sign token")
|
||||
tokenString := string(signed)
|
||||
|
||||
@@ -968,6 +964,9 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
||||
Build()
|
||||
require.NoError(t, err, "Failed to build token")
|
||||
|
||||
err = SetTokenType(token, OAuthAccessTokenJWTType)
|
||||
require.NoError(t, err, "Failed to set token type")
|
||||
|
||||
signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey))
|
||||
require.NoError(t, err, "Failed to sign token")
|
||||
|
||||
@@ -1168,59 +1167,50 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestVerifyTokenTypeHeader(t *testing.T) {
|
||||
mockConfig := &AppConfigService{}
|
||||
tempDir := t.TempDir()
|
||||
// Helper function to create a token with a specific type header
|
||||
createTokenWithType := func(tokenType string) (string, error) {
|
||||
// Create a simple JWT token
|
||||
token := jwt.New()
|
||||
err := token.Set("test_claim", "test_value")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set claim: %w", err)
|
||||
}
|
||||
|
||||
// Create headers with the specified type
|
||||
hdrs := jws.NewHeaders()
|
||||
if tokenType != "" {
|
||||
err = hdrs.Set(jws.TypeKey, tokenType)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set type header: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Sign the token with the headers
|
||||
service := &JwtService{}
|
||||
err = service.init(mockConfig, tempDir)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey, jws.WithProtectedHeaders(hdrs)))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign token: %w", err)
|
||||
}
|
||||
|
||||
return string(signed), nil
|
||||
}
|
||||
func TestTokenTypeValidator(t *testing.T) {
|
||||
// Create a context for the validator function
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("succeeds when token type matches expected type", func(t *testing.T) {
|
||||
// Create a token with "JWT" type
|
||||
tokenString, err := createTokenWithType("JWT")
|
||||
require.NoError(t, err, "Failed to create test token")
|
||||
// Create a token with the expected type
|
||||
token := jwt.New()
|
||||
err := token.Set(TokenTypeClaim, AccessTokenJWTType)
|
||||
require.NoError(t, err, "Failed to set token type claim")
|
||||
|
||||
// Verify the token type
|
||||
err = VerifyTokenTypeHeader(tokenString, "JWT")
|
||||
assert.NoError(t, err, "Should accept token with matching type")
|
||||
// Create a validator function for the expected type
|
||||
validator := TokenTypeValidator(AccessTokenJWTType)
|
||||
|
||||
// Validate the token
|
||||
err = validator(ctx, token)
|
||||
assert.NoError(t, err, "Validator should accept token with matching type")
|
||||
})
|
||||
|
||||
t.Run("fails when token type doesn't match expected type", func(t *testing.T) {
|
||||
// Create a token with "AT+JWT" type
|
||||
tokenString, err := createTokenWithType("AT+JWT")
|
||||
require.NoError(t, err, "Failed to create test token")
|
||||
// Create a token with a different type
|
||||
token := jwt.New()
|
||||
err := token.Set(TokenTypeClaim, OAuthAccessTokenJWTType)
|
||||
require.NoError(t, err, "Failed to set token type claim")
|
||||
|
||||
// Verify the token with different expected type
|
||||
err = VerifyTokenTypeHeader(tokenString, "JWT")
|
||||
require.Error(t, err, "Should reject token with non-matching type")
|
||||
assert.Contains(t, err.Error(), "header mismatch: expected 'JWT', got 'AT+JWT'")
|
||||
// Create a validator function for a different expected type
|
||||
validator := TokenTypeValidator(IDTokenJWTType)
|
||||
|
||||
// Validate the token
|
||||
err = validator(ctx, token)
|
||||
require.Error(t, err, "Validator should reject token with non-matching type")
|
||||
assert.Contains(t, err.Error(), "invalid token type: expected id-token, got oauth-access-token")
|
||||
})
|
||||
|
||||
t.Run("fails when token type claim is missing", func(t *testing.T) {
|
||||
// Create a token without a type claim
|
||||
token := jwt.New()
|
||||
|
||||
// Create a validator function
|
||||
validator := TokenTypeValidator(AccessTokenJWTType)
|
||||
|
||||
// Validate the token
|
||||
err := validator(ctx, token)
|
||||
require.Error(t, err, "Validator should reject token without type claim")
|
||||
assert.Contains(t, err.Error(), "failed to get token type claim")
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user