Skip to content

Commit

Permalink
Change CR shared address to handle empty address string in all CR met…
Browse files Browse the repository at this point in the history
…hods
  • Loading branch information
ilija42 committed Feb 3, 2025
1 parent 8aaa689 commit f5191db
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 38 deletions.
140 changes: 112 additions & 28 deletions integration-tests/relayinterface/chain_components_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/gagliardetto/solana-go/rpc"
"github.com/gagliardetto/solana-go/rpc/ws"
"github.com/gagliardetto/solana-go/text"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

commoncodec "github.com/smartcontractkit/chainlink-common/pkg/codec"
Expand All @@ -38,6 +39,12 @@ import (
solanautils "github.com/smartcontractkit/chainlink-solana/pkg/solana/utils"
)

const (
AnyContractNameWithSharedAddress1 = AnyContractName + "Shared1"
AnyContractNameWithSharedAddress2 = AnyContractName + "Shared2"
AnyContractNameWithSharedAddress3 = AnyContractName + "Shared3"
)

func TestChainComponents(t *testing.T) {
t.Parallel()
helper := &helper{}
Expand Down Expand Up @@ -96,7 +103,70 @@ func DisableTests(it *SolanaChainComponentsInterfaceTester[*testing.T]) {
}

func RunChainComponentsSolanaTests[T TestingT[T]](t T, it *SolanaChainComponentsInterfaceTester[T]) {
RunContractReaderSolanaTests(t, it)
testCases := Testcase[T]{
Name: "Test address groups where first namespace shares address with second namespace",
Test: func(t T) {
ctx := tests.Context(t)
cfg := it.contractReaderConfig
cfg.AddressShareGroups = [][]string{{AnyContractNameWithSharedAddress1, AnyContractNameWithSharedAddress2, AnyContractNameWithSharedAddress3}}
cr := it.GetContractReaderWithCustomCfg(t, cfg)

t.Run("Namespace is part of an address share group that doesn't have a registered address and provides no address during Bind", func(t T) {
bound1 := []types.BoundContract{{
Name: AnyContractNameWithSharedAddress1,
}}
require.Error(t, cr.Bind(ctx, bound1))
})

addressToBeShared := it.Helper.CreateAccount(t, AnyValueToReadWithoutAnArgument).String()
t.Run("Namespace is part of an address share group that doesn't have a registered address and provides an address during Bind", func(t T) {
bound1 := []types.BoundContract{{Name: AnyContractNameWithSharedAddress1, Address: addressToBeShared}}

require.NoError(t, cr.Bind(ctx, bound1))

var prim uint64
require.NoError(t, cr.GetLatestValue(ctx, bound1[0].ReadIdentifier(MethodReturningUint64), primitives.Unconfirmed, nil, &prim))
assert.Equal(t, AnyValueToReadWithoutAnArgument, prim)
})

t.Run("Namespace is part of an address share group that has a registered address and provides that same address during Bind", func(t T) {
bound2 := []types.BoundContract{{
Name: AnyContractNameWithSharedAddress2,
Address: addressToBeShared}}
require.NoError(t, cr.Bind(ctx, bound2))

var prim uint64
require.NoError(t, cr.GetLatestValue(ctx, bound2[0].ReadIdentifier(MethodReturningUint64), primitives.Unconfirmed, nil, &prim))
assert.Equal(t, AnyValueToReadWithoutAnArgument, prim)
assert.Equal(t, addressToBeShared, bound2[0].Address)
})

t.Run("Namespace is part of an address share group that has a registered address and provides no address during Bind", func(t T) {
bound3 := []types.BoundContract{{Name: AnyContractNameWithSharedAddress3}}
require.NoError(t, cr.Bind(ctx, bound3))

var prim uint64
require.NoError(t, cr.GetLatestValue(ctx, bound3[0].ReadIdentifier(MethodReturningUint64), primitives.Unconfirmed, nil, &prim))
assert.Equal(t, AnyValueToReadWithoutAnArgument, prim)
assert.Equal(t, addressToBeShared, bound3[0].Address)

// when run in a loop Bind address won't be set, so check if CR Method works without set address.
prim = 0
require.NoError(t, cr.GetLatestValue(ctx, types.BoundContract{
Address: "",
Name: AnyContractNameWithSharedAddress3,
}.ReadIdentifier(MethodReturningUint64), primitives.Unconfirmed, nil, &prim))
assert.Equal(t, AnyValueToReadWithoutAnArgument, prim)
})

t.Run("Namespace is not part of an address share group that has a registered address and provides no address during Bind", func(t T) {
require.Error(t, cr.Bind(ctx, []types.BoundContract{{Name: AnyContractName}}))
})
},
}

RunTests(t, it, []Testcase[T]{testCases})
RunContractReaderTests(t, it)
// Add ChainWriter tests here
}

Expand All @@ -105,20 +175,12 @@ func RunChainComponentsInLoopSolanaTests[T TestingT[T]](t T, it ChainComponentsI
// Add ChainWriter tests here
}

func RunContractReaderSolanaTests[T TestingT[T]](t T, it *SolanaChainComponentsInterfaceTester[T]) {
func RunContractReaderTests[T TestingT[T]](t T, it *SolanaChainComponentsInterfaceTester[T]) {
RunContractReaderInterfaceTests(t, it, false, true)

var testCases []Testcase[T]

RunTests(t, it, testCases)
}

func RunContractReaderInLoopTests[T TestingT[T]](t T, it ChainComponentsInterfaceTester[T]) {
RunContractReaderInterfaceTests(t, it, false, true)

var testCases []Testcase[T]

RunTests(t, it, testCases)
}

type SolanaChainComponentsInterfaceTesterHelper[T TestingT[T]] interface {
Expand All @@ -140,26 +202,28 @@ type SolanaChainComponentsInterfaceTester[T TestingT[T]] struct {
func (it *SolanaChainComponentsInterfaceTester[T]) Setup(t T) {
t.Cleanup(func() {})

it.contractReaderConfig = config.ContractReader{
Namespaces: map[string]config.ChainContractReader{
AnyContractName: {
IDL: mustUnmarshalIDL(t, string(it.Helper.GetJSONEncodedIDL(t))),
Reads: map[string]config.ReadDefinition{
MethodReturningUint64: {
ChainSpecificName: "DataAccount",
ReadType: config.Account,
OutputModifications: commoncodec.ModifiersConfig{
&commoncodec.PropertyExtractorConfig{FieldName: "U64Value"},
},
},
MethodReturningUint64Slice: {
ChainSpecificName: "DataAccount",
OutputModifications: commoncodec.ModifiersConfig{
&commoncodec.PropertyExtractorConfig{FieldName: "U64Slice"},
},
},
anyContractReadDef := config.ChainContractReader{
IDL: mustUnmarshalIDL(t, string(it.Helper.GetJSONEncodedIDL(t))),
Reads: map[string]config.ReadDefinition{
MethodReturningUint64: {
ChainSpecificName: "DataAccount",
ReadType: config.Account,
OutputModifications: commoncodec.ModifiersConfig{
&commoncodec.PropertyExtractorConfig{FieldName: "U64Value"},
},
},
MethodReturningUint64Slice: {
ChainSpecificName: "DataAccount",
OutputModifications: commoncodec.ModifiersConfig{
&commoncodec.PropertyExtractorConfig{FieldName: "U64Slice"},
},
},
},
}

it.contractReaderConfig = config.ContractReader{
Namespaces: map[string]config.ChainContractReader{
AnyContractName: anyContractReadDef,
AnySecondContractName: {
IDL: mustUnmarshalIDL(t, string(it.Helper.GetJSONEncodedIDL(t))),
Reads: map[string]config.ReadDefinition{
Expand All @@ -171,6 +235,10 @@ func (it *SolanaChainComponentsInterfaceTester[T]) Setup(t T) {
},
},
},
// these are for testing shared address groups
AnyContractNameWithSharedAddress1: anyContractReadDef,
AnyContractNameWithSharedAddress2: anyContractReadDef,
AnyContractNameWithSharedAddress3: anyContractReadDef,
},
}
}
Expand Down Expand Up @@ -203,6 +271,22 @@ func (it *SolanaChainComponentsInterfaceTester[T]) GetContractReader(t T) types.
return svc
}

func (it *SolanaChainComponentsInterfaceTester[T]) GetContractReaderWithCustomCfg(t T, cfg config.ContractReader) types.ContractReader {
ctx := it.Helper.Context(t)
if it.cr != nil {
return it.cr
}

svc, err := chainreader.NewChainReaderService(it.Helper.Logger(t), it.Helper.RPCClient(), cfg)

require.NoError(t, err)
require.NoError(t, svc.Start(ctx))

it.cr = svc

return svc
}

func (it *SolanaChainComponentsInterfaceTester[T]) GetContractWriter(t T) types.ContractWriter {
return nil
}
Expand Down
30 changes: 21 additions & 9 deletions pkg/solana/chainreader/bindings.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ func (b *bindingsRegistry) CreateType(namespace, readName string, forEncoding bo
}

func (b *bindingsRegistry) Bind(boundContract *types.BoundContract) error {
if boundContract == nil {
return fmt.Errorf("%w: bound contract is nil", types.ErrInvalidType)
}

if err := b.handleAddressSharing(boundContract); err != nil {
return err
}
Expand All @@ -77,7 +81,7 @@ func (b *bindingsRegistry) Bind(boundContract *types.BoundContract) error {

key, err := solana.PublicKeyFromBase58(boundContract.Address)
if err != nil {
return err
return fmt.Errorf("%w: failed to parse address: %q for contract %q", types.ErrInvalidConfig, boundContract.Address, boundContract.Name)
}

for _, rBinding := range rBindings {
Expand All @@ -96,30 +100,38 @@ func (b *bindingsRegistry) SetCodec(codec types.RemoteCodec) {
}

func (b *bindingsRegistry) handleAddressSharing(boundContract *types.BoundContract) error {
shareGroup, sharesAddress := b.addressShareGroups[boundContract.Name]
if !sharesAddress {
shareGroup, isInAGroup := b.getShareGroup(*boundContract)
if !isInAGroup {
return nil
}

shareGroup.mux.Lock()
defer shareGroup.mux.Unlock()

// set shared address to the binding address
shareGroupAddress := shareGroup.address
if shareGroupAddress.IsZero() {
if shareGroup.address.IsZero() {
key, err := solana.PublicKeyFromBase58(boundContract.Address)
if err != nil {
return err
}
shareGroup.address = key
} else if boundContract.Address != shareGroupAddress.String() && boundContract.Address != "" {
return fmt.Errorf("namespace: %q shares address: %q with namespaceBindings: %v and cannot be bound with a new address: %s", boundContract.Name, shareGroupAddress, shareGroup.group, boundContract.Address)
b.addressShareGroups[boundContract.Name].address, shareGroup.address = key, key
} else if boundContract.Address != shareGroup.address.String() && boundContract.Address != "" {
return fmt.Errorf("namespace: %q shares address: %q with namespaceBindings: %v and cannot be bound with a new address: %s", boundContract.Name, shareGroup.address, shareGroup.group, boundContract.Address)
}

boundContract.Address = shareGroupAddress.String()
boundContract.Address = shareGroup.address.String()
return nil
}

func (b *bindingsRegistry) getShareGroup(boundContract types.BoundContract) (*addressShareGroup, bool) {
shareGroup, sharesAddress := b.addressShareGroups[boundContract.Name]
if !sharesAddress {
return nil, false
}

return shareGroup, sharesAddress
}

func (b *bindingsRegistry) initAddressSharing(addressShareGroups [][]string) error {
b.addressShareGroups = make(map[string]*addressShareGroup)
for _, group := range addressShareGroups {
Expand Down
5 changes: 4 additions & 1 deletion pkg/solana/chainreader/chain_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,16 @@ func (s *SolanaChainReaderService) QueryKey(_ context.Context, _ types.BoundCont
// Bind implements the types.ContractReader interface and allows new contract namespaceBindings to be added
// to the service.
func (s *SolanaChainReaderService) Bind(_ context.Context, bindings []types.BoundContract) error {
fmt.Println("binidngs are ", bindings)
for i := range bindings {
if err := s.bdRegistry.Bind(&bindings[i]); err != nil {
return err
}

s.lookup.bindAddressForContract(bindings[i].Name, bindings[i].Address)
// also bind with an empty address so that we can look up the contract without providing address when calling CR methods
if _, isInAShareGroup := s.bdRegistry.getShareGroup(bindings[i]); isInAShareGroup {
s.lookup.bindAddressForContract(bindings[i].Name, "")
}
}

return nil
Expand Down

0 comments on commit f5191db

Please sign in to comment.