Skip to content

Commit

Permalink
Copy initial version (#2)
Browse files Browse the repository at this point in the history
* First pass of copy with some nice progress bars

* First release with cp and progress bars
  • Loading branch information
sethkor authored May 5, 2019
1 parent 79a9133 commit 8d1cdb5
Show file tree
Hide file tree
Showing 8 changed files with 395 additions and 31 deletions.
23 changes: 12 additions & 11 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ package main

import (
"fmt"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"go.uber.org/zap"
)

func checkBucket(sess *session.Session, bucket string, autoRegion bool) (svc *s3.S3, err error) {
func checkBucket(sess *session.Session, bucket string) (svc *s3.S3, err error) {

var logger = zap.S()

Expand Down Expand Up @@ -43,15 +44,15 @@ func checkBucket(sess *session.Session, bucket string, autoRegion bool) (svc *s3
return svc, err
} //if

if autoRegion {
svc = s3.New(sess, &aws.Config{MaxRetries: aws.Int(30),
Region: aws.String(bucketLocation)})
} else {
if *svc.Config.Region != bucketLocation {
fmt.Println("Bucket exist in region", bucketLocation, "which is different to region passed", *svc.Config.Region, ". Please adjust region on the command line our use --auto-region")
logger.Fatal("Bucket exist in region", bucketLocation, "which is different to region passed", *svc.Config.Region, ". Please adjust region on the command line our use --auto-region")
}
}
//if autoRegion {
svc = s3.New(sess, &aws.Config{MaxRetries: aws.Int(30),
Region: aws.String(bucketLocation)})
//} else {
// if *svc.Config.Region != bucketLocation {
// fmt.Println("Bucket exist in region", bucketLocation, "which is different to region passed", *svc.Config.Region, ". Please adjust region on the command line our use --auto-region")
// logger.Fatal("Bucket exist in region", bucketLocation, "which is different to region passed", *svc.Config.Region, ". Please adjust region on the command line our use --auto-region")
// }
//}

return svc, err
}
}
4 changes: 4 additions & 0 deletions constants.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package main

const bigChanSize int64 = 1000000

251 changes: 251 additions & 0 deletions copy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
package main

import (
"errors"
"net/url"
"os"
"path/filepath"
"sync"
"time"

"github.com/aws/aws-sdk-go/service/s3"

"github.com/vbauerster/mpb/decor"

"github.com/vbauerster/mpb"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"go.uber.org/zap"
)

//The different type of progress bars. We have one for counting files and another for counting sizes
type progressBar struct {
count *mpb.Bar
fileSize *mpb.Bar
}

type BucketCopier struct {
source url.URL
target url.URL
acl string `enum:"ObjectCannedACL"`
uploadManager s3manager.Uploader
bars progressBar
wg *sync.WaitGroup
files chan fileJob
fileCounter chan int64
threads semaphore
template s3manager.UploadInput
}

func (copier BucketCopier) uploadFile() func(file fileJob) {
var logger = zap.S()

//Some logic to determin the base path to be used as the prefix for S3. If the source pass ends with a "/" then
//the base of the source path is not used in the S3 prefix as we assume iths the contents of the directory, not
//the actual directory that is needed in the copy
_, splitFile := filepath.Split(copier.source.Path)
includeRoot := 0
if splitFile != "" {
includeRoot = len(splitFile)
}

sourceLength := len(copier.source.Path) - includeRoot
if len(copier.source.Path) == 0 {
sourceLength++

}

return func(file fileJob) {
defer copier.threads.release(1)
start := time.Now()
input := copier.template
if file.info.IsDir() {
//Don't create a prefix for the base dir
if len(file.path) != sourceLength {
input.Key = aws.String(copier.target.Path + "/" + file.path[sourceLength:] + "/")
_, err := copier.uploadManager.Upload(&input)

if err != nil {
logger.Error("Prefix failed to create in S3 ", file.path)

if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
default:
logger.Error(aerr.Error())
} //switch
} else {
// Message from an error.
logger.Error(err.Error())
} //else
return
}
logger.Info("dir>>>>s3 ", file.path)
} //if
} else {

f, err := os.Open(file.path)
if err != nil {
logger.Errorf("failed to open file %q, %v", file.path, err)
} else {
// Upload the file to S3.
input.Key = aws.String(copier.target.Path + "/" + file.path[sourceLength:])
input.Body = f
_, err = copier.uploadManager.Upload(&input)

if err != nil {
logger.Error("Object failed to create in S3 ", file.path)

if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
default:
logger.Error(aerr.Error())
} //switch
} else {
// Message from an error.
logger.Error(err.Error())
} //else
}
_ = f.Close()
logger.Debug("file>>>s3 ", file.path)
} //else
} //else
copier.bars.count.IncrInt64(1, time.Since(start))
copier.bars.fileSize.IncrInt64(file.info.Size(), time.Since(start))
}
}

func (copier BucketCopier) processFiles() {
defer copier.wg.Done()

allThreads := len(copier.threads)
uploadFileFunc := copier.uploadFile()
for file := range copier.files {
copier.threads.acquire(1) // or block until one slot is free
go uploadFileFunc(file)
} //for
copier.threads.acquire(allThreads) // don't continue until all goroutines complete

}

