mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-30 09:15:42 +03:00
fix: use transactions when operations involve multiple database queries (#392)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
This commit is contained in:
committed by
GitHub
parent
c810fec8c4
commit
ec626ee797
@@ -1,16 +1,19 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-webauthn/webauthn/protocol"
|
||||
"github.com/go-webauthn/webauthn/webauthn"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type WebAuthnService struct {
|
||||
@@ -43,15 +46,31 @@ func NewWebAuthnService(db *gorm.DB, jwtService *JwtService, auditLogService *Au
|
||||
return &WebAuthnService{db: db, webAuthn: wa, jwtService: jwtService, auditLogService: auditLogService, appConfigService: appConfigService}
|
||||
}
|
||||
|
||||
func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCredentialCreationOptions, error) {
|
||||
func (s *WebAuthnService) BeginRegistration(ctx context.Context, userID string) (*model.PublicKeyCredentialCreationOptions, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
// This is a no-op if the transaction has been committed already
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
s.updateWebAuthnConfig()
|
||||
|
||||
var user model.User
|
||||
if err := s.db.Preload("Credentials").Find(&user, "id = ?", userID).Error; err != nil {
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Preload("Credentials").
|
||||
Find(&user, "id = ?", userID).
|
||||
Error
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
options, session, err := s.webAuthn.BeginRegistration(&user, webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired), webauthn.WithExclusions(user.WebAuthnCredentialDescriptors()))
|
||||
options, session, err := s.webAuthn.BeginRegistration(
|
||||
&user,
|
||||
webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired),
|
||||
webauthn.WithExclusions(user.WebAuthnCredentialDescriptors()),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -62,7 +81,16 @@ func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCred
|
||||
UserVerification: string(session.UserVerification),
|
||||
}
|
||||
|
||||
if err := s.db.Create(&sessionToStore).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Create(&sessionToStore).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -73,9 +101,19 @@ func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCred
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.Request) (model.WebauthnCredential, error) {
|
||||
func (s *WebAuthnService) VerifyRegistration(ctx context.Context, sessionID, userID string, r *http.Request) (model.WebauthnCredential, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
// This is a no-op if the transaction has been committed already
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
var storedSession model.WebauthnSession
|
||||
if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil {
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
First(&storedSession, "id = ?", sessionID).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.WebauthnCredential{}, err
|
||||
}
|
||||
|
||||
@@ -86,7 +124,11 @@ func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.R
|
||||
}
|
||||
|
||||
var user model.User
|
||||
if err := s.db.Find(&user, "id = ?", userID).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Find(&user, "id = ?", userID).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.WebauthnCredential{}, err
|
||||
}
|
||||
|
||||
@@ -108,7 +150,16 @@ func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.R
|
||||
BackupEligible: credential.Flags.BackupEligible,
|
||||
BackupState: credential.Flags.BackupState,
|
||||
}
|
||||
if err := s.db.Create(&credentialToStore).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Create(&credentialToStore).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.WebauthnCredential{}, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return model.WebauthnCredential{}, err
|
||||
}
|
||||
|
||||
@@ -125,7 +176,7 @@ func (s *WebAuthnService) determinePasskeyName(aaguid []byte) string {
|
||||
return "New Passkey" // Default fallback
|
||||
}
|
||||
|
||||
func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions, error) {
|
||||
func (s *WebAuthnService) BeginLogin(ctx context.Context) (*model.PublicKeyCredentialRequestOptions, error) {
|
||||
options, session, err := s.webAuthn.BeginDiscoverableLogin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -137,7 +188,11 @@ func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions
|
||||
UserVerification: string(session.UserVerification),
|
||||
}
|
||||
|
||||
if err := s.db.Create(&sessionToStore).Error; err != nil {
|
||||
err = s.db.
|
||||
WithContext(ctx).
|
||||
Create(&sessionToStore).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -148,9 +203,19 @@ func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *WebAuthnService) VerifyLogin(sessionID string, credentialAssertionData *protocol.ParsedCredentialAssertionData, ipAddress, userAgent string) (model.User, string, error) {
|
||||
func (s *WebAuthnService) VerifyLogin(ctx context.Context, sessionID string, credentialAssertionData *protocol.ParsedCredentialAssertionData, ipAddress, userAgent string) (model.User, string, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
// This is a no-op if the transaction has been committed already
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
var storedSession model.WebauthnSession
|
||||
if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil {
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
First(&storedSession, "id = ?", sessionID).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.User{}, "", err
|
||||
}
|
||||
|
||||
@@ -160,9 +225,14 @@ func (s *WebAuthnService) VerifyLogin(sessionID string, credentialAssertionData
|
||||
}
|
||||
|
||||
var user *model.User
|
||||
_, err := s.webAuthn.ValidateDiscoverableLogin(func(_, userHandle []byte) (webauthn.User, error) {
|
||||
if err := s.db.Preload("Credentials").First(&user, "id = ?", string(userHandle)).Error; err != nil {
|
||||
return nil, err
|
||||
_, err = s.webAuthn.ValidateDiscoverableLogin(func(_, userHandle []byte) (webauthn.User, error) {
|
||||
innerErr := tx.
|
||||
WithContext(ctx).
|
||||
Preload("Credentials").
|
||||
First(&user, "id = ?", string(userHandle)).
|
||||
Error
|
||||
if innerErr != nil {
|
||||
return nil, innerErr
|
||||
}
|
||||
return user, nil
|
||||
}, session, credentialAssertionData)
|
||||
@@ -176,41 +246,70 @@ func (s *WebAuthnService) VerifyLogin(sessionID string, credentialAssertionData
|
||||
return model.User{}, "", err
|
||||
}
|
||||
|
||||
s.auditLogService.CreateNewSignInWithEmail(ipAddress, userAgent, user.ID)
|
||||
s.auditLogService.CreateNewSignInWithEmail(ctx, ipAddress, userAgent, user.ID, tx)
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return model.User{}, "", err
|
||||
}
|
||||
|
||||
return *user, token, nil
|
||||
}
|
||||
|
||||
func (s *WebAuthnService) ListCredentials(userID string) ([]model.WebauthnCredential, error) {
|
||||
func (s *WebAuthnService) ListCredentials(ctx context.Context, userID string) ([]model.WebauthnCredential, error) {
|
||||
var credentials []model.WebauthnCredential
|
||||
if err := s.db.Find(&credentials, "user_id = ?", userID).Error; err != nil {
|
||||
err := s.db.
|
||||
WithContext(ctx).
|
||||
Find(&credentials, "user_id = ?", userID).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return credentials, nil
|
||||
}
|
||||
|
||||
func (s *WebAuthnService) DeleteCredential(userID, credentialID string) error {
|
||||
var credential model.WebauthnCredential
|
||||
if err := s.db.First(&credential, "id = ? AND user_id = ?", credentialID, userID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.db.Delete(&credential).Error; err != nil {
|
||||
return err
|
||||
func (s *WebAuthnService) DeleteCredential(ctx context.Context, userID, credentialID string) error {
|
||||
err := s.db.
|
||||
WithContext(ctx).
|
||||
Where("id = ? AND user_id = ?", credentialID, userID).
|
||||
Delete(&model.WebauthnCredential{}).
|
||||
Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete record: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WebAuthnService) UpdateCredential(userID, credentialID, name string) (model.WebauthnCredential, error) {
|
||||
func (s *WebAuthnService) UpdateCredential(ctx context.Context, userID, credentialID, name string) (model.WebauthnCredential, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
// This is a no-op if the transaction has been committed already
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
var credential model.WebauthnCredential
|
||||
if err := s.db.Where("id = ? AND user_id = ?", credentialID, userID).First(&credential).Error; err != nil {
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Where("id = ? AND user_id = ?", credentialID, userID).
|
||||
First(&credential).
|
||||
Error
|
||||
if err != nil {
|
||||
return credential, err
|
||||
}
|
||||
|
||||
credential.Name = name
|
||||
|
||||
if err := s.db.Save(&credential).Error; err != nil {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Save(&credential).
|
||||
Error
|
||||
if err != nil {
|
||||
return credential, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return credential, err
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user