diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go index 534863cc..fadfc68a 100644 --- a/backend/internal/service/oidc_service.go +++ b/backend/internal/service/oidc_service.go @@ -255,7 +255,7 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O tx.Rollback() }() - _, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input)) + _, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input), true) if err != nil { return CreatedTokens{}, err } @@ -336,7 +336,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu tx.Rollback() }() - client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input)) + client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input), true) if err != nil { return CreatedTokens{}, err } @@ -420,7 +420,7 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto tx.Rollback() }() - client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input)) + client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input), true) if err != nil { return CreatedTokens{}, err } @@ -490,6 +490,11 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto } func (s *OidcService) IntrospectToken(ctx context.Context, creds ClientAuthCredentials, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) { + client, err := s.verifyClientCredentialsInternal(ctx, s.db, creds, false) + if err != nil { + return introspectDto, err + } + // Get the type of the token and the client ID tokenType, token, err := s.jwtService.GetTokenType(tokenString) if err != nil { @@ -498,24 +503,16 @@ func (s *OidcService) IntrospectToken(ctx context.Context, creds ClientAuthCrede return introspectDto, nil //nolint:nilerr } - // If we don't have a client ID, get it from the token - // Otherwise, we need to make sure that the client ID passed as credential matches + // Get the audience from the token tokenAudiences, _ := token.Audience() if len(tokenAudiences) != 1 || tokenAudiences[0] == "" { - // We just treat the token as invalid introspectDto.Active = false return introspectDto, nil } - if creds.ClientID == "" { - creds.ClientID = tokenAudiences[0] - } else if creds.ClientID != tokenAudiences[0] { - return introspectDto, &common.OidcMissingClientCredentialsError{} - } - // Verify the credentials for the call - client, err := s.verifyClientCredentialsInternal(ctx, s.db, creds) - if err != nil { - return introspectDto, err + // Audience must match the client ID + if client.ID != tokenAudiences[0] { + return introspectDto, &common.OidcMissingClientCredentialsError{} } // Introspect the token @@ -1137,7 +1134,7 @@ func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.O ClientSecret: input.ClientSecret, ClientAssertionType: input.ClientAssertionType, ClientAssertion: input.ClientAssertion, - }) + }, true) if err != nil { return nil, err } @@ -1385,24 +1382,39 @@ func clientAuthCredentialsFromCreateTokensDto(d *dto.OidcCreateTokensDto) Client } } -func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input ClientAuthCredentials) (*model.OidcClient, error) { - // First, ensure we have a valid client ID - if input.ClientID == "" { +func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input ClientAuthCredentials, allowPublicClientsWithoutAuth bool) (client *model.OidcClient, err error) { + isClientAssertion := input.ClientAssertionType == ClientAssertionTypeJWTBearer && input.ClientAssertion != "" + + // Determine the client ID based on the authentication method + var clientID string + switch { + case isClientAssertion: + // Extract client ID from the JWT assertion's 'sub' claim + clientID, err = s.extractClientIDFromAssertion(input.ClientAssertion) + if err != nil { + slog.Error("Failed to extract client ID from assertion", "error", err) + return nil, &common.OidcClientAssertionInvalidError{} + } + case input.ClientID != "": + // Use the provided client ID for other authentication methods + clientID = input.ClientID + default: return nil, &common.OidcMissingClientCredentialsError{} } // Load the OIDC client's configuration - var client model.OidcClient - err := tx. + err = tx. WithContext(ctx). - First(&client, "id = ?", input.ClientID). + First(&client, "id = ?", clientID). Error if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) && isClientAssertion { + return nil, &common.OidcClientAssertionInvalidError{} + } return nil, err } - // We have 3 options - // If credentials are provided, we validate them; otherwise, we can continue without credentials for public clients only + // Validate credentials based on the authentication method switch { // First, if we have a client secret, we validate it case input.ClientSecret != "": @@ -1410,21 +1422,21 @@ func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *g if err != nil { return nil, &common.OidcClientSecretInvalidError{} } - return &client, nil + return client, nil // Next, check if we want to use client assertions from federated identities - case input.ClientAssertionType == ClientAssertionTypeJWTBearer && input.ClientAssertion != "": - err = s.verifyClientAssertionFromFederatedIdentities(ctx, &client, input) + case isClientAssertion: + err = s.verifyClientAssertionFromFederatedIdentities(ctx, client, input) if err != nil { log.Printf("Invalid assertion for client '%s': %v", client.ID, err) return nil, &common.OidcClientAssertionInvalidError{} } - return &client, nil + return client, nil // There's no credentials // This is allowed only if the client is public - case client.IsPublic: - return &client, nil + case client.IsPublic && allowPublicClientsWithoutAuth: + return client, nil // If we're here, we have no credentials AND the client is not public, so credentials are required default: @@ -1523,6 +1535,23 @@ func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.C return nil } +// extractClientIDFromAssertion extracts the client_id from the JWT assertion's 'sub' claim +func (s *OidcService) extractClientIDFromAssertion(assertion string) (string, error) { + // Parse the JWT without verification first to get the claims + insecureToken, err := jwt.ParseInsecure([]byte(assertion)) + if err != nil { + return "", fmt.Errorf("failed to parse JWT assertion: %w", err) + } + + // Extract the subject claim which must be the client_id according to RFC 7523 + sub, ok := insecureToken.Subject() + if !ok || sub == "" { + return "", fmt.Errorf("missing or invalid 'sub' claim in JWT assertion") + } + + return sub, nil +} + func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, userID string, scopes string) (*dto.OidcClientPreviewDto, error) { tx := s.db.Begin() defer func() { diff --git a/backend/internal/service/oidc_service_test.go b/backend/internal/service/oidc_service_test.go index 2d2abdda..c642a4e9 100644 --- a/backend/internal/service/oidc_service_test.go +++ b/backend/internal/service/oidc_service_test.go @@ -134,7 +134,6 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) { const ( federatedClientIssuer = "https://external-idp.com" federatedClientAudience = "https://pocket-id.com" - federatedClientSubject = "123456abcdef" federatedClientIssuerDefaults = "https://external-idp-defaults.com/" ) @@ -192,18 +191,24 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) { federatedClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{ Name: "Federated Client", CallbackURLs: []string{"https://example.com/callback"}, + }, "test-user-id") + require.NoError(t, err) + + federatedClient, err = s.UpdateClient(t.Context(), federatedClient.ID, dto.OidcClientCreateDto{ + Name: federatedClient.Name, + CallbackURLs: federatedClient.CallbackURLs, Credentials: dto.OidcClientCredentialsDto{ FederatedIdentities: []dto.OidcClientFederatedIdentityDto{ { Issuer: federatedClientIssuer, Audience: federatedClientAudience, - Subject: federatedClientSubject, + Subject: federatedClient.ID, JWKS: federatedClientIssuer + "/jwks.json", }, {Issuer: federatedClientIssuerDefaults}, }, }, - }, "test-user-id") + }) require.NoError(t, err) // Test cases for confidential client (using client secret) @@ -213,7 +218,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) { client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{ ClientID: confidentialClient.ID, ClientSecret: confidentialSecret, - }) + }, true) require.NoError(t, err) require.NotNil(t, client) assert.Equal(t, confidentialClient.ID, client.ID) @@ -224,7 +229,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) { client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{ ClientID: confidentialClient.ID, ClientSecret: "invalid-secret", - }) + }, true) require.Error(t, err) require.ErrorIs(t, err, &common.OidcClientSecretInvalidError{}) assert.Nil(t, client) @@ -234,7 +239,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) { // Test with missing client secret client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{ ClientID: confidentialClient.ID, - }) + }, true) require.Error(t, err) require.ErrorIs(t, err, &common.OidcMissingClientCredentialsError{}) assert.Nil(t, client) @@ -247,11 +252,21 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) { // Public clients don't require client secret client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{ ClientID: publicClient.ID, - }) + }, true) require.NoError(t, err) require.NotNil(t, client) assert.Equal(t, publicClient.ID, client.ID) }) + + t.Run("Fails with no credentials if allowPublicClientsWithoutAuth is false", func(t *testing.T) { + // Public clients don't require client secret + client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{ + ClientID: publicClient.ID, + }, false) + require.Error(t, err) + require.ErrorIs(t, err, &common.OidcMissingClientCredentialsError{}) + assert.Nil(t, client) + }) }) // Test cases for federated client using JWT assertion @@ -261,7 +276,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) { token, err := jwt.NewBuilder(). Issuer(federatedClientIssuer). Audience([]string{federatedClientAudience}). - Subject(federatedClientSubject). + Subject(federatedClient.ID). IssuedAt(time.Now()). Expiration(time.Now().Add(10 * time.Minute)). Build() @@ -274,7 +289,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) { ClientID: federatedClient.ID, ClientAssertionType: ClientAssertionTypeJWTBearer, ClientAssertion: string(signedToken), - }) + }, true) require.NoError(t, err) require.NotNil(t, client) assert.Equal(t, federatedClient.ID, client.ID) @@ -286,7 +301,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) { ClientID: federatedClient.ID, ClientAssertionType: ClientAssertionTypeJWTBearer, ClientAssertion: "invalid.jwt.token", - }) + }, true) require.Error(t, err) require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{}) assert.Nil(t, client) @@ -298,7 +313,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) { builder := jwt.NewBuilder(). Issuer(federatedClientIssuer). Audience([]string{federatedClientAudience}). - Subject(federatedClientSubject). + Subject(federatedClient.ID). IssuedAt(time.Now()). Expiration(time.Now().Add(10 * time.Minute)) @@ -315,7 +330,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) { ClientID: federatedClient.ID, ClientAssertionType: ClientAssertionTypeJWTBearer, ClientAssertion: string(signedToken), - }) + }, true) require.Error(t, err) require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{}) require.Nil(t, client) @@ -356,7 +371,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) { ClientID: federatedClient.ID, ClientAssertionType: ClientAssertionTypeJWTBearer, ClientAssertion: string(signedToken), - }) + }, true) require.NoError(t, err) require.NotNil(t, client) assert.Equal(t, federatedClient.ID, client.ID)