diff --git a/common.go b/common.go index 2077a10..2dca544 100644 --- a/common.go +++ b/common.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "sync" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" @@ -13,8 +14,10 @@ import ( // checkBucket checks a buckets region. We use the HeadObject function as this is can be used anonymously and is not // subject to the buckets policy like GetBucketRegion. The region is retorned in the header of the HTTP response. -func checkBucket(sess *session.Session, bucket string) (svc *s3.S3, err error) { - +func checkBucket(sess *session.Session, bucket string, wg *sync.WaitGroup) (svc *s3.S3, err error) { + if wg != nil { + defer wg.Done() + } var logger = zap.S() svc = s3.New(sess) @@ -27,52 +30,51 @@ func checkBucket(sess *session.Session, bucket string) (svc *s3.S3, err error) { svc = s3.New(sess, &aws.Config{MaxRetries: aws.Int(30), Region: aws.String(s3.NormalizeBucketLocation(*result.LocationConstraint))}) - return - } else { + return svc, err + } - if aerr, ok := err.(awserr.Error); ok { + if aerr, ok := err.(awserr.Error); ok { - if aerr.Code() == s3.ErrCodeNoSuchBucket { - fmt.Println(aerr.Message()) - logger.Fatal(aerr.Message()) - return - } + if aerr.Code() == s3.ErrCodeNoSuchBucket { + fmt.Println(aerr.Message() + ": " + bucket) + logger.Fatal(aerr.Message()) + return svc, err + } + + //Try getting the region via head-object. + svc = s3.New(session.Must(session.NewSession(&aws.Config{ + Credentials: credentials.AnonymousCredentials, + Region: sess.Config.Region, + }))) + + req, _ := svc.HeadBucketRequest(&s3.HeadBucketInput{ + Bucket: aws.String(bucket), + }) + + err = req.Send() + if err != nil { + fmt.Println(err) + + if aerr, ok := err.(awserr.Error); ok { + switch aerr.Code() { + + case s3.ErrCodeNoSuchBucket: + logger.Fatal(aerr.Message()) + default: + logger.Fatal(aerr.Error()) - //Try getting the region via head-object. - svc = s3.New(session.Must(session.NewSession(&aws.Config{ - Credentials: credentials.AnonymousCredentials, - Region: sess.Config.Region, - }))) - - req, _ := svc.HeadBucketRequest(&s3.HeadBucketInput{ - Bucket: aws.String(bucket), - }) - - err = req.Send() - if err != nil { - fmt.Println(err) - - if aerr, ok := err.(awserr.Error); ok { - switch aerr.Code() { - - case s3.ErrCodeNoSuchBucket: - logger.Fatal(aerr.Message()) - default: - logger.Fatal(aerr.Error()) - - } - } else { - // Print the error, cast err to awserr.Error to get the Code and - // Message from an error. - logger.Fatal(err.Error()) } - return svc, err - } //if + } else { + // Print the error, cast err to awserr.Error to get the Code and + // Message from an error. + logger.Fatal(err.Error()) + } + return svc, err + } //if - req.HTTPResponse.Header.Get("X-Amz-Bucket-Region") - svc = s3.New(sess, &aws.Config{MaxRetries: aws.Int(30), - Region: aws.String(s3.NormalizeBucketLocation(req.HTTPResponse.Header.Get("X-Amz-Bucket-Region")))}) - } + req.HTTPResponse.Header.Get("X-Amz-Bucket-Region") + svc = s3.New(sess, &aws.Config{MaxRetries: aws.Int(30), + Region: aws.String(s3.NormalizeBucketLocation(req.HTTPResponse.Header.Get("X-Amz-Bucket-Region")))}) } return svc, err diff --git a/copy.go b/copy.go index 7c40ff6..689db22 100644 --- a/copy.go +++ b/copy.go @@ -561,7 +561,7 @@ func (cp *BucketCopier) copy(recursive bool) { // a bucket func NewBucketCopier(source string, dest string, threads int, quiet bool, sess *session.Session, template s3manager.UploadInput) (*BucketCopier, error) { - var svc *s3.S3 + var svc, destSvc *s3.S3 sourceURL, err := url.Parse(source) if err != nil { return nil, err @@ -577,18 +577,28 @@ func NewBucketCopier(source string, dest string, threads int, quiet bool, sess * } + var wg sync.WaitGroup if sourceURL.Scheme == "s3" { - svc, err = checkBucket(sess, sourceURL.Host) - if err != nil { - return nil, err - } + wg.Add(1) + go func() { + svc, err = checkBucket(sess, sourceURL.Host, &wg) + }() } if destURL.Scheme == "s3" { - svc, err = checkBucket(sess, destURL.Host) - if err != nil { - return nil, err - } + wg.Add(1) + go func() { + destSvc, err = checkBucket(sess, destURL.Host, &wg) + }() + } + + wg.Wait() + if err != nil { + return nil, err + } + + if svc == nil { + svc = destSvc } template.Bucket = aws.String(destURL.Host) @@ -607,7 +617,7 @@ func NewBucketCopier(source string, dest string, threads int, quiet bool, sess * } if sourceURL.Scheme == "s3" { - bc.lister, err = NewBucketLister(source, threads, sess) + bc.lister, err = NewBucketListerWithSvc(source, threads, svc) //if destURL.Scheme == "s3" { bc.objects = make(chan []*s3.Object, threads) bc.lister.objects = bc.objects diff --git a/delete.go b/delete.go index 5f47e9a..4a7cc67 100644 --- a/delete.go +++ b/delete.go @@ -186,7 +186,7 @@ func NewBucketDeleter(source string, quite bool, threads int, versions bool, rec bd.lister.objects = bd.objects } - bd.svc, err = checkBucket(sess, sourceURL.Host) + bd.svc, err = checkBucket(sess, sourceURL.Host, nil) if err != nil { return nil, err } diff --git a/list.go b/list.go index 2e769e9..cd104fd 100644 --- a/list.go +++ b/list.go @@ -225,9 +225,7 @@ func (bl *BucketLister) List(versions bool) { bl.printAllObjects(versions) } -// NewBucketLister creates a new BucketLister struct initialized with all variables needed to list a bucket -func NewBucketLister(source string, threads int, sess *session.Session) (*BucketLister, error) { - +func initBucketLister(source string, threads int) (*BucketLister, error) { sourceURL, err := url.Parse(source) if err != nil { return nil, err @@ -247,11 +245,29 @@ func NewBucketLister(source string, threads int, sess *session.Session) (*Bucket sizeChan: make(chan objectCounter, threads), threads: threads, } + return bl, nil +} - bl.svc, err = checkBucket(sess, sourceURL.Host) - if err != nil { - return nil, err +// NewBucketLister creates a new BucketLister struct initialized with all variables needed to list a bucket +func NewBucketLister(source string, threads int, sess *session.Session) (*BucketLister, error) { + + bl, err := initBucketLister(source, threads) + + if err == nil { + bl.svc, err = checkBucket(sess, bl.source.Host, nil) } - return bl, nil + return bl, err +} + +// NewBucketListerWithSvc creates a new BucketLister struct initialized with all variables needed to list a bucket +func NewBucketListerWithSvc(source string, threads int, svc *s3.S3) (*BucketLister, error) { + + bl, err := initBucketLister(source, threads) + + if err == nil { + bl.svc = svc + } + + return bl, err } diff --git a/multicopy.go b/multicopy.go index 7154673..8117695 100644 --- a/multicopy.go +++ b/multicopy.go @@ -23,7 +23,7 @@ const MaxCopyParts = 10000 // Amazon S3. Should be 5MB const MinCopyPartSize int64 = 1024 * 1024 * 5 -// MinCopyPartSize is the maximum allowed part size when copying a part to +// MaxCopyPartSize is the maximum allowed part size when copying a part to // Amazon S3. Sould be 5GB const MaxCopyPartSize int64 = 1024 * 1024 * 1024 * 5