From 2862c2a1c09a538d149d9b532c5ad1234f671e58 Mon Sep 17 00:00:00 2001 From: teddy Date: Fri, 27 Sep 2024 01:54:26 +0200 Subject: [PATCH] feat : add sdk mistral * add parametre -L * improve the way we add message to history * improve input function for interactive mode --- CHANGELOG.md | 12 ++++++ commands/text.go | 28 ++++++++++--- go.mod | 2 + go.sum | 4 ++ lang/en.go | 53 ++++++++++++------------ lang/fr.go | 53 ++++++++++++------------ sdk/claude.go | 25 +++++------- sdk/history.go | 43 +++++++++++++------- sdk/mistral.go | 104 +++++++++++++++++++++++++++++++++++++++++++++++ sdk/openai.go | 2 + sdk/text.go | 12 +++--- utils/utils.go | 24 ++++++----- 12 files changed, 259 insertions(+), 103 deletions(-) create mode 100644 sdk/mistral.go diff --git a/CHANGELOG.md b/CHANGELOG.md index a8e1c97..26b0f83 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,17 @@ # Changelog +## [Unreleased] + +### Added + +* Add parameter -L for text command to list all history name +* Add sdk mistral for text command + +### Changed + +* Improve the way we add message to history +* Improve input function for interactive mode + ## [0.1.0] - 2024-09-22 ### Added diff --git a/commands/text.go b/commands/text.go index 4ea7133..9e5c96b 100644 --- a/commands/text.go +++ b/commands/text.go @@ -128,9 +128,10 @@ func textFlags() []cli.Flag { }, }, &cli.BoolFlag{ - Name: "clear", - Aliases: []string{"c"}, - Usage: l.Get("text-clear-usage"), + Name: "clear", + Aliases: []string{"c"}, + Usage: l.Get("text-clear-usage"), + DisableDefaultText: true, Action: func(c *cli.Context, value bool) error { text := sdk.GetSdkText() text.ClearHistory() @@ -142,9 +143,10 @@ func textFlags() []cli.Flag { }, }, &cli.BoolFlag{ - Name: "list-history", - Aliases: []string{"l"}, - Usage: l.Get("text-list-history-usage"), + Name: "list-history", + Aliases: []string{"l"}, + Usage: l.Get("text-list-history-usage"), + DisableDefaultText: true, Action: func(c *cli.Context, value bool) error { if err := service.ListHistory(true); err != nil { return err @@ -153,6 +155,20 @@ func textFlags() []cli.Flag { return nil }, }, + &cli.BoolFlag{ + Name: "list-history-name", + Aliases: []string{"L"}, + Usage: l.Get("text-list-history-name-usage"), + DisableDefaultText: true, + Action: func(c *cli.Context, value bool) error { + text := sdk.GetSdkText() + for _, name := range text.GetHistoryNames() { + fmt.Println(name) + } + os.Exit(0) + return nil + }, + }, } } diff --git a/go.mod b/go.mod index 20df793..3d68499 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.23.1 require ( github.com/urfave/cli/v2 v2.27.3 + golang.org/x/term v0.24.0 gopkg.in/ini.v1 v1.67.0 ) @@ -12,4 +13,5 @@ require ( github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/stretchr/testify v1.9.0 // indirect github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect + golang.org/x/sys v0.25.0 // indirect ) diff --git a/go.sum b/go.sum index 71969ad..6132115 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,10 @@ github.com/urfave/cli/v2 v2.27.3 h1:/POWahRmdh7uztQ3CYnaDddk0Rm90PyOgIxgW2rr41M= github.com/urfave/cli/v2 v2.27.3/go.mod h1:m4QzxcD2qpra4z7WhzEGn74WZLViBnMpb1ToCAKdGRQ= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= +golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= +golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.24.0 h1:Mh5cbb+Zk2hqqXNO7S1iTjEphVL+jb8ZWaqh/g+JWkM= +golang.org/x/term v0.24.0/go.mod h1:lOBK/LVxemqiMij05LGJ0tzNr8xlmwBRJ81PX6wVLH8= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/lang/en.go b/lang/en.go index fafd1ae..48137d1 100644 --- a/lang/en.go +++ b/lang/en.go @@ -3,30 +3,31 @@ package lang import "github.com/LordPax/aicli/utils" var EN_STRINGS = LangString{ - "usage": "CLI toot to use ai model", - "output-desc": "Output directory", - "output-dir-empty": "Output directory is empty", - "silent": "Disable printing log to stdout", - "no-args": "No arguments provided", - "no-command": "No command provided", - "unknown-sdk": "Unknown sdk \"%s\"", - "sdk-model-usage": "Select a model", - "text-usage": "Generate text from a prompt", - "sdk-usage": "Select a sdk", - "text-temp-usage": "Set temperature", - "text-system-usage": "Instruction with role system (use \"-\" for stdin)", - "text-history-usage": "Select a history", - "text-clear-usage": "Clear history", - "text-file-usage": "Text file to use", - "text-input": "(\"exit\" to quit) " + utils.Blue + "user> " + utils.Reset, - "translate-input": "(\"exit\" to quit) " + utils.Blue + "> " + utils.Reset, - "text-list-history-usage": "List history", - "type-required": "Type is required", - "apiKey-required": "API key is required", - "empty-file": "File \"%s\" is empty", - "empty-history": "History \"%s\" is empty\n", - "translate-usage": "Translate a text", - "translate-source-usage": "Source language", - "translate-target-usage": "Target language", - "translate-target-required": "Target language is required", + "usage": "CLI toot to use ai model", + "output-desc": "Output directory", + "output-dir-empty": "Output directory is empty", + "silent": "Disable printing log to stdout", + "no-args": "No arguments provided", + "no-command": "No command provided", + "unknown-sdk": "Unknown sdk \"%s\"", + "sdk-model-usage": "Select a model", + "text-usage": "Generate text from a prompt", + "sdk-usage": "Select a sdk", + "text-temp-usage": "Set temperature", + "text-system-usage": "Instruction with role system (use \"-\" for stdin)", + "text-history-usage": "Select a history", + "text-clear-usage": "Clear history", + "text-file-usage": "Text file to use", + "text-input": "(\"exit\" to quit) " + utils.Blue + "user> " + utils.Reset, + "translate-input": "(\"exit\" to quit) " + utils.Blue + "> " + utils.Reset, + "text-list-history-usage": "List history", + "text-list-history-name-usage": "List history names", + "type-required": "Type is required", + "api-key-required": "API key is required", + "empty-file": "File \"%s\" is empty", + "empty-history": "History \"%s\" is empty\n", + "translate-usage": "Translate a text", + "translate-source-usage": "Source language", + "translate-target-usage": "Target language", + "translate-target-required": "Target language is required", } diff --git a/lang/fr.go b/lang/fr.go index 345f2ad..8456cd8 100644 --- a/lang/fr.go +++ b/lang/fr.go @@ -3,30 +3,31 @@ package lang import "github.com/LordPax/aicli/utils" var FR_STRINGS = LangString{ - "usage": "CLI pour utiliser des modèles d'IA", - "output-desc": "Répertoire de sortie", - "output-dir-empty": "Le répertoire de sortie est vide", - "silent": "Désactiver l'impression du journal sur stdout", - "no-args": "Aucun argument fourni", - "no-command": "Aucune commande fournie", - "unknown-sdk": "Sdk inconnu \"%s\"", - "sdk-model-usage": "Sélectionner un modèle", - "text-usage": "Générer du texte à partir d'un prompt", - "sdk-usage": "Sélectionner un sdk", - "text-temp-usage": "Définir la température", - "text-system-usage": "Instruction avec rôle système (utilisez \"-\" pour stdin)", - "text-history-usage": "Sélectionner un historique", - "text-clear-usage": "Effacer l'historique", - "text-file-usage": "Fichier texte à utiliser", - "text-input": "(\"exit\" pour quitter) " + utils.Blue + "user> " + utils.Reset, - "translate-input": "(\"exit\" pour quitter) " + utils.Blue + "> " + utils.Reset, - "text-list-history-usage": "Lister l'historique", - "type-required": "Le type est requis", - "apiKey-required": "La clé API est requise", - "empty-file": "Le fichier \"%s\" est vide", - "empty-history": "L'historique \"%s\" est vide\n", - "translate-usage": "Traduire un texte", - "translate-source-usage": "Langue source", - "translate-target-usage": "Langue cible", - "translate-target-required": "La langue cible est requise", + "usage": "CLI pour utiliser des modèles d'IA", + "output-desc": "Répertoire de sortie", + "output-dir-empty": "Le répertoire de sortie est vide", + "silent": "Désactiver l'impression du journal sur stdout", + "no-args": "Aucun argument fourni", + "no-command": "Aucune commande fournie", + "unknown-sdk": "Sdk inconnu \"%s\"", + "sdk-model-usage": "Sélectionner un modèle", + "text-usage": "Générer du texte à partir d'un prompt", + "sdk-usage": "Sélectionner un sdk", + "text-temp-usage": "Définir la température", + "text-system-usage": "Instruction avec rôle système (utilisez \"-\" pour stdin)", + "text-history-usage": "Sélectionner un historique", + "text-clear-usage": "Effacer l'historique", + "text-file-usage": "Fichier texte à utiliser", + "text-input": "(\"exit\" pour quitter) " + utils.Blue + "user> " + utils.Reset, + "translate-input": "(\"exit\" pour quitter) " + utils.Blue + "> " + utils.Reset, + "text-list-history-usage": "Lister l'historique", + "text-list-history-name-usage": "Lister les noms d'historique", + "type-required": "Le type est requis", + "api-key-required": "La clé API est requise", + "empty-file": "Le fichier \"%s\" est vide", + "empty-history": "L'historique \"%s\" est vide\n", + "translate-usage": "Traduire un texte", + "translate-source-usage": "Langue source", + "translate-target-usage": "Langue cible", + "translate-target-required": "La langue cible est requise", } diff --git a/sdk/claude.go b/sdk/claude.go index 4be1418..623f3de 100644 --- a/sdk/claude.go +++ b/sdk/claude.go @@ -68,19 +68,14 @@ func NewClaudeText(apiKey, model string, temp float64) (*ClaudeText, error) { func (c *ClaudeText) SendRequest(text string) (Message, error) { var textResponse ClaudeResponse - idLastMsg := len(c.GetHistory()) - 1 - lastMessage := c.GetMessage(idLastMsg) + c.AppendHistory("user", text) - if lastMessage != nil && lastMessage.Role == "user" { - c.AppendMessage(idLastMsg, text) - } else { - c.AppendHistory("user", text) - } + test := c.GetHistory() jsonBody, err := json.Marshal(ClaudeBody{ Model: c.Model, MaxTokens: 1024, - Messages: c.GetHistory(), + Messages: test, }) if err != nil { return Message{}, err @@ -123,23 +118,23 @@ func (c *ClaudeText) SendRequest(text string) (Message, error) { } func (c *ClaudeText) AppendHistory(role string, text ...string) Message { - var content []Content name := c.SelectedHistory if role == "system" { role = "user" } - for _, t := range text { - content = append(content, Content{ - Type: "text", - Text: t, - }) + idLastMsg := len(c.GetHistory()) - 1 + lastMessage := c.GetMessage(idLastMsg) + + // If the last message is from the same role, append the new text to the last message + if lastMessage != nil && lastMessage.Role == role { + return c.AppendMessage(idLastMsg, text...) } message := Message{ Role: role, - Content: content, + Content: textContent(text...), } c.History[name] = append(c.History[name], message) diff --git a/sdk/history.go b/sdk/history.go index 6fecfe4..4896c94 100644 --- a/sdk/history.go +++ b/sdk/history.go @@ -4,6 +4,7 @@ import ( "encoding/json" "os" "path" + "strings" "github.com/LordPax/aicli/config" "github.com/LordPax/aicli/utils" @@ -18,7 +19,8 @@ type ITextHistory interface { GetSelectedHistory() string ClearHistory() GetMessage(index int) *Message - AppendMessage(index int, text ...string) + AppendMessage(index int, text ...string) Message + GetHistoryNames() []string } type TextHistory struct { @@ -59,19 +61,11 @@ func (t *TextHistory) GetSelectedHistory() string { } func (t *TextHistory) AppendHistory(role string, text ...string) Message { - var content []Content name := t.SelectedHistory - for _, t := range text { - content = append(content, Content{ - Type: "text", - Text: t, - }) - } - message := Message{ Role: role, - Content: content, + Content: textContent(text...), } t.History[name] = append(t.History[name], message) @@ -132,16 +126,37 @@ func (t *TextHistory) GetMessage(index int) *Message { return &t.History[name][index] } -func (t *TextHistory) AppendMessage(index int, text ...string) { +func (t *TextHistory) AppendMessage(index int, text ...string) Message { name := t.SelectedHistory message := t.GetMessage(index) + content := textContent(text...) + message.Content = append(message.Content, content...) + + t.History[name][index] = *message + + return *message +} + +func (t *TextHistory) GetHistoryNames() []string { + var names []string + + for k := range t.History { + names = append(names, k) + } + + return names +} + +func textContent(text ...string) []Content { + var content []Content + for _, t := range text { - message.Content = append(message.Content, Content{ + content = append(content, Content{ Type: "text", - Text: t, + Text: strings.TrimSpace(t), }) } - t.History[name][index] = *message + return content } diff --git a/sdk/mistral.go b/sdk/mistral.go new file mode 100644 index 0000000..5f272b5 --- /dev/null +++ b/sdk/mistral.go @@ -0,0 +1,104 @@ +package sdk + +import ( + "encoding/json" + "errors" + "io" + "net/http" + + "github.com/LordPax/aicli/utils" +) + +type MistralErrorMsg struct { + Message string `json:"message"` +} + +type MistralText struct { + Sdk + SdkText + TextHistory +} + +func NewMistralText(apiKey, model string, temp float64) (*MistralText, error) { + history, err := NewTextHistory("mistral") + if err != nil { + return nil, err + } + + sdkService := &MistralText{ + Sdk: Sdk{ + Name: "mistral", + ApiUrl: "https://api.mistral.ai/v1/chat/completions", + ApiKey: apiKey, + }, + SdkText: SdkText{ + Model: "mistral-medium", + Temp: 0.7, + }, + TextHistory: *history, + } + + if model != "" { + sdkService.Model = model + } + + if temp != 0 { + sdkService.Temp = temp + } + + if err := sdkService.LoadHistory(); err != nil { + return nil, err + } + + return sdkService, nil +} + +func (m *MistralText) SendRequest(text string) (Message, error) { + var textResponse OpenaiResponse + + m.AppendHistory("user", text) + + jsonBody, err := json.Marshal(OpenaiBody{ + Model: m.Model, + Messages: m.GetHistory(), + Temp: m.Temp, + }) + if err != nil { + return Message{}, err + } + + resp, err := utils.PostRequest(m.ApiUrl, jsonBody, map[string]string{ + "Content-Type": "application/json", + "Authorization": "Bearer " + m.ApiKey, + }) + if err != nil { + return Message{}, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return Message{}, err + } + + if resp.StatusCode != http.StatusOK { + var errorMsg MistralErrorMsg + if err := json.Unmarshal(respBody, &errorMsg); err != nil { + return Message{}, err + } + return Message{}, errors.New(errorMsg.Message) + } + + if err := json.Unmarshal(respBody, &textResponse); err != nil { + return Message{}, err + } + + msg := textResponse.Choices[0].Message + respMessage := m.AppendHistory(msg.Role, msg.GetContent()) + + if err := m.SaveHistory(); err != nil { + return Message{}, err + } + + return respMessage, nil +} diff --git a/sdk/openai.go b/sdk/openai.go index 13d4fef..7f61f6c 100644 --- a/sdk/openai.go +++ b/sdk/openai.go @@ -30,6 +30,7 @@ func (m *ChoicesMessage) GetContent() string { type OpenaiBody struct { Model string `json:"model"` Messages []Message `json:"messages"` + Temp float64 `json:"temperature"` } type OpenaiText struct { @@ -81,6 +82,7 @@ func (o *OpenaiText) SendRequest(text string) (Message, error) { jsonBody, err := json.Marshal(OpenaiBody{ Model: o.Model, Messages: o.GetHistory(), + Temp: o.Temp, }) if err != nil { return Message{}, err diff --git a/sdk/text.go b/sdk/text.go index e8c5581..39c6bcb 100644 --- a/sdk/text.go +++ b/sdk/text.go @@ -32,11 +32,14 @@ func (m *Message) GetContent() string { return m.Content[0].Text } - for _, c := range m.Content { + for i, c := range m.Content { if c.Type != "text" { continue } - text += "\n" + c.Text + "\n" + if i != 0 { + text += "\n---\n" + } + text += c.Text } return text @@ -81,6 +84,8 @@ func InitSdkText(sdk string) error { sdkTextInstance, err = NewOpenaiText(apiKey, model, temp) case "claude": sdkTextInstance, err = NewClaudeText(apiKey, model, temp) + case "mistral": + sdkTextInstance, err = NewMistralText(apiKey, model, temp) default: return fmt.Errorf(l.Get("unknown-sdk"), sdkType) } @@ -106,9 +111,6 @@ func getConfigText(sdkType string) (string, string, string, float64, error) { apiKey := confText.Key("apiKey").String() if apiKey == "" { apiKey = confText.Key(sdkType + "-apiKey").String() - if apiKey == "" { - return "", "", "", 0, errors.New(l.Get("api-key-required")) - } } model := confText.Key("model").String() diff --git a/utils/utils.go b/utils/utils.go index 40fcbf7..35c9979 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -1,11 +1,12 @@ package utils import ( - "bufio" "bytes" - "fmt" "net/http" "os" + "strings" + + "golang.org/x/term" ) const ( @@ -44,23 +45,24 @@ func PostRequest(url string, data []byte, option map[string]string) (*http.Respo } func Input(prompt string, defaultVal string, nullable bool) string { - if defaultVal != "" { - prompt = fmt.Sprintf("[%s] %s", defaultVal, prompt) + oldState, err := term.MakeRaw(int(os.Stdin.Fd())) + if err != nil { + panic(err) } + defer term.Restore(int(os.Stdin.Fd()), oldState) - fmt.Print(prompt) + t := term.NewTerminal(os.Stdin, prompt) - scanner := bufio.NewScanner(os.Stdin) - scanner.Scan() - text := scanner.Text() + line, _ := t.ReadLine() + line = strings.TrimSpace(line) - if text == "" && defaultVal != "" { + if line == "" && defaultVal != "" { return defaultVal } - if text == "" && !nullable { + if line == "" && !nullable { return Input(prompt, defaultVal, nullable) } - return text + return line }