Skip to content

Commit

Permalink
fix(claude) : fix problem of alternance of role user and assistant
Browse files Browse the repository at this point in the history
  • Loading branch information
LordPax committed Sep 21, 2024
1 parent ff9eb3f commit 2cd6d87
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 15 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ Projet to use ai api to generate text, image, etc.

## To do

- [ ] Faire command text
- [ ] Faire command translate
- [ ] Faire command image
- [ ] Faire command speech
- [ ] Faire command tts
- [x] command text
- [ ] command translate
- [ ] command image
- [ ] command speech
- [ ] command tts

## Build and install

Expand Down
2 changes: 1 addition & 1 deletion commands/text.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func TextFlags() []cli.Flag {
Aliases: []string{"l"},
Usage: l.Get("text-list-history-usage"),
Action: func(c *cli.Context, value bool) error {
if err := service.ListHistory(true, true); err != nil {
if err := service.ListHistory(true); err != nil {
return err
}
os.Exit(0)
Expand Down
33 changes: 32 additions & 1 deletion sdk/claude.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,14 @@ func NewClaudeText(apiKey, model string, temp float64) (*ClaudeText, error) {
func (c *ClaudeText) SendRequest(text string) (Message, error) {
var textResponse ClaudeResponse

c.AppendHistory("user", text)
idLastMsg := len(c.GetHistory()) - 1
lastMessage := c.GetMessage(idLastMsg)

if lastMessage != nil && lastMessage.Role == "user" {
c.AppendMessage(idLastMsg, text)
} else {
c.AppendHistory("user", text)
}

jsonBody, err := json.Marshal(ClaudeBody{
Model: c.Model,
Expand Down Expand Up @@ -114,3 +121,27 @@ func (c *ClaudeText) SendRequest(text string) (Message, error) {

return respMessage, nil
}

func (c *ClaudeText) AppendHistory(role string, text ...string) Message {
var content []Content
name := c.SelectedHistory

if role == "system" {
role = "user"
}

for _, t := range text {
content = append(content, Content{
Type: "text",
Text: t,
})
}

message := Message{
Role: role,
Content: content,
}
c.History[name] = append(c.History[name], message)

return message
}
26 changes: 25 additions & 1 deletion sdk/history.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ type ITextHistory interface {
SetSelectedHistory(name string)
GetSelectedHistory() string
ClearHistory()
GetMessage(index int) *Message
AppendMessage(index int, text ...string)
}

type TextHistory struct {
Expand Down Expand Up @@ -58,6 +60,7 @@ func (t *TextHistory) GetSelectedHistory() string {

func (t *TextHistory) AppendHistory(role string, text ...string) Message {
var content []Content
name := t.SelectedHistory

for _, t := range text {
content = append(content, Content{
Expand All @@ -66,7 +69,6 @@ func (t *TextHistory) AppendHistory(role string, text ...string) Message {
})
}

name := t.SelectedHistory
message := Message{
Role: role,
Content: content,
Expand Down Expand Up @@ -121,3 +123,25 @@ func (t *TextHistory) GetHistory() []Message {
name := t.SelectedHistory
return t.History[name]
}

func (t *TextHistory) GetMessage(index int) *Message {
if index < 0 {
return nil
}
name := t.SelectedHistory
return &t.History[name][index]
}

func (t *TextHistory) AppendMessage(index int, text ...string) {
name := t.SelectedHistory
message := t.GetMessage(index)

for _, t := range text {
message.Content = append(message.Content, Content{
Type: "text",
Text: t,
})
}

t.History[name][index] = *message
}
15 changes: 14 additions & 1 deletion sdk/text.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,20 @@ type Message struct {
}

func (m *Message) GetContent() string {
return m.Content[0].Text
var text string

if len(m.Content) == 1 {
return m.Content[0].Text
}

for _, c := range m.Content {
if c.Type != "text" {
continue
}
text += "\n" + c.Text + "\n"
}

return text
}

type ErrorMsg struct {
Expand Down
8 changes: 2 additions & 6 deletions service/text.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func InteractiveMode() error {
textSdk := sdk.GetSdkText()
l := lang.GetLocalize()

if err := ListHistory(false, false); err != nil {
if err := ListHistory(false); err != nil {
return err
}

Expand All @@ -60,7 +60,7 @@ func InteractiveMode() error {
return nil
}

func ListHistory(showSystem, showMsg bool) error {
func ListHistory(showMsg bool) error {
textSdk := sdk.GetSdkText()
l := lang.GetLocalize()
log, err := utils.GetLog()
Expand All @@ -78,10 +78,6 @@ func ListHistory(showSystem, showMsg bool) error {
for _, message := range history {
role := message.Role

if role == "system" && !showSystem {
continue
}

switch role {
case "user":
fmt.Print(utils.Blue + "user> " + utils.Reset)
Expand Down

0 comments on commit 2cd6d87

Please sign in to comment.