Skip to content

Commit

Permalink
rename modifier functional for nested fields
Browse files Browse the repository at this point in the history
  • Loading branch information
EasterTheBunny committed Jan 21, 2025
1 parent 62443f4 commit c4128ed
Show file tree
Hide file tree
Showing 8 changed files with 376 additions and 57 deletions.
22 changes: 15 additions & 7 deletions pkg/codec/by_item_type_modifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,17 @@ type byItemTypeModifier struct {
modByitemType map[string]Modifier
}

// RetypeToOffChain attempts to apply a modifier using the provided itemType. To allow access to nested fields, this
// function applies no modifications if a modifier by the specified name is not found.
func (b *byItemTypeModifier) RetypeToOffChain(onChainType reflect.Type, itemType string) (reflect.Type, error) {
mod, ok := b.modByitemType[itemType]
head, tail := extendedItemType(itemType).next()

mod, ok := b.modByitemType[head]
if !ok {
return nil, fmt.Errorf("%w: cannot find modifier for %s", types.ErrInvalidType, itemType)
}

return mod.RetypeToOffChain(onChainType, itemType)
return mod.RetypeToOffChain(onChainType, tail)
}

func (b *byItemTypeModifier) TransformToOnChain(offChainValue any, itemType string) (any, error) {
Expand All @@ -40,13 +44,17 @@ func (b *byItemTypeModifier) TransformToOffChain(onChainValue any, itemType stri
}

func (b *byItemTypeModifier) transform(
val any, itemType string, transform func(Modifier, any, string) (any, error)) (any, error) {
mod, ok := b.modByitemType[itemType]
if !ok {
return nil, fmt.Errorf("%w: cannot find modifier for %s", types.ErrInvalidType, itemType)
val any,
itemType string,
transform func(Modifier, any, string) (any, error),
) (any, error) {
head, tail := extendedItemType(itemType).next()

if mod, ok := b.modByitemType[head]; ok {
return transform(mod, val, tail)
}

return transform(mod, val, itemType)
return val, nil
}

var _ Modifier = &byItemTypeModifier{}
58 changes: 54 additions & 4 deletions pkg/codec/encodings/struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package encodings
import (
"fmt"
"reflect"
"strings"

"github.com/smartcontractkit/chainlink-common/pkg/types"
)
Expand All @@ -24,6 +25,8 @@ func NewStructCodec(fields []NamedTypeCodec) (c TopLevelCodec, err error) {

sfs := make([]reflect.StructField, len(fields))
codecFields := make([]TypeCodec, len(fields))
lookup := make(map[string]int)

for i, field := range fields {
ft := field.Codec.GetType()
if ft.Kind() != reflect.Pointer {
Expand All @@ -35,18 +38,22 @@ func NewStructCodec(fields []NamedTypeCodec) (c TopLevelCodec, err error) {
Name: field.Name,
Type: ft,
}

codecFields[i] = field.Codec
lookup[field.Name] = i
}

return &structCodec{
fields: codecFields,
tpe: reflect.PointerTo(reflect.StructOf(sfs)),
fields: codecFields,
fieldLookup: lookup,
tpe: reflect.PointerTo(reflect.StructOf(sfs)),
}, nil
}

type structCodec struct {
fields []TypeCodec
tpe reflect.Type
fields []TypeCodec
fieldLookup map[string]int
tpe reflect.Type
}

func (s *structCodec) Encode(value any, into []byte) ([]byte, error) {
Expand Down Expand Up @@ -113,3 +120,46 @@ func (s *structCodec) SizeAtTopLevel(numItems int) (int, error) {
}
return size, nil
}

func (s *structCodec) FieldCodec(itemType string) (TypeCodec, error) {
path := extendedItemType(itemType)

// itemType could recurse into nested structs
fieldName, tail := path.next()
if fieldName == "" {
return nil, fmt.Errorf("%w: field name required", types.ErrInvalidType)
}

idx, ok := s.fieldLookup[fieldName]
if !ok {
return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType)
}

codec := s.fields[idx]

if tail == "" {
return codec, nil
}

structType, ok := codec.(StructTypeCodec)
if !ok {
return nil, fmt.Errorf("%w: extended path not traversable for type %s", types.ErrInvalidType, itemType)
}

return structType.FieldCodec(tail)
}

type extendedItemType string

func (t extendedItemType) next() (string, string) {
if string(t) == "" {
return "", ""
}

path := strings.Split(string(t), ".")
if len(path) == 1 {
return path[0], ""
}

return path[0], strings.Join(path[1:], ".")
}
62 changes: 48 additions & 14 deletions pkg/codec/encodings/type_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ type TopLevelCodec interface {
SizeAtTopLevel(numItems int) (int, error)
}

type StructTypeCodec interface {
TypeCodec
FieldCodec(string) (TypeCodec, error)
}

// CodecFromTypeCodec maps TypeCodec to types.RemoteCodec, using the key as the itemType
// If the TypeCodec is a TopLevelCodec, GetMaxEncodingSize and GetMaxDecodingSize will call SizeAtTopLevel instead of Size.
type CodecFromTypeCodec map[string]TypeCodec
Expand All @@ -45,9 +50,9 @@ type LenientCodecFromTypeCodec map[string]TypeCodec
var _ types.RemoteCodec = &LenientCodecFromTypeCodec{}

func (c CodecFromTypeCodec) CreateType(itemType string, _ bool) (any, error) {
ntcwt, ok := c[itemType]
if !ok {
return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType)
ntcwt, err := getCodec(c, itemType)
if err != nil {
return nil, err
}

tpe := ntcwt.GetType()
Expand All @@ -59,9 +64,9 @@ func (c CodecFromTypeCodec) CreateType(itemType string, _ bool) (any, error) {
}

func (c CodecFromTypeCodec) Encode(_ context.Context, item any, itemType string) ([]byte, error) {
ntcwt, ok := c[itemType]
if !ok {
return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType)
ntcwt, err := getCodec(c, itemType)
if err != nil {
return nil, err
}

if item != nil {
Expand All @@ -86,14 +91,15 @@ func (c CodecFromTypeCodec) Encode(_ context.Context, item any, itemType string)
}

func (c CodecFromTypeCodec) GetMaxEncodingSize(_ context.Context, n int, itemType string) (int, error) {
ntcwt, ok := c[itemType]
if !ok {
return 0, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType)
ntcwt, err := getCodec(c, itemType)
if err != nil {
return 0, err
}

if lp, ok := ntcwt.(TopLevelCodec); ok {
return lp.SizeAtTopLevel(n)
}

return ntcwt.Size(n)
}

Expand Down Expand Up @@ -121,11 +127,16 @@ func (c LenientCodecFromTypeCodec) Decode(ctx context.Context, raw []byte, into
return decode(c, raw, into, itemType, false)
}

func (c CodecFromTypeCodec) GetMaxDecodingSize(ctx context.Context, n int, itemType string) (int, error) {
return c.GetMaxEncodingSize(ctx, n, itemType)
}

func decode(c map[string]TypeCodec, raw []byte, into any, itemType string, exactSize bool) error {
ntcwt, ok := c[itemType]
if !ok {
return fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType)
ntcwt, err := getCodec(c, itemType)
if err != nil {
return err
}

val, remaining, err := ntcwt.Decode(raw)
if err != nil {
return err
Expand All @@ -138,6 +149,29 @@ func decode(c map[string]TypeCodec, raw []byte, into any, itemType string, exact
return codec.Convert(reflect.ValueOf(val), reflect.ValueOf(into), nil)
}

func (c CodecFromTypeCodec) GetMaxDecodingSize(ctx context.Context, n int, itemType string) (int, error) {
return c.GetMaxEncodingSize(ctx, n, itemType)
func getCodec(c map[string]TypeCodec, itemType string) (TypeCodec, error) {
// itemType could recurse into nested structs
path := extendedItemType(itemType)

// itemType could recurse into nested structs
head, tail := path.next()
if head == "" {
return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType)
}

ntcwt, ok := c[head]
if !ok {
return nil, fmt.Errorf("%w: cannot find type %s", types.ErrInvalidType, itemType)
}

if tail == "" {
return ntcwt, nil
}

structType, ok := ntcwt.(StructTypeCodec)
if !ok {
return nil, fmt.Errorf("%w: extended path not traversable for type %s", types.ErrInvalidType, itemType)
}

return structType.FieldCodec(tail)
}
31 changes: 30 additions & 1 deletion pkg/codec/encodings/type_codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
rawbin "encoding/binary"
"math"
"reflect"
"strings"
"testing"

"github.com/smartcontractkit/libocr/bigbigendian"
Expand Down Expand Up @@ -122,6 +123,34 @@ func TestCodecFromTypeCodecs(t *testing.T) {

assert.Equal(t, singleItemSize*2, actual)
})

t.Run("CreateType works for nested struct values and modifiers", func(t *testing.T) {
itemType := strings.Join([]string{TestItemWithConfigExtra, "AccountStruct", "Account"}, ".")
ts := CreateTestStruct(0, biit)
c := biit.GetCodec(t)

encoded, err := c.Encode(tests.Context(t), ts.AccountStruct.Account, itemType)
require.NoError(t, err)

var actual []byte
require.NoError(t, c.Decode(tests.Context(t), encoded, &actual, itemType))

assert.Equal(t, ts.AccountStruct.Account, actual)
})

t.Run("CreateType works for nested struct values", func(t *testing.T) {
itemType := strings.Join([]string{TestItemType, "NestedDynamicStruct", "Inner", "S"}, ".")
ts := CreateTestStruct(0, biit)
c := biit.GetCodec(t)

encoded, err := c.Encode(tests.Context(t), ts.NestedDynamicStruct.Inner.S, itemType)
require.NoError(t, err)

var actual string
require.NoError(t, c.Decode(tests.Context(t), encoded, &actual, itemType))

assert.Equal(t, ts.NestedDynamicStruct.Inner.S, actual)
})
}

type interfaceTesterBase struct{}
Expand Down Expand Up @@ -319,7 +348,7 @@ func (b *bigEndianInterfaceTester) GetCodec(t *testing.T) types.Codec {
modCodec, err := codec.NewModifierCodec(c, byTypeMod, codec.BigIntHook)
require.NoError(t, err)

_, err = mod.RetypeToOffChain(reflect.PointerTo(testStruct.GetType()), TestItemWithConfigExtra)
_, err = mod.RetypeToOffChain(reflect.PointerTo(testStruct.GetType()), "")
require.NoError(t, err)

return modCodec
Expand Down
4 changes: 3 additions & 1 deletion pkg/codec/hard_coder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package codec

import (
"fmt"
"log"
"reflect"
"strings"

Expand Down Expand Up @@ -81,7 +82,8 @@ func verifyHardCodeKeys(values map[string]any) error {
return nil
}

func (o *onChainHardCoder) TransformToOnChain(offChainValue any, _ string) (any, error) {
func (o *onChainHardCoder) TransformToOnChain(offChainValue any, itemType string) (any, error) {
log.Println(itemType)
return transformWithMaps(offChainValue, o.offToOnChainType, o.onChain, hardCode, o.hooks...)
}

Expand Down
Loading

0 comments on commit c4128ed

Please sign in to comment.