func (pb progressBar) updateBar(fileSize <-chan int64, wg *sync.WaitGroup) {
defer wg.Done()
var fileCount int64 = 0
var fileSizeTotal int64 = 0

//var chunk int64 = 0
for size := range fileSize {
fileCount++
fileSizeTotal += size
pb.count.SetTotal(fileCount, false)
pb.fileSize.SetTotal(fileSizeTotal, false)

}

}

func ACL(acl string) func(copier *BucketCopier) {
return func(copier *BucketCopier) {
copier.acl = acl
}
}

func NewBucketCopier(source string, dest string, sess *session.Session, template s3manager.UploadInput) (*BucketCopier, error) {

var svc *s3.S3 = nil
sourceURL, err := url.Parse(source)
if err != nil {
return nil, err
}

destURL, err := url.Parse(dest)
if err != nil {
return nil, err
}

if sourceURL.Scheme != "s3" && destURL.Scheme != "s3" {
return nil, errors.New("usage: aws s3 cp <LocalPath> <S3Uri> or <S3Uri> <LocalPath> or <S3Uri> <S3Uri>")

}

if sourceURL.Scheme == "s3" {
svc, err = checkBucket(sess, sourceURL.Host)
if err != nil {
return nil, err
}
}

if destURL.Scheme == "s3" {
svc, err = checkBucket(sess, destURL.Host)
if err != nil {
return nil, err
}
}

template.Bucket = aws.String(destURL.Host)

s3manager.NewUploaderWithClient(svc)

bc := &BucketCopier{
source: *sourceURL,
target: *destURL,
uploadManager: *s3manager.NewUploaderWithClient(svc),
threads: make(semaphore, 1000),
files: make(chan fileJob, 1000),
fileCounter: make(chan int64, 1000),
wg: &sync.WaitGroup{},
template: template,
}

return bc, nil
}

func (myCopier BucketCopier) copy() {
//var logger = zap.S()

if myCopier.source.Scheme != "s3" {

go walkFiles(myCopier.source.Path, myCopier.files, myCopier.fileCounter)
}

progress := mpb.New()

myCopier.bars.count = progress.AddBar(0,
mpb.PrependDecorators(
// simple name decorator
decor.Name("Files", decor.WC{W: 6, C: decor.DSyncWidth}),
decor.CountersNoUnit(" %d / %d", decor.WCSyncWidth),
),

mpb.AppendDecorators(
decor.Percentage(decor.WCSyncWidth),
decor.Name(" "),
decor.MovingAverageETA(decor.ET_STYLE_GO, decor.NewMedian(), decor.FixedIntervalTimeNormalizer(5000), decor.WCSyncSpaceR),
),
)

myCopier.bars.fileSize = progress.AddBar(0,
mpb.PrependDecorators(
decor.Name("Size ", decor.WC{W: 6, C: decor.DSyncWidth}),
decor.Counters(decor.UnitKB, "% .1f / % .1f", decor.WCSyncWidth),
),
mpb.AppendDecorators(
decor.Percentage(decor.WCSyncWidth),
decor.Name(" "),
decor.AverageSpeed(decor.UnitKB, "% .1f", decor.WCSyncWidth),
),
)

myCopier.wg.Add(2)

go myCopier.bars.updateBar(myCopier.fileCounter, myCopier.wg)

go myCopier.processFiles()

myCopier.wg.Wait()

progress.Wait()

}
16 changes: 5 additions & 11 deletions delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ func deleteAllObjects(deleteBucket string, resultsChan <-chan []*s3.ObjectIdenti

}

func delete(sess *session.Session, path string, autoRegion bool, versions bool, recursive bool) {
func delete(sess *session.Session, path string, versions bool, recursive bool) {
var logger = zap.S()

s3URL, err := url.Parse(path)

if err == nil && s3URL.Scheme == "s3" {

var svc *s3.S3
svc, err = checkBucket(sess, s3URL.Host, autoRegion)
svc, err = checkBucket(sess, s3URL.Host)

if recursive {
threads := 50
Expand Down Expand Up @@ -95,8 +95,7 @@ func delete(sess *session.Session, path string, autoRegion bool, versions bool,
logger.Fatal("Must pass an object in the bucket to remove, not just the bucket name")
}


if versions{
if versions {
//we want to delete all versions of the object specified

//make a channel for processing
Expand All @@ -108,11 +107,10 @@ func delete(sess *session.Session, path string, autoRegion bool, versions bool,

go listObjectVersions(*s3URL, resultsChan, true, bar, *svc)


} else {
_, err = svc.DeleteObject( &s3.DeleteObjectInput{
_, err = svc.DeleteObject(&s3.DeleteObjectInput{
Bucket: aws.String(s3URL.Host),
Key: aws.String(s3URL.Path[1:]),
Key: aws.String(s3URL.Path[1:]),
})

if err != nil {
Expand All @@ -132,12 +130,8 @@ func delete(sess *session.Session, path string, autoRegion bool, versions bool,
}
}



} else {
fmt.Println("S3 URL passed not formatted correctly")
logger.Fatal("S3 URL passed not formatted correctly")
}
}


Loading

0 comments on commit 8d1cdb5

Please sign in to comment.