2024-08-17 21:57:14 +02:00
package service
import (
2025-04-06 06:04:08 -07:00
"context"
2024-11-15 15:00:25 +01:00
"crypto/sha256"
2025-06-06 03:23:51 -07:00
"crypto/tls"
2024-11-15 15:00:25 +01:00
"encoding/base64"
2025-02-05 18:28:21 +01:00
"encoding/json"
2024-08-17 21:57:14 +02:00
"errors"
"fmt"
2025-09-29 10:07:55 -05:00
"io"
2025-06-06 03:23:51 -07:00
"log/slog"
2024-08-17 21:57:14 +02:00
"mime/multipart"
2025-09-29 10:07:55 -05:00
"net"
2025-06-06 03:23:51 -07:00
"net/http"
2025-09-29 10:07:55 -05:00
"net/url"
2024-08-17 21:57:14 +02:00
"os"
2025-01-20 11:19:23 +01:00
"regexp"
2025-04-06 06:04:08 -07:00
"slices"
2024-08-19 18:48:18 +02:00
"strings"
2024-08-17 21:57:14 +02:00
"time"
2025-02-05 18:08:01 +01:00
2025-06-06 03:23:51 -07:00
"github.com/lestrrat-go/httprc/v3"
"github.com/lestrrat-go/httprc/v3/errsink"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/lestrrat-go/jwx/v3/jws"
2025-04-09 09:18:03 +02:00
"github.com/lestrrat-go/jwx/v3/jwt"
2025-05-19 08:10:33 -07:00
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
"gorm.io/gorm/clause"
2025-04-09 09:18:03 +02:00
2025-02-05 18:08:01 +01:00
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/utils"
2025-05-19 08:10:33 -07:00
)
const (
GrantTypeAuthorizationCode = "authorization_code"
GrantTypeRefreshToken = "refresh_token"
GrantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
2025-09-03 01:33:01 +02:00
GrantTypeClientCredentials = "client_credentials"
2025-06-06 03:23:51 -07:00
ClientAssertionTypeJWTBearer = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" //nolint:gosec
2025-06-09 12:17:55 -07:00
2025-09-03 01:33:01 +02:00
AccessTokenDuration = time . Hour
2025-06-09 12:17:55 -07:00
RefreshTokenDuration = 30 * 24 * time . Hour // 30 days
DeviceCodeDuration = 15 * time . Minute
2024-08-17 21:57:14 +02:00
)
type OidcService struct {
2024-10-28 18:11:54 +01:00
db * gorm . DB
jwtService * JwtService
appConfigService * AppConfigService
auditLogService * AuditLogService
customClaimService * CustomClaimService
2025-08-22 08:56:40 +02:00
webAuthnService * WebAuthnService
2025-06-06 03:23:51 -07:00
httpClient * http . Client
jwkCache * jwk . Cache
2024-08-17 21:57:14 +02:00
}
2025-06-06 03:23:51 -07:00
func NewOidcService (
ctx context . Context ,
db * gorm . DB ,
jwtService * JwtService ,
appConfigService * AppConfigService ,
auditLogService * AuditLogService ,
customClaimService * CustomClaimService ,
2025-08-22 08:56:40 +02:00
webAuthnService * WebAuthnService ,
2025-09-29 10:07:55 -05:00
httpClient * http . Client ,
2025-06-06 03:23:51 -07:00
) ( s * OidcService , err error ) {
s = & OidcService {
2024-10-28 18:11:54 +01:00
db : db ,
jwtService : jwtService ,
appConfigService : appConfigService ,
auditLogService : auditLogService ,
customClaimService : customClaimService ,
2025-08-22 08:56:40 +02:00
webAuthnService : webAuthnService ,
2025-09-29 10:07:55 -05:00
httpClient : httpClient ,
2024-08-17 21:57:14 +02:00
}
2025-06-06 03:23:51 -07:00
// Note: we don't pass the HTTP Client with OTel instrumented to this because requests are always made in background and not tied to a specific trace
s . jwkCache , err = s . getJWKCache ( ctx )
if err != nil {
return nil , err
}
return s , nil
}
func ( s * OidcService ) getJWKCache ( ctx context . Context ) ( * jwk . Cache , error ) {
// We need to create a custom HTTP client to set a timeout.
client := s . httpClient
if client == nil {
client = & http . Client {
Timeout : 20 * time . Second ,
}
defaultTransport , ok := http . DefaultTransport . ( * http . Transport )
if ! ok {
// Indicates a development-time error
panic ( "Default transport is not of type *http.Transport" )
}
transport := defaultTransport . Clone ( )
transport . TLSClientConfig . MinVersion = tls . VersionTLS12
client . Transport = transport
}
// Create the JWKS cache
return jwk . NewCache ( ctx ,
httprc . NewClient (
httprc . WithErrorSink ( errsink . NewSlog ( slog . Default ( ) ) ) ,
httprc . WithHTTPClient ( client ) ,
) ,
)
2024-08-17 21:57:14 +02:00
}
2025-04-06 06:04:08 -07:00
func ( s * OidcService ) Authorize ( ctx context . Context , input dto . AuthorizeOidcClientRequestDto , userID , ipAddress , userAgent string ) ( string , string , error ) {
tx := s . db . Begin ( )
defer func ( ) {
tx . Rollback ( )
} ( )
2025-02-03 18:41:15 +01:00
var client model . OidcClient
2025-04-06 06:04:08 -07:00
err := tx .
WithContext ( ctx ) .
Preload ( "AllowedUserGroups" ) .
First ( & client , "id = ?" , input . ClientID ) .
Error
if err != nil {
2025-02-03 18:41:15 +01:00
return "" , "" , err
2024-11-15 15:00:25 +01:00
}
2025-08-22 08:56:40 +02:00
if client . RequiresReauthentication {
if input . ReauthenticationToken == "" {
return "" , "" , & common . ReauthenticationRequiredError { }
}
err = s . webAuthnService . ConsumeReauthenticationToken ( ctx , tx , input . ReauthenticationToken , userID )
if err != nil {
return "" , "" , err
}
}
2025-02-03 18:41:15 +01:00
// If the client is not public, the code challenge must be provided
if client . IsPublic && input . CodeChallenge == "" {
return "" , "" , & common . OidcMissingCodeChallengeError { }
2024-08-17 21:57:14 +02:00
}
2025-02-03 18:41:15 +01:00
// Get the callback URL of the client. Return an error if the provided callback URL is not allowed
2025-05-29 13:01:23 -07:00
callbackURL , err := s . getCallbackURL ( & client , input . CallbackURL , tx , ctx )
2024-08-24 00:49:08 +02:00
if err != nil {
return "" , "" , err
}
2025-02-03 18:41:15 +01:00
// Check if the user group is allowed to authorize the client
var user model . User
2025-04-06 06:04:08 -07:00
err = tx .
WithContext ( ctx ) .
Preload ( "UserGroups" ) .
First ( & user , "id = ?" , userID ) .
Error
if err != nil {
2024-09-09 10:29:41 +02:00
return "" , "" , err
}
2025-02-03 18:41:15 +01:00
if ! s . IsUserGroupAllowedToAuthorize ( user , client ) {
return "" , "" , & common . OidcAccessDeniedError { }
}
2024-08-17 21:57:14 +02:00
2025-08-10 10:56:03 -05:00
hasAlreadyAuthorizedClient , err := s . createAuthorizedClientInternal ( ctx , userID , input . ClientID , input . Scope , tx )
2025-02-03 18:41:15 +01:00
if err != nil {
2024-08-24 00:49:08 +02:00
return "" , "" , err
}
2025-02-03 18:41:15 +01:00
// Create the authorization code
2025-04-06 06:04:08 -07:00
code , err := s . createAuthorizationCode ( ctx , input . ClientID , userID , input . Scope , input . Nonce , input . CodeChallenge , input . CodeChallengeMethod , tx )
2024-08-24 00:49:08 +02:00
if err != nil {
return "" , "" , err
}
2025-02-03 18:41:15 +01:00
// Log the authorization event
2025-08-10 10:56:03 -05:00
if hasAlreadyAuthorizedClient {
2025-04-06 06:04:08 -07:00
s . auditLogService . Create ( ctx , model . AuditLogEventClientAuthorization , ipAddress , userAgent , userID , model . AuditLogData { "clientName" : client . Name } , tx )
2025-02-03 18:41:15 +01:00
} else {
2025-04-06 06:04:08 -07:00
s . auditLogService . Create ( ctx , model . AuditLogEventNewClientAuthorization , ipAddress , userAgent , userID , model . AuditLogData { "clientName" : client . Name } , tx )
}
2025-02-03 18:41:15 +01:00
2025-04-06 06:04:08 -07:00
err = tx . Commit ( ) . Error
if err != nil {
return "" , "" , err
2024-08-17 21:57:14 +02:00
}
2025-02-03 18:41:15 +01:00
return code , callbackURL , nil
}
// HasAuthorizedClient checks if the user has already authorized the client with the given scope
2025-04-06 06:04:08 -07:00
func ( s * OidcService ) HasAuthorizedClient ( ctx context . Context , clientID , userID , scope string ) ( bool , error ) {
return s . hasAuthorizedClientInternal ( ctx , clientID , userID , scope , s . db )
}
func ( s * OidcService ) hasAuthorizedClientInternal ( ctx context . Context , clientID , userID , scope string , tx * gorm . DB ) ( bool , error ) {
2025-02-03 18:41:15 +01:00
var userAuthorizedOidcClient model . UserAuthorizedOidcClient
2025-04-06 06:04:08 -07:00
err := tx .
WithContext ( ctx ) .
First ( & userAuthorizedOidcClient , "client_id = ? AND user_id = ?" , clientID , userID ) .
Error
if err != nil {
2025-02-03 18:41:15 +01:00
if errors . Is ( err , gorm . ErrRecordNotFound ) {
return false , nil
2024-08-17 21:57:14 +02:00
}
2025-02-03 18:41:15 +01:00
return false , err
2024-08-17 21:57:14 +02:00
}
2025-02-03 18:41:15 +01:00
if userAuthorizedOidcClient . Scope != scope {
return false , nil
2024-09-09 10:29:41 +02:00
}
2025-02-03 18:41:15 +01:00
return true , nil
}
2024-09-09 10:29:41 +02:00
2025-02-03 18:41:15 +01:00
// IsUserGroupAllowedToAuthorize checks if the user group of the user is allowed to authorize the client
func ( s * OidcService ) IsUserGroupAllowedToAuthorize ( user model . User , client model . OidcClient ) bool {
if len ( client . AllowedUserGroups ) == 0 {
return true
}
isAllowedToAuthorize := false
for _ , userGroup := range client . AllowedUserGroups {
for _ , userGroupUser := range user . UserGroups {
if userGroup . ID == userGroupUser . ID {
isAllowedToAuthorize = true
break
}
}
}
return isAllowedToAuthorize
2024-08-17 21:57:14 +02:00
}
2025-05-19 08:10:33 -07:00
type CreatedTokens struct {
IdToken string
AccessToken string
RefreshToken string
ExpiresIn time . Duration
}
func ( s * OidcService ) CreateTokens ( ctx context . Context , input dto . OidcCreateTokensDto ) ( CreatedTokens , error ) {
2025-04-25 12:14:51 -05:00
switch input . GrantType {
2025-05-19 08:10:33 -07:00
case GrantTypeAuthorizationCode :
return s . createTokenFromAuthorizationCode ( ctx , input )
case GrantTypeRefreshToken :
return s . createTokenFromRefreshToken ( ctx , input )
case GrantTypeDeviceCode :
return s . createTokenFromDeviceCode ( ctx , input )
2025-09-03 01:33:01 +02:00
case GrantTypeClientCredentials :
return s . createTokenFromClientCredentials ( ctx , input )
2025-03-25 07:36:53 -07:00
default :
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , & common . OidcGrantTypeNotSupportedError { }
2025-03-25 07:36:53 -07:00
}
}
2024-08-17 21:57:14 +02:00
2025-05-19 08:10:33 -07:00
func ( s * OidcService ) createTokenFromDeviceCode ( ctx context . Context , input dto . OidcCreateTokensDto ) ( CreatedTokens , error ) {
2025-04-06 06:04:08 -07:00
tx := s . db . Begin ( )
defer func ( ) {
tx . Rollback ( )
} ( )
2025-07-01 23:14:07 +02:00
_ , err := s . verifyClientCredentialsInternal ( ctx , tx , clientAuthCredentialsFromCreateTokensDto ( & input ) , true )
2025-04-06 06:04:08 -07:00
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-03-25 07:36:53 -07:00
}
2024-08-17 21:57:14 +02:00
2025-04-25 12:14:51 -05:00
// Get the device authorization from database with explicit query conditions
var deviceAuth model . OidcDeviceCode
2025-05-19 08:10:33 -07:00
err = tx .
WithContext ( ctx ) .
Preload ( "User" ) .
Where ( "device_code = ? AND client_id = ?" , input . DeviceCode , input . ClientID ) .
First ( & deviceAuth ) .
Error
if err != nil {
2025-04-25 12:14:51 -05:00
if errors . Is ( err , gorm . ErrRecordNotFound ) {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , & common . OidcInvalidDeviceCodeError { }
2024-11-15 15:00:25 +01:00
}
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-04-25 12:14:51 -05:00
}
2024-11-15 15:00:25 +01:00
2025-04-25 12:14:51 -05:00
// Check if device code has expired
if time . Now ( ) . After ( deviceAuth . ExpiresAt . ToTime ( ) ) {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , & common . OidcDeviceCodeExpiredError { }
2025-04-25 12:14:51 -05:00
}
// Check if device code has been authorized
if ! deviceAuth . IsAuthorized || deviceAuth . UserID == nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , & common . OidcAuthorizationPendingError { }
2025-04-25 12:14:51 -05:00
}
// Get user claims for the ID token - ensure UserID is not nil
if deviceAuth . UserID == nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , & common . OidcAuthorizationPendingError { }
2025-04-25 12:14:51 -05:00
}
2025-05-19 08:10:33 -07:00
userClaims , err := s . getUserClaimsForClientInternal ( ctx , * deviceAuth . UserID , input . ClientID , tx )
2025-04-25 12:14:51 -05:00
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-04-25 12:14:51 -05:00
}
// Explicitly use the input clientID for the audience claim to ensure consistency
2025-05-19 08:10:33 -07:00
idToken , err := s . jwtService . GenerateIDToken ( userClaims , input . ClientID , "" )
2025-04-25 12:14:51 -05:00
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-04-25 12:14:51 -05:00
}
2025-05-19 08:10:33 -07:00
refreshToken , err := s . createRefreshToken ( ctx , input . ClientID , * deviceAuth . UserID , deviceAuth . Scope , tx )
2025-04-25 12:14:51 -05:00
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-04-25 12:14:51 -05:00
}
2025-06-09 12:17:55 -07:00
accessToken , err := s . jwtService . GenerateOAuthAccessToken ( deviceAuth . User , input . ClientID )
2025-04-25 12:14:51 -05:00
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-04-25 12:14:51 -05:00
}
// Delete the used device code
2025-05-19 08:10:33 -07:00
err = tx . WithContext ( ctx ) . Delete ( & deviceAuth ) . Error
if err != nil {
return CreatedTokens { } , err
2025-04-25 12:14:51 -05:00
}
2025-05-19 08:10:33 -07:00
err = tx . Commit ( ) . Error
if err != nil {
return CreatedTokens { } , err
2025-04-25 12:14:51 -05:00
}
2025-05-19 08:10:33 -07:00
return CreatedTokens {
IdToken : idToken ,
AccessToken : accessToken ,
RefreshToken : refreshToken ,
2025-09-03 01:33:01 +02:00
ExpiresIn : AccessTokenDuration ,
} , nil
}
func ( s * OidcService ) createTokenFromClientCredentials ( ctx context . Context , input dto . OidcCreateTokensDto ) ( CreatedTokens , error ) {
client , err := s . verifyClientCredentialsInternal ( ctx , s . db , clientAuthCredentialsFromCreateTokensDto ( & input ) , false )
if err != nil {
return CreatedTokens { } , err
}
// GenerateOAuthAccessToken uses user.ID as a "sub" claim. Prefix is used to take those security considerations
// into account: https://datatracker.ietf.org/doc/html/rfc9068#name-security-considerations
dummyUser := model . User {
Base : model . Base { ID : "client-" + client . ID } ,
}
audClaim := client . ID
if input . Resource != "" {
audClaim = input . Resource
}
accessToken , err := s . jwtService . GenerateOAuthAccessToken ( dummyUser , audClaim )
if err != nil {
return CreatedTokens { } , err
}
return CreatedTokens {
AccessToken : accessToken ,
ExpiresIn : AccessTokenDuration ,
2025-05-19 08:10:33 -07:00
} , nil
2025-04-25 12:14:51 -05:00
}
2025-05-19 08:10:33 -07:00
func ( s * OidcService ) createTokenFromAuthorizationCode ( ctx context . Context , input dto . OidcCreateTokensDto ) ( CreatedTokens , error ) {
2025-04-25 12:14:51 -05:00
tx := s . db . Begin ( )
defer func ( ) {
tx . Rollback ( )
} ( )
2025-07-01 23:14:07 +02:00
client , err := s . verifyClientCredentialsInternal ( ctx , tx , clientAuthCredentialsFromCreateTokensDto ( & input ) , true )
2025-04-25 12:14:51 -05:00
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-03-25 07:36:53 -07:00
}
2024-11-15 15:00:25 +01:00
2025-03-25 07:36:53 -07:00
var authorizationCodeMetaData model . OidcAuthorizationCode
2025-04-06 06:04:08 -07:00
err = tx .
WithContext ( ctx ) .
Preload ( "User" ) .
2025-05-19 08:10:33 -07:00
First ( & authorizationCodeMetaData , "code = ?" , input . Code ) .
2025-04-06 06:04:08 -07:00
Error
2025-03-25 07:36:53 -07:00
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , & common . OidcInvalidAuthorizationCodeError { }
2025-03-25 07:36:53 -07:00
}
2024-08-17 21:57:14 +02:00
2025-03-25 07:36:53 -07:00
// If the client is public or PKCE is enabled, the code verifier must match the code challenge
if client . IsPublic || client . PkceEnabled {
2025-05-19 08:10:33 -07:00
if ! s . validateCodeVerifier ( input . CodeVerifier , * authorizationCodeMetaData . CodeChallenge , * authorizationCodeMetaData . CodeChallengeMethodSha256 ) {
return CreatedTokens { } , & common . OidcInvalidCodeVerifierError { }
2025-03-23 15:14:26 -05:00
}
2025-03-25 07:36:53 -07:00
}
2024-08-17 21:57:14 +02:00
2025-05-19 08:10:33 -07:00
if authorizationCodeMetaData . ClientID != input . ClientID && authorizationCodeMetaData . ExpiresAt . ToTime ( ) . Before ( time . Now ( ) ) {
return CreatedTokens { } , & common . OidcInvalidAuthorizationCodeError { }
2025-03-25 07:36:53 -07:00
}
2025-03-23 15:14:26 -05:00
2025-05-19 08:10:33 -07:00
userClaims , err := s . getUserClaimsForClientInternal ( ctx , authorizationCodeMetaData . UserID , input . ClientID , tx )
2025-03-25 07:36:53 -07:00
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-03-25 07:36:53 -07:00
}
2025-03-23 15:14:26 -05:00
2025-05-19 08:10:33 -07:00
idToken , err := s . jwtService . GenerateIDToken ( userClaims , input . ClientID , authorizationCodeMetaData . Nonce )
2025-03-25 07:36:53 -07:00
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-03-25 07:36:53 -07:00
}
2025-03-23 15:14:26 -05:00
2025-03-25 07:36:53 -07:00
// Generate a refresh token
2025-05-19 08:10:33 -07:00
refreshToken , err := s . createRefreshToken ( ctx , input . ClientID , authorizationCodeMetaData . UserID , authorizationCodeMetaData . Scope , tx )
2025-03-25 07:36:53 -07:00
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2024-08-19 18:48:18 +02:00
}
2025-06-09 12:17:55 -07:00
accessToken , err := s . jwtService . GenerateOAuthAccessToken ( authorizationCodeMetaData . User , input . ClientID )
2025-03-27 17:46:10 +01:00
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-03-27 17:46:10 +01:00
}
2025-03-23 15:14:26 -05:00
2025-04-06 06:04:08 -07:00
err = tx .
WithContext ( ctx ) .
Delete ( & authorizationCodeMetaData ) .
Error
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-04-06 06:04:08 -07:00
}
err = tx . Commit ( ) . Error
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-04-06 06:04:08 -07:00
}
2025-03-23 15:14:26 -05:00
2025-05-19 08:10:33 -07:00
return CreatedTokens {
IdToken : idToken ,
AccessToken : accessToken ,
RefreshToken : refreshToken ,
2025-09-03 01:33:01 +02:00
ExpiresIn : AccessTokenDuration ,
2025-05-19 08:10:33 -07:00
} , nil
2025-03-25 07:36:53 -07:00
}
2025-03-23 15:14:26 -05:00
2025-05-19 08:10:33 -07:00
func ( s * OidcService ) createTokenFromRefreshToken ( ctx context . Context , input dto . OidcCreateTokensDto ) ( CreatedTokens , error ) {
if input . RefreshToken == "" {
return CreatedTokens { } , & common . OidcMissingRefreshTokenError { }
2025-03-25 07:36:53 -07:00
}
2025-03-23 15:14:26 -05:00
2025-06-09 12:17:55 -07:00
// Validate the signed refresh token and extract the actual token (which is a claim in the signed one)
userID , clientID , rt , err := s . jwtService . VerifyOAuthRefreshToken ( input . RefreshToken )
if err != nil {
return CreatedTokens { } , & common . OidcInvalidRefreshTokenError { }
}
2025-04-06 06:04:08 -07:00
tx := s . db . Begin ( )
defer func ( ) {
tx . Rollback ( )
} ( )
2025-07-01 23:14:07 +02:00
client , err := s . verifyClientCredentialsInternal ( ctx , tx , clientAuthCredentialsFromCreateTokensDto ( & input ) , true )
2025-04-06 06:04:08 -07:00
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-03-25 07:36:53 -07:00
}
2025-03-23 15:14:26 -05:00
2025-06-09 12:17:55 -07:00
// The ID of the client that made the call must match the client ID in the token
if client . ID != clientID {
return CreatedTokens { } , & common . OidcInvalidRefreshTokenError { }
}
2025-03-25 07:36:53 -07:00
// Verify refresh token
var storedRefreshToken model . OidcRefreshToken
2025-04-06 06:04:08 -07:00
err = tx .
WithContext ( ctx ) .
2025-09-30 02:44:38 -07:00
Preload ( "User.UserGroups" ) .
2025-06-09 12:17:55 -07:00
Where (
"token = ? AND expires_at > ? AND user_id = ? AND client_id = ?" ,
utils . CreateSha256Hash ( rt ) ,
datatype . DateTime ( time . Now ( ) ) ,
userID ,
input . ClientID ,
) .
2025-03-25 07:36:53 -07:00
First ( & storedRefreshToken ) .
Error
2025-09-09 02:31:50 -07:00
if errors . Is ( err , gorm . ErrRecordNotFound ) {
return CreatedTokens { } , & common . OidcInvalidRefreshTokenError { }
} else if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-03-25 07:36:53 -07:00
}
2024-08-17 21:57:14 +02:00
2025-03-25 07:36:53 -07:00
// Verify that the refresh token belongs to the provided client
2025-05-19 08:10:33 -07:00
if storedRefreshToken . ClientID != input . ClientID {
return CreatedTokens { } , & common . OidcInvalidRefreshTokenError { }
2025-03-25 07:36:53 -07:00
}
2025-03-23 15:14:26 -05:00
2025-03-25 07:36:53 -07:00
// Generate a new access token
2025-06-09 12:17:55 -07:00
accessToken , err := s . jwtService . GenerateOAuthAccessToken ( storedRefreshToken . User , input . ClientID )
2025-03-25 07:36:53 -07:00
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-03-23 15:14:26 -05:00
}
2025-09-09 02:31:50 -07:00
// Load the profile, which we need for the ID token
userClaims , err := s . getUserClaims ( ctx , & storedRefreshToken . User , storedRefreshToken . Scopes ( ) , tx )
if err != nil {
return CreatedTokens { } , err
}
// Generate a new ID token
// There's no nonce here because we don't have one with the refresh token, but that's not required
idToken , err := s . jwtService . GenerateIDToken ( userClaims , input . ClientID , "" )
if err != nil {
return CreatedTokens { } , err
}
2025-03-25 07:36:53 -07:00
// Generate a new refresh token and invalidate the old one
2025-05-19 08:10:33 -07:00
newRefreshToken , err := s . createRefreshToken ( ctx , input . ClientID , storedRefreshToken . UserID , storedRefreshToken . Scope , tx )
2025-03-25 07:36:53 -07:00
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-03-25 07:36:53 -07:00
}
// Delete the used refresh token
2025-04-06 06:04:08 -07:00
err = tx .
WithContext ( ctx ) .
Delete ( & storedRefreshToken ) .
Error
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-04-06 06:04:08 -07:00
}
err = tx . Commit ( ) . Error
if err != nil {
2025-05-19 08:10:33 -07:00
return CreatedTokens { } , err
2025-04-06 06:04:08 -07:00
}
2025-03-25 07:36:53 -07:00
2025-05-19 08:10:33 -07:00
return CreatedTokens {
AccessToken : accessToken ,
RefreshToken : newRefreshToken ,
2025-09-09 02:31:50 -07:00
IdToken : idToken ,
2025-09-03 01:33:01 +02:00
ExpiresIn : AccessTokenDuration ,
2025-05-19 08:10:33 -07:00
} , nil
2024-08-17 21:57:14 +02:00
}
2025-06-09 12:17:55 -07:00
func ( s * OidcService ) IntrospectToken ( ctx context . Context , creds ClientAuthCredentials , tokenString string ) ( introspectDto dto . OidcIntrospectionResponseDto , err error ) {
2025-07-01 23:14:07 +02:00
client , err := s . verifyClientCredentialsInternal ( ctx , s . db , creds , false )
if err != nil {
return introspectDto , err
}
2025-06-09 12:17:55 -07:00
// Get the type of the token and the client ID
tokenType , token , err := s . jwtService . GetTokenType ( tokenString )
if err != nil {
// We just treat the token as invalid
introspectDto . Active = false
return introspectDto , nil //nolint:nilerr
}
2025-07-01 23:14:07 +02:00
// Get the audience from the token
2025-06-09 12:17:55 -07:00
tokenAudiences , _ := token . Audience ( )
if len ( tokenAudiences ) != 1 || tokenAudiences [ 0 ] == "" {
introspectDto . Active = false
return introspectDto , nil
}
2025-04-09 09:18:03 +02:00
2025-07-01 23:14:07 +02:00
// Audience must match the client ID
if client . ID != tokenAudiences [ 0 ] {
return introspectDto , & common . OidcMissingClientCredentialsError { }
2025-04-09 09:18:03 +02:00
}
2025-06-09 12:17:55 -07:00
// Introspect the token
switch tokenType {
case OAuthAccessTokenJWTType :
return s . introspectAccessToken ( client . ID , tokenString )
case OAuthRefreshTokenJWTType :
return s . introspectRefreshToken ( ctx , client . ID , tokenString )
default :
// We just treat the token as invalid
introspectDto . Active = false
return introspectDto , nil
}
}
2025-04-09 09:18:03 +02:00
2025-06-09 12:17:55 -07:00
func ( s * OidcService ) introspectAccessToken ( clientID string , tokenString string ) ( introspectDto dto . OidcIntrospectionResponseDto , err error ) {
token , err := s . jwtService . VerifyOAuthAccessToken ( tokenString )
if err != nil {
2025-04-09 09:18:03 +02:00
// Every failure we get means the token is invalid. Nothing more to do with the error.
2025-06-09 12:17:55 -07:00
introspectDto . Active = false
return introspectDto , nil //nolint:nilerr
}
// The ID of the client that made the request must match the client ID in the token
audience , ok := token . Audience ( )
if ! ok || len ( audience ) != 1 || audience [ 0 ] == "" {
2025-04-09 09:18:03 +02:00
introspectDto . Active = false
return introspectDto , nil
}
2025-06-09 12:17:55 -07:00
if audience [ 0 ] != clientID {
return introspectDto , & common . OidcMissingClientCredentialsError { }
}
2025-04-09 09:18:03 +02:00
introspectDto . Active = true
introspectDto . TokenType = "access_token"
2025-06-09 12:17:55 -07:00
introspectDto . Audience = audience
2025-04-09 09:18:03 +02:00
if token . Has ( "scope" ) {
2025-06-06 03:23:51 -07:00
var (
asString string
asStrings [ ] string
)
2025-04-09 09:18:03 +02:00
if err := token . Get ( "scope" , & asString ) ; err == nil {
introspectDto . Scope = asString
} else if err := token . Get ( "scope" , & asStrings ) ; err == nil {
introspectDto . Scope = strings . Join ( asStrings , " " )
}
}
2025-06-06 03:23:51 -07:00
if expiration , ok := token . Expiration ( ) ; ok {
2025-04-09 09:18:03 +02:00
introspectDto . Expiration = expiration . Unix ( )
}
2025-06-06 03:23:51 -07:00
if issuedAt , ok := token . IssuedAt ( ) ; ok {
2025-04-09 09:18:03 +02:00
introspectDto . IssuedAt = issuedAt . Unix ( )
}
2025-06-06 03:23:51 -07:00
if notBefore , ok := token . NotBefore ( ) ; ok {
2025-04-09 09:18:03 +02:00
introspectDto . NotBefore = notBefore . Unix ( )
}
2025-06-06 03:23:51 -07:00
if subject , ok := token . Subject ( ) ; ok {
2025-04-09 09:18:03 +02:00
introspectDto . Subject = subject
}
2025-06-06 03:23:51 -07:00
if issuer , ok := token . Issuer ( ) ; ok {
2025-04-09 09:18:03 +02:00
introspectDto . Issuer = issuer
}
2025-06-06 03:23:51 -07:00
if identifier , ok := token . JwtID ( ) ; ok {
2025-04-09 09:18:03 +02:00
introspectDto . Identifier = identifier
}
return introspectDto , nil
}
2025-06-09 12:17:55 -07:00
func ( s * OidcService ) introspectRefreshToken ( ctx context . Context , clientID string , refreshToken string ) ( introspectDto dto . OidcIntrospectionResponseDto , err error ) {
// Validate the signed refresh token and extract the actual token (which is a claim in the signed one)
tokenUserID , tokenClientID , tokenRT , err := s . jwtService . VerifyOAuthRefreshToken ( refreshToken )
if err != nil {
return introspectDto , fmt . Errorf ( "invalid refresh token: %w" , err )
}
// The ID of the client that made the call must match the client ID in the token
if tokenClientID != clientID {
return introspectDto , errors . New ( "invalid refresh token: client ID does not match" )
}
2025-04-09 09:18:03 +02:00
var storedRefreshToken model . OidcRefreshToken
2025-04-27 02:32:42 +09:00
err = s . db .
WithContext ( ctx ) .
Preload ( "User" ) .
2025-06-09 12:17:55 -07:00
Where (
"token = ? AND expires_at > ? AND user_id = ? AND client_id = ?" ,
utils . CreateSha256Hash ( tokenRT ) ,
datatype . DateTime ( time . Now ( ) ) ,
tokenUserID ,
tokenClientID ,
) .
2025-04-09 09:18:03 +02:00
First ( & storedRefreshToken ) .
Error
if err != nil {
if errors . Is ( err , gorm . ErrRecordNotFound ) {
introspectDto . Active = false
return introspectDto , nil
}
return introspectDto , err
}
introspectDto . Active = true
introspectDto . TokenType = "refresh_token"
return introspectDto , nil
}
2025-04-06 06:04:08 -07:00
func ( s * OidcService ) GetClient ( ctx context . Context , clientID string ) ( model . OidcClient , error ) {
return s . getClientInternal ( ctx , clientID , s . db )
}
func ( s * OidcService ) getClientInternal ( ctx context . Context , clientID string , tx * gorm . DB ) ( model . OidcClient , error ) {
2024-08-17 21:57:14 +02:00
var client model . OidcClient
2025-04-06 06:04:08 -07:00
err := tx .
WithContext ( ctx ) .
Preload ( "CreatedBy" ) .
Preload ( "AllowedUserGroups" ) .
First ( & client , "id = ?" , clientID ) .
Error
if err != nil {
2024-08-23 17:04:19 +02:00
return model . OidcClient { } , err
2024-08-17 21:57:14 +02:00
}
2024-08-23 17:04:19 +02:00
return client , nil
2024-08-17 21:57:14 +02:00
}
2025-05-25 14:22:25 -05:00
func ( s * OidcService ) ListClients ( ctx context . Context , name string , sortedPaginationRequest utils . SortedPaginationRequest ) ( [ ] model . OidcClient , utils . PaginationResponse , error ) {
2024-08-17 21:57:14 +02:00
var clients [ ] model . OidcClient
2025-04-06 06:04:08 -07:00
query := s . db .
WithContext ( ctx ) .
Preload ( "CreatedBy" ) .
Model ( & model . OidcClient { } )
2025-05-25 14:22:25 -05:00
if name != "" {
query = query . Where ( "name LIKE ?" , "%" + name + "%" )
2024-08-17 21:57:14 +02:00
}
2025-05-25 14:22:25 -05:00
// As allowedUserGroupsCount is not a column, we need to manually sort it
2025-08-17 22:47:34 +02:00
if sortedPaginationRequest . Sort . Column == "allowedUserGroupsCount" && utils . IsValidSortDirection ( sortedPaginationRequest . Sort . Direction ) {
2025-05-25 14:22:25 -05:00
query = query . Select ( "oidc_clients.*, COUNT(oidc_clients_allowed_user_groups.oidc_client_id)" ) .
Joins ( "LEFT JOIN oidc_clients_allowed_user_groups ON oidc_clients.id = oidc_clients_allowed_user_groups.oidc_client_id" ) .
Group ( "oidc_clients.id" ) .
Order ( "COUNT(oidc_clients_allowed_user_groups.oidc_client_id) " + sortedPaginationRequest . Sort . Direction )
response , err := utils . Paginate ( sortedPaginationRequest . Pagination . Page , sortedPaginationRequest . Pagination . Limit , query , & clients )
return clients , response , err
2024-08-17 21:57:14 +02:00
}
2025-05-25 14:22:25 -05:00
response , err := utils . PaginateAndSort ( sortedPaginationRequest , query , & clients )
return clients , response , err
2024-08-17 21:57:14 +02:00
}
2025-04-06 06:04:08 -07:00
func ( s * OidcService ) CreateClient ( ctx context . Context , input dto . OidcClientCreateDto , userID string ) ( model . OidcClient , error ) {
2025-09-29 10:07:55 -05:00
tx := s . db . Begin ( )
defer func ( ) {
tx . Rollback ( )
} ( )
2024-08-17 21:57:14 +02:00
client := model . OidcClient {
2025-08-23 18:41:05 +02:00
Base : model . Base {
ID : input . ID ,
} ,
2025-08-23 17:54:51 +02:00
CreatedByID : utils . Ptr ( userID ) ,
2024-08-17 21:57:14 +02:00
}
2025-08-23 18:41:05 +02:00
updateOIDCClientModelFromDto ( & client , & input . OidcClientUpdateDto )
2024-08-17 21:57:14 +02:00
2025-09-29 10:07:55 -05:00
err := tx .
2025-04-06 06:04:08 -07:00
WithContext ( ctx ) .
Create ( & client ) .
Error
if err != nil {
2025-08-23 18:41:05 +02:00
if errors . Is ( err , gorm . ErrDuplicatedKey ) {
return model . OidcClient { } , & common . ClientIdAlreadyExistsError { }
}
2024-08-23 17:04:19 +02:00
return model . OidcClient { } , err
2024-08-17 21:57:14 +02:00
}
2025-09-29 10:07:55 -05:00
if input . LogoURL != nil {
err = s . downloadAndSaveLogoFromURL ( ctx , tx , client . ID , * input . LogoURL )
if err != nil {
return model . OidcClient { } , fmt . Errorf ( "failed to download logo: %w" , err )
}
}
err = tx . Commit ( ) . Error
if err != nil {
return model . OidcClient { } , err
}
2024-08-23 17:04:19 +02:00
return client , nil
2024-08-17 21:57:14 +02:00
}
2025-08-23 18:41:05 +02:00
func ( s * OidcService ) UpdateClient ( ctx context . Context , clientID string , input dto . OidcClientUpdateDto ) ( model . OidcClient , error ) {
2025-04-06 06:04:08 -07:00
tx := s . db . Begin ( )
2025-09-29 10:07:55 -05:00
defer func ( ) { tx . Rollback ( ) } ( )
2025-04-06 06:04:08 -07:00
2024-08-17 21:57:14 +02:00
var client model . OidcClient
2025-09-29 10:07:55 -05:00
if err := tx . WithContext ( ctx ) .
2025-04-06 06:04:08 -07:00
Preload ( "CreatedBy" ) .
2025-09-29 10:07:55 -05:00
First ( & client , "id = ?" , clientID ) . Error ; err != nil {
2024-08-23 17:04:19 +02:00
return model . OidcClient { } , err
2024-08-17 21:57:14 +02:00
}
2025-06-06 03:23:51 -07:00
updateOIDCClientModelFromDto ( & client , & input )
2024-08-17 21:57:14 +02:00
2025-09-29 10:07:55 -05:00
if err := tx . WithContext ( ctx ) . Save ( & client ) . Error ; err != nil {
2025-04-06 06:04:08 -07:00
return model . OidcClient { } , err
}
2025-09-29 10:07:55 -05:00
if input . LogoURL != nil {
err := s . downloadAndSaveLogoFromURL ( ctx , tx , client . ID , * input . LogoURL )
if err != nil {
return model . OidcClient { } , fmt . Errorf ( "failed to download logo: %w" , err )
}
2024-08-17 21:57:14 +02:00
}
2025-09-29 10:07:55 -05:00
if err := tx . Commit ( ) . Error ; err != nil {
return model . OidcClient { } , err
}
2024-08-23 17:04:19 +02:00
return client , nil
2024-08-17 21:57:14 +02:00
}
2025-08-23 18:41:05 +02:00
func updateOIDCClientModelFromDto ( client * model . OidcClient , input * dto . OidcClientUpdateDto ) {
2025-06-06 03:23:51 -07:00
// Base fields
client . Name = input . Name
client . CallbackURLs = input . CallbackURLs
client . LogoutCallbackURLs = input . LogoutCallbackURLs
client . IsPublic = input . IsPublic
// PKCE is required for public clients
client . PkceEnabled = input . IsPublic || input . PkceEnabled
2025-08-22 08:56:40 +02:00
client . RequiresReauthentication = input . RequiresReauthentication
2025-08-10 10:56:03 -05:00
client . LaunchURL = input . LaunchURL
2025-06-06 03:23:51 -07:00
// Credentials
2025-08-23 17:40:06 +02:00
client . Credentials . FederatedIdentities = make ( [ ] model . OidcClientFederatedIdentity , len ( input . Credentials . FederatedIdentities ) )
for i , fi := range input . Credentials . FederatedIdentities {
client . Credentials . FederatedIdentities [ i ] = model . OidcClientFederatedIdentity {
Issuer : fi . Issuer ,
Audience : fi . Audience ,
Subject : fi . Subject ,
JWKS : fi . JWKS ,
2025-06-06 03:23:51 -07:00
}
}
2025-08-23 17:46:00 +02:00
2025-06-06 03:23:51 -07:00
}
2025-04-06 06:04:08 -07:00
func ( s * OidcService ) DeleteClient ( ctx context . Context , clientID string ) error {
2024-08-17 21:57:14 +02:00
var client model . OidcClient
2025-04-06 06:04:08 -07:00
err := s . db .
WithContext ( ctx ) .
Where ( "id = ?" , clientID ) .
Delete ( & client ) .
Error
if err != nil {
2024-08-17 21:57:14 +02:00
return err
}
return nil
}
2025-04-06 06:04:08 -07:00
func ( s * OidcService ) CreateClientSecret ( ctx context . Context , clientID string ) ( string , error ) {
tx := s . db . Begin ( )
defer func ( ) {
tx . Rollback ( )
} ( )
2024-08-17 21:57:14 +02:00
var client model . OidcClient
2025-04-06 06:04:08 -07:00
err := tx .
WithContext ( ctx ) .
First ( & client , "id = ?" , clientID ) .
Error
if err != nil {
2024-08-17 21:57:14 +02:00
return "" , err
}
clientSecret , err := utils . GenerateRandomAlphanumericString ( 32 )
if err != nil {
return "" , err
}
hashedSecret , err := bcrypt . GenerateFromPassword ( [ ] byte ( clientSecret ) , bcrypt . DefaultCost )
if err != nil {
return "" , err
}
client . Secret = string ( hashedSecret )
2025-04-06 06:04:08 -07:00
err = tx .
WithContext ( ctx ) .
Save ( & client ) .
Error
if err != nil {
return "" , err
}
err = tx . Commit ( ) . Error
if err != nil {
2024-08-17 21:57:14 +02:00
return "" , err
}
return clientSecret , nil
}
2025-04-06 06:04:08 -07:00
func ( s * OidcService ) GetClientLogo ( ctx context . Context , clientID string ) ( string , string , error ) {
2024-08-17 21:57:14 +02:00
var client model . OidcClient
2025-04-06 06:04:08 -07:00
err := s . db .
WithContext ( ctx ) .
First ( & client , "id = ?" , clientID ) .
Error
if err != nil {
2024-08-17 21:57:14 +02:00
return "" , "" , err
}
if client . ImageType == nil {
return "" , "" , errors . New ( "image not found" )
}
2025-04-06 06:04:08 -07:00
imagePath := common . EnvConfig . UploadPath + "/oidc-client-images/" + client . ID + "." + * client . ImageType
mimeType := utils . GetImageMimeType ( * client . ImageType )
2024-08-17 21:57:14 +02:00
return imagePath , mimeType , nil
}
2025-04-06 06:04:08 -07:00
func ( s * OidcService ) UpdateClientLogo ( ctx context . Context , clientID string , file * multipart . FileHeader ) error {
2025-06-10 10:51:46 +02:00
fileType := strings . ToLower ( utils . GetFileExtension ( file . Filename ) )
2024-08-17 21:57:14 +02:00
if mimeType := utils . GetImageMimeType ( fileType ) ; mimeType == "" {
2024-10-28 18:11:54 +01:00
return & common . FileTypeNotSupportedError { }
2024-08-17 21:57:14 +02:00
}
2025-04-06 06:04:08 -07:00
imagePath := common . EnvConfig . UploadPath + "/oidc-client-images/" + clientID + "." + fileType
err := utils . SaveFile ( file , imagePath )
if err != nil {
2024-08-17 21:57:14 +02:00
return err
}
2025-04-06 06:04:08 -07:00
tx := s . db . Begin ( )
2025-09-29 10:07:55 -05:00
err = s . updateClientLogoType ( ctx , tx , clientID , fileType )
2025-04-06 06:04:08 -07:00
if err != nil {
2025-09-29 10:07:55 -05:00
tx . Rollback ( )
2024-08-17 21:57:14 +02:00
return err
}
2025-09-29 10:07:55 -05:00
return tx . Commit ( ) . Error
2024-08-17 21:57:14 +02:00
}
2025-04-06 06:04:08 -07:00
func ( s * OidcService ) DeleteClientLogo ( ctx context . Context , clientID string ) error {
tx := s . db . Begin ( )
defer func ( ) {
tx . Rollback ( )
} ( )
2024-08-17 21:57:14 +02:00
var client model . OidcClient
2025-04-06 06:04:08 -07:00
err := tx .
WithContext ( ctx ) .
First ( & client , "id = ?" , clientID ) .
Error
if err != nil {
2024-08-17 21:57:14 +02:00
return err
}
if client . ImageType == nil {
return errors . New ( "image not found" )
}
2025-06-06 08:50:33 +02:00
oldImageType := * client . ImageType
2025-04-06 06:04:08 -07:00
client . ImageType = nil
2025-09-29 10:07:55 -05:00
2025-04-06 06:04:08 -07:00
err = tx .
WithContext ( ctx ) .
Save ( & client ) .
Error
if err != nil {
return err
}
2025-06-06 08:50:33 +02:00
imagePath := common . EnvConfig . UploadPath + "/oidc-client-images/" + client . ID + "." + oldImageType
2024-08-17 21:57:14 +02:00
if err := os . Remove ( imagePath ) ; err != nil {
return err
}
2025-04-06 06:04:08 -07:00
err = tx . Commit ( ) . Error
if err != nil {
2024-08-17 21:57:14 +02:00
return err
}
return nil
}
2025-04-06 06:04:08 -07:00
func ( s * OidcService ) UpdateAllowedUserGroups ( ctx context . Context , id string , input dto . OidcUpdateAllowedUserGroupsDto ) ( client model . OidcClient , err error ) {
tx := s . db . Begin ( )
defer func ( ) {
tx . Rollback ( )
} ( )
client , err = s . getClientInternal ( ctx , id , tx )
2025-02-03 18:41:15 +01:00
if err != nil {
return model . OidcClient { } , err
}
// Fetch the user groups based on UserGroupIDs in input
var groups [ ] model . UserGroup
if len ( input . UserGroupIDs ) > 0 {
2025-04-06 06:04:08 -07:00
err = tx .
WithContext ( ctx ) .
Where ( "id IN (?)" , input . UserGroupIDs ) .
Find ( & groups ) .
Error
if err != nil {
2025-02-03 18:41:15 +01:00
return model . OidcClient { } , err
}
}
// Replace the current user groups with the new set of user groups
2025-04-06 06:04:08 -07:00
err = tx .
WithContext ( ctx ) .
Model ( & client ) .
Association ( "AllowedUserGroups" ) .
Replace ( groups )
if err != nil {
2025-02-03 18:41:15 +01:00
return model . OidcClient { } , err
}
// Save the updated client
2025-04-06 06:04:08 -07:00
err = tx .
WithContext ( ctx ) .
Save ( & client ) .
Error
if err != nil {
return model . OidcClient { } , err
}
err = tx . Commit ( ) . Error
if err != nil {
2025-02-03 18:41:15 +01:00
return model . OidcClient { } , err
}
return client , nil
}
2025-02-14 17:09:27 +01:00
// ValidateEndSession returns the logout callback URL for the client if all the validations pass
2025-04-06 06:04:08 -07:00
func ( s * OidcService ) ValidateEndSession ( ctx context . Context , input dto . OidcLogoutDto , userID string ) ( string , error ) {
2025-02-14 17:09:27 +01:00
// If no ID token hint is provided, return an error
if input . IdTokenHint == "" {
return "" , & common . TokenInvalidError { }
}
// If the ID token hint is provided, verify the ID token
2025-03-27 10:20:39 -07:00
// Here we also accept expired ID tokens, which are fine per spec
token , err := s . jwtService . VerifyIdToken ( input . IdTokenHint , true )
2025-02-14 17:09:27 +01:00
if err != nil {
return "" , & common . TokenInvalidError { }
}
// If the client ID is provided check if the client ID in the ID token matches the client ID in the request
2025-03-27 10:20:39 -07:00
clientID , ok := token . Audience ( )
if ! ok || len ( clientID ) == 0 {
return "" , & common . TokenInvalidError { }
}
if input . ClientId != "" && clientID [ 0 ] != input . ClientId {
2025-02-14 17:09:27 +01:00
return "" , & common . OidcClientIdNotMatchingError { }
}
// Check if the user has authorized the client before
var userAuthorizedOIDCClient model . UserAuthorizedOidcClient
2025-04-06 06:04:08 -07:00
err = s . db .
WithContext ( ctx ) .
Preload ( "Client" ) .
First ( & userAuthorizedOIDCClient , "client_id = ? AND user_id = ?" , clientID [ 0 ] , userID ) .
Error
if err != nil {
2025-02-14 17:09:27 +01:00
return "" , & common . OidcMissingAuthorizationError { }
}
// If the client has no logout callback URLs, return an error
if len ( userAuthorizedOIDCClient . Client . LogoutCallbackURLs ) == 0 {
return "" , & common . OidcNoCallbackURLError { }
}
2025-05-29 13:01:23 -07:00
callbackURL , err := s . getLogoutCallbackURL ( & userAuthorizedOIDCClient . Client , input . PostLogoutRedirectUri )
2025-02-14 17:09:27 +01:00
if err != nil {
return "" , err
}
return callbackURL , nil
}
2025-04-06 06:04:08 -07:00
func ( s * OidcService ) createAuthorizationCode ( ctx context . Context , clientID string , userID string , scope string , nonce string , codeChallenge string , codeChallengeMethod string , tx * gorm . DB ) ( string , error ) {
2024-08-17 21:57:14 +02:00
randomString , err := utils . GenerateRandomAlphanumericString ( 32 )
if err != nil {
return "" , err
}
2024-11-15 15:00:25 +01:00
codeChallengeMethodSha256 := strings . ToUpper ( codeChallengeMethod ) == "S256"
2024-08-17 21:57:14 +02:00
oidcAuthorizationCode := model . OidcAuthorizationCode {
2024-11-15 15:00:25 +01:00
ExpiresAt : datatype . DateTime ( time . Now ( ) . Add ( 15 * time . Minute ) ) ,
Code : randomString ,
ClientID : clientID ,
UserID : userID ,
Scope : scope ,
Nonce : nonce ,
CodeChallenge : & codeChallenge ,
CodeChallengeMethodSha256 : & codeChallengeMethodSha256 ,
2024-08-17 21:57:14 +02:00
}
2025-04-06 06:04:08 -07:00
err = tx .
WithContext ( ctx ) .
Create ( & oidcAuthorizationCode ) .
Error
if err != nil {
2024-08-17 21:57:14 +02:00
return "" , err
}
return randomString , nil
}
2024-08-24 00:49:08 +02:00
2024-11-15 15:00:25 +01:00
func ( s * OidcService ) validateCodeVerifier ( codeVerifier , codeChallenge string , codeChallengeMethodSha256 bool ) bool {
2025-01-03 16:15:10 +01:00
if codeVerifier == "" || codeChallenge == "" {
return false
}
2024-11-15 15:00:25 +01:00
if ! codeChallengeMethodSha256 {
return codeVerifier == codeChallenge
}
// Compute SHA-256 hash of the codeVerifier
h := sha256 . New ( )
h . Write ( [ ] byte ( codeVerifier ) )
codeVerifierHash := h . Sum ( nil )
// Base64 URL encode the verifier hash
encodedVerifierHash := base64 . RawURLEncoding . EncodeToString ( codeVerifierHash )
return encodedVerifierHash == codeChallenge
}
2025-05-29 13:01:23 -07:00
func ( s * OidcService ) getCallbackURL ( client * model . OidcClient , inputCallbackURL string , tx * gorm . DB , ctx context . Context ) ( callbackURL string , err error ) {
2025-05-29 10:16:10 -05:00
// If no input callback URL provided, use the first configured URL
2024-08-24 00:49:08 +02:00
if inputCallbackURL == "" {
2025-05-29 13:01:23 -07:00
if len ( client . CallbackURLs ) > 0 {
return client . CallbackURLs [ 0 ] , nil
2025-05-29 10:16:10 -05:00
}
// If no URLs are configured and no input URL, this is an error
return "" , & common . OidcMissingCallbackURLError { }
2024-08-24 00:49:08 +02:00
}
2025-01-20 11:19:23 +01:00
2025-05-29 10:16:10 -05:00
// If URLs are already configured, validate against them
2025-05-29 13:01:23 -07:00
if len ( client . CallbackURLs ) > 0 {
matched , err := s . getCallbackURLFromList ( client . CallbackURLs , inputCallbackURL )
if err != nil {
return "" , err
} else if matched == "" {
return "" , & common . OidcInvalidCallbackURLError { }
2025-01-20 11:19:23 +01:00
}
2025-05-29 13:01:23 -07:00
return matched , nil
2024-08-24 00:49:08 +02:00
}
2025-05-29 10:16:10 -05:00
// If no URLs are configured, trust and store the first URL (TOFU)
2025-05-29 13:01:23 -07:00
err = s . addCallbackURLToClient ( ctx , client , inputCallbackURL , tx )
2025-05-29 10:16:10 -05:00
if err != nil {
return "" , err
}
return inputCallbackURL , nil
}
2025-05-29 13:01:23 -07:00
func ( s * OidcService ) getLogoutCallbackURL ( client * model . OidcClient , inputLogoutCallbackURL string ) ( callbackURL string , err error ) {
if inputLogoutCallbackURL == "" {
return client . LogoutCallbackURLs [ 0 ] , nil
}
matched , err := s . getCallbackURLFromList ( client . LogoutCallbackURLs , inputLogoutCallbackURL )
2025-05-29 10:16:10 -05:00
if err != nil {
2025-05-29 13:01:23 -07:00
return "" , err
} else if matched == "" {
return "" , & common . OidcInvalidCallbackURLError { }
}
return matched , nil
}
func ( s * OidcService ) getCallbackURLFromList ( urls [ ] string , inputCallbackURL string ) ( callbackURL string , err error ) {
for _ , callbackPattern := range urls {
regexPattern := "^" + strings . ReplaceAll ( regexp . QuoteMeta ( callbackPattern ) , ` \* ` , ".*" ) + "$"
matched , err := regexp . MatchString ( regexPattern , inputCallbackURL )
if err != nil {
return "" , err
}
if matched {
return inputCallbackURL , nil
}
2025-05-29 10:16:10 -05:00
}
2025-05-29 13:01:23 -07:00
return "" , nil
}
func ( s * OidcService ) addCallbackURLToClient ( ctx context . Context , client * model . OidcClient , callbackURL string , tx * gorm . DB ) error {
2025-05-29 10:16:10 -05:00
// Add the new callback URL to the existing list
client . CallbackURLs = append ( client . CallbackURLs , callbackURL )
2025-05-29 13:01:23 -07:00
err := tx . WithContext ( ctx ) . Save ( client ) . Error
2025-05-29 10:16:10 -05:00
if err != nil {
return err
}
return nil
2024-08-24 00:49:08 +02:00
}
2025-03-23 15:14:26 -05:00
2025-04-27 02:32:42 +09:00
func ( s * OidcService ) CreateDeviceAuthorization ( ctx context . Context , input dto . OidcDeviceAuthorizationRequestDto ) ( * dto . OidcDeviceAuthorizationResponseDto , error ) {
2025-06-09 12:17:55 -07:00
client , err := s . verifyClientCredentialsInternal ( ctx , s . db , ClientAuthCredentials {
ClientID : input . ClientID ,
ClientSecret : input . ClientSecret ,
ClientAssertionType : input . ClientAssertionType ,
ClientAssertion : input . ClientAssertion ,
2025-07-01 23:14:07 +02:00
} , true )
2025-04-25 12:14:51 -05:00
if err != nil {
return nil , err
}
// Generate codes
deviceCode , err := utils . GenerateRandomAlphanumericString ( 32 )
if err != nil {
return nil , err
}
userCode , err := utils . GenerateRandomAlphanumericString ( 8 )
if err != nil {
return nil , err
}
// Create device authorization
deviceAuth := & model . OidcDeviceCode {
DeviceCode : deviceCode ,
UserCode : userCode ,
Scope : input . Scope ,
2025-06-09 12:17:55 -07:00
ExpiresAt : datatype . DateTime ( time . Now ( ) . Add ( DeviceCodeDuration ) ) ,
2025-04-25 12:14:51 -05:00
IsAuthorized : false ,
ClientID : client . ID ,
}
if err := s . db . Create ( deviceAuth ) . Error ; err != nil {
return nil , err
}
return & dto . OidcDeviceAuthorizationResponseDto {
DeviceCode : deviceCode ,
UserCode : userCode ,
VerificationURI : common . EnvConfig . AppURL + "/device" ,
VerificationURIComplete : common . EnvConfig . AppURL + "/device?code=" + userCode ,
2025-06-09 12:17:55 -07:00
ExpiresIn : int ( DeviceCodeDuration . Seconds ( ) ) ,
2025-04-25 12:14:51 -05:00
Interval : 5 ,
} , nil
}
func ( s * OidcService ) VerifyDeviceCode ( ctx context . Context , userCode string , userID string , ipAddress string , userAgent string ) error {
tx := s . db . Begin ( )
defer func ( ) {
tx . Rollback ( )
} ( )
var deviceAuth model . OidcDeviceCode
2025-07-27 06:34:23 +02:00
err := tx .
WithContext ( ctx ) .
Preload ( "Client.AllowedUserGroups" ) .
First ( & deviceAuth , "user_code = ?" , userCode ) .
Error
if err != nil {
return fmt . Errorf ( "error finding device code: %w" , err )
2025-04-25 12:14:51 -05:00
}
if time . Now ( ) . After ( deviceAuth . ExpiresAt . ToTime ( ) ) {
return & common . OidcDeviceCodeExpiredError { }
}
// Check if the user group is allowed to authorize the client
var user model . User
2025-07-27 06:34:23 +02:00
err = tx .
WithContext ( ctx ) .
Preload ( "UserGroups" ) .
First ( & user , "id = ?" , userID ) .
Error
if err != nil {
return fmt . Errorf ( "error finding user groups: %w" , err )
2025-04-25 12:14:51 -05:00
}
if ! s . IsUserGroupAllowedToAuthorize ( user , deviceAuth . Client ) {
return & common . OidcAccessDeniedError { }
}
2025-07-27 06:34:23 +02:00
err = tx .
WithContext ( ctx ) .
Preload ( "Client" ) .
First ( & deviceAuth , "user_code = ?" , userCode ) .
Error
if err != nil {
return fmt . Errorf ( "error finding device code: %w" , err )
2025-04-25 12:14:51 -05:00
}
if time . Now ( ) . After ( deviceAuth . ExpiresAt . ToTime ( ) ) {
return & common . OidcDeviceCodeExpiredError { }
}
deviceAuth . UserID = & userID
deviceAuth . IsAuthorized = true
2025-07-27 06:34:23 +02:00
err = tx .
WithContext ( ctx ) .
Save ( & deviceAuth ) .
Error
if err != nil {
return fmt . Errorf ( "error saving device auth: %w" , err )
2025-04-25 12:14:51 -05:00
}
2025-08-10 10:56:03 -05:00
hasAlreadyAuthorizedClient , err := s . createAuthorizedClientInternal ( ctx , userID , deviceAuth . ClientID , deviceAuth . Scope , tx )
2025-04-25 12:14:51 -05:00
if err != nil {
return err
}
2025-07-27 06:34:23 +02:00
auditLogData := model . AuditLogData { "clientName" : deviceAuth . Client . Name }
2025-08-10 10:56:03 -05:00
if hasAlreadyAuthorizedClient {
2025-07-27 06:34:23 +02:00
s . auditLogService . Create ( ctx , model . AuditLogEventDeviceCodeAuthorization , ipAddress , userAgent , userID , auditLogData , tx )
2025-08-10 10:56:03 -05:00
} else {
s . auditLogService . Create ( ctx , model . AuditLogEventNewDeviceCodeAuthorization , ipAddress , userAgent , userID , auditLogData , tx )
2025-04-25 12:14:51 -05:00
}
return tx . Commit ( ) . Error
}
func ( s * OidcService ) GetDeviceCodeInfo ( ctx context . Context , userCode string , userID string ) ( * dto . DeviceCodeInfoDto , error ) {
var deviceAuth model . OidcDeviceCode
2025-04-27 02:32:42 +09:00
err := s . db .
WithContext ( ctx ) .
Preload ( "Client" ) .
First ( & deviceAuth , "user_code = ?" , userCode ) .
Error
if err != nil {
2025-04-25 12:14:51 -05:00
if errors . Is ( err , gorm . ErrRecordNotFound ) {
return nil , & common . OidcInvalidDeviceCodeError { }
}
return nil , err
}
if time . Now ( ) . After ( deviceAuth . ExpiresAt . ToTime ( ) ) {
return nil , & common . OidcDeviceCodeExpiredError { }
}
// Check if the user has already authorized this client with this scope
hasAuthorizedClient := false
if userID != "" {
var err error
hasAuthorizedClient , err = s . HasAuthorizedClient ( ctx , deviceAuth . ClientID , userID , deviceAuth . Scope )
if err != nil {
return nil , err
}
}
return & dto . DeviceCodeInfoDto {
Client : dto . OidcClientMetaDataDto {
ID : deviceAuth . Client . ID ,
Name : deviceAuth . Client . Name ,
2025-09-29 10:07:55 -05:00
HasLogo : deviceAuth . Client . HasLogo ( ) ,
2025-04-25 12:14:51 -05:00
} ,
Scope : deviceAuth . Scope ,
AuthorizationRequired : ! hasAuthorizedClient ,
} , nil
}
2025-05-25 14:22:25 -05:00
func ( s * OidcService ) GetAllowedGroupsCountOfClient ( ctx context . Context , id string ) ( int64 , error ) {
2025-05-25 22:24:20 +02:00
// We only perform select queries here, so we can rollback in all cases
tx := s . db . Begin ( )
defer func ( ) {
tx . Rollback ( )
} ( )
2025-05-25 14:22:25 -05:00
var client model . OidcClient
2025-05-25 22:24:20 +02:00
err := tx . WithContext ( ctx ) . Where ( "id = ?" , id ) . First ( & client ) . Error
2025-05-25 14:22:25 -05:00
if err != nil {
return 0 , err
}
2025-05-25 22:24:20 +02:00
count := tx . WithContext ( ctx ) . Model ( & client ) . Association ( "AllowedUserGroups" ) . Count ( )
2025-05-25 14:22:25 -05:00
return count , nil
}
2025-06-04 09:23:44 +02:00
func ( s * OidcService ) ListAuthorizedClients ( ctx context . Context , userID string , sortedPaginationRequest utils . SortedPaginationRequest ) ( [ ] model . UserAuthorizedOidcClient , utils . PaginationResponse , error ) {
query := s . db .
WithContext ( ctx ) .
Model ( & model . UserAuthorizedOidcClient { } ) .
Preload ( "Client" ) .
Where ( "user_id = ?" , userID )
var authorizedClients [ ] model . UserAuthorizedOidcClient
response , err := utils . PaginateAndSort ( sortedPaginationRequest , query , & authorizedClients )
return authorizedClients , response , err
}
2025-08-10 10:56:03 -05:00
func ( s * OidcService ) RevokeAuthorizedClient ( ctx context . Context , userID string , clientID string ) error {
tx := s . db . Begin ( )
defer func ( ) {
tx . Rollback ( )
} ( )
var authorizedClient model . UserAuthorizedOidcClient
err := tx .
WithContext ( ctx ) .
Where ( "user_id = ? AND client_id = ?" , userID , clientID ) .
First ( & authorizedClient ) . Error
if err != nil {
return err
}
err = tx . WithContext ( ctx ) . Delete ( & authorizedClient ) . Error
if err != nil {
return err
}
err = tx . Commit ( ) . Error
if err != nil {
return err
}
return nil
}
2025-08-17 22:47:34 +02:00
func ( s * OidcService ) ListAccessibleOidcClients ( ctx context . Context , userID string , sortedPaginationRequest utils . SortedPaginationRequest ) ( [ ] dto . AccessibleOidcClientDto , utils . PaginationResponse , error ) {
tx := s . db . Begin ( )
defer func ( ) {
tx . Rollback ( )
} ( )
var user model . User
err := tx .
WithContext ( ctx ) .
Preload ( "UserGroups" ) .
First ( & user , "id = ?" , userID ) .
Error
if err != nil {
return nil , utils . PaginationResponse { } , err
}
userGroupIDs := make ( [ ] string , len ( user . UserGroups ) )
for i , group := range user . UserGroups {
userGroupIDs [ i ] = group . ID
}
// Build the query for accessible clients
query := tx .
WithContext ( ctx ) .
Model ( & model . OidcClient { } ) .
2025-08-24 19:08:33 +02:00
Preload ( "UserAuthorizedOidcClients" , "user_id = ?" , userID )
2025-08-17 22:47:34 +02:00
// If user has no groups, only return clients with no allowed user groups
if len ( userGroupIDs ) == 0 {
2025-08-27 17:34:11 +02:00
query = query . Where ( ` NOT EXISTS (
SELECT 1 FROM oidc_clients_allowed_user_groups
WHERE oidc_clients_allowed_user_groups . oidc_client_id = oidc_clients . id ) ` )
2025-08-17 22:47:34 +02:00
} else {
2025-08-27 17:34:11 +02:00
query = query . Where ( `
NOT EXISTS (
SELECT 1 FROM oidc_clients_allowed_user_groups
WHERE oidc_clients_allowed_user_groups . oidc_client_id = oidc_clients . id
) OR EXISTS (
SELECT 1 FROM oidc_clients_allowed_user_groups
WHERE oidc_clients_allowed_user_groups . oidc_client_id = oidc_clients . id
AND oidc_clients_allowed_user_groups . user_group_id IN ( ? ) ) ` , userGroupIDs )
2025-08-17 22:47:34 +02:00
}
var clients [ ] model . OidcClient
// Handle custom sorting for lastUsedAt column
var response utils . PaginationResponse
if sortedPaginationRequest . Sort . Column == "lastUsedAt" && utils . IsValidSortDirection ( sortedPaginationRequest . Sort . Direction ) {
query = query .
Joins ( "LEFT JOIN user_authorized_oidc_clients ON oidc_clients.id = user_authorized_oidc_clients.client_id AND user_authorized_oidc_clients.user_id = ?" , userID ) .
2025-08-24 19:08:33 +02:00
Order ( "user_authorized_oidc_clients.last_used_at " + sortedPaginationRequest . Sort . Direction + " NULLS LAST" )
2025-08-17 22:47:34 +02:00
}
response , err = utils . PaginateAndSort ( sortedPaginationRequest , query , & clients )
if err != nil {
return nil , utils . PaginationResponse { } , err
}
dtos := make ( [ ] dto . AccessibleOidcClientDto , len ( clients ) )
for i , client := range clients {
var lastUsedAt * datatype . DateTime
if len ( client . UserAuthorizedOidcClients ) > 0 {
lastUsedAt = & client . UserAuthorizedOidcClients [ 0 ] . LastUsedAt
}
dtos [ i ] = dto . AccessibleOidcClientDto {
OidcClientMetaDataDto : dto . OidcClientMetaDataDto {
ID : client . ID ,
Name : client . Name ,
LaunchURL : client . LaunchURL ,
2025-09-29 10:07:55 -05:00
HasLogo : client . HasLogo ( ) ,
2025-08-17 22:47:34 +02:00
} ,
LastUsedAt : lastUsedAt ,
}
}
return dtos , response , err
}
2025-04-06 06:04:08 -07:00
func ( s * OidcService ) createRefreshToken ( ctx context . Context , clientID string , userID string , scope string , tx * gorm . DB ) ( string , error ) {
2025-03-25 07:36:53 -07:00
refreshToken , err := utils . GenerateRandomAlphanumericString ( 40 )
2025-03-23 15:14:26 -05:00
if err != nil {
return "" , err
}
2025-03-25 07:36:53 -07:00
// Compute the hash of the refresh token to store in the DB
// Refresh tokens are pretty long already, so a "simple" SHA-256 hash is enough
refreshTokenHash := utils . CreateSha256Hash ( refreshToken )
m := model . OidcRefreshToken {
2025-06-09 12:17:55 -07:00
ExpiresAt : datatype . DateTime ( time . Now ( ) . Add ( RefreshTokenDuration ) ) ,
2025-03-25 07:36:53 -07:00
Token : refreshTokenHash ,
2025-03-23 15:14:26 -05:00
ClientID : clientID ,
UserID : userID ,
Scope : scope ,
}
2025-04-06 06:04:08 -07:00
err = tx .
WithContext ( ctx ) .
Create ( & m ) .
Error
if err != nil {
2025-03-23 15:14:26 -05:00
return "" , err
}
2025-06-09 12:17:55 -07:00
// Sign the refresh token
signed , err := s . jwtService . GenerateOAuthRefreshToken ( userID , clientID , refreshToken )
if err != nil {
return "" , fmt . Errorf ( "failed to sign refresh token: %w" , err )
}
return signed , nil
2025-03-23 15:14:26 -05:00
}
2025-04-25 12:14:51 -05:00
2025-08-10 10:56:03 -05:00
func ( s * OidcService ) createAuthorizedClientInternal ( ctx context . Context , userID string , clientID string , scope string , tx * gorm . DB ) ( hasAlreadyAuthorizedClient bool , err error ) {
// Check if the user has already authorized the client with the given scope
hasAlreadyAuthorizedClient , err = s . hasAuthorizedClientInternal ( ctx , clientID , userID , scope , tx )
if err != nil {
return false , err
}
if hasAlreadyAuthorizedClient {
err = tx .
WithContext ( ctx ) .
Model ( & model . UserAuthorizedOidcClient { } ) .
Where ( "user_id = ? AND client_id = ?" , userID , clientID ) .
Update ( "last_used_at" , datatype . DateTime ( time . Now ( ) ) ) .
Error
if err != nil {
return hasAlreadyAuthorizedClient , err
}
return hasAlreadyAuthorizedClient , nil
}
2025-04-28 09:29:18 +02:00
userAuthorizedClient := model . UserAuthorizedOidcClient {
2025-08-10 10:56:03 -05:00
UserID : userID ,
ClientID : clientID ,
Scope : scope ,
LastUsedAt : datatype . DateTime ( time . Now ( ) ) ,
2025-04-28 09:29:18 +02:00
}
2025-08-10 10:56:03 -05:00
err = tx . WithContext ( ctx ) .
2025-04-28 09:29:18 +02:00
Clauses ( clause . OnConflict {
Columns : [ ] clause . Column { { Name : "user_id" } , { Name : "client_id" } } ,
DoUpdates : clause . AssignmentColumns ( [ ] string { "scope" } ) ,
} ) .
Create ( & userAuthorizedClient ) .
Error
2025-08-10 10:56:03 -05:00
return hasAlreadyAuthorizedClient , err
2025-04-28 09:29:18 +02:00
}
2025-06-09 12:17:55 -07:00
type ClientAuthCredentials struct {
ClientID string
ClientSecret string
ClientAssertion string
ClientAssertionType string
}
func clientAuthCredentialsFromCreateTokensDto ( d * dto . OidcCreateTokensDto ) ClientAuthCredentials {
return ClientAuthCredentials {
ClientID : d . ClientID ,
ClientSecret : d . ClientSecret ,
ClientAssertion : d . ClientAssertion ,
ClientAssertionType : d . ClientAssertionType ,
}
}
2025-07-01 23:14:07 +02:00
func ( s * OidcService ) verifyClientCredentialsInternal ( ctx context . Context , tx * gorm . DB , input ClientAuthCredentials , allowPublicClientsWithoutAuth bool ) ( client * model . OidcClient , err error ) {
isClientAssertion := input . ClientAssertionType == ClientAssertionTypeJWTBearer && input . ClientAssertion != ""
// Determine the client ID based on the authentication method
var clientID string
switch {
case isClientAssertion :
// Extract client ID from the JWT assertion's 'sub' claim
clientID , err = s . extractClientIDFromAssertion ( input . ClientAssertion )
if err != nil {
slog . Error ( "Failed to extract client ID from assertion" , "error" , err )
return nil , & common . OidcClientAssertionInvalidError { }
}
case input . ClientID != "" :
// Use the provided client ID for other authentication methods
clientID = input . ClientID
default :
2025-06-06 03:23:51 -07:00
return nil , & common . OidcMissingClientCredentialsError { }
2025-04-25 12:14:51 -05:00
}
2025-05-19 08:10:33 -07:00
// Load the OIDC client's configuration
2025-07-01 23:14:07 +02:00
err = tx .
2025-04-25 12:14:51 -05:00
WithContext ( ctx ) .
2025-07-01 23:14:07 +02:00
First ( & client , "id = ?" , clientID ) .
2025-04-25 12:14:51 -05:00
Error
if err != nil {
2025-07-01 23:14:07 +02:00
if errors . Is ( err , gorm . ErrRecordNotFound ) && isClientAssertion {
return nil , & common . OidcClientAssertionInvalidError { }
}
2025-06-06 03:23:51 -07:00
return nil , err
2025-04-25 12:14:51 -05:00
}
2025-07-01 23:14:07 +02:00
// Validate credentials based on the authentication method
2025-06-06 03:23:51 -07:00
switch {
2025-08-17 01:55:32 +10:00
// First, if we have a client secret, we validate it unless client is marked as public
case input . ClientSecret != "" && ! client . IsPublic :
2025-06-06 03:23:51 -07:00
err = bcrypt . CompareHashAndPassword ( [ ] byte ( client . Secret ) , [ ] byte ( input . ClientSecret ) )
2025-05-19 08:10:33 -07:00
if err != nil {
2025-06-06 03:23:51 -07:00
return nil , & common . OidcClientSecretInvalidError { }
2025-04-25 12:14:51 -05:00
}
2025-07-01 23:14:07 +02:00
return client , nil
2025-06-06 03:23:51 -07:00
// Next, check if we want to use client assertions from federated identities
2025-07-01 23:14:07 +02:00
case isClientAssertion :
err = s . verifyClientAssertionFromFederatedIdentities ( ctx , client , input )
2025-06-06 03:23:51 -07:00
if err != nil {
2025-07-27 06:34:23 +02:00
slog . WarnContext ( ctx , "Invalid assertion for client" , slog . String ( "client" , client . ID ) , slog . Any ( "error" , err ) )
2025-06-06 03:23:51 -07:00
return nil , & common . OidcClientAssertionInvalidError { }
}
2025-07-01 23:14:07 +02:00
return client , nil
2025-06-06 03:23:51 -07:00
// There's no credentials
// This is allowed only if the client is public
2025-07-01 23:14:07 +02:00
case client . IsPublic && allowPublicClientsWithoutAuth :
return client , nil
2025-06-06 03:23:51 -07:00
// If we're here, we have no credentials AND the client is not public, so credentials are required
default :
return nil , & common . OidcMissingClientCredentialsError { }
2025-04-25 12:14:51 -05:00
}
2025-06-06 03:23:51 -07:00
}
2025-04-25 12:14:51 -05:00
2025-06-06 03:23:51 -07:00
func ( s * OidcService ) jwkSetForURL ( ctx context . Context , url string ) ( set jwk . Set , err error ) {
// Check if we have already registered the URL
if ! s . jwkCache . IsRegistered ( ctx , url ) {
// We set a timeout because otherwise Register will keep trying in case of errors
registerCtx , registerCancel := context . WithTimeout ( ctx , 15 * time . Second )
defer registerCancel ( )
// We need to register the URL
err = s . jwkCache . Register (
registerCtx ,
url ,
jwk . WithMaxInterval ( 24 * time . Hour ) ,
jwk . WithMinInterval ( 15 * time . Minute ) ,
jwk . WithWaitReady ( true ) ,
)
// In case of race conditions (two goroutines calling jwkCache.Register at the same time), it's possible we can get a conflict anyways, so we ignore that error
if err != nil && ! errors . Is ( err , httprc . ErrResourceAlreadyExists ( ) ) {
return nil , fmt . Errorf ( "failed to register JWK set: %w" , err )
}
}
jwks , err := s . jwkCache . CachedSet ( url )
if err != nil {
return nil , fmt . Errorf ( "failed to get cached JWK set: %w" , err )
}
return jwks , nil
}
2025-06-09 12:17:55 -07:00
func ( s * OidcService ) verifyClientAssertionFromFederatedIdentities ( ctx context . Context , client * model . OidcClient , input ClientAuthCredentials ) error {
2025-06-06 03:23:51 -07:00
// First, parse the assertion JWT, without validating it, to check the issuer
assertion := [ ] byte ( input . ClientAssertion )
insecureToken , err := jwt . ParseInsecure ( assertion )
if err != nil {
return fmt . Errorf ( "failed to parse client assertion JWT: %w" , err )
}
issuer , _ := insecureToken . Issuer ( )
if issuer == "" {
return errors . New ( "client assertion does not contain an issuer claim" )
}
// Ensure that this client is federated with the one that issued the token
ocfi , ok := client . Credentials . FederatedIdentityForIssuer ( issuer )
if ! ok {
return fmt . Errorf ( "client assertion is not from an allowed issuer: %s" , issuer )
}
// Get the JWK set for the issuer
jwksURL := ocfi . JWKS
if jwksURL == "" {
// Default URL is from the issuer
if strings . HasSuffix ( issuer , "/" ) {
jwksURL = issuer + ".well-known/jwks.json"
} else {
jwksURL = issuer + "/.well-known/jwks.json"
}
}
jwks , err := s . jwkSetForURL ( ctx , jwksURL )
if err != nil {
return fmt . Errorf ( "failed to get JWK set for issuer '%s': %w" , issuer , err )
}
// Set default audience and subject if missing
audience := ocfi . Audience
if audience == "" {
// Default to the Pocket ID's URL
audience = common . EnvConfig . AppURL
}
subject := ocfi . Subject
if subject == "" {
// Default to the client ID, per RFC 7523
subject = client . ID
}
// Now re-parse the token with proper validation
// (Note: we don't use jwt.WithIssuer() because that would be redundant)
_ , err = jwt . Parse ( assertion ,
jwt . WithValidate ( true ) ,
jwt . WithAcceptableSkew ( clockSkew ) ,
jwt . WithKeySet ( jwks , jws . WithInferAlgorithmFromKey ( true ) , jws . WithUseDefault ( true ) ) ,
jwt . WithAudience ( audience ) ,
jwt . WithSubject ( subject ) ,
)
if err != nil {
return fmt . Errorf ( "client assertion is not valid: %w" , err )
}
// If we're here, the assertion is valid
return nil
2025-04-25 12:14:51 -05:00
}
2025-06-09 10:46:03 -05:00
2025-07-01 23:14:07 +02:00
// extractClientIDFromAssertion extracts the client_id from the JWT assertion's 'sub' claim
func ( s * OidcService ) extractClientIDFromAssertion ( assertion string ) ( string , error ) {
// Parse the JWT without verification first to get the claims
insecureToken , err := jwt . ParseInsecure ( [ ] byte ( assertion ) )
if err != nil {
return "" , fmt . Errorf ( "failed to parse JWT assertion: %w" , err )
}
// Extract the subject claim which must be the client_id according to RFC 7523
sub , ok := insecureToken . Subject ( )
if ! ok || sub == "" {
return "" , fmt . Errorf ( "missing or invalid 'sub' claim in JWT assertion" )
}
return sub , nil
}
2025-09-09 02:31:50 -07:00
func ( s * OidcService ) GetClientPreview ( ctx context . Context , clientID string , userID string , scopes [ ] string ) ( * dto . OidcClientPreviewDto , error ) {
2025-06-09 10:46:03 -05:00
tx := s . db . Begin ( )
defer func ( ) {
tx . Rollback ( )
} ( )
client , err := s . getClientInternal ( ctx , clientID , tx )
if err != nil {
return nil , err
}
var user model . User
err = tx .
WithContext ( ctx ) .
Preload ( "UserGroups" ) .
First ( & user , "id = ?" , userID ) .
Error
if err != nil {
return nil , err
}
if ! s . IsUserGroupAllowedToAuthorize ( user , client ) {
return nil , & common . OidcAccessDeniedError { }
}
2025-09-09 02:31:50 -07:00
userClaims , err := s . getUserClaims ( ctx , & user , scopes , tx )
2025-06-09 10:46:03 -05:00
if err != nil {
return nil , err
}
2025-06-09 12:17:55 -07:00
// Commit the transaction before signing tokens to avoid locking the database for longer
err = tx . Commit ( ) . Error
2025-06-09 10:46:03 -05:00
if err != nil {
return nil , err
}
2025-06-09 12:17:55 -07:00
idToken , err := s . jwtService . BuildIDToken ( userClaims , clientID , "" )
2025-06-09 10:46:03 -05:00
if err != nil {
return nil , err
}
2025-06-09 12:17:55 -07:00
accessToken , err := s . jwtService . BuildOAuthAccessToken ( user , clientID )
2025-06-09 10:46:03 -05:00
if err != nil {
return nil , err
}
2025-06-09 12:17:55 -07:00
idTokenPayload , err := utils . GetClaimsFromToken ( idToken )
2025-06-09 10:46:03 -05:00
if err != nil {
return nil , err
}
2025-06-09 12:17:55 -07:00
accessTokenPayload , err := utils . GetClaimsFromToken ( accessToken )
2025-06-09 10:46:03 -05:00
if err != nil {
return nil , err
}
return & dto . OidcClientPreviewDto {
IdToken : idTokenPayload ,
AccessToken : accessTokenPayload ,
UserInfo : userClaims ,
} , nil
}
2025-06-09 12:17:55 -07:00
func ( s * OidcService ) GetUserClaimsForClient ( ctx context . Context , userID string , clientID string ) ( map [ string ] any , error ) {
return s . getUserClaimsForClientInternal ( ctx , userID , clientID , s . db )
2025-06-09 10:46:03 -05:00
}
2025-06-09 12:17:55 -07:00
func ( s * OidcService ) getUserClaimsForClientInternal ( ctx context . Context , userID string , clientID string , tx * gorm . DB ) ( map [ string ] any , error ) {
2025-06-09 10:46:03 -05:00
var authorizedOidcClient model . UserAuthorizedOidcClient
err := tx .
WithContext ( ctx ) .
Preload ( "User.UserGroups" ) .
First ( & authorizedOidcClient , "user_id = ? AND client_id = ?" , userID , clientID ) .
Error
if err != nil {
return nil , err
}
2025-09-09 02:31:50 -07:00
return s . getUserClaims ( ctx , & authorizedOidcClient . User , authorizedOidcClient . Scopes ( ) , tx )
2025-06-09 10:46:03 -05:00
}
2025-09-09 02:31:50 -07:00
func ( s * OidcService ) getUserClaims ( ctx context . Context , user * model . User , scopes [ ] string , tx * gorm . DB ) ( map [ string ] any , error ) {
2025-06-09 12:17:55 -07:00
claims := make ( map [ string ] any , 10 )
2025-06-09 10:46:03 -05:00
2025-06-09 12:17:55 -07:00
claims [ "sub" ] = user . ID
2025-06-09 10:46:03 -05:00
if slices . Contains ( scopes , "email" ) {
claims [ "email" ] = user . Email
claims [ "email_verified" ] = s . appConfigService . GetDbConfig ( ) . EmailsVerified . IsTrue ( )
}
if slices . Contains ( scopes , "groups" ) {
userGroups := make ( [ ] string , len ( user . UserGroups ) )
for i , group := range user . UserGroups {
userGroups [ i ] = group . Name
}
claims [ "groups" ] = userGroups
}
if slices . Contains ( scopes , "profile" ) {
// Add custom claims
customClaims , err := s . customClaimService . GetCustomClaimsForUserWithUserGroups ( ctx , user . ID , tx )
if err != nil {
return nil , err
}
for _ , customClaim := range customClaims {
// The value of the custom claim can be a JSON object or a string
2025-06-09 12:17:55 -07:00
var jsonValue any
2025-06-09 10:46:03 -05:00
err := json . Unmarshal ( [ ] byte ( customClaim . Value ) , & jsonValue )
if err == nil {
// It's JSON, so we store it as an object
claims [ customClaim . Key ] = jsonValue
} else {
// Marshaling failed, so we store it as a string
claims [ customClaim . Key ] = customClaim . Value
}
}
2025-09-17 10:18:27 -05:00
// Add profile claims
claims [ "given_name" ] = user . FirstName
claims [ "family_name" ] = user . LastName
claims [ "name" ] = user . FullName ( )
claims [ "display_name" ] = user . DisplayName
claims [ "preferred_username" ] = user . Username
claims [ "picture" ] = common . EnvConfig . AppURL + "/api/users/" + user . ID + "/profile-picture.png"
2025-06-09 10:46:03 -05:00
}
if slices . Contains ( scopes , "email" ) {
claims [ "email" ] = user . Email
}
return claims , nil
}
2025-08-10 10:56:03 -05:00
func ( s * OidcService ) IsClientAccessibleToUser ( ctx context . Context , clientID string , userID string ) ( bool , error ) {
var user model . User
err := s . db . WithContext ( ctx ) . Preload ( "UserGroups" ) . First ( & user , "id = ?" , userID ) . Error
if err != nil {
return false , err
}
var client model . OidcClient
err = s . db . WithContext ( ctx ) . Preload ( "AllowedUserGroups" ) . First ( & client , "id = ?" , clientID ) . Error
if err != nil {
return false , err
}
return s . IsUserGroupAllowedToAuthorize ( user , client ) , nil
}
2025-09-29 10:07:55 -05:00
func ( s * OidcService ) downloadAndSaveLogoFromURL ( parentCtx context . Context , tx * gorm . DB , clientID string , raw string ) error {
u , err := url . Parse ( raw )
if err != nil {
return err
}
ctx , cancel := context . WithTimeout ( parentCtx , 15 * time . Second )
defer cancel ( )
r := net . Resolver { }
ips , err := r . LookupIPAddr ( ctx , u . Hostname ( ) )
if err != nil || len ( ips ) == 0 {
return fmt . Errorf ( "cannot resolve hostname" )
}
// Prevents SSRF by allowing only public IPs
for _ , addr := range ips {
if utils . IsPrivateIP ( addr . IP ) {
return fmt . Errorf ( "private IP addresses are not allowed" )
}
}
req , err := http . NewRequestWithContext ( ctx , http . MethodGet , raw , nil )
if err != nil {
return err
}
req . Header . Set ( "User-Agent" , "pocket-id/oidc-logo-fetcher" )
req . Header . Set ( "Accept" , "image/*" )
resp , err := s . httpClient . Do ( req )
if err != nil {
return err
}
defer resp . Body . Close ( )
if resp . StatusCode != http . StatusOK {
return fmt . Errorf ( "failed to fetch logo: %s" , resp . Status )
}
const maxLogoSize int64 = 2 * 1024 * 1024 // 2MB
if resp . ContentLength > maxLogoSize {
return fmt . Errorf ( "logo is too large" )
}
// Prefer extension in path if supported
ext := utils . GetFileExtension ( u . Path )
if ext == "" || utils . GetImageMimeType ( ext ) == "" {
// Otherwise, try to detect from content type
ext = utils . GetImageExtensionFromMimeType ( resp . Header . Get ( "Content-Type" ) )
}
if ext == "" {
return & common . FileTypeNotSupportedError { }
}
imagePath := common . EnvConfig . UploadPath + "/oidc-client-images/" + clientID + "." + ext
err = utils . SaveFileStream ( io . LimitReader ( resp . Body , maxLogoSize + 1 ) , imagePath )
if err != nil {
return err
}
if err := s . updateClientLogoType ( ctx , tx , clientID , ext ) ; err != nil {
return err
}
return nil
}
func ( s * OidcService ) updateClientLogoType ( ctx context . Context , tx * gorm . DB , clientID , ext string ) error {
uploadsDir := common . EnvConfig . UploadPath + "/oidc-client-images"
var client model . OidcClient
if err := tx . WithContext ( ctx ) . First ( & client , "id = ?" , clientID ) . Error ; err != nil {
return err
}
if client . ImageType != nil && * client . ImageType != ext {
old := fmt . Sprintf ( "%s/%s.%s" , uploadsDir , client . ID , * client . ImageType )
_ = os . Remove ( old )
}
client . ImageType = & ext
return tx . WithContext ( ctx ) . Save ( & client ) . Error
}