mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-21 17:25:44 +03:00
fix: token introspection authentication not handled correctly (#704)
Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
@@ -255,7 +255,7 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O
|
|||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
_, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input))
|
_, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input), true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return CreatedTokens{}, err
|
return CreatedTokens{}, err
|
||||||
}
|
}
|
||||||
@@ -336,7 +336,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu
|
|||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input))
|
client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input), true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return CreatedTokens{}, err
|
return CreatedTokens{}, err
|
||||||
}
|
}
|
||||||
@@ -420,7 +420,7 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
|
|||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input))
|
client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input), true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return CreatedTokens{}, err
|
return CreatedTokens{}, err
|
||||||
}
|
}
|
||||||
@@ -490,6 +490,11 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) IntrospectToken(ctx context.Context, creds ClientAuthCredentials, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
|
func (s *OidcService) IntrospectToken(ctx context.Context, creds ClientAuthCredentials, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
|
||||||
|
client, err := s.verifyClientCredentialsInternal(ctx, s.db, creds, false)
|
||||||
|
if err != nil {
|
||||||
|
return introspectDto, err
|
||||||
|
}
|
||||||
|
|
||||||
// Get the type of the token and the client ID
|
// Get the type of the token and the client ID
|
||||||
tokenType, token, err := s.jwtService.GetTokenType(tokenString)
|
tokenType, token, err := s.jwtService.GetTokenType(tokenString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -498,24 +503,16 @@ func (s *OidcService) IntrospectToken(ctx context.Context, creds ClientAuthCrede
|
|||||||
return introspectDto, nil //nolint:nilerr
|
return introspectDto, nil //nolint:nilerr
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we don't have a client ID, get it from the token
|
// Get the audience from the token
|
||||||
// Otherwise, we need to make sure that the client ID passed as credential matches
|
|
||||||
tokenAudiences, _ := token.Audience()
|
tokenAudiences, _ := token.Audience()
|
||||||
if len(tokenAudiences) != 1 || tokenAudiences[0] == "" {
|
if len(tokenAudiences) != 1 || tokenAudiences[0] == "" {
|
||||||
// We just treat the token as invalid
|
|
||||||
introspectDto.Active = false
|
introspectDto.Active = false
|
||||||
return introspectDto, nil
|
return introspectDto, nil
|
||||||
}
|
}
|
||||||
if creds.ClientID == "" {
|
|
||||||
creds.ClientID = tokenAudiences[0]
|
|
||||||
} else if creds.ClientID != tokenAudiences[0] {
|
|
||||||
return introspectDto, &common.OidcMissingClientCredentialsError{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the credentials for the call
|
// Audience must match the client ID
|
||||||
client, err := s.verifyClientCredentialsInternal(ctx, s.db, creds)
|
if client.ID != tokenAudiences[0] {
|
||||||
if err != nil {
|
return introspectDto, &common.OidcMissingClientCredentialsError{}
|
||||||
return introspectDto, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Introspect the token
|
// Introspect the token
|
||||||
@@ -1137,7 +1134,7 @@ func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.O
|
|||||||
ClientSecret: input.ClientSecret,
|
ClientSecret: input.ClientSecret,
|
||||||
ClientAssertionType: input.ClientAssertionType,
|
ClientAssertionType: input.ClientAssertionType,
|
||||||
ClientAssertion: input.ClientAssertion,
|
ClientAssertion: input.ClientAssertion,
|
||||||
})
|
}, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1385,24 +1382,39 @@ func clientAuthCredentialsFromCreateTokensDto(d *dto.OidcCreateTokensDto) Client
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input ClientAuthCredentials) (*model.OidcClient, error) {
|
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input ClientAuthCredentials, allowPublicClientsWithoutAuth bool) (client *model.OidcClient, err error) {
|
||||||
// First, ensure we have a valid client ID
|
isClientAssertion := input.ClientAssertionType == ClientAssertionTypeJWTBearer && input.ClientAssertion != ""
|
||||||
if input.ClientID == "" {
|
|
||||||
|
// Determine the client ID based on the authentication method
|
||||||
|
var clientID string
|
||||||
|
switch {
|
||||||
|
case isClientAssertion:
|
||||||
|
// Extract client ID from the JWT assertion's 'sub' claim
|
||||||
|
clientID, err = s.extractClientIDFromAssertion(input.ClientAssertion)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("Failed to extract client ID from assertion", "error", err)
|
||||||
|
return nil, &common.OidcClientAssertionInvalidError{}
|
||||||
|
}
|
||||||
|
case input.ClientID != "":
|
||||||
|
// Use the provided client ID for other authentication methods
|
||||||
|
clientID = input.ClientID
|
||||||
|
default:
|
||||||
return nil, &common.OidcMissingClientCredentialsError{}
|
return nil, &common.OidcMissingClientCredentialsError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load the OIDC client's configuration
|
// Load the OIDC client's configuration
|
||||||
var client model.OidcClient
|
err = tx.
|
||||||
err := tx.
|
|
||||||
WithContext(ctx).
|
WithContext(ctx).
|
||||||
First(&client, "id = ?", input.ClientID).
|
First(&client, "id = ?", clientID).
|
||||||
Error
|
Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) && isClientAssertion {
|
||||||
|
return nil, &common.OidcClientAssertionInvalidError{}
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// We have 3 options
|
// Validate credentials based on the authentication method
|
||||||
// If credentials are provided, we validate them; otherwise, we can continue without credentials for public clients only
|
|
||||||
switch {
|
switch {
|
||||||
// First, if we have a client secret, we validate it
|
// First, if we have a client secret, we validate it
|
||||||
case input.ClientSecret != "":
|
case input.ClientSecret != "":
|
||||||
@@ -1410,21 +1422,21 @@ func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *g
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &common.OidcClientSecretInvalidError{}
|
return nil, &common.OidcClientSecretInvalidError{}
|
||||||
}
|
}
|
||||||
return &client, nil
|
return client, nil
|
||||||
|
|
||||||
// Next, check if we want to use client assertions from federated identities
|
// Next, check if we want to use client assertions from federated identities
|
||||||
case input.ClientAssertionType == ClientAssertionTypeJWTBearer && input.ClientAssertion != "":
|
case isClientAssertion:
|
||||||
err = s.verifyClientAssertionFromFederatedIdentities(ctx, &client, input)
|
err = s.verifyClientAssertionFromFederatedIdentities(ctx, client, input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Invalid assertion for client '%s': %v", client.ID, err)
|
log.Printf("Invalid assertion for client '%s': %v", client.ID, err)
|
||||||
return nil, &common.OidcClientAssertionInvalidError{}
|
return nil, &common.OidcClientAssertionInvalidError{}
|
||||||
}
|
}
|
||||||
return &client, nil
|
return client, nil
|
||||||
|
|
||||||
// There's no credentials
|
// There's no credentials
|
||||||
// This is allowed only if the client is public
|
// This is allowed only if the client is public
|
||||||
case client.IsPublic:
|
case client.IsPublic && allowPublicClientsWithoutAuth:
|
||||||
return &client, nil
|
return client, nil
|
||||||
|
|
||||||
// If we're here, we have no credentials AND the client is not public, so credentials are required
|
// If we're here, we have no credentials AND the client is not public, so credentials are required
|
||||||
default:
|
default:
|
||||||
@@ -1523,6 +1535,23 @@ func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.C
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractClientIDFromAssertion extracts the client_id from the JWT assertion's 'sub' claim
|
||||||
|
func (s *OidcService) extractClientIDFromAssertion(assertion string) (string, error) {
|
||||||
|
// Parse the JWT without verification first to get the claims
|
||||||
|
insecureToken, err := jwt.ParseInsecure([]byte(assertion))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to parse JWT assertion: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the subject claim which must be the client_id according to RFC 7523
|
||||||
|
sub, ok := insecureToken.Subject()
|
||||||
|
if !ok || sub == "" {
|
||||||
|
return "", fmt.Errorf("missing or invalid 'sub' claim in JWT assertion")
|
||||||
|
}
|
||||||
|
|
||||||
|
return sub, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, userID string, scopes string) (*dto.OidcClientPreviewDto, error) {
|
func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, userID string, scopes string) (*dto.OidcClientPreviewDto, error) {
|
||||||
tx := s.db.Begin()
|
tx := s.db.Begin()
|
||||||
defer func() {
|
defer func() {
|
||||||
|
|||||||
@@ -134,7 +134,6 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
|||||||
const (
|
const (
|
||||||
federatedClientIssuer = "https://external-idp.com"
|
federatedClientIssuer = "https://external-idp.com"
|
||||||
federatedClientAudience = "https://pocket-id.com"
|
federatedClientAudience = "https://pocket-id.com"
|
||||||
federatedClientSubject = "123456abcdef"
|
|
||||||
federatedClientIssuerDefaults = "https://external-idp-defaults.com/"
|
federatedClientIssuerDefaults = "https://external-idp-defaults.com/"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -192,18 +191,24 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
|||||||
federatedClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{
|
federatedClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{
|
||||||
Name: "Federated Client",
|
Name: "Federated Client",
|
||||||
CallbackURLs: []string{"https://example.com/callback"},
|
CallbackURLs: []string{"https://example.com/callback"},
|
||||||
|
}, "test-user-id")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
federatedClient, err = s.UpdateClient(t.Context(), federatedClient.ID, dto.OidcClientCreateDto{
|
||||||
|
Name: federatedClient.Name,
|
||||||
|
CallbackURLs: federatedClient.CallbackURLs,
|
||||||
Credentials: dto.OidcClientCredentialsDto{
|
Credentials: dto.OidcClientCredentialsDto{
|
||||||
FederatedIdentities: []dto.OidcClientFederatedIdentityDto{
|
FederatedIdentities: []dto.OidcClientFederatedIdentityDto{
|
||||||
{
|
{
|
||||||
Issuer: federatedClientIssuer,
|
Issuer: federatedClientIssuer,
|
||||||
Audience: federatedClientAudience,
|
Audience: federatedClientAudience,
|
||||||
Subject: federatedClientSubject,
|
Subject: federatedClient.ID,
|
||||||
JWKS: federatedClientIssuer + "/jwks.json",
|
JWKS: federatedClientIssuer + "/jwks.json",
|
||||||
},
|
},
|
||||||
{Issuer: federatedClientIssuerDefaults},
|
{Issuer: federatedClientIssuerDefaults},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}, "test-user-id")
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Test cases for confidential client (using client secret)
|
// Test cases for confidential client (using client secret)
|
||||||
@@ -213,7 +218,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
|||||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
||||||
ClientID: confidentialClient.ID,
|
ClientID: confidentialClient.ID,
|
||||||
ClientSecret: confidentialSecret,
|
ClientSecret: confidentialSecret,
|
||||||
})
|
}, true)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, client)
|
require.NotNil(t, client)
|
||||||
assert.Equal(t, confidentialClient.ID, client.ID)
|
assert.Equal(t, confidentialClient.ID, client.ID)
|
||||||
@@ -224,7 +229,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
|||||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
||||||
ClientID: confidentialClient.ID,
|
ClientID: confidentialClient.ID,
|
||||||
ClientSecret: "invalid-secret",
|
ClientSecret: "invalid-secret",
|
||||||
})
|
}, true)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.ErrorIs(t, err, &common.OidcClientSecretInvalidError{})
|
require.ErrorIs(t, err, &common.OidcClientSecretInvalidError{})
|
||||||
assert.Nil(t, client)
|
assert.Nil(t, client)
|
||||||
@@ -234,7 +239,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
|||||||
// Test with missing client secret
|
// Test with missing client secret
|
||||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
||||||
ClientID: confidentialClient.ID,
|
ClientID: confidentialClient.ID,
|
||||||
})
|
}, true)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.ErrorIs(t, err, &common.OidcMissingClientCredentialsError{})
|
require.ErrorIs(t, err, &common.OidcMissingClientCredentialsError{})
|
||||||
assert.Nil(t, client)
|
assert.Nil(t, client)
|
||||||
@@ -247,11 +252,21 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
|||||||
// Public clients don't require client secret
|
// Public clients don't require client secret
|
||||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
||||||
ClientID: publicClient.ID,
|
ClientID: publicClient.ID,
|
||||||
})
|
}, true)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, client)
|
require.NotNil(t, client)
|
||||||
assert.Equal(t, publicClient.ID, client.ID)
|
assert.Equal(t, publicClient.ID, client.ID)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("Fails with no credentials if allowPublicClientsWithoutAuth is false", func(t *testing.T) {
|
||||||
|
// Public clients don't require client secret
|
||||||
|
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
||||||
|
ClientID: publicClient.ID,
|
||||||
|
}, false)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorIs(t, err, &common.OidcMissingClientCredentialsError{})
|
||||||
|
assert.Nil(t, client)
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
// Test cases for federated client using JWT assertion
|
// Test cases for federated client using JWT assertion
|
||||||
@@ -261,7 +276,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
|||||||
token, err := jwt.NewBuilder().
|
token, err := jwt.NewBuilder().
|
||||||
Issuer(federatedClientIssuer).
|
Issuer(federatedClientIssuer).
|
||||||
Audience([]string{federatedClientAudience}).
|
Audience([]string{federatedClientAudience}).
|
||||||
Subject(federatedClientSubject).
|
Subject(federatedClient.ID).
|
||||||
IssuedAt(time.Now()).
|
IssuedAt(time.Now()).
|
||||||
Expiration(time.Now().Add(10 * time.Minute)).
|
Expiration(time.Now().Add(10 * time.Minute)).
|
||||||
Build()
|
Build()
|
||||||
@@ -274,7 +289,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
|||||||
ClientID: federatedClient.ID,
|
ClientID: federatedClient.ID,
|
||||||
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
||||||
ClientAssertion: string(signedToken),
|
ClientAssertion: string(signedToken),
|
||||||
})
|
}, true)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, client)
|
require.NotNil(t, client)
|
||||||
assert.Equal(t, federatedClient.ID, client.ID)
|
assert.Equal(t, federatedClient.ID, client.ID)
|
||||||
@@ -286,7 +301,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
|||||||
ClientID: federatedClient.ID,
|
ClientID: federatedClient.ID,
|
||||||
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
||||||
ClientAssertion: "invalid.jwt.token",
|
ClientAssertion: "invalid.jwt.token",
|
||||||
})
|
}, true)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{})
|
require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{})
|
||||||
assert.Nil(t, client)
|
assert.Nil(t, client)
|
||||||
@@ -298,7 +313,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
|||||||
builder := jwt.NewBuilder().
|
builder := jwt.NewBuilder().
|
||||||
Issuer(federatedClientIssuer).
|
Issuer(federatedClientIssuer).
|
||||||
Audience([]string{federatedClientAudience}).
|
Audience([]string{federatedClientAudience}).
|
||||||
Subject(federatedClientSubject).
|
Subject(federatedClient.ID).
|
||||||
IssuedAt(time.Now()).
|
IssuedAt(time.Now()).
|
||||||
Expiration(time.Now().Add(10 * time.Minute))
|
Expiration(time.Now().Add(10 * time.Minute))
|
||||||
|
|
||||||
@@ -315,7 +330,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
|||||||
ClientID: federatedClient.ID,
|
ClientID: federatedClient.ID,
|
||||||
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
||||||
ClientAssertion: string(signedToken),
|
ClientAssertion: string(signedToken),
|
||||||
})
|
}, true)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{})
|
require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{})
|
||||||
require.Nil(t, client)
|
require.Nil(t, client)
|
||||||
@@ -356,7 +371,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
|||||||
ClientID: federatedClient.ID,
|
ClientID: federatedClient.ID,
|
||||||
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
||||||
ClientAssertion: string(signedToken),
|
ClientAssertion: string(signedToken),
|
||||||
})
|
}, true)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, client)
|
require.NotNil(t, client)
|
||||||
assert.Equal(t, federatedClient.ID, client.ID)
|
assert.Equal(t, federatedClient.ID, client.ID)
|
||||||
|
|||||||
Reference in New Issue
Block a user