mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-05-04 18:00:38 +03:00
fix: access token renewal bypasses important checks
This commit is contained in:
@@ -534,6 +534,43 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
|
||||
return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
|
||||
}
|
||||
|
||||
if storedRefreshToken.User.Disabled {
|
||||
return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
|
||||
}
|
||||
|
||||
var authorizedClient model.UserAuthorizedOidcClient
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Where("user_id = ? AND client_id = ?", storedRefreshToken.UserID, input.ClientID).
|
||||
First(&authorizedClient).
|
||||
Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
err = tx.WithContext(ctx).Delete(&storedRefreshToken).Error
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
|
||||
} else if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
if client.IsGroupRestricted {
|
||||
err = tx.WithContext(ctx).Model(client).Association("AllowedUserGroups").Find(&client.AllowedUserGroups)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
}
|
||||
|
||||
if !IsUserGroupAllowedToAuthorize(storedRefreshToken.User, *client) {
|
||||
return CreatedTokens{}, &common.OidcAccessDeniedError{}
|
||||
}
|
||||
|
||||
// Generate a new access token
|
||||
authenticationMethods := storedRefreshToken.AuthenticationMethod
|
||||
accessToken, err := s.jwtService.GenerateOAuthAccessToken(storedRefreshToken.User, input.ClientID, authenticationMethods)
|
||||
@@ -1500,6 +1537,15 @@ func (s *OidcService) RevokeAuthorizedClient(ctx context.Context, userID string,
|
||||
return err
|
||||
}
|
||||
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Where("user_id = ? AND client_id = ?", userID, clientID).
|
||||
Delete(&model.OidcRefreshToken{}).
|
||||
Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
@@ -562,6 +563,140 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestOidcServiceRefreshTokenAuthorizationState(t *testing.T) {
|
||||
newFixture := func(t *testing.T, isGroupRestricted bool) (*OidcService, *gorm.DB, model.User, model.OidcClient, string, string, *model.UserGroup) {
|
||||
t.Helper()
|
||||
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
common.EnvConfig.EncryptionKey = []byte("0123456789abcdef0123456789abcdef")
|
||||
|
||||
mockConfig := NewTestAppConfigService(&model.AppConfig{
|
||||
SessionDuration: model.AppConfigVariable{Value: "60"},
|
||||
})
|
||||
jwtService, err := NewJwtService(t.Context(), db, mockConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
service := &OidcService{
|
||||
db: db,
|
||||
jwtService: jwtService,
|
||||
appConfigService: mockConfig,
|
||||
}
|
||||
|
||||
email := "refresh-token-user@example.com"
|
||||
user := model.User{
|
||||
Username: "refresh-token-user",
|
||||
Email: &email,
|
||||
EmailVerified: true,
|
||||
FirstName: "Refresh",
|
||||
LastName: "User",
|
||||
}
|
||||
require.NoError(t, db.Create(&user).Error)
|
||||
|
||||
client, err := service.CreateClient(t.Context(), dto.OidcClientCreateDto{
|
||||
OidcClientUpdateDto: dto.OidcClientUpdateDto{
|
||||
Name: "Refresh Token Client",
|
||||
CallbackURLs: []string{"https://example.com/callback"},
|
||||
IsGroupRestricted: isGroupRestricted,
|
||||
},
|
||||
}, user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientSecret, err := service.CreateClientSecret(t.Context(), client.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
var userGroup *model.UserGroup
|
||||
if isGroupRestricted {
|
||||
group := model.UserGroup{
|
||||
FriendlyName: "Allowed Group",
|
||||
Name: "allowed-group",
|
||||
}
|
||||
require.NoError(t, db.Create(&group).Error)
|
||||
require.NoError(t, db.Model(&user).Association("UserGroups").Append(&group))
|
||||
require.NoError(t, db.Model(&client).Association("AllowedUserGroups").Append(&group))
|
||||
userGroup = &group
|
||||
}
|
||||
|
||||
scope := "openid profile email groups"
|
||||
require.NoError(t, db.Create(&model.UserAuthorizedOidcClient{
|
||||
UserID: user.ID,
|
||||
ClientID: client.ID,
|
||||
Scope: scope,
|
||||
}).Error)
|
||||
|
||||
refreshToken, err := service.createRefreshToken(t.Context(), client.ID, user.ID, scope, AuthenticationMethodPhishingResistant, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
return service, db, user, client, clientSecret, refreshToken, userGroup
|
||||
}
|
||||
|
||||
refreshInput := func(client model.OidcClient, clientSecret string, refreshToken string) dto.OidcCreateTokensDto {
|
||||
return dto.OidcCreateTokensDto{
|
||||
GrantType: GrantTypeRefreshToken,
|
||||
RefreshToken: refreshToken,
|
||||
ClientID: client.ID,
|
||||
ClientSecret: clientSecret,
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("rejects refresh token after authorization revocation", func(t *testing.T) {
|
||||
service, db, user, client, clientSecret, refreshToken, _ := newFixture(t, false)
|
||||
|
||||
err := service.RevokeAuthorizedClient(t.Context(), user.ID, client.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
var refreshTokenCount int64
|
||||
require.NoError(t, db.Model(&model.OidcRefreshToken{}).
|
||||
Where("user_id = ? AND client_id = ?", user.ID, client.ID).
|
||||
Count(&refreshTokenCount).Error)
|
||||
assert.Zero(t, refreshTokenCount)
|
||||
|
||||
_, err = service.createTokenFromRefreshToken(t.Context(), refreshInput(client, clientSecret, refreshToken))
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, &common.OidcInvalidRefreshTokenError{})
|
||||
})
|
||||
|
||||
t.Run("rejects and deletes stale refresh token without authorization record", func(t *testing.T) {
|
||||
service, db, user, client, clientSecret, refreshToken, _ := newFixture(t, false)
|
||||
|
||||
require.NoError(t, db.
|
||||
Where("user_id = ? AND client_id = ?", user.ID, client.ID).
|
||||
Delete(&model.UserAuthorizedOidcClient{}).Error)
|
||||
|
||||
_, err := service.createTokenFromRefreshToken(t.Context(), refreshInput(client, clientSecret, refreshToken))
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, &common.OidcInvalidRefreshTokenError{})
|
||||
|
||||
var refreshTokenCount int64
|
||||
require.NoError(t, db.Model(&model.OidcRefreshToken{}).
|
||||
Where("user_id = ? AND client_id = ?", user.ID, client.ID).
|
||||
Count(&refreshTokenCount).Error)
|
||||
assert.Zero(t, refreshTokenCount)
|
||||
})
|
||||
|
||||
t.Run("rejects refresh token for disabled user", func(t *testing.T) {
|
||||
service, db, user, client, clientSecret, refreshToken, _ := newFixture(t, false)
|
||||
|
||||
require.NoError(t, db.Model(&model.User{}).
|
||||
Where("id = ?", user.ID).
|
||||
Update("disabled", true).Error)
|
||||
|
||||
_, err := service.createTokenFromRefreshToken(t.Context(), refreshInput(client, clientSecret, refreshToken))
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, &common.OidcInvalidRefreshTokenError{})
|
||||
})
|
||||
|
||||
t.Run("rejects refresh token after user leaves allowed group", func(t *testing.T) {
|
||||
service, db, user, client, clientSecret, refreshToken, userGroup := newFixture(t, true)
|
||||
require.NotNil(t, userGroup)
|
||||
|
||||
require.NoError(t, db.Model(&user).Association("UserGroups").Delete(userGroup))
|
||||
|
||||
_, err := service.createTokenFromRefreshToken(t.Context(), refreshInput(client, clientSecret, refreshToken))
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, &common.OidcAccessDeniedError{})
|
||||
})
|
||||
}
|
||||
|
||||
func TestOidcServiceAuthenticationMethodsPersistence(t *testing.T) {
|
||||
mockConfig := NewTestAppConfigService(&model.AppConfig{
|
||||
SessionDuration: model.AppConfigVariable{Value: "60"},
|
||||
|
||||
Reference in New Issue
Block a user