refactor: some clean-up in OIDC service and controller (#550)

This commit is contained in:
Alessandro (Ale) Segala
2025-05-19 08:10:33 -07:00
committed by Elias Schneider
parent 3896b7bb3b
commit b71c84c355
3 changed files with 127 additions and 96 deletions

View File

@@ -136,13 +136,13 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
}
// Validate that code is provided for authorization_code grant type
if input.GrantType == "authorization_code" && input.Code == "" {
if input.GrantType == service.GrantTypeAuthorizationCode && input.Code == "" {
_ = c.Error(&common.OidcMissingAuthorizationCodeError{})
return
}
// Validate that refresh_token is provided for refresh_token grant type
if input.GrantType == "refresh_token" && input.RefreshToken == "" {
if input.GrantType == service.GrantTypeRefreshToken && input.RefreshToken == "" {
_ = c.Error(&common.OidcMissingRefreshTokenError{})
return
}
@@ -152,8 +152,7 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
input.ClientID, input.ClientSecret, _ = c.Request.BasicAuth()
}
idToken, accessToken, refreshToken, expiresIn, err :=
oc.oidcService.CreateTokens(c.Request.Context(), input)
tokens, err := oc.oidcService.CreateTokens(c.Request.Context(), input)
switch {
case errors.Is(err, &common.OidcAuthorizationPendingError{}):
@@ -171,23 +170,13 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
return
}
response := dto.OidcTokenResponseDto{
AccessToken: accessToken,
TokenType: "Bearer",
ExpiresIn: expiresIn,
}
// Include ID token only for authorization_code grant
if idToken != "" {
response.IdToken = idToken
}
// Include refresh token if generated
if refreshToken != "" {
response.RefreshToken = refreshToken
}
c.JSON(http.StatusOK, response)
c.JSON(http.StatusOK, dto.OidcTokenResponseDto{
AccessToken: tokens.AccessToken,
TokenType: "Bearer",
ExpiresIn: int(tokens.ExpiresIn.Seconds()),
IdToken: tokens.IdToken, // May be empty
RefreshToken: tokens.RefreshToken, // May be empty
})
}
// userInfoHandler godoc

View File

