From bc420fcc6b93b9aac660554a02028f0e30a01b75 Mon Sep 17 00:00:00 2001 From: Yi Duan Date: Tue, 6 Aug 2024 19:03:55 +0800 Subject: [PATCH] fix: didn't consider `json.Marshaler/Unmarshal` when handling `json:",string"` tag (#682) Co-authored-by: liuqiang.06 --- .../decoder/jitdec/assembler_regabi_amd64.go | 29 +++- internal/decoder/jitdec/compiler.go | 77 +++++++-- .../decoder/jitdec/generic_regabi_amd64.go | 1 + internal/decoder/jitdec/primitives.go | 7 + internal/decoder/optdec/compile_struct.go | 39 ++++- internal/decoder/optdec/compiler.go | 14 +- internal/decoder/optdec/interface.go | 20 ++- internal/encoder/compiler.go | 25 ++- issue_test/issue670_test.go | 147 ++++++++++++++++++ 9 files changed, 326 insertions(+), 33 deletions(-) create mode 100644 issue_test/issue670_test.go diff --git a/internal/decoder/jitdec/assembler_regabi_amd64.go b/internal/decoder/jitdec/assembler_regabi_amd64.go index 1daddad0e..3a2b718e9 100644 --- a/internal/decoder/jitdec/assembler_regabi_amd64.go +++ b/internal/decoder/jitdec/assembler_regabi_amd64.go @@ -972,11 +972,13 @@ var ( var ( _F_decodeJsonUnmarshaler obj.Addr + _F_decodeJsonUnmarshalerQuoted obj.Addr _F_decodeTextUnmarshaler obj.Addr ) func init() { _F_decodeJsonUnmarshaler = jit.Func(decodeJsonUnmarshaler) + _F_decodeJsonUnmarshalerQuoted = jit.Func(decodeJsonUnmarshalerQuoted) _F_decodeTextUnmarshaler = jit.Func(decodeTextUnmarshaler) } @@ -1061,14 +1063,15 @@ var ( _F_skip_number = jit.Imm(int64(native.S_skip_number)) ) -func (self *_Assembler) unmarshal_json(t reflect.Type, deref bool) { +func (self *_Assembler) unmarshal_json(t reflect.Type, deref bool, f obj.Addr) { self.call_sf(_F_skip_one) // CALL_SF skip_one self.Emit("TESTQ", _AX, _AX) // TESTQ AX, AX self.Sjmp("JS" , _LB_parsing_error_v) // JS _parse_error_v + self.Emit("MOVQ", _IC, _VAR_ic) // store for mismatche error skip self.slice_from_r(_AX, 0) // SLICE_R AX, $0 self.Emit("MOVQ" , _DI, _ARG_sv_p) // MOVQ DI, sv.p self.Emit("MOVQ" , _SI, _ARG_sv_n) // MOVQ SI, sv.n - self.unmarshal_func(t, _F_decodeJsonUnmarshaler, deref) // UNMARSHAL json, ${t}, ${deref} + self.unmarshal_func(t, f, deref) // UNMARSHAL json, ${t}, ${deref} } func (self *_Assembler) unmarshal_text(t reflect.Type, deref bool) { @@ -1103,7 +1106,15 @@ func (self *_Assembler) unmarshal_func(t reflect.Type, fn obj.Addr, deref bool) self.Emit("MOVQ" , _ARG_sv_n, _DI) // MOVQ sv.n, DI self.call_go(fn) // CALL_GO ${fn} self.Emit("TESTQ", _ET, _ET) // TESTQ ET, ET - self.Sjmp("JNZ" , _LB_error) // JNZ _error + self.Sjmp("JZ" , "_unmarshal_func_end_{n}") // JNZ _error + self.Emit("MOVQ", _I_json_MismatchTypeError, _CX) // MOVQ ET, VAR.et + self.Emit("CMPQ", _ET, _CX) // check if MismatchedError + self.Sjmp("JNE" , _LB_error) + self.Emit("MOVQ", jit.Type(t), _CX) // store current type + self.Emit("MOVQ", _CX, _VAR_et) // store current type + self.Emit("MOVQ", _VAR_ic, _IC) // recover the pos + self.Emit("XORL", _ET, _ET) + self.Link("_unmarshal_func_end_{n}") } /** Dynamic Decoding Routine **/ @@ -1774,11 +1785,19 @@ func (self *_Assembler) _asm_OP_struct_field(p *_Instr) { } func (self *_Assembler) _asm_OP_unmarshal(p *_Instr) { - self.unmarshal_json(p.vt(), true) + if iv := p.i64(); iv != 0 { + self.unmarshal_json(p.vt(), true, _F_decodeJsonUnmarshalerQuoted) + } else { + self.unmarshal_json(p.vt(), true, _F_decodeJsonUnmarshaler) + } } func (self *_Assembler) _asm_OP_unmarshal_p(p *_Instr) { - self.unmarshal_json(p.vt(), false) + if iv := p.i64(); iv != 0 { + self.unmarshal_json(p.vt(), false, _F_decodeJsonUnmarshalerQuoted) + } else { + self.unmarshal_json(p.vt(), false, _F_decodeJsonUnmarshaler) + } } func (self *_Assembler) _asm_OP_unmarshal_text(p *_Instr) { diff --git a/internal/decoder/jitdec/compiler.go b/internal/decoder/jitdec/compiler.go index f61105bc2..2ad3f6d82 100644 --- a/internal/decoder/jitdec/compiler.go +++ b/internal/decoder/jitdec/compiler.go @@ -271,6 +271,13 @@ func newInsVt(op _Op, vt reflect.Type) _Instr { } } +func newInsVtI(op _Op, vt reflect.Type, iv int) _Instr { + return _Instr { + u: packOp(op) | rt.PackInt(iv), + p: unsafe.Pointer(rt.UnpackType(vt)), + } +} + func newInsVf(op _Op, vf *caching.FieldMap) _Instr { return _Instr { u: packOp(op), @@ -452,6 +459,10 @@ func (self *_Program) rtt(op _Op, vt reflect.Type) { *self = append(*self, newInsVt(op, vt)) } +func (self *_Program) rtti(op _Op, vt reflect.Type, iv int) { + *self = append(*self, newInsVtI(op, vt, iv)) +} + func (self *_Program) fmv(op _Op, vf *caching.FieldMap) { *self = append(*self, newInsVf(op, vf)) } @@ -527,35 +538,54 @@ func (self *_Compiler) compile(vt reflect.Type) (ret _Program, err error) { return } -func (self *_Compiler) checkMarshaler(p *_Program, vt reflect.Type) bool { +const ( + checkMarshalerFlags_quoted = 1 +) + +func (self *_Compiler) checkMarshaler(p *_Program, vt reflect.Type, flags int, exec bool) bool { pt := reflect.PtrTo(vt) /* check for `json.Unmarshaler` with pointer receiver */ if pt.Implements(jsonUnmarshalerType) { - p.rtt(_OP_unmarshal_p, pt) + if exec { + p.add(_OP_lspace) + p.rtti(_OP_unmarshal_p, pt, flags) + } return true } /* check for `json.Unmarshaler` */ if vt.Implements(jsonUnmarshalerType) { - p.add(_OP_lspace) - self.compileUnmarshalJson(p, vt) + if exec { + p.add(_OP_lspace) + self.compileUnmarshalJson(p, vt, flags) + } return true } + if flags == checkMarshalerFlags_quoted { + // text marshaler shouldn't be supported for quoted string + return false + } + /* check for `encoding.TextMarshaler` with pointer receiver */ if pt.Implements(encodingTextUnmarshalerType) { - p.add(_OP_lspace) - self.compileUnmarshalTextPtr(p, pt) + if exec { + p.add(_OP_lspace) + self.compileUnmarshalTextPtr(p, pt, flags) + } return true } /* check for `encoding.TextUnmarshaler` */ if vt.Implements(encodingTextUnmarshalerType) { - p.add(_OP_lspace) - self.compileUnmarshalText(p, vt) + if exec { + p.add(_OP_lspace) + self.compileUnmarshalText(p, vt, flags) + } return true } + return false } @@ -567,7 +597,7 @@ func (self *_Compiler) compileOne(p *_Program, sp int, vt reflect.Type) { return } - if self.checkMarshaler(p, vt) { + if self.checkMarshaler(p, vt, 0, true) { return } @@ -690,7 +720,7 @@ func (self *_Compiler) compilePtr(p *_Program, sp int, et reflect.Type) { /* dereference all the way down */ for et.Kind() == reflect.Ptr { - if self.checkMarshaler(p, et) { + if self.checkMarshaler(p, et, 0, true) { return } et = et.Elem() @@ -938,7 +968,22 @@ end_of_object: p.pin(skip) } +func (self *_Compiler) compileStructFieldStrUnmarshal(p *_Program, vt reflect.Type) { + p.add(_OP_lspace) + n0 := p.pc() + p.add(_OP_is_null) + self.checkMarshaler(p, vt, checkMarshalerFlags_quoted, true) + p.pin(n0) +} + func (self *_Compiler) compileStructFieldStr(p *_Program, sp int, vt reflect.Type) { + // according to std, json.Unmarshaler should be called before stringize + // see https://github.com/bytedance/sonic/issues/670 + if self.checkMarshaler(p, vt, checkMarshalerFlags_quoted, false) { + self.compileStructFieldStrUnmarshal(p, vt) + return + } + n1 := -1 ft := vt sv := false @@ -1106,7 +1151,7 @@ func (self *_Compiler) compileUnmarshalEnd(p *_Program, vt reflect.Type, i int) p.pin(j) } -func (self *_Compiler) compileUnmarshalJson(p *_Program, vt reflect.Type) { +func (self *_Compiler) compileUnmarshalJson(p *_Program, vt reflect.Type, flags int) { i := p.pc() v := _OP_unmarshal p.add(_OP_is_null) @@ -1117,11 +1162,11 @@ func (self *_Compiler) compileUnmarshalJson(p *_Program, vt reflect.Type) { } /* call the unmarshaler */ - p.rtt(v, vt) + p.rtti(v, vt, flags) self.compileUnmarshalEnd(p, vt, i) } -func (self *_Compiler) compileUnmarshalText(p *_Program, vt reflect.Type) { +func (self *_Compiler) compileUnmarshalText(p *_Program, vt reflect.Type, iv int) { i := p.pc() v := _OP_unmarshal_text p.add(_OP_is_null) @@ -1134,15 +1179,15 @@ func (self *_Compiler) compileUnmarshalText(p *_Program, vt reflect.Type) { } /* call the unmarshaler */ - p.rtt(v, vt) + p.rtti(v, vt, iv) self.compileUnmarshalEnd(p, vt, i) } -func (self *_Compiler) compileUnmarshalTextPtr(p *_Program, vt reflect.Type) { +func (self *_Compiler) compileUnmarshalTextPtr(p *_Program, vt reflect.Type, iv int) { i := p.pc() p.add(_OP_is_null) p.chr(_OP_match_char, '"') - p.rtt(_OP_unmarshal_text_p, vt) + p.rtti(_OP_unmarshal_text_p, vt, iv) p.pin(i) } diff --git a/internal/decoder/jitdec/generic_regabi_amd64.go b/internal/decoder/jitdec/generic_regabi_amd64.go index e6d5e3e84..2c21944a5 100644 --- a/internal/decoder/jitdec/generic_regabi_amd64.go +++ b/internal/decoder/jitdec/generic_regabi_amd64.go @@ -186,6 +186,7 @@ var ( _T_slice = jit.Type(reflect.TypeOf(([]interface{})(nil))) _T_string = jit.Type(reflect.TypeOf("")) _T_number = jit.Type(reflect.TypeOf(json.Number(""))) + _T_miserr = jit.Type(reflect.TypeOf(MismatchTypeError{})) _T_float64 = jit.Type(reflect.TypeOf(float64(0))) ) diff --git a/internal/decoder/jitdec/primitives.go b/internal/decoder/jitdec/primitives.go index ba865dc7b..5adfc038a 100644 --- a/internal/decoder/jitdec/primitives.go +++ b/internal/decoder/jitdec/primitives.go @@ -39,6 +39,13 @@ func decodeJsonUnmarshaler(vv interface{}, s string) error { return vv.(json.Unmarshaler).UnmarshalJSON(rt.Str2Mem(s)) } +func decodeJsonUnmarshalerQuoted(vv interface{}, s string) error { + if len(s) < 2 || s[0] != '"' || s[len(s)-1] != '"' { + return &MismatchTypeError{} + } + return vv.(json.Unmarshaler).UnmarshalJSON(rt.Str2Mem(s[1:len(s)-1])) +} + func decodeTextUnmarshaler(vv interface{}, s string) error { return vv.(encoding.TextUnmarshaler).UnmarshalText(rt.Str2Mem(s)) } diff --git a/internal/decoder/optdec/compile_struct.go b/internal/decoder/optdec/compile_struct.go index 51552a287..713fb6561 100644 --- a/internal/decoder/optdec/compile_struct.go +++ b/internal/decoder/optdec/compile_struct.go @@ -39,7 +39,43 @@ func (c *compiler) compileIntStringOption(vt reflect.Type) decFunc { panic("unreachable") } +func isInteger(vt reflect.Type) bool { + switch vt.Kind() { + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, reflect.Uintptr, reflect.Int: return true + default: return false + } +} + +func (c *compiler) assertStringOptTypes(vt reflect.Type) { + if c.depth > _CompileMaxDepth { + panic(*stackOverflow) + } + + c.depth += 1 + defer func () { + c.depth -= 1 + }() + + if isInteger(vt) { + return + } + + switch vt.Kind() { + case reflect.String, reflect.Bool, reflect.Float32, reflect.Float64: + return + case reflect.Ptr: c.assertStringOptTypes(vt.Elem()) + default: + panicForInvalidStrType(vt) + } +} + func (c *compiler) compileFieldStringOption(vt reflect.Type) decFunc { + c.assertStringOptTypes(vt) + unmDec := c.tryCompilePtrUnmarshaler(vt, true) + if unmDec != nil { + return unmDec + } + switch vt.Kind() { case reflect.String: if vt == jsonNumberType { @@ -80,7 +116,8 @@ func (c *compiler) compileFieldStringOption(vt reflect.Type) decFunc { deref: c.compileFieldStringOption(vt.Elem()), } default: - panic("string options should appliy only to fields of string, floating point, integer, or boolean types.") + panicForInvalidStrType(vt) + return nil } } diff --git a/internal/decoder/optdec/compiler.go b/internal/decoder/optdec/compiler.go index bb47f91f8..fd164af93 100644 --- a/internal/decoder/optdec/compiler.go +++ b/internal/decoder/optdec/compiler.go @@ -34,7 +34,6 @@ type compiler struct { counts int opts option.CompileOptions namedPtr bool - } func newCompiler() *compiler { @@ -114,7 +113,7 @@ func (c *compiler) compile(vt reflect.Type) decFunc { } } - dec := c.tryCompilePtrUnmarshaler(vt) + dec := c.tryCompilePtrUnmarshaler(vt, false) if dec != nil { return dec } @@ -420,18 +419,23 @@ func (c *compiler) compileMapKey(vt reflect.Type) decKey { } // maybe vt is a named type, and not a pointer receiver, see issue 379 -func (c *compiler) tryCompilePtrUnmarshaler(vt reflect.Type) decFunc { +func (c *compiler) tryCompilePtrUnmarshaler(vt reflect.Type, strOpt bool) decFunc { pt := reflect.PtrTo(vt) /* check for `json.Unmarshaler` with pointer receiver */ if pt.Implements(jsonUnmarshalerType) { return &unmarshalJSONDecoder{ typ: rt.UnpackType(pt), + strOpt: strOpt, } } /* check for `encoding.TextMarshaler` with pointer receiver */ if pt.Implements(encodingTextUnmarshalerType) { + /* TextUnmarshal not support ,strig tag */ + if strOpt { + panicForInvalidStrType(vt) + } return &unmarshalTextDecoder{ typ: rt.UnpackType(pt), } @@ -439,3 +443,7 @@ func (c *compiler) tryCompilePtrUnmarshaler(vt reflect.Type) decFunc { return nil } + +func panicForInvalidStrType(vt reflect.Type) { + panic(error_type(rt.UnpackType(vt))) +} diff --git a/internal/decoder/optdec/interface.go b/internal/decoder/optdec/interface.go index b96d3fb1c..0c063d55f 100644 --- a/internal/decoder/optdec/interface.go +++ b/internal/decoder/optdec/interface.go @@ -131,7 +131,8 @@ func (d *unmarshalTextDecoder) FromDom(vp unsafe.Pointer, node Node, ctx *contex } type unmarshalJSONDecoder struct { - typ *rt.GoType + typ *rt.GoType + strOpt bool } func (d *unmarshalJSONDecoder) FromDom(vp unsafe.Pointer, node Node, ctx *context) error { @@ -140,15 +141,28 @@ func (d *unmarshalJSONDecoder) FromDom(vp unsafe.Pointer, node Node, ctx *contex Value: vp, })) + var input []byte + if d.strOpt && node.IsNull() { + input = []byte("null") + } else if d.strOpt { + s, ok := node.AsStringText(ctx) + if !ok { + return error_mismatch(node, ctx, d.typ.Pack()) + } + input = s + } else { + input = []byte(node.AsRaw(ctx)) + } + // fast path if u, ok := v.(json.Unmarshaler); ok { - return u.UnmarshalJSON([]byte(node.AsRaw(ctx))) + return u.UnmarshalJSON((input)) } // slow path rv := reflect.ValueOf(v) if u, ok := rv.Interface().(json.Unmarshaler); ok { - return u.UnmarshalJSON([]byte(node.AsRaw(ctx))) + return u.UnmarshalJSON(input) } return error_type(d.typ) diff --git a/internal/encoder/compiler.go b/internal/encoder/compiler.go index 034e1d17d..902fbc98b 100644 --- a/internal/encoder/compiler.go +++ b/internal/encoder/compiler.go @@ -127,31 +127,40 @@ func (self *Compiler) compileOne(p *ir.Program, sp int, vt reflect.Type, pv bool } } -func (self *Compiler) compileRec(p *ir.Program, sp int, vt reflect.Type, pv bool) { - pr := self.pv +func (self *Compiler) tryCompileMarshaler(p *ir.Program, vt reflect.Type, pv bool) bool { pt := reflect.PtrTo(vt) /* check for addressable `json.Marshaler` with pointer receiver */ if pv && pt.Implements(vars.JsonMarshalerType) { addMarshalerOp(p, ir.OP_marshal_p, pt, vars.JsonMarshalerType) - return + return true } /* check for `json.Marshaler` */ if vt.Implements(vars.JsonMarshalerType) { self.compileMarshaler(p, ir.OP_marshal, vt, vars.JsonMarshalerType) - return + return true } /* check for addressable `encoding.TextMarshaler` with pointer receiver */ if pv && pt.Implements(vars.EncodingTextMarshalerType) { addMarshalerOp(p, ir.OP_marshal_text_p, pt, vars.EncodingTextMarshalerType) - return + return true } /* check for `encoding.TextMarshaler` */ if vt.Implements(vars.EncodingTextMarshalerType) { self.compileMarshaler(p, ir.OP_marshal_text, vt, vars.EncodingTextMarshalerType) + return true + } + + return false +} + +func (self *Compiler) compileRec(p *ir.Program, sp int, vt reflect.Type, pv bool) { + pr := self.pv + + if self.tryCompileMarshaler(p, vt, pv) { return } @@ -485,6 +494,12 @@ func (self *Compiler) compileStructBody(p *ir.Program, sp int, vt reflect.Type) } func (self *Compiler) compileStructFieldStr(p *ir.Program, sp int, vt reflect.Type) { + // NOTICE: according to encoding/json, Marshaler type has higher priority than string option + // see issue: + if self.tryCompileMarshaler(p, vt, self.pv) { + return + } + pc := -1 ft := vt sv := false diff --git a/issue_test/issue670_test.go b/issue_test/issue670_test.go new file mode 100644 index 000000000..f4605ab8a --- /dev/null +++ b/issue_test/issue670_test.go @@ -0,0 +1,147 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package issue_test + +import ( + "encoding/json" + "fmt" + "reflect" + "testing" + "time" + + "github.com/bytedance/sonic" + "github.com/stretchr/testify/assert" +) + +func TestIssue670_JSONMarshaler(t *testing.T) { + var obj = Issue670JSONMarshaler{ D: Date(time.Now().Unix()) } + so, _ := sonic.MarshalString(obj) + eo, _ := json.Marshal(obj) + assert.Equal(t, string(eo), so) + println(string(eo)) +} + +func TestIssue670_JSONUnmarshaler(t *testing.T) { + // match + eo := []byte(`{"D":"2021-08-26","E":1}`) + et := reflect.TypeOf(Issue670JSONMarshaler{}) + testUnmarshal(t, eo, et, true) + + // mismatch + eo = []byte(`{"D":11,"E":1}`) + testUnmarshal(t, eo, et, true) + + // null + eo = []byte(`{"D":null,"E":1}`) + testUnmarshal(t, eo, et, true) +} + +func testUnmarshal(t *testing.T, eo []byte, rt reflect.Type, checkobj bool) { + obj := reflect.New(rt).Interface() + println(string(eo)) + println("sonic") + es := sonic.Unmarshal(eo, obj) + obj2 := reflect.New(rt).Interface() + println("std") + ee := json.Unmarshal(eo, obj2) + assert.Equal(t, ee ==nil, es == nil, es) + if checkobj { + assert.Equal(t, obj2, obj) + } + fmt.Printf("std: %v, obj: %#v", ee, obj2) + fmt.Printf("sonic error: %v, obj: %#v", es, obj) +} + +func TestIssue670_TextMarshaler(t *testing.T) { + var obj = Issue670TextMarshaler{ D: int(time.Now().Unix()) } + so, _ := sonic.MarshalString(obj) + eo, _ := json.Marshal(obj) + assert.Equal(t, string(eo), so) + println(string(eo)) +} + +func TestIssue670_TextUnmarshaler(t *testing.T) { + // match + eo := []byte(`{"D":"2021-08-26","E":1}`) + et := reflect.TypeOf(Issue670TextMarshaler{}) + testUnmarshal(t, eo, et, false) + + // mismatch + eo = []byte(`{"D":11,"E":1}`) + testUnmarshal(t, eo, et, false) + + // null + eo = []byte(`{"D":null,"E":1}`) + testUnmarshal(t, eo, et, true) +} + +type Issue670JSONMarshaler struct { + D Date `form:"D" json:"D,string" query:"D"` + E int +} + +type Date int64 + +func (d Date) MarshalJSON() ([]byte, error) { + if d == 0 { + return []byte("null"), nil + } + return []byte(fmt.Sprintf("\"%s\"", time.Unix(int64(d), 0).Format("2006-01-02"))), nil +} + +func (d *Date) UnmarshalJSON(in []byte) error { + if string(in) == "null" { + *d = 0 + return nil + } + + println("hook ", string(in)) + t, err := time.Parse("2006-01-02", string(in)) + if err != nil { + return err + } + *d = Date(t.Unix()) + return nil +} + +type Issue670TextMarshaler struct { + D int `form:"D" json:"D,string" query:"D"` + E int +} + + +type Date2 int64 + +func (d Date2) MarshalText() ([]byte, error) { + println("hook 1") + if d == 0 { + return []byte("null"), nil + } + return []byte(fmt.Sprintf("\"%s\"", time.Unix(int64(d), 0).Format("2006-01-02"))), nil +} + +func (d *Date2) UnmarshalText(in []byte) error { + println("hook 2", string(in)) + if string(in) == "null" { + *d = 0 + return nil + } + t, err := time.Parse("2006-01-02", string(in)) + if err != nil { + return err + } + *d = Date2(t.Unix()) + return nil +} \ No newline at end of file