mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-16 09:13:20 +03:00
fix: ensure SQLite has a writable temporary directory (#876)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
committed by
GitHub
parent
912008b048
commit
1f3550c9bd
@@ -5,6 +5,8 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -21,6 +23,7 @@ import (
|
||||
gormLogger "gorm.io/gorm/logger"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
sqliteutil "github.com/pocket-id/pocket-id/backend/internal/utils/sqlite"
|
||||
"github.com/pocket-id/pocket-id/backend/resources"
|
||||
)
|
||||
@@ -140,11 +143,20 @@ func connectDatabase() (db *gorm.DB, err error) {
|
||||
if common.EnvConfig.DbConnectionString == "" {
|
||||
return nil, errors.New("missing required env var 'DB_CONNECTION_STRING' for SQLite database")
|
||||
}
|
||||
|
||||
sqliteutil.RegisterSqliteFunctions()
|
||||
connString, err := parseSqliteConnectionString(common.EnvConfig.DbConnectionString)
|
||||
|
||||
connString, dbPath, err := parseSqliteConnectionString(common.EnvConfig.DbConnectionString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Before we connect, also make sure that there's a temporary folder for SQLite to write its data
|
||||
err = ensureSqliteTempDir(filepath.Dir(dbPath))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dialector = sqlite.Open(connString)
|
||||
case common.DbProviderPostgres:
|
||||
if common.EnvConfig.DbConnectionString == "" {
|
||||
@@ -174,7 +186,7 @@ func connectDatabase() (db *gorm.DB, err error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func parseSqliteConnectionString(connString string) (string, error) {
|
||||
func parseSqliteConnectionString(connString string) (parsedConnString string, dbPath string, err error) {
|
||||
if !strings.HasPrefix(connString, "file:") {
|
||||
connString = "file:" + connString
|
||||
}
|
||||
@@ -185,7 +197,7 @@ func parseSqliteConnectionString(connString string) (string, error) {
|
||||
// Parse the connection string
|
||||
connStringUrl, err := url.Parse(connString)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse SQLite connection string: %w", err)
|
||||
return "", "", fmt.Errorf("failed to parse SQLite connection string: %w", err)
|
||||
}
|
||||
|
||||
// Convert options for the old SQLite driver to the new one
|
||||
@@ -194,10 +206,19 @@ func parseSqliteConnectionString(connString string) (string, error) {
|
||||
// Add the default and required params
|
||||
err = addSqliteDefaultParameters(connStringUrl, isMemoryDB)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid SQLite connection string: %w", err)
|
||||
return "", "", fmt.Errorf("invalid SQLite connection string: %w", err)
|
||||
}
|
||||
|
||||
return connStringUrl.String(), nil
|
||||
// Get the absolute path to the database
|
||||
// Here, we know for a fact that the ? is present
|
||||
parsedConnString = connStringUrl.String()
|
||||
idx := strings.IndexRune(parsedConnString, '?')
|
||||
dbPath, err = filepath.Abs(parsedConnString[len("file:"):idx])
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to determine absolute path to the database: %w", err)
|
||||
}
|
||||
|
||||
return parsedConnString, dbPath, nil
|
||||
}
|
||||
|
||||
// The official C implementation of SQLite allows some additional properties in the connection string
|
||||
@@ -350,6 +371,48 @@ func isSqliteInMemory(connString string) bool {
|
||||
return len(qs["mode"]) > 0 && qs["mode"][0] == "memory"
|
||||
}
|
||||
|
||||
// ensureSqliteTempDir ensures that SQLite has a directory where it can write temporary files if needed
|
||||
// The default directory may not be writable when using a container with a read-only root file system
|
||||
// See: https://www.sqlite.org/tempfiles.html
|
||||
func ensureSqliteTempDir(dbPath string) error {
|
||||
// Per docs, SQLite tries these folders in order (excluding those that aren't applicable to us):
|
||||
//
|
||||
// - The SQLITE_TMPDIR environment variable
|
||||
// - The TMPDIR environment variable
|
||||
// - /var/tmp
|
||||
// - /usr/tmp
|
||||
// - /tmp
|
||||
//
|
||||
// Source: https://www.sqlite.org/tempfiles.html#temporary_file_storage_locations
|
||||
//
|
||||
// First, let's check if SQLITE_TMPDIR or TMPDIR are set, in which case we trust the user has taken care of the problem already
|
||||
if os.Getenv("SQLITE_TMPDIR") != "" || os.Getenv("TMPDIR") != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Now, let's check if /var/tmp, /usr/tmp, or /tmp exist and are writable
|
||||
for _, dir := range []string{"/var/tmp", "/usr/tmp", "/tmp"} {
|
||||
ok, err := utils.IsWritableDir(dir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if %s is writable: %w", dir, err)
|
||||
}
|
||||
if ok {
|
||||
// We found a folder that's writable
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// If we're here, there's no temporary directory that's writable (not unusual for containers with a read-only root file system), so we set SQLITE_TMPDIR to the folder where the SQLite database is set
|
||||
err := os.Setenv("SQLITE_TMPDIR", dbPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set SQLITE_TMPDIR environmental variable: %w", err)
|
||||
}
|
||||
|
||||
slog.Debug("Set SQLITE_TMPDIR to the database directory", "path", dbPath)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getGormLogger() gormLogger.Interface {
|
||||
loggerOpts := make([]slogGorm.Option, 0, 5)
|
||||
loggerOpts = append(loggerOpts,
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pocket-id/pocket-id/backend/resources"
|
||||
@@ -136,3 +139,41 @@ func FileExists(path string) (bool, error) {
|
||||
}
|
||||
return !s.IsDir(), nil
|
||||
}
|
||||
|
||||
// IsWritableDir checks if a directory exists and is writable
|
||||
func IsWritableDir(dir string) (bool, error) {
|
||||
// Check if directory exists and it's actually a directory
|
||||
info, err := os.Stat(dir)
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
} else if err != nil {
|
||||
return false, fmt.Errorf("failed to stat '%s': %w", dir, err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Generate a random suffix for the test file to avoid conflicts
|
||||
randomBytes := make([]byte, 8)
|
||||
_, err = io.ReadFull(rand.Reader, randomBytes)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
|
||||
// Check if directory is writable by trying to create a temporary file
|
||||
testFile := filepath.Join(dir, ".pocketid_test_write_"+hex.EncodeToString(randomBytes))
|
||||
defer os.Remove(testFile)
|
||||
|
||||
file, err := os.Create(testFile)
|
||||
if err != nil {
|
||||
if os.IsPermission(err) || errors.Is(err, syscall.EROFS) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return false, fmt.Errorf("failed to create test file: %w", err)
|
||||
}
|
||||
|
||||
_ = file.Close()
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user