Skip to content

Commit

Permalink
feat(text) : sdk claude and openai
Browse files Browse the repository at this point in the history
  • Loading branch information
LordPax committed Sep 19, 2024
1 parent 3f23c1c commit ad23250
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 40 deletions.
12 changes: 10 additions & 2 deletions commands/text.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ func TextFlags() []cli.Flag {
Aliases: []string{"s"},
Usage: l.Get("text-system-usage"),
Action: func(c *cli.Context, values []string) error {
var content []string

for _, value := range values {
if value == "-" {
stdin, err := io.ReadAll(os.Stdin)
Expand All @@ -75,9 +77,11 @@ func TextFlags() []cli.Flag {
value = string(stdin)
}

textSdk.AppendHistory("system", value)
content = append(content, value)
}

textSdk.AppendHistory("system", content...)

return nil
},
},
Expand All @@ -86,6 +90,8 @@ func TextFlags() []cli.Flag {
Aliases: []string{"f"},
Usage: l.Get("text-file-usage"),
Action: func(c *cli.Context, files []string) error {
var fileContent []string

for _, file := range files {
f, err := os.ReadFile(file)
if err != nil {
Expand All @@ -96,9 +102,11 @@ func TextFlags() []cli.Flag {
return fmt.Errorf(l.Get("empty-file"), file)
}

textSdk.AppendHistory("system", string(f))
fileContent = append(fileContent, string(f))
}

textSdk.AppendHistory("system", fileContent...)

return nil
},
},
Expand Down
20 changes: 15 additions & 5 deletions sdk/claude.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@ type ClaudeResponse struct {
Content []Content `json:"content"`
}

func (c *ClaudeResponse) GetContent() string {
return c.Content[0].Text
}

type ClaudeBody struct {
Model string `json:"model"`
MaxTokens int64 `json:"max_tokens"`
Messages []Message `json:"messages"`
}

type ClaudeText struct {
Sdk
SdkText
Expand Down Expand Up @@ -54,10 +64,10 @@ func (c *ClaudeText) SendRequest(text string) (Message, error) {

c.AppendHistory("user", text)

jsonBody, err := json.Marshal(TextBody{
Model: c.Model,
// MaxTokens: 1024,
Messages: c.GetHistory(),
jsonBody, err := json.Marshal(ClaudeBody{
Model: c.Model,
MaxTokens: 1024,
Messages: c.GetHistory(),
})
if err != nil {
return Message{}, err
Expand Down Expand Up @@ -90,7 +100,7 @@ func (c *ClaudeText) SendRequest(text string) (Message, error) {
return Message{}, err
}

respMessage := c.AppendHistory(textResponse.Role, textResponse.Content[0].Text)
respMessage := c.AppendHistory(textResponse.Role, textResponse.GetContent())

if err := c.SaveHistory(); err != nil {
return Message{}, err
Expand Down
22 changes: 18 additions & 4 deletions sdk/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,22 @@ type OpenaiResponse struct {
}

type Choices []struct {
Index int64 `json:"index"`
Message Message `json:"message"`
Index int64 `json:"index"`
Message ChoicesMessage `json:"message"`
}

type ChoicesMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}

func (m *ChoicesMessage) GetContent() string {
return m.Content
}

type OpenaiBody struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
}

type OpenaiText struct {
Expand Down Expand Up @@ -58,7 +72,7 @@ func (o *OpenaiText) SendRequest(text string) (Message, error) {

o.AppendHistory("user", text)

jsonBody, err := json.Marshal(TextBody{
jsonBody, err := json.Marshal(OpenaiBody{
Model: o.Model,
Messages: o.GetHistory(),
})
Expand Down Expand Up @@ -93,7 +107,7 @@ func (o *OpenaiText) SendRequest(text string) (Message, error) {
}

msg := textResponse.Choices[0].Message
respMessage := o.AppendHistory(msg.Role, msg.Content)
respMessage := o.AppendHistory(msg.Role, msg.GetContent())

if err := o.SaveHistory(); err != nil {
return Message{}, err
Expand Down
46 changes: 23 additions & 23 deletions sdk/text.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,22 @@ import (
)

type Content struct {
Type string `json:"type"`
Text string `json:"text"`
Source struct {
Type string `json:"type"`
MediaType string `json:"media_type"`
Data string `json:"data"`
} `json:"source"`
Type string `json:"type"`
Text string `json:"text"`
// Source struct {
// Type string `json:"type"`
// MediaType string `json:"media_type"`
// Data string `json:"data"`
// } `json:"source"`
}

type Message struct {
Role string `json:"role"`
Content string `json:"content"`
// Content []Content `json:"content"`
Role string `json:"role"`
Content []Content `json:"content"`
}

type TextBody struct {
Model string `json:"model"`
// MaxTokens int64 `json:"max_tokens"`
Messages []Message `json:"messages"`
func (m *Message) GetContent() string {
return m.Content[0].Text
}

type ErrorMsg struct {
Expand All @@ -43,7 +40,7 @@ type ITextService interface {
type ISdkText interface {
SetTemp(temp float64)
GetTemp() float64
AppendHistory(role, text string) Message
AppendHistory(role string, text ...string) Message
SaveHistory() error
LoadHistory() error
GetHistory() []Message
Expand Down Expand Up @@ -74,17 +71,20 @@ func (s *SdkText) GetTemp() float64 {
return s.Temp
}

func (s *SdkText) AppendHistory(role, text string) Message {
func (s *SdkText) AppendHistory(role string, text ...string) Message {
var content []Content

for _, t := range text {
content = append(content, Content{
Type: "text",
Text: t,
})
}

name := s.SelectedHistory
message := Message{
Role: role,
Content: text,
// Content: []Content{
// {
// Type: "text",
// Text: text,
// },
// },
Content: content,
}
s.History[name] = append(s.History[name], message)

Expand Down
9 changes: 3 additions & 6 deletions service/text.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ func SendTextRequest(prompt string) error {
return err
}

// fmt.Println(resp.Content[0].Text)
fmt.Println(resp.Content)
fmt.Println(resp.GetContent())

return nil
}
Expand All @@ -54,8 +53,7 @@ func InteractiveMode() error {

fmt.Print("\n")
fmt.Println(utils.Red + resp.Role + ">" + utils.Reset)
// fmt.Println(resp.Content[0].Text)
fmt.Println(resp.Content)
fmt.Println(resp.GetContent())
fmt.Print("\n")
}

Expand Down Expand Up @@ -93,8 +91,7 @@ func ListHistory(showSystem, showMsg bool) error {
fmt.Println(utils.Red + "assistant> " + utils.Reset)
}

// fmt.Println(message.Content[0].Text)
fmt.Println(message.Content)
fmt.Println(message.GetContent())
fmt.Print("\n")
}

Expand Down

0 comments on commit ad23250

Please sign in to comment.