Skip to content

Commit

Permalink
Add tests and fix bugs for payer API
Browse files Browse the repository at this point in the history
  • Loading branch information
neekolas committed Oct 22, 2024
1 parent ef6061a commit 0f7a6ef
Show file tree
Hide file tree
Showing 14 changed files with 203 additions and 34 deletions.
8 changes: 4 additions & 4 deletions pkg/api/message/publish_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
)

func TestPublishEnvelope(t *testing.T) {
api, db, cleanup := apiTestUtils.NewTestAPIClient(t)
api, db, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()

payerEnvelope := envelopeTestUtils.CreatePayerEnvelope(t)
Expand Down Expand Up @@ -64,7 +64,7 @@ func TestPublishEnvelope(t *testing.T) {
}

func TestUnmarshalErrorOnPublish(t *testing.T) {
api, _, cleanup := apiTestUtils.NewTestAPIClient(t)
api, _, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()

envelope := envelopeTestUtils.CreatePayerEnvelope(t)
Expand All @@ -79,7 +79,7 @@ func TestUnmarshalErrorOnPublish(t *testing.T) {
}

func TestMismatchingOriginatorOnPublish(t *testing.T) {
api, _, cleanup := apiTestUtils.NewTestAPIClient(t)
api, _, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()

clientEnv := envelopeTestUtils.CreateClientEnvelope()
Expand All @@ -96,7 +96,7 @@ func TestMismatchingOriginatorOnPublish(t *testing.T) {
}

func TestMissingTopicOnPublish(t *testing.T) {
api, _, cleanup := apiTestUtils.NewTestAPIClient(t)
api, _, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()

clientEnv := envelopeTestUtils.CreateClientEnvelope()
Expand Down
2 changes: 1 addition & 1 deletion pkg/api/message/subscribe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func setupTest(t *testing.T) (message_api.ReplicationApiClient, *sql.DB, func())
},
}

return testUtilsApi.NewTestAPIClient(t)
return testUtilsApi.NewTestReplicationAPIClient(t)
}

func insertInitialRows(t *testing.T, store *sql.DB) {
Expand Down
127 changes: 127 additions & 0 deletions pkg/api/payer/publish_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package payer_test

import (
"context"
"testing"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/xmtp/xmtpd/pkg/api/payer"
"github.com/xmtp/xmtpd/pkg/envelopes"
blockchainMocks "github.com/xmtp/xmtpd/pkg/mocks/blockchain"
registryMocks "github.com/xmtp/xmtpd/pkg/mocks/registry"
"github.com/xmtp/xmtpd/pkg/proto/identity/associations"
envelopesProto "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/envelopes"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/payer_api"
"github.com/xmtp/xmtpd/pkg/registry"
"github.com/xmtp/xmtpd/pkg/testutils"
apiTestUtils "github.com/xmtp/xmtpd/pkg/testutils/api"
envelopesTestUtils "github.com/xmtp/xmtpd/pkg/testutils/envelopes"
"github.com/xmtp/xmtpd/pkg/utils"
)

func buildPayerService(
t *testing.T,
) (*payer.Service, *blockchainMocks.MockIBlockchainPublisher, *registryMocks.MockNodeRegistry, func()) {
ctx, cancel := context.WithCancel(context.Background())
log := testutils.NewLog(t)
privKey, err := crypto.GenerateKey()
require.NoError(t, err)
mockRegistry := registryMocks.NewMockNodeRegistry(t)

require.NoError(t, err)
mockMessagePublisher := blockchainMocks.NewMockIBlockchainPublisher(t)

payerService, err := payer.NewPayerApiService(
ctx,
log,
mockRegistry,
privKey,
mockMessagePublisher,
)
require.NoError(t, err)

return payerService, mockMessagePublisher, mockRegistry, func() {
cancel()
}
}

func TestPublishIdentityUpdate(t *testing.T) {
ctx := context.Background()
svc, mockMessagePublisher, _, cleanup := buildPayerService(t)
defer cleanup()

inboxId := testutils.RandomInboxId()
inboxIdBytes, err := utils.ParseInboxId(inboxId)
require.NoError(t, err)

txnHash := common.Hash{1, 2, 3}

mockMessagePublisher.EXPECT().
PublishIdentityUpdate(mock.Anything, mock.Anything, mock.Anything).
Return(txnHash, nil)

identityUpdate := &associations.IdentityUpdate{
InboxId: inboxId,
}

envelope := envelopesTestUtils.CreateIdentityUpdateClientEnvelope(inboxIdBytes, identityUpdate)

publishResponse, err := svc.PublishClientEnvelopes(
ctx,
&payer_api.PublishClientEnvelopesRequest{
Envelopes: []*envelopesProto.ClientEnvelope{envelope},
},
)
require.NoError(t, err)
require.NotNil(t, publishResponse)
require.Len(t, publishResponse.OriginatorEnvelopes, 1)

responseEnvelope := publishResponse.OriginatorEnvelopes[0]
parsedOriginatorEnvelope, err := envelopes.NewOriginatorEnvelope(responseEnvelope)
require.NoError(t, err)

proof := parsedOriginatorEnvelope.Proto().Proof.(*envelopesProto.OriginatorEnvelope_BlockchainProof)

require.Equal(t, proof.BlockchainProof.TransactionHash, txnHash[:])
}

func TestPublishToNodes(t *testing.T) {
originatorServer, _, originatorCleanup := apiTestUtils.NewTestAPIServer(t)
defer originatorCleanup()

ctx := context.Background()
svc, _, mockRegistry, cleanup := buildPayerService(t)
defer cleanup()

mockRegistry.EXPECT().GetNode(mock.Anything).Return(&registry.Node{
HttpAddress: formatAddress(originatorServer.Addr().String()),
}, nil)

groupId := testutils.RandomGroupID()
testGroupMessage := envelopesTestUtils.CreateGroupMessageClientEnvelope(
groupId,
[]byte("test message"),
100, // This is the expected originator ID of the test server
)

publishResponse, err := svc.PublishClientEnvelopes(
ctx,
&payer_api.PublishClientEnvelopesRequest{
Envelopes: []*envelopesProto.ClientEnvelope{testGroupMessage},
},
)
require.NoError(t, err)
require.NotNil(t, publishResponse)
require.Len(t, publishResponse.OriginatorEnvelopes, 1)

responseEnvelope := publishResponse.OriginatorEnvelopes[0]
parsedOriginatorEnvelope, err := envelopes.NewOriginatorEnvelope(responseEnvelope)
require.NoError(t, err)

targetTopic := parsedOriginatorEnvelope.UnsignedOriginatorEnvelope.PayerEnvelope.ClientEnvelope.TargetTopic()

require.Equal(t, targetTopic.Identifier(), groupId[:])
}
32 changes: 27 additions & 5 deletions pkg/api/payer/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"go.uber.org/zap"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)

type Service struct {
Expand Down Expand Up @@ -59,8 +60,8 @@ func (s *Service) PublishClientEnvelopes(

// For each originator found in the request, publish all matching envelopes to the node
for originatorId, payloadsWithIndex := range grouped.forNodes {
s.log.Info("publishing to originator", zap.Uint32("originator_id", originatorId))
originatorEnvelopes, err := s.publishToNodes(ctx, originatorId, payloadsWithIndex)

if err != nil {
s.log.Error("error publishing payer envelopes", zap.Error(err))
return nil, status.Error(codes.Internal, "error publishing payer envelopes")
Expand All @@ -73,14 +74,17 @@ func (s *Service) PublishClientEnvelopes(
}

for _, payload := range grouped.forBlockchain {
s.log.Info("publishing to blockchain", zap.Int("index", payload.originalIndex))
var originatorEnvelope *envelopesProto.OriginatorEnvelope
if originatorEnvelope, err = s.publishToBlockchain(ctx, payload.payload); err != nil {
return nil, status.Errorf(codes.Internal, "error publishing group message: %v", err)
}
out[payload.originalIndex] = originatorEnvelope
}

return nil, status.Errorf(codes.Unimplemented, "method PublishClientEnvelopes not implemented")
return &payer_api.PublishClientEnvelopesResponse{
OriginatorEnvelopes: out,
}, nil
}

// A struct that groups client envelopes by their intended destination
Expand All @@ -94,7 +98,7 @@ type groupedEnvelopes struct {
func (s *Service) groupEnvelopes(
rawEnvelopes []*envelopesProto.ClientEnvelope,
) (*groupedEnvelopes, error) {
out := groupedEnvelopes{}
out := groupedEnvelopes{forNodes: make(map[uint32][]clientEnvelopeWithIndex)}

for i, rawClientEnvelope := range rawEnvelopes {
clientEnvelope, err := envelopes.NewClientEnvelope(rawClientEnvelope)
Expand Down Expand Up @@ -206,8 +210,26 @@ func (s *Service) publishToBlockchain(
)
}

unsignedOriginatorEnvelope := &envelopesProto.UnsignedOriginatorEnvelope{
OriginatorNodeId: clientEnvelope.Aad().TargetOriginator,
OriginatorSequenceId: 0, // TODO: get this data from a full node
OriginatorNs: 0, // TODO: get this data from a full node
PayerEnvelope: &envelopesProto.PayerEnvelope{
UnsignedClientEnvelope: payload,
},
}

unsignedBytes, err := proto.Marshal(unsignedOriginatorEnvelope)
if err != nil {
return nil, status.Errorf(
codes.Internal,
"error marshalling unsigned originator envelope: %v",
err,
)
}

return &envelopesProto.OriginatorEnvelope{
UnsignedOriginatorEnvelope: payload,
UnsignedOriginatorEnvelope: unsignedBytes,
Proof: &envelopesProto.OriginatorEnvelope_BlockchainProof{
BlockchainProof: &envelopesProto.BlockchainProof{
TransactionHash: hash.Bytes(),
Expand Down Expand Up @@ -256,7 +278,7 @@ func shouldSendToBlockchain(targetTopic topic.Topic, aad *envelopesProto.Authent
case topic.TOPIC_KIND_IDENTITY_UPDATES_V1:
return true
case topic.TOPIC_KIND_GROUP_MESSAGES_V1:
return aad.TargetOriginator < constants.MAX_BLOCKCHAIN_ORIGINATOR_ID
return aad.TargetOriginator < uint32(constants.MAX_BLOCKCHAIN_ORIGINATOR_ID)
default:
return false
}
Expand Down
20 changes: 10 additions & 10 deletions pkg/api/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopePar
}

func TestQueryAllEnvelopes(t *testing.T) {
api, db, cleanup := apiTestUtils.NewTestAPIClient(t)
api, db, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()
db_rows := setupQueryTest(t, db)

Expand All @@ -84,7 +84,7 @@ func TestQueryAllEnvelopes(t *testing.T) {
}

func TestQueryPagedEnvelopes(t *testing.T) {
api, db, cleanup := apiTestUtils.NewTestAPIClient(t)
api, db, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()
db_rows := setupQueryTest(t, db)

Expand All @@ -100,7 +100,7 @@ func TestQueryPagedEnvelopes(t *testing.T) {
}

func TestQueryEnvelopesByOriginator(t *testing.T) {
api, db, cleanup := apiTestUtils.NewTestAPIClient(t)
api, db, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()
db_rows := setupQueryTest(t, db)

Expand All @@ -119,7 +119,7 @@ func TestQueryEnvelopesByOriginator(t *testing.T) {
}

func TestQueryEnvelopesByTopic(t *testing.T) {
api, store, cleanup := apiTestUtils.NewTestAPIClient(t)
api, store, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()
db_rows := setupQueryTest(t, store)

Expand All @@ -138,7 +138,7 @@ func TestQueryEnvelopesByTopic(t *testing.T) {
}

func TestQueryEnvelopesFromLastSeen(t *testing.T) {
api, db, cleanup := apiTestUtils.NewTestAPIClient(t)
api, db, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()
db_rows := setupQueryTest(t, db)

Expand All @@ -156,7 +156,7 @@ func TestQueryEnvelopesFromLastSeen(t *testing.T) {
}

func TestQueryTopicFromLastSeen(t *testing.T) {
api, store, cleanup := apiTestUtils.NewTestAPIClient(t)
api, store, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()
db_rows := setupQueryTest(t, store)

Expand All @@ -177,7 +177,7 @@ func TestQueryTopicFromLastSeen(t *testing.T) {
}

func TestQueryMultipleTopicsFromLastSeen(t *testing.T) {
api, store, cleanup := apiTestUtils.NewTestAPIClient(t)
api, store, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()
db_rows := setupQueryTest(t, store)

Expand All @@ -198,7 +198,7 @@ func TestQueryMultipleTopicsFromLastSeen(t *testing.T) {
}

func TestQueryMultipleOriginatorsFromLastSeen(t *testing.T) {
api, store, cleanup := apiTestUtils.NewTestAPIClient(t)
api, store, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()
db_rows := setupQueryTest(t, store)

Expand All @@ -219,7 +219,7 @@ func TestQueryMultipleOriginatorsFromLastSeen(t *testing.T) {
}

func TestQueryEnvelopesWithEmptyResult(t *testing.T) {
api, store, cleanup := apiTestUtils.NewTestAPIClient(t)
api, store, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()
db_rows := setupQueryTest(t, store)

Expand All @@ -237,7 +237,7 @@ func TestQueryEnvelopesWithEmptyResult(t *testing.T) {
}

func TestInvalidQuery(t *testing.T) {
api, store, cleanup := apiTestUtils.NewTestAPIClient(t)
api, store, cleanup := apiTestUtils.NewTestReplicationAPIClient(t)
defer cleanup()
_ = setupQueryTest(t, store)

Expand Down
2 changes: 1 addition & 1 deletion pkg/envelopes/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type ClientEnvelope struct {

func NewClientEnvelope(proto *envelopesProto.ClientEnvelope) (*ClientEnvelope, error) {
if proto == nil {
return nil, errors.New("proto is nil")
return nil, errors.New("client envelope proto is nil")
}

if proto.Aad == nil {
Expand Down
2 changes: 1 addition & 1 deletion pkg/envelopes/originator.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type OriginatorEnvelope struct {

func NewOriginatorEnvelope(proto *envelopesProto.OriginatorEnvelope) (*OriginatorEnvelope, error) {
if proto == nil {
return nil, errors.New("proto is nil")
return nil, errors.New("originator envelope proto is nil")
}

unsigned, err := NewUnsignedOriginatorEnvelopeFromBytes(proto.UnsignedOriginatorEnvelope)
Expand Down
2 changes: 1 addition & 1 deletion pkg/envelopes/payer.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type PayerEnvelope struct {

func NewPayerEnvelope(proto *envelopesProto.PayerEnvelope) (*PayerEnvelope, error) {
if proto == nil {
return nil, errors.New("proto is nil")
return nil, errors.New("payer envelope proto is nil")
}

clientEnv, err := NewClientEnvelopeFromBytes(proto.UnsignedClientEnvelope)
Expand Down
2 changes: 1 addition & 1 deletion pkg/envelopes/unsignedOriginator.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func NewUnsignedOriginatorEnvelope(
proto *envelopesProto.UnsignedOriginatorEnvelope,
) (*UnsignedOriginatorEnvelope, error) {
if proto == nil {
return nil, errors.New("proto is nil")
return nil, errors.New("unsigned originator envelopeproto is nil")
}

payer, err := NewPayerEnvelope(proto.PayerEnvelope)
Expand Down
2 changes: 1 addition & 1 deletion pkg/indexer/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestStoreMessages(t *testing.T) {
groupID := testutils.RandomGroupID()
topic := topic.NewTopic(topic.TOPIC_KIND_GROUP_MESSAGES_V1, groupID[:]).Bytes()

clientEnvelope := envelopesTestUtils.CreateGroupMessageClientEnvelope(groupID, message)
clientEnvelope := envelopesTestUtils.CreateGroupMessageClientEnvelope(groupID, message, 0)
clientEnvelopeBytes, err := proto.Marshal(clientEnvelope)
require.NoError(t, err)

Expand Down
Loading

0 comments on commit 0f7a6ef

Please sign in to comment.