diff --git a/internal/irc-reader/config/config.go b/internal/irc-reader/config/config.go index 873fba6..3b7b626 100644 --- a/internal/irc-reader/config/config.go +++ b/internal/irc-reader/config/config.go @@ -14,6 +14,7 @@ var OnChange func() type Config struct { LogLevel string + Replicas int RateLimit struct { Join int64 diff --git a/internal/irc-reader/nats.go b/internal/irc-reader/nats.go index 26d81f2..03d88b4 100644 --- a/internal/irc-reader/nats.go +++ b/internal/irc-reader/nats.go @@ -57,7 +57,6 @@ func (c *Controller) watchChanges(nc *nats.Conn) { if !bitwise.Has(channel.Flags, bitwise.JOIN_IRC) { return } - println("joining: " + channel.Username) c.joinChannel(channel) case database.Update: // TODO: implement diff --git a/internal/irc-reader/service.go b/internal/irc-reader/service.go index 88bb783..c13b227 100644 --- a/internal/irc-reader/service.go +++ b/internal/irc-reader/service.go @@ -2,6 +2,9 @@ package irc_reader import ( "context" + "os" + "strconv" + "strings" "github.com/nats-io/nats.go" "github.com/redis/go-redis/v9" @@ -17,6 +20,8 @@ type Controller struct { jetStream nats.JetStreamContext twitch *manager.IRCManager + shardID int + // limit amount of workers for joining channels joinSem chan struct{} } @@ -29,6 +34,9 @@ func New(cfg *config.Config) *Controller { } func (c *Controller) Init() error { + if c.cfg.Replicas > 1 { + c.shardID = getShardID() + } nc, err := nats.Connect(c.cfg.Nats.URL) if err != nil { return err @@ -84,6 +92,17 @@ func (c *Controller) Init() error { return nil } +func getShardID() int { + env := os.Getenv("HOSTNAME") + split := strings.Split(env, "-") + if len(split) == 0 { + return 0 + } + id := split[len(split)-1] + result, _ := strconv.Atoi(id) + return result +} + func (c *Controller) Shutdown() { wg := c.twitch.Shutdown() wg.Wait() diff --git a/internal/irc-reader/twitch.go b/internal/irc-reader/twitch.go index b14c28b..c15ae83 100644 --- a/internal/irc-reader/twitch.go +++ b/internal/irc-reader/twitch.go @@ -71,16 +71,18 @@ func (c *Controller) joinChannels(channels []types.Channel) { } func (c *Controller) joinChannel(channel types.Channel) { + if !c.shouldJoin(channel.ID) { + return + } c.joinSem <- struct{}{} ch := channel go func() { - // TODO: filter out channels based on user ID & shard ID, so we can spread the load across kubernetes statefulset - // make sure the channel is flagged to be joined if !bitwise.Has(ch.Flags, bitwise.JOIN_IRC) { return } + zap.S().Infof("joining channel: %v", ch.Username) err := c.twitch.Join(ch.Username, ch.Weight) if err != nil { zap.L().Error( @@ -93,6 +95,18 @@ func (c *Controller) joinChannel(channel types.Channel) { }() } +func (c *Controller) shouldJoin(userID int64) bool { + if c.cfg.Replicas < 2 { + return true + } + + if int(userID)%c.cfg.Replicas == c.shardID { + return true + } + + return false +} + // parses out the channel name from a PRIVMSG, // don't use on any other type of message seen as though there's no slice length checks func parseChannel(in string) string {