From 7dbb1b0863a38a649b7e049a89a2033ccc4588cd Mon Sep 17 00:00:00 2001 From: Gabriel Paradiso Date: Fri, 10 Jan 2025 11:10:35 +0100 Subject: [PATCH] [CRE-42] Fix partial or truncated writes (#989) * fix: check size and len(src) match to avoid partial or truncated writes * fix: return the number of bytes copied * chore: align test naming --- pkg/workflows/wasm/host/module.go | 8 +++++--- pkg/workflows/wasm/host/module_test.go | 20 +++++++++++++++++--- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/pkg/workflows/wasm/host/module.go b/pkg/workflows/wasm/host/module.go index 601b69632..f8bd77d02 100644 --- a/pkg/workflows/wasm/host/module.go +++ b/pkg/workflows/wasm/host/module.go @@ -728,11 +728,13 @@ func write(memory, src []byte, ptr, size int32) int64 { return -1 } + if len(src) != int(size) { + return -1 + } + if int32(len(memory)) < ptr+size { return -1 } buffer := memory[ptr : ptr+size] - dataLen := int64(len(src)) - copy(buffer, src) - return dataLen + return int64(copy(buffer, src)) } diff --git a/pkg/workflows/wasm/host/module_test.go b/pkg/workflows/wasm/host/module_test.go index a19c43fa2..66390adeb 100644 --- a/pkg/workflows/wasm/host/module_test.go +++ b/pkg/workflows/wasm/host/module_test.go @@ -565,17 +565,31 @@ func Test_write(t *testing.T) { giveSrc := []byte("hello, world") memory := make([]byte, len(giveSrc)-1) n := write(memory, giveSrc, 0, int32(len(giveSrc))) - assert.Equal(t, n, int64(-1)) + assert.Equal(t, int64(-1), n) }) t.Run("fails to write to invalid access", func(t *testing.T) { giveSrc := []byte("hello, world") memory := make([]byte, len(giveSrc)) n := write(memory, giveSrc, 0, -1) - assert.Equal(t, n, int64(-1)) + assert.Equal(t, int64(-1), n) n = write(memory, giveSrc, -1, 1) - assert.Equal(t, n, int64(-1)) + assert.Equal(t, int64(-1), n) + }) + + t.Run("truncated write due to size being smaller than len", func(t *testing.T) { + giveSrc := []byte("hello, world") + memory := make([]byte, 12) + n := write(memory, giveSrc, 0, int32(len(giveSrc)-2)) + assert.Equal(t, int64(-1), n) + }) + + t.Run("unwanted data when size exceeds written data", func(t *testing.T) { + giveSrc := []byte("hello, world") + memory := make([]byte, 20) + n := write(memory, giveSrc, 0, 20) + assert.Equal(t, int64(-1), n) }) }