Skip to content

Commit

Permalink
Merge pull request #371 from Shopify/recursive-raise-leak
Browse files Browse the repository at this point in the history
Prevent memory leak when a recursive unpacker raises an exception
  • Loading branch information
byroot authored Nov 8, 2024
2 parents 6bbaa97 + 8a79706 commit 9bac145
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 16 deletions.
2 changes: 2 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
* Fixed a potental memory leak when recursive unpacker raise.

2024-10-03 1.7.3

* Limit initial containers pre-allocation to `SHRT_MAX` (32k) entries.
Expand Down
44 changes: 40 additions & 4 deletions ext/msgpack/unpacker.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,27 @@
#define rb_proc_call_with_block(recv, argc, argv, block) rb_funcallv(recv, rb_intern("call"), argc, argv)
#endif

struct protected_proc_call_args {
VALUE proc;
int argc;
VALUE *argv;
};

static VALUE protected_proc_call_safe(VALUE _args) {
struct protected_proc_call_args *args = (struct protected_proc_call_args *)_args;

return rb_proc_call_with_block(args->proc, args->argc, args->argv, Qnil);
}

static VALUE protected_proc_call(VALUE proc, int argc, VALUE *argv, int *raised) {
struct protected_proc_call_args args = {
.proc = proc,
.argc = argc,
.argv = argv,
};
return rb_protect(protected_proc_call_safe, (VALUE)&args, raised);
}

static int RAW_TYPE_STRING = 256;
static int RAW_TYPE_BINARY = 257;
static int16_t INITIAL_BUFFER_CAPACITY_MAX = SHRT_MAX;
Expand Down Expand Up @@ -87,7 +108,12 @@ static inline void _msgpack_unpacker_free_stack(msgpack_unpacker_stack_t* stack)

void _msgpack_unpacker_destroy(msgpack_unpacker_t* uk)
{
_msgpack_unpacker_free_stack(uk->stack);
msgpack_unpacker_stack_t *stack;
while ((stack = uk->stack)) {
uk->stack = stack->parent;
_msgpack_unpacker_free_stack(stack);
}

msgpack_buffer_destroy(UNPACKER_BUFFER_(uk));
}

Expand Down Expand Up @@ -186,7 +212,12 @@ static inline int object_complete_ext(msgpack_unpacker_t* uk, int ext_type, VALU
if(proc != Qnil) {
VALUE obj;
VALUE arg = (str == Qnil ? rb_str_buf_new(0) : str);
obj = rb_proc_call_with_block(proc, 1, &arg, Qnil);
int raised;
obj = protected_proc_call(proc, 1, &arg, &raised);
if (raised) {
uk->last_object = rb_errinfo();
return PRIMITIVE_RECURSIVE_RAISED;
}
return object_complete(uk, obj);
}

Expand Down Expand Up @@ -316,11 +347,16 @@ static inline int read_raw_body_begin(msgpack_unpacker_t* uk, int raw_type)
child_stack->parent = uk->stack;
uk->stack = child_stack;

obj = rb_proc_call_with_block(proc, 1, &uk->self, Qnil);

int raised;
obj = protected_proc_call(proc, 1, &uk->self, &raised);
uk->stack = child_stack->parent;
_msgpack_unpacker_free_stack(child_stack);

if (raised) {
uk->last_object = rb_errinfo();
return PRIMITIVE_RECURSIVE_RAISED;
}

return object_complete(uk, obj);
}
}
Expand Down
1 change: 1 addition & 0 deletions ext/msgpack/unpacker.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ static inline void msgpack_unpacker_set_allow_unknown_ext(msgpack_unpacker_t* uk
#define PRIMITIVE_STACK_TOO_DEEP -3
#define PRIMITIVE_UNEXPECTED_TYPE -4
#define PRIMITIVE_UNEXPECTED_EXT_TYPE -5
#define PRIMITIVE_RECURSIVE_RAISED -6

int msgpack_unpacker_read(msgpack_unpacker_t* uk, size_t target_stack_depth);

Expand Down
32 changes: 22 additions & 10 deletions ext/msgpack/unpacker_class.c
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ static size_t Unpacker_memsize(const void *ptr)
total_size += sizeof(msgpack_unpacker_ext_registry_t) / (uk->ext_registry->borrow_count + 1);
}

total_size += (uk->stack->depth + 1) * sizeof(msgpack_unpacker_stack_t);
msgpack_unpacker_stack_t *stack = uk->stack;
while (stack) {
total_size += (stack->depth + 1) * sizeof(msgpack_unpacker_stack_t);
stack = stack->parent;
}

return total_size + msgpack_buffer_memsize(&uk->buffer);
}
Expand Down Expand Up @@ -156,20 +160,28 @@ static VALUE Unpacker_allow_unknown_ext_p(VALUE self)
return uk->allow_unknown_ext ? Qtrue : Qfalse;
}

