2024-08-17 21:57:14 +02:00
package bootstrap
2024-08-12 11:00:25 +02:00
import (
"errors"
2024-12-12 17:21:28 +01:00
"fmt"
2025-07-27 03:03:52 +02:00
"log/slog"
2025-05-14 00:29:04 -07:00
"net/url"
2025-03-29 15:12:48 -07:00
"strings"
2025-02-05 18:08:01 +01:00
"time"
2025-05-14 00:29:04 -07:00
"github.com/glebarez/sqlite"
2024-08-12 11:00:25 +02:00
"github.com/golang-migrate/migrate/v4"
2024-12-12 17:21:28 +01:00
"github.com/golang-migrate/migrate/v4/database"
postgresMigrate "github.com/golang-migrate/migrate/v4/database/postgres"
sqliteMigrate "github.com/golang-migrate/migrate/v4/database/sqlite3"
2025-01-03 15:08:55 +01:00
"github.com/golang-migrate/migrate/v4/source/iofs"
2025-07-27 03:03:52 +02:00
slogGorm "github.com/orandin/slog-gorm"
2024-12-12 17:21:28 +01:00
"gorm.io/driver/postgres"
2024-08-17 21:57:14 +02:00
"gorm.io/gorm"
2025-07-27 03:03:52 +02:00
gormLogger "gorm.io/gorm/logger"
2025-05-14 00:29:04 -07:00
"github.com/pocket-id/pocket-id/backend/internal/common"
2025-07-13 18:15:57 +02:00
sqliteutil "github.com/pocket-id/pocket-id/backend/internal/utils/sqlite"
2025-05-14 00:29:04 -07:00
"github.com/pocket-id/pocket-id/backend/resources"
2024-08-12 11:00:25 +02:00
)
2025-07-27 06:34:23 +02:00
func NewDatabase ( ) ( db * gorm . DB , err error ) {
db , err = connectDatabase ( )
2024-08-17 21:57:14 +02:00
if err != nil {
2025-07-27 06:34:23 +02:00
return nil , fmt . Errorf ( "failed to connect to database: %w" , err )
2024-08-17 21:57:14 +02:00
}
sqlDb , err := db . DB ( )
2024-08-12 11:00:25 +02:00
if err != nil {
2025-07-27 06:34:23 +02:00
return nil , fmt . Errorf ( "failed to get sql.DB: %w" , err )
2024-08-12 11:00:25 +02:00
}
2024-08-17 21:57:14 +02:00
2024-12-12 17:21:28 +01:00
// Choose the correct driver for the database provider
var driver database . Driver
switch common . EnvConfig . DbProvider {
case common . DbProviderSqlite :
driver , err = sqliteMigrate . WithInstance ( sqlDb , & sqliteMigrate . Config { } )
case common . DbProviderPostgres :
driver , err = postgresMigrate . WithInstance ( sqlDb , & postgresMigrate . Config { } )
default :
2025-03-29 15:12:48 -07:00
// Should never happen at this point
2025-07-27 06:34:23 +02:00
return nil , fmt . Errorf ( "unsupported database provider: %s" , common . EnvConfig . DbProvider )
2024-12-12 17:21:28 +01:00
}
if err != nil {
2025-07-27 06:34:23 +02:00
return nil , fmt . Errorf ( "failed to create migration driver: %w" , err )
2024-12-12 17:21:28 +01:00
}
// Run migrations
2025-01-03 15:08:55 +01:00
if err := migrateDatabase ( driver ) ; err != nil {
2025-07-27 06:34:23 +02:00
return nil , fmt . Errorf ( "failed to run migrations: %w" , err )
2025-01-03 15:08:55 +01:00
}
2025-07-27 06:34:23 +02:00
return db , nil
2025-01-03 15:08:55 +01:00
}
func migrateDatabase ( driver database . Driver ) error {
// Use the embedded migrations
source , err := iofs . New ( resources . FS , "migrations/" + string ( common . EnvConfig . DbProvider ) )
2024-08-12 11:00:25 +02:00
if err != nil {
2025-03-27 08:13:56 -07:00
return fmt . Errorf ( "failed to create embedded migration source: %w" , err )
2025-01-03 15:08:55 +01:00
}
m , err := migrate . NewWithInstance ( "iofs" , source , "pocket-id" , driver )
if err != nil {
2025-03-27 08:13:56 -07:00
return fmt . Errorf ( "failed to create migration instance: %w" , err )
2024-08-12 11:00:25 +02:00
}
err = m . Up ( )
if err != nil && ! errors . Is ( err , migrate . ErrNoChange ) {
2025-03-27 08:13:56 -07:00
return fmt . Errorf ( "failed to apply migrations: %w" , err )
2024-08-12 11:00:25 +02:00
}
2024-08-17 21:57:14 +02:00
2025-01-03 15:08:55 +01:00
return nil
2024-08-12 11:00:25 +02:00
}
2024-08-17 21:57:14 +02:00
func connectDatabase ( ) ( db * gorm . DB , err error ) {
2024-12-12 17:21:28 +01:00
var dialector gorm . Dialector
2024-08-12 11:00:25 +02:00
2024-12-12 17:21:28 +01:00
// Choose the correct database provider
switch common . EnvConfig . DbProvider {
case common . DbProviderSqlite :
2025-03-29 15:12:48 -07:00
if common . EnvConfig . DbConnectionString == "" {
return nil , errors . New ( "missing required env var 'DB_CONNECTION_STRING' for SQLite database" )
}
if ! strings . HasPrefix ( common . EnvConfig . DbConnectionString , "file:" ) {
return nil , errors . New ( "invalid value for env var 'DB_CONNECTION_STRING': does not begin with 'file:'" )
}
2025-07-13 18:15:57 +02:00
sqliteutil . RegisterSqliteFunctions ( )
2025-05-14 00:29:04 -07:00
connString , err := parseSqliteConnectionString ( common . EnvConfig . DbConnectionString )
if err != nil {
return nil , err
}
dialector = sqlite . Open ( connString )
2024-12-12 17:21:28 +01:00
case common . DbProviderPostgres :
2025-03-29 15:12:48 -07:00
if common . EnvConfig . DbConnectionString == "" {
return nil , errors . New ( "missing required env var 'DB_CONNECTION_STRING' for Postgres database" )
}
dialector = postgres . Open ( common . EnvConfig . DbConnectionString )
2024-12-12 17:21:28 +01:00
default :
return nil , fmt . Errorf ( "unsupported database provider: %s" , common . EnvConfig . DbProvider )
2024-08-12 11:00:25 +02:00
}
for i := 1 ; i <= 3 ; i ++ {
2024-12-12 17:21:28 +01:00
db , err = gorm . Open ( dialector , & gorm . Config {
2024-08-12 11:00:25 +02:00
TranslateError : true ,
2025-07-27 03:03:52 +02:00
Logger : getGormLogger ( ) ,
2024-08-12 11:00:25 +02:00
} )
if err == nil {
2025-08-06 18:04:25 +02:00
slog . Info ( "Connected to database" , slog . String ( "provider" , string ( common . EnvConfig . DbProvider ) ) )
2025-03-29 15:12:48 -07:00
return db , nil
2024-08-12 11:00:25 +02:00
}
2025-03-29 15:12:48 -07:00
2025-08-06 18:04:25 +02:00
slog . Warn ( "Failed to connect to database, will retry in 3s" , slog . Int ( "attempt" , i ) , slog . String ( "provider" , string ( common . EnvConfig . DbProvider ) ) , slog . Any ( "error" , err ) )
2025-03-29 15:12:48 -07:00
time . Sleep ( 3 * time . Second )
2024-08-12 11:00:25 +02:00
}
2025-08-06 18:04:25 +02:00
slog . Error ( "Failed to connect to database after 3 attempts" , slog . String ( "provider" , string ( common . EnvConfig . DbProvider ) ) , slog . Any ( "error" , err ) )
2025-03-29 15:12:48 -07:00
return nil , err
2024-08-12 11:00:25 +02:00
}
2025-05-14 00:29:04 -07:00
// The official C implementation of SQLite allows some additional properties in the connection string
// that are not supported in the in the modernc.org/sqlite driver, and which must be passed as PRAGMA args instead.
// To ensure that people can use similar args as in the C driver, which was also used by Pocket ID
// previously (via github.com/mattn/go-sqlite3), we are converting some options.
func parseSqliteConnectionString ( connString string ) ( string , error ) {
if ! strings . HasPrefix ( connString , "file:" ) {
connString = "file:" + connString
}
connStringUrl , err := url . Parse ( connString )
if err != nil {
return "" , fmt . Errorf ( "failed to parse SQLite connection string: %w" , err )
}
// Reference: https://github.com/mattn/go-sqlite3?tab=readme-ov-file#connection-string
// This only includes a subset of options, excluding those that are not relevant to us
qs := make ( url . Values , len ( connStringUrl . Query ( ) ) )
for k , v := range connStringUrl . Query ( ) {
switch k {
case "_auto_vacuum" , "_vacuum" :
qs . Add ( "_pragma" , "auto_vacuum(" + v [ 0 ] + ")" )
case "_busy_timeout" , "_timeout" :
qs . Add ( "_pragma" , "busy_timeout(" + v [ 0 ] + ")" )
case "_case_sensitive_like" , "_cslike" :
qs . Add ( "_pragma" , "case_sensitive_like(" + v [ 0 ] + ")" )
case "_foreign_keys" , "_fk" :
qs . Add ( "_pragma" , "foreign_keys(" + v [ 0 ] + ")" )
case "_locking_mode" , "_locking" :
qs . Add ( "_pragma" , "locking_mode(" + v [ 0 ] + ")" )
case "_secure_delete" :
qs . Add ( "_pragma" , "secure_delete(" + v [ 0 ] + ")" )
case "_synchronous" , "_sync" :
qs . Add ( "_pragma" , "synchronous(" + v [ 0 ] + ")" )
default :
// Pass other query-string args as-is
qs [ k ] = v
}
}
connStringUrl . RawQuery = qs . Encode ( )
return connStringUrl . String ( ) , nil
}
2025-07-27 03:03:52 +02:00
func getGormLogger ( ) gormLogger . Interface {
loggerOpts := make ( [ ] slogGorm . Option , 0 , 5 )
loggerOpts = append ( loggerOpts ,
slogGorm . WithSlowThreshold ( 200 * time . Millisecond ) ,
slogGorm . WithErrorField ( "error" ) ,
)
2024-08-12 11:00:25 +02:00
2025-07-27 03:03:52 +02:00
if common . EnvConfig . AppEnv == "production" {
loggerOpts = append ( loggerOpts ,
slogGorm . SetLogLevel ( slogGorm . DefaultLogType , slog . LevelWarn ) ,
slogGorm . WithIgnoreTrace ( ) ,
)
2024-08-12 11:00:25 +02:00
} else {
2025-07-27 03:03:52 +02:00
loggerOpts = append ( loggerOpts ,
slogGorm . SetLogLevel ( slogGorm . DefaultLogType , slog . LevelDebug ) ,
slogGorm . WithRecordNotFoundError ( ) ,
slogGorm . WithTraceAll ( ) ,
)
2024-08-12 11:00:25 +02:00
}
2025-07-27 03:03:52 +02:00
return slogGorm . New ( loggerOpts ... )
2024-08-12 11:00:25 +02:00
}