fix: auth fails when client IP is empty on Postgres (#695)

This commit is contained in:
Alessandro (Ale) Segala
2025-06-30 05:04:30 -07:00
committed by GitHub
parent dbf3da41f3
commit 031181ad2a
10 changed files with 122 additions and 51 deletions

View File

@@ -1,7 +1,6 @@
package dto package dto
import ( import (
"github.com/pocket-id/pocket-id/backend/internal/model"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
) )
@@ -9,14 +8,14 @@ type AuditLogDto struct {
ID string `json:"id"` ID string `json:"id"`
CreatedAt datatype.DateTime `json:"createdAt"` CreatedAt datatype.DateTime `json:"createdAt"`
Event model.AuditLogEvent `json:"event"` Event string `json:"event"`
IpAddress string `json:"ipAddress"` IpAddress string `json:"ipAddress"`
Country string `json:"country"` Country string `json:"country"`
City string `json:"city"` City string `json:"city"`
Device string `json:"device"` Device string `json:"device"`
UserID string `json:"userID"` UserID string `json:"userID"`
Username string `json:"username"` Username string `json:"username"`
Data model.AuditLogData `json:"data"` Data map[string]string `json:"data"`
} }
type AuditLogFilterDto struct { type AuditLogFilterDto struct {

View File

@@ -8,13 +8,13 @@ import (
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
) )
var validateUsername validator.Func = func(fl validator.FieldLevel) bool {
// [a-zA-Z0-9] : The username must start with an alphanumeric character // [a-zA-Z0-9] : The username must start with an alphanumeric character
// [a-zA-Z0-9_.@-]* : The rest of the username can contain alphanumeric characters, dots, underscores, hyphens, and "@" symbols // [a-zA-Z0-9_.@-]* : The rest of the username can contain alphanumeric characters, dots, underscores, hyphens, and "@" symbols
// [a-zA-Z0-9]$ : The username must end with an alphanumeric character // [a-zA-Z0-9]$ : The username must end with an alphanumeric character
regex := "^[a-zA-Z0-9][a-zA-Z0-9_.@-]*[a-zA-Z0-9]$" var validateUsernameRegex = regexp.MustCompile("^[a-zA-Z0-9][a-zA-Z0-9_.@-]*[a-zA-Z0-9]$")
matched, _ := regexp.MatchString(regex, fl.Field().String())
return matched var validateUsername validator.Func = func(fl validator.FieldLevel) bool {
return validateUsernameRegex.MatchString(fl.Field().String())
} }
func init() { func init() {

View File

@@ -10,7 +10,7 @@ type AuditLog struct {
Base Base
Event AuditLogEvent `sortable:"true"` Event AuditLogEvent `sortable:"true"`
IpAddress string `sortable:"true"` IpAddress *string `sortable:"true"`
Country string `sortable:"true"` Country string `sortable:"true"`
City string `sortable:"true"` City string `sortable:"true"`
UserAgent string `sortable:"true"` UserAgent string `sortable:"true"`

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"log" "log"
"log/slog"
userAgentParser "github.com/mileusna/useragent" userAgentParser "github.com/mileusna/useragent"
"github.com/pocket-id/pocket-id/backend/internal/dto" "github.com/pocket-id/pocket-id/backend/internal/dto"
@@ -25,15 +26,15 @@ func NewAuditLogService(db *gorm.DB, appConfigService *AppConfigService, emailSe
} }
// Create creates a new audit log entry in the database // Create creates a new audit log entry in the database
func (s *AuditLogService) Create(ctx context.Context, event model.AuditLogEvent, ipAddress, userAgent, userID string, data model.AuditLogData, tx *gorm.DB) model.AuditLog { func (s *AuditLogService) Create(ctx context.Context, event model.AuditLogEvent, ipAddress, userAgent, userID string, data model.AuditLogData, tx *gorm.DB) (model.AuditLog, bool) {
country, city, err := s.geoliteService.GetLocationByIP(ipAddress) country, city, err := s.geoliteService.GetLocationByIP(ipAddress)
if err != nil { if err != nil {
log.Printf("Failed to get IP location: %v", err) // Log the error but don't interrupt the operation
slog.Warn("Failed to get IP location", "error", err)
} }
auditLog := model.AuditLog{ auditLog := model.AuditLog{
Event: event, Event: event,
IpAddress: ipAddress,
Country: country, Country: country,
City: city, City: city,
UserAgent: userAgent, UserAgent: userAgent,
@@ -41,22 +42,31 @@ func (s *AuditLogService) Create(ctx context.Context, event model.AuditLogEvent,
Data: data, Data: data,
} }
if ipAddress != "" {
// Only set ipAddress if not empty, because on Postgres we use INET columns that don't allow non-null empty values
auditLog.IpAddress = &ipAddress
}
// Save the audit log in the database // Save the audit log in the database
err = tx. err = tx.
WithContext(ctx). WithContext(ctx).
Create(&auditLog). Create(&auditLog).
Error Error
if err != nil { if err != nil {
log.Printf("Failed to create audit log: %v", err) slog.Error("Failed to create audit log", "error", err)
return model.AuditLog{} return model.AuditLog{}, false
} }
return auditLog return auditLog, true
} }
// CreateNewSignInWithEmail creates a new audit log entry in the database and sends an email if the device hasn't been used before // CreateNewSignInWithEmail creates a new audit log entry in the database and sends an email if the device hasn't been used before
func (s *AuditLogService) CreateNewSignInWithEmail(ctx context.Context, ipAddress, userAgent, userID string, tx *gorm.DB) model.AuditLog { func (s *AuditLogService) CreateNewSignInWithEmail(ctx context.Context, ipAddress, userAgent, userID string, tx *gorm.DB) model.AuditLog {
createdAuditLog := s.Create(ctx, model.AuditLogEventSignIn, ipAddress, userAgent, userID, model.AuditLogData{}, tx) createdAuditLog, ok := s.Create(ctx, model.AuditLogEventSignIn, ipAddress, userAgent, userID, model.AuditLogData{}, tx)
if !ok {
// At this point the transaction has been canceled already, and error has been logged
return createdAuditLog
}
// Count the number of times the user has logged in from the same device // Count the number of times the user has logged in from the same device
var count int64 var count int64
@@ -67,7 +77,7 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ctx context.Context, ipAddres
Count(&count). Count(&count).
Error Error
if err != nil { if err != nil {
log.Printf("Failed to count audit logs: %v\n", err) log.Printf("Failed to count audit logs: %v", err)
return createdAuditLog return createdAuditLog
} }

View File

@@ -122,6 +122,10 @@ func (s *GeoLiteService) DisableUpdater() bool {
// GetLocationByIP returns the country and city of the given IP address. // GetLocationByIP returns the country and city of the given IP address.
func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string, err error) { func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string, err error) {
if ipAddress == "" {
return "", "", nil
}
// Check the IP address against known private IP ranges // Check the IP address against known private IP ranges
if ip := net.ParseIP(ipAddress); ip != nil { if ip := net.ParseIP(ipAddress); ip != nil {
// Check IPv6 local ranges first // Check IPv6 local ranges first
@@ -147,6 +151,11 @@ func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string
} }
} }
addr, err := netip.ParseAddr(ipAddress)
if err != nil {
return "", "", fmt.Errorf("failed to parse IP address: %w", err)
}
// Race condition between reading and writing the database. // Race condition between reading and writing the database.
s.mutex.RLock() s.mutex.RLock()
defer s.mutex.RUnlock() defer s.mutex.RUnlock()
@@ -157,11 +166,6 @@ func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string
} }
defer db.Close() defer db.Close()
addr, err := netip.ParseAddr(ipAddress)
if err != nil {
return "", "", fmt.Errorf("failed to parse IP address: %w", err)
}
var record struct { var record struct {
City struct { City struct {
Names map[string]string `maxminddb:"names"` Names map[string]string `maxminddb:"names"`

View File

@@ -6,6 +6,8 @@ import (
"testing" "testing"
"github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestGeoLiteService_IPv6LocalRanges(t *testing.T) { func TestGeoLiteService_IPv6LocalRanges(t *testing.T) {
@@ -80,15 +82,9 @@ func TestGeoLiteService_IPv6LocalRanges(t *testing.T) {
t.Errorf("Expected error or internal network classification for external IP") t.Errorf("Expected error or internal network classification for external IP")
} }
} else { } else {
if err != nil { require.NoError(t, err)
t.Errorf("Expected no error for local IP, got: %v", err) assert.Equal(t, tt.expectedCountry, country)
} assert.Equal(t, tt.expectedCity, city)
if country != tt.expectedCountry {
t.Errorf("Expected country %s, got %s", tt.expectedCountry, country)
}
if city != tt.expectedCity {
t.Errorf("Expected city %s, got %s", tt.expectedCity, city)
}
} }
}) })
} }
@@ -148,9 +144,7 @@ func TestGeoLiteService_isLocalIPv6(t *testing.T) {
} }
result := service.isLocalIPv6(ip) result := service.isLocalIPv6(ip)
if result != tt.expected { assert.Equal(t, tt.expected, result)
t.Errorf("Expected %v, got %v for IP %s", tt.expected, result, tt.testIP)
}
}) })
} }
} }
@@ -214,18 +208,13 @@ func TestGeoLiteService_initializeIPv6LocalRanges(t *testing.T) {
err := service.initializeIPv6LocalRanges() err := service.initializeIPv6LocalRanges()
if tt.expectError && err == nil { if tt.expectError {
t.Errorf("Expected error but got none") require.Error(t, err)
} } else {
if !tt.expectError && err != nil { require.NoError(t, err)
t.Errorf("Expected no error but got: %v", err)
} }
rangeCount := len(service.localIPv6Ranges) assert.Len(t, service.localIPv6Ranges, tt.expectCount)
if rangeCount != tt.expectCount {
t.Errorf("Expected %d ranges, got %d", tt.expectCount, rangeCount)
}
}) })
} }
} }

