From 2dae48dc22111eb8ea42d574efe0e593901d5224 Mon Sep 17 00:00:00 2001 From: Andy Xie Date: Wed, 20 Feb 2019 16:30:37 +0800 Subject: [PATCH] add max_channel_client_connection_count --- apps/nsqd/nsqd.go | 1 + nsqd/channel.go | 11 +++++++++-- nsqd/channel_test.go | 29 ++++++++++++++++++++++++++++- nsqd/nsqd_test.go | 3 ++- nsqd/options.go | 2 ++ nsqd/protocol_v2.go | 5 ++++- 6 files changed, 46 insertions(+), 5 deletions(-) diff --git a/apps/nsqd/nsqd.go b/apps/nsqd/nsqd.go index 25f052d9c..4d85228c1 100644 --- a/apps/nsqd/nsqd.go +++ b/apps/nsqd/nsqd.go @@ -119,6 +119,7 @@ func nsqdFlagSet(opts *nsqd.Options) *flag.FlagSet { flagSet.Duration("max-output-buffer-timeout", opts.MaxOutputBufferTimeout, "maximum client configurable duration of time between flushing to a client") flagSet.Duration("min-output-buffer-timeout", opts.MinOutputBufferTimeout, "minimum client configurable duration of time between flushing to a client") flagSet.Duration("output-buffer-timeout", opts.OutputBufferTimeout, "default duration of time between flushing data to clients") + flagSet.Int("max-channel-consumers", opts.MaxChannelConsumers, "maximum channel consumer connection count per nsqd instance (default 0, i.e., unlimited)") // statsd integration options flagSet.String("statsd-address", opts.StatsdAddress, "UDP : of a statsd daemon for pushing stats") diff --git a/nsqd/channel.go b/nsqd/channel.go index 95a7adb21..f6a7e40a7 100644 --- a/nsqd/channel.go +++ b/nsqd/channel.go @@ -389,15 +389,22 @@ func (c *Channel) RequeueMessage(clientID int64, id MessageID, timeout time.Dura } // AddClient adds a client to the Channel's client list -func (c *Channel) AddClient(clientID int64, client Consumer) { +func (c *Channel) AddClient(clientID int64, client Consumer) error { c.Lock() defer c.Unlock() _, ok := c.clients[clientID] if ok { - return + return nil + } + + maxChannelConsumers := c.ctx.nsqd.getOpts().MaxChannelConsumers + if maxChannelConsumers != 0 && len(c.clients) >= maxChannelConsumers { + return errors.New("E_TOO_MANY_CHANNEL_SUBSCRIPTIONS") } + c.clients[clientID] = client + return nil } // RemoveClient removes a client from the Channel's client list diff --git a/nsqd/channel_test.go b/nsqd/channel_test.go index b98d35339..4873706a3 100644 --- a/nsqd/channel_test.go +++ b/nsqd/channel_test.go @@ -152,7 +152,8 @@ func TestChannelEmptyConsumer(t *testing.T) { channel := topic.GetChannel("channel") client := newClientV2(0, conn, &context{nsqd}) client.SetReadyCount(25) - channel.AddClient(client.ID, client) + err := channel.AddClient(client.ID, client) + test.Equal(t, err, nil) for i := 0; i < 25; i++ { msg := NewMessage(topic.GenerateID(), []byte("test")) @@ -173,6 +174,32 @@ func TestChannelEmptyConsumer(t *testing.T) { } } +func TestMaxChannelConsumers(t *testing.T) { + opts := NewOptions() + opts.Logger = test.NewTestLogger(t) + opts.MaxChannelConsumers = 1 + tcpAddr, _, nsqd := mustStartNSQD(opts) + defer os.RemoveAll(opts.DataPath) + defer nsqd.Exit() + + conn, _ := mustConnectNSQD(tcpAddr) + defer conn.Close() + + topicName := "test_max_channel_consumers" + strconv.Itoa(int(time.Now().Unix())) + topic := nsqd.GetTopic(topicName) + channel := topic.GetChannel("channel") + + client1 := newClientV2(1, conn, &context{nsqd}) + client1.SetReadyCount(25) + err := channel.AddClient(client1.ID, client1) + test.Equal(t, err, nil) + + client2 := newClientV2(2, conn, &context{nsqd}) + client2.SetReadyCount(25) + err = channel.AddClient(client2.ID, client2) + test.NotEqual(t, err, nil) +} + func TestChannelHealth(t *testing.T) { opts := NewOptions() opts.Logger = test.NewTestLogger(t) diff --git a/nsqd/nsqd_test.go b/nsqd/nsqd_test.go index 2045792e0..9bacd0e6e 100644 --- a/nsqd/nsqd_test.go +++ b/nsqd/nsqd_test.go @@ -180,7 +180,8 @@ func TestEphemeralTopicsAndChannels(t *testing.T) { topic := nsqd.GetTopic(topicName) ephemeralChannel := topic.GetChannel("ch1#ephemeral") client := newClientV2(0, nil, &context{nsqd}) - ephemeralChannel.AddClient(client.ID, client) + err := ephemeralChannel.AddClient(client.ID, client) + test.Equal(t, err, nil) msg := NewMessage(topic.GenerateID(), body) topic.PutMessage(msg) diff --git a/nsqd/options.go b/nsqd/options.go index fdfda8572..300e1af5a 100644 --- a/nsqd/options.go +++ b/nsqd/options.go @@ -58,6 +58,7 @@ type Options struct { MaxOutputBufferTimeout time.Duration `flag:"max-output-buffer-timeout"` MinOutputBufferTimeout time.Duration `flag:"min-output-buffer-timeout"` OutputBufferTimeout time.Duration `flag:"output-buffer-timeout"` + MaxChannelConsumers int `flag:"max-channel-consumers"` // statsd integration StatsdAddress string `flag:"statsd-address"` @@ -134,6 +135,7 @@ func NewOptions() *Options { MaxOutputBufferTimeout: 30 * time.Second, MinOutputBufferTimeout: 25 * time.Millisecond, OutputBufferTimeout: 250 * time.Millisecond, + MaxChannelConsumers: 0, StatsdPrefix: "nsq.%s", StatsdInterval: 60 * time.Second, diff --git a/nsqd/protocol_v2.go b/nsqd/protocol_v2.go index c2e7d7b42..aa754c9b5 100644 --- a/nsqd/protocol_v2.go +++ b/nsqd/protocol_v2.go @@ -615,7 +615,10 @@ func (p *protocolV2) SUB(client *clientV2, params [][]byte) ([]byte, error) { for { topic := p.ctx.nsqd.GetTopic(topicName) channel = topic.GetChannel(channelName) - channel.AddClient(client.ID, client) + if err := channel.AddClient(client.ID, client); err != nil { + return nil, protocol.NewFatalClientErr(nil, "E_TOO_MANY_CHANNEL_CONSUMERS", + fmt.Sprintf("TOO many channel consumers for %s:%s", topicName, channelName)) + } if (channel.ephemeral && channel.Exiting()) || (topic.ephemeral && topic.Exiting()) { channel.RemoveClient(client.ID)