Skip to content

Commit

Permalink
feat(plugins): Use wazero instead of wasmtime
Browse files Browse the repository at this point in the history
  • Loading branch information
kyleconroy committed Dec 5, 2023
1 parent eb8d97f commit 5e3d938
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 106 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ go 1.21

require (
github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230321174746-8dcc6526cfb1
github.com/bytecodealliance/wasmtime-go/v14 v14.0.0
github.com/cubicdaiya/gonp v1.0.4
github.com/davecgh/go-spew v1.1.1
github.com/fatih/structtag v1.2.0
Expand All @@ -20,6 +19,7 @@ require (
github.com/riza-io/grpc-go v0.2.0
github.com/spf13/cobra v1.8.0
github.com/spf13/pflag v1.0.5
github.com/tetratelabs/wazero v1.5.0
github.com/wasilibs/go-pgquery v0.0.0-20231205013331-96e794bb074e
github.com/xeipuuv/gojsonschema v1.2.0
golang.org/x/sync v0.5.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230321174746-8dcc6526cfb1/g
github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI=
github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g=
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/bytecodealliance/wasmtime-go/v14 v14.0.0 h1:ur7S3P+PAeJmgllhSrKnGQOAmmtUbLQxb/nw2NZiaEM=
github.com/bytecodealliance/wasmtime-go/v14 v14.0.0/go.mod h1:tqOVEUjnXY6aGpSfM9qdVRR6G//Yc513fFYUdzZb/DY=
github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I=
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
Expand Down Expand Up @@ -185,6 +183,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/tetratelabs/wazero v1.5.0 h1:Yz3fZHivfDiZFUXnWMPUoiW7s8tC1sjdBtlJn08qYa0=
github.com/tetratelabs/wazero v1.5.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A=
github.com/wasilibs/go-pgquery v0.0.0-20231205013331-96e794bb074e h1:sGIC6/D0KqpA+qBSDSVDQswU/IJVYkbnUXnipgTLQWk=
github.com/wasilibs/go-pgquery v0.0.0-20231205013331-96e794bb074e/go.mod h1:KW0azBSWqkPZ71r+3O4qt8h6A/NisFLp0rbjZ3py4OE=
github.com/wasilibs/wazerox v0.0.0-20231117065139-b3503f4aeff6 h1:jwbU8u5TuXModzdEG4wI0g4FyuD7ROSttU86go5sPdU=
Expand Down
149 changes: 46 additions & 103 deletions internal/ext/wasm/wasm.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package wasm

import (
"bytes"
"context"
"crypto/sha256"
"errors"
Expand All @@ -15,10 +16,11 @@ import (
"os"
"path/filepath"
"runtime"
"runtime/trace"
"strings"

wasmtime "github.com/bytecodealliance/wasmtime-go/v14"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
"github.com/tetratelabs/wazero/sys"
"golang.org/x/sync/singleflight"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
Expand Down Expand Up @@ -70,13 +72,17 @@ func (r *Runner) getChecksum(ctx context.Context) (string, error) {
return sum, nil
}

func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasmtime.Module, error) {
func (r *Runner) loadBytes(ctx context.Context) ([]byte, error) {
expected, err := r.getChecksum(ctx)
if err != nil {
return nil, err
}
cacheDir, err := cache.PluginsDir()
if err != nil {
return nil, err
}
value, err, _ := flight.Do(expected, func() (interface{}, error) {
return r.loadSerializedModule(ctx, engine, expected)
return r.loadWASM(ctx, cacheDir, expected)
})
if err != nil {
return nil, err
Expand All @@ -85,52 +91,7 @@ func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasm
if !ok {
return nil, fmt.Errorf("returned value was not a byte slice")
}
return wasmtime.NewModuleDeserialize(engine, data)
}

func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engine, expectedSha string) ([]byte, error) {
cacheDir, err := cache.PluginsDir()
if err != nil {
return nil, err
}

pluginDir := filepath.Join(cacheDir, expectedSha)
modName := fmt.Sprintf("plugin_%s_%s_%s.module", runtime.GOOS, runtime.GOARCH, wasmtimeVersion)
modPath := filepath.Join(pluginDir, modName)
_, staterr := os.Stat(modPath)
if staterr == nil {
data, err := os.ReadFile(modPath)
if err != nil {
return nil, err
}
return data, nil
}

wmod, err := r.loadWASM(ctx, cacheDir, expectedSha)
if err != nil {
return nil, err
}

moduRegion := trace.StartRegion(ctx, "wasmtime.NewModule")
module, err := wasmtime.NewModule(engine, wmod)
moduRegion.End()
if err != nil {
return nil, fmt.Errorf("define wasi: %w", err)
}

err = os.Mkdir(pluginDir, 0755)
if err != nil && !os.IsExist(err) {
return nil, fmt.Errorf("mkdirall: %w", err)
}
out, err := module.Serialize()
if err != nil {
return nil, fmt.Errorf("serialize: %w", err)
}
if err := os.WriteFile(modPath, out, 0444); err != nil {
return nil, fmt.Errorf("cache wasm: %w", err)
}

return out, nil
return data, nil
}

func (r *Runner) fetch(ctx context.Context, uri string) ([]byte, string, error) {
Expand Down Expand Up @@ -245,72 +206,56 @@ func (r *Runner) Invoke(ctx context.Context, method string, args any, reply any,
return fmt.Errorf("failed to encode codegen request: %w", err)
}

engine := wasmtime.NewEngine()
module, err := r.loadModule(ctx, engine)
cacheDir, err := cache.PluginsDir()
if err != nil {
return fmt.Errorf("loadModule: %w", err)
return err
}

linker := wasmtime.NewLinker(engine)
if err := linker.DefineWasi(); err != nil {
cache, err := wazero.NewCompilationCacheWithDir(filepath.Join(cacheDir, "wazero"))
if err != nil {
return err
}

dir, err := os.MkdirTemp(os.Getenv("SQLCTMPDIR"), "out")
wasmBytes, err := r.loadBytes(ctx)
if err != nil {
return fmt.Errorf("temp dir: %w", err)
return fmt.Errorf("loadModule: %w", err)
}

defer os.RemoveAll(dir)
stdinPath := filepath.Join(dir, "stdin")
stderrPath := filepath.Join(dir, "stderr")
stdoutPath := filepath.Join(dir, "stdout")
config := wazero.NewRuntimeConfig().WithCompilationCache(cache)
rt := wazero.NewRuntimeWithConfig(ctx, config)
defer rt.Close(ctx)

if err := os.WriteFile(stdinPath, stdinBlob, 0755); err != nil {
return fmt.Errorf("write file: %w", err)
}

// Configure WASI imports to write stdout into a file.
wasiConfig := wasmtime.NewWasiConfig()
wasiConfig.SetArgv([]string{"plugin.wasm", method})
wasiConfig.SetStdinFile(stdinPath)
wasiConfig.SetStdoutFile(stdoutPath)
wasiConfig.SetStderrFile(stderrPath)
// TODO: Handle error
wasi_snapshot_preview1.MustInstantiate(ctx, rt)

keys := []string{"SQLC_VERSION"}
vals := []string{info.Version}
for _, key := range r.Env {
keys = append(keys, key)
vals = append(vals, os.Getenv(key))
// Compile the Wasm binary once so that we can skip the entire compilation time during instantiation.
mod, err := rt.CompileModule(ctx, wasmBytes)
if err != nil {
return err
}
wasiConfig.SetEnv(keys, vals)

store := wasmtime.NewStore(engine)
store.SetWasi(wasiConfig)
var stderr, stdout bytes.Buffer

linkRegion := trace.StartRegion(ctx, "linker.DefineModule")
err = linker.DefineModule(store, "", module)
linkRegion.End()
if err != nil {
return fmt.Errorf("define wasi: %w", err)
conf := wazero.NewModuleConfig()
conf = conf.WithArgs("plugin.wasm", method)
conf = conf.WithEnv("SQLC_VERSION", info.Version)
for _, key := range r.Env {
conf = conf.WithEnv(key, os.Getenv(key))
}
conf = conf.WithStdin(bytes.NewReader(stdinBlob))
conf = conf.WithStdout(&stdout)
conf = conf.WithStderr(&stderr)

// Run the function
fn, err := linker.GetDefault(store, "")
if err != nil {
return fmt.Errorf("wasi: get default: %w", err)
result, err := rt.InstantiateModule(ctx, mod, conf)
if result != nil {
defer result.Close(ctx)
}

callRegion := trace.StartRegion(ctx, "call _start")
_, err = fn.Call(store)
callRegion.End()

if cerr := checkError(err, stderrPath); cerr != nil {
if cerr := checkError(err, &stderr); cerr != nil {
return cerr
}

// Print WASM stdout
stdoutBlob, err := os.ReadFile(stdoutPath)
stdoutBlob, err := io.ReadAll(&stdout)
if err != nil {
return fmt.Errorf("read file: %w", err)
}
Expand All @@ -331,21 +276,19 @@ func (r *Runner) NewStream(ctx context.Context, desc *grpc.StreamDesc, method st
return nil, status.Error(codes.Unimplemented, "")
}

func checkError(err error, stderrPath string) error {
func checkError(err error, stderr io.Reader) error {
if err == nil {
return err
}

var wtError *wasmtime.Error
if errors.As(err, &wtError) {
if code, ok := wtError.ExitStatus(); ok {
if code == 0 {
return nil
}
if exitErr, ok := err.(*sys.ExitError); ok {
if exitErr.ExitCode() == 0 {
return nil
}
}

// Print WASM stdout
stderrBlob, rferr := os.ReadFile(stderrPath)
stderrBlob, rferr := io.ReadAll(stderr)
if rferr == nil && len(stderrBlob) > 0 {
return errors.New(string(stderrBlob))
}
Expand Down

0 comments on commit 5e3d938

Please sign in to comment.