mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-09 23:02:56 +03:00
Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
192 lines
5.4 KiB
Go
192 lines
5.4 KiB
Go
package cmds
|
|
|
|
import (
|
|
"archive/zip"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"path/filepath"
|
|
"time"
|
|
|
|
"github.com/spf13/cobra"
|
|
"gorm.io/gorm"
|
|
|
|
"github.com/pocket-id/pocket-id/backend/internal/bootstrap"
|
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
|
"github.com/pocket-id/pocket-id/backend/internal/service"
|
|
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
|
)
|
|
|
|
type importFlags struct {
|
|
Path string
|
|
Yes bool
|
|
ForcefullyAcquireLock bool
|
|
}
|
|
|
|
func init() {
|
|
var flags importFlags
|
|
|
|
importCmd := &cobra.Command{
|
|
Use: "import",
|
|
Short: "Imports all data of Pocket ID from a ZIP file",
|
|
RunE: func(cmd *cobra.Command, args []string) error {
|
|
return runImport(cmd.Context(), flags)
|
|
},
|
|
}
|
|
|
|
importCmd.Flags().StringVarP(&flags.Path, "path", "p", "pocket-id-export.zip", "Path to the ZIP file to import the data from, or '-' to read from stdin")
|
|
importCmd.Flags().BoolVarP(&flags.Yes, "yes", "y", false, "Skip confirmation prompts")
|
|
importCmd.Flags().BoolVarP(&flags.ForcefullyAcquireLock, "forcefully-acquire-lock", "", false, "Forcefully acquire the application lock by terminating the Pocket ID instance")
|
|
|
|
rootCmd.AddCommand(importCmd)
|
|
}
|
|
|
|
// runImport handles the high-level orchestration of the import process
|
|
func runImport(ctx context.Context, flags importFlags) error {
|
|
if !flags.Yes {
|
|
ok, err := askForConfirmation()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get confirmation: %w", err)
|
|
}
|
|
if !ok {
|
|
fmt.Println("Aborted")
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
var (
|
|
zipReader *zip.ReadCloser
|
|
cleanup func()
|
|
err error
|
|
)
|
|
|
|
if flags.Path == "-" {
|
|
zipReader, cleanup, err = readZipFromStdin()
|
|
defer cleanup()
|
|
} else {
|
|
zipReader, err = zip.OpenReader(flags.Path)
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("failed to open zip: %w", err)
|
|
}
|
|
defer zipReader.Close()
|
|
|
|
db, err := bootstrap.ConnectDatabase()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = acquireImportLock(ctx, db, flags.ForcefullyAcquireLock)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
storage, err := bootstrap.InitStorage(ctx, db)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to initialize storage: %w", err)
|
|
}
|
|
|
|
importService := service.NewImportService(db, storage)
|
|
err = importService.ImportFromZip(ctx, &zipReader.Reader)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to import data from zip: %w", err)
|
|
}
|
|
|
|
fmt.Println("Import completed successfully.")
|
|
return nil
|
|
}
|
|
|
|
func acquireImportLock(ctx context.Context, db *gorm.DB, force bool) error {
|
|
// Check if the kv table exists, in case we are starting from an empty database
|
|
exists, err := utils.DBTableExists(db, "kv")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to check if kv table exists: %w", err)
|
|
}
|
|
if !exists {
|
|
// This either means the database is empty, or the import is into an old version of PocketID that doesn't support locks
|
|
// In either case, there's no lock to acquire
|
|
fmt.Println("Could not acquire a lock because the 'kv' table does not exist. This is fine if you're importing into a new database, but make sure that there isn't an instance of Pocket ID currently running and using the same database.")
|
|
return nil
|
|
}
|
|
|
|
// Note that we do not call a deferred Release if the data was imported
|
|
// This is because we are overriding the contents of the database, so the lock is automatically lost
|
|
appLockService := service.NewAppLockService(db)
|
|
|
|
opCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
|
defer cancel()
|
|
|
|
waitUntil, err := appLockService.Acquire(opCtx, force)
|
|
if err != nil {
|
|
if errors.Is(err, service.ErrLockUnavailable) {
|
|
//nolint:staticcheck
|
|
return errors.New("Pocket ID must be stopped before importing data; please stop the running instance or run with --forcefully-acquire-lock to terminate the other instance")
|
|
}
|
|
return fmt.Errorf("failed to acquire application lock: %w", err)
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-time.After(time.Until(waitUntil)):
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func askForConfirmation() (bool, error) {
|
|
fmt.Println("WARNING: This feature is experimental and may not work correctly. Please create a backup before proceeding and report any issues you encounter.")
|
|
fmt.Println()
|
|
fmt.Println("WARNING: Import will erase all existing data at the following locations:")
|
|
fmt.Printf("Database: %s\n", absolutePathOrOriginal(common.EnvConfig.DbConnectionString))
|
|
fmt.Printf("Uploads Path: %s\n", absolutePathOrOriginal(common.EnvConfig.UploadPath))
|
|
|
|
ok, err := utils.PromptForConfirmation("Do you want to continue?")
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
return ok, nil
|
|
}
|
|
|
|
// absolutePathOrOriginal returns the absolute path of the given path, or the original if it fails
|
|
func absolutePathOrOriginal(path string) string {
|
|
abs, err := filepath.Abs(path)
|
|
if err != nil {
|
|
return path
|
|
}
|
|
return abs
|
|
}
|
|
|
|
func readZipFromStdin() (*zip.ReadCloser, func(), error) {
|
|
tmpFile, err := os.CreateTemp("", "pocket-id-import-*.zip")
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to create temporary file: %w", err)
|
|
}
|
|
|
|
cleanup := func() {
|
|
_ = os.Remove(tmpFile.Name())
|
|
}
|
|
|
|
if _, err := io.Copy(tmpFile, os.Stdin); err != nil {
|
|
tmpFile.Close()
|
|
cleanup()
|
|
return nil, nil, fmt.Errorf("failed to read data from stdin: %w", err)
|
|
}
|
|
|
|
if err := tmpFile.Close(); err != nil {
|
|
cleanup()
|
|
return nil, nil, fmt.Errorf("failed to close temporary file: %w", err)
|
|
}
|
|
|
|
r, err := zip.OpenReader(tmpFile.Name())
|
|
if err != nil {
|
|
cleanup()
|
|
return nil, nil, err
|
|
}
|
|
|
|
return r, cleanup, nil
|
|
}
|