diff --git a/backend/internal/bootstrap/db_bootstrap.go b/backend/internal/bootstrap/db_bootstrap.go index e29b884f..eb754f8c 100644 --- a/backend/internal/bootstrap/db_bootstrap.go +++ b/backend/internal/bootstrap/db_bootstrap.go @@ -5,6 +5,8 @@ import ( "fmt" "log/slog" "net/url" + "os" + "path/filepath" "strings" "time" @@ -21,6 +23,7 @@ import ( gormLogger "gorm.io/gorm/logger" "github.com/pocket-id/pocket-id/backend/internal/common" + "github.com/pocket-id/pocket-id/backend/internal/utils" sqliteutil "github.com/pocket-id/pocket-id/backend/internal/utils/sqlite" "github.com/pocket-id/pocket-id/backend/resources" ) @@ -140,11 +143,20 @@ func connectDatabase() (db *gorm.DB, err error) { if common.EnvConfig.DbConnectionString == "" { return nil, errors.New("missing required env var 'DB_CONNECTION_STRING' for SQLite database") } + sqliteutil.RegisterSqliteFunctions() - connString, err := parseSqliteConnectionString(common.EnvConfig.DbConnectionString) + + connString, dbPath, err := parseSqliteConnectionString(common.EnvConfig.DbConnectionString) if err != nil { return nil, err } + + // Before we connect, also make sure that there's a temporary folder for SQLite to write its data + err = ensureSqliteTempDir(filepath.Dir(dbPath)) + if err != nil { + return nil, err + } + dialector = sqlite.Open(connString) case common.DbProviderPostgres: if common.EnvConfig.DbConnectionString == "" { @@ -174,7 +186,7 @@ func connectDatabase() (db *gorm.DB, err error) { return nil, err } -func parseSqliteConnectionString(connString string) (string, error) { +func parseSqliteConnectionString(connString string) (parsedConnString string, dbPath string, err error) { if !strings.HasPrefix(connString, "file:") { connString = "file:" + connString } @@ -185,7 +197,7 @@ func parseSqliteConnectionString(connString string) (string, error) { // Parse the connection string connStringUrl, err := url.Parse(connString) if err != nil { - return "", fmt.Errorf("failed to parse SQLite connection string: %w", err) + return "", "", fmt.Errorf("failed to parse SQLite connection string: %w", err) } // Convert options for the old SQLite driver to the new one @@ -194,10 +206,19 @@ func parseSqliteConnectionString(connString string) (string, error) { // Add the default and required params err = addSqliteDefaultParameters(connStringUrl, isMemoryDB) if err != nil { - return "", fmt.Errorf("invalid SQLite connection string: %w", err) + return "", "", fmt.Errorf("invalid SQLite connection string: %w", err) } - return connStringUrl.String(), nil + // Get the absolute path to the database + // Here, we know for a fact that the ? is present + parsedConnString = connStringUrl.String() + idx := strings.IndexRune(parsedConnString, '?') + dbPath, err = filepath.Abs(parsedConnString[len("file:"):idx]) + if err != nil { + return "", "", fmt.Errorf("failed to determine absolute path to the database: %w", err) + } + + return parsedConnString, dbPath, nil } // The official C implementation of SQLite allows some additional properties in the connection string @@ -350,6 +371,48 @@ func isSqliteInMemory(connString string) bool { return len(qs["mode"]) > 0 && qs["mode"][0] == "memory" } +// ensureSqliteTempDir ensures that SQLite has a directory where it can write temporary files if needed +// The default directory may not be writable when using a container with a read-only root file system +// See: https://www.sqlite.org/tempfiles.html +func ensureSqliteTempDir(dbPath string) error { + // Per docs, SQLite tries these folders in order (excluding those that aren't applicable to us): + // + // - The SQLITE_TMPDIR environment variable + // - The TMPDIR environment variable + // - /var/tmp + // - /usr/tmp + // - /tmp + // + // Source: https://www.sqlite.org/tempfiles.html#temporary_file_storage_locations + // + // First, let's check if SQLITE_TMPDIR or TMPDIR are set, in which case we trust the user has taken care of the problem already + if os.Getenv("SQLITE_TMPDIR") != "" || os.Getenv("TMPDIR") != "" { + return nil + } + + // Now, let's check if /var/tmp, /usr/tmp, or /tmp exist and are writable + for _, dir := range []string{"/var/tmp", "/usr/tmp", "/tmp"} { + ok, err := utils.IsWritableDir(dir) + if err != nil { + return fmt.Errorf("failed to check if %s is writable: %w", dir, err) + } + if ok { + // We found a folder that's writable + return nil + } + } + + // If we're here, there's no temporary directory that's writable (not unusual for containers with a read-only root file system), so we set SQLITE_TMPDIR to the folder where the SQLite database is set + err := os.Setenv("SQLITE_TMPDIR", dbPath) + if err != nil { + return fmt.Errorf("failed to set SQLITE_TMPDIR environmental variable: %w", err) + } + + slog.Debug("Set SQLITE_TMPDIR to the database directory", "path", dbPath) + + return nil +} + func getGormLogger() gormLogger.Interface { loggerOpts := make([]slogGorm.Option, 0, 5) loggerOpts = append(loggerOpts, diff --git a/backend/internal/utils/file_util.go b/backend/internal/utils/file_util.go index 0fa322dc..3d2fde71 100644 --- a/backend/internal/utils/file_util.go +++ b/backend/internal/utils/file_util.go @@ -1,12 +1,15 @@ package utils import ( + "crypto/rand" + "encoding/hex" "errors" "fmt" "io" "mime/multipart" "os" "path/filepath" + "syscall" "github.com/google/uuid" "github.com/pocket-id/pocket-id/backend/resources" @@ -136,3 +139,41 @@ func FileExists(path string) (bool, error) { } return !s.IsDir(), nil } + +// IsWritableDir checks if a directory exists and is writable +func IsWritableDir(dir string) (bool, error) { + // Check if directory exists and it's actually a directory + info, err := os.Stat(dir) + if os.IsNotExist(err) { + return false, nil + } else if err != nil { + return false, fmt.Errorf("failed to stat '%s': %w", dir, err) + } + if !info.IsDir() { + return false, nil + } + + // Generate a random suffix for the test file to avoid conflicts + randomBytes := make([]byte, 8) + _, err = io.ReadFull(rand.Reader, randomBytes) + if err != nil { + return false, fmt.Errorf("failed to generate random bytes: %w", err) + } + + // Check if directory is writable by trying to create a temporary file + testFile := filepath.Join(dir, ".pocketid_test_write_"+hex.EncodeToString(randomBytes)) + defer os.Remove(testFile) + + file, err := os.Create(testFile) + if err != nil { + if os.IsPermission(err) || errors.Is(err, syscall.EROFS) { + return false, nil + } + + return false, fmt.Errorf("failed to create test file: %w", err) + } + + _ = file.Close() + + return true, nil +}