fix: use transactions when operations involve multiple database queries (#392)

Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
This commit is contained in:
Alessandro (Ale) Segala
2025-04-06 06:04:08 -07:00
committed by GitHub
parent c810fec8c4
commit ec626ee797
33 changed files with 1401 additions and 501 deletions

View File

@@ -1,8 +1,8 @@
package service
import (
"context"
"errors"
"log"
"time"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
@@ -12,6 +12,7 @@ import (
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type ApiKeyService struct {
@@ -22,8 +23,11 @@ func NewApiKeyService(db *gorm.DB) *ApiKeyService {
return &ApiKeyService{db: db}
}
func (s *ApiKeyService) ListApiKeys(userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.ApiKey, utils.PaginationResponse, error) {
query := s.db.Where("user_id = ?", userID).Model(&model.ApiKey{})
func (s *ApiKeyService) ListApiKeys(ctx context.Context, userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.ApiKey, utils.PaginationResponse, error) {
query := s.db.
WithContext(ctx).
Where("user_id = ?", userID).
Model(&model.ApiKey{})
var apiKeys []model.ApiKey
pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &apiKeys)
@@ -34,7 +38,7 @@ func (s *ApiKeyService) ListApiKeys(userID string, sortedPaginationRequest utils
return apiKeys, pagination, nil
}
func (s *ApiKeyService) CreateApiKey(userID string, input dto.ApiKeyCreateDto) (model.ApiKey, string, error) {
func (s *ApiKeyService) CreateApiKey(ctx context.Context, userID string, input dto.ApiKeyCreateDto) (model.ApiKey, string, error) {
// Check if expiration is in the future
if !input.ExpiresAt.ToTime().After(time.Now()) {
return model.ApiKey{}, "", &common.APIKeyExpirationDateError{}
@@ -54,7 +58,11 @@ func (s *ApiKeyService) CreateApiKey(userID string, input dto.ApiKeyCreateDto) (
UserID: userID,
}
if err := s.db.Create(&apiKey).Error; err != nil {
err = s.db.
WithContext(ctx).
Create(&apiKey).
Error
if err != nil {
return model.ApiKey{}, "", err
}
@@ -62,29 +70,44 @@ func (s *ApiKeyService) CreateApiKey(userID string, input dto.ApiKeyCreateDto) (
return apiKey, token, nil
}
func (s *ApiKeyService) RevokeApiKey(userID, apiKeyID string) error {
func (s *ApiKeyService) RevokeApiKey(ctx context.Context, userID, apiKeyID string) error {
var apiKey model.ApiKey
if err := s.db.Where("id = ? AND user_id = ?", apiKeyID, userID).First(&apiKey).Error; err != nil {
err := s.db.
WithContext(ctx).
Where("id = ? AND user_id = ?", apiKeyID, userID).
Delete(&apiKey).
Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return &common.APIKeyNotFoundError{}
}
return err
}
return s.db.Delete(&apiKey).Error
return nil
}
func (s *ApiKeyService) ValidateApiKey(apiKey string) (model.User, error) {
func (s *ApiKeyService) ValidateApiKey(ctx context.Context, apiKey string) (model.User, error) {
if apiKey == "" {
return model.User{}, &common.NoAPIKeyProvidedError{}
}
var key model.ApiKey
now := time.Now()
hashedKey := utils.CreateSha256Hash(apiKey)
if err := s.db.Preload("User").Where("key = ? AND expires_at > ?",
hashedKey, datatype.DateTime(time.Now())).Preload("User").First(&key).Error; err != nil {
var key model.ApiKey
err := s.db.
WithContext(ctx).
Model(&model.ApiKey{}).
Clauses(clause.Returning{}).
Where("key = ? AND expires_at > ?", hashedKey, datatype.DateTime(now)).
Updates(&model.ApiKey{
LastUsedAt: utils.Ptr(datatype.DateTime(now)),
}).
Preload("User").
First(&key).
Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return model.User{}, &common.InvalidAPIKeyError{}
}
@@ -92,12 +115,5 @@ func (s *ApiKeyService) ValidateApiKey(apiKey string) (model.User, error) {
return model.User{}, err
}
// Update last used time
now := datatype.DateTime(time.Now())
key.LastUsedAt = &now
if err := s.db.Save(&key).Error; err != nil {
log.Printf("Failed to update last used time: %v", err)
}
return key.User, nil
}