mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-16 10:13:01 +03:00
feat: allow introspection and device code endpoints to use Federated Client Credentials (#640)
This commit is contained in:
committed by
GitHub
parent
df5c1ed1f8
commit
b62b61fb01
@@ -40,6 +40,9 @@ const (
|
||||
GrantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
|
||||
|
||||
ClientAssertionTypeJWTBearer = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" //nolint:gosec
|
||||
|
||||
RefreshTokenDuration = 30 * 24 * time.Hour // 30 days
|
||||
DeviceCodeDuration = 15 * time.Minute
|
||||
)
|
||||
|
||||
type OidcService struct {
|
||||
@@ -252,7 +255,7 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
_, err := s.verifyClientCredentialsInternal(ctx, tx, input)
|
||||
_, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input))
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
@@ -303,7 +306,7 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
accessToken, err := s.jwtService.GenerateOauthAccessToken(deviceAuth.User, input.ClientID)
|
||||
accessToken, err := s.jwtService.GenerateOAuthAccessToken(deviceAuth.User, input.ClientID)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
@@ -333,7 +336,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, tx, input)
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input))
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
@@ -375,7 +378,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
accessToken, err := s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, input.ClientID)
|
||||
accessToken, err := s.jwtService.GenerateOAuthAccessToken(authorizationCodeMetaData.User, input.ClientID)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
@@ -406,22 +409,39 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
|
||||
return CreatedTokens{}, &common.OidcMissingRefreshTokenError{}
|
||||
}
|
||||
|
||||
// Validate the signed refresh token and extract the actual token (which is a claim in the signed one)
|
||||
userID, clientID, rt, err := s.jwtService.VerifyOAuthRefreshToken(input.RefreshToken)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
|
||||
}
|
||||
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
_, err := s.verifyClientCredentialsInternal(ctx, tx, input)
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, tx, clientAuthCredentialsFromCreateTokensDto(&input))
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
// The ID of the client that made the call must match the client ID in the token
|
||||
if client.ID != clientID {
|
||||
return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
|
||||
}
|
||||
|
||||
// Verify refresh token
|
||||
var storedRefreshToken model.OidcRefreshToken
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Preload("User").
|
||||
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(input.RefreshToken), datatype.DateTime(time.Now())).
|
||||
Where(
|
||||
"token = ? AND expires_at > ? AND user_id = ? AND client_id = ?",
|
||||
utils.CreateSha256Hash(rt),
|
||||
datatype.DateTime(time.Now()),
|
||||
userID,
|
||||
input.ClientID,
|
||||
).
|
||||
First(&storedRefreshToken).
|
||||
Error
|
||||
if err != nil {
|
||||
@@ -437,7 +457,7 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
|
||||
}
|
||||
|
||||
// Generate a new access token
|
||||
accessToken, err := s.jwtService.GenerateOauthAccessToken(storedRefreshToken.User, input.ClientID)
|
||||
accessToken, err := s.jwtService.GenerateOAuthAccessToken(storedRefreshToken.User, input.ClientID)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
@@ -469,33 +489,69 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecret, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
|
||||
if clientID == "" || clientSecret == "" {
|
||||
func (s *OidcService) IntrospectToken(ctx context.Context, creds ClientAuthCredentials, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
|
||||
// Get the type of the token and the client ID
|
||||
tokenType, token, err := s.jwtService.GetTokenType(tokenString)
|
||||
if err != nil {
|
||||
// We just treat the token as invalid
|
||||
introspectDto.Active = false
|
||||
return introspectDto, nil //nolint:nilerr
|
||||
}
|
||||
|
||||
// If we don't have a client ID, get it from the token
|
||||
// Otherwise, we need to make sure that the client ID passed as credential matches
|
||||
tokenAudiences, _ := token.Audience()
|
||||
if len(tokenAudiences) != 1 || tokenAudiences[0] == "" {
|
||||
// We just treat the token as invalid
|
||||
introspectDto.Active = false
|
||||
return introspectDto, nil
|
||||
}
|
||||
if creds.ClientID == "" {
|
||||
creds.ClientID = tokenAudiences[0]
|
||||
} else if creds.ClientID != tokenAudiences[0] {
|
||||
return introspectDto, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
|
||||
_, err = s.verifyClientCredentialsInternal(ctx, s.db, dto.OidcCreateTokensDto{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
})
|
||||
// Verify the credentials for the call
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, s.db, creds)
|
||||
if err != nil {
|
||||
return introspectDto, err
|
||||
}
|
||||
|
||||
token, err := s.jwtService.VerifyOauthAccessToken(tokenString)
|
||||
if err != nil {
|
||||
if errors.Is(err, jwt.ParseError()) {
|
||||
// It's apparently not a valid JWT token, so we check if it's a valid refresh_token.
|
||||
return s.introspectRefreshToken(ctx, tokenString)
|
||||
}
|
||||
|
||||
// Every failure we get means the token is invalid. Nothing more to do with the error.
|
||||
// Introspect the token
|
||||
switch tokenType {
|
||||
case OAuthAccessTokenJWTType:
|
||||
return s.introspectAccessToken(client.ID, tokenString)
|
||||
case OAuthRefreshTokenJWTType:
|
||||
return s.introspectRefreshToken(ctx, client.ID, tokenString)
|
||||
default:
|
||||
// We just treat the token as invalid
|
||||
introspectDto.Active = false
|
||||
return introspectDto, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OidcService) introspectAccessToken(clientID string, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
|
||||
token, err := s.jwtService.VerifyOAuthAccessToken(tokenString)
|
||||
if err != nil {
|
||||
// Every failure we get means the token is invalid. Nothing more to do with the error.
|
||||
introspectDto.Active = false
|
||||
return introspectDto, nil //nolint:nilerr
|
||||
}
|
||||
|
||||
// The ID of the client that made the request must match the client ID in the token
|
||||
audience, ok := token.Audience()
|
||||
if !ok || len(audience) != 1 || audience[0] == "" {
|
||||
introspectDto.Active = false
|
||||
return introspectDto, nil
|
||||
}
|
||||
if audience[0] != clientID {
|
||||
return introspectDto, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
|
||||
introspectDto.Active = true
|
||||
introspectDto.TokenType = "access_token"
|
||||
introspectDto.Audience = audience
|
||||
if token.Has("scope") {
|
||||
var (
|
||||
asString string
|
||||
@@ -519,9 +575,6 @@ func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecre
|
||||
if subject, ok := token.Subject(); ok {
|
||||
introspectDto.Subject = subject
|
||||
}
|
||||
if audience, ok := token.Audience(); ok {
|
||||
introspectDto.Audience = audience
|
||||
}
|
||||
if issuer, ok := token.Issuer(); ok {
|
||||
introspectDto.Issuer = issuer
|
||||
}
|
||||
@@ -532,12 +585,29 @@ func (s *OidcService) IntrospectToken(ctx context.Context, clientID, clientSecre
|
||||
return introspectDto, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) introspectRefreshToken(ctx context.Context, refreshToken string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
|
||||
func (s *OidcService) introspectRefreshToken(ctx context.Context, clientID string, refreshToken string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
|
||||
// Validate the signed refresh token and extract the actual token (which is a claim in the signed one)
|
||||
tokenUserID, tokenClientID, tokenRT, err := s.jwtService.VerifyOAuthRefreshToken(refreshToken)
|
||||
if err != nil {
|
||||
return introspectDto, fmt.Errorf("invalid refresh token: %w", err)
|
||||
}
|
||||
|
||||
// The ID of the client that made the call must match the client ID in the token
|
||||
if tokenClientID != clientID {
|
||||
return introspectDto, errors.New("invalid refresh token: client ID does not match")
|
||||
}
|
||||
|
||||
var storedRefreshToken model.OidcRefreshToken
|
||||
err = s.db.
|
||||
WithContext(ctx).
|
||||
Preload("User").
|
||||
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(refreshToken), datatype.DateTime(time.Now())).
|
||||
Where(
|
||||
"token = ? AND expires_at > ? AND user_id = ? AND client_id = ?",
|
||||
utils.CreateSha256Hash(tokenRT),
|
||||
datatype.DateTime(time.Now()),
|
||||
tokenUserID,
|
||||
tokenClientID,
|
||||
).
|
||||
First(&storedRefreshToken).
|
||||
Error
|
||||
if err != nil {
|
||||
@@ -1062,9 +1132,11 @@ func (s *OidcService) addCallbackURLToClient(ctx context.Context, client *model.
|
||||
}
|
||||
|
||||
func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.OidcDeviceAuthorizationRequestDto) (*dto.OidcDeviceAuthorizationResponseDto, error) {
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, s.db, dto.OidcCreateTokensDto{
|
||||
ClientID: input.ClientID,
|
||||
ClientSecret: input.ClientSecret,
|
||||
client, err := s.verifyClientCredentialsInternal(ctx, s.db, ClientAuthCredentials{
|
||||
ClientID: input.ClientID,
|
||||
ClientSecret: input.ClientSecret,
|
||||
ClientAssertionType: input.ClientAssertionType,
|
||||
ClientAssertion: input.ClientAssertion,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1085,7 +1157,7 @@ func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.O
|
||||
DeviceCode: deviceCode,
|
||||
UserCode: userCode,
|
||||
Scope: input.Scope,
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(15 * time.Minute)),
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(DeviceCodeDuration)),
|
||||
IsAuthorized: false,
|
||||
ClientID: client.ID,
|
||||
}
|
||||
@@ -1099,7 +1171,7 @@ func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.O
|
||||
UserCode: userCode,
|
||||
VerificationURI: common.EnvConfig.AppURL + "/device",
|
||||
VerificationURIComplete: common.EnvConfig.AppURL + "/device?code=" + userCode,
|
||||
ExpiresIn: 900, // 15 minutes
|
||||
ExpiresIn: int(DeviceCodeDuration.Seconds()),
|
||||
Interval: 5,
|
||||
}, nil
|
||||
}
|
||||
@@ -1255,7 +1327,7 @@ func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, u
|
||||
refreshTokenHash := utils.CreateSha256Hash(refreshToken)
|
||||
|
||||
m := model.OidcRefreshToken{
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(30 * 24 * time.Hour)), // 30 days
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(RefreshTokenDuration)),
|
||||
Token: refreshTokenHash,
|
||||
ClientID: clientID,
|
||||
UserID: userID,
|
||||
@@ -1270,7 +1342,13 @@ func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, u
|
||||
return "", err
|
||||
}
|
||||
|
||||
return refreshToken, nil
|
||||
// Sign the refresh token
|
||||
signed, err := s.jwtService.GenerateOAuthRefreshToken(userID, clientID, refreshToken)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign refresh token: %w", err)
|
||||
}
|
||||
|
||||
return signed, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) createAuthorizedClientInternal(ctx context.Context, userID string, clientID string, scope string, tx *gorm.DB) error {
|
||||
@@ -1291,7 +1369,23 @@ func (s *OidcService) createAuthorizedClientInternal(ctx context.Context, userID
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input dto.OidcCreateTokensDto) (*model.OidcClient, error) {
|
||||
type ClientAuthCredentials struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
ClientAssertion string
|
||||
ClientAssertionType string
|
||||
}
|
||||
|
||||
func clientAuthCredentialsFromCreateTokensDto(d *dto.OidcCreateTokensDto) ClientAuthCredentials {
|
||||
return ClientAuthCredentials{
|
||||
ClientID: d.ClientID,
|
||||
ClientSecret: d.ClientSecret,
|
||||
ClientAssertion: d.ClientAssertion,
|
||||
ClientAssertionType: d.ClientAssertionType,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OidcService) verifyClientCredentialsInternal(ctx context.Context, tx *gorm.DB, input ClientAuthCredentials) (*model.OidcClient, error) {
|
||||
// First, ensure we have a valid client ID
|
||||
if input.ClientID == "" {
|
||||
return nil, &common.OidcMissingClientCredentialsError{}
|
||||
@@ -1366,7 +1460,7 @@ func (s *OidcService) jwkSetForURL(ctx context.Context, url string) (set jwk.Set
|
||||
return jwks, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.Context, client *model.OidcClient, input dto.OidcCreateTokensDto) error {
|
||||
func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.Context, client *model.OidcClient, input ClientAuthCredentials) error {
|
||||
// First, parse the assertion JWT, without validating it, to check the issuer
|
||||
assertion := []byte(input.ClientAssertion)
|
||||
insecureToken, err := jwt.ParseInsecure(assertion)
|
||||
@@ -1466,12 +1560,18 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Commit the transaction before signing tokens to avoid locking the database for longer
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
idToken, err := s.jwtService.BuildIDToken(userClaims, clientID, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accessToken, err := s.jwtService.BuildOauthAccessToken(user, clientID)
|
||||
accessToken, err := s.jwtService.BuildOAuthAccessToken(user, clientID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1486,11 +1586,6 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &dto.OidcClientPreviewDto{
|
||||
IdToken: idTokenPayload,
|
||||
AccessToken: accessTokenPayload,
|
||||
@@ -1498,26 +1593,11 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) GetUserClaimsForClient(ctx context.Context, userID string, clientID string) (map[string]interface{}, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
claims, err := s.getUserClaimsForClientInternal(ctx, userID, clientID, s.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
func (s *OidcService) GetUserClaimsForClient(ctx context.Context, userID string, clientID string) (map[string]any, error) {
|
||||
return s.getUserClaimsForClientInternal(ctx, userID, clientID, s.db)
|
||||
}
|
||||
|
||||
func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID string, clientID string, tx *gorm.DB) (map[string]interface{}, error) {
|
||||
func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID string, clientID string, tx *gorm.DB) (map[string]any, error) {
|
||||
var authorizedOidcClient model.UserAuthorizedOidcClient
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
@@ -1532,14 +1612,13 @@ func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID
|
||||
|
||||
}
|
||||
|
||||
func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, authorizedClient *model.UserAuthorizedOidcClient, tx *gorm.DB) (map[string]interface{}, error) {
|
||||
func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, authorizedClient *model.UserAuthorizedOidcClient, tx *gorm.DB) (map[string]any, error) {
|
||||
user := authorizedClient.User
|
||||
scopes := strings.Split(authorizedClient.Scope, " ")
|
||||
|
||||
claims := map[string]interface{}{
|
||||
"sub": user.ID,
|
||||
}
|
||||
claims := make(map[string]any, 10)
|
||||
|
||||
claims["sub"] = user.ID
|
||||
if slices.Contains(scopes, "email") {
|
||||
claims["email"] = user.Email
|
||||
claims["email_verified"] = s.appConfigService.GetDbConfig().EmailsVerified.IsTrue()
|
||||
@@ -1553,19 +1632,13 @@ func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, aut
|
||||
claims["groups"] = userGroups
|
||||
}
|
||||
|
||||
profileClaims := map[string]interface{}{
|
||||
"given_name": user.FirstName,
|
||||
"family_name": user.LastName,
|
||||
"name": user.FullName(),
|
||||
"preferred_username": user.Username,
|
||||
"picture": common.EnvConfig.AppURL + "/api/users/" + user.ID + "/profile-picture.png",
|
||||
}
|
||||
|
||||
if slices.Contains(scopes, "profile") {
|
||||
// Add profile claims
|
||||
for k, v := range profileClaims {
|
||||
claims[k] = v
|
||||
}
|
||||
claims["given_name"] = user.FirstName
|
||||
claims["family_name"] = user.LastName
|
||||
claims["name"] = user.FullName()
|
||||
claims["preferred_username"] = user.Username
|
||||
claims["picture"] = common.EnvConfig.AppURL + "/api/users/" + user.ID + "/profile-picture.png"
|
||||
|
||||
// Add custom claims
|
||||
customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(ctx, user.ID, tx)
|
||||
@@ -1575,7 +1648,7 @@ func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, aut
|
||||
|
||||
for _, customClaim := range customClaims {
|
||||
// The value of the custom claim can be a JSON object or a string
|
||||
var jsonValue interface{}
|
||||
var jsonValue any
|
||||
err := json.Unmarshal([]byte(customClaim.Value), &jsonValue)
|
||||
if err == nil {
|
||||
// It's JSON, so we store it as an object
|
||||
|
||||
Reference in New Issue
Block a user