diff --git a/pkg/storage/sos/object.go b/pkg/storage/sos/object.go index 0e1d799d..7c77aaca 100644 --- a/pkg/storage/sos/object.go +++ b/pkg/storage/sos/object.go @@ -129,15 +129,18 @@ type DownloadConfig struct { } func (c *Client) DownloadFiles(ctx context.Context, config *DownloadConfig) error { - if len(config.Objects) > 1 && !strings.HasSuffix(config.Destination, "/") { - return errors.New(`multiple files to download, destination must end with "/"`) - } - - // Handle relative filesystem destination (e.g. ".", "../.." etc.) - if dstInfo, err := os.Stat(config.Destination); err == nil { - if dstInfo.IsDir() && !strings.HasSuffix(config.Destination, "/") { - config.Destination += "/" + config.Destination = filepath.Clean(strings.TrimRight(config.Destination, string(os.PathSeparator))) + dstInfo, err := os.Stat(config.Destination) + if err == nil { + if !dstInfo.IsDir() { + return fmt.Errorf("destination %q is not a directory, use flag `-f` to overwrite", config.Destination) } + } else if os.IsNotExist(err) { + if err := os.MkdirAll(config.Destination, 0o755); err != nil { + return fmt.Errorf("failed to create destination directory: %w", err) + } + } else { + return fmt.Errorf("error checking destination path: %w", err) } if config.DryRun { @@ -146,19 +149,37 @@ func (c *Client) DownloadFiles(ctx context.Context, config *DownloadConfig) erro for _, object := range config.Objects { dst := func() string { - if strings.HasSuffix(config.Source, "/") { - return path.Join(config.Destination, strings.TrimPrefix(aws.ToString(object.Key), config.Prefix)) + if config.Recursive { + relativePath := strings.TrimPrefix(aws.ToString(object.Key), config.Prefix) + if !strings.HasPrefix(config.Prefix, "/") && !strings.HasPrefix(relativePath, config.Prefix) { + relativePath = filepath.Join(config.Prefix, relativePath) + } + return filepath.Join(config.Destination, relativePath) } - if strings.HasSuffix(config.Destination, "/") { - return path.Join(config.Destination, path.Base(aws.ToString(object.Key))) + if strings.HasSuffix(config.Destination, string(os.PathSeparator)) || dstInfo.IsDir() { + return filepath.Join(config.Destination, path.Base(aws.ToString(object.Key))) } - return path.Join(config.Destination) + return filepath.Clean(config.Destination) }() + if strings.HasSuffix(aws.ToString(object.Key), "/") { + if err := os.MkdirAll(dst, 0o755); err != nil { + return fmt.Errorf("failed to create directory %s: %w", dst, err) + } + continue + } + + parentDir := filepath.Dir(dst) + if _, err := os.Stat(parentDir); os.IsNotExist(err) { + if err := os.MkdirAll(parentDir, 0o755); err != nil { + return fmt.Errorf("failed to create directories: %w", err) + } + } + if config.DryRun { - fmt.Printf("%s/%s -> %s\n", config.Bucket, aws.ToString(object.Key), dst) + fmt.Printf("[DRY-RUN] %s/%s -> %s\n", config.Bucket, aws.ToString(object.Key), dst) continue } @@ -166,12 +187,6 @@ func (c *Client) DownloadFiles(ctx context.Context, config *DownloadConfig) erro return fmt.Errorf("file %q already exists, use flag `-f` to overwrite", dst) } - if _, err := os.Stat(path.Dir(dst)); errors.Is(err, os.ErrNotExist) { - if err := os.MkdirAll(path.Dir(dst), 0o755); err != nil { - return err - } - } - if err := c.DownloadFile(ctx, config.Bucket, object, dst); err != nil { return err }