Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
billytrend-cohere committed Mar 12, 2024
1 parent a29345a commit 654d168
Showing 1 changed file with 90 additions and 7 deletions.
97 changes: 90 additions & 7 deletions tests/sdk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"errors"
"io"
"os"
"strings"
"testing"

Expand All @@ -26,6 +25,10 @@ func strPointer(s string) *string {
return &s
}

func boolPointer(s bool) *bool {
return &s
}

func TestNewClient(t *testing.T) {
client := client.NewClient(client.WithToken(os.Getenv("COHERE_API_KEY")))

Check failure on line 33 in tests/sdk_test.go

View workflow job for this annotation

GitHub Actions / test

undefined: os

Check failure on line 33 in tests/sdk_test.go

View workflow job for this annotation

GitHub Actions / build

undefined: os

Expand Down Expand Up @@ -186,10 +189,10 @@ func TestNewClient(t *testing.T) {
&cohere.RerankRequest{
Query: "What is the capital of the United States?",
Documents: []*cohere.RerankRequestDocumentsItem{
cohere.NewRerankRequestDocumentsItemFromString("Carson City is the capital city of the American state of Nevada."),
cohere.NewRerankRequestDocumentsItemFromString("The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan."),
cohere.NewRerankRequestDocumentsItemFromString("Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."),
cohere.NewRerankRequestDocumentsItemFromString("Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."),
{String: "Carson City is the capital city of the American state of Nevada."},
{String: "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan."},
{String: "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."},
{String: "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."},
},
})

Expand Down Expand Up @@ -218,8 +221,8 @@ func TestNewClient(t *testing.T) {
&MyReader{Reader: strings.NewReader(`{"text": "The quick brown fox jumps over the lazy dog"}`), name: "test.jsonl"},
&MyReader{Reader: strings.NewReader(""), name: "a.jsonl"},
&cohere.DatasetsCreateRequest{
Name: strPointer("prompt-completion-dataset"),
Type: cohere.DatasetTypeEmbedResult.Ptr(),
Name: "prompt-completion-dataset",
Type: cohere.DatasetTypeEmbedResult,
})

require.NoError(t, err)
Expand Down Expand Up @@ -345,4 +348,84 @@ func TestNewClient(t *testing.T) {
require.NoError(t, err)
print(delete)
})

t.Run("TestTool", func(t *testing.T) {
tools := []*cohere.Tool{
{
Name: "sales_database",
Description: "Connects to a database about sales volumes",
ParameterDefinitions: map[string]*cohere.ToolParameterDefinitionsValue{
"day": {
Description: "Retrieves sales data from this day, formatted as YYYY-MM-DD.",
Type: "str",
Required: boolPointer(true),
},
},
},
}

toolsResponse, err := client.Chat(
context.TODO(),
&cohere.ChatRequest{
Message: "How good were the sales on September 29?",
Tools: tools,
Preamble: strPointer(`
## Task Description
You help people answer their questions and other requests interactively. You will be asked a very wide array of requests on all kinds of topics. You will be equipped with a wide range of search engines or similar tools to help you, which you use to research your answer. You should focus on serving the user's needs as best you can, which will be wide-ranging.
## Style Guide
Unless the user asks for a different style of answer, you should answer in full sentences, using proper grammar and spelling.
`),
})

require.NoError(t, err)
require.NotNil(t, toolsResponse.ToolCalls)
require.Len(t, toolsResponse.ToolCalls, 1)
require.Equal(t, toolsResponse.ToolCalls[0].Name, "sales_database")
require.Equal(t, toolsResponse.ToolCalls[0].Parameters["day"], "2023-09-29")

print(toolsResponse)

localTools := map[string]func(string) *[]map[string]interface{}{
"sales_database": func(day string) *[]map[string]interface{} {
return &[]map[string]interface{}{
{
"numberOfSales": 120,
"totalRevenue": 48500,
"averageSaleValue": 404.17,
"date": "2023-09-29",
},
}
},
}

toolResults := make([]*cohere.ChatRequestToolResultsItem, 0)

for _, toolCall := range toolsResponse.ToolCalls {
result := localTools[toolCall.Name](toolCall.Parameters["day"].(string))
toolResult := &cohere.ChatRequestToolResultsItem{
Call: toolCall,
Outputs: *result,
}
toolResults = append(toolResults, toolResult)
}

citedResponse, err := client.Chat(
context.TODO(),
&cohere.ChatRequest{
Message: "How good were the sales on September 29?",
Tools: tools,
ToolResults: toolResults,
Model: strPointer("command-nightly"),
})

require.NoError(t, err)

require.Equal(t, citedResponse.Documents[0]["averageSaleValue"], "404.17")
require.Equal(t, citedResponse.Documents[0]["date"], "2023-09-29")
require.Equal(t, citedResponse.Documents[0]["numberOfSales"], "120")
require.Equal(t, citedResponse.Documents[0]["totalRevenue"], "48500")

})

}

0 comments on commit 654d168

Please sign in to comment.