feat: allow introspection and device code endpoints to use Federated Client Credentials (#640)

This commit is contained in:
Alessandro (Ale) Segala
2025-06-09 12:17:55 -07:00
committed by GitHub
parent df5c1ed1f8
commit b62b61fb01
13 changed files with 802 additions and 139 deletions

View File

@@ -40,6 +40,9 @@ const (
GrantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
ClientAssertionTypeJWTBearer = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" //nolint:gosec
RefreshTokenDuration = 30 * 24 * time.Hour // 30 days
DeviceCodeDuration = 15 * time.Minute
)
type OidcService struct {
@@ -252,7 +255,7 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O
tx.Rollback()
}()
_, err := s.verifyClientCredentialsInternal(ctx, tx, input)
_, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input))
if err != nil {
return CreatedTokens{}, err
}
@@ -303,7 +306,7 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O
return CreatedTokens{}, err
}
accessToken, err := s.jwtService.GenerateOauthAccessToken(deviceAuth.User, input.ClientID)
accessToken, err := s.jwtService.GenerateOAuthAccessToken(deviceAuth.User, input.ClientID)
if err != nil {
return CreatedTokens{}, err
}
@@ -333,7 +336,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu
tx.Rollback()
}()
client, err := s.verifyClientCredentialsInternal(ctx, tx, input)
client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input))
if err != nil {
return CreatedTokens{}, err
}
@@ -375,7 +378,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu
return CreatedTokens{}, err
}
accessToken, err := s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, input.ClientID)
accessToken, err := s.jwtService.GenerateOAuthAccessToken(authorizationCodeMetaData.User, input.ClientID)
if err != nil {
return CreatedTokens{}, err
}
@@ -406,22 +409,39 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
return CreatedTokens{}, &common.OidcMissingRefreshTokenError{}
}
// Validate the signed refresh token and extract the actual token (which is a claim in the signed one)
userID, clientID, rt, err := s.jwtService.VerifyOAuthRefreshToken(input.RefreshToken)
if err != nil {
return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
}
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
_, err := s.verifyClientCredentialsInternal(ctx, tx, input)
client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input))
if err != nil {
return CreatedTokens{}, err
}
// The ID of the client that made the call must match the client ID in the token
if client.ID != clientID {
return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
}
// Verify refresh token
var storedRefreshToken model.OidcRefreshToken
err = tx.
WithContext(ctx).
Preload("User").
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(input.RefreshToken), datatype.DateTime(time.Now())).
Where(
"token = ? AND expires_at > ? AND user_id = ? AND client_id = ?",
utils.CreateSha256Hash(rt),
datatype.DateTime(time.Now()),
userID,
input.ClientID,
).
First(&storedRefreshToken).
Error
if err != nil {
@@ -437,7 +457,7 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
}
// Generate a new access token
accessToken, err := s.jwtService.GenerateOauthAccessToken(storedRefreshToken.User, input.ClientID)
accessToken, err := s.jwtService.GenerateOAuthAccessToken(storedRefreshToken.User, input.ClientID)
if err != nil {
return CreatedTokens{}, err
}
@@ -469,33 +489,69 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
}, nil
}
func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecret, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
if clientID == "" || clientSecret == "" {
func (s *OidcService) IntrospectToken(ctx context.Context, creds ClientAuthCredentials, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
// Get the type of the token and the client ID
tokenType, token, err := s.jwtService.GetTokenType(tokenString)
if err != nil {
// We just treat the token as invalid
introspectDto.Active = false
return introspectDto, nil //nolint:nilerr
}
// If we don't have a client ID, get it from the token
// Otherwise, we need to make sure that the client ID passed as credential matches
tokenAudiences, _ := token.Audience()
if len(tokenAudiences) != 1 || tokenAudiences[0] == "" {
// We just treat the token as invalid
introspectDto.Active = false
return introspectDto, nil
}
if creds.ClientID == "" {
creds.ClientID = tokenAudiences[0]
} else if creds.ClientID != tokenAudiences[0] {
return introspectDto, &common.OidcMissingClientCredentialsError{}
}
_, err = s.verifyClientCredentialsInternal(ctx, s.db, dto.OidcCreateTokensDto{
ClientID: clientID,
ClientSecret: clientSecret,
})
// Verify the credentials for the call
client, err := s.verifyClientCredentialsInternal(ctx, s.db, creds)
if err != nil {
return introspectDto, err
}
token, err := s.jwtService.VerifyOauthAccessToken(tokenString)
if err != nil {
if errors.Is(err, jwt.ParseError()) {
// It's apparently not a valid JWT token, so we check if it's a valid refresh_token.
return s.introspectRefreshToken(ctx, tokenString)
}
// Every failure we get means the token is invalid. Nothing more to do with the error.
// Introspect the token
switch tokenType {
case OAuthAccessTokenJWTType:
return s.introspectAccessToken(client.ID, tokenString)
case OAuthRefreshTokenJWTType:
return s.introspectRefreshToken(ctx, client.ID, tokenString)
default:
// We just treat the token as invalid
introspectDto.Active = false
return introspectDto, nil
}
}
func (s *OidcService) introspectAccessToken(clientID string, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
token, err := s.jwtService.VerifyOAuthAccessToken(tokenString)
if err != nil {
// Every failure we get means the token is invalid. Nothing more to do with the error.
introspectDto.Active = false
return introspectDto, nil //nolint:nilerr
}
// The ID of the client that made the request must match the client ID in the token
audience, ok := token.Audience()
if !ok || len(audience) != 1 || audience[0] == "" {
introspectDto.Active = false
return introspectDto, nil
}
if audience[0] != clientID {
return introspectDto, &common.OidcMissingClientCredentialsError{}
}
introspectDto.Active = true
introspectDto.TokenType = "access_token"
introspectDto.Audience = audience
if token.Has("scope") {
var (
asString string
@@ -519,9 +575,6 @@ func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecre
if subject, ok := token.Subject(); ok {
introspectDto.Subject = subject
}
if audience, ok := token.Audience(); ok {
introspectDto.Audience = audience
}
if issuer, ok := token.Issuer(); ok {
introspectDto.Issuer = issuer
}
@@ -532,12 +585,29 @@ func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecre
return introspectDto, nil
}
func (s *OidcService) introspectRefreshToken(ctx context.Context, refreshToken string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
func (s *OidcService) introspectRefreshToken(ctx context.Context, clientID string, refreshToken string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
// Validate the signed refresh token and extract the actual token (which is a claim in the signed one)
tokenUserID, tokenClientID, tokenRT, err := s.jwtService.VerifyOAuthRefreshToken(refreshToken)
if err != nil {
return introspectDto, fmt.Errorf("invalid refresh token: %w", err)
}
// The ID of the client that made the call must match the client ID in the token
if tokenClientID != clientID {
return introspectDto, errors.New("invalid refresh token: client ID does not match")
}
var storedRefreshToken model.OidcRefreshToken
err = s.db.
WithContext(ctx).
Preload("User").
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(refreshToken), datatype.DateTime(time.Now())).
Where(
"token = ? AND expires_at > ? AND user_id = ? AND client_id = ?",
utils.CreateSha256Hash(tokenRT),
datatype.DateTime(time.Now()),
tokenUserID,
tokenClientID,
).
First(&storedRefreshToken).
Error
if err != nil {
@@ -1062,9 +1132,11 @@ func (s *OidcService) addCallbackURLToClient(ctx context.Context, client *model.
}
func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.OidcDeviceAuthorizationRequestDto) (*dto.OidcDeviceAuthorizationResponseDto, error) {
client, err := s.verifyClientCredentialsInternal(ctx, s.db, dto.OidcCreateTokensDto{
ClientID: input.ClientID,
ClientSecret: input.ClientSecret,
client, err := s.verifyClientCredentialsInternal(ctx, s.db, ClientAuthCredentials{
ClientID: input.ClientID,
ClientSecret: input.ClientSecret,
ClientAssertionType: input.ClientAssertionType,
ClientAssertion: input.ClientAssertion,
})
if err != nil {
return nil, err
@@ -1085,7 +1157,7 @@ func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.O
DeviceCode: deviceCode,
UserCode: userCode,
Scope: input.Scope,
ExpiresAt: datatype.DateTime(time.Now().Add(15 * time.Minute)),
ExpiresAt: datatype.DateTime(time.Now().Add(DeviceCodeDuration)),
IsAuthorized: false,
ClientID: client.ID,
}
@@ -1099,7 +1171,7 @@ func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.O
UserCode: userCode,
VerificationURI: common.EnvConfig.AppURL + "/device",
VerificationURIComplete: common.EnvConfig.AppURL + "/device?code=" + userCode,
ExpiresIn: 900, // 15 minutes
ExpiresIn: int(DeviceCodeDuration.Seconds()),
Interval: 5,
}, nil
}
@@ -1255,7 +1327,7 @@ func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, u
refreshTokenHash := utils.CreateSha256Hash(refreshToken)
m := model.OidcRefreshToken{
ExpiresAt: datatype.DateTime(time.Now().Add(30 * 24 * time.Hour)), // 30 days
ExpiresAt: datatype.DateTime(time.Now().Add(RefreshTokenDuration)),
Token: refreshTokenHash,
ClientID: clientID,
UserID: userID,
@@ -1270,7 +1342,13 @@ func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, u
return "", err
}
return refreshToken, nil
// Sign the refresh token
signed, err := s.jwtService.GenerateOAuthRefreshToken(userID, clientID, refreshToken)
if err != nil {
return "", fmt.Errorf("failed to sign refresh token: %w", err)
}
return signed, nil
}
func (s *OidcService) createAuthorizedClientInternal(ctx context.Context, userID string, clientID string, scope string, tx *gorm.DB) error {
@@ -1291,7 +1369,23 @@ func (s *OidcService) createAuthorizedClientInternal(ctx context.Context, userID
return err
}
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input dto.OidcCreateTokensDto) (*model.OidcClient, error) {
type ClientAuthCredentials struct {
ClientID string
ClientSecret string
ClientAssertion string
ClientAssertionType string
}
func clientAuthCredentialsFromCreateTokensDto(d *dto.OidcCreateTokensDto) ClientAuthCredentials {
return ClientAuthCredentials{
ClientID: d.ClientID,
ClientSecret: d.ClientSecret,
ClientAssertion: d.ClientAssertion,
ClientAssertionType: d.ClientAssertionType,
}
}
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input ClientAuthCredentials) (*model.OidcClient, error) {
// First, ensure we have a valid client ID
if input.ClientID == "" {
return nil, &common.OidcMissingClientCredentialsError{}
@@ -1366,7 +1460,7 @@ func (s *OidcService) jwkSetForURL(ctx context.Context, url string) (set jwk.Set
return jwks, nil
}
func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.Context, client *model.OidcClient, input dto.OidcCreateTokensDto) error {
func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.Context, client *model.OidcClient, input ClientAuthCredentials) error {
// First, parse the assertion JWT, without validating it, to check the issuer
assertion := []byte(input.ClientAssertion)
insecureToken, err := jwt.ParseInsecure(assertion)
@@ -1466,12 +1560,18 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use
return nil, err
}
// Commit the transaction before signing tokens to avoid locking the database for longer
err = tx.Commit().Error
if err != nil {
return nil, err
}
idToken, err := s.jwtService.BuildIDToken(userClaims, clientID, "")
if err != nil {
return nil, err
}
accessToken, err := s.jwtService.BuildOauthAccessToken(user, clientID)
accessToken, err := s.jwtService.BuildOAuthAccessToken(user, clientID)
if err != nil {
return nil, err
}
@@ -1486,11 +1586,6 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use
return nil, err
}
err = tx.Commit().Error
if err != nil {
return nil, err
}
return &dto.OidcClientPreviewDto{
IdToken: idTokenPayload,
AccessToken: accessTokenPayload,
@@ -1498,26 +1593,11 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use
}, nil
}
func (s *OidcService) GetUserClaimsForClient(ctx context.Context, userID string, clientID string) (map[string]interface{}, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
claims, err := s.getUserClaimsForClientInternal(ctx, userID, clientID, s.db)
if err != nil {
return nil, err
}
err = tx.Commit().Error
if err != nil {
return nil, err
}
return claims, nil
func (s *OidcService) GetUserClaimsForClient(ctx context.Context, userID string, clientID string) (map[string]any, error) {
return s.getUserClaimsForClientInternal(ctx, userID, clientID, s.db)
}
func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID string, clientID string, tx *gorm.DB) (map[string]interface{}, error) {
func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID string, clientID string, tx *gorm.DB) (map[string]any, error) {
var authorizedOidcClient model.UserAuthorizedOidcClient
err := tx.
WithContext(ctx).
@@ -1532,14 +1612,13 @@ func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID
}
func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, authorizedClient *model.UserAuthorizedOidcClient, tx *gorm.DB) (map[string]interface{}, error) {
func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, authorizedClient *model.UserAuthorizedOidcClient, tx *gorm.DB) (map[string]any, error) {
user := authorizedClient.User
scopes := strings.Split(authorizedClient.Scope, " ")
claims := map[string]interface{}{
"sub": user.ID,
}
claims := make(map[string]any, 10)
claims["sub"] = user.ID
if slices.Contains(scopes, "email") {
claims["email"] = user.Email
claims["email_verified"] = s.appConfigService.GetDbConfig().EmailsVerified.IsTrue()
@@ -1553,19 +1632,13 @@ func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, aut
claims["groups"] = userGroups
}
profileClaims := map[string]interface{}{
"given_name": user.FirstName,
"family_name": user.LastName,
"name": user.FullName(),
"preferred_username": user.Username,
"picture": common.EnvConfig.AppURL + "/api/users/" + user.ID + "/profile-picture.png",
}
if slices.Contains(scopes, "profile") {
// Add profile claims
for k, v := range profileClaims {
claims[k] = v
}
claims["given_name"] = user.FirstName
claims["family_name"] = user.LastName
claims["name"] = user.FullName()
claims["preferred_username"] = user.Username
claims["picture"] = common.EnvConfig.AppURL + "/api/users/" + user.ID + "/profile-picture.png"
// Add custom claims
customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(ctx, user.ID, tx)
@@ -1575,7 +1648,7 @@ func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, aut
for _, customClaim := range customClaims {
// The value of the custom claim can be a JSON object or a string
var jsonValue interface{}
var jsonValue any
err := json.Unmarshal([]byte(customClaim.Value), &jsonValue)
if err == nil {
// It's JSON, so we store it as an object