mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-11 15:53:00 +03:00
977 lines
32 KiB
Go
977 lines
32 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/lestrrat-go/jwx/v3/jwa"
|
|
"github.com/lestrrat-go/jwx/v3/jwk"
|
|
"github.com/lestrrat-go/jwx/v3/jwt"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
|
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
|
"github.com/pocket-id/pocket-id/backend/internal/model"
|
|
"github.com/pocket-id/pocket-id/backend/internal/storage"
|
|
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
|
|
)
|
|
|
|
// generateTestECDSAKey creates an ECDSA key for testing
|
|
func generateTestECDSAKey(t *testing.T) (jwk.Key, []byte) {
|
|
t.Helper()
|
|
|
|
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
require.NoError(t, err)
|
|
|
|
privateJwk, err := jwk.Import(privateKey)
|
|
require.NoError(t, err)
|
|
|
|
err = privateJwk.Set(jwk.KeyIDKey, "test-key-1")
|
|
require.NoError(t, err)
|
|
err = privateJwk.Set(jwk.AlgorithmKey, "ES256")
|
|
require.NoError(t, err)
|
|
err = privateJwk.Set("use", "sig")
|
|
require.NoError(t, err)
|
|
|
|
publicJwk, err := jwk.PublicKeyOf(privateJwk)
|
|
require.NoError(t, err)
|
|
|
|
// Create a JWK Set with the public key
|
|
jwkSet := jwk.NewSet()
|
|
err = jwkSet.AddKey(publicJwk)
|
|
require.NoError(t, err)
|
|
jwkSetJSON, err := json.Marshal(jwkSet)
|
|
require.NoError(t, err)
|
|
|
|
return privateJwk, jwkSetJSON
|
|
}
|
|
|
|
func TestOidcService_jwkSetForURL(t *testing.T) {
|
|
// Generate a test key for JWKS
|
|
_, jwkSetJSON1 := generateTestECDSAKey(t)
|
|
_, jwkSetJSON2 := generateTestECDSAKey(t)
|
|
|
|
// Create a mock HTTP client with responses for different URLs
|
|
const (
|
|
url1 = "https://example.com/.well-known/jwks.json"
|
|
url2 = "https://other-issuer.com/jwks"
|
|
)
|
|
mockResponses := map[string]*http.Response{
|
|
//nolint:bodyclose
|
|
url1: testutils.NewMockResponse(http.StatusOK, string(jwkSetJSON1)),
|
|
//nolint:bodyclose
|
|
url2: testutils.NewMockResponse(http.StatusOK, string(jwkSetJSON2)),
|
|
}
|
|
httpClient := &http.Client{
|
|
Transport: &testutils.MockRoundTripper{
|
|
Responses: mockResponses,
|
|
},
|
|
}
|
|
|
|
// Create the OidcService with our mock client
|
|
s := &OidcService{
|
|
httpClient: httpClient,
|
|
}
|
|
|
|
var err error
|
|
s.jwkCache, err = s.getJWKCache(t.Context())
|
|
require.NoError(t, err)
|
|
|
|
t.Run("Fetches and caches JWK set", func(t *testing.T) {
|
|
jwks, err := s.jwkSetForURL(t.Context(), url1)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, jwks)
|
|
|
|
// Verify the JWK set contains our key
|
|
require.Equal(t, 1, jwks.Len())
|
|
})
|
|
|
|
t.Run("Fails with invalid URL", func(t *testing.T) {
|
|
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
|
|
defer cancel()
|
|
_, err := s.jwkSetForURL(ctx, "https://bad-url.com")
|
|
require.Error(t, err)
|
|
require.ErrorIs(t, err, context.DeadlineExceeded)
|
|
})
|
|
|
|
t.Run("Safe for concurrent use", func(t *testing.T) {
|
|
const concurrency = 20
|
|
|
|
// Channel to collect errors
|
|
errChan := make(chan error, concurrency)
|
|
|
|
// Start concurrent requests
|
|
for range concurrency {
|
|
go func() {
|
|
jwks, err := s.jwkSetForURL(t.Context(), url2)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
|
|
// Verify the JWK set is valid
|
|
if jwks == nil || jwks.Len() != 1 {
|
|
errChan <- assert.AnError
|
|
return
|
|
}
|
|
|
|
errChan <- nil
|
|
}()
|
|
}
|
|
|
|
// Check for errors
|
|
for range concurrency {
|
|
assert.NoError(t, <-errChan, "Concurrent JWK set fetching should not produce errors")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
|
const (
|
|
federatedClientIssuer = "https://external-idp.com"
|
|
federatedClientAudience = "https://pocket-id.com"
|
|
federatedClientIssuerDefaults = "https://external-idp-defaults.com/"
|
|
)
|
|
|
|
var err error
|
|
// Create a test database
|
|
db := testutils.NewDatabaseForTest(t)
|
|
common.EnvConfig.EncryptionKey = []byte("0123456789abcdef0123456789abcdef")
|
|
|
|
// Create two JWKs for testing
|
|
privateJWK, jwkSetJSON := generateTestECDSAKey(t)
|
|
require.NoError(t, err)
|
|
privateJWKDefaults, jwkSetJSONDefaults := generateTestECDSAKey(t)
|
|
require.NoError(t, err)
|
|
|
|
// Create a mock config and JwtService to test complete a token creation process
|
|
mockConfig := NewTestAppConfigService(&model.AppConfig{
|
|
SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes
|
|
})
|
|
mockJwtService, err := NewJwtService(db, mockConfig)
|
|
require.NoError(t, err)
|
|
|
|
// Create a mock HTTP client with custom transport to return the JWKS
|
|
httpClient := &http.Client{
|
|
Transport: &testutils.MockRoundTripper{
|
|
Responses: map[string]*http.Response{
|
|
//nolint:bodyclose
|
|
federatedClientIssuer + "/jwks.json": testutils.NewMockResponse(http.StatusOK, string(jwkSetJSON)),
|
|
//nolint:bodyclose
|
|
federatedClientIssuerDefaults + ".well-known/jwks.json": testutils.NewMockResponse(http.StatusOK, string(jwkSetJSONDefaults)),
|
|
},
|
|
},
|
|
}
|
|
|
|
// Init the OidcService
|
|
s := &OidcService{
|
|
db: db,
|
|
jwtService: mockJwtService,
|
|
appConfigService: mockConfig,
|
|
httpClient: httpClient,
|
|
}
|
|
s.jwkCache, err = s.getJWKCache(t.Context())
|
|
require.NoError(t, err)
|
|
|
|
// Create the test clients
|
|
// 1. Confidential client
|
|
confidentialClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{
|
|
OidcClientUpdateDto: dto.OidcClientUpdateDto{
|
|
Name: "Confidential Client",
|
|
CallbackURLs: []string{"https://example.com/callback"},
|
|
},
|
|
}, "test-user-id")
|
|
require.NoError(t, err)
|
|
|
|
// Create a client secret for the confidential client
|
|
confidentialSecret, err := s.CreateClientSecret(t.Context(), confidentialClient.ID)
|
|
require.NoError(t, err)
|
|
|
|
// 2. Public client
|
|
publicClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{
|
|
OidcClientUpdateDto: dto.OidcClientUpdateDto{
|
|
Name: "Public Client",
|
|
CallbackURLs: []string{"https://example.com/callback"},
|
|
IsPublic: true,
|
|
},
|
|
}, "test-user-id")
|
|
require.NoError(t, err)
|
|
|
|
// 3. Confidential client with federated identity
|
|
federatedClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{
|
|
OidcClientUpdateDto: dto.OidcClientUpdateDto{
|
|
Name: "Federated Client",
|
|
CallbackURLs: []string{"https://example.com/callback"},
|
|
},
|
|
}, "test-user-id")
|
|
require.NoError(t, err)
|
|
|
|
federatedClient, err = s.UpdateClient(t.Context(), federatedClient.ID, dto.OidcClientUpdateDto{
|
|
Name: federatedClient.Name,
|
|
CallbackURLs: federatedClient.CallbackURLs,
|
|
Credentials: dto.OidcClientCredentialsDto{
|
|
FederatedIdentities: []dto.OidcClientFederatedIdentityDto{
|
|
{
|
|
Issuer: federatedClientIssuer,
|
|
Audience: federatedClientAudience,
|
|
Subject: federatedClient.ID,
|
|
JWKS: federatedClientIssuer + "/jwks.json",
|
|
},
|
|
{Issuer: federatedClientIssuerDefaults},
|
|
},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// Test cases for confidential client (using client secret)
|
|
t.Run("Confidential client", func(t *testing.T) {
|
|
t.Run("Succeeds with valid secret", func(t *testing.T) {
|
|
// Test with valid client credentials
|
|
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
|
ClientID: confidentialClient.ID,
|
|
ClientSecret: confidentialSecret,
|
|
}, true)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, client)
|
|
assert.Equal(t, confidentialClient.ID, client.ID)
|
|
})
|
|
|
|
t.Run("Fails with invalid secret", func(t *testing.T) {
|
|
// Test with invalid client secret
|
|
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
|
ClientID: confidentialClient.ID,
|
|
ClientSecret: "invalid-secret",
|
|
}, true)
|
|
require.Error(t, err)
|
|
require.ErrorIs(t, err, &common.OidcClientSecretInvalidError{})
|
|
assert.Nil(t, client)
|
|
})
|
|
|
|
t.Run("Fails with missing secret", func(t *testing.T) {
|
|
// Test with missing client secret
|
|
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
|
ClientID: confidentialClient.ID,
|
|
}, true)
|
|
require.Error(t, err)
|
|
require.ErrorIs(t, err, &common.OidcMissingClientCredentialsError{})
|
|
assert.Nil(t, client)
|
|
})
|
|
})
|
|
|
|
// Test cases for public client
|
|
t.Run("Public client", func(t *testing.T) {
|
|
t.Run("Succeeds with no credentials", func(t *testing.T) {
|
|
// Public clients don't require client secret
|
|
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
|
ClientID: publicClient.ID,
|
|
}, true)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, client)
|
|
assert.Equal(t, publicClient.ID, client.ID)
|
|
})
|
|
|
|
t.Run("Fails with no credentials if allowPublicClientsWithoutAuth is false", func(t *testing.T) {
|
|
// Public clients don't require client secret
|
|
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
|
ClientID: publicClient.ID,
|
|
}, false)
|
|
require.Error(t, err)
|
|
require.ErrorIs(t, err, &common.OidcMissingClientCredentialsError{})
|
|
assert.Nil(t, client)
|
|
})
|
|
})
|
|
|
|
// Test cases for federated client using JWT assertion
|
|
t.Run("Federated client", func(t *testing.T) {
|
|
t.Run("Succeeds with valid JWT", func(t *testing.T) {
|
|
// Create JWT for federated identity
|
|
token, err := jwt.NewBuilder().
|
|
Issuer(federatedClientIssuer).
|
|
Audience([]string{federatedClientAudience}).
|
|
Subject(federatedClient.ID).
|
|
IssuedAt(time.Now()).
|
|
Expiration(time.Now().Add(10 * time.Minute)).
|
|
Build()
|
|
require.NoError(t, err)
|
|
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256(), privateJWK))
|
|
require.NoError(t, err)
|
|
|
|
// Test with valid JWT assertion
|
|
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
|
ClientID: federatedClient.ID,
|
|
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
|
ClientAssertion: string(signedToken),
|
|
}, true)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, client)
|
|
assert.Equal(t, federatedClient.ID, client.ID)
|
|
})
|
|
|
|
t.Run("Fails with malformed JWT", func(t *testing.T) {
|
|
// Test with invalid JWT assertion (just a random string)
|
|
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
|
ClientID: federatedClient.ID,
|
|
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
|
ClientAssertion: "invalid.jwt.token",
|
|
}, true)
|
|
require.Error(t, err)
|
|
require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{})
|
|
assert.Nil(t, client)
|
|
})
|
|
|
|
testBadJWT := func(builderFn func(builder *jwt.Builder)) func(t *testing.T) {
|
|
return func(t *testing.T) {
|
|
// Populate all claims with valid values
|
|
builder := jwt.NewBuilder().
|
|
Issuer(federatedClientIssuer).
|
|
Audience([]string{federatedClientAudience}).
|
|
Subject(federatedClient.ID).
|
|
IssuedAt(time.Now()).
|
|
Expiration(time.Now().Add(10 * time.Minute))
|
|
|
|
// Call builderFn to override the claims
|
|
builderFn(builder)
|
|
|
|
token, err := builder.Build()
|
|
require.NoError(t, err)
|
|
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256(), privateJWK))
|
|
require.NoError(t, err)
|
|
|
|
// Test with invalid JWT assertion
|
|
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
|
ClientID: federatedClient.ID,
|
|
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
|
ClientAssertion: string(signedToken),
|
|
}, true)
|
|
require.Error(t, err)
|
|
require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{})
|
|
require.Nil(t, client)
|
|
}
|
|
}
|
|
|
|
t.Run("Fails with expired JWT", testBadJWT(func(builder *jwt.Builder) {
|
|
builder.Expiration(time.Now().Add(-30 * time.Minute))
|
|
}))
|
|
|
|
t.Run("Fails with wrong issuer in JWT", testBadJWT(func(builder *jwt.Builder) {
|
|
builder.Issuer("https://bad-issuer.com")
|
|
}))
|
|
|
|
t.Run("Fails with wrong audience in JWT", testBadJWT(func(builder *jwt.Builder) {
|
|
builder.Audience([]string{"bad-audience"})
|
|
}))
|
|
|
|
t.Run("Fails with wrong subject in JWT", testBadJWT(func(builder *jwt.Builder) {
|
|
builder.Subject("bad-subject")
|
|
}))
|
|
|
|
t.Run("Uses default values for audience and subject", func(t *testing.T) {
|
|
// Create JWT for federated identity
|
|
token, err := jwt.NewBuilder().
|
|
Issuer(federatedClientIssuerDefaults).
|
|
Audience([]string{common.EnvConfig.AppURL}).
|
|
Subject(federatedClient.ID).
|
|
IssuedAt(time.Now()).
|
|
Expiration(time.Now().Add(10 * time.Minute)).
|
|
Build()
|
|
require.NoError(t, err)
|
|
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256(), privateJWKDefaults))
|
|
require.NoError(t, err)
|
|
|
|
// Test with valid JWT assertion
|
|
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
|
ClientID: federatedClient.ID,
|
|
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
|
ClientAssertion: string(signedToken),
|
|
}, true)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, client)
|
|
assert.Equal(t, federatedClient.ID, client.ID)
|
|
})
|
|
})
|
|
|
|
t.Run("Complete token creation flow", func(t *testing.T) {
|
|
t.Run("Client Credentials flow", func(t *testing.T) {
|
|
t.Run("Succeeds with valid secret", func(t *testing.T) {
|
|
// Generate a token
|
|
input := dto.OidcCreateTokensDto{
|
|
ClientID: confidentialClient.ID,
|
|
ClientSecret: confidentialSecret,
|
|
}
|
|
token, err := s.createTokenFromClientCredentials(t.Context(), input)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, token)
|
|
|
|
// Verify the token
|
|
claims, err := s.jwtService.VerifyOAuthAccessToken(token.AccessToken)
|
|
require.NoError(t, err, "Failed to verify generated token")
|
|
|
|
// Check the claims
|
|
subject, ok := claims.Subject()
|
|
_ = assert.True(t, ok, "User ID not found in token") &&
|
|
assert.Equal(t, "client-"+confidentialClient.ID, subject, "Token subject should match confidential client ID with prefix")
|
|
audience, ok := claims.Audience()
|
|
_ = assert.True(t, ok, "Audience not found in token") &&
|
|
assert.Equal(t, []string{confidentialClient.ID}, audience, "Audience should contain confidential client ID")
|
|
})
|
|
|
|
t.Run("Fails with invalid secret", func(t *testing.T) {
|
|
input := dto.OidcCreateTokensDto{
|
|
ClientID: confidentialClient.ID,
|
|
ClientSecret: "invalid-secret",
|
|
}
|
|
_, err := s.createTokenFromClientCredentials(t.Context(), input)
|
|
require.Error(t, err)
|
|
require.ErrorIs(t, err, &common.OidcClientSecretInvalidError{})
|
|
})
|
|
|
|
t.Run("Fails without client secret for public clients", func(t *testing.T) {
|
|
input := dto.OidcCreateTokensDto{
|
|
ClientID: publicClient.ID,
|
|
}
|
|
_, err := s.createTokenFromClientCredentials(t.Context(), input)
|
|
require.Error(t, err)
|
|
require.ErrorIs(t, err, &common.OidcMissingClientCredentialsError{})
|
|
})
|
|
|
|
t.Run("Succeeds with valid assertion", func(t *testing.T) {
|
|
// Create JWT for federated identity
|
|
token, err := jwt.NewBuilder().
|
|
Issuer(federatedClientIssuer).
|
|
Audience([]string{federatedClientAudience}).
|
|
Subject(federatedClient.ID).
|
|
IssuedAt(time.Now()).
|
|
Expiration(time.Now().Add(10 * time.Minute)).
|
|
Build()
|
|
require.NoError(t, err)
|
|
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256(), privateJWK))
|
|
require.NoError(t, err)
|
|
|
|
// Generate a token
|
|
input := dto.OidcCreateTokensDto{
|
|
ClientAssertion: string(signedToken),
|
|
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
|
}
|
|
createdToken, err := s.createTokenFromClientCredentials(t.Context(), input)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, token)
|
|
|
|
// Verify the token
|
|
claims, err := s.jwtService.VerifyOAuthAccessToken(createdToken.AccessToken)
|
|
require.NoError(t, err, "Failed to verify generated token")
|
|
|
|
// Check the claims
|
|
subject, ok := claims.Subject()
|
|
_ = assert.True(t, ok, "User ID not found in token") &&
|
|
assert.Equal(t, "client-"+federatedClient.ID, subject, "Token subject should match federated client ID with prefix")
|
|
audience, ok := claims.Audience()
|
|
_ = assert.True(t, ok, "Audience not found in token") &&
|
|
assert.Equal(t, []string{federatedClient.ID}, audience, "Audience should contain the federated client ID")
|
|
})
|
|
|
|
t.Run("Fails with invalid assertion", func(t *testing.T) {
|
|
input := dto.OidcCreateTokensDto{
|
|
ClientAssertion: "invalid.jwt.token",
|
|
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
|
}
|
|
_, err := s.createTokenFromClientCredentials(t.Context(), input)
|
|
require.Error(t, err)
|
|
require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{})
|
|
})
|
|
|
|
t.Run("Succeeds with custom resource", func(t *testing.T) {
|
|
// Generate a token
|
|
input := dto.OidcCreateTokensDto{
|
|
ClientID: confidentialClient.ID,
|
|
ClientSecret: confidentialSecret,
|
|
Resource: "https://example.com/",
|
|
}
|
|
token, err := s.createTokenFromClientCredentials(t.Context(), input)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, token)
|
|
|
|
// Verify the token
|
|
claims, err := s.jwtService.VerifyOAuthAccessToken(token.AccessToken)
|
|
require.NoError(t, err, "Failed to verify generated token")
|
|
|
|
// Check the claims
|
|
subject, ok := claims.Subject()
|
|
_ = assert.True(t, ok, "User ID not found in token") &&
|
|
assert.Equal(t, "client-"+confidentialClient.ID, subject, "Token subject should match confidential client ID with prefix")
|
|
audience, ok := claims.Audience()
|
|
_ = assert.True(t, ok, "Audience not found in token") &&
|
|
assert.Equal(t, []string{input.Resource}, audience, "Audience should contain the resource provided in request")
|
|
})
|
|
})
|
|
})
|
|
}
|
|
|
|
func TestValidateCodeVerifier_Plain(t *testing.T) {
|
|
require.False(t, validateCodeVerifier("", "", false))
|
|
require.False(t, validateCodeVerifier("", "", true))
|
|
|
|
t.Run("plain", func(t *testing.T) {
|
|
require.False(t, validateCodeVerifier("", "challenge", false))
|
|
require.False(t, validateCodeVerifier("verifier", "", false))
|
|
require.True(t, validateCodeVerifier("plainVerifier", "plainVerifier", false))
|
|
require.False(t, validateCodeVerifier("plainVerifier", "otherVerifier", false))
|
|
})
|
|
|
|
t.Run("SHA 256", func(t *testing.T) {
|
|
codeVerifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
|
hash := sha256.Sum256([]byte(codeVerifier))
|
|
codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
|
|
|
|
require.True(t, validateCodeVerifier(codeVerifier, codeChallenge, true))
|
|
require.False(t, validateCodeVerifier("wrongVerifier", codeChallenge, true))
|
|
require.False(t, validateCodeVerifier(codeVerifier, "!", true))
|
|
|
|
// Invalid base64
|
|
require.False(t, validateCodeVerifier("NOT!VALID", codeChallenge, true))
|
|
})
|
|
}
|
|
|
|
func TestOidcService_updateClientLogoType(t *testing.T) {
|
|
// Create a test database
|
|
db := testutils.NewDatabaseForTest(t)
|
|
|
|
// Create database storage
|
|
dbStorage, err := storage.NewDatabaseStorage(db)
|
|
require.NoError(t, err)
|
|
|
|
// Init the OidcService
|
|
s := &OidcService{
|
|
db: db,
|
|
fileStorage: dbStorage,
|
|
}
|
|
|
|
// Create a test client
|
|
client := model.OidcClient{
|
|
Name: "Test Client",
|
|
CallbackURLs: model.UrlList{"https://example.com/callback"},
|
|
}
|
|
err = db.Create(&client).Error
|
|
require.NoError(t, err)
|
|
|
|
// Helper function to check if a file exists in storage
|
|
fileExists := func(t *testing.T, path string) bool {
|
|
t.Helper()
|
|
_, _, err := dbStorage.Open(t.Context(), path)
|
|
return err == nil
|
|
}
|
|
|
|
// Helper function to create a dummy file in storage
|
|
createDummyFile := func(t *testing.T, path string) {
|
|
t.Helper()
|
|
err := dbStorage.Save(t.Context(), path, strings.NewReader("dummy content"))
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
t.Run("Updates light logo type for client without previous logo", func(t *testing.T) {
|
|
// Update the logo type
|
|
err := s.updateClientLogoType(t.Context(), client.ID, "png", true)
|
|
require.NoError(t, err)
|
|
|
|
// Verify the client was updated
|
|
var updatedClient model.OidcClient
|
|
err = db.First(&updatedClient, "id = ?", client.ID).Error
|
|
require.NoError(t, err)
|
|
require.NotNil(t, updatedClient.ImageType)
|
|
assert.Equal(t, "png", *updatedClient.ImageType)
|
|
})
|
|
|
|
t.Run("Updates dark logo type for client without previous dark logo", func(t *testing.T) {
|
|
// Update the dark logo type
|
|
err := s.updateClientLogoType(t.Context(), client.ID, "jpg", false)
|
|
require.NoError(t, err)
|
|
|
|
// Verify the client was updated
|
|
var updatedClient model.OidcClient
|
|
err = db.First(&updatedClient, "id = ?", client.ID).Error
|
|
require.NoError(t, err)
|
|
require.NotNil(t, updatedClient.DarkImageType)
|
|
assert.Equal(t, "jpg", *updatedClient.DarkImageType)
|
|
})
|
|
|
|
t.Run("Updates light logo type and deletes old file when type changes", func(t *testing.T) {
|
|
// Create the old PNG file in storage
|
|
oldPath := "oidc-client-images/" + client.ID + ".png"
|
|
createDummyFile(t, oldPath)
|
|
require.True(t, fileExists(t, oldPath), "Old file should exist before update")
|
|
|
|
// Client currently has a PNG logo, update to WEBP
|
|
err := s.updateClientLogoType(t.Context(), client.ID, "webp", true)
|
|
require.NoError(t, err)
|
|
|
|
// Verify the client was updated
|
|
var updatedClient model.OidcClient
|
|
err = db.First(&updatedClient, "id = ?", client.ID).Error
|
|
require.NoError(t, err)
|
|
require.NotNil(t, updatedClient.ImageType)
|
|
assert.Equal(t, "webp", *updatedClient.ImageType)
|
|
|
|
// Old PNG file should be deleted
|
|
assert.False(t, fileExists(t, oldPath), "Old PNG file should have been deleted")
|
|
})
|
|
|
|
t.Run("Updates dark logo type and deletes old file when type changes", func(t *testing.T) {
|
|
// Create the old JPG dark file in storage
|
|
oldPath := "oidc-client-images/" + client.ID + "-dark.jpg"
|
|
createDummyFile(t, oldPath)
|
|
require.True(t, fileExists(t, oldPath), "Old dark file should exist before update")
|
|
|
|
// Client currently has a JPG dark logo, update to WEBP
|
|
err := s.updateClientLogoType(t.Context(), client.ID, "webp", false)
|
|
require.NoError(t, err)
|
|
|
|
// Verify the client was updated
|
|
var updatedClient model.OidcClient
|
|
err = db.First(&updatedClient, "id = ?", client.ID).Error
|
|
require.NoError(t, err)
|
|
require.NotNil(t, updatedClient.DarkImageType)
|
|
assert.Equal(t, "webp", *updatedClient.DarkImageType)
|
|
|
|
// Old JPG dark file should be deleted
|
|
assert.False(t, fileExists(t, oldPath), "Old JPG dark file should have been deleted")
|
|
})
|
|
|
|
t.Run("Does not delete file when type remains the same", func(t *testing.T) {
|
|
// Create the WEBP file in storage
|
|
webpPath := "oidc-client-images/" + client.ID + ".webp"
|
|
createDummyFile(t, webpPath)
|
|
require.True(t, fileExists(t, webpPath), "WEBP file should exist before update")
|
|
|
|
// Update to the same type (WEBP)
|
|
err := s.updateClientLogoType(t.Context(), client.ID, "webp", true)
|
|
require.NoError(t, err)
|
|
|
|
// Verify the client still has WEBP
|
|
var updatedClient model.OidcClient
|
|
err = db.First(&updatedClient, "id = ?", client.ID).Error
|
|
require.NoError(t, err)
|
|
require.NotNil(t, updatedClient.ImageType)
|
|
assert.Equal(t, "webp", *updatedClient.ImageType)
|
|
|
|
// WEBP file should still exist since type didn't change
|
|
assert.True(t, fileExists(t, webpPath), "WEBP file should still exist")
|
|
})
|
|
|
|
t.Run("Returns error for non-existent client", func(t *testing.T) {
|
|
err := s.updateClientLogoType(t.Context(), "non-existent-client-id", "png", true)
|
|
require.Error(t, err)
|
|
require.ErrorContains(t, err, "failed to look up client")
|
|
})
|
|
}
|
|
|
|
func TestOidcService_downloadAndSaveLogoFromURL(t *testing.T) {
|
|
// Create a test database
|
|
db := testutils.NewDatabaseForTest(t)
|
|
|
|
// Create database storage
|
|
dbStorage, err := storage.NewDatabaseStorage(db)
|
|
require.NoError(t, err)
|
|
|
|
// Create a test client
|
|
client := model.OidcClient{
|
|
Name: "Test Client",
|
|
CallbackURLs: model.UrlList{"https://example.com/callback"},
|
|
}
|
|
err = db.Create(&client).Error
|
|
require.NoError(t, err)
|
|
|
|
// Helper function to check if a file exists in storage
|
|
fileExists := func(t *testing.T, path string) bool {
|
|
t.Helper()
|
|
_, _, err := dbStorage.Open(t.Context(), path)
|
|
return err == nil
|
|
}
|
|
|
|
// Helper function to get file content from storage
|
|
getFileContent := func(t *testing.T, path string) []byte {
|
|
t.Helper()
|
|
reader, _, err := dbStorage.Open(t.Context(), path)
|
|
require.NoError(t, err)
|
|
defer reader.Close()
|
|
content, err := io.ReadAll(reader)
|
|
require.NoError(t, err)
|
|
return content
|
|
}
|
|
|
|
t.Run("Successfully downloads and saves PNG logo from URL", func(t *testing.T) {
|
|
// Create mock PNG content
|
|
pngContent := []byte("fake-png-content")
|
|
|
|
// Create a mock HTTP response with headers
|
|
//nolint:bodyclose
|
|
pngResponse := testutils.NewMockResponse(http.StatusOK, string(pngContent))
|
|
pngResponse.Header.Set("Content-Type", "image/png")
|
|
|
|
// Create a mock HTTP client with responses
|
|
mockResponses := map[string]*http.Response{
|
|
//nolint:bodyclose
|
|
"https://example.com/logo.png": pngResponse,
|
|
}
|
|
httpClient := &http.Client{
|
|
Transport: &testutils.MockRoundTripper{
|
|
Responses: mockResponses,
|
|
},
|
|
}
|
|
|
|
// Init the OidcService with mock HTTP client
|
|
s := &OidcService{
|
|
db: db,
|
|
fileStorage: dbStorage,
|
|
httpClient: httpClient,
|
|
}
|
|
|
|
// Download and save the logo
|
|
err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/logo.png", true)
|
|
require.NoError(t, err)
|
|
|
|
// Verify the file was saved
|
|
logoPath := "oidc-client-images/" + client.ID + ".png"
|
|
require.True(t, fileExists(t, logoPath), "Logo file should exist in storage")
|
|
|
|
// Verify the content
|
|
savedContent := getFileContent(t, logoPath)
|
|
assert.Equal(t, pngContent, savedContent)
|
|
|
|
// Verify the client was updated
|
|
var updatedClient model.OidcClient
|
|
err = db.First(&updatedClient, "id = ?", client.ID).Error
|
|
require.NoError(t, err)
|
|
require.NotNil(t, updatedClient.ImageType)
|
|
assert.Equal(t, "png", *updatedClient.ImageType)
|
|
})
|
|
|
|
t.Run("Successfully downloads and saves dark logo", func(t *testing.T) {
|
|
// Create mock WEBP content
|
|
webpContent := []byte("fake-webp-content")
|
|
|
|
//nolint:bodyclose
|
|
webpResponse := testutils.NewMockResponse(http.StatusOK, string(webpContent))
|
|
webpResponse.Header.Set("Content-Type", "image/webp")
|
|
|
|
mockResponses := map[string]*http.Response{
|
|
//nolint:bodyclose
|
|
"https://example.com/dark-logo.webp": webpResponse,
|
|
}
|
|
httpClient := &http.Client{
|
|
Transport: &testutils.MockRoundTripper{
|
|
Responses: mockResponses,
|
|
},
|
|
}
|
|
|
|
s := &OidcService{
|
|
db: db,
|
|
fileStorage: dbStorage,
|
|
httpClient: httpClient,
|
|
}
|
|
|
|
// Download and save the dark logo
|
|
err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/dark-logo.webp", false)
|
|
require.NoError(t, err)
|
|
|
|
// Verify the dark logo file was saved
|
|
darkLogoPath := "oidc-client-images/" + client.ID + "-dark.webp"
|
|
require.True(t, fileExists(t, darkLogoPath), "Dark logo file should exist in storage")
|
|
|
|
// Verify the content
|
|
savedContent := getFileContent(t, darkLogoPath)
|
|
assert.Equal(t, webpContent, savedContent)
|
|
|
|
// Verify the client was updated
|
|
var updatedClient model.OidcClient
|
|
err = db.First(&updatedClient, "id = ?", client.ID).Error
|
|
require.NoError(t, err)
|
|
require.NotNil(t, updatedClient.DarkImageType)
|
|
assert.Equal(t, "webp", *updatedClient.DarkImageType)
|
|
})
|
|
|
|
t.Run("Detects extension from URL path", func(t *testing.T) {
|
|
svgContent := []byte("<svg></svg>")
|
|
|
|
mockResponses := map[string]*http.Response{
|
|
//nolint:bodyclose
|
|
"https://example.com/icon.svg": testutils.NewMockResponse(http.StatusOK, string(svgContent)),
|
|
}
|
|
httpClient := &http.Client{
|
|
Transport: &testutils.MockRoundTripper{
|
|
Responses: mockResponses,
|
|
},
|
|
}
|
|
|
|
s := &OidcService{
|
|
db: db,
|
|
fileStorage: dbStorage,
|
|
httpClient: httpClient,
|
|
}
|
|
|
|
err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/icon.svg", true)
|
|
require.NoError(t, err)
|
|
|
|
// Verify SVG file was saved
|
|
logoPath := "oidc-client-images/" + client.ID + ".svg"
|
|
require.True(t, fileExists(t, logoPath), "SVG logo should exist")
|
|
})
|
|
|
|
t.Run("Detects extension from Content-Type when path has no extension", func(t *testing.T) {
|
|
jpgContent := []byte("fake-jpg-content")
|
|
|
|
//nolint:bodyclose
|
|
jpgResponse := testutils.NewMockResponse(http.StatusOK, string(jpgContent))
|
|
jpgResponse.Header.Set("Content-Type", "image/jpeg")
|
|
|
|
mockResponses := map[string]*http.Response{
|
|
//nolint:bodyclose
|
|
"https://example.com/logo": jpgResponse,
|
|
}
|
|
httpClient := &http.Client{
|
|
Transport: &testutils.MockRoundTripper{
|
|
Responses: mockResponses,
|
|
},
|
|
}
|
|
|
|
s := &OidcService{
|
|
db: db,
|
|
fileStorage: dbStorage,
|
|
httpClient: httpClient,
|
|
}
|
|
|
|
err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/logo", true)
|
|
require.NoError(t, err)
|
|
|
|
// Verify JPG file was saved (jpeg extension is normalized to jpg)
|
|
logoPath := "oidc-client-images/" + client.ID + ".jpg"
|
|
require.True(t, fileExists(t, logoPath), "JPG logo should exist")
|
|
})
|
|
|
|
t.Run("Returns error for invalid URL", func(t *testing.T) {
|
|
s := &OidcService{
|
|
db: db,
|
|
fileStorage: dbStorage,
|
|
httpClient: &http.Client{},
|
|
}
|
|
|
|
err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "://invalid-url", true)
|
|
require.Error(t, err)
|
|
})
|
|
|
|
t.Run("Returns error for non-200 status code", func(t *testing.T) {
|
|
mockResponses := map[string]*http.Response{
|
|
//nolint:bodyclose
|
|
"https://example.com/not-found.png": testutils.NewMockResponse(http.StatusNotFound, "Not Found"),
|
|
}
|
|
httpClient := &http.Client{
|
|
Transport: &testutils.MockRoundTripper{
|
|
Responses: mockResponses,
|
|
},
|
|
}
|
|
|
|
s := &OidcService{
|
|
db: db,
|
|
fileStorage: dbStorage,
|
|
httpClient: httpClient,
|
|
}
|
|
|
|
err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/not-found.png", true)
|
|
require.Error(t, err)
|
|
require.ErrorContains(t, err, "failed to fetch logo")
|
|
})
|
|
|
|
t.Run("Returns error for too large content", func(t *testing.T) {
|
|
// Create content larger than 2MB (maxLogoSize)
|
|
largeContent := strings.Repeat("x", 2<<20+100) // 2.1MB
|
|
|
|
//nolint:bodyclose
|
|
largeResponse := testutils.NewMockResponse(http.StatusOK, largeContent)
|
|
largeResponse.Header.Set("Content-Type", "image/png")
|
|
largeResponse.Header.Set("Content-Length", strconv.Itoa(len(largeContent)))
|
|
|
|
mockResponses := map[string]*http.Response{
|
|
//nolint:bodyclose
|
|
"https://example.com/large.png": largeResponse,
|
|
}
|
|
httpClient := &http.Client{
|
|
Transport: &testutils.MockRoundTripper{
|
|
Responses: mockResponses,
|
|
},
|
|
}
|
|
|
|
s := &OidcService{
|
|
db: db,
|
|
fileStorage: dbStorage,
|
|
httpClient: httpClient,
|
|
}
|
|
|
|
err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/large.png", true)
|
|
require.Error(t, err)
|
|
require.ErrorIs(t, err, errLogoTooLarge)
|
|
})
|
|
|
|
t.Run("Returns error for unsupported file type", func(t *testing.T) {
|
|
//nolint:bodyclose
|
|
textResponse := testutils.NewMockResponse(http.StatusOK, "text content")
|
|
textResponse.Header.Set("Content-Type", "text/plain")
|
|
|
|
mockResponses := map[string]*http.Response{
|
|
//nolint:bodyclose
|
|
"https://example.com/file.txt": textResponse,
|
|
}
|
|
httpClient := &http.Client{
|
|
Transport: &testutils.MockRoundTripper{
|
|
Responses: mockResponses,
|
|
},
|
|
}
|
|
|
|
s := &OidcService{
|
|
db: db,
|
|
fileStorage: dbStorage,
|
|
httpClient: httpClient,
|
|
}
|
|
|
|
err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/file.txt", true)
|
|
require.Error(t, err)
|
|
var fileTypeErr *common.FileTypeNotSupportedError
|
|
require.ErrorAs(t, err, &fileTypeErr)
|
|
})
|
|
|
|
t.Run("Returns error for non-existent client", func(t *testing.T) {
|
|
//nolint:bodyclose
|
|
pngResponse := testutils.NewMockResponse(http.StatusOK, "content")
|
|
pngResponse.Header.Set("Content-Type", "image/png")
|
|
|
|
mockResponses := map[string]*http.Response{
|
|
//nolint:bodyclose
|
|
"https://example.com/logo.png": pngResponse,
|
|
}
|
|
httpClient := &http.Client{
|
|
Transport: &testutils.MockRoundTripper{
|
|
Responses: mockResponses,
|
|
},
|
|
}
|
|
|
|
s := &OidcService{
|
|
db: db,
|
|
fileStorage: dbStorage,
|
|
httpClient: httpClient,
|
|
}
|
|
|
|
err := s.downloadAndSaveLogoFromURL(t.Context(), "non-existent-client-id", "https://example.com/logo.png", true)
|
|
require.Error(t, err)
|
|
require.ErrorContains(t, err, "failed to look up client")
|
|
})
|
|
}
|