From b62b61fb017dba31a6fc612c138bebf370d3956c Mon Sep 17 00:00:00 2001 From: "Alessandro (Ale) Segala" <43508+ItalyPaleAle@users.noreply.github.com> Date: Mon, 9 Jun 2025 12:17:55 -0700 Subject: [PATCH] feat: allow introspection and device code endpoints to use Federated Client Credentials (#640) --- .../internal/controller/e2etest_controller.go | 22 ++ .../internal/controller/oidc_controller.go | 18 +- backend/internal/dto/oidc_dto.go | 8 +- backend/internal/service/e2etest_service.go | 4 + backend/internal/service/jwt_service.go | 113 ++++++++- backend/internal/service/jwt_service_test.go | 207 ++++++++++++++- backend/internal/service/oidc_service.go | 223 ++++++++++------ backend/internal/service/oidc_service_test.go | 16 +- backend/internal/utils/http_util.go | 18 ++ backend/internal/utils/http_util_test.go | 65 +++++ backend/internal/utils/jwt_util.go | 5 +- tests/data.ts | 2 + tests/specs/oidc.spec.ts | 240 +++++++++++++++--- 13 files changed, 802 insertions(+), 139 deletions(-) create mode 100644 backend/internal/utils/http_util.go create mode 100644 backend/internal/utils/http_util_test.go diff --git a/backend/internal/controller/e2etest_controller.go b/backend/internal/controller/e2etest_controller.go index d5ecc989..801f6c40 100644 --- a/backend/internal/controller/e2etest_controller.go +++ b/backend/internal/controller/e2etest_controller.go @@ -14,6 +14,7 @@ func NewTestController(group *gin.RouterGroup, testService *service.TestService) testController := &TestController{TestService: testService} group.POST("/test/reset", testController.resetAndSeedHandler) + group.POST("/test/refreshtoken", testController.signRefreshToken) group.GET("/externalidp/jwks.json", testController.externalIdPJWKS) group.POST("/externalidp/sign", testController.externalIdPSignToken) @@ -100,3 +101,24 @@ func (tc *TestController) externalIdPSignToken(c *gin.Context) { c.Writer.WriteString(token) } + +func (tc *TestController) signRefreshToken(c *gin.Context) { + var input struct { + UserID string `json:"user"` + ClientID string `json:"client"` + RefreshToken string `json:"rt"` + } + err := c.ShouldBindJSON(&input) + if err != nil { + _ = c.Error(err) + return + } + + token, err := tc.TestService.SignRefreshToken(input.UserID, input.ClientID, input.RefreshToken) + if err != nil { + _ = c.Error(err) + return + } + + c.Writer.WriteString(token) +} diff --git a/backend/internal/controller/oidc_controller.go b/backend/internal/controller/oidc_controller.go index 754a7955..ff1110a7 100644 --- a/backend/internal/controller/oidc_controller.go +++ b/backend/internal/controller/oidc_controller.go @@ -200,7 +200,7 @@ func (oc *OidcController) userInfoHandler(c *gin.Context) { return } - token, err := oc.jwtService.VerifyOauthAccessToken(authToken) + token, err := oc.jwtService.VerifyOAuthAccessToken(authToken) if err != nil { _ = c.Error(err) return @@ -308,9 +308,21 @@ func (oc *OidcController) introspectTokenHandler(c *gin.Context) { // find valid tokens) while still allowing it to be used by an application that is // supposed to interact with our IdP (since that needs to have a client_id // and client_secret anyway). - clientID, clientSecret, _ := c.Request.BasicAuth() + var ( + creds service.ClientAuthCredentials + ok bool + ) + creds.ClientID, creds.ClientSecret, ok = c.Request.BasicAuth() + if !ok { + // If there's no basic auth, check if we have a bearer token + bearer, ok := utils.BearerAuth(c.Request) + if ok { + creds.ClientAssertionType = service.ClientAssertionTypeJWTBearer + creds.ClientAssertion = bearer + } + } - response, err := oc.oidcService.IntrospectToken(c.Request.Context(), clientID, clientSecret, input.Token) + response, err := oc.oidcService.IntrospectToken(c.Request.Context(), creds, input.Token) if err != nil { _ = c.Error(err) return diff --git a/backend/internal/dto/oidc_dto.go b/backend/internal/dto/oidc_dto.go index 9e9aaf7c..f9cab715 100644 --- a/backend/internal/dto/oidc_dto.go +++ b/backend/internal/dto/oidc_dto.go @@ -113,9 +113,11 @@ type OidcIntrospectionResponseDto struct { } type OidcDeviceAuthorizationRequestDto struct { - ClientID string `form:"client_id" binding:"required"` - Scope string `form:"scope" binding:"required"` - ClientSecret string `form:"client_secret"` + ClientID string `form:"client_id" binding:"required"` + Scope string `form:"scope" binding:"required"` + ClientSecret string `form:"client_secret"` + ClientAssertion string `form:"client_assertion"` + ClientAssertionType string `form:"client_assertion_type"` } type OidcDeviceAuthorizationResponseDto struct { diff --git a/backend/internal/service/e2etest_service.go b/backend/internal/service/e2etest_service.go index 96234a99..f7684870 100644 --- a/backend/internal/service/e2etest_service.go +++ b/backend/internal/service/e2etest_service.go @@ -479,6 +479,10 @@ func (s *TestService) SetLdapTestConfig(ctx context.Context) error { return nil } +func (s *TestService) SignRefreshToken(userID, clientID, refreshToken string) (string, error) { + return s.jwtService.GenerateOAuthRefreshToken(userID, clientID, refreshToken) +} + // GetExternalIdPJWKS returns the JWKS for the "external IdP". func (s *TestService) GetExternalIdPJWKS() (jwk.Set, error) { pubKey, err := s.externalIdPKey.PublicKey() diff --git a/backend/internal/service/jwt_service.go b/backend/internal/service/jwt_service.go index bf3061de..a74653d4 100644 --- a/backend/internal/service/jwt_service.go +++ b/backend/internal/service/jwt_service.go @@ -39,9 +39,15 @@ const ( // TokenTypeClaim is the claim used to identify the type of token TokenTypeClaim = "type" + // RefreshTokenClaim is the claim used for the refresh token's value + RefreshTokenClaim = "rt" + // OAuthAccessTokenJWTType identifies a JWT as an OAuth access token OAuthAccessTokenJWTType = "oauth-access-token" //nolint:gosec + // OAuthRefreshTokenJWTType identifies a JWT as an OAuth refresh token + OAuthRefreshTokenJWTType = "refresh-token" + // AccessTokenJWTType identifies a JWT as an access token used by Pocket ID AccessTokenJWTType = "access-token" @@ -322,8 +328,8 @@ func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool) return token, nil } -// BuildOauthAccessToken creates an OAuth access token with all claims -func (s *JwtService) BuildOauthAccessToken(user model.User, clientID string) (jwt.Token, error) { +// BuildOAuthAccessToken creates an OAuth access token with all claims +func (s *JwtService) BuildOAuthAccessToken(user model.User, clientID string) (jwt.Token, error) { now := time.Now() token, err := jwt.NewBuilder(). Subject(user.ID). @@ -348,9 +354,9 @@ func (s *JwtService) BuildOauthAccessToken(user model.User, clientID string) (jw return token, nil } -// GenerateOauthAccessToken creates and signs an OAuth access token -func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) { - token, err := s.BuildOauthAccessToken(user, clientID) +// GenerateOAuthAccessToken creates and signs an OAuth access token +func (s *JwtService) GenerateOAuthAccessToken(user model.User, clientID string) (string, error) { + token, err := s.BuildOAuthAccessToken(user, clientID) if err != nil { return "", err } @@ -364,7 +370,7 @@ func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) return string(signed), nil } -func (s *JwtService) VerifyOauthAccessToken(tokenString string) (jwt.Token, error) { +func (s *JwtService) VerifyOAuthAccessToken(tokenString string) (jwt.Token, error) { alg, _ := s.privateKey.Algorithm() token, err := jwt.ParseString( tokenString, @@ -381,6 +387,96 @@ func (s *JwtService) VerifyOauthAccessToken(tokenString string) (jwt.Token, erro return token, nil } +func (s *JwtService) GenerateOAuthRefreshToken(userID string, clientID string, refreshToken string) (string, error) { + now := time.Now() + token, err := jwt.NewBuilder(). + Subject(userID). + Expiration(now.Add(RefreshTokenDuration)). + IssuedAt(now). + Issuer(common.EnvConfig.AppURL). + Build() + if err != nil { + return "", fmt.Errorf("failed to build token: %w", err) + } + + err = token.Set(RefreshTokenClaim, refreshToken) + if err != nil { + return "", fmt.Errorf("failed to set 'rt' claim in token: %w", err) + } + + err = SetAudienceString(token, clientID) + if err != nil { + return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err) + } + + err = SetTokenType(token, OAuthRefreshTokenJWTType) + if err != nil { + return "", fmt.Errorf("failed to set 'type' claim in token: %w", err) + } + + alg, _ := s.privateKey.Algorithm() + signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey)) + if err != nil { + return "", fmt.Errorf("failed to sign token: %w", err) + } + + return string(signed), nil +} + +func (s *JwtService) VerifyOAuthRefreshToken(tokenString string) (userID, clientID, rt string, err error) { + alg, _ := s.privateKey.Algorithm() + token, err := jwt.ParseString( + tokenString, + jwt.WithValidate(true), + jwt.WithKey(alg, s.privateKey), + jwt.WithAcceptableSkew(clockSkew), + jwt.WithIssuer(common.EnvConfig.AppURL), + jwt.WithValidator(TokenTypeValidator(OAuthRefreshTokenJWTType)), + ) + if err != nil { + return "", "", "", fmt.Errorf("failed to parse token: %w", err) + } + + err = token.Get(RefreshTokenClaim, &rt) + if err != nil { + return "", "", "", fmt.Errorf("failed to get '%s' claim from token: %w", RefreshTokenClaim, err) + } + + audiences, ok := token.Audience() + if !ok || len(audiences) != 1 || audiences[0] == "" { + return "", "", "", errors.New("failed to get 'aud' claim from token") + } + clientID = audiences[0] + + userID, ok = token.Subject() + if !ok { + return "", "", "", errors.New("failed to get 'sub' claim from token") + } + + return userID, clientID, rt, nil +} + +// GetTokenType returns the type of the JWT token issued by Pocket ID, but **does not validate it**. +func (s *JwtService) GetTokenType(tokenString string) (string, jwt.Token, error) { + // Disable validation and verification to parse the token without checking it + token, err := jwt.ParseString( + tokenString, + jwt.WithValidate(false), + jwt.WithVerify(false), + ) + if err != nil { + return "", nil, fmt.Errorf("failed to parse token: %w", err) + } + + var tokenType string + err = token.Get(TokenTypeClaim, &tokenType) + if err != nil { + return "", nil, fmt.Errorf("failed to get token type claim: %w", err) + } + + return tokenType, token, nil +} + // GetPublicJWK returns the JSON Web Key (JWK) for the public key. func (s *JwtService) GetPublicJWK() (jwk.Key, error) { if s.privateKey == nil { @@ -478,7 +574,10 @@ func GetIsAdmin(token jwt.Token) (bool, error) { } var isAdmin bool err := token.Get(IsAdminClaim, &isAdmin) - return isAdmin, err + if err != nil { + return false, fmt.Errorf("failed to get 'isAdmin' claim from token: %w", err) + } + return isAdmin, nil } // SetTokenType sets the "type" claim in the token diff --git a/backend/internal/service/jwt_service_test.go b/backend/internal/service/jwt_service_test.go index 43c7520c..0a00f2fe 100644 --- a/backend/internal/service/jwt_service_test.go +++ b/backend/internal/service/jwt_service_test.go @@ -882,7 +882,7 @@ func TestGenerateVerifyIdToken(t *testing.T) { }) } -func TestGenerateVerifyOauthAccessToken(t *testing.T) { +func TestGenerateVerifyOAuthAccessToken(t *testing.T) { // Create a temporary directory for the test tempDir := t.TempDir() @@ -914,12 +914,12 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) { const clientID = "test-client-123" // Generate a token - tokenString, err := service.GenerateOauthAccessToken(user, clientID) + tokenString, err := service.GenerateOAuthAccessToken(user, clientID) require.NoError(t, err, "Failed to generate OAuth access token") assert.NotEmpty(t, tokenString, "Token should not be empty") // Verify the token - claims, err := service.VerifyOauthAccessToken(tokenString) + claims, err := service.VerifyOAuthAccessToken(tokenString) require.NoError(t, err, "Failed to verify generated OAuth access token") // Check the claims @@ -972,7 +972,7 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) { require.NoError(t, err, "Failed to sign token") // Verify should fail due to expiration - _, err = service.VerifyOauthAccessToken(string(signed)) + _, err = service.VerifyOAuthAccessToken(string(signed)) require.Error(t, err, "Verification should fail with expired token") assert.Contains(t, err.Error(), `"exp" not satisfied`, "Error message should indicate token verification failure") }) @@ -996,11 +996,11 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) { const clientID = "test-client-789" // Generate a token with the first service - tokenString, err := service1.GenerateOauthAccessToken(user, clientID) + tokenString, err := service1.GenerateOAuthAccessToken(user, clientID) require.NoError(t, err, "Failed to generate OAuth access token") // Verify with the second service should fail due to different keys - _, err = service2.VerifyOauthAccessToken(tokenString) + _, err = service2.VerifyOAuthAccessToken(tokenString) require.Error(t, err, "Verification should fail with invalid signature") assert.Contains(t, err.Error(), "verification error", "Error message should indicate token verification failure") }) @@ -1032,12 +1032,12 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) { const clientID = "eddsa-oauth-client" // Generate a token - tokenString, err := service.GenerateOauthAccessToken(user, clientID) + tokenString, err := service.GenerateOAuthAccessToken(user, clientID) require.NoError(t, err, "Failed to generate OAuth access token with key") assert.NotEmpty(t, tokenString, "Token should not be empty") // Verify the token - claims, err := service.VerifyOauthAccessToken(tokenString) + claims, err := service.VerifyOAuthAccessToken(tokenString) require.NoError(t, err, "Failed to verify generated OAuth access token with key") // Check the claims @@ -1086,12 +1086,12 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) { const clientID = "ecdsa-oauth-client" // Generate a token - tokenString, err := service.GenerateOauthAccessToken(user, clientID) + tokenString, err := service.GenerateOAuthAccessToken(user, clientID) require.NoError(t, err, "Failed to generate OAuth access token with key") assert.NotEmpty(t, tokenString, "Token should not be empty") // Verify the token - claims, err := service.VerifyOauthAccessToken(tokenString) + claims, err := service.VerifyOAuthAccessToken(tokenString) require.NoError(t, err, "Failed to verify generated OAuth access token with key") // Check the claims @@ -1140,12 +1140,12 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) { const clientID = "rsa-oauth-client" // Generate a token - tokenString, err := service.GenerateOauthAccessToken(user, clientID) + tokenString, err := service.GenerateOAuthAccessToken(user, clientID) require.NoError(t, err, "Failed to generate OAuth access token with key") assert.NotEmpty(t, tokenString, "Token should not be empty") // Verify the token - claims, err := service.VerifyOauthAccessToken(tokenString) + claims, err := service.VerifyOAuthAccessToken(tokenString) require.NoError(t, err, "Failed to verify generated OAuth access token with key") // Check the claims @@ -1168,6 +1168,92 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) { }) } +func TestGenerateVerifyOAuthRefreshToken(t *testing.T) { + // Create a temporary directory for the test + tempDir := t.TempDir() + + // Initialize the JWT service with a mock AppConfigService + mockConfig := NewTestAppConfigService(&model.AppConfig{}) + + // Setup the environment variable required by the token verification + originalAppURL := common.EnvConfig.AppURL + common.EnvConfig.AppURL = "https://test.example.com" + defer func() { + common.EnvConfig.AppURL = originalAppURL + }() + + t.Run("generates and verifies refresh token", func(t *testing.T) { + // Create a JWT service + service := &JwtService{} + err := service.init(mockConfig, tempDir) + require.NoError(t, err, "Failed to initialize JWT service") + + // Create a test user + const ( + userID = "user123" + clientID = "client123" + refreshToken = "rt-123" + ) + + // Generate a token + tokenString, err := service.GenerateOAuthRefreshToken(userID, clientID, refreshToken) + require.NoError(t, err, "Failed to generate refresh token") + assert.NotEmpty(t, tokenString, "Token should not be empty") + + // Verify the token + resUser, resClient, resRT, err := service.VerifyOAuthRefreshToken(tokenString) + require.NoError(t, err, "Failed to verify generated token") + assert.Equal(t, userID, resUser, "Should return correct user ID") + assert.Equal(t, clientID, resClient, "Should return correct client ID") + assert.Equal(t, refreshToken, resRT, "Should return correct refresh token") + }) + + t.Run("fails verification for expired token", func(t *testing.T) { + // Create a JWT service + service := &JwtService{} + err := service.init(mockConfig, tempDir) + require.NoError(t, err, "Failed to initialize JWT service") + + // Generate a token using JWT directly to create an expired token + token, err := jwt.NewBuilder(). + Subject("user789"). + Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago + IssuedAt(time.Now().Add(-2 * time.Hour)). + Audience([]string{"client123"}). + Issuer(common.EnvConfig.AppURL). + Build() + require.NoError(t, err, "Failed to build token") + + signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey)) + require.NoError(t, err, "Failed to sign token") + + // Verify should fail due to expiration + _, _, _, err = service.VerifyOAuthRefreshToken(string(signed)) + require.Error(t, err, "Verification should fail with expired token") + assert.Contains(t, err.Error(), `"exp" not satisfied`, "Error message should indicate token verification failure") + }) + + t.Run("fails verification with invalid signature", func(t *testing.T) { + // Create two JWT services with different keys + service1 := &JwtService{} + err := service1.init(mockConfig, t.TempDir()) + require.NoError(t, err, "Failed to initialize first JWT service") + + service2 := &JwtService{} + err = service2.init(mockConfig, t.TempDir()) + require.NoError(t, err, "Failed to initialize second JWT service") + + // Generate a token with the first service + tokenString, err := service1.GenerateOAuthRefreshToken("user789", "client123", "my-rt-123") + require.NoError(t, err, "Failed to generate refresh token") + + // Verify with the second service should fail due to different keys + _, _, _, err = service2.VerifyOAuthRefreshToken(tokenString) + require.Error(t, err, "Verification should fail with invalid signature") + assert.Contains(t, err.Error(), "verification error", "Error message should indicate token verification failure") + }) +} + func TestTokenTypeValidator(t *testing.T) { // Create a context for the validator function ctx := context.Background() @@ -1213,7 +1299,104 @@ func TestTokenTypeValidator(t *testing.T) { require.Error(t, err, "Validator should reject token without type claim") assert.Contains(t, err.Error(), "failed to get token type claim") }) +} +func TestGetTokenType(t *testing.T) { + // Create a temporary directory for the test + tempDir := t.TempDir() + + // Initialize the JWT service + mockConfig := NewTestAppConfigService(&model.AppConfig{}) + service := &JwtService{} + err := service.init(mockConfig, tempDir) + require.NoError(t, err, "Failed to initialize JWT service") + + buildTokenForType := func(t *testing.T, typ string, setClaimsFn func(b *jwt.Builder)) string { + t.Helper() + + b := jwt.NewBuilder() + b.Subject("user123") + if setClaimsFn != nil { + setClaimsFn(b) + } + + token, err := b.Build() + require.NoError(t, err, "Failed to build token") + + err = SetTokenType(token, typ) + require.NoError(t, err, "Failed to set token type") + + alg, _ := service.privateKey.Algorithm() + signed, err := jwt.Sign(token, jwt.WithKey(alg, service.privateKey)) + require.NoError(t, err, "Failed to sign token") + + return string(signed) + } + + t.Run("correctly identifies access tokens", func(t *testing.T) { + tokenString := buildTokenForType(t, AccessTokenJWTType, nil) + + // Get the token type without validating + tokenType, _, err := service.GetTokenType(tokenString) + require.NoError(t, err, "GetTokenType should not return an error") + assert.Equal(t, AccessTokenJWTType, tokenType, "Token type should be correctly identified as access token") + }) + + t.Run("correctly identifies ID tokens", func(t *testing.T) { + tokenString := buildTokenForType(t, IDTokenJWTType, nil) + + // Get the token type without validating + tokenType, _, err := service.GetTokenType(tokenString) + require.NoError(t, err, "GetTokenType should not return an error") + assert.Equal(t, IDTokenJWTType, tokenType, "Token type should be correctly identified as ID token") + }) + + t.Run("correctly identifies OAuth access tokens", func(t *testing.T) { + tokenString := buildTokenForType(t, OAuthAccessTokenJWTType, nil) + + // Get the token type without validating + tokenType, _, err := service.GetTokenType(tokenString) + require.NoError(t, err, "GetTokenType should not return an error") + assert.Equal(t, OAuthAccessTokenJWTType, tokenType, "Token type should be correctly identified as OAuth access token") + }) + + t.Run("correctly identifies refresh tokens", func(t *testing.T) { + tokenString := buildTokenForType(t, OAuthRefreshTokenJWTType, nil) + + // Get the token type without validating + tokenType, _, err := service.GetTokenType(tokenString) + require.NoError(t, err, "GetTokenType should not return an error") + assert.Equal(t, OAuthRefreshTokenJWTType, tokenType, "Token type should be correctly identified as refresh token") + }) + + t.Run("works with expired tokens", func(t *testing.T) { + tokenString := buildTokenForType(t, AccessTokenJWTType, func(b *jwt.Builder) { + b.Expiration(time.Now().Add(-1 * time.Hour)) // Expired 1 hour ago + }) + + // Get the token type without validating + tokenType, _, err := service.GetTokenType(tokenString) + require.NoError(t, err, "GetTokenType should not return an error for expired tokens") + assert.Equal(t, AccessTokenJWTType, tokenType, "Token type should be correctly identified even for expired tokens") + }) + + t.Run("returns error for malformed tokens", func(t *testing.T) { + // Try to get the token type of a malformed token + tokenType, _, err := service.GetTokenType("not.a.valid.jwt.token") + require.Error(t, err, "GetTokenType should return an error for malformed tokens") + assert.Empty(t, tokenType, "Token type should be empty for malformed tokens") + }) + + t.Run("returns error for tokens without type claim", func(t *testing.T) { + // Create a token without type claim + tokenString := buildTokenForType(t, "", nil) + + // Get the token type without validating + tokenType, _, err := service.GetTokenType(tokenString) + require.Error(t, err, "GetTokenType should return an error for tokens without type claim") + assert.Empty(t, tokenType, "Token type should be empty when type claim is missing") + assert.Contains(t, err.Error(), "failed to get token type claim", "Error message should indicate missing token type claim") + }) } func importKey(t *testing.T, privateKeyRaw any, path string) string { diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go index b8d1e8ce..b4e26706 100644 --- a/backend/internal/service/oidc_service.go +++ b/backend/internal/service/oidc_service.go @@ -40,6 +40,9 @@ const ( GrantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code" ClientAssertionTypeJWTBearer = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" //nolint:gosec + + RefreshTokenDuration = 30 * 24 * time.Hour // 30 days + DeviceCodeDuration = 15 * time.Minute ) type OidcService struct { @@ -252,7 +255,7 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O tx.Rollback() }() - _, err := s.verifyClientCredentialsInternal(ctx, tx, input) + _, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input)) if err != nil { return CreatedTokens{}, err } @@ -303,7 +306,7 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O return CreatedTokens{}, err } - accessToken, err := s.jwtService.GenerateOauthAccessToken(deviceAuth.User, input.ClientID) + accessToken, err := s.jwtService.GenerateOAuthAccessToken(deviceAuth.User, input.ClientID) if err != nil { return CreatedTokens{}, err } @@ -333,7 +336,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu tx.Rollback() }() - client, err := s.verifyClientCredentialsInternal(ctx, tx, input) + client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input)) if err != nil { return CreatedTokens{}, err } @@ -375,7 +378,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu return CreatedTokens{}, err } - accessToken, err := s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, input.ClientID) + accessToken, err := s.jwtService.GenerateOAuthAccessToken(authorizationCodeMetaData.User, input.ClientID) if err != nil { return CreatedTokens{}, err } @@ -406,22 +409,39 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto return CreatedTokens{}, &common.OidcMissingRefreshTokenError{} } + // Validate the signed refresh token and extract the actual token (which is a claim in the signed one) + userID, clientID, rt, err := s.jwtService.VerifyOAuthRefreshToken(input.RefreshToken) + if err != nil { + return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{} + } + tx := s.db.Begin() defer func() { tx.Rollback() }() - _, err := s.verifyClientCredentialsInternal(ctx, tx, input) + client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input)) if err != nil { return CreatedTokens{}, err } + // The ID of the client that made the call must match the client ID in the token + if client.ID != clientID { + return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{} + } + // Verify refresh token var storedRefreshToken model.OidcRefreshToken err = tx. WithContext(ctx). Preload("User"). - Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(input.RefreshToken), datatype.DateTime(time.Now())). + Where( + "token = ? AND expires_at > ? AND user_id = ? AND client_id = ?", + utils.CreateSha256Hash(rt), + datatype.DateTime(time.Now()), + userID, + input.ClientID, + ). First(&storedRefreshToken). Error if err != nil { @@ -437,7 +457,7 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto } // Generate a new access token - accessToken, err := s.jwtService.GenerateOauthAccessToken(storedRefreshToken.User, input.ClientID) + accessToken, err := s.jwtService.GenerateOAuthAccessToken(storedRefreshToken.User, input.ClientID) if err != nil { return CreatedTokens{}, err } @@ -469,33 +489,69 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto }, nil } -func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecret, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) { - if clientID == "" || clientSecret == "" { +func (s *OidcService) IntrospectToken(ctx context.Context, creds ClientAuthCredentials, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) { + // Get the type of the token and the client ID + tokenType, token, err := s.jwtService.GetTokenType(tokenString) + if err != nil { + // We just treat the token as invalid + introspectDto.Active = false + return introspectDto, nil //nolint:nilerr + } + + // If we don't have a client ID, get it from the token + // Otherwise, we need to make sure that the client ID passed as credential matches + tokenAudiences, _ := token.Audience() + if len(tokenAudiences) != 1 || tokenAudiences[0] == "" { + // We just treat the token as invalid + introspectDto.Active = false + return introspectDto, nil + } + if creds.ClientID == "" { + creds.ClientID = tokenAudiences[0] + } else if creds.ClientID != tokenAudiences[0] { return introspectDto, &common.OidcMissingClientCredentialsError{} } - _, err = s.verifyClientCredentialsInternal(ctx, s.db, dto.OidcCreateTokensDto{ - ClientID: clientID, - ClientSecret: clientSecret, - }) + // Verify the credentials for the call + client, err := s.verifyClientCredentialsInternal(ctx, s.db, creds) if err != nil { return introspectDto, err } - token, err := s.jwtService.VerifyOauthAccessToken(tokenString) - if err != nil { - if errors.Is(err, jwt.ParseError()) { - // It's apparently not a valid JWT token, so we check if it's a valid refresh_token. - return s.introspectRefreshToken(ctx, tokenString) - } - - // Every failure we get means the token is invalid. Nothing more to do with the error. + // Introspect the token + switch tokenType { + case OAuthAccessTokenJWTType: + return s.introspectAccessToken(client.ID, tokenString) + case OAuthRefreshTokenJWTType: + return s.introspectRefreshToken(ctx, client.ID, tokenString) + default: + // We just treat the token as invalid introspectDto.Active = false return introspectDto, nil } +} + +func (s *OidcService) introspectAccessToken(clientID string, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) { + token, err := s.jwtService.VerifyOAuthAccessToken(tokenString) + if err != nil { + // Every failure we get means the token is invalid. Nothing more to do with the error. + introspectDto.Active = false + return introspectDto, nil //nolint:nilerr + } + + // The ID of the client that made the request must match the client ID in the token + audience, ok := token.Audience() + if !ok || len(audience) != 1 || audience[0] == "" { + introspectDto.Active = false + return introspectDto, nil + } + if audience[0] != clientID { + return introspectDto, &common.OidcMissingClientCredentialsError{} + } introspectDto.Active = true introspectDto.TokenType = "access_token" + introspectDto.Audience = audience if token.Has("scope") { var ( asString string @@ -519,9 +575,6 @@ func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecre if subject, ok := token.Subject(); ok { introspectDto.Subject = subject } - if audience, ok := token.Audience(); ok { - introspectDto.Audience = audience - } if issuer, ok := token.Issuer(); ok { introspectDto.Issuer = issuer } @@ -532,12 +585,29 @@ func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecre return introspectDto, nil } -func (s *OidcService) introspectRefreshToken(ctx context.Context, refreshToken string) (introspectDto dto.OidcIntrospectionResponseDto, err error) { +func (s *OidcService) introspectRefreshToken(ctx context.Context, clientID string, refreshToken string) (introspectDto dto.OidcIntrospectionResponseDto, err error) { + // Validate the signed refresh token and extract the actual token (which is a claim in the signed one) + tokenUserID, tokenClientID, tokenRT, err := s.jwtService.VerifyOAuthRefreshToken(refreshToken) + if err != nil { + return introspectDto, fmt.Errorf("invalid refresh token: %w", err) + } + + // The ID of the client that made the call must match the client ID in the token + if tokenClientID != clientID { + return introspectDto, errors.New("invalid refresh token: client ID does not match") + } + var storedRefreshToken model.OidcRefreshToken err = s.db. WithContext(ctx). Preload("User"). - Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(refreshToken), datatype.DateTime(time.Now())). + Where( + "token = ? AND expires_at > ? AND user_id = ? AND client_id = ?", + utils.CreateSha256Hash(tokenRT), + datatype.DateTime(time.Now()), + tokenUserID, + tokenClientID, + ). First(&storedRefreshToken). Error if err != nil { @@ -1062,9 +1132,11 @@ func (s *OidcService) addCallbackURLToClient(ctx context.Context, client *model. } func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.OidcDeviceAuthorizationRequestDto) (*dto.OidcDeviceAuthorizationResponseDto, error) { - client, err := s.verifyClientCredentialsInternal(ctx, s.db, dto.OidcCreateTokensDto{ - ClientID: input.ClientID, - ClientSecret: input.ClientSecret, + client, err := s.verifyClientCredentialsInternal(ctx, s.db, ClientAuthCredentials{ + ClientID: input.ClientID, + ClientSecret: input.ClientSecret, + ClientAssertionType: input.ClientAssertionType, + ClientAssertion: input.ClientAssertion, }) if err != nil { return nil, err @@ -1085,7 +1157,7 @@ func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.O DeviceCode: deviceCode, UserCode: userCode, Scope: input.Scope, - ExpiresAt: datatype.DateTime(time.Now().Add(15 * time.Minute)), + ExpiresAt: datatype.DateTime(time.Now().Add(DeviceCodeDuration)), IsAuthorized: false, ClientID: client.ID, } @@ -1099,7 +1171,7 @@ func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.O UserCode: userCode, VerificationURI: common.EnvConfig.AppURL + "/device", VerificationURIComplete: common.EnvConfig.AppURL + "/device?code=" + userCode, - ExpiresIn: 900, // 15 minutes + ExpiresIn: int(DeviceCodeDuration.Seconds()), Interval: 5, }, nil } @@ -1255,7 +1327,7 @@ func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, u refreshTokenHash := utils.CreateSha256Hash(refreshToken) m := model.OidcRefreshToken{ - ExpiresAt: datatype.DateTime(time.Now().Add(30 * 24 * time.Hour)), // 30 days + ExpiresAt: datatype.DateTime(time.Now().Add(RefreshTokenDuration)), Token: refreshTokenHash, ClientID: clientID, UserID: userID, @@ -1270,7 +1342,13 @@ func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, u return "", err } - return refreshToken, nil + // Sign the refresh token + signed, err := s.jwtService.GenerateOAuthRefreshToken(userID, clientID, refreshToken) + if err != nil { + return "", fmt.Errorf("failed to sign refresh token: %w", err) + } + + return signed, nil } func (s *OidcService) createAuthorizedClientInternal(ctx context.Context, userID string, clientID string, scope string, tx *gorm.DB) error { @@ -1291,7 +1369,23 @@ func (s *OidcService) createAuthorizedClientInternal(ctx context.Context, userID return err } -func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input dto.OidcCreateTokensDto) (*model.OidcClient, error) { +type ClientAuthCredentials struct { + ClientID string + ClientSecret string + ClientAssertion string + ClientAssertionType string +} + +func clientAuthCredentialsFromCreateTokensDto(d *dto.OidcCreateTokensDto) ClientAuthCredentials { + return ClientAuthCredentials{ + ClientID: d.ClientID, + ClientSecret: d.ClientSecret, + ClientAssertion: d.ClientAssertion, + ClientAssertionType: d.ClientAssertionType, + } +} + +func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input ClientAuthCredentials) (*model.OidcClient, error) { // First, ensure we have a valid client ID if input.ClientID == "" { return nil, &common.OidcMissingClientCredentialsError{} @@ -1366,7 +1460,7 @@ func (s *OidcService) jwkSetForURL(ctx context.Context, url string) (set jwk.Set return jwks, nil } -func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.Context, client *model.OidcClient, input dto.OidcCreateTokensDto) error { +func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.Context, client *model.OidcClient, input ClientAuthCredentials) error { // First, parse the assertion JWT, without validating it, to check the issuer assertion := []byte(input.ClientAssertion) insecureToken, err := jwt.ParseInsecure(assertion) @@ -1466,12 +1560,18 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use return nil, err } + // Commit the transaction before signing tokens to avoid locking the database for longer + err = tx.Commit().Error + if err != nil { + return nil, err + } + idToken, err := s.jwtService.BuildIDToken(userClaims, clientID, "") if err != nil { return nil, err } - accessToken, err := s.jwtService.BuildOauthAccessToken(user, clientID) + accessToken, err := s.jwtService.BuildOAuthAccessToken(user, clientID) if err != nil { return nil, err } @@ -1486,11 +1586,6 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use return nil, err } - err = tx.Commit().Error - if err != nil { - return nil, err - } - return &dto.OidcClientPreviewDto{ IdToken: idTokenPayload, AccessToken: accessTokenPayload, @@ -1498,26 +1593,11 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use }, nil } -func (s *OidcService) GetUserClaimsForClient(ctx context.Context, userID string, clientID string) (map[string]interface{}, error) { - tx := s.db.Begin() - defer func() { - tx.Rollback() - }() - - claims, err := s.getUserClaimsForClientInternal(ctx, userID, clientID, s.db) - if err != nil { - return nil, err - } - - err = tx.Commit().Error - if err != nil { - return nil, err - } - - return claims, nil +func (s *OidcService) GetUserClaimsForClient(ctx context.Context, userID string, clientID string) (map[string]any, error) { + return s.getUserClaimsForClientInternal(ctx, userID, clientID, s.db) } -func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID string, clientID string, tx *gorm.DB) (map[string]interface{}, error) { +func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID string, clientID string, tx *gorm.DB) (map[string]any, error) { var authorizedOidcClient model.UserAuthorizedOidcClient err := tx. WithContext(ctx). @@ -1532,14 +1612,13 @@ func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID } -func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, authorizedClient *model.UserAuthorizedOidcClient, tx *gorm.DB) (map[string]interface{}, error) { +func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, authorizedClient *model.UserAuthorizedOidcClient, tx *gorm.DB) (map[string]any, error) { user := authorizedClient.User scopes := strings.Split(authorizedClient.Scope, " ") - claims := map[string]interface{}{ - "sub": user.ID, - } + claims := make(map[string]any, 10) + claims["sub"] = user.ID if slices.Contains(scopes, "email") { claims["email"] = user.Email claims["email_verified"] = s.appConfigService.GetDbConfig().EmailsVerified.IsTrue() @@ -1553,19 +1632,13 @@ func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, aut claims["groups"] = userGroups } - profileClaims := map[string]interface{}{ - "given_name": user.FirstName, - "family_name": user.LastName, - "name": user.FullName(), - "preferred_username": user.Username, - "picture": common.EnvConfig.AppURL + "/api/users/" + user.ID + "/profile-picture.png", - } - if slices.Contains(scopes, "profile") { // Add profile claims - for k, v := range profileClaims { - claims[k] = v - } + claims["given_name"] = user.FirstName + claims["family_name"] = user.LastName + claims["name"] = user.FullName() + claims["preferred_username"] = user.Username + claims["picture"] = common.EnvConfig.AppURL + "/api/users/" + user.ID + "/profile-picture.png" // Add custom claims customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(ctx, user.ID, tx) @@ -1575,7 +1648,7 @@ func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, aut for _, customClaim := range customClaims { // The value of the custom claim can be a JSON object or a string - var jsonValue interface{} + var jsonValue any err := json.Unmarshal([]byte(customClaim.Value), &jsonValue) if err == nil { // It's JSON, so we store it as an object diff --git a/backend/internal/service/oidc_service_test.go b/backend/internal/service/oidc_service_test.go index 3a230f49..2d2abdda 100644 --- a/backend/internal/service/oidc_service_test.go +++ b/backend/internal/service/oidc_service_test.go @@ -210,7 +210,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) { t.Run("Confidential client", func(t *testing.T) { t.Run("Succeeds with valid secret", func(t *testing.T) { // Test with valid client credentials - client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{ + client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{ ClientID: confidentialClient.ID, ClientSecret: confidentialSecret, }) @@ -221,7 +221,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) { t.Run("Fails with invalid secret", func(t *testing.T) { // Test with invalid client secret - client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{ + client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{ ClientID: confidentialClient.ID, ClientSecret: "invalid-secret", }) @@ -232,7 +232,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) { t.Run("Fails with missing secret", func(t *testing.T) { // Test with missing client secret - client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{ + client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{ ClientID: confidentialClient.ID, }) require.Error(t, err) @@ -245,7 +245,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) { t.Run("Public client", func(t *testing.T) { t.Run("Succeeds with no credentials", func(t *testing.T) { // Public clients don't require client secret - client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{ + client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{ ClientID: publicClient.ID, }) require.NoError(t, err) @@ -270,7 +270,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) { require.NoError(t, err) // Test with valid JWT assertion - client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{ + client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{ ClientID: federatedClient.ID, ClientAssertionType: ClientAssertionTypeJWTBearer, ClientAssertion: string(signedToken), @@ -282,7 +282,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) { t.Run("Fails with malformed JWT", func(t *testing.T) { // Test with invalid JWT assertion (just a random string) - client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{ + client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{ ClientID: federatedClient.ID, ClientAssertionType: ClientAssertionTypeJWTBearer, ClientAssertion: "invalid.jwt.token", @@ -311,7 +311,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) { require.NoError(t, err) // Test with invalid JWT assertion - client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{ + client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{ ClientID: federatedClient.ID, ClientAssertionType: ClientAssertionTypeJWTBearer, ClientAssertion: string(signedToken), @@ -352,7 +352,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) { require.NoError(t, err) // Test with valid JWT assertion - client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{ + client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{ ClientID: federatedClient.ID, ClientAssertionType: ClientAssertionTypeJWTBearer, ClientAssertion: string(signedToken), diff --git a/backend/internal/utils/http_util.go b/backend/internal/utils/http_util.go new file mode 100644 index 00000000..b8c81b3f --- /dev/null +++ b/backend/internal/utils/http_util.go @@ -0,0 +1,18 @@ +package utils + +import ( + "net/http" + "strings" +) + +// BearerAuth returns the value of the bearer token in the Authorization header if present +func BearerAuth(r *http.Request) (string, bool) { + const prefix = "bearer " + + authHeader := r.Header.Get("Authorization") + if len(authHeader) >= len(prefix) && strings.ToLower(authHeader[:len(prefix)]) == prefix { + return authHeader[len(prefix):], true + } + + return "", false +} diff --git a/backend/internal/utils/http_util_test.go b/backend/internal/utils/http_util_test.go new file mode 100644 index 00000000..c754c878 --- /dev/null +++ b/backend/internal/utils/http_util_test.go @@ -0,0 +1,65 @@ +package utils + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBearerAuth(t *testing.T) { + tests := []struct { + name string + authHeader string + expectedToken string + expectedFound bool + }{ + { + name: "Valid bearer token", + authHeader: "Bearer token123", + expectedToken: "token123", + expectedFound: true, + }, + { + name: "Valid bearer token with mixed case", + authHeader: "beARer token456", + expectedToken: "token456", + expectedFound: true, + }, + { + name: "No bearer prefix", + authHeader: "Basic dXNlcjpwYXNz", + expectedToken: "", + expectedFound: false, + }, + { + name: "Empty auth header", + authHeader: "", + expectedToken: "", + expectedFound: false, + }, + { + name: "Bearer prefix only", + authHeader: "Bearer ", + expectedToken: "", + expectedFound: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, "http://example.com", nil) + require.NoError(t, err, "Failed to create request") + + if tt.authHeader != "" { + req.Header.Set("Authorization", tt.authHeader) + } + + token, found := BearerAuth(req) + + assert.Equal(t, tt.expectedFound, found) + assert.Equal(t, tt.expectedToken, token) + }) + } +} diff --git a/backend/internal/utils/jwt_util.go b/backend/internal/utils/jwt_util.go index b02b9f5a..950801e1 100644 --- a/backend/internal/utils/jwt_util.go +++ b/backend/internal/utils/jwt_util.go @@ -7,8 +7,9 @@ import ( ) func GetClaimsFromToken(token jwt.Token) (map[string]any, error) { - claims := make(map[string]any) - for _, key := range token.Keys() { + keys := token.Keys() + claims := make(map[string]any, len(keys)) + for _, key := range keys { var value any if err := token.Get(key, &value); err != nil { return nil, fmt.Errorf("failed to get claim %s: %w", key, err) diff --git a/tests/data.ts b/tests/data.ts index 51e57d13..4bd350e0 100644 --- a/tests/data.ts +++ b/tests/data.ts @@ -87,11 +87,13 @@ export const refreshTokens = [ { token: 'ou87UDg249r1StBLYkMEqy9TXDbV5HmGuDpMcZDo', clientId: oidcClients.nextcloud.id, + userId: 'f4b89dc2-62fb-46bf-9f5f-c34f4eafe93e', expired: false }, { token: 'X4vqwtRyCUaq51UafHea4Fsg8Km6CAns6vp3tuX4', clientId: oidcClients.nextcloud.id, + userId: 'f4b89dc2-62fb-46bf-9f5f-c34f4eafe93e', expired: true } ]; diff --git a/tests/specs/oidc.spec.ts b/tests/specs/oidc.spec.ts index 52f5e35d..9c1c78a9 100644 --- a/tests/specs/oidc.spec.ts +++ b/tests/specs/oidc.spec.ts @@ -156,11 +156,21 @@ test("End session with id token hint redirects to callback URL", async ({ test("Successfully refresh tokens with valid refresh token", async ({ request, }) => { - const { token, clientId } = refreshTokens.filter( + const { token, clientId, userId } = refreshTokens.filter( (token) => !token.expired )[0]; const clientSecret = "w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY"; + // Sign the refresh token + const refreshToken = await request.post("/api/test/refreshtoken", { + data: { + rt: token, + client: clientId, + user: userId, + } + }).then((r) => r.text()) + + // Perform the exchange const refreshResponse = await request.post("/api/oidc/token", { headers: { "Content-Type": "application/x-www-form-urlencoded", @@ -168,7 +178,7 @@ test("Successfully refresh tokens with valid refresh token", async ({ form: { grant_type: "refresh_token", client_id: clientId, - refresh_token: token, + refresh_token: refreshToken, client_secret: clientSecret, }, }); @@ -184,26 +194,25 @@ test("Successfully refresh tokens with valid refresh token", async ({ expect(tokenData.refresh_token).not.toBe(token); }); -test("Using refresh token invalidates it for future use", async ({ + +test("Refresh token fails when used for the wrong client", async ({ request, }) => { - const { token, clientId } = refreshTokens.filter( + const { token, clientId, userId } = refreshTokens.filter( (token) => !token.expired )[0]; const clientSecret = "w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY"; - await request.post("/api/oidc/token", { - headers: { - "Content-Type": "application/x-www-form-urlencoded", - }, - form: { - grant_type: "refresh_token", - client_id: clientId, - refresh_token: token, - client_secret: clientSecret, - }, - }); + // Sign the refresh token + const refreshToken = await request.post("/api/test/refreshtoken", { + data: { + rt: token, + client: 'bad-client', + user: userId, + } + }).then((r) => r.text()) + // Perform the exchange const refreshResponse = await request.post("/api/oidc/token", { headers: { "Content-Type": "application/x-www-form-urlencoded", @@ -211,7 +220,86 @@ test("Using refresh token invalidates it for future use", async ({ form: { grant_type: "refresh_token", client_id: clientId, - refresh_token: token, + refresh_token: refreshToken, + client_secret: clientSecret, + }, + }); + + expect(refreshResponse.status()).toBe(400); +}); + +test("Refresh token fails when used for the wrong user", async ({ + request, +}) => { + const { token, clientId } = refreshTokens.filter( + (token) => !token.expired + )[0]; + const clientSecret = "w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY"; + + // Sign the refresh token + const refreshToken = await request.post("/api/test/refreshtoken", { + data: { + rt: token, + client: clientId, + user: 'bad-user', + } + }).then((r) => r.text()) + + // Perform the exchange + const refreshResponse = await request.post("/api/oidc/token", { + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + form: { + grant_type: "refresh_token", + client_id: clientId, + refresh_token: refreshToken, + client_secret: clientSecret, + }, + }); + + expect(refreshResponse.status()).toBe(400); +}); + +test("Using refresh token invalidates it for future use", async ({ + request, +}) => { + const { token, clientId, userId } = refreshTokens.filter( + (token) => !token.expired + )[0]; + const clientSecret = "w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY"; + + // Sign the refresh token + const refreshToken = await request.post("/api/test/refreshtoken", { + data: { + rt: token, + client: clientId, + user: userId, + } + }).then((r) => r.text()) + + // Perform the exchange + await request.post("/api/oidc/token", { + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + form: { + grant_type: "refresh_token", + client_id: clientId, + refresh_token: refreshToken, + client_secret: clientSecret, + }, + }); + + // Try again + const refreshResponse = await request.post("/api/oidc/token", { + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + form: { + grant_type: "refresh_token", + client_id: clientId, + refresh_token: refreshToken, client_secret: clientSecret, }, }); @@ -219,11 +307,10 @@ test("Using refresh token invalidates it for future use", async ({ }); test.describe("Introspection endpoint", () => { - const client = oidcClients.nextcloud; - test("without client_id and client_secret fails", async ({ request }) => { + test("fails without client credentials", async ({ request }) => { const validAccessToken = await generateOauthAccessToken( users.tim, - client.id + oidcClients.nextcloud.id ); const introspectionResponse = await request.post("/api/oidc/introspect", { headers: { @@ -237,20 +324,20 @@ test.describe("Introspection endpoint", () => { expect(introspectionResponse.status()).toBe(400); }); - test("with client_id and client_secret succeeds", async ({ + test("succeeds with client credentials", async ({ request, baseURL, }) => { const validAccessToken = await generateOauthAccessToken( users.tim, - client.id + oidcClients.nextcloud.id ); const introspectionResponse = await request.post("/api/oidc/introspect", { headers: { "Content-Type": "application/x-www-form-urlencoded", Authorization: "Basic " + - Buffer.from(`${client.id}:${client.secret}`).toString("base64"), + Buffer.from(`${oidcClients.nextcloud.id}:${oidcClients.nextcloud.secret}`).toString("base64"), }, form: { token: validAccessToken, @@ -266,18 +353,102 @@ test.describe("Introspection endpoint", () => { expect(introspectionBody.aud).toStrictEqual([oidcClients.nextcloud.id]); }); + test("succeeds with federated client credentials", async ({ + page, + request, + baseURL, + }) => { + const validAccessToken = await generateOauthAccessToken( + users.tim, + oidcClients.federated.id + ); + const clientAssertion = await oidcUtil.getClientAssertion(page, oidcClients.federated.federatedJWT); + const introspectionResponse = await request.post("/api/oidc/introspect", { + headers: { + "Content-Type": "application/x-www-form-urlencoded", + Authorization: "Bearer " + clientAssertion, + }, + form: { + token: validAccessToken, + }, + }); + + expect(introspectionResponse.status()).toBe(200); + const introspectionBody = await introspectionResponse.json(); + expect(introspectionBody.active).toBe(true); + expect(introspectionBody.token_type).toBe("access_token"); + expect(introspectionBody.iss).toBe(baseURL); + expect(introspectionBody.sub).toBe(users.tim.id); + expect(introspectionBody.aud).toStrictEqual([oidcClients.federated.id]); + }); + + test("fails with client credentials for wrong app", async ({ + request, + }) => { + const validAccessToken = await generateOauthAccessToken( + users.tim, + oidcClients.nextcloud.id + ); + const introspectionResponse = await request.post("/api/oidc/introspect", { + headers: { + "Content-Type": "application/x-www-form-urlencoded", + Authorization: + "Basic " + + Buffer.from(`${oidcClients.immich.id}:${oidcClients.immich.secret}`).toString("base64"), + }, + form: { + token: validAccessToken, + }, + }); + + expect(introspectionResponse.status()).toBe(400); + }); + + test("fails with federated credentials for wrong app", async ({ + page, + request, + }) => { + const validAccessToken = await generateOauthAccessToken( + users.tim, + oidcClients.nextcloud.id + ); + const clientAssertion = await oidcUtil.getClientAssertion(page, oidcClients.federated.federatedJWT); + const introspectionResponse = await request.post("/api/oidc/introspect", { + headers: { + "Content-Type": "application/x-www-form-urlencoded", + Authorization: "Bearer " + clientAssertion, + }, + form: { + token: validAccessToken, + }, + }); + + expect(introspectionResponse.status()).toBe(400); + }); + test("non-expired refresh_token can be verified", async ({ request }) => { - const { token } = refreshTokens.filter((token) => !token.expired)[0]; + const { token, clientId, userId } = refreshTokens.filter( + (token) => !token.expired + )[0]; + + // Sign the refresh token + const refreshToken = await request.post("/api/test/refreshtoken", { + data: { + rt: token, + client: clientId, + user: userId, + } + }).then((r) => r.text()) const introspectionResponse = await request.post("/api/oidc/introspect", { headers: { "Content-Type": "application/x-www-form-urlencoded", Authorization: "Basic " + - Buffer.from(`${client.id}:${client.secret}`).toString("base64"), + Buffer.from(`${oidcClients.nextcloud.id}:${oidcClients.nextcloud.secret}`).toString("base64"), }, form: { - token: token, + token: refreshToken, }, }); @@ -288,17 +459,28 @@ test.describe("Introspection endpoint", () => { }); test("expired refresh_token can be verified", async ({ request }) => { - const { token } = refreshTokens.filter((token) => token.expired)[0]; + const { token, clientId, userId } = refreshTokens.filter( + (token) => token.expired + )[0]; + + // Sign the refresh token + const refreshToken = await request.post("/api/test/refreshtoken", { + data: { + rt: token, + client: clientId, + user: userId, + } + }).then((r) => r.text()) const introspectionResponse = await request.post("/api/oidc/introspect", { headers: { "Content-Type": "application/x-www-form-urlencoded", Authorization: "Basic " + - Buffer.from(`${client.id}:${client.secret}`).toString("base64"), + Buffer.from(`${oidcClients.nextcloud.id}:${oidcClients.nextcloud.secret}`).toString("base64"), }, form: { - token: token, + token: refreshToken, }, }); @@ -310,7 +492,7 @@ test.describe("Introspection endpoint", () => { test("expired access_token can't be verified", async ({ request }) => { const expiredAccessToken = await generateOauthAccessToken( users.tim, - client.id, + oidcClients.nextcloud.id, true ); const introspectionResponse = await request.post("/api/oidc/introspect", {