mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-17 17:23:30 +03:00
refactor: some clean-up in OIDC service and controller (#550)
This commit is contained in:
committed by
Elias Schneider
parent
3896b7bb3b
commit
b71c84c355
@@ -136,13 +136,13 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Validate that code is provided for authorization_code grant type
|
// Validate that code is provided for authorization_code grant type
|
||||||
if input.GrantType == "authorization_code" && input.Code == "" {
|
if input.GrantType == service.GrantTypeAuthorizationCode && input.Code == "" {
|
||||||
_ = c.Error(&common.OidcMissingAuthorizationCodeError{})
|
_ = c.Error(&common.OidcMissingAuthorizationCodeError{})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate that refresh_token is provided for refresh_token grant type
|
// Validate that refresh_token is provided for refresh_token grant type
|
||||||
if input.GrantType == "refresh_token" && input.RefreshToken == "" {
|
if input.GrantType == service.GrantTypeRefreshToken && input.RefreshToken == "" {
|
||||||
_ = c.Error(&common.OidcMissingRefreshTokenError{})
|
_ = c.Error(&common.OidcMissingRefreshTokenError{})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -152,8 +152,7 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
|
|||||||
input.ClientID, input.ClientSecret, _ = c.Request.BasicAuth()
|
input.ClientID, input.ClientSecret, _ = c.Request.BasicAuth()
|
||||||
}
|
}
|
||||||
|
|
||||||
idToken, accessToken, refreshToken, expiresIn, err :=
|
tokens, err := oc.oidcService.CreateTokens(c.Request.Context(), input)
|
||||||
oc.oidcService.CreateTokens(c.Request.Context(), input)
|
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, &common.OidcAuthorizationPendingError{}):
|
case errors.Is(err, &common.OidcAuthorizationPendingError{}):
|
||||||
@@ -171,23 +170,13 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response := dto.OidcTokenResponseDto{
|
c.JSON(http.StatusOK, dto.OidcTokenResponseDto{
|
||||||
AccessToken: accessToken,
|
AccessToken: tokens.AccessToken,
|
||||||
TokenType: "Bearer",
|
TokenType: "Bearer",
|
||||||
ExpiresIn: expiresIn,
|
ExpiresIn: int(tokens.ExpiresIn.Seconds()),
|
||||||
}
|
IdToken: tokens.IdToken, // May be empty
|
||||||
|
RefreshToken: tokens.RefreshToken, // May be empty
|
||||||
// Include ID token only for authorization_code grant
|
})
|
||||||
if idToken != "" {
|
|
||||||
response.IdToken = idToken
|
|
||||||
}
|
|
||||||
|
|
||||||
// Include refresh token if generated
|
|
||||||
if refreshToken != "" {
|
|
||||||
response.RefreshToken = refreshToken
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, response)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// userInfoHandler godoc
|
// userInfoHandler godoc
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ func (wkc *WellKnownController) computeOIDCConfiguration() ([]byte, error) {
|
|||||||
"introspection_endpoint": appUrl + "/api/oidc/introspect",
|
"introspection_endpoint": appUrl + "/api/oidc/introspect",
|
||||||
"device_authorization_endpoint": appUrl + "/api/oidc/device/authorize",
|
"device_authorization_endpoint": appUrl + "/api/oidc/device/authorize",
|
||||||
"jwks_uri": appUrl + "/.well-known/jwks.json",
|
"jwks_uri": appUrl + "/.well-known/jwks.json",
|
||||||
"grant_types_supported": []string{"authorization_code", "refresh_token", "urn:ietf:params:oauth:grant-type:device_code"},
|
"grant_types_supported": []string{service.GrantTypeAuthorizationCode, service.GrantTypeRefreshToken, service.GrantTypeDeviceCode},
|
||||||
"scopes_supported": []string{"openid", "profile", "email", "groups"},
|
"scopes_supported": []string{"openid", "profile", "email", "groups"},
|
||||||
"claims_supported": []string{"sub", "given_name", "family_name", "name", "email", "email_verified", "preferred_username", "picture", "groups"},
|
"claims_supported": []string{"sub", "given_name", "family_name", "name", "email", "email_verified", "preferred_username", "picture", "groups"},
|
||||||
"response_types_supported": []string{"code", "id_token"},
|
"response_types_supported": []string{"code", "id_token"},
|
||||||
|
|||||||
@@ -15,17 +15,22 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm/clause"
|
|
||||||
|
|
||||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||||
"golang.org/x/crypto/bcrypt"
|
)
|
||||||
"gorm.io/gorm"
|
|
||||||
|
const (
|
||||||
|
GrantTypeAuthorizationCode = "authorization_code"
|
||||||
|
GrantTypeRefreshToken = "refresh_token"
|
||||||
|
GrantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OidcService struct {
|
type OidcService struct {
|
||||||
@@ -167,139 +172,158 @@ func (s *OidcService) IsUserGroupAllowedToAuthorize(user model.User, client mode
|
|||||||
return isAllowedToAuthorize
|
return isAllowedToAuthorize
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) CreateTokens(ctx context.Context, input dto.OidcCreateTokensDto) (idToken string, accessToken string, newRefreshToken string, exp int, err error) {
|
type CreatedTokens struct {
|
||||||
|
IdToken string
|
||||||
|
AccessToken string
|
||||||
|
RefreshToken string
|
||||||
|
ExpiresIn time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OidcService) CreateTokens(ctx context.Context, input dto.OidcCreateTokensDto) (CreatedTokens, error) {
|
||||||
switch input.GrantType {
|
switch input.GrantType {
|
||||||
case "authorization_code":
|
case GrantTypeAuthorizationCode:
|
||||||
return s.createTokenFromAuthorizationCode(ctx, input.Code, input.ClientID, input.ClientSecret, input.CodeVerifier)
|
return s.createTokenFromAuthorizationCode(ctx, input)
|
||||||
case "refresh_token":
|
case GrantTypeRefreshToken:
|
||||||
accessToken, newRefreshToken, exp, err = s.createTokenFromRefreshToken(ctx, input.RefreshToken, input.ClientID, input.ClientSecret)
|
return s.createTokenFromRefreshToken(ctx, input)
|
||||||
return "", accessToken, newRefreshToken, exp, err
|
case GrantTypeDeviceCode:
|
||||||
case "urn:ietf:params:oauth:grant-type:device_code":
|
return s.createTokenFromDeviceCode(ctx, input)
|
||||||
return s.createTokenFromDeviceCode(ctx, input.DeviceCode, input.ClientID, input.ClientSecret)
|
|
||||||
default:
|
default:
|
||||||
return "", "", "", 0, &common.OidcGrantTypeNotSupportedError{}
|
return CreatedTokens{}, &common.OidcGrantTypeNotSupportedError{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, deviceCode, clientID string, clientSecret string) (idToken string, accessToken string, refreshToken string, exp int, err error) {
|
func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.OidcCreateTokensDto) (CreatedTokens, error) {
|
||||||
tx := s.db.Begin()
|
tx := s.db.Begin()
|
||||||
defer func() {
|
defer func() {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
_, err = s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, tx)
|
_, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", "", 0, err
|
return CreatedTokens{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the device authorization from database with explicit query conditions
|
// Get the device authorization from database with explicit query conditions
|
||||||
var deviceAuth model.OidcDeviceCode
|
var deviceAuth model.OidcDeviceCode
|
||||||
if err := tx.WithContext(ctx).Preload("User").Where("device_code = ? AND client_id = ?", deviceCode, clientID).First(&deviceAuth).Error; err != nil {
|
err = tx.
|
||||||
|
WithContext(ctx).
|
||||||
|
Preload("User").
|
||||||
|
Where("device_code = ? AND client_id = ?", input.DeviceCode, input.ClientID).
|
||||||
|
First(&deviceAuth).
|
||||||
|
Error
|
||||||
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return "", "", "", 0, &common.OidcInvalidDeviceCodeError{}
|
return CreatedTokens{}, &common.OidcInvalidDeviceCodeError{}
|
||||||
}
|
}
|
||||||
return "", "", "", 0, err
|
return CreatedTokens{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if device code has expired
|
// Check if device code has expired
|
||||||
if time.Now().After(deviceAuth.ExpiresAt.ToTime()) {
|
if time.Now().After(deviceAuth.ExpiresAt.ToTime()) {
|
||||||
return "", "", "", 0, &common.OidcDeviceCodeExpiredError{}
|
return CreatedTokens{}, &common.OidcDeviceCodeExpiredError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if device code has been authorized
|
// Check if device code has been authorized
|
||||||
if !deviceAuth.IsAuthorized || deviceAuth.UserID == nil {
|
if !deviceAuth.IsAuthorized || deviceAuth.UserID == nil {
|
||||||
return "", "", "", 0, &common.OidcAuthorizationPendingError{}
|
return CreatedTokens{}, &common.OidcAuthorizationPendingError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get user claims for the ID token - ensure UserID is not nil
|
// Get user claims for the ID token - ensure UserID is not nil
|
||||||
if deviceAuth.UserID == nil {
|
if deviceAuth.UserID == nil {
|
||||||
return "", "", "", 0, &common.OidcAuthorizationPendingError{}
|
return CreatedTokens{}, &common.OidcAuthorizationPendingError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
userClaims, err := s.getUserClaimsForClientInternal(ctx, *deviceAuth.UserID, clientID, tx)
|
userClaims, err := s.getUserClaimsForClientInternal(ctx, *deviceAuth.UserID, input.ClientID, tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", "", 0, err
|
return CreatedTokens{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Explicitly use the input clientID for the audience claim to ensure consistency
|
// Explicitly use the input clientID for the audience claim to ensure consistency
|
||||||
idToken, err = s.jwtService.GenerateIDToken(userClaims, clientID, "")
|
idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", "", 0, err
|
return CreatedTokens{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
refreshToken, err = s.createRefreshToken(ctx, clientID, *deviceAuth.UserID, deviceAuth.Scope, tx)
|
refreshToken, err := s.createRefreshToken(ctx, input.ClientID, *deviceAuth.UserID, deviceAuth.Scope, tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", "", 0, err
|
return CreatedTokens{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
accessToken, err = s.jwtService.GenerateOauthAccessToken(deviceAuth.User, clientID)
|
accessToken, err := s.jwtService.GenerateOauthAccessToken(deviceAuth.User, input.ClientID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", "", 0, err
|
return CreatedTokens{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete the used device code
|
// Delete the used device code
|
||||||
if err := tx.WithContext(ctx).Delete(&deviceAuth).Error; err != nil {
|
err = tx.WithContext(ctx).Delete(&deviceAuth).Error
|
||||||
return "", "", "", 0, err
|
if err != nil {
|
||||||
|
return CreatedTokens{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Commit().Error; err != nil {
|
err = tx.Commit().Error
|
||||||
return "", "", "", 0, err
|
if err != nil {
|
||||||
|
return CreatedTokens{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return idToken, accessToken, refreshToken, 3600, nil
|
return CreatedTokens{
|
||||||
|
IdToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
ExpiresIn: time.Hour,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, code, clientID, clientSecret, codeVerifier string) (idToken string, accessToken string, refreshToken string, exp int, err error) {
|
func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, input dto.OidcCreateTokensDto) (CreatedTokens, error) {
|
||||||
tx := s.db.Begin()
|
tx := s.db.Begin()
|
||||||
defer func() {
|
defer func() {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
client, err := s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, tx)
|
client, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", "", 0, err
|
return CreatedTokens{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var authorizationCodeMetaData model.OidcAuthorizationCode
|
var authorizationCodeMetaData model.OidcAuthorizationCode
|
||||||
err = tx.
|
err = tx.
|
||||||
WithContext(ctx).
|
WithContext(ctx).
|
||||||
Preload("User").
|
Preload("User").
|
||||||
First(&authorizationCodeMetaData, "code = ?", code).
|
First(&authorizationCodeMetaData, "code = ?", input.Code).
|
||||||
Error
|
Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{}
|
return CreatedTokens{}, &common.OidcInvalidAuthorizationCodeError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the client is public or PKCE is enabled, the code verifier must match the code challenge
|
// If the client is public or PKCE is enabled, the code verifier must match the code challenge
|
||||||
if client.IsPublic || client.PkceEnabled {
|
if client.IsPublic || client.PkceEnabled {
|
||||||
if !s.validateCodeVerifier(codeVerifier, *authorizationCodeMetaData.CodeChallenge, *authorizationCodeMetaData.CodeChallengeMethodSha256) {
|
if !s.validateCodeVerifier(input.CodeVerifier, *authorizationCodeMetaData.CodeChallenge, *authorizationCodeMetaData.CodeChallengeMethodSha256) {
|
||||||
return "", "", "", 0, &common.OidcInvalidCodeVerifierError{}
|
return CreatedTokens{}, &common.OidcInvalidCodeVerifierError{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) {
|
if authorizationCodeMetaData.ClientID != input.ClientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) {
|
||||||
return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{}
|
return CreatedTokens{}, &common.OidcInvalidAuthorizationCodeError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
userClaims, err := s.getUserClaimsForClientInternal(ctx, authorizationCodeMetaData.UserID, clientID, tx)
|
userClaims, err := s.getUserClaimsForClientInternal(ctx, authorizationCodeMetaData.UserID, input.ClientID, tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", "", 0, err
|
return CreatedTokens{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
idToken, err = s.jwtService.GenerateIDToken(userClaims, clientID, authorizationCodeMetaData.Nonce)
|
idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, authorizationCodeMetaData.Nonce)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", "", 0, err
|
return CreatedTokens{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate a refresh token
|
// Generate a refresh token
|
||||||
refreshToken, err = s.createRefreshToken(ctx, clientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope, tx)
|
refreshToken, err := s.createRefreshToken(ctx, input.ClientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope, tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", "", 0, err
|
return CreatedTokens{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
accessToken, err = s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, clientID)
|
accessToken, err := s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, input.ClientID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", "", 0, err
|
return CreatedTokens{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tx.
|
err = tx.
|
||||||
@@ -307,20 +331,25 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, code
|
|||||||
Delete(&authorizationCodeMetaData).
|
Delete(&authorizationCodeMetaData).
|
||||||
Error
|
Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", "", 0, err
|
return CreatedTokens{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tx.Commit().Error
|
err = tx.Commit().Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", "", 0, err
|
return CreatedTokens{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return idToken, accessToken, refreshToken, 3600, nil
|
return CreatedTokens{
|
||||||
|
IdToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
ExpiresIn: time.Hour,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshToken, clientID, clientSecret string) (accessToken string, newRefreshToken string, exp int, err error) {
|
func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto.OidcCreateTokensDto) (CreatedTokens, error) {
|
||||||
if refreshToken == "" {
|
if input.RefreshToken == "" {
|
||||||
return "", "", 0, &common.OidcMissingRefreshTokenError{}
|
return CreatedTokens{}, &common.OidcMissingRefreshTokenError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
tx := s.db.Begin()
|
tx := s.db.Begin()
|
||||||
@@ -328,9 +357,9 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshTo
|
|||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
_, err = s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, tx)
|
_, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", 0, err
|
return CreatedTokens{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify refresh token
|
// Verify refresh token
|
||||||
@@ -338,31 +367,31 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshTo
|
|||||||
err = tx.
|
err = tx.
|
||||||
WithContext(ctx).
|
WithContext(ctx).
|
||||||
Preload("User").
|
Preload("User").
|
||||||
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(refreshToken), datatype.DateTime(time.Now())).
|
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(input.RefreshToken), datatype.DateTime(time.Now())).
|
||||||
First(&storedRefreshToken).
|
First(&storedRefreshToken).
|
||||||
Error
|
Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return "", "", 0, &common.OidcInvalidRefreshTokenError{}
|
return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
|
||||||
}
|
}
|
||||||
return "", "", 0, err
|
return CreatedTokens{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify that the refresh token belongs to the provided client
|
// Verify that the refresh token belongs to the provided client
|
||||||
if storedRefreshToken.ClientID != clientID {
|
if storedRefreshToken.ClientID != input.ClientID {
|
||||||
return "", "", 0, &common.OidcInvalidRefreshTokenError{}
|
return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate a new access token
|
// Generate a new access token
|
||||||
accessToken, err = s.jwtService.GenerateOauthAccessToken(storedRefreshToken.User, clientID)
|
accessToken, err := s.jwtService.GenerateOauthAccessToken(storedRefreshToken.User, input.ClientID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", 0, err
|
return CreatedTokens{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate a new refresh token and invalidate the old one
|
// Generate a new refresh token and invalidate the old one
|
||||||
newRefreshToken, err = s.createRefreshToken(ctx, clientID, storedRefreshToken.UserID, storedRefreshToken.Scope, tx)
|
newRefreshToken, err := s.createRefreshToken(ctx, input.ClientID, storedRefreshToken.UserID, storedRefreshToken.Scope, tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", 0, err
|
return CreatedTokens{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete the used refresh token
|
// Delete the used refresh token
|
||||||
@@ -371,15 +400,19 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshTo
|
|||||||
Delete(&storedRefreshToken).
|
Delete(&storedRefreshToken).
|
||||||
Error
|
Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", 0, err
|
return CreatedTokens{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tx.Commit().Error
|
err = tx.Commit().Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", 0, err
|
return CreatedTokens{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return accessToken, newRefreshToken, 3600, nil
|
return CreatedTokens{
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: newRefreshToken,
|
||||||
|
ExpiresIn: time.Hour,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecret, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
|
func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecret, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
|
||||||
@@ -1181,9 +1214,12 @@ func (s *OidcService) createAuthorizedClientInternal(ctx context.Context, userID
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, clientID, clientSecret string, tx *gorm.DB) (model.OidcClient, error) {
|
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, clientID, clientSecret string, tx *gorm.DB) (model.OidcClient, error) {
|
||||||
|
// First, ensure we have a valid client ID
|
||||||
if clientID == "" {
|
if clientID == "" {
|
||||||
return model.OidcClient{}, &common.OidcMissingClientCredentialsError{}
|
return model.OidcClient{}, &common.OidcMissingClientCredentialsError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Load the OIDC client's configuration
|
||||||
var client model.OidcClient
|
var client model.OidcClient
|
||||||
err := tx.
|
err := tx.
|
||||||
WithContext(ctx).
|
WithContext(ctx).
|
||||||
@@ -1193,10 +1229,16 @@ func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, clien
|
|||||||
return model.OidcClient{}, err
|
return model.OidcClient{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !client.IsPublic {
|
// If we have a client secret, we validate it
|
||||||
if err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)); err != nil {
|
// Otherwise, we require the client to be public
|
||||||
|
if clientSecret != "" {
|
||||||
|
err = bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
|
||||||
|
if err != nil {
|
||||||
return model.OidcClient{}, &common.OidcClientSecretInvalidError{}
|
return model.OidcClient{}, &common.OidcClientSecretInvalidError{}
|
||||||
}
|
}
|
||||||
|
return client, nil
|
||||||
|
} else if !client.IsPublic {
|
||||||
|
return model.OidcClient{}, &common.OidcMissingClientCredentialsError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
return client, nil
|
return client, nil
|
||||||
|
|||||||
Reference in New Issue
Block a user