From 9225bc1a857e6395f446094786c74cdd07cfa73e Mon Sep 17 00:00:00 2001 From: Street <5597260+MStreet3@users.noreply.github.com> Date: Tue, 22 Oct 2024 12:02:10 -0400 Subject: [PATCH] [cappl-86] feat(workflows/wasm): emit msgs to beholder (#845) * wip(wasm): adds Emit to Runtime interface WIP on Runtime with panics * refactor(wasm): separte funcs out of NewRunner * refactor(wasm): shifts logging related funcs around * feat(wasm): adds custom pb message * feat(wasm): calls emit from guest runner * refactor(workflows): splits out emitter interface + docstring * feat(host): defines a beholder adapter for emitter * wip(host): implement host side emit * refactor(wasm/host): abstracts read and write to wasm * protos wip * feat(wasm): emits error response * refactor(wasm/host): write all failures from wasm to memory * feat(wasm): inject metadata into module * feat(events+wasm): pull emit md from req md * feat(custmsg): creates labels from map * feat(wasm): adds tests and validates labels * feat(wasm/host): use custmsg implementation for calling beholder * chore(wasm+host): docstrings and lint * chore(host): new emitter iface + private func types * chore(multi) review comments * chore(wasm): add id and md to config directly * refactor(custmsg+host): adapter labeler from config for emit * refactor(wasm): remove emitter from mod config * refactor(custmsg+wasm): expose emitlabeler on guest * refactor(wasm+sdk): EmitLabeler to MessageEmitter * refactor(wasm+events): share label keys * refactor(wasm+values): use map[string]string directly --- pkg/capabilities/events/events.go | 32 +- pkg/custmsg/custom_message.go | 40 ++- pkg/custmsg/custom_message_test.go | 16 +- pkg/values/map.go | 2 +- pkg/workflows/sdk/runtime.go | 13 + pkg/workflows/sdk/testutils/runtime.go | 4 + pkg/workflows/wasm/host/module.go | 242 +++++++++++-- pkg/workflows/wasm/host/module_test.go | 333 ++++++++++++++++++ pkg/workflows/wasm/host/test/emit/cmd/main.go | 40 +++ pkg/workflows/wasm/host/wasip1.go | 10 +- pkg/workflows/wasm/host/wasm_test.go | 129 +++++++ pkg/workflows/wasm/pb/wasm.pb.go | 258 ++++++++++++-- pkg/workflows/wasm/pb/wasm.proto | 21 +- pkg/workflows/wasm/runner.go | 4 +- pkg/workflows/wasm/runner_test.go | 95 ++++- pkg/workflows/wasm/runner_wasip1.go | 181 +++++----- pkg/workflows/wasm/sdk.go | 159 ++++++++- pkg/workflows/wasm/sdk_test.go | 66 ++++ 18 files changed, 1459 insertions(+), 186 deletions(-) create mode 100644 pkg/workflows/wasm/host/module_test.go create mode 100644 pkg/workflows/wasm/host/test/emit/cmd/main.go create mode 100644 pkg/workflows/wasm/sdk_test.go diff --git a/pkg/capabilities/events/events.go b/pkg/capabilities/events/events.go index 1443fb2ff..bc74422c1 100644 --- a/pkg/capabilities/events/events.go +++ b/pkg/capabilities/events/events.go @@ -14,14 +14,14 @@ import ( const ( // Duplicates the attributes in beholder/message.go::Metadata - labelWorkflowOwner = "workflow_owner_address" - labelWorkflowID = "workflow_id" - labelWorkflowExecutionID = "workflow_execution_id" - labelWorkflowName = "workflow_name" - labelCapabilityContractAddress = "capability_contract_address" - labelCapabilityID = "capability_id" - labelCapabilityVersion = "capability_version" - labelCapabilityName = "capability_name" + LabelWorkflowOwner = "workflow_owner_address" + LabelWorkflowID = "workflow_id" + LabelWorkflowExecutionID = "workflow_execution_id" + LabelWorkflowName = "workflow_name" + LabelCapabilityContractAddress = "capability_contract_address" + LabelCapabilityID = "capability_id" + LabelCapabilityVersion = "capability_version" + LabelCapabilityName = "capability_name" ) type EmitMetadata struct { @@ -93,35 +93,35 @@ func (e EmitMetadata) attrs() []any { a := []any{} if e.WorkflowOwner != "" { - a = append(a, labelWorkflowOwner, e.WorkflowOwner) + a = append(a, LabelWorkflowOwner, e.WorkflowOwner) } if e.WorkflowID != "" { - a = append(a, labelWorkflowID, e.WorkflowID) + a = append(a, LabelWorkflowID, e.WorkflowID) } if e.WorkflowExecutionID != "" { - a = append(a, labelWorkflowExecutionID, e.WorkflowExecutionID) + a = append(a, LabelWorkflowExecutionID, e.WorkflowExecutionID) } if e.WorkflowName != "" { - a = append(a, labelWorkflowName, e.WorkflowName) + a = append(a, LabelWorkflowName, e.WorkflowName) } if e.CapabilityContractAddress != "" { - a = append(a, labelCapabilityContractAddress, e.CapabilityContractAddress) + a = append(a, LabelCapabilityContractAddress, e.CapabilityContractAddress) } if e.CapabilityID != "" { - a = append(a, labelCapabilityID, e.CapabilityID) + a = append(a, LabelCapabilityID, e.CapabilityID) } if e.CapabilityVersion != "" { - a = append(a, labelCapabilityVersion, e.CapabilityVersion) + a = append(a, LabelCapabilityVersion, e.CapabilityVersion) } if e.CapabilityName != "" { - a = append(a, labelCapabilityName, e.CapabilityName) + a = append(a, LabelCapabilityName, e.CapabilityName) } return a diff --git a/pkg/custmsg/custom_message.go b/pkg/custmsg/custom_message.go index cf4fe57b9..67665e499 100644 --- a/pkg/custmsg/custom_message.go +++ b/pkg/custmsg/custom_message.go @@ -19,17 +19,35 @@ func NewLabeler() Labeler { return Labeler{labels: make(map[string]string)} } +// WithMapLabels adds multiple key-value pairs to the CustomMessageLabeler for transmission +// With SendLogAsCustomMessage +func (l Labeler) WithMapLabels(labels map[string]string) Labeler { + newCustomMessageLabeler := NewLabeler() + + // Copy existing labels from the current agent + for k, v := range l.labels { + newCustomMessageLabeler.labels[k] = v + } + + // Add new key-value pairs + for k, v := range labels { + newCustomMessageLabeler.labels[k] = v + } + + return newCustomMessageLabeler +} + // With adds multiple key-value pairs to the CustomMessageLabeler for transmission With SendLogAsCustomMessage -func (c Labeler) With(keyValues ...string) Labeler { +func (l Labeler) With(keyValues ...string) Labeler { newCustomMessageLabeler := NewLabeler() if len(keyValues)%2 != 0 { // If an odd number of key-value arguments is passed, return the original CustomMessageLabeler unchanged - return c + return l } // Copy existing labels from the current agent - for k, v := range c.labels { + for k, v := range l.labels { newCustomMessageLabeler.labels[k] = v } @@ -43,10 +61,22 @@ func (c Labeler) With(keyValues ...string) Labeler { return newCustomMessageLabeler } +func (l Labeler) Emit(msg string) error { + return sendLogAsCustomMessageW(msg, l.labels) +} + +func (l Labeler) Labels() map[string]string { + copied := make(map[string]string, len(l.labels)) + for k, v := range l.labels { + copied[k] = v + } + return copied +} + // SendLogAsCustomMessage emits a BaseMessage With msg and labels as data. // any key in labels that is not part of orderedLabelKeys will not be transmitted -func (c Labeler) SendLogAsCustomMessage(msg string) error { - return sendLogAsCustomMessageW(msg, c.labels) +func (l Labeler) SendLogAsCustomMessage(msg string) error { + return sendLogAsCustomMessageW(msg, l.labels) } func sendLogAsCustomMessageW(msg string, labels map[string]string) error { diff --git a/pkg/custmsg/custom_message_test.go b/pkg/custmsg/custom_message_test.go index 4d41408f1..4ae8269e8 100644 --- a/pkg/custmsg/custom_message_test.go +++ b/pkg/custmsg/custom_message_test.go @@ -12,5 +12,19 @@ func Test_CustomMessageAgent(t *testing.T) { cma1 := cma.With("key1", "value1") cma2 := cma1.With("key2", "value2") - assert.NotEqual(t, cma1.labels, cma2.labels) + assert.NotEqual(t, cma1.Labels(), cma2.Labels()) +} + +func Test_CustomMessageAgent_With(t *testing.T) { + cma := NewLabeler() + cma = cma.With("key1", "value1") + + assert.Equal(t, cma.Labels(), map[string]string{"key1": "value1"}) +} + +func Test_CustomMessageAgent_WithMapLabels(t *testing.T) { + cma := NewLabeler() + cma = cma.WithMapLabels(map[string]string{"key1": "value1"}) + + assert.Equal(t, cma.Labels(), map[string]string{"key1": "value1"}) } diff --git a/pkg/values/map.go b/pkg/values/map.go index 076831102..bfe5fb494 100644 --- a/pkg/values/map.go +++ b/pkg/values/map.go @@ -20,7 +20,7 @@ func EmptyMap() *Map { } } -func NewMap(m map[string]any) (*Map, error) { +func NewMap[T any](m map[string]T) (*Map, error) { mv := map[string]Value{} for k, v := range m { val, err := Wrap(v) diff --git a/pkg/workflows/sdk/runtime.go b/pkg/workflows/sdk/runtime.go index de254acaf..d6403717e 100644 --- a/pkg/workflows/sdk/runtime.go +++ b/pkg/workflows/sdk/runtime.go @@ -7,9 +7,22 @@ import ( var BreakErr = capabilities.ErrStopExecution +type MessageEmitter interface { + // Emit sends a message to the labeler's destination. + Emit(string) error + + // With sets the labels for the message to be emitted. Labels are passed as key-value pairs + // and are cumulative. + With(kvs ...string) MessageEmitter +} + +// Guest interface type Runtime interface { Logger() logger.Logger Fetch(req FetchRequest) (FetchResponse, error) + + // Emitter sends the given message and labels to the configured collector. + Emitter() MessageEmitter } type FetchRequest struct { diff --git a/pkg/workflows/sdk/testutils/runtime.go b/pkg/workflows/sdk/testutils/runtime.go index 5ae962663..8234b77b1 100644 --- a/pkg/workflows/sdk/testutils/runtime.go +++ b/pkg/workflows/sdk/testutils/runtime.go @@ -17,3 +17,7 @@ func (nr *NoopRuntime) Logger() logger.Logger { l, _ := logger.New() return l } + +func (nr *NoopRuntime) Emitter() sdk.MessageEmitter { + return nil +} diff --git a/pkg/workflows/wasm/host/module.go b/pkg/workflows/wasm/host/module.go index 8b31e77c6..a8c5009ed 100644 --- a/pkg/workflows/wasm/host/module.go +++ b/pkg/workflows/wasm/host/module.go @@ -19,36 +19,11 @@ import ( "google.golang.org/protobuf/proto" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/values" "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm" wasmpb "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/pb" ) -// safeMem returns a copy of the wasm module memory at the given pointer and size. -func safeMem(caller *wasmtime.Caller, ptr int32, size int32) ([]byte, error) { - mem := caller.GetExport("memory").Memory() - data := mem.UnsafeData(caller) - if ptr+size > int32(len(data)) { - return nil, errors.New("out of bounds memory access") - } - - cd := make([]byte, size) - copy(cd, data[ptr:ptr+size]) - return cd, nil -} - -// copyBuffer copies the given src byte slice into the wasm module memory at the given pointer and size. -func copyBuffer(caller *wasmtime.Caller, src []byte, ptr int32, size int32) int64 { - mem := caller.GetExport("memory").Memory() - rawData := mem.UnsafeData(caller) - if int32(len(rawData)) < ptr+size { - return -1 - } - buffer := rawData[ptr : ptr+size] - dataLen := int64(len(src)) - copy(buffer, src) - return dataLen -} - type respStore struct { m map[string]*wasmpb.Response mu sync.RWMutex @@ -91,6 +66,14 @@ type DeterminismConfig struct { Seed int64 } +type MessageEmitter interface { + // Emit sends a message to the labeler's destination. + Emit(string) error + + // WithMapLabels sets the labels for the message to be emitted. Labels are cumulative. + WithMapLabels(map[string]string) MessageEmitter +} + type ModuleConfig struct { TickInterval time.Duration Timeout *time.Duration @@ -100,6 +83,9 @@ type ModuleConfig struct { IsUncompressed bool Fetch func(*wasmpb.FetchRequest) (*wasmpb.FetchResponse, error) + // Labeler is used to emit messages from the module. + Labeler MessageEmitter + // If Determinism is set, the module will override the random_get function in the WASI API with // the provided seed to ensure deterministic behavior. Determinism *DeterminismConfig @@ -110,6 +96,7 @@ type Module struct { module *wasmtime.Module linker *wasmtime.Linker + // respStore collects responses from sendResponse mapped by request ID r *respStore cfg *ModuleConfig @@ -148,6 +135,10 @@ func NewModule(modCfg *ModuleConfig, binary []byte, opts ...func(*ModuleConfig)) } } + if modCfg.Labeler == nil { + modCfg.Labeler = &unimplementedMessageEmitter{} + } + logger := modCfg.Logger if modCfg.TickInterval == 0 { @@ -200,7 +191,7 @@ func NewModule(modCfg *ModuleConfig, binary []byte, opts ...func(*ModuleConfig)) "env", "sendResponse", func(caller *wasmtime.Caller, ptr int32, ptrlen int32) int32 { - b, innerErr := safeMem(caller, ptr, ptrlen) + b, innerErr := wasmRead(caller, ptr, ptrlen) if innerErr != nil { logger.Errorf("error calling sendResponse: %s", err) return ErrnoFault @@ -230,7 +221,7 @@ func NewModule(modCfg *ModuleConfig, binary []byte, opts ...func(*ModuleConfig)) "env", "log", func(caller *wasmtime.Caller, ptr int32, ptrlen int32) { - b, innerErr := safeMem(caller, ptr, ptrlen) + b, innerErr := wasmRead(caller, ptr, ptrlen) if innerErr != nil { logger.Errorf("error calling log: %s", err) return @@ -285,6 +276,15 @@ func NewModule(modCfg *ModuleConfig, binary []byte, opts ...func(*ModuleConfig)) return nil, fmt.Errorf("error wrapping fetch func: %w", err) } + err = linker.FuncWrap( + "env", + "emit", + createEmitFn(logger, modCfg.Labeler, wasmRead, wasmWrite, wasmWriteUInt32), + ) + if err != nil { + return nil, fmt.Errorf("error wrapping emit func: %w", err) + } + m := &Module{ engine: engine, module: mod, @@ -404,7 +404,7 @@ func containsCode(err error, code int) bool { func fetchFn(logger logger.Logger, modCfg *ModuleConfig) func(caller *wasmtime.Caller, respptr int32, resplenptr int32, reqptr int32, reqptrlen int32) int32 { const fetchErrSfx = "error calling fetch" return func(caller *wasmtime.Caller, respptr int32, resplenptr int32, reqptr int32, reqptrlen int32) int32 { - b, innerErr := safeMem(caller, reqptr, reqptrlen) + b, innerErr := wasmRead(caller, reqptr, reqptrlen) if innerErr != nil { logger.Errorf("%s: %s", fetchErrSfx, innerErr) return ErrnoFault @@ -429,19 +429,189 @@ func fetchFn(logger logger.Logger, modCfg *ModuleConfig) func(caller *wasmtime.C return ErrnoFault } - size := copyBuffer(caller, respBytes, respptr, int32(len(respBytes))) - if size == -1 { + if size := wasmWrite(caller, respBytes, respptr, int32(len(respBytes))); size == -1 { return ErrnoFault } - uint32Size := int32(4) - resplenBytes := make([]byte, uint32Size) - binary.LittleEndian.PutUint32(resplenBytes, uint32(len(respBytes))) - size = copyBuffer(caller, resplenBytes, resplenptr, uint32Size) - if size == -1 { + if size := wasmWriteUInt32(caller, resplenptr, uint32(len(respBytes))); size == -1 { return ErrnoFault } return ErrnoSuccess } } + +// createEmitFn injects dependencies and builds the emit function exposed by the WASM. Errors in +// Emit, if any, are returned in the Error Message of the response. +func createEmitFn( + l logger.Logger, + e MessageEmitter, + reader unsafeReaderFunc, + writer unsafeWriterFunc, + sizeWriter unsafeFixedLengthWriterFunc, +) func(caller *wasmtime.Caller, respptr, resplenptr, msgptr, msglen int32) int32 { + logErr := func(err error) { + l.Errorf("error emitting message: %s", err) + } + + return func(caller *wasmtime.Caller, respptr, resplenptr, msgptr, msglen int32) int32 { + // writeErr marshals and writes an error response to wasm + writeErr := func(err error) int32 { + logErr(err) + + resp := &wasmpb.EmitMessageResponse{ + Error: &wasmpb.Error{ + Message: err.Error(), + }, + } + + respBytes, perr := proto.Marshal(resp) + if perr != nil { + logErr(perr) + return ErrnoFault + } + + if size := writer(caller, respBytes, respptr, int32(len(respBytes))); size == -1 { + logErr(errors.New("failed to write response")) + return ErrnoFault + } + + if size := sizeWriter(caller, resplenptr, uint32(len(respBytes))); size == -1 { + logErr(errors.New("failed to write response length")) + return ErrnoFault + } + + return ErrnoSuccess + } + + b, err := reader(caller, msgptr, msglen) + if err != nil { + return writeErr(err) + } + + msg, labels, err := toEmissible(b) + if err != nil { + return writeErr(err) + } + + if err := e.WithMapLabels(labels).Emit(msg); err != nil { + return writeErr(err) + } + + return ErrnoSuccess + } +} + +type unimplementedMessageEmitter struct{} + +func (u *unimplementedMessageEmitter) Emit(string) error { + return errors.New("unimplemented") +} + +func (u *unimplementedMessageEmitter) WithMapLabels(map[string]string) MessageEmitter { + return u +} + +func toEmissible(b []byte) (string, map[string]string, error) { + msg := &wasmpb.EmitMessageRequest{} + if err := proto.Unmarshal(b, msg); err != nil { + return "", nil, err + } + + validated, err := toValidatedLabels(msg) + if err != nil { + return "", nil, err + } + + return msg.Message, validated, nil +} + +func toValidatedLabels(msg *wasmpb.EmitMessageRequest) (map[string]string, error) { + vl, err := values.FromMapValueProto(msg.Labels) + if err != nil { + return nil, err + } + + // Handle the case of no labels before unwrapping. + if vl == nil { + vl = values.EmptyMap() + } + + var labels map[string]string + if err := vl.UnwrapTo(&labels); err != nil { + return nil, err + } + + return labels, nil +} + +// unsafeWriterFunc defines behavior for writing directly to wasm memory. A source slice of bytes +// is written to the location defined by the ptr. +type unsafeWriterFunc func(c *wasmtime.Caller, src []byte, ptr, len int32) int64 + +// unsafeFixedLengthWriterFunc defines behavior for writing a uint32 value to wasm memory at the location defined +// by the ptr. +type unsafeFixedLengthWriterFunc func(c *wasmtime.Caller, ptr int32, val uint32) int64 + +// unsafeReaderFunc abstractly defines the behavior of reading from WASM memory. Returns a copy of +// the memory at the given pointer and size. +type unsafeReaderFunc func(c *wasmtime.Caller, ptr, len int32) ([]byte, error) + +// wasmMemoryAccessor is the default implementation for unsafely accessing the memory of the WASM module. +func wasmMemoryAccessor(caller *wasmtime.Caller) []byte { + return caller.GetExport("memory").Memory().UnsafeData(caller) +} + +// wasmRead returns a copy of the wasm module memory at the given pointer and size. +func wasmRead(caller *wasmtime.Caller, ptr int32, size int32) ([]byte, error) { + return read(wasmMemoryAccessor(caller), ptr, size) +} + +// Read acts on a byte slice that should represent an unsafely accessed slice of memory. It returns +// a copy of the memory at the given pointer and size. +func read(memory []byte, ptr int32, size int32) ([]byte, error) { + if size < 0 || ptr < 0 { + return nil, fmt.Errorf("invalid memory access: ptr: %d, size: %d", ptr, size) + } + + if ptr+size > int32(len(memory)) { + return nil, errors.New("out of bounds memory access") + } + + cd := make([]byte, size) + copy(cd, memory[ptr:ptr+size]) + return cd, nil +} + +// wasmWrite copies the given src byte slice into the wasm module memory at the given pointer and size. +func wasmWrite(caller *wasmtime.Caller, src []byte, ptr int32, size int32) int64 { + return write(wasmMemoryAccessor(caller), src, ptr, size) +} + +// wasmWriteUInt32 binary encodes and writes a uint32 to the wasm module memory at the given pointer. +func wasmWriteUInt32(caller *wasmtime.Caller, ptr int32, val uint32) int64 { + return writeUInt32(wasmMemoryAccessor(caller), ptr, val) +} + +// writeUInt32 binary encodes and writes a uint32 to the memory at the given pointer. +func writeUInt32(memory []byte, ptr int32, val uint32) int64 { + uint32Size := int32(4) + buffer := make([]byte, uint32Size) + binary.LittleEndian.PutUint32(buffer, val) + return write(memory, buffer, ptr, uint32Size) +} + +// write copies the given src byte slice into the memory at the given pointer and size. +func write(memory, src []byte, ptr, size int32) int64 { + if size < 0 || ptr < 0 { + return -1 + } + + if int32(len(memory)) < ptr+size { + return -1 + } + buffer := memory[ptr : ptr+size] + dataLen := int64(len(src)) + copy(buffer, src) + return dataLen +} diff --git a/pkg/workflows/wasm/host/module_test.go b/pkg/workflows/wasm/host/module_test.go new file mode 100644 index 000000000..72264d149 --- /dev/null +++ b/pkg/workflows/wasm/host/module_test.go @@ -0,0 +1,333 @@ +package host + +import ( + "encoding/binary" + "testing" + + "github.com/bytecodealliance/wasmtime-go/v23" + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/proto" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/values/pb" + wasmpb "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/pb" +) + +type mockMessageEmitter struct { + e func(string, map[string]string) error + labels map[string]string +} + +func (m *mockMessageEmitter) Emit(msg string) error { + return m.e(msg, m.labels) +} + +func (m *mockMessageEmitter) WithMapLabels(labels map[string]string) MessageEmitter { + m.labels = labels + return m +} + +func newMockMessageEmitter(e func(string, map[string]string) error) MessageEmitter { + return &mockMessageEmitter{e: e} +} + +// Test_createEmitFn tests that the emit function used by the module is created correctly. Memory +// access functions are injected as mocks. +func Test_createEmitFn(t *testing.T) { + t.Run("success", func(t *testing.T) { + emitFn := createEmitFn( + logger.Test(t), + newMockMessageEmitter(func(_ string, _ map[string]string) error { + return nil + }), + unsafeReaderFunc(func(_ *wasmtime.Caller, _, _ int32) ([]byte, error) { + b, err := proto.Marshal(&wasmpb.EmitMessageRequest{ + Message: "hello, world", + Labels: &pb.Map{ + Fields: map[string]*pb.Value{ + "foo": { + Value: &pb.Value_StringValue{ + StringValue: "bar", + }, + }, + }, + }, + }) + assert.NoError(t, err) + return b, nil + }), + unsafeWriterFunc(func(c *wasmtime.Caller, src []byte, ptr, len int32) int64 { + return 0 + }), + unsafeFixedLengthWriterFunc(func(c *wasmtime.Caller, ptr int32, val uint32) int64 { + return 0 + }), + ) + gotCode := emitFn(new(wasmtime.Caller), 0, 0, 0, 0) + assert.Equal(t, ErrnoSuccess, gotCode) + }) + + t.Run("success without labels", func(t *testing.T) { + emitFn := createEmitFn( + logger.Test(t), + newMockMessageEmitter(func(_ string, _ map[string]string) error { + return nil + }), + unsafeReaderFunc(func(_ *wasmtime.Caller, _, _ int32) ([]byte, error) { + b, err := proto.Marshal(&wasmpb.EmitMessageRequest{}) + assert.NoError(t, err) + return b, nil + }), + unsafeWriterFunc(func(c *wasmtime.Caller, src []byte, ptr, len int32) int64 { + return 0 + }), + unsafeFixedLengthWriterFunc(func(c *wasmtime.Caller, ptr int32, val uint32) int64 { + return 0 + }), + ) + gotCode := emitFn(new(wasmtime.Caller), 0, 0, 0, 0) + assert.Equal(t, ErrnoSuccess, gotCode) + }) + + t.Run("successfully write error to memory on failure to read", func(t *testing.T) { + respBytes, err := proto.Marshal(&wasmpb.EmitMessageResponse{ + Error: &wasmpb.Error{ + Message: assert.AnError.Error(), + }, + }) + assert.NoError(t, err) + + emitFn := createEmitFn( + logger.Test(t), + nil, + unsafeReaderFunc(func(_ *wasmtime.Caller, _, _ int32) ([]byte, error) { + return nil, assert.AnError + }), + unsafeWriterFunc(func(c *wasmtime.Caller, src []byte, ptr, len int32) int64 { + assert.Equal(t, respBytes, src, "marshalled response not equal to bytes to write") + return 0 + }), + unsafeFixedLengthWriterFunc(func(c *wasmtime.Caller, ptr int32, val uint32) int64 { + assert.Equal(t, uint32(len(respBytes)), val, "did not write length of response") + return 0 + }), + ) + gotCode := emitFn(new(wasmtime.Caller), 0, int32(len(respBytes)), 0, 0) + assert.Equal(t, ErrnoSuccess, gotCode, "code mismatch") + }) + + t.Run("failure to emit writes error to memory", func(t *testing.T) { + respBytes, err := proto.Marshal(&wasmpb.EmitMessageResponse{ + Error: &wasmpb.Error{ + Message: assert.AnError.Error(), + }, + }) + assert.NoError(t, err) + + emitFn := createEmitFn( + logger.Test(t), + newMockMessageEmitter(func(_ string, _ map[string]string) error { + return assert.AnError + }), + unsafeReaderFunc(func(_ *wasmtime.Caller, _, _ int32) ([]byte, error) { + b, err := proto.Marshal(&wasmpb.EmitMessageRequest{}) + assert.NoError(t, err) + return b, nil + }), + unsafeWriterFunc(func(c *wasmtime.Caller, src []byte, ptr, len int32) int64 { + assert.Equal(t, respBytes, src, "marshalled response not equal to bytes to write") + return 0 + }), + unsafeFixedLengthWriterFunc(func(c *wasmtime.Caller, ptr int32, val uint32) int64 { + assert.Equal(t, uint32(len(respBytes)), val, "did not write length of response") + return 0 + }), + ) + gotCode := emitFn(new(wasmtime.Caller), 0, 0, 0, 0) + assert.Equal(t, ErrnoSuccess, gotCode) + }) + + t.Run("bad read failure to unmarshal protos", func(t *testing.T) { + badData := []byte("not proto bufs") + msg := &wasmpb.EmitMessageRequest{} + marshallErr := proto.Unmarshal(badData, msg) + assert.Error(t, marshallErr) + + respBytes, err := proto.Marshal(&wasmpb.EmitMessageResponse{ + Error: &wasmpb.Error{ + Message: marshallErr.Error(), + }, + }) + assert.NoError(t, err) + + emitFn := createEmitFn( + logger.Test(t), + nil, + unsafeReaderFunc(func(_ *wasmtime.Caller, _, _ int32) ([]byte, error) { + return badData, nil + }), + unsafeWriterFunc(func(c *wasmtime.Caller, src []byte, ptr, len int32) int64 { + assert.Equal(t, respBytes, src, "marshalled response not equal to bytes to write") + return 0 + }), + unsafeFixedLengthWriterFunc(func(c *wasmtime.Caller, ptr int32, val uint32) int64 { + assert.Equal(t, uint32(len(respBytes)), val, "did not write length of response") + return 0 + }), + ) + gotCode := emitFn(new(wasmtime.Caller), 0, 0, 0, 0) + assert.Equal(t, ErrnoSuccess, gotCode) + }) +} + +func Test_read(t *testing.T) { + t.Run("successfully read from slice", func(t *testing.T) { + memory := []byte("hello, world") + got, err := read(memory, 0, int32(len(memory))) + assert.NoError(t, err) + assert.Equal(t, []byte("hello, world"), got) + }) + + t.Run("fail to read because out of bounds request", func(t *testing.T) { + memory := []byte("hello, world") + _, err := read(memory, 0, int32(len(memory)+1)) + assert.Error(t, err) + }) + + t.Run("fails to read because of invalid pointer or length", func(t *testing.T) { + memory := []byte("hello, world") + _, err := read(memory, 0, -1) + assert.Error(t, err) + + _, err = read(memory, -1, 1) + assert.Error(t, err) + }) + + t.Run("validate that memory is read only once copied", func(t *testing.T) { + memory := []byte("hello, world") + copied, err := read(memory, 0, int32(len(memory))) + assert.NoError(t, err) + + // mutate copy + copied[0] = 'H' + assert.Equal(t, []byte("Hello, world"), copied) + + // original memory is unchanged + assert.Equal(t, []byte("hello, world"), memory) + }) +} + +func Test_write(t *testing.T) { + t.Run("successfully write to slice", func(t *testing.T) { + giveSrc := []byte("hello, world") + memory := make([]byte, 12) + n := write(memory, giveSrc, 0, int32(len(giveSrc))) + assert.Equal(t, n, int64(len(giveSrc))) + assert.Equal(t, []byte("hello, world"), memory[:len(giveSrc)]) + }) + + t.Run("cannot write to slice because memory too small", func(t *testing.T) { + giveSrc := []byte("hello, world") + memory := make([]byte, len(giveSrc)-1) + n := write(memory, giveSrc, 0, int32(len(giveSrc))) + assert.Equal(t, n, int64(-1)) + }) + + t.Run("fails to write to invalid access", func(t *testing.T) { + giveSrc := []byte("hello, world") + memory := make([]byte, len(giveSrc)) + n := write(memory, giveSrc, 0, -1) + assert.Equal(t, n, int64(-1)) + + n = write(memory, giveSrc, -1, 1) + assert.Equal(t, n, int64(-1)) + }) +} + +// Test_writeUInt32 tests that a uint32 is written to memory correctly. +func Test_writeUInt32(t *testing.T) { + t.Run("success", func(t *testing.T) { + memory := make([]byte, 4) + n := writeUInt32(memory, 0, 42) + wantBuf := make([]byte, 4) + binary.LittleEndian.PutUint32(wantBuf, 42) + assert.Equal(t, n, int64(4)) + assert.Equal(t, wantBuf, memory) + }) +} + +func Test_toValidatedLabels(t *testing.T) { + t.Run("success", func(t *testing.T) { + msg := &wasmpb.EmitMessageRequest{ + Labels: &pb.Map{ + Fields: map[string]*pb.Value{ + "test": { + Value: &pb.Value_StringValue{ + StringValue: "value", + }, + }, + }, + }, + } + wantLabels := map[string]string{ + "test": "value", + } + gotLabels, err := toValidatedLabels(msg) + assert.NoError(t, err) + assert.Equal(t, wantLabels, gotLabels) + }) + + t.Run("success with empty labels", func(t *testing.T) { + msg := &wasmpb.EmitMessageRequest{} + wantLabels := map[string]string{} + gotLabels, err := toValidatedLabels(msg) + assert.NoError(t, err) + assert.Equal(t, wantLabels, gotLabels) + }) + + t.Run("fails with non string", func(t *testing.T) { + msg := &wasmpb.EmitMessageRequest{ + Labels: &pb.Map{ + Fields: map[string]*pb.Value{ + "test": { + Value: &pb.Value_Int64Value{ + Int64Value: *proto.Int64(42), + }, + }, + }, + }, + } + _, err := toValidatedLabels(msg) + assert.Error(t, err) + }) +} + +func Test_toEmissible(t *testing.T) { + t.Run("success", func(t *testing.T) { + msg := &wasmpb.EmitMessageRequest{ + Message: "hello, world", + Labels: &pb.Map{ + Fields: map[string]*pb.Value{ + "test": { + Value: &pb.Value_StringValue{ + StringValue: "value", + }, + }, + }, + }, + } + + b, err := proto.Marshal(msg) + assert.NoError(t, err) + + gotMsg, gotLabels, err := toEmissible(b) + assert.NoError(t, err) + assert.Equal(t, "hello, world", gotMsg) + assert.Equal(t, map[string]string{"test": "value"}, gotLabels) + }) + + t.Run("fails with bad message", func(t *testing.T) { + _, _, err := toEmissible([]byte("not proto bufs")) + assert.Error(t, err) + }) +} diff --git a/pkg/workflows/wasm/host/test/emit/cmd/main.go b/pkg/workflows/wasm/host/test/emit/cmd/main.go new file mode 100644 index 000000000..712b56e59 --- /dev/null +++ b/pkg/workflows/wasm/host/test/emit/cmd/main.go @@ -0,0 +1,40 @@ +//go:build wasip1 + +package main + +import ( + "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm" + + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/cli/cmd/testdata/fixtures/capabilities/basictrigger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/sdk" +) + +func BuildWorkflow(config []byte) *sdk.WorkflowSpecFactory { + workflow := sdk.NewWorkflowSpecFactory( + sdk.NewWorkflowParams{}, + ) + + triggerCfg := basictrigger.TriggerConfig{Name: "trigger", Number: 100} + trigger := triggerCfg.New(workflow) + + sdk.Compute1[basictrigger.TriggerOutputs, bool]( + workflow, + "transform", + sdk.Compute1Inputs[basictrigger.TriggerOutputs]{Arg0: trigger}, + func(rsdk sdk.Runtime, outputs basictrigger.TriggerOutputs) (bool, error) { + if err := rsdk.Emitter(). + With("test-string-field-key", "this is a test field content"). + Emit("testing emit"); err != nil { + return false, err + } + return true, nil + }) + + return workflow +} + +func main() { + runner := wasm.NewRunner() + workflow := BuildWorkflow(runner.Config()) + runner.Run(workflow) +} diff --git a/pkg/workflows/wasm/host/wasip1.go b/pkg/workflows/wasm/host/wasip1.go index 28950a16d..08235e23e 100644 --- a/pkg/workflows/wasm/host/wasip1.go +++ b/pkg/workflows/wasm/host/wasip1.go @@ -81,7 +81,7 @@ func clockTimeGet(caller *wasmtime.Caller, id int32, precision int64, resultTime uint64Size := int32(8) trg := make([]byte, uint64Size) binary.LittleEndian.PutUint64(trg, uint64(val)) - copyBuffer(caller, trg, resultTimestamp, uint64Size) + wasmWrite(caller, trg, resultTimestamp, uint64Size) return ErrnoSuccess } @@ -105,7 +105,7 @@ func pollOneoff(caller *wasmtime.Caller, subscriptionptr int32, eventsptr int32, return ErrnoInval } - subs, err := safeMem(caller, subscriptionptr, nsubscriptions*subscriptionLen) + subs, err := wasmRead(caller, subscriptionptr, nsubscriptions*subscriptionLen) if err != nil { return ErrnoFault } @@ -176,13 +176,13 @@ func pollOneoff(caller *wasmtime.Caller, subscriptionptr int32, eventsptr int32, binary.LittleEndian.PutUint32(rne, uint32(nsubscriptions)) // Write the number of events to `resultNevents` - size := copyBuffer(caller, rne, resultNevents, uint32Size) + size := wasmWrite(caller, rne, resultNevents, uint32Size) if size == -1 { return ErrnoFault } // Write the events to `events` - size = copyBuffer(caller, events, eventsptr, nsubscriptions*eventsLen) + size = wasmWrite(caller, events, eventsptr, nsubscriptions*eventsLen) if size == -1 { return ErrnoFault } @@ -221,7 +221,7 @@ func createRandomGet(cfg *ModuleConfig) func(caller *wasmtime.Caller, buf, bufLe } // Copy the random bytes into the wasm module memory - if n := copyBuffer(caller, randOutput, buf, bufLen); n != int64(len(randOutput)) { + if n := wasmWrite(caller, randOutput, buf, bufLen); n != int64(len(randOutput)) { return ErrnoFault } diff --git a/pkg/workflows/wasm/host/wasm_test.go b/pkg/workflows/wasm/host/wasm_test.go index 4692ef96d..f2fbcd7b2 100644 --- a/pkg/workflows/wasm/host/wasm_test.go +++ b/pkg/workflows/wasm/host/wasm_test.go @@ -49,6 +49,8 @@ const ( fetchBinaryCmd = "test/fetch/cmd" randBinaryLocation = "test/rand/cmd/testmodule.wasm" randBinaryCmd = "test/rand/cmd" + emitBinaryLocation = "test/emit/cmd/testmodule.wasm" + emitBinaryCmd = "test/emit/cmd" ) func createTestBinary(outputPath, path string, compress bool, t *testing.T) []byte { @@ -187,6 +189,133 @@ func Test_Compute_Logs(t *testing.T) { } } +func Test_Compute_Emit(t *testing.T) { + binary := createTestBinary(emitBinaryCmd, emitBinaryLocation, true, t) + + lggr := logger.Test(t) + + req := &wasmpb.Request{ + Id: uuid.New().String(), + Message: &wasmpb.Request_ComputeRequest{ + ComputeRequest: &wasmpb.ComputeRequest{ + Request: &capabilitiespb.CapabilityRequest{ + Inputs: &valuespb.Map{}, + Config: &valuespb.Map{}, + Metadata: &capabilitiespb.RequestMetadata{ + ReferenceId: "transform", + WorkflowId: "workflow-id", + WorkflowName: "workflow-name", + WorkflowOwner: "workflow-owner", + WorkflowExecutionId: "workflow-execution-id", + }, + }, + }, + }, + } + + fetchFunc := func(req *wasmpb.FetchRequest) (*wasmpb.FetchResponse, error) { + return nil, nil + } + + t.Run("successfully call emit with metadata in labels", func(t *testing.T) { + m, err := NewModule(&ModuleConfig{ + Logger: lggr, + Fetch: fetchFunc, + Labeler: newMockMessageEmitter(func(msg string, kvs map[string]string) error { + t.Helper() + + assert.Equal(t, "testing emit", msg) + assert.Equal(t, "this is a test field content", kvs["test-string-field-key"]) + assert.Equal(t, "workflow-id", kvs["workflow_id"]) + assert.Equal(t, "workflow-name", kvs["workflow_name"]) + assert.Equal(t, "workflow-owner", kvs["workflow_owner_address"]) + assert.Equal(t, "workflow-execution-id", kvs["workflow_execution_id"]) + return nil + }), + }, binary) + require.NoError(t, err) + + m.Start() + + _, err = m.Run(req) + assert.Nil(t, err) + }) + + t.Run("failure on emit writes to error chain and logs", func(t *testing.T) { + lggr, logs := logger.TestObserved(t, zapcore.InfoLevel) + + m, err := NewModule(&ModuleConfig{ + Logger: lggr, + Fetch: fetchFunc, + Labeler: newMockMessageEmitter(func(msg string, kvs map[string]string) error { + t.Helper() + + assert.Equal(t, "testing emit", msg) + assert.Equal(t, "this is a test field content", kvs["test-string-field-key"]) + assert.Equal(t, "workflow-id", kvs["workflow_id"]) + assert.Equal(t, "workflow-name", kvs["workflow_name"]) + assert.Equal(t, "workflow-owner", kvs["workflow_owner_address"]) + assert.Equal(t, "workflow-execution-id", kvs["workflow_execution_id"]) + + return assert.AnError + }), + }, binary) + require.NoError(t, err) + + m.Start() + + _, err = m.Run(req) + assert.Error(t, err) + assert.ErrorContains(t, err, assert.AnError.Error()) + + require.Len(t, logs.AllUntimed(), 1) + + expectedEntries := []Entry{ + { + Log: zapcore.Entry{Level: zapcore.ErrorLevel, Message: fmt.Sprintf("error emitting message: %s", assert.AnError)}, + }, + } + for i := range expectedEntries { + assert.Equal(t, expectedEntries[i].Log.Level, logs.AllUntimed()[i].Entry.Level) + assert.Equal(t, expectedEntries[i].Log.Message, logs.AllUntimed()[i].Entry.Message) + } + }) + + t.Run("failure on emit due to missing workflow identifying metadata", func(t *testing.T) { + lggr := logger.Test(t) + + m, err := NewModule(&ModuleConfig{ + Logger: lggr, + Fetch: fetchFunc, + Labeler: newMockMessageEmitter(func(msg string, labels map[string]string) error { + return nil + }), // never called + }, binary) + require.NoError(t, err) + + m.Start() + + req = &wasmpb.Request{ + Id: uuid.New().String(), + Message: &wasmpb.Request_ComputeRequest{ + ComputeRequest: &wasmpb.ComputeRequest{ + Request: &capabilitiespb.CapabilityRequest{ + Inputs: &valuespb.Map{}, + Config: &valuespb.Map{}, + Metadata: &capabilitiespb.RequestMetadata{ + ReferenceId: "transform", + }, + }, + }, + }, + } + + _, err = m.Run(req) + assert.Error(t, err) + assert.ErrorContains(t, err, "failed to create emission") + }) +} + func Test_Compute_Fetch(t *testing.T) { binary := createTestBinary(fetchBinaryCmd, fetchBinaryLocation, true, t) diff --git a/pkg/workflows/wasm/pb/wasm.pb.go b/pkg/workflows/wasm/pb/wasm.pb.go index 95a8839a0..22255dc8f 100644 --- a/pkg/workflows/wasm/pb/wasm.pb.go +++ b/pkg/workflows/wasm/pb/wasm.pb.go @@ -671,11 +671,12 @@ type FetchResponse struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ExecutionError bool `protobuf:"varint,1,opt,name=executionError,proto3" json:"executionError,omitempty"` - ErrorMessage string `protobuf:"bytes,2,opt,name=errorMessage,proto3" json:"errorMessage,omitempty"` - StatusCode uint32 `protobuf:"varint,3,opt,name=statusCode,proto3" json:"statusCode,omitempty"` // NOTE: this is actually a uint8, but proto doesn't support this. - Headers *pb1.Map `protobuf:"bytes,4,opt,name=headers,proto3" json:"headers,omitempty"` - Body []byte `protobuf:"bytes,5,opt,name=body,proto3" json:"body,omitempty"` + ExecutionError bool `protobuf:"varint,1,opt,name=executionError,proto3" json:"executionError,omitempty"` + ErrorMessage string `protobuf:"bytes,2,opt,name=errorMessage,proto3" json:"errorMessage,omitempty"` + // NOTE: this is actually a uint8, but proto doesn't support this. + StatusCode uint32 `protobuf:"varint,3,opt,name=statusCode,proto3" json:"statusCode,omitempty"` + Headers *pb1.Map `protobuf:"bytes,4,opt,name=headers,proto3" json:"headers,omitempty"` + Body []byte `protobuf:"bytes,5,opt,name=body,proto3" json:"body,omitempty"` } func (x *FetchResponse) Reset() { @@ -745,6 +746,155 @@ func (x *FetchResponse) GetBody() []byte { return nil } +type EmitMessageRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` + Labels *pb1.Map `protobuf:"bytes,2,opt,name=labels,proto3" json:"labels,omitempty"` +} + +func (x *EmitMessageRequest) Reset() { + *x = EmitMessageRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_workflows_wasm_pb_wasm_proto_msgTypes[10] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *EmitMessageRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*EmitMessageRequest) ProtoMessage() {} + +func (x *EmitMessageRequest) ProtoReflect() protoreflect.Message { + mi := &file_workflows_wasm_pb_wasm_proto_msgTypes[10] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use EmitMessageRequest.ProtoReflect.Descriptor instead. +func (*EmitMessageRequest) Descriptor() ([]byte, []int) { + return file_workflows_wasm_pb_wasm_proto_rawDescGZIP(), []int{10} +} + +func (x *EmitMessageRequest) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +func (x *EmitMessageRequest) GetLabels() *pb1.Map { + if x != nil { + return x.Labels + } + return nil +} + +type Error struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` +} + +func (x *Error) Reset() { + *x = Error{} + if protoimpl.UnsafeEnabled { + mi := &file_workflows_wasm_pb_wasm_proto_msgTypes[11] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Error) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Error) ProtoMessage() {} + +func (x *Error) ProtoReflect() protoreflect.Message { + mi := &file_workflows_wasm_pb_wasm_proto_msgTypes[11] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Error.ProtoReflect.Descriptor instead. +func (*Error) Descriptor() ([]byte, []int) { + return file_workflows_wasm_pb_wasm_proto_rawDescGZIP(), []int{11} +} + +func (x *Error) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +type EmitMessageResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Error *Error `protobuf:"bytes,1,opt,name=error,proto3" json:"error,omitempty"` +} + +func (x *EmitMessageResponse) Reset() { + *x = EmitMessageResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_workflows_wasm_pb_wasm_proto_msgTypes[12] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *EmitMessageResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*EmitMessageResponse) ProtoMessage() {} + +func (x *EmitMessageResponse) ProtoReflect() protoreflect.Message { + mi := &file_workflows_wasm_pb_wasm_proto_msgTypes[12] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use EmitMessageResponse.ProtoReflect.Descriptor instead. +func (*EmitMessageResponse) Descriptor() ([]byte, []int) { + return file_workflows_wasm_pb_wasm_proto_rawDescGZIP(), []int{12} +} + +func (x *EmitMessageResponse) GetError() *Error { + if x != nil { + return x.Error + } + return nil +} + var File_workflows_wasm_pb_wasm_proto protoreflect.FileDescriptor var file_workflows_wasm_pb_wasm_proto_rawDesc = []byte{ @@ -850,11 +1000,22 @@ var file_workflows_wasm_pb_wasm_proto_rawDesc = []byte{ 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0b, 0x2e, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x2e, 0x4d, 0x61, 0x70, 0x52, 0x07, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x62, 0x6f, 0x64, 0x79, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x62, 0x6f, 0x64, - 0x79, 0x42, 0x43, 0x5a, 0x41, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, - 0x73, 0x6d, 0x61, 0x72, 0x74, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x61, 0x63, 0x74, 0x6b, 0x69, 0x74, - 0x2f, 0x63, 0x68, 0x61, 0x69, 0x6e, 0x6c, 0x69, 0x6e, 0x6b, 0x2d, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, - 0x6e, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x77, 0x6f, 0x72, 0x6b, 0x66, 0x6c, 0x6f, 0x77, 0x73, 0x2f, - 0x73, 0x64, 0x6b, 0x2f, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x79, 0x22, 0x53, 0x0a, 0x12, 0x45, 0x6d, 0x69, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, + 0x65, 0x12, 0x23, 0x0a, 0x06, 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x0b, 0x2e, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x2e, 0x4d, 0x61, 0x70, 0x52, 0x06, + 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x73, 0x22, 0x21, 0x0a, 0x05, 0x45, 0x72, 0x72, 0x6f, 0x72, 0x12, + 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x37, 0x0a, 0x13, 0x45, 0x6d, 0x69, + 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x12, 0x20, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x0a, 0x2e, 0x73, 0x64, 0x6b, 0x2e, 0x45, 0x72, 0x72, 0x6f, 0x72, 0x52, 0x05, 0x65, 0x72, 0x72, + 0x6f, 0x72, 0x42, 0x43, 0x5a, 0x41, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, + 0x2f, 0x73, 0x6d, 0x61, 0x72, 0x74, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x61, 0x63, 0x74, 0x6b, 0x69, + 0x74, 0x2f, 0x63, 0x68, 0x61, 0x69, 0x6e, 0x6c, 0x69, 0x6e, 0x6b, 0x2d, 0x63, 0x6f, 0x6d, 0x6d, + 0x6f, 0x6e, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x77, 0x6f, 0x72, 0x6b, 0x66, 0x6c, 0x6f, 0x77, 0x73, + 0x2f, 0x73, 0x64, 0x6b, 0x2f, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -869,7 +1030,7 @@ func file_workflows_wasm_pb_wasm_proto_rawDescGZIP() []byte { return file_workflows_wasm_pb_wasm_proto_rawDescData } -var file_workflows_wasm_pb_wasm_proto_msgTypes = make([]protoimpl.MessageInfo, 10) +var file_workflows_wasm_pb_wasm_proto_msgTypes = make([]protoimpl.MessageInfo, 13) var file_workflows_wasm_pb_wasm_proto_goTypes = []interface{}{ (*RuntimeConfig)(nil), // 0: sdk.RuntimeConfig (*ComputeRequest)(nil), // 1: sdk.ComputeRequest @@ -881,33 +1042,38 @@ var file_workflows_wasm_pb_wasm_proto_goTypes = []interface{}{ (*Response)(nil), // 7: sdk.Response (*FetchRequest)(nil), // 8: sdk.FetchRequest (*FetchResponse)(nil), // 9: sdk.FetchResponse - (*pb.CapabilityRequest)(nil), // 10: capabilities.CapabilityRequest - (*emptypb.Empty)(nil), // 11: google.protobuf.Empty - (*pb.CapabilityResponse)(nil), // 12: capabilities.CapabilityResponse - (*pb1.Map)(nil), // 13: values.Map + (*EmitMessageRequest)(nil), // 10: sdk.EmitMessageRequest + (*Error)(nil), // 11: sdk.Error + (*EmitMessageResponse)(nil), // 12: sdk.EmitMessageResponse + (*pb.CapabilityRequest)(nil), // 13: capabilities.CapabilityRequest + (*emptypb.Empty)(nil), // 14: google.protobuf.Empty + (*pb.CapabilityResponse)(nil), // 15: capabilities.CapabilityResponse + (*pb1.Map)(nil), // 16: values.Map } var file_workflows_wasm_pb_wasm_proto_depIdxs = []int32{ - 10, // 0: sdk.ComputeRequest.request:type_name -> capabilities.CapabilityRequest + 13, // 0: sdk.ComputeRequest.request:type_name -> capabilities.CapabilityRequest 0, // 1: sdk.ComputeRequest.runtimeConfig:type_name -> sdk.RuntimeConfig 1, // 2: sdk.Request.computeRequest:type_name -> sdk.ComputeRequest - 11, // 3: sdk.Request.specRequest:type_name -> google.protobuf.Empty - 12, // 4: sdk.ComputeResponse.response:type_name -> capabilities.CapabilityResponse - 13, // 5: sdk.StepInputs.mapping:type_name -> values.Map + 14, // 3: sdk.Request.specRequest:type_name -> google.protobuf.Empty + 15, // 4: sdk.ComputeResponse.response:type_name -> capabilities.CapabilityResponse + 16, // 5: sdk.StepInputs.mapping:type_name -> values.Map 4, // 6: sdk.StepDefinition.inputs:type_name -> sdk.StepInputs - 13, // 7: sdk.StepDefinition.config:type_name -> values.Map + 16, // 7: sdk.StepDefinition.config:type_name -> values.Map 5, // 8: sdk.WorkflowSpec.triggers:type_name -> sdk.StepDefinition 5, // 9: sdk.WorkflowSpec.actions:type_name -> sdk.StepDefinition 5, // 10: sdk.WorkflowSpec.consensus:type_name -> sdk.StepDefinition 5, // 11: sdk.WorkflowSpec.targets:type_name -> sdk.StepDefinition 3, // 12: sdk.Response.computeResponse:type_name -> sdk.ComputeResponse 6, // 13: sdk.Response.specResponse:type_name -> sdk.WorkflowSpec - 13, // 14: sdk.FetchRequest.headers:type_name -> values.Map - 13, // 15: sdk.FetchResponse.headers:type_name -> values.Map - 16, // [16:16] is the sub-list for method output_type - 16, // [16:16] is the sub-list for method input_type - 16, // [16:16] is the sub-list for extension type_name - 16, // [16:16] is the sub-list for extension extendee - 0, // [0:16] is the sub-list for field type_name + 16, // 14: sdk.FetchRequest.headers:type_name -> values.Map + 16, // 15: sdk.FetchResponse.headers:type_name -> values.Map + 16, // 16: sdk.EmitMessageRequest.labels:type_name -> values.Map + 11, // 17: sdk.EmitMessageResponse.error:type_name -> sdk.Error + 18, // [18:18] is the sub-list for method output_type + 18, // [18:18] is the sub-list for method input_type + 18, // [18:18] is the sub-list for extension type_name + 18, // [18:18] is the sub-list for extension extendee + 0, // [0:18] is the sub-list for field type_name } func init() { file_workflows_wasm_pb_wasm_proto_init() } @@ -1036,6 +1202,42 @@ func file_workflows_wasm_pb_wasm_proto_init() { return nil } } + file_workflows_wasm_pb_wasm_proto_msgTypes[10].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*EmitMessageRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_workflows_wasm_pb_wasm_proto_msgTypes[11].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Error); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_workflows_wasm_pb_wasm_proto_msgTypes[12].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*EmitMessageResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } file_workflows_wasm_pb_wasm_proto_msgTypes[2].OneofWrappers = []interface{}{ (*Request_ComputeRequest)(nil), @@ -1051,7 +1253,7 @@ func file_workflows_wasm_pb_wasm_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_workflows_wasm_pb_wasm_proto_rawDesc, NumEnums: 0, - NumMessages: 10, + NumMessages: 13, NumExtensions: 0, NumServices: 0, }, diff --git a/pkg/workflows/wasm/pb/wasm.proto b/pkg/workflows/wasm/pb/wasm.proto index 180b2cd12..a838a4f98 100644 --- a/pkg/workflows/wasm/pb/wasm.proto +++ b/pkg/workflows/wasm/pb/wasm.proto @@ -8,9 +8,7 @@ import "capabilities/pb/capabilities.proto"; import "values/pb/values.proto"; import "google/protobuf/empty.proto"; -message RuntimeConfig { - int64 maxFetchResponseSizeBytes = 1; -} +message RuntimeConfig { int64 maxFetchResponseSizeBytes = 1; } message ComputeRequest { capabilities.CapabilityRequest request = 1; @@ -27,9 +25,7 @@ message Request { } } -message ComputeResponse { - capabilities.CapabilityResponse response = 1; -} +message ComputeResponse { capabilities.CapabilityResponse response = 1; } message StepInputs { string outputRef = 1; @@ -74,7 +70,18 @@ message FetchRequest { message FetchResponse { bool executionError = 1; string errorMessage = 2; - uint32 statusCode = 3; // NOTE: this is actually a uint8, but proto doesn't support this. + + // NOTE: this is actually a uint8, but proto doesn't support this. + uint32 statusCode = 3; values.Map headers = 4; bytes body = 5; } + +message EmitMessageRequest { + string message = 1; + values.Map labels = 2; +} + +message Error { string message = 1; } + +message EmitMessageResponse { Error error = 1; } diff --git a/pkg/workflows/wasm/runner.go b/pkg/workflows/wasm/runner.go index 1372117fa..0d8ab006e 100644 --- a/pkg/workflows/wasm/runner.go +++ b/pkg/workflows/wasm/runner.go @@ -26,7 +26,7 @@ var _ sdk.Runner = (*Runner)(nil) type Runner struct { sendResponse func(payload *wasmpb.Response) - sdkFactory func(cfg *RuntimeConfig) *Runtime + sdkFactory func(cfg *RuntimeConfig, opts ...func(*RuntimeConfig)) *Runtime args []string req *wasmpb.Request } @@ -156,7 +156,7 @@ func (r *Runner) handleComputeRequest(factory *sdk.WorkflowSpecFactory, id strin } // Extract the config from the request - drc := defaultRuntimeConfig() + drc := defaultRuntimeConfig(id, &creq.Metadata) if rc := computeReq.GetRuntimeConfig(); rc != nil { if rc.MaxFetchResponseSizeBytes != 0 { drc.MaxFetchResponseSizeBytes = rc.MaxFetchResponseSizeBytes diff --git a/pkg/workflows/wasm/runner_test.go b/pkg/workflows/wasm/runner_test.go index c8f3eda0a..569b11b5d 100644 --- a/pkg/workflows/wasm/runner_test.go +++ b/pkg/workflows/wasm/runner_test.go @@ -2,7 +2,9 @@ package wasm import ( "encoding/base64" + "encoding/binary" "testing" + "unsafe" "github.com/google/uuid" "github.com/stretchr/testify/assert" @@ -14,6 +16,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/capabilities" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/cli/cmd/testdata/fixtures/capabilities/basictrigger" capabilitiespb "github.com/smartcontractkit/chainlink-common/pkg/capabilities/pb" + "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/values" "github.com/smartcontractkit/chainlink-common/pkg/workflows/sdk" wasmpb "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/pb" @@ -132,7 +135,9 @@ func TestRunner_Run_ExecuteCompute(t *testing.T) { runner := &Runner{ args: []string{"wasm", str}, sendResponse: responseFn, - sdkFactory: func(cfg *RuntimeConfig) *Runtime { return nil }, + sdkFactory: func(cfg *RuntimeConfig, _ ...func(*RuntimeConfig)) *Runtime { + return nil + }, } runner.Run(workflow) @@ -202,3 +207,91 @@ func TestRunner_Run_GetWorkflowSpec(t *testing.T) { gotSpec.Triggers[0].Config["number"] = int64(gotSpec.Triggers[0].Config["number"].(uint64)) assert.Equal(t, &gotSpec, spc) } + +// Test_createEmitFn validates the runtime's emit function implementation. Uses mocks of the +// imported wasip1 emit function. +func Test_createEmitFn(t *testing.T) { + var ( + l = logger.Test(t) + sdkConfig = &RuntimeConfig{ + MaxFetchResponseSizeBytes: 1_000, + Metadata: &capabilities.RequestMetadata{ + WorkflowID: "workflow_id", + WorkflowExecutionID: "workflow_execution_id", + WorkflowName: "workflow_name", + WorkflowOwner: "workflow_owner_address", + }, + } + giveMsg = "testing guest" + giveLabels = map[string]string{ + "some-key": "some-value", + } + ) + + t.Run("success", func(t *testing.T) { + hostEmit := func(respptr, resplenptr, reqptr unsafe.Pointer, reqptrlen int32) int32 { + return 0 + } + runtimeEmit := createEmitFn(sdkConfig, l, hostEmit) + err := runtimeEmit(giveMsg, giveLabels) + assert.NoError(t, err) + }) + + t.Run("successfully read error message when emit fails", func(t *testing.T) { + hostEmit := func(respptr, resplenptr, reqptr unsafe.Pointer, reqptrlen int32) int32 { + // marshall the protobufs + b, err := proto.Marshal(&wasmpb.EmitMessageResponse{ + Error: &wasmpb.Error{ + Message: assert.AnError.Error(), + }, + }) + assert.NoError(t, err) + + // write the marshalled response message to memory + resp := unsafe.Slice((*byte)(respptr), len(b)) + copy(resp, b) + + // write the length of the response to memory in little endian + respLen := unsafe.Slice((*byte)(resplenptr), uint32Size) + binary.LittleEndian.PutUint32(respLen, uint32(len(b))) + + return 0 + } + runtimeEmit := createEmitFn(sdkConfig, l, hostEmit) + err := runtimeEmit(giveMsg, giveLabels) + assert.Error(t, err) + assert.ErrorContains(t, err, assert.AnError.Error()) + }) + + t.Run("fail to deserialize response from memory", func(t *testing.T) { + hostEmit := func(respptr, resplenptr, reqptr unsafe.Pointer, reqptrlen int32) int32 { + // b is a non-protobuf byte slice + b := []byte(assert.AnError.Error()) + + // write the marshalled response message to memory + resp := unsafe.Slice((*byte)(respptr), len(b)) + copy(resp, b) + + // write the length of the response to memory in little endian + respLen := unsafe.Slice((*byte)(resplenptr), uint32Size) + binary.LittleEndian.PutUint32(respLen, uint32(len(b))) + + return 0 + } + + runtimeEmit := createEmitFn(sdkConfig, l, hostEmit) + err := runtimeEmit(giveMsg, giveLabels) + assert.Error(t, err) + assert.ErrorContains(t, err, "invalid wire-format data") + }) + + t.Run("fail with nonzero code from emit", func(t *testing.T) { + hostEmit := func(respptr, resplenptr, reqptr unsafe.Pointer, reqptrlen int32) int32 { + return 42 + } + runtimeEmit := createEmitFn(sdkConfig, l, hostEmit) + err := runtimeEmit(giveMsg, giveLabels) + assert.Error(t, err) + assert.ErrorContains(t, err, "emit failed with errno 42") + }) +} diff --git a/pkg/workflows/wasm/runner_wasip1.go b/pkg/workflows/wasm/runner_wasip1.go index 6a85a43db..6b24ed0bc 100644 --- a/pkg/workflows/wasm/runner_wasip1.go +++ b/pkg/workflows/wasm/runner_wasip1.go @@ -24,103 +24,118 @@ func log(respptr unsafe.Pointer, respptrlen int32) //go:wasmimport env fetch func fetch(respptr unsafe.Pointer, resplenptr unsafe.Pointer, reqptr unsafe.Pointer, reqptrlen int32) int32 -const uint32Size = int32(4) - -func bufferToPointerLen(buf []byte) (unsafe.Pointer, int32) { - return unsafe.Pointer(&buf[0]), int32(len(buf)) -} +//go:wasmimport env emit +func emit(respptr unsafe.Pointer, resplenptr unsafe.Pointer, reqptr unsafe.Pointer, reqptrlen int32) int32 func NewRunner() *Runner { l := logger.NewWithSync(&wasmWriteSyncer{}) return &Runner{ - sendResponse: func(response *wasmpb.Response) { - pb, err := proto.Marshal(response) - if err != nil { - // We somehow couldn't marshal the response, so let's - // exit with a special error code letting the host know - // what happened. - os.Exit(CodeInvalidResponse) - } - - // unknownID will only be set when we've failed to parse - // the request. Like before, let's bubble this up. - if response.Id == unknownID { - os.Exit(CodeInvalidRequest) - } - - ptr, ptrlen := bufferToPointerLen(pb) - errno := sendResponse(ptr, ptrlen) - if errno != 0 { - os.Exit(CodeHostErr) + sendResponse: sendResponseFn, + sdkFactory: func(sdkConfig *RuntimeConfig, opts ...func(*RuntimeConfig)) *Runtime { + for _, opt := range opts { + opt(sdkConfig) } - code := CodeSuccess - if response.ErrMsg != "" { - code = CodeRunnerErr - } - - os.Exit(code) - }, - sdkFactory: func(sdkConfig *RuntimeConfig) *Runtime { return &Runtime{ - logger: l, - fetchFn: func(req sdk.FetchRequest) (sdk.FetchResponse, error) { - headerspb, err := values.NewMap(req.Headers) - if err != nil { - return sdk.FetchResponse{}, fmt.Errorf("failed to create headers map: %w", err) - } - - b, err := proto.Marshal(&wasmpb.FetchRequest{ - Url: req.URL, - Method: req.Method, - Headers: values.ProtoMap(headerspb), - Body: req.Body, - TimeoutMs: req.TimeoutMs, - }) - if err != nil { - return sdk.FetchResponse{}, fmt.Errorf("failed to marshal fetch request: %w", err) - } - reqptr, reqptrlen := bufferToPointerLen(b) - - respBuffer := make([]byte, sdkConfig.MaxFetchResponseSizeBytes) - respptr, _ := bufferToPointerLen(respBuffer) - - resplenBuffer := make([]byte, uint32Size) - resplenptr, _ := bufferToPointerLen(resplenBuffer) - - errno := fetch(respptr, resplenptr, reqptr, reqptrlen) - if errno != 0 { - return sdk.FetchResponse{}, errors.New("failed to execute fetch") - } - - responseSize := binary.LittleEndian.Uint32(resplenBuffer) - response := &wasmpb.FetchResponse{} - err = proto.Unmarshal(respBuffer[:responseSize], response) - if err != nil { - return sdk.FetchResponse{}, fmt.Errorf("failed to unmarshal fetch response: %w", err) - } - - fields := response.Headers.GetFields() - headersResp := make(map[string]any, len(fields)) - for k, v := range fields { - headersResp[k] = v - } - - return sdk.FetchResponse{ - ExecutionError: response.ExecutionError, - ErrorMessage: response.ErrorMessage, - StatusCode: uint8(response.StatusCode), - Headers: headersResp, - Body: response.Body, - }, nil - }, + logger: l, + fetchFn: createFetchFn(sdkConfig, l), + emitFn: createEmitFn(sdkConfig, l, emit), } }, args: os.Args, } } +// sendResponseFn implements sendResponse for import into WASM. +func sendResponseFn(response *wasmpb.Response) { + pb, err := proto.Marshal(response) + if err != nil { + // We somehow couldn't marshal the response, so let's + // exit with a special error code letting the host know + // what happened. + os.Exit(CodeInvalidResponse) + } + + // unknownID will only be set when we've failed to parse + // the request. Like before, let's bubble this up. + if response.Id == unknownID { + os.Exit(CodeInvalidRequest) + } + + ptr, ptrlen := bufferToPointerLen(pb) + errno := sendResponse(ptr, ptrlen) + if errno != 0 { + os.Exit(CodeHostErr) + } + + code := CodeSuccess + if response.ErrMsg != "" { + code = CodeRunnerErr + } + + os.Exit(code) +} + +// createFetchFn injects dependencies and creates a fetch function that can be used by the WASM +// binary. +func createFetchFn( + sdkConfig *RuntimeConfig, + l logger.Logger, +) func(sdk.FetchRequest) (sdk.FetchResponse, error) { + fetchFn := func(req sdk.FetchRequest) (sdk.FetchResponse, error) { + headerspb, err := values.NewMap(req.Headers) + if err != nil { + return sdk.FetchResponse{}, fmt.Errorf("failed to create headers map: %w", err) + } + + b, err := proto.Marshal(&wasmpb.FetchRequest{ + Url: req.URL, + Method: req.Method, + Headers: values.ProtoMap(headerspb), + Body: req.Body, + TimeoutMs: req.TimeoutMs, + }) + if err != nil { + return sdk.FetchResponse{}, fmt.Errorf("failed to marshal fetch request: %w", err) + } + reqptr, reqptrlen := bufferToPointerLen(b) + + respBuffer := make([]byte, sdkConfig.MaxFetchResponseSizeBytes) + respptr, _ := bufferToPointerLen(respBuffer) + + resplenBuffer := make([]byte, uint32Size) + resplenptr, _ := bufferToPointerLen(resplenBuffer) + + errno := fetch(respptr, resplenptr, reqptr, reqptrlen) + if errno != 0 { + return sdk.FetchResponse{}, errors.New("failed to execute fetch") + } + + responseSize := binary.LittleEndian.Uint32(resplenBuffer) + response := &wasmpb.FetchResponse{} + err = proto.Unmarshal(respBuffer[:responseSize], response) + if err != nil { + return sdk.FetchResponse{}, fmt.Errorf("failed to unmarshal fetch response: %w", err) + } + + fields := response.Headers.GetFields() + headersResp := make(map[string]any, len(fields)) + for k, v := range fields { + headersResp[k] = v + } + + return sdk.FetchResponse{ + ExecutionError: response.ExecutionError, + ErrorMessage: response.ErrorMessage, + StatusCode: uint8(response.StatusCode), + Headers: headersResp, + Body: response.Body, + }, nil + } + return fetchFn +} + type wasmWriteSyncer struct{} // Write is used to proxy log requests from the WASM binary back to the host diff --git a/pkg/workflows/wasm/sdk.go b/pkg/workflows/wasm/sdk.go index d6c29a009..e45e0ff04 100644 --- a/pkg/workflows/wasm/sdk.go +++ b/pkg/workflows/wasm/sdk.go @@ -1,26 +1,47 @@ package wasm import ( + "encoding/binary" + "errors" + "fmt" + "unsafe" + + "google.golang.org/protobuf/proto" + + "github.com/smartcontractkit/chainlink-common/pkg/capabilities" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/events" + "github.com/smartcontractkit/chainlink-common/pkg/custmsg" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/values" "github.com/smartcontractkit/chainlink-common/pkg/workflows/sdk" + wasmpb "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/pb" ) +// Length of responses are encoded into 4 byte buffers in little endian. uint32Size is the size +// of that buffer. +const uint32Size = int32(4) + type Runtime struct { fetchFn func(req sdk.FetchRequest) (sdk.FetchResponse, error) + emitFn func(msg string, labels map[string]string) error logger logger.Logger } type RuntimeConfig struct { MaxFetchResponseSizeBytes int64 + RequestID *string + Metadata *capabilities.RequestMetadata } const ( defaultMaxFetchResponseSizeBytes = 5 * 1024 ) -func defaultRuntimeConfig() *RuntimeConfig { +func defaultRuntimeConfig(id string, md *capabilities.RequestMetadata) *RuntimeConfig { return &RuntimeConfig{ MaxFetchResponseSizeBytes: defaultMaxFetchResponseSizeBytes, + RequestID: &id, + Metadata: md, } } @@ -33,3 +54,139 @@ func (r *Runtime) Fetch(req sdk.FetchRequest) (sdk.FetchResponse, error) { func (r *Runtime) Logger() logger.Logger { return r.logger } + +func (r *Runtime) Emitter() sdk.MessageEmitter { + return newWasmGuestEmitter(r.emitFn) +} + +type wasmGuestEmitter struct { + base custmsg.Labeler + emitFn func(string, map[string]string) error + labels map[string]string +} + +func newWasmGuestEmitter(emitFn func(string, map[string]string) error) wasmGuestEmitter { + return wasmGuestEmitter{ + emitFn: emitFn, + labels: make(map[string]string), + base: custmsg.NewLabeler(), + } +} + +func (w wasmGuestEmitter) Emit(msg string) error { + return w.emitFn(msg, w.labels) +} + +func (w wasmGuestEmitter) With(keyValues ...string) sdk.MessageEmitter { + newEmitter := newWasmGuestEmitter(w.emitFn) + newEmitter.base = w.base.With(keyValues...) + newEmitter.labels = newEmitter.base.Labels() + return newEmitter +} + +// createEmitFn builds the runtime's emit function implementation, which is a function +// that handles marshalling and unmarshalling messages for the WASM to act on. +func createEmitFn( + sdkConfig *RuntimeConfig, + l logger.Logger, + emit func(respptr unsafe.Pointer, resplenptr unsafe.Pointer, reqptr unsafe.Pointer, reqptrlen int32) int32, +) func(string, map[string]string) error { + emitFn := func(msg string, labels map[string]string) error { + // Prepare the labels to be emitted + if sdkConfig.Metadata == nil { + return NewEmissionError(fmt.Errorf("metadata is required to emit")) + } + + labels, err := toEmitLabels(sdkConfig.Metadata, labels) + if err != nil { + return NewEmissionError(err) + } + + vm, err := values.NewMap(labels) + if err != nil { + return NewEmissionError(fmt.Errorf("could not wrap labels to map: %w", err)) + } + + // Marshal the message and labels into a protobuf message + b, err := proto.Marshal(&wasmpb.EmitMessageRequest{ + Message: msg, + Labels: values.ProtoMap(vm), + }) + if err != nil { + return err + } + + // Prepare the request to be sent to the host memory by allocating space for the + // response and response length buffers. + respBuffer := make([]byte, sdkConfig.MaxFetchResponseSizeBytes) + respptr, _ := bufferToPointerLen(respBuffer) + + resplenBuffer := make([]byte, uint32Size) + resplenptr, _ := bufferToPointerLen(resplenBuffer) + + // The request buffer is the wasm memory, get a pointer to the first element and the length + // of the protobuf message. + reqptr, reqptrlen := bufferToPointerLen(b) + + // Emit the message via the method imported from the host + errno := emit(respptr, resplenptr, reqptr, reqptrlen) + if errno != 0 { + return NewEmissionError(fmt.Errorf("emit failed with errno %d", errno)) + } + + // Attempt to read and handle the response from the host memory + responseSize := binary.LittleEndian.Uint32(resplenBuffer) + response := &wasmpb.EmitMessageResponse{} + if err := proto.Unmarshal(respBuffer[:responseSize], response); err != nil { + l.Errorw("failed to unmarshal emit response", "error", err.Error()) + return NewEmissionError(err) + } + + if response.Error != nil && response.Error.Message != "" { + return NewEmissionError(errors.New(response.Error.Message)) + } + + return nil + } + + return emitFn +} + +// bufferToPointerLen returns a pointer to the first element of the buffer and the length of the buffer. +func bufferToPointerLen(buf []byte) (unsafe.Pointer, int32) { + return unsafe.Pointer(&buf[0]), int32(len(buf)) +} + +// toEmitLabels ensures that the required metadata is present in the labels map +func toEmitLabels(md *capabilities.RequestMetadata, labels map[string]string) (map[string]string, error) { + if md.WorkflowID == "" { + return nil, fmt.Errorf("must provide workflow id to emit event") + } + + if md.WorkflowName == "" { + return nil, fmt.Errorf("must provide workflow name to emit event") + } + + if md.WorkflowOwner == "" { + return nil, fmt.Errorf("must provide workflow owner to emit event") + } + + labels[events.LabelWorkflowExecutionID] = md.WorkflowExecutionID + labels[events.LabelWorkflowOwner] = md.WorkflowOwner + labels[events.LabelWorkflowID] = md.WorkflowID + labels[events.LabelWorkflowName] = md.WorkflowName + return labels, nil +} + +// EmissionError wraps all errors that occur during the emission process for the runtime to handle. +type EmissionError struct { + Wrapped error +} + +func NewEmissionError(err error) *EmissionError { + return &EmissionError{Wrapped: err} +} + +func (e *EmissionError) Error() string { + return fmt.Errorf("failed to create emission: %w", e.Wrapped).Error() +} diff --git a/pkg/workflows/wasm/sdk_test.go b/pkg/workflows/wasm/sdk_test.go new file mode 100644 index 000000000..312dba7c7 --- /dev/null +++ b/pkg/workflows/wasm/sdk_test.go @@ -0,0 +1,66 @@ +package wasm + +import ( + "testing" + + "github.com/smartcontractkit/chainlink-common/pkg/capabilities" + + "github.com/stretchr/testify/assert" +) + +func Test_toEmitLabels(t *testing.T) { + t.Run("successfully transforms metadata", func(t *testing.T) { + md := &capabilities.RequestMetadata{ + WorkflowID: "workflow-id", + WorkflowName: "workflow-name", + WorkflowOwner: "workflow-owner", + } + empty := make(map[string]string, 0) + + gotLabels, err := toEmitLabels(md, empty) + assert.NoError(t, err) + + assert.Equal(t, map[string]string{ + "workflow_id": "workflow-id", + "workflow_name": "workflow-name", + "workflow_owner_address": "workflow-owner", + "workflow_execution_id": "", + }, gotLabels) + }) + + t.Run("fails on missing workflow id", func(t *testing.T) { + md := &capabilities.RequestMetadata{ + WorkflowName: "workflow-name", + WorkflowOwner: "workflow-owner", + } + empty := make(map[string]string, 0) + + _, err := toEmitLabels(md, empty) + assert.Error(t, err) + assert.ErrorContains(t, err, "workflow id") + }) + + t.Run("fails on missing workflow name", func(t *testing.T) { + md := &capabilities.RequestMetadata{ + WorkflowID: "workflow-id", + WorkflowOwner: "workflow-owner", + } + empty := make(map[string]string, 0) + + _, err := toEmitLabels(md, empty) + assert.Error(t, err) + assert.ErrorContains(t, err, "workflow name") + }) + + t.Run("fails on missing workflow owner", func(t *testing.T) { + md := &capabilities.RequestMetadata{ + WorkflowID: "workflow-id", + WorkflowName: "workflow-name", + } + empty := make(map[string]string, 0) + + _, err := toEmitLabels(md, empty) + assert.Error(t, err) + assert.ErrorContains(t, err, "workflow owner") + }) +}