mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-21 17:25:44 +03:00
fix: token introspection authentication not handled correctly (#704)
Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
@@ -255,7 +255,7 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
_, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input))
|
||||
_, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input), true)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
@@ -336,7 +336,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input))
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input), true)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
@@ -420,7 +420,7 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input))
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input), true)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
@@ -490,6 +490,11 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
|
||||
}
|
||||
|
||||
func (s *OidcService) IntrospectToken(ctx context.Context, creds ClientAuthCredentials, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, s.db, creds, false)
|
||||
if err != nil {
|
||||
return introspectDto, err
|
||||
}
|
||||
|
||||
// Get the type of the token and the client ID
|
||||
tokenType, token, err := s.jwtService.GetTokenType(tokenString)
|
||||
if err != nil {
|
||||
@@ -498,24 +503,16 @@ func (s *OidcService) IntrospectToken(ctx context.Context, creds ClientAuthCrede
|
||||
return introspectDto, nil //nolint:nilerr
|
||||
}
|
||||
|
||||
// If we don't have a client ID, get it from the token
|
||||
// Otherwise, we need to make sure that the client ID passed as credential matches
|
||||
// Get the audience from the token
|
||||
tokenAudiences, _ := token.Audience()
|
||||
if len(tokenAudiences) != 1 || tokenAudiences[0] == "" {
|
||||
// We just treat the token as invalid
|
||||
introspectDto.Active = false
|
||||
return introspectDto, nil
|
||||
}
|
||||
if creds.ClientID == "" {
|
||||
creds.ClientID = tokenAudiences[0]
|
||||
} else if creds.ClientID != tokenAudiences[0] {
|
||||
return introspectDto, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
|
||||
// Verify the credentials for the call
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, s.db, creds)
|
||||
if err != nil {
|
||||
return introspectDto, err
|
||||
// Audience must match the client ID
|
||||
if client.ID != tokenAudiences[0] {
|
||||
return introspectDto, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
|
||||
// Introspect the token
|
||||
@@ -1137,7 +1134,7 @@ func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.O
|
||||
ClientSecret: input.ClientSecret,
|
||||
ClientAssertionType: input.ClientAssertionType,
|
||||
ClientAssertion: input.ClientAssertion,
|
||||
})
|
||||
}, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1385,24 +1382,39 @@ func clientAuthCredentialsFromCreateTokensDto(d *dto.OidcCreateTokensDto) Client
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input ClientAuthCredentials) (*model.OidcClient, error) {
|
||||
// First, ensure we have a valid client ID
|
||||
if input.ClientID == "" {
|
||||
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input ClientAuthCredentials, allowPublicClientsWithoutAuth bool) (client *model.OidcClient, err error) {
|
||||
isClientAssertion := input.ClientAssertionType == ClientAssertionTypeJWTBearer && input.ClientAssertion != ""
|
||||
|
||||
// Determine the client ID based on the authentication method
|
||||
var clientID string
|
||||
switch {
|
||||
case isClientAssertion:
|
||||
// Extract client ID from the JWT assertion's 'sub' claim
|
||||
clientID, err = s.extractClientIDFromAssertion(input.ClientAssertion)
|
||||
if err != nil {
|
||||
slog.Error("Failed to extract client ID from assertion", "error", err)
|
||||
return nil, &common.OidcClientAssertionInvalidError{}
|
||||
}
|
||||
case input.ClientID != "":
|
||||
// Use the provided client ID for other authentication methods
|
||||
clientID = input.ClientID
|
||||
default:
|
||||
return nil, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
|
||||
// Load the OIDC client's configuration
|
||||
var client model.OidcClient
|
||||
err := tx.
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
First(&client, "id = ?", input.ClientID).
|
||||
First(&client, "id = ?", clientID).
|
||||
Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) && isClientAssertion {
|
||||
return nil, &common.OidcClientAssertionInvalidError{}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// We have 3 options
|
||||
// If credentials are provided, we validate them; otherwise, we can continue without credentials for public clients only
|
||||
// Validate credentials based on the authentication method
|
||||
switch {
|
||||
// First, if we have a client secret, we validate it
|
||||
case input.ClientSecret != "":
|
||||
@@ -1410,21 +1422,21 @@ func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *g
|
||||
if err != nil {
|
||||
return nil, &common.OidcClientSecretInvalidError{}
|
||||
}
|
||||
return &client, nil
|
||||
return client, nil
|
||||
|
||||
// Next, check if we want to use client assertions from federated identities
|
||||
case input.ClientAssertionType == ClientAssertionTypeJWTBearer && input.ClientAssertion != "":
|
||||
err = s.verifyClientAssertionFromFederatedIdentities(ctx, &client, input)
|
||||
case isClientAssertion:
|
||||
err = s.verifyClientAssertionFromFederatedIdentities(ctx, client, input)
|
||||
if err != nil {
|
||||
log.Printf("Invalid assertion for client '%s': %v", client.ID, err)
|
||||
return nil, &common.OidcClientAssertionInvalidError{}
|
||||
}
|
||||
return &client, nil
|
||||
return client, nil
|
||||
|
||||
// There's no credentials
|
||||
// This is allowed only if the client is public
|
||||
case client.IsPublic:
|
||||
return &client, nil
|
||||
case client.IsPublic && allowPublicClientsWithoutAuth:
|
||||
return client, nil
|
||||
|
||||
// If we're here, we have no credentials AND the client is not public, so credentials are required
|
||||
default:
|
||||
@@ -1523,6 +1535,23 @@ func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.C
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractClientIDFromAssertion extracts the client_id from the JWT assertion's 'sub' claim
|
||||
func (s *OidcService) extractClientIDFromAssertion(assertion string) (string, error) {
|
||||
// Parse the JWT without verification first to get the claims
|
||||
insecureToken, err := jwt.ParseInsecure([]byte(assertion))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse JWT assertion: %w", err)
|
||||
}
|
||||
|
||||
// Extract the subject claim which must be the client_id according to RFC 7523
|
||||
sub, ok := insecureToken.Subject()
|
||||
if !ok || sub == "" {
|
||||
return "", fmt.Errorf("missing or invalid 'sub' claim in JWT assertion")
|
||||
}
|
||||
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, userID string, scopes string) (*dto.OidcClientPreviewDto, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
|
||||
@@ -134,7 +134,6 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
const (
|
||||
federatedClientIssuer = "https://external-idp.com"
|
||||
federatedClientAudience = "https://pocket-id.com"
|
||||
federatedClientSubject = "123456abcdef"
|
||||
federatedClientIssuerDefaults = "https://external-idp-defaults.com/"
|
||||
)
|
||||
|
||||
@@ -192,18 +191,24 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
federatedClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{
|
||||
Name: "Federated Client",
|
||||
CallbackURLs: []string{"https://example.com/callback"},
|
||||
}, "test-user-id")
|
||||
require.NoError(t, err)
|
||||
|
||||
federatedClient, err = s.UpdateClient(t.Context(), federatedClient.ID, dto.OidcClientCreateDto{
|
||||
Name: federatedClient.Name,
|
||||
CallbackURLs: federatedClient.CallbackURLs,
|
||||
Credentials: dto.OidcClientCredentialsDto{
|
||||
FederatedIdentities: []dto.OidcClientFederatedIdentityDto{
|
||||
{
|
||||
Issuer: federatedClientIssuer,
|
||||
Audience: federatedClientAudience,
|
||||
Subject: federatedClientSubject,
|
||||
Subject: federatedClient.ID,
|
||||
JWKS: federatedClientIssuer + "/jwks.json",
|
||||
},
|
||||
{Issuer: federatedClientIssuerDefaults},
|
||||
},
|
||||
},
|
||||
}, "test-user-id")
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test cases for confidential client (using client secret)
|
||||
@@ -213,7 +218,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
||||
ClientID: confidentialClient.ID,
|
||||
ClientSecret: confidentialSecret,
|
||||
})
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
assert.Equal(t, confidentialClient.ID, client.ID)
|
||||
@@ -224,7 +229,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
||||
ClientID: confidentialClient.ID,
|
||||
ClientSecret: "invalid-secret",
|
||||
})
|
||||
}, true)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, &common.OidcClientSecretInvalidError{})
|
||||
assert.Nil(t, client)
|
||||
@@ -234,7 +239,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
// Test with missing client secret
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
||||
ClientID: confidentialClient.ID,
|
||||
})
|
||||
}, true)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, &common.OidcMissingClientCredentialsError{})
|
||||
assert.Nil(t, client)
|
||||
@@ -247,11 +252,21 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
// Public clients don't require client secret
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
||||
ClientID: publicClient.ID,
|
||||
})
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
assert.Equal(t, publicClient.ID, client.ID)
|
||||
})
|
||||
|
||||
t.Run("Fails with no credentials if allowPublicClientsWithoutAuth is false", func(t *testing.T) {
|
||||
// Public clients don't require client secret
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
||||
ClientID: publicClient.ID,
|
||||
}, false)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, &common.OidcMissingClientCredentialsError{})
|
||||
assert.Nil(t, client)
|
||||
})
|
||||
})
|
||||
|
||||
// Test cases for federated client using JWT assertion
|
||||
@@ -261,7 +276,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
token, err := jwt.NewBuilder().
|
||||
Issuer(federatedClientIssuer).
|
||||
Audience([]string{federatedClientAudience}).
|
||||
Subject(federatedClientSubject).
|
||||
Subject(federatedClient.ID).
|
||||
IssuedAt(time.Now()).
|
||||
Expiration(time.Now().Add(10 * time.Minute)).
|
||||
Build()
|
||||
@@ -274,7 +289,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
ClientID: federatedClient.ID,
|
||||
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
||||
ClientAssertion: string(signedToken),
|
||||
})
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
assert.Equal(t, federatedClient.ID, client.ID)
|
||||
@@ -286,7 +301,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
ClientID: federatedClient.ID,
|
||||
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
||||
ClientAssertion: "invalid.jwt.token",
|
||||
})
|
||||
}, true)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{})
|
||||
assert.Nil(t, client)
|
||||
@@ -298,7 +313,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
builder := jwt.NewBuilder().
|
||||
Issuer(federatedClientIssuer).
|
||||
Audience([]string{federatedClientAudience}).
|
||||
Subject(federatedClientSubject).
|
||||
Subject(federatedClient.ID).
|
||||
IssuedAt(time.Now()).
|
||||
Expiration(time.Now().Add(10 * time.Minute))
|
||||
|
||||
@@ -315,7 +330,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
ClientID: federatedClient.ID,
|
||||
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
||||
ClientAssertion: string(signedToken),
|
||||
})
|
||||
}, true)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{})
|
||||
require.Nil(t, client)
|
||||
@@ -356,7 +371,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
ClientID: federatedClient.ID,
|
||||
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
||||
ClientAssertion: string(signedToken),
|
||||
})
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
assert.Equal(t, federatedClient.ID, client.ID)
|
||||
|
||||
Reference in New Issue
Block a user