mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-16 02:02:58 +03:00
feat: JWT bearer assertions for client authentication (#566)
Co-authored-by: Kyle Mendell <ksm@ofkm.us> Co-authored-by: Kyle Mendell <kmendell@ofkm.us> Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
committed by
GitHub
parent
035b2c022b
commit
05bfe00924
365
backend/internal/service/oidc_service_test.go
Normal file
365
backend/internal/service/oidc_service_test.go
Normal file
@@ -0,0 +1,365 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
)
|
||||
|
||||
// generateTestECDSAKey creates an ECDSA key for testing
|
||||
func generateTestECDSAKey(t *testing.T) (jwk.Key, []byte) {
|
||||
t.Helper()
|
||||
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
privateJwk, err := jwk.Import(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = privateJwk.Set(jwk.KeyIDKey, "test-key-1")
|
||||
require.NoError(t, err)
|
||||
err = privateJwk.Set(jwk.AlgorithmKey, "ES256")
|
||||
require.NoError(t, err)
|
||||
err = privateJwk.Set("use", "sig")
|
||||
require.NoError(t, err)
|
||||
|
||||
publicJwk, err := jwk.PublicKeyOf(privateJwk)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a JWK Set with the public key
|
||||
jwkSet := jwk.NewSet()
|
||||
err = jwkSet.AddKey(publicJwk)
|
||||
require.NoError(t, err)
|
||||
jwkSetJSON, err := json.Marshal(jwkSet)
|
||||
require.NoError(t, err)
|
||||
|
||||
return privateJwk, jwkSetJSON
|
||||
}
|
||||
|
||||
func TestOidcService_jwkSetForURL(t *testing.T) {
|
||||
// Generate a test key for JWKS
|
||||
_, jwkSetJSON1 := generateTestECDSAKey(t)
|
||||
_, jwkSetJSON2 := generateTestECDSAKey(t)
|
||||
|
||||
// Create a mock HTTP client with responses for different URLs
|
||||
const (
|
||||
url1 = "https://example.com/.well-known/jwks.json"
|
||||
url2 = "https://other-issuer.com/jwks"
|
||||
)
|
||||
mockResponses := map[string]*http.Response{
|
||||
//nolint:bodyclose
|
||||
url1: NewMockResponse(http.StatusOK, string(jwkSetJSON1)),
|
||||
//nolint:bodyclose
|
||||
url2: NewMockResponse(http.StatusOK, string(jwkSetJSON2)),
|
||||
}
|
||||
httpClient := &http.Client{
|
||||
Transport: &MockRoundTripper{
|
||||
Responses: mockResponses,
|
||||
},
|
||||
}
|
||||
|
||||
// Create the OidcService with our mock client
|
||||
s := &OidcService{
|
||||
httpClient: httpClient,
|
||||
}
|
||||
|
||||
var err error
|
||||
s.jwkCache, err = s.getJWKCache(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("Fetches and caches JWK set", func(t *testing.T) {
|
||||
jwks, err := s.jwkSetForURL(t.Context(), url1)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, jwks)
|
||||
|
||||
// Verify the JWK set contains our key
|
||||
require.Equal(t, 1, jwks.Len())
|
||||
})
|
||||
|
||||
t.Run("Fails with invalid URL", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
|
||||
defer cancel()
|
||||
_, err := s.jwkSetForURL(ctx, "https://bad-url.com")
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
})
|
||||
|
||||
t.Run("Safe for concurrent use", func(t *testing.T) {
|
||||
const concurrency = 20
|
||||
|
||||
// Channel to collect errors
|
||||
errChan := make(chan error, concurrency)
|
||||
|
||||
// Start concurrent requests
|
||||
for range concurrency {
|
||||
go func() {
|
||||
jwks, err := s.jwkSetForURL(t.Context(), url2)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the JWK set is valid
|
||||
if jwks == nil || jwks.Len() != 1 {
|
||||
errChan <- assert.AnError
|
||||
return
|
||||
}
|
||||
|
||||
errChan <- nil
|
||||
}()
|
||||
}
|
||||
|
||||
// Check for errors
|
||||
for range concurrency {
|
||||
assert.NoError(t, <-errChan, "Concurrent JWK set fetching should not produce errors")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
const (
|
||||
federatedClientIssuer = "https://external-idp.com"
|
||||
federatedClientAudience = "https://pocket-id.com"
|
||||
federatedClientSubject = "123456abcdef"
|
||||
federatedClientIssuerDefaults = "https://external-idp-defaults.com/"
|
||||
)
|
||||
|
||||
var err error
|
||||
// Create a test database
|
||||
db := newDatabaseForTest(t)
|
||||
|
||||
// Create two JWKs for testing
|
||||
privateJWK, jwkSetJSON := generateTestECDSAKey(t)
|
||||
require.NoError(t, err)
|
||||
privateJWKDefaults, jwkSetJSONDefaults := generateTestECDSAKey(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a mock HTTP client with custom transport to return the JWKS
|
||||
httpClient := &http.Client{
|
||||
Transport: &MockRoundTripper{
|
||||
Responses: map[string]*http.Response{
|
||||
//nolint:bodyclose
|
||||
federatedClientIssuer + "/jwks.json": NewMockResponse(http.StatusOK, string(jwkSetJSON)),
|
||||
//nolint:bodyclose
|
||||
federatedClientIssuerDefaults + ".well-known/jwks.json": NewMockResponse(http.StatusOK, string(jwkSetJSONDefaults)),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Init the OidcService
|
||||
s := &OidcService{
|
||||
db: db,
|
||||
httpClient: httpClient,
|
||||
}
|
||||
s.jwkCache, err = s.getJWKCache(t.Context())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create the test clients
|
||||
// 1. Confidential client
|
||||
confidentialClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{
|
||||
Name: "Confidential Client",
|
||||
CallbackURLs: []string{"https://example.com/callback"},
|
||||
}, "test-user-id")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a client secret for the confidential client
|
||||
confidentialSecret, err := s.CreateClientSecret(t.Context(), confidentialClient.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 2. Public client
|
||||
publicClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{
|
||||
Name: "Public Client",
|
||||
CallbackURLs: []string{"https://example.com/callback"},
|
||||
IsPublic: true,
|
||||
}, "test-user-id")
|
||||
require.NoError(t, err)
|
||||
|
||||
// 3. Confidential client with federated identity
|
||||
federatedClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{
|
||||
Name: "Federated Client",
|
||||
CallbackURLs: []string{"https://example.com/callback"},
|
||||
Credentials: dto.OidcClientCredentialsDto{
|
||||
FederatedIdentities: []dto.OidcClientFederatedIdentityDto{
|
||||
{
|
||||
Issuer: federatedClientIssuer,
|
||||
Audience: federatedClientAudience,
|
||||
Subject: federatedClientSubject,
|
||||
JWKS: federatedClientIssuer + "/jwks.json",
|
||||
},
|
||||
{Issuer: federatedClientIssuerDefaults},
|
||||
},
|
||||
},
|
||||
}, "test-user-id")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test cases for confidential client (using client secret)
|
||||
t.Run("Confidential client", func(t *testing.T) {
|
||||
t.Run("Succeeds with valid secret", func(t *testing.T) {
|
||||
// Test with valid client credentials
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
|
||||
ClientID: confidentialClient.ID,
|
||||
ClientSecret: confidentialSecret,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
assert.Equal(t, confidentialClient.ID, client.ID)
|
||||
})
|
||||
|
||||
t.Run("Fails with invalid secret", func(t *testing.T) {
|
||||
// Test with invalid client secret
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
|
||||
ClientID: confidentialClient.ID,
|
||||
ClientSecret: "invalid-secret",
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, &common.OidcClientSecretInvalidError{})
|
||||
assert.Nil(t, client)
|
||||
})
|
||||
|
||||
t.Run("Fails with missing secret", func(t *testing.T) {
|
||||
// Test with missing client secret
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
|
||||
ClientID: confidentialClient.ID,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, &common.OidcMissingClientCredentialsError{})
|
||||
assert.Nil(t, client)
|
||||
})
|
||||
})
|
||||
|
||||
// Test cases for public client
|
||||
t.Run("Public client", func(t *testing.T) {
|
||||
t.Run("Succeeds with no credentials", func(t *testing.T) {
|
||||
// Public clients don't require client secret
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
|
||||
ClientID: publicClient.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
assert.Equal(t, publicClient.ID, client.ID)
|
||||
})
|
||||
})
|
||||
|
||||
// Test cases for federated client using JWT assertion
|
||||
t.Run("Federated client", func(t *testing.T) {
|
||||
t.Run("Succeeds with valid JWT", func(t *testing.T) {
|
||||
// Create JWT for federated identity
|
||||
token, err := jwt.NewBuilder().
|
||||
Issuer(federatedClientIssuer).
|
||||
Audience([]string{federatedClientAudience}).
|
||||
Subject(federatedClientSubject).
|
||||
IssuedAt(time.Now()).
|
||||
Expiration(time.Now().Add(10 * time.Minute)).
|
||||
Build()
|
||||
require.NoError(t, err)
|
||||
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256(), privateJWK))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test with valid JWT assertion
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
|
||||
ClientID: federatedClient.ID,
|
||||
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
||||
ClientAssertion: string(signedToken),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
assert.Equal(t, federatedClient.ID, client.ID)
|
||||
})
|
||||
|
||||
t.Run("Fails with malformed JWT", func(t *testing.T) {
|
||||
// Test with invalid JWT assertion (just a random string)
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
|
||||
ClientID: federatedClient.ID,
|
||||
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
||||
ClientAssertion: "invalid.jwt.token",
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{})
|
||||
assert.Nil(t, client)
|
||||
})
|
||||
|
||||
testBadJWT := func(builderFn func(builder *jwt.Builder)) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
// Populate all claims with valid values
|
||||
builder := jwt.NewBuilder().
|
||||
Issuer(federatedClientIssuer).
|
||||
Audience([]string{federatedClientAudience}).
|
||||
Subject(federatedClientSubject).
|
||||
IssuedAt(time.Now()).
|
||||
Expiration(time.Now().Add(10 * time.Minute))
|
||||
|
||||
// Call builderFn to override the claims
|
||||
builderFn(builder)
|
||||
|
||||
token, err := builder.Build()
|
||||
require.NoError(t, err)
|
||||
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256(), privateJWK))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test with invalid JWT assertion
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
|
||||
ClientID: federatedClient.ID,
|
||||
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
||||
ClientAssertion: string(signedToken),
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{})
|
||||
require.Nil(t, client)
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("Fails with expired JWT", testBadJWT(func(builder *jwt.Builder) {
|
||||
builder.Expiration(time.Now().Add(-30 * time.Minute))
|
||||
}))
|
||||
|
||||
t.Run("Fails with wrong issuer in JWT", testBadJWT(func(builder *jwt.Builder) {
|
||||
builder.Issuer("https://bad-issuer.com")
|
||||
}))
|
||||
|
||||
t.Run("Fails with wrong audience in JWT", testBadJWT(func(builder *jwt.Builder) {
|
||||
builder.Audience([]string{"bad-audience"})
|
||||
}))
|
||||
|
||||
t.Run("Fails with wrong subject in JWT", testBadJWT(func(builder *jwt.Builder) {
|
||||
builder.Subject("bad-subject")
|
||||
}))
|
||||
|
||||
t.Run("Uses default values for audience and subject", func(t *testing.T) {
|
||||
// Create JWT for federated identity
|
||||
token, err := jwt.NewBuilder().
|
||||
Issuer(federatedClientIssuerDefaults).
|
||||
Audience([]string{common.EnvConfig.AppURL}).
|
||||
Subject(federatedClient.ID).
|
||||
IssuedAt(time.Now()).
|
||||
Expiration(time.Now().Add(10 * time.Minute)).
|
||||
Build()
|
||||
require.NoError(t, err)
|
||||
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256(), privateJWKDefaults))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test with valid JWT assertion
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
|
||||
ClientID: federatedClient.ID,
|
||||
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
||||
ClientAssertion: string(signedToken),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
assert.Equal(t, federatedClient.ID, client.ID)
|
||||
})
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user