From ad23250af694022f26a7ee9153b6fec372869d42 Mon Sep 17 00:00:00 2001 From: teddy Date: Thu, 19 Sep 2024 03:10:37 +0200 Subject: [PATCH] feat(text) : sdk claude and openai --- commands/text.go | 12 ++++++++++-- sdk/claude.go | 20 +++++++++++++++----- sdk/openai.go | 22 ++++++++++++++++++---- sdk/text.go | 46 +++++++++++++++++++++++----------------------- service/text.go | 9 +++------ 5 files changed, 69 insertions(+), 40 deletions(-) diff --git a/commands/text.go b/commands/text.go index 7b410c4..b8050a8 100644 --- a/commands/text.go +++ b/commands/text.go @@ -65,6 +65,8 @@ func TextFlags() []cli.Flag { Aliases: []string{"s"}, Usage: l.Get("text-system-usage"), Action: func(c *cli.Context, values []string) error { + var content []string + for _, value := range values { if value == "-" { stdin, err := io.ReadAll(os.Stdin) @@ -75,9 +77,11 @@ func TextFlags() []cli.Flag { value = string(stdin) } - textSdk.AppendHistory("system", value) + content = append(content, value) } + textSdk.AppendHistory("system", content...) + return nil }, }, @@ -86,6 +90,8 @@ func TextFlags() []cli.Flag { Aliases: []string{"f"}, Usage: l.Get("text-file-usage"), Action: func(c *cli.Context, files []string) error { + var fileContent []string + for _, file := range files { f, err := os.ReadFile(file) if err != nil { @@ -96,9 +102,11 @@ func TextFlags() []cli.Flag { return fmt.Errorf(l.Get("empty-file"), file) } - textSdk.AppendHistory("system", string(f)) + fileContent = append(fileContent, string(f)) } + textSdk.AppendHistory("system", fileContent...) + return nil }, }, diff --git a/sdk/claude.go b/sdk/claude.go index 6a944da..0f97ce4 100644 --- a/sdk/claude.go +++ b/sdk/claude.go @@ -14,6 +14,16 @@ type ClaudeResponse struct { Content []Content `json:"content"` } +func (c *ClaudeResponse) GetContent() string { + return c.Content[0].Text +} + +type ClaudeBody struct { + Model string `json:"model"` + MaxTokens int64 `json:"max_tokens"` + Messages []Message `json:"messages"` +} + type ClaudeText struct { Sdk SdkText @@ -54,10 +64,10 @@ func (c *ClaudeText) SendRequest(text string) (Message, error) { c.AppendHistory("user", text) - jsonBody, err := json.Marshal(TextBody{ - Model: c.Model, - // MaxTokens: 1024, - Messages: c.GetHistory(), + jsonBody, err := json.Marshal(ClaudeBody{ + Model: c.Model, + MaxTokens: 1024, + Messages: c.GetHistory(), }) if err != nil { return Message{}, err @@ -90,7 +100,7 @@ func (c *ClaudeText) SendRequest(text string) (Message, error) { return Message{}, err } - respMessage := c.AppendHistory(textResponse.Role, textResponse.Content[0].Text) + respMessage := c.AppendHistory(textResponse.Role, textResponse.GetContent()) if err := c.SaveHistory(); err != nil { return Message{}, err diff --git a/sdk/openai.go b/sdk/openai.go index dd7dbf3..2920262 100644 --- a/sdk/openai.go +++ b/sdk/openai.go @@ -14,8 +14,22 @@ type OpenaiResponse struct { } type Choices []struct { - Index int64 `json:"index"` - Message Message `json:"message"` + Index int64 `json:"index"` + Message ChoicesMessage `json:"message"` +} + +type ChoicesMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +func (m *ChoicesMessage) GetContent() string { + return m.Content +} + +type OpenaiBody struct { + Model string `json:"model"` + Messages []Message `json:"messages"` } type OpenaiText struct { @@ -58,7 +72,7 @@ func (o *OpenaiText) SendRequest(text string) (Message, error) { o.AppendHistory("user", text) - jsonBody, err := json.Marshal(TextBody{ + jsonBody, err := json.Marshal(OpenaiBody{ Model: o.Model, Messages: o.GetHistory(), }) @@ -93,7 +107,7 @@ func (o *OpenaiText) SendRequest(text string) (Message, error) { } msg := textResponse.Choices[0].Message - respMessage := o.AppendHistory(msg.Role, msg.Content) + respMessage := o.AppendHistory(msg.Role, msg.GetContent()) if err := o.SaveHistory(); err != nil { return Message{}, err diff --git a/sdk/text.go b/sdk/text.go index f9f0e95..e8da4dd 100644 --- a/sdk/text.go +++ b/sdk/text.go @@ -8,25 +8,22 @@ import ( ) type Content struct { - Type string `json:"type"` - Text string `json:"text"` - Source struct { - Type string `json:"type"` - MediaType string `json:"media_type"` - Data string `json:"data"` - } `json:"source"` + Type string `json:"type"` + Text string `json:"text"` + // Source struct { + // Type string `json:"type"` + // MediaType string `json:"media_type"` + // Data string `json:"data"` + // } `json:"source"` } type Message struct { - Role string `json:"role"` - Content string `json:"content"` - // Content []Content `json:"content"` + Role string `json:"role"` + Content []Content `json:"content"` } -type TextBody struct { - Model string `json:"model"` - // MaxTokens int64 `json:"max_tokens"` - Messages []Message `json:"messages"` +func (m *Message) GetContent() string { + return m.Content[0].Text } type ErrorMsg struct { @@ -43,7 +40,7 @@ type ITextService interface { type ISdkText interface { SetTemp(temp float64) GetTemp() float64 - AppendHistory(role, text string) Message + AppendHistory(role string, text ...string) Message SaveHistory() error LoadHistory() error GetHistory() []Message @@ -74,17 +71,20 @@ func (s *SdkText) GetTemp() float64 { return s.Temp } -func (s *SdkText) AppendHistory(role, text string) Message { +func (s *SdkText) AppendHistory(role string, text ...string) Message { + var content []Content + + for _, t := range text { + content = append(content, Content{ + Type: "text", + Text: t, + }) + } + name := s.SelectedHistory message := Message{ Role: role, - Content: text, - // Content: []Content{ - // { - // Type: "text", - // Text: text, - // }, - // }, + Content: content, } s.History[name] = append(s.History[name], message) diff --git a/service/text.go b/service/text.go index 7b3c309..e89ddb0 100644 --- a/service/text.go +++ b/service/text.go @@ -27,8 +27,7 @@ func SendTextRequest(prompt string) error { return err } - // fmt.Println(resp.Content[0].Text) - fmt.Println(resp.Content) + fmt.Println(resp.GetContent()) return nil } @@ -54,8 +53,7 @@ func InteractiveMode() error { fmt.Print("\n") fmt.Println(utils.Red + resp.Role + ">" + utils.Reset) - // fmt.Println(resp.Content[0].Text) - fmt.Println(resp.Content) + fmt.Println(resp.GetContent()) fmt.Print("\n") } @@ -93,8 +91,7 @@ func ListHistory(showSystem, showMsg bool) error { fmt.Println(utils.Red + "assistant> " + utils.Reset) } - // fmt.Println(message.Content[0].Text) - fmt.Println(message.Content) + fmt.Println(message.GetContent()) fmt.Print("\n") }