Skip to content

Commit

Permalink
Fixed bug where message state would screw up
Browse files Browse the repository at this point in the history
  • Loading branch information
diamondburned committed Jan 20, 2020
1 parent 0978d51 commit 27e315c
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 32 deletions.
4 changes: 4 additions & 0 deletions discord/time.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ func (t Timestamp) MarshalJSON() ([]byte, error) {
return []byte(`"` + time.Time(t).Format(TimestampFormat) + `"`), nil
}

func (t Timestamp) Valid() bool {
return !time.Time(t).IsZero()
}

//

type UnixTimestamp int64
Expand Down
6 changes: 3 additions & 3 deletions gateway/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@ type (
// Clients may only update their game status 5 times per 20 seconds.
PresenceUpdateEvent discord.Presence
TypingStartEvent struct {
ChannelID discord.Snowflake `json:"channel_id"`
UserID discord.Snowflake `json:"user_id"`
Timestamp discord.Timestamp `json:"timestamp"`
ChannelID discord.Snowflake `json:"channel_id"`
UserID discord.Snowflake `json:"user_id"`
Timestamp discord.UnixTimestamp `json:"timestamp"`

GuildID discord.Snowflake `json:"guild_id,omitempty"`
Member *discord.Member `json:"member,omitempty"`
Expand Down
27 changes: 12 additions & 15 deletions internal/wsutil/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
)

var WSBuffer = 12
var WSReadLimit = 4096 // 4096 bytes
var WSReadLimit int64 = 8192000 // 8 MiB

// Connection is an interface that abstracts around a generic Websocket driver.
// This connection expects the driver to handle compression by itself.
Expand Down Expand Up @@ -64,6 +64,8 @@ func (c *Conn) Dial(ctx context.Context, addr string) error {
HTTPHeader: headers,
})

c.Conn.SetReadLimit(WSReadLimit)

go func() {
c.readLoop(c.events)
}()
Expand Down Expand Up @@ -109,6 +111,7 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
// Probably a zlib payload
z, err := zlib.NewReader(r)
if err != nil {
c.CloseRead(ctx)
return nil,
errors.Wrap(err, "Failed to create a zlib reader")
}
Expand All @@ -117,24 +120,18 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
r = z
}

return ioutil.ReadAll(r)
}

func (c *Conn) Send(ctx context.Context, b []byte) error {
// TODO: zlib stream

w, err := c.Writer(ctx, websocket.MessageText)
b, err := ioutil.ReadAll(r)
if err != nil {
return errors.Wrap(err, "Failed to get WS writer")
c.CloseRead(ctx)
return nil, err
}

defer w.Close()

// Compress with zlib by default NOT.
// w = zlib.NewWriter(w)
return b, nil
}

_, err = w.Write(b)
return err
func (c *Conn) Send(ctx context.Context, b []byte) error {
// TODO: zlib stream
return c.Write(ctx, websocket.MessageText, b)
}

