mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-24 09:15:05 +03:00
feat: add CLI command for importing and exporting Pocket ID data (#998)
Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
217
backend/internal/service/export_service.go
Normal file
217
backend/internal/service/export_service.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"path/filepath"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/storage"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
)
|
||||
|
||||
// ExportService handles exporting Pocket ID data into a ZIP archive.
|
||||
type ExportService struct {
|
||||
db *gorm.DB
|
||||
storage storage.FileStorage
|
||||
}
|
||||
|
||||
func NewExportService(db *gorm.DB, storage storage.FileStorage) *ExportService {
|
||||
return &ExportService{
|
||||
db: db,
|
||||
storage: storage,
|
||||
}
|
||||
}
|
||||
|
||||
// ExportToZip performs the full export process and writes the ZIP data to the given writer.
|
||||
func (s *ExportService) ExportToZip(ctx context.Context, w io.Writer) error {
|
||||
dbData, err := s.extractDatabase()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.writeExportZipStream(ctx, w, dbData)
|
||||
}
|
||||
|
||||
// extractDatabase reads all tables into a DatabaseExport struct
|
||||
func (s *ExportService) extractDatabase() (DatabaseExport, error) {
|
||||
schema, err := utils.LoadDBSchemaTypes(s.db)
|
||||
if err != nil {
|
||||
return DatabaseExport{}, fmt.Errorf("failed to load schema types: %w", err)
|
||||
}
|
||||
|
||||
version, err := s.schemaVersion()
|
||||
if err != nil {
|
||||
return DatabaseExport{}, err
|
||||
}
|
||||
|
||||
out := DatabaseExport{
|
||||
Provider: s.db.Name(),
|
||||
Version: version,
|
||||
Tables: map[string][]map[string]any{},
|
||||
// These tables need to be inserted in a specific order because of foreign key constraints
|
||||
// Not all tables are listed here, because not all tables are order-dependent
|
||||
TableOrder: []string{"users", "user_groups", "oidc_clients"},
|
||||
}
|
||||
|
||||
for table := range schema {
|
||||
if table == "storage" || table == "schema_migrations" {
|
||||
continue
|
||||
}
|
||||
err = s.dumpTable(table, schema[table], &out)
|
||||
if err != nil {
|
||||
return DatabaseExport{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *ExportService) schemaVersion() (uint, error) {
|
||||
var version uint
|
||||
if err := s.db.Raw("SELECT version FROM schema_migrations").Row().Scan(&version); err != nil {
|
||||
return 0, fmt.Errorf("failed to query schema version: %w", err)
|
||||
}
|
||||
return version, nil
|
||||
}
|
||||
|
||||
// dumpTable selects all rows from a table and appends them to out.Tables
|
||||
func (s *ExportService) dumpTable(table string, types utils.DBSchemaTableTypes, out *DatabaseExport) error {
|
||||
rows, err := s.db.Raw("SELECT * FROM " + table).Rows()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read table %s: %w", table, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
cols, _ := rows.Columns()
|
||||
if len(cols) != len(types) {
|
||||
// Should never happen...
|
||||
return fmt.Errorf("mismatched columns in table (%d) and schema (%d)", len(cols), len(types))
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
vals := s.getScanValuesForTable(cols, types)
|
||||
err = rows.Scan(vals...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to scan row in table %s: %w", table, err)
|
||||
}
|
||||
|
||||
rowMap := make(map[string]any, len(cols))
|
||||
for i, col := range cols {
|
||||
rowMap[col] = vals[i]
|
||||
}
|
||||
|
||||
// Skip the app lock row in the kv table
|
||||
if table == "kv" {
|
||||
if keyPtr, ok := rowMap["key"].(*string); ok && keyPtr != nil && *keyPtr == lockKey {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
out.Tables[table] = append(out.Tables[table], rowMap)
|
||||
}
|
||||
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
func (s *ExportService) getScanValuesForTable(cols []string, types utils.DBSchemaTableTypes) []any {
|
||||
res := make([]any, len(cols))
|
||||
for i, col := range cols {
|
||||
// Store a pointer
|
||||
// Note: don't create a helper function for this switch, because it would return type "any" and mess everything up
|
||||
// If the column is nullable, we need a pointer to a pointer!
|
||||
switch types[col].Name {
|
||||
case "boolean", "bool":
|
||||
var x bool
|
||||
if types[col].Nullable {
|
||||
res[i] = utils.Ptr(utils.Ptr(x))
|
||||
} else {
|
||||
res[i] = utils.Ptr(x)
|
||||
}
|
||||
case "blob", "bytea", "jsonb":
|
||||
// Treat jsonb columns as binary too
|
||||
var x []byte
|
||||
if types[col].Nullable {
|
||||
res[i] = utils.Ptr(utils.Ptr(x))
|
||||
} else {
|
||||
res[i] = utils.Ptr(x)
|
||||
}
|
||||
case "timestamp", "timestamptz", "timestamp with time zone", "datetime":
|
||||
var x datatype.DateTime
|
||||
if types[col].Nullable {
|
||||
res[i] = utils.Ptr(utils.Ptr(x))
|
||||
} else {
|
||||
res[i] = utils.Ptr(x)
|
||||
}
|
||||
case "integer", "int", "bigint":
|
||||
var x int64
|
||||
if types[col].Nullable {
|
||||
res[i] = utils.Ptr(utils.Ptr(x))
|
||||
} else {
|
||||
res[i] = utils.Ptr(x)
|
||||
}
|
||||
default:
|
||||
// Treat everything else as a string (including the "numeric" type)
|
||||
var x string
|
||||
if types[col].Nullable {
|
||||
res[i] = utils.Ptr(utils.Ptr(x))
|
||||
} else {
|
||||
res[i] = utils.Ptr(x)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (s *ExportService) writeExportZipStream(ctx context.Context, w io.Writer, dbData DatabaseExport) error {
|
||||
zipWriter := zip.NewWriter(w)
|
||||
|
||||
// Add database.json
|
||||
jsonWriter, err := zipWriter.Create("database.json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create database.json in zip: %w", err)
|
||||
}
|
||||
|
||||
jsonEncoder := json.NewEncoder(jsonWriter)
|
||||
jsonEncoder.SetEscapeHTML(false)
|
||||
|
||||
if err := jsonEncoder.Encode(dbData); err != nil {
|
||||
return fmt.Errorf("failed to encode database.json: %w", err)
|
||||
}
|
||||
|
||||
// Add uploaded files
|
||||
if err := s.addUploadsToZip(ctx, zipWriter); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return zipWriter.Close()
|
||||
}
|
||||
|
||||
// addUploadsToZip adds all files from the storage to the ZIP archive under the "uploads/" directory
|
||||
func (s *ExportService) addUploadsToZip(ctx context.Context, zipWriter *zip.Writer) error {
|
||||
return s.storage.Walk(ctx, "/", func(p storage.ObjectInfo) error {
|
||||
zipPath := filepath.Join("uploads", p.Path)
|
||||
|
||||
w, err := zipWriter.Create(zipPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create zip entry for %s: %w", zipPath, err)
|
||||
}
|
||||
|
||||
f, _, err := s.storage.Open(ctx, p.Path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open file %s: %w", zipPath, err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := io.Copy(w, f); err != nil {
|
||||
return fmt.Errorf("failed to copy file %s into zip: %w", zipPath, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user