Files
pocket-id/backend/internal/bootstrap/db_bootstrap.go

192 lines
6.2 KiB
Go
Raw Normal View History

package bootstrap
2024-08-12 11:00:25 +02:00
import (
"errors"
"fmt"
"log/slog"
"net/url"
"strings"
"time"
"github.com/glebarez/sqlite"
2024-08-12 11:00:25 +02:00
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
postgresMigrate "github.com/golang-migrate/migrate/v4/database/postgres"
sqliteMigrate "github.com/golang-migrate/migrate/v4/database/sqlite3"
2025-01-03 15:08:55 +01:00
"github.com/golang-migrate/migrate/v4/source/iofs"
slogGorm "github.com/orandin/slog-gorm"
"gorm.io/driver/postgres"
"gorm.io/gorm"
gormLogger "gorm.io/gorm/logger"
"github.com/pocket-id/pocket-id/backend/internal/common"
sqliteutil "github.com/pocket-id/pocket-id/backend/internal/utils/sqlite"
"github.com/pocket-id/pocket-id/backend/resources"
2024-08-12 11:00:25 +02:00
)
func NewDatabase() (db *gorm.DB, err error) {
db, err = connectDatabase()
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
sqlDb, err := db.DB()
2024-08-12 11:00:25 +02:00
if err != nil {
return nil, fmt.Errorf("failed to get sql.DB: %w", err)
2024-08-12 11:00:25 +02:00
}
// Choose the correct driver for the database provider
var driver database.Driver
switch common.EnvConfig.DbProvider {
case common.DbProviderSqlite:
driver, err = sqliteMigrate.WithInstance(sqlDb, &sqliteMigrate.Config{})
case common.DbProviderPostgres:
driver, err = postgresMigrate.WithInstance(sqlDb, &postgresMigrate.Config{})
default:
// Should never happen at this point
return nil, fmt.Errorf("unsupported database provider: %s", common.EnvConfig.DbProvider)
}
if err != nil {
return nil, fmt.Errorf("failed to create migration driver: %w", err)
}
// Run migrations
2025-01-03 15:08:55 +01:00
if err := migrateDatabase(driver); err != nil {
return nil, fmt.Errorf("failed to run migrations: %w", err)
2025-01-03 15:08:55 +01:00
}
return db, nil
2025-01-03 15:08:55 +01:00
}
func migrateDatabase(driver database.Driver) error {
// Use the embedded migrations
source, err := iofs.New(resources.FS, "migrations/"+string(common.EnvConfig.DbProvider))
2024-08-12 11:00:25 +02:00
if err != nil {
return fmt.Errorf("failed to create embedded migration source: %w", err)
2025-01-03 15:08:55 +01:00
}
m, err := migrate.NewWithInstance("iofs", source, "pocket-id", driver)
if err != nil {
return fmt.Errorf("failed to create migration instance: %w", err)
2024-08-12 11:00:25 +02:00
}
err = m.Up()
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
return fmt.Errorf("failed to apply migrations: %w", err)
2024-08-12 11:00:25 +02:00
}
2025-01-03 15:08:55 +01:00
return nil
2024-08-12 11:00:25 +02:00
}
func connectDatabase() (db *gorm.DB, err error) {
var dialector gorm.Dialector
2024-08-12 11:00:25 +02:00
// Choose the correct database provider
switch common.EnvConfig.DbProvider {
case common.DbProviderSqlite:
if common.EnvConfig.DbConnectionString == "" {
return nil, errors.New("missing required env var 'DB_CONNECTION_STRING' for SQLite database")
}
if !strings.HasPrefix(common.EnvConfig.DbConnectionString, "file:") {
return nil, errors.New("invalid value for env var 'DB_CONNECTION_STRING': does not begin with 'file:'")
}
sqliteutil.RegisterSqliteFunctions()
connString, err := parseSqliteConnectionString(common.EnvConfig.DbConnectionString)
if err != nil {
return nil, err
}
dialector = sqlite.Open(connString)
case common.DbProviderPostgres:
if common.EnvConfig.DbConnectionString == "" {
return nil, errors.New("missing required env var 'DB_CONNECTION_STRING' for Postgres database")
}
dialector = postgres.Open(common.EnvConfig.DbConnectionString)
default:
return nil, fmt.Errorf("unsupported database provider: %s", common.EnvConfig.DbProvider)
2024-08-12 11:00:25 +02:00
}
for i := 1; i <= 3; i++ {
db, err = gorm.Open(dialector, &gorm.Config{
2024-08-12 11:00:25 +02:00
TranslateError: true,
Logger: getGormLogger(),
2024-08-12 11:00:25 +02:00
})
if err == nil {
slog.Info("Connected to database", slog.String("provider", string(common.EnvConfig.DbProvider)))
return db, nil
2024-08-12 11:00:25 +02:00
}
slog.Warn("Failed to connect to database, will retry in 3s", slog.Int("attempt", i), slog.String("provider", string(common.EnvConfig.DbProvider)), slog.Any("error", err))
time.Sleep(3 * time.Second)
2024-08-12 11:00:25 +02:00
}
slog.Error("Failed to connect to database after 3 attempts", slog.String("provider", string(common.EnvConfig.DbProvider)), slog.Any("error", err))
return nil, err
2024-08-12 11:00:25 +02:00
}
// The official C implementation of SQLite allows some additional properties in the connection string
// that are not supported in the in the modernc.org/sqlite driver, and which must be passed as PRAGMA args instead.
// To ensure that people can use similar args as in the C driver, which was also used by Pocket ID
// previously (via github.com/mattn/go-sqlite3), we are converting some options.
func parseSqliteConnectionString(connString string) (string, error) {
if !strings.HasPrefix(connString, "file:") {
connString = "file:" + connString
}
connStringUrl, err := url.Parse(connString)
if err != nil {
return "", fmt.Errorf("failed to parse SQLite connection string: %w", err)
}
// Reference: https://github.com/mattn/go-sqlite3?tab=readme-ov-file#connection-string
// This only includes a subset of options, excluding those that are not relevant to us
qs := make(url.Values, len(connStringUrl.Query()))
for k, v := range connStringUrl.Query() {
switch k {
case "_auto_vacuum", "_vacuum":
qs.Add("_pragma", "auto_vacuum("+v[0]+")")
case "_busy_timeout", "_timeout":
qs.Add("_pragma", "busy_timeout("+v[0]+")")
case "_case_sensitive_like", "_cslike":
qs.Add("_pragma", "case_sensitive_like("+v[0]+")")
case "_foreign_keys", "_fk":
qs.Add("_pragma", "foreign_keys("+v[0]+")")
case "_locking_mode", "_locking":
qs.Add("_pragma", "locking_mode("+v[0]+")")
case "_secure_delete":
qs.Add("_pragma", "secure_delete("+v[0]+")")
case "_synchronous", "_sync":
qs.Add("_pragma", "synchronous("+v[0]+")")
default:
// Pass other query-string args as-is
qs[k] = v
}
}
connStringUrl.RawQuery = qs.Encode()
return connStringUrl.String(), nil
}
func getGormLogger() gormLogger.Interface {
loggerOpts := make([]slogGorm.Option, 0, 5)
loggerOpts = append(loggerOpts,
slogGorm.WithSlowThreshold(200*time.Millisecond),
slogGorm.WithErrorField("error"),
)
2024-08-12 11:00:25 +02:00
if common.EnvConfig.AppEnv == "production" {
loggerOpts = append(loggerOpts,
slogGorm.SetLogLevel(slogGorm.DefaultLogType, slog.LevelWarn),
slogGorm.WithIgnoreTrace(),
)
2024-08-12 11:00:25 +02:00
} else {
loggerOpts = append(loggerOpts,
slogGorm.SetLogLevel(slogGorm.DefaultLogType, slog.LevelDebug),
slogGorm.WithRecordNotFoundError(),
slogGorm.WithTraceAll(),
)
2024-08-12 11:00:25 +02:00
}
return slogGorm.New(loggerOpts...)
2024-08-12 11:00:25 +02:00
}