Files
pocket-id-pocket-id-1/backend/internal/service/export_service.go
2025-11-26 10:38:15 +01:00

218 lines
5.7 KiB
Go

package service
import (
"archive/zip"
"context"
"encoding/json"
"fmt"
"io"
"path/filepath"
"gorm.io/gorm"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/storage"
"github.com/pocket-id/pocket-id/backend/internal/utils"
)
// ExportService handles exporting Pocket ID data into a ZIP archive.
type ExportService struct {
db *gorm.DB
storage storage.FileStorage
}
func NewExportService(db *gorm.DB, storage storage.FileStorage) *ExportService {
return &ExportService{
db: db,
storage: storage,
}
}
// ExportToZip performs the full export process and writes the ZIP data to the given writer.
func (s *ExportService) ExportToZip(ctx context.Context, w io.Writer) error {
dbData, err := s.extractDatabase()
if err != nil {
return err
}
return s.writeExportZipStream(ctx, w, dbData)
}
// extractDatabase reads all tables into a DatabaseExport struct
func (s *ExportService) extractDatabase() (DatabaseExport, error) {
schema, err := utils.LoadDBSchemaTypes(s.db)
if err != nil {
return DatabaseExport{}, fmt.Errorf("failed to load schema types: %w", err)
}
version, err := s.schemaVersion()
if err != nil {
return DatabaseExport{}, err
}
out := DatabaseExport{
Provider: s.db.Name(),
Version: version,
Tables: map[string][]map[string]any{},
// These tables need to be inserted in a specific order because of foreign key constraints
// Not all tables are listed here, because not all tables are order-dependent
TableOrder: []string{"users", "user_groups", "oidc_clients"},
}
for table := range schema {
if table == "storage" || table == "schema_migrations" {
continue
}
err = s.dumpTable(table, schema[table], &out)
if err != nil {
return DatabaseExport{}, err
}
}
return out, nil
}
func (s *ExportService) schemaVersion() (uint, error) {
var version uint
if err := s.db.Raw("SELECT version FROM schema_migrations").Row().Scan(&version); err != nil {
return 0, fmt.Errorf("failed to query schema version: %w", err)
}
return version, nil
}
// dumpTable selects all rows from a table and appends them to out.Tables
func (s *ExportService) dumpTable(table string, types utils.DBSchemaTableTypes, out *DatabaseExport) error {
rows, err := s.db.Raw("SELECT * FROM " + table).Rows()
if err != nil {
return fmt.Errorf("failed to read table %s: %w", table, err)
}
defer rows.Close()
cols, _ := rows.Columns()
if len(cols) != len(types) {
// Should never happen...
return fmt.Errorf("mismatched columns in table (%d) and schema (%d)", len(cols), len(types))
}
for rows.Next() {
vals := s.getScanValuesForTable(cols, types)
err = rows.Scan(vals...)
if err != nil {
return fmt.Errorf("failed to scan row in table %s: %w", table, err)
}
rowMap := make(map[string]any, len(cols))
for i, col := range cols {
rowMap[col] = vals[i]
}
// Skip the app lock row in the kv table
if table == "kv" {
if keyPtr, ok := rowMap["key"].(*string); ok && keyPtr != nil && *keyPtr == lockKey {
continue
}
}
out.Tables[table] = append(out.Tables[table], rowMap)
}
return rows.Err()
}
func (s *ExportService) getScanValuesForTable(cols []string, types utils.DBSchemaTableTypes) []any {
res := make([]any, len(cols))
for i, col := range cols {
// Store a pointer
// Note: don't create a helper function for this switch, because it would return type "any" and mess everything up
// If the column is nullable, we need a pointer to a pointer!
switch types[col].Name {
case "boolean", "bool":
var x bool
if types[col].Nullable {
res[i] = utils.Ptr(utils.Ptr(x))
} else {
res[i] = utils.Ptr(x)
}
case "blob", "bytea", "jsonb":
// Treat jsonb columns as binary too
var x []byte
if types[col].Nullable {
res[i] = utils.Ptr(utils.Ptr(x))
} else {
res[i] = utils.Ptr(x)
}
case "timestamp", "timestamptz", "timestamp with time zone", "datetime":
var x datatype.DateTime
if types[col].Nullable {
res[i] = utils.Ptr(utils.Ptr(x))
} else {
res[i] = utils.Ptr(x)
}
case "integer", "int", "bigint":
var x int64
if types[col].Nullable {
res[i] = utils.Ptr(utils.Ptr(x))
} else {
res[i] = utils.Ptr(x)
}
default:
// Treat everything else as a string (including the "numeric" type)
var x string
if types[col].Nullable {
res[i] = utils.Ptr(utils.Ptr(x))
} else {
res[i] = utils.Ptr(x)
}
}
}
return res
}
func (s *ExportService) writeExportZipStream(ctx context.Context, w io.Writer, dbData DatabaseExport) error {
zipWriter := zip.NewWriter(w)
// Add database.json
jsonWriter, err := zipWriter.Create("database.json")
if err != nil {
return fmt.Errorf("failed to create database.json in zip: %w", err)
}
jsonEncoder := json.NewEncoder(jsonWriter)
jsonEncoder.SetEscapeHTML(false)
if err := jsonEncoder.Encode(dbData); err != nil {
return fmt.Errorf("failed to encode database.json: %w", err)
}
// Add uploaded files
if err := s.addUploadsToZip(ctx, zipWriter); err != nil {
return err
}
return zipWriter.Close()
}
// addUploadsToZip adds all files from the storage to the ZIP archive under the "uploads/" directory
func (s *ExportService) addUploadsToZip(ctx context.Context, zipWriter *zip.Writer) error {
return s.storage.Walk(ctx, "/", func(p storage.ObjectInfo) error {
zipPath := filepath.Join("uploads", p.Path)
w, err := zipWriter.Create(zipPath)
if err != nil {
return fmt.Errorf("failed to create zip entry for %s: %w", zipPath, err)
}
f, _, err := s.storage.Open(ctx, p.Path)
if err != nil {
return fmt.Errorf("failed to open file %s: %w", zipPath, err)
}
defer f.Close()
if _, err := io.Copy(w, f); err != nil {
return fmt.Errorf("failed to copy file %s into zip: %w", zipPath, err)
}
return nil
})
}