diff --git a/backend/internal/controller/oidc_controller.go b/backend/internal/controller/oidc_controller.go index 82317b6c..f9b9dece 100644 --- a/backend/internal/controller/oidc_controller.go +++ b/backend/internal/controller/oidc_controller.go @@ -828,7 +828,7 @@ func (oc *OidcController) getClientPreviewHandler(c *gin.Context) { return } - preview, err := oc.oidcService.GetClientPreview(c.Request.Context(), clientID, userID, scopes) + preview, err := oc.oidcService.GetClientPreview(c.Request.Context(), clientID, userID, strings.Split(scopes, " ")) if err != nil { _ = c.Error(err) return diff --git a/backend/internal/model/oidc.go b/backend/internal/model/oidc.go index d9731d62..3521e7cf 100644 --- a/backend/internal/model/oidc.go +++ b/backend/internal/model/oidc.go @@ -4,6 +4,7 @@ import ( "database/sql/driver" "encoding/json" "fmt" + "strings" "gorm.io/gorm" @@ -21,6 +22,14 @@ type UserAuthorizedOidcClient struct { Client OidcClient } +func (c UserAuthorizedOidcClient) Scopes() []string { + if len(c.Scope) == 0 { + return []string{} + } + + return strings.Split(c.Scope, " ") +} + type OidcAuthorizationCode struct { Base @@ -72,6 +81,14 @@ type OidcRefreshToken struct { Client OidcClient } +func (c OidcRefreshToken) Scopes() []string { + if len(c.Scope) == 0 { + return []string{} + } + + return strings.Split(c.Scope, " ") +} + func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) { // Compute HasLogo field c.HasLogo = c.ImageType != nil && *c.ImageType != "" diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go index f8b0c638..eaf8c474 100644 --- a/backend/internal/service/oidc_service.go +++ b/backend/internal/service/oidc_service.go @@ -479,10 +479,9 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto ). First(&storedRefreshToken). Error - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{} - } + if errors.Is(err, gorm.ErrRecordNotFound) { + return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{} + } else if err != nil { return CreatedTokens{}, err } @@ -497,6 +496,19 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto return CreatedTokens{}, err } + // Load the profile, which we need for the ID token + userClaims, err := s.getUserClaims(ctx, &storedRefreshToken.User, storedRefreshToken.Scopes(), tx) + if err != nil { + return CreatedTokens{}, err + } + + // Generate a new ID token + // There's no nonce here because we don't have one with the refresh token, but that's not required + idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, "") + if err != nil { + return CreatedTokens{}, err + } + // Generate a new refresh token and invalidate the old one newRefreshToken, err := s.createRefreshToken(ctx, input.ClientID, storedRefreshToken.UserID, storedRefreshToken.Scope, tx) if err != nil { @@ -520,6 +532,7 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto return CreatedTokens{ AccessToken: accessToken, RefreshToken: newRefreshToken, + IdToken: idToken, ExpiresIn: AccessTokenDuration, }, nil } @@ -1726,7 +1739,7 @@ func (s *OidcService) extractClientIDFromAssertion(assertion string) (string, er return sub, nil } -func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, userID string, scopes string) (*dto.OidcClientPreviewDto, error) { +func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, userID string, scopes []string) (*dto.OidcClientPreviewDto, error) { tx := s.db.Begin() defer func() { tx.Rollback() @@ -1751,14 +1764,7 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use return nil, &common.OidcAccessDeniedError{} } - dummyAuthorizedClient := model.UserAuthorizedOidcClient{ - UserID: userID, - ClientID: clientID, - Scope: scopes, - User: user, - } - - userClaims, err := s.getUserClaimsFromAuthorizedClient(ctx, &dummyAuthorizedClient, tx) + userClaims, err := s.getUserClaims(ctx, &user, scopes, tx) if err != nil { return nil, err } @@ -1811,14 +1817,10 @@ func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID return nil, err } - return s.getUserClaimsFromAuthorizedClient(ctx, &authorizedOidcClient, tx) - + return s.getUserClaims(ctx, &authorizedOidcClient.User, authorizedOidcClient.Scopes(), tx) } -func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, authorizedClient *model.UserAuthorizedOidcClient, tx *gorm.DB) (map[string]any, error) { - user := authorizedClient.User - scopes := strings.Split(authorizedClient.Scope, " ") - +func (s *OidcService) getUserClaims(ctx context.Context, user *model.User, scopes []string, tx *gorm.DB) (map[string]any, error) { claims := make(map[string]any, 10) claims["sub"] = user.ID diff --git a/tests/specs/oidc.spec.ts b/tests/specs/oidc.spec.ts index ab24f5e6..26fb244b 100644 --- a/tests/specs/oidc.spec.ts +++ b/tests/specs/oidc.spec.ts @@ -167,6 +167,7 @@ test('Successfully refresh tokens with valid refresh token', async ({ request }) const tokenData = await refreshResponse.json(); expect(tokenData.access_token).toBeDefined(); expect(tokenData.refresh_token).toBeDefined(); + expect(tokenData.id_token).toBeDefined(); expect(tokenData.token_type).toBe('Bearer'); expect(tokenData.expires_in).toBe(3600);