NORETURN(static void raise_unpacker_error(int r))
NORETURN(static void raise_unpacker_error(msgpack_unpacker_t *uk, int r))
{
uk->stack->depth = 0;
switch(r) {
case PRIMITIVE_EOF:
rb_raise(rb_eEOFError, "end of buffer reached");
break;
case PRIMITIVE_INVALID_BYTE:
rb_raise(eMalformedFormatError, "invalid byte");
break;
case PRIMITIVE_STACK_TOO_DEEP:
rb_raise(eStackError, "stack level too deep");
break;
case PRIMITIVE_UNEXPECTED_TYPE:
rb_raise(eUnexpectedTypeError, "unexpected type");
break;
case PRIMITIVE_UNEXPECTED_EXT_TYPE:
// rb_bug("unexpected extension type");
rb_raise(eUnknownExtTypeError, "unexpected extension type");
break;
case PRIMITIVE_RECURSIVE_RAISED:
rb_exc_raise(msgpack_unpacker_get_last_object(uk));
break;
default:
rb_raise(eUnpackError, "logically unknown error %d", r);
}
Expand All @@ -190,7 +202,7 @@ static VALUE Unpacker_read(VALUE self)

int r = msgpack_unpacker_read(uk, 0);
if(r < 0) {
raise_unpacker_error(r);
raise_unpacker_error(uk, r);
}

return msgpack_unpacker_get_last_object(uk);
Expand All @@ -202,7 +214,7 @@ static VALUE Unpacker_skip(VALUE self)

int r = msgpack_unpacker_skip(uk, 0);
if(r < 0) {
raise_unpacker_error(r);
raise_unpacker_error(uk, r);
}

return Qnil;
Expand All @@ -214,7 +226,7 @@ static VALUE Unpacker_skip_nil(VALUE self)

int r = msgpack_unpacker_skip_nil(uk);
if(r < 0) {
raise_unpacker_error(r);
raise_unpacker_error(uk, r);
}

if(r) {
Expand All @@ -230,7 +242,7 @@ static VALUE Unpacker_read_array_header(VALUE self)
uint32_t size;
int r = msgpack_unpacker_read_array_header(uk, &size);
if(r < 0) {
raise_unpacker_error(r);
raise_unpacker_error(uk, r);
}

return ULONG2NUM(size); // long at least 32 bits
Expand All @@ -243,7 +255,7 @@ static VALUE Unpacker_read_map_header(VALUE self)
uint32_t size;
int r = msgpack_unpacker_read_map_header(uk, &size);
if(r < 0) {
raise_unpacker_error((int)r);
raise_unpacker_error(uk, r);
}

return ULONG2NUM(size); // long at least 32 bits
Expand All @@ -270,7 +282,7 @@ static VALUE Unpacker_each_impl(VALUE self)
if(r == PRIMITIVE_EOF) {
return Qnil;
}
raise_unpacker_error(r);
raise_unpacker_error(uk, r);
}
VALUE v = msgpack_unpacker_get_last_object(uk);
#ifdef JRUBY
Expand Down Expand Up @@ -369,7 +381,7 @@ static VALUE Unpacker_full_unpack(VALUE self)

int r = msgpack_unpacker_read(uk, 0);
if(r < 0) {
raise_unpacker_error(r);
raise_unpacker_error(uk, r);
}

/* raise if extra bytes follow */
Expand Down
4 changes: 2 additions & 2 deletions spec/spec_helper.rb
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
# This method was added in Ruby 3.0.0. Calling it this way asks the GC to
# move objects around, helping to find object movement bugs.
begin
GC.verify_compaction_references(double_heap: true, toward: :empty)
rescue NotImplementedError
GC.verify_compaction_references(expand_heap: true, toward: :empty)
rescue NotImplementedError, ArgumentError
# Some platforms don't support compaction
end
end
Expand Down
41 changes: 41 additions & 0 deletions spec/unpacker_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -901,4 +901,45 @@ def flatten(struct, results = [])
GC.stress = stress
end
end

if RUBY_PLATFORM != "java"
it "doesn't leak when a recursive unpacker raises" do
hash_with_indifferent_access = Class.new(Hash)
msgpack = MessagePack::Factory.new
msgpack.register_type(
0x02,
hash_with_indifferent_access,
packer: ->(value, packer) do
packer.write(value.to_h)
end,
unpacker: ->(unpacker) { raise RuntimeError, "Ooops" },
recursive: true
)

packer = msgpack.packer
data = [[[[[[[hash_with_indifferent_access.new]]]]]]]
payload = msgpack.dump(data)

unpacker = msgpack.unpacker
2.times do
unpacker.buffer.clear
unpacker.feed(payload)
expect {
unpacker.full_unpack
}.to raise_error(RuntimeError, "Ooops")
end

memsize = ObjectSpace.memsize_of(unpacker)

10.times do
unpacker.buffer.clear
unpacker.feed(payload)
expect {
unpacker.full_unpack
}.to raise_error(RuntimeError, "Ooops")
end

expect(memsize).to eq ObjectSpace.memsize_of(unpacker)
end
end
end

0 comments on commit 9bac145

Please sign in to comment.