Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Persist Python interpreter for repeated use #6

Merged
merged 15 commits into from
Jul 23, 2024
202 changes: 41 additions & 161 deletions bagit.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
package bagit

import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"

"github.com/artefactual-labs/bagit-gython/internal/dist/data"
Expand All @@ -17,18 +13,27 @@ import (
"github.com/kluctl/go-embed-python/python"
)

// ErrInvalid indicates that bag validation failed. If there is a validation
// error message, ErrInvalid will be wrapped so make sure to use
// `errors.Is(err, ErrInvalid)` to test equivalency.
var ErrInvalid = errors.New("invalid")
var (
// ErrInvalid indicates that bag validation failed. If there is a validation
// error message, ErrInvalid will be wrapped so make sure to use
// `errors.Is(err, ErrInvalid)` to test equivalency.
ErrInvalid = errors.New("invalid")

// ErrBusy is returned when an operation is attempted on BagIt while it is
// already processing another command. This ensures that only one command is
// processed at a time, preventing race conditions and ensuring the
// integrity of the shared resources.
ErrBusy = errors.New("runner is busy")
)

// BagIt is an abstraction to work with BagIt packages that embeds Python and
// the bagit-python.
type BagIt struct {
tmpDir string // Top-level container for embedded files.
ep *python.EmbeddedPython // Python files.
lib *embed_util.EmbeddedFiles // bagit-python library files.
runner *embed_util.EmbeddedFiles // bagit-python wrapper files (runner).
tmpDir string // Top-level container for embedded files.
embedPython *python.EmbeddedPython // Python files.
embedBagit *embed_util.EmbeddedFiles // bagit-python library files.
embedRunner *embed_util.EmbeddedFiles // bagit-python wrapper files (runner).
runner *pyRunner
}

// NewBagIt creates and initializes a new BagIt instance. This constructor is
Expand All @@ -44,54 +49,28 @@ func NewBagIt() (*BagIt, error) {
return nil, fmt.Errorf("make tmpDir: %v", err)
}

ep, err := python.NewEmbeddedPythonWithTmpDir(filepath.Join(b.tmpDir, "python"), true)
b.embedPython, err = python.NewEmbeddedPythonWithTmpDir(filepath.Join(b.tmpDir, "python"), true)
if err != nil {
return nil, fmt.Errorf("embed python: %v", err)
}
b.ep = ep

b.lib, err = embed_util.NewEmbeddedFilesWithTmpDir(data.Data, filepath.Join(b.tmpDir, "bagit-lib"), true)
b.embedBagit, err = embed_util.NewEmbeddedFilesWithTmpDir(data.Data, filepath.Join(b.tmpDir, "bagit-lib"), true)
if err != nil {
return nil, fmt.Errorf("embed bagit: %v", err)
}
b.ep.AddPythonPath(b.lib.GetExtractedPath())
b.embedPython.AddPythonPath(b.embedBagit.GetExtractedPath())

b.runner, err = embed_util.NewEmbeddedFilesWithTmpDir(runner.Source, filepath.Join(b.tmpDir, "bagit-runner"), true)
b.embedRunner, err = embed_util.NewEmbeddedFilesWithTmpDir(runner.Source, filepath.Join(b.tmpDir, "bagit-runner"), true)
if err != nil {
return nil, fmt.Errorf("embed runner: %v", err)
}

return b, nil
}

// create a Python intepreter running the bagit-python wrapper.
func (b *BagIt) create() (*runnerInstance, error) {
i := &runnerInstance{}

cmd, err := b.ep.PythonCmd(filepath.Join(b.runner.GetExtractedPath(), "main.py"))
if err != nil {
return nil, fmt.Errorf("create command: %v", err)
}
i.cmd = cmd

stdin, err := cmd.StdinPipe()
if err != nil {
return nil, fmt.Errorf("create stdin pipe: %v", err)
}
i.stdin = stdin

stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("create stdout pipe: %v", err)
}
i.stdout = stdout

