mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-18 17:23:26 +03:00
feat: return new id_token when using refresh token (#925)
This commit is contained in:
committed by
GitHub
parent
6c696b46c8
commit
307caaa3ef
@@ -828,7 +828,7 @@ func (oc *OidcController) getClientPreviewHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
preview, err := oc.oidcService.GetClientPreview(c.Request.Context(), clientID, userID, scopes)
|
preview, err := oc.oidcService.GetClientPreview(c.Request.Context(), clientID, userID, strings.Split(scopes, " "))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = c.Error(err)
|
_ = c.Error(err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
@@ -21,6 +22,14 @@ type UserAuthorizedOidcClient struct {
|
|||||||
Client OidcClient
|
Client OidcClient
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c UserAuthorizedOidcClient) Scopes() []string {
|
||||||
|
if len(c.Scope) == 0 {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Split(c.Scope, " ")
|
||||||
|
}
|
||||||
|
|
||||||
type OidcAuthorizationCode struct {
|
type OidcAuthorizationCode struct {
|
||||||
Base
|
Base
|
||||||
|
|
||||||
@@ -72,6 +81,14 @@ type OidcRefreshToken struct {
|
|||||||
Client OidcClient
|
Client OidcClient
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c OidcRefreshToken) Scopes() []string {
|
||||||
|
if len(c.Scope) == 0 {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Split(c.Scope, " ")
|
||||||
|
}
|
||||||
|
|
||||||
func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) {
|
func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) {
|
||||||
// Compute HasLogo field
|
// Compute HasLogo field
|
||||||
c.HasLogo = c.ImageType != nil && *c.ImageType != ""
|
c.HasLogo = c.ImageType != nil && *c.ImageType != ""
|
||||||
|
|||||||
@@ -479,10 +479,9 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
|
|||||||
).
|
).
|
||||||
First(&storedRefreshToken).
|
First(&storedRefreshToken).
|
||||||
Error
|
Error
|
||||||
if err != nil {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
|
||||||
return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
|
} else if err != nil {
|
||||||
}
|
|
||||||
return CreatedTokens{}, err
|
return CreatedTokens{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -497,6 +496,19 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
|
|||||||
return CreatedTokens{}, err
|
return CreatedTokens{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Load the profile, which we need for the ID token
|
||||||
|
userClaims, err := s.getUserClaims(ctx, &storedRefreshToken.User, storedRefreshToken.Scopes(), tx)
|
||||||
|
if err != nil {
|
||||||
|
return CreatedTokens{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a new ID token
|
||||||
|
// There's no nonce here because we don't have one with the refresh token, but that's not required
|
||||||
|
idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, "")
|
||||||
|
if err != nil {
|
||||||
|
return CreatedTokens{}, err
|
||||||
|
}
|
||||||
|
|
||||||
// Generate a new refresh token and invalidate the old one
|
// Generate a new refresh token and invalidate the old one
|
||||||
newRefreshToken, err := s.createRefreshToken(ctx, input.ClientID, storedRefreshToken.UserID, storedRefreshToken.Scope, tx)
|
newRefreshToken, err := s.createRefreshToken(ctx, input.ClientID, storedRefreshToken.UserID, storedRefreshToken.Scope, tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -520,6 +532,7 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
|
|||||||
return CreatedTokens{
|
return CreatedTokens{
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
RefreshToken: newRefreshToken,
|
RefreshToken: newRefreshToken,
|
||||||
|
IdToken: idToken,
|
||||||
ExpiresIn: AccessTokenDuration,
|
ExpiresIn: AccessTokenDuration,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@@ -1726,7 +1739,7 @@ func (s *OidcService) extractClientIDFromAssertion(assertion string) (string, er
|
|||||||
return sub, nil
|
return sub, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, userID string, scopes string) (*dto.OidcClientPreviewDto, error) {
|
func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, userID string, scopes []string) (*dto.OidcClientPreviewDto, error) {
|
||||||
tx := s.db.Begin()
|
tx := s.db.Begin()
|
||||||
defer func() {
|
defer func() {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
@@ -1751,14 +1764,7 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use
|
|||||||
return nil, &common.OidcAccessDeniedError{}
|
return nil, &common.OidcAccessDeniedError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
dummyAuthorizedClient := model.UserAuthorizedOidcClient{
|
userClaims, err := s.getUserClaims(ctx, &user, scopes, tx)
|
||||||
UserID: userID,
|
|
||||||
ClientID: clientID,
|
|
||||||
Scope: scopes,
|
|
||||||
User: user,
|
|
||||||
}
|
|
||||||
|
|
||||||
userClaims, err := s.getUserClaimsFromAuthorizedClient(ctx, &dummyAuthorizedClient, tx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1811,14 +1817,10 @@ func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.getUserClaimsFromAuthorizedClient(ctx, &authorizedOidcClient, tx)
|
return s.getUserClaims(ctx, &authorizedOidcClient.User, authorizedOidcClient.Scopes(), tx)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, authorizedClient *model.UserAuthorizedOidcClient, tx *gorm.DB) (map[string]any, error) {
|
func (s *OidcService) getUserClaims(ctx context.Context, user *model.User, scopes []string, tx *gorm.DB) (map[string]any, error) {
|
||||||
user := authorizedClient.User
|
|
||||||
scopes := strings.Split(authorizedClient.Scope, " ")
|
|
||||||
|
|
||||||
claims := make(map[string]any, 10)
|
claims := make(map[string]any, 10)
|
||||||
|
|
||||||
claims["sub"] = user.ID
|
claims["sub"] = user.ID
|
||||||
|
|||||||
@@ -167,6 +167,7 @@ test('Successfully refresh tokens with valid refresh token', async ({ request })
|
|||||||
const tokenData = await refreshResponse.json();
|
const tokenData = await refreshResponse.json();
|
||||||
expect(tokenData.access_token).toBeDefined();
|
expect(tokenData.access_token).toBeDefined();
|
||||||
expect(tokenData.refresh_token).toBeDefined();
|
expect(tokenData.refresh_token).toBeDefined();
|
||||||
|
expect(tokenData.id_token).toBeDefined();
|
||||||
expect(tokenData.token_type).toBe('Bearer');
|
expect(tokenData.token_type).toBe('Bearer');
|
||||||
expect(tokenData.expires_in).toBe(3600);
|
expect(tokenData.expires_in).toBe(3600);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user