@@ -77,7 +77,7 @@ func (wkc *WellKnownController) computeOIDCConfiguration() ([]byte, error) {
"introspection_endpoint": appUrl + "/api/oidc/introspect",
"device_authorization_endpoint": appUrl + "/api/oidc/device/authorize",
"jwks_uri": appUrl + "/.well-known/jwks.json",
"grant_types_supported": []string{"authorization_code", "refresh_token", "urn:ietf:params:oauth:grant-type:device_code"},
"grant_types_supported": []string{service.GrantTypeAuthorizationCode, service.GrantTypeRefreshToken, service.GrantTypeDeviceCode},
"scopes_supported": []string{"openid", "profile", "email", "groups"},
"claims_supported": []string{"sub", "given_name", "family_name", "name", "email", "email_verified", "preferred_username", "picture", "groups"},
"response_types_supported": []string{"code", "id_token"},

View File

@@ -15,17 +15,22 @@ import (
"strings"
"time"
"gorm.io/gorm/clause"
"github.com/lestrrat-go/jwx/v3/jwt"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
const (
GrantTypeAuthorizationCode = "authorization_code"
GrantTypeRefreshToken = "refresh_token"
GrantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
)
type OidcService struct {
@@ -167,139 +172,158 @@ func (s *OidcService) IsUserGroupAllowedToAuthorize(user model.User, client mode
return isAllowedToAuthorize
}
func (s *OidcService) CreateTokens(ctx context.Context, input dto.OidcCreateTokensDto) (idToken string, accessToken string, newRefreshToken string, exp int, err error) {
type CreatedTokens struct {
IdToken string
AccessToken string
RefreshToken string
ExpiresIn time.Duration
}
func (s *OidcService) CreateTokens(ctx context.Context, input dto.OidcCreateTokensDto) (CreatedTokens, error) {
switch input.GrantType {
case "authorization_code":
return s.createTokenFromAuthorizationCode(ctx, input.Code, input.ClientID, input.ClientSecret, input.CodeVerifier)
case "refresh_token":
accessToken, newRefreshToken, exp, err = s.createTokenFromRefreshToken(ctx, input.RefreshToken, input.ClientID, input.ClientSecret)
return "", accessToken, newRefreshToken, exp, err
case "urn:ietf:params:oauth:grant-type:device_code":
return s.createTokenFromDeviceCode(ctx, input.DeviceCode, input.ClientID, input.ClientSecret)
case GrantTypeAuthorizationCode:
return s.createTokenFromAuthorizationCode(ctx, input)
case GrantTypeRefreshToken:
return s.createTokenFromRefreshToken(ctx, input)
case GrantTypeDeviceCode:
return s.createTokenFromDeviceCode(ctx, input)
default:
return "", "", "", 0, &common.OidcGrantTypeNotSupportedError{}
return CreatedTokens{}, &common.OidcGrantTypeNotSupportedError{}
}
}
func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, deviceCode, clientID string, clientSecret string) (idToken string, accessToken string, refreshToken string, exp int, err error) {
func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.OidcCreateTokensDto) (CreatedTokens, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
_, err = s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, tx)
_, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
if err != nil {
return "", "", "", 0, err
return CreatedTokens{}, err
}
// Get the device authorization from database with explicit query conditions
var deviceAuth model.OidcDeviceCode
if err := tx.WithContext(ctx).Preload("User").Where("device_code = ? AND client_id = ?", deviceCode, clientID).First(&deviceAuth).Error; err != nil {
err = tx.
WithContext(ctx).
Preload("User").
Where("device_code = ? AND client_id = ?", input.DeviceCode, input.ClientID).
First(&deviceAuth).
Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", "", "", 0, &common.OidcInvalidDeviceCodeError{}
return CreatedTokens{}, &common.OidcInvalidDeviceCodeError{}
}
return "", "", "", 0, err
return CreatedTokens{}, err
}
// Check if device code has expired
if time.Now().After(deviceAuth.ExpiresAt.ToTime()) {
return "", "", "", 0, &common.OidcDeviceCodeExpiredError{}
return CreatedTokens{}, &common.OidcDeviceCodeExpiredError{}
}
// Check if device code has been authorized
if !deviceAuth.IsAuthorized || deviceAuth.UserID == nil {
return "", "", "", 0, &common.OidcAuthorizationPendingError{}
return CreatedTokens{}, &common.OidcAuthorizationPendingError{}
}
// Get user claims for the ID token - ensure UserID is not nil
if deviceAuth.UserID == nil {
return "", "", "", 0, &common.OidcAuthorizationPendingError{}
return CreatedTokens{}, &common.OidcAuthorizationPendingError{}
}
userClaims, err := s.getUserClaimsForClientInternal(ctx, *deviceAuth.UserID, clientID, tx)
userClaims, err := s.getUserClaimsForClientInternal(ctx, *deviceAuth.UserID, input.ClientID, tx)
if err != nil {
return "", "", "", 0, err
return CreatedTokens{}, err
}
// Explicitly use the input clientID for the audience claim to ensure consistency
idToken, err = s.jwtService.GenerateIDToken(userClaims, clientID, "")
idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, "")
if err != nil {
return "", "", "", 0, err
return CreatedTokens{}, err
}
refreshToken, err = s.createRefreshToken(ctx, clientID, *deviceAuth.UserID, deviceAuth.Scope, tx)
refreshToken, err := s.createRefreshToken(ctx, input.ClientID, *deviceAuth.UserID, deviceAuth.Scope, tx)
if err != nil {
return "", "", "", 0, err
return CreatedTokens{}, err
}
accessToken, err = s.jwtService.GenerateOauthAccessToken(deviceAuth.User, clientID)
accessToken, err := s.jwtService.GenerateOauthAccessToken(deviceAuth.User, input.ClientID)
if err != nil {
return "", "", "", 0, err
return CreatedTokens{}, err
}
// Delete the used device code
if err := tx.WithContext(ctx).Delete(&deviceAuth).Error; err != nil {
return "", "", "", 0, err
err = tx.WithContext(ctx).Delete(&deviceAuth).Error
if err != nil {
return CreatedTokens{}, err
}
if err := tx.Commit().Error; err != nil {
return "", "", "", 0, err
err = tx.Commit().Error
if err != nil {
return CreatedTokens{}, err
}
return idToken, accessToken, refreshToken, 3600, nil
return CreatedTokens{
IdToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
ExpiresIn: time.Hour,
}, nil
}
func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, code, clientID, clientSecret, codeVerifier string) (idToken string, accessToken string, refreshToken string, exp int, err error) {
func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, input dto.OidcCreateTokensDto) (CreatedTokens, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
client, err := s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, tx)
client, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
if err != nil {
return "", "", "", 0, err
return CreatedTokens{}, err
}
var authorizationCodeMetaData model.OidcAuthorizationCode
err = tx.
WithContext(ctx).
Preload("User").
First(&authorizationCodeMetaData, "code = ?", code).
First(&authorizationCodeMetaData, "code = ?", input.Code).
Error
if err != nil {
return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{}
return CreatedTokens{}, &common.OidcInvalidAuthorizationCodeError{}
}
// If the client is public or PKCE is enabled, the code verifier must match the code challenge
if client.IsPublic || client.PkceEnabled {
if !s.validateCodeVerifier(codeVerifier, *authorizationCodeMetaData.CodeChallenge, *authorizationCodeMetaData.CodeChallengeMethodSha256) {
return "", "", "", 0, &common.OidcInvalidCodeVerifierError{}
if !s.validateCodeVerifier(input.CodeVerifier, *authorizationCodeMetaData.CodeChallenge, *authorizationCodeMetaData.CodeChallengeMethodSha256) {
return CreatedTokens{}, &common.OidcInvalidCodeVerifierError{}
}
}
if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) {
return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{}
if authorizationCodeMetaData.ClientID != input.ClientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) {
return CreatedTokens{}, &common.OidcInvalidAuthorizationCodeError{}
}
userClaims, err := s.getUserClaimsForClientInternal(ctx, authorizationCodeMetaData.UserID, clientID, tx)
userClaims, err := s.getUserClaimsForClientInternal(ctx, authorizationCodeMetaData.UserID, input.ClientID, tx)
if err != nil {
return "", "", "", 0, err
return CreatedTokens{}, err
}
idToken, err = s.jwtService.GenerateIDToken(userClaims, clientID, authorizationCodeMetaData.Nonce)
idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, authorizationCodeMetaData.Nonce)
if err != nil {
return "", "", "", 0, err
return CreatedTokens{}, err
}
// Generate a refresh token
refreshToken, err = s.createRefreshToken(ctx, clientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope, tx)
refreshToken, err := s.createRefreshToken(ctx, input.ClientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope, tx)
if err != nil {
return "", "", "", 0, err
return CreatedTokens{}, err
}
accessToken, err = s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, clientID)
accessToken, err := s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, input.ClientID)
if err != nil {
return "", "", "", 0, err
return CreatedTokens{}, err
}
err = tx.
@@ -307,20 +331,25 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, code
Delete(&authorizationCodeMetaData).
Error
if err != nil {
return "", "", "", 0, err
return CreatedTokens{}, err
}
err = tx.Commit().Error
if err != nil {
return "", "", "", 0, err
return CreatedTokens{}, err
}
return idToken, accessToken, refreshToken, 3600, nil
return CreatedTokens{
IdToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
ExpiresIn: time.Hour,
}, nil
}
func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshToken, clientID, clientSecret string) (accessToken string, newRefreshToken string, exp int, err error) {
if refreshToken == "" {
return "", "", 0, &common.OidcMissingRefreshTokenError{}
func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto.OidcCreateTokensDto) (CreatedTokens, error) {
if input.RefreshToken == "" {
return CreatedTokens{}, &common.OidcMissingRefreshTokenError{}
}
tx := s.db.Begin()
@@ -328,9 +357,9 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshTo
tx.Rollback()
}()
_, err = s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, tx)
_, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
if err != nil {
return "", "", 0, err
return CreatedTokens{}, err
}
// Verify refresh token
@@ -338,31 +367,31 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshTo
err = tx.
WithContext(ctx).
Preload("User").
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(refreshToken), datatype.DateTime(time.Now())).
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(input.RefreshToken), datatype.DateTime(time.Now())).
First(&storedRefreshToken).
Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", "", 0, &common.OidcInvalidRefreshTokenError{}
return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
}
return "", "", 0, err
return CreatedTokens{}, err
}
// Verify that the refresh token belongs to the provided client
if storedRefreshToken.ClientID != clientID {
return "", "", 0, &common.OidcInvalidRefreshTokenError{}
if storedRefreshToken.ClientID != input.ClientID {
return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
}
// Generate a new access token
accessToken, err = s.jwtService.GenerateOauthAccessToken(storedRefreshToken.User, clientID)
accessToken, err := s.jwtService.GenerateOauthAccessToken(storedRefreshToken.User, input.ClientID)
if err != nil {
return "", "", 0, err
return CreatedTokens{}, err
}
// Generate a new refresh token and invalidate the old one
newRefreshToken, err = s.createRefreshToken(ctx, clientID, storedRefreshToken.UserID, storedRefreshToken.Scope, tx)
newRefreshToken, err := s.createRefreshToken(ctx, input.ClientID, storedRefreshToken.UserID, storedRefreshToken.Scope, tx)
if err != nil {
return "", "", 0, err
return CreatedTokens{}, err
}
// Delete the used refresh token
@@ -371,15 +400,19 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshTo
Delete(&storedRefreshToken).
Error
if err != nil {
return "", "", 0, err
return CreatedTokens{}, err
}
err = tx.Commit().Error
if err != nil {
return "", "", 0, err
return CreatedTokens{}, err
}
return accessToken, newRefreshToken, 3600, nil
return CreatedTokens{
AccessToken: accessToken,
RefreshToken: newRefreshToken,
ExpiresIn: time.Hour,
}, nil
}
func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecret, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
@@ -1181,9 +1214,12 @@ func (s *OidcService) createAuthorizedClientInternal(ctx context.Context, userID
}
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, clientID, clientSecret string, tx *gorm.DB) (model.OidcClient, error) {
// First, ensure we have a valid client ID
if clientID == "" {
return model.OidcClient{}, &common.OidcMissingClientCredentialsError{}
}
// Load the OIDC client's configuration
var client model.OidcClient
err := tx.
WithContext(ctx).
@@ -1193,10 +1229,16 @@ func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, clien
return model.OidcClient{}, err
}
if !client.IsPublic {
if err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)); err != nil {
// If we have a client secret, we validate it
// Otherwise, we require the client to be public
if clientSecret != "" {
err = bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
if err != nil {
return model.OidcClient{}, &common.OidcClientSecretInvalidError{}
}
return client, nil
} else if !client.IsPublic {
return model.OidcClient{}, &common.OidcMissingClientCredentialsError{}
}
return client, nil