feat: allow introspection and device code endpoints to use Federated Client Credentials (#640)

This commit is contained in:
Alessandro (Ale) Segala
2025-06-09 12:17:55 -07:00
committed by GitHub
parent df5c1ed1f8
commit b62b61fb01
13 changed files with 802 additions and 139 deletions

View File

@@ -14,6 +14,7 @@ func NewTestController(group *gin.RouterGroup, testService *service.TestService)
testController := &TestController{TestService: testService}
group.POST("/test/reset", testController.resetAndSeedHandler)
group.POST("/test/refreshtoken", testController.signRefreshToken)
group.GET("/externalidp/jwks.json", testController.externalIdPJWKS)
group.POST("/externalidp/sign", testController.externalIdPSignToken)
@@ -100,3 +101,24 @@ func (tc *TestController) externalIdPSignToken(c *gin.Context) {
c.Writer.WriteString(token)
}
func (tc *TestController) signRefreshToken(c *gin.Context) {
var input struct {
UserID string `json:"user"`
ClientID string `json:"client"`
RefreshToken string `json:"rt"`
}
err := c.ShouldBindJSON(&input)
if err != nil {
_ = c.Error(err)
return
}
token, err := tc.TestService.SignRefreshToken(input.UserID, input.ClientID, input.RefreshToken)
if err != nil {
_ = c.Error(err)
return
}
c.Writer.WriteString(token)
}

View File

@@ -200,7 +200,7 @@ func (oc *OidcController) userInfoHandler(c *gin.Context) {
return
}
token, err := oc.jwtService.VerifyOauthAccessToken(authToken)
token, err := oc.jwtService.VerifyOAuthAccessToken(authToken)
if err != nil {
_ = c.Error(err)
return
@@ -308,9 +308,21 @@ func (oc *OidcController) introspectTokenHandler(c *gin.Context) {
// find valid tokens) while still allowing it to be used by an application that is
// supposed to interact with our IdP (since that needs to have a client_id
// and client_secret anyway).
clientID, clientSecret, _ := c.Request.BasicAuth()
var (
creds service.ClientAuthCredentials
ok bool
)
creds.ClientID, creds.ClientSecret, ok = c.Request.BasicAuth()
if !ok {
// If there's no basic auth, check if we have a bearer token
bearer, ok := utils.BearerAuth(c.Request)
if ok {
creds.ClientAssertionType = service.ClientAssertionTypeJWTBearer
creds.ClientAssertion = bearer
}
}
response, err := oc.oidcService.IntrospectToken(c.Request.Context(), clientID, clientSecret, input.Token)
response, err := oc.oidcService.IntrospectToken(c.Request.Context(), creds, input.Token)
if err != nil {
_ = c.Error(err)
return

View File

@@ -113,9 +113,11 @@ type OidcIntrospectionResponseDto struct {
}
type OidcDeviceAuthorizationRequestDto struct {
ClientID string `form:"client_id" binding:"required"`
Scope string `form:"scope" binding:"required"`
ClientSecret string `form:"client_secret"`
ClientID string `form:"client_id" binding:"required"`
Scope string `form:"scope" binding:"required"`
ClientSecret string `form:"client_secret"`
ClientAssertion string `form:"client_assertion"`
ClientAssertionType string `form:"client_assertion_type"`
}
type OidcDeviceAuthorizationResponseDto struct {

View File

@@ -479,6 +479,10 @@ func (s *TestService) SetLdapTestConfig(ctx context.Context) error {
return nil
}
func (s *TestService) SignRefreshToken(userID, clientID, refreshToken string) (string, error) {
return s.jwtService.GenerateOAuthRefreshToken(userID, clientID, refreshToken)
}
// GetExternalIdPJWKS returns the JWKS for the "external IdP".
func (s *TestService) GetExternalIdPJWKS() (jwk.Set, error) {
pubKey, err := s.externalIdPKey.PublicKey()

View File

@@ -39,9 +39,15 @@ const (
// TokenTypeClaim is the claim used to identify the type of token
TokenTypeClaim = "type"
// RefreshTokenClaim is the claim used for the refresh token's value
RefreshTokenClaim = "rt"
// OAuthAccessTokenJWTType identifies a JWT as an OAuth access token
OAuthAccessTokenJWTType = "oauth-access-token" //nolint:gosec
// OAuthRefreshTokenJWTType identifies a JWT as an OAuth refresh token
OAuthRefreshTokenJWTType = "refresh-token"
// AccessTokenJWTType identifies a JWT as an access token used by Pocket ID
AccessTokenJWTType = "access-token"
@@ -322,8 +328,8 @@ func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool)
return token, nil
}
// BuildOauthAccessToken creates an OAuth access token with all claims
func (s *JwtService) BuildOauthAccessToken(user model.User, clientID string) (jwt.Token, error) {
// BuildOAuthAccessToken creates an OAuth access token with all claims
func (s *JwtService) BuildOAuthAccessToken(user model.User, clientID string) (jwt.Token, error) {
now := time.Now()
token, err := jwt.NewBuilder().
Subject(user.ID).
@@ -348,9 +354,9 @@ func (s *JwtService) BuildOauthAccessToken(user model.User, clientID string) (jw
return token, nil
}
// GenerateOauthAccessToken creates and signs an OAuth access token
func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) {
token, err := s.BuildOauthAccessToken(user, clientID)
// GenerateOAuthAccessToken creates and signs an OAuth access token
func (s *JwtService) GenerateOAuthAccessToken(user model.User, clientID string) (string, error) {
token, err := s.BuildOAuthAccessToken(user, clientID)
if err != nil {
return "", err
}
@@ -364,7 +370,7 @@ func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string)
return string(signed), nil
}
func (s *JwtService) VerifyOauthAccessToken(tokenString string) (jwt.Token, error) {
func (s *JwtService) VerifyOAuthAccessToken(tokenString string) (jwt.Token, error) {
alg, _ := s.privateKey.Algorithm()
token, err := jwt.ParseString(
tokenString,
@@ -381,6 +387,96 @@ func (s *JwtService) VerifyOauthAccessToken(tokenString string) (jwt.Token, erro
return token, nil
}
func (s *JwtService) GenerateOAuthRefreshToken(userID string, clientID string, refreshToken string) (string, error) {
now := time.Now()
token, err := jwt.NewBuilder().
Subject(userID).
Expiration(now.Add(RefreshTokenDuration)).
IssuedAt(now).
Issuer(common.EnvConfig.AppURL).
Build()
if err != nil {
return "", fmt.Errorf("failed to build token: %w", err)
}
err = token.Set(RefreshTokenClaim, refreshToken)
if err != nil {
return "", fmt.Errorf("failed to set 'rt' claim in token: %w", err)
}
err = SetAudienceString(token, clientID)
if err != nil {
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
}
err = SetTokenType(token, OAuthRefreshTokenJWTType)
if err != nil {
return "", fmt.Errorf("failed to set 'type' claim in token: %w", err)
}
alg, _ := s.privateKey.Algorithm()
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
if err != nil {
return "", fmt.Errorf("failed to sign token: %w", err)
}
return string(signed), nil
}
func (s *JwtService) VerifyOAuthRefreshToken(tokenString string) (userID, clientID, rt string, err error) {
alg, _ := s.privateKey.Algorithm()
token, err := jwt.ParseString(
tokenString,
jwt.WithValidate(true),
jwt.WithKey(alg, s.privateKey),
jwt.WithAcceptableSkew(clockSkew),
jwt.WithIssuer(common.EnvConfig.AppURL),
jwt.WithValidator(TokenTypeValidator(OAuthRefreshTokenJWTType)),
)
if err != nil {
return "", "", "", fmt.Errorf("failed to parse token: %w", err)
}
err = token.Get(RefreshTokenClaim, &rt)
if err != nil {
return "", "", "", fmt.Errorf("failed to get '%s' claim from token: %w", RefreshTokenClaim, err)
}
audiences, ok := token.Audience()
if !ok || len(audiences) != 1 || audiences[0] == "" {
return "", "", "", errors.New("failed to get 'aud' claim from token")
}
clientID = audiences[0]
userID, ok = token.Subject()
if !ok {
return "", "", "", errors.New("failed to get 'sub' claim from token")
}
return userID, clientID, rt, nil
}
// GetTokenType returns the type of the JWT token issued by Pocket ID, but **does not validate it**.
func (s *JwtService) GetTokenType(tokenString string) (string, jwt.Token, error) {
// Disable validation and verification to parse the token without checking it
token, err := jwt.ParseString(
tokenString,
jwt.WithValidate(false),
jwt.WithVerify(false),
)
if err != nil {
return "", nil, fmt.Errorf("failed to parse token: %w", err)
}
var tokenType string
err = token.Get(TokenTypeClaim, &tokenType)
if err != nil {
return "", nil, fmt.Errorf("failed to get token type claim: %w", err)
}
return tokenType, token, nil
}
// GetPublicJWK returns the JSON Web Key (JWK) for the public key.
func (s *JwtService) GetPublicJWK() (jwk.Key, error) {
if s.privateKey == nil {
@@ -478,7 +574,10 @@ func GetIsAdmin(token jwt.Token) (bool, error) {
}
var isAdmin bool
err := token.Get(IsAdminClaim, &isAdmin)
return isAdmin, err
if err != nil {
return false, fmt.Errorf("failed to get 'isAdmin' claim from token: %w", err)
}
return isAdmin, nil
}
// SetTokenType sets the "type" claim in the token

View File

@@ -882,7 +882,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
})
}
func TestGenerateVerifyOauthAccessToken(t *testing.T) {
func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
@@ -914,12 +914,12 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
const clientID = "test-client-123"
// Generate a token
tokenString, err := service.GenerateOauthAccessToken(user, clientID)
tokenString, err := service.GenerateOAuthAccessToken(user, clientID)
require.NoError(t, err, "Failed to generate OAuth access token")
assert.NotEmpty(t, tokenString, "Token should not be empty")
// Verify the token
claims, err := service.VerifyOauthAccessToken(tokenString)
claims, err := service.VerifyOAuthAccessToken(tokenString)
require.NoError(t, err, "Failed to verify generated OAuth access token")
// Check the claims
@@ -972,7 +972,7 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
require.NoError(t, err, "Failed to sign token")
// Verify should fail due to expiration
_, err = service.VerifyOauthAccessToken(string(signed))
_, err = service.VerifyOAuthAccessToken(string(signed))
require.Error(t, err, "Verification should fail with expired token")
assert.Contains(t, err.Error(), `"exp" not satisfied`, "Error message should indicate token verification failure")
})
@@ -996,11 +996,11 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
const clientID = "test-client-789"
// Generate a token with the first service
tokenString, err := service1.GenerateOauthAccessToken(user, clientID)
tokenString, err := service1.GenerateOAuthAccessToken(user, clientID)
require.NoError(t, err, "Failed to generate OAuth access token")
// Verify with the second service should fail due to different keys
_, err = service2.VerifyOauthAccessToken(tokenString)
_, err = service2.VerifyOAuthAccessToken(tokenString)
require.Error(t, err, "Verification should fail with invalid signature")
assert.Contains(t, err.Error(), "verification error", "Error message should indicate token verification failure")
})
@@ -1032,12 +1032,12 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
const clientID = "eddsa-oauth-client"
// Generate a token
tokenString, err := service.GenerateOauthAccessToken(user, clientID)
tokenString, err := service.GenerateOAuthAccessToken(user, clientID)
require.NoError(t, err, "Failed to generate OAuth access token with key")
assert.NotEmpty(t, tokenString, "Token should not be empty")
// Verify the token
claims, err := service.VerifyOauthAccessToken(tokenString)
claims, err := service.VerifyOAuthAccessToken(tokenString)
require.NoError(t, err, "Failed to verify generated OAuth access token with key")
// Check the claims
@@ -1086,12 +1086,12 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
const clientID = "ecdsa-oauth-client"
// Generate a token
tokenString, err := service.GenerateOauthAccessToken(user, clientID)
tokenString, err := service.GenerateOAuthAccessToken(user, clientID)
require.NoError(t, err, "Failed to generate OAuth access token with key")
assert.NotEmpty(t, tokenString, "Token should not be empty")
// Verify the token
claims, err := service.VerifyOauthAccessToken(tokenString)
claims, err := service.VerifyOAuthAccessToken(tokenString)
require.NoError(t, err, "Failed to verify generated OAuth access token with key")
// Check the claims
@@ -1140,12 +1140,12 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
const clientID = "rsa-oauth-client"
// Generate a token
tokenString, err := service.GenerateOauthAccessToken(user, clientID)
tokenString, err := service.GenerateOAuthAccessToken(user, clientID)
require.NoError(t, err, "Failed to generate OAuth access token with key")
assert.NotEmpty(t, tokenString, "Token should not be empty")
// Verify the token
claims, err := service.VerifyOauthAccessToken(tokenString)
claims, err := service.VerifyOAuthAccessToken(tokenString)
require.NoError(t, err, "Failed to verify generated OAuth access token with key")
// Check the claims
@@ -1168,6 +1168,92 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
})
}
func TestGenerateVerifyOAuthRefreshToken(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
// Initialize the JWT service with a mock AppConfigService
mockConfig := NewTestAppConfigService(&model.AppConfig{})
// Setup the environment variable required by the token verification
originalAppURL := common.EnvConfig.AppURL
common.EnvConfig.AppURL = "https://test.example.com"
defer func() {
common.EnvConfig.AppURL = originalAppURL
}()
t.Run("generates and verifies refresh token", func(t *testing.T) {
// Create a JWT service
service := &JwtService{}
err := service.init(mockConfig, tempDir)
require.NoError(t, err, "Failed to initialize JWT service")
// Create a test user
const (
userID = "user123"
clientID = "client123"
refreshToken = "rt-123"
)
// Generate a token
tokenString, err := service.GenerateOAuthRefreshToken(userID, clientID, refreshToken)
require.NoError(t, err, "Failed to generate refresh token")
assert.NotEmpty(t, tokenString, "Token should not be empty")
// Verify the token
resUser, resClient, resRT, err := service.VerifyOAuthRefreshToken(tokenString)
require.NoError(t, err, "Failed to verify generated token")
assert.Equal(t, userID, resUser, "Should return correct user ID")
assert.Equal(t, clientID, resClient, "Should return correct client ID")
assert.Equal(t, refreshToken, resRT, "Should return correct refresh token")
})
t.Run("fails verification for expired token", func(t *testing.T) {
// Create a JWT service
service := &JwtService{}
err := service.init(mockConfig, tempDir)
require.NoError(t, err, "Failed to initialize JWT service")
// Generate a token using JWT directly to create an expired token
token, err := jwt.NewBuilder().
Subject("user789").
Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago
IssuedAt(time.Now().Add(-2 * time.Hour)).
Audience([]string{"client123"}).
Issuer(common.EnvConfig.AppURL).
Build()
require.NoError(t, err, "Failed to build token")
signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey))
require.NoError(t, err, "Failed to sign token")
// Verify should fail due to expiration
_, _, _, err = service.VerifyOAuthRefreshToken(string(signed))
require.Error(t, err, "Verification should fail with expired token")
assert.Contains(t, err.Error(), `"exp" not satisfied`, "Error message should indicate token verification failure")
})
t.Run("fails verification with invalid signature", func(t *testing.T) {
// Create two JWT services with different keys
service1 := &JwtService{}
err := service1.init(mockConfig, t.TempDir())
require.NoError(t, err, "Failed to initialize first JWT service")
service2 := &JwtService{}
err = service2.init(mockConfig, t.TempDir())
require.NoError(t, err, "Failed to initialize second JWT service")
// Generate a token with the first service
tokenString, err := service1.GenerateOAuthRefreshToken("user789", "client123", "my-rt-123")
require.NoError(t, err, "Failed to generate refresh token")
// Verify with the second service should fail due to different keys
_, _, _, err = service2.VerifyOAuthRefreshToken(tokenString)
require.Error(t, err, "Verification should fail with invalid signature")
assert.Contains(t, err.Error(), "verification error", "Error message should indicate token verification failure")
})
}
func TestTokenTypeValidator(t *testing.T) {
// Create a context for the validator function
ctx := context.Background()
@@ -1213,7 +1299,104 @@ func TestTokenTypeValidator(t *testing.T) {
require.Error(t, err, "Validator should reject token without type claim")
assert.Contains(t, err.Error(), "failed to get token type claim")
})
}
func TestGetTokenType(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
// Initialize the JWT service
mockConfig := NewTestAppConfigService(&model.AppConfig{})
service := &JwtService{}
err := service.init(mockConfig, tempDir)
require.NoError(t, err, "Failed to initialize JWT service")
buildTokenForType := func(t *testing.T, typ string, setClaimsFn func(b *jwt.Builder)) string {
t.Helper()
b := jwt.NewBuilder()
b.Subject("user123")
if setClaimsFn != nil {
setClaimsFn(b)
}
token, err := b.Build()
require.NoError(t, err, "Failed to build token")
err = SetTokenType(token, typ)
require.NoError(t, err, "Failed to set token type")
alg, _ := service.privateKey.Algorithm()
signed, err := jwt.Sign(token, jwt.WithKey(alg, service.privateKey))
require.NoError(t, err, "Failed to sign token")
return string(signed)
}
t.Run("correctly identifies access tokens", func(t *testing.T) {
tokenString := buildTokenForType(t, AccessTokenJWTType, nil)
// Get the token type without validating
tokenType, _, err := service.GetTokenType(tokenString)
require.NoError(t, err, "GetTokenType should not return an error")
assert.Equal(t, AccessTokenJWTType, tokenType, "Token type should be correctly identified as access token")
})
t.Run("correctly identifies ID tokens", func(t *testing.T) {
tokenString := buildTokenForType(t, IDTokenJWTType, nil)
// Get the token type without validating
tokenType, _, err := service.GetTokenType(tokenString)
require.NoError(t, err, "GetTokenType should not return an error")
assert.Equal(t, IDTokenJWTType, tokenType, "Token type should be correctly identified as ID token")
})
t.Run("correctly identifies OAuth access tokens", func(t *testing.T) {
tokenString := buildTokenForType(t, OAuthAccessTokenJWTType, nil)
// Get the token type without validating
tokenType, _, err := service.GetTokenType(tokenString)
require.NoError(t, err, "GetTokenType should not return an error")
assert.Equal(t, OAuthAccessTokenJWTType, tokenType, "Token type should be correctly identified as OAuth access token")
})
t.Run("correctly identifies refresh tokens", func(t *testing.T) {
tokenString := buildTokenForType(t, OAuthRefreshTokenJWTType, nil)
// Get the token type without validating
tokenType, _, err := service.GetTokenType(tokenString)
require.NoError(t, err, "GetTokenType should not return an error")
assert.Equal(t, OAuthRefreshTokenJWTType, tokenType, "Token type should be correctly identified as refresh token")
})
t.Run("works with expired tokens", func(t *testing.T) {
tokenString := buildTokenForType(t, AccessTokenJWTType, func(b *jwt.Builder) {
b.Expiration(time.Now().Add(-1 * time.Hour)) // Expired 1 hour ago
})
// Get the token type without validating
tokenType, _, err := service.GetTokenType(tokenString)
require.NoError(t, err, "GetTokenType should not return an error for expired tokens")
assert.Equal(t, AccessTokenJWTType, tokenType, "Token type should be correctly identified even for expired tokens")
})
t.Run("returns error for malformed tokens", func(t *testing.T) {
// Try to get the token type of a malformed token
tokenType, _, err := service.GetTokenType("not.a.valid.jwt.token")
require.Error(t, err, "GetTokenType should return an error for malformed tokens")
assert.Empty(t, tokenType, "Token type should be empty for malformed tokens")
})
t.Run("returns error for tokens without type claim", func(t *testing.T) {
// Create a token without type claim
tokenString := buildTokenForType(t, "", nil)
// Get the token type without validating
tokenType, _, err := service.GetTokenType(tokenString)
require.Error(t, err, "GetTokenType should return an error for tokens without type claim")
assert.Empty(t, tokenType, "Token type should be empty when type claim is missing")
assert.Contains(t, err.Error(), "failed to get token type claim", "Error message should indicate missing token type claim")
})
}
func importKey(t *testing.T, privateKeyRaw any, path string) string {

View File

@@ -40,6 +40,9 @@ const (
GrantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
ClientAssertionTypeJWTBearer = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" //nolint:gosec
RefreshTokenDuration = 30 * 24 * time.Hour // 30 days
DeviceCodeDuration = 15 * time.Minute
)
type OidcService struct {
@@ -252,7 +255,7 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O
tx.Rollback()
}()
_, err := s.verifyClientCredentialsInternal(ctx, tx, input)
_, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input))
if err != nil {
return CreatedTokens{}, err
}
@@ -303,7 +306,7 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O
return CreatedTokens{}, err
}
accessToken, err := s.jwtService.GenerateOauthAccessToken(deviceAuth.User, input.ClientID)
accessToken, err := s.jwtService.GenerateOAuthAccessToken(deviceAuth.User, input.ClientID)
if err != nil {
return CreatedTokens{}, err
}
@@ -333,7 +336,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu
tx.Rollback()
}()
client, err := s.verifyClientCredentialsInternal(ctx, tx, input)
client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input))
if err != nil {
return CreatedTokens{}, err
}
@@ -375,7 +378,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu
return CreatedTokens{}, err
}
accessToken, err := s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, input.ClientID)
accessToken, err := s.jwtService.GenerateOAuthAccessToken(authorizationCodeMetaData.User, input.ClientID)
if err != nil {
return CreatedTokens{}, err
}
@@ -406,22 +409,39 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
return CreatedTokens{}, &common.OidcMissingRefreshTokenError{}
}
// Validate the signed refresh token and extract the actual token (which is a claim in the signed one)
userID, clientID, rt, err := s.jwtService.VerifyOAuthRefreshToken(input.RefreshToken)
if err != nil {
return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
}
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
_, err := s.verifyClientCredentialsInternal(ctx, tx, input)
client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input))
if err != nil {
return CreatedTokens{}, err
}
// The ID of the client that made the call must match the client ID in the token
if client.ID != clientID {
return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
}
// Verify refresh token
var storedRefreshToken model.OidcRefreshToken
err = tx.
WithContext(ctx).
Preload("User").
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(input.RefreshToken), datatype.DateTime(time.Now())).
Where(
"token = ? AND expires_at > ? AND user_id = ? AND client_id = ?",
utils.CreateSha256Hash(rt),
datatype.DateTime(time.Now()),
userID,
input.ClientID,
).
First(&storedRefreshToken).
Error
if err != nil {
@@ -437,7 +457,7 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
}
// Generate a new access token
accessToken, err := s.jwtService.GenerateOauthAccessToken(storedRefreshToken.User, input.ClientID)
accessToken, err := s.jwtService.GenerateOAuthAccessToken(storedRefreshToken.User, input.ClientID)
if err != nil {
return CreatedTokens{}, err
}
@@ -469,33 +489,69 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
}, nil
}
func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecret, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
if clientID == "" || clientSecret == "" {
func (s *OidcService) IntrospectToken(ctx context.Context, creds ClientAuthCredentials, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
// Get the type of the token and the client ID
tokenType, token, err := s.jwtService.GetTokenType(tokenString)
if err != nil {
// We just treat the token as invalid
introspectDto.Active = false
return introspectDto, nil //nolint:nilerr
}
// If we don't have a client ID, get it from the token
// Otherwise, we need to make sure that the client ID passed as credential matches
tokenAudiences, _ := token.Audience()
if len(tokenAudiences) != 1 || tokenAudiences[0] == "" {
// We just treat the token as invalid
introspectDto.Active = false
return introspectDto, nil
}
if creds.ClientID == "" {
creds.ClientID = tokenAudiences[0]
} else if creds.ClientID != tokenAudiences[0] {
return introspectDto, &common.OidcMissingClientCredentialsError{}
}
_, err = s.verifyClientCredentialsInternal(ctx, s.db, dto.OidcCreateTokensDto{
ClientID: clientID,
ClientSecret: clientSecret,
})
// Verify the credentials for the call
client, err := s.verifyClientCredentialsInternal(ctx, s.db, creds)
if err != nil {
return introspectDto, err
}
token, err := s.jwtService.VerifyOauthAccessToken(tokenString)
if err != nil {
if errors.Is(err, jwt.ParseError()) {
// It's apparently not a valid JWT token, so we check if it's a valid refresh_token.
return s.introspectRefreshToken(ctx, tokenString)
}
// Every failure we get means the token is invalid. Nothing more to do with the error.
// Introspect the token
switch tokenType {
case OAuthAccessTokenJWTType:
return s.introspectAccessToken(client.ID, tokenString)
case OAuthRefreshTokenJWTType:
return s.introspectRefreshToken(ctx, client.ID, tokenString)
default:
// We just treat the token as invalid
introspectDto.Active = false
return introspectDto, nil
}
}
func (s *OidcService) introspectAccessToken(clientID string, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
token, err := s.jwtService.VerifyOAuthAccessToken(tokenString)
if err != nil {
// Every failure we get means the token is invalid. Nothing more to do with the error.
introspectDto.Active = false
return introspectDto, nil //nolint:nilerr
}
// The ID of the client that made the request must match the client ID in the token
audience, ok := token.Audience()
if !ok || len(audience) != 1 || audience[0] == "" {
introspectDto.Active = false
return introspectDto, nil
}
if audience[0] != clientID {
return introspectDto, &common.OidcMissingClientCredentialsError{}
}
introspectDto.Active = true
introspectDto.TokenType = "access_token"
introspectDto.Audience = audience
if token.Has("scope") {
var (
asString string
@@ -519,9 +575,6 @@ func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecre
if subject, ok := token.Subject(); ok {
introspectDto.Subject = subject
}
if audience, ok := token.Audience(); ok {
introspectDto.Audience = audience
}
if issuer, ok := token.Issuer(); ok {
introspectDto.Issuer = issuer
}
@@ -532,12 +585,29 @@ func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecre
return introspectDto, nil
}
func (s *OidcService) introspectRefreshToken(ctx context.Context, refreshToken string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
func (s *OidcService) introspectRefreshToken(ctx context.Context, clientID string, refreshToken string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
// Validate the signed refresh token and extract the actual token (which is a claim in the signed one)
tokenUserID, tokenClientID, tokenRT, err := s.jwtService.VerifyOAuthRefreshToken(refreshToken)
if err != nil {
return introspectDto, fmt.Errorf("invalid refresh token: %w", err)
}
// The ID of the client that made the call must match the client ID in the token
if tokenClientID != clientID {
return introspectDto, errors.New("invalid refresh token: client ID does not match")
}
var storedRefreshToken model.OidcRefreshToken
err = s.db.
WithContext(ctx).
Preload("User").
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(refreshToken), datatype.DateTime(time.Now())).
Where(
"token = ? AND expires_at > ? AND user_id = ? AND client_id = ?",
utils.CreateSha256Hash(tokenRT),
datatype.DateTime(time.Now()),
tokenUserID,
tokenClientID,
).
First(&storedRefreshToken).
Error
if err != nil {
@@ -1062,9 +1132,11 @@ func (s *OidcService) addCallbackURLToClient(ctx context.Context, client *model.
}
func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.OidcDeviceAuthorizationRequestDto) (*dto.OidcDeviceAuthorizationResponseDto, error) {
client, err := s.verifyClientCredentialsInternal(ctx, s.db, dto.OidcCreateTokensDto{
ClientID: input.ClientID,
ClientSecret: input.ClientSecret,
client, err := s.verifyClientCredentialsInternal(ctx, s.db, ClientAuthCredentials{
ClientID: input.ClientID,
ClientSecret: input.ClientSecret,
ClientAssertionType: input.ClientAssertionType,
ClientAssertion: input.ClientAssertion,
})
if err != nil {
return nil, err
@@ -1085,7 +1157,7 @@ func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.O
DeviceCode: deviceCode,
UserCode: userCode,
Scope: input.Scope,
ExpiresAt: datatype.DateTime(time.Now().Add(15 * time.Minute)),
ExpiresAt: datatype.DateTime(time.Now().Add(DeviceCodeDuration)),
IsAuthorized: false,
ClientID: client.ID,
}
@@ -1099,7 +1171,7 @@ func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.O
UserCode: userCode,
VerificationURI: common.EnvConfig.AppURL + "/device",
VerificationURIComplete: common.EnvConfig.AppURL + "/device?code=" + userCode,
ExpiresIn: 900, // 15 minutes
ExpiresIn: int(DeviceCodeDuration.Seconds()),
Interval: 5,
}, nil
}
@@ -1255,7 +1327,7 @@ func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, u
refreshTokenHash := utils.CreateSha256Hash(refreshToken)
m := model.OidcRefreshToken{
ExpiresAt: datatype.DateTime(time.Now().Add(30 * 24 * time.Hour)), // 30 days
ExpiresAt: datatype.DateTime(time.Now().Add(RefreshTokenDuration)),
Token: refreshTokenHash,
ClientID: clientID,
UserID: userID,
@@ -1270,7 +1342,13 @@ func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, u
return "", err
}
return refreshToken, nil
// Sign the refresh token
signed, err := s.jwtService.GenerateOAuthRefreshToken(userID, clientID, refreshToken)
if err != nil {
return "", fmt.Errorf("failed to sign refresh token: %w", err)
}
return signed, nil
}
func (s *OidcService) createAuthorizedClientInternal(ctx context.Context, userID string, clientID string, scope string, tx *gorm.DB) error {
@@ -1291,7 +1369,23 @@ func (s *OidcService) createAuthorizedClientInternal(ctx context.Context, userID
return err
}
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input dto.OidcCreateTokensDto) (*model.OidcClient, error) {
type ClientAuthCredentials struct {
ClientID string
ClientSecret string
ClientAssertion string
ClientAssertionType string
}
func clientAuthCredentialsFromCreateTokensDto(d *dto.OidcCreateTokensDto) ClientAuthCredentials {
return ClientAuthCredentials{
ClientID: d.ClientID,
ClientSecret: d.ClientSecret,
ClientAssertion: d.ClientAssertion,
ClientAssertionType: d.ClientAssertionType,
}
}
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input ClientAuthCredentials) (*model.OidcClient, error) {
// First, ensure we have a valid client ID
if input.ClientID == "" {
return nil, &common.OidcMissingClientCredentialsError{}
@@ -1366,7 +1460,7 @@ func (s *OidcService) jwkSetForURL(ctx context.Context, url string) (set jwk.Set
return jwks, nil
}
func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.Context, client *model.OidcClient, input dto.OidcCreateTokensDto) error {
func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.Context, client *model.OidcClient, input ClientAuthCredentials) error {
// First, parse the assertion JWT, without validating it, to check the issuer
assertion := []byte(input.ClientAssertion)
insecureToken, err := jwt.ParseInsecure(assertion)
@@ -1466,12 +1560,18 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use
return nil, err
}
// Commit the transaction before signing tokens to avoid locking the database for longer
err = tx.Commit().Error
if err != nil {
return nil, err
}
idToken, err := s.jwtService.BuildIDToken(userClaims, clientID, "")
if err != nil {
return nil, err
}
accessToken, err := s.jwtService.BuildOauthAccessToken(user, clientID)
accessToken, err := s.jwtService.BuildOAuthAccessToken(user, clientID)
if err != nil {
return nil, err
}
@@ -1486,11 +1586,6 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use
return nil, err
}
err = tx.Commit().Error
if err != nil {
return nil, err
}
return &dto.OidcClientPreviewDto{
IdToken: idTokenPayload,
AccessToken: accessTokenPayload,
@@ -1498,26 +1593,11 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use
}, nil
}
func (s *OidcService) GetUserClaimsForClient(ctx context.Context, userID string, clientID string) (map[string]interface{}, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
claims, err := s.getUserClaimsForClientInternal(ctx, userID, clientID, s.db)
if err != nil {
return nil, err
}
err = tx.Commit().Error
if err != nil {
return nil, err
}
return claims, nil
func (s *OidcService) GetUserClaimsForClient(ctx context.Context, userID string, clientID string) (map[string]any, error) {
return s.getUserClaimsForClientInternal(ctx, userID, clientID, s.db)
}
func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID string, clientID string, tx *gorm.DB) (map[string]interface{}, error) {
func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID string, clientID string, tx *gorm.DB) (map[string]any, error) {
var authorizedOidcClient model.UserAuthorizedOidcClient
err := tx.
WithContext(ctx).
@@ -1532,14 +1612,13 @@ func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID
}
func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, authorizedClient *model.UserAuthorizedOidcClient, tx *gorm.DB) (map[string]interface{}, error) {
func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, authorizedClient *model.UserAuthorizedOidcClient, tx *gorm.DB) (map[string]any, error) {
user := authorizedClient.User
scopes := strings.Split(authorizedClient.Scope, " ")
claims := map[string]interface{}{
"sub": user.ID,
}
claims := make(map[string]any, 10)
claims["sub"] = user.ID
if slices.Contains(scopes, "email") {
claims["email"] = user.Email
claims["email_verified"] = s.appConfigService.GetDbConfig().EmailsVerified.IsTrue()
@@ -1553,19 +1632,13 @@ func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, aut
claims["groups"] = userGroups
}
profileClaims := map[string]interface{}{
"given_name": user.FirstName,
"family_name": user.LastName,
"name": user.FullName(),
"preferred_username": user.Username,
"picture": common.EnvConfig.AppURL + "/api/users/" + user.ID + "/profile-picture.png",
}
if slices.Contains(scopes, "profile") {
// Add profile claims
for k, v := range profileClaims {
claims[k] = v
}
claims["given_name"] = user.FirstName
claims["family_name"] = user.LastName
claims["name"] = user.FullName()
claims["preferred_username"] = user.Username
claims["picture"] = common.EnvConfig.AppURL + "/api/users/" + user.ID + "/profile-picture.png"
// Add custom claims
customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(ctx, user.ID, tx)
@@ -1575,7 +1648,7 @@ func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, aut
for _, customClaim := range customClaims {
// The value of the custom claim can be a JSON object or a string
var jsonValue interface{}
var jsonValue any
err := json.Unmarshal([]byte(customClaim.Value), &jsonValue)
if err == nil {
// It's JSON, so we store it as an object

View File

@@ -210,7 +210,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
t.Run("Confidential client", func(t *testing.T) {
t.Run("Succeeds with valid secret", func(t *testing.T) {
// Test with valid client credentials
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: confidentialClient.ID,
ClientSecret: confidentialSecret,
})
@@ -221,7 +221,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
t.Run("Fails with invalid secret", func(t *testing.T) {
// Test with invalid client secret
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: confidentialClient.ID,
ClientSecret: "invalid-secret",
})
@@ -232,7 +232,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
t.Run("Fails with missing secret", func(t *testing.T) {
// Test with missing client secret
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: confidentialClient.ID,
})
require.Error(t, err)
@@ -245,7 +245,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
t.Run("Public client", func(t *testing.T) {
t.Run("Succeeds with no credentials", func(t *testing.T) {
// Public clients don't require client secret
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: publicClient.ID,
})
require.NoError(t, err)
@@ -270,7 +270,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
require.NoError(t, err)
// Test with valid JWT assertion
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: federatedClient.ID,
ClientAssertionType: ClientAssertionTypeJWTBearer,
ClientAssertion: string(signedToken),
@@ -282,7 +282,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
t.Run("Fails with malformed JWT", func(t *testing.T) {
// Test with invalid JWT assertion (just a random string)
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: federatedClient.ID,
ClientAssertionType: ClientAssertionTypeJWTBearer,
ClientAssertion: "invalid.jwt.token",
@@ -311,7 +311,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
require.NoError(t, err)
// Test with invalid JWT assertion
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: federatedClient.ID,
ClientAssertionType: ClientAssertionTypeJWTBearer,
ClientAssertion: string(signedToken),
@@ -352,7 +352,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
require.NoError(t, err)
// Test with valid JWT assertion
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: federatedClient.ID,
ClientAssertionType: ClientAssertionTypeJWTBearer,
ClientAssertion: string(signedToken),

View File

@@ -0,0 +1,18 @@
package utils
import (
"net/http"
"strings"
)
// BearerAuth returns the value of the bearer token in the Authorization header if present
func BearerAuth(r *http.Request) (string, bool) {
const prefix = "bearer "
authHeader := r.Header.Get("Authorization")
if len(authHeader) >= len(prefix) && strings.ToLower(authHeader[:len(prefix)]) == prefix {
return authHeader[len(prefix):], true
}
return "", false
}

View File

@@ -0,0 +1,65 @@
package utils
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestBearerAuth(t *testing.T) {
tests := []struct {
name string
authHeader string
expectedToken string
expectedFound bool
}{
{
name: "Valid bearer token",
authHeader: "Bearer token123",
expectedToken: "token123",
expectedFound: true,
},
{
name: "Valid bearer token with mixed case",
authHeader: "beARer token456",
expectedToken: "token456",
expectedFound: true,
},
{
name: "No bearer prefix",
authHeader: "Basic dXNlcjpwYXNz",
expectedToken: "",
expectedFound: false,
},
{
name: "Empty auth header",
authHeader: "",
expectedToken: "",
expectedFound: false,
},
{
name: "Bearer prefix only",
authHeader: "Bearer ",
expectedToken: "",
expectedFound: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, "http://example.com", nil)
require.NoError(t, err, "Failed to create request")
if tt.authHeader != "" {
req.Header.Set("Authorization", tt.authHeader)
}
token, found := BearerAuth(req)
assert.Equal(t, tt.expectedFound, found)
assert.Equal(t, tt.expectedToken, token)
})
}
}

View File

@@ -7,8 +7,9 @@ import (
)
func GetClaimsFromToken(token jwt.Token) (map[string]any, error) {
claims := make(map[string]any)
for _, key := range token.Keys() {
keys := token.Keys()
claims := make(map[string]any, len(keys))
for _, key := range keys {
var value any
if err := token.Get(key, &value); err != nil {
return nil, fmt.Errorf("failed to get claim %s: %w", key, err)

View File

@@ -87,11 +87,13 @@ export const refreshTokens = [
{
token: 'ou87UDg249r1StBLYkMEqy9TXDbV5HmGuDpMcZDo',
clientId: oidcClients.nextcloud.id,
userId: 'f4b89dc2-62fb-46bf-9f5f-c34f4eafe93e',
expired: false
},
{
token: 'X4vqwtRyCUaq51UafHea4Fsg8Km6CAns6vp3tuX4',
clientId: oidcClients.nextcloud.id,
userId: 'f4b89dc2-62fb-46bf-9f5f-c34f4eafe93e',
expired: true
}
];

View File

@@ -156,11 +156,21 @@ test("End session with id token hint redirects to callback URL", async ({
test("Successfully refresh tokens with valid refresh token", async ({
request,
}) => {
const { token, clientId } = refreshTokens.filter(
const { token, clientId, userId } = refreshTokens.filter(
(token) => !token.expired
)[0];
const clientSecret = "w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY";
// Sign the refresh token
const refreshToken = await request.post("/api/test/refreshtoken", {
data: {
rt: token,
client: clientId,
user: userId,
}
}).then((r) => r.text())
// Perform the exchange
const refreshResponse = await request.post("/api/oidc/token", {
headers: {
"Content-Type": "application/x-www-form-urlencoded",
@@ -168,7 +178,7 @@ test("Successfully refresh tokens with valid refresh token", async ({
form: {
grant_type: "refresh_token",
client_id: clientId,
refresh_token: token,
refresh_token: refreshToken,
client_secret: clientSecret,
},
});
@@ -184,26 +194,25 @@ test("Successfully refresh tokens with valid refresh token", async ({
expect(tokenData.refresh_token).not.toBe(token);
});
test("Using refresh token invalidates it for future use", async ({
test("Refresh token fails when used for the wrong client", async ({
request,
}) => {
const { token, clientId } = refreshTokens.filter(
const { token, clientId, userId } = refreshTokens.filter(
(token) => !token.expired
)[0];
const clientSecret = "w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY";
await request.post("/api/oidc/token", {
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
form: {
grant_type: "refresh_token",
client_id: clientId,
refresh_token: token,
client_secret: clientSecret,
},
});
// Sign the refresh token
const refreshToken = await request.post("/api/test/refreshtoken", {
data: {
rt: token,
client: 'bad-client',
user: userId,
}
}).then((r) => r.text())
// Perform the exchange
const refreshResponse = await request.post("/api/oidc/token", {
headers: {
"Content-Type": "application/x-www-form-urlencoded",
@@ -211,7 +220,86 @@ test("Using refresh token invalidates it for future use", async ({
form: {
grant_type: "refresh_token",
client_id: clientId,
refresh_token: token,
refresh_token: refreshToken,
client_secret: clientSecret,
},
});
expect(refreshResponse.status()).toBe(400);
});
test("Refresh token fails when used for the wrong user", async ({
request,
}) => {
const { token, clientId } = refreshTokens.filter(
(token) => !token.expired
)[0];
const clientSecret = "w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY";
// Sign the refresh token
const refreshToken = await request.post("/api/test/refreshtoken", {
data: {
rt: token,
client: clientId,
user: 'bad-user',
}
}).then((r) => r.text())
// Perform the exchange
const refreshResponse = await request.post("/api/oidc/token", {
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
form: {
grant_type: "refresh_token",
client_id: clientId,
refresh_token: refreshToken,
client_secret: clientSecret,
},
});
expect(refreshResponse.status()).toBe(400);
});
test("Using refresh token invalidates it for future use", async ({
request,
}) => {
const { token, clientId, userId } = refreshTokens.filter(
(token) => !token.expired
)[0];
const clientSecret = "w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY";
// Sign the refresh token
const refreshToken = await request.post("/api/test/refreshtoken", {
data: {
rt: token,
client: clientId,
user: userId,
}
}).then((r) => r.text())
// Perform the exchange
await request.post("/api/oidc/token", {
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
form: {
grant_type: "refresh_token",
client_id: clientId,
refresh_token: refreshToken,
client_secret: clientSecret,
},
});
// Try again
const refreshResponse = await request.post("/api/oidc/token", {
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
form: {
grant_type: "refresh_token",
client_id: clientId,
refresh_token: refreshToken,
client_secret: clientSecret,
},
});
@@ -219,11 +307,10 @@ test("Using refresh token invalidates it for future use", async ({
});
test.describe("Introspection endpoint", () => {
const client = oidcClients.nextcloud;
test("without client_id and client_secret fails", async ({ request }) => {
test("fails without client credentials", async ({ request }) => {
const validAccessToken = await generateOauthAccessToken(
users.tim,
client.id
oidcClients.nextcloud.id
);
const introspectionResponse = await request.post("/api/oidc/introspect", {
headers: {
@@ -237,20 +324,20 @@ test.describe("Introspection endpoint", () => {
expect(introspectionResponse.status()).toBe(400);
});
test("with client_id and client_secret succeeds", async ({
test("succeeds with client credentials", async ({
request,
baseURL,
}) => {
const validAccessToken = await generateOauthAccessToken(
users.tim,
client.id
oidcClients.nextcloud.id
);
const introspectionResponse = await request.post("/api/oidc/introspect", {
headers: {
"Content-Type": "application/x-www-form-urlencoded",
Authorization:
"Basic " +
Buffer.from(`${client.id}:${client.secret}`).toString("base64"),
Buffer.from(`${oidcClients.nextcloud.id}:${oidcClients.nextcloud.secret}`).toString("base64"),
},
form: {
token: validAccessToken,
@@ -266,18 +353,102 @@ test.describe("Introspection endpoint", () => {
expect(introspectionBody.aud).toStrictEqual([oidcClients.nextcloud.id]);
});
test("succeeds with federated client credentials", async ({
page,
request,
baseURL,
}) => {
const validAccessToken = await generateOauthAccessToken(
users.tim,
oidcClients.federated.id
);
const clientAssertion = await oidcUtil.getClientAssertion(page, oidcClients.federated.federatedJWT);
const introspectionResponse = await request.post("/api/oidc/introspect", {
headers: {
"Content-Type": "application/x-www-form-urlencoded",
Authorization: "Bearer " + clientAssertion,
},
form: {
token: validAccessToken,
},
});
expect(introspectionResponse.status()).toBe(200);
const introspectionBody = await introspectionResponse.json();
expect(introspectionBody.active).toBe(true);
expect(introspectionBody.token_type).toBe("access_token");
expect(introspectionBody.iss).toBe(baseURL);
expect(introspectionBody.sub).toBe(users.tim.id);
expect(introspectionBody.aud).toStrictEqual([oidcClients.federated.id]);
});
test("fails with client credentials for wrong app", async ({
request,
}) => {
const validAccessToken = await generateOauthAccessToken(
users.tim,
oidcClients.nextcloud.id
);
const introspectionResponse = await request.post("/api/oidc/introspect", {
headers: {
"Content-Type": "application/x-www-form-urlencoded",
Authorization:
"Basic " +
Buffer.from(`${oidcClients.immich.id}:${oidcClients.immich.secret}`).toString("base64"),
},
form: {
token: validAccessToken,
},
});
expect(introspectionResponse.status()).toBe(400);
});
test("fails with federated credentials for wrong app", async ({
page,
request,
}) => {
const validAccessToken = await generateOauthAccessToken(
users.tim,
oidcClients.nextcloud.id
);
const clientAssertion = await oidcUtil.getClientAssertion(page, oidcClients.federated.federatedJWT);
const introspectionResponse = await request.post("/api/oidc/introspect", {
headers: {
"Content-Type": "application/x-www-form-urlencoded",
Authorization: "Bearer " + clientAssertion,
},
form: {
token: validAccessToken,
},
});
expect(introspectionResponse.status()).toBe(400);
});
test("non-expired refresh_token can be verified", async ({ request }) => {
const { token } = refreshTokens.filter((token) => !token.expired)[0];
const { token, clientId, userId } = refreshTokens.filter(
(token) => !token.expired
)[0];
// Sign the refresh token
const refreshToken = await request.post("/api/test/refreshtoken", {
data: {
rt: token,
client: clientId,
user: userId,
}
}).then((r) => r.text())
const introspectionResponse = await request.post("/api/oidc/introspect", {
headers: {
"Content-Type": "application/x-www-form-urlencoded",
Authorization:
"Basic " +
Buffer.from(`${client.id}:${client.secret}`).toString("base64"),
Buffer.from(`${oidcClients.nextcloud.id}:${oidcClients.nextcloud.secret}`).toString("base64"),
},
form: {
token: token,
token: refreshToken,
},
});
@@ -288,17 +459,28 @@ test.describe("Introspection endpoint", () => {
});
test("expired refresh_token can be verified", async ({ request }) => {
const { token } = refreshTokens.filter((token) => token.expired)[0];
const { token, clientId, userId } = refreshTokens.filter(
(token) => token.expired
)[0];
// Sign the refresh token
const refreshToken = await request.post("/api/test/refreshtoken", {
data: {
rt: token,
client: clientId,
user: userId,
}
}).then((r) => r.text())
const introspectionResponse = await request.post("/api/oidc/introspect", {
headers: {
"Content-Type": "application/x-www-form-urlencoded",
Authorization:
"Basic " +
Buffer.from(`${client.id}:${client.secret}`).toString("base64"),
Buffer.from(`${oidcClients.nextcloud.id}:${oidcClients.nextcloud.secret}`).toString("base64"),
},
form: {
token: token,
token: refreshToken,
},
});
@@ -310,7 +492,7 @@ test.describe("Introspection endpoint", () => {
test("expired access_token can't be verified", async ({ request }) => {
const expiredAccessToken = await generateOauthAccessToken(
users.tim,
client.id,
oidcClients.nextcloud.id,
true
);
const introspectionResponse = await request.post("/api/oidc/introspect", {