From f989620b0b861f9e50da108f0ec41b96a4c1269e Mon Sep 17 00:00:00 2001 From: teddy Date: Sat, 28 Sep 2024 17:14:00 +0200 Subject: [PATCH] improve : multiple improvement * add parameter -i for text command to not make api call * improve the way we initialize the sdk * rename `system` parameter to `context` for text command (breaking change) --- CHANGELOG.md | 11 +++++++++++ commands/base.go | 7 ++++--- commands/text.go | 46 ++++++++++++++++++++++++++++++++++--------- commands/translate.go | 27 +++++++++++++++++-------- lang/en.go | 3 ++- lang/fr.go | 3 ++- main.go | 23 ++++++++++++++-------- sdk/claude.go | 14 ++++++++++--- sdk/deepl.go | 1 + sdk/mistral.go | 12 ++++++++++- sdk/openai.go | 12 ++++++++++- sdk/sdk.go | 23 ++++++++++------------ sdk/text.go | 4 ++++ service/text.go | 8 ++++++++ 14 files changed, 146 insertions(+), 48 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6068f82..8e906fa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,16 @@ # Changelog +## [Unreleased] + +### Added + +* Add parameter -i for text command to not make api call + +### Changed + +* Improve the way we initialize the sdk +* Rename `system` parameter to `context` for text command (breaking change) + ## [0.2.0] - 2024-09-27 ### Added diff --git a/commands/base.go b/commands/base.go index 4123631..7dcfc44 100644 --- a/commands/base.go +++ b/commands/base.go @@ -13,9 +13,10 @@ func MainFlags() []cli.Flag { l := lang.GetLocalize() return []cli.Flag{ &cli.BoolFlag{ - Name: "silent", - Aliases: []string{"s"}, - Usage: l.Get("silent"), + Name: "silent", + Aliases: []string{"s"}, + Usage: l.Get("silent"), + DisableDefaultText: true, Action: func(c *cli.Context, value bool) error { log, err := utils.GetLog() if err != nil { diff --git a/commands/text.go b/commands/text.go index 9e5c96b..0d390c1 100644 --- a/commands/text.go +++ b/commands/text.go @@ -13,8 +13,13 @@ import ( cli "github.com/urfave/cli/v2" ) -func TextCommand() *cli.Command { +func TextCommand() (*cli.Command, error) { l := lang.GetLocalize() + + if err := sdk.InitSdkText(""); err != nil { + return nil, err + } + return &cli.Command{ Name: "text", Usage: l.Get("text-usage"), @@ -22,7 +27,7 @@ func TextCommand() *cli.Command { Aliases: []string{"t"}, Action: textAction, Flags: textFlags(), - } + }, nil } func textFlags() []cli.Flag { @@ -35,6 +40,7 @@ func textFlags() []cli.Flag { Aliases: []string{"S"}, Usage: l.Get("sdk-usage"), DefaultText: textSdk.GetName(), + Category: "global", Action: func(c *cli.Context, value string) error { if err := sdk.InitSdkText(value); err != nil { return err @@ -42,11 +48,24 @@ func textFlags() []cli.Flag { return nil }, }, + &cli.BoolFlag{ + Name: "inerte", + Aliases: []string{"i"}, + Usage: l.Get("inerte-usage"), + DisableDefaultText: true, + Category: "global", + Action: func(c *cli.Context, value bool) error { + text := sdk.GetSdkText() + text.SetInerte(value) + return nil + }, + }, &cli.StringFlag{ Name: "history", Aliases: []string{"H"}, Usage: l.Get("text-history-usage"), DefaultText: textSdk.GetSelectedHistory(), + Category: "history", Action: func(c *cli.Context, value string) error { text := sdk.GetSdkText() text.SetSelectedHistory(value) @@ -58,6 +77,7 @@ func textFlags() []cli.Flag { Aliases: []string{"m"}, Usage: l.Get("sdk-model-usage"), DefaultText: textSdk.GetModel(), + Category: "text", Action: func(c *cli.Context, value string) error { text := sdk.GetSdkText() text.SetModel(value) @@ -69,6 +89,7 @@ func textFlags() []cli.Flag { Aliases: []string{"t"}, Usage: l.Get("text-temp-usage"), DefaultText: strconv.FormatFloat(textSdk.GetTemp(), 'f', -1, 64), + Category: "text", Action: func(c *cli.Context, value float64) error { text := sdk.GetSdkText() text.SetTemp(value) @@ -76,9 +97,10 @@ func textFlags() []cli.Flag { }, }, &cli.StringSliceFlag{ - Name: "system", - Aliases: []string{"s"}, - Usage: l.Get("text-system-usage"), + Name: "context", + Aliases: []string{"s"}, + Usage: l.Get("text-system-usage"), + Category: "text", Action: func(c *cli.Context, values []string) error { text := sdk.GetSdkText() var content []string @@ -102,9 +124,10 @@ func textFlags() []cli.Flag { }, }, &cli.StringSliceFlag{ - Name: "file", - Aliases: []string{"f"}, - Usage: l.Get("text-file-usage"), + Name: "file", + Aliases: []string{"f"}, + Usage: l.Get("text-file-usage"), + Category: "text", Action: func(c *cli.Context, files []string) error { text := sdk.GetSdkText() var fileContent []string @@ -132,6 +155,7 @@ func textFlags() []cli.Flag { Aliases: []string{"c"}, Usage: l.Get("text-clear-usage"), DisableDefaultText: true, + Category: "history", Action: func(c *cli.Context, value bool) error { text := sdk.GetSdkText() text.ClearHistory() @@ -147,6 +171,7 @@ func textFlags() []cli.Flag { Aliases: []string{"l"}, Usage: l.Get("text-list-history-usage"), DisableDefaultText: true, + Category: "history", Action: func(c *cli.Context, value bool) error { if err := service.ListHistory(true); err != nil { return err @@ -160,6 +185,7 @@ func textFlags() []cli.Flag { Aliases: []string{"L"}, Usage: l.Get("text-list-history-name-usage"), DisableDefaultText: true, + Category: "history", Action: func(c *cli.Context, value bool) error { text := sdk.GetSdkText() for _, name := range text.GetHistoryNames() { @@ -173,7 +199,9 @@ func textFlags() []cli.Flag { } func textAction(c *cli.Context) error { - if c.NArg() == 0 { + textSdk := sdk.GetSdkText() + + if c.NArg() == 0 && !textSdk.GetInerte() { if err := service.InteractiveMode(); err != nil { return err } diff --git a/commands/translate.go b/commands/translate.go index 8927c44..0564095 100644 --- a/commands/translate.go +++ b/commands/translate.go @@ -7,8 +7,13 @@ import ( cli "github.com/urfave/cli/v2" ) -func TranslateCommand() *cli.Command { +func TranslateCommand() (*cli.Command, error) { l := lang.GetLocalize() + + if err := sdk.InitSdkTranslate(""); err != nil { + return nil, err + } + return &cli.Command{ Name: "translate", Usage: l.Get("translate-usage"), @@ -16,18 +21,20 @@ func TranslateCommand() *cli.Command { Aliases: []string{"tr"}, Action: translateAction, Flags: translateFlags(), - } + }, nil } func translateFlags() []cli.Flag { l := lang.GetLocalize() sdkTranslate := sdk.GetSdkTranslate() + return []cli.Flag{ &cli.StringFlag{ Name: "sdk", Aliases: []string{"S"}, Usage: l.Get("sdk-usage"), DefaultText: sdkTranslate.GetName(), + Category: "global", Action: func(c *cli.Context, value string) error { if err := sdk.InitSdkTranslate(value); err != nil { return err @@ -36,19 +43,23 @@ func translateFlags() []cli.Flag { }, }, &cli.StringFlag{ - Name: "source", - Aliases: []string{"s"}, - Usage: l.Get("translate-source-usage"), + Name: "source", + Aliases: []string{"s"}, + Usage: l.Get("translate-source-usage"), + Category: "translate", Action: func(c *cli.Context, value string) error { + sdkTranslate := sdk.GetSdkTranslate() sdkTranslate.SetSourceLang(value) return nil }, }, &cli.StringFlag{ - Name: "target", - Aliases: []string{"t"}, - Usage: l.Get("translate-target-usage"), + Name: "target", + Aliases: []string{"t"}, + Usage: l.Get("translate-target-usage"), + Category: "translate", Action: func(c *cli.Context, value string) error { + sdkTranslate := sdk.GetSdkTranslate() sdkTranslate.SetTargetLang(value) return nil }, diff --git a/lang/en.go b/lang/en.go index 48137d1..e449bdd 100644 --- a/lang/en.go +++ b/lang/en.go @@ -11,10 +11,11 @@ var EN_STRINGS = LangString{ "no-command": "No command provided", "unknown-sdk": "Unknown sdk \"%s\"", "sdk-model-usage": "Select a model", + "inerte-usage": "Do not make API call", "text-usage": "Generate text from a prompt", "sdk-usage": "Select a sdk", "text-temp-usage": "Set temperature", - "text-system-usage": "Instruction with role system (use \"-\" for stdin)", + "text-system-usage": "Instruction to enter as context (use \"-\" for stdin)", "text-history-usage": "Select a history", "text-clear-usage": "Clear history", "text-file-usage": "Text file to use", diff --git a/lang/fr.go b/lang/fr.go index 8456cd8..b095a04 100644 --- a/lang/fr.go +++ b/lang/fr.go @@ -11,10 +11,11 @@ var FR_STRINGS = LangString{ "no-command": "Aucune commande fournie", "unknown-sdk": "Sdk inconnu \"%s\"", "sdk-model-usage": "Sélectionner un modèle", + "inerte-usage": "N'effectue pas d'appel à l'API", "text-usage": "Générer du texte à partir d'un prompt", "sdk-usage": "Sélectionner un sdk", "text-temp-usage": "Définir la température", - "text-system-usage": "Instruction avec rôle système (utilisez \"-\" pour stdin)", + "text-system-usage": "Instruction à entrer comme context (utilisez \"-\" pour stdin)", "text-history-usage": "Sélectionner un historique", "text-clear-usage": "Effacer l'historique", "text-file-usage": "Fichier texte à utiliser", diff --git a/main.go b/main.go index 85ba26c..436b39b 100644 --- a/main.go +++ b/main.go @@ -7,7 +7,6 @@ import ( "github.com/LordPax/aicli/commands" "github.com/LordPax/aicli/config" "github.com/LordPax/aicli/lang" - "github.com/LordPax/aicli/sdk" "github.com/LordPax/aicli/utils" cli "github.com/urfave/cli/v2" @@ -38,20 +37,28 @@ func main() { l.AddStrings(&lang.EN_STRINGS, "en_US.UTF-8", "en_GB.UTF-8", "en") l.AddStrings(&lang.FR_STRINGS, "fr_FR.UTF-8", "fr_CA.UTF-8", "fr") - if err := sdk.InitSdk(); err != nil { - log.PrintfErr("%v\n", err) - os.Exit(1) - } - app := cli.NewApp() app.Name = config.NAME app.Usage = l.Get("usage") app.Version = config.VERSION app.Action = commands.MainAction app.Flags = commands.MainFlags() + + textCmd, err := commands.TextCommand() + if err != nil { + log.PrintfErr("%v\n", err) + os.Exit(1) + } + + translateCmd, err := commands.TranslateCommand() + if err != nil { + log.PrintfErr("%v\n", err) + os.Exit(1) + } + app.Commands = []*cli.Command{ - commands.TextCommand(), - commands.TranslateCommand(), + textCmd, + translateCmd, // TODO : add command for image, audio and translate } diff --git a/sdk/claude.go b/sdk/claude.go index 623f3de..c167388 100644 --- a/sdk/claude.go +++ b/sdk/claude.go @@ -42,6 +42,7 @@ func NewClaudeText(apiKey, model string, temp float64) (*ClaudeText, error) { Name: "claude", ApiUrl: "https://api.anthropic.com/v1/messages", ApiKey: apiKey, + Inerte: false, }, SdkText: SdkText{ Model: "claude-3-5-sonnet-20240620", @@ -68,14 +69,21 @@ func NewClaudeText(apiKey, model string, temp float64) (*ClaudeText, error) { func (c *ClaudeText) SendRequest(text string) (Message, error) { var textResponse ClaudeResponse - c.AppendHistory("user", text) + if text != "" { + c.AppendHistory("user", text) + } - test := c.GetHistory() + if c.GetInerte() { + if err := c.SaveHistory(); err != nil { + return Message{}, err + } + return Message{}, nil + } jsonBody, err := json.Marshal(ClaudeBody{ Model: c.Model, MaxTokens: 1024, - Messages: test, + Messages: c.GetHistory(), }) if err != nil { return Message{}, err diff --git a/sdk/deepl.go b/sdk/deepl.go index 595e912..ebdae2c 100644 --- a/sdk/deepl.go +++ b/sdk/deepl.go @@ -37,6 +37,7 @@ func NewDeepL(apiKey string) (*DeepL, error) { Name: "deepl", ApiUrl: "https://api-free.deepl.com/v2/translate", ApiKey: apiKey, + Inerte: false, }, SdkTranslate: SdkTranslate{ SourceLang: "", diff --git a/sdk/mistral.go b/sdk/mistral.go index 5f272b5..7ff38d3 100644 --- a/sdk/mistral.go +++ b/sdk/mistral.go @@ -30,6 +30,7 @@ func NewMistralText(apiKey, model string, temp float64) (*MistralText, error) { Name: "mistral", ApiUrl: "https://api.mistral.ai/v1/chat/completions", ApiKey: apiKey, + Inerte: false, }, SdkText: SdkText{ Model: "mistral-medium", @@ -56,7 +57,16 @@ func NewMistralText(apiKey, model string, temp float64) (*MistralText, error) { func (m *MistralText) SendRequest(text string) (Message, error) { var textResponse OpenaiResponse - m.AppendHistory("user", text) + if text != "" { + m.AppendHistory("user", text) + } + + if m.GetInerte() { + if err := m.SaveHistory(); err != nil { + return Message{}, err + } + return Message{}, nil + } jsonBody, err := json.Marshal(OpenaiBody{ Model: m.Model, diff --git a/sdk/openai.go b/sdk/openai.go index 7f61f6c..e2b0ddf 100644 --- a/sdk/openai.go +++ b/sdk/openai.go @@ -51,6 +51,7 @@ func NewOpenaiText(apiKey, model string, temp float64) (*OpenaiText, error) { Name: "openai", ApiUrl: "https://api.openai.com/v1/chat/completions", ApiKey: apiKey, + Inerte: false, }, SdkText: SdkText{ Model: "gpt-4", @@ -77,7 +78,16 @@ func NewOpenaiText(apiKey, model string, temp float64) (*OpenaiText, error) { func (o *OpenaiText) SendRequest(text string) (Message, error) { var textResponse OpenaiResponse - o.AppendHistory("user", text) + if text != "" { + o.AppendHistory("user", text) + } + + if o.GetInerte() { + if err := o.SaveHistory(); err != nil { + return Message{}, err + } + return Message{}, nil + } jsonBody, err := json.Marshal(OpenaiBody{ Model: o.Model, diff --git a/sdk/sdk.go b/sdk/sdk.go index e799583..4d840cb 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -2,28 +2,25 @@ package sdk type ISdk interface { GetName() string + SetInerte(bool) + GetInerte() bool } type Sdk struct { Name string ApiUrl string ApiKey string + Inerte bool } -func (s *Sdk) GetName() string { - return s.Name +func (s *Sdk) SetInerte(inerte bool) { + s.Inerte = inerte } -func InitSdk() error { - if err := InitSdkText(""); err != nil { - return err - } - - if err := InitSdkTranslate(""); err != nil { - return err - } - - // TODO : init sdk for image, audio +func (s *Sdk) GetInerte() bool { + return s.Inerte +} - return nil +func (s *Sdk) GetName() string { + return s.Name } diff --git a/sdk/text.go b/sdk/text.go index 39c6bcb..6d75631 100644 --- a/sdk/text.go +++ b/sdk/text.go @@ -25,6 +25,10 @@ type Message struct { Content []Content `json:"content"` } +func (m *Message) IsEmpty() bool { + return len(m.Content) == 0 +} + func (m *Message) GetContent() string { var text string diff --git a/service/text.go b/service/text.go index af0f7b6..7353e87 100644 --- a/service/text.go +++ b/service/text.go @@ -27,6 +27,10 @@ func SendTextRequest(prompt string) error { return err } + if resp.IsEmpty() { + return nil + } + fmt.Println(resp.GetContent()) return nil @@ -51,6 +55,10 @@ func InteractiveMode() error { return err } + if resp.IsEmpty() { + continue + } + fmt.Print("\n") fmt.Println(utils.Red + resp.Role + ">" + utils.Reset) fmt.Println(resp.GetContent())