Skip to content

Commit

Permalink
Address PR review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
StephenButtolph committed Oct 31, 2023
1 parent 5b96789 commit 24a9267
Showing 1 changed file with 92 additions and 52 deletions.
144 changes: 92 additions & 52 deletions codec/reflectcodec/type_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,23 @@ import (
"golang.org/x/exp/slices"

"github.com/ava-labs/avalanchego/codec"
"github.com/ava-labs/avalanchego/utils/set"
"github.com/ava-labs/avalanchego/utils/wrappers"
)

// DefaultTagName that enables serialization.
const DefaultTagName = "serialize"
const (
// DefaultTagName that enables serialization.
DefaultTagName = "serialize"
initialSliceLen = 16
)

var (
_ codec.Codec = (*genericCodec)(nil)

errMarshalNil = errors.New("can't marshal nil pointer or interface")
errUnmarshalNil = errors.New("can't unmarshal nil")
errNeedPointer = errors.New("argument to unmarshal must be a pointer")
errRecursiveInterfaceTypes = errors.New("recursive interface types")
errMarshalNil = errors.New("can't marshal nil pointer or interface")
errUnmarshalNil = errors.New("can't unmarshal nil")
errNeedPointer = errors.New("argument to unmarshal must be a pointer")
)

type TypeCodec interface {
Expand Down Expand Up @@ -85,12 +90,15 @@ func (c *genericCodec) Size(value interface{}) (int, error) {
return 0, errMarshalNil // can't marshal nil
}

size, _, err := c.size(reflect.ValueOf(value))
size, _, err := c.size(reflect.ValueOf(value), nil)
return size, err
}

