mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-09 14:52:57 +03:00
fix: auth fails when client IP is empty on Postgres (#695)
This commit is contained in:
committed by
GitHub
parent
dbf3da41f3
commit
031181ad2a
@@ -1,7 +1,6 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
)
|
||||
|
||||
@@ -9,14 +8,14 @@ type AuditLogDto struct {
|
||||
ID string `json:"id"`
|
||||
CreatedAt datatype.DateTime `json:"createdAt"`
|
||||
|
||||
Event model.AuditLogEvent `json:"event"`
|
||||
IpAddress string `json:"ipAddress"`
|
||||
Country string `json:"country"`
|
||||
City string `json:"city"`
|
||||
Device string `json:"device"`
|
||||
UserID string `json:"userID"`
|
||||
Username string `json:"username"`
|
||||
Data model.AuditLogData `json:"data"`
|
||||
Event string `json:"event"`
|
||||
IpAddress string `json:"ipAddress"`
|
||||
Country string `json:"country"`
|
||||
City string `json:"city"`
|
||||
Device string `json:"device"`
|
||||
UserID string `json:"userID"`
|
||||
Username string `json:"username"`
|
||||
Data map[string]string `json:"data"`
|
||||
}
|
||||
|
||||
type AuditLogFilterDto struct {
|
||||
|
||||
@@ -8,13 +8,13 @@ import (
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
// [a-zA-Z0-9] : The username must start with an alphanumeric character
|
||||
// [a-zA-Z0-9_.@-]* : The rest of the username can contain alphanumeric characters, dots, underscores, hyphens, and "@" symbols
|
||||
// [a-zA-Z0-9]$ : The username must end with an alphanumeric character
|
||||
var validateUsernameRegex = regexp.MustCompile("^[a-zA-Z0-9][a-zA-Z0-9_.@-]*[a-zA-Z0-9]$")
|
||||
|
||||
var validateUsername validator.Func = func(fl validator.FieldLevel) bool {
|
||||
// [a-zA-Z0-9] : The username must start with an alphanumeric character
|
||||
// [a-zA-Z0-9_.@-]* : The rest of the username can contain alphanumeric characters, dots, underscores, hyphens, and "@" symbols
|
||||
// [a-zA-Z0-9]$ : The username must end with an alphanumeric character
|
||||
regex := "^[a-zA-Z0-9][a-zA-Z0-9_.@-]*[a-zA-Z0-9]$"
|
||||
matched, _ := regexp.MatchString(regex, fl.Field().String())
|
||||
return matched
|
||||
return validateUsernameRegex.MatchString(fl.Field().String())
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -10,7 +10,7 @@ type AuditLog struct {
|
||||
Base
|
||||
|
||||
Event AuditLogEvent `sortable:"true"`
|
||||
IpAddress string `sortable:"true"`
|
||||
IpAddress *string `sortable:"true"`
|
||||
Country string `sortable:"true"`
|
||||
City string `sortable:"true"`
|
||||
UserAgent string `sortable:"true"`
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
|
||||
userAgentParser "github.com/mileusna/useragent"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||
@@ -25,15 +26,15 @@ func NewAuditLogService(db *gorm.DB, appConfigService *AppConfigService, emailSe
|
||||
}
|
||||
|
||||
// Create creates a new audit log entry in the database
|
||||
func (s *AuditLogService) Create(ctx context.Context, event model.AuditLogEvent, ipAddress, userAgent, userID string, data model.AuditLogData, tx *gorm.DB) model.AuditLog {
|
||||
func (s *AuditLogService) Create(ctx context.Context, event model.AuditLogEvent, ipAddress, userAgent, userID string, data model.AuditLogData, tx *gorm.DB) (model.AuditLog, bool) {
|
||||
country, city, err := s.geoliteService.GetLocationByIP(ipAddress)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get IP location: %v", err)
|
||||
// Log the error but don't interrupt the operation
|
||||
slog.Warn("Failed to get IP location", "error", err)
|
||||
}
|
||||
|
||||
auditLog := model.AuditLog{
|
||||
Event: event,
|
||||
IpAddress: ipAddress,
|
||||
Country: country,
|
||||
City: city,
|
||||
UserAgent: userAgent,
|
||||
@@ -41,22 +42,31 @@ func (s *AuditLogService) Create(ctx context.Context, event model.AuditLogEvent,
|
||||
Data: data,
|
||||
}
|
||||
|
||||
if ipAddress != "" {
|
||||
// Only set ipAddress if not empty, because on Postgres we use INET columns that don't allow non-null empty values
|
||||
auditLog.IpAddress = &ipAddress
|
||||
}
|
||||
|
||||
// Save the audit log in the database
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Create(&auditLog).
|
||||
Error
|
||||
if err != nil {
|
||||
log.Printf("Failed to create audit log: %v", err)
|
||||
return model.AuditLog{}
|
||||
slog.Error("Failed to create audit log", "error", err)
|
||||
return model.AuditLog{}, false
|
||||
}
|
||||
|
||||
return auditLog
|
||||
return auditLog, true
|
||||
}
|
||||
|
||||
// CreateNewSignInWithEmail creates a new audit log entry in the database and sends an email if the device hasn't been used before
|
||||
func (s *AuditLogService) CreateNewSignInWithEmail(ctx context.Context, ipAddress, userAgent, userID string, tx *gorm.DB) model.AuditLog {
|
||||
createdAuditLog := s.Create(ctx, model.AuditLogEventSignIn, ipAddress, userAgent, userID, model.AuditLogData{}, tx)
|
||||
createdAuditLog, ok := s.Create(ctx, model.AuditLogEventSignIn, ipAddress, userAgent, userID, model.AuditLogData{}, tx)
|
||||
if !ok {
|
||||
// At this point the transaction has been canceled already, and error has been logged
|
||||
return createdAuditLog
|
||||
}
|
||||
|
||||
// Count the number of times the user has logged in from the same device
|
||||
var count int64
|
||||
@@ -67,7 +77,7 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ctx context.Context, ipAddres
|
||||
Count(&count).
|
||||
Error
|
||||
if err != nil {
|
||||
log.Printf("Failed to count audit logs: %v\n", err)
|
||||
log.Printf("Failed to count audit logs: %v", err)
|
||||
return createdAuditLog
|
||||
}
|
||||
|
||||
|
||||
@@ -122,6 +122,10 @@ func (s *GeoLiteService) DisableUpdater() bool {
|
||||
|
||||
// GetLocationByIP returns the country and city of the given IP address.
|
||||
func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string, err error) {
|
||||
if ipAddress == "" {
|
||||
return "", "", nil
|
||||
}
|
||||
|
||||
// Check the IP address against known private IP ranges
|
||||
if ip := net.ParseIP(ipAddress); ip != nil {
|
||||
// Check IPv6 local ranges first
|
||||
@@ -147,6 +151,11 @@ func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string
|
||||
}
|
||||
}
|
||||
|
||||
addr, err := netip.ParseAddr(ipAddress)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to parse IP address: %w", err)
|
||||
}
|
||||
|
||||
// Race condition between reading and writing the database.
|
||||
s.mutex.RLock()
|
||||
defer s.mutex.RUnlock()
|
||||
@@ -157,11 +166,6 @@ func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
addr, err := netip.ParseAddr(ipAddress)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to parse IP address: %w", err)
|
||||
}
|
||||
|
||||
var record struct {
|
||||
City struct {
|
||||
Names map[string]string `maxminddb:"names"`
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGeoLiteService_IPv6LocalRanges(t *testing.T) {
|
||||
@@ -80,15 +82,9 @@ func TestGeoLiteService_IPv6LocalRanges(t *testing.T) {
|
||||
t.Errorf("Expected error or internal network classification for external IP")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for local IP, got: %v", err)
|
||||
}
|
||||
if country != tt.expectedCountry {
|
||||
t.Errorf("Expected country %s, got %s", tt.expectedCountry, country)
|
||||
}
|
||||
if city != tt.expectedCity {
|
||||
t.Errorf("Expected city %s, got %s", tt.expectedCity, city)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedCountry, country)
|
||||
assert.Equal(t, tt.expectedCity, city)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -148,9 +144,7 @@ func TestGeoLiteService_isLocalIPv6(t *testing.T) {
|
||||
}
|
||||
|
||||
result := service.isLocalIPv6(ip)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %v, got %v for IP %s", tt.expected, result, tt.testIP)
|
||||
}
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -214,18 +208,13 @@ func TestGeoLiteService_initializeIPv6LocalRanges(t *testing.T) {
|
||||
|
||||
err := service.initializeIPv6LocalRanges()
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Errorf("Expected error but got none")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error but got: %v", err)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
rangeCount := len(service.localIPv6Ranges)
|
||||
|
||||
if rangeCount != tt.expectCount {
|
||||
t.Errorf("Expected %d ranges, got %d", tt.expectCount, rangeCount)
|
||||
}
|
||||
assert.Len(t, service.localIPv6Ranges, tt.expectCount)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
ALTER TABLE audit_logs ALTER COLUMN ip_address SET NOT NULL;
|
||||
|
||||
DROP INDEX IF EXISTS idx_audit_logs_created_at;
|
||||
DROP INDEX IF EXISTS idx_audit_logs_user_agent;
|
||||
@@ -0,0 +1,5 @@
|
||||
ALTER TABLE audit_logs ALTER COLUMN ip_address DROP NOT NULL;
|
||||
|
||||
-- Add missing indexes
|
||||
CREATE INDEX idx_audit_logs_created_at ON audit_logs(created_at);
|
||||
CREATE INDEX idx_audit_logs_user_agent ON audit_logs(user_agent);
|
||||
@@ -0,0 +1,30 @@
|
||||
-- Re-create the table with non-nullable ip_address
|
||||
-- We then move the data and rename the table
|
||||
CREATE TABLE audit_logs_new
|
||||
(
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
created_at DATETIME,
|
||||
event TEXT NOT NULL,
|
||||
ip_address TEXT NOT NULL,
|
||||
user_agent TEXT NOT NULL,
|
||||
data BLOB NOT NULL,
|
||||
user_id TEXT REFERENCES users,
|
||||
country TEXT,
|
||||
city TEXT
|
||||
);
|
||||
|
||||
INSERT INTO audit_logs_new
|
||||
SELECT id, created_at, event, ip_address, user_agent, data, user_id, country, city
|
||||
FROM audit_logs;
|
||||
|
||||
DROP TABLE audit_logs;
|
||||
|
||||
ALTER TABLE audit_logs_new RENAME TO audit_logs;
|
||||
|
||||
-- Re-create indexes
|
||||
CREATE INDEX idx_audit_logs_event ON audit_logs(event);
|
||||
CREATE INDEX idx_audit_logs_created_at ON audit_logs(created_at);
|
||||
CREATE INDEX idx_audit_logs_user_id ON audit_logs(user_id);
|
||||
CREATE INDEX idx_audit_logs_user_agent ON audit_logs(user_agent);
|
||||
CREATE INDEX idx_audit_logs_client_name ON audit_logs((json_extract(data, '$.clientName')));
|
||||
CREATE INDEX idx_audit_logs_country ON audit_logs(country);
|
||||
@@ -0,0 +1,30 @@
|
||||
-- Re-create the table with nullable ip_address
|
||||
-- We then move the data and rename the table
|
||||
CREATE TABLE audit_logs_new
|
||||
(
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
created_at DATETIME,
|
||||
event TEXT NOT NULL,
|
||||
ip_address TEXT,
|
||||
user_agent TEXT NOT NULL,
|
||||
data BLOB NOT NULL,
|
||||
user_id TEXT REFERENCES users,
|
||||
country TEXT,
|
||||
city TEXT
|
||||
);
|
||||
|
||||
INSERT INTO audit_logs_new
|
||||
SELECT id, created_at, event, ip_address, user_agent, data, user_id, country, city
|
||||
FROM audit_logs;
|
||||
|
||||
DROP TABLE audit_logs;
|
||||
|
||||
ALTER TABLE audit_logs_new RENAME TO audit_logs;
|
||||
|
||||
-- Re-create indexes
|
||||
CREATE INDEX idx_audit_logs_event ON audit_logs(event);
|
||||
CREATE INDEX idx_audit_logs_created_at ON audit_logs(created_at);
|
||||
CREATE INDEX idx_audit_logs_user_id ON audit_logs(user_id);
|
||||
CREATE INDEX idx_audit_logs_user_agent ON audit_logs(user_agent);
|
||||
CREATE INDEX idx_audit_logs_client_name ON audit_logs((json_extract(data, '$.clientName')));
|
||||
CREATE INDEX idx_audit_logs_country ON audit_logs(country);
|
||||
Reference in New Issue
Block a user