Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CRE-40] Check binary size before decompression #994

Merged
merged 3 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 26 additions & 14 deletions pkg/workflows/wasm/host/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,27 +70,29 @@ func (r *store) delete(id string) {
}

var (
defaultTickInterval = 100 * time.Millisecond
defaultTimeout = 10 * time.Second
defaultMinMemoryMBs = 128
DefaultInitialFuel = uint64(100_000_000)
defaultMaxFetchRequests = 5
defaultTickInterval = 100 * time.Millisecond
defaultTimeout = 10 * time.Second
defaultMinMemoryMBs = 128
DefaultInitialFuel = uint64(100_000_000)
defaultMaxFetchRequests = 5
defaultMaxCompressedBinarySize = 10 * 1024 * 1024 // 10 MB
)

type DeterminismConfig struct {
// Seed is the seed used to generate cryptographically insecure random numbers in the module.
Seed int64
}
type ModuleConfig struct {
TickInterval time.Duration
Timeout *time.Duration
MaxMemoryMBs int64
MinMemoryMBs int64
InitialFuel uint64
Logger logger.Logger
IsUncompressed bool
Fetch func(ctx context.Context, req *wasmpb.FetchRequest) (*wasmpb.FetchResponse, error)
MaxFetchRequests int
TickInterval time.Duration
Timeout *time.Duration
MaxMemoryMBs int64
MinMemoryMBs int64
InitialFuel uint64
Logger logger.Logger
IsUncompressed bool
Fetch func(ctx context.Context, req *wasmpb.FetchRequest) (*wasmpb.FetchResponse, error)
MaxFetchRequests int
MaxCompressedBinarySize uint64

// Labeler is used to emit messages from the module.
Labeler custmsg.MessageEmitter
Expand Down Expand Up @@ -166,6 +168,10 @@ func NewModule(modCfg *ModuleConfig, binary []byte, opts ...func(*ModuleConfig))
modCfg.MinMemoryMBs = int64(defaultMinMemoryMBs)
}

if modCfg.MaxCompressedBinarySize == 0 {
modCfg.MaxCompressedBinarySize = uint64(defaultMaxCompressedBinarySize)
}

// Take the max of the min and the configured max memory mbs.
// We do this because Go requires a minimum of 16 megabytes to run,
// and local testing has shown that with less than the min, some
Expand All @@ -183,6 +189,12 @@ func NewModule(modCfg *ModuleConfig, binary []byte, opts ...func(*ModuleConfig))

engine := wasmtime.NewEngineWithConfig(cfg)
if !modCfg.IsUncompressed {
// validate the binary size before decompressing
// this is to prevent decompression bombs
if uint64(len(binary)) > modCfg.MaxCompressedBinarySize {
return nil, fmt.Errorf("binary size exceeds the maximum allowed size of %d bytes", modCfg.MaxCompressedBinarySize)
}

rdr := brotli.NewReader(bytes.NewBuffer(binary))
decompedBinary, err := io.ReadAll(rdr)
if err != nil {
Expand Down
40 changes: 40 additions & 0 deletions pkg/workflows/wasm/host/wasm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,46 @@ func TestModule_Sandbox_Memory(t *testing.T) {
assert.ErrorContains(t, err, "exit status 2")
}

func TestModule_CompressedBinarySize(t *testing.T) {
t.Parallel()

t.Run("compressed binary size is smaller than the default 10mb limit", func(t *testing.T) {
binary := createTestBinary(successBinaryCmd, successBinaryLocation, false, t)

_, err := NewModule(&ModuleConfig{IsUncompressed: false, Logger: logger.Test(t)}, binary)
require.NoError(t, err)
})

t.Run("compressed binary size is bigger than the default 10mb limit", func(t *testing.T) {
binary := make([]byte, defaultMaxCompressedBinarySize+1)

var b bytes.Buffer
bwr := brotli.NewWriter(&b)
_, err := bwr.Write(binary)
require.NoError(t, err)
require.NoError(t, bwr.Close())

_, err = NewModule(&ModuleConfig{IsUncompressed: false, Logger: logger.Test(t)}, binary)
default10mbLimit := fmt.Sprintf("binary size exceeds the maximum allowed size of %d bytes", defaultMaxCompressedBinarySize)
require.ErrorContains(t, err, default10mbLimit)
})

t.Run("compressed binary size is bigger than the custom limit", func(t *testing.T) {
customMaxCompressedBinarySize := uint64(1 * 1024 * 1024)
binary := make([]byte, customMaxCompressedBinarySize+1)

var b bytes.Buffer
bwr := brotli.NewWriter(&b)
_, err := bwr.Write(binary)
require.NoError(t, err)
require.NoError(t, bwr.Close())

_, err = NewModule(&ModuleConfig{IsUncompressed: false, MaxCompressedBinarySize: customMaxCompressedBinarySize, Logger: logger.Test(t)}, binary)
default10mbLimit := fmt.Sprintf("binary size exceeds the maximum allowed size of %d bytes", customMaxCompressedBinarySize)
require.ErrorContains(t, err, default10mbLimit)
})
}

func TestModule_Sandbox_SleepIsStubbedOut(t *testing.T) {
t.Parallel()
ctx := tests.Context(t)
Expand Down
Loading