diff --git a/cmd/connect.go b/cmd/connect.go index 87074c24..5ec46f40 100644 --- a/cmd/connect.go +++ b/cmd/connect.go @@ -78,7 +78,7 @@ func connectAndRunTUI(cmd *cobra.Command, target string) error { } } - m := newRemoteRootModel(port, service, serverHost) + m := newRemoteRootModel(port, service, serverHost, baseURL, token) p := tea.NewProgram(m) go func() { @@ -93,8 +93,9 @@ func connectAndRunTUI(cmd *cobra.Command, target string) error { return nil } -func newRemoteRootModel(port int, service core.DownloadService, serverHost string) tui.RootModel { +func newRemoteRootModel(port int, service core.DownloadService, serverHost string, baseURL string, token string) tui.RootModel { m := tui.InitialRootModel(port, Version, service, nil, false) + m.Transfer = core.NewRemoteTransferService(baseURL, token) m.ServerHost = serverHost m.IsRemote = true return m diff --git a/cmd/connect_test.go b/cmd/connect_test.go index dd3c473e..304bb8fe 100644 --- a/cmd/connect_test.go +++ b/cmd/connect_test.go @@ -65,7 +65,7 @@ func (f *fakeRemoteDownloadService) GetStatus(id string) (*types.DownloadStatus, func (f *fakeRemoteDownloadService) Shutdown() error { return nil } func TestNewRemoteRootModel_UsesNilOrchestrator(t *testing.T) { - m := newRemoteRootModel(1700, nil, "example.com") + m := newRemoteRootModel(1700, nil, "example.com", "https://example.com:1700", "token") if m.Orchestrator != nil { t.Fatal("expected remote root model to use nil orchestrator") @@ -80,7 +80,7 @@ func TestNewRemoteRootModel_UsesNilOrchestrator(t *testing.T) { func TestNewRemoteRootModel_DownloadRequestUsesServiceAdd(t *testing.T) { service := &fakeRemoteDownloadService{} - m := newRemoteRootModel(1700, service, "example.com") + m := newRemoteRootModel(1700, service, "example.com", "https://example.com:1700", "token") m.Settings.Extension.ExtensionPrompt = false m.Settings.General.WarnOnDuplicate = false diff --git a/cmd/export.go b/cmd/export.go new file mode 100644 index 00000000..8301c784 --- /dev/null +++ b/cmd/export.go @@ -0,0 +1,54 @@ +package cmd + +import ( + "context" + "fmt" + + "github.com/SurgeDM/Surge/internal/backup" + "github.com/spf13/cobra" +) + +var exportCmd = &cobra.Command{ + Use: "export ", + Short: "Export Surge data to a bundle", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + if err := initializeGlobalState(); err != nil { + return err + } + + includeLogs, _ := cmd.Flags().GetBool("include-logs") + includePartials, _ := cmd.Flags().GetBool("include-partials") + leavePaused, _ := cmd.Flags().GetBool("leave-paused") + jsonOutput, _ := cmd.Flags().GetBool("json") + + transfer, _, err := resolveTransferService() + if err != nil { + return err + } + + manifest, err := exportBundle(context.Background(), transfer, args[0], backup.ExportOptions{ + IncludeLogs: includeLogs, + IncludePartials: includePartials, + LeavePaused: leavePaused, + }) + if err != nil { + return err + } + + if jsonOutput { + return printJSON(manifest) + } + fmt.Printf("Exported bundle to %s\n", ensureExportPath(args[0])) + return nil + }, +} + +func init() { + rootCmd.AddCommand(exportCmd) + exportCmd.Flags().Bool("include-logs", false, "Include Surge log files in the export bundle") + exportCmd.Flags().Bool("include-partials", false, "Include paused .surge partial files for resumable restore") + exportCmd.Flags().Bool("leave-paused", false, "Leave downloads paused after export") + exportCmd.Flags().Bool("json", false, "Output manifest as JSON") +} + diff --git a/cmd/http_api.go b/cmd/http_api.go index 2a1e51df..ff7ffe7d 100644 --- a/cmd/http_api.go +++ b/cmd/http_api.go @@ -143,6 +143,8 @@ func registerHTTPRoutes(mux *http.ServeMux, port int, defaultOutputDir string, s writeJSONResponse(w, http.StatusOK, map[string]string{"status": "updated", "id": id, "url": newURL}) }))) + + registerTransferRoutes(mux, service) } func eventsHandler(service core.DownloadService) http.HandlerFunc { diff --git a/cmd/import.go b/cmd/import.go new file mode 100644 index 00000000..2fe6711e --- /dev/null +++ b/cmd/import.go @@ -0,0 +1,69 @@ +package cmd + +import ( + "context" + + "github.com/SurgeDM/Surge/internal/backup" + "github.com/spf13/cobra" +) + +var importCmd = &cobra.Command{ + Use: "import ", + Short: "Preview or import a Surge bundle", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + if err := initializeGlobalState(); err != nil { + return err + } + + apply, _ := cmd.Flags().GetBool("apply") + replace, _ := cmd.Flags().GetBool("replace") + rootDir, _ := cmd.Flags().GetString("root") + jsonOutput, _ := cmd.Flags().GetBool("json") + + transfer, isRemote, err := resolveTransferService() + if err != nil { + return err + } + + previewOpts := backup.ImportOptions{ + RootDir: rootDir, + Replace: replace, + } + preview, err := previewBundle(context.Background(), transfer, args[0], previewOpts) + if err != nil { + return err + } + + if !apply { + if jsonOutput { + return printJSON(preview) + } + printImportPreview(preview) + return nil + } + + opts := previewOpts + if isRemote { + opts.SessionID = preview.SessionID + } + result, err := applyBundle(context.Background(), transfer, args[0], opts) + if err != nil { + return err + } + + if jsonOutput { + return printJSON(result) + } + printImportResult(result) + return nil + }, +} + +func init() { + rootCmd.AddCommand(importCmd) + importCmd.Flags().Bool("apply", false, "Apply the import after preview succeeds") + importCmd.Flags().Bool("replace", false, "Replace existing Surge state instead of merging") + importCmd.Flags().String("root", "", "Root directory for rebased imported paths") + importCmd.Flags().Bool("json", false, "Output preview/result as JSON") +} diff --git a/cmd/root.go b/cmd/root.go index 4d8099d1..cf3e2177 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -462,6 +462,7 @@ func startTUI(port int, exitWhenDone bool, noResume bool) error { m := tui.InitialRootModel(port, Version, GlobalService, currentLifecycle(), noResume) m = m.WithEnqueueContext(currentEnqueueContext(), currentEnqueueCancel()) + m.Transfer = core.NewLocalTransferService(GlobalService, Version) m.ServerHost = serverBindHost if m.ServerHost == "" { m.ServerHost = "127.0.0.1" diff --git a/cmd/transfer_api.go b/cmd/transfer_api.go new file mode 100644 index 00000000..a1b32e63 --- /dev/null +++ b/cmd/transfer_api.go @@ -0,0 +1,191 @@ +package cmd + +import ( + "context" + "crypto/rand" + "encoding/hex" + "encoding/json" + "io" + "net/http" + "net/url" + "os" + "strings" + "sync" + "time" + + "github.com/SurgeDM/Surge/internal/backup" + "github.com/SurgeDM/Surge/internal/core" +) + +var maxImportPreviewSize int64 = 512 * 1024 * 1024 + +type stagedImportSession struct { + Path string + CreatedAt time.Time +} + +var importSessionStore = struct { + mu sync.Mutex + items map[string]stagedImportSession +}{ + items: make(map[string]stagedImportSession), +} + +func cleanupImportSessions() { + cutoff := time.Now().Add(-1 * time.Hour) + importSessionStore.mu.Lock() + defer importSessionStore.mu.Unlock() + for id, session := range importSessionStore.items { + if session.CreatedAt.After(cutoff) { + continue + } + _ = os.Remove(session.Path) + delete(importSessionStore.items, id) + } +} + +func newImportSessionID() (string, error) { + token := make([]byte, 16) + if _, err := rand.Read(token); err != nil { + return "", err + } + return hex.EncodeToString(token), nil +} + +func registerTransferRoutes(mux *http.ServeMux, service core.DownloadService) { + mux.HandleFunc("/data/export", requireMethod(http.MethodPost, func(w http.ResponseWriter, r *http.Request) { + cleanupImportSessions() + var opts backup.ExportOptions + if r.Body != nil { + if err := json.NewDecoder(r.Body).Decode(&opts); err != nil && err != io.EOF { + http.Error(w, "invalid export request", http.StatusBadRequest) + return + } + } + + transfer := core.NewLocalTransferService(service, Version) + tmpFile, err := os.CreateTemp("", "surge-export-*.zip") + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer func() { + _ = tmpFile.Close() + _ = os.Remove(tmpFile.Name()) + }() + + manifest, err := transfer.Export(r.Context(), opts, tmpFile) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if _, err := tmpFile.Seek(0, 0); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + manifestBytes, _ := json.Marshal(manifest) + w.Header().Set("Content-Type", "application/octet-stream") + w.Header().Set("Content-Disposition", "attachment; filename=\"surge-export.surge-export\"") + w.Header().Set("X-Surge-Manifest", url.QueryEscape(string(manifestBytes))) + if _, err := io.Copy(w, tmpFile); err != nil { + return + } + })) + + mux.HandleFunc("/data/import/preview", requireMethod(http.MethodPost, func(w http.ResponseWriter, r *http.Request) { + cleanupImportSessions() + opts := backup.ImportOptions{ + RootDir: strings.TrimSpace(r.URL.Query().Get("root_dir")), + Replace: strings.EqualFold(strings.TrimSpace(r.URL.Query().Get("replace")), "true"), + } + tmpFile, err := os.CreateTemp("", "surge-import-preview-*.zip") + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer func() { _ = tmpFile.Close() }() + r.Body = http.MaxBytesReader(w, r.Body, maxImportPreviewSize) + if _, err := io.Copy(tmpFile, r.Body); err != nil { + _ = os.Remove(tmpFile.Name()) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if _, err := tmpFile.Seek(0, 0); err != nil { + _ = os.Remove(tmpFile.Name()) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + transfer := core.NewLocalTransferService(service, Version) + preview, err := transfer.PreviewImport(context.Background(), tmpFile, opts) + if err != nil { + _ = os.Remove(tmpFile.Name()) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + sessionID, err := newImportSessionID() + if err != nil { + _ = os.Remove(tmpFile.Name()) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + importSessionStore.mu.Lock() + importSessionStore.items[sessionID] = stagedImportSession{ + Path: tmpFile.Name(), + CreatedAt: time.Now(), + } + importSessionStore.mu.Unlock() + + preview.SessionID = sessionID + writeJSONResponse(w, http.StatusOK, preview) + })) + + mux.HandleFunc("/data/import/apply", requireMethod(http.MethodPost, func(w http.ResponseWriter, r *http.Request) { + cleanupImportSessions() + var req struct { + SessionID string `json:"session_id"` + RootDir string `json:"root_dir"` + Replace bool `json:"replace"` + } + if err := decodeJSONBody(r, &req); err != nil { + http.Error(w, "invalid request body", http.StatusBadRequest) + return + } + if strings.TrimSpace(req.SessionID) == "" { + http.Error(w, "missing session_id", http.StatusBadRequest) + return + } + + importSessionStore.mu.Lock() + session, ok := importSessionStore.items[req.SessionID] + if ok { + delete(importSessionStore.items, req.SessionID) + } + importSessionStore.mu.Unlock() + if !ok { + http.Error(w, "import session not found", http.StatusNotFound) + return + } + defer func() { _ = os.Remove(session.Path) }() + + file, err := os.Open(session.Path) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer func() { _ = file.Close() }() + + transfer := core.NewLocalTransferService(service, Version) + result, err := transfer.ApplyImport(r.Context(), file, backup.ImportOptions{ + RootDir: req.RootDir, + Replace: req.Replace, + }) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSONResponse(w, http.StatusOK, result) + })) +} diff --git a/cmd/transfer_api_test.go b/cmd/transfer_api_test.go new file mode 100644 index 00000000..44506124 --- /dev/null +++ b/cmd/transfer_api_test.go @@ -0,0 +1,94 @@ +package cmd + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + + "github.com/SurgeDM/Surge/internal/backup" +) + +func clearImportSessionStoreForTest(t *testing.T) { + t.Helper() + + importSessionStore.mu.Lock() + defer importSessionStore.mu.Unlock() + for id, session := range importSessionStore.items { + _ = os.Remove(session.Path) + delete(importSessionStore.items, id) + } +} + +func createImportBundleForTest(t *testing.T) []byte { + t.Helper() + + setupXDGEnvIsolation(t) + if err := initializeGlobalState(); err != nil { + t.Fatalf("initializeGlobalState failed: %v", err) + } + + var buf bytes.Buffer + if _, err := backup.Export(context.Background(), &buf, backup.ExportOptions{}, nil); err != nil { + t.Fatalf("Export failed: %v", err) + } + return buf.Bytes() +} + +func TestTransferPreviewEndpoint_EnforcesUploadSizeLimit(t *testing.T) { + clearImportSessionStoreForTest(t) + + originalLimit := maxImportPreviewSize + maxImportPreviewSize = 32 + t.Cleanup(func() { + maxImportPreviewSize = originalLimit + clearImportSessionStoreForTest(t) + }) + + mux := http.NewServeMux() + registerHTTPRoutes(mux, 0, "", &httpAPITestService{}) + + req := httptest.NewRequest(http.MethodPost, "/data/import/preview", strings.NewReader(strings.Repeat("a", 64))) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rec.Code == http.StatusOK { + t.Fatalf("expected oversized preview upload to fail, got %d", rec.Code) + } +} + +func TestTransferPreviewEndpoint_GeneratesRandomSessionIDs(t *testing.T) { + clearImportSessionStoreForTest(t) + t.Cleanup(func() { + clearImportSessionStoreForTest(t) + }) + + bundle := createImportBundleForTest(t) + + mux := http.NewServeMux() + registerHTTPRoutes(mux, 0, "", &httpAPITestService{}) + + var previews [2]backup.ImportPreview + for i := range previews { + req := httptest.NewRequest(http.MethodPost, "/data/import/preview", bytes.NewReader(bundle)) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("request %d returned %d: %s", i, rec.Code, rec.Body.String()) + } + if err := json.Unmarshal(rec.Body.Bytes(), &previews[i]); err != nil { + t.Fatalf("unmarshal preview %d failed: %v", i, err) + } + if len(previews[i].SessionID) != 32 { + t.Fatalf("session id %d length=%d, want 32", i, len(previews[i].SessionID)) + } + } + + if previews[0].SessionID == previews[1].SessionID { + t.Fatal("expected preview session IDs to differ") + } +} diff --git a/cmd/transfer_utils.go b/cmd/transfer_utils.go new file mode 100644 index 00000000..88626174 --- /dev/null +++ b/cmd/transfer_utils.go @@ -0,0 +1,105 @@ +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/SurgeDM/Surge/internal/backup" + "github.com/SurgeDM/Surge/internal/core" +) + +func resolveTransferService() (core.TransferService, bool, error) { + baseURL, token, err := resolveAPIConnection(false) + if err != nil { + return nil, false, err + } + if baseURL != "" { + return core.NewRemoteTransferService(baseURL, token), true, nil + } + return core.NewLocalTransferService(GlobalService, Version), false, nil +} + +func ensureExportPath(path string) string { + if strings.TrimSpace(path) == "" { + return "surge-export" + backup.BundleExtension + } + if strings.HasSuffix(path, backup.BundleExtension) { + return path + } + return path + backup.BundleExtension +} + +func printJSON(v interface{}) error { + data, err := json.MarshalIndent(v, "", " ") + if err != nil { + return err + } + fmt.Println(string(data)) + return nil +} + +func printImportPreview(preview *backup.ImportPreview) { + fmt.Printf("Root: %s\n", preview.RootDir) + fmt.Printf("Imports: %v\n", preview.ImportsByStatus) + fmt.Printf("Duplicates skipped: %d\n", preview.DuplicatesSkipped) + fmt.Printf("Renamed items: %d\n", preview.RenamedItems) + fmt.Printf("Downgraded to queue: %d\n", preview.ResumableJobsDowngradedToQueue) +} + +func printImportResult(result *backup.ImportResult) { + if result == nil { + return + } + if result.Preview != nil { + printImportPreview(result.Preview) + } + fmt.Printf("Imported: %d\n", result.Imported) + if result.LogsRestored > 0 { + fmt.Printf("Logs restored: %d\n", result.LogsRestored) + } +} + +func exportBundle(ctx context.Context, transfer core.TransferService, path string, opts backup.ExportOptions) (*backup.Manifest, error) { + target := ensureExportPath(path) + if err := os.MkdirAll(filepath.Dir(filepath.Clean(target)), 0o755); err != nil && filepath.Dir(filepath.Clean(target)) != "." { + return nil, err + } + file, err := os.Create(target) + if err != nil { + return nil, err + } + defer func() { _ = file.Close() }() + + manifest, err := transfer.Export(ctx, opts, file) + if err != nil { + return nil, err + } + return manifest, nil +} + +func previewBundle(ctx context.Context, transfer core.TransferService, path string, opts backup.ImportOptions) (*backup.ImportPreview, error) { + file, err := os.Open(filepath.Clean(path)) + if err != nil { + return nil, err + } + defer func() { _ = file.Close() }() + return transfer.PreviewImport(ctx, file, opts) +} + +func applyBundle(ctx context.Context, transfer core.TransferService, path string, opts backup.ImportOptions) (*backup.ImportResult, error) { + var src io.Reader + if strings.TrimSpace(opts.SessionID) == "" { + file, err := os.Open(filepath.Clean(path)) + if err != nil { + return nil, err + } + defer func() { _ = file.Close() }() + src = file + } + return transfer.ApplyImport(ctx, src, opts) +} diff --git a/internal/backup/backup.go b/internal/backup/backup.go new file mode 100644 index 00000000..afc46174 --- /dev/null +++ b/internal/backup/backup.go @@ -0,0 +1,956 @@ +package backup + +import ( + "archive/zip" + "context" + "crypto/sha256" + "database/sql" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "slices" + "strings" + "time" + + "github.com/SurgeDM/Surge/internal/config" + "github.com/SurgeDM/Surge/internal/engine/state" + "github.com/SurgeDM/Surge/internal/engine/types" + "github.com/SurgeDM/Surge/internal/processing" + "github.com/SurgeDM/Surge/internal/utils" + "github.com/google/uuid" +) + +type plannedImport struct { + Record PortableDownload + FinalID string + FinalPath string + FinalStatus string + Skip bool + UsePartial bool +} + +func Export(ctx context.Context, w io.Writer, opts ExportOptions, controller Controller) (*Manifest, error) { + if w == nil { + return nil, fmt.Errorf("export writer is required") + } + + resumeIDs, err := quiesceActiveDownloads(ctx, controller) + if err != nil { + return nil, err + } + if len(resumeIDs) > 0 && !opts.LeavePaused && controller != nil { + defer resumeDownloads(controller, resumeIDs) + } + + settings, err := config.LoadSettings() + if err != nil || settings == nil { + settings = config.DefaultSettings() + } + + records, counts, err := exportDownloads(opts.IncludePartials) + if err != nil { + return nil, err + } + + settingsBytes, err := json.MarshalIndent(settings, "", " ") + if err != nil { + return nil, err + } + downloadsBytes, err := json.MarshalIndent(records, "", " ") + if err != nil { + return nil, err + } + + type fileSource struct { + name string + size int64 + sum string + data []byte + path string + } + + var files []fileSource + addBytesSource := func(name string, data []byte) { + sum := sha256.Sum256(data) + files = append(files, fileSource{ + name: name, + size: int64(len(data)), + sum: hex.EncodeToString(sum[:]), + data: data, + }) + } + + addFileSource := func(name, path string) error { + sum, size, err := fileSHA256(path) + if err != nil { + return err + } + files = append(files, fileSource{ + name: name, + size: size, + sum: sum, + path: path, + }) + return nil + } + + addBytesSource("settings/settings.json", settingsBytes) + addBytesSource("state/downloads.json", downloadsBytes) + + if opts.IncludePartials { + for _, rec := range records { + if rec.Resumable == nil || rec.Resumable.PartialFile == "" { + continue + } + sourcePath := rec.DestPath + types.IncompleteSuffix + if err := addFileSource(rec.Resumable.PartialFile, sourcePath); err != nil { + return nil, err + } + } + } + + if opts.IncludeLogs { + logFiles, err := collectLogFiles() + if err != nil { + return nil, err + } + for _, path := range logFiles { + relName := filepath.ToSlash(filepath.Join("logs", filepath.Base(path))) + if err := addFileSource(relName, path); err != nil { + return nil, err + } + } + } + + manifest := &Manifest{ + SchemaVersion: SchemaVersion, + CreatedAt: time.Now().UTC(), + SurgeVersion: strings.TrimSpace(opts.AppVersion), + OriginalDefaultDownloadDir: settings.General.DefaultDownloadDir, + IncludeLogs: opts.IncludeLogs, + IncludePartials: opts.IncludePartials, + Counts: counts, + } + for _, file := range files { + manifest.Files = append(manifest.Files, ManifestFile{ + Path: filepath.ToSlash(file.name), + SHA256: file.sum, + Size: file.size, + }) + } + + manifestBytes, err := json.MarshalIndent(manifest, "", " ") + if err != nil { + return nil, err + } + + zw := zip.NewWriter(w) + if err := writeZipBytes(zw, "manifest.json", manifestBytes); err != nil { + _ = zw.Close() + return nil, err + } + for _, file := range files { + if len(file.data) > 0 { + if err := writeZipBytes(zw, file.name, file.data); err != nil { + _ = zw.Close() + return nil, err + } + continue + } + if err := writeZipFile(zw, file.name, file.path); err != nil { + _ = zw.Close() + return nil, err + } + } + if err := zw.Close(); err != nil { + return nil, err + } + + return manifest, nil +} + +func PreviewImport(ctx context.Context, r io.Reader, opts ImportOptions) (*ImportPreview, error) { + _ = ctx + opened, err := openBundle(r) + if err != nil { + return nil, err + } + defer opened.Close() + + manifest, payload, err := loadBundle(opened.Reader) + if err != nil { + return nil, err + } + + return buildImportPreview(manifest, payload, opts), nil +} + +func ApplyImport(ctx context.Context, r io.Reader, opts ImportOptions, controller Controller) (*ImportResult, error) { + opened, err := openBundle(r) + if err != nil { + return nil, err + } + defer opened.Close() + + manifest, payload, err := loadBundle(opened.Reader) + if err != nil { + return nil, err + } + + resumeIDs, err := quiesceActiveDownloads(ctx, controller) + if err != nil { + return nil, err + } + if len(resumeIDs) > 0 && controller != nil { + defer resumeDownloads(controller, resumeIDs) + } + + preview := buildImportPreview(manifest, payload, opts) + rootDir := preview.RootDir + plan, err := planImport(payload.Downloads, manifest, opts, rootDir) + if err != nil { + return nil, err + } + + if opts.Replace { + if err := clearExistingState(); err != nil { + return nil, err + } + } + + settingsSaved := false + if payload.Settings != nil { + if payload.Settings.General.DefaultDownloadDir != "" && rootDir != "" { + payload.Settings.General.DefaultDownloadDir = utils.EnsureAbsPath(rootDir) + } + if err := config.SaveSettings(payload.Settings); err != nil { + return nil, err + } + settingsSaved = true + } + + logsRestored, err := restoreLogFiles(opened.Reader, manifest, opts.Replace) + if err != nil { + return nil, err + } + + imported := 0 + for _, item := range plan { + if item.Skip { + continue + } + if err := ctx.Err(); err != nil { + return nil, err + } + + record := item.Record + entry := types.DownloadEntry{ + ID: item.FinalID, + URL: record.URL, + URLHash: state.URLHash(record.URL), + DestPath: item.FinalPath, + Filename: filepath.Base(item.FinalPath), + Status: item.FinalStatus, + TotalSize: record.TotalSize, + CompletedAt: record.CompletedAt, + TimeTaken: record.TimeTaken, + AvgSpeed: record.AvgSpeed, + Mirrors: append([]string(nil), record.Mirrors...), + } + + switch item.FinalStatus { + case "completed": + entry.Downloaded = record.TotalSize + case "error": + entry.Downloaded = record.Downloaded + default: + if item.UsePartial { + entry.Downloaded = record.Downloaded + } else { + entry.Downloaded = 0 + } + } + + if item.UsePartial && record.Resumable != nil { + if err := restorePartialFile(opened.Reader, record.Resumable.PartialFile, item.FinalPath+types.IncompleteSuffix, rootDir); err != nil { + return nil, err + } + + saved := &types.DownloadState{ + ID: item.FinalID, + URLHash: record.Resumable.URLHash, + URL: record.URL, + DestPath: item.FinalPath, + TotalSize: record.TotalSize, + Downloaded: record.Downloaded, + Tasks: append([]types.Task(nil), record.Resumable.Tasks...), + Filename: filepath.Base(item.FinalPath), + CreatedAt: record.Resumable.CreatedAt, + PausedAt: record.Resumable.PausedAt, + Elapsed: record.Resumable.Elapsed, + Mirrors: append([]string(nil), record.Mirrors...), + ChunkBitmap: append([]byte(nil), record.Resumable.ChunkBitmap...), + ActualChunkSize: record.Resumable.ActualChunkSize, + FileHash: record.Resumable.FileHash, + } + if err := state.SaveStateWithOptions(record.URL, item.FinalPath, saved, state.SaveStateOptions{SkipFileHash: true}); err != nil { + return nil, err + } + if item.FinalStatus != "paused" { + if err := state.UpdateStatus(item.FinalID, item.FinalStatus); err != nil { + return nil, err + } + } + } else { + if err := state.AddToMasterList(entry); err != nil { + return nil, err + } + } + + imported++ + } + + return &ImportResult{ + Preview: preview, + Imported: imported, + SettingsSaved: settingsSaved, + LogsRestored: logsRestored, + }, nil +} + +func exportDownloads(includePartials bool) ([]PortableDownload, map[string]int, error) { + entries, err := state.ListAllDownloads() + if err != nil { + return nil, nil, err + } + + counts := make(map[string]int) + out := make([]PortableDownload, 0, len(entries)) + for _, entry := range entries { + rec := PortableDownload{ + ID: entry.ID, + URL: entry.URL, + DestPath: entry.DestPath, + Filename: entry.Filename, + Status: entry.Status, + OriginalStatus: entry.Status, + TotalSize: entry.TotalSize, + Downloaded: entry.Downloaded, + CompletedAt: entry.CompletedAt, + TimeTaken: entry.TimeTaken, + AvgSpeed: entry.AvgSpeed, + Mirrors: append([]string(nil), entry.Mirrors...), + } + + switch entry.Status { + case "completed", "error": + default: + saved, _ := state.LoadState(entry.URL, entry.DestPath) + if includePartials && saved != nil { + partialPath := entry.DestPath + types.IncompleteSuffix + if _, err := os.Stat(partialPath); err == nil { + rec.Resumable = &PortableResumeState{ + URLHash: saved.URLHash, + CreatedAt: saved.CreatedAt, + PausedAt: saved.PausedAt, + Elapsed: saved.Elapsed, + Tasks: append([]types.Task(nil), saved.Tasks...), + ChunkBitmap: append([]byte(nil), saved.ChunkBitmap...), + ActualChunkSize: saved.ActualChunkSize, + FileHash: saved.FileHash, + PartialFile: filepath.ToSlash(filepath.Join("partials", entry.ID+types.IncompleteSuffix)), + } + switch entry.Status { + case "downloading", "pausing": + rec.Status = "paused" + } + break + } + } + + if entry.Status == "paused" || entry.Status == "downloading" || entry.Status == "pausing" { + rec.Status = "queued" + rec.Downloaded = 0 + } + } + + counts[rec.Status]++ + out = append(out, rec) + } + + slices.SortFunc(out, func(a, b PortableDownload) int { + return strings.Compare(a.ID, b.ID) + }) + return out, counts, nil +} + +func buildImportPreview(manifest *Manifest, payload *bundlePayload, opts ImportOptions) *ImportPreview { + rootDir := utils.EnsureAbsPath(strings.TrimSpace(opts.RootDir)) + if rootDir == "" { + settings, _ := config.LoadSettings() + if settings != nil { + rootDir = utils.EnsureAbsPath(strings.TrimSpace(settings.General.DefaultDownloadDir)) + } + if rootDir == "" { + rootDir = "." + } + } + + plan, conflicts := planImportPreview(payload.Downloads, manifest, opts, rootDir) + + preview := &ImportPreview{ + Manifest: manifest, + RootDir: rootDir, + ImportsByStatus: make(map[string]int), + Conflicts: conflicts, + } + for _, item := range plan { + if item.Skip { + preview.DuplicatesSkipped++ + continue + } + preview.ImportsByStatus[item.FinalStatus]++ + if item.FinalPath != "" && item.Record.DestPath != "" && filepath.Clean(item.FinalPath) != filepath.Clean(rebaseImportedPath(item.Record.DestPath, manifest.OriginalDefaultDownloadDir, rootDir).Path) { + preview.RenamedItems++ + } + if item.Record.OriginalStatus != "" && item.Record.OriginalStatus != item.FinalStatus && item.FinalStatus == "queued" { + preview.ResumableJobsDowngradedToQueue++ + } + rebase := rebaseImportedPath(item.Record.DestPath, manifest.OriginalDefaultDownloadDir, rootDir) + if rebase.Rebased { + preview.RebasedPaths++ + } + if rebase.Externalized { + preview.ExternalizedPaths++ + } + } + + return preview +} + +func planImport(downloads []PortableDownload, manifest *Manifest, opts ImportOptions, rootDir string) ([]plannedImport, error) { + plan, _ := planImportPreview(downloads, manifest, opts, rootDir) + return plan, nil +} + +func planImportPreview(downloads []PortableDownload, manifest *Manifest, opts ImportOptions, rootDir string) ([]plannedImport, []ImportConflict) { + existing, _ := state.ListAllDownloads() + existingByID := make(map[string]types.DownloadEntry, len(existing)) + existingByLogical := make(map[string]types.DownloadEntry, len(existing)) + reservedNames := make(map[string]struct{}) + if !opts.Replace { + for _, entry := range existing { + existingByID[entry.ID] = entry + existingByLogical[logicalKey(entry.URL, entry.DestPath, entry.Filename, entry.Status)] = entry + if entry.DestPath != "" { + reservedNames[pathKey(filepath.Dir(entry.DestPath), filepath.Base(entry.DestPath))] = struct{}{} + } + } + } + + var conflicts []ImportConflict + plan := make([]plannedImport, 0, len(downloads)) + for _, record := range downloads { + rebase := rebaseImportedPath(record.DestPath, manifest.OriginalDefaultDownloadDir, rootDir) + finalPath := rebase.Path + finalID := record.ID + finalStatus := record.Status + usePartial := record.Resumable != nil && strings.TrimSpace(record.Resumable.PartialFile) != "" + + if !usePartial && finalStatus == "paused" { + finalStatus = "queued" + } + + if existing, ok := existingByID[record.ID]; ok && !sameLogical(existing, record, finalPath) { + oldID := finalID + finalID = uuid.New().String() + conflicts = append(conflicts, ImportConflict{ + Type: "id_conflict", + ID: oldID, + Message: "assigned a new download ID during import", + }) + } + + if existing, ok := existingByLogical[logicalKey(record.URL, finalPath, filepath.Base(finalPath), finalStatus)]; ok && existing.ID != record.ID { + plan = append(plan, plannedImport{ + Record: record, + FinalID: finalID, + FinalPath: finalPath, + FinalStatus: finalStatus, + Skip: true, + }) + conflicts = append(conflicts, ImportConflict{ + Type: "duplicate", + ID: existing.ID, + Path: finalPath, + Message: "skipped exact duplicate import record", + }) + continue + } + + dir := filepath.Dir(finalPath) + name := filepath.Base(finalPath) + uniqueName := processing.GetUniqueFilename(dir, name, func(checkDir, checkName string) bool { + _, exists := reservedNames[pathKey(checkDir, checkName)] + return exists + }) + if uniqueName != "" && uniqueName != name { + newPath := filepath.Join(dir, uniqueName) + conflicts = append(conflicts, ImportConflict{ + Type: "path_rename", + ID: record.ID, + OldPath: finalPath, + NewPath: newPath, + Message: "renamed during import to avoid destination collision", + }) + finalPath = newPath + } + + if finalPath != "" { + reservedNames[pathKey(filepath.Dir(finalPath), filepath.Base(finalPath))] = struct{}{} + } + + plan = append(plan, plannedImport{ + Record: record, + FinalID: finalID, + FinalPath: finalPath, + FinalStatus: finalStatus, + UsePartial: usePartial, + }) + } + + return plan, conflicts +} + +func sameLogical(entry types.DownloadEntry, record PortableDownload, finalPath string) bool { + return entry.URL == record.URL && + filepath.Clean(entry.DestPath) == filepath.Clean(finalPath) && + entry.Filename == filepath.Base(finalPath) && + entry.Status == record.Status +} + +func logicalKey(url, destPath, filename, status string) string { + return strings.Join([]string{ + strings.TrimSpace(url), + filepath.Clean(strings.TrimSpace(destPath)), + strings.TrimSpace(filename), + strings.TrimSpace(status), + }, "|") +} + +func pathKey(dir, name string) string { + return filepath.Clean(dir) + "|" + strings.TrimSpace(name) +} + +func quiesceActiveDownloads(ctx context.Context, controller Controller) ([]string, error) { + if controller == nil { + return nil, nil + } + statuses, err := controller.List() + if err != nil { + return nil, err + } + + var activeIDs []string + for _, status := range statuses { + if status.Status == "downloading" || status.Status == "pausing" { + activeIDs = append(activeIDs, status.ID) + } + } + for _, id := range activeIDs { + if err := controller.Pause(id); err != nil { + return nil, err + } + } + if len(activeIDs) == 0 { + return nil, nil + } + + deadline := time.Now().Add(15 * time.Second) + for time.Now().Before(deadline) { + if err := ctx.Err(); err != nil { + return nil, err + } + statuses, err := controller.List() + if err != nil { + return nil, err + } + pending := false + for _, status := range statuses { + if !slices.Contains(activeIDs, status.ID) { + continue + } + if status.Status == "downloading" || status.Status == "pausing" { + pending = true + break + } + } + if !pending { + return activeIDs, nil + } + time.Sleep(150 * time.Millisecond) + } + + return activeIDs, nil +} + +func resumeDownloads(controller Controller, ids []string) { + for _, id := range ids { + if err := controller.Resume(id); err != nil { + utils.Debug("backup: failed resuming %s after transfer: %v", id, err) + } + } +} + +func collectLogFiles() ([]string, error) { + logDir := config.GetLogsDir() + entries, err := os.ReadDir(logDir) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, err + } + var out []string + for _, entry := range entries { + if entry.IsDir() { + continue + } + out = append(out, filepath.Join(logDir, entry.Name())) + } + slices.Sort(out) + return out, nil +} + +type openedBundle struct { + File *os.File + Reader *zip.Reader +} + +func (o *openedBundle) Close() error { + if o == nil || o.File == nil { + return nil + } + name := o.File.Name() + if err := o.File.Close(); err != nil { + return err + } + return os.Remove(name) +} + +func openBundle(r io.Reader) (*openedBundle, error) { + if r == nil { + return nil, fmt.Errorf("bundle reader is required") + } + tmpFile, err := os.CreateTemp("", "surge-import-*.zip") + if err != nil { + return nil, err + } + defer func() { + _ = tmpFile.Close() + }() + if _, err := io.Copy(tmpFile, r); err != nil { + _ = os.Remove(tmpFile.Name()) + return nil, err + } + + info, err := os.Stat(tmpFile.Name()) + if err != nil { + _ = os.Remove(tmpFile.Name()) + return nil, err + } + file, err := os.Open(tmpFile.Name()) + if err != nil { + _ = os.Remove(tmpFile.Name()) + return nil, err + } + reader, err := zip.NewReader(file, info.Size()) + if err != nil { + _ = file.Close() + _ = os.Remove(tmpFile.Name()) + return nil, err + } + return &openedBundle{File: file, Reader: reader}, nil +} + +func loadBundle(reader *zip.Reader) (*Manifest, *bundlePayload, error) { + manifestBytes, err := readZipEntry(reader, "manifest.json") + if err != nil { + return nil, nil, err + } + var manifest Manifest + if err := json.Unmarshal(manifestBytes, &manifest); err != nil { + return nil, nil, err + } + if manifest.SchemaVersion != SchemaVersion { + return nil, nil, fmt.Errorf("unsupported bundle schema version %d", manifest.SchemaVersion) + } + if err := verifyManifestFiles(reader, manifest.Files); err != nil { + return nil, nil, err + } + + settingsBytes, err := readZipEntry(reader, "settings/settings.json") + if err != nil { + return nil, nil, err + } + var settings config.Settings + if err := json.Unmarshal(settingsBytes, &settings); err != nil { + return nil, nil, err + } + downloadsBytes, err := readZipEntry(reader, "state/downloads.json") + if err != nil { + return nil, nil, err + } + var downloads []PortableDownload + if err := json.Unmarshal(downloadsBytes, &downloads); err != nil { + return nil, nil, err + } + return &manifest, &bundlePayload{ + Settings: &settings, + Downloads: downloads, + }, nil +} + +func readZipEntry(reader *zip.Reader, name string) ([]byte, error) { + for _, file := range reader.File { + if filepath.ToSlash(file.Name) != filepath.ToSlash(name) { + continue + } + rc, err := file.Open() + if err != nil { + return nil, err + } + defer func() { _ = rc.Close() }() + return io.ReadAll(rc) + } + return nil, fmt.Errorf("bundle entry %s not found", name) +} + +func verifyManifestFiles(reader *zip.Reader, files []ManifestFile) error { + for _, file := range files { + data, err := readZipEntry(reader, file.Path) + if err != nil { + return err + } + sum := sha256.Sum256(data) + if hex.EncodeToString(sum[:]) != file.SHA256 { + return fmt.Errorf("checksum mismatch for %s", file.Path) + } + if int64(len(data)) != file.Size { + return fmt.Errorf("size mismatch for %s", file.Path) + } + } + return nil +} + +func writeZipBytes(zw *zip.Writer, name string, data []byte) error { + w, err := zw.Create(filepath.ToSlash(name)) + if err != nil { + return err + } + _, err = w.Write(data) + return err +} + +func writeZipFile(zw *zip.Writer, name, path string) error { + w, err := zw.Create(filepath.ToSlash(name)) + if err != nil { + return err + } + file, err := os.Open(path) + if err != nil { + return err + } + defer func() { _ = file.Close() }() + _, err = io.Copy(w, file) + return err +} + +func fileSHA256(path string) (string, int64, error) { + file, err := os.Open(path) + if err != nil { + return "", 0, err + } + defer func() { _ = file.Close() }() + h := sha256.New() + size, err := io.Copy(h, file) + if err != nil { + return "", 0, err + } + return hex.EncodeToString(h.Sum(nil)), size, nil +} + +func normalizeAllowedRoot(root string) (string, error) { + root = utils.EnsureAbsPath(strings.TrimSpace(root)) + if root == "" { + if wd, err := filepath.Abs("."); err == nil { + root = wd + } + } + if strings.TrimSpace(root) == "" { + return "", fmt.Errorf("invalid allowed root") + } + cleanRoot, err := filepath.Abs(filepath.Clean(root)) + if err != nil { + return "", err + } + return cleanRoot, nil +} + +func resolveRestoreDestination(destPath, allowedRoot string) (string, error) { + finalDest := filepath.Clean(strings.TrimSpace(destPath)) + if finalDest == "" || finalDest == "." { + return "", fmt.Errorf("invalid destination path") + } + + cleanDest, err := filepath.Abs(finalDest) + if err != nil { + return "", err + } + cleanRoot, err := normalizeAllowedRoot(allowedRoot) + if err != nil { + return "", err + } + + rel, err := filepath.Rel(cleanRoot, cleanDest) + if err != nil { + return "", err + } + if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) { + return "", fmt.Errorf("refusing to write outside allowed root") + } + + return cleanDest, nil +} + +func restoreBundleFile(reader *zip.Reader, bundlePath, destPath, allowedRoot string) error { + data, err := readZipEntry(reader, bundlePath) + if err != nil { + return err + } + finalDest, err := resolveRestoreDestination(destPath, allowedRoot) + if err != nil { + return err + } + if err := os.MkdirAll(filepath.Dir(finalDest), 0o755); err != nil { + return err + } + return os.WriteFile(finalDest, data, 0o644) +} + +func restorePartialFile(reader *zip.Reader, bundlePath, destPath, allowedRoot string) error { + return restoreBundleFile(reader, bundlePath, destPath, allowedRoot) +} + +func restoreLogFiles(reader *zip.Reader, manifest *Manifest, replace bool) (int, error) { + logPaths := manifestLogPaths(manifest) + if len(logPaths) == 0 { + return 0, nil + } + + logDir := config.GetLogsDir() + if err := os.MkdirAll(logDir, 0o755); err != nil { + return 0, err + } + if replace { + if err := clearLogFiles(); err != nil { + return 0, err + } + } + + restored := 0 + for _, bundlePath := range logPaths { + filename := filepath.Base(filepath.FromSlash(bundlePath)) + if filename == "." || filename == "" { + continue + } + + destName := filename + if !replace { + destName = processing.GetUniqueFilename(logDir, filename, nil) + if destName == "" { + return restored, fmt.Errorf("could not determine a unique log filename for %s", filename) + } + } + + if err := restoreBundleFile(reader, bundlePath, filepath.Join(logDir, destName), logDir); err != nil { + return restored, err + } + restored++ + } + + return restored, nil +} + +func manifestLogPaths(manifest *Manifest) []string { + if manifest == nil { + return nil + } + + var paths []string + for _, file := range manifest.Files { + path := filepath.ToSlash(strings.TrimSpace(file.Path)) + if strings.HasPrefix(path, "logs/") { + paths = append(paths, path) + } + } + return paths +} + +func clearExistingState() error { + entries, _ := state.ListAllDownloads() + for _, entry := range entries { + if entry.Status == "completed" || strings.TrimSpace(entry.DestPath) == "" { + continue + } + _ = os.Remove(entry.DestPath + types.IncompleteSuffix) + } + + db, err := state.GetDB() + if err != nil { + return err + } + return withTransaction(db, func(tx *sql.Tx) error { + if _, err := tx.Exec("DELETE FROM tasks"); err != nil { + return err + } + if _, err := tx.Exec("DELETE FROM downloads"); err != nil { + return err + } + return nil + }) +} + +func clearLogFiles() error { + entries, err := os.ReadDir(config.GetLogsDir()) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + for _, entry := range entries { + if entry.IsDir() { + continue + } + if err := os.Remove(filepath.Join(config.GetLogsDir(), entry.Name())); err != nil && !os.IsNotExist(err) { + return err + } + } + return nil +} + +func withTransaction(db *sql.DB, fn func(*sql.Tx) error) error { + tx, err := db.Begin() + if err != nil { + return err + } + if err := fn(tx); err != nil { + _ = tx.Rollback() + return err + } + return tx.Commit() +} diff --git a/internal/backup/backup_test.go b/internal/backup/backup_test.go new file mode 100644 index 00000000..2dbc2b4d --- /dev/null +++ b/internal/backup/backup_test.go @@ -0,0 +1,306 @@ +package backup + +import ( + "archive/zip" + "bytes" + "os" + "path/filepath" + "testing" + + "github.com/SurgeDM/Surge/internal/config" + "github.com/SurgeDM/Surge/internal/engine/state" + "github.com/SurgeDM/Surge/internal/engine/types" + "github.com/SurgeDM/Surge/internal/testutil" +) + +func setupBackupEnv(t *testing.T) string { + t.Helper() + + root := t.TempDir() + t.Setenv("XDG_CONFIG_HOME", filepath.Join(root, "config")) + t.Setenv("XDG_STATE_HOME", filepath.Join(root, "state")) + t.Setenv("APPDATA", filepath.Join(root, "config")) + + if err := config.EnsureDirs(); err != nil { + t.Fatalf("EnsureDirs failed: %v", err) + } + + state.CloseDB() + state.Configure(filepath.Join(config.GetStateDir(), "surge.db")) + if _, err := state.GetDB(); err != nil { + t.Fatalf("state.GetDB failed: %v", err) + } + t.Cleanup(state.CloseDB) + return root +} + +func seedPausedDownload(t *testing.T, downloadRoot string) (string, string) { + t.Helper() + + destPath := filepath.Join(downloadRoot, "nested", "video.bin") + if err := config.SaveSettings(&config.Settings{ + General: config.GeneralSettings{ + DefaultDownloadDir: downloadRoot, + }, + Network: config.DefaultSettings().Network, + Performance: config.DefaultSettings().Performance, + Categories: config.DefaultSettings().Categories, + Extension: config.DefaultSettings().Extension, + }); err != nil { + t.Fatalf("SaveSettings failed: %v", err) + } + + if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil { + t.Fatalf("mkdir failed: %v", err) + } + if _, err := testutil.CreateSurgeFile(filepath.Dir(destPath), "video.bin", 1024, 256); err != nil { + t.Fatalf("CreateSurgeFile failed: %v", err) + } + + saved := &types.DownloadState{ + ID: "paused-download", + URL: "https://example.com/video.bin", + DestPath: destPath, + Filename: "video.bin", + TotalSize: 1024, + Downloaded: 256, + Tasks: []types.Task{ + {Offset: 256, Length: 768}, + }, + CreatedAt: 1, + PausedAt: 2, + Elapsed: 3, + } + if err := state.SaveStateWithOptions(saved.URL, saved.DestPath, saved, state.SaveStateOptions{SkipFileHash: true}); err != nil { + t.Fatalf("SaveStateWithOptions failed: %v", err) + } + return saved.URL, destPath +} + +func TestApplyImport_WithoutPartialsDowngradesPausedDownloads(t *testing.T) { + root := setupBackupEnv(t) + sourceRoot := filepath.Join(root, "source-downloads") + url, _ := seedPausedDownload(t, sourceRoot) + + var buf bytes.Buffer + if _, err := Export(t.Context(), &buf, ExportOptions{}, nil); err != nil { + t.Fatalf("Export failed: %v", err) + } + + importRoot := filepath.Join(root, "imported") + result, err := ApplyImport(t.Context(), bytes.NewReader(buf.Bytes()), ImportOptions{ + RootDir: importRoot, + Replace: true, + }, nil) + if err != nil { + t.Fatalf("ApplyImport failed: %v", err) + } + if result.Preview.ResumableJobsDowngradedToQueue != 1 { + t.Fatalf("downgraded=%d, want 1", result.Preview.ResumableJobsDowngradedToQueue) + } + + entry, err := state.GetDownload("paused-download") + if err != nil { + t.Fatalf("GetDownload failed: %v", err) + } + if entry == nil { + t.Fatal("expected imported entry") + } + if entry.Status != "queued" { + t.Fatalf("status=%q, want queued", entry.Status) + } + if entry.Downloaded != 0 { + t.Fatalf("downloaded=%d, want 0", entry.Downloaded) + } + + wantDest := filepath.Join(importRoot, "nested", "video.bin") + if filepath.Clean(entry.DestPath) != filepath.Clean(wantDest) { + t.Fatalf("dest=%q, want %q", entry.DestPath, wantDest) + } + saved, err := state.LoadState(url, wantDest) + if err != nil { + t.Fatalf("LoadState failed: %v", err) + } + if saved == nil { + t.Fatal("expected queued metadata to remain importable") + } + if len(saved.Tasks) != 0 { + t.Fatalf("tasks=%d, want 0 without exported partial state", len(saved.Tasks)) + } +} + +func TestPreviewImport_UsesImportOptions(t *testing.T) { + root := setupBackupEnv(t) + sourceRoot := filepath.Join(root, "source-downloads") + url, _ := seedPausedDownload(t, sourceRoot) + + var buf bytes.Buffer + if _, err := Export(t.Context(), &buf, ExportOptions{}, nil); err != nil { + t.Fatalf("Export failed: %v", err) + } + + importRoot := filepath.Join(root, "imported") + if err := state.AddToMasterList(types.DownloadEntry{ + ID: "existing-download", + URL: url, + URLHash: state.URLHash(url), + DestPath: filepath.Join(importRoot, "nested", "video.bin"), + Filename: "video.bin", + Status: "queued", + }); err != nil { + t.Fatalf("AddToMasterList failed: %v", err) + } + + previewMerge, err := PreviewImport(t.Context(), bytes.NewReader(buf.Bytes()), ImportOptions{ + RootDir: importRoot, + }) + if err != nil { + t.Fatalf("PreviewImport merge failed: %v", err) + } + if filepath.Clean(previewMerge.RootDir) != filepath.Clean(importRoot) { + t.Fatalf("merge root=%q, want %q", previewMerge.RootDir, importRoot) + } + if previewMerge.DuplicatesSkipped != 1 { + t.Fatalf("merge duplicates=%d, want 1", previewMerge.DuplicatesSkipped) + } + + previewReplace, err := PreviewImport(t.Context(), bytes.NewReader(buf.Bytes()), ImportOptions{ + RootDir: importRoot, + Replace: true, + }) + if err != nil { + t.Fatalf("PreviewImport replace failed: %v", err) + } + if previewReplace.DuplicatesSkipped != 0 { + t.Fatalf("replace duplicates=%d, want 0", previewReplace.DuplicatesSkipped) + } + if previewReplace.ImportsByStatus["queued"] != 1 { + t.Fatalf("replace queued imports=%d, want 1", previewReplace.ImportsByStatus["queued"]) + } +} + +func TestApplyImport_WithPartialsRestoresPausedState(t *testing.T) { + root := setupBackupEnv(t) + sourceRoot := filepath.Join(root, "source-downloads") + url, _ := seedPausedDownload(t, sourceRoot) + + var buf bytes.Buffer + if _, err := Export(t.Context(), &buf, ExportOptions{IncludePartials: true}, nil); err != nil { + t.Fatalf("Export failed: %v", err) + } + + importRoot := filepath.Join(root, "imported") + _, err := ApplyImport(t.Context(), bytes.NewReader(buf.Bytes()), ImportOptions{ + RootDir: importRoot, + Replace: true, + }, nil) + if err != nil { + t.Fatalf("ApplyImport failed: %v", err) + } + + wantDest := filepath.Join(importRoot, "nested", "video.bin") + entry, err := state.GetDownload("paused-download") + if err != nil { + t.Fatalf("GetDownload failed: %v", err) + } + if entry == nil { + t.Fatal("expected imported entry") + } + if entry.Status != "paused" { + t.Fatalf("status=%q, want paused", entry.Status) + } + + saved, err := state.LoadState(url, wantDest) + if err != nil { + t.Fatalf("LoadState failed: %v", err) + } + if saved == nil { + t.Fatal("expected restored resumable state") + } + if saved.Downloaded != 256 { + t.Fatalf("downloaded=%d, want 256", saved.Downloaded) + } + if !testutil.FileExists(wantDest + types.IncompleteSuffix) { + t.Fatalf("expected restored partial file %s", wantDest+types.IncompleteSuffix) + } +} + +func TestApplyImport_WithLogsRestoresLogFiles(t *testing.T) { + setupBackupEnv(t) + + logDir := config.GetLogsDir() + originalOne := []byte("first log line\n") + originalTwo := []byte("second log line\n") + if err := os.WriteFile(filepath.Join(logDir, "session-1.log"), originalOne, 0o644); err != nil { + t.Fatalf("WriteFile session-1 failed: %v", err) + } + if err := os.WriteFile(filepath.Join(logDir, "session-2.log"), originalTwo, 0o644); err != nil { + t.Fatalf("WriteFile session-2 failed: %v", err) + } + + var buf bytes.Buffer + if _, err := Export(t.Context(), &buf, ExportOptions{IncludeLogs: true}, nil); err != nil { + t.Fatalf("Export failed: %v", err) + } + + if err := os.WriteFile(filepath.Join(logDir, "session-1.log"), []byte("stale"), 0o644); err != nil { + t.Fatalf("WriteFile stale session-1 failed: %v", err) + } + if err := os.WriteFile(filepath.Join(logDir, "stale.log"), []byte("remove me"), 0o644); err != nil { + t.Fatalf("WriteFile stale log failed: %v", err) + } + + result, err := ApplyImport(t.Context(), bytes.NewReader(buf.Bytes()), ImportOptions{Replace: true}, nil) + if err != nil { + t.Fatalf("ApplyImport failed: %v", err) + } + if result.LogsRestored != 2 { + t.Fatalf("logs restored=%d, want 2", result.LogsRestored) + } + + gotOne, err := os.ReadFile(filepath.Join(logDir, "session-1.log")) + if err != nil { + t.Fatalf("ReadFile session-1 failed: %v", err) + } + if !bytes.Equal(gotOne, originalOne) { + t.Fatalf("session-1 content=%q, want %q", string(gotOne), string(originalOne)) + } + + gotTwo, err := os.ReadFile(filepath.Join(logDir, "session-2.log")) + if err != nil { + t.Fatalf("ReadFile session-2 failed: %v", err) + } + if !bytes.Equal(gotTwo, originalTwo) { + t.Fatalf("session-2 content=%q, want %q", string(gotTwo), string(originalTwo)) + } + + if _, err := os.Stat(filepath.Join(logDir, "stale.log")); !os.IsNotExist(err) { + t.Fatalf("stale.log should be removed, stat err=%v", err) + } +} + +func TestRestoreBundleFile_RejectsPathOutsideAllowedRoot(t *testing.T) { + var buf bytes.Buffer + zw := zip.NewWriter(&buf) + w, err := zw.Create("logs/test.log") + if err != nil { + t.Fatalf("Create failed: %v", err) + } + if _, err := w.Write([]byte("log")); err != nil { + t.Fatalf("Write failed: %v", err) + } + if err := zw.Close(); err != nil { + t.Fatalf("Close failed: %v", err) + } + + reader, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len())) + if err != nil { + t.Fatalf("NewReader failed: %v", err) + } + + allowedRoot := t.TempDir() + destPath := filepath.Join(allowedRoot, "..", "escape.log") + if err := restoreBundleFile(reader, "logs/test.log", destPath, allowedRoot); err == nil { + t.Fatal("expected restoreBundleFile to reject paths outside the allowed root") + } +} diff --git a/internal/backup/path.go b/internal/backup/path.go new file mode 100644 index 00000000..a79f368e --- /dev/null +++ b/internal/backup/path.go @@ -0,0 +1,123 @@ +package backup + +import ( + "path/filepath" + "runtime" + "strings" + + "github.com/SurgeDM/Surge/internal/utils" +) + +type pathRebaseResult struct { + Path string + Rebased bool + Externalized bool +} + +func sanitizeExternalRoot(path string) string { + cleaned := filepath.Clean(strings.TrimSpace(path)) + volume := filepath.VolumeName(cleaned) + if volume != "" { + trimmed := strings.TrimSuffix(volume, ":") + if trimmed != "" { + return trimmed + } + } + if filepath.IsAbs(cleaned) { + return "root" + } + return "relative" +} + +func trimVolumePrefix(path string) string { + volume := filepath.VolumeName(path) + if volume == "" { + return path + } + trimmed := strings.TrimPrefix(path, volume) + trimmed = strings.TrimLeft(trimmed, string(filepath.Separator)) + return trimmed +} + +func pathWithinRoot(root, candidate string) bool { + cleanRoot, err := filepath.Abs(filepath.Clean(root)) + if err != nil { + return false + } + cleanCandidate, err := filepath.Abs(filepath.Clean(candidate)) + if err != nil { + return false + } + rel, err := filepath.Rel(cleanRoot, cleanCandidate) + if err != nil { + return false + } + return rel != ".." && !strings.HasPrefix(rel, ".."+string(filepath.Separator)) +} + +func sanitizeRelativeExternalPath(path string) string { + cleaned := filepath.Clean(strings.TrimSpace(path)) + parts := strings.FieldsFunc(cleaned, func(r rune) bool { + return r == '/' || r == '\\' + }) + safeParts := make([]string, 0, len(parts)) + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" || part == "." || part == ".." { + continue + } + safeParts = append(safeParts, part) + } + if len(safeParts) == 0 { + return "imported" + } + return filepath.Join(safeParts...) +} + +func rebaseImportedPath(originalPath, exportedRoot, targetRoot string) pathRebaseResult { + targetRoot = utils.EnsureAbsPath(strings.TrimSpace(targetRoot)) + if targetRoot == "" { + targetRoot = "." + } + + cleanOriginal := filepath.Clean(strings.TrimSpace(originalPath)) + cleanExportedRoot := filepath.Clean(strings.TrimSpace(exportedRoot)) + if cleanOriginal == "." || cleanOriginal == "" { + return pathRebaseResult{Path: targetRoot} + } + + if !filepath.IsAbs(cleanOriginal) { + joined := filepath.Join(targetRoot, cleanOriginal) + if pathWithinRoot(targetRoot, joined) { + return pathRebaseResult{ + Path: joined, + Rebased: true, + } + } + return pathRebaseResult{ + Path: filepath.Join(targetRoot, "external", "relative", sanitizeRelativeExternalPath(cleanOriginal)), + Externalized: true, + } + } + + if cleanExportedRoot != "." && cleanExportedRoot != "" { + if rel, err := filepath.Rel(cleanExportedRoot, cleanOriginal); err == nil && rel != ".." && !strings.HasPrefix(rel, ".."+string(filepath.Separator)) { + return pathRebaseResult{ + Path: filepath.Join(targetRoot, rel), + Rebased: true, + } + } + } + + externalRoot := sanitizeExternalRoot(cleanOriginal) + rest := trimVolumePrefix(cleanOriginal) + rest = strings.TrimLeft(rest, string(filepath.Separator)) + if runtime.GOOS != "windows" && strings.HasPrefix(rest, "/") { + rest = strings.TrimLeft(rest, "/") + } + + return pathRebaseResult{ + Path: filepath.Join(targetRoot, "external", externalRoot, rest), + Externalized: true, + } +} diff --git a/internal/backup/path_test.go b/internal/backup/path_test.go new file mode 100644 index 00000000..f9249a26 --- /dev/null +++ b/internal/backup/path_test.go @@ -0,0 +1,22 @@ +package backup + +import ( + "path/filepath" + "testing" +) + +func TestRebaseImportedPath_RelativeTraversalIsExternalized(t *testing.T) { + targetRoot := filepath.Join(t.TempDir(), "import-root") + got := rebaseImportedPath(filepath.Join("..", "..", "etc", "passwd"), "", targetRoot) + + if !got.Externalized { + t.Fatal("expected relative traversal path to be externalized") + } + if pathWithinRoot(targetRoot, got.Path) { + if filepath.Clean(got.Path) == filepath.Clean(filepath.Join(targetRoot, "..", "..", "etc", "passwd")) { + t.Fatal("expected escaped relative path to be rewritten, not preserved") + } + } else { + t.Fatalf("externalized path escaped target root: %q", got.Path) + } +} diff --git a/internal/backup/types.go b/internal/backup/types.go new file mode 100644 index 00000000..ba4a9cc5 --- /dev/null +++ b/internal/backup/types.go @@ -0,0 +1,112 @@ +package backup + +import ( + "time" + + "github.com/SurgeDM/Surge/internal/config" + "github.com/SurgeDM/Surge/internal/engine/types" +) + +const ( + SchemaVersion = 1 + BundleExtension = ".surge-export" +) + +// Controller captures the lifecycle operations backup needs to produce a stable snapshot. +type Controller interface { + List() ([]types.DownloadStatus, error) + Pause(id string) error + Resume(id string) error +} + +type ExportOptions struct { + IncludeLogs bool `json:"include_logs"` + IncludePartials bool `json:"include_partials"` + LeavePaused bool `json:"leave_paused"` + AppVersion string `json:"app_version,omitempty"` +} + +type ImportOptions struct { + RootDir string `json:"root_dir,omitempty"` + Replace bool `json:"replace"` + SessionID string `json:"session_id,omitempty"` +} + +type Manifest struct { + SchemaVersion int `json:"schema_version"` + CreatedAt time.Time `json:"created_at"` + SurgeVersion string `json:"surge_version"` + OriginalDefaultDownloadDir string `json:"original_default_download_dir"` + IncludeLogs bool `json:"include_logs"` + IncludePartials bool `json:"include_partials"` + Counts map[string]int `json:"counts"` + Files []ManifestFile `json:"files,omitempty"` +} + +type ManifestFile struct { + Path string `json:"path"` + SHA256 string `json:"sha256"` + Size int64 `json:"size"` +} + +type PortableDownload struct { + ID string `json:"id"` + URL string `json:"url"` + DestPath string `json:"dest_path"` + Filename string `json:"filename"` + Status string `json:"status"` + OriginalStatus string `json:"original_status,omitempty"` + TotalSize int64 `json:"total_size"` + Downloaded int64 `json:"downloaded"` + CompletedAt int64 `json:"completed_at,omitempty"` + TimeTaken int64 `json:"time_taken,omitempty"` + AvgSpeed float64 `json:"avg_speed,omitempty"` + Mirrors []string `json:"mirrors,omitempty"` + Resumable *PortableResumeState `json:"resumable,omitempty"` +} + +type PortableResumeState struct { + URLHash string `json:"url_hash"` + CreatedAt int64 `json:"created_at"` + PausedAt int64 `json:"paused_at"` + Elapsed int64 `json:"elapsed"` + Tasks []types.Task `json:"tasks,omitempty"` + ChunkBitmap []byte `json:"chunk_bitmap,omitempty"` + ActualChunkSize int64 `json:"actual_chunk_size,omitempty"` + FileHash string `json:"file_hash,omitempty"` + PartialFile string `json:"partial_file,omitempty"` +} + +type bundlePayload struct { + Settings *config.Settings `json:"settings"` + Downloads []PortableDownload `json:"downloads"` +} + +type ImportPreview struct { + SessionID string `json:"session_id,omitempty"` + Manifest *Manifest `json:"manifest"` + RootDir string `json:"root_dir"` + ImportsByStatus map[string]int `json:"imports_by_status"` + DuplicatesSkipped int `json:"duplicates_skipped"` + RenamedItems int `json:"renamed_items"` + ResumableJobsDowngradedToQueue int `json:"resumable_jobs_downgraded_to_queue"` + RebasedPaths int `json:"rebased_paths"` + ExternalizedPaths int `json:"externalized_paths"` + Conflicts []ImportConflict `json:"conflicts,omitempty"` +} + +type ImportConflict struct { + Type string `json:"type"` + ID string `json:"id,omitempty"` + Path string `json:"path,omitempty"` + OldPath string `json:"old_path,omitempty"` + NewPath string `json:"new_path,omitempty"` + Message string `json:"message,omitempty"` +} + +type ImportResult struct { + Preview *ImportPreview `json:"preview"` + Imported int `json:"imported"` + SettingsSaved bool `json:"settings_saved"` + LogsRestored int `json:"logs_restored"` +} diff --git a/internal/core/transfer.go b/internal/core/transfer.go new file mode 100644 index 00000000..1a6f7fe0 --- /dev/null +++ b/internal/core/transfer.go @@ -0,0 +1,173 @@ +package core + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/SurgeDM/Surge/internal/backup" +) + +// TransferService defines import/export operations for local and remote use. +type TransferService interface { + Export(ctx context.Context, opts backup.ExportOptions, dst io.Writer) (*backup.Manifest, error) + PreviewImport(ctx context.Context, src io.Reader, opts backup.ImportOptions) (*backup.ImportPreview, error) + ApplyImport(ctx context.Context, src io.Reader, opts backup.ImportOptions) (*backup.ImportResult, error) +} + +type LocalTransferService struct { + Controller backup.Controller + Version string +} + +func NewLocalTransferService(controller backup.Controller, version string) *LocalTransferService { + return &LocalTransferService{ + Controller: controller, + Version: strings.TrimSpace(version), + } +} + +func (s *LocalTransferService) Export(ctx context.Context, opts backup.ExportOptions, dst io.Writer) (*backup.Manifest, error) { + opts.AppVersion = s.Version + return backup.Export(ctx, dst, opts, s.Controller) +} + +func (s *LocalTransferService) PreviewImport(ctx context.Context, src io.Reader, opts backup.ImportOptions) (*backup.ImportPreview, error) { + return backup.PreviewImport(ctx, src, opts) +} + +func (s *LocalTransferService) ApplyImport(ctx context.Context, src io.Reader, opts backup.ImportOptions) (*backup.ImportResult, error) { + return backup.ApplyImport(ctx, src, opts, s.Controller) +} + +type RemoteTransferService struct { + BaseURL string + Token string + Client *http.Client +} + +func NewRemoteTransferService(baseURL, token string) *RemoteTransferService { + return &RemoteTransferService{ + BaseURL: strings.TrimRight(baseURL, "/"), + Token: token, + Client: &http.Client{Timeout: 5 * time.Minute}, + } +} + +func (s *RemoteTransferService) Export(ctx context.Context, opts backup.ExportOptions, dst io.Writer) (*backup.Manifest, error) { + body, err := json.Marshal(opts) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.BaseURL+"/data/export", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + if s.Token != "" { + req.Header.Set("Authorization", "Bearer "+s.Token) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := s.Client.Do(req) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) + return nil, fmt.Errorf("remote export failed: %s - %s", resp.Status, string(b)) + } + manifestHeader := strings.TrimSpace(resp.Header.Get("X-Surge-Manifest")) + var manifest backup.Manifest + if manifestHeader != "" { + if decoded, err := url.QueryUnescape(manifestHeader); err == nil { + _ = json.Unmarshal([]byte(decoded), &manifest) + } + } + if _, err := io.Copy(dst, resp.Body); err != nil { + return nil, err + } + if manifest.SchemaVersion == 0 { + return &backup.Manifest{}, nil + } + return &manifest, nil +} + +func (s *RemoteTransferService) PreviewImport(ctx context.Context, src io.Reader, opts backup.ImportOptions) (*backup.ImportPreview, error) { + endpoint := s.BaseURL + "/data/import/preview" + query := url.Values{} + if strings.TrimSpace(opts.RootDir) != "" { + query.Set("root_dir", opts.RootDir) + } + if opts.Replace { + query.Set("replace", "true") + } + if encoded := query.Encode(); encoded != "" { + endpoint += "?" + encoded + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, src) + if err != nil { + return nil, err + } + if s.Token != "" { + req.Header.Set("Authorization", "Bearer "+s.Token) + } + req.Header.Set("Content-Type", "application/octet-stream") + + resp, err := s.Client.Do(req) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) + return nil, fmt.Errorf("remote preview failed: %s - %s", resp.Status, string(b)) + } + + var preview backup.ImportPreview + if err := json.NewDecoder(resp.Body).Decode(&preview); err != nil { + return nil, err + } + return &preview, nil +} + +func (s *RemoteTransferService) ApplyImport(ctx context.Context, src io.Reader, opts backup.ImportOptions) (*backup.ImportResult, error) { + payload := map[string]interface{}{ + "session_id": opts.SessionID, + "root_dir": opts.RootDir, + "replace": opts.Replace, + } + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.BaseURL+"/data/import/apply", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + if s.Token != "" { + req.Header.Set("Authorization", "Bearer "+s.Token) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := s.Client.Do(req) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) + return nil, fmt.Errorf("remote apply failed: %s - %s", resp.Status, string(b)) + } + var result backup.ImportResult + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, err + } + _ = src + return &result, nil +} diff --git a/internal/processing/file_utils.go b/internal/processing/file_utils.go index 3fa4fe68..55158a10 100644 --- a/internal/processing/file_utils.go +++ b/internal/processing/file_utils.go @@ -14,6 +14,39 @@ import ( "github.com/SurgeDM/Surge/internal/utils" ) +func safeReservationPath(dir, name string) (string, error) { + dir = utils.EnsureAbsPath(strings.TrimSpace(dir)) + if dir == "" { + return "", fmt.Errorf("invalid reservation directory") + } + + name = strings.TrimSpace(name) + if name == "" { + return "", fmt.Errorf("invalid reservation name") + } + baseName := filepath.Base(name) + if baseName != name || strings.Contains(name, "/") || strings.Contains(name, "\\") || name == "." || name == ".." { + return "", fmt.Errorf("invalid reservation name") + } + + cleanDir, err := filepath.Abs(filepath.Clean(dir)) + if err != nil { + return "", err + } + targetPath, err := filepath.Abs(filepath.Join(cleanDir, name)) + if err != nil { + return "", err + } + rel, err := filepath.Rel(cleanDir, targetPath) + if err != nil { + return "", err + } + if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) { + return "", fmt.Errorf("reservation path escapes directory") + } + return targetPath, nil +} + // InferFilenameFromURL is the final naming fallback when neither the user nor // the probe produced a trustworthy filename. func InferFilenameFromURL(rawURL string) string { @@ -79,22 +112,34 @@ func GetUniqueFilename(dir, filename string, isNameActive func(string, string) b if strings.Contains(filename, "/") || strings.Contains(filename, "\\") || filename == "." || filename == ".." { return "" } + cleanDir, err := safeReservationPath(dir, filename) + if err != nil { + return "" + } + cleanDir = filepath.Dir(cleanDir) existsOnDisk := func(name string) bool { - targetPath := filepath.Join(dir, name) + targetPath, err := safeReservationPath(cleanDir, name) + if err != nil { + return true + } if _, err := os.Stat(targetPath); !os.IsNotExist(err) { return true } // A .surge sibling means another active or recoverable download already // claimed this filename, so we must not hand it out again. - if _, err := os.Stat(targetPath + types.IncompleteSuffix); !os.IsNotExist(err) { + incompletePath, err := safeReservationPath(cleanDir, name+types.IncompleteSuffix) + if err != nil { + return true + } + if _, err := os.Stat(incompletePath); !os.IsNotExist(err) { return true } return false } existsAnywhere := func(name string) bool { - if isNameActive != nil && isNameActive(dir, name) { + if isNameActive != nil && isNameActive(cleanDir, name) { return true } return existsOnDisk(name) diff --git a/internal/processing/file_utils_test.go b/internal/processing/file_utils_test.go index 4b504d3a..faae7408 100644 --- a/internal/processing/file_utils_test.go +++ b/internal/processing/file_utils_test.go @@ -94,6 +94,7 @@ func TestGetUniqueFilename(t *testing.T) { if name := processing.GetUniqueFilename(tmpDir, "overflow.bin", overflowActive); name != "" { t.Errorf("Expected empty result after exhaustion, got %s", name) } + } func TestGetCategoryPath(t *testing.T) { diff --git a/internal/tui/keys.go b/internal/tui/keys.go index a67f47c1..09245193 100644 --- a/internal/tui/keys.go +++ b/internal/tui/keys.go @@ -15,6 +15,7 @@ type KeyMap struct { Update UpdateKeyMap CategoryMgr CategoryManagerKeyMap QuitConfirm QuitConfirmKeyMap + Transfer TransferKeyMap } // DashboardKeyMap defines keybindings for the main dashboard @@ -30,6 +31,7 @@ type DashboardKeyMap struct { Refresh key.Binding Delete key.Binding Settings key.Binding + DataTransfer key.Binding Log key.Binding ToggleHelp key.Binding OpenFile key.Binding @@ -129,6 +131,18 @@ type QuitConfirmKeyMap struct { Cancel key.Binding } +// TransferKeyMap defines keybindings for data import/export. +type TransferKeyMap struct { + Export key.Binding + Import key.Binding + TogglePartials key.Binding + ToggleLogs key.Binding + ToggleReplace key.Binding + BrowseRoot key.Binding + Apply key.Binding + Close key.Binding +} + // CategoryManagerKeyMap defines keybindings for the category manager type CategoryManagerKeyMap struct { Up key.Binding @@ -188,6 +202,10 @@ var Keys = KeyMap{ key.WithKeys("s"), key.WithHelp("s", "settings"), ), + DataTransfer: key.NewBinding( + key.WithKeys("d"), + key.WithHelp("d", "data"), + ), Log: key.NewBinding( key.WithKeys("l"), key.WithHelp("l", "toggle log"), @@ -445,6 +463,40 @@ var Keys = KeyMap{ key.WithHelp("n/esc", "cancel"), ), }, + Transfer: TransferKeyMap{ + Export: key.NewBinding( + key.WithKeys("e"), + key.WithHelp("e", "export"), + ), + Import: key.NewBinding( + key.WithKeys("i"), + key.WithHelp("i", "import"), + ), + TogglePartials: key.NewBinding( + key.WithKeys("p"), + key.WithHelp("p", "partials"), + ), + ToggleLogs: key.NewBinding( + key.WithKeys("l"), + key.WithHelp("l", "logs"), + ), + ToggleReplace: key.NewBinding( + key.WithKeys("x"), + key.WithHelp("x", "replace"), + ), + BrowseRoot: key.NewBinding( + key.WithKeys("r"), + key.WithHelp("r", "root"), + ), + Apply: key.NewBinding( + key.WithKeys("a"), + key.WithHelp("a", "apply"), + ), + Close: key.NewBinding( + key.WithKeys("esc"), + key.WithHelp("esc", "close"), + ), + }, } // ShortHelp returns keybindings to show in the mini help view diff --git a/internal/tui/messages.go b/internal/tui/messages.go index 4ddf708d..fa61728b 100644 --- a/internal/tui/messages.go +++ b/internal/tui/messages.go @@ -1,6 +1,7 @@ package tui import ( + "github.com/SurgeDM/Surge/internal/backup" "github.com/SurgeDM/Surge/internal/version" ) @@ -33,3 +34,18 @@ type resumeResultMsg struct { id string err error } + +type transferExportResultMsg struct { + path string + err error +} + +type transferPreviewResultMsg struct { + preview *backup.ImportPreview + err error +} + +type transferApplyResultMsg struct { + result *backup.ImportResult + err error +} diff --git a/internal/tui/model.go b/internal/tui/model.go index 443728cc..14e02582 100644 --- a/internal/tui/model.go +++ b/internal/tui/model.go @@ -17,6 +17,7 @@ import ( tea "charm.land/bubbletea/v2" "charm.land/lipgloss/v2" + "github.com/SurgeDM/Surge/internal/backup" "github.com/SurgeDM/Surge/internal/config" "github.com/SurgeDM/Surge/internal/core" "github.com/SurgeDM/Surge/internal/engine/types" @@ -43,6 +44,10 @@ const ( CategoryManagerState // CategoryManagerState is 13 QuitConfirmState // QuitConfirmState is 14 HelpModalState // HelpModalState is 15 + DataTransferState // DataTransferState is 16 + TransferExportPickerState // TransferExportPickerState is 17 + TransferImportPickerState // TransferImportPickerState is 18 + TransferRootPickerState // TransferRootPickerState is 19 ) const ( @@ -90,6 +95,7 @@ type RootModel struct { // Service Interface // Core Service core.DownloadService + Transfer core.TransferService Orchestrator *processing.LifecycleManager // File picker for directory selection @@ -147,6 +153,15 @@ type RootModel struct { // URL Refresh urlUpdateInput textinput.Model // Text input for updating URL + // Data transfer + transferIncludeLogs bool + transferIncludePartials bool + transferReplace bool + transferPreview *backup.ImportPreview + transferImportFile string + transferRootDir string + transferStatus string + // Category manager categoryFilter string // Dashboard filter ("" = all) catMgrCursor int // Selected category index @@ -382,6 +397,7 @@ func InitialRootModel(serverPort int, currentVersion string, service core.Downlo SettingsInput: settingsInput, searchInput: searchInput, urlUpdateInput: urlUpdateInput, + transferRootDir: settings.General.DefaultDownloadDir, catMgrInputs: [4]textinput.Model{catNameInput, catDescInput, catPatternInput, catPathInput}, keys: Keys, ServerPort: serverPort, @@ -393,6 +409,9 @@ func InitialRootModel(serverPort int, currentVersion string, service core.Downlo } InitAuthToken() // Cache auth token for TUI to avoid per-frame disk I/O + if strings.TrimSpace(m.transferRootDir) == "" { + m.transferRootDir = "." + } m.refreshThemeCaches() diff --git a/internal/tui/update.go b/internal/tui/update.go index a746c712..fad59392 100644 --- a/internal/tui/update.go +++ b/internal/tui/update.go @@ -1,6 +1,8 @@ package tui import ( + "fmt" + "github.com/SurgeDM/Surge/internal/config" "github.com/SurgeDM/Surge/internal/utils" @@ -124,6 +126,36 @@ func (m RootModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } return m, tea.Quit + case transferExportResultMsg: + if msg.err != nil { + m.transferStatus = "Export failed: " + msg.err.Error() + } else { + m.transferStatus = "Exported: " + msg.path + } + m.state = DataTransferState + return m, nil + + case transferPreviewResultMsg: + if msg.err != nil { + m.transferStatus = "Import preview failed: " + msg.err.Error() + } else { + m.transferPreview = msg.preview + m.transferStatus = "Preview loaded" + } + m.state = DataTransferState + return m, nil + + case transferApplyResultMsg: + if msg.err != nil { + m.transferStatus = "Import apply failed: " + msg.err.Error() + } else { + m.transferPreview = msg.result.Preview + m.transferStatus = fmt.Sprintf("Imported items: %d", msg.result.Imported) + m.UpdateListItems() + } + m.state = DataTransferState + return m, nil + case tea.PasteMsg: return m.updatePaste(msg) @@ -158,6 +190,17 @@ func (m RootModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { cmds = append(cmds, cmd) return m, tea.Batch(cmds...) } + + if m.state == TransferExportPickerState || m.state == TransferImportPickerState || m.state == TransferRootPickerState { + var cmd tea.Cmd + m.filepicker, cmd = m.filepicker.Update(msg) + if didSelect, path := m.filepicker.DidSelectFile(msg); didSelect { + model, selCmd := m.handleTransferFileSelection(path) + return model, tea.Batch(append(cmds, selCmd)...) + } + cmds = append(cmds, cmd) + return m, tea.Batch(cmds...) + } model, cmd := m.updateEvents(msg) cmds = append(cmds, cmd) return model, tea.Batch(cmds...) @@ -214,6 +257,12 @@ func (m RootModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } return m, nil + case DataTransferState: + return m.updateTransfer(msg) + + case TransferExportPickerState, TransferImportPickerState, TransferRootPickerState: + return m.updateTransferPicker(msg) + default: return m, nil } diff --git a/internal/tui/update_dashboard.go b/internal/tui/update_dashboard.go index ce62e14d..7124ca64 100644 --- a/internal/tui/update_dashboard.go +++ b/internal/tui/update_dashboard.go @@ -216,6 +216,18 @@ func (m RootModel) updateDashboard(msg tea.KeyPressMsg) (tea.Model, tea.Cmd) { return m, nil } + if key.Matches(msg, m.keys.Dashboard.DataTransfer) { + m.state = DataTransferState + m.transferStatus = "" + if m.transferRootDir == "" { + m.transferRootDir = m.Settings.General.DefaultDownloadDir + if m.transferRootDir == "" { + m.transferRootDir = "." + } + } + return m, nil + } + if key.Matches(msg, m.keys.Dashboard.CategoryFilter) { if !m.Settings.Categories.CategoryEnabled || len(m.Settings.Categories.Categories) == 0 { if m.categoryFilter != "" { diff --git a/internal/tui/update_transfer.go b/internal/tui/update_transfer.go new file mode 100644 index 00000000..18810ae6 --- /dev/null +++ b/internal/tui/update_transfer.go @@ -0,0 +1,180 @@ +package tui + +import ( + "context" + "fmt" + "os" + "path/filepath" + "time" + + "charm.land/bubbles/v2/key" + tea "charm.land/bubbletea/v2" + + "github.com/SurgeDM/Surge/internal/backup" +) + +func (m RootModel) exportBundleCmd(dir string) tea.Cmd { + return func() tea.Msg { + if m.Transfer == nil { + return transferExportResultMsg{err: fmt.Errorf("transfer service unavailable")} + } + name := fmt.Sprintf("surge-export-%s%s", time.Now().Format("20060102-150405"), backup.BundleExtension) + target := filepath.Join(dir, name) + file, err := os.Create(target) + if err != nil { + return transferExportResultMsg{err: err} + } + defer func() { _ = file.Close() }() + _, err = m.Transfer.Export(context.Background(), backup.ExportOptions{ + IncludeLogs: m.transferIncludeLogs, + IncludePartials: m.transferIncludePartials, + }, file) + return transferExportResultMsg{path: target, err: err} + } +} + +func (m RootModel) previewImportCmd(path string) tea.Cmd { + return func() tea.Msg { + if m.Transfer == nil { + return transferPreviewResultMsg{err: fmt.Errorf("transfer service unavailable")} + } + file, err := os.Open(filepath.Clean(path)) + if err != nil { + return transferPreviewResultMsg{err: err} + } + defer func() { _ = file.Close() }() + preview, err := m.Transfer.PreviewImport(context.Background(), file, backup.ImportOptions{ + RootDir: m.transferRootDir, + Replace: m.transferReplace, + }) + return transferPreviewResultMsg{preview: preview, err: err} + } +} + +func (m RootModel) applyImportCmd() tea.Cmd { + return func() tea.Msg { + if m.Transfer == nil { + return transferApplyResultMsg{err: fmt.Errorf("transfer service unavailable")} + } + + var src *os.File + var err error + opts := backup.ImportOptions{ + RootDir: m.transferRootDir, + Replace: m.transferReplace, + } + if m.transferPreview != nil { + opts.SessionID = m.transferPreview.SessionID + } + if opts.SessionID == "" { + src, err = os.Open(filepath.Clean(m.transferImportFile)) + if err != nil { + return transferApplyResultMsg{err: err} + } + defer func() { _ = src.Close() }() + } + result, err := m.Transfer.ApplyImport(context.Background(), src, opts) + return transferApplyResultMsg{result: result, err: err} + } +} + +func (m *RootModel) handleTransferFileSelection(path string) (tea.Model, tea.Cmd) { + switch m.state { + case TransferExportPickerState: + m.state = DataTransferState + m.transferStatus = "Exporting..." + return m, m.exportBundleCmd(path) + case TransferImportPickerState: + m.transferImportFile = path + m.state = DataTransferState + m.transferStatus = "Loading preview..." + return m, m.previewImportCmd(path) + case TransferRootPickerState: + m.transferRootDir = path + m.transferPreview = nil + m.transferStatus = "Import root changed. Reload preview." + m.state = DataTransferState + return m, nil + default: + m.state = DataTransferState + return m, nil + } +} + +func (m RootModel) updateTransferPicker(msg tea.KeyPressMsg) (tea.Model, tea.Cmd) { + if key.Matches(msg, m.keys.FilePicker.Cancel) { + m.state = DataTransferState + return m, nil + } + + if key.Matches(msg, m.keys.FilePicker.GotoHome) { + cmd := m.handleFilePickerGotoHome() + if m.state == TransferImportPickerState { + m.filepicker.FileAllowed = true + m.filepicker.DirAllowed = false + } + return m, cmd + } + + var cmd tea.Cmd + m.filepicker, cmd = m.filepicker.Update(msg) + if didSelect, path := m.filepicker.DidSelectFile(msg); didSelect { + return m.handleTransferFileSelection(path) + } + return m, cmd +} + +func (m RootModel) updateTransfer(msg tea.KeyPressMsg) (tea.Model, tea.Cmd) { + if key.Matches(msg, m.keys.Transfer.Close) { + m.state = DashboardState + return m, nil + } + if key.Matches(msg, m.keys.Transfer.TogglePartials) { + m.transferIncludePartials = !m.transferIncludePartials + return m, nil + } + if key.Matches(msg, m.keys.Transfer.ToggleLogs) { + m.transferIncludeLogs = !m.transferIncludeLogs + return m, nil + } + if key.Matches(msg, m.keys.Transfer.ToggleReplace) { + m.transferReplace = !m.transferReplace + m.transferPreview = nil + m.transferStatus = "Import mode changed. Reload preview." + return m, nil + } + if key.Matches(msg, m.keys.Transfer.BrowseRoot) { + dir := m.transferRootDir + if dir == "" { + dir = m.PWD + } + m.filepicker = newFilepicker(dir) + m.filepicker.DirAllowed = true + m.filepicker.FileAllowed = false + m.state = TransferRootPickerState + return m, m.filepicker.Init() + } + if key.Matches(msg, m.keys.Transfer.Export) { + m.filepicker = newFilepicker(m.PWD) + m.filepicker.DirAllowed = true + m.filepicker.FileAllowed = false + m.state = TransferExportPickerState + return m, m.filepicker.Init() + } + if key.Matches(msg, m.keys.Transfer.Import) { + m.filepicker = newFilepicker(m.PWD) + m.filepicker.DirAllowed = false + m.filepicker.FileAllowed = true + m.state = TransferImportPickerState + return m, m.filepicker.Init() + } + if key.Matches(msg, m.keys.Transfer.Apply) { + if m.transferPreview == nil { + m.transferStatus = "Load an import preview first" + return m, nil + } + m.transferStatus = "Applying import..." + return m, m.applyImportCmd() + } + return m, nil +} diff --git a/internal/tui/view.go b/internal/tui/view.go index 0d95601a..e041bdb5 100644 --- a/internal/tui/view.go +++ b/internal/tui/view.go @@ -233,6 +233,46 @@ func (m RootModel) View() tea.View { return m.wrapView(m.renderModalWithOverlay(box)) } + if m.state == DataTransferState { + return m.wrapView(m.renderModalWithOverlay(m.viewTransfer())) + } + + if m.state == TransferExportPickerState { + picker := components.NewFilePickerModal( + " Export Destination ", + m.filepicker, + m.help, + m.keys.FilePicker, + colors.NeonCyan, + ) + box := picker.RenderWithBtopBox(renderBtopBox, PaneTitleStyle) + return m.wrapView(m.renderModalWithOverlay(box)) + } + + if m.state == TransferImportPickerState { + picker := components.NewFilePickerModal( + " Select Bundle ", + m.filepicker, + m.help, + m.keys.FilePicker, + colors.NeonCyan, + ) + box := picker.RenderWithBtopBox(renderBtopBox, PaneTitleStyle) + return m.wrapView(m.renderModalWithOverlay(box)) + } + + if m.state == TransferRootPickerState { + picker := components.NewFilePickerModal( + " Import Root ", + m.filepicker, + m.help, + m.keys.FilePicker, + colors.NeonCyan, + ) + box := picker.RenderWithBtopBox(renderBtopBox, PaneTitleStyle) + return m.wrapView(m.renderModalWithOverlay(box)) + } + if m.state == HelpModalState { modalW := PopupWidth if m.width < modalW { diff --git a/internal/tui/view_transfer.go b/internal/tui/view_transfer.go new file mode 100644 index 00000000..8b2c07a9 --- /dev/null +++ b/internal/tui/view_transfer.go @@ -0,0 +1,64 @@ +package tui + +import ( + "fmt" + "strings" + + "charm.land/lipgloss/v2" + + "github.com/SurgeDM/Surge/internal/tui/colors" +) + +func boolLabel(enabled bool) string { + if enabled { + return "on" + } + return "off" +} + +func (m RootModel) viewTransfer() string { + width := 76 + if m.width > 0 && m.width < width { + width = m.width + } + + lines := []string{ + lipgloss.NewStyle().Foreground(colors.NeonCyan).Bold(true).Render("Data Transfer"), + "", + "[e] Export bundle", + "[i] Import bundle", + fmt.Sprintf("[p] Include partials: %s", boolLabel(m.transferIncludePartials)), + fmt.Sprintf("[l] Include logs: %s", boolLabel(m.transferIncludeLogs)), + fmt.Sprintf("[x] Replace on apply: %s", boolLabel(m.transferReplace)), + fmt.Sprintf("[r] Import root: %s", m.transferRootDir), + } + + if m.transferImportFile != "" { + lines = append(lines, fmt.Sprintf("Bundle: %s", truncateString(m.transferImportFile, 58))) + } + if m.transferPreview != nil { + lines = append(lines, + "", + lipgloss.NewStyle().Foreground(colors.NeonPink).Render("Preview"), + fmt.Sprintf("Imports: %v", m.transferPreview.ImportsByStatus), + fmt.Sprintf("Duplicates skipped: %d", m.transferPreview.DuplicatesSkipped), + fmt.Sprintf("Renamed items: %d", m.transferPreview.RenamedItems), + fmt.Sprintf("Downgraded to queue: %d", m.transferPreview.ResumableJobsDowngradedToQueue), + "[a] Apply import", + ) + } + if strings.TrimSpace(m.transferStatus) != "" { + lines = append(lines, "", lipgloss.NewStyle().Foreground(colors.LightGray).Render(m.transferStatus)) + } + + body := strings.Join(lines, "\n") + box := renderBtopBox( + PaneTitleStyle.Render(" Data Transfer "), + "", + lipgloss.NewStyle().Padding(1, 2).Render(body), + width, + 20, + colors.NeonCyan, + ) + return box +}