fix: token introspection authentication not handled correctly (#704)

Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
Elias Schneider
2025-07-01 23:14:07 +02:00
committed by GitHub
parent 031181ad2a
commit aefb308536
2 changed files with 87 additions and 43 deletions

View File

@@ -255,7 +255,7 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O
tx.Rollback() tx.Rollback()
}() }()
_, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input)) _, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input), true)
if err != nil { if err != nil {
return CreatedTokens{}, err return CreatedTokens{}, err
} }
@@ -336,7 +336,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu
tx.Rollback() tx.Rollback()
}() }()
client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input)) client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input), true)
if err != nil { if err != nil {
return CreatedTokens{}, err return CreatedTokens{}, err
} }
@@ -420,7 +420,7 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
tx.Rollback() tx.Rollback()
}() }()
client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input)) client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input), true)
if err != nil { if err != nil {
return CreatedTokens{}, err return CreatedTokens{}, err
} }
@@ -490,6 +490,11 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
} }
func (s *OidcService) IntrospectToken(ctx context.Context, creds ClientAuthCredentials, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) { func (s *OidcService) IntrospectToken(ctx context.Context, creds ClientAuthCredentials, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
client, err := s.verifyClientCredentialsInternal(ctx, s.db, creds, false)
if err != nil {
return introspectDto, err
}
// Get the type of the token and the client ID // Get the type of the token and the client ID
tokenType, token, err := s.jwtService.GetTokenType(tokenString) tokenType, token, err := s.jwtService.GetTokenType(tokenString)
if err != nil { if err != nil {
@@ -498,24 +503,16 @@ func (s *OidcService) IntrospectToken(ctx context.Context, creds ClientAuthCrede
return introspectDto, nil //nolint:nilerr return introspectDto, nil //nolint:nilerr
} }
// If we don't have a client ID, get it from the token // Get the audience from the token
// Otherwise, we need to make sure that the client ID passed as credential matches
tokenAudiences, _ := token.Audience() tokenAudiences, _ := token.Audience()
if len(tokenAudiences) != 1 || tokenAudiences[0] == "" { if len(tokenAudiences) != 1 || tokenAudiences[0] == "" {
// We just treat the token as invalid
introspectDto.Active = false introspectDto.Active = false
return introspectDto, nil return introspectDto, nil
} }
if creds.ClientID == "" {
creds.ClientID = tokenAudiences[0]
} else if creds.ClientID != tokenAudiences[0] {
return introspectDto, &common.OidcMissingClientCredentialsError{}
}
// Verify the credentials for the call // Audience must match the client ID
client, err := s.verifyClientCredentialsInternal(ctx, s.db, creds) if client.ID != tokenAudiences[0] {
if err != nil { return introspectDto, &common.OidcMissingClientCredentialsError{}
return introspectDto, err
} }
// Introspect the token // Introspect the token
@@ -1137,7 +1134,7 @@ func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.O
ClientSecret: input.ClientSecret, ClientSecret: input.ClientSecret,
ClientAssertionType: input.ClientAssertionType, ClientAssertionType: input.ClientAssertionType,
ClientAssertion: input.ClientAssertion, ClientAssertion: input.ClientAssertion,
}) }, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1385,24 +1382,39 @@ func clientAuthCredentialsFromCreateTokensDto(d *dto.OidcCreateTokensDto) Client
} }
} }
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input ClientAuthCredentials) (*model.OidcClient, error) { func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input ClientAuthCredentials, allowPublicClientsWithoutAuth bool) (client *model.OidcClient, err error) {
// First, ensure we have a valid client ID isClientAssertion := input.ClientAssertionType == ClientAssertionTypeJWTBearer && input.ClientAssertion != ""
if input.ClientID == "" {
// Determine the client ID based on the authentication method
var clientID string
switch {
case isClientAssertion:
// Extract client ID from the JWT assertion's 'sub' claim
clientID, err = s.extractClientIDFromAssertion(input.ClientAssertion)
if err != nil {
slog.Error("Failed to extract client ID from assertion", "error", err)
return nil, &common.OidcClientAssertionInvalidError{}
}
case input.ClientID != "":
// Use the provided client ID for other authentication methods
clientID = input.ClientID
default:
return nil, &common.OidcMissingClientCredentialsError{} return nil, &common.OidcMissingClientCredentialsError{}
} }
// Load the OIDC client's configuration // Load the OIDC client's configuration
var client model.OidcClient err = tx.
err := tx.
WithContext(ctx). WithContext(ctx).
First(&client, "id = ?", input.ClientID). First(&client, "id = ?", clientID).
Error Error
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) && isClientAssertion {
return nil, &common.OidcClientAssertionInvalidError{}
}
return nil, err return nil, err
} }
// We have 3 options // Validate credentials based on the authentication method
// If credentials are provided, we validate them; otherwise, we can continue without credentials for public clients only
switch { switch {
// First, if we have a client secret, we validate it // First, if we have a client secret, we validate it
case input.ClientSecret != "": case input.ClientSecret != "":
@@ -1410,21 +1422,21 @@ func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *g
if err != nil { if err != nil {
return nil, &common.OidcClientSecretInvalidError{} return nil, &common.OidcClientSecretInvalidError{}
} }
return &client, nil return client, nil
// Next, check if we want to use client assertions from federated identities // Next, check if we want to use client assertions from federated identities
case input.ClientAssertionType == ClientAssertionTypeJWTBearer && input.ClientAssertion != "": case isClientAssertion:
err = s.verifyClientAssertionFromFederatedIdentities(ctx, &client, input) err = s.verifyClientAssertionFromFederatedIdentities(ctx, client, input)
if err != nil { if err != nil {
log.Printf("Invalid assertion for client '%s': %v", client.ID, err) log.Printf("Invalid assertion for client '%s': %v", client.ID, err)
return nil, &common.OidcClientAssertionInvalidError{} return nil, &common.OidcClientAssertionInvalidError{}
} }
return &client, nil return client, nil
// There's no credentials // There's no credentials
// This is allowed only if the client is public // This is allowed only if the client is public
case client.IsPublic: case client.IsPublic && allowPublicClientsWithoutAuth:
return &client, nil return client, nil
// If we're here, we have no credentials AND the client is not public, so credentials are required // If we're here, we have no credentials AND the client is not public, so credentials are required
default: default:
@@ -1523,6 +1535,23 @@ func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.C
return nil return nil
} }
// extractClientIDFromAssertion extracts the client_id from the JWT assertion's 'sub' claim
func (s *OidcService) extractClientIDFromAssertion(assertion string) (string, error) {
// Parse the JWT without verification first to get the claims
insecureToken, err := jwt.ParseInsecure([]byte(assertion))
if err != nil {
return "", fmt.Errorf("failed to parse JWT assertion: %w", err)
}
// Extract the subject claim which must be the client_id according to RFC 7523
sub, ok := insecureToken.Subject()
if !ok || sub == "" {
return "", fmt.Errorf("missing or invalid 'sub' claim in JWT assertion")
}
return sub, nil
}
func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, userID string, scopes string) (*dto.OidcClientPreviewDto, error) { func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, userID string, scopes string) (*dto.OidcClientPreviewDto, error) {
tx := s.db.Begin() tx := s.db.Begin()
defer func() { defer func() {

View File

@@ -134,7 +134,6 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
const ( const (
federatedClientIssuer = "https://external-idp.com" federatedClientIssuer = "https://external-idp.com"
federatedClientAudience = "https://pocket-id.com" federatedClientAudience = "https://pocket-id.com"
federatedClientSubject = "123456abcdef"
federatedClientIssuerDefaults = "https://external-idp-defaults.com/" federatedClientIssuerDefaults = "https://external-idp-defaults.com/"
) )
@@ -192,18 +191,24 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
federatedClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{ federatedClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{
Name: "Federated Client", Name: "Federated Client",
CallbackURLs: []string{"https://example.com/callback"}, CallbackURLs: []string{"https://example.com/callback"},
}, "test-user-id")
require.NoError(t, err)
federatedClient, err = s.UpdateClient(t.Context(), federatedClient.ID, dto.OidcClientCreateDto{
Name: federatedClient.Name,
CallbackURLs: federatedClient.CallbackURLs,
Credentials: dto.OidcClientCredentialsDto{ Credentials: dto.OidcClientCredentialsDto{
FederatedIdentities: []dto.OidcClientFederatedIdentityDto{ FederatedIdentities: []dto.OidcClientFederatedIdentityDto{
{ {
Issuer: federatedClientIssuer, Issuer: federatedClientIssuer,
Audience: federatedClientAudience, Audience: federatedClientAudience,
Subject: federatedClientSubject, Subject: federatedClient.ID,
JWKS: federatedClientIssuer + "/jwks.json", JWKS: federatedClientIssuer + "/jwks.json",
}, },
{Issuer: federatedClientIssuerDefaults}, {Issuer: federatedClientIssuerDefaults},
}, },
}, },
}, "test-user-id") })
require.NoError(t, err) require.NoError(t, err)
// Test cases for confidential client (using client secret) // Test cases for confidential client (using client secret)
@@ -213,7 +218,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{ client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: confidentialClient.ID, ClientID: confidentialClient.ID,
ClientSecret: confidentialSecret, ClientSecret: confidentialSecret,
}) }, true)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, client) require.NotNil(t, client)
assert.Equal(t, confidentialClient.ID, client.ID) assert.Equal(t, confidentialClient.ID, client.ID)
@@ -224,7 +229,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{ client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: confidentialClient.ID, ClientID: confidentialClient.ID,
ClientSecret: "invalid-secret", ClientSecret: "invalid-secret",
}) }, true)
require.Error(t, err) require.Error(t, err)
require.ErrorIs(t, err, &common.OidcClientSecretInvalidError{}) require.ErrorIs(t, err, &common.OidcClientSecretInvalidError{})
assert.Nil(t, client) assert.Nil(t, client)
@@ -234,7 +239,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
// Test with missing client secret // Test with missing client secret
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{ client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: confidentialClient.ID, ClientID: confidentialClient.ID,
}) }, true)
require.Error(t, err) require.Error(t, err)
require.ErrorIs(t, err, &common.OidcMissingClientCredentialsError{}) require.ErrorIs(t, err, &common.OidcMissingClientCredentialsError{})
assert.Nil(t, client) assert.Nil(t, client)
@@ -247,11 +252,21 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
// Public clients don't require client secret // Public clients don't require client secret
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{ client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: publicClient.ID, ClientID: publicClient.ID,
}) }, true)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, client) require.NotNil(t, client)
assert.Equal(t, publicClient.ID, client.ID) assert.Equal(t, publicClient.ID, client.ID)
}) })
t.Run("Fails with no credentials if allowPublicClientsWithoutAuth is false", func(t *testing.T) {
// Public clients don't require client secret
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: publicClient.ID,
}, false)
require.Error(t, err)
require.ErrorIs(t, err, &common.OidcMissingClientCredentialsError{})
assert.Nil(t, client)
})
}) })
// Test cases for federated client using JWT assertion // Test cases for federated client using JWT assertion
@@ -261,7 +276,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
token, err := jwt.NewBuilder(). token, err := jwt.NewBuilder().
Issuer(federatedClientIssuer). Issuer(federatedClientIssuer).
Audience([]string{federatedClientAudience}). Audience([]string{federatedClientAudience}).
Subject(federatedClientSubject). Subject(federatedClient.ID).
IssuedAt(time.Now()). IssuedAt(time.Now()).
Expiration(time.Now().Add(10 * time.Minute)). Expiration(time.Now().Add(10 * time.Minute)).
Build() Build()
@@ -274,7 +289,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
ClientID: federatedClient.ID, ClientID: federatedClient.ID,
ClientAssertionType: ClientAssertionTypeJWTBearer, ClientAssertionType: ClientAssertionTypeJWTBearer,
ClientAssertion: string(signedToken), ClientAssertion: string(signedToken),
}) }, true)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, client) require.NotNil(t, client)
assert.Equal(t, federatedClient.ID, client.ID) assert.Equal(t, federatedClient.ID, client.ID)
@@ -286,7 +301,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
ClientID: federatedClient.ID, ClientID: federatedClient.ID,
ClientAssertionType: ClientAssertionTypeJWTBearer, ClientAssertionType: ClientAssertionTypeJWTBearer,
ClientAssertion: "invalid.jwt.token", ClientAssertion: "invalid.jwt.token",
}) }, true)
require.Error(t, err) require.Error(t, err)
require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{}) require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{})
assert.Nil(t, client) assert.Nil(t, client)
@@ -298,7 +313,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
builder := jwt.NewBuilder(). builder := jwt.NewBuilder().
Issuer(federatedClientIssuer). Issuer(federatedClientIssuer).
Audience([]string{federatedClientAudience}). Audience([]string{federatedClientAudience}).
Subject(federatedClientSubject). Subject(federatedClient.ID).
IssuedAt(time.Now()). IssuedAt(time.Now()).
Expiration(time.Now().Add(10 * time.Minute)) Expiration(time.Now().Add(10 * time.Minute))
@@ -315,7 +330,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
ClientID: federatedClient.ID, ClientID: federatedClient.ID,
ClientAssertionType: ClientAssertionTypeJWTBearer, ClientAssertionType: ClientAssertionTypeJWTBearer,
ClientAssertion: string(signedToken), ClientAssertion: string(signedToken),
}) }, true)
require.Error(t, err) require.Error(t, err)
require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{}) require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{})
require.Nil(t, client) require.Nil(t, client)
@@ -356,7 +371,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
ClientID: federatedClient.ID, ClientID: federatedClient.ID,
ClientAssertionType: ClientAssertionTypeJWTBearer, ClientAssertionType: ClientAssertionTypeJWTBearer,
ClientAssertion: string(signedToken), ClientAssertion: string(signedToken),
}) }, true)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, client) require.NotNil(t, client)
assert.Equal(t, federatedClient.ID, client.ID) assert.Equal(t, federatedClient.ID, client.ID)