mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-11 15:53:00 +03:00
Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
131 lines
4.5 KiB
Go
131 lines
4.5 KiB
Go
package utils
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
|
|
"github.com/golang-migrate/migrate/v4"
|
|
"github.com/golang-migrate/migrate/v4/database"
|
|
postgresMigrate "github.com/golang-migrate/migrate/v4/database/postgres"
|
|
sqliteMigrate "github.com/golang-migrate/migrate/v4/database/sqlite3"
|
|
"github.com/golang-migrate/migrate/v4/source/iofs"
|
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
|
"github.com/pocket-id/pocket-id/backend/resources"
|
|
)
|
|
|
|
// MigrateDatabase applies database migrations using embedded migration files or fetches them from GitHub if a downgrade is detected.
|
|
func MigrateDatabase(sqlDb *sql.DB) error {
|
|
m, err := GetEmbeddedMigrateInstance(sqlDb)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get migrate instance: %w", err)
|
|
}
|
|
|
|
path := "migrations/" + string(common.EnvConfig.DbProvider)
|
|
requiredVersion, err := getRequiredMigrationVersion(path)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get last migration version: %w", err)
|
|
}
|
|
|
|
currentVersion, _, _ := m.Version()
|
|
if currentVersion > requiredVersion {
|
|
slog.Warn("Database version is newer than the application supports, possible downgrade detected", slog.Uint64("db_version", uint64(currentVersion)), slog.Uint64("app_version", uint64(requiredVersion)))
|
|
if !common.EnvConfig.AllowDowngrade {
|
|
return fmt.Errorf("database version (%d) is newer than application version (%d), downgrades are not allowed (set ALLOW_DOWNGRADE=true to enable)", currentVersion, requiredVersion)
|
|
}
|
|
slog.Info("Fetching migrations from GitHub to handle possible downgrades")
|
|
return migrateDatabaseFromGitHub(sqlDb, requiredVersion)
|
|
}
|
|
|
|
if err := m.Migrate(requiredVersion); err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
|
return fmt.Errorf("failed to apply embedded migrations: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetEmbeddedMigrateInstance creates a migrate.Migrate instance using embedded migration files.
|
|
func GetEmbeddedMigrateInstance(sqlDb *sql.DB) (*migrate.Migrate, error) {
|
|
path := "migrations/" + string(common.EnvConfig.DbProvider)
|
|
source, err := iofs.New(resources.FS, path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create embedded migration source: %w", err)
|
|
}
|
|
|
|
driver, err := newMigrationDriver(sqlDb, common.EnvConfig.DbProvider)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create migration driver: %w", err)
|
|
}
|
|
|
|
m, err := migrate.NewWithInstance("iofs", source, "pocket-id", driver)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create migration instance: %w", err)
|
|
}
|
|
return m, nil
|
|
}
|
|
|
|
// newMigrationDriver creates a database.Driver instance based on the given database provider.
|
|
func newMigrationDriver(sqlDb *sql.DB, dbProvider common.DbProvider) (driver database.Driver, err error) {
|
|
switch dbProvider {
|
|
case common.DbProviderSqlite:
|
|
driver, err = sqliteMigrate.WithInstance(sqlDb, &sqliteMigrate.Config{
|
|
NoTxWrap: true,
|
|
})
|
|
case common.DbProviderPostgres:
|
|
driver, err = postgresMigrate.WithInstance(sqlDb, &postgresMigrate.Config{})
|
|
default:
|
|
// Should never happen at this point
|
|
return nil, fmt.Errorf("unsupported database provider: %s", common.EnvConfig.DbProvider)
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create migration driver: %w", err)
|
|
}
|
|
|
|
return driver, nil
|
|
}
|
|
|
|
// migrateDatabaseFromGitHub applies database migrations fetched from GitHub to handle downgrades.
|
|
func migrateDatabaseFromGitHub(sqlDb *sql.DB, version uint) error {
|
|
srcURL := "github://pocket-id/pocket-id/backend/resources/migrations/" + string(common.EnvConfig.DbProvider)
|
|
|
|
driver, err := newMigrationDriver(sqlDb, common.EnvConfig.DbProvider)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create migration driver: %w", err)
|
|
}
|
|
|
|
m, err := migrate.NewWithDatabaseInstance(srcURL, "pocket-id", driver)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create GitHub migration instance: %w", err)
|
|
}
|
|
|
|
if err := m.Migrate(version); err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
|
return fmt.Errorf("failed to apply GitHub migrations: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// getRequiredMigrationVersion reads the embedded migration files and returns the highest version number found.
|
|
func getRequiredMigrationVersion(path string) (uint, error) {
|
|
entries, err := resources.FS.ReadDir(path)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("failed to read migration directory: %w", err)
|
|
}
|
|
|
|
var maxVersion uint
|
|
for _, entry := range entries {
|
|
if entry.IsDir() {
|
|
continue
|
|
}
|
|
name := entry.Name()
|
|
var version uint
|
|
n, err := fmt.Sscanf(name, "%d_", &version)
|
|
if err == nil && n == 1 {
|
|
if version > maxVersion {
|
|
maxVersion = version
|
|
}
|
|
}
|
|
}
|
|
|
|
return maxVersion, nil
|
|
}
|