refactor: some clean-up in OIDC service and controller (#550)

This commit is contained in:
Alessandro (Ale) Segala
2025-05-19 08:10:33 -07:00
committed by Elias Schneider
parent 3896b7bb3b
commit b71c84c355
3 changed files with 127 additions and 96 deletions

View File

@@ -136,13 +136,13 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
} }
// Validate that code is provided for authorization_code grant type // Validate that code is provided for authorization_code grant type
if input.GrantType == "authorization_code" && input.Code == "" { if input.GrantType == service.GrantTypeAuthorizationCode && input.Code == "" {
_ = c.Error(&common.OidcMissingAuthorizationCodeError{}) _ = c.Error(&common.OidcMissingAuthorizationCodeError{})
return return
} }
// Validate that refresh_token is provided for refresh_token grant type // Validate that refresh_token is provided for refresh_token grant type
if input.GrantType == "refresh_token" && input.RefreshToken == "" { if input.GrantType == service.GrantTypeRefreshToken && input.RefreshToken == "" {
_ = c.Error(&common.OidcMissingRefreshTokenError{}) _ = c.Error(&common.OidcMissingRefreshTokenError{})
return return
} }
@@ -152,8 +152,7 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
input.ClientID, input.ClientSecret, _ = c.Request.BasicAuth() input.ClientID, input.ClientSecret, _ = c.Request.BasicAuth()
} }
idToken, accessToken, refreshToken, expiresIn, err := tokens, err := oc.oidcService.CreateTokens(c.Request.Context(), input)
oc.oidcService.CreateTokens(c.Request.Context(), input)
switch { switch {
case errors.Is(err, &common.OidcAuthorizationPendingError{}): case errors.Is(err, &common.OidcAuthorizationPendingError{}):
@@ -171,23 +170,13 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
return return
} }
response := dto.OidcTokenResponseDto{ c.JSON(http.StatusOK, dto.OidcTokenResponseDto{
AccessToken: accessToken, AccessToken: tokens.AccessToken,
TokenType: "Bearer", TokenType: "Bearer",
ExpiresIn: expiresIn, ExpiresIn: int(tokens.ExpiresIn.Seconds()),
} IdToken: tokens.IdToken, // May be empty
RefreshToken: tokens.RefreshToken, // May be empty
// Include ID token only for authorization_code grant })
if idToken != "" {
response.IdToken = idToken
}
// Include refresh token if generated
if refreshToken != "" {
response.RefreshToken = refreshToken
}
c.JSON(http.StatusOK, response)
} }
// userInfoHandler godoc // userInfoHandler godoc

View File

