Skip to content

Commit

Permalink
Merge pull request #5 from fandujar/release/0.1.0
Browse files Browse the repository at this point in the history
feat: add open ai context
  • Loading branch information
fandujar authored Sep 22, 2024
2 parents d1a1d9e + 60c5a85 commit ede70f6
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 21 deletions.
11 changes: 3 additions & 8 deletions pkg/services/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,12 @@ func NewOpenAIService(p *providers.OpenAIProvider, nc *nats.Conn) *OpenAIService
}
}

func (s *OpenAIService) ChatCompletion(prompt string) (string, error) {
func (s *OpenAIService) ChatCompletion(messages []openai.ChatCompletionMessage) (string, error) {
resp, err := s.OpenAIProvider.Client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: "gpt-4o-mini",
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: prompt,
},
},
Model: "gpt-4o-mini",
Messages: messages,
},
)

Expand Down
16 changes: 16 additions & 0 deletions pkg/services/slack.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package services
import (
"github.com/fandujar/baymax/pkg/providers"
"github.com/nats-io/nats.go"
"github.com/slack-go/slack"
)

type SlackService struct {
Expand All @@ -16,3 +17,18 @@ func NewSlackService(slackProvider *providers.SlackProvider, natsClient *nats.Co
NatsClient: natsClient,
}
}

func (s *SlackService) GetAllMessagesFromThread(channel, threadTimestamp string) ([]slack.Message, error) {
messages, _, _, err := s.SlackProvider.Client.GetConversationReplies(
&slack.GetConversationRepliesParameters{
ChannelID: channel,
Timestamp: threadTimestamp,
Limit: 100,
},
)
if err != nil {
return nil, err
}

return messages, nil
}
21 changes: 17 additions & 4 deletions pkg/transport/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"github.com/fandujar/baymax/pkg/subjects"
"github.com/nats-io/nats.go"
"github.com/rs/zerolog/log"
"github.com/slack-go/slack/slackevents"
"github.com/sashabaranov/go-openai"
)

type OpenAIHandler struct {
Expand All @@ -25,19 +25,32 @@ func (h *OpenAIHandler) RunEventLoop() {
// Get the message and call the OpenAI API to get a response
// Send the response to NATS using the subject SlackResponse

ev := &slackevents.AppMentionEvent{}
ev := &ThreadMessage{}
if err := json.Unmarshal(m.Data, ev); err != nil {
log.Error().Err(err).Msg("failed to unmarshal event")
return
}

resp, err := h.Service.ChatCompletion(ev.Text)
messages := []openai.ChatCompletionMessage{}
for _, message := range ev.Messages {
messages = append(messages, openai.ChatCompletionMessage{
Role: "user",
Content: message.Text,
})
}

messages = append(messages, openai.ChatCompletionMessage{
Role: "user",
Content: ev.Event.Text,
})

resp, err := h.Service.ChatCompletion(messages)
if err != nil {
log.Error().Err(err).Msg("failed to get chat completion")
return
}

ev.Text = resp
ev.Event.Text = resp

data, err := json.Marshal(ev)
if err != nil {
Expand Down
38 changes: 29 additions & 9 deletions pkg/transport/slack.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ func NewSlackHandler(service *services.SlackService) *SlackHandler {
}
}

type ThreadMessage struct {
Event *slackevents.AppMentionEvent `json:"event"`
Messages []slack.Message `json:"messages"`
}

func (h *SlackHandler) RunEventLoop() {
handler := h.RegisterSlackHandlers()
go func() {
Expand All @@ -33,24 +38,24 @@ func (h *SlackHandler) RunEventLoop() {

h.Service.NatsClient.Subscribe(subjects.SlackResponse, func(m *nats.Msg) {
log.Debug().Msgf("Received a message: %s", string(m.Data))
var ev *slackevents.AppMentionEvent
var ev *ThreadMessage
if err := json.Unmarshal(m.Data, &ev); err != nil {
log.Error().Err(err).Msg("failed to unmarshal event")
return
}

if ev.ThreadTimeStamp == "" {
if ev.Event.ThreadTimeStamp == "" {
log.Debug().Msg("thread timestamp is empty")
ev.ThreadTimeStamp = ev.TimeStamp
ev.Event.ThreadTimeStamp = ev.Event.TimeStamp
}

message := slack.MsgOptionCompose(
slack.MsgOptionText(ev.Text, false),
slack.MsgOptionTS(ev.ThreadTimeStamp),
slack.MsgOptionText(ev.Event.Text, false),
slack.MsgOptionTS(ev.Event.ThreadTimeStamp),
)

if _, _, _, err := h.Service.SlackProvider.Client.SendMessage(
ev.Channel,
ev.Event.Channel,
message,
); err != nil {
log.Error().Err(err).Msg("failed to send message")
Expand Down Expand Up @@ -101,14 +106,29 @@ func (h *SlackHandler) appMentionHandler(evt *socketmode.Event, client *socketmo
return
}

evJSON, err := json.Marshal(ev)
var messages []slack.Message
var err error
if ev.ThreadTimeStamp != "" {
// If inside a thread, get all messages in the thread to pass as context
messages, err = h.Service.GetAllMessagesFromThread(ev.Channel, ev.ThreadTimeStamp)
if err != nil {
log.Error().Err(err).Msg("failed to get messages from thread")
return
}
}

threadMessage := &ThreadMessage{
Event: ev,
Messages: messages,
}

threadMessageJSON, err := json.Marshal(threadMessage)
if err != nil {
log.Error().Err(err).Msg("failed to marshal event")
return
}

log.Debug().Msgf("Received a message: %s", string(evJSON))
if err := h.Service.NatsClient.Publish(subjects.SlackEvents, evJSON); err != nil {
if err := h.Service.NatsClient.Publish(subjects.SlackEvents, threadMessageJSON); err != nil {
log.Error().Err(err).Msg("failed to publish event to NATS")
}
}
Expand Down

0 comments on commit ede70f6

Please sign in to comment.