View File

@@ -0,0 +1,4 @@
ALTER TABLE audit_logs ALTER COLUMN ip_address SET NOT NULL;
DROP INDEX IF EXISTS idx_audit_logs_created_at;
DROP INDEX IF EXISTS idx_audit_logs_user_agent;

View File

@@ -0,0 +1,5 @@
ALTER TABLE audit_logs ALTER COLUMN ip_address DROP NOT NULL;
-- Add missing indexes
CREATE INDEX idx_audit_logs_created_at ON audit_logs(created_at);
CREATE INDEX idx_audit_logs_user_agent ON audit_logs(user_agent);

View File

@@ -0,0 +1,30 @@
-- Re-create the table with non-nullable ip_address
-- We then move the data and rename the table
CREATE TABLE audit_logs_new
(
id TEXT NOT NULL PRIMARY KEY,
created_at DATETIME,
event TEXT NOT NULL,
ip_address TEXT NOT NULL,
user_agent TEXT NOT NULL,
data BLOB NOT NULL,
user_id TEXT REFERENCES users,
country TEXT,
city TEXT
);
INSERT INTO audit_logs_new
SELECT id, created_at, event, ip_address, user_agent, data, user_id, country, city
FROM audit_logs;
DROP TABLE audit_logs;
ALTER TABLE audit_logs_new RENAME TO audit_logs;
-- Re-create indexes
CREATE INDEX idx_audit_logs_event ON audit_logs(event);
CREATE INDEX idx_audit_logs_created_at ON audit_logs(created_at);
CREATE INDEX idx_audit_logs_user_id ON audit_logs(user_id);
CREATE INDEX idx_audit_logs_user_agent ON audit_logs(user_agent);
CREATE INDEX idx_audit_logs_client_name ON audit_logs((json_extract(data, '$.clientName')));
CREATE INDEX idx_audit_logs_country ON audit_logs(country);

View File

@@ -0,0 +1,30 @@
-- Re-create the table with nullable ip_address
-- We then move the data and rename the table
CREATE TABLE audit_logs_new
(
id TEXT NOT NULL PRIMARY KEY,
created_at DATETIME,
event TEXT NOT NULL,
ip_address TEXT,
user_agent TEXT NOT NULL,
data BLOB NOT NULL,
user_id TEXT REFERENCES users,
country TEXT,
city TEXT
);
INSERT INTO audit_logs_new
SELECT id, created_at, event, ip_address, user_agent, data, user_id, country, city
FROM audit_logs;
DROP TABLE audit_logs;
ALTER TABLE audit_logs_new RENAME TO audit_logs;
-- Re-create indexes
CREATE INDEX idx_audit_logs_event ON audit_logs(event);
CREATE INDEX idx_audit_logs_created_at ON audit_logs(created_at);
CREATE INDEX idx_audit_logs_user_id ON audit_logs(user_id);
CREATE INDEX idx_audit_logs_user_agent ON audit_logs(user_agent);
CREATE INDEX idx_audit_logs_client_name ON audit_logs((json_extract(data, '$.clientName')));
CREATE INDEX idx_audit_logs_country ON audit_logs(country);