@@ -77,7 +77,7 @@ func (wkc *WellKnownController) computeOIDCConfiguration() ([]byte, error) {
"introspection_endpoint": appUrl + "/api/oidc/introspect", "introspection_endpoint": appUrl + "/api/oidc/introspect",
"device_authorization_endpoint": appUrl + "/api/oidc/device/authorize", "device_authorization_endpoint": appUrl + "/api/oidc/device/authorize",
"jwks_uri": appUrl + "/.well-known/jwks.json", "jwks_uri": appUrl + "/.well-known/jwks.json",
"grant_types_supported": []string{"authorization_code", "refresh_token", "urn:ietf:params:oauth:grant-type:device_code"}, "grant_types_supported": []string{service.GrantTypeAuthorizationCode, service.GrantTypeRefreshToken, service.GrantTypeDeviceCode},
"scopes_supported": []string{"openid", "profile", "email", "groups"}, "scopes_supported": []string{"openid", "profile", "email", "groups"},
"claims_supported": []string{"sub", "given_name", "family_name", "name", "email", "email_verified", "preferred_username", "picture", "groups"}, "claims_supported": []string{"sub", "given_name", "family_name", "name", "email", "email_verified", "preferred_username", "picture", "groups"},
"response_types_supported": []string{"code", "id_token"}, "response_types_supported": []string{"code", "id_token"},

View File

@@ -15,17 +15,22 @@ import (
"strings" "strings"
"time" "time"
"gorm.io/gorm/clause"
"github.com/lestrrat-go/jwx/v3/jwt" "github.com/lestrrat-go/jwx/v3/jwt"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto" "github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model" "github.com/pocket-id/pocket-id/backend/internal/model"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/utils" "github.com/pocket-id/pocket-id/backend/internal/utils"
"golang.org/x/crypto/bcrypt" )
"gorm.io/gorm"
const (
GrantTypeAuthorizationCode = "authorization_code"
GrantTypeRefreshToken = "refresh_token"
GrantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
) )
type OidcService struct { type OidcService struct {
@@ -167,139 +172,158 @@ func (s *OidcService) IsUserGroupAllowedToAuthorize(user model.User, client mode
return isAllowedToAuthorize return isAllowedToAuthorize
} }
func (s *OidcService) CreateTokens(ctx context.Context, input dto.OidcCreateTokensDto) (idToken string, accessToken string, newRefreshToken string, exp int, err error) { type CreatedTokens struct {
IdToken string
AccessToken string
RefreshToken string
ExpiresIn time.Duration
}
func (s *OidcService) CreateTokens(ctx context.Context, input dto.OidcCreateTokensDto) (CreatedTokens, error) {
switch input.GrantType { switch input.GrantType {
case "authorization_code": case GrantTypeAuthorizationCode:
return s.createTokenFromAuthorizationCode(ctx, input.Code, input.ClientID, input.ClientSecret, input.CodeVerifier) return s.createTokenFromAuthorizationCode(ctx, input)
case "refresh_token": case GrantTypeRefreshToken:
accessToken, newRefreshToken, exp, err = s.createTokenFromRefreshToken(ctx, input.RefreshToken, input.ClientID, input.ClientSecret) return s.createTokenFromRefreshToken(ctx, input)
return "", accessToken, newRefreshToken, exp, err case GrantTypeDeviceCode:
case "urn:ietf:params:oauth:grant-type:device_code": return s.createTokenFromDeviceCode(ctx, input)
return s.createTokenFromDeviceCode(ctx, input.DeviceCode, input.ClientID, input.ClientSecret)
default: default:
return "", "", "", 0, &common.OidcGrantTypeNotSupportedError{} return CreatedTokens{}, &common.OidcGrantTypeNotSupportedError{}
} }
} }
func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, deviceCode, clientID string, clientSecret string) (idToken string, accessToken string, refreshToken string, exp int, err error) { func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.OidcCreateTokensDto) (CreatedTokens, error) {
tx := s.db.Begin() tx := s.db.Begin()
defer func() { defer func() {
tx.Rollback() tx.Rollback()
}() }()
_, err = s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, tx) _, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
if err != nil { if err != nil {
return "", "", "", 0, err return CreatedTokens{}, err
} }
// Get the device authorization from database with explicit query conditions // Get the device authorization from database with explicit query conditions
var deviceAuth model.OidcDeviceCode var deviceAuth model.OidcDeviceCode
if err := tx.WithContext(ctx).Preload("User").Where("device_code = ? AND client_id = ?", deviceCode, clientID).First(&deviceAuth).Error; err != nil { err = tx.
WithContext(ctx).
Preload("User").
Where("device_code = ? AND client_id = ?", input.DeviceCode, input.ClientID).
First(&deviceAuth).
Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return "", "", "", 0, &common.OidcInvalidDeviceCodeError{} return CreatedTokens{}, &common.OidcInvalidDeviceCodeError{}
} }
return "", "", "", 0, err return CreatedTokens{}, err
} }
// Check if device code has expired // Check if device code has expired
if time.Now().After(deviceAuth.ExpiresAt.ToTime()) { if time.Now().After(deviceAuth.ExpiresAt.ToTime()) {
return "", "", "", 0, &common.OidcDeviceCodeExpiredError{} return CreatedTokens{}, &common.OidcDeviceCodeExpiredError{}
} }
// Check if device code has been authorized // Check if device code has been authorized
if !deviceAuth.IsAuthorized || deviceAuth.UserID == nil { if !deviceAuth.IsAuthorized || deviceAuth.UserID == nil {
return "", "", "", 0, &common.OidcAuthorizationPendingError{} return CreatedTokens{}, &common.OidcAuthorizationPendingError{}
} }
// Get user claims for the ID token - ensure UserID is not nil // Get user claims for the ID token - ensure UserID is not nil
if deviceAuth.UserID == nil { if deviceAuth.UserID == nil {
return "", "", "", 0, &common.OidcAuthorizationPendingError{} return CreatedTokens{}, &common.OidcAuthorizationPendingError{}
} }
userClaims, err := s.getUserClaimsForClientInternal(ctx, *deviceAuth.UserID, clientID, tx) userClaims, err := s.getUserClaimsForClientInternal(ctx, *deviceAuth.UserID, input.ClientID, tx)
if err != nil { if err != nil {
return "", "", "", 0, err return CreatedTokens{}, err
} }
// Explicitly use the input clientID for the audience claim to ensure consistency // Explicitly use the input clientID for the audience claim to ensure consistency
idToken, err = s.jwtService.GenerateIDToken(userClaims, clientID, "") idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, "")
if err != nil { if err != nil {
return "", "", "", 0, err return CreatedTokens{}, err
} }
refreshToken, err = s.createRefreshToken(ctx, clientID, *deviceAuth.UserID, deviceAuth.Scope, tx) refreshToken, err := s.createRefreshToken(ctx, input.ClientID, *deviceAuth.UserID, deviceAuth.Scope, tx)
if err != nil { if err != nil {
return "", "", "", 0, err return CreatedTokens{}, err
} }
accessToken, err = s.jwtService.GenerateOauthAccessToken(deviceAuth.User, clientID) accessToken, err := s.jwtService.GenerateOauthAccessToken(deviceAuth.User, input.ClientID)
if err != nil { if err != nil {
return "", "", "", 0, err return CreatedTokens{}, err
} }
// Delete the used device code // Delete the used device code
if err := tx.WithContext(ctx).Delete(&deviceAuth).Error; err != nil { err = tx.WithContext(ctx).Delete(&deviceAuth).Error
return "", "", "", 0, err if err != nil {
return CreatedTokens{}, err
} }
if err := tx.Commit().Error; err != nil { err = tx.Commit().Error
return "", "", "", 0, err if err != nil {
return CreatedTokens{}, err
} }
return idToken, accessToken, refreshToken, 3600, nil return CreatedTokens{
IdToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
ExpiresIn: time.Hour,
}, nil
} }
func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, code, clientID, clientSecret, codeVerifier string) (idToken string, accessToken string, refreshToken string, exp int, err error) { func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, input dto.OidcCreateTokensDto) (CreatedTokens, error) {
tx := s.db.Begin() tx := s.db.Begin()
defer func() { defer func() {
tx.Rollback() tx.Rollback()
}() }()
client, err := s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, tx) client, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
if err != nil { if err != nil {
return "", "", "", 0, err return CreatedTokens{}, err
} }
var authorizationCodeMetaData model.OidcAuthorizationCode var authorizationCodeMetaData model.OidcAuthorizationCode
err = tx. err = tx.
WithContext(ctx). WithContext(ctx).
Preload("User"). Preload("User").
First(&authorizationCodeMetaData, "code = ?", code). First(&authorizationCodeMetaData, "code = ?", input.Code).
Error Error
if err != nil { if err != nil {
return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{} return CreatedTokens{}, &common.OidcInvalidAuthorizationCodeError{}
} }
// If the client is public or PKCE is enabled, the code verifier must match the code challenge // If the client is public or PKCE is enabled, the code verifier must match the code challenge
if client.IsPublic || client.PkceEnabled { if client.IsPublic || client.PkceEnabled {
if !s.validateCodeVerifier(codeVerifier, *authorizationCodeMetaData.CodeChallenge, *authorizationCodeMetaData.CodeChallengeMethodSha256) { if !s.validateCodeVerifier(input.CodeVerifier, *authorizationCodeMetaData.CodeChallenge, *authorizationCodeMetaData.CodeChallengeMethodSha256) {
return "", "", "", 0, &common.OidcInvalidCodeVerifierError{} return CreatedTokens{}, &common.OidcInvalidCodeVerifierError{}
} }
} }
if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) { if authorizationCodeMetaData.ClientID != input.ClientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) {
return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{} return CreatedTokens{}, &common.OidcInvalidAuthorizationCodeError{}
} }
userClaims, err := s.getUserClaimsForClientInternal(ctx, authorizationCodeMetaData.UserID, clientID, tx) userClaims, err := s.getUserClaimsForClientInternal(ctx, authorizationCodeMetaData.UserID, input.ClientID, tx)
if err != nil { if err != nil {
return "", "", "", 0, err return CreatedTokens{}, err
} }
idToken, err = s.jwtService.GenerateIDToken(userClaims, clientID, authorizationCodeMetaData.Nonce) idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, authorizationCodeMetaData.Nonce)
if err != nil { if err != nil {
return "", "", "", 0, err return CreatedTokens{}, err
} }
// Generate a refresh token // Generate a refresh token
refreshToken, err = s.createRefreshToken(ctx, clientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope, tx) refreshToken, err := s.createRefreshToken(ctx, input.ClientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope, tx)
if err != nil { if err != nil {
return "", "", "", 0, err return CreatedTokens{}, err
} }
accessToken, err = s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, clientID) accessToken, err := s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, input.ClientID)
if err != nil { if err != nil {
return "", "", "", 0, err return CreatedTokens{}, err
} }
err = tx. err = tx.
@@ -307,20 +331,25 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, code
Delete(&authorizationCodeMetaData). Delete(&authorizationCodeMetaData).
Error Error
if err != nil { if err != nil {
return "", "", "", 0, err return CreatedTokens{}, err
} }
err = tx.Commit().Error err = tx.Commit().Error
if err != nil { if err != nil {
return "", "", "", 0, err return CreatedTokens{}, err
} }
return idToken, accessToken, refreshToken, 3600, nil return CreatedTokens{
IdToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
ExpiresIn: time.Hour,
}, nil
} }
func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshToken, clientID, clientSecret string) (accessToken string, newRefreshToken string, exp int, err error) { func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto.OidcCreateTokensDto) (CreatedTokens, error) {
if refreshToken == "" { if input.RefreshToken == "" {
return "", "", 0, &common.OidcMissingRefreshTokenError{} return CreatedTokens{}, &common.OidcMissingRefreshTokenError{}
} }
tx := s.db.Begin() tx := s.db.Begin()
@@ -328,9 +357,9 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshTo
tx.Rollback() tx.Rollback()
}() }()
_, err = s.verifyClientCredentialsInternal(ctx, clientID, clientSecret, tx) _, err := s.verifyClientCredentialsInternal(ctx, input.ClientID, input.ClientSecret, tx)
if err != nil { if err != nil {
return "", "", 0, err return CreatedTokens{}, err
} }
// Verify refresh token // Verify refresh token
@@ -338,31 +367,31 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshTo
err = tx. err = tx.
WithContext(ctx). WithContext(ctx).
Preload("User"). Preload("User").
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(refreshToken), datatype.DateTime(time.Now())). Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(input.RefreshToken), datatype.DateTime(time.Now())).
First(&storedRefreshToken). First(&storedRefreshToken).
Error Error
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return "", "", 0, &common.OidcInvalidRefreshTokenError{} return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
} }
return "", "", 0, err return CreatedTokens{}, err
} }
// Verify that the refresh token belongs to the provided client // Verify that the refresh token belongs to the provided client
if storedRefreshToken.ClientID != clientID { if storedRefreshToken.ClientID != input.ClientID {
return "", "", 0, &common.OidcInvalidRefreshTokenError{} return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
} }
// Generate a new access token // Generate a new access token
accessToken, err = s.jwtService.GenerateOauthAccessToken(storedRefreshToken.User, clientID) accessToken, err := s.jwtService.GenerateOauthAccessToken(storedRefreshToken.User, input.ClientID)
if err != nil { if err != nil {
return "", "", 0, err return CreatedTokens{}, err
} }
// Generate a new refresh token and invalidate the old one // Generate a new refresh token and invalidate the old one
newRefreshToken, err = s.createRefreshToken(ctx, clientID, storedRefreshToken.UserID, storedRefreshToken.Scope, tx) newRefreshToken, err := s.createRefreshToken(ctx, input.ClientID, storedRefreshToken.UserID, storedRefreshToken.Scope, tx)
if err != nil { if err != nil {
return "", "", 0, err return CreatedTokens{}, err
} }
// Delete the used refresh token // Delete the used refresh token
@@ -371,15 +400,19 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshTo
Delete(&storedRefreshToken). Delete(&storedRefreshToken).
Error Error
if err != nil { if err != nil {
return "", "", 0, err return CreatedTokens{}, err
} }
err = tx.Commit().Error err = tx.Commit().Error
if err != nil { if err != nil {
return "", "", 0, err return CreatedTokens{}, err
} }
return accessToken, newRefreshToken, 3600, nil return CreatedTokens{
AccessToken: accessToken,
RefreshToken: newRefreshToken,
ExpiresIn: time.Hour,
}, nil
} }
func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecret, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) { func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecret, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
@@ -1181,9 +1214,12 @@ func (s *OidcService) createAuthorizedClientInternal(ctx context.Context, userID
} }
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, clientID, clientSecret string, tx *gorm.DB) (model.OidcClient, error) { func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, clientID, clientSecret string, tx *gorm.DB) (model.OidcClient, error) {
// First, ensure we have a valid client ID
if clientID == "" { if clientID == "" {
return model.OidcClient{}, &common.OidcMissingClientCredentialsError{} return model.OidcClient{}, &common.OidcMissingClientCredentialsError{}
} }
// Load the OIDC client's configuration
var client model.OidcClient var client model.OidcClient
err := tx. err := tx.
WithContext(ctx). WithContext(ctx).
@@ -1193,10 +1229,16 @@ func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, clien
return model.OidcClient{}, err return model.OidcClient{}, err
} }
if !client.IsPublic { // If we have a client secret, we validate it
if err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)); err != nil { // Otherwise, we require the client to be public
if clientSecret != "" {
err = bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
if err != nil {
return model.OidcClient{}, &common.OidcClientSecretInvalidError{} return model.OidcClient{}, &common.OidcClientSecretInvalidError{}
} }
return client, nil
} else if !client.IsPublic {
return model.OidcClient{}, &common.OidcMissingClientCredentialsError{}
} }
return client, nil return client, nil