Skip to content

Commit

Permalink
Merge pull request #27 from mutablelogic/v1
Browse files Browse the repository at this point in the history
Updated the mistral client
  • Loading branch information
djthorpe authored May 26, 2024
2 parents f96eb6f + 51f3c6b commit 6106611
Show file tree
Hide file tree
Showing 22 changed files with 655 additions and 346 deletions.
64 changes: 58 additions & 6 deletions cmd/api/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@ import (
// GLOBALS

var (
anthropicName = "claude"
anthropicClient *anthropic.Client
anthropicName = "claude"
anthropicClient *anthropic.Client
anthropicModel string
anthropicTemperature *float64
anthropicMaxTokens *uint64
anthropicStream bool
)

///////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -46,6 +50,21 @@ func anthropicParse(flags *Flags, opts ...client.ClientOpt) error {
anthropicClient = client
}

// Get the command-line parameters
anthropicModel = flags.GetString("model")
if temp, err := flags.GetValue("temperature"); err == nil {
t := temp.(float64)
anthropicTemperature = &t
}
if maxtokens, err := flags.GetValue("max-tokens"); err == nil {
t := maxtokens.(uint64)
anthropicMaxTokens = &t
}
if stream, err := flags.GetValue("stream"); err == nil {
t := stream.(bool)
anthropicStream = t
}

// Return success
return nil
}
Expand All @@ -54,7 +73,36 @@ func anthropicParse(flags *Flags, opts ...client.ClientOpt) error {
// METHODS

func anthropicChat(ctx context.Context, w *tablewriter.Writer, args []string) error {
// Request -> Response

// Set options
opts := []anthropic.Opt{}
if anthropicModel != "" {
opts = append(opts, anthropic.OptModel(anthropicModel))
}
if anthropicTemperature != nil {
opts = append(opts, anthropic.OptTemperature(float32(*anthropicTemperature)))
}
if anthropicMaxTokens != nil {
opts = append(opts, anthropic.OptMaxTokens(int(*anthropicMaxTokens)))
}
if anthropicStream {
opts = append(opts, anthropic.OptStream(func(choice schema.MessageChoice) {
w := w.Output()
if choice.Delta != nil {
if choice.Delta.Role != "" {
fmt.Fprintf(w, "\n%v: ", choice.Delta.Role)
}
if choice.Delta.Content != "" {
fmt.Fprintf(w, "%v", choice.Delta.Content)
}
}
if choice.FinishReason != "" {
fmt.Printf("\nfinish_reason: %q\n", choice.FinishReason)
}
}))
}

// Append user message
message := schema.NewMessage("user")
for _, arg := range args {
message.Add(schema.Text(arg))
Expand All @@ -63,11 +111,15 @@ func anthropicChat(ctx context.Context, w *tablewriter.Writer, args []string) er
// Request -> Response
responses, err := anthropicClient.Messages(ctx, []*schema.Message{
message,
})
}, opts...)
if err != nil {
return err
}

// Write table
return w.Write(responses)
// Write table (if not streaming)
if !anthropicStream {
return w.Write(responses)
} else {
return nil
}
}
2 changes: 1 addition & 1 deletion cmd/api/mistral.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func mistralChat(ctx context.Context, w *tablewriter.Writer, args []string) erro
opts = append(opts, mistral.OptModel(mistralModel))
}
if mistralTemperature != nil {
opts = append(opts, mistral.OptTemperature(*mistralTemperature))
opts = append(opts, mistral.OptTemperature(float32(*mistralTemperature)))
}
if mistralMaxTokens != nil {
opts = append(opts, mistral.OptMaxTokens(int(*mistralMaxTokens)))
Expand Down
2 changes: 1 addition & 1 deletion cmd/api/samantha.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func samChat(ctx context.Context, w *tablewriter.Writer, _ []string) error {
}
}

func samCall(_ context.Context, content schema.Content) *schema.Content {
func samCall(_ context.Context, content *schema.Content) *schema.Content {
anthropicClient.Debugf("%v: %v: %v", content.Type, content.Name, content.Input)
if content.Type != "tool_use" {
return schema.ToolResult(content.Id, fmt.Sprint("unexpected content type:", content.Type))
Expand Down
2 changes: 1 addition & 1 deletion pkg/anthropic/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ const (
endPoint = "https://api.anthropic.com/v1"
defaultVersion = "2023-06-01"
defaultMessageModel = "claude-3-haiku-20240307"
defaultMaxTokens = 4096
defaultMaxTokens = 1024
)

///////////////////////////////////////////////////////////////////////////////
Expand Down
Loading

0 comments on commit 6106611

Please sign in to comment.