fix: updating scopes of an authorized client fails with Postgres

This commit is contained in:
Elias Schneider
2025-04-28 09:29:18 +02:00
parent 02cacba5c5
commit 0a24ab8001

View File

@@ -7,6 +7,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"gorm.io/gorm/clause"
"log" "log"
"mime/multipart" "mime/multipart"
"os" "os"
@@ -94,24 +95,8 @@ func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClie
// If the user has not authorized the client, create a new authorization in the database // If the user has not authorized the client, create a new authorization in the database
if !hasAuthorizedClient { if !hasAuthorizedClient {
userAuthorizedClient := model.UserAuthorizedOidcClient{ err := s.createAuthorizedClientInternal(ctx, userID, input.ClientID, input.Scope, tx)
UserID: userID, if err != nil {
ClientID: input.ClientID,
Scope: input.Scope,
}
err = tx.
WithContext(ctx).
Create(&userAuthorizedClient).
Error
if errors.Is(err, gorm.ErrDuplicatedKey) {
// The client has already been authorized but with a different scope so we need to update the scope
if err := tx.
WithContext(ctx).
Model(&userAuthorizedClient).Update("scope", input.Scope).Error; err != nil {
return "", "", err
}
} else if err != nil {
return "", "", err return "", "", err
} }
} }
@@ -201,7 +186,7 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, deviceCode,
tx.Rollback() tx.Rollback()
}() }()
_, err = s.VerifyClientCredentials(ctx, clientID, clientSecret, tx) _, err = s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, tx)
if err != nil { if err != nil {
return "", "", "", 0, err return "", "", "", 0, err
} }
@@ -269,7 +254,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, code
tx.Rollback() tx.Rollback()
}() }()
client, err := s.VerifyClientCredentials(ctx, clientID, clientSecret, tx) client, err := s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, tx)
if err != nil { if err != nil {
return "", "", "", 0, err return "", "", "", 0, err
} }
@@ -342,7 +327,7 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshTo
tx.Rollback() tx.Rollback()
}() }()
_, err = s.VerifyClientCredentials(ctx, clientID, clientSecret, tx) _, err = s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, tx)
if err != nil { if err != nil {
return "", "", 0, err return "", "", 0, err
} }
@@ -401,7 +386,7 @@ func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecre
return introspectDto, &common.OidcMissingClientCredentialsError{} return introspectDto, &common.OidcMissingClientCredentialsError{}
} }
_, err = s.VerifyClientCredentials(ctx, clientID, clientSecret, s.db) _, err = s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, s.db)
if err != nil { if err != nil {
return introspectDto, err return introspectDto, err
} }
@@ -999,7 +984,7 @@ func (s *OidcService) getCallbackURL(urls []string, inputCallbackURL string) (ca
} }
func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.OidcDeviceAuthorizationRequestDto) (*dto.OidcDeviceAuthorizationResponseDto, error) { func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.OidcDeviceAuthorizationRequestDto) (*dto.OidcDeviceAuthorizationResponseDto, error) {
client, err := s.VerifyClientCredentials(ctx, input.ClientID, input.ClientSecret, s.db) client, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, s.db)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1095,23 +1080,11 @@ func (s *OidcService) VerifyDeviceCode(ctx context.Context, userCode string, use
} }
if !hasAuthorizedClient { if !hasAuthorizedClient {
userAuthorizedClient := model.UserAuthorizedOidcClient{ err := s.createAuthorizedClientInternal(ctx, deviceAuth.ClientID, userID, deviceAuth.Scope, tx)
UserID: userID, if err != nil {
ClientID: deviceAuth.ClientID, return err
Scope: deviceAuth.Scope,
} }
if err := tx.WithContext(ctx).Create(&userAuthorizedClient).Error; err != nil {
if !errors.Is(err, gorm.ErrDuplicatedKey) {
return err
}
// If duplicate, update scope
if err := tx.WithContext(ctx).Model(&model.UserAuthorizedOidcClient{}).
Where("user_id = ? AND client_id = ?", userID, deviceAuth.ClientID).
Update("scope", deviceAuth.Scope).Error; err != nil {
return err
}
}
s.auditLogService.Create(ctx, model.AuditLogEventNewDeviceCodeAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": deviceAuth.Client.Name}, tx) s.auditLogService.Create(ctx, model.AuditLogEventNewDeviceCodeAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": deviceAuth.Client.Name}, tx)
} else { } else {
s.auditLogService.Create(ctx, model.AuditLogEventDeviceCodeAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": deviceAuth.Client.Name}, tx) s.auditLogService.Create(ctx, model.AuditLogEventDeviceCodeAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": deviceAuth.Client.Name}, tx)
@@ -1188,7 +1161,25 @@ func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, u
return refreshToken, nil return refreshToken, nil
} }
func (s *OidcService) VerifyClientCredentials(ctx context.Context, clientID, clientSecret string, tx *gorm.DB) (model.OidcClient, error) { func (s *OidcService) createAuthorizedClientInternal(ctx context.Context, userID string, clientID string, scope string, tx *gorm.DB) error {
userAuthorizedClient := model.UserAuthorizedOidcClient{
UserID: userID,
ClientID: clientID,
Scope: scope,
}
err := tx.WithContext(ctx).
Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "user_id"}, {Name: "client_id"}},
DoUpdates: clause.AssignmentColumns([]string{"scope"}),
}).
Create(&userAuthorizedClient).
Error
return err
}
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, clientID, clientSecret string, tx *gorm.DB) (model.OidcClient, error) {
if clientID == "" { if clientID == "" {
return model.OidcClient{}, &common.OidcMissingClientCredentialsError{} return model.OidcClient{}, &common.OidcMissingClientCredentialsError{}
} }