mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-14 01:10:54 +03:00
feat: allow introspection and device code endpoints to use Federated Client Credentials (#640)
This commit is contained in:
committed by
GitHub
parent
df5c1ed1f8
commit
b62b61fb01
@@ -14,6 +14,7 @@ func NewTestController(group *gin.RouterGroup, testService *service.TestService)
|
||||
testController := &TestController{TestService: testService}
|
||||
|
||||
group.POST("/test/reset", testController.resetAndSeedHandler)
|
||||
group.POST("/test/refreshtoken", testController.signRefreshToken)
|
||||
|
||||
group.GET("/externalidp/jwks.json", testController.externalIdPJWKS)
|
||||
group.POST("/externalidp/sign", testController.externalIdPSignToken)
|
||||
@@ -100,3 +101,24 @@ func (tc *TestController) externalIdPSignToken(c *gin.Context) {
|
||||
|
||||
c.Writer.WriteString(token)
|
||||
}
|
||||
|
||||
func (tc *TestController) signRefreshToken(c *gin.Context) {
|
||||
var input struct {
|
||||
UserID string `json:"user"`
|
||||
ClientID string `json:"client"`
|
||||
RefreshToken string `json:"rt"`
|
||||
}
|
||||
err := c.ShouldBindJSON(&input)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := tc.TestService.SignRefreshToken(input.UserID, input.ClientID, input.RefreshToken)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
c.Writer.WriteString(token)
|
||||
}
|
||||
|
||||
@@ -200,7 +200,7 @@ func (oc *OidcController) userInfoHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
token, err := oc.jwtService.VerifyOauthAccessToken(authToken)
|
||||
token, err := oc.jwtService.VerifyOAuthAccessToken(authToken)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -308,9 +308,21 @@ func (oc *OidcController) introspectTokenHandler(c *gin.Context) {
|
||||
// find valid tokens) while still allowing it to be used by an application that is
|
||||
// supposed to interact with our IdP (since that needs to have a client_id
|
||||
// and client_secret anyway).
|
||||
clientID, clientSecret, _ := c.Request.BasicAuth()
|
||||
var (
|
||||
creds service.ClientAuthCredentials
|
||||
ok bool
|
||||
)
|
||||
creds.ClientID, creds.ClientSecret, ok = c.Request.BasicAuth()
|
||||
if !ok {
|
||||
// If there's no basic auth, check if we have a bearer token
|
||||
bearer, ok := utils.BearerAuth(c.Request)
|
||||
if ok {
|
||||
creds.ClientAssertionType = service.ClientAssertionTypeJWTBearer
|
||||
creds.ClientAssertion = bearer
|
||||
}
|
||||
}
|
||||
|
||||
response, err := oc.oidcService.IntrospectToken(c.Request.Context(), clientID, clientSecret, input.Token)
|
||||
response, err := oc.oidcService.IntrospectToken(c.Request.Context(), creds, input.Token)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
|
||||
@@ -116,6 +116,8 @@ type OidcDeviceAuthorizationRequestDto struct {
|
||||
ClientID string `form:"client_id" binding:"required"`
|
||||
Scope string `form:"scope" binding:"required"`
|
||||
ClientSecret string `form:"client_secret"`
|
||||
ClientAssertion string `form:"client_assertion"`
|
||||
ClientAssertionType string `form:"client_assertion_type"`
|
||||
}
|
||||
|
||||
type OidcDeviceAuthorizationResponseDto struct {
|
||||
|
||||
@@ -479,6 +479,10 @@ func (s *TestService) SetLdapTestConfig(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *TestService) SignRefreshToken(userID, clientID, refreshToken string) (string, error) {
|
||||
return s.jwtService.GenerateOAuthRefreshToken(userID, clientID, refreshToken)
|
||||
}
|
||||
|
||||
// GetExternalIdPJWKS returns the JWKS for the "external IdP".
|
||||
func (s *TestService) GetExternalIdPJWKS() (jwk.Set, error) {
|
||||
pubKey, err := s.externalIdPKey.PublicKey()
|
||||
|
||||
@@ -39,9 +39,15 @@ const (
|
||||
// TokenTypeClaim is the claim used to identify the type of token
|
||||
TokenTypeClaim = "type"
|
||||
|
||||
// RefreshTokenClaim is the claim used for the refresh token's value
|
||||
RefreshTokenClaim = "rt"
|
||||
|
||||
// OAuthAccessTokenJWTType identifies a JWT as an OAuth access token
|
||||
OAuthAccessTokenJWTType = "oauth-access-token" //nolint:gosec
|
||||
|
||||
// OAuthRefreshTokenJWTType identifies a JWT as an OAuth refresh token
|
||||
OAuthRefreshTokenJWTType = "refresh-token"
|
||||
|
||||
// AccessTokenJWTType identifies a JWT as an access token used by Pocket ID
|
||||
AccessTokenJWTType = "access-token"
|
||||
|
||||
@@ -322,8 +328,8 @@ func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// BuildOauthAccessToken creates an OAuth access token with all claims
|
||||
func (s *JwtService) BuildOauthAccessToken(user model.User, clientID string) (jwt.Token, error) {
|
||||
// BuildOAuthAccessToken creates an OAuth access token with all claims
|
||||
func (s *JwtService) BuildOAuthAccessToken(user model.User, clientID string) (jwt.Token, error) {
|
||||
now := time.Now()
|
||||
token, err := jwt.NewBuilder().
|
||||
Subject(user.ID).
|
||||
@@ -348,9 +354,9 @@ func (s *JwtService) BuildOauthAccessToken(user model.User, clientID string) (jw
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// GenerateOauthAccessToken creates and signs an OAuth access token
|
||||
func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) {
|
||||
token, err := s.BuildOauthAccessToken(user, clientID)
|
||||
// GenerateOAuthAccessToken creates and signs an OAuth access token
|
||||
func (s *JwtService) GenerateOAuthAccessToken(user model.User, clientID string) (string, error) {
|
||||
token, err := s.BuildOAuthAccessToken(user, clientID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -364,7 +370,7 @@ func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string)
|
||||
return string(signed), nil
|
||||
}
|
||||
|
||||
func (s *JwtService) VerifyOauthAccessToken(tokenString string) (jwt.Token, error) {
|
||||
func (s *JwtService) VerifyOAuthAccessToken(tokenString string) (jwt.Token, error) {
|
||||
alg, _ := s.privateKey.Algorithm()
|
||||
token, err := jwt.ParseString(
|
||||
tokenString,
|
||||
@@ -381,6 +387,96 @@ func (s *JwtService) VerifyOauthAccessToken(tokenString string) (jwt.Token, erro
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *JwtService) GenerateOAuthRefreshToken(userID string, clientID string, refreshToken string) (string, error) {
|
||||
now := time.Now()
|
||||
token, err := jwt.NewBuilder().
|
||||
Subject(userID).
|
||||
Expiration(now.Add(RefreshTokenDuration)).
|
||||
IssuedAt(now).
|
||||
Issuer(common.EnvConfig.AppURL).
|
||||
Build()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to build token: %w", err)
|
||||
}
|
||||
|
||||
err = token.Set(RefreshTokenClaim, refreshToken)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set 'rt' claim in token: %w", err)
|
||||
}
|
||||
|
||||
err = SetAudienceString(token, clientID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
|
||||
}
|
||||
|
||||
err = SetTokenType(token, OAuthRefreshTokenJWTType)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set 'type' claim in token: %w", err)
|
||||
}
|
||||
|
||||
alg, _ := s.privateKey.Algorithm()
|
||||
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign token: %w", err)
|
||||
}
|
||||
|
||||
return string(signed), nil
|
||||
}
|
||||
|
||||
func (s *JwtService) VerifyOAuthRefreshToken(tokenString string) (userID, clientID, rt string, err error) {
|
||||
alg, _ := s.privateKey.Algorithm()
|
||||
token, err := jwt.ParseString(
|
||||
tokenString,
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithKey(alg, s.privateKey),
|
||||
jwt.WithAcceptableSkew(clockSkew),
|
||||
jwt.WithIssuer(common.EnvConfig.AppURL),
|
||||
jwt.WithValidator(TokenTypeValidator(OAuthRefreshTokenJWTType)),
|
||||
)
|
||||
if err != nil {
|
||||
return "", "", "", fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
err = token.Get(RefreshTokenClaim, &rt)
|
||||
if err != nil {
|
||||
return "", "", "", fmt.Errorf("failed to get '%s' claim from token: %w", RefreshTokenClaim, err)
|
||||
}
|
||||
|
||||
audiences, ok := token.Audience()
|
||||
if !ok || len(audiences) != 1 || audiences[0] == "" {
|
||||
return "", "", "", errors.New("failed to get 'aud' claim from token")
|
||||
}
|
||||
clientID = audiences[0]
|
||||
|
||||
userID, ok = token.Subject()
|
||||
if !ok {
|
||||
return "", "", "", errors.New("failed to get 'sub' claim from token")
|
||||
}
|
||||
|
||||
return userID, clientID, rt, nil
|
||||
}
|
||||
|
||||
// GetTokenType returns the type of the JWT token issued by Pocket ID, but **does not validate it**.
|
||||
func (s *JwtService) GetTokenType(tokenString string) (string, jwt.Token, error) {
|
||||
// Disable validation and verification to parse the token without checking it
|
||||
token, err := jwt.ParseString(
|
||||
tokenString,
|
||||
jwt.WithValidate(false),
|
||||
jwt.WithVerify(false),
|
||||
)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
var tokenType string
|
||||
err = token.Get(TokenTypeClaim, &tokenType)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to get token type claim: %w", err)
|
||||
}
|
||||
|
||||
return tokenType, token, nil
|
||||
}
|
||||
|
||||
// GetPublicJWK returns the JSON Web Key (JWK) for the public key.
|
||||
func (s *JwtService) GetPublicJWK() (jwk.Key, error) {
|
||||
if s.privateKey == nil {
|
||||
@@ -478,7 +574,10 @@ func GetIsAdmin(token jwt.Token) (bool, error) {
|
||||
}
|
||||
var isAdmin bool
|
||||
err := token.Get(IsAdminClaim, &isAdmin)
|
||||
return isAdmin, err
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get 'isAdmin' claim from token: %w", err)
|
||||
}
|
||||
return isAdmin, nil
|
||||
}
|
||||
|
||||
// SetTokenType sets the "type" claim in the token
|
||||
|
||||
@@ -882,7 +882,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
||||
func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
|
||||
// Create a temporary directory for the test
|
||||
tempDir := t.TempDir()
|
||||
|
||||
@@ -914,12 +914,12 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
||||
const clientID = "test-client-123"
|
||||
|
||||
// Generate a token
|
||||
tokenString, err := service.GenerateOauthAccessToken(user, clientID)
|
||||
tokenString, err := service.GenerateOAuthAccessToken(user, clientID)
|
||||
require.NoError(t, err, "Failed to generate OAuth access token")
|
||||
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||
|
||||
// Verify the token
|
||||
claims, err := service.VerifyOauthAccessToken(tokenString)
|
||||
claims, err := service.VerifyOAuthAccessToken(tokenString)
|
||||
require.NoError(t, err, "Failed to verify generated OAuth access token")
|
||||
|
||||
// Check the claims
|
||||
@@ -972,7 +972,7 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
||||
require.NoError(t, err, "Failed to sign token")
|
||||
|
||||
// Verify should fail due to expiration
|
||||
_, err = service.VerifyOauthAccessToken(string(signed))
|
||||
_, err = service.VerifyOAuthAccessToken(string(signed))
|
||||
require.Error(t, err, "Verification should fail with expired token")
|
||||
assert.Contains(t, err.Error(), `"exp" not satisfied`, "Error message should indicate token verification failure")
|
||||
})
|
||||
@@ -996,11 +996,11 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
||||
const clientID = "test-client-789"
|
||||
|
||||
// Generate a token with the first service
|
||||
tokenString, err := service1.GenerateOauthAccessToken(user, clientID)
|
||||
tokenString, err := service1.GenerateOAuthAccessToken(user, clientID)
|
||||
require.NoError(t, err, "Failed to generate OAuth access token")
|
||||
|
||||
// Verify with the second service should fail due to different keys
|
||||
_, err = service2.VerifyOauthAccessToken(tokenString)
|
||||
_, err = service2.VerifyOAuthAccessToken(tokenString)
|
||||
require.Error(t, err, "Verification should fail with invalid signature")
|
||||
assert.Contains(t, err.Error(), "verification error", "Error message should indicate token verification failure")
|
||||
})
|
||||
@@ -1032,12 +1032,12 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
||||
const clientID = "eddsa-oauth-client"
|
||||
|
||||
// Generate a token
|
||||
tokenString, err := service.GenerateOauthAccessToken(user, clientID)
|
||||
tokenString, err := service.GenerateOAuthAccessToken(user, clientID)
|
||||
require.NoError(t, err, "Failed to generate OAuth access token with key")
|
||||
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||
|
||||
// Verify the token
|
||||
claims, err := service.VerifyOauthAccessToken(tokenString)
|
||||
claims, err := service.VerifyOAuthAccessToken(tokenString)
|
||||
require.NoError(t, err, "Failed to verify generated OAuth access token with key")
|
||||
|
||||
// Check the claims
|
||||
@@ -1086,12 +1086,12 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
||||
const clientID = "ecdsa-oauth-client"
|
||||
|
||||
// Generate a token
|
||||
tokenString, err := service.GenerateOauthAccessToken(user, clientID)
|
||||
tokenString, err := service.GenerateOAuthAccessToken(user, clientID)
|
||||
require.NoError(t, err, "Failed to generate OAuth access token with key")
|
||||
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||
|
||||
// Verify the token
|
||||
claims, err := service.VerifyOauthAccessToken(tokenString)
|
||||
claims, err := service.VerifyOAuthAccessToken(tokenString)
|
||||
require.NoError(t, err, "Failed to verify generated OAuth access token with key")
|
||||
|
||||
// Check the claims
|
||||
@@ -1140,12 +1140,12 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
||||
const clientID = "rsa-oauth-client"
|
||||
|
||||
// Generate a token
|
||||
tokenString, err := service.GenerateOauthAccessToken(user, clientID)
|
||||
tokenString, err := service.GenerateOAuthAccessToken(user, clientID)
|
||||
require.NoError(t, err, "Failed to generate OAuth access token with key")
|
||||
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||
|
||||
// Verify the token
|
||||
claims, err := service.VerifyOauthAccessToken(tokenString)
|
||||
claims, err := service.VerifyOAuthAccessToken(tokenString)
|
||||
require.NoError(t, err, "Failed to verify generated OAuth access token with key")
|
||||
|
||||
// Check the claims
|
||||
@@ -1168,6 +1168,92 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateVerifyOAuthRefreshToken(t *testing.T) {
|
||||
// Create a temporary directory for the test
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Initialize the JWT service with a mock AppConfigService
|
||||
mockConfig := NewTestAppConfigService(&model.AppConfig{})
|
||||
|
||||
// Setup the environment variable required by the token verification
|
||||
originalAppURL := common.EnvConfig.AppURL
|
||||
common.EnvConfig.AppURL = "https://test.example.com"
|
||||
defer func() {
|
||||
common.EnvConfig.AppURL = originalAppURL
|
||||
}()
|
||||
|
||||
t.Run("generates and verifies refresh token", func(t *testing.T) {
|
||||
// Create a JWT service
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Create a test user
|
||||
const (
|
||||
userID = "user123"
|
||||
clientID = "client123"
|
||||
refreshToken = "rt-123"
|
||||
)
|
||||
|
||||
// Generate a token
|
||||
tokenString, err := service.GenerateOAuthRefreshToken(userID, clientID, refreshToken)
|
||||
require.NoError(t, err, "Failed to generate refresh token")
|
||||
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||
|
||||
// Verify the token
|
||||
resUser, resClient, resRT, err := service.VerifyOAuthRefreshToken(tokenString)
|
||||
require.NoError(t, err, "Failed to verify generated token")
|
||||
assert.Equal(t, userID, resUser, "Should return correct user ID")
|
||||
assert.Equal(t, clientID, resClient, "Should return correct client ID")
|
||||
assert.Equal(t, refreshToken, resRT, "Should return correct refresh token")
|
||||
})
|
||||
|
||||
t.Run("fails verification for expired token", func(t *testing.T) {
|
||||
// Create a JWT service
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
// Generate a token using JWT directly to create an expired token
|
||||
token, err := jwt.NewBuilder().
|
||||
Subject("user789").
|
||||
Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago
|
||||
IssuedAt(time.Now().Add(-2 * time.Hour)).
|
||||
Audience([]string{"client123"}).
|
||||
Issuer(common.EnvConfig.AppURL).
|
||||
Build()
|
||||
require.NoError(t, err, "Failed to build token")
|
||||
|
||||
signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey))
|
||||
require.NoError(t, err, "Failed to sign token")
|
||||
|
||||
// Verify should fail due to expiration
|
||||
_, _, _, err = service.VerifyOAuthRefreshToken(string(signed))
|
||||
require.Error(t, err, "Verification should fail with expired token")
|
||||
assert.Contains(t, err.Error(), `"exp" not satisfied`, "Error message should indicate token verification failure")
|
||||
})
|
||||
|
||||
t.Run("fails verification with invalid signature", func(t *testing.T) {
|
||||
// Create two JWT services with different keys
|
||||
service1 := &JwtService{}
|
||||
err := service1.init(mockConfig, t.TempDir())
|
||||
require.NoError(t, err, "Failed to initialize first JWT service")
|
||||
|
||||
service2 := &JwtService{}
|
||||
err = service2.init(mockConfig, t.TempDir())
|
||||
require.NoError(t, err, "Failed to initialize second JWT service")
|
||||
|
||||
// Generate a token with the first service
|
||||
tokenString, err := service1.GenerateOAuthRefreshToken("user789", "client123", "my-rt-123")
|
||||
require.NoError(t, err, "Failed to generate refresh token")
|
||||
|
||||
// Verify with the second service should fail due to different keys
|
||||
_, _, _, err = service2.VerifyOAuthRefreshToken(tokenString)
|
||||
require.Error(t, err, "Verification should fail with invalid signature")
|
||||
assert.Contains(t, err.Error(), "verification error", "Error message should indicate token verification failure")
|
||||
})
|
||||
}
|
||||
|
||||
func TestTokenTypeValidator(t *testing.T) {
|
||||
// Create a context for the validator function
|
||||
ctx := context.Background()
|
||||
@@ -1213,7 +1299,104 @@ func TestTokenTypeValidator(t *testing.T) {
|
||||
require.Error(t, err, "Validator should reject token without type claim")
|
||||
assert.Contains(t, err.Error(), "failed to get token type claim")
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetTokenType(t *testing.T) {
|
||||
// Create a temporary directory for the test
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Initialize the JWT service
|
||||
mockConfig := NewTestAppConfigService(&model.AppConfig{})
|
||||
service := &JwtService{}
|
||||
err := service.init(mockConfig, tempDir)
|
||||
require.NoError(t, err, "Failed to initialize JWT service")
|
||||
|
||||
buildTokenForType := func(t *testing.T, typ string, setClaimsFn func(b *jwt.Builder)) string {
|
||||
t.Helper()
|
||||
|
||||
b := jwt.NewBuilder()
|
||||
b.Subject("user123")
|
||||
if setClaimsFn != nil {
|
||||
setClaimsFn(b)
|
||||
}
|
||||
|
||||
token, err := b.Build()
|
||||
require.NoError(t, err, "Failed to build token")
|
||||
|
||||
err = SetTokenType(token, typ)
|
||||
require.NoError(t, err, "Failed to set token type")
|
||||
|
||||
alg, _ := service.privateKey.Algorithm()
|
||||
signed, err := jwt.Sign(token, jwt.WithKey(alg, service.privateKey))
|
||||
require.NoError(t, err, "Failed to sign token")
|
||||
|
||||
return string(signed)
|
||||
}
|
||||
|
||||
t.Run("correctly identifies access tokens", func(t *testing.T) {
|
||||
tokenString := buildTokenForType(t, AccessTokenJWTType, nil)
|
||||
|
||||
// Get the token type without validating
|
||||
tokenType, _, err := service.GetTokenType(tokenString)
|
||||
require.NoError(t, err, "GetTokenType should not return an error")
|
||||
assert.Equal(t, AccessTokenJWTType, tokenType, "Token type should be correctly identified as access token")
|
||||
})
|
||||
|
||||
t.Run("correctly identifies ID tokens", func(t *testing.T) {
|
||||
tokenString := buildTokenForType(t, IDTokenJWTType, nil)
|
||||
|
||||
// Get the token type without validating
|
||||
tokenType, _, err := service.GetTokenType(tokenString)
|
||||
require.NoError(t, err, "GetTokenType should not return an error")
|
||||
assert.Equal(t, IDTokenJWTType, tokenType, "Token type should be correctly identified as ID token")
|
||||
})
|
||||
|
||||
t.Run("correctly identifies OAuth access tokens", func(t *testing.T) {
|
||||
tokenString := buildTokenForType(t, OAuthAccessTokenJWTType, nil)
|
||||
|
||||
// Get the token type without validating
|
||||
tokenType, _, err := service.GetTokenType(tokenString)
|
||||
require.NoError(t, err, "GetTokenType should not return an error")
|
||||
assert.Equal(t, OAuthAccessTokenJWTType, tokenType, "Token type should be correctly identified as OAuth access token")
|
||||
})
|
||||
|
||||
t.Run("correctly identifies refresh tokens", func(t *testing.T) {
|
||||
tokenString := buildTokenForType(t, OAuthRefreshTokenJWTType, nil)
|
||||
|
||||
// Get the token type without validating
|
||||
tokenType, _, err := service.GetTokenType(tokenString)
|
||||
require.NoError(t, err, "GetTokenType should not return an error")
|
||||
assert.Equal(t, OAuthRefreshTokenJWTType, tokenType, "Token type should be correctly identified as refresh token")
|
||||
})
|
||||
|
||||
t.Run("works with expired tokens", func(t *testing.T) {
|
||||
tokenString := buildTokenForType(t, AccessTokenJWTType, func(b *jwt.Builder) {
|
||||
b.Expiration(time.Now().Add(-1 * time.Hour)) // Expired 1 hour ago
|
||||
})
|
||||
|
||||
// Get the token type without validating
|
||||
tokenType, _, err := service.GetTokenType(tokenString)
|
||||
require.NoError(t, err, "GetTokenType should not return an error for expired tokens")
|
||||
assert.Equal(t, AccessTokenJWTType, tokenType, "Token type should be correctly identified even for expired tokens")
|
||||
})
|
||||
|
||||
t.Run("returns error for malformed tokens", func(t *testing.T) {
|
||||
// Try to get the token type of a malformed token
|
||||
tokenType, _, err := service.GetTokenType("not.a.valid.jwt.token")
|
||||
require.Error(t, err, "GetTokenType should return an error for malformed tokens")
|
||||
assert.Empty(t, tokenType, "Token type should be empty for malformed tokens")
|
||||
})
|
||||
|
||||
t.Run("returns error for tokens without type claim", func(t *testing.T) {
|
||||
// Create a token without type claim
|
||||
tokenString := buildTokenForType(t, "", nil)
|
||||
|
||||
// Get the token type without validating
|
||||
tokenType, _, err := service.GetTokenType(tokenString)
|
||||
require.Error(t, err, "GetTokenType should return an error for tokens without type claim")
|
||||
assert.Empty(t, tokenType, "Token type should be empty when type claim is missing")
|
||||
assert.Contains(t, err.Error(), "failed to get token type claim", "Error message should indicate missing token type claim")
|
||||
})
|
||||
}
|
||||
|
||||
func importKey(t *testing.T, privateKeyRaw any, path string) string {
|
||||
|
||||
@@ -40,6 +40,9 @@ const (
|
||||
GrantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
|
||||
|
||||
ClientAssertionTypeJWTBearer = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" //nolint:gosec
|
||||
|
||||
RefreshTokenDuration = 30 * 24 * time.Hour // 30 days
|
||||
DeviceCodeDuration = 15 * time.Minute
|
||||
)
|
||||
|
||||
type OidcService struct {
|
||||
@@ -252,7 +255,7 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
_, err := s.verifyClientCredentialsInternal(ctx, tx, input)
|
||||
_, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input))
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
@@ -303,7 +306,7 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
accessToken, err := s.jwtService.GenerateOauthAccessToken(deviceAuth.User, input.ClientID)
|
||||
accessToken, err := s.jwtService.GenerateOAuthAccessToken(deviceAuth.User, input.ClientID)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
@@ -333,7 +336,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, tx, input)
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input))
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
@@ -375,7 +378,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
accessToken, err := s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, input.ClientID)
|
||||
accessToken, err := s.jwtService.GenerateOAuthAccessToken(authorizationCodeMetaData.User, input.ClientID)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
@@ -406,22 +409,39 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
|
||||
return CreatedTokens{}, &common.OidcMissingRefreshTokenError{}
|
||||
}
|
||||
|
||||
// Validate the signed refresh token and extract the actual token (which is a claim in the signed one)
|
||||
userID, clientID, rt, err := s.jwtService.VerifyOAuthRefreshToken(input.RefreshToken)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
|
||||
}
|
||||
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
_, err := s.verifyClientCredentialsInternal(ctx, tx, input)
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input))
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
// The ID of the client that made the call must match the client ID in the token
|
||||
if client.ID != clientID {
|
||||
return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
|
||||
}
|
||||
|
||||
// Verify refresh token
|
||||
var storedRefreshToken model.OidcRefreshToken
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Preload("User").
|
||||
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(input.RefreshToken), datatype.DateTime(time.Now())).
|
||||
Where(
|
||||
"token = ? AND expires_at > ? AND user_id = ? AND client_id = ?",
|
||||
utils.CreateSha256Hash(rt),
|
||||
datatype.DateTime(time.Now()),
|
||||
userID,
|
||||
input.ClientID,
|
||||
).
|
||||
First(&storedRefreshToken).
|
||||
Error
|
||||
if err != nil {
|
||||
@@ -437,7 +457,7 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
|
||||
}
|
||||
|
||||
// Generate a new access token
|
||||
accessToken, err := s.jwtService.GenerateOauthAccessToken(storedRefreshToken.User, input.ClientID)
|
||||
accessToken, err := s.jwtService.GenerateOAuthAccessToken(storedRefreshToken.User, input.ClientID)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
@@ -469,33 +489,69 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecret, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
|
||||
if clientID == "" || clientSecret == "" {
|
||||
func (s *OidcService) IntrospectToken(ctx context.Context, creds ClientAuthCredentials, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
|
||||
// Get the type of the token and the client ID
|
||||
tokenType, token, err := s.jwtService.GetTokenType(tokenString)
|
||||
if err != nil {
|
||||
// We just treat the token as invalid
|
||||
introspectDto.Active = false
|
||||
return introspectDto, nil //nolint:nilerr
|
||||
}
|
||||
|
||||
// If we don't have a client ID, get it from the token
|
||||
// Otherwise, we need to make sure that the client ID passed as credential matches
|
||||
tokenAudiences, _ := token.Audience()
|
||||
if len(tokenAudiences) != 1 || tokenAudiences[0] == "" {
|
||||
// We just treat the token as invalid
|
||||
introspectDto.Active = false
|
||||
return introspectDto, nil
|
||||
}
|
||||
if creds.ClientID == "" {
|
||||
creds.ClientID = tokenAudiences[0]
|
||||
} else if creds.ClientID != tokenAudiences[0] {
|
||||
return introspectDto, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
|
||||
_, err = s.verifyClientCredentialsInternal(ctx, s.db, dto.OidcCreateTokensDto{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
})
|
||||
// Verify the credentials for the call
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, s.db, creds)
|
||||
if err != nil {
|
||||
return introspectDto, err
|
||||
}
|
||||
|
||||
token, err := s.jwtService.VerifyOauthAccessToken(tokenString)
|
||||
if err != nil {
|
||||
if errors.Is(err, jwt.ParseError()) {
|
||||
// It's apparently not a valid JWT token, so we check if it's a valid refresh_token.
|
||||
return s.introspectRefreshToken(ctx, tokenString)
|
||||
}
|
||||
|
||||
// Every failure we get means the token is invalid. Nothing more to do with the error.
|
||||
// Introspect the token
|
||||
switch tokenType {
|
||||
case OAuthAccessTokenJWTType:
|
||||
return s.introspectAccessToken(client.ID, tokenString)
|
||||
case OAuthRefreshTokenJWTType:
|
||||
return s.introspectRefreshToken(ctx, client.ID, tokenString)
|
||||
default:
|
||||
// We just treat the token as invalid
|
||||
introspectDto.Active = false
|
||||
return introspectDto, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OidcService) introspectAccessToken(clientID string, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
|
||||
token, err := s.jwtService.VerifyOAuthAccessToken(tokenString)
|
||||
if err != nil {
|
||||
// Every failure we get means the token is invalid. Nothing more to do with the error.
|
||||
introspectDto.Active = false
|
||||
return introspectDto, nil //nolint:nilerr
|
||||
}
|
||||
|
||||
// The ID of the client that made the request must match the client ID in the token
|
||||
audience, ok := token.Audience()
|
||||
if !ok || len(audience) != 1 || audience[0] == "" {
|
||||
introspectDto.Active = false
|
||||
return introspectDto, nil
|
||||
}
|
||||
if audience[0] != clientID {
|
||||
return introspectDto, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
|
||||
introspectDto.Active = true
|
||||
introspectDto.TokenType = "access_token"
|
||||
introspectDto.Audience = audience
|
||||
if token.Has("scope") {
|
||||
var (
|
||||
asString string
|
||||
@@ -519,9 +575,6 @@ func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecre
|
||||
if subject, ok := token.Subject(); ok {
|
||||
introspectDto.Subject = subject
|
||||
}
|
||||
if audience, ok := token.Audience(); ok {
|
||||
introspectDto.Audience = audience
|
||||
}
|
||||
if issuer, ok := token.Issuer(); ok {
|
||||
introspectDto.Issuer = issuer
|
||||
}
|
||||
@@ -532,12 +585,29 @@ func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecre
|
||||
return introspectDto, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) introspectRefreshToken(ctx context.Context, refreshToken string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
|
||||
func (s *OidcService) introspectRefreshToken(ctx context.Context, clientID string, refreshToken string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
|
||||
// Validate the signed refresh token and extract the actual token (which is a claim in the signed one)
|
||||
tokenUserID, tokenClientID, tokenRT, err := s.jwtService.VerifyOAuthRefreshToken(refreshToken)
|
||||
if err != nil {
|
||||
return introspectDto, fmt.Errorf("invalid refresh token: %w", err)
|
||||
}
|
||||
|
||||
// The ID of the client that made the call must match the client ID in the token
|
||||
if tokenClientID != clientID {
|
||||
return introspectDto, errors.New("invalid refresh token: client ID does not match")
|
||||
}
|
||||
|
||||
var storedRefreshToken model.OidcRefreshToken
|
||||
err = s.db.
|
||||
WithContext(ctx).
|
||||
Preload("User").
|
||||
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(refreshToken), datatype.DateTime(time.Now())).
|
||||
Where(
|
||||
"token = ? AND expires_at > ? AND user_id = ? AND client_id = ?",
|
||||
utils.CreateSha256Hash(tokenRT),
|
||||
datatype.DateTime(time.Now()),
|
||||
tokenUserID,
|
||||
tokenClientID,
|
||||
).
|
||||
First(&storedRefreshToken).
|
||||
Error
|
||||
if err != nil {
|
||||
@@ -1062,9 +1132,11 @@ func (s *OidcService) addCallbackURLToClient(ctx context.Context, client *model.
|
||||
}
|
||||
|
||||
func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.OidcDeviceAuthorizationRequestDto) (*dto.OidcDeviceAuthorizationResponseDto, error) {
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, s.db, dto.OidcCreateTokensDto{
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, s.db, ClientAuthCredentials{
|
||||
ClientID: input.ClientID,
|
||||
ClientSecret: input.ClientSecret,
|
||||
ClientAssertionType: input.ClientAssertionType,
|
||||
ClientAssertion: input.ClientAssertion,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1085,7 +1157,7 @@ func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.O
|
||||
DeviceCode: deviceCode,
|
||||
UserCode: userCode,
|
||||
Scope: input.Scope,
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(15 * time.Minute)),
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(DeviceCodeDuration)),
|
||||
IsAuthorized: false,
|
||||
ClientID: client.ID,
|
||||
}
|
||||
@@ -1099,7 +1171,7 @@ func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.O
|
||||
UserCode: userCode,
|
||||
VerificationURI: common.EnvConfig.AppURL + "/device",
|
||||
VerificationURIComplete: common.EnvConfig.AppURL + "/device?code=" + userCode,
|
||||
ExpiresIn: 900, // 15 minutes
|
||||
ExpiresIn: int(DeviceCodeDuration.Seconds()),
|
||||
Interval: 5,
|
||||
}, nil
|
||||
}
|
||||
@@ -1255,7 +1327,7 @@ func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, u
|
||||
refreshTokenHash := utils.CreateSha256Hash(refreshToken)
|
||||
|
||||
m := model.OidcRefreshToken{
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(30 * 24 * time.Hour)), // 30 days
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(RefreshTokenDuration)),
|
||||
Token: refreshTokenHash,
|
||||
ClientID: clientID,
|
||||
UserID: userID,
|
||||
@@ -1270,7 +1342,13 @@ func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, u
|
||||
return "", err
|
||||
}
|
||||
|
||||
return refreshToken, nil
|
||||
// Sign the refresh token
|
||||
signed, err := s.jwtService.GenerateOAuthRefreshToken(userID, clientID, refreshToken)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign refresh token: %w", err)
|
||||
}
|
||||
|
||||
return signed, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) createAuthorizedClientInternal(ctx context.Context, userID string, clientID string, scope string, tx *gorm.DB) error {
|
||||
@@ -1291,7 +1369,23 @@ func (s *OidcService) createAuthorizedClientInternal(ctx context.Context, userID
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input dto.OidcCreateTokensDto) (*model.OidcClient, error) {
|
||||
type ClientAuthCredentials struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
ClientAssertion string
|
||||
ClientAssertionType string
|
||||
}
|
||||
|
||||
func clientAuthCredentialsFromCreateTokensDto(d *dto.OidcCreateTokensDto) ClientAuthCredentials {
|
||||
return ClientAuthCredentials{
|
||||
ClientID: d.ClientID,
|
||||
ClientSecret: d.ClientSecret,
|
||||
ClientAssertion: d.ClientAssertion,
|
||||
ClientAssertionType: d.ClientAssertionType,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input ClientAuthCredentials) (*model.OidcClient, error) {
|
||||
// First, ensure we have a valid client ID
|
||||
if input.ClientID == "" {
|
||||
return nil, &common.OidcMissingClientCredentialsError{}
|
||||
@@ -1366,7 +1460,7 @@ func (s *OidcService) jwkSetForURL(ctx context.Context, url string) (set jwk.Set
|
||||
return jwks, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.Context, client *model.OidcClient, input dto.OidcCreateTokensDto) error {
|
||||
func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.Context, client *model.OidcClient, input ClientAuthCredentials) error {
|
||||
// First, parse the assertion JWT, without validating it, to check the issuer
|
||||
assertion := []byte(input.ClientAssertion)
|
||||
insecureToken, err := jwt.ParseInsecure(assertion)
|
||||
@@ -1466,12 +1560,18 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Commit the transaction before signing tokens to avoid locking the database for longer
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
idToken, err := s.jwtService.BuildIDToken(userClaims, clientID, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accessToken, err := s.jwtService.BuildOauthAccessToken(user, clientID)
|
||||
accessToken, err := s.jwtService.BuildOAuthAccessToken(user, clientID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1486,11 +1586,6 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &dto.OidcClientPreviewDto{
|
||||
IdToken: idTokenPayload,
|
||||
AccessToken: accessTokenPayload,
|
||||
@@ -1498,26 +1593,11 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) GetUserClaimsForClient(ctx context.Context, userID string, clientID string) (map[string]interface{}, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
claims, err := s.getUserClaimsForClientInternal(ctx, userID, clientID, s.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
func (s *OidcService) GetUserClaimsForClient(ctx context.Context, userID string, clientID string) (map[string]any, error) {
|
||||
return s.getUserClaimsForClientInternal(ctx, userID, clientID, s.db)
|
||||
}
|
||||
|
||||
func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID string, clientID string, tx *gorm.DB) (map[string]interface{}, error) {
|
||||
func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID string, clientID string, tx *gorm.DB) (map[string]any, error) {
|
||||
var authorizedOidcClient model.UserAuthorizedOidcClient
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
@@ -1532,14 +1612,13 @@ func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID
|
||||
|
||||
}
|
||||
|
||||
func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, authorizedClient *model.UserAuthorizedOidcClient, tx *gorm.DB) (map[string]interface{}, error) {
|
||||
func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, authorizedClient *model.UserAuthorizedOidcClient, tx *gorm.DB) (map[string]any, error) {
|
||||
user := authorizedClient.User
|
||||
scopes := strings.Split(authorizedClient.Scope, " ")
|
||||
|
||||
claims := map[string]interface{}{
|
||||
"sub": user.ID,
|
||||
}
|
||||
claims := make(map[string]any, 10)
|
||||
|
||||
claims["sub"] = user.ID
|
||||
if slices.Contains(scopes, "email") {
|
||||
claims["email"] = user.Email
|
||||
claims["email_verified"] = s.appConfigService.GetDbConfig().EmailsVerified.IsTrue()
|
||||
@@ -1553,19 +1632,13 @@ func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, aut
|
||||
claims["groups"] = userGroups
|
||||
}
|
||||
|
||||
profileClaims := map[string]interface{}{
|
||||
"given_name": user.FirstName,
|
||||
"family_name": user.LastName,
|
||||
"name": user.FullName(),
|
||||
"preferred_username": user.Username,
|
||||
"picture": common.EnvConfig.AppURL + "/api/users/" + user.ID + "/profile-picture.png",
|
||||
}
|
||||
|
||||
if slices.Contains(scopes, "profile") {
|
||||
// Add profile claims
|
||||
for k, v := range profileClaims {
|
||||
claims[k] = v
|
||||
}
|
||||
claims["given_name"] = user.FirstName
|
||||
claims["family_name"] = user.LastName
|
||||
claims["name"] = user.FullName()
|
||||
claims["preferred_username"] = user.Username
|
||||
claims["picture"] = common.EnvConfig.AppURL + "/api/users/" + user.ID + "/profile-picture.png"
|
||||
|
||||
// Add custom claims
|
||||
customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(ctx, user.ID, tx)
|
||||
@@ -1575,7 +1648,7 @@ func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, aut
|
||||
|
||||
for _, customClaim := range customClaims {
|
||||
// The value of the custom claim can be a JSON object or a string
|
||||
var jsonValue interface{}
|
||||
var jsonValue any
|
||||
err := json.Unmarshal([]byte(customClaim.Value), &jsonValue)
|
||||
if err == nil {
|
||||
// It's JSON, so we store it as an object
|
||||
|
||||
@@ -210,7 +210,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
t.Run("Confidential client", func(t *testing.T) {
|
||||
t.Run("Succeeds with valid secret", func(t *testing.T) {
|
||||
// Test with valid client credentials
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
||||
ClientID: confidentialClient.ID,
|
||||
ClientSecret: confidentialSecret,
|
||||
})
|
||||
@@ -221,7 +221,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
|
||||
t.Run("Fails with invalid secret", func(t *testing.T) {
|
||||
// Test with invalid client secret
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
||||
ClientID: confidentialClient.ID,
|
||||
ClientSecret: "invalid-secret",
|
||||
})
|
||||
@@ -232,7 +232,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
|
||||
t.Run("Fails with missing secret", func(t *testing.T) {
|
||||
// Test with missing client secret
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
||||
ClientID: confidentialClient.ID,
|
||||
})
|
||||
require.Error(t, err)
|
||||
@@ -245,7 +245,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
t.Run("Public client", func(t *testing.T) {
|
||||
t.Run("Succeeds with no credentials", func(t *testing.T) {
|
||||
// Public clients don't require client secret
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
||||
ClientID: publicClient.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -270,7 +270,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test with valid JWT assertion
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
||||
ClientID: federatedClient.ID,
|
||||
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
||||
ClientAssertion: string(signedToken),
|
||||
@@ -282,7 +282,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
|
||||
t.Run("Fails with malformed JWT", func(t *testing.T) {
|
||||
// Test with invalid JWT assertion (just a random string)
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
||||
ClientID: federatedClient.ID,
|
||||
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
||||
ClientAssertion: "invalid.jwt.token",
|
||||
@@ -311,7 +311,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test with invalid JWT assertion
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
||||
ClientID: federatedClient.ID,
|
||||
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
||||
ClientAssertion: string(signedToken),
|
||||
@@ -352,7 +352,7 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test with valid JWT assertion
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, dto.OidcCreateTokensDto{
|
||||
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
|
||||
ClientID: federatedClient.ID,
|
||||
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
||||
ClientAssertion: string(signedToken),
|
||||
|
||||
18
backend/internal/utils/http_util.go
Normal file
18
backend/internal/utils/http_util.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// BearerAuth returns the value of the bearer token in the Authorization header if present
|
||||
func BearerAuth(r *http.Request) (string, bool) {
|
||||
const prefix = "bearer "
|
||||
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if len(authHeader) >= len(prefix) && strings.ToLower(authHeader[:len(prefix)]) == prefix {
|
||||
return authHeader[len(prefix):], true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
65
backend/internal/utils/http_util_test.go
Normal file
65
backend/internal/utils/http_util_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBearerAuth(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
authHeader string
|
||||
expectedToken string
|
||||
expectedFound bool
|
||||
}{
|
||||
{
|
||||
name: "Valid bearer token",
|
||||
authHeader: "Bearer token123",
|
||||
expectedToken: "token123",
|
||||
expectedFound: true,
|
||||
},
|
||||
{
|
||||
name: "Valid bearer token with mixed case",
|
||||
authHeader: "beARer token456",
|
||||
expectedToken: "token456",
|
||||
expectedFound: true,
|
||||
},
|
||||
{
|
||||
name: "No bearer prefix",
|
||||
authHeader: "Basic dXNlcjpwYXNz",
|
||||
expectedToken: "",
|
||||
expectedFound: false,
|
||||
},
|
||||
{
|
||||
name: "Empty auth header",
|
||||
authHeader: "",
|
||||
expectedToken: "",
|
||||
expectedFound: false,
|
||||
},
|
||||
{
|
||||
name: "Bearer prefix only",
|
||||
authHeader: "Bearer ",
|
||||
expectedToken: "",
|
||||
expectedFound: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, "http://example.com", nil)
|
||||
require.NoError(t, err, "Failed to create request")
|
||||
|
||||
if tt.authHeader != "" {
|
||||
req.Header.Set("Authorization", tt.authHeader)
|
||||
}
|
||||
|
||||
token, found := BearerAuth(req)
|
||||
|
||||
assert.Equal(t, tt.expectedFound, found)
|
||||
assert.Equal(t, tt.expectedToken, token)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,8 +7,9 @@ import (
|
||||
)
|
||||
|
||||
func GetClaimsFromToken(token jwt.Token) (map[string]any, error) {
|
||||
claims := make(map[string]any)
|
||||
for _, key := range token.Keys() {
|
||||
keys := token.Keys()
|
||||
claims := make(map[string]any, len(keys))
|
||||
for _, key := range keys {
|
||||
var value any
|
||||
if err := token.Get(key, &value); err != nil {
|
||||
return nil, fmt.Errorf("failed to get claim %s: %w", key, err)
|
||||
|
||||
@@ -87,11 +87,13 @@ export const refreshTokens = [
|
||||
{
|
||||
token: 'ou87UDg249r1StBLYkMEqy9TXDbV5HmGuDpMcZDo',
|
||||
clientId: oidcClients.nextcloud.id,
|
||||
userId: 'f4b89dc2-62fb-46bf-9f5f-c34f4eafe93e',
|
||||
expired: false
|
||||
},
|
||||
{
|
||||
token: 'X4vqwtRyCUaq51UafHea4Fsg8Km6CAns6vp3tuX4',
|
||||
clientId: oidcClients.nextcloud.id,
|
||||
userId: 'f4b89dc2-62fb-46bf-9f5f-c34f4eafe93e',
|
||||
expired: true
|
||||
}
|
||||
];
|
||||
|
||||
@@ -156,11 +156,21 @@ test("End session with id token hint redirects to callback URL", async ({
|
||||
test("Successfully refresh tokens with valid refresh token", async ({
|
||||
request,
|
||||
}) => {
|
||||
const { token, clientId } = refreshTokens.filter(
|
||||
const { token, clientId, userId } = refreshTokens.filter(
|
||||
(token) => !token.expired
|
||||
)[0];
|
||||
const clientSecret = "w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY";
|
||||
|
||||
// Sign the refresh token
|
||||
const refreshToken = await request.post("/api/test/refreshtoken", {
|
||||
data: {
|
||||
rt: token,
|
||||
client: clientId,
|
||||
user: userId,
|
||||
}
|
||||
}).then((r) => r.text())
|
||||
|
||||
// Perform the exchange
|
||||
const refreshResponse = await request.post("/api/oidc/token", {
|
||||
headers: {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
@@ -168,7 +178,7 @@ test("Successfully refresh tokens with valid refresh token", async ({
|
||||
form: {
|
||||
grant_type: "refresh_token",
|
||||
client_id: clientId,
|
||||
refresh_token: token,
|
||||
refresh_token: refreshToken,
|
||||
client_secret: clientSecret,
|
||||
},
|
||||
});
|
||||
@@ -184,26 +194,25 @@ test("Successfully refresh tokens with valid refresh token", async ({
|
||||
expect(tokenData.refresh_token).not.toBe(token);
|
||||
});
|
||||
|
||||
test("Using refresh token invalidates it for future use", async ({
|
||||
|
||||
test("Refresh token fails when used for the wrong client", async ({
|
||||
request,
|
||||
}) => {
|
||||
const { token, clientId } = refreshTokens.filter(
|
||||
const { token, clientId, userId } = refreshTokens.filter(
|
||||
(token) => !token.expired
|
||||
)[0];
|
||||
const clientSecret = "w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY";
|
||||
|
||||
await request.post("/api/oidc/token", {
|
||||
headers: {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
},
|
||||
form: {
|
||||
grant_type: "refresh_token",
|
||||
client_id: clientId,
|
||||
refresh_token: token,
|
||||
client_secret: clientSecret,
|
||||
},
|
||||
});
|
||||
// Sign the refresh token
|
||||
const refreshToken = await request.post("/api/test/refreshtoken", {
|
||||
data: {
|
||||
rt: token,
|
||||
client: 'bad-client',
|
||||
user: userId,
|
||||
}
|
||||
}).then((r) => r.text())
|
||||
|
||||
// Perform the exchange
|
||||
const refreshResponse = await request.post("/api/oidc/token", {
|
||||
headers: {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
@@ -211,7 +220,86 @@ test("Using refresh token invalidates it for future use", async ({
|
||||
form: {
|
||||
grant_type: "refresh_token",
|
||||
client_id: clientId,
|
||||
refresh_token: token,
|
||||
refresh_token: refreshToken,
|
||||
client_secret: clientSecret,
|
||||
},
|
||||
});
|
||||
|
||||
expect(refreshResponse.status()).toBe(400);
|
||||
});
|
||||
|
||||
test("Refresh token fails when used for the wrong user", async ({
|
||||
request,
|
||||
}) => {
|
||||
const { token, clientId } = refreshTokens.filter(
|
||||
(token) => !token.expired
|
||||
)[0];
|
||||
const clientSecret = "w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY";
|
||||
|
||||
// Sign the refresh token
|
||||
const refreshToken = await request.post("/api/test/refreshtoken", {
|
||||
data: {
|
||||
rt: token,
|
||||
client: clientId,
|
||||
user: 'bad-user',
|
||||
}
|
||||
}).then((r) => r.text())
|
||||
|
||||
// Perform the exchange
|
||||
const refreshResponse = await request.post("/api/oidc/token", {
|
||||
headers: {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
},
|
||||
form: {
|
||||
grant_type: "refresh_token",
|
||||
client_id: clientId,
|
||||
refresh_token: refreshToken,
|
||||
client_secret: clientSecret,
|
||||
},
|
||||
});
|
||||
|
||||
expect(refreshResponse.status()).toBe(400);
|
||||
});
|
||||
|
||||
test("Using refresh token invalidates it for future use", async ({
|
||||
request,
|
||||
}) => {
|
||||
const { token, clientId, userId } = refreshTokens.filter(
|
||||
(token) => !token.expired
|
||||
)[0];
|
||||
const clientSecret = "w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY";
|
||||
|
||||
// Sign the refresh token
|
||||
const refreshToken = await request.post("/api/test/refreshtoken", {
|
||||
data: {
|
||||
rt: token,
|
||||
client: clientId,
|
||||
user: userId,
|
||||
}
|
||||
}).then((r) => r.text())
|
||||
|
||||
// Perform the exchange
|
||||
await request.post("/api/oidc/token", {
|
||||
headers: {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
},
|
||||
form: {
|
||||
grant_type: "refresh_token",
|
||||
client_id: clientId,
|
||||
refresh_token: refreshToken,
|
||||
client_secret: clientSecret,
|
||||
},
|
||||
});
|
||||
|
||||
// Try again
|
||||
const refreshResponse = await request.post("/api/oidc/token", {
|
||||
headers: {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
},
|
||||
form: {
|
||||
grant_type: "refresh_token",
|
||||
client_id: clientId,
|
||||
refresh_token: refreshToken,
|
||||
client_secret: clientSecret,
|
||||
},
|
||||
});
|
||||
@@ -219,11 +307,10 @@ test("Using refresh token invalidates it for future use", async ({
|
||||
});
|
||||
|
||||
test.describe("Introspection endpoint", () => {
|
||||
const client = oidcClients.nextcloud;
|
||||
test("without client_id and client_secret fails", async ({ request }) => {
|
||||
test("fails without client credentials", async ({ request }) => {
|
||||
const validAccessToken = await generateOauthAccessToken(
|
||||
users.tim,
|
||||
client.id
|
||||
oidcClients.nextcloud.id
|
||||
);
|
||||
const introspectionResponse = await request.post("/api/oidc/introspect", {
|
||||
headers: {
|
||||
@@ -237,20 +324,20 @@ test.describe("Introspection endpoint", () => {
|
||||
expect(introspectionResponse.status()).toBe(400);
|
||||
});
|
||||
|
||||
test("with client_id and client_secret succeeds", async ({
|
||||
test("succeeds with client credentials", async ({
|
||||
request,
|
||||
baseURL,
|
||||
}) => {
|
||||
const validAccessToken = await generateOauthAccessToken(
|
||||
users.tim,
|
||||
client.id
|
||||
oidcClients.nextcloud.id
|
||||
);
|
||||
const introspectionResponse = await request.post("/api/oidc/introspect", {
|
||||
headers: {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
Authorization:
|
||||
"Basic " +
|
||||
Buffer.from(`${client.id}:${client.secret}`).toString("base64"),
|
||||
Buffer.from(`${oidcClients.nextcloud.id}:${oidcClients.nextcloud.secret}`).toString("base64"),
|
||||
},
|
||||
form: {
|
||||
token: validAccessToken,
|
||||
@@ -266,18 +353,102 @@ test.describe("Introspection endpoint", () => {
|
||||
expect(introspectionBody.aud).toStrictEqual([oidcClients.nextcloud.id]);
|
||||
});
|
||||
|
||||
test("succeeds with federated client credentials", async ({
|
||||
page,
|
||||
request,
|
||||
baseURL,
|
||||
}) => {
|
||||
const validAccessToken = await generateOauthAccessToken(
|
||||
users.tim,
|
||||
oidcClients.federated.id
|
||||
);
|
||||
const clientAssertion = await oidcUtil.getClientAssertion(page, oidcClients.federated.federatedJWT);
|
||||
const introspectionResponse = await request.post("/api/oidc/introspect", {
|
||||
headers: {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
Authorization: "Bearer " + clientAssertion,
|
||||
},
|
||||
form: {
|
||||
token: validAccessToken,
|
||||
},
|
||||
});
|
||||
|
||||
expect(introspectionResponse.status()).toBe(200);
|
||||
const introspectionBody = await introspectionResponse.json();
|
||||
expect(introspectionBody.active).toBe(true);
|
||||
expect(introspectionBody.token_type).toBe("access_token");
|
||||
expect(introspectionBody.iss).toBe(baseURL);
|
||||
expect(introspectionBody.sub).toBe(users.tim.id);
|
||||
expect(introspectionBody.aud).toStrictEqual([oidcClients.federated.id]);
|
||||
});
|
||||
|
||||
test("fails with client credentials for wrong app", async ({
|
||||
request,
|
||||
}) => {
|
||||
const validAccessToken = await generateOauthAccessToken(
|
||||
users.tim,
|
||||
oidcClients.nextcloud.id
|
||||
);
|
||||
const introspectionResponse = await request.post("/api/oidc/introspect", {
|
||||
headers: {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
Authorization:
|
||||
"Basic " +
|
||||
Buffer.from(`${oidcClients.immich.id}:${oidcClients.immich.secret}`).toString("base64"),
|
||||
},
|
||||
form: {
|
||||
token: validAccessToken,
|
||||
},
|
||||
});
|
||||
|
||||
expect(introspectionResponse.status()).toBe(400);
|
||||
});
|
||||
|
||||
test("fails with federated credentials for wrong app", async ({
|
||||
page,
|
||||
request,
|
||||
}) => {
|
||||
const validAccessToken = await generateOauthAccessToken(
|
||||
users.tim,
|
||||
oidcClients.nextcloud.id
|
||||
);
|
||||
const clientAssertion = await oidcUtil.getClientAssertion(page, oidcClients.federated.federatedJWT);
|
||||
const introspectionResponse = await request.post("/api/oidc/introspect", {
|
||||
headers: {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
Authorization: "Bearer " + clientAssertion,
|
||||
},
|
||||
form: {
|
||||
token: validAccessToken,
|
||||
},
|
||||
});
|
||||
|
||||
expect(introspectionResponse.status()).toBe(400);
|
||||
});
|
||||
|
||||
test("non-expired refresh_token can be verified", async ({ request }) => {
|
||||
const { token } = refreshTokens.filter((token) => !token.expired)[0];
|
||||
const { token, clientId, userId } = refreshTokens.filter(
|
||||
(token) => !token.expired
|
||||
)[0];
|
||||
|
||||
// Sign the refresh token
|
||||
const refreshToken = await request.post("/api/test/refreshtoken", {
|
||||
data: {
|
||||
rt: token,
|
||||
client: clientId,
|
||||
user: userId,
|
||||
}
|
||||
}).then((r) => r.text())
|
||||
|
||||
const introspectionResponse = await request.post("/api/oidc/introspect", {
|
||||
headers: {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
Authorization:
|
||||
"Basic " +
|
||||
Buffer.from(`${client.id}:${client.secret}`).toString("base64"),
|
||||
Buffer.from(`${oidcClients.nextcloud.id}:${oidcClients.nextcloud.secret}`).toString("base64"),
|
||||
},
|
||||
form: {
|
||||
token: token,
|
||||
token: refreshToken,
|
||||
},
|
||||
});
|
||||
|
||||
@@ -288,17 +459,28 @@ test.describe("Introspection endpoint", () => {
|
||||
});
|
||||
|
||||
test("expired refresh_token can be verified", async ({ request }) => {
|
||||
const { token } = refreshTokens.filter((token) => token.expired)[0];
|
||||
const { token, clientId, userId } = refreshTokens.filter(
|
||||
(token) => token.expired
|
||||
)[0];
|
||||
|
||||
// Sign the refresh token
|
||||
const refreshToken = await request.post("/api/test/refreshtoken", {
|
||||
data: {
|
||||
rt: token,
|
||||
client: clientId,
|
||||
user: userId,
|
||||
}
|
||||
}).then((r) => r.text())
|
||||
|
||||
const introspectionResponse = await request.post("/api/oidc/introspect", {
|
||||
headers: {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
Authorization:
|
||||
"Basic " +
|
||||
Buffer.from(`${client.id}:${client.secret}`).toString("base64"),
|
||||
Buffer.from(`${oidcClients.nextcloud.id}:${oidcClients.nextcloud.secret}`).toString("base64"),
|
||||
},
|
||||
form: {
|
||||
token: token,
|
||||
token: refreshToken,
|
||||
},
|
||||
});
|
||||
|
||||
@@ -310,7 +492,7 @@ test.describe("Introspection endpoint", () => {
|
||||
test("expired access_token can't be verified", async ({ request }) => {
|
||||
const expiredAccessToken = await generateOauthAccessToken(
|
||||
users.tim,
|
||||
client.id,
|
||||
oidcClients.nextcloud.id,
|
||||
true
|
||||
);
|
||||
const introspectionResponse = await request.post("/api/oidc/introspect", {
|
||||
|
||||
Reference in New Issue
Block a user