Skip to content

Commit

Permalink
Prevent memory leak when a recursive unpacker raises an exception
Browse files Browse the repository at this point in the history
The child stack wouldn't be popped nor freed.

Additionally, even once the packer was freed, only the latest stack
would be freed, all the parent stack would be leaked.

Practically speaking, every time a recursive unpacker would raise,
4kiB would be leaked.
  • Loading branch information
byroot committed Nov 8, 2024
1 parent 6bbaa97 commit 8a79706
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 16 deletions.
2 changes: 2 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
* Fixed a potental memory leak when recursive unpacker raise.

2024-10-03 1.7.3

* Limit initial containers pre-allocation to `SHRT_MAX` (32k) entries.
Expand Down
44 changes: 40 additions & 4 deletions ext/msgpack/unpacker.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,27 @@
#define rb_proc_call_with_block(recv, argc, argv, block) rb_funcallv(recv, rb_intern("call"), argc, argv)
#endif

struct protected_proc_call_args {
VALUE proc;
int argc;
VALUE *argv;
};

static VALUE protected_proc_call_safe(VALUE _args) {
struct protected_proc_call_args *args = (struct protected_proc_call_args *)_args;

return rb_proc_call_with_block(args->proc, args->argc, args->argv, Qnil);
}

static VALUE protected_proc_call(VALUE proc, int argc, VALUE *argv, int *raised) {
struct protected_proc_call_args args = {
.proc = proc,
.argc = argc,
.argv = argv,
};
return rb_protect(protected_proc_call_safe, (VALUE)&args, raised);
}

static int RAW_TYPE_STRING = 256;
static int RAW_TYPE_BINARY = 257;
static int16_t INITIAL_BUFFER_CAPACITY_MAX = SHRT_MAX;
Expand Down Expand Up @@ -87,7 +108,12 @@ static inline void _msgpack_unpacker_free_stack(msgpack_unpacker_stack_t* stack)

void _msgpack_unpacker_destroy(msgpack_unpacker_t* uk)
{
_msgpack_unpacker_free_stack(uk->stack);
msgpack_unpacker_stack_t *stack;
while ((stack = uk->stack)) {
uk->stack = stack->parent;
_msgpack_unpacker_free_stack(stack);
}

msgpack_buffer_destroy(UNPACKER_BUFFER_(uk));
}

Expand Down Expand Up @@ -186,7 +212,12 @@ static inline int object_complete_ext(msgpack_unpacker_t* uk, int ext_type, VALU
if(proc != Qnil) {
VALUE obj;
VALUE arg = (str == Qnil ? rb_str_buf_new(0) : str);
obj = rb_proc_call_with_block(proc, 1, &arg, Qnil);
int raised;
obj = protected_proc_call(proc, 1, &arg, &raised);
if (raised) {
uk->last_object = rb_errinfo();
return PRIMITIVE_RECURSIVE_RAISED;
}
return object_complete(uk, obj);
}

Expand Down Expand Up @@ -316,11 +347,16 @@ static inline int read_raw_body_begin(msgpack_unpacker_t* uk, int raw_type)
child_stack->parent = uk->stack;
uk->stack = child_stack;

obj = rb_proc_call_with_block(proc, 1, &uk->self, Qnil);

int raised;
obj = protected_proc_call(proc, 1, &uk->self, &raised);
uk->stack = child_stack->parent;
_msgpack_unpacker_free_stack(child_stack);

if (raised) {
uk->last_object = rb_errinfo();
return PRIMITIVE_RECURSIVE_RAISED;
}

return object_complete(uk, obj);
}
}
Expand Down
1 change: 1 addition & 0 deletions ext/msgpack/unpacker.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ static inline void msgpack_unpacker_set_allow_unknown_ext(msgpack_unpacker_t* uk
#define PRIMITIVE_STACK_TOO_DEEP -3
#define PRIMITIVE_UNEXPECTED_TYPE -4
#define PRIMITIVE_UNEXPECTED_EXT_TYPE -5
#define PRIMITIVE_RECURSIVE_RAISED -6

int msgpack_unpacker_read(msgpack_unpacker_t* uk, size_t target_stack_depth);

Expand Down
32 changes: 22 additions & 10 deletions ext/msgpack/unpacker_class.c
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ static size_t Unpacker_memsize(const void *ptr)
total_size += sizeof(msgpack_unpacker_ext_registry_t) / (uk->ext_registry->borrow_count + 1);
}

total_size += (uk->stack->depth + 1) * sizeof(msgpack_unpacker_stack_t);
msgpack_unpacker_stack_t *stack = uk->stack;
while (stack) {
total_size += (stack->depth + 1) * sizeof(msgpack_unpacker_stack_t);
stack = stack->parent;
}

return total_size + msgpack_buffer_memsize(&uk->buffer);
}
Expand Down Expand Up @@ -156,20 +160,28 @@ static VALUE Unpacker_allow_unknown_ext_p(VALUE self)
return uk->allow_unknown_ext ? Qtrue : Qfalse;
}

