From 8a79706fbbe0d3156ef40aa3f2489f075410e31b Mon Sep 17 00:00:00 2001 From: Jean Boussier Date: Fri, 8 Nov 2024 15:30:52 +0100 Subject: [PATCH] Prevent memory leak when a recursive unpacker raises an exception The child stack wouldn't be popped nor freed. Additionally, even once the packer was freed, only the latest stack would be freed, all the parent stack would be leaked. Practically speaking, every time a recursive unpacker would raise, 4kiB would be leaked. --- ChangeLog | 2 ++ ext/msgpack/unpacker.c | 44 ++++++++++++++++++++++++++++++++---- ext/msgpack/unpacker.h | 1 + ext/msgpack/unpacker_class.c | 32 ++++++++++++++++++-------- spec/spec_helper.rb | 4 ++-- spec/unpacker_spec.rb | 41 +++++++++++++++++++++++++++++++++ 6 files changed, 108 insertions(+), 16 deletions(-) diff --git a/ChangeLog b/ChangeLog index dafbc38f..4c454cbc 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,5 @@ +* Fixed a potental memory leak when recursive unpacker raise. + 2024-10-03 1.7.3 * Limit initial containers pre-allocation to `SHRT_MAX` (32k) entries. diff --git a/ext/msgpack/unpacker.c b/ext/msgpack/unpacker.c index 4fd5942d..32423e9c 100644 --- a/ext/msgpack/unpacker.c +++ b/ext/msgpack/unpacker.c @@ -26,6 +26,27 @@ #define rb_proc_call_with_block(recv, argc, argv, block) rb_funcallv(recv, rb_intern("call"), argc, argv) #endif +struct protected_proc_call_args { + VALUE proc; + int argc; + VALUE *argv; +}; + +static VALUE protected_proc_call_safe(VALUE _args) { + struct protected_proc_call_args *args = (struct protected_proc_call_args *)_args; + + return rb_proc_call_with_block(args->proc, args->argc, args->argv, Qnil); +} + +static VALUE protected_proc_call(VALUE proc, int argc, VALUE *argv, int *raised) { + struct protected_proc_call_args args = { + .proc = proc, + .argc = argc, + .argv = argv, + }; + return rb_protect(protected_proc_call_safe, (VALUE)&args, raised); +} + static int RAW_TYPE_STRING = 256; static int RAW_TYPE_BINARY = 257; static int16_t INITIAL_BUFFER_CAPACITY_MAX = SHRT_MAX; @@ -87,7 +108,12 @@ static inline void _msgpack_unpacker_free_stack(msgpack_unpacker_stack_t* stack) void _msgpack_unpacker_destroy(msgpack_unpacker_t* uk) { - _msgpack_unpacker_free_stack(uk->stack); + msgpack_unpacker_stack_t *stack; + while ((stack = uk->stack)) { + uk->stack = stack->parent; + _msgpack_unpacker_free_stack(stack); + } + msgpack_buffer_destroy(UNPACKER_BUFFER_(uk)); } @@ -186,7 +212,12 @@ static inline int object_complete_ext(msgpack_unpacker_t* uk, int ext_type, VALU if(proc != Qnil) { VALUE obj; VALUE arg = (str == Qnil ? rb_str_buf_new(0) : str); - obj = rb_proc_call_with_block(proc, 1, &arg, Qnil); + int raised; + obj = protected_proc_call(proc, 1, &arg, &raised); + if (raised) { + uk->last_object = rb_errinfo(); + return PRIMITIVE_RECURSIVE_RAISED; + } return object_complete(uk, obj); } @@ -316,11 +347,16 @@ static inline int read_raw_body_begin(msgpack_unpacker_t* uk, int raw_type) child_stack->parent = uk->stack; uk->stack = child_stack; - obj = rb_proc_call_with_block(proc, 1, &uk->self, Qnil); - + int raised; + obj = protected_proc_call(proc, 1, &uk->self, &raised); uk->stack = child_stack->parent; _msgpack_unpacker_free_stack(child_stack); + if (raised) { + uk->last_object = rb_errinfo(); + return PRIMITIVE_RECURSIVE_RAISED; + } + return object_complete(uk, obj); } } diff --git a/ext/msgpack/unpacker.h b/ext/msgpack/unpacker.h index 40c3b9f6..6925b108 100644 --- a/ext/msgpack/unpacker.h +++ b/ext/msgpack/unpacker.h @@ -119,6 +119,7 @@ static inline void msgpack_unpacker_set_allow_unknown_ext(msgpack_unpacker_t* uk #define PRIMITIVE_STACK_TOO_DEEP -3 #define PRIMITIVE_UNEXPECTED_TYPE -4 #define PRIMITIVE_UNEXPECTED_EXT_TYPE -5 +#define PRIMITIVE_RECURSIVE_RAISED -6 int msgpack_unpacker_read(msgpack_unpacker_t* uk, size_t target_stack_depth); diff --git a/ext/msgpack/unpacker_class.c b/ext/msgpack/unpacker_class.c index ee5f3796..06dd3dfa 100644 --- a/ext/msgpack/unpacker_class.c +++ b/ext/msgpack/unpacker_class.c @@ -65,7 +65,11 @@ static size_t Unpacker_memsize(const void *ptr) total_size += sizeof(msgpack_unpacker_ext_registry_t) / (uk->ext_registry->borrow_count + 1); } - total_size += (uk->stack->depth + 1) * sizeof(msgpack_unpacker_stack_t); + msgpack_unpacker_stack_t *stack = uk->stack; + while (stack) { + total_size += (stack->depth + 1) * sizeof(msgpack_unpacker_stack_t); + stack = stack->parent; + } return total_size + msgpack_buffer_memsize(&uk->buffer); } @@ -156,20 +160,28 @@ static VALUE Unpacker_allow_unknown_ext_p(VALUE self) return uk->allow_unknown_ext ? Qtrue : Qfalse; } -NORETURN(static void raise_unpacker_error(int r)) +NORETURN(static void raise_unpacker_error(msgpack_unpacker_t *uk, int r)) { + uk->stack->depth = 0; switch(r) { case PRIMITIVE_EOF: rb_raise(rb_eEOFError, "end of buffer reached"); + break; case PRIMITIVE_INVALID_BYTE: rb_raise(eMalformedFormatError, "invalid byte"); + break; case PRIMITIVE_STACK_TOO_DEEP: rb_raise(eStackError, "stack level too deep"); + break; case PRIMITIVE_UNEXPECTED_TYPE: rb_raise(eUnexpectedTypeError, "unexpected type"); + break; case PRIMITIVE_UNEXPECTED_EXT_TYPE: - // rb_bug("unexpected extension type"); rb_raise(eUnknownExtTypeError, "unexpected extension type"); + break; + case PRIMITIVE_RECURSIVE_RAISED: + rb_exc_raise(msgpack_unpacker_get_last_object(uk)); + break; default: rb_raise(eUnpackError, "logically unknown error %d", r); } @@ -190,7 +202,7 @@ static VALUE Unpacker_read(VALUE self) int r = msgpack_unpacker_read(uk, 0); if(r < 0) { - raise_unpacker_error(r); + raise_unpacker_error(uk, r); } return msgpack_unpacker_get_last_object(uk); @@ -202,7 +214,7 @@ static VALUE Unpacker_skip(VALUE self) int r = msgpack_unpacker_skip(uk, 0); if(r < 0) { - raise_unpacker_error(r); + raise_unpacker_error(uk, r); } return Qnil; @@ -214,7 +226,7 @@ static VALUE Unpacker_skip_nil(VALUE self) int r = msgpack_unpacker_skip_nil(uk); if(r < 0) { - raise_unpacker_error(r); + raise_unpacker_error(uk, r); } if(r) { @@ -230,7 +242,7 @@ static VALUE Unpacker_read_array_header(VALUE self) uint32_t size; int r = msgpack_unpacker_read_array_header(uk, &size); if(r < 0) { - raise_unpacker_error(r); + raise_unpacker_error(uk, r); } return ULONG2NUM(size); // long at least 32 bits @@ -243,7 +255,7 @@ static VALUE Unpacker_read_map_header(VALUE self) uint32_t size; int r = msgpack_unpacker_read_map_header(uk, &size); if(r < 0) { - raise_unpacker_error((int)r); + raise_unpacker_error(uk, r); } return ULONG2NUM(size); // long at least 32 bits @@ -270,7 +282,7 @@ static VALUE Unpacker_each_impl(VALUE self) if(r == PRIMITIVE_EOF) { return Qnil; } - raise_unpacker_error(r); + raise_unpacker_error(uk, r); } VALUE v = msgpack_unpacker_get_last_object(uk); #ifdef JRUBY @@ -369,7 +381,7 @@ static VALUE Unpacker_full_unpack(VALUE self) int r = msgpack_unpacker_read(uk, 0); if(r < 0) { - raise_unpacker_error(r); + raise_unpacker_error(uk, r); } /* raise if extra bytes follow */ diff --git a/spec/spec_helper.rb b/spec/spec_helper.rb index 3556a7b6..18c4b004 100644 --- a/spec/spec_helper.rb +++ b/spec/spec_helper.rb @@ -22,8 +22,8 @@ # This method was added in Ruby 3.0.0. Calling it this way asks the GC to # move objects around, helping to find object movement bugs. begin - GC.verify_compaction_references(double_heap: true, toward: :empty) - rescue NotImplementedError + GC.verify_compaction_references(expand_heap: true, toward: :empty) + rescue NotImplementedError, ArgumentError # Some platforms don't support compaction end end diff --git a/spec/unpacker_spec.rb b/spec/unpacker_spec.rb index 3b4b2a32..de2f06cd 100644 --- a/spec/unpacker_spec.rb +++ b/spec/unpacker_spec.rb @@ -901,4 +901,45 @@ def flatten(struct, results = []) GC.stress = stress end end + + if RUBY_PLATFORM != "java" + it "doesn't leak when a recursive unpacker raises" do + hash_with_indifferent_access = Class.new(Hash) + msgpack = MessagePack::Factory.new + msgpack.register_type( + 0x02, + hash_with_indifferent_access, + packer: ->(value, packer) do + packer.write(value.to_h) + end, + unpacker: ->(unpacker) { raise RuntimeError, "Ooops" }, + recursive: true + ) + + packer = msgpack.packer + data = [[[[[[[hash_with_indifferent_access.new]]]]]]] + payload = msgpack.dump(data) + + unpacker = msgpack.unpacker + 2.times do + unpacker.buffer.clear + unpacker.feed(payload) + expect { + unpacker.full_unpack + }.to raise_error(RuntimeError, "Ooops") + end + + memsize = ObjectSpace.memsize_of(unpacker) + + 10.times do + unpacker.buffer.clear + unpacker.feed(payload) + expect { + unpacker.full_unpack + }.to raise_error(RuntimeError, "Ooops") + end + + expect(memsize).to eq ObjectSpace.memsize_of(unpacker) + end + end end