2024-08-17 21:57:14 +02:00
package bootstrap
2024-08-12 11:00:25 +02:00
import (
"errors"
2024-12-12 17:21:28 +01:00
"fmt"
2025-07-27 03:03:52 +02:00
"log/slog"
2025-05-14 00:29:04 -07:00
"net/url"
2025-08-24 11:50:51 -07:00
"os"
"path/filepath"
2025-03-29 15:12:48 -07:00
"strings"
2025-02-05 18:08:01 +01:00
"time"
2025-05-14 00:29:04 -07:00
"github.com/glebarez/sqlite"
2024-08-12 11:00:25 +02:00
"github.com/golang-migrate/migrate/v4"
2024-12-12 17:21:28 +01:00
"github.com/golang-migrate/migrate/v4/database"
postgresMigrate "github.com/golang-migrate/migrate/v4/database/postgres"
sqliteMigrate "github.com/golang-migrate/migrate/v4/database/sqlite3"
2025-08-24 16:56:28 +02:00
_ "github.com/golang-migrate/migrate/v4/source/github"
2025-01-03 15:08:55 +01:00
"github.com/golang-migrate/migrate/v4/source/iofs"
2025-07-27 03:03:52 +02:00
slogGorm "github.com/orandin/slog-gorm"
2024-12-12 17:21:28 +01:00
"gorm.io/driver/postgres"
2024-08-17 21:57:14 +02:00
"gorm.io/gorm"
2025-07-27 03:03:52 +02:00
gormLogger "gorm.io/gorm/logger"
2025-05-14 00:29:04 -07:00
"github.com/pocket-id/pocket-id/backend/internal/common"
2025-08-24 11:50:51 -07:00
"github.com/pocket-id/pocket-id/backend/internal/utils"
2025-07-13 18:15:57 +02:00
sqliteutil "github.com/pocket-id/pocket-id/backend/internal/utils/sqlite"
2025-05-14 00:29:04 -07:00
"github.com/pocket-id/pocket-id/backend/resources"
2024-08-12 11:00:25 +02:00
)
2025-07-27 06:34:23 +02:00
func NewDatabase ( ) ( db * gorm . DB , err error ) {
db , err = connectDatabase ( )
2024-08-17 21:57:14 +02:00
if err != nil {
2025-07-27 06:34:23 +02:00
return nil , fmt . Errorf ( "failed to connect to database: %w" , err )
2024-08-17 21:57:14 +02:00
}
sqlDb , err := db . DB ( )
2024-08-12 11:00:25 +02:00
if err != nil {
2025-07-27 06:34:23 +02:00
return nil , fmt . Errorf ( "failed to get sql.DB: %w" , err )
2024-08-12 11:00:25 +02:00
}
2024-08-17 21:57:14 +02:00
2024-12-12 17:21:28 +01:00
// Choose the correct driver for the database provider
var driver database . Driver
switch common . EnvConfig . DbProvider {
case common . DbProviderSqlite :
2025-08-24 23:07:50 +02:00
driver , err = sqliteMigrate . WithInstance ( sqlDb , & sqliteMigrate . Config {
NoTxWrap : true ,
} )
2024-12-12 17:21:28 +01:00
case common . DbProviderPostgres :
driver , err = postgresMigrate . WithInstance ( sqlDb , & postgresMigrate . Config { } )
default :
2025-03-29 15:12:48 -07:00
// Should never happen at this point
2025-07-27 06:34:23 +02:00
return nil , fmt . Errorf ( "unsupported database provider: %s" , common . EnvConfig . DbProvider )
2024-12-12 17:21:28 +01:00
}
if err != nil {
2025-07-27 06:34:23 +02:00
return nil , fmt . Errorf ( "failed to create migration driver: %w" , err )
2024-12-12 17:21:28 +01:00
}
// Run migrations
2025-01-03 15:08:55 +01:00
if err := migrateDatabase ( driver ) ; err != nil {
2025-07-27 06:34:23 +02:00
return nil , fmt . Errorf ( "failed to run migrations: %w" , err )
2025-01-03 15:08:55 +01:00
}
2025-07-27 06:34:23 +02:00
return db , nil
2025-01-03 15:08:55 +01:00
}
func migrateDatabase ( driver database . Driver ) error {
2025-08-24 16:56:28 +02:00
// Embedded migrations via iofs
path := "migrations/" + string ( common . EnvConfig . DbProvider )
source , err := iofs . New ( resources . FS , path )
2024-08-12 11:00:25 +02:00
if err != nil {
2025-03-27 08:13:56 -07:00
return fmt . Errorf ( "failed to create embedded migration source: %w" , err )
2025-01-03 15:08:55 +01:00
}
m , err := migrate . NewWithInstance ( "iofs" , source , "pocket-id" , driver )
if err != nil {
2025-03-27 08:13:56 -07:00
return fmt . Errorf ( "failed to create migration instance: %w" , err )
2024-08-12 11:00:25 +02:00
}
2025-08-24 16:56:28 +02:00
requiredVersion , err := getRequiredMigrationVersion ( path )
if err != nil {
return fmt . Errorf ( "failed to get last migration version: %w" , err )
}
currentVersion , _ , _ := m . Version ( )
if currentVersion > requiredVersion {
slog . Warn ( "Database version is newer than the application supports, possible downgrade detected" , slog . Uint64 ( "db_version" , uint64 ( currentVersion ) ) , slog . Uint64 ( "app_version" , uint64 ( requiredVersion ) ) )
if ! common . EnvConfig . AllowDowngrade {
return fmt . Errorf ( "database version (%d) is newer than application version (%d), downgrades are not allowed (set ALLOW_DOWNGRADE=true to enable)" , currentVersion , requiredVersion )
}
slog . Info ( "Fetching migrations from GitHub to handle possible downgrades" )
return migrateDatabaseFromGitHub ( driver , requiredVersion )
}
if err := m . Migrate ( requiredVersion ) ; err != nil && ! errors . Is ( err , migrate . ErrNoChange ) {
return fmt . Errorf ( "failed to apply embedded migrations: %w" , err )
}
return nil
}
func migrateDatabaseFromGitHub ( driver database . Driver , version uint ) error {
srcURL := "github://pocket-id/pocket-id/backend/resources/migrations/" + string ( common . EnvConfig . DbProvider )
m , err := migrate . NewWithDatabaseInstance ( srcURL , "pocket-id" , driver )
if err != nil {
return fmt . Errorf ( "failed to create GitHub migration instance: %w" , err )
2024-08-12 11:00:25 +02:00
}
2024-08-17 21:57:14 +02:00
2025-08-24 16:56:28 +02:00
if err := m . Migrate ( version ) ; err != nil && ! errors . Is ( err , migrate . ErrNoChange ) {
return fmt . Errorf ( "failed to apply GitHub migrations: %w" , err )
}
2025-01-03 15:08:55 +01:00
return nil
2024-08-12 11:00:25 +02:00
}
2025-08-24 16:56:28 +02:00
// getRequiredMigrationVersion reads the embedded migration files and returns the highest version number found.
func getRequiredMigrationVersion ( path string ) ( uint , error ) {
entries , err := resources . FS . ReadDir ( path )
if err != nil {
return 0 , fmt . Errorf ( "failed to read migration directory: %w" , err )
}
var maxVersion uint
for _ , entry := range entries {
if entry . IsDir ( ) {
continue
}
name := entry . Name ( )
var version uint
n , err := fmt . Sscanf ( name , "%d_" , & version )
if err == nil && n == 1 {
if version > maxVersion {
maxVersion = version
}
}
}
return maxVersion , nil
}
2024-08-17 21:57:14 +02:00
func connectDatabase ( ) ( db * gorm . DB , err error ) {
2024-12-12 17:21:28 +01:00
var dialector gorm . Dialector
2024-08-12 11:00:25 +02:00
2024-12-12 17:21:28 +01:00
// Choose the correct database provider
switch common . EnvConfig . DbProvider {
case common . DbProviderSqlite :
2025-03-29 15:12:48 -07:00
if common . EnvConfig . DbConnectionString == "" {
return nil , errors . New ( "missing required env var 'DB_CONNECTION_STRING' for SQLite database" )
}
2025-08-24 11:50:51 -07:00
2025-07-13 18:15:57 +02:00
sqliteutil . RegisterSqliteFunctions ( )
2025-08-24 11:50:51 -07:00
connString , dbPath , err := parseSqliteConnectionString ( common . EnvConfig . DbConnectionString )
if err != nil {
return nil , err
}
// Before we connect, also make sure that there's a temporary folder for SQLite to write its data
err = ensureSqliteTempDir ( filepath . Dir ( dbPath ) )
2025-05-14 00:29:04 -07:00
if err != nil {
return nil , err
}
2025-08-24 11:50:51 -07:00
2025-05-14 00:29:04 -07:00
dialector = sqlite . Open ( connString )
2024-12-12 17:21:28 +01:00
case common . DbProviderPostgres :
2025-03-29 15:12:48 -07:00
if common . EnvConfig . DbConnectionString == "" {
return nil , errors . New ( "missing required env var 'DB_CONNECTION_STRING' for Postgres database" )
}
dialector = postgres . Open ( common . EnvConfig . DbConnectionString )
2024-12-12 17:21:28 +01:00
default :
return nil , fmt . Errorf ( "unsupported database provider: %s" , common . EnvConfig . DbProvider )
2024-08-12 11:00:25 +02:00
}
for i := 1 ; i <= 3 ; i ++ {
2024-12-12 17:21:28 +01:00
db , err = gorm . Open ( dialector , & gorm . Config {
2024-08-12 11:00:25 +02:00
TranslateError : true ,
2025-07-27 03:03:52 +02:00
Logger : getGormLogger ( ) ,
2024-08-12 11:00:25 +02:00
} )
if err == nil {
2025-08-06 18:04:25 +02:00
slog . Info ( "Connected to database" , slog . String ( "provider" , string ( common . EnvConfig . DbProvider ) ) )
2025-03-29 15:12:48 -07:00
return db , nil
2024-08-12 11:00:25 +02:00
}
2025-03-29 15:12:48 -07:00
2025-08-06 18:04:25 +02:00
slog . Warn ( "Failed to connect to database, will retry in 3s" , slog . Int ( "attempt" , i ) , slog . String ( "provider" , string ( common . EnvConfig . DbProvider ) ) , slog . Any ( "error" , err ) )
2025-03-29 15:12:48 -07:00
time . Sleep ( 3 * time . Second )
2024-08-12 11:00:25 +02:00
}
2025-08-06 18:04:25 +02:00
slog . Error ( "Failed to connect to database after 3 attempts" , slog . String ( "provider" , string ( common . EnvConfig . DbProvider ) ) , slog . Any ( "error" , err ) )
2025-03-29 15:12:48 -07:00
return nil , err
2024-08-12 11:00:25 +02:00
}
2025-08-24 11:50:51 -07:00
func parseSqliteConnectionString ( connString string ) ( parsedConnString string , dbPath string , err error ) {
2025-05-14 00:29:04 -07:00
if ! strings . HasPrefix ( connString , "file:" ) {
connString = "file:" + connString
}
2025-08-23 17:54:51 +02:00
// Check if we're using an in-memory database
isMemoryDB := isSqliteInMemory ( connString )
// Parse the connection string
2025-05-14 00:29:04 -07:00
connStringUrl , err := url . Parse ( connString )
if err != nil {
2025-08-24 11:50:51 -07:00
return "" , "" , fmt . Errorf ( "failed to parse SQLite connection string: %w" , err )
2025-05-14 00:29:04 -07:00
}
2025-08-23 17:54:51 +02:00
// Convert options for the old SQLite driver to the new one
convertSqlitePragmaArgs ( connStringUrl )
// Add the default and required params
err = addSqliteDefaultParameters ( connStringUrl , isMemoryDB )
if err != nil {
2025-08-24 11:50:51 -07:00
return "" , "" , fmt . Errorf ( "invalid SQLite connection string: %w" , err )
}
// Get the absolute path to the database
// Here, we know for a fact that the ? is present
parsedConnString = connStringUrl . String ( )
idx := strings . IndexRune ( parsedConnString , '?' )
dbPath , err = filepath . Abs ( parsedConnString [ len ( "file:" ) : idx ] )
if err != nil {
return "" , "" , fmt . Errorf ( "failed to determine absolute path to the database: %w" , err )
2025-08-23 17:54:51 +02:00
}
2025-08-24 11:50:51 -07:00
return parsedConnString , dbPath , nil
2025-08-23 17:54:51 +02:00
}
// The official C implementation of SQLite allows some additional properties in the connection string
// that are not supported in the in the modernc.org/sqlite driver, and which must be passed as PRAGMA args instead.
// To ensure that people can use similar args as in the C driver, which was also used by Pocket ID
// previously (via github.com/mattn/go-sqlite3), we are converting some options.
// Note this function updates connStringUrl.
func convertSqlitePragmaArgs ( connStringUrl * url . URL ) {
2025-05-14 00:29:04 -07:00
// Reference: https://github.com/mattn/go-sqlite3?tab=readme-ov-file#connection-string
// This only includes a subset of options, excluding those that are not relevant to us
qs := make ( url . Values , len ( connStringUrl . Query ( ) ) )
for k , v := range connStringUrl . Query ( ) {
2025-08-23 17:54:51 +02:00
switch strings . ToLower ( k ) {
2025-05-14 00:29:04 -07:00
case "_auto_vacuum" , "_vacuum" :
qs . Add ( "_pragma" , "auto_vacuum(" + v [ 0 ] + ")" )
case "_busy_timeout" , "_timeout" :
qs . Add ( "_pragma" , "busy_timeout(" + v [ 0 ] + ")" )
case "_case_sensitive_like" , "_cslike" :
qs . Add ( "_pragma" , "case_sensitive_like(" + v [ 0 ] + ")" )
case "_foreign_keys" , "_fk" :
qs . Add ( "_pragma" , "foreign_keys(" + v [ 0 ] + ")" )
case "_locking_mode" , "_locking" :
qs . Add ( "_pragma" , "locking_mode(" + v [ 0 ] + ")" )
case "_secure_delete" :
qs . Add ( "_pragma" , "secure_delete(" + v [ 0 ] + ")" )
case "_synchronous" , "_sync" :
qs . Add ( "_pragma" , "synchronous(" + v [ 0 ] + ")" )
default :
// Pass other query-string args as-is
qs [ k ] = v
}
}
2025-08-23 17:54:51 +02:00
// Update the connStringUrl object
2025-05-14 00:29:04 -07:00
connStringUrl . RawQuery = qs . Encode ( )
2025-08-23 17:54:51 +02:00
}
2025-05-14 00:29:04 -07:00
2025-08-23 17:54:51 +02:00
// Adds the default (and some required) parameters to the SQLite connection string.
// Note this function updates connStringUrl.
func addSqliteDefaultParameters ( connStringUrl * url . URL , isMemoryDB bool ) error {
// This function include code adapted from https://github.com/dapr/components-contrib/blob/v1.14.6/
// Copyright (C) 2023 The Dapr Authors
// License: Apache2
const defaultBusyTimeout = 2500 * time . Millisecond
// Get the "query string" from the connection string if present
qs := connStringUrl . Query ( )
if len ( qs ) == 0 {
qs = make ( url . Values , 2 )
}
// If the database is in-memory, we must ensure that cache=shared is set
if isMemoryDB {
qs [ "cache" ] = [ ] string { "shared" }
}
// Check if the database is read-only or immutable
isReadOnly := false
if len ( qs [ "mode" ] ) > 0 {
// Keep the first value only
qs [ "mode" ] = [ ] string {
strings . ToLower ( qs [ "mode" ] [ 0 ] ) ,
}
if qs [ "mode" ] [ 0 ] == "ro" {
isReadOnly = true
}
}
if len ( qs [ "immutable" ] ) > 0 {
// Keep the first value only
qs [ "immutable" ] = [ ] string {
strings . ToLower ( qs [ "immutable" ] [ 0 ] ) ,
}
if qs [ "immutable" ] [ 0 ] == "1" {
isReadOnly = true
}
}
// We do not want to override a _txlock if set, but we'll show a warning if it's not "immediate"
if len ( qs [ "_txlock" ] ) > 0 {
// Keep the first value only
qs [ "_txlock" ] = [ ] string {
strings . ToLower ( qs [ "_txlock" ] [ 0 ] ) ,
}
if qs [ "_txlock" ] [ 0 ] != "immediate" {
slog . Warn ( "SQLite connection is being created with a _txlock different from the recommended value 'immediate'" )
}
} else {
qs [ "_txlock" ] = [ ] string { "immediate" }
}
// Add pragma values
var hasBusyTimeout , hasJournalMode bool
if len ( qs [ "_pragma" ] ) == 0 {
qs [ "_pragma" ] = make ( [ ] string , 0 , 3 )
} else {
for _ , p := range qs [ "_pragma" ] {
p = strings . ToLower ( p )
switch {
case strings . HasPrefix ( p , "busy_timeout" ) :
hasBusyTimeout = true
case strings . HasPrefix ( p , "journal_mode" ) :
hasJournalMode = true
case strings . HasPrefix ( p , "foreign_keys" ) :
return errors . New ( "found forbidden option '_pragma=foreign_keys' in the connection string" )
}
}
}
if ! hasBusyTimeout {
qs [ "_pragma" ] = append ( qs [ "_pragma" ] , fmt . Sprintf ( "busy_timeout(%d)" , defaultBusyTimeout . Milliseconds ( ) ) )
}
if ! hasJournalMode {
switch {
case isMemoryDB :
// For in-memory databases, set the journal to MEMORY, the only allowed option besides OFF (which would make transactions ineffective)
qs [ "_pragma" ] = append ( qs [ "_pragma" ] , "journal_mode(MEMORY)" )
case isReadOnly :
// Set the journaling mode to "DELETE" (the default) if the database is read-only
qs [ "_pragma" ] = append ( qs [ "_pragma" ] , "journal_mode(DELETE)" )
default :
// Enable WAL
qs [ "_pragma" ] = append ( qs [ "_pragma" ] , "journal_mode(WAL)" )
}
}
// Forcefully enable foreign keys
qs [ "_pragma" ] = append ( qs [ "_pragma" ] , "foreign_keys(1)" )
// Update the connStringUrl object
connStringUrl . RawQuery = qs . Encode ( )
return nil
}
// isSqliteInMemory returns true if the connection string is for an in-memory database.
func isSqliteInMemory ( connString string ) bool {
lc := strings . ToLower ( connString )
// First way to define an in-memory database is to use ":memory:" or "file::memory:" as connection string
if strings . HasPrefix ( lc , ":memory:" ) || strings . HasPrefix ( lc , "file::memory:" ) {
return true
}
// Another way is to pass "mode=memory" in the "query string"
idx := strings . IndexRune ( lc , '?' )
if idx < 0 {
return false
}
qs , _ := url . ParseQuery ( lc [ ( idx + 1 ) : ] )
return len ( qs [ "mode" ] ) > 0 && qs [ "mode" ] [ 0 ] == "memory"
2025-05-14 00:29:04 -07:00
}
2025-08-24 11:50:51 -07:00
// ensureSqliteTempDir ensures that SQLite has a directory where it can write temporary files if needed
// The default directory may not be writable when using a container with a read-only root file system
// See: https://www.sqlite.org/tempfiles.html
func ensureSqliteTempDir ( dbPath string ) error {
// Per docs, SQLite tries these folders in order (excluding those that aren't applicable to us):
//
// - The SQLITE_TMPDIR environment variable
// - The TMPDIR environment variable
// - /var/tmp
// - /usr/tmp
// - /tmp
//
// Source: https://www.sqlite.org/tempfiles.html#temporary_file_storage_locations
//
// First, let's check if SQLITE_TMPDIR or TMPDIR are set, in which case we trust the user has taken care of the problem already
if os . Getenv ( "SQLITE_TMPDIR" ) != "" || os . Getenv ( "TMPDIR" ) != "" {
return nil
}
// Now, let's check if /var/tmp, /usr/tmp, or /tmp exist and are writable
for _ , dir := range [ ] string { "/var/tmp" , "/usr/tmp" , "/tmp" } {
ok , err := utils . IsWritableDir ( dir )
if err != nil {
return fmt . Errorf ( "failed to check if %s is writable: %w" , dir , err )
}
if ok {
// We found a folder that's writable
return nil
}
}
// If we're here, there's no temporary directory that's writable (not unusual for containers with a read-only root file system), so we set SQLITE_TMPDIR to the folder where the SQLite database is set
err := os . Setenv ( "SQLITE_TMPDIR" , dbPath )
if err != nil {
return fmt . Errorf ( "failed to set SQLITE_TMPDIR environmental variable: %w" , err )
}
slog . Debug ( "Set SQLITE_TMPDIR to the database directory" , "path" , dbPath )
return nil
}
2025-07-27 03:03:52 +02:00
func getGormLogger ( ) gormLogger . Interface {
loggerOpts := make ( [ ] slogGorm . Option , 0 , 5 )
loggerOpts = append ( loggerOpts ,
slogGorm . WithSlowThreshold ( 200 * time . Millisecond ) ,
slogGorm . WithErrorField ( "error" ) ,
)
2024-08-12 11:00:25 +02:00
2025-09-14 17:26:21 +02:00
if common . EnvConfig . LogLevel == "debug" {
2025-07-27 03:03:52 +02:00
loggerOpts = append ( loggerOpts ,
slogGorm . SetLogLevel ( slogGorm . DefaultLogType , slog . LevelDebug ) ,
slogGorm . WithRecordNotFoundError ( ) ,
slogGorm . WithTraceAll ( ) ,
)
2025-09-14 17:26:21 +02:00
} else {
loggerOpts = append ( loggerOpts ,
slogGorm . SetLogLevel ( slogGorm . DefaultLogType , slog . LevelWarn ) ,
slogGorm . WithIgnoreTrace ( ) ,
)
2024-08-12 11:00:25 +02:00
}
2025-07-27 03:03:52 +02:00
return slogGorm . New ( loggerOpts ... )
2024-08-12 11:00:25 +02:00
}