diff --git a/server/forward.go b/server/forward.go index c009d0dd32c..d073173f6da 100644 --- a/server/forward.go +++ b/server/forward.go @@ -415,7 +415,7 @@ func (s *GrpcServer) getGlobalTSO(ctx context.Context) (pdpb.Timestamp, error) { } var ( forwardedHost string - forwardStream tsopb.TSO_TsoClient + forwardStream *streamWrapper ts *tsopb.TsoResponse err error ok bool @@ -447,15 +447,21 @@ func (s *GrpcServer) getGlobalTSO(ctx context.Context) (pdpb.Timestamp, error) { if err != nil { return pdpb.Timestamp{}, err } + start := time.Now() + forwardStream.Lock() err = forwardStream.Send(request) if err != nil { if needRetry := handleStreamError(err); needRetry { + forwardStream.Unlock() continue } log.Error("send request to tso primary server failed", zap.Error(err), zap.String("tso-addr", forwardedHost)) + forwardStream.Unlock() return pdpb.Timestamp{}, err } ts, err = forwardStream.Recv() + forwardStream.Unlock() + forwardTsoDuration.Observe(time.Since(start).Seconds()) if err != nil { if needRetry := handleStreamError(err); needRetry { continue @@ -469,7 +475,7 @@ func (s *GrpcServer) getGlobalTSO(ctx context.Context) (pdpb.Timestamp, error) { return pdpb.Timestamp{}, err } -func (s *GrpcServer) getTSOForwardStream(forwardedHost string) (tsopb.TSO_TsoClient, error) { +func (s *GrpcServer) getTSOForwardStream(forwardedHost string) (*streamWrapper, error) { s.tsoClientPool.RLock() forwardStream, ok := s.tsoClientPool.clients[forwardedHost] s.tsoClientPool.RUnlock() @@ -495,11 +501,14 @@ func (s *GrpcServer) getTSOForwardStream(forwardedHost string) (tsopb.TSO_TsoCli done := make(chan struct{}) ctx, cancel := context.WithCancel(s.ctx) go grpcutil.CheckStream(ctx, cancel, done) - forwardStream, err = tsopb.NewTSOClient(client).Tso(ctx) + tsoClient, err := tsopb.NewTSOClient(client).Tso(ctx) done <- struct{}{} if err != nil { return nil, err } + forwardStream = &streamWrapper{ + TSO_TsoClient: tsoClient, + } s.tsoClientPool.clients[forwardedHost] = forwardStream return forwardStream, nil } diff --git a/server/metrics.go b/server/metrics.go index 0935008a420..73aa5a5e9c7 100644 --- a/server/metrics.go +++ b/server/metrics.go @@ -160,6 +160,14 @@ var ( Name: "forward_fail_total", Help: "Counter of forward fail.", }, []string{"request", "type"}) + forwardTsoDuration = prometheus.NewHistogram( + prometheus.HistogramOpts{ + Namespace: "pd", + Subsystem: "server", + Name: "forward_tso_duration_seconds", + Help: "Bucketed histogram of processing time (s) of handled forward tso requests.", + Buckets: prometheus.ExponentialBuckets(0.0005, 2, 13), + }) ) func init() { @@ -180,4 +188,5 @@ func init() { prometheus.MustRegister(bucketReportInterval) prometheus.MustRegister(apiConcurrencyGauge) prometheus.MustRegister(forwardFailCounter) + prometheus.MustRegister(forwardTsoDuration) } diff --git a/server/server.go b/server/server.go index 7d3f12239e5..bfefca7540f 100644 --- a/server/server.go +++ b/server/server.go @@ -126,6 +126,11 @@ var ( etcdCommittedIndexGauge = etcdStateGauge.WithLabelValues("committedIndex") ) +type streamWrapper struct { + tsopb.TSO_TsoClient + syncutil.Mutex +} + // Server is the pd server. It implements bs.Server // nolint type Server struct { @@ -206,7 +211,7 @@ type Server struct { tsoClientPool struct { syncutil.RWMutex - clients map[string]tsopb.TSO_TsoClient + clients map[string]*streamWrapper } // tsoDispatcher is used to dispatch different TSO requests to @@ -267,9 +272,9 @@ func CreateServer(ctx context.Context, cfg *config.Config, services []string, le mode: mode, tsoClientPool: struct { syncutil.RWMutex - clients map[string]tsopb.TSO_TsoClient + clients map[string]*streamWrapper }{ - clients: make(map[string]tsopb.TSO_TsoClient), + clients: make(map[string]*streamWrapper), }, } s.handler = newHandler(s) diff --git a/tests/integrations/mcs/tso/server_test.go b/tests/integrations/mcs/tso/server_test.go index 108740e46f9..4d478a848ed 100644 --- a/tests/integrations/mcs/tso/server_test.go +++ b/tests/integrations/mcs/tso/server_test.go @@ -22,6 +22,7 @@ import ( "net/http" "strconv" "strings" + "sync" "testing" "time" @@ -518,6 +519,142 @@ func (suite *APIServerForwardTestSuite) checkAvailableTSO(re *require.Assertions re.NoError(err) } +func TestForwardTsoConcurrently(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestAPICluster(ctx, 3) + re.NoError(err) + defer cluster.Destroy() + + err = cluster.RunInitialServers() + re.NoError(err) + + leaderName := cluster.WaitLeader() + pdLeader := cluster.GetServer(leaderName) + backendEndpoints := pdLeader.GetAddr() + re.NoError(pdLeader.BootstrapCluster()) + leader := cluster.GetServer(cluster.WaitLeader()) + rc := leader.GetServer().GetRaftCluster() + for i := 0; i < 3; i++ { + region := &metapb.Region{ + Id: uint64(i*4 + 1), + Peers: []*metapb.Peer{{Id: uint64(i*4 + 2), StoreId: uint64(i*4 + 3)}}, + StartKey: []byte{byte(i)}, + EndKey: []byte{byte(i + 1)}, + } + rc.HandleRegionHeartbeat(core.NewRegionInfo(region, region.Peers[0])) + } + + re.NoError(failpoint.Enable("github.com/tikv/pd/client/usePDServiceMode", "return(true)")) + defer func() { + re.NoError(failpoint.Disable("github.com/tikv/pd/client/usePDServiceMode")) + }() + + tc, err := tests.NewTestTSOCluster(ctx, 2, backendEndpoints) + re.NoError(err) + defer tc.Destroy() + tc.WaitForDefaultPrimaryServing(re) + + wg := sync.WaitGroup{} + for i := 0; i < 3; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + pdClient, err := pd.NewClientWithContext( + context.Background(), + []string{backendEndpoints}, + pd.SecurityOption{}) + re.NoError(err) + re.NotNil(pdClient) + defer pdClient.Close() + for j := 0; j < 10; j++ { + testutil.Eventually(re, func() bool { + min, err := pdClient.UpdateServiceGCSafePoint(context.Background(), fmt.Sprintf("service-%d", i), 1000, 1) + return err == nil && min == 0 + }) + } + }(i) + } + wg.Wait() +} + +func BenchmarkForwardTsoConcurrently(b *testing.B) { + re := require.New(b) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestAPICluster(ctx, 3) + re.NoError(err) + defer cluster.Destroy() + + err = cluster.RunInitialServers() + re.NoError(err) + + leaderName := cluster.WaitLeader() + pdLeader := cluster.GetServer(leaderName) + backendEndpoints := pdLeader.GetAddr() + re.NoError(pdLeader.BootstrapCluster()) + leader := cluster.GetServer(cluster.WaitLeader()) + rc := leader.GetServer().GetRaftCluster() + for i := 0; i < 3; i++ { + region := &metapb.Region{ + Id: uint64(i*4 + 1), + Peers: []*metapb.Peer{{Id: uint64(i*4 + 2), StoreId: uint64(i*4 + 3)}}, + StartKey: []byte{byte(i)}, + EndKey: []byte{byte(i + 1)}, + } + rc.HandleRegionHeartbeat(core.NewRegionInfo(region, region.Peers[0])) + } + + re.NoError(failpoint.Enable("github.com/tikv/pd/client/usePDServiceMode", "return(true)")) + defer func() { + re.NoError(failpoint.Disable("github.com/tikv/pd/client/usePDServiceMode")) + }() + + tc, err := tests.NewTestTSOCluster(ctx, 1, backendEndpoints) + re.NoError(err) + defer tc.Destroy() + tc.WaitForDefaultPrimaryServing(re) + + initClients := func(num int) []pd.Client { + var clients []pd.Client + for i := 0; i < num; i++ { + pdClient, err := pd.NewClientWithContext(context.Background(), + []string{backendEndpoints}, pd.SecurityOption{}, pd.WithMaxErrorRetry(1)) + re.NoError(err) + re.NotNil(pdClient) + clients = append(clients, pdClient) + } + return clients + } + + concurrencyLevels := []int{1, 2, 5, 10, 20} + for _, clientsNum := range concurrencyLevels { + clients := initClients(clientsNum) + b.Run(fmt.Sprintf("clients_%d", clientsNum), func(b *testing.B) { + wg := sync.WaitGroup{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + for j, client := range clients { + wg.Add(1) + go func(j int, client pd.Client) { + defer wg.Done() + for k := 0; k < 1000; k++ { + min, err := client.UpdateServiceGCSafePoint(context.Background(), fmt.Sprintf("service-%d", j), 1000, 1) + re.NoError(err) + re.Equal(uint64(0), min) + } + }(j, client) + } + } + wg.Wait() + }) + for _, c := range clients { + c.Close() + } + } +} + type CommonTestSuite struct { suite.Suite ctx context.Context