// size returns the size of the value along with whether the value is constant sized.
func (c *genericCodec) size(value reflect.Value) (int, bool, error) {
func (c *genericCodec) size(
value reflect.Value,
typeStack set.Set[reflect.Type],
) (int, bool, error) {
switch valueKind := value.Kind(); valueKind {
case reflect.Uint8:
return wrappers.ByteLen, true, nil
Expand All @@ -117,7 +125,7 @@ func (c *genericCodec) size(value reflect.Value) (int, bool, error) {
// Can't marshal nil pointers (but nil slices are fine)
return 0, false, errMarshalNil
}
return c.size(value.Elem())
return c.size(value.Elem(), typeStack)

case reflect.Interface:
if value.IsNil() {
Expand All @@ -126,11 +134,18 @@ func (c *genericCodec) size(value reflect.Value) (int, bool, error) {
}
underlyingValue := value.Interface()
underlyingType := reflect.TypeOf(underlyingValue)
if typeStack.Contains(underlyingType) {
return 0, false, fmt.Errorf("%w: %s", errRecursiveInterfaceTypes, underlyingType)
}
typeStack.Add(underlyingType)

prefixSize := c.typer.PrefixSize(underlyingType)
valueSize, _, err := c.size(value.Elem())
valueSize, _, err := c.size(value.Elem(), typeStack)
if err != nil {
return 0, false, err
}

typeStack.Remove(underlyingType)
return prefixSize + valueSize, false, nil

case reflect.Slice:
Expand All @@ -139,7 +154,7 @@ func (c *genericCodec) size(value reflect.Value) (int, bool, error) {
return wrappers.IntLen, false, nil
}

size, constSize, err := c.size(value.Index(0))
size, constSize, err := c.size(value.Index(0), typeStack)
if err != nil {
return 0, false, err
}
Expand All @@ -151,7 +166,7 @@ func (c *genericCodec) size(value reflect.Value) (int, bool, error) {
}

for i := 1; i < numElts; i++ {
innerSize, _, err := c.size(value.Index(i))
innerSize, _, err := c.size(value.Index(i), typeStack)
if err != nil {
return 0, false, err
}
Expand All @@ -165,7 +180,7 @@ func (c *genericCodec) size(value reflect.Value) (int, bool, error) {
return 0, true, nil
}

size, constSize, err := c.size(value.Index(0))
size, constSize, err := c.size(value.Index(0), typeStack)
if err != nil {
return 0, false, err
}
Expand All @@ -177,7 +192,7 @@ func (c *genericCodec) size(value reflect.Value) (int, bool, error) {
}

for i := 1; i < numElts; i++ {
innerSize, _, err := c.size(value.Index(i))
innerSize, _, err := c.size(value.Index(i), typeStack)
if err != nil {
return 0, false, err
}
Expand All @@ -196,7 +211,7 @@ func (c *genericCodec) size(value reflect.Value) (int, bool, error) {
constSize = true
)
for _, fieldDesc := range serializedFields {
innerSize, innerConstSize, err := c.size(value.Field(fieldDesc.Index))
innerSize, innerConstSize, err := c.size(value.Field(fieldDesc.Index), typeStack)
if err != nil {
return 0, false, err
}
Expand All @@ -211,11 +226,11 @@ func (c *genericCodec) size(value reflect.Value) (int, bool, error) {
return wrappers.IntLen, false, nil
}

keySize, keyConstSize, err := c.size(iter.Key())
keySize, keyConstSize, err := c.size(iter.Key(), typeStack)
if err != nil {
return 0, false, err
}
valueSize, valueConstSize, err := c.size(iter.Value())
valueSize, valueConstSize, err := c.size(iter.Value(), typeStack)
if err != nil {
return 0, false, err
}
Expand All @@ -230,7 +245,7 @@ func (c *genericCodec) size(value reflect.Value) (int, bool, error) {
totalValueSize = valueSize
)
for iter.Next() {
valueSize, _, err := c.size(iter.Value())
valueSize, _, err := c.size(iter.Value(), typeStack)
if err != nil {
return 0, false, err
}
Expand All @@ -244,7 +259,7 @@ func (c *genericCodec) size(value reflect.Value) (int, bool, error) {
totalKeySize = keySize
)
for iter.Next() {
keySize, _, err := c.size(iter.Key())
keySize, _, err := c.size(iter.Key(), typeStack)
if err != nil {
return 0, false, err
}
Expand All @@ -255,11 +270,11 @@ func (c *genericCodec) size(value reflect.Value) (int, bool, error) {
default:
totalSize := wrappers.IntLen + keySize + valueSize
for iter.Next() {
keySize, _, err := c.size(iter.Key())
keySize, _, err := c.size(iter.Key(), typeStack)
if err != nil {
return 0, false, err
}
valueSize, _, err := c.size(iter.Value())
valueSize, _, err := c.size(iter.Value(), typeStack)
if err != nil {
return 0, false, err
}
Expand All @@ -279,13 +294,18 @@ func (c *genericCodec) MarshalInto(value interface{}, p *wrappers.Packer) error
return errMarshalNil // can't marshal nil
}

return c.marshal(reflect.ValueOf(value), p, c.maxSliceLen)
return c.marshal(reflect.ValueOf(value), p, c.maxSliceLen, nil)
}

// marshal writes the byte representation of [value] to [p]
// [value]'s underlying value must not be a nil pointer or interface
// c.lock should be held for the duration of this function
func (c *genericCodec) marshal(value reflect.Value, p *wrappers.Packer, maxSliceLen uint32) error {
func (c *genericCodec) marshal(
value reflect.Value,
p *wrappers.Packer,
maxSliceLen uint32,
typeStack set.Set[reflect.Type],
) error {
switch valueKind := value.Kind(); valueKind {
case reflect.Uint8:
p.PackByte(uint8(value.Uint()))
Expand Down Expand Up @@ -321,19 +341,24 @@ func (c *genericCodec) marshal(value reflect.Value, p *wrappers.Packer, maxSlice
if value.IsNil() { // Can't marshal nil (except nil slices)
return errMarshalNil
}
return c.marshal(value.Elem(), p, c.maxSliceLen)
return c.marshal(value.Elem(), p, c.maxSliceLen, typeStack)
case reflect.Interface:
if value.IsNil() { // Can't marshal nil (except nil slices)
return errMarshalNil
}
underlyingValue := value.Interface()
underlyingType := reflect.TypeOf(underlyingValue)
if typeStack.Contains(underlyingType) {
return fmt.Errorf("%w: %s", errRecursiveInterfaceTypes, underlyingType)
}
typeStack.Add(underlyingType)
if err := c.typer.PackPrefix(p, underlyingType); err != nil {
return err
}
if err := c.marshal(value.Elem(), p, c.maxSliceLen); err != nil {
if err := c.marshal(value.Elem(), p, c.maxSliceLen, typeStack); err != nil {
return err
}
typeStack.Remove(underlyingType)
return p.Err
case reflect.Slice:
numElts := value.Len() // # elements in the slice/array. 0 if this slice is nil.
Expand Down Expand Up @@ -361,7 +386,7 @@ func (c *genericCodec) marshal(value reflect.Value, p *wrappers.Packer, maxSlice
return p.Err
}
for i := 0; i < numElts; i++ { // Process each element in the slice
if err := c.marshal(value.Index(i), p, c.maxSliceLen); err != nil {
if err := c.marshal(value.Index(i), p, c.maxSliceLen, typeStack); err != nil {
return err
}
}
Expand All @@ -381,7 +406,7 @@ func (c *genericCodec) marshal(value reflect.Value, p *wrappers.Packer, maxSlice
)
}
for i := 0; i < numElts; i++ { // Process each element in the array
if err := c.marshal(value.Index(i), p, c.maxSliceLen); err != nil {
if err := c.marshal(value.Index(i), p, c.maxSliceLen, typeStack); err != nil {
return err
}
}
Expand All @@ -392,7 +417,7 @@ func (c *genericCodec) marshal(value reflect.Value, p *wrappers.Packer, maxSlice
return err
}
for _, fieldDesc := range serializedFields { // Go through all fields of this struct that are serialized
if err := c.marshal(value.Field(fieldDesc.Index), p, fieldDesc.MaxSliceLen); err != nil { // Serialize the field and write to byte array
if err := c.marshal(value.Field(fieldDesc.Index), p, fieldDesc.MaxSliceLen, typeStack); err != nil { // Serialize the field and write to byte array
return err
}
}
Expand Down Expand Up @@ -423,7 +448,7 @@ func (c *genericCodec) marshal(value reflect.Value, p *wrappers.Packer, maxSlice
startOffset := p.Offset
endOffset := p.Offset
for i, key := range keys {
if err := c.marshal(key, p, c.maxSliceLen); err != nil {
if err := c.marshal(key, p, c.maxSliceLen, typeStack); err != nil {
return err
}
if p.Err != nil {
Expand Down Expand Up @@ -456,7 +481,7 @@ func (c *genericCodec) marshal(value reflect.Value, p *wrappers.Packer, maxSlice
}

// serialize and pack value
if err := c.marshal(value.MapIndex(key.key), p, c.maxSliceLen); err != nil {
if err := c.marshal(value.MapIndex(key.key), p, c.maxSliceLen, typeStack); err != nil {
return err
}
}
Expand All @@ -481,7 +506,7 @@ func (c *genericCodec) Unmarshal(bytes []byte, dest interface{}) error {
if destPtr.Kind() != reflect.Ptr {
return errNeedPointer
}
if err := c.unmarshal(&p, destPtr.Elem(), c.maxSliceLen); err != nil {
if err := c.unmarshal(&p, destPtr.Elem(), c.maxSliceLen, nil); err != nil {
return err
}
if p.Offset != len(bytes) {
Expand All @@ -496,7 +521,12 @@ func (c *genericCodec) Unmarshal(bytes []byte, dest interface{}) error {

// Unmarshal from p.Bytes into [value]. [value] must be addressable.
// c.lock should be held for the duration of this function
func (c *genericCodec) unmarshal(p *wrappers.Packer, value reflect.Value, maxSliceLen uint32) error {
func (c *genericCodec) unmarshal(
p *wrappers.Packer,
value reflect.Value,
maxSliceLen uint32,
typeStack set.Set[reflect.Type],
) error {
switch value.Kind() {
case reflect.Uint8:
value.SetUint(uint64(p.UnpackByte()))
Expand Down Expand Up @@ -573,18 +603,22 @@ func (c *genericCodec) unmarshal(p *wrappers.Packer, value reflect.Value, maxSli
}
numElts := int(numElts32)

sliceType := value.Type()
innerType := sliceType.Elem()

// If this is a slice of bytes, manually unpack the bytes rather
// than calling unmarshal on each byte. This improves performance.
if elemKind := value.Type().Elem().Kind(); elemKind == reflect.Uint8 {
if elemKind := innerType.Kind(); elemKind == reflect.Uint8 {
value.SetBytes(p.UnpackFixedBytes(numElts))
return p.Err
}
// set [value] to be a slice of the appropriate type/capacity (right now it is nil)
value.Set(reflect.MakeSlice(value.Type(), numElts, numElts))
// Unmarshal each element into the appropriate index of the slice
// Unmarshal each element and append it into the slice.
value.Set(reflect.MakeSlice(value.Type(), 0, initialSliceLen))
zeroValue := reflect.Zero(innerType)
for i := 0; i < numElts; i++ {
if err := c.unmarshal(p, value.Index(i), c.maxSliceLen); err != nil {
return fmt.Errorf("couldn't unmarshal slice element: %w", err)
value.Set(reflect.Append(value, zeroValue))
if err := c.unmarshal(p, value.Index(i), c.maxSliceLen, typeStack); err != nil {
return err
}
}
return nil
Expand All @@ -601,8 +635,8 @@ func (c *genericCodec) unmarshal(p *wrappers.Packer, value reflect.Value, maxSli
return nil
}
for i := 0; i < numElts; i++ {
if err := c.unmarshal(p, value.Index(i), c.maxSliceLen); err != nil {
return fmt.Errorf("couldn't unmarshal array element: %w", err)
if err := c.unmarshal(p, value.Index(i), c.maxSliceLen, typeStack); err != nil {
return err
}
}
return nil
Expand All @@ -617,11 +651,17 @@ func (c *genericCodec) unmarshal(p *wrappers.Packer, value reflect.Value, maxSli
if err != nil {
return err
}
// Unmarshal into the struct
if err := c.unmarshal(p, intfImplementor, c.maxSliceLen); err != nil {
return fmt.Errorf("couldn't unmarshal interface: %w", err)
intfImplementorType := intfImplementor.Type()
if typeStack.Contains(intfImplementorType) {
return fmt.Errorf("%w: %s", errRecursiveInterfaceTypes, intfImplementorType)
}
// And assign the filled struct to the value
typeStack.Add(intfImplementorType)

if err := c.unmarshal(p, intfImplementor, c.maxSliceLen, typeStack); err != nil {
return err
}

typeStack.Remove(intfImplementorType)
value.Set(intfImplementor)
return nil
case reflect.Struct:
Expand All @@ -632,8 +672,8 @@ func (c *genericCodec) unmarshal(p *wrappers.Packer, value reflect.Value, maxSli
}
// Go through the fields and umarshal into them
for _, fieldDesc := range serializedFieldIndices {
if err := c.unmarshal(p, value.Field(fieldDesc.Index), fieldDesc.MaxSliceLen); err != nil {
return fmt.Errorf("couldn't unmarshal struct: %w", err)
if err := c.unmarshal(p, value.Field(fieldDesc.Index), fieldDesc.MaxSliceLen, typeStack); err != nil {
return err
}
}
return nil
Expand All @@ -643,8 +683,8 @@ func (c *genericCodec) unmarshal(p *wrappers.Packer, value reflect.Value, maxSli
// Create a new pointer to a new value of the underlying type
v := reflect.New(t)
// Fill the value
if err := c.unmarshal(p, v.Elem(), c.maxSliceLen); err != nil {
return fmt.Errorf("couldn't unmarshal pointer: %w", err)
if err := c.unmarshal(p, v.Elem(), c.maxSliceLen, typeStack); err != nil {
return err
}
// Assign to the top-level struct's member
value.Set(v)
Expand All @@ -671,15 +711,15 @@ func (c *genericCodec) unmarshal(p *wrappers.Packer, value reflect.Value, maxSli
)

// Set [value] to be a new map of the appropriate type.
value.Set(reflect.MakeMapWithSize(mapType, numElts))
value.Set(reflect.MakeMap(mapType))

for i := 0; i < numElts; i++ {
mapKey := reflect.New(mapKeyType).Elem()

keyStartOffset := p.Offset

if err := c.unmarshal(p, mapKey, c.maxSliceLen); err != nil {
return fmt.Errorf("couldn't unmarshal map key (%s): %w", mapKeyType, err)
if err := c.unmarshal(p, mapKey, c.maxSliceLen, typeStack); err != nil {
return err
}

// Get the key's byte representation and check that the new key is
Expand All @@ -696,8 +736,8 @@ func (c *genericCodec) unmarshal(p *wrappers.Packer, value reflect.Value, maxSli

// Get the value
mapValue := reflect.New(mapValueType).Elem()
if err := c.unmarshal(p, mapValue, c.maxSliceLen); err != nil {
return fmt.Errorf("couldn't unmarshal map value for key %s: %w", mapKey, err)
if err := c.unmarshal(p, mapValue, c.maxSliceLen, typeStack); err != nil {
return err
}

// Assign the key-value pair in the map
Expand Down

0 comments on commit 24a9267

Please sign in to comment.