Skip to content

Commit

Permalink
Use a chat session to to able to set roles
Browse files Browse the repository at this point in the history
  • Loading branch information
rakyll committed Sep 6, 2024
1 parent a243342 commit 69c8aac
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
23 changes: 17 additions & 6 deletions openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,25 +54,36 @@ func (h *handlers) ChatCompletionsHandler(w http.ResponseWriter, r *http.Request
Temperature: chatReq.Temperature,
TopP: chatReq.TopP,
}
parts := []genai.Part{}
for _, r := range chatReq.Messages {

chat := model.StartChat()
var lastPart genai.Part
for i, r := range chatReq.Messages {
if r.Role == "system" {
model.SystemInstruction = &genai.Content{
Role: r.Role,
Parts: []genai.Part{genai.Text(r.Content)},
}
continue
}
// TODO: parts don't support role for model.GenerateContent
parts = append(parts, genai.Text(r.Content))
if i == len(chatReq.Messages)-1 { // the last message
// TODO(jbd): This hack strips away the role of the last message.
// But Gemini API Go SDK doesn't give flexibility to call SendMessage
// with a list of contents.
lastPart = genai.Text(r.Content)
break
}
chat.History = append(chat.History, &genai.Content{
Role: r.Role,
Parts: []genai.Part{genai.Text(r.Content)},
})
}

if chatReq.Stream {
streamingChatCompletionsHandler(w, r, chatReq.Model, model, parts)
streamingChatCompletionsHandler(w, r, chatReq.Model, chat, lastPart)
return
}

geminiResp, err := model.GenerateContent(r.Context(), parts...)
geminiResp, err := chat.SendMessage(r.Context(), lastPart)
if err != nil {
internal.ErrorHandler(w, r, http.StatusInternalServerError, "failed to generate content: %v", err)
return
Expand Down
4 changes: 2 additions & 2 deletions openai/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ import (
"google.golang.org/api/iterator"
)

func streamingChatCompletionsHandler(w http.ResponseWriter, r *http.Request, model string, genModel *genai.GenerativeModel, parts []genai.Part) {
iter := genModel.GenerateContentStream(r.Context(), parts...)
func streamingChatCompletionsHandler(w http.ResponseWriter, r *http.Request, model string, chat *genai.ChatSession, lastPart genai.Part) {
iter := chat.SendMessageStream(r.Context(), lastPart)

encoder := json.NewEncoder(w)
for {
Expand Down

0 comments on commit 69c8aac

Please sign in to comment.