From ae9b6c410d465de6f54c933036dbc4b14241e391 Mon Sep 17 00:00:00 2001 From: Eric Chlebek Date: Wed, 22 Nov 2017 10:01:48 -0800 Subject: [PATCH] Add Unmarshal function. Unexport GetExtendedAttributes. This change makes the package significantly easier to use, and also presents a greater symmetry between its types and functions. This comes at a cost of additional reflection, which can be seen in the benchmark for Unmarshal. Old: BenchmarkUnmarshal-4 6492 ns/op 6568 B/op 57 allocs/op New: BenchmarkUnmarshal-4 9244 ns/op 7592 B/op 77 allocs/op Refs #586 --- types/dynamic/dynamic.go | 74 ++++++++++++++++++++++++++++------- types/dynamic/dynamic_test.go | 73 ++++++++++++++++------------------ 2 files changed, 93 insertions(+), 54 deletions(-) diff --git a/types/dynamic/dynamic.go b/types/dynamic/dynamic.go index fc349f6cb9..4f27e26018 100644 --- a/types/dynamic/dynamic.go +++ b/types/dynamic/dynamic.go @@ -1,6 +1,7 @@ package dynamic import ( + "encoding/json" "fmt" "reflect" "sort" @@ -9,11 +10,16 @@ import ( jsoniter "github.com/json-iterator/go" ) -// ExtendedAttributer is for use with GetField. It allows GetField to access -// serialized extended attributes. -type ExtendedAttributer interface { - // ExtendedAttributes returns json-serialized extended attributes. - ExtendedAttributes() []byte +// Attributes hold arbitrary JSON-encoded data. +type Attributes struct { + data []byte +} + +// Implement Attributer to enable a type to work with the Marshal and Unmarshal +// functions in this package. +type Attributer interface { + Attributes() Attributes + SetAttributes(Attributes) } // GetField gets a field from v according to its name. @@ -21,7 +27,7 @@ type ExtendedAttributer interface { // it will try to dynamically find the corresponding item in the 'Extended' // field. GetField is case-sensitive, but extended attribute names will be // converted to CamelCaps. -func GetField(v ExtendedAttributer, name string) (interface{}, error) { +func GetField(v Attributer, name string) (interface{}, error) { strukt := reflect.Indirect(reflect.ValueOf(v)) if kind := strukt.Kind(); kind != reflect.Struct { return nil, fmt.Errorf("invalid type (want struct): %v", kind) @@ -32,7 +38,7 @@ func GetField(v ExtendedAttributer, name string) (interface{}, error) { return field.Value.Interface(), nil } // If we get here, we are dealing with extended attributes. - return getExtendedAttribute(v.ExtendedAttributes(), name) + return getExtendedAttribute(v.Attributes().data, name) } // getExtendedAttribute dynamically builds a concrete type. If the concrete @@ -189,19 +195,19 @@ func getJSONFields(v reflect.Value) map[string]structField { return result } -// ExtractExtendedAttributes selects only extended attributes from msg. It will +// extractExtendedAttributes selects only extended attributes from msg. It will // ignore any fields in msg that correspond to fields in v. v must be of kind // reflect.Struct. -func ExtractExtendedAttributes(v interface{}, msg []byte) ([]byte, error) { +func extractExtendedAttributes(v interface{}, msg []byte) (Attributes, error) { strukt := reflect.Indirect(reflect.ValueOf(v)) if kind := strukt.Kind(); kind != reflect.Struct { - return nil, fmt.Errorf("invalid type (want struct): %v", kind) + return Attributes{}, fmt.Errorf("invalid type (want struct): %v", kind) } fields := getJSONFields(strukt) stream := jsoniter.NewStream(jsoniter.ConfigDefault, nil, 4096) var anys map[string]jsoniter.Any if err := jsoniter.Unmarshal(msg, &anys); err != nil { - return nil, err + return Attributes{}, err } stream.WriteObjectStart() j := 0 @@ -219,14 +225,54 @@ func ExtractExtendedAttributes(v interface{}, msg []byte) ([]byte, error) { any.WriteTo(stream) } stream.WriteObjectEnd() - return stream.Buffer(), nil + return Attributes{data: stream.Buffer()}, nil +} + +// Unmarshal decodes msg into v, storing what fields it can into the basic +// fields of the struct, and storing the rest into Attributes. +func Unmarshal(msg []byte, v Attributer) error { + if _, ok := v.(json.Unmarshaler); ok { + // Can't safely call UnmarshalJSON here without potentially causing an + // infinite recursion. Copy the struct into a new type that doesn't + // implement the method. + oldVal := reflect.Indirect(reflect.ValueOf(v)) + typ := oldVal.Type() + numField := typ.NumField() + fields := make([]reflect.StructField, 0, numField) + for i := 0; i < numField; i++ { + field := typ.Field(i) + if len(field.PkgPath) == 0 { + fields = append(fields, field) + } + } + newType := reflect.StructOf(fields) + newPtr := reflect.New(newType) + newVal := reflect.Indirect(newPtr) + if err := json.Unmarshal(msg, newPtr.Interface()); err != nil { + return err + } + for _, field := range fields { + oldVal.FieldByName(field.Name).Set(newVal.FieldByName(field.Name)) + } + } else { + if err := json.Unmarshal(msg, v); err != nil { + return err + } + } + + attrs, err := extractExtendedAttributes(v, msg) + if err != nil { + return err + } + v.SetAttributes(attrs) + return nil } // Marshal encodes the struct fields in v that are valid to encode. // It also encodes any extended attributes that are defined. Marshal // respects the encoding/json rules regarding exported fields, and tag // semantics. If v's kind is not reflect.Struct, an error will be returned. -func Marshal(v ExtendedAttributer) ([]byte, error) { +func Marshal(v Attributer) ([]byte, error) { s := jsoniter.NewStream(jsoniter.ConfigDefault, nil, 4096) s.WriteObjectStart() @@ -234,7 +280,7 @@ func Marshal(v ExtendedAttributer) ([]byte, error) { return nil, err } - extended := v.ExtendedAttributes() + extended := v.Attributes().data if len(extended) > 0 { if err := encodeExtendedFields(extended, s); err != nil { return nil, err diff --git a/types/dynamic/dynamic_test.go b/types/dynamic/dynamic_test.go index 71a5f0345d..f90a491285 100644 --- a/types/dynamic/dynamic_test.go +++ b/types/dynamic/dynamic_test.go @@ -84,34 +84,27 @@ type MyType struct { Foo string `json:"foo"` Bar []MyType `json:"bar"` - extended []byte + attrs Attributes } -func (m MyType) ExtendedAttributes() []byte { - return m.extended +func (m *MyType) Attributes() Attributes { + return m.attrs } -func (m MyType) Get(name string) (interface{}, error) { +func (m *MyType) SetAttributes(a Attributes) { + m.attrs = a +} + +func (m *MyType) Get(name string) (interface{}, error) { return GetField(m, name) } -func (m MyType) MarshalJSON() ([]byte, error) { +func (m *MyType) MarshalJSON() ([]byte, error) { return Marshal(m) } func (m *MyType) UnmarshalJSON(p []byte) error { - type temporary MyType - var x temporary - if err := json.Unmarshal(p, &x); err != nil { - return err - } - *m = MyType(x) - extended, err := ExtractExtendedAttributes(m, p) - if err != nil { - return err - } - m.extended = extended - return nil + return Unmarshal(p, m) } func TestExtractEmptyExtendedAttributes(t *testing.T) { @@ -121,9 +114,9 @@ func TestExtractEmptyExtendedAttributes(t *testing.T) { msg := []byte(`{"foo": "hello, world!","bar":[{"foo":"o hai"}]}`) var m MyType - attrs, err := ExtractExtendedAttributes(m, msg) + attrs, err := extractExtendedAttributes(m, msg) require.Nil(err) - assert.Equal([]byte("{}"), attrs) + assert.Equal([]byte("{}"), attrs.data) } func TestExtractExtendedAttributes(t *testing.T) { @@ -133,9 +126,9 @@ func TestExtractExtendedAttributes(t *testing.T) { msg := []byte(`{"foo": "hello, world!","bar":[{"foo":"o hai"}], "extendedattr": "such extended"}`) var m MyType - attrs, err := ExtractExtendedAttributes(m, msg) + attrs, err := extractExtendedAttributes(m, msg) require.Nil(err) - assert.Equal([]byte(`{"extendedattr":"such extended"}`), attrs) + assert.Equal([]byte(`{"extendedattr":"such extended"}`), attrs.data) } func TestMarshal(t *testing.T) { @@ -144,10 +137,10 @@ func TestMarshal(t *testing.T) { extendedBytes := []byte(`{"a":1,"b":2.0,"c":true,"d":"false","e":[1,2,3],"f":{"foo":"bar"}}`) expBytes := []byte(`{"bar":null,"foo":"hello world!","a":1,"b":2.0,"c":true,"d":"false","e":[1,2,3],"f":{"foo":"bar"}}`) - m := MyType{ - Foo: "hello world!", - Bar: nil, - extended: extendedBytes, + m := &MyType{ + Foo: "hello world!", + Bar: nil, + attrs: Attributes{data: extendedBytes}, } b, err := Marshal(m) @@ -156,10 +149,10 @@ func TestMarshal(t *testing.T) { } func TestGetField(t *testing.T) { - m := MyType{ - Foo: "hello", - Bar: []MyType{{Foo: "there"}}, - extended: []byte(`{"a":"a","b":1,"c":2.0,"d":true,"e":null,"foo":{"hello":5},"bar":[true,10.5]}`), + m := &MyType{ + Foo: "hello", + Bar: []MyType{{Foo: "there"}}, + attrs: Attributes{data: []byte(`{"a":"a","b":1,"c":2.0,"d":true,"e":null,"foo":{"hello":5},"bar":[true,10.5]}`)}, } tests := []struct { @@ -217,8 +210,8 @@ func TestGetField(t *testing.T) { } func TestQueryGovaluateSimple(t *testing.T) { - m := MyType{ - extended: []byte(`{"hello":5}`), + m := &MyType{ + attrs: Attributes{data: []byte(`{"hello":5}`)}, } expr, err := govaluate.NewEvaluableExpression("hello == 5") @@ -239,8 +232,8 @@ func TestQueryGovaluateSimple(t *testing.T) { } func BenchmarkQueryGovaluateSimple(b *testing.B) { - m := MyType{ - extended: []byte(`{"hello":5}`), + m := &MyType{ + attrs: Attributes{data: []byte(`{"hello":5}`)}, } expr, err := govaluate.NewEvaluableExpression("hello == 5") @@ -254,8 +247,8 @@ func BenchmarkQueryGovaluateSimple(b *testing.B) { } func TestQueryGovaluateComplex(t *testing.T) { - m := MyType{ - extended: []byte(`{"hello":{"foo":5,"bar":6.0}}`), + m := &MyType{ + attrs: Attributes{data: []byte(`{"hello":{"foo":5,"bar":6.0}}`)}, } expr, err := govaluate.NewEvaluableExpression("hello.Foo == 5") @@ -284,8 +277,8 @@ func TestQueryGovaluateComplex(t *testing.T) { } func BenchmarkQueryGovaluateComplex(b *testing.B) { - m := MyType{ - extended: []byte(`{"hello":{"foo":5,"bar":6.0}}`), + m := &MyType{ + attrs: Attributes{data: []byte(`{"hello":{"foo":5,"bar":6.0}}`)}, } expr, err := govaluate.NewEvaluableExpression("hello.Foo == 5") @@ -303,8 +296,8 @@ func TestMarshalUnmarshal(t *testing.T) { var m MyType err := json.Unmarshal(data, &m) require.Nil(t, err) - assert.Equal(t, MyType{Foo: "hello", extended: []byte(`{"a":10,"b":"c"}`)}, m) - b, err := json.Marshal(m) + assert.Equal(t, MyType{Foo: "hello", attrs: Attributes{data: []byte(`{"a":10,"b":"c"}`)}}, m) + b, err := json.Marshal(&m) require.Nil(t, err) assert.Equal(t, data, b) } @@ -323,6 +316,6 @@ func BenchmarkMarshal(b *testing.B) { json.Unmarshal(data, &m) b.ResetTimer() for i := 0; i < b.N; i++ { - json.Marshal(m) + json.Marshal(&m) } }