feat: allow introspection and device code endpoints to use Federated Client Credentials (#640)

This commit is contained in:
Alessandro (Ale) Segala
2025-06-09 12:17:55 -07:00
committed by GitHub
parent df5c1ed1f8
commit b62b61fb01
13 changed files with 802 additions and 139 deletions

View File

@@ -14,6 +14,7 @@ func NewTestController(group *gin.RouterGroup, testService *service.TestService)
testController := &TestController{TestService: testService} testController := &TestController{TestService: testService}
group.POST("/test/reset", testController.resetAndSeedHandler) group.POST("/test/reset", testController.resetAndSeedHandler)
group.POST("/test/refreshtoken", testController.signRefreshToken)
group.GET("/externalidp/jwks.json", testController.externalIdPJWKS) group.GET("/externalidp/jwks.json", testController.externalIdPJWKS)
group.POST("/externalidp/sign", testController.externalIdPSignToken) group.POST("/externalidp/sign", testController.externalIdPSignToken)
@@ -100,3 +101,24 @@ func (tc *TestController) externalIdPSignToken(c *gin.Context) {
c.Writer.WriteString(token) c.Writer.WriteString(token)
} }
func (tc *TestController) signRefreshToken(c *gin.Context) {
var input struct {
UserID string `json:"user"`
ClientID string `json:"client"`
RefreshToken string `json:"rt"`
}
err := c.ShouldBindJSON(&input)
if err != nil {
_ = c.Error(err)
return
}
token, err := tc.TestService.SignRefreshToken(input.UserID, input.ClientID, input.RefreshToken)
if err != nil {
_ = c.Error(err)
return
}
c.Writer.WriteString(token)
}

View File

