mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-13 16:53:01 +03:00
fix: updating scopes of an authorized client fails with Postgres
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"gorm.io/gorm/clause"
|
||||
"log"
|
||||
"mime/multipart"
|
||||
"os"
|
||||
@@ -94,24 +95,8 @@ func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClie
|
||||
|
||||
// If the user has not authorized the client, create a new authorization in the database
|
||||
if !hasAuthorizedClient {
|
||||
userAuthorizedClient := model.UserAuthorizedOidcClient{
|
||||
UserID: userID,
|
||||
ClientID: input.ClientID,
|
||||
Scope: input.Scope,
|
||||
}
|
||||
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Create(&userAuthorizedClient).
|
||||
Error
|
||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
// The client has already been authorized but with a different scope so we need to update the scope
|
||||
if err := tx.
|
||||
WithContext(ctx).
|
||||
Model(&userAuthorizedClient).Update("scope", input.Scope).Error; err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
} else if err != nil {
|
||||
err := s.createAuthorizedClientInternal(ctx, userID, input.ClientID, input.Scope, tx)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
@@ -201,7 +186,7 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, deviceCode,
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
_, err = s.VerifyClientCredentials(ctx, clientID, clientSecret, tx)
|
||||
_, err = s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, tx)
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
}
|
||||
@@ -269,7 +254,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, code
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
client, err := s.VerifyClientCredentials(ctx, clientID, clientSecret, tx)
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, tx)
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
}
|
||||
@@ -342,7 +327,7 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshTo
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
_, err = s.VerifyClientCredentials(ctx, clientID, clientSecret, tx)
|
||||
_, err = s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, tx)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
}
|
||||
@@ -401,7 +386,7 @@ func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecre
|
||||
return introspectDto, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
|
||||
_, err = s.VerifyClientCredentials(ctx, clientID, clientSecret, s.db)
|
||||
_, err = s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, s.db)
|
||||
if err != nil {
|
||||
return introspectDto, err
|
||||
}
|
||||
@@ -999,7 +984,7 @@ func (s *OidcService) getCallbackURL(urls []string, inputCallbackURL string) (ca
|
||||
}
|
||||
|
||||
func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.OidcDeviceAuthorizationRequestDto) (*dto.OidcDeviceAuthorizationResponseDto, error) {
|
||||
client, err := s.VerifyClientCredentials(ctx, input.ClientID, input.ClientSecret, s.db)
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, s.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1095,23 +1080,11 @@ func (s *OidcService) VerifyDeviceCode(ctx context.Context, userCode string, use
|
||||
}
|
||||
|
||||
if !hasAuthorizedClient {
|
||||
userAuthorizedClient := model.UserAuthorizedOidcClient{
|
||||
UserID: userID,
|
||||
ClientID: deviceAuth.ClientID,
|
||||
Scope: deviceAuth.Scope,
|
||||
err := s.createAuthorizedClientInternal(ctx, deviceAuth.ClientID, userID, deviceAuth.Scope, tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.WithContext(ctx).Create(&userAuthorizedClient).Error; err != nil {
|
||||
if !errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
return err
|
||||
}
|
||||
// If duplicate, update scope
|
||||
if err := tx.WithContext(ctx).Model(&model.UserAuthorizedOidcClient{}).
|
||||
Where("user_id = ? AND client_id = ?", userID, deviceAuth.ClientID).
|
||||
Update("scope", deviceAuth.Scope).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
s.auditLogService.Create(ctx, model.AuditLogEventNewDeviceCodeAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": deviceAuth.Client.Name}, tx)
|
||||
} else {
|
||||
s.auditLogService.Create(ctx, model.AuditLogEventDeviceCodeAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": deviceAuth.Client.Name}, tx)
|
||||
@@ -1188,7 +1161,25 @@ func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, u
|
||||
return refreshToken, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) VerifyClientCredentials(ctx context.Context, clientID, clientSecret string, tx *gorm.DB) (model.OidcClient, error) {
|
||||
func (s *OidcService) createAuthorizedClientInternal(ctx context.Context, userID string, clientID string, scope string, tx *gorm.DB) error {
|
||||
userAuthorizedClient := model.UserAuthorizedOidcClient{
|
||||
UserID: userID,
|
||||
ClientID: clientID,
|
||||
Scope: scope,
|
||||
}
|
||||
|
||||
err := tx.WithContext(ctx).
|
||||
Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "user_id"}, {Name: "client_id"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"scope"}),
|
||||
}).
|
||||
Create(&userAuthorizedClient).
|
||||
Error
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, clientID, clientSecret string, tx *gorm.DB) (model.OidcClient, error) {
|
||||
if clientID == "" {
|
||||
return model.OidcClient{}, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user