Skip to content

Commit

Permalink
Add support for (Un)MarshalAminoJSON override (#323)
Browse files Browse the repository at this point in the history
* Add support for (Un)MarshalAminoJSON override

* Put override higher

* Update codec.go

Co-authored-by: Aaron Craelius <[email protected]>

* Remove dup code

* Address reviews

* Add Changelog entry

* Fix build

Co-authored-by: Aaron Craelius <[email protected]>
  • Loading branch information
amaury1093 and aaronc authored Sep 11, 2020
1 parent af63998 commit ccb15b1
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 7 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@

## [Unreleased]

## 0.16.0 (September 11, 2020)

IMPROVEMENTS:
- Add support for `Un/MarshalAminoJSON` override: if a type implements
`Un/MarshalAminoJSON`, then amino will use these methods for JSON un/marshalling
([#323]).

[#323]: https://github.com/tendermint/go-amino/pull/323

## 0.15.1 (October 10, 2019)

### IMPROVEMENTS:
Expand Down
2 changes: 1 addition & 1 deletion binary-encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (cdc *Codec) encodeReflectBinary(w io.Writer, info *TypeInfo, rv reflect.Va
}()
}

// Handle override if rv implements json.Marshaler.
// Handle override if rv implements MarshalAmino.
if info.IsAminoMarshaler {
// First, encode rv into repr instance.
var rrv, rinfo = reflect.Value{}, (*TypeInfo)(nil)
Expand Down
59 changes: 55 additions & 4 deletions codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,14 @@ type ConcreteInfo struct {

// These fields get set for all concrete types,
// even those not manually registered (e.g. are never interface values).
IsAminoMarshaler bool // Implements MarshalAmino() (<ReprObject>, error).
AminoMarshalReprType reflect.Type // <ReprType>
IsAminoUnmarshaler bool // Implements UnmarshalAmino(<ReprObject>) (error).
AminoUnmarshalReprType reflect.Type // <ReprType>
IsAminoMarshaler bool // Implements MarshalAmino() (<ReprObject>, error).
AminoMarshalReprType reflect.Type // <ReprType>
IsAminoUnmarshaler bool // Implements UnmarshalAmino(<ReprObject>) (error).
AminoUnmarshalReprType reflect.Type // <ReprType>
IsAminoJSONMarshaler bool // Implements MarshalAminoJSON() (<ReprObject>, error).
AminoJSONMarshalReprType reflect.Type // <ReprType>
IsAminoJSONUnmarshaler bool // Implements UnmarshalAminoJSON(<ReprObject>) (error).
AminoJSONUnmarshalReprType reflect.Type // <ReprType>
}

type StructInfo struct {
Expand Down Expand Up @@ -566,6 +570,14 @@ func (cdc *Codec) newTypeInfoUnregistered(rt reflect.Type) *TypeInfo {
info.ConcreteInfo.IsAminoUnmarshaler = true
info.ConcreteInfo.AminoUnmarshalReprType = unmarshalAminoReprType(rm)
}
if rm, ok := rt.MethodByName("MarshalAminoJSON"); ok {
info.ConcreteInfo.IsAminoJSONMarshaler = true
info.ConcreteInfo.AminoJSONMarshalReprType = marshalAminoJSONReprType(rm)
}
if rm, ok := reflect.PtrTo(rt).MethodByName("UnmarshalAminoJSON"); ok {
info.ConcreteInfo.IsAminoJSONUnmarshaler = true
info.ConcreteInfo.AminoJSONUnmarshalReprType = unmarshalAminoJSONReprType(rm)
}
return info
}

Expand Down Expand Up @@ -804,3 +816,42 @@ func unmarshalAminoReprType(rm reflect.Method) (rrt reflect.Type) {
}
return
}

func marshalAminoJSONReprType(rm reflect.Method) (rrt reflect.Type) {
// Verify form of this method.
if rm.Type.NumIn() != 1 {
panic(fmt.Sprintf("MarshalAminoJSON should have 1 input parameters (including receiver); got %v", rm.Type))
}
if rm.Type.NumOut() != 2 {
panic(fmt.Sprintf("MarshalAminoJSON should have 2 output parameters; got %v", rm.Type))
}
if out := rm.Type.Out(1); out != errorType {
panic(fmt.Sprintf("MarshalAminoJSON should have second output parameter of error type, got %v", out))
}
rrt = rm.Type.Out(0)
if rrt.Kind() == reflect.Ptr {
panic(fmt.Sprintf("Representative objects cannot be pointers; got %v", rrt))
}
return
}

func unmarshalAminoJSONReprType(rm reflect.Method) (rrt reflect.Type) {
// Verify form of this method.
if rm.Type.NumIn() != 2 {
panic(fmt.Sprintf("UnmarshalAminoJSON should have 2 input parameters (including receiver); got %v", rm.Type))
}
if in1 := rm.Type.In(0); in1.Kind() != reflect.Ptr {
panic(fmt.Sprintf("UnmarshalAminoJSON first input parameter should be pointer type but got %v", in1))
}
if rm.Type.NumOut() != 1 {
panic(fmt.Sprintf("UnmarshalAminoJSON should have 1 output parameters; got %v", rm.Type))
}
if out := rm.Type.Out(0); out != errorType {
panic(fmt.Sprintf("UnmarshalAminoJSON should have first output parameter of error type, got %v", out))
}
rrt = rm.Type.In(1)
if rrt.Kind() == reflect.Ptr {
panic(fmt.Sprintf("Representative objects cannot be pointers; got %v", rrt))
}
return
}
25 changes: 24 additions & 1 deletion json-decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,29 @@ func (cdc *Codec) decodeReflectJSON(bz []byte, info *TypeInfo, rv reflect.Value,
}
}

// Handle override if a pointer to rv implements UnmarshalAminoJSON.
if info.IsAminoJSONUnmarshaler {
// First, decode repr instance from JSON.
rrv := reflect.New(info.AminoJSONUnmarshalReprType).Elem()
var rinfo *TypeInfo
rinfo, err = cdc.getTypeInfo_wlock(info.AminoJSONUnmarshalReprType)
if err != nil {
return
}
err = cdc.decodeReflectJSON(bz, rinfo, rrv, fopts)
if err != nil {
return
}
// Then, decode from repr instance.
uwrm := rv.Addr().MethodByName("UnmarshalAminoJSON")
uwouts := uwrm.Call([]reflect.Value{rrv})
erri := uwouts[0].Interface()
if erri != nil {
err = erri.(error)
}
return
}

// Handle override if a pointer to rv implements json.Unmarshaler.
if rv.Addr().Type().Implements(jsonUnmarshalerType) {
err = rv.Addr().Interface().(json.Unmarshaler).UnmarshalJSON(bz)
Expand Down Expand Up @@ -401,7 +424,7 @@ func (cdc *Codec) decodeReflectJSONStruct(bz []byte, info *TypeInfo, rv reflect.
// Set nil/zero on frv.
frv.Set(reflect.Zero(frv.Type()))
}

continue
}

Expand Down
23 changes: 22 additions & 1 deletion json-encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,27 @@ func (cdc *Codec) encodeReflectJSON(w io.Writer, info *TypeInfo, rv reflect.Valu
ct := rv.Interface().(time.Time).Round(0).UTC()
rv = reflect.ValueOf(ct)
}

// Handle override if rv implements MarshalAminoJSON.
if info.IsAminoJSONMarshaler {
// First, encode rv into repr instance.
var (
rrv reflect.Value
rinfo *TypeInfo
)
rrv, err = toReprJSONObject(rv)
if err != nil {
return
}
rinfo, err = cdc.getTypeInfo_wlock(info.AminoJSONMarshalReprType)
if err != nil {
return
}
// Then, encode the repr instance.
err = cdc.encodeReflectJSON(w, rinfo, rrv, fopts)
return
}

// Handle override if rv implements json.Marshaler.
if rv.CanAddr() { // Try pointer first.
if rv.Addr().Type().Implements(jsonMarshalerType) {
Expand All @@ -59,7 +80,7 @@ func (cdc *Codec) encodeReflectJSON(w io.Writer, info *TypeInfo, rv reflect.Valu
return
}

// Handle override if rv implements json.Marshaler.
// Handle override if rv implements MarshalAmino.
if info.IsAminoMarshaler {
// First, encode rv into repr instance.
var rrv, rinfo = reflect.Value{}, (*TypeInfo)(nil)
Expand Down
19 changes: 19 additions & 0 deletions reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,22 @@ func toReprObject(rv reflect.Value) (rrv reflect.Value, err error) {
rrv = mwouts[0]
return
}

func toReprJSONObject(rv reflect.Value) (rrv reflect.Value, err error) {
var mwrm reflect.Value
if rv.CanAddr() {
mwrm = rv.Addr().MethodByName("MarshalAminoJSON")
} else {
mwrm = rv.MethodByName("MarshalAminoJSON")
}
mwouts := mwrm.Call(nil)
if !mwouts[1].IsNil() {
erri := mwouts[1].Interface()
if erri != nil {
err = erri.(error)
return rrv, err
}
}
rrv = mwouts[0]
return
}
51 changes: 51 additions & 0 deletions repr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,54 @@ func TestMarshalAminoJSON(t *testing.T) {
assert.Equal(t, f, f2)
assert.Equal(t, f.a, f2.a) // In case the above doesn't check private fields?
}

type Bar struct {
a string
b int
c []*Bar
D string // exposed
}

func (b Bar) MarshalAminoJSON() ([]pair, error) { // nolint: golint
return []pair{
{"a", b.a},
{"b", b.b},
{"c", b.c},
{"D", b.D},
}, nil
}

func (b *Bar) UnmarshalAminoJSON(repr []pair) error {
b.a = repr[0].get("a").(string)
b.b = repr[1].get("b").(int)
b.c = repr[2].get("c").([]*Bar)
b.D = repr[3].get("D").(string)
return nil
}

func TestMarshalAminoJSON_Override(t *testing.T) {

cdc := NewCodec()
cdc.RegisterInterface((*interface{})(nil), nil)
cdc.RegisterConcrete(string(""), "string", nil)
cdc.RegisterConcrete(int(0), "int", nil)
cdc.RegisterConcrete(([]*Bar)(nil), "[]*Bar", nil)

var f = Bar{
a: "K",
b: 2,
c: []*Bar{nil, nil, nil},
D: "J",
}
bz, err := cdc.MarshalJSON(f)
assert.Nil(t, err)

t.Logf("bz %X", bz)

var f2 Bar
err = cdc.UnmarshalJSON(bz, &f2)
assert.Nil(t, err)

assert.Equal(t, f, f2)
assert.Equal(t, f.a, f2.a) // In case the above doesn't check private fields?
}

0 comments on commit ccb15b1

Please sign in to comment.