@@ -200,7 +200,7 @@ func (oc *OidcController) userInfoHandler(c *gin.Context) {
return return
} }
token, err := oc.jwtService.VerifyOauthAccessToken(authToken) token, err := oc.jwtService.VerifyOAuthAccessToken(authToken)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -308,9 +308,21 @@ func (oc *OidcController) introspectTokenHandler(c *gin.Context) {
// find valid tokens) while still allowing it to be used by an application that is // find valid tokens) while still allowing it to be used by an application that is
// supposed to interact with our IdP (since that needs to have a client_id // supposed to interact with our IdP (since that needs to have a client_id
// and client_secret anyway). // and client_secret anyway).
clientID, clientSecret, _ := c.Request.BasicAuth() var (
creds service.ClientAuthCredentials
ok bool
)
creds.ClientID, creds.ClientSecret, ok = c.Request.BasicAuth()
if !ok {
// If there's no basic auth, check if we have a bearer token
bearer, ok := utils.BearerAuth(c.Request)
if ok {
creds.ClientAssertionType = service.ClientAssertionTypeJWTBearer
creds.ClientAssertion = bearer
}
}
response, err := oc.oidcService.IntrospectToken(c.Request.Context(), clientID, clientSecret, input.Token) response, err := oc.oidcService.IntrospectToken(c.Request.Context(), creds, input.Token)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return

View File

@@ -116,6 +116,8 @@ type OidcDeviceAuthorizationRequestDto struct {
ClientID string `form:"client_id" binding:"required"` ClientID string `form:"client_id" binding:"required"`
Scope string `form:"scope" binding:"required"` Scope string `form:"scope" binding:"required"`
ClientSecret string `form:"client_secret"` ClientSecret string `form:"client_secret"`
ClientAssertion string `form:"client_assertion"`
ClientAssertionType string `form:"client_assertion_type"`
} }
type OidcDeviceAuthorizationResponseDto struct { type OidcDeviceAuthorizationResponseDto struct {

View File

@@ -479,6 +479,10 @@ func (s *TestService) SetLdapTestConfig(ctx context.Context) error {
return nil return nil
} }
func (s *TestService) SignRefreshToken(userID, clientID, refreshToken string) (string, error) {
return s.jwtService.GenerateOAuthRefreshToken(userID, clientID, refreshToken)
}
// GetExternalIdPJWKS returns the JWKS for the "external IdP". // GetExternalIdPJWKS returns the JWKS for the "external IdP".
func (s *TestService) GetExternalIdPJWKS() (jwk.Set, error) { func (s *TestService) GetExternalIdPJWKS() (jwk.Set, error) {
pubKey, err := s.externalIdPKey.PublicKey() pubKey, err := s.externalIdPKey.PublicKey()

View File

@@ -39,9 +39,15 @@ const (
// TokenTypeClaim is the claim used to identify the type of token // TokenTypeClaim is the claim used to identify the type of token
TokenTypeClaim = "type" TokenTypeClaim = "type"
// RefreshTokenClaim is the claim used for the refresh token's value
RefreshTokenClaim = "rt"
// OAuthAccessTokenJWTType identifies a JWT as an OAuth access token // OAuthAccessTokenJWTType identifies a JWT as an OAuth access token
OAuthAccessTokenJWTType = "oauth-access-token" //nolint:gosec OAuthAccessTokenJWTType = "oauth-access-token" //nolint:gosec
// OAuthRefreshTokenJWTType identifies a JWT as an OAuth refresh token
OAuthRefreshTokenJWTType = "refresh-token"
// AccessTokenJWTType identifies a JWT as an access token used by Pocket ID // AccessTokenJWTType identifies a JWT as an access token used by Pocket ID
AccessTokenJWTType = "access-token" AccessTokenJWTType = "access-token"
@@ -322,8 +328,8 @@ func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool)
return token, nil return token, nil
} }
// BuildOauthAccessToken creates an OAuth access token with all claims // BuildOAuthAccessToken creates an OAuth access token with all claims
func (s *JwtService) BuildOauthAccessToken(user model.User, clientID string) (jwt.Token, error) { func (s *JwtService) BuildOAuthAccessToken(user model.User, clientID string) (jwt.Token, error) {
now := time.Now() now := time.Now()
token, err := jwt.NewBuilder(). token, err := jwt.NewBuilder().
Subject(user.ID). Subject(user.ID).
@@ -348,9 +354,9 @@ func (s *JwtService) BuildOauthAccessToken(user model.User, clientID string) (jw
return token, nil return token, nil
} }
// GenerateOauthAccessToken creates and signs an OAuth access token // GenerateOAuthAccessToken creates and signs an OAuth access token
func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) { func (s *JwtService) GenerateOAuthAccessToken(user model.User, clientID string) (string, error) {
token, err := s.BuildOauthAccessToken(user, clientID) token, err := s.BuildOAuthAccessToken(user, clientID)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -364,7 +370,7 @@ func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string)
return string(signed), nil return string(signed), nil
} }
func (s *JwtService) VerifyOauthAccessToken(tokenString string) (jwt.Token, error) { func (s *JwtService) VerifyOAuthAccessToken(tokenString string) (jwt.Token, error) {
alg, _ := s.privateKey.Algorithm() alg, _ := s.privateKey.Algorithm()
token, err := jwt.ParseString( token, err := jwt.ParseString(
tokenString, tokenString,
@@ -381,6 +387,96 @@ func (s *JwtService) VerifyOauthAccessToken(tokenString string) (jwt.Token, erro
return token, nil return token, nil
} }
func (s *JwtService) GenerateOAuthRefreshToken(userID string, clientID string, refreshToken string) (string, error) {
now := time.Now()
token, err := jwt.NewBuilder().
Subject(userID).
Expiration(now.Add(RefreshTokenDuration)).
IssuedAt(now).
Issuer(common.EnvConfig.AppURL).
Build()
if err != nil {
return "", fmt.Errorf("failed to build token: %w", err)
}
err = token.Set(RefreshTokenClaim, refreshToken)
if err != nil {
return "", fmt.Errorf("failed to set 'rt' claim in token: %w", err)
}
err = SetAudienceString(token, clientID)
if err != nil {
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
}
err = SetTokenType(token, OAuthRefreshTokenJWTType)
if err != nil {
return "", fmt.Errorf("failed to set 'type' claim in token: %w", err)
}
alg, _ := s.privateKey.Algorithm()
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
if err != nil {
return "", fmt.Errorf("failed to sign token: %w", err)
}
return string(signed), nil
}
func (s *JwtService) VerifyOAuthRefreshToken(tokenString string) (userID, clientID, rt string, err error) {
alg, _ := s.privateKey.Algorithm()
token, err := jwt.ParseString(
tokenString,
jwt.WithValidate(true),
jwt.WithKey(alg, s.privateKey),
jwt.WithAcceptableSkew(clockSkew),
jwt.WithIssuer(common.EnvConfig.AppURL),
jwt.WithValidator(TokenTypeValidator(OAuthRefreshTokenJWTType)),
)
if err != nil {
return "", "", "", fmt.Errorf("failed to parse token: %w", err)
}
err = token.Get(RefreshTokenClaim, &rt)
if err != nil {
return "", "", "", fmt.Errorf("failed to get '%s' claim from token: %w", RefreshTokenClaim, err)
}
audiences, ok := token.Audience()
if !ok || len(audiences) != 1 || audiences[0] == "" {
return "", "", "", errors.New("failed to get 'aud' claim from token")
}
clientID = audiences[0]
userID, ok = token.Subject()
if !ok {
return "", "", "", errors.New("failed to get 'sub' claim from token")
}
return userID, clientID, rt, nil
}
// GetTokenType returns the type of the JWT token issued by Pocket ID, but **does not validate it**.
func (s *JwtService) GetTokenType(tokenString string) (string, jwt.Token, error) {
// Disable validation and verification to parse the token without checking it
token, err := jwt.ParseString(
tokenString,
jwt.WithValidate(false),
jwt.WithVerify(false),
)
if err != nil {
return "", nil, fmt.Errorf("failed to parse token: %w", err)
}
var tokenType string
err = token.Get(TokenTypeClaim, &tokenType)
if err != nil {
return "", nil, fmt.Errorf("failed to get token type claim: %w", err)
}
return tokenType, token, nil
}
// GetPublicJWK returns the JSON Web Key (JWK) for the public key. // GetPublicJWK returns the JSON Web Key (JWK) for the public key.
func (s *JwtService) GetPublicJWK() (jwk.Key, error) { func (s *JwtService) GetPublicJWK() (jwk.Key, error) {
if s.privateKey == nil { if s.privateKey == nil {
@@ -478,7 +574,10 @@ func GetIsAdmin(token jwt.Token) (bool, error) {
} }
var isAdmin bool var isAdmin bool
err := token.Get(IsAdminClaim, &isAdmin) err := token.Get(IsAdminClaim, &isAdmin)
return isAdmin, err if err != nil {
return false, fmt.Errorf("failed to get 'isAdmin' claim from token: %w", err)
}
return isAdmin, nil
} }
// SetTokenType sets the "type" claim in the token // SetTokenType sets the "type" claim in the token

View File

@@ -882,7 +882,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
}) })
} }
func TestGenerateVerifyOauthAccessToken(t *testing.T) { func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
// Create a temporary directory for the test // Create a temporary directory for the test
tempDir := t.TempDir() tempDir := t.TempDir()
@@ -914,12 +914,12 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
const clientID = "test-client-123" const clientID = "test-client-123"
// Generate a token // Generate a token
tokenString, err := service.GenerateOauthAccessToken(user, clientID) tokenString, err := service.GenerateOAuthAccessToken(user, clientID)
require.NoError(t, err, "Failed to generate OAuth access token") require.NoError(t, err, "Failed to generate OAuth access token")
assert.NotEmpty(t, tokenString, "Token should not be empty") assert.NotEmpty(t, tokenString, "Token should not be empty")
// Verify the token // Verify the token
claims, err := service.VerifyOauthAccessToken(tokenString) claims, err := service.VerifyOAuthAccessToken(tokenString)
require.NoError(t, err, "Failed to verify generated OAuth access token") require.NoError(t, err, "Failed to verify generated OAuth access token")
// Check the claims // Check the claims
@@ -972,7 +972,7 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
require.NoError(t, err, "Failed to sign token") require.NoError(t, err, "Failed to sign token")
// Verify should fail due to expiration // Verify should fail due to expiration
_, err = service.VerifyOauthAccessToken(string(signed)) _, err = service.VerifyOAuthAccessToken(string(signed))
require.Error(t, err, "Verification should fail with expired token") require.Error(t, err, "Verification should fail with expired token")
assert.Contains(t, err.Error(), `"exp" not satisfied`, "Error message should indicate token verification failure") assert.Contains(t, err.Error(), `"exp" not satisfied`, "Error message should indicate token verification failure")
}) })
@@ -996,11 +996,11 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
const clientID = "test-client-789" const clientID = "test-client-789"
// Generate a token with the first service // Generate a token with the first service
tokenString, err := service1.GenerateOauthAccessToken(user, clientID) tokenString, err := service1.GenerateOAuthAccessToken(user, clientID)
require.NoError(t, err, "Failed to generate OAuth access token") require.NoError(t, err, "Failed to generate OAuth access token")
// Verify with the second service should fail due to different keys // Verify with the second service should fail due to different keys
_, err = service2.VerifyOauthAccessToken(tokenString) _, err = service2.VerifyOAuthAccessToken(tokenString)
require.Error(t, err, "Verification should fail with invalid signature") require.Error(t, err, "Verification should fail with invalid signature")
assert.Contains(t, err.Error(), "verification error", "Error message should indicate token verification failure") assert.Contains(t, err.Error(), "verification error", "Error message should indicate token verification failure")
}) })
@@ -1032,12 +1032,12 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
const clientID = "eddsa-oauth-client" const clientID = "eddsa-oauth-client"
// Generate a token // Generate a token
tokenString, err := service.GenerateOauthAccessToken(user, clientID) tokenString, err := service.GenerateOAuthAccessToken(user, clientID)
require.NoError(t, err, "Failed to generate OAuth access token with key") require.NoError(t, err, "Failed to generate OAuth access token with key")
assert.NotEmpty(t, tokenString, "Token should not be empty") assert.NotEmpty(t, tokenString, "Token should not be empty")
// Verify the token // Verify the token
claims, err := service.VerifyOauthAccessToken(tokenString) claims, err := service.VerifyOAuthAccessToken(tokenString)
require.NoError(t, err, "Failed to verify generated OAuth access token with key") require.NoError(t, err, "Failed to verify generated OAuth access token with key")
// Check the claims // Check the claims
@@ -1086,12 +1086,12 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
const clientID = "ecdsa-oauth-client" const clientID = "ecdsa-oauth-client"
// Generate a token // Generate a token
tokenString, err := service.GenerateOauthAccessToken(user, clientID) tokenString, err := service.GenerateOAuthAccessToken(user, clientID)
require.NoError(t, err, "Failed to generate OAuth access token with key") require.NoError(t, err, "Failed to generate OAuth access token with key")
assert.NotEmpty(t, tokenString, "Token should not be empty") assert.NotEmpty(t, tokenString, "Token should not be empty")
// Verify the token // Verify the token
claims, err := service.VerifyOauthAccessToken(tokenString) claims, err := service.VerifyOAuthAccessToken(tokenString)
require.NoError(t, err, "Failed to verify generated OAuth access token with key") require.NoError(t, err, "Failed to verify generated OAuth access token with key")
// Check the claims // Check the claims
@@ -1140,12 +1140,12 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
const clientID = "rsa-oauth-client" const clientID = "rsa-oauth-client"
// Generate a token // Generate a token
tokenString, err := service.GenerateOauthAccessToken(user, clientID) tokenString, err := service.GenerateOAuthAccessToken(user, clientID)
require.NoError(t, err, "Failed to generate OAuth access token with key") require.NoError(t, err, "Failed to generate OAuth access token with key")
assert.NotEmpty(t, tokenString, "Token should not be empty") assert.NotEmpty(t, tokenString, "Token should not be empty")
// Verify the token // Verify the token
claims, err := service.VerifyOauthAccessToken(tokenString) claims, err := service.VerifyOAuthAccessToken(tokenString)
require.NoError(t, err, "Failed to verify generated OAuth access token with key") require.NoError(t, err, "Failed to verify generated OAuth access token with key")
// Check the claims // Check the claims
@@ -1168,6 +1168,92 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
}) })
} }
func TestGenerateVerifyOAuthRefreshToken(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
// Initialize the JWT service with a mock AppConfigService
mockConfig := NewTestAppConfigService(&model.AppConfig{})
// Setup the environment variable required by the token verification
originalAppURL := common.EnvConfig.AppURL
common.EnvConfig.AppURL = "https://test.example.com"
defer func() {
common.EnvConfig.AppURL = originalAppURL
}()
t.Run("generates and verifies refresh token", func(t *testing.T) {
// Create a JWT service
service := &JwtService{}
err := service.init(mockConfig, tempDir)
require.NoError(t, err, "Failed to initialize JWT service")
// Create a test user
const (
userID = "user123"
clientID = "client123"
refreshToken = "rt-123"
)
// Generate a token
tokenString, err := service.GenerateOAuthRefreshToken(userID, clientID, refreshToken)
require.NoError(t, err, "Failed to generate refresh token")
assert.NotEmpty(t, tokenString, "Token should not be empty")
// Verify the token
resUser, resClient, resRT, err := service.VerifyOAuthRefreshToken(tokenString)
require.NoError(t, err, "Failed to verify generated token")
assert.Equal(t, userID, resUser, "Should return correct user ID")
assert.Equal(t, clientID, resClient, "Should return correct client ID")
assert.Equal(t, refreshToken, resRT, "Should return correct refresh token")
})
t.Run("fails verification for expired token", func(t *testing.T) {
// Create a JWT service
service := &JwtService{}
err := service.init(mockConfig, tempDir)
require.NoError(t, err, "Failed to initialize JWT service")
// Generate a token using JWT directly to create an expired token
token, err := jwt.NewBuilder().
Subject("user789").
Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago
IssuedAt(time.Now().Add(-2 * time.Hour)).
Audience([]string{"client123"}).
Issuer(common.EnvConfig.AppURL).
Build()
require.NoError(t, err, "Failed to build token")
signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey))
require.NoError(t, err, "Failed to sign token")
// Verify should fail due to expiration
_, _, _, err = service.VerifyOAuthRefreshToken(string(signed))
require.Error(t, err, "Verification should fail with expired token")
assert.Contains(t, err.Error(), `"exp" not satisfied`, "Error message should indicate token verification failure")
})
t.Run("fails verification with invalid signature", func(t *testing.T) {
// Create two JWT services with different keys
service1 := &JwtService{}
err := service1.init(mockConfig, t.TempDir())
require.NoError(t, err, "Failed to initialize first JWT service")
service2 := &JwtService{}
err = service2.init(mockConfig, t.TempDir())
require.NoError(t, err, "Failed to initialize second JWT service")
// Generate a token with the first service
tokenString, err := service1.GenerateOAuthRefreshToken("user789", "client123", "my-rt-123")
require.NoError(t, err, "Failed to generate refresh token")
// Verify with the second service should fail due to different keys
_, _, _, err = service2.VerifyOAuthRefreshToken(tokenString)
require.Error(t, err, "Verification should fail with invalid signature")
assert.Contains(t, err.Error(), "verification error", "Error message should indicate token verification failure")
})
}
func TestTokenTypeValidator(t *testing.T) { func TestTokenTypeValidator(t *testing.T) {
// Create a context for the validator function // Create a context for the validator function
ctx := context.Background() ctx := context.Background()
@@ -1213,7 +1299,104 @@ func TestTokenTypeValidator(t *testing.T) {
require.Error(t, err, "Validator should reject token without type claim") require.Error(t, err, "Validator should reject token without type claim")
assert.Contains(t, err.Error(), "failed to get token type claim") assert.Contains(t, err.Error(), "failed to get token type claim")
}) })
}
func TestGetTokenType(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
// Initialize the JWT service
mockConfig := NewTestAppConfigService(&model.AppConfig{})
service := &JwtService{}
err := service.init(mockConfig, tempDir)
require.NoError(t, err, "Failed to initialize JWT service")
buildTokenForType := func(t *testing.T, typ string, setClaimsFn func(b *jwt.Builder)) string {
t.Helper()
b := jwt.NewBuilder()
b.Subject("user123")
if setClaimsFn != nil {
setClaimsFn(b)
}
token, err := b.Build()
require.NoError(t, err, "Failed to build token")
err = SetTokenType(token, typ)
require.NoError(t, err, "Failed to set token type")
alg, _ := service.privateKey.Algorithm()
signed, err := jwt.Sign(token, jwt.WithKey(alg, service.privateKey))
require.NoError(t, err, "Failed to sign token")
return string(signed)
}
t.Run("correctly identifies access tokens", func(t *testing.T) {
tokenString := buildTokenForType(t, AccessTokenJWTType, nil)
// Get the token type without validating
tokenType, _, err := service.GetTokenType(tokenString)
require.NoError(t, err, "GetTokenType should not return an error")
assert.Equal(t, AccessTokenJWTType, tokenType, "Token type should be correctly identified as access token")
})
t.Run("correctly identifies ID tokens", func(t *testing.T) {
tokenString := buildTokenForType(t, IDTokenJWTType, nil)
// Get the token type without validating
tokenType, _, err := service.GetTokenType(tokenString)
require.NoError(t, err, "GetTokenType should not return an error")
assert.Equal(t, IDTokenJWTType, tokenType, "Token type should be correctly identified as ID token")
})
t.Run("correctly identifies OAuth access tokens", func(t *testing.T) {
tokenString := buildTokenForType(t, OAuthAccessTokenJWTType, nil)
// Get the token type without validating
tokenType, _, err := service.GetTokenType(tokenString)
require.NoError(t, err, "GetTokenType should not return an error")
assert.Equal(t, OAuthAccessTokenJWTType, tokenType, "Token type should be correctly identified as OAuth access token")
})
t.Run("correctly identifies refresh tokens", func(t *testing.T) {
tokenString := buildTokenForType(t, OAuthRefreshTokenJWTType, nil)
// Get the token type without validating
tokenType, _, err := service.GetTokenType(tokenString)
require.NoError(t, err, "GetTokenType should not return an error")
assert.Equal(t, OAuthRefreshTokenJWTType, tokenType, "Token type should be correctly identified as refresh token")
})
t.Run("works with expired tokens", func(t *testing.T) {
tokenString := buildTokenForType(t, AccessTokenJWTType, func(b *jwt.Builder) {
b.Expiration(time.Now().Add(-1 * time.Hour)) // Expired 1 hour ago
})
// Get the token type without validating
tokenType, _, err := service.GetTokenType(tokenString)
require.NoError(t, err, "GetTokenType should not return an error for expired tokens")
assert.Equal(t, AccessTokenJWTType, tokenType, "Token type should be correctly identified even for expired tokens")
})
t.Run("returns error for malformed tokens", func(t *testing.T) {
// Try to get the token type of a malformed token
tokenType, _, err := service.GetTokenType("not.a.valid.jwt.token")
require.Error(t, err, "GetTokenType should return an error for malformed tokens")
assert.Empty(t, tokenType, "Token type should be empty for malformed tokens")
})
t.Run("returns error for tokens without type claim", func(t *testing.T) {
// Create a token without type claim
tokenString := buildTokenForType(t, "", nil)
// Get the token type without validating
tokenType, _, err := service.GetTokenType(tokenString)
require.Error(t, err, "GetTokenType should return an error for tokens without type claim")
assert.Empty(t, tokenType, "Token type should be empty when type claim is missing")
assert.Contains(t, err.Error(), "failed to get token type claim", "Error message should indicate missing token type claim")
})
} }
func importKey(t *testing.T, privateKeyRaw any, path string) string { func importKey(t *testing.T, privateKeyRaw any, path string) string {

View File

@@ -40,6 +40,9 @@ const (
GrantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code" GrantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
ClientAssertionTypeJWTBearer = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" //nolint:gosec ClientAssertionTypeJWTBearer = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" //nolint:gosec
RefreshTokenDuration = 30 * 24 * time.Hour // 30 days
DeviceCodeDuration = 15 * time.Minute
) )
type OidcService struct { type OidcService struct {
@@ -252,7 +255,7 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O
tx.Rollback() tx.Rollback()
}() }()
_, err := s.verifyClientCredentialsInternal(ctx, tx, input) _, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input))
if err != nil { if err != nil {
return CreatedTokens{}, err return CreatedTokens{}, err
} }
@@ -303,7 +306,7 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O
return CreatedTokens{}, err return CreatedTokens{}, err
} }
accessToken, err := s.jwtService.GenerateOauthAccessToken(deviceAuth.User, input.ClientID) accessToken, err := s.jwtService.GenerateOAuthAccessToken(deviceAuth.User, input.ClientID)
if err != nil { if err != nil {
return CreatedTokens{}, err return CreatedTokens{}, err
} }
@@ -333,7 +336,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu
tx.Rollback() tx.Rollback()
}() }()
client, err := s.verifyClientCredentialsInternal(ctx, tx, input) client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input))
if err != nil { if err != nil {
return CreatedTokens{}, err return CreatedTokens{}, err
} }
@@ -375,7 +378,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu
return CreatedTokens{}, err return CreatedTokens{}, err
} }
accessToken, err := s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, input.ClientID) accessToken, err := s.jwtService.GenerateOAuthAccessToken(authorizationCodeMetaData.User, input.ClientID)
if err != nil { if err != nil {
return CreatedTokens{}, err return CreatedTokens{}, err
} }
@@ -406,22 +409,39 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
return CreatedTokens{}, &common.OidcMissingRefreshTokenError{} return CreatedTokens{}, &common.OidcMissingRefreshTokenError{}
} }
// Validate the signed refresh token and extract the actual token (which is a claim in the signed one)
userID, clientID, rt, err := s.jwtService.VerifyOAuthRefreshToken(input.RefreshToken)
if err != nil {
return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
}
tx := s.db.Begin() tx := s.db.Begin()
defer func() { defer func() {
tx.Rollback() tx.Rollback()
}() }()
_, err := s.verifyClientCredentialsInternal(ctx, tx, input) client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input))
if err != nil { if err != nil {
return CreatedTokens{}, err return CreatedTokens{}, err
} }
// The ID of the client that made the call must match the client ID in the token
if client.ID != clientID {
return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
}
// Verify refresh token // Verify refresh token
var storedRefreshToken model.OidcRefreshToken var storedRefreshToken model.OidcRefreshToken
err = tx. err = tx.
WithContext(ctx). WithContext(ctx).
Preload("User"). Preload("User").
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(input.RefreshToken), datatype.DateTime(time.Now())). Where(
"token = ? AND expires_at > ? AND user_id = ? AND client_id = ?",
utils.CreateSha256Hash(rt),
datatype.DateTime(time.Now()),
userID,
input.ClientID,
).
First(&storedRefreshToken). First(&storedRefreshToken).
Error Error
if err != nil { if err != nil {
@@ -437,7 +457,7 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
} }
// Generate a new access token // Generate a new access token
accessToken, err := s.jwtService.GenerateOauthAccessToken(storedRefreshToken.User, input.ClientID) accessToken, err := s.jwtService.GenerateOAuthAccessToken(storedRefreshToken.User, input.ClientID)
if err != nil { if err != nil {
return CreatedTokens{}, err return CreatedTokens{}, err
} }
@@ -469,33 +489,69 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
}, nil }, nil
} }
func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecret, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) { func (s *OidcService) IntrospectToken(ctx context.Context, creds ClientAuthCredentials, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
if clientID == "" || clientSecret == "" { // Get the type of the token and the client ID
tokenType, token, err := s.jwtService.GetTokenType(tokenString)
if err != nil {
// We just treat the token as invalid
introspectDto.Active = false
return introspectDto, nil //nolint:nilerr
}
// If we don't have a client ID, get it from the token
// Otherwise, we need to make sure that the client ID passed as credential matches
tokenAudiences, _ := token.Audience()
if len(tokenAudiences) != 1 || tokenAudiences[0] == "" {
// We just treat the token as invalid
introspectDto.Active = false
return introspectDto, nil
}
if creds.ClientID == "" {
creds.ClientID = tokenAudiences[0]
} else if creds.ClientID != tokenAudiences[0] {
return introspectDto, &common.OidcMissingClientCredentialsError{} return introspectDto, &common.OidcMissingClientCredentialsError{}
} }
_, err = s.verifyClientCredentialsInternal(ctx, s.db, dto.OidcCreateTokensDto{ // Verify the credentials for the call
ClientID: clientID, client, err := s.verifyClientCredentialsInternal(ctx, s.db, creds)
ClientSecret: clientSecret,
})
if err != nil { if err != nil {
return introspectDto, err return introspectDto, err
} }
token, err := s.jwtService.VerifyOauthAccessToken(tokenString) // Introspect the token
if err != nil { switch tokenType {
if errors.Is(err, jwt.ParseError()) { case OAuthAccessTokenJWTType:
// It's apparently not a valid JWT token, so we check if it's a valid refresh_token. return s.introspectAccessToken(client.ID, tokenString)
return s.introspectRefreshToken(ctx, tokenString) case OAuthRefreshTokenJWTType:
} return s.introspectRefreshToken(ctx, client.ID, tokenString)
default:
// Every failure we get means the token is invalid. Nothing more to do with the error. // We just treat the token as invalid
introspectDto.Active = false introspectDto.Active = false
return introspectDto, nil return introspectDto, nil
} }
}
func (s *OidcService) introspectAccessToken(clientID string, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
token, err := s.jwtService.VerifyOAuthAccessToken(tokenString)
if err != nil {
// Every failure we get means the token is invalid. Nothing more to do with the error.
introspectDto.Active = false
return introspectDto, nil //nolint:nilerr
}
// The ID of the client that made the request must match the client ID in the token
audience, ok := token.Audience()
if !ok || len(audience) != 1 || audience[0] == "" {
introspectDto.Active = false
return introspectDto, nil
}
if audience[0] != clientID {
return introspectDto, &common.OidcMissingClientCredentialsError{}
}
introspectDto.Active = true introspectDto.Active = true
introspectDto.TokenType = "access_token" introspectDto.TokenType = "access_token"
introspectDto.Audience = audience
if token.Has("scope") { if token.Has("scope") {
var ( var (
asString string asString string
@@ -519,9 +575,6 @@ func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecre
if subject, ok := token.Subject(); ok { if subject, ok := token.Subject(); ok {
introspectDto.Subject = subject introspectDto.Subject = subject
} }
if audience, ok := token.Audience(); ok {
introspectDto.Audience = audience
}
if issuer, ok := token.Issuer(); ok { if issuer, ok := token.Issuer(); ok {
introspectDto.Issuer = issuer introspectDto.Issuer = issuer
} }
@@ -532,12 +585,29 @@ func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecre
return introspectDto, nil return introspectDto, nil
} }
func (s *OidcService) introspectRefreshToken(ctx context.Context, refreshToken string) (introspectDto dto.OidcIntrospectionResponseDto, err error) { func (s *OidcService) introspectRefreshToken(ctx context.Context, clientID string, refreshToken string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
// Validate the signed refresh token and extract the actual token (which is a claim in the signed one)
tokenUserID, tokenClientID, tokenRT, err := s.jwtService.VerifyOAuthRefreshToken(refreshToken)
if err != nil {
return introspectDto, fmt.Errorf("invalid refresh token: %w", err)
}
// The ID of the client that made the call must match the client ID in the token
if tokenClientID != clientID {
return introspectDto, errors.New("invalid refresh token: client ID does not match")
}
var storedRefreshToken model.OidcRefreshToken var storedRefreshToken model.OidcRefreshToken
err = s.db. err = s.db.
WithContext(ctx). WithContext(ctx).
Preload("User"). Preload("User").
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(refreshToken), datatype.DateTime(time.Now())). Where(
"token = ? AND expires_at > ? AND user_id = ? AND client_id = ?",
utils.CreateSha256Hash(tokenRT),
datatype.DateTime(time.Now()),
tokenUserID,
tokenClientID,
).
First(&storedRefreshToken). First(&storedRefreshToken).
Error Error
if err != nil { if err != nil {
@@ -1062,9 +1132,11 @@ func (s *OidcService) addCallbackURLToClient(ctx context.Context, client *model.
} }
func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.OidcDeviceAuthorizationRequestDto) (*dto.OidcDeviceAuthorizationResponseDto, error) { func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.OidcDeviceAuthorizationRequestDto) (*dto.OidcDeviceAuthorizationResponseDto, error) {
client, err := s.verifyClientCredentialsInternal(ctx, s.db, dto.OidcCreateTokensDto{ client, err := s.verifyClientCredentialsInternal(ctx, s.db, ClientAuthCredentials{
ClientID: input.ClientID, ClientID: input.ClientID,
ClientSecret: input.ClientSecret, ClientSecret: input.ClientSecret,
ClientAssertionType: input.ClientAssertionType,
ClientAssertion: input.ClientAssertion,
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -1085,7 +1157,7 @@ func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.O
DeviceCode: deviceCode, DeviceCode: deviceCode,
UserCode: userCode, UserCode: userCode,
Scope: input.Scope, Scope: input.Scope,
ExpiresAt: datatype.DateTime(time.Now().Add(15 * time.Minute)), ExpiresAt: datatype.DateTime(time.Now().Add(DeviceCodeDuration)),
IsAuthorized: false, IsAuthorized: false,
ClientID: client.ID, ClientID: client.ID,
} }
@@ -1099,7 +1171,7 @@ func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.O
UserCode: userCode, UserCode: userCode,
VerificationURI: common.EnvConfig.AppURL + "/device", VerificationURI: common.EnvConfig.AppURL + "/device",
VerificationURIComplete: common.EnvConfig.AppURL + "/device?code=" + userCode, VerificationURIComplete: common.EnvConfig.AppURL + "/device?code=" + userCode,
ExpiresIn: 900, // 15 minutes ExpiresIn: int(DeviceCodeDuration.Seconds()),
Interval: 5, Interval: 5,
}, nil }, nil
} }
@@ -1255,7 +1327,7 @@ func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, u
refreshTokenHash := utils.CreateSha256Hash(refreshToken) refreshTokenHash := utils.CreateSha256Hash(refreshToken)
m := model.OidcRefreshToken{ m := model.OidcRefreshToken{
ExpiresAt: datatype.DateTime(time.Now().Add(30 * 24 * time.Hour)), // 30 days ExpiresAt: datatype.DateTime(time.Now().Add(RefreshTokenDuration)),
Token: refreshTokenHash, Token: refreshTokenHash,
ClientID: clientID, ClientID: clientID,
UserID: userID, UserID: userID,
@@ -1270,7 +1342,13 @@ func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, u
return "", err return "", err
} }
return refreshToken, nil // Sign the refresh token
signed, err := s.jwtService.GenerateOAuthRefreshToken(userID, clientID, refreshToken)
if err != nil {
return "", fmt.Errorf("failed to sign refresh token: %w", err)
}
return signed, nil
} }
func (s *OidcService) createAuthorizedClientInternal(ctx context.Context, userID string, clientID string, scope string, tx *gorm.DB) error { func (s *OidcService) createAuthorizedClientInternal(ctx context.Context, userID string, clientID string, scope string, tx *gorm.DB) error {
@@ -1291,7 +1369,23 @@ func (s *OidcService) createAuthorizedClientInternal(ctx context.Context, userID
return err return err
} }
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input dto.OidcCreateTokensDto) (*model.OidcClient, error) { type ClientAuthCredentials struct {
ClientID string
ClientSecret string
ClientAssertion string
ClientAssertionType string
}
func clientAuthCredentialsFromCreateTokensDto(d *dto.OidcCreateTokensDto) ClientAuthCredentials {
return ClientAuthCredentials{
ClientID: d.ClientID,
ClientSecret: d.ClientSecret,
ClientAssertion: d.ClientAssertion,
ClientAssertionType: d.ClientAssertionType,
}
}
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input ClientAuthCredentials) (*model.OidcClient, error) {
// First, ensure we have a valid client ID // First, ensure we have a valid client ID
if input.ClientID == "" { if input.ClientID == "" {
return nil, &common.OidcMissingClientCredentialsError{} return nil, &common.OidcMissingClientCredentialsError{}
@@ -1366,7 +1460,7 @@ func (s *OidcService) jwkSetForURL(ctx context.Context, url string) (set jwk.Set
return jwks, nil return jwks, nil
} }
func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.Context, client *model.OidcClient, input dto.OidcCreateTokensDto) error { func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.Context, client *model.OidcClient, input ClientAuthCredentials) error {
// First, parse the assertion JWT, without validating it, to check the issuer // First, parse the assertion JWT, without validating it, to check the issuer
assertion := []byte(input.ClientAssertion) assertion := []byte(input.ClientAssertion)
insecureToken, err := jwt.ParseInsecure(assertion) insecureToken, err := jwt.ParseInsecure(assertion)
@@ -1466,12 +1560,18 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use
return nil, err return nil, err
} }
// Commit the transaction before signing tokens to avoid locking the database for longer
err = tx.Commit().Error
if err != nil {
return nil, err
}
idToken, err := s.jwtService.BuildIDToken(userClaims, clientID, "") idToken, err := s.jwtService.BuildIDToken(userClaims, clientID, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
accessToken, err := s.jwtService.BuildOauthAccessToken(user, clientID) accessToken, err := s.jwtService.BuildOAuthAccessToken(user, clientID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1486,11 +1586,6 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use
return nil, err return nil, err
} }
err = tx.Commit().Error
if err != nil {
return nil, err
}
return &dto.OidcClientPreviewDto{ return &dto.OidcClientPreviewDto{
IdToken: idTokenPayload, IdToken: idTokenPayload,
AccessToken: accessTokenPayload, AccessToken: accessTokenPayload,
@@ -1498,26 +1593,11 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use
}, nil }, nil
} }
func (s *OidcService) GetUserClaimsForClient(ctx context.Context, userID string, clientID string) (map[string]interface{}, error) { func (s *OidcService) GetUserClaimsForClient(ctx context.Context, userID string, clientID string) (map[string]any, error) {
tx := s.db.Begin() return s.getUserClaimsForClientInternal(ctx, userID, clientID, s.db)
defer func() {
tx.Rollback()
}()
claims, err := s.getUserClaimsForClientInternal(ctx, userID, clientID, s.db)
if err != nil {
return nil, err
}
err = tx.Commit().Error
if err != nil {
return nil, err
}
return claims, nil
} }
func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID string, clientID string, tx *gorm.DB) (map[string]interface{}, error) { func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID string, clientID string, tx *gorm.DB) (map[string]any, error) {
var authorizedOidcClient model.UserAuthorizedOidcClient var authorizedOidcClient model.UserAuthorizedOidcClient
err := tx. err := tx.
WithContext(ctx). WithContext(ctx).
@@ -1532,14 +1612,13 @@ func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID
} }
func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, authorizedClient *model.UserAuthorizedOidcClient, tx *gorm.DB) (map[string]interface{}, error) { func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, authorizedClient *model.UserAuthorizedOidcClient, tx *gorm.DB) (map[string]any, error) {
user := authorizedClient.User user := authorizedClient.User
scopes := strings.Split(authorizedClient.Scope, " ") scopes := strings.Split(authorizedClient.Scope, " ")
claims := map[string]interface{}{ claims := make(map[string]any, 10)
"sub": user.ID,
}
claims["sub"] = user.ID
if slices.Contains(scopes, "email") { if slices.Contains(scopes, "email") {
claims["email"] = user.Email claims["email"] = user.Email
claims["email_verified"] = s.appConfigService.GetDbConfig().EmailsVerified.IsTrue() claims["email_verified"] = s.appConfigService.GetDbConfig().EmailsVerified.IsTrue()
@@ -1553,19 +1632,13 @@ func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, aut
claims["groups"] = userGroups claims["groups"] = userGroups
} }
profileClaims := map[string]interface{}{
"given_name": user.FirstName,
"family_name": user.LastName,
"name": user.FullName(),
"preferred_username": user.Username,
"picture": common.EnvConfig.AppURL + "/api/users/" + user.ID + "/profile-picture.png",
}
if slices.Contains(scopes, "profile") { if slices.Contains(scopes, "profile") {
// Add profile claims // Add profile claims
for k, v := range profileClaims { claims["given_name"] = user.FirstName
claims[k] = v claims["family_name"] = user.LastName
} claims["name"] = user.FullName()
claims["preferred_username"] = user.Username
claims["picture"] = common.EnvConfig.AppURL + "/api/users/" + user.ID + "/profile-picture.png"
// Add custom claims // Add custom claims
customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(ctx, user.ID, tx) customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(ctx, user.ID, tx)
@@ -1575,7 +1648,7 @@ func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, aut
for _, customClaim := range customClaims { for _, customClaim := range customClaims {
// The value of the custom claim can be a JSON object or a string // The value of the custom claim can be a JSON object or a string
var jsonValue interface{} var jsonValue any
err := json.Unmarshal([]byte(customClaim.Value), &jsonValue) err := json.Unmarshal([]byte(customClaim.Value), &jsonValue)
if err == nil { if err == nil {
// It's JSON, so we store it as an object // It's JSON, so we store it as an object

View File

@@ -210,7 +210,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
t.Run("Confidential client", func(t *testing.T) { t.Run("Confidential client", func(t *testing.T) {
t.Run("Succeeds with valid secret", func(t *testing.T) { t.Run("Succeeds with valid secret", func(t *testing.T) {
// Test with valid client credentials // Test with valid client credentials
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{ client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: confidentialClient.ID, ClientID: confidentialClient.ID,
ClientSecret: confidentialSecret, ClientSecret: confidentialSecret,
}) })
@@ -221,7 +221,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
t.Run("Fails with invalid secret", func(t *testing.T) { t.Run("Fails with invalid secret", func(t *testing.T) {
// Test with invalid client secret // Test with invalid client secret
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{ client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: confidentialClient.ID, ClientID: confidentialClient.ID,
ClientSecret: "invalid-secret", ClientSecret: "invalid-secret",
}) })
@@ -232,7 +232,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
t.Run("Fails with missing secret", func(t *testing.T) { t.Run("Fails with missing secret", func(t *testing.T) {
// Test with missing client secret // Test with missing client secret
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{ client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: confidentialClient.ID, ClientID: confidentialClient.ID,
}) })
require.Error(t, err) require.Error(t, err)
@@ -245,7 +245,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
t.Run("Public client", func(t *testing.T) { t.Run("Public client", func(t *testing.T) {
t.Run("Succeeds with no credentials", func(t *testing.T) { t.Run("Succeeds with no credentials", func(t *testing.T) {
// Public clients don't require client secret // Public clients don't require client secret
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{ client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: publicClient.ID, ClientID: publicClient.ID,
}) })
require.NoError(t, err) require.NoError(t, err)
@@ -270,7 +270,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Test with valid JWT assertion // Test with valid JWT assertion
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{ client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: federatedClient.ID, ClientID: federatedClient.ID,
ClientAssertionType: ClientAssertionTypeJWTBearer, ClientAssertionType: ClientAssertionTypeJWTBearer,
ClientAssertion: string(signedToken), ClientAssertion: string(signedToken),
@@ -282,7 +282,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
t.Run("Fails with malformed JWT", func(t *testing.T) { t.Run("Fails with malformed JWT", func(t *testing.T) {
// Test with invalid JWT assertion (just a random string) // Test with invalid JWT assertion (just a random string)
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{ client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: federatedClient.ID, ClientID: federatedClient.ID,
ClientAssertionType: ClientAssertionTypeJWTBearer, ClientAssertionType: ClientAssertionTypeJWTBearer,
ClientAssertion: "invalid.jwt.token", ClientAssertion: "invalid.jwt.token",
@@ -311,7 +311,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Test with invalid JWT assertion // Test with invalid JWT assertion
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{ client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: federatedClient.ID, ClientID: federatedClient.ID,
ClientAssertionType: ClientAssertionTypeJWTBearer, ClientAssertionType: ClientAssertionTypeJWTBearer,
ClientAssertion: string(signedToken), ClientAssertion: string(signedToken),
@@ -352,7 +352,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Test with valid JWT assertion // Test with valid JWT assertion
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{ client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: federatedClient.ID, ClientID: federatedClient.ID,
ClientAssertionType: ClientAssertionTypeJWTBearer, ClientAssertionType: ClientAssertionTypeJWTBearer,
ClientAssertion: string(signedToken), ClientAssertion: string(signedToken),

View File

@@ -0,0 +1,18 @@
package utils
import (
"net/http"
"strings"
)
// BearerAuth returns the value of the bearer token in the Authorization header if present
func BearerAuth(r *http.Request) (string, bool) {
const prefix = "bearer "
authHeader := r.Header.Get("Authorization")
if len(authHeader) >= len(prefix) && strings.ToLower(authHeader[:len(prefix)]) == prefix {
return authHeader[len(prefix):], true
}
return "", false
}

View File

@@ -0,0 +1,65 @@
package utils
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestBearerAuth(t *testing.T) {
tests := []struct {
name string
authHeader string
expectedToken string
expectedFound bool
}{
{
name: "Valid bearer token",
authHeader: "Bearer token123",
expectedToken: "token123",
expectedFound: true,
},
{
name: "Valid bearer token with mixed case",
authHeader: "beARer token456",
expectedToken: "token456",
expectedFound: true,
},
{
name: "No bearer prefix",
authHeader: "Basic dXNlcjpwYXNz",
expectedToken: "",
expectedFound: false,
},
{
name: "Empty auth header",
authHeader: "",
expectedToken: "",
expectedFound: false,
},
{
name: "Bearer prefix only",
authHeader: "Bearer ",
expectedToken: "",
expectedFound: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, "http://example.com", nil)
require.NoError(t, err, "Failed to create request")
if tt.authHeader != "" {
req.Header.Set("Authorization", tt.authHeader)
}
token, found := BearerAuth(req)
assert.Equal(t, tt.expectedFound, found)
assert.Equal(t, tt.expectedToken, token)
})
}
}

View File

@@ -7,8 +7,9 @@ import (
) )
func GetClaimsFromToken(token jwt.Token) (map[string]any, error) { func GetClaimsFromToken(token jwt.Token) (map[string]any, error) {
claims := make(map[string]any) keys := token.Keys()
for _, key := range token.Keys() { claims := make(map[string]any, len(keys))
for _, key := range keys {
var value any var value any
if err := token.Get(key, &value); err != nil { if err := token.Get(key, &value); err != nil {
return nil, fmt.Errorf("failed to get claim %s: %w", key, err) return nil, fmt.Errorf("failed to get claim %s: %w", key, err)

View File

@@ -87,11 +87,13 @@ export const refreshTokens = [
{ {
token: 'ou87UDg249r1StBLYkMEqy9TXDbV5HmGuDpMcZDo', token: 'ou87UDg249r1StBLYkMEqy9TXDbV5HmGuDpMcZDo',
clientId: oidcClients.nextcloud.id, clientId: oidcClients.nextcloud.id,
userId: 'f4b89dc2-62fb-46bf-9f5f-c34f4eafe93e',
expired: false expired: false
}, },
{ {
token: 'X4vqwtRyCUaq51UafHea4Fsg8Km6CAns6vp3tuX4', token: 'X4vqwtRyCUaq51UafHea4Fsg8Km6CAns6vp3tuX4',
clientId: oidcClients.nextcloud.id, clientId: oidcClients.nextcloud.id,
userId: 'f4b89dc2-62fb-46bf-9f5f-c34f4eafe93e',
expired: true expired: true
} }
]; ];

View File

@@ -156,11 +156,21 @@ test("End session with id token hint redirects to callback URL", async ({
test("Successfully refresh tokens with valid refresh token", async ({ test("Successfully refresh tokens with valid refresh token", async ({
request, request,
}) => { }) => {
const { token, clientId } = refreshTokens.filter( const { token, clientId, userId } = refreshTokens.filter(
(token) => !token.expired (token) => !token.expired
)[0]; )[0];
const clientSecret = "w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY"; const clientSecret = "w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY";
// Sign the refresh token
const refreshToken = await request.post("/api/test/refreshtoken", {
data: {
rt: token,
client: clientId,
user: userId,
}
}).then((r) => r.text())
// Perform the exchange
const refreshResponse = await request.post("/api/oidc/token", { const refreshResponse = await request.post("/api/oidc/token", {
headers: { headers: {
"Content-Type": "application/x-www-form-urlencoded", "Content-Type": "application/x-www-form-urlencoded",
@@ -168,7 +178,7 @@ test("Successfully refresh tokens with valid refresh token", async ({
form: { form: {
grant_type: "refresh_token", grant_type: "refresh_token",
client_id: clientId, client_id: clientId,
refresh_token: token, refresh_token: refreshToken,
client_secret: clientSecret, client_secret: clientSecret,
}, },
}); });
@@ -184,26 +194,25 @@ test("Successfully refresh tokens with valid refresh token", async ({
expect(tokenData.refresh_token).not.toBe(token); expect(tokenData.refresh_token).not.toBe(token);
}); });
test("Using refresh token invalidates it for future use", async ({
test("Refresh token fails when used for the wrong client", async ({
request, request,
}) => { }) => {
const { token, clientId } = refreshTokens.filter( const { token, clientId, userId } = refreshTokens.filter(
(token) => !token.expired (token) => !token.expired
)[0]; )[0];
const clientSecret = "w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY"; const clientSecret = "w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY";
await request.post("/api/oidc/token", { // Sign the refresh token
headers: { const refreshToken = await request.post("/api/test/refreshtoken", {
"Content-Type": "application/x-www-form-urlencoded", data: {
}, rt: token,
form: { client: 'bad-client',
grant_type: "refresh_token", user: userId,
client_id: clientId, }
refresh_token: token, }).then((r) => r.text())
client_secret: clientSecret,
},
});
// Perform the exchange
const refreshResponse = await request.post("/api/oidc/token", { const refreshResponse = await request.post("/api/oidc/token", {
headers: { headers: {
"Content-Type": "application/x-www-form-urlencoded", "Content-Type": "application/x-www-form-urlencoded",
@@ -211,7 +220,86 @@ test("Using refresh token invalidates it for future use", async ({
form: { form: {
grant_type: "refresh_token", grant_type: "refresh_token",
client_id: clientId, client_id: clientId,
refresh_token: token, refresh_token: refreshToken,
client_secret: clientSecret,
},
});
expect(refreshResponse.status()).toBe(400);
});
test("Refresh token fails when used for the wrong user", async ({
request,
}) => {
const { token, clientId } = refreshTokens.filter(
(token) => !token.expired
)[0];
const clientSecret = "w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY";
// Sign the refresh token
const refreshToken = await request.post("/api/test/refreshtoken", {
data: {
rt: token,
client: clientId,
user: 'bad-user',
}
}).then((r) => r.text())
// Perform the exchange
const refreshResponse = await request.post("/api/oidc/token", {
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
form: {
grant_type: "refresh_token",
client_id: clientId,
refresh_token: refreshToken,
client_secret: clientSecret,
},
});
expect(refreshResponse.status()).toBe(400);
});
test("Using refresh token invalidates it for future use", async ({
request,
}) => {
const { token, clientId, userId } = refreshTokens.filter(
(token) => !token.expired
)[0];
const clientSecret = "w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY";
// Sign the refresh token
const refreshToken = await request.post("/api/test/refreshtoken", {
data: {
rt: token,
client: clientId,
user: userId,
}
}).then((r) => r.text())
// Perform the exchange
await request.post("/api/oidc/token", {
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
form: {
grant_type: "refresh_token",
client_id: clientId,
refresh_token: refreshToken,
client_secret: clientSecret,
},
});
// Try again
const refreshResponse = await request.post("/api/oidc/token", {
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
form: {
grant_type: "refresh_token",
client_id: clientId,
refresh_token: refreshToken,
client_secret: clientSecret, client_secret: clientSecret,
}, },
}); });
@@ -219,11 +307,10 @@ test("Using refresh token invalidates it for future use", async ({
}); });
test.describe("Introspection endpoint", () => { test.describe("Introspection endpoint", () => {
const client = oidcClients.nextcloud; test("fails without client credentials", async ({ request }) => {
test("without client_id and client_secret fails", async ({ request }) => {
const validAccessToken = await generateOauthAccessToken( const validAccessToken = await generateOauthAccessToken(
users.tim, users.tim,
client.id oidcClients.nextcloud.id
); );
const introspectionResponse = await request.post("/api/oidc/introspect", { const introspectionResponse = await request.post("/api/oidc/introspect", {
headers: { headers: {
@@ -237,20 +324,20 @@ test.describe("Introspection endpoint", () => {
expect(introspectionResponse.status()).toBe(400); expect(introspectionResponse.status()).toBe(400);
}); });
test("with client_id and client_secret succeeds", async ({ test("succeeds with client credentials", async ({
request, request,
baseURL, baseURL,
}) => { }) => {
const validAccessToken = await generateOauthAccessToken( const validAccessToken = await generateOauthAccessToken(
users.tim, users.tim,
client.id oidcClients.nextcloud.id
); );
const introspectionResponse = await request.post("/api/oidc/introspect", { const introspectionResponse = await request.post("/api/oidc/introspect", {
headers: { headers: {
"Content-Type": "application/x-www-form-urlencoded", "Content-Type": "application/x-www-form-urlencoded",
Authorization: Authorization:
"Basic " + "Basic " +
Buffer.from(`${client.id}:${client.secret}`).toString("base64"), Buffer.from(`${oidcClients.nextcloud.id}:${oidcClients.nextcloud.secret}`).toString("base64"),
}, },
form: { form: {
token: validAccessToken, token: validAccessToken,
@@ -266,18 +353,102 @@ test.describe("Introspection endpoint", () => {
expect(introspectionBody.aud).toStrictEqual([oidcClients.nextcloud.id]); expect(introspectionBody.aud).toStrictEqual([oidcClients.nextcloud.id]);
}); });
test("succeeds with federated client credentials", async ({
page,
request,
baseURL,
}) => {
const validAccessToken = await generateOauthAccessToken(
users.tim,
oidcClients.federated.id
);
const clientAssertion = await oidcUtil.getClientAssertion(page, oidcClients.federated.federatedJWT);
const introspectionResponse = await request.post("/api/oidc/introspect", {
headers: {
"Content-Type": "application/x-www-form-urlencoded",
Authorization: "Bearer " + clientAssertion,
},
form: {
token: validAccessToken,
},
});
expect(introspectionResponse.status()).toBe(200);
const introspectionBody = await introspectionResponse.json();
expect(introspectionBody.active).toBe(true);
expect(introspectionBody.token_type).toBe("access_token");
expect(introspectionBody.iss).toBe(baseURL);
expect(introspectionBody.sub).toBe(users.tim.id);
expect(introspectionBody.aud).toStrictEqual([oidcClients.federated.id]);
});
test("fails with client credentials for wrong app", async ({
request,
}) => {
const validAccessToken = await generateOauthAccessToken(
users.tim,
oidcClients.nextcloud.id
);
const introspectionResponse = await request.post("/api/oidc/introspect", {
headers: {
"Content-Type": "application/x-www-form-urlencoded",
Authorization:
"Basic " +
Buffer.from(`${oidcClients.immich.id}:${oidcClients.immich.secret}`).toString("base64"),
},
form: {
token: validAccessToken,
},
});
expect(introspectionResponse.status()).toBe(400);
});
test("fails with federated credentials for wrong app", async ({
page,
request,
}) => {
const validAccessToken = await generateOauthAccessToken(
users.tim,
oidcClients.nextcloud.id
);
const clientAssertion = await oidcUtil.getClientAssertion(page, oidcClients.federated.federatedJWT);
const introspectionResponse = await request.post("/api/oidc/introspect", {
headers: {
"Content-Type": "application/x-www-form-urlencoded",
Authorization: "Bearer " + clientAssertion,
},
form: {
token: validAccessToken,
},
});
expect(introspectionResponse.status()).toBe(400);
});
test("non-expired refresh_token can be verified", async ({ request }) => { test("non-expired refresh_token can be verified", async ({ request }) => {
const { token } = refreshTokens.filter((token) => !token.expired)[0]; const { token, clientId, userId } = refreshTokens.filter(
(token) => !token.expired
)[0];
// Sign the refresh token
const refreshToken = await request.post("/api/test/refreshtoken", {
data: {
rt: token,
client: clientId,
user: userId,
}
}).then((r) => r.text())
const introspectionResponse = await request.post("/api/oidc/introspect", { const introspectionResponse = await request.post("/api/oidc/introspect", {
headers: { headers: {
"Content-Type": "application/x-www-form-urlencoded", "Content-Type": "application/x-www-form-urlencoded",
Authorization: Authorization:
"Basic " + "Basic " +
Buffer.from(`${client.id}:${client.secret}`).toString("base64"), Buffer.from(`${oidcClients.nextcloud.id}:${oidcClients.nextcloud.secret}`).toString("base64"),
}, },
form: { form: {
token: token, token: refreshToken,
}, },
}); });
@@ -288,17 +459,28 @@ test.describe("Introspection endpoint", () => {
}); });
test("expired refresh_token can be verified", async ({ request }) => { test("expired refresh_token can be verified", async ({ request }) => {
const { token } = refreshTokens.filter((token) => token.expired)[0]; const { token, clientId, userId } = refreshTokens.filter(
(token) => token.expired
)[0];
// Sign the refresh token
const refreshToken = await request.post("/api/test/refreshtoken", {
data: {
rt: token,
client: clientId,
user: userId,
}
}).then((r) => r.text())
const introspectionResponse = await request.post("/api/oidc/introspect", { const introspectionResponse = await request.post("/api/oidc/introspect", {
headers: { headers: {
"Content-Type": "application/x-www-form-urlencoded", "Content-Type": "application/x-www-form-urlencoded",
Authorization: Authorization:
"Basic " + "Basic " +
Buffer.from(`${client.id}:${client.secret}`).toString("base64"), Buffer.from(`${oidcClients.nextcloud.id}:${oidcClients.nextcloud.secret}`).toString("base64"),
}, },
form: { form: {
token: token, token: refreshToken,
}, },
}); });
@@ -310,7 +492,7 @@ test.describe("Introspection endpoint", () => {
test("expired access_token can't be verified", async ({ request }) => { test("expired access_token can't be verified", async ({ request }) => {
const expiredAccessToken = await generateOauthAccessToken( const expiredAccessToken = await generateOauthAccessToken(
users.tim, users.tim,
client.id, oidcClients.nextcloud.id,
true true
); );
const introspectionResponse = await request.post("/api/oidc/introspect", { const introspectionResponse = await request.post("/api/oidc/introspect", {