diff --git a/CHANGELOG.md b/CHANGELOG.md index 515f3ce..b762e3b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [Unreleased] + +### Changed + +* Parameter -f in text command accept image file (only for claude sdk) + ## [0.4.0] - 2024-09-28 ### Changed diff --git a/commands/text.go b/commands/text.go index 5b52c5b..909a726 100644 --- a/commands/text.go +++ b/commands/text.go @@ -9,6 +9,7 @@ import ( "github.com/LordPax/aicli/lang" "github.com/LordPax/aicli/sdk" "github.com/LordPax/aicli/service" + "github.com/LordPax/aicli/utils" cli "github.com/urfave/cli/v2" ) @@ -125,7 +126,6 @@ func textFlags() []cli.Flag { Category: "text", Action: func(c *cli.Context, files []string) error { text := sdk.GetSdkText() - var fileContent []string for _, file := range files { f, err := os.ReadFile(file) @@ -137,10 +137,15 @@ func textFlags() []cli.Flag { return fmt.Errorf(l.Get("empty-file"), file) } - fileContent = append(fileContent, string(f)) - } + if fileType := utils.IsFileType(f, utils.IMAGE); fileType != "" { + if err := text.AppendImageHistory("system", "image/"+fileType, f); err != nil { + return err + } + continue + } - text.AppendHistory("system", fileContent...) + text.AppendHistory("system", string(f)) + } return nil }, diff --git a/go.mod b/go.mod index 3d68499..b3720ec 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/LordPax/aicli go 1.23.1 require ( + github.com/h2non/filetype v1.1.3 github.com/urfave/cli/v2 v2.27.3 golang.org/x/term v0.24.0 gopkg.in/ini.v1 v1.67.0 diff --git a/go.sum b/go.sum index 6132115..2c67a84 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.4 h1:wfIWP927BUkWJb2NmU/kNDYIBTh/ziUX91+lV github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/h2non/filetype v1.1.3 h1:FKkx9QbD7HR/zjK1Ia5XiBsq9zdLi5Kf3zGyFTAFkGg= +github.com/h2non/filetype v1.1.3/go.mod h1:319b3zT68BvV+WRj7cwy856M2ehB3HqNOt6sy1HndBY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= diff --git a/lang/en.go b/lang/en.go index e449bdd..7329211 100644 --- a/lang/en.go +++ b/lang/en.go @@ -9,6 +9,7 @@ var EN_STRINGS = LangString{ "silent": "Disable printing log to stdout", "no-args": "No arguments provided", "no-command": "No command provided", + "not-implemented": "Feature \"%s\" is not implemented", "unknown-sdk": "Unknown sdk \"%s\"", "sdk-model-usage": "Select a model", "inerte-usage": "Do not make API call", @@ -18,7 +19,7 @@ var EN_STRINGS = LangString{ "text-system-usage": "Instruction to enter as context (use \"-\" for stdin)", "text-history-usage": "Select a history", "text-clear-usage": "Clear history", - "text-file-usage": "Text file to use", + "text-file-usage": "Text or image file to use", "text-input": "(\"exit\" to quit) " + utils.Blue + "user> " + utils.Reset, "translate-input": "(\"exit\" to quit) " + utils.Blue + "> " + utils.Reset, "text-list-history-usage": "List history", diff --git a/lang/fr.go b/lang/fr.go index b095a04..8c9f138 100644 --- a/lang/fr.go +++ b/lang/fr.go @@ -9,6 +9,7 @@ var FR_STRINGS = LangString{ "silent": "Désactiver l'impression du journal sur stdout", "no-args": "Aucun argument fourni", "no-command": "Aucune commande fournie", + "not-implemented": "La fonctionnalité \"%s\" n'est pas implémentée", "unknown-sdk": "Sdk inconnu \"%s\"", "sdk-model-usage": "Sélectionner un modèle", "inerte-usage": "N'effectue pas d'appel à l'API", @@ -18,7 +19,7 @@ var FR_STRINGS = LangString{ "text-system-usage": "Instruction à entrer comme context (utilisez \"-\" pour stdin)", "text-history-usage": "Sélectionner un historique", "text-clear-usage": "Effacer l'historique", - "text-file-usage": "Fichier texte à utiliser", + "text-file-usage": "Fichier texte ou image à utiliser", "text-input": "(\"exit\" pour quitter) " + utils.Blue + "user> " + utils.Reset, "translate-input": "(\"exit\" pour quitter) " + utils.Blue + "> " + utils.Reset, "text-list-history-usage": "Lister l'historique", diff --git a/sdk/claude.go b/sdk/claude.go index c167388..7e11ad8 100644 --- a/sdk/claude.go +++ b/sdk/claude.go @@ -1,6 +1,7 @@ package sdk import ( + "encoding/base64" "encoding/json" "errors" "io" @@ -10,8 +11,8 @@ import ( ) type ClaudeResponse struct { - Role string `json:"role"` - Content []Content `json:"content"` + Role string `json:"role"` + Content []ContentText `json:"content"` } func (c *ClaudeResponse) GetContent() string { @@ -137,7 +138,7 @@ func (c *ClaudeText) AppendHistory(role string, text ...string) Message { // If the last message is from the same role, append the new text to the last message if lastMessage != nil && lastMessage.Role == role { - return c.AppendMessage(idLastMsg, text...) + return c.AppendTextMessage(idLastMsg, text...) } message := Message{ @@ -148,3 +149,21 @@ func (c *ClaudeText) AppendHistory(role string, text ...string) Message { return message } + +func (c *ClaudeText) AppendImageHistory(role, fileType string, file []byte) error { + name := c.SelectedHistory + + if role == "system" { + role = "user" + } + + str := base64.StdEncoding.EncodeToString(file) + + message := Message{ + Role: role, + Content: []IContent{NewContentImage(str, fileType)}, + } + c.History[name] = append(c.History[name], message) + + return nil +} diff --git a/sdk/history.go b/sdk/history.go index 4896c94..6b9f79c 100644 --- a/sdk/history.go +++ b/sdk/history.go @@ -1,10 +1,10 @@ package sdk import ( + "encoding/base64" "encoding/json" "os" "path" - "strings" "github.com/LordPax/aicli/config" "github.com/LordPax/aicli/utils" @@ -19,7 +19,8 @@ type ITextHistory interface { GetSelectedHistory() string ClearHistory() GetMessage(index int) *Message - AppendMessage(index int, text ...string) Message + AppendTextMessage(index int, text ...string) Message + AppendImageHistory(role, fileType string, file []byte) error GetHistoryNames() []string } @@ -72,6 +73,20 @@ func (t *TextHistory) AppendHistory(role string, text ...string) Message { return message } +func (t *TextHistory) AppendImageHistory(role, fileType string, file []byte) error { + name := t.SelectedHistory + + str := base64.StdEncoding.EncodeToString(file) + + message := Message{ + Role: role, + Content: []IContent{NewContentImage(str, fileType)}, + } + t.History[name] = append(t.History[name], message) + + return nil +} + func (t *TextHistory) SaveHistory() error { f, err := os.Create(t.HistoryFile) if err != nil { @@ -101,10 +116,21 @@ func (t *TextHistory) LoadHistory() error { return nil } - if err := json.Unmarshal(f, &t.History); err != nil { + var tempHistory map[string][]json.RawMessage + if err := json.Unmarshal(f, &tempHistory); err != nil { return err } + t.History = make(map[string][]Message) + for key, rawMessages := range tempHistory { + t.History[key] = make([]Message, len(rawMessages)) + for i, rawMessage := range rawMessages { + if err := json.Unmarshal(rawMessage, &t.History[key][i]); err != nil { + return err + } + } + } + return nil } @@ -126,7 +152,7 @@ func (t *TextHistory) GetMessage(index int) *Message { return &t.History[name][index] } -func (t *TextHistory) AppendMessage(index int, text ...string) Message { +func (t *TextHistory) AppendTextMessage(index int, text ...string) Message { name := t.SelectedHistory message := t.GetMessage(index) @@ -148,14 +174,11 @@ func (t *TextHistory) GetHistoryNames() []string { return names } -func textContent(text ...string) []Content { - var content []Content +func textContent(text ...string) []IContent { + var content []IContent for _, t := range text { - content = append(content, Content{ - Type: "text", - Text: strings.TrimSpace(t), - }) + content = append(content, NewContentText(t)) } return content diff --git a/sdk/mistral.go b/sdk/mistral.go index 7ff38d3..b8c42d1 100644 --- a/sdk/mistral.go +++ b/sdk/mistral.go @@ -3,9 +3,11 @@ package sdk import ( "encoding/json" "errors" + "fmt" "io" "net/http" + "github.com/LordPax/aicli/lang" "github.com/LordPax/aicli/utils" ) @@ -112,3 +114,8 @@ func (m *MistralText) SendRequest(text string) (Message, error) { return respMessage, nil } + +func (t *MistralText) AppendImageHistory(role, fileType string, file []byte) error { + l := lang.GetLocalize() + return fmt.Errorf(l.Get("not-implemented"), "AppendImageHistory") +} diff --git a/sdk/openai.go b/sdk/openai.go index e2b0ddf..1f2e485 100644 --- a/sdk/openai.go +++ b/sdk/openai.go @@ -3,9 +3,11 @@ package sdk import ( "encoding/json" "errors" + "fmt" "io" "net/http" + "github.com/LordPax/aicli/lang" "github.com/LordPax/aicli/utils" ) @@ -133,3 +135,8 @@ func (o *OpenaiText) SendRequest(text string) (Message, error) { return respMessage, nil } + +func (t *OpenaiText) AppendImageHistory(role, fileType string, file []byte) error { + l := lang.GetLocalize() + return fmt.Errorf(l.Get("not-implemented"), "AppendImageHistory") +} diff --git a/sdk/text.go b/sdk/text.go index 6d75631..2dcb0b1 100644 --- a/sdk/text.go +++ b/sdk/text.go @@ -1,6 +1,7 @@ package sdk import ( + "encoding/json" "errors" "fmt" @@ -10,19 +11,113 @@ import ( var sdkTextInstance ITextService -type Content struct { +type IContent interface { + GetValue() string +} + +type ContentText struct { Type string `json:"type"` Text string `json:"text"` - // Source struct { - // Type string `json:"type"` - // MediaType string `json:"media_type"` - // Data string `json:"data"` - // } `json:"source"` +} + +func NewContentText(text string) *ContentText { + return &ContentText{ + Type: "text", + Text: text, + } +} + +func (c *ContentText) GetValue() string { + return c.Text +} + +type ContentImage struct { + Type string `json:"type"` + Source struct { + Type string `json:"type"` + MediaType string `json:"media_type"` + Data string `json:"data"` + } `json:"source"` +} + +func NewContentImage(data, fileType string) *ContentImage { + return &ContentImage{ + Type: "image", + Source: struct { + Type string `json:"type"` + MediaType string `json:"media_type"` + Data string `json:"data"` + }{ + Type: "base64", + MediaType: fileType, + Data: data, + }, + } +} + +func (c *ContentImage) GetValue() string { + return "image: " + c.Source.MediaType } type Message struct { - Role string `json:"role"` - Content []Content `json:"content"` + Role string `json:"role"` + Content []IContent `json:"content"` +} + +type AuxMessage struct { + Role string `json:"role"` + Content json.RawMessage `json:"content"` +} + +func (m *Message) UnmarshalJSON(data []byte) error { + var aux AuxMessage + + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + m.Role = aux.Role + m.Content = []IContent{} + + var contentArray []json.RawMessage + if err := json.Unmarshal(aux.Content, &contentArray); err != nil { + return err + } + + for _, item := range contentArray { + content, err := unmarshalContent(item) + if err != nil { + return err + } + m.Content = append(m.Content, content) + } + + return nil +} + +func unmarshalContent(data []byte) (IContent, error) { + var temp map[string]interface{} + if err := json.Unmarshal(data, &temp); err != nil { + return nil, err + } + + if _, ok := temp["text"]; ok { + var ct ContentText + if err := json.Unmarshal(data, &ct); err != nil { + return nil, err + } + return &ct, nil + } + + if _, ok := temp["source"]; ok { + var ci ContentImage + if err := json.Unmarshal(data, &ci); err != nil { + return nil, err + } + return &ci, nil + } + + return nil, fmt.Errorf("unknown content type") } func (m *Message) IsEmpty() bool { @@ -33,17 +128,14 @@ func (m *Message) GetContent() string { var text string if len(m.Content) == 1 { - return m.Content[0].Text + return m.Content[0].GetValue() } for i, c := range m.Content { - if c.Type != "text" { - continue - } if i != 0 { text += "\n---\n" } - text += c.Text + text += c.GetValue() } return text diff --git a/utils/utils.go b/utils/utils.go index 35c9979..bbb6219 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -6,6 +6,7 @@ import ( "os" "strings" + "github.com/h2non/filetype" "golang.org/x/term" ) @@ -17,6 +18,10 @@ const ( Blue = Escape + "[34m" ) +var ( + IMAGE = []string{"jpg", "jpeg", "png", "gif", "webp"} +) + func FileExist(file string) bool { _, err := os.Stat(file) return !os.IsNotExist(err) @@ -66,3 +71,12 @@ func Input(prompt string, defaultVal string, nullable bool) string { return line } + +func IsFileType(buf []byte, fileType []string) string { + for _, t := range fileType { + if filetype.Is(buf, t) { + return t + } + } + return "" +}