diff --git a/pkg/workflows/wasm/host/module.go b/pkg/workflows/wasm/host/module.go index f8bd77d02..48345216d 100644 --- a/pkg/workflows/wasm/host/module.go +++ b/pkg/workflows/wasm/host/module.go @@ -75,6 +75,7 @@ var ( defaultMinMemoryMBs = 128 DefaultInitialFuel = uint64(100_000_000) defaultMaxFetchRequests = 5 + defaultMaxBinarySize = 10 * 1024 * 1024 // 10 MB ) type DeterminismConfig struct { @@ -91,6 +92,7 @@ type ModuleConfig struct { IsUncompressed bool Fetch func(ctx context.Context, req *wasmpb.FetchRequest) (*wasmpb.FetchResponse, error) MaxFetchRequests int + MaxBinarySize int64 // Labeler is used to emit messages from the module. Labeler custmsg.MessageEmitter @@ -166,6 +168,10 @@ func NewModule(modCfg *ModuleConfig, binary []byte, opts ...func(*ModuleConfig)) modCfg.MinMemoryMBs = int64(defaultMinMemoryMBs) } + if modCfg.MaxBinarySize == 0 { + modCfg.MaxBinarySize = int64(defaultMaxBinarySize) + } + // Take the max of the min and the configured max memory mbs. // We do this because Go requires a minimum of 16 megabytes to run, // and local testing has shown that with less than the min, some @@ -183,6 +189,12 @@ func NewModule(modCfg *ModuleConfig, binary []byte, opts ...func(*ModuleConfig)) engine := wasmtime.NewEngineWithConfig(cfg) if !modCfg.IsUncompressed { + // validate the binary size before decompressing + // this is to prevent decompression bombs + if int64(len(binary)) > modCfg.MaxBinarySize { + return nil, fmt.Errorf("binary size exceeds the maximum allowed size of %d bytes", modCfg.MaxBinarySize) + } + rdr := brotli.NewReader(bytes.NewBuffer(binary)) decompedBinary, err := io.ReadAll(rdr) if err != nil { diff --git a/pkg/workflows/wasm/host/wasm_test.go b/pkg/workflows/wasm/host/wasm_test.go index 3e5335a9d..6ecbc5272 100644 --- a/pkg/workflows/wasm/host/wasm_test.go +++ b/pkg/workflows/wasm/host/wasm_test.go @@ -903,6 +903,48 @@ func TestModule_Sandbox_Memory(t *testing.T) { assert.ErrorContains(t, err, "exit status 2") } +func TestModule_CompressedBinarySize(t *testing.T) { + t.Parallel() + + t.Run("compressed binary size is smaller than the default 10mb limit", func(t *testing.T) { + binary := createTestBinary(successBinaryCmd, successBinaryLocation, false, t) + + _, err := NewModule(&ModuleConfig{IsUncompressed: false, Logger: logger.Test(t)}, binary) + require.NoError(t, err) + }) + + t.Run("compressed binary size is bigger than the default 10mb limit", func(t *testing.T) { + // 11mb binary + binary := make([]byte, 11*1024*1024) + + var b bytes.Buffer + bwr := brotli.NewWriter(&b) + _, err := bwr.Write(binary) + require.NoError(t, err) + require.NoError(t, bwr.Close()) + + _, err = NewModule(&ModuleConfig{IsUncompressed: false, Logger: logger.Test(t)}, binary) + default10mbLimit := fmt.Sprintf("binary size exceeds the maximum allowed size of %d bytes", 10*1024*1024) + require.ErrorContains(t, err, default10mbLimit) + }) + + t.Run("compressed binary size is bigger than the custom limit", func(t *testing.T) { + // 2mb binary + binary := make([]byte, 2*1024*1024) + + var b bytes.Buffer + bwr := brotli.NewWriter(&b) + _, err := bwr.Write(binary) + require.NoError(t, err) + require.NoError(t, bwr.Close()) + + customMaxBinarySize := int64(1 * 1024 * 1024) + _, err = NewModule(&ModuleConfig{IsUncompressed: false, MaxBinarySize: customMaxBinarySize, Logger: logger.Test(t)}, binary) + default10mbLimit := fmt.Sprintf("binary size exceeds the maximum allowed size of %d bytes", customMaxBinarySize) + require.ErrorContains(t, err, default10mbLimit) + }) +} + func TestModule_Sandbox_SleepIsStubbedOut(t *testing.T) { t.Parallel() ctx := tests.Context(t)