func (c *Conn) Close(err error) error {
Expand Down
40 changes: 38 additions & 2 deletions state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package state

import (
"log"
"sync"

"github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/gateway"
Expand Down Expand Up @@ -37,6 +38,11 @@ type State struct {
PreHandler *handler.Handler // default nil

unhooker func()

// List of channels with few messages, so it doesn't bother hitting the API
// again.
fewMessages []discord.Snowflake
fewMutex sync.Mutex
}

func NewFromSession(s *session.Session, store Store) (*State, error) {
Expand Down Expand Up @@ -298,9 +304,28 @@ func (s *State) Message(
// Messages fetches maximum 100 messages from the API, if it has to. There is no
// limit if it's from the State storage.
func (s *State) Messages(channelID discord.Snowflake) ([]discord.Message, error) {
// TODO: Think of a design that doesn't rely on MaxMessages().
var maxMsgs = s.MaxMessages()

ms, err := s.Store.Messages(channelID)
if err == nil {
return ms, nil
// If the state already has as many messages as it can, skip the API.
if maxMsgs <= len(ms) {
return ms, nil
}

// Is the channel tiny?
s.fewMutex.Lock()
for _, ch := range s.fewMessages {
if ch == channelID {
// Yes, skip the state.
s.fewMutex.Unlock()
return ms, nil
}
}

// No, fetch from the state.
s.fewMutex.Unlock()
}

ms, err = s.Session.Messages(channelID, 100)
Expand All @@ -314,7 +339,18 @@ func (s *State) Messages(channelID discord.Snowflake) ([]discord.Message, error)
}
}

return ms, nil
if len(ms) < maxMsgs {
// Tiny channel, store this.
s.fewMutex.Lock()
s.fewMessages = append(s.fewMessages, channelID)
s.fewMutex.Unlock()

return ms, nil
}

// Since the latest messages are at the end and we already know the maxMsgs,
// we could slice this right away.
return ms[:maxMsgs], nil
}

////
Expand Down
1 change: 1 addition & 0 deletions state/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type StoreGetter interface {

Message(channelID, messageID discord.Snowflake) (*discord.Message, error)
Messages(channelID discord.Snowflake) ([]discord.Message, error)
MaxMessages() int // used to know if the state is filled or not.

// These don't get fetched from the API, it's Gateway only.
Presence(guildID, userID discord.Snowflake) (*discord.Presence, error)
Expand Down
80 changes: 68 additions & 12 deletions state/store_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ func (s *DefaultStore) ChannelSet(channel *discord.Channel) error {

for i, ch := range chs {
if ch.ID == channel.ID {
// Also from discordgo.
if channel.Permissions == nil {
channel.Permissions = ch.Permissions
}

// Found, just edit
chs[i] = *channel

Expand Down Expand Up @@ -289,11 +294,21 @@ func (s *DefaultStore) Guilds() ([]discord.Guild, error) {
return gs, nil
}

func (s *DefaultStore) GuildSet(g *discord.Guild) error {
func (s *DefaultStore) GuildSet(guild *discord.Guild) error {
s.mut.Lock()
s.guilds[g.ID] = g
s.mut.Unlock()
defer s.mut.Unlock()

if g, ok := s.guilds[guild.ID]; ok {
// preserve state stuff
if guild.Roles == nil {
guild.Roles = g.Roles
}
if guild.Emojis == nil {
guild.Emojis = g.Emojis
}
}

s.guilds[guild.ID] = guild
return nil
}

Expand Down Expand Up @@ -425,25 +440,66 @@ func (s *DefaultStore) Messages(
return ms, nil
}

func (s *DefaultStore) MaxMessages() int {
return int(s.DefaultStoreOptions.MaxMessages)
}

func (s *DefaultStore) MessageSet(message *discord.Message) error {
s.mut.Lock()
defer s.mut.Unlock()

ms, ok := s.messages[message.ChannelID]
if !ok {
ms = make([]discord.Message, 0, int(s.MaxMessages)+1)
ms = make([]discord.Message, 0, s.MaxMessages()+1)
}

// Append
ms = append(ms, *message)
// Check if we already have the message.
for i, m := range ms {
if m.ID == message.ID {
// Thanks, Discord.
if message.Content != "" {
m.Content = message.Content
}
if message.EditedTimestamp != nil {
m.EditedTimestamp = message.EditedTimestamp
}
if message.Mentions != nil {
m.Mentions = message.Mentions
}
if message.Embeds != nil {
m.Embeds = message.Embeds
}
if message.Attachments != nil {
m.Attachments = message.Attachments
}
if message.Timestamp.Valid() {
m.Timestamp = message.Timestamp
}
if message.Author.ID.Valid() {
m.Author = message.Author
}

// Sort (should be fast since it's presorted)
sort.Slice(ms, func(i, j int) bool {
return ms[i].ID > ms[j].ID
})
ms[i] = m
return nil
}
}

// Prepend the latest message at the end

if len(ms) > 0 {
var end = s.MaxMessages()
if len(ms) < end {
end = len(ms)
}

// Copy hack to prepend. This copies the 0th-(end-1)th entries to
// 1st-endth.
copy(ms[1:end], ms[0:end-1])
// Then, set the 0th entry.
ms[0] = *message

if len(ms) > int(s.MaxMessages) {
ms = ms[len(ms)-int(s.MaxMessages):]
} else {
ms = append(ms, *message)
}

s.messages[message.ChannelID] = ms
Expand Down

0 comments on commit 27e315c

Please sign in to comment.