err = cmd.Start()
if err != nil {
return nil, fmt.Errorf("cmd: %v", err)
}
b.runner = createRunner(
b.embedPython,
filepath.Join(b.embedRunner.GetExtractedPath(), "main.py"),
)

return i, nil
return b, nil
}

type validateRequest struct {
Expand All @@ -104,41 +83,15 @@ type validateResponse struct {
}

func (b *BagIt) Validate(path string) error {
i, err := b.create()
blob, err := b.runner.send("validate", &validateRequest{
Path: path,
})
if err != nil {
return fmt.Errorf("run python: %v", err)
}
defer i.stop()

reader := bufio.NewReader(i.stdout)

if err := i.send(args{
Cmd: "validate",
Opts: &validateRequest{
Path: path,
},
}); err != nil {
return err
}

line := bytes.NewBuffer(nil)
for {
l, p, err := reader.ReadLine()
if err != nil && err != io.EOF {
return fmt.Errorf("read line: %v", err)
}
line.Write(l)
if !p {
break
}
}

if line.Len() < 1 {
return fmt.Errorf("response not received")
}

r := validateResponse{}
err = json.Unmarshal(line.Bytes(), &r)
err = json.Unmarshal(blob, &r)
if err != nil {
return fmt.Errorf("decode response: %v", err)
}
Expand All @@ -162,41 +115,15 @@ type makeResponse struct {
}

func (b *BagIt) Make(path string) error {
i, err := b.create()
blob, err := b.runner.send("make", &makeRequest{
Path: path,
})
if err != nil {
return fmt.Errorf("run python: %v", err)
}
defer i.stop()

reader := bufio.NewReader(i.stdout)

if err := i.send(args{
Cmd: "make",
Opts: &makeRequest{
Path: path,
},
}); err != nil {
return err
}

line := bytes.NewBuffer(nil)
for {
l, p, err := reader.ReadLine()
if err != nil && err != io.EOF {
return fmt.Errorf("read line: %v", err)
}
line.Write(l)
if !p {
break
}
}

if line.Len() < 1 {
return fmt.Errorf("response not received")
}

r := makeResponse{}
err = json.Unmarshal(line.Bytes(), &r)
err = json.Unmarshal(blob, &r)
if err != nil {
return fmt.Errorf("decode response: %v", err)
}
Expand All @@ -210,15 +137,19 @@ func (b *BagIt) Make(path string) error {
func (b *BagIt) Cleanup() error {
var e error

if err := b.runner.Cleanup(); err != nil {
if err := b.runner.stop(); err != nil {
e = errors.Join(e, fmt.Errorf("stop runner: %v", err))
}

if err := b.embedRunner.Cleanup(); err != nil {
e = errors.Join(e, fmt.Errorf("clean up runner: %v", err))
}

if err := b.lib.Cleanup(); err != nil {
if err := b.embedBagit.Cleanup(); err != nil {
e = errors.Join(e, fmt.Errorf("clean up bagit: %v", err))
}

if err := b.ep.Cleanup(); err != nil {
if err := b.embedPython.Cleanup(); err != nil {
e = errors.Join(e, fmt.Errorf("clean up python: %v", err))
}

Expand All @@ -228,54 +159,3 @@ func (b *BagIt) Cleanup() error {

return e
}

type args struct {
Cmd string `json:"cmd"`
Opts any `json:"opts"`
}

// runnerInstance is an instance of a Python interpreter executing the
// bagit-python runner.
type runnerInstance struct {
cmd *exec.Cmd
stdin io.WriteCloser
stdout io.ReadCloser
}

// send a command to the runner.
func (i *runnerInstance) send(args args) error {
blob, err := json.Marshal(args)
if err != nil {
return fmt.Errorf("encode args: %v", err)
}
blob = append(blob, '\n')

_, err = i.stdin.Write(blob)
if err != nil {
return fmt.Errorf("write blob: %v", err)
}

return nil
}

func (i *runnerInstance) stop() error {
var e error

if err := i.stdin.Close(); err != nil {
e = errors.Join(e, err)
}

if err := i.stdout.Close(); err != nil {
e = errors.Join(e, err)
}

if err := i.cmd.Process.Kill(); err != nil {
e = errors.Join(e, err)
}

if _, err := i.cmd.Process.Wait(); err != nil {
e = errors.Join(e, err)
}

return e
}
75 changes: 59 additions & 16 deletions bagit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,31 @@ import (
"gotest.tools/v3/fs"
)

func setUp(t *testing.T) *bagit.BagIt {
t.Helper()
func setUp(tb testing.TB) *bagit.BagIt {
tb.Helper()

b, err := bagit.NewBagIt()
assert.NilError(t, err)
assert.NilError(tb, err)

t.Cleanup(func() {
assert.NilError(t, b.Cleanup())
tb.Cleanup(func() {
assert.NilError(tb, b.Cleanup())
})

return b
}

func TestValidateBag(t *testing.T) {
t.Parallel()
func BenchmarkValidate(b *testing.B) {
bagit := setUp(b)

t.Run("Fails validation", func(t *testing.T) {
t.Parallel()
b.ResetTimer()

b := setUp(t)
for i := 0; i < b.N; i++ {
_ = bagit.Validate("internal/testdata/valid-bag")
}
}

err := b.Validate("/tmp/691b8e7f-e6b7-41dd-bc47-868e2ff69333")
assert.Error(t, err, "invalid: Expected bagit.txt does not exist: /tmp/691b8e7f-e6b7-41dd-bc47-868e2ff69333/bagit.txt")
assert.Assert(t, errors.Is(err, bagit.ErrInvalid))
})
func TestConcurrency(t *testing.T) {
t.Parallel()

t.Run("Validates bag", func(t *testing.T) {
t.Parallel()
Expand All @@ -46,23 +46,66 @@ func TestValidateBag(t *testing.T) {
assert.NilError(t, err)
})

t.Run("Validates bag concurrently", func(t *testing.T) {
t.Run("Returns ErrBusy if the resource is busy", func(t *testing.T) {
t.Parallel()

b := setUp(t)

// This test should pass because each call to Validate() creates its own
// distinct Python interpreter instance.
var g errgroup.Group
for i := 0; i < 10; i++ {
for i := 0; i < 3; i++ {
g.Go(func() error {
return b.Validate("internal/testdata/valid-bag")
})
}

err := g.Wait()
assert.ErrorIs(t, err, bagit.ErrBusy)
})

t.Run("Parallel execution", func(t *testing.T) {
t.Parallel()

// *bagit.BagIt is not shareable, each goroutine must create its own.
var g errgroup.Group
for i := 0; i < 3; i++ {
g.Go(func() error {
b := setUp(t)
return b.Validate("internal/testdata/valid-bag")
})
}

err := g.Wait()
assert.NilError(t, err)
})
}

func TestValidateBag(t *testing.T) {
t.Parallel()

t.Run("Fails validation", func(t *testing.T) {
t.Parallel()

b := setUp(t)

err := b.Validate("/tmp/691b8e7f-e6b7-41dd-bc47-868e2ff69333")
assert.Error(t, err, "invalid: Expected bagit.txt does not exist: /tmp/691b8e7f-e6b7-41dd-bc47-868e2ff69333/bagit.txt")
assert.Assert(t, errors.Is(err, bagit.ErrInvalid))
})

t.Run("Validates bag", func(t *testing.T) {
t.Parallel()

b := setUp(t)

err := b.Validate("internal/testdata/valid-bag")
assert.NilError(t, err)
})
}

func TestMakeBag(t *testing.T) {
t.Parallel()

t.Run("Creates bag", func(t *testing.T) {
t.Parallel()
Expand Down
Loading