fix: token introspection authentication not handled correctly (#704)

Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
Elias Schneider
2025-07-01 23:14:07 +02:00
committed by GitHub
parent 031181ad2a
commit aefb308536
2 changed files with 87 additions and 43 deletions

View File

@@ -255,7 +255,7 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O
tx.Rollback()
}()
_, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input))
_, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input), true)
if err != nil {
return CreatedTokens{}, err
}
@@ -336,7 +336,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu
tx.Rollback()
}()
client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input))
client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input), true)
if err != nil {
return CreatedTokens{}, err
}
@@ -420,7 +420,7 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
tx.Rollback()
}()
client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input))
client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input), true)
if err != nil {
return CreatedTokens{}, err
}
@@ -490,6 +490,11 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
}
func (s *OidcService) IntrospectToken(ctx context.Context, creds ClientAuthCredentials, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
client, err := s.verifyClientCredentialsInternal(ctx, s.db, creds, false)
if err != nil {
return introspectDto, err
}
// Get the type of the token and the client ID
tokenType, token, err := s.jwtService.GetTokenType(tokenString)
if err != nil {
@@ -498,24 +503,16 @@ func (s *OidcService) IntrospectToken(ctx context.Context, creds ClientAuthCrede
return introspectDto, nil //nolint:nilerr
}
// If we don't have a client ID, get it from the token
// Otherwise, we need to make sure that the client ID passed as credential matches
// Get the audience from the token
tokenAudiences, _ := token.Audience()
if len(tokenAudiences) != 1 || tokenAudiences[0] == "" {
// We just treat the token as invalid
introspectDto.Active = false
return introspectDto, nil
}
if creds.ClientID == "" {
creds.ClientID = tokenAudiences[0]
} else if creds.ClientID != tokenAudiences[0] {
return introspectDto, &common.OidcMissingClientCredentialsError{}
}
// Verify the credentials for the call
client, err := s.verifyClientCredentialsInternal(ctx, s.db, creds)
if err != nil {
return introspectDto, err
// Audience must match the client ID
if client.ID != tokenAudiences[0] {
return introspectDto, &common.OidcMissingClientCredentialsError{}
}
// Introspect the token
@@ -1137,7 +1134,7 @@ func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.O
ClientSecret: input.ClientSecret,
ClientAssertionType: input.ClientAssertionType,
ClientAssertion: input.ClientAssertion,
})
}, true)
if err != nil {
return nil, err
}
@@ -1385,24 +1382,39 @@ func clientAuthCredentialsFromCreateTokensDto(d *dto.OidcCreateTokensDto) Client
}
}
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input ClientAuthCredentials) (*model.OidcClient, error) {
// First, ensure we have a valid client ID
if input.ClientID == "" {
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input ClientAuthCredentials, allowPublicClientsWithoutAuth bool) (client *model.OidcClient, err error) {
isClientAssertion := input.ClientAssertionType == ClientAssertionTypeJWTBearer && input.ClientAssertion != ""
// Determine the client ID based on the authentication method
var clientID string
switch {
case isClientAssertion:
// Extract client ID from the JWT assertion's 'sub' claim
clientID, err = s.extractClientIDFromAssertion(input.ClientAssertion)
if err != nil {
slog.Error("Failed to extract client ID from assertion", "error", err)
return nil, &common.OidcClientAssertionInvalidError{}
}
case input.ClientID != "":
// Use the provided client ID for other authentication methods
clientID = input.ClientID
default:
return nil, &common.OidcMissingClientCredentialsError{}
}
// Load the OIDC client's configuration
var client model.OidcClient
err := tx.
err = tx.
WithContext(ctx).
First(&client, "id = ?", input.ClientID).
First(&client, "id = ?", clientID).
Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) && isClientAssertion {
return nil, &common.OidcClientAssertionInvalidError{}
}
return nil, err
}
// We have 3 options
// If credentials are provided, we validate them; otherwise, we can continue without credentials for public clients only
// Validate credentials based on the authentication method
switch {
// First, if we have a client secret, we validate it
case input.ClientSecret != "":
@@ -1410,21 +1422,21 @@ func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *g
if err != nil {
return nil, &common.OidcClientSecretInvalidError{}
}
return &client, nil
return client, nil
// Next, check if we want to use client assertions from federated identities
case input.ClientAssertionType == ClientAssertionTypeJWTBearer && input.ClientAssertion != "":
err = s.verifyClientAssertionFromFederatedIdentities(ctx, &client, input)
case isClientAssertion:
err = s.verifyClientAssertionFromFederatedIdentities(ctx, client, input)
if err != nil {
log.Printf("Invalid assertion for client '%s': %v", client.ID, err)
return nil, &common.OidcClientAssertionInvalidError{}
}
return &client, nil
return client, nil
// There's no credentials
// This is allowed only if the client is public
case client.IsPublic:
return &client, nil
case client.IsPublic && allowPublicClientsWithoutAuth:
return client, nil
// If we're here, we have no credentials AND the client is not public, so credentials are required
default:
@@ -1523,6 +1535,23 @@ func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.C
return nil
}
// extractClientIDFromAssertion extracts the client_id from the JWT assertion's 'sub' claim
func (s *OidcService) extractClientIDFromAssertion(assertion string) (string, error) {
// Parse the JWT without verification first to get the claims
insecureToken, err := jwt.ParseInsecure([]byte(assertion))
if err != nil {
return "", fmt.Errorf("failed to parse JWT assertion: %w", err)
}
// Extract the subject claim which must be the client_id according to RFC 7523
sub, ok := insecureToken.Subject()
if !ok || sub == "" {
return "", fmt.Errorf("missing or invalid 'sub' claim in JWT assertion")
}
return sub, nil
}
func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, userID string, scopes string) (*dto.OidcClientPreviewDto, error) {
tx := s.db.Begin()
defer func() {

View File

@@ -134,7 +134,6 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
const (
federatedClientIssuer = "https://external-idp.com"
federatedClientAudience = "https://pocket-id.com"
federatedClientSubject = "123456abcdef"
federatedClientIssuerDefaults = "https://external-idp-defaults.com/"
)
@@ -192,18 +191,24 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
federatedClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{
Name: "Federated Client",
CallbackURLs: []string{"https://example.com/callback"},
}, "test-user-id")
require.NoError(t, err)
federatedClient, err = s.UpdateClient(t.Context(), federatedClient.ID, dto.OidcClientCreateDto{
Name: federatedClient.Name,
CallbackURLs: federatedClient.CallbackURLs,
Credentials: dto.OidcClientCredentialsDto{
FederatedIdentities: []dto.OidcClientFederatedIdentityDto{
{
Issuer: federatedClientIssuer,
Audience: federatedClientAudience,
Subject: federatedClientSubject,
Subject: federatedClient.ID,
JWKS: federatedClientIssuer + "/jwks.json",
},
{Issuer: federatedClientIssuerDefaults},
},
},
}, "test-user-id")
})
require.NoError(t, err)
// Test cases for confidential client (using client secret)
@@ -213,7 +218,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: confidentialClient.ID,
ClientSecret: confidentialSecret,
})
}, true)
require.NoError(t, err)
require.NotNil(t, client)
assert.Equal(t, confidentialClient.ID, client.ID)
@@ -224,7 +229,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: confidentialClient.ID,
ClientSecret: "invalid-secret",
})
}, true)
require.Error(t, err)
require.ErrorIs(t, err, &common.OidcClientSecretInvalidError{})
assert.Nil(t, client)
@@ -234,7 +239,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
// Test with missing client secret
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: confidentialClient.ID,
})
}, true)
require.Error(t, err)
require.ErrorIs(t, err, &common.OidcMissingClientCredentialsError{})
assert.Nil(t, client)
@@ -247,11 +252,21 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
// Public clients don't require client secret
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: publicClient.ID,
})
}, true)
require.NoError(t, err)
require.NotNil(t, client)
assert.Equal(t, publicClient.ID, client.ID)
})
t.Run("Fails with no credentials if allowPublicClientsWithoutAuth is false", func(t *testing.T) {
// Public clients don't require client secret
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: publicClient.ID,
}, false)
require.Error(t, err)
require.ErrorIs(t, err, &common.OidcMissingClientCredentialsError{})
assert.Nil(t, client)
})
})
// Test cases for federated client using JWT assertion
@@ -261,7 +276,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
token, err := jwt.NewBuilder().
Issuer(federatedClientIssuer).
Audience([]string{federatedClientAudience}).
Subject(federatedClientSubject).
Subject(federatedClient.ID).
IssuedAt(time.Now()).
Expiration(time.Now().Add(10 * time.Minute)).
Build()
@@ -274,7 +289,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
ClientID: federatedClient.ID,
ClientAssertionType: ClientAssertionTypeJWTBearer,
ClientAssertion: string(signedToken),
})
}, true)
require.NoError(t, err)
require.NotNil(t, client)
assert.Equal(t, federatedClient.ID, client.ID)
@@ -286,7 +301,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
ClientID: federatedClient.ID,
ClientAssertionType: ClientAssertionTypeJWTBearer,
ClientAssertion: "invalid.jwt.token",
})
}, true)
require.Error(t, err)
require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{})
assert.Nil(t, client)
@@ -298,7 +313,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
builder := jwt.NewBuilder().
Issuer(federatedClientIssuer).
Audience([]string{federatedClientAudience}).
Subject(federatedClientSubject).
Subject(federatedClient.ID).
IssuedAt(time.Now()).
Expiration(time.Now().Add(10 * time.Minute))
@@ -315,7 +330,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
ClientID: federatedClient.ID,
ClientAssertionType: ClientAssertionTypeJWTBearer,
ClientAssertion: string(signedToken),
})
}, true)
require.Error(t, err)
require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{})
require.Nil(t, client)
@@ -356,7 +371,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
ClientID: federatedClient.ID,
ClientAssertionType: ClientAssertionTypeJWTBearer,
ClientAssertion: string(signedToken),
})
}, true)
require.NoError(t, err)
require.NotNil(t, client)
assert.Equal(t, federatedClient.ID, client.ID)