diff --git a/pkg/services/openai.go b/pkg/services/openai.go index b5f280e..9de3005 100644 --- a/pkg/services/openai.go +++ b/pkg/services/openai.go @@ -20,17 +20,12 @@ func NewOpenAIService(p *providers.OpenAIProvider, nc *nats.Conn) *OpenAIService } } -func (s *OpenAIService) ChatCompletion(prompt string) (string, error) { +func (s *OpenAIService) ChatCompletion(messages []openai.ChatCompletionMessage) (string, error) { resp, err := s.OpenAIProvider.Client.CreateChatCompletion( context.Background(), openai.ChatCompletionRequest{ - Model: "gpt-4o-mini", - Messages: []openai.ChatCompletionMessage{ - { - Role: openai.ChatMessageRoleUser, - Content: prompt, - }, - }, + Model: "gpt-4o-mini", + Messages: messages, }, ) diff --git a/pkg/services/slack.go b/pkg/services/slack.go index 1cd3f9b..8574c73 100644 --- a/pkg/services/slack.go +++ b/pkg/services/slack.go @@ -3,6 +3,7 @@ package services import ( "github.com/fandujar/baymax/pkg/providers" "github.com/nats-io/nats.go" + "github.com/slack-go/slack" ) type SlackService struct { @@ -16,3 +17,18 @@ func NewSlackService(slackProvider *providers.SlackProvider, natsClient *nats.Co NatsClient: natsClient, } } + +func (s *SlackService) GetAllMessagesFromThread(channel, threadTimestamp string) ([]slack.Message, error) { + messages, _, _, err := s.SlackProvider.Client.GetConversationReplies( + &slack.GetConversationRepliesParameters{ + ChannelID: channel, + Timestamp: threadTimestamp, + Limit: 100, + }, + ) + if err != nil { + return nil, err + } + + return messages, nil +} diff --git a/pkg/transport/openai.go b/pkg/transport/openai.go index 022695e..28d049d 100644 --- a/pkg/transport/openai.go +++ b/pkg/transport/openai.go @@ -7,7 +7,7 @@ import ( "github.com/fandujar/baymax/pkg/subjects" "github.com/nats-io/nats.go" "github.com/rs/zerolog/log" - "github.com/slack-go/slack/slackevents" + "github.com/sashabaranov/go-openai" ) type OpenAIHandler struct { @@ -25,19 +25,32 @@ func (h *OpenAIHandler) RunEventLoop() { // Get the message and call the OpenAI API to get a response // Send the response to NATS using the subject SlackResponse - ev := &slackevents.AppMentionEvent{} + ev := &ThreadMessage{} if err := json.Unmarshal(m.Data, ev); err != nil { log.Error().Err(err).Msg("failed to unmarshal event") return } - resp, err := h.Service.ChatCompletion(ev.Text) + messages := []openai.ChatCompletionMessage{} + for _, message := range ev.Messages { + messages = append(messages, openai.ChatCompletionMessage{ + Role: "user", + Content: message.Text, + }) + } + + messages = append(messages, openai.ChatCompletionMessage{ + Role: "user", + Content: ev.Event.Text, + }) + + resp, err := h.Service.ChatCompletion(messages) if err != nil { log.Error().Err(err).Msg("failed to get chat completion") return } - ev.Text = resp + ev.Event.Text = resp data, err := json.Marshal(ev) if err != nil { diff --git a/pkg/transport/slack.go b/pkg/transport/slack.go index 7ea17ec..5b2810d 100644 --- a/pkg/transport/slack.go +++ b/pkg/transport/slack.go @@ -23,6 +23,11 @@ func NewSlackHandler(service *services.SlackService) *SlackHandler { } } +type ThreadMessage struct { + Event *slackevents.AppMentionEvent `json:"event"` + Messages []slack.Message `json:"messages"` +} + func (h *SlackHandler) RunEventLoop() { handler := h.RegisterSlackHandlers() go func() { @@ -33,24 +38,24 @@ func (h *SlackHandler) RunEventLoop() { h.Service.NatsClient.Subscribe(subjects.SlackResponse, func(m *nats.Msg) { log.Debug().Msgf("Received a message: %s", string(m.Data)) - var ev *slackevents.AppMentionEvent + var ev *ThreadMessage if err := json.Unmarshal(m.Data, &ev); err != nil { log.Error().Err(err).Msg("failed to unmarshal event") return } - if ev.ThreadTimeStamp == "" { + if ev.Event.ThreadTimeStamp == "" { log.Debug().Msg("thread timestamp is empty") - ev.ThreadTimeStamp = ev.TimeStamp + ev.Event.ThreadTimeStamp = ev.Event.TimeStamp } message := slack.MsgOptionCompose( - slack.MsgOptionText(ev.Text, false), - slack.MsgOptionTS(ev.ThreadTimeStamp), + slack.MsgOptionText(ev.Event.Text, false), + slack.MsgOptionTS(ev.Event.ThreadTimeStamp), ) if _, _, _, err := h.Service.SlackProvider.Client.SendMessage( - ev.Channel, + ev.Event.Channel, message, ); err != nil { log.Error().Err(err).Msg("failed to send message") @@ -101,14 +106,29 @@ func (h *SlackHandler) appMentionHandler(evt *socketmode.Event, client *socketmo return } - evJSON, err := json.Marshal(ev) + var messages []slack.Message + var err error + if ev.ThreadTimeStamp != "" { + // If inside a thread, get all messages in the thread to pass as context + messages, err = h.Service.GetAllMessagesFromThread(ev.Channel, ev.ThreadTimeStamp) + if err != nil { + log.Error().Err(err).Msg("failed to get messages from thread") + return + } + } + + threadMessage := &ThreadMessage{ + Event: ev, + Messages: messages, + } + + threadMessageJSON, err := json.Marshal(threadMessage) if err != nil { log.Error().Err(err).Msg("failed to marshal event") return } - log.Debug().Msgf("Received a message: %s", string(evJSON)) - if err := h.Service.NatsClient.Publish(subjects.SlackEvents, evJSON); err != nil { + if err := h.Service.NatsClient.Publish(subjects.SlackEvents, threadMessageJSON); err != nil { log.Error().Err(err).Msg("failed to publish event to NATS") } }