From 33878ceb99a6132b34a2dbee186695d97688510f Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Mon, 16 Sep 2024 11:28:26 +0200 Subject: [PATCH 1/2] Add max recursion depth Signature: ``` // MaxDepth will set the maximum recursion depth. // If the maximum depth is exceeded, ErrMaxDepth is returned. // Less than or 0 means no limit (default). func (d *Decoder) MaxDepth(n int) *Decoder { ``` This can be used to prevent stack exhaustion on adversarial content. --- decoder.go | 18 ++++++++++++++++++ decoder_test.go | 39 +++++++++++++++++++++++++++++++++++++++ errors.go | 1 + 3 files changed, 58 insertions(+) diff --git a/decoder.go b/decoder.go index a70d495..8028bf8 100644 --- a/decoder.go +++ b/decoder.go @@ -67,6 +67,7 @@ func (kvs KVS) MarshalJSON() ([]byte, error) { type Decoder struct { *scanner emitDepth int + maxDepth int emitKV bool emitRecursive bool objectAsKVS bool @@ -137,6 +138,14 @@ func (d *Decoder) Pos() int { return int(d.pos) } // Err returns the most recent decoder error if any, or nil func (d *Decoder) Err() error { return d.err } +// MaxDepth will set the maximum recursion depth. +// If the maximum depth is exceeded, ErrMaxDepth is returned. +// Less than or 0 means no limit (default). +func (d *Decoder) MaxDepth(n int) *Decoder { + d.maxDepth = n + return d +} + // Decode parses the JSON-encoded data and returns an interface value func (d *Decoder) decode() { defer close(d.metaCh) @@ -417,6 +426,9 @@ func (d *Decoder) number() (float64, error) { // array accept valid JSON array value func (d *Decoder) array() ([]interface{}, error) { d.depth++ + if d.maxDepth > 0 && d.depth > d.maxDepth { + return nil, ErrMaxDepth + } var ( c byte @@ -458,6 +470,9 @@ out: // object accept valid JSON array value func (d *Decoder) object() (map[string]interface{}, error) { d.depth++ + if d.maxDepth > 0 && d.depth > d.maxDepth { + return nil, ErrMaxDepth + } var ( c byte @@ -543,6 +558,9 @@ out: // object (ordered) accept valid JSON array value func (d *Decoder) objectOrdered() (KVS, error) { d.depth++ + if d.maxDepth > 0 && d.depth > d.maxDepth { + return nil, ErrMaxDepth + } var ( c byte diff --git a/decoder_test.go b/decoder_test.go index 642c4d5..8c876fc 100644 --- a/decoder_test.go +++ b/decoder_test.go @@ -235,3 +235,42 @@ func TestDecoderReaderFailure(t *testing.T) { t.Fatalf("missing expected underlying reader error") } } + +func TestDecoderMaxDepth(t *testing.T) { + tests := []struct { + input string + maxDepth int + mustFail bool + }{ + // No limit + {input: `[{"bio":"bada bing bada boom","id":1,"name":"Charles","falseVal":false}]`, maxDepth: 0, mustFail: false}, + // Array + object = depth 2 = false + {input: `[{"bio":"bada bing bada boom","id":1,"name":"Charles","falseVal":false}]`, maxDepth: 1, mustFail: true}, + // Depth 2 = ok + {input: `[{"bio":"bada bing bada boom","id":1,"name":"Charles","falseVal":false}]`, maxDepth: 2, mustFail: false}, + // Arrays: + {input: `[[[[[[[[[[[[[[[[[[[[[["ok"]]]]]]]]]]]]]]]]]]]]]]`, maxDepth: 2, mustFail: true}, + {input: `[[[[[[[[[[[[[[[[[[[[[["ok"]]]]]]]]]]]]]]]]]]]]]]`, maxDepth: 10, mustFail: true}, + {input: `[[[[[[[[[[[[[[[[[[[[[["ok"]]]]]]]]]]]]]]]]]]]]]]`, maxDepth: 100, mustFail: false}, + // Objects: + {input: `{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"ok":false}}}}}}}}}}}}}}}}}}}}}}`, maxDepth: 2, mustFail: true}, + {input: `{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"ok":false}}}}}}}}}}}}}}}}}}}}}}`, maxDepth: 10, mustFail: true}, + {input: `{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"a":{"ok":false}}}}}}}}}}}}}}}}}}}}}}`, maxDepth: 100, mustFail: false}, + } + + for _, test := range tests { + decoder := NewDecoder(mkReader(test.input), 0).MaxDepth(test.maxDepth) + var mv *MetaValue + for mv = range decoder.Stream() { + t.Logf("depth=%d offset=%d len=%d (%v)", mv.Depth, mv.Offset, mv.Length, mv.Value) + } + + err := decoder.Err() + if test.mustFail && err != ErrMaxDepth { + t.Fatalf("missing expected decoder error, got %q", err) + } + if !test.mustFail && err != nil { + t.Fatalf("unexpected error: %q", err) + } + } +} diff --git a/errors.go b/errors.go index f240638..8b665bf 100644 --- a/errors.go +++ b/errors.go @@ -9,6 +9,7 @@ import ( var ( ErrSyntax = DecoderError{msg: "invalid character"} ErrUnexpectedEOF = DecoderError{msg: "unexpected end of JSON input"} + ErrMaxDepth = DecoderError{msg: "maximum recursion depth exceeded"} ) type errPos [2]int // line number, byte offset where error occurred From 78d6d4ff22307740311ffd46b60b59af4182709b Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Mon, 23 Sep 2024 12:04:00 +0200 Subject: [PATCH 2/2] Bonus: Fix number-after `-` and number after `.` checks. --- decoder.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/decoder.go b/decoder.go index 8028bf8..2c47d20 100644 --- a/decoder.go +++ b/decoder.go @@ -200,7 +200,7 @@ func (d *Decoder) any() (interface{}, ValueType, error) { i, err := d.number() return i, Number, err case '-': - if c = d.next(); c < '0' && c > '9' { + if c = d.next(); c < '0' || c > '9' { return nil, Unknown, d.mkError(ErrSyntax, "in negative numeric literal") } n, err := d.number() @@ -374,7 +374,7 @@ func (d *Decoder) number() (float64, error) { d.scratch.add(c) // first char following must be digit - if c = d.next(); c < '0' && c > '9' { + if c = d.next(); c < '0' || c > '9' { return 0, d.mkError(ErrSyntax, "after decimal point in numeric literal") } d.scratch.add(c)