2024-08-17 21:57:14 +02:00
package service
import (
2025-04-06 06:04:08 -07:00
"context"
2024-11-15 15:00:25 +01:00
"crypto/sha256"
2025-06-06 03:23:51 -07:00
"crypto/tls"
2024-11-15 15:00:25 +01:00
"encoding/base64"
2025-02-05 18:28:21 +01:00
"encoding/json"
2024-08-17 21:57:14 +02:00
"errors"
"fmt"
2025-04-25 12:14:51 -05:00
"log"
2025-06-06 03:23:51 -07:00
"log/slog"
2024-08-17 21:57:14 +02:00
"mime/multipart"
2025-06-06 03:23:51 -07:00
"net/http"
2024-08-17 21:57:14 +02:00
"os"
2025-01-20 11:19:23 +01:00
"regexp"
2025-04-06 06:04:08 -07:00
"slices"
2024-08-19 18:48:18 +02:00
"strings"
2024-08-17 21:57:14 +02:00
"time"
2025-02-05 18:08:01 +01:00
2025-06-06 03:23:51 -07:00
"github.com/lestrrat-go/httprc/v3"
"github.com/lestrrat-go/httprc/v3/errsink"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/lestrrat-go/jwx/v3/jws"
2025-04-09 09:18:03 +02:00
"github.com/lestrrat-go/jwx/v3/jwt"
2025-05-19 08:10:33 -07:00
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
"gorm.io/gorm/clause"
2025-04-09 09:18:03 +02:00
2025-02-05 18:08:01 +01:00
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/utils"
2025-05-19 08:10:33 -07:00
)
const (
GrantTypeAuthorizationCode = "authorization_code"
GrantTypeRefreshToken = "refresh_token"
GrantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
2025-06-06 03:23:51 -07:00
ClientAssertionTypeJWTBearer = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" //nolint:gosec
2024-08-17 21:57:14 +02:00
)
type OidcService struct {
2024-10-28 18:11:54 +01:00
db * gorm . DB
jwtService * JwtService
appConfigService * AppConfigService
auditLogService * AuditLogService
customClaimService * CustomClaimService
2025-06-06 03:23:51 -07:00
httpClient * http . Client
jwkCache * jwk . Cache
2024-08-17 21:57:14 +02:00
}
2025-06-06 03:23:51 -07:00
func NewOidcService (
ctx context . Context ,
db * gorm . DB ,
jwtService * JwtService ,
appConfigService * AppConfigService ,
auditLogService * AuditLogService ,
customClaimService * CustomClaimService ,
) ( s * OidcService , err error ) {
s = & OidcService {
2024-10-28 18:11:54 +01:00
db : db ,
jwtService : jwtService ,
appConfigService : appConfigService ,
auditLogService : auditLogService ,
customClaimService : customClaimService ,
2024-08-17 21:57:14 +02:00
}
2025-06-06 03:23:51 -07:00
// Note: we don't pass the HTTP Client with OTel instrumented to this because requests are always made in background and not tied to a specific trace
s . jwkCache , err = s . getJWKCache ( ctx )
if err != nil {
return nil , err
}
return s , nil
}
func ( s * OidcService ) getJWKCache ( ctx context . Context ) ( * jwk . Cache , error ) {
// We need to create a custom HTTP client to set a timeout.
client := s . httpClient
if client == nil {
client = & http . Client {
Timeout : 20 * time . Second ,
}
defaultTransport , ok := http . DefaultTransport . ( * http . Transport )
if ! ok {
// Indicates a development-time error
panic ( "Default transport is not of type *http.Transport" )
}
transport := defaultTransport . Clone ( )
transport . TLSClientConfig . MinVersion = tls . VersionTLS12
client . Transport = transport
}
// Create the JWKS cache
return jwk . NewCache ( ctx ,
httprc . NewClient (
httprc . WithErrorSink ( errsink . NewSlog ( slog . Default ( ) ) ) ,
httprc . WithHTTPClient ( client ) ,
) ,
)
2024-08-17 21:57:14 +02:00
}
2025-04-06 06:04:08 -07:00
func ( s * OidcService ) Authorize ( ctx context . Context , input dto . AuthorizeOidcClientRequestDto , userID , ipAddress , userAgent string ) ( string , string , error ) {
tx := s . db . Begin ( )
defer func ( ) {
tx . Rollback ( )
} ( )
2025-02-03 18:41:15 +01:00
var client model . OidcClient
2025-04-06 06:04:08 -07:00
err := tx .
WithContext ( ctx ) .
Preload ( "AllowedUserGroups" ) .
First ( & client , "id = ?" , input . ClientID ) .
Error
if err != nil {
2025-02-03 18:41:15 +01:00
return "" , "" , err
2024-11-15 15:00:25 +01:00
}
2025-02-03 18:41:15 +01:00
// If the client is not public, the code challenge must be provided
if client . IsPublic && input . CodeChallenge == "" {
return "" , "" , & common . OidcMissingCodeChallengeError { }
2024-08-17 21:57:14 +02:00
}
2025-02-03 18:41:15 +01:00
// Get the callback URL of the client. Return an error if the provided callback URL is not allowed
2025-05-29 13:01:23 -07:00
callbackURL , err := s . getCallbackURL ( & client , input . CallbackURL , tx , ctx )
2024-08-24 00:49:08 +02:00
if err != nil {
return "" , "" , err
}
2025-02-03 18:41:15 +01:00
// Check if the user group is allowed to authorize the client
var user model . User
2025-04-06 06:04:08 -07:00
err = tx .
WithContext ( ctx ) .
Preload ( "UserGroups" ) .
First ( & user , "id = ?" , userID ) .
Error
if err != nil {
2024-09-09 10:29:41 +02:00
return "" , "" , err
}
2025-02-03 18:41:15 +01:00
if ! s . IsUserGroupAllowedToAuthorize ( user , client ) {
return "" , "" , & common . OidcAccessDeniedError { }
}
2024-08-17 21:57:14 +02:00
2025-02-03 18:41:15 +01:00
// Check if the user has already authorized the client with the given scope
2025-04-06 06:04:08 -07:00
hasAuthorizedClient , err := s . hasAuthorizedClientInternal ( ctx , input . ClientID , userID , input . Scope , tx )
2025-02-03 18:41:15 +01:00
if err != nil {
2024-08-24 00:49:08 +02:00
return "" , "" , err
}
2025-02-03 18:41:15 +01:00
// If the user has not authorized the client, create a new authorization in the database
if ! hasAuthorizedClient {
2025-04-28 09:29:18 +02:00
err := s . createAuthorizedClientInternal ( ctx , userID , input . ClientID , input . Scope , tx )
if err != nil {
2025-04-06 06:04:08 -07:00
return "" , "" , err
2025-02-03 18:41:15 +01:00
}
2024-11-15 15:00:25 +01:00
}
2025-02-03 18:41:15 +01:00
// Create the authorization code
2025-04-06 06:04:08 -07:00
code , err := s . createAuthorizationCode ( ctx , input . ClientID , userID , input . Scope , input . Nonce , input . CodeChallenge , input . CodeChallengeMethod , tx )
2024-08-24 00:49:08 +02:00
if err != nil {
return "" , "" , err
}
2025-02-03 18:41:15 +01:00
// Log the authorization event
if hasAuthorizedClient {
2025-04-06 06:04:08 -07:00
s . auditLogService . Create ( ctx , model . AuditLogEventClientAuthorization , ipAddress , userAgent , userID , model . AuditLogData { "clientName" : client . Name } , tx )
2025-02-03 18:41:15 +01:00
} else {
2025-04-06 06:04:08 -07:00
s . auditLogService . Create ( ctx , model . AuditLogEventNewClientAuthorization , ipAddress , userAgent , userID , model . AuditLogData { "clientName" : client . Name } , tx )
}
2025-02-03 18:41:15 +01:00
2025-04-06 06:04:08 -07:00
err = tx . Commit ( ) . Error
if err != nil {
return "" , "" , err
2024-08-17 21:57:14 +02:00
}
2025-02-03 18:41:15 +01:00
return code , callbackURL , nil
}
// HasAuthorizedClient checks if the user has already authorized the client with the given scope
2025-04-06 06:04:08 -07:00
func ( s * OidcService ) HasAuthorizedClient ( ctx context . Context , clientID , userID , scope string ) ( bool , error ) {
return s . hasAuthorizedClientInternal ( ctx , clientID , userID , scope , s . db )
}
func ( s * OidcService ) hasAuthorizedClientInternal ( ctx context . Context , clientID , userID , scope string , tx * gorm . DB ) ( bool , error ) {
2025-02-03 18:41:15 +01:00
var userAuthorizedOidcClient model . UserAuthorizedOidcClient
2025-04-06 06:04:08 -07:00
err := tx .
WithContext ( ctx ) .
First ( & userAuthorizedOidcClient , "client_id = ? AND user_id = ?" , clientID , userID ) .
Error
if err != nil {
2025-02-03 18:41:15 +01:00
if errors . Is ( err , gorm . ErrRecordNotFound ) {
return false , nil
2024-08-17 21:57:14 +02:00
}
2025-02-03 18:41:15 +01:00
return false , err
2024-08-17 21:57:14 +02:00
}
2025-02-03 18:41:15 +01:00
if userAuthorizedOidcClient . Scope != scope {
return false , nil
2024-09-09 10:29:41 +02:00
}
2025-02-03 18:41:15 +01:00
return true , nil
}
2024-09-09 10:29:41 +02:00
2025-02-03 18:41:15 +01:00
// IsUserGroupAllowedToAuthorize checks if the user group of the user is allowed to authorize the client
func ( s * OidcService ) IsUserGroupAllowedToAuthorize ( user model . User , client model . OidcClient ) bool {
if len ( client . AllowedUserGroups ) == 0 {
return true
}
isAllowedToAuthorize := false
for _ , userGroup := range client . AllowedUserGroups {
for _ , userGroupUser := range user . UserGroups {
if userGroup . ID == userGroupUser . ID {
isAllowedToAuthorize = true
break
}
}
}
return isAllowedToAuthorize
2024-08-17 21:57:14 +02:00
}
2025-05-19 08:10:33 -07:00
type CreatedTokens struct {
IdToken string
AccessToken string
RefreshToken string
ExpiresIn time . Duration
}
func ( s * OidcService ) CreateTokens ( ctx context . Context , input dto . OidcCreateTokensDto ) ( CreatedTokens , error ) {
2025-04-25 12:14:51 -05:00
switch input . GrantType {
2025-05-19 08:10:33 -07:00
case GrantTypeAuthorizationCode :
return s . createTokenFromAuthorizationCode ( ctx , input )
case GrantTypeRefreshToken :
return s . createTokenFromRefreshToken ( ctx , input )
case GrantTypeDeviceCode :
return s . createTokenFromDeviceCode ( ctx , input )
2025-03-25 07:36:53 -07:00
default :
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , & common . OidcGrantTypeNotSupportedError { }
2025-03-25 07:36:53 -07:00
}
}
2024-08-17 21:57:14 +02:00
2025-05-19 08:10:33 -07:00
func ( s * OidcService ) createTokenFromDeviceCode ( ctx context . Context , input dto . OidcCreateTokensDto ) ( CreatedTokens , error ) {
2025-04-06 06:04:08 -07:00
tx := s . db . Begin ( )
defer func ( ) {
tx . Rollback ( )
} ( )
2025-06-06 03:23:51 -07:00
_ , err := s . verifyClientCredentialsInternal ( ctx , tx , input )
2025-04-06 06:04:08 -07:00
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-03-25 07:36:53 -07:00
}
2024-08-17 21:57:14 +02:00
2025-04-25 12:14:51 -05:00
// Get the device authorization from database with explicit query conditions
var deviceAuth model . OidcDeviceCode
2025-05-19 08:10:33 -07:00
err = tx .
WithContext ( ctx ) .
Preload ( "User" ) .
Where ( "device_code = ? AND client_id = ?" , input . DeviceCode , input . ClientID ) .
First ( & deviceAuth ) .
Error
if err != nil {
2025-04-25 12:14:51 -05:00
if errors . Is ( err , gorm . ErrRecordNotFound ) {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , & common . OidcInvalidDeviceCodeError { }
2024-11-15 15:00:25 +01:00
}
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-04-25 12:14:51 -05:00
}
2024-11-15 15:00:25 +01:00
2025-04-25 12:14:51 -05:00
// Check if device code has expired
if time . Now ( ) . After ( deviceAuth . ExpiresAt . ToTime ( ) ) {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , & common . OidcDeviceCodeExpiredError { }
2025-04-25 12:14:51 -05:00
}
// Check if device code has been authorized
if ! deviceAuth . IsAuthorized || deviceAuth . UserID == nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , & common . OidcAuthorizationPendingError { }
2025-04-25 12:14:51 -05:00
}
// Get user claims for the ID token - ensure UserID is not nil
if deviceAuth . UserID == nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , & common . OidcAuthorizationPendingError { }
2025-04-25 12:14:51 -05:00
}
2025-05-19 08:10:33 -07:00
userClaims , err := s . getUserClaimsForClientInternal ( ctx , * deviceAuth . UserID , input . ClientID , tx )
2025-04-25 12:14:51 -05:00
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-04-25 12:14:51 -05:00
}
// Explicitly use the input clientID for the audience claim to ensure consistency
2025-05-19 08:10:33 -07:00
idToken , err := s . jwtService . GenerateIDToken ( userClaims , input . ClientID , "" )
2025-04-25 12:14:51 -05:00
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-04-25 12:14:51 -05:00
}
2025-05-19 08:10:33 -07:00
refreshToken , err := s . createRefreshToken ( ctx , input . ClientID , * deviceAuth . UserID , deviceAuth . Scope , tx )
2025-04-25 12:14:51 -05:00
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-04-25 12:14:51 -05:00
}
2025-05-19 08:10:33 -07:00
accessToken , err := s . jwtService . GenerateOauthAccessToken ( deviceAuth . User , input . ClientID )
2025-04-25 12:14:51 -05:00
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-04-25 12:14:51 -05:00
}
// Delete the used device code
2025-05-19 08:10:33 -07:00
err = tx . WithContext ( ctx ) . Delete ( & deviceAuth ) . Error
if err != nil {
return CreatedTokens { } , err
2025-04-25 12:14:51 -05:00
}
2025-05-19 08:10:33 -07:00
err = tx . Commit ( ) . Error
if err != nil {
return CreatedTokens { } , err
2025-04-25 12:14:51 -05:00
}
2025-05-19 08:10:33 -07:00
return CreatedTokens {
IdToken : idToken ,
AccessToken : accessToken ,
RefreshToken : refreshToken ,
ExpiresIn : time . Hour ,
} , nil
2025-04-25 12:14:51 -05:00
}
2025-05-19 08:10:33 -07:00
func ( s * OidcService ) createTokenFromAuthorizationCode ( ctx context . Context , input dto . OidcCreateTokensDto ) ( CreatedTokens , error ) {
2025-04-25 12:14:51 -05:00
tx := s . db . Begin ( )
defer func ( ) {
tx . Rollback ( )
} ( )
2025-06-06 03:23:51 -07:00
client , err := s . verifyClientCredentialsInternal ( ctx , tx , input )
2025-04-25 12:14:51 -05:00
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-03-25 07:36:53 -07:00
}
2024-11-15 15:00:25 +01:00
2025-03-25 07:36:53 -07:00
var authorizationCodeMetaData model . OidcAuthorizationCode
2025-04-06 06:04:08 -07:00
err = tx .
WithContext ( ctx ) .
Preload ( "User" ) .
2025-05-19 08:10:33 -07:00
First ( & authorizationCodeMetaData , "code = ?" , input . Code ) .
2025-04-06 06:04:08 -07:00
Error
2025-03-25 07:36:53 -07:00
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , & common . OidcInvalidAuthorizationCodeError { }
2025-03-25 07:36:53 -07:00
}
2024-08-17 21:57:14 +02:00
2025-03-25 07:36:53 -07:00
// If the client is public or PKCE is enabled, the code verifier must match the code challenge
if client . IsPublic || client . PkceEnabled {
2025-05-19 08:10:33 -07:00
if ! s . validateCodeVerifier ( input . CodeVerifier , * authorizationCodeMetaData . CodeChallenge , * authorizationCodeMetaData . CodeChallengeMethodSha256 ) {
return CreatedTokens { } , & common . OidcInvalidCodeVerifierError { }
2025-03-23 15:14:26 -05:00
}
2025-03-25 07:36:53 -07:00
}
2024-08-17 21:57:14 +02:00
2025-05-19 08:10:33 -07:00
if authorizationCodeMetaData . ClientID != input . ClientID && authorizationCodeMetaData . ExpiresAt . ToTime ( ) . Before ( time . Now ( ) ) {
return CreatedTokens { } , & common . OidcInvalidAuthorizationCodeError { }
2025-03-25 07:36:53 -07:00
}
2025-03-23 15:14:26 -05:00
2025-05-19 08:10:33 -07:00
userClaims , err := s . getUserClaimsForClientInternal ( ctx , authorizationCodeMetaData . UserID , input . ClientID , tx )
2025-03-25 07:36:53 -07:00
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-03-25 07:36:53 -07:00
}
2025-03-23 15:14:26 -05:00
2025-05-19 08:10:33 -07:00
idToken , err := s . jwtService . GenerateIDToken ( userClaims , input . ClientID , authorizationCodeMetaData . Nonce )
2025-03-25 07:36:53 -07:00
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-03-25 07:36:53 -07:00
}
2025-03-23 15:14:26 -05:00
2025-03-25 07:36:53 -07:00
// Generate a refresh token
2025-05-19 08:10:33 -07:00
refreshToken , err := s . createRefreshToken ( ctx , input . ClientID , authorizationCodeMetaData . UserID , authorizationCodeMetaData . Scope , tx )
2025-03-25 07:36:53 -07:00
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2024-08-19 18:48:18 +02:00
}
2025-05-19 08:10:33 -07:00
accessToken , err := s . jwtService . GenerateOauthAccessToken ( authorizationCodeMetaData . User , input . ClientID )
2025-03-27 17:46:10 +01:00
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-03-27 17:46:10 +01:00
}
2025-03-23 15:14:26 -05:00
2025-04-06 06:04:08 -07:00
err = tx .
WithContext ( ctx ) .
Delete ( & authorizationCodeMetaData ) .
Error
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-04-06 06:04:08 -07:00
}
err = tx . Commit ( ) . Error
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-04-06 06:04:08 -07:00
}
2025-03-23 15:14:26 -05:00
2025-05-19 08:10:33 -07:00
return CreatedTokens {
IdToken : idToken ,
AccessToken : accessToken ,
RefreshToken : refreshToken ,
ExpiresIn : time . Hour ,
} , nil
2025-03-25 07:36:53 -07:00
}
2025-03-23 15:14:26 -05:00
2025-05-19 08:10:33 -07:00
func ( s * OidcService ) createTokenFromRefreshToken ( ctx context . Context , input dto . OidcCreateTokensDto ) ( CreatedTokens , error ) {
if input . RefreshToken == "" {
return CreatedTokens { } , & common . OidcMissingRefreshTokenError { }
2025-03-25 07:36:53 -07:00
}
2025-03-23 15:14:26 -05:00
2025-04-06 06:04:08 -07:00
tx := s . db . Begin ( )
defer func ( ) {
tx . Rollback ( )
} ( )
2025-06-06 03:23:51 -07:00
_ , err := s . verifyClientCredentialsInternal ( ctx , tx , input )
2025-04-06 06:04:08 -07:00
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-03-25 07:36:53 -07:00
}
2025-03-23 15:14:26 -05:00
2025-03-25 07:36:53 -07:00
// Verify refresh token
var storedRefreshToken model . OidcRefreshToken
2025-04-06 06:04:08 -07:00
err = tx .
WithContext ( ctx ) .
Preload ( "User" ) .
2025-05-19 08:10:33 -07:00
Where ( "token = ? AND expires_at > ?" , utils . CreateSha256Hash ( input . RefreshToken ) , datatype . DateTime ( time . Now ( ) ) ) .
2025-03-25 07:36:53 -07:00
First ( & storedRefreshToken ) .
Error
if err != nil {
if errors . Is ( err , gorm . ErrRecordNotFound ) {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , & common . OidcInvalidRefreshTokenError { }
2025-03-23 15:14:26 -05:00
}
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-03-25 07:36:53 -07:00
}
2024-08-17 21:57:14 +02:00
2025-03-25 07:36:53 -07:00
// Verify that the refresh token belongs to the provided client
2025-05-19 08:10:33 -07:00
if storedRefreshToken . ClientID != input . ClientID {
return CreatedTokens { } , & common . OidcInvalidRefreshTokenError { }
2025-03-25 07:36:53 -07:00
}
2025-03-23 15:14:26 -05:00
2025-03-25 07:36:53 -07:00
// Generate a new access token
2025-05-19 08:10:33 -07:00
accessToken , err := s . jwtService . GenerateOauthAccessToken ( storedRefreshToken . User , input . ClientID )
2025-03-25 07:36:53 -07:00
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-03-23 15:14:26 -05:00
}
2025-03-25 07:36:53 -07:00
// Generate a new refresh token and invalidate the old one
2025-05-19 08:10:33 -07:00
newRefreshToken , err := s . createRefreshToken ( ctx , input . ClientID , storedRefreshToken . UserID , storedRefreshToken . Scope , tx )
2025-03-25 07:36:53 -07:00
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-03-25 07:36:53 -07:00
}
// Delete the used refresh token
2025-04-06 06:04:08 -07:00
err = tx .
WithContext ( ctx ) .
Delete ( & storedRefreshToken ) .
Error
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-04-06 06:04:08 -07:00
}
err = tx . Commit ( ) . Error
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-04-06 06:04:08 -07:00
}
2025-03-25 07:36:53 -07:00
2025-05-19 08:10:33 -07:00
return CreatedTokens {
AccessToken : accessToken ,
RefreshToken : newRefreshToken ,
ExpiresIn : time . Hour ,
} , nil
2024-08-17 21:57:14 +02:00
}
2025-04-27 02:32:42 +09:00
func ( s * OidcService ) IntrospectToken ( ctx context . Context , clientID , clientSecret , tokenString string ) ( introspectDto dto . OidcIntrospectionResponseDto , err error ) {
2025-04-09 09:18:03 +02:00
if clientID == "" || clientSecret == "" {
return introspectDto , & common . OidcMissingClientCredentialsError { }
}
2025-06-06 03:23:51 -07:00
_ , err = s . verifyClientCredentialsInternal ( ctx , s . db , dto . OidcCreateTokensDto {
ClientID : clientID ,
ClientSecret : clientSecret ,
} )
2025-04-25 12:14:51 -05:00
if err != nil {
return introspectDto , err
2025-04-09 09:18:03 +02:00
}
token , err := s . jwtService . VerifyOauthAccessToken ( tokenString )
if err != nil {
if errors . Is ( err , jwt . ParseError ( ) ) {
// It's apparently not a valid JWT token, so we check if it's a valid refresh_token.
2025-04-27 02:32:42 +09:00
return s . introspectRefreshToken ( ctx , tokenString )
2025-04-09 09:18:03 +02:00
}
// Every failure we get means the token is invalid. Nothing more to do with the error.
introspectDto . Active = false
return introspectDto , nil
}
introspectDto . Active = true
introspectDto . TokenType = "access_token"
if token . Has ( "scope" ) {
2025-06-06 03:23:51 -07:00
var (
asString string
asStrings [ ] string
)
2025-04-09 09:18:03 +02:00
if err := token . Get ( "scope" , & asString ) ; err == nil {
introspectDto . Scope = asString
} else if err := token . Get ( "scope" , & asStrings ) ; err == nil {
introspectDto . Scope = strings . Join ( asStrings , " " )
}
}
2025-06-06 03:23:51 -07:00
if expiration , ok := token . Expiration ( ) ; ok {
2025-04-09 09:18:03 +02:00
introspectDto . Expiration = expiration . Unix ( )
}
2025-06-06 03:23:51 -07:00
if issuedAt , ok := token . IssuedAt ( ) ; ok {
2025-04-09 09:18:03 +02:00
introspectDto . IssuedAt = issuedAt . Unix ( )
}
2025-06-06 03:23:51 -07:00
if notBefore , ok := token . NotBefore ( ) ; ok {
2025-04-09 09:18:03 +02:00
introspectDto . NotBefore = notBefore . Unix ( )
}
2025-06-06 03:23:51 -07:00
if subject , ok := token . Subject ( ) ; ok {
2025-04-09 09:18:03 +02:00
introspectDto . Subject = subject
}
2025-06-06 03:23:51 -07:00
if audience , ok := token . Audience ( ) ; ok {
2025-04-09 09:18:03 +02:00
introspectDto . Audience = audience
}
2025-06-06 03:23:51 -07:00
if issuer , ok := token . Issuer ( ) ; ok {
2025-04-09 09:18:03 +02:00
introspectDto . Issuer = issuer
}
2025-06-06 03:23:51 -07:00
if identifier , ok := token . JwtID ( ) ; ok {
2025-04-09 09:18:03 +02:00
introspectDto . Identifier = identifier
}
return introspectDto , nil
}
2025-04-27 02:32:42 +09:00
func ( s * OidcService ) introspectRefreshToken ( ctx context . Context , refreshToken string ) ( introspectDto dto . OidcIntrospectionResponseDto , err error ) {
2025-04-09 09:18:03 +02:00
var storedRefreshToken model . OidcRefreshToken
2025-04-27 02:32:42 +09:00
err = s . db .
WithContext ( ctx ) .
Preload ( "User" ) .
2025-04-09 09:18:03 +02:00
Where ( "token = ? AND expires_at > ?" , utils . CreateSha256Hash ( refreshToken ) , datatype . DateTime ( time . Now ( ) ) ) .
First ( & storedRefreshToken ) .
Error
if err != nil {
if errors . Is ( err , gorm . ErrRecordNotFound ) {
introspectDto . Active = false
return introspectDto , nil
}
return introspectDto , err
}
introspectDto . Active = true
introspectDto . TokenType = "refresh_token"
return introspectDto , nil
}
2025-04-06 06:04:08 -07:00
func ( s * OidcService ) GetClient ( ctx context . Context , clientID string ) ( model . OidcClient , error ) {
return s . getClientInternal ( ctx , clientID , s . db )
}
func ( s * OidcService ) getClientInternal ( ctx context . Context , clientID string , tx * gorm . DB ) ( model . OidcClient , error ) {
2024-08-17 21:57:14 +02:00
var client model . OidcClient
2025-04-06 06:04:08 -07:00
err := tx .
WithContext ( ctx ) .
Preload ( "CreatedBy" ) .
Preload ( "AllowedUserGroups" ) .
First ( & client , "id = ?" , clientID ) .
Error
if err != nil {
2024-08-23 17:04:19 +02:00
return model . OidcClient { } , err
2024-08-17 21:57:14 +02:00
}
2024-08-23 17:04:19 +02:00
return client , nil
2024-08-17 21:57:14 +02:00
}
2025-05-25 14:22:25 -05:00
func ( s * OidcService ) ListClients ( ctx context . Context , name string , sortedPaginationRequest utils . SortedPaginationRequest ) ( [ ] model . OidcClient , utils . PaginationResponse , error ) {
2024-08-17 21:57:14 +02:00
var clients [ ] model . OidcClient
2025-04-06 06:04:08 -07:00
query := s . db .
WithContext ( ctx ) .
Preload ( "CreatedBy" ) .
Model ( & model . OidcClient { } )
2025-05-25 14:22:25 -05:00
if name != "" {
query = query . Where ( "name LIKE ?" , "%" + name + "%" )
2024-08-17 21:57:14 +02:00
}
2025-05-25 14:22:25 -05:00
// As allowedUserGroupsCount is not a column, we need to manually sort it
isValidSortDirection := sortedPaginationRequest . Sort . Direction == "asc" || sortedPaginationRequest . Sort . Direction == "desc"
if sortedPaginationRequest . Sort . Column == "allowedUserGroupsCount" && isValidSortDirection {
query = query . Select ( "oidc_clients.*, COUNT(oidc_clients_allowed_user_groups.oidc_client_id)" ) .
Joins ( "LEFT JOIN oidc_clients_allowed_user_groups ON oidc_clients.id = oidc_clients_allowed_user_groups.oidc_client_id" ) .
Group ( "oidc_clients.id" ) .
Order ( "COUNT(oidc_clients_allowed_user_groups.oidc_client_id) " + sortedPaginationRequest . Sort . Direction )
response , err := utils . Paginate ( sortedPaginationRequest . Pagination . Page , sortedPaginationRequest . Pagination . Limit , query , & clients )
return clients , response , err
2024-08-17 21:57:14 +02:00
}
2025-05-25 14:22:25 -05:00
response , err := utils . PaginateAndSort ( sortedPaginationRequest , query , & clients )
return clients , response , err
2024-08-17 21:57:14 +02:00
}
2025-04-06 06:04:08 -07:00
func ( s * OidcService ) CreateClient ( ctx context . Context , input dto . OidcClientCreateDto , userID string ) ( model . OidcClient , error ) {
2024-08-17 21:57:14 +02:00
client := model . OidcClient {
2025-06-06 03:23:51 -07:00
CreatedByID : userID ,
2024-08-17 21:57:14 +02:00
}
2025-06-06 03:23:51 -07:00
updateOIDCClientModelFromDto ( & client , & input )
2024-08-17 21:57:14 +02:00
2025-04-06 06:04:08 -07:00
err := s . db .
WithContext ( ctx ) .
Create ( & client ) .
Error
if err != nil {
2024-08-23 17:04:19 +02:00
return model . OidcClient { } , err
2024-08-17 21:57:14 +02:00
}
2024-08-23 17:04:19 +02:00
return client , nil
2024-08-17 21:57:14 +02:00
}
2025-04-06 06:04:08 -07:00
func ( s * OidcService ) UpdateClient ( ctx context . Context , clientID string , input dto . OidcClientCreateDto ) ( model . OidcClient , error ) {
tx := s . db . Begin ( )
defer func ( ) {
tx . Rollback ( )
} ( )
2024-08-17 21:57:14 +02:00
var client model . OidcClient
2025-04-06 06:04:08 -07:00
err := tx .
WithContext ( ctx ) .
Preload ( "CreatedBy" ) .
First ( & client , "id = ?" , clientID ) .
Error
if err != nil {
2024-08-23 17:04:19 +02:00
return model . OidcClient { } , err
2024-08-17 21:57:14 +02:00
}
2025-06-06 03:23:51 -07:00
updateOIDCClientModelFromDto ( & client , & input )
2024-08-17 21:57:14 +02:00
2025-04-06 06:04:08 -07:00
err = tx .
WithContext ( ctx ) .
Save ( & client ) .
Error
if err != nil {
return model . OidcClient { } , err
}
err = tx . Commit ( ) . Error
if err != nil {
2024-08-23 17:04:19 +02:00
return model . OidcClient { } , err
2024-08-17 21:57:14 +02:00
}
2024-08-23 17:04:19 +02:00
return client , nil
2024-08-17 21:57:14 +02:00
}
2025-06-06 03:23:51 -07:00
func updateOIDCClientModelFromDto ( client * model . OidcClient , input * dto . OidcClientCreateDto ) {
// Base fields
client . Name = input . Name
client . CallbackURLs = input . CallbackURLs
client . LogoutCallbackURLs = input . LogoutCallbackURLs
client . IsPublic = input . IsPublic
// PKCE is required for public clients
client . PkceEnabled = input . IsPublic || input . PkceEnabled
// Credentials
if len ( input . Credentials . FederatedIdentities ) > 0 {
client . Credentials . FederatedIdentities = make ( [ ] model . OidcClientFederatedIdentity , len ( input . Credentials . FederatedIdentities ) )
for i , fi := range input . Credentials . FederatedIdentities {
client . Credentials . FederatedIdentities [ i ] = model . OidcClientFederatedIdentity {
Issuer : fi . Issuer ,
Audience : fi . Audience ,
Subject : fi . Subject ,
JWKS : fi . JWKS ,
}
}
}
}
2025-04-06 06:04:08 -07:00
func ( s * OidcService ) DeleteClient ( ctx context . Context , clientID string ) error {
2024-08-17 21:57:14 +02:00
var client model . OidcClient
2025-04-06 06:04:08 -07:00
err := s . db .
WithContext ( ctx ) .
Where ( "id = ?" , clientID ) .
Delete ( & client ) .
Error
if err != nil {
2024-08-17 21:57:14 +02:00
return err
}
return nil
}
2025-04-06 06:04:08 -07:00
func ( s * OidcService ) CreateClientSecret ( ctx context . Context , clientID string ) ( string , error ) {
tx := s . db . Begin ( )
defer func ( ) {
tx . Rollback ( )
} ( )
2024-08-17 21:57:14 +02:00
var client model . OidcClient
2025-04-06 06:04:08 -07:00
err := tx .
WithContext ( ctx ) .
First ( & client , "id = ?" , clientID ) .
Error
if err != nil {
2024-08-17 21:57:14 +02:00
return "" , err
}
clientSecret , err := utils . GenerateRandomAlphanumericString ( 32 )
if err != nil {
return "" , err
}
hashedSecret , err := bcrypt . GenerateFromPassword ( [ ] byte ( clientSecret ) , bcrypt . DefaultCost )
if err != nil {
return "" , err
}
client . Secret = string ( hashedSecret )
2025-04-06 06:04:08 -07:00
err = tx .
WithContext ( ctx ) .
Save ( & client ) .
Error
if err != nil {
return "" , err
}
err = tx . Commit ( ) . Error
if err != nil {
2024-08-17 21:57:14 +02:00
return "" , err
}
return clientSecret , nil
}
2025-04-06 06:04:08 -07:00
func ( s * OidcService ) GetClientLogo ( ctx context . Context , clientID string ) ( string , string , error ) {
2024-08-17 21:57:14 +02:00
var client model . OidcClient
2025-04-06 06:04:08 -07:00
err := s . db .
WithContext ( ctx ) .
First ( & client , "id = ?" , clientID ) .
Error
if err != nil {
2024-08-17 21:57:14 +02:00
return "" , "" , err
}
if client . ImageType == nil {
return "" , "" , errors . New ( "image not found" )
}
2025-04-06 06:04:08 -07:00
imagePath := common . EnvConfig . UploadPath + "/oidc-client-images/" + client . ID + "." + * client . ImageType
mimeType := utils . GetImageMimeType ( * client . ImageType )
2024-08-17 21:57:14 +02:00
return imagePath , mimeType , nil
}
2025-04-06 06:04:08 -07:00
func ( s * OidcService ) UpdateClientLogo ( ctx context . Context , clientID string , file * multipart . FileHeader ) error {
2024-08-17 21:57:14 +02:00
fileType := utils . GetFileExtension ( file . Filename )
if mimeType := utils . GetImageMimeType ( fileType ) ; mimeType == "" {
2024-10-28 18:11:54 +01:00
return & common . FileTypeNotSupportedError { }
2024-08-17 21:57:14 +02:00
}
2025-04-06 06:04:08 -07:00
imagePath := common . EnvConfig . UploadPath + "/oidc-client-images/" + clientID + "." + fileType
err := utils . SaveFile ( file , imagePath )
if err != nil {
2024-08-17 21:57:14 +02:00
return err
}
2025-04-06 06:04:08 -07:00
tx := s . db . Begin ( )
defer func ( ) {
tx . Rollback ( )
} ( )
2024-08-17 21:57:14 +02:00
var client model . OidcClient
2025-04-06 06:04:08 -07:00
err = tx .
WithContext ( ctx ) .
First ( & client , "id = ?" , clientID ) .
Error
if err != nil {
2024-08-17 21:57:14 +02:00
return err
}
if client . ImageType != nil && fileType != * client . ImageType {
oldImagePath := fmt . Sprintf ( "%s/oidc-client-images/%s.%s" , common . EnvConfig . UploadPath , client . ID , * client . ImageType )
if err := os . Remove ( oldImagePath ) ; err != nil {
return err
}
}
client . ImageType = & fileType
2025-04-06 06:04:08 -07:00
err = tx .
WithContext ( ctx ) .
Save ( & client ) .
Error
if err != nil {
return err
}
err = tx . Commit ( ) . Error
if err != nil {
2024-08-17 21:57:14 +02:00
return err
}
return nil
}
2025-04-06 06:04:08 -07:00
func ( s * OidcService ) DeleteClientLogo ( ctx context . Context , clientID string ) error {
tx := s . db . Begin ( )
defer func ( ) {
tx . Rollback ( )
} ( )
2024-08-17 21:57:14 +02:00
var client model . OidcClient
2025-04-06 06:04:08 -07:00
err := tx .
WithContext ( ctx ) .
First ( & client , "id = ?" , clientID ) .
Error
if err != nil {
2024-08-17 21:57:14 +02:00
return err
}
if client . ImageType == nil {
return errors . New ( "image not found" )
}
2025-06-06 08:50:33 +02:00
oldImageType := * client . ImageType
2025-04-06 06:04:08 -07:00
client . ImageType = nil
err = tx .
WithContext ( ctx ) .
Save ( & client ) .
Error
if err != nil {
return err
}
2025-06-06 08:50:33 +02:00
imagePath := common . EnvConfig . UploadPath + "/oidc-client-images/" + client . ID + "." + oldImageType
2024-08-17 21:57:14 +02:00
if err := os . Remove ( imagePath ) ; err != nil {
return err
}
2025-04-06 06:04:08 -07:00
err = tx . Commit ( ) . Error
if err != nil {
2024-08-17 21:57:14 +02:00
return err
}
return nil
}
2025-04-06 06:04:08 -07:00
func ( s * OidcService ) UpdateAllowedUserGroups ( ctx context . Context , id string , input dto . OidcUpdateAllowedUserGroupsDto ) ( client model . OidcClient , err error ) {
tx := s . db . Begin ( )
defer func ( ) {
tx . Rollback ( )
} ( )
client , err = s . getClientInternal ( ctx , id , tx )
2025-02-03 18:41:15 +01:00
if err != nil {
return model . OidcClient { } , err
}
// Fetch the user groups based on UserGroupIDs in input
var groups [ ] model . UserGroup
if len ( input . UserGroupIDs ) > 0 {
2025-04-06 06:04:08 -07:00
err = tx .
WithContext ( ctx ) .
Where ( "id IN (?)" , input . UserGroupIDs ) .
Find ( & groups ) .
Error
if err != nil {
2025-02-03 18:41:15 +01:00
return model . OidcClient { } , err
}
}
// Replace the current user groups with the new set of user groups
2025-04-06 06:04:08 -07:00
err = tx .
WithContext ( ctx ) .
Model ( & client ) .
Association ( "AllowedUserGroups" ) .
Replace ( groups )
if err != nil {
2025-02-03 18:41:15 +01:00
return model . OidcClient { } , err
}
// Save the updated client
2025-04-06 06:04:08 -07:00
err = tx .
WithContext ( ctx ) .
Save ( & client ) .
Error
if err != nil {
return model . OidcClient { } , err
}
err = tx . Commit ( ) . Error
if err != nil {
2025-02-03 18:41:15 +01:00
return model . OidcClient { } , err
}
return client , nil
}
2025-02-14 17:09:27 +01:00
// ValidateEndSession returns the logout callback URL for the client if all the validations pass
2025-04-06 06:04:08 -07:00
func ( s * OidcService ) ValidateEndSession ( ctx context . Context , input dto . OidcLogoutDto , userID string ) ( string , error ) {
2025-02-14 17:09:27 +01:00
// If no ID token hint is provided, return an error
if input . IdTokenHint == "" {
return "" , & common . TokenInvalidError { }
}
// If the ID token hint is provided, verify the ID token
2025-03-27 10:20:39 -07:00
// Here we also accept expired ID tokens, which are fine per spec
token , err := s . jwtService . VerifyIdToken ( input . IdTokenHint , true )
2025-02-14 17:09:27 +01:00
if err != nil {
return "" , & common . TokenInvalidError { }
}
// If the client ID is provided check if the client ID in the ID token matches the client ID in the request
2025-03-27 10:20:39 -07:00
clientID , ok := token . Audience ( )
if ! ok || len ( clientID ) == 0 {
return "" , & common . TokenInvalidError { }
}
if input . ClientId != "" && clientID [ 0 ] != input . ClientId {
2025-02-14 17:09:27 +01:00
return "" , & common . OidcClientIdNotMatchingError { }
}
// Check if the user has authorized the client before
var userAuthorizedOIDCClient model . UserAuthorizedOidcClient
2025-04-06 06:04:08 -07:00
err = s . db .
WithContext ( ctx ) .
Preload ( "Client" ) .
First ( & userAuthorizedOIDCClient , "client_id = ? AND user_id = ?" , clientID [ 0 ] , userID ) .
Error
if err != nil {
2025-02-14 17:09:27 +01:00
return "" , & common . OidcMissingAuthorizationError { }
}
// If the client has no logout callback URLs, return an error
if len ( userAuthorizedOIDCClient . Client . LogoutCallbackURLs ) == 0 {
return "" , & common . OidcNoCallbackURLError { }
}
2025-05-29 13:01:23 -07:00
callbackURL , err := s . getLogoutCallbackURL ( & userAuthorizedOIDCClient . Client , input . PostLogoutRedirectUri )
2025-02-14 17:09:27 +01:00
if err != nil {
return "" , err
}
return callbackURL , nil
}
2025-04-06 06:04:08 -07:00
func ( s * OidcService ) createAuthorizationCode ( ctx context . Context , clientID string , userID string , scope string , nonce string , codeChallenge string , codeChallengeMethod string , tx * gorm . DB ) ( string , error ) {
2024-08-17 21:57:14 +02:00
randomString , err := utils . GenerateRandomAlphanumericString ( 32 )
if err != nil {
return "" , err
}
2024-11-15 15:00:25 +01:00
codeChallengeMethodSha256 := strings . ToUpper ( codeChallengeMethod ) == "S256"
2024-08-17 21:57:14 +02:00
oidcAuthorizationCode := model . OidcAuthorizationCode {
2024-11-15 15:00:25 +01:00
ExpiresAt : datatype . DateTime ( time . Now ( ) . Add ( 15 * time . Minute ) ) ,
Code : randomString ,
ClientID : clientID ,
UserID : userID ,
Scope : scope ,
Nonce : nonce ,
CodeChallenge : & codeChallenge ,
CodeChallengeMethodSha256 : & codeChallengeMethodSha256 ,
2024-08-17 21:57:14 +02:00
}
2025-04-06 06:04:08 -07:00
err = tx .
WithContext ( ctx ) .
Create ( & oidcAuthorizationCode ) .
Error
if err != nil {
2024-08-17 21:57:14 +02:00
return "" , err
}
return randomString , nil
}
2024-08-24 00:49:08 +02:00
2024-11-15 15:00:25 +01:00
func ( s * OidcService ) validateCodeVerifier ( codeVerifier , codeChallenge string , codeChallengeMethodSha256 bool ) bool {
2025-01-03 16:15:10 +01:00
if codeVerifier == "" || codeChallenge == "" {
return false
}
2024-11-15 15:00:25 +01:00
if ! codeChallengeMethodSha256 {
return codeVerifier == codeChallenge
}
// Compute SHA-256 hash of the codeVerifier
h := sha256 . New ( )
h . Write ( [ ] byte ( codeVerifier ) )
codeVerifierHash := h . Sum ( nil )
// Base64 URL encode the verifier hash
encodedVerifierHash := base64 . RawURLEncoding . EncodeToString ( codeVerifierHash )
return encodedVerifierHash == codeChallenge
}
2025-05-29 13:01:23 -07:00
func ( s * OidcService ) getCallbackURL ( client * model . OidcClient , inputCallbackURL string , tx * gorm . DB , ctx context . Context ) ( callbackURL string , err error ) {
2025-05-29 10:16:10 -05:00
// If no input callback URL provided, use the first configured URL
2024-08-24 00:49:08 +02:00
if inputCallbackURL == "" {
2025-05-29 13:01:23 -07:00
if len ( client . CallbackURLs ) > 0 {
return client . CallbackURLs [ 0 ] , nil
2025-05-29 10:16:10 -05:00
}
// If no URLs are configured and no input URL, this is an error
return "" , & common . OidcMissingCallbackURLError { }
2024-08-24 00:49:08 +02:00
}
2025-01-20 11:19:23 +01:00
2025-05-29 10:16:10 -05:00
// If URLs are already configured, validate against them
2025-05-29 13:01:23 -07:00
if len ( client . CallbackURLs ) > 0 {
matched , err := s . getCallbackURLFromList ( client . CallbackURLs , inputCallbackURL )
if err != nil {
return "" , err
} else if matched == "" {
return "" , & common . OidcInvalidCallbackURLError { }
2025-01-20 11:19:23 +01:00
}
2025-05-29 13:01:23 -07:00
return matched , nil
2024-08-24 00:49:08 +02:00
}
2025-05-29 10:16:10 -05:00
// If no URLs are configured, trust and store the first URL (TOFU)
2025-05-29 13:01:23 -07:00
err = s . addCallbackURLToClient ( ctx , client , inputCallbackURL , tx )
2025-05-29 10:16:10 -05:00
if err != nil {
return "" , err
}
return inputCallbackURL , nil
}
2025-05-29 13:01:23 -07:00
func ( s * OidcService ) getLogoutCallbackURL ( client * model . OidcClient , inputLogoutCallbackURL string ) ( callbackURL string , err error ) {
if inputLogoutCallbackURL == "" {
return client . LogoutCallbackURLs [ 0 ] , nil
}
matched , err := s . getCallbackURLFromList ( client . LogoutCallbackURLs , inputLogoutCallbackURL )
2025-05-29 10:16:10 -05:00
if err != nil {
2025-05-29 13:01:23 -07:00
return "" , err
} else if matched == "" {
return "" , & common . OidcInvalidCallbackURLError { }
}
return matched , nil
}
func ( s * OidcService ) getCallbackURLFromList ( urls [ ] string , inputCallbackURL string ) ( callbackURL string , err error ) {
for _ , callbackPattern := range urls {
regexPattern := "^" + strings . ReplaceAll ( regexp . QuoteMeta ( callbackPattern ) , ` \* ` , ".*" ) + "$"
matched , err := regexp . MatchString ( regexPattern , inputCallbackURL )
if err != nil {
return "" , err
}
if matched {
return inputCallbackURL , nil
}
2025-05-29 10:16:10 -05:00
}
2025-05-29 13:01:23 -07:00
return "" , nil
}
func ( s * OidcService ) addCallbackURLToClient ( ctx context . Context , client * model . OidcClient , callbackURL string , tx * gorm . DB ) error {
2025-05-29 10:16:10 -05:00
// Add the new callback URL to the existing list
client . CallbackURLs = append ( client . CallbackURLs , callbackURL )
2025-05-29 13:01:23 -07:00
err := tx . WithContext ( ctx ) . Save ( client ) . Error
2025-05-29 10:16:10 -05:00
if err != nil {
return err
}
return nil
2024-08-24 00:49:08 +02:00
}
2025-03-23 15:14:26 -05:00
2025-04-27 02:32:42 +09:00
func ( s * OidcService ) CreateDeviceAuthorization ( ctx context . Context , input dto . OidcDeviceAuthorizationRequestDto ) ( * dto . OidcDeviceAuthorizationResponseDto , error ) {
2025-06-06 03:23:51 -07:00
client , err := s . verifyClientCredentialsInternal ( ctx , s . db , dto . OidcCreateTokensDto {
ClientID : input . ClientID ,
ClientSecret : input . ClientSecret ,
} )
2025-04-25 12:14:51 -05:00
if err != nil {
return nil , err
}
// Generate codes
deviceCode , err := utils . GenerateRandomAlphanumericString ( 32 )
if err != nil {
return nil , err
}
userCode , err := utils . GenerateRandomAlphanumericString ( 8 )
if err != nil {
return nil , err
}
// Create device authorization
deviceAuth := & model . OidcDeviceCode {
DeviceCode : deviceCode ,
UserCode : userCode ,
Scope : input . Scope ,
ExpiresAt : datatype . DateTime ( time . Now ( ) . Add ( 15 * time . Minute ) ) ,
IsAuthorized : false ,
ClientID : client . ID ,
}
if err := s . db . Create ( deviceAuth ) . Error ; err != nil {
return nil , err
}
return & dto . OidcDeviceAuthorizationResponseDto {
DeviceCode : deviceCode ,
UserCode : userCode ,
VerificationURI : common . EnvConfig . AppURL + "/device" ,
VerificationURIComplete : common . EnvConfig . AppURL + "/device?code=" + userCode ,
ExpiresIn : 900 , // 15 minutes
Interval : 5 ,
} , nil
}
func ( s * OidcService ) VerifyDeviceCode ( ctx context . Context , userCode string , userID string , ipAddress string , userAgent string ) error {
tx := s . db . Begin ( )
defer func ( ) {
tx . Rollback ( )
} ( )
var deviceAuth model . OidcDeviceCode
if err := tx . WithContext ( ctx ) . Preload ( "Client.AllowedUserGroups" ) . First ( & deviceAuth , "user_code = ?" , userCode ) . Error ; err != nil {
log . Printf ( "Error finding device code with user_code %s: %v" , userCode , err )
return err
}
if time . Now ( ) . After ( deviceAuth . ExpiresAt . ToTime ( ) ) {
return & common . OidcDeviceCodeExpiredError { }
}
// Check if the user group is allowed to authorize the client
var user model . User
if err := tx . WithContext ( ctx ) . Preload ( "UserGroups" ) . First ( & user , "id = ?" , userID ) . Error ; err != nil {
return err
}
if ! s . IsUserGroupAllowedToAuthorize ( user , deviceAuth . Client ) {
return & common . OidcAccessDeniedError { }
}
if err := tx . WithContext ( ctx ) . Preload ( "Client" ) . First ( & deviceAuth , "user_code = ?" , userCode ) . Error ; err != nil {
log . Printf ( "Error finding device code with user_code %s: %v" , userCode , err )
return err
}
if time . Now ( ) . After ( deviceAuth . ExpiresAt . ToTime ( ) ) {
return & common . OidcDeviceCodeExpiredError { }
}
deviceAuth . UserID = & userID
deviceAuth . IsAuthorized = true
if err := tx . WithContext ( ctx ) . Save ( & deviceAuth ) . Error ; err != nil {
log . Printf ( "Error saving device auth: %v" , err )
return err
}
// Verify the update was successful
var verifiedAuth model . OidcDeviceCode
if err := tx . WithContext ( ctx ) . First ( & verifiedAuth , "device_code = ?" , deviceAuth . DeviceCode ) . Error ; err != nil {
log . Printf ( "Error verifying update: %v" , err )
return err
}
// Create user authorization if needed
hasAuthorizedClient , err := s . hasAuthorizedClientInternal ( ctx , deviceAuth . ClientID , userID , deviceAuth . Scope , tx )
if err != nil {
return err
}
if ! hasAuthorizedClient {
2025-04-28 10:40:43 +02:00
err := s . createAuthorizedClientInternal ( ctx , userID , deviceAuth . ClientID , deviceAuth . Scope , tx )
2025-04-28 09:29:18 +02:00
if err != nil {
return err
2025-04-25 12:14:51 -05:00
}
s . auditLogService . Create ( ctx , model . AuditLogEventNewDeviceCodeAuthorization , ipAddress , userAgent , userID , model . AuditLogData { "clientName" : deviceAuth . Client . Name } , tx )
} else {
s . auditLogService . Create ( ctx , model . AuditLogEventDeviceCodeAuthorization , ipAddress , userAgent , userID , model . AuditLogData { "clientName" : deviceAuth . Client . Name } , tx )
}
return tx . Commit ( ) . Error
}
func ( s * OidcService ) GetDeviceCodeInfo ( ctx context . Context , userCode string , userID string ) ( * dto . DeviceCodeInfoDto , error ) {
var deviceAuth model . OidcDeviceCode
2025-04-27 02:32:42 +09:00
err := s . db .
WithContext ( ctx ) .
Preload ( "Client" ) .
First ( & deviceAuth , "user_code = ?" , userCode ) .
Error
if err != nil {
2025-04-25 12:14:51 -05:00
if errors . Is ( err , gorm . ErrRecordNotFound ) {
return nil , & common . OidcInvalidDeviceCodeError { }
}
return nil , err
}
if time . Now ( ) . After ( deviceAuth . ExpiresAt . ToTime ( ) ) {
return nil , & common . OidcDeviceCodeExpiredError { }
}
// Check if the user has already authorized this client with this scope
hasAuthorizedClient := false
if userID != "" {
var err error
hasAuthorizedClient , err = s . HasAuthorizedClient ( ctx , deviceAuth . ClientID , userID , deviceAuth . Scope )
if err != nil {
return nil , err
}
}
return & dto . DeviceCodeInfoDto {
Client : dto . OidcClientMetaDataDto {
ID : deviceAuth . Client . ID ,
Name : deviceAuth . Client . Name ,
HasLogo : deviceAuth . Client . HasLogo ,
} ,
Scope : deviceAuth . Scope ,
AuthorizationRequired : ! hasAuthorizedClient ,
} , nil
}
2025-05-25 14:22:25 -05:00
func ( s * OidcService ) GetAllowedGroupsCountOfClient ( ctx context . Context , id string ) ( int64 , error ) {
2025-05-25 22:24:20 +02:00
// We only perform select queries here, so we can rollback in all cases
tx := s . db . Begin ( )
defer func ( ) {
tx . Rollback ( )
} ( )
2025-05-25 14:22:25 -05:00
var client model . OidcClient
2025-05-25 22:24:20 +02:00
err := tx . WithContext ( ctx ) . Where ( "id = ?" , id ) . First ( & client ) . Error
2025-05-25 14:22:25 -05:00
if err != nil {
return 0 , err
}
2025-05-25 22:24:20 +02:00
count := tx . WithContext ( ctx ) . Model ( & client ) . Association ( "AllowedUserGroups" ) . Count ( )
2025-05-25 14:22:25 -05:00
return count , nil
}
2025-06-04 09:23:44 +02:00
func ( s * OidcService ) ListAuthorizedClients ( ctx context . Context , userID string , sortedPaginationRequest utils . SortedPaginationRequest ) ( [ ] model . UserAuthorizedOidcClient , utils . PaginationResponse , error ) {
query := s . db .
WithContext ( ctx ) .
Model ( & model . UserAuthorizedOidcClient { } ) .
Preload ( "Client" ) .
Where ( "user_id = ?" , userID )
var authorizedClients [ ] model . UserAuthorizedOidcClient
response , err := utils . PaginateAndSort ( sortedPaginationRequest , query , & authorizedClients )
return authorizedClients , response , err
}
2025-04-06 06:04:08 -07:00
func ( s * OidcService ) createRefreshToken ( ctx context . Context , clientID string , userID string , scope string , tx * gorm . DB ) ( string , error ) {
2025-03-25 07:36:53 -07:00
refreshToken , err := utils . GenerateRandomAlphanumericString ( 40 )
2025-03-23 15:14:26 -05:00
if err != nil {
return "" , err
}
2025-03-25 07:36:53 -07:00
// Compute the hash of the refresh token to store in the DB
// Refresh tokens are pretty long already, so a "simple" SHA-256 hash is enough
refreshTokenHash := utils . CreateSha256Hash ( refreshToken )
m := model . OidcRefreshToken {
2025-03-23 15:14:26 -05:00
ExpiresAt : datatype . DateTime ( time . Now ( ) . Add ( 30 * 24 * time . Hour ) ) , // 30 days
2025-03-25 07:36:53 -07:00
Token : refreshTokenHash ,
2025-03-23 15:14:26 -05:00
ClientID : clientID ,
UserID : userID ,
Scope : scope ,
}
2025-04-06 06:04:08 -07:00
err = tx .
WithContext ( ctx ) .
Create ( & m ) .
Error
if err != nil {
2025-03-23 15:14:26 -05:00
return "" , err
}
2025-03-25 07:36:53 -07:00
return refreshToken , nil
2025-03-23 15:14:26 -05:00
}
2025-04-25 12:14:51 -05:00
2025-04-28 09:29:18 +02:00
func ( s * OidcService ) createAuthorizedClientInternal ( ctx context . Context , userID string , clientID string , scope string , tx * gorm . DB ) error {
userAuthorizedClient := model . UserAuthorizedOidcClient {
UserID : userID ,
ClientID : clientID ,
Scope : scope ,
}
err := tx . WithContext ( ctx ) .
Clauses ( clause . OnConflict {
Columns : [ ] clause . Column { { Name : "user_id" } , { Name : "client_id" } } ,
DoUpdates : clause . AssignmentColumns ( [ ] string { "scope" } ) ,
} ) .
Create ( & userAuthorizedClient ) .
Error
return err
}
2025-06-06 03:23:51 -07:00
func ( s * OidcService ) verifyClientCredentialsInternal ( ctx context . Context , tx * gorm . DB , input dto . OidcCreateTokensDto ) ( * model . OidcClient , error ) {
2025-05-19 08:10:33 -07:00
// First, ensure we have a valid client ID
2025-06-06 03:23:51 -07:00
if input . ClientID == "" {
return nil , & common . OidcMissingClientCredentialsError { }
2025-04-25 12:14:51 -05:00
}
2025-05-19 08:10:33 -07:00
// Load the OIDC client's configuration
2025-04-25 12:14:51 -05:00
var client model . OidcClient
err := tx .
WithContext ( ctx ) .
2025-06-06 03:23:51 -07:00
First ( & client , "id = ?" , input . ClientID ) .
2025-04-25 12:14:51 -05:00
Error
if err != nil {
2025-06-06 03:23:51 -07:00
return nil , err
2025-04-25 12:14:51 -05:00
}
2025-06-06 03:23:51 -07:00
// We have 3 options
// If credentials are provided, we validate them; otherwise, we can continue without credentials for public clients only
switch {
// First, if we have a client secret, we validate it
case input . ClientSecret != "" :
err = bcrypt . CompareHashAndPassword ( [ ] byte ( client . Secret ) , [ ] byte ( input . ClientSecret ) )
2025-05-19 08:10:33 -07:00
if err != nil {
2025-06-06 03:23:51 -07:00
return nil , & common . OidcClientSecretInvalidError { }
2025-04-25 12:14:51 -05:00
}
2025-06-06 03:23:51 -07:00
return & client , nil
// Next, check if we want to use client assertions from federated identities
case input . ClientAssertionType == ClientAssertionTypeJWTBearer && input . ClientAssertion != "" :
err = s . verifyClientAssertionFromFederatedIdentities ( ctx , & client , input )
if err != nil {
log . Printf ( "Invalid assertion for client '%s': %v" , client . ID , err )
return nil , & common . OidcClientAssertionInvalidError { }
}
return & client , nil
// There's no credentials
// This is allowed only if the client is public
case client . IsPublic :
return & client , nil
// If we're here, we have no credentials AND the client is not public, so credentials are required
default :
return nil , & common . OidcMissingClientCredentialsError { }
2025-04-25 12:14:51 -05:00
}
2025-06-06 03:23:51 -07:00
}
2025-04-25 12:14:51 -05:00
2025-06-06 03:23:51 -07:00
func ( s * OidcService ) jwkSetForURL ( ctx context . Context , url string ) ( set jwk . Set , err error ) {
// Check if we have already registered the URL
if ! s . jwkCache . IsRegistered ( ctx , url ) {
// We set a timeout because otherwise Register will keep trying in case of errors
registerCtx , registerCancel := context . WithTimeout ( ctx , 15 * time . Second )
defer registerCancel ( )
// We need to register the URL
err = s . jwkCache . Register (
registerCtx ,
url ,
jwk . WithMaxInterval ( 24 * time . Hour ) ,
jwk . WithMinInterval ( 15 * time . Minute ) ,
jwk . WithWaitReady ( true ) ,
)
// In case of race conditions (two goroutines calling jwkCache.Register at the same time), it's possible we can get a conflict anyways, so we ignore that error
if err != nil && ! errors . Is ( err , httprc . ErrResourceAlreadyExists ( ) ) {
return nil , fmt . Errorf ( "failed to register JWK set: %w" , err )
}
}
jwks , err := s . jwkCache . CachedSet ( url )
if err != nil {
return nil , fmt . Errorf ( "failed to get cached JWK set: %w" , err )
}
return jwks , nil
}
func ( s * OidcService ) verifyClientAssertionFromFederatedIdentities ( ctx context . Context , client * model . OidcClient , input dto . OidcCreateTokensDto ) error {
// First, parse the assertion JWT, without validating it, to check the issuer
assertion := [ ] byte ( input . ClientAssertion )
insecureToken , err := jwt . ParseInsecure ( assertion )
if err != nil {
return fmt . Errorf ( "failed to parse client assertion JWT: %w" , err )
}
issuer , _ := insecureToken . Issuer ( )
if issuer == "" {
return errors . New ( "client assertion does not contain an issuer claim" )
}
// Ensure that this client is federated with the one that issued the token
ocfi , ok := client . Credentials . FederatedIdentityForIssuer ( issuer )
if ! ok {
return fmt . Errorf ( "client assertion is not from an allowed issuer: %s" , issuer )
}
// Get the JWK set for the issuer
jwksURL := ocfi . JWKS
if jwksURL == "" {
// Default URL is from the issuer
if strings . HasSuffix ( issuer , "/" ) {
jwksURL = issuer + ".well-known/jwks.json"
} else {
jwksURL = issuer + "/.well-known/jwks.json"
}
}
jwks , err := s . jwkSetForURL ( ctx , jwksURL )
if err != nil {
return fmt . Errorf ( "failed to get JWK set for issuer '%s': %w" , issuer , err )
}
// Set default audience and subject if missing
audience := ocfi . Audience
if audience == "" {
// Default to the Pocket ID's URL
audience = common . EnvConfig . AppURL
}
subject := ocfi . Subject
if subject == "" {
// Default to the client ID, per RFC 7523
subject = client . ID
}
// Now re-parse the token with proper validation
// (Note: we don't use jwt.WithIssuer() because that would be redundant)
_ , err = jwt . Parse ( assertion ,
jwt . WithValidate ( true ) ,
jwt . WithAcceptableSkew ( clockSkew ) ,
jwt . WithKeySet ( jwks , jws . WithInferAlgorithmFromKey ( true ) , jws . WithUseDefault ( true ) ) ,
jwt . WithAudience ( audience ) ,
jwt . WithSubject ( subject ) ,
)
if err != nil {
return fmt . Errorf ( "client assertion is not valid: %w" , err )
}
// If we're here, the assertion is valid
return nil
2025-04-25 12:14:51 -05:00
}
2025-06-09 10:46:03 -05:00
func ( s * OidcService ) GetClientPreview ( ctx context . Context , clientID string , userID string , scopes string ) ( * dto . OidcClientPreviewDto , error ) {
tx := s . db . Begin ( )
defer func ( ) {
tx . Rollback ( )
} ( )
client , err := s . getClientInternal ( ctx , clientID , tx )
if err != nil {
return nil , err
}
var user model . User
err = tx .
WithContext ( ctx ) .
Preload ( "UserGroups" ) .
First ( & user , "id = ?" , userID ) .
Error
if err != nil {
return nil , err
}
if ! s . IsUserGroupAllowedToAuthorize ( user , client ) {
return nil , & common . OidcAccessDeniedError { }
}
dummyAuthorizedClient := model . UserAuthorizedOidcClient {
UserID : userID ,
ClientID : clientID ,
Scope : scopes ,
User : user ,
}
userClaims , err := s . getUserClaimsFromAuthorizedClient ( ctx , & dummyAuthorizedClient , tx )
if err != nil {
return nil , err
}
idToken , err := s . jwtService . BuildIDToken ( userClaims , clientID , "" )
if err != nil {
return nil , err
}
accessToken , err := s . jwtService . BuildOauthAccessToken ( user , clientID )
if err != nil {
return nil , err
}
idTokenPayload , err := utils . GetClaimsFromToken ( idToken )
if err != nil {
return nil , err
}
accessTokenPayload , err := utils . GetClaimsFromToken ( accessToken )
if err != nil {
return nil , err
}
err = tx . Commit ( ) . Error
if err != nil {
return nil , err
}
return & dto . OidcClientPreviewDto {
IdToken : idTokenPayload ,
AccessToken : accessTokenPayload ,
UserInfo : userClaims ,
} , nil
}
func ( s * OidcService ) GetUserClaimsForClient ( ctx context . Context , userID string , clientID string ) ( map [ string ] interface { } , error ) {
tx := s . db . Begin ( )
defer func ( ) {
tx . Rollback ( )
} ( )
claims , err := s . getUserClaimsForClientInternal ( ctx , userID , clientID , s . db )
if err != nil {
return nil , err
}
err = tx . Commit ( ) . Error
if err != nil {
return nil , err
}
return claims , nil
}
func ( s * OidcService ) getUserClaimsForClientInternal ( ctx context . Context , userID string , clientID string , tx * gorm . DB ) ( map [ string ] interface { } , error ) {
var authorizedOidcClient model . UserAuthorizedOidcClient
err := tx .
WithContext ( ctx ) .
Preload ( "User.UserGroups" ) .
First ( & authorizedOidcClient , "user_id = ? AND client_id = ?" , userID , clientID ) .
Error
if err != nil {
return nil , err
}
return s . getUserClaimsFromAuthorizedClient ( ctx , & authorizedOidcClient , tx )
}
func ( s * OidcService ) getUserClaimsFromAuthorizedClient ( ctx context . Context , authorizedClient * model . UserAuthorizedOidcClient , tx * gorm . DB ) ( map [ string ] interface { } , error ) {
user := authorizedClient . User
scopes := strings . Split ( authorizedClient . Scope , " " )
claims := map [ string ] interface { } {
"sub" : user . ID ,
}
if slices . Contains ( scopes , "email" ) {
claims [ "email" ] = user . Email
claims [ "email_verified" ] = s . appConfigService . GetDbConfig ( ) . EmailsVerified . IsTrue ( )
}
if slices . Contains ( scopes , "groups" ) {
userGroups := make ( [ ] string , len ( user . UserGroups ) )
for i , group := range user . UserGroups {
userGroups [ i ] = group . Name
}
claims [ "groups" ] = userGroups
}
profileClaims := map [ string ] interface { } {
"given_name" : user . FirstName ,
"family_name" : user . LastName ,
"name" : user . FullName ( ) ,
"preferred_username" : user . Username ,
"picture" : common . EnvConfig . AppURL + "/api/users/" + user . ID + "/profile-picture.png" ,
}
if slices . Contains ( scopes , "profile" ) {
// Add profile claims
for k , v := range profileClaims {
claims [ k ] = v
}
// Add custom claims
customClaims , err := s . customClaimService . GetCustomClaimsForUserWithUserGroups ( ctx , user . ID , tx )
if err != nil {
return nil , err
}
for _ , customClaim := range customClaims {
// The value of the custom claim can be a JSON object or a string
var jsonValue interface { }
err := json . Unmarshal ( [ ] byte ( customClaim . Value ) , & jsonValue )
if err == nil {
// It's JSON, so we store it as an object
claims [ customClaim . Key ] = jsonValue
} else {
// Marshaling failed, so we store it as a string
claims [ customClaim . Key ] = customClaim . Value
}
}
}
if slices . Contains ( scopes , "email" ) {
claims [ "email" ] = user . Email
}
return claims , nil
}