NORETURN(static void raise_unpacker_error(int r))
NORETURN(static void raise_unpacker_error(msgpack_unpacker_t *uk, int r))
{
uk->stack->depth = 0;
switch(r) {
case PRIMITIVE_EOF:
rb_raise(rb_eEOFError, "end of buffer reached");
break;
case PRIMITIVE_INVALID_BYTE:
rb_raise(eMalformedFormatError, "invalid byte");
break;
case PRIMITIVE_STACK_TOO_DEEP:
rb_raise(eStackError, "stack level too deep");
break;
case PRIMITIVE_UNEXPECTED_TYPE:
rb_raise(eUnexpectedTypeError, "unexpected type");
break;
case PRIMITIVE_UNEXPECTED_EXT_TYPE:
// rb_bug("unexpected extension type");
rb_raise(eUnknownExtTypeError, "unexpected extension type");
break;
case PRIMITIVE_RECURSIVE_RAISED:
rb_exc_raise(msgpack_unpacker_get_last_object(uk));
break;
default:
rb_raise(eUnpackError, "logically unknown error %d", r);
}
Expand All @@ -190,7 +202,7 @@ static VALUE Unpacker_read(VALUE self)

int r = msgpack_unpacker_read(uk, 0);
if(r < 0) {
raise_unpacker_error(r);
raise_unpacker_error(uk, r);
}

return msgpack_unpacker_get_last_object(uk);
Expand All @@ -202,7 +214,7 @@ static VALUE Unpacker_skip(VALUE self)

int r = msgpack_unpacker_skip(uk, 0);
if(r < 0) {
raise_unpacker_error(r);
raise_unpacker_error(uk, r);
}

return Qnil;
Expand All @@ -214,7 +226,7 @@ static VALUE Unpacker_skip_nil(VALUE self)

int r = msgpack_unpacker_skip_nil(uk);
if(r < 0) {
raise_unpacker_error(r);
raise_unpacker_error(uk, r);
}

if(r) {
Expand All @@ -230,7 +242,7 @@ static VALUE Unpacker_read_array_header(VALUE self)
uint32_t size;
int r = msgpack_unpacker_read_array_header(uk, &size);
if(r < 0) {
raise_unpacker_error(r);
raise_unpacker_error(uk, r);
}

return ULONG2NUM(size); // long at least 32 bits
Expand All @@ -243,7 +255,7 @@ static VALUE Unpacker_read_map_header(VALUE self)
uint32_t size;
int r = msgpack_unpacker_read_map_header(uk, &size);
if(r < 0) {
raise_unpacker_error((int)r);
raise_unpacker_error(uk, r);
}

return ULONG2NUM(size); // long at least 32 bits
Expand All @@ -270,7 +282,7 @@ static VALUE Unpacker_each_impl(VALUE self)
if(r == PRIMITIVE_EOF) {
return Qnil;
}
raise_unpacker_error(r);
raise_unpacker_error(uk, r);
}
VALUE v = msgpack_unpacker_get_last_object(uk);
#ifdef JRUBY
Expand Down Expand Up @@ -369,7 +381,7 @@ static VALUE Unpacker_full_unpack(VALUE self)

int r = msgpack_unpacker_read(uk, 0);
if(r < 0) {
raise_unpacker_error(r);
raise_unpacker_error(uk, r);
}

/* raise if extra bytes follow */
Expand Down
4 changes: 2 additions & 2 deletions spec/spec_helper.rb
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
# This method was added in Ruby 3.0.0. Calling it this way asks the GC to
# move objects around, helping to find object movement bugs.
begin
GC.verify_compaction_references(double_heap: true, toward: :empty)
rescue NotImplementedError
GC.verify_compaction_references(expand_heap: true, toward: :empty)
rescue NotImplementedError, ArgumentError
# Some platforms don't support compaction
end
end
Expand Down
41 changes: 41 additions & 0 deletions spec/unpacker_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -901,4 +901,45 @@ def flatten(struct, results = [])
GC.stress = stress
end
end

if RUBY_PLATFORM != "java"
it "doesn't leak when a recursive unpacker raises" do
hash_with_indifferent_access = Class.new(Hash)
msgpack = MessagePack::Factory.new
msgpack.register_type(
0x02,
hash_with_indifferent_access,
packer: ->(value, packer) do
packer.write(value.to_h)
end,
unpacker: ->(unpacker) { raise RuntimeError, "Ooops" },
recursive: true
)

packer = msgpack.packer
data = [[[[[[[hash_with_indifferent_access.new]]]]]]]
payload = msgpack.dump(data)

unpacker = msgpack.unpacker
2.times do
unpacker.buffer.clear
unpacker.feed(payload)
expect {
unpacker.full_unpack
}.to raise_error(RuntimeError, "Ooops")
end

memsize = ObjectSpace.memsize_of(unpacker)

10.times do
unpacker.buffer.clear
unpacker.feed(payload)
expect {
unpacker.full_unpack
}.to raise_error(RuntimeError, "Ooops")
end

expect(memsize).to eq ObjectSpace.memsize_of(unpacker)
end
end
end

0 comments on commit 8a79706

Please sign in to comment.