From 5b42a49b7bc5494337a546cb13d980d7d4e0406b Mon Sep 17 00:00:00 2001 From: Youngmin Koo Date: Mon, 1 Sep 2025 18:15:05 +0900 Subject: [PATCH] fix restore command to support documented URL patterns Signed-off-by: Youngmin Koo --- cmd/restore.go | 21 +++++++++++++++------ pkg/storage/s3/s3.go | 12 ++++++++++-- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/cmd/restore.go b/cmd/restore.go index 127247f..c4212ac 100644 --- a/cmd/restore.go +++ b/cmd/restore.go @@ -28,7 +28,7 @@ func restoreCmd(passedExecs execs, cmdConfig *cmdConfiguration) (*cobra.Command, PreRun: func(cmd *cobra.Command, args []string) { bindFlags(cmd, v) }, - Args: cobra.MinimumNArgs(1), + Args: cobra.RangeArgs(0, 1), RunE: func(cmd *cobra.Command, args []string) error { cmdConfig.logger.Debug("starting restore") ctx := context.Background() @@ -40,8 +40,20 @@ func restoreCmd(passedExecs execs, cmdConfig *cmdConfiguration) (*cobra.Command, }() ctx = util.ContextWithTracer(ctx, tracer) _, startupSpan := tracer.Start(ctx, "startup") - targetFile := args[0] - target := v.GetString("target") + + // Get target from args[0], --target flag, or DB_RESTORE_TARGET environment variable + var target string + if len(args) > 0 { + target = args[0] + } else { + target = v.GetString("target") + } + if target == "" { + return fmt.Errorf("target must be specified as argument, --target flag, or DB_RESTORE_TARGET environment variable") + } + + // Always pass empty targetFile to use the full path from the URL + targetFile := "" // get databases namesand mappings databasesMap := make(map[string]string) databases := strings.TrimSpace(v.GetString("database")) @@ -144,9 +156,6 @@ func restoreCmd(passedExecs execs, cmdConfig *cmdConfiguration) (*cobra.Command, flags := cmd.Flags() flags.String("target", "", "full URL target to the backup that you wish to restore") - if err := cmd.MarkFlagRequired("target"); err != nil { - return nil, err - } // compression flags.String("compression", defaultCompression, "Compression to use. Supported are: `gzip`, `bzip2`, `none`") diff --git a/pkg/storage/s3/s3.go b/pkg/storage/s3/s3.go index ab1c236..8738260 100644 --- a/pkg/storage/s3/s3.go +++ b/pkg/storage/s3/s3.go @@ -73,7 +73,15 @@ func (s *S3) Pull(ctx context.Context, source, target string, logger *log.Entry) return 0, fmt.Errorf("failed to get AWS client: %v", err) } - bucket, path := s.url.Hostname(), path.Join(s.url.Path, source) + bucket := s.url.Hostname() + // If source is empty, use the path from URL directly (for restore command) + // Otherwise, append source to the URL path (for dump command) + var objectPath string + if source == "" { + objectPath = strings.TrimPrefix(s.url.Path, "/") + } else { + objectPath = strings.TrimPrefix(path.Join(s.url.Path, source), "/") + } // Create a downloader with the session and default options downloader := manager.NewDownloader(client) @@ -88,7 +96,7 @@ func (s *S3) Pull(ctx context.Context, source, target string, logger *log.Entry) // Write the contents of S3 Object to the file n, err := downloader.Download(context.TODO(), f, &s3.GetObjectInput{ Bucket: aws.String(bucket), - Key: aws.String(path), + Key: aws.String(objectPath), }) if err != nil { return 0, fmt.Errorf("failed to download file, %v", err)