Skip to content

Commit

Permalink
Add Codestral support
Browse files Browse the repository at this point in the history
  • Loading branch information
sam-ulrich1 committed Jun 5, 2024
1 parent 49a9d45 commit ebe86b3
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 7 deletions.
33 changes: 29 additions & 4 deletions chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,30 @@ func TestChat(t *testing.T) {
assert.Equal(t, res.Choices[0].Message.Content, "Test Succeeded")
}

func TestChatCodestral(t *testing.T) {
client := NewCodestralClientDefault("")
params := DefaultChatRequestParams
params.MaxTokens = 10
params.Temperature = 0
res, err := client.Chat(
ModelCodestralLatest,
[]ChatMessage{
{
Role: RoleUser,
Content: "You are in test mode and must reply to this with exactly and only `Test Succeeded`",
},
},
&params,
)
assert.NoError(t, err)
assert.NotNil(t, res)

assert.Greater(t, len(res.Choices), 0)
assert.Greater(t, len(res.Choices[0].Message.Content), 0)
assert.Equal(t, res.Choices[0].Message.Role, RoleAssistant)
assert.Equal(t, res.Choices[0].Message.Content, "Test Succeeded")
}

func TestChatFunctionCall(t *testing.T) {
client := NewMistralClientDefault("")
params := DefaultChatRequestParams
Expand Down Expand Up @@ -135,6 +159,7 @@ func TestChatFunctionCall2(t *testing.T) {
Role: RoleAssistant,
ToolCalls: []ToolCall{
{
Id: "aaaaaaaaa",
Type: ToolTypeFunction,
Function: FunctionCall{
Name: "get_weather",
Expand Down Expand Up @@ -166,7 +191,7 @@ func TestChatJsonMode(t *testing.T) {
params.Temperature = 0
params.ResponseFormat = ResponseFormatJsonObject
res, err := client.Chat(
ModelMistralSmallLatest,
ModelOpenMixtral8x22b,
[]ChatMessage{
{
Role: RoleUser,
Expand All @@ -186,7 +211,7 @@ func TestChatJsonMode(t *testing.T) {
assert.Greater(t, len(res.Choices), 0)
assert.Greater(t, len(res.Choices[0].Message.Content), 0)
assert.Equal(t, res.Choices[0].Message.Role, RoleAssistant)
assert.Equal(t, res.Choices[0].Message.Content, "{\"symbols\": [\"Go\", \"ChatMessage\", \"FunctionCall\", \"ToolCall\"]}")
assert.Equal(t, res.Choices[0].Message.Content, "{\"symbols\": [\"Go\", \"ChatMessage\", \"Any\", \"FunctionCall\", \"ToolCall\", \"ToolResponse\"]}")
}

func TestChatStream(t *testing.T) {
Expand Down Expand Up @@ -309,7 +334,7 @@ func TestChatStreamJsonMode(t *testing.T) {
params.Temperature = 0
params.ResponseFormat = ResponseFormatJsonObject
resChan, err := client.ChatStream(
ModelMistralSmallLatest,
ModelOpenMixtral8x22b,
[]ChatMessage{
{
Role: RoleUser,
Expand Down Expand Up @@ -347,6 +372,6 @@ func TestChatStreamJsonMode(t *testing.T) {
}
}

assert.Equal(t, totalOutput, "{\"symbols\": [\"Go\", \"ChatMessage\", \"FunctionCall\", \"ToolCall\"]}")
assert.Equal(t, totalOutput, "{\"symbols\": [\"Go\", \"ChatMessage\", \"Any\", \"FunctionCall\", \"ToolCall\", \"ToolResponse\"]}")
assert.Nil(t, functionCall)
}
17 changes: 14 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"io"
"net/http"
"net/url"
"os"
Expand All @@ -13,6 +13,7 @@ import (

const (
Endpoint = "https://api.mistral.ai"
CodestralEndpoint = "https://codestral.mistral.ai"
DefaultMaxRetries = 5
DefaultTimeout = 120 * time.Second
)
Expand Down Expand Up @@ -54,6 +55,7 @@ func NewMistralClient(apiKey string, endpoint string, maxRetries int, timeout ti
}
}

// NewMistralClientDefault creates a new Mistral API client with the default endpoint and the given API key. Defaults to using MISTRAL_API_KEY from the environment.
func NewMistralClientDefault(apiKey string) *MistralClient {
if apiKey == "" {
apiKey = os.Getenv("MISTRAL_API_KEY")
Expand All @@ -62,6 +64,15 @@ func NewMistralClientDefault(apiKey string) *MistralClient {
return NewMistralClient(apiKey, Endpoint, DefaultMaxRetries, DefaultTimeout)
}

// NewCodestralClientDefault creates a new Codestral API client with the default endpoint and the given API key. Defaults to using CODESTRAL_API_KEY from the environment.
func NewCodestralClientDefault(apiKey string) *MistralClient {
if apiKey == "" {
apiKey = os.Getenv("CODESTRAL_API_KEY")
}

return NewMistralClient(apiKey, CodestralEndpoint, DefaultMaxRetries, DefaultTimeout)
}

func (c *MistralClient) request(method string, jsonData map[string]interface{}, path string, stream bool, params map[string]string) (interface{}, error) {
uri, err := url.Parse(c.endpoint)
if err != nil {
Expand Down Expand Up @@ -98,7 +109,7 @@ func (c *MistralClient) request(method string, jsonData map[string]interface{},
}

if resp.StatusCode >= 400 {
responseBytes, _ := ioutil.ReadAll(resp.Body)
responseBytes, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("(HTTP Error %d) %s", resp.StatusCode, string(responseBytes))
}

Expand All @@ -107,7 +118,7 @@ func (c *MistralClient) request(method string, jsonData map[string]interface{},
}

defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
Expand Down
66 changes: 66 additions & 0 deletions fim.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package mistral

import (
"fmt"
"net/http"
)

// FIMRequestParams represents the parameters for the FIM method of MistralClient.
type FIMRequestParams struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Suffix string `json:"suffix"`
MaxTokens int `json:"max_tokens"`
Temperature float64 `json:"temperature"`
Stop []string `json:"stop,omitempty"`
}

// FIMCompletionResponse represents the response from the FIM completion endpoint.
type FIMCompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
Model string `json:"model"`
Choices []FIMCompletionResponseChoice `json:"choices"`
Usage UsageInfo `json:"usage"`
}

// FIMCompletionResponseChoice represents a choice in the FIM completion response.
type FIMCompletionResponseChoice struct {
Index int `json:"index"`
Message ChatMessage `json:"message"`
FinishReason FinishReason `json:"finish_reason,omitempty"`
}

// FIM sends a FIM request and returns the completion response.
func (c *MistralClient) FIM(params *FIMRequestParams) (*FIMCompletionResponse, error) {
requestData := map[string]interface{}{
"model": params.Model,
"prompt": params.Prompt,
"suffix": params.Suffix,
"max_tokens": params.MaxTokens,
"temperature": params.Temperature,
}

if params.Stop != nil {
requestData["stop"] = params.Stop
}

response, err := c.request(http.MethodPost, requestData, "v1/fim/completions", false, nil)
if err != nil {
return nil, err
}

respData, ok := response.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("invalid response type: %T", response)
}

var fimResponse FIMCompletionResponse
err = mapToStruct(respData, &fimResponse)
if err != nil {
return nil, err
}

return &fimResponse, nil
}
59 changes: 59 additions & 0 deletions fim_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package mistral

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestFIM(t *testing.T) {
client := NewMistralClientDefault("")
params := FIMRequestParams{
Model: ModelCodestralLatest,
Prompt: "def f(",
Suffix: "return a + b",
MaxTokens: 64,
Temperature: 0,
Stop: []string{"\n"},
}
res, err := client.FIM(&params)
assert.NoError(t, err)
assert.NotNil(t, res)

assert.Greater(t, len(res.Choices), 0)
assert.Equal(t, res.Choices[0].Message.Content, "a, b):")
assert.Equal(t, res.Choices[0].FinishReason, FinishReasonStop)
}

func TestFIMWithStop(t *testing.T) {
client := NewMistralClientDefault("")
params := FIMRequestParams{
Model: ModelCodestralLatest,
Prompt: "def is_odd(n): \n return n % 2 == 1 \n def test_is_odd():",
Suffix: "test_is_odd()",
MaxTokens: 64,
Temperature: 0,
Stop: []string{"False"},
}
res, err := client.FIM(&params)
assert.NoError(t, err)
assert.NotNil(t, res)

assert.Greater(t, len(res.Choices), 0)
assert.Equal(t, res.Choices[0].Message.Content, "\n assert is_odd(1) == True\n assert is_odd(2) == ")
assert.Equal(t, res.Choices[0].FinishReason, FinishReasonStop)
}

func TestFIMInvalidModel(t *testing.T) {
client := NewMistralClientDefault("")
params := FIMRequestParams{
Model: "invalid-model",
Prompt: "This is a test prompt",
Suffix: "This is a test suffix",
MaxTokens: 10,
Temperature: 0.5,
}
res, err := client.FIM(&params)
assert.Error(t, err)
assert.Nil(t, res)
}
3 changes: 3 additions & 0 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ const (
ModelMistralLargeLatest = "mistral-large-latest"
ModelMistralMediumLatest = "mistral-medium-latest"
ModelMistralSmallLatest = "mistral-small-latest"
ModelCodestralLatest = "codestral-latest"

ModelOpenMixtral8x7b = "open-mixtral-8x7b"
ModelOpenMixtral8x22b = "open-mixtral-8x22b"
ModelOpenMistral7b = "open-mistral-7b"

ModelMistralLarge2402 = "mistral-large-2402"
Expand Down

0 comments on commit ebe86b3

Please sign in to comment.