diff --git a/integration-tests/relayinterface/chain_components_test.go b/integration-tests/relayinterface/chain_components_test.go index 330b6e65d..0ec4e60fc 100644 --- a/integration-tests/relayinterface/chain_components_test.go +++ b/integration-tests/relayinterface/chain_components_test.go @@ -18,6 +18,7 @@ import ( "github.com/gagliardetto/solana-go/rpc" "github.com/gagliardetto/solana-go/rpc/ws" "github.com/gagliardetto/solana-go/text" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" commoncodec "github.com/smartcontractkit/chainlink-common/pkg/codec" @@ -38,6 +39,12 @@ import ( solanautils "github.com/smartcontractkit/chainlink-solana/pkg/solana/utils" ) +const ( + AnyContractNameWithSharedAddress1 = AnyContractName + "Shared1" + AnyContractNameWithSharedAddress2 = AnyContractName + "Shared2" + AnyContractNameWithSharedAddress3 = AnyContractName + "Shared3" +) + func TestChainComponents(t *testing.T) { t.Parallel() helper := &helper{} @@ -96,7 +103,70 @@ func DisableTests(it *SolanaChainComponentsInterfaceTester[*testing.T]) { } func RunChainComponentsSolanaTests[T TestingT[T]](t T, it *SolanaChainComponentsInterfaceTester[T]) { - RunContractReaderSolanaTests(t, it) + testCases := Testcase[T]{ + Name: "Test address groups where first namespace shares address with second namespace", + Test: func(t T) { + ctx := tests.Context(t) + cfg := it.contractReaderConfig + cfg.AddressShareGroups = [][]string{{AnyContractNameWithSharedAddress1, AnyContractNameWithSharedAddress2, AnyContractNameWithSharedAddress3}} + cr := it.GetContractReaderWithCustomCfg(t, cfg) + + t.Run("Namespace is part of an address share group that doesn't have a registered address and provides no address during Bind", func(t T) { + bound1 := []types.BoundContract{{ + Name: AnyContractNameWithSharedAddress1, + }} + require.Error(t, cr.Bind(ctx, bound1)) + }) + + addressToBeShared := it.Helper.CreateAccount(t, AnyValueToReadWithoutAnArgument).String() + t.Run("Namespace is part of an address share group that doesn't have a registered address and provides an address during Bind", func(t T) { + bound1 := []types.BoundContract{{Name: AnyContractNameWithSharedAddress1, Address: addressToBeShared}} + + require.NoError(t, cr.Bind(ctx, bound1)) + + var prim uint64 + require.NoError(t, cr.GetLatestValue(ctx, bound1[0].ReadIdentifier(MethodReturningUint64), primitives.Unconfirmed, nil, &prim)) + assert.Equal(t, AnyValueToReadWithoutAnArgument, prim) + }) + + t.Run("Namespace is part of an address share group that has a registered address and provides that same address during Bind", func(t T) { + bound2 := []types.BoundContract{{ + Name: AnyContractNameWithSharedAddress2, + Address: addressToBeShared}} + require.NoError(t, cr.Bind(ctx, bound2)) + + var prim uint64 + require.NoError(t, cr.GetLatestValue(ctx, bound2[0].ReadIdentifier(MethodReturningUint64), primitives.Unconfirmed, nil, &prim)) + assert.Equal(t, AnyValueToReadWithoutAnArgument, prim) + assert.Equal(t, addressToBeShared, bound2[0].Address) + }) + + t.Run("Namespace is part of an address share group that has a registered address and provides no address during Bind", func(t T) { + bound3 := []types.BoundContract{{Name: AnyContractNameWithSharedAddress3}} + require.NoError(t, cr.Bind(ctx, bound3)) + + var prim uint64 + require.NoError(t, cr.GetLatestValue(ctx, bound3[0].ReadIdentifier(MethodReturningUint64), primitives.Unconfirmed, nil, &prim)) + assert.Equal(t, AnyValueToReadWithoutAnArgument, prim) + assert.Equal(t, addressToBeShared, bound3[0].Address) + + // when run in a loop Bind address won't be set, so check if CR Method works without set address. + prim = 0 + require.NoError(t, cr.GetLatestValue(ctx, types.BoundContract{ + Address: "", + Name: AnyContractNameWithSharedAddress3, + }.ReadIdentifier(MethodReturningUint64), primitives.Unconfirmed, nil, &prim)) + assert.Equal(t, AnyValueToReadWithoutAnArgument, prim) + }) + + t.Run("Namespace is not part of an address share group that has a registered address and provides no address during Bind", func(t T) { + require.Error(t, cr.Bind(ctx, []types.BoundContract{{Name: AnyContractName}})) + }) + }, + } + + RunTests(t, it, []Testcase[T]{testCases}) + RunContractReaderTests(t, it) // Add ChainWriter tests here } @@ -105,20 +175,12 @@ func RunChainComponentsInLoopSolanaTests[T TestingT[T]](t T, it ChainComponentsI // Add ChainWriter tests here } -func RunContractReaderSolanaTests[T TestingT[T]](t T, it *SolanaChainComponentsInterfaceTester[T]) { +func RunContractReaderTests[T TestingT[T]](t T, it *SolanaChainComponentsInterfaceTester[T]) { RunContractReaderInterfaceTests(t, it, false, true) - - var testCases []Testcase[T] - - RunTests(t, it, testCases) } func RunContractReaderInLoopTests[T TestingT[T]](t T, it ChainComponentsInterfaceTester[T]) { RunContractReaderInterfaceTests(t, it, false, true) - - var testCases []Testcase[T] - - RunTests(t, it, testCases) } type SolanaChainComponentsInterfaceTesterHelper[T TestingT[T]] interface { @@ -140,26 +202,28 @@ type SolanaChainComponentsInterfaceTester[T TestingT[T]] struct { func (it *SolanaChainComponentsInterfaceTester[T]) Setup(t T) { t.Cleanup(func() {}) - it.contractReaderConfig = config.ContractReader{ - Namespaces: map[string]config.ChainContractReader{ - AnyContractName: { - IDL: mustUnmarshalIDL(t, string(it.Helper.GetJSONEncodedIDL(t))), - Reads: map[string]config.ReadDefinition{ - MethodReturningUint64: { - ChainSpecificName: "DataAccount", - ReadType: config.Account, - OutputModifications: commoncodec.ModifiersConfig{ - &commoncodec.PropertyExtractorConfig{FieldName: "U64Value"}, - }, - }, - MethodReturningUint64Slice: { - ChainSpecificName: "DataAccount", - OutputModifications: commoncodec.ModifiersConfig{ - &commoncodec.PropertyExtractorConfig{FieldName: "U64Slice"}, - }, - }, + anyContractReadDef := config.ChainContractReader{ + IDL: mustUnmarshalIDL(t, string(it.Helper.GetJSONEncodedIDL(t))), + Reads: map[string]config.ReadDefinition{ + MethodReturningUint64: { + ChainSpecificName: "DataAccount", + ReadType: config.Account, + OutputModifications: commoncodec.ModifiersConfig{ + &commoncodec.PropertyExtractorConfig{FieldName: "U64Value"}, + }, + }, + MethodReturningUint64Slice: { + ChainSpecificName: "DataAccount", + OutputModifications: commoncodec.ModifiersConfig{ + &commoncodec.PropertyExtractorConfig{FieldName: "U64Slice"}, }, }, + }, + } + + it.contractReaderConfig = config.ContractReader{ + Namespaces: map[string]config.ChainContractReader{ + AnyContractName: anyContractReadDef, AnySecondContractName: { IDL: mustUnmarshalIDL(t, string(it.Helper.GetJSONEncodedIDL(t))), Reads: map[string]config.ReadDefinition{ @@ -171,6 +235,10 @@ func (it *SolanaChainComponentsInterfaceTester[T]) Setup(t T) { }, }, }, + // these are for testing shared address groups + AnyContractNameWithSharedAddress1: anyContractReadDef, + AnyContractNameWithSharedAddress2: anyContractReadDef, + AnyContractNameWithSharedAddress3: anyContractReadDef, }, } } @@ -203,6 +271,22 @@ func (it *SolanaChainComponentsInterfaceTester[T]) GetContractReader(t T) types. return svc } +func (it *SolanaChainComponentsInterfaceTester[T]) GetContractReaderWithCustomCfg(t T, cfg config.ContractReader) types.ContractReader { + ctx := it.Helper.Context(t) + if it.cr != nil { + return it.cr + } + + svc, err := chainreader.NewChainReaderService(it.Helper.Logger(t), it.Helper.RPCClient(), cfg) + + require.NoError(t, err) + require.NoError(t, svc.Start(ctx)) + + it.cr = svc + + return svc +} + func (it *SolanaChainComponentsInterfaceTester[T]) GetContractWriter(t T) types.ContractWriter { return nil } diff --git a/pkg/solana/chainreader/bindings.go b/pkg/solana/chainreader/bindings.go index 94d1667c8..2d989ddf9 100644 --- a/pkg/solana/chainreader/bindings.go +++ b/pkg/solana/chainreader/bindings.go @@ -66,6 +66,10 @@ func (b *bindingsRegistry) CreateType(namespace, readName string, forEncoding bo } func (b *bindingsRegistry) Bind(boundContract *types.BoundContract) error { + if boundContract == nil { + return fmt.Errorf("%w: bound contract is nil", types.ErrInvalidType) + } + if err := b.handleAddressSharing(boundContract); err != nil { return err } @@ -77,7 +81,7 @@ func (b *bindingsRegistry) Bind(boundContract *types.BoundContract) error { key, err := solana.PublicKeyFromBase58(boundContract.Address) if err != nil { - return err + return fmt.Errorf("%w: failed to parse address: %q for contract %q", types.ErrInvalidConfig, boundContract.Address, boundContract.Name) } for _, rBinding := range rBindings { @@ -96,8 +100,8 @@ func (b *bindingsRegistry) SetCodec(codec types.RemoteCodec) { } func (b *bindingsRegistry) handleAddressSharing(boundContract *types.BoundContract) error { - shareGroup, sharesAddress := b.addressShareGroups[boundContract.Name] - if !sharesAddress { + shareGroup, isInAGroup := b.getShareGroup(*boundContract) + if !isInAGroup { return nil } @@ -105,21 +109,29 @@ func (b *bindingsRegistry) handleAddressSharing(boundContract *types.BoundContra defer shareGroup.mux.Unlock() // set shared address to the binding address - shareGroupAddress := shareGroup.address - if shareGroupAddress.IsZero() { + if shareGroup.address.IsZero() { key, err := solana.PublicKeyFromBase58(boundContract.Address) if err != nil { return err } - shareGroup.address = key - } else if boundContract.Address != shareGroupAddress.String() && boundContract.Address != "" { - return fmt.Errorf("namespace: %q shares address: %q with namespaceBindings: %v and cannot be bound with a new address: %s", boundContract.Name, shareGroupAddress, shareGroup.group, boundContract.Address) + b.addressShareGroups[boundContract.Name].address, shareGroup.address = key, key + } else if boundContract.Address != shareGroup.address.String() && boundContract.Address != "" { + return fmt.Errorf("namespace: %q shares address: %q with namespaceBindings: %v and cannot be bound with a new address: %s", boundContract.Name, shareGroup.address, shareGroup.group, boundContract.Address) } - boundContract.Address = shareGroupAddress.String() + boundContract.Address = shareGroup.address.String() return nil } +func (b *bindingsRegistry) getShareGroup(boundContract types.BoundContract) (*addressShareGroup, bool) { + shareGroup, sharesAddress := b.addressShareGroups[boundContract.Name] + if !sharesAddress { + return nil, false + } + + return shareGroup, sharesAddress +} + func (b *bindingsRegistry) initAddressSharing(addressShareGroups [][]string) error { b.addressShareGroups = make(map[string]*addressShareGroup) for _, group := range addressShareGroups { diff --git a/pkg/solana/chainreader/chain_reader.go b/pkg/solana/chainreader/chain_reader.go index aea7865d8..7c20570d2 100644 --- a/pkg/solana/chainreader/chain_reader.go +++ b/pkg/solana/chainreader/chain_reader.go @@ -208,13 +208,16 @@ func (s *SolanaChainReaderService) QueryKey(_ context.Context, _ types.BoundCont // Bind implements the types.ContractReader interface and allows new contract namespaceBindings to be added // to the service. func (s *SolanaChainReaderService) Bind(_ context.Context, bindings []types.BoundContract) error { - fmt.Println("binidngs are ", bindings) for i := range bindings { if err := s.bdRegistry.Bind(&bindings[i]); err != nil { return err } s.lookup.bindAddressForContract(bindings[i].Name, bindings[i].Address) + // also bind with an empty address so that we can look up the contract without providing address when calling CR methods + if _, isInAShareGroup := s.bdRegistry.getShareGroup(bindings[i]); isInAShareGroup { + s.lookup.bindAddressForContract(bindings[i].Name, "") + } } return nil