mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-10 23:32:56 +03:00
refactor: some clean-up in OIDC service and controller (#550)
This commit is contained in:
committed by
Elias Schneider
parent
3896b7bb3b
commit
b71c84c355
@@ -136,13 +136,13 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Validate that code is provided for authorization_code grant type
|
||||
if input.GrantType == "authorization_code" && input.Code == "" {
|
||||
if input.GrantType == service.GrantTypeAuthorizationCode && input.Code == "" {
|
||||
_ = c.Error(&common.OidcMissingAuthorizationCodeError{})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate that refresh_token is provided for refresh_token grant type
|
||||
if input.GrantType == "refresh_token" && input.RefreshToken == "" {
|
||||
if input.GrantType == service.GrantTypeRefreshToken && input.RefreshToken == "" {
|
||||
_ = c.Error(&common.OidcMissingRefreshTokenError{})
|
||||
return
|
||||
}
|
||||
@@ -152,8 +152,7 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
|
||||
input.ClientID, input.ClientSecret, _ = c.Request.BasicAuth()
|
||||
}
|
||||
|
||||
idToken, accessToken, refreshToken, expiresIn, err :=
|
||||
oc.oidcService.CreateTokens(c.Request.Context(), input)
|
||||
tokens, err := oc.oidcService.CreateTokens(c.Request.Context(), input)
|
||||
|
||||
switch {
|
||||
case errors.Is(err, &common.OidcAuthorizationPendingError{}):
|
||||
@@ -171,23 +170,13 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response := dto.OidcTokenResponseDto{
|
||||
AccessToken: accessToken,
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: expiresIn,
|
||||
}
|
||||
|
||||
// Include ID token only for authorization_code grant
|
||||
if idToken != "" {
|
||||
response.IdToken = idToken
|
||||
}
|
||||
|
||||
// Include refresh token if generated
|
||||
if refreshToken != "" {
|
||||
response.RefreshToken = refreshToken
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
c.JSON(http.StatusOK, dto.OidcTokenResponseDto{
|
||||
AccessToken: tokens.AccessToken,
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: int(tokens.ExpiresIn.Seconds()),
|
||||
IdToken: tokens.IdToken, // May be empty
|
||||
RefreshToken: tokens.RefreshToken, // May be empty
|
||||
})
|
||||
}
|
||||
|
||||
// userInfoHandler godoc
|
||||
|
||||
@@ -77,7 +77,7 @@ func (wkc *WellKnownController) computeOIDCConfiguration() ([]byte, error) {
|
||||
"introspection_endpoint": appUrl + "/api/oidc/introspect",
|
||||
"device_authorization_endpoint": appUrl + "/api/oidc/device/authorize",
|
||||
"jwks_uri": appUrl + "/.well-known/jwks.json",
|
||||
"grant_types_supported": []string{"authorization_code", "refresh_token", "urn:ietf:params:oauth:grant-type:device_code"},
|
||||
"grant_types_supported": []string{service.GrantTypeAuthorizationCode, service.GrantTypeRefreshToken, service.GrantTypeDeviceCode},
|
||||
"scopes_supported": []string{"openid", "profile", "email", "groups"},
|
||||
"claims_supported": []string{"sub", "given_name", "family_name", "name", "email", "email_verified", "preferred_username", "picture", "groups"},
|
||||
"response_types_supported": []string{"code", "id_token"},
|
||||
|
||||
@@ -15,17 +15,22 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
GrantTypeAuthorizationCode = "authorization_code"
|
||||
GrantTypeRefreshToken = "refresh_token"
|
||||
GrantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
|
||||
)
|
||||
|
||||
type OidcService struct {
|
||||
@@ -167,139 +172,158 @@ func (s *OidcService) IsUserGroupAllowedToAuthorize(user model.User, client mode
|
||||
return isAllowedToAuthorize
|
||||
}
|
||||
|
||||
func (s *OidcService) CreateTokens(ctx context.Context, input dto.OidcCreateTokensDto) (idToken string, accessToken string, newRefreshToken string, exp int, err error) {
|
||||
type CreatedTokens struct {
|
||||
IdToken string
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
ExpiresIn time.Duration
|
||||
}
|
||||
|
||||
func (s *OidcService) CreateTokens(ctx context.Context, input dto.OidcCreateTokensDto) (CreatedTokens, error) {
|
||||
switch input.GrantType {
|
||||
case "authorization_code":
|
||||
return s.createTokenFromAuthorizationCode(ctx, input.Code, input.ClientID, input.ClientSecret, input.CodeVerifier)
|
||||
case "refresh_token":
|
||||
accessToken, newRefreshToken, exp, err = s.createTokenFromRefreshToken(ctx, input.RefreshToken, input.ClientID, input.ClientSecret)
|
||||
return "", accessToken, newRefreshToken, exp, err
|
||||
case "urn:ietf:params:oauth:grant-type:device_code":
|
||||
return s.createTokenFromDeviceCode(ctx, input.DeviceCode, input.ClientID, input.ClientSecret)
|
||||
case GrantTypeAuthorizationCode:
|
||||
return s.createTokenFromAuthorizationCode(ctx, input)
|
||||
case GrantTypeRefreshToken:
|
||||
return s.createTokenFromRefreshToken(ctx, input)
|
||||
case GrantTypeDeviceCode:
|
||||
return s.createTokenFromDeviceCode(ctx, input)
|
||||
default:
|
||||
return "", "", "", 0, &common.OidcGrantTypeNotSupportedError{}
|
||||
return CreatedTokens{}, &common.OidcGrantTypeNotSupportedError{}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, deviceCode, clientID string, clientSecret string) (idToken string, accessToken string, refreshToken string, exp int, err error) {
|
||||
func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.OidcCreateTokensDto) (CreatedTokens, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
_, err = s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, tx)
|
||||
_, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
// Get the device authorization from database with explicit query conditions
|
||||
var deviceAuth model.OidcDeviceCode
|
||||
if err := tx.WithContext(ctx).Preload("User").Where("device_code = ? AND client_id = ?", deviceCode, clientID).First(&deviceAuth).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Preload("User").
|
||||
Where("device_code = ? AND client_id = ?", input.DeviceCode, input.ClientID).
|
||||
First(&deviceAuth).
|
||||
Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return "", "", "", 0, &common.OidcInvalidDeviceCodeError{}
|
||||
return CreatedTokens{}, &common.OidcInvalidDeviceCodeError{}
|
||||
}
|
||||
return "", "", "", 0, err
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
// Check if device code has expired
|
||||
if time.Now().After(deviceAuth.ExpiresAt.ToTime()) {
|
||||
return "", "", "", 0, &common.OidcDeviceCodeExpiredError{}
|
||||
return CreatedTokens{}, &common.OidcDeviceCodeExpiredError{}
|
||||
}
|
||||
|
||||
// Check if device code has been authorized
|
||||
if !deviceAuth.IsAuthorized || deviceAuth.UserID == nil {
|
||||
return "", "", "", 0, &common.OidcAuthorizationPendingError{}
|
||||
return CreatedTokens{}, &common.OidcAuthorizationPendingError{}
|
||||
}
|
||||
|
||||
// Get user claims for the ID token - ensure UserID is not nil
|
||||
if deviceAuth.UserID == nil {
|
||||
return "", "", "", 0, &common.OidcAuthorizationPendingError{}
|
||||
return CreatedTokens{}, &common.OidcAuthorizationPendingError{}
|
||||
}
|
||||
|
||||
userClaims, err := s.getUserClaimsForClientInternal(ctx, *deviceAuth.UserID, clientID, tx)
|
||||
userClaims, err := s.getUserClaimsForClientInternal(ctx, *deviceAuth.UserID, input.ClientID, tx)
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
// Explicitly use the input clientID for the audience claim to ensure consistency
|
||||
idToken, err = s.jwtService.GenerateIDToken(userClaims, clientID, "")
|
||||
idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, "")
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
refreshToken, err = s.createRefreshToken(ctx, clientID, *deviceAuth.UserID, deviceAuth.Scope, tx)
|
||||
refreshToken, err := s.createRefreshToken(ctx, input.ClientID, *deviceAuth.UserID, deviceAuth.Scope, tx)
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
accessToken, err = s.jwtService.GenerateOauthAccessToken(deviceAuth.User, clientID)
|
||||
accessToken, err := s.jwtService.GenerateOauthAccessToken(deviceAuth.User, input.ClientID)
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
// Delete the used device code
|
||||
if err := tx.WithContext(ctx).Delete(&deviceAuth).Error; err != nil {
|
||||
return "", "", "", 0, err
|
||||
err = tx.WithContext(ctx).Delete(&deviceAuth).Error
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return "", "", "", 0, err
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
return idToken, accessToken, refreshToken, 3600, nil
|
||||
return CreatedTokens{
|
||||
IdToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
ExpiresIn: time.Hour,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, code, clientID, clientSecret, codeVerifier string) (idToken string, accessToken string, refreshToken string, exp int, err error) {
|
||||
func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, input dto.OidcCreateTokensDto) (CreatedTokens, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, tx)
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
var authorizationCodeMetaData model.OidcAuthorizationCode
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Preload("User").
|
||||
First(&authorizationCodeMetaData, "code = ?", code).
|
||||
First(&authorizationCodeMetaData, "code = ?", input.Code).
|
||||
Error
|
||||
if err != nil {
|
||||
return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{}
|
||||
return CreatedTokens{}, &common.OidcInvalidAuthorizationCodeError{}
|
||||
}
|
||||
|
||||
// If the client is public or PKCE is enabled, the code verifier must match the code challenge
|
||||
if client.IsPublic || client.PkceEnabled {
|
||||
if !s.validateCodeVerifier(codeVerifier, *authorizationCodeMetaData.CodeChallenge, *authorizationCodeMetaData.CodeChallengeMethodSha256) {
|
||||
return "", "", "", 0, &common.OidcInvalidCodeVerifierError{}
|
||||
if !s.validateCodeVerifier(input.CodeVerifier, *authorizationCodeMetaData.CodeChallenge, *authorizationCodeMetaData.CodeChallengeMethodSha256) {
|
||||
return CreatedTokens{}, &common.OidcInvalidCodeVerifierError{}
|
||||
}
|
||||
}
|
||||
|
||||
if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) {
|
||||
return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{}
|
||||
if authorizationCodeMetaData.ClientID != input.ClientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) {
|
||||
return CreatedTokens{}, &common.OidcInvalidAuthorizationCodeError{}
|
||||
}
|
||||
|
||||
userClaims, err := s.getUserClaimsForClientInternal(ctx, authorizationCodeMetaData.UserID, clientID, tx)
|
||||
userClaims, err := s.getUserClaimsForClientInternal(ctx, authorizationCodeMetaData.UserID, input.ClientID, tx)
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
idToken, err = s.jwtService.GenerateIDToken(userClaims, clientID, authorizationCodeMetaData.Nonce)
|
||||
idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, authorizationCodeMetaData.Nonce)
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
// Generate a refresh token
|
||||
refreshToken, err = s.createRefreshToken(ctx, clientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope, tx)
|
||||
refreshToken, err := s.createRefreshToken(ctx, input.ClientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope, tx)
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
accessToken, err = s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, clientID)
|
||||
accessToken, err := s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, input.ClientID)
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
err = tx.
|
||||
@@ -307,20 +331,25 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, code
|
||||
Delete(&authorizationCodeMetaData).
|
||||
Error
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
return idToken, accessToken, refreshToken, 3600, nil
|
||||
return CreatedTokens{
|
||||
IdToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
ExpiresIn: time.Hour,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshToken, clientID, clientSecret string) (accessToken string, newRefreshToken string, exp int, err error) {
|
||||
if refreshToken == "" {
|
||||
return "", "", 0, &common.OidcMissingRefreshTokenError{}
|
||||
func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto.OidcCreateTokensDto) (CreatedTokens, error) {
|
||||
if input.RefreshToken == "" {
|
||||
return CreatedTokens{}, &common.OidcMissingRefreshTokenError{}
|
||||
}
|
||||
|
||||
tx := s.db.Begin()
|
||||
@@ -328,9 +357,9 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshTo
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
_, err = s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, tx)
|
||||
_, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
// Verify refresh token
|
||||
@@ -338,31 +367,31 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshTo
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Preload("User").
|
||||
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(refreshToken), datatype.DateTime(time.Now())).
|
||||
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(input.RefreshToken), datatype.DateTime(time.Now())).
|
||||
First(&storedRefreshToken).
|
||||
Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return "", "", 0, &common.OidcInvalidRefreshTokenError{}
|
||||
return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
|
||||
}
|
||||
return "", "", 0, err
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
// Verify that the refresh token belongs to the provided client
|
||||
if storedRefreshToken.ClientID != clientID {
|
||||
return "", "", 0, &common.OidcInvalidRefreshTokenError{}
|
||||
if storedRefreshToken.ClientID != input.ClientID {
|
||||
return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
|
||||
}
|
||||
|
||||
// Generate a new access token
|
||||
accessToken, err = s.jwtService.GenerateOauthAccessToken(storedRefreshToken.User, clientID)
|
||||
accessToken, err := s.jwtService.GenerateOauthAccessToken(storedRefreshToken.User, input.ClientID)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
// Generate a new refresh token and invalidate the old one
|
||||
newRefreshToken, err = s.createRefreshToken(ctx, clientID, storedRefreshToken.UserID, storedRefreshToken.Scope, tx)
|
||||
newRefreshToken, err := s.createRefreshToken(ctx, input.ClientID, storedRefreshToken.UserID, storedRefreshToken.Scope, tx)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
// Delete the used refresh token
|
||||
@@ -371,15 +400,19 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshTo
|
||||
Delete(&storedRefreshToken).
|
||||
Error
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
return accessToken, newRefreshToken, 3600, nil
|
||||
return CreatedTokens{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: newRefreshToken,
|
||||
ExpiresIn: time.Hour,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecret, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
|
||||
@@ -1181,9 +1214,12 @@ func (s *OidcService) createAuthorizedClientInternal(ctx context.Context, userID
|
||||
}
|
||||
|
||||
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, clientID, clientSecret string, tx *gorm.DB) (model.OidcClient, error) {
|
||||
// First, ensure we have a valid client ID
|
||||
if clientID == "" {
|
||||
return model.OidcClient{}, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
|
||||
// Load the OIDC client's configuration
|
||||
var client model.OidcClient
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
@@ -1193,10 +1229,16 @@ func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, clien
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
if !client.IsPublic {
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)); err != nil {
|
||||
// If we have a client secret, we validate it
|
||||
// Otherwise, we require the client to be public
|
||||
if clientSecret != "" {
|
||||
err = bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
|
||||
if err != nil {
|
||||
return model.OidcClient{}, &common.OidcClientSecretInvalidError{}
|
||||
}
|
||||
return client, nil
|
||||
} else if !client.IsPublic {
|
||||
return model.OidcClient{}, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
|
||||
return client, nil
|
||||
|
||||
Reference in New Issue
Block a user