diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go index ef765b27..6c09cb8c 100644 --- a/backend/internal/service/oidc_service.go +++ b/backend/internal/service/oidc_service.go @@ -534,6 +534,43 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{} } + if storedRefreshToken.User.Disabled { + return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{} + } + + var authorizedClient model.UserAuthorizedOidcClient + err = tx. + WithContext(ctx). + Where("user_id = ? AND client_id = ?", storedRefreshToken.UserID, input.ClientID). + First(&authorizedClient). + Error + if errors.Is(err, gorm.ErrRecordNotFound) { + err = tx.WithContext(ctx).Delete(&storedRefreshToken).Error + if err != nil { + return CreatedTokens{}, err + } + + err = tx.Commit().Error + if err != nil { + return CreatedTokens{}, err + } + + return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{} + } else if err != nil { + return CreatedTokens{}, err + } + + if client.IsGroupRestricted { + err = tx.WithContext(ctx).Model(client).Association("AllowedUserGroups").Find(&client.AllowedUserGroups) + if err != nil { + return CreatedTokens{}, err + } + } + + if !IsUserGroupAllowedToAuthorize(storedRefreshToken.User, *client) { + return CreatedTokens{}, &common.OidcAccessDeniedError{} + } + // Generate a new access token authenticationMethods := storedRefreshToken.AuthenticationMethod accessToken, err := s.jwtService.GenerateOAuthAccessToken(storedRefreshToken.User, input.ClientID, authenticationMethods) @@ -1500,6 +1537,15 @@ func (s *OidcService) RevokeAuthorizedClient(ctx context.Context, userID string, return err } + err = tx. + WithContext(ctx). + Where("user_id = ? AND client_id = ?", userID, clientID). + Delete(&model.OidcRefreshToken{}). + Error + if err != nil { + return err + } + err = tx.Commit().Error if err != nil { return err diff --git a/backend/internal/service/oidc_service_test.go b/backend/internal/service/oidc_service_test.go index c7518e4c..0f807432 100644 --- a/backend/internal/service/oidc_service_test.go +++ b/backend/internal/service/oidc_service_test.go @@ -20,6 +20,7 @@ import ( "github.com/lestrrat-go/jwx/v3/jwt" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gorm.io/gorm" "github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/dto" @@ -562,6 +563,140 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) { }) } +func TestOidcServiceRefreshTokenAuthorizationState(t *testing.T) { + newFixture := func(t *testing.T, isGroupRestricted bool) (*OidcService, *gorm.DB, model.User, model.OidcClient, string, string, *model.UserGroup) { + t.Helper() + + db := testutils.NewDatabaseForTest(t) + common.EnvConfig.EncryptionKey = []byte("0123456789abcdef0123456789abcdef") + + mockConfig := NewTestAppConfigService(&model.AppConfig{ + SessionDuration: model.AppConfigVariable{Value: "60"}, + }) + jwtService, err := NewJwtService(t.Context(), db, mockConfig) + require.NoError(t, err) + + service := &OidcService{ + db: db, + jwtService: jwtService, + appConfigService: mockConfig, + } + + email := "refresh-token-user@example.com" + user := model.User{ + Username: "refresh-token-user", + Email: &email, + EmailVerified: true, + FirstName: "Refresh", + LastName: "User", + } + require.NoError(t, db.Create(&user).Error) + + client, err := service.CreateClient(t.Context(), dto.OidcClientCreateDto{ + OidcClientUpdateDto: dto.OidcClientUpdateDto{ + Name: "Refresh Token Client", + CallbackURLs: []string{"https://example.com/callback"}, + IsGroupRestricted: isGroupRestricted, + }, + }, user.ID) + require.NoError(t, err) + + clientSecret, err := service.CreateClientSecret(t.Context(), client.ID) + require.NoError(t, err) + + var userGroup *model.UserGroup + if isGroupRestricted { + group := model.UserGroup{ + FriendlyName: "Allowed Group", + Name: "allowed-group", + } + require.NoError(t, db.Create(&group).Error) + require.NoError(t, db.Model(&user).Association("UserGroups").Append(&group)) + require.NoError(t, db.Model(&client).Association("AllowedUserGroups").Append(&group)) + userGroup = &group + } + + scope := "openid profile email groups" + require.NoError(t, db.Create(&model.UserAuthorizedOidcClient{ + UserID: user.ID, + ClientID: client.ID, + Scope: scope, + }).Error) + + refreshToken, err := service.createRefreshToken(t.Context(), client.ID, user.ID, scope, AuthenticationMethodPhishingResistant, db) + require.NoError(t, err) + + return service, db, user, client, clientSecret, refreshToken, userGroup + } + + refreshInput := func(client model.OidcClient, clientSecret string, refreshToken string) dto.OidcCreateTokensDto { + return dto.OidcCreateTokensDto{ + GrantType: GrantTypeRefreshToken, + RefreshToken: refreshToken, + ClientID: client.ID, + ClientSecret: clientSecret, + } + } + + t.Run("rejects refresh token after authorization revocation", func(t *testing.T) { + service, db, user, client, clientSecret, refreshToken, _ := newFixture(t, false) + + err := service.RevokeAuthorizedClient(t.Context(), user.ID, client.ID) + require.NoError(t, err) + + var refreshTokenCount int64 + require.NoError(t, db.Model(&model.OidcRefreshToken{}). + Where("user_id = ? AND client_id = ?", user.ID, client.ID). + Count(&refreshTokenCount).Error) + assert.Zero(t, refreshTokenCount) + + _, err = service.createTokenFromRefreshToken(t.Context(), refreshInput(client, clientSecret, refreshToken)) + require.Error(t, err) + require.ErrorIs(t, err, &common.OidcInvalidRefreshTokenError{}) + }) + + t.Run("rejects and deletes stale refresh token without authorization record", func(t *testing.T) { + service, db, user, client, clientSecret, refreshToken, _ := newFixture(t, false) + + require.NoError(t, db. + Where("user_id = ? AND client_id = ?", user.ID, client.ID). + Delete(&model.UserAuthorizedOidcClient{}).Error) + + _, err := service.createTokenFromRefreshToken(t.Context(), refreshInput(client, clientSecret, refreshToken)) + require.Error(t, err) + require.ErrorIs(t, err, &common.OidcInvalidRefreshTokenError{}) + + var refreshTokenCount int64 + require.NoError(t, db.Model(&model.OidcRefreshToken{}). + Where("user_id = ? AND client_id = ?", user.ID, client.ID). + Count(&refreshTokenCount).Error) + assert.Zero(t, refreshTokenCount) + }) + + t.Run("rejects refresh token for disabled user", func(t *testing.T) { + service, db, user, client, clientSecret, refreshToken, _ := newFixture(t, false) + + require.NoError(t, db.Model(&model.User{}). + Where("id = ?", user.ID). + Update("disabled", true).Error) + + _, err := service.createTokenFromRefreshToken(t.Context(), refreshInput(client, clientSecret, refreshToken)) + require.Error(t, err) + require.ErrorIs(t, err, &common.OidcInvalidRefreshTokenError{}) + }) + + t.Run("rejects refresh token after user leaves allowed group", func(t *testing.T) { + service, db, user, client, clientSecret, refreshToken, userGroup := newFixture(t, true) + require.NotNil(t, userGroup) + + require.NoError(t, db.Model(&user).Association("UserGroups").Delete(userGroup)) + + _, err := service.createTokenFromRefreshToken(t.Context(), refreshInput(client, clientSecret, refreshToken)) + require.Error(t, err) + require.ErrorIs(t, err, &common.OidcAccessDeniedError{}) + }) +} + func TestOidcServiceAuthenticationMethodsPersistence(t *testing.T) { mockConfig := NewTestAppConfigService(&model.AppConfig{ SessionDuration: model.AppConfigVariable{Value: "60"},