fix: ensure indexes on audit_logs table (#415)

Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
This commit is contained in:
Alessandro (Ale) Segala
2025-04-04 10:05:32 -07:00
committed by GitHub
parent 731113183e
commit 9e88926283
9 changed files with 63 additions and 34 deletions

View File

@@ -102,7 +102,7 @@ func (alc *AuditLogController) listAllAuditLogsHandler(c *gin.Context) {
return return
} }
logs, pagination, err := alc.auditLogService.ListAllAuditLogs(sortedPaginationRequest, filters) logs, pagination, err := alc.auditLogService.ListAllAuditLogs(c.Request.Context(), sortedPaginationRequest, filters)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -134,7 +134,7 @@ func (alc *AuditLogController) listAllAuditLogsHandler(c *gin.Context) {
// @Success 200 {array} string "List of client names" // @Success 200 {array} string "List of client names"
// @Router /api/audit-logs/filters/client-names [get] // @Router /api/audit-logs/filters/client-names [get]
func (alc *AuditLogController) listClientNamesHandler(c *gin.Context) { func (alc *AuditLogController) listClientNamesHandler(c *gin.Context) {
names, err := alc.auditLogService.ListClientNames() names, err := alc.auditLogService.ListClientNames(c.Request.Context())
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -150,7 +150,7 @@ func (alc *AuditLogController) listClientNamesHandler(c *gin.Context) {
// @Success 200 {object} map[string]string "Map of user IDs to usernames" // @Success 200 {object} map[string]string "Map of user IDs to usernames"
// @Router /api/audit-logs/filters/users [get] // @Router /api/audit-logs/filters/users [get]
func (alc *AuditLogController) listUserNamesWithIdsHandler(c *gin.Context) { func (alc *AuditLogController) listUserNamesWithIdsHandler(c *gin.Context) {
users, err := alc.auditLogService.ListUsernamesWithIds() users, err := alc.auditLogService.ListUsernamesWithIds(c.Request.Context())
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return

View File

@@ -3,7 +3,7 @@ package model
import ( import (
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
"errors" "fmt"
) )
type AuditLog struct { type AuditLog struct {
@@ -34,7 +34,7 @@ const (
// Scan and Value methods for GORM to handle the custom type // Scan and Value methods for GORM to handle the custom type
func (e *AuditLogEvent) Scan(value interface{}) error { func (e *AuditLogEvent) Scan(value any) error {
*e = AuditLogEvent(value.(string)) *e = AuditLogEvent(value.(string))
return nil return nil
} }
@@ -43,11 +43,14 @@ func (e AuditLogEvent) Value() (driver.Value, error) {
return string(e), nil return string(e), nil
} }
func (d *AuditLogData) Scan(value interface{}) error { func (d *AuditLogData) Scan(value any) error {
if v, ok := value.([]byte); ok { switch v := value.(type) {
case []byte:
return json.Unmarshal(v, d) return json.Unmarshal(v, d)
} else { case string:
return errors.New("type assertion to []byte failed") return json.Unmarshal([]byte(v), d)
default:
return fmt.Errorf("unsupported type: %T", value)
} }
} }

View File

@@ -3,7 +3,7 @@ package model
import ( import (
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
"errors" "fmt"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"gorm.io/gorm" "gorm.io/gorm"
@@ -74,10 +74,13 @@ func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) {
type UrlList []string //nolint:recvcheck type UrlList []string //nolint:recvcheck
func (cu *UrlList) Scan(value interface{}) error { func (cu *UrlList) Scan(value interface{}) error {
if v, ok := value.([]byte); ok { switch v := value.(type) {
case []byte:
return json.Unmarshal(v, cu) return json.Unmarshal(v, cu)
} else { case string:
return errors.New("type assertion to []byte failed") return json.Unmarshal([]byte(v), cu)
default:
return fmt.Errorf("unsupported type: %T", value)
} }
} }

View File

@@ -3,7 +3,7 @@ package model
import ( import (
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
"errors" "fmt"
"time" "time"
"github.com/go-webauthn/webauthn/protocol" "github.com/go-webauthn/webauthn/protocol"
@@ -49,11 +49,13 @@ type AuthenticatorTransportList []protocol.AuthenticatorTransport //nolint:recvc
// Scan and Value methods for GORM to handle the custom type // Scan and Value methods for GORM to handle the custom type
func (atl *AuthenticatorTransportList) Scan(value interface{}) error { func (atl *AuthenticatorTransportList) Scan(value interface{}) error {
switch v := value.(type) {
if v, ok := value.([]byte); ok { case []byte:
return json.Unmarshal(v, atl) return json.Unmarshal(v, atl)
} else { case string:
return errors.New("type assertion to []byte failed") return json.Unmarshal([]byte(v), atl)
default:
return fmt.Errorf("unsupported type: %T", value)
} }
} }

View File

@@ -1,6 +1,7 @@
package service package service
import ( import (
"context"
"fmt" "fmt"
"log" "log"
@@ -100,10 +101,13 @@ func (s *AuditLogService) DeviceStringFromUserAgent(userAgent string) string {
return ua.Name + " on " + ua.OS + " " + ua.OSVersion return ua.Name + " on " + ua.OS + " " + ua.OSVersion
} }
func (s *AuditLogService) ListAllAuditLogs(sortedPaginationRequest utils.SortedPaginationRequest, filters dto.AuditLogFilterDto) ([]model.AuditLog, utils.PaginationResponse, error) { func (s *AuditLogService) ListAllAuditLogs(ctx context.Context, sortedPaginationRequest utils.SortedPaginationRequest, filters dto.AuditLogFilterDto) ([]model.AuditLog, utils.PaginationResponse, error) {
var logs []model.AuditLog var logs []model.AuditLog
query := s.db.Preload("User").Model(&model.AuditLog{}) query := s.db.
WithContext(ctx).
Preload("User").
Model(&model.AuditLog{})
if filters.UserID != "" { if filters.UserID != "" {
query = query.Where("user_id = ?", filters.UserID) query = query.Where("user_id = ?", filters.UserID)
@@ -131,8 +135,11 @@ func (s *AuditLogService) ListAllAuditLogs(sortedPaginationRequest utils.SortedP
return logs, pagination, nil return logs, pagination, nil
} }
func (s *AuditLogService) ListUsernamesWithIds() (users map[string]string, err error) { func (s *AuditLogService) ListUsernamesWithIds(ctx context.Context) (users map[string]string, err error) {
query := s.db.Joins("User").Model(&model.AuditLog{}). query := s.db.
WithContext(ctx).
Joins("User").
Model(&model.AuditLog{}).
Select("DISTINCT User.id, User.username"). Select("DISTINCT User.id, User.username").
Where("User.username IS NOT NULL") Where("User.username IS NOT NULL")
@@ -146,7 +153,7 @@ func (s *AuditLogService) ListUsernamesWithIds() (users map[string]string, err e
return nil, fmt.Errorf("failed to query user IDs: %w", err) return nil, fmt.Errorf("failed to query user IDs: %w", err)
} }
users = make(map[string]string) users = make(map[string]string, len(results))
for _, result := range results { for _, result := range results {
users[result.ID] = result.Username users[result.ID] = result.Username
} }
@@ -154,25 +161,27 @@ func (s *AuditLogService) ListUsernamesWithIds() (users map[string]string, err e
return users, nil return users, nil
} }
func (s *AuditLogService) ListClientNames() (clientNames []string, err error) { func (s *AuditLogService) ListClientNames(ctx context.Context) (clientNames []string, err error) {
dialect := s.db.Name() query := s.db.
var query *gorm.DB WithContext(ctx).
Model(&model.AuditLog{})
dialect := s.db.Name()
switch dialect { switch dialect {
case "sqlite": case "sqlite":
query = s.db.Model(&model.AuditLog{}). query = query.
Select("DISTINCT json_extract(data, '$.clientName') as clientName"). Select("DISTINCT json_extract(data, '$.clientName') as client_name").
Where("json_extract(data, '$.clientName') IS NOT NULL") Where("json_extract(data, '$.clientName') IS NOT NULL")
case "postgres": case "postgres":
query = s.db.Model(&model.AuditLog{}). query = query.
Select("DISTINCT data->>'clientName' as clientName"). Select("DISTINCT data->>'clientName' as client_name").
Where("data->>'clientName' IS NOT NULL") Where("data->>'clientName' IS NOT NULL")
default: default:
return nil, fmt.Errorf("unsupported database dialect: %s", dialect) return nil, fmt.Errorf("unsupported database dialect: %s", dialect)
} }
type Result struct { type Result struct {
ClientName string `gorm:"column:clientName"` ClientName string `gorm:"column:client_name"`
} }
var results []Result var results []Result
@@ -180,9 +189,9 @@ func (s *AuditLogService) ListClientNames() (clientNames []string, err error) {
return nil, fmt.Errorf("failed to query client IDs: %w", err) return nil, fmt.Errorf("failed to query client IDs: %w", err)
} }
for _, result := range results { clientNames = make([]string, len(results))
clientNames = append(clientNames, result.ClientName) for i, result := range results {
clientNames[i] = result.ClientName
} }
return clientNames, nil return clientNames, nil

View File

@@ -0,0 +1,3 @@
DROP INDEX IF EXISTS idx_audit_logs_event;
DROP INDEX IF EXISTS idx_audit_logs_user_id;
DROP INDEX IF EXISTS idx_audit_logs_client_name;

View File

@@ -0,0 +1,3 @@
CREATE INDEX idx_audit_logs_event ON audit_logs(event);
CREATE INDEX idx_audit_logs_user_id ON audit_logs(user_id);
CREATE INDEX idx_audit_logs_client_name ON audit_logs(("data"->>'clientName'));

View File

@@ -0,0 +1,3 @@
DROP INDEX IF EXISTS idx_audit_logs_event;
DROP INDEX IF EXISTS idx_audit_logs_user_id;
DROP INDEX IF EXISTS idx_audit_logs_client_name;

View File

@@ -0,0 +1,3 @@
CREATE INDEX idx_audit_logs_event ON audit_logs(event);
CREATE INDEX idx_audit_logs_user_id ON audit_logs(user_id);
CREATE INDEX idx_audit_logs_client_name ON audit_logs((json_extract(data, '$.clientName')));