diff --git a/pkg/api/message/publish_test.go b/pkg/api/message/publish_test.go index c5027240..a73bbc42 100644 --- a/pkg/api/message/publish_test.go +++ b/pkg/api/message/publish_test.go @@ -16,7 +16,7 @@ import ( ) func TestPublishEnvelope(t *testing.T) { - api, db, cleanup := apiTestUtils.NewTestAPIClient(t) + api, db, cleanup := apiTestUtils.NewTestReplicationAPIClient(t) defer cleanup() payerEnvelope := envelopeTestUtils.CreatePayerEnvelope(t) @@ -64,7 +64,7 @@ func TestPublishEnvelope(t *testing.T) { } func TestUnmarshalErrorOnPublish(t *testing.T) { - api, _, cleanup := apiTestUtils.NewTestAPIClient(t) + api, _, cleanup := apiTestUtils.NewTestReplicationAPIClient(t) defer cleanup() envelope := envelopeTestUtils.CreatePayerEnvelope(t) @@ -79,7 +79,7 @@ func TestUnmarshalErrorOnPublish(t *testing.T) { } func TestMismatchingOriginatorOnPublish(t *testing.T) { - api, _, cleanup := apiTestUtils.NewTestAPIClient(t) + api, _, cleanup := apiTestUtils.NewTestReplicationAPIClient(t) defer cleanup() clientEnv := envelopeTestUtils.CreateClientEnvelope() @@ -96,7 +96,7 @@ func TestMismatchingOriginatorOnPublish(t *testing.T) { } func TestMissingTopicOnPublish(t *testing.T) { - api, _, cleanup := apiTestUtils.NewTestAPIClient(t) + api, _, cleanup := apiTestUtils.NewTestReplicationAPIClient(t) defer cleanup() clientEnv := envelopeTestUtils.CreateClientEnvelope() diff --git a/pkg/api/message/subscribe_test.go b/pkg/api/message/subscribe_test.go index 938f419b..e79c4cc9 100644 --- a/pkg/api/message/subscribe_test.go +++ b/pkg/api/message/subscribe_test.go @@ -70,7 +70,7 @@ func setupTest(t *testing.T) (message_api.ReplicationApiClient, *sql.DB, func()) }, } - return testUtilsApi.NewTestAPIClient(t) + return testUtilsApi.NewTestReplicationAPIClient(t) } func insertInitialRows(t *testing.T, store *sql.DB) { diff --git a/pkg/api/payer/publish_test.go b/pkg/api/payer/publish_test.go new file mode 100644 index 00000000..66bb1f60 --- /dev/null +++ b/pkg/api/payer/publish_test.go @@ -0,0 +1,127 @@ +package payer_test + +import ( + "context" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/xmtp/xmtpd/pkg/api/payer" + "github.com/xmtp/xmtpd/pkg/envelopes" + blockchainMocks "github.com/xmtp/xmtpd/pkg/mocks/blockchain" + registryMocks "github.com/xmtp/xmtpd/pkg/mocks/registry" + "github.com/xmtp/xmtpd/pkg/proto/identity/associations" + envelopesProto "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/envelopes" + "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/payer_api" + "github.com/xmtp/xmtpd/pkg/registry" + "github.com/xmtp/xmtpd/pkg/testutils" + apiTestUtils "github.com/xmtp/xmtpd/pkg/testutils/api" + envelopesTestUtils "github.com/xmtp/xmtpd/pkg/testutils/envelopes" + "github.com/xmtp/xmtpd/pkg/utils" +) + +func buildPayerService( + t *testing.T, +) (*payer.Service, *blockchainMocks.MockIBlockchainPublisher, *registryMocks.MockNodeRegistry, func()) { + ctx, cancel := context.WithCancel(context.Background()) + log := testutils.NewLog(t) + privKey, err := crypto.GenerateKey() + require.NoError(t, err) + mockRegistry := registryMocks.NewMockNodeRegistry(t) + + require.NoError(t, err) + mockMessagePublisher := blockchainMocks.NewMockIBlockchainPublisher(t) + + payerService, err := payer.NewPayerApiService( + ctx, + log, + mockRegistry, + privKey, + mockMessagePublisher, + ) + require.NoError(t, err) + + return payerService, mockMessagePublisher, mockRegistry, func() { + cancel() + } +} + +func TestPublishIdentityUpdate(t *testing.T) { + ctx := context.Background() + svc, mockMessagePublisher, _, cleanup := buildPayerService(t) + defer cleanup() + + inboxId := testutils.RandomInboxId() + inboxIdBytes, err := utils.ParseInboxId(inboxId) + require.NoError(t, err) + + txnHash := common.Hash{1, 2, 3} + + mockMessagePublisher.EXPECT(). + PublishIdentityUpdate(mock.Anything, mock.Anything, mock.Anything). + Return(txnHash, nil) + + identityUpdate := &associations.IdentityUpdate{ + InboxId: inboxId, + } + + envelope := envelopesTestUtils.CreateIdentityUpdateClientEnvelope(inboxIdBytes, identityUpdate) + + publishResponse, err := svc.PublishClientEnvelopes( + ctx, + &payer_api.PublishClientEnvelopesRequest{ + Envelopes: []*envelopesProto.ClientEnvelope{envelope}, + }, + ) + require.NoError(t, err) + require.NotNil(t, publishResponse) + require.Len(t, publishResponse.OriginatorEnvelopes, 1) + + responseEnvelope := publishResponse.OriginatorEnvelopes[0] + parsedOriginatorEnvelope, err := envelopes.NewOriginatorEnvelope(responseEnvelope) + require.NoError(t, err) + + proof := parsedOriginatorEnvelope.Proto().Proof.(*envelopesProto.OriginatorEnvelope_BlockchainProof) + + require.Equal(t, proof.BlockchainProof.TransactionHash, txnHash[:]) +} + +func TestPublishToNodes(t *testing.T) { + originatorServer, _, originatorCleanup := apiTestUtils.NewTestAPIServer(t) + defer originatorCleanup() + + ctx := context.Background() + svc, _, mockRegistry, cleanup := buildPayerService(t) + defer cleanup() + + mockRegistry.EXPECT().GetNode(mock.Anything).Return(®istry.Node{ + HttpAddress: formatAddress(originatorServer.Addr().String()), + }, nil) + + groupId := testutils.RandomGroupID() + testGroupMessage := envelopesTestUtils.CreateGroupMessageClientEnvelope( + groupId, + []byte("test message"), + 100, // This is the expected originator ID of the test server + ) + + publishResponse, err := svc.PublishClientEnvelopes( + ctx, + &payer_api.PublishClientEnvelopesRequest{ + Envelopes: []*envelopesProto.ClientEnvelope{testGroupMessage}, + }, + ) + require.NoError(t, err) + require.NotNil(t, publishResponse) + require.Len(t, publishResponse.OriginatorEnvelopes, 1) + + responseEnvelope := publishResponse.OriginatorEnvelopes[0] + parsedOriginatorEnvelope, err := envelopes.NewOriginatorEnvelope(responseEnvelope) + require.NoError(t, err) + + targetTopic := parsedOriginatorEnvelope.UnsignedOriginatorEnvelope.PayerEnvelope.ClientEnvelope.TargetTopic() + + require.Equal(t, targetTopic.Identifier(), groupId[:]) +} diff --git a/pkg/api/payer/service.go b/pkg/api/payer/service.go index a25794d5..4c17c2fa 100644 --- a/pkg/api/payer/service.go +++ b/pkg/api/payer/service.go @@ -18,6 +18,7 @@ import ( "go.uber.org/zap" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" ) type Service struct { @@ -59,8 +60,8 @@ func (s *Service) PublishClientEnvelopes( // For each originator found in the request, publish all matching envelopes to the node for originatorId, payloadsWithIndex := range grouped.forNodes { + s.log.Info("publishing to originator", zap.Uint32("originator_id", originatorId)) originatorEnvelopes, err := s.publishToNodes(ctx, originatorId, payloadsWithIndex) - if err != nil { s.log.Error("error publishing payer envelopes", zap.Error(err)) return nil, status.Error(codes.Internal, "error publishing payer envelopes") @@ -73,6 +74,7 @@ func (s *Service) PublishClientEnvelopes( } for _, payload := range grouped.forBlockchain { + s.log.Info("publishing to blockchain", zap.Int("index", payload.originalIndex)) var originatorEnvelope *envelopesProto.OriginatorEnvelope if originatorEnvelope, err = s.publishToBlockchain(ctx, payload.payload); err != nil { return nil, status.Errorf(codes.Internal, "error publishing group message: %v", err) @@ -80,7 +82,9 @@ func (s *Service) PublishClientEnvelopes( out[payload.originalIndex] = originatorEnvelope } - return nil, status.Errorf(codes.Unimplemented, "method PublishClientEnvelopes not implemented") + return &payer_api.PublishClientEnvelopesResponse{ + OriginatorEnvelopes: out, + }, nil } // A struct that groups client envelopes by their intended destination @@ -94,7 +98,7 @@ type groupedEnvelopes struct { func (s *Service) groupEnvelopes( rawEnvelopes []*envelopesProto.ClientEnvelope, ) (*groupedEnvelopes, error) { - out := groupedEnvelopes{} + out := groupedEnvelopes{forNodes: make(map[uint32][]clientEnvelopeWithIndex)} for i, rawClientEnvelope := range rawEnvelopes { clientEnvelope, err := envelopes.NewClientEnvelope(rawClientEnvelope) @@ -206,8 +210,26 @@ func (s *Service) publishToBlockchain( ) } + unsignedOriginatorEnvelope := &envelopesProto.UnsignedOriginatorEnvelope{ + OriginatorNodeId: clientEnvelope.Aad().TargetOriginator, + OriginatorSequenceId: 0, // TODO: get this data from a full node + OriginatorNs: 0, // TODO: get this data from a full node + PayerEnvelope: &envelopesProto.PayerEnvelope{ + UnsignedClientEnvelope: payload, + }, + } + + unsignedBytes, err := proto.Marshal(unsignedOriginatorEnvelope) + if err != nil { + return nil, status.Errorf( + codes.Internal, + "error marshalling unsigned originator envelope: %v", + err, + ) + } + return &envelopesProto.OriginatorEnvelope{ - UnsignedOriginatorEnvelope: payload, + UnsignedOriginatorEnvelope: unsignedBytes, Proof: &envelopesProto.OriginatorEnvelope_BlockchainProof{ BlockchainProof: &envelopesProto.BlockchainProof{ TransactionHash: hash.Bytes(), @@ -256,7 +278,7 @@ func shouldSendToBlockchain(targetTopic topic.Topic, aad *envelopesProto.Authent case topic.TOPIC_KIND_IDENTITY_UPDATES_V1: return true case topic.TOPIC_KIND_GROUP_MESSAGES_V1: - return aad.TargetOriginator < constants.MAX_BLOCKCHAIN_ORIGINATOR_ID + return aad.TargetOriginator < uint32(constants.MAX_BLOCKCHAIN_ORIGINATOR_ID) default: return false } diff --git a/pkg/api/query_test.go b/pkg/api/query_test.go index 54eb6957..c1499772 100644 --- a/pkg/api/query_test.go +++ b/pkg/api/query_test.go @@ -68,7 +68,7 @@ func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopePar } func TestQueryAllEnvelopes(t *testing.T) { - api, db, cleanup := apiTestUtils.NewTestAPIClient(t) + api, db, cleanup := apiTestUtils.NewTestReplicationAPIClient(t) defer cleanup() db_rows := setupQueryTest(t, db) @@ -84,7 +84,7 @@ func TestQueryAllEnvelopes(t *testing.T) { } func TestQueryPagedEnvelopes(t *testing.T) { - api, db, cleanup := apiTestUtils.NewTestAPIClient(t) + api, db, cleanup := apiTestUtils.NewTestReplicationAPIClient(t) defer cleanup() db_rows := setupQueryTest(t, db) @@ -100,7 +100,7 @@ func TestQueryPagedEnvelopes(t *testing.T) { } func TestQueryEnvelopesByOriginator(t *testing.T) { - api, db, cleanup := apiTestUtils.NewTestAPIClient(t) + api, db, cleanup := apiTestUtils.NewTestReplicationAPIClient(t) defer cleanup() db_rows := setupQueryTest(t, db) @@ -119,7 +119,7 @@ func TestQueryEnvelopesByOriginator(t *testing.T) { } func TestQueryEnvelopesByTopic(t *testing.T) { - api, store, cleanup := apiTestUtils.NewTestAPIClient(t) + api, store, cleanup := apiTestUtils.NewTestReplicationAPIClient(t) defer cleanup() db_rows := setupQueryTest(t, store) @@ -138,7 +138,7 @@ func TestQueryEnvelopesByTopic(t *testing.T) { } func TestQueryEnvelopesFromLastSeen(t *testing.T) { - api, db, cleanup := apiTestUtils.NewTestAPIClient(t) + api, db, cleanup := apiTestUtils.NewTestReplicationAPIClient(t) defer cleanup() db_rows := setupQueryTest(t, db) @@ -156,7 +156,7 @@ func TestQueryEnvelopesFromLastSeen(t *testing.T) { } func TestQueryTopicFromLastSeen(t *testing.T) { - api, store, cleanup := apiTestUtils.NewTestAPIClient(t) + api, store, cleanup := apiTestUtils.NewTestReplicationAPIClient(t) defer cleanup() db_rows := setupQueryTest(t, store) @@ -177,7 +177,7 @@ func TestQueryTopicFromLastSeen(t *testing.T) { } func TestQueryMultipleTopicsFromLastSeen(t *testing.T) { - api, store, cleanup := apiTestUtils.NewTestAPIClient(t) + api, store, cleanup := apiTestUtils.NewTestReplicationAPIClient(t) defer cleanup() db_rows := setupQueryTest(t, store) @@ -198,7 +198,7 @@ func TestQueryMultipleTopicsFromLastSeen(t *testing.T) { } func TestQueryMultipleOriginatorsFromLastSeen(t *testing.T) { - api, store, cleanup := apiTestUtils.NewTestAPIClient(t) + api, store, cleanup := apiTestUtils.NewTestReplicationAPIClient(t) defer cleanup() db_rows := setupQueryTest(t, store) @@ -219,7 +219,7 @@ func TestQueryMultipleOriginatorsFromLastSeen(t *testing.T) { } func TestQueryEnvelopesWithEmptyResult(t *testing.T) { - api, store, cleanup := apiTestUtils.NewTestAPIClient(t) + api, store, cleanup := apiTestUtils.NewTestReplicationAPIClient(t) defer cleanup() db_rows := setupQueryTest(t, store) @@ -237,7 +237,7 @@ func TestQueryEnvelopesWithEmptyResult(t *testing.T) { } func TestInvalidQuery(t *testing.T) { - api, store, cleanup := apiTestUtils.NewTestAPIClient(t) + api, store, cleanup := apiTestUtils.NewTestReplicationAPIClient(t) defer cleanup() _ = setupQueryTest(t, store) diff --git a/pkg/envelopes/client.go b/pkg/envelopes/client.go index 80178598..7ff94bb9 100644 --- a/pkg/envelopes/client.go +++ b/pkg/envelopes/client.go @@ -16,7 +16,7 @@ type ClientEnvelope struct { func NewClientEnvelope(proto *envelopesProto.ClientEnvelope) (*ClientEnvelope, error) { if proto == nil { - return nil, errors.New("proto is nil") + return nil, errors.New("client envelope proto is nil") } if proto.Aad == nil { diff --git a/pkg/envelopes/originator.go b/pkg/envelopes/originator.go index 9a79cd08..76286bcd 100644 --- a/pkg/envelopes/originator.go +++ b/pkg/envelopes/originator.go @@ -15,7 +15,7 @@ type OriginatorEnvelope struct { func NewOriginatorEnvelope(proto *envelopesProto.OriginatorEnvelope) (*OriginatorEnvelope, error) { if proto == nil { - return nil, errors.New("proto is nil") + return nil, errors.New("originator envelope proto is nil") } unsigned, err := NewUnsignedOriginatorEnvelopeFromBytes(proto.UnsignedOriginatorEnvelope) diff --git a/pkg/envelopes/payer.go b/pkg/envelopes/payer.go index c7ea76f3..22ec7183 100644 --- a/pkg/envelopes/payer.go +++ b/pkg/envelopes/payer.go @@ -17,7 +17,7 @@ type PayerEnvelope struct { func NewPayerEnvelope(proto *envelopesProto.PayerEnvelope) (*PayerEnvelope, error) { if proto == nil { - return nil, errors.New("proto is nil") + return nil, errors.New("payer envelope proto is nil") } clientEnv, err := NewClientEnvelopeFromBytes(proto.UnsignedClientEnvelope) diff --git a/pkg/envelopes/unsignedOriginator.go b/pkg/envelopes/unsignedOriginator.go index 9d00a199..a98517e7 100644 --- a/pkg/envelopes/unsignedOriginator.go +++ b/pkg/envelopes/unsignedOriginator.go @@ -18,7 +18,7 @@ func NewUnsignedOriginatorEnvelope( proto *envelopesProto.UnsignedOriginatorEnvelope, ) (*UnsignedOriginatorEnvelope, error) { if proto == nil { - return nil, errors.New("proto is nil") + return nil, errors.New("unsigned originator envelopeproto is nil") } payer, err := NewPayerEnvelope(proto.PayerEnvelope) diff --git a/pkg/indexer/e2e_test.go b/pkg/indexer/e2e_test.go index 239d68af..2c76eed8 100644 --- a/pkg/indexer/e2e_test.go +++ b/pkg/indexer/e2e_test.go @@ -62,7 +62,7 @@ func TestStoreMessages(t *testing.T) { groupID := testutils.RandomGroupID() topic := topic.NewTopic(topic.TOPIC_KIND_GROUP_MESSAGES_V1, groupID[:]).Bytes() - clientEnvelope := envelopesTestUtils.CreateGroupMessageClientEnvelope(groupID, message) + clientEnvelope := envelopesTestUtils.CreateGroupMessageClientEnvelope(groupID, message, 0) clientEnvelopeBytes, err := proto.Marshal(clientEnvelope) require.NoError(t, err) diff --git a/pkg/indexer/storer/groupMessage_test.go b/pkg/indexer/storer/groupMessage_test.go index 5540b1f5..3499486c 100644 --- a/pkg/indexer/storer/groupMessage_test.go +++ b/pkg/indexer/storer/groupMessage_test.go @@ -49,7 +49,7 @@ func TestStoreGroupMessages(t *testing.T) { message := testutils.RandomBytes(30) sequenceID := uint64(1) - clientEnvelope := envelopesTestUtils.CreateGroupMessageClientEnvelope(groupID, message) + clientEnvelope := envelopesTestUtils.CreateGroupMessageClientEnvelope(groupID, message, 0) logMessage := testutils.BuildMessageSentLog(t, groupID, clientEnvelope, sequenceID) var err error @@ -92,7 +92,7 @@ func TestStoreGroupMessageDuplicate(t *testing.T) { message := testutils.RandomBytes(30) sequenceID := uint64(1) - clientEnvelope := envelopesTestUtils.CreateGroupMessageClientEnvelope(groupID, message) + clientEnvelope := envelopesTestUtils.CreateGroupMessageClientEnvelope(groupID, message, 0) logMessage := testutils.BuildMessageSentLog(t, groupID, clientEnvelope, sequenceID) diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index 4d00af06..6539e945 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -90,9 +90,9 @@ func TestCreateServer(t *testing.T) { server2 := NewTestServer(t, server2Port, dbs[1], registry, privateKey2) require.NotEqual(t, server1.Addr(), server2.Addr()) - client1, cleanup1 := apiTestUtils.NewAPIClient(t, ctx, server1.Addr().String()) + client1, cleanup1 := apiTestUtils.NewReplicationAPIClient(t, ctx, server1.Addr().String()) defer cleanup1() - client2, cleanup2 := apiTestUtils.NewAPIClient(t, ctx, server2.Addr().String()) + client2, cleanup2 := apiTestUtils.NewReplicationAPIClient(t, ctx, server2.Addr().String()) defer cleanup2() targetTopic := topic.NewTopic(topic.TOPIC_KIND_GROUP_MESSAGES_V1, []byte{1, 2, 3}). diff --git a/pkg/testutils/api/api.go b/pkg/testutils/api/api.go index 824f0adb..11862f84 100644 --- a/pkg/testutils/api/api.go +++ b/pkg/testutils/api/api.go @@ -24,7 +24,7 @@ import ( "google.golang.org/grpc/credentials/insecure" ) -func NewAPIClient( +func NewReplicationAPIClient( t *testing.T, ctx context.Context, addr string, @@ -44,6 +44,25 @@ func NewAPIClient( } } +func NewPayerAPIClient( + t *testing.T, + ctx context.Context, + addr string, +) (payer_api.PayerApiClient, func()) { + dialAddr := fmt.Sprintf("passthrough://localhost/%s", addr) + conn, err := grpc.NewClient( + dialAddr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultCallOptions(), + ) + require.NoError(t, err) + client := payer_api.NewPayerApiClient(conn) + return client, func() { + err := conn.Close() + require.NoError(t, err) + } +} + func NewTestAPIServer(t *testing.T) (*api.ApiServer, *sql.DB, func()) { ctx, cancel := context.WithCancel(context.Background()) log := testutils.NewLog(t) @@ -53,7 +72,7 @@ func NewTestAPIServer(t *testing.T) (*api.ApiServer, *sql.DB, func()) { privKeyStr := "0x" + utils.HexEncode(crypto.FromECDSA(privKey)) mockRegistry := mocks.NewMockNodeRegistry(t) mockRegistry.EXPECT().GetNodes().Return([]registry.Node{ - {NodeID: 1, SigningKey: &privKey.PublicKey}, + {NodeID: 100, SigningKey: &privKey.PublicKey}, }, nil) registrant, err := registrant.NewRegistrant(ctx, log, queries.New(db), mockRegistry, privKeyStr) require.NoError(t, err) @@ -98,9 +117,9 @@ func NewTestAPIServer(t *testing.T) (*api.ApiServer, *sql.DB, func()) { } } -func NewTestAPIClient(t *testing.T) (message_api.ReplicationApiClient, *sql.DB, func()) { +func NewTestReplicationAPIClient(t *testing.T) (message_api.ReplicationApiClient, *sql.DB, func()) { svc, db, svcCleanup := NewTestAPIServer(t) - client, clientCleanup := NewAPIClient(t, context.Background(), svc.Addr().String()) + client, clientCleanup := NewReplicationAPIClient(t, context.Background(), svc.Addr().String()) return client, db, func() { clientCleanup() svcCleanup() diff --git a/pkg/testutils/envelopes/envelopes.go b/pkg/testutils/envelopes/envelopes.go index 648d8a5e..883fa0af 100644 --- a/pkg/testutils/envelopes/envelopes.go +++ b/pkg/testutils/envelopes/envelopes.go @@ -42,12 +42,13 @@ func CreateClientEnvelope(aad ...*envelopes.AuthenticatedData) *envelopes.Client func CreateGroupMessageClientEnvelope( groupID [32]byte, message []byte, + targetOriginator uint32, ) *envelopes.ClientEnvelope { return &envelopes.ClientEnvelope{ Aad: &envelopes.AuthenticatedData{ TargetTopic: topic.NewTopic(topic.TOPIC_KIND_GROUP_MESSAGES_V1, groupID[:]). Bytes(), - TargetOriginator: 0, + TargetOriginator: targetOriginator, }, Payload: &envelopes.ClientEnvelope_GroupMessage{ GroupMessage: &mlsv1.GroupMessageInput{