Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llms: implement StreamingReasoningFunc for reasoning models #1125

Merged
merged 3 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,13 @@ func main() {
content,
llms.WithMaxTokens(2000),
llms.WithTemperature(0.7),
llms.WithStreamingFunc(func(ctx context.Context, chunk []byte) error {
fmt.Print(string(chunk))
llms.WithStreamingReasoningFunc(func(ctx context.Context, reasoningChunk []byte, chunk []byte) error {
if len(reasoningChunk) > 0 {
fmt.Printf("Streaming Reasoning: %s\n", string(reasoningChunk))
}
if len(chunk) > 0 {
fmt.Printf("Streaming Content: %s\n", string(chunk))
}
return nil
}),
)
Expand Down
17 changes: 14 additions & 3 deletions llms/openai/internal/openaiclient/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ type ChatRequest struct {
// Return an error to stop streaming early.
StreamingFunc func(ctx context.Context, chunk []byte) error `json:"-"`

// StreamingReasoningFunc is a function to be called for each reasoning and content chunk of a streaming response.
// Return an error to stop streaming early.
StreamingReasoningFunc func(ctx context.Context, reasoningChunk, chunk []byte) error `json:"-"`

// Deprecated: use Tools instead.
Functions []FunctionDefinition `json:"functions,omitempty"`
// Deprecated: use ToolChoice instead.
Expand Down Expand Up @@ -380,7 +384,7 @@ type FunctionCall struct {
}

func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatCompletionResponse, error) {
if payload.StreamingFunc != nil {
if payload.StreamingFunc != nil || payload.StreamingReasoningFunc != nil {
payload.Stream = true
if payload.StreamOptions == nil {
payload.StreamOptions = &StreamOptions{IncludeUsage: true}
Expand Down Expand Up @@ -421,7 +425,7 @@ func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatCom

return nil, fmt.Errorf("%s: %s", msg, errResp.Error.Message) // nolint:goerr113
}
if payload.StreamingFunc != nil {
if payload.StreamingFunc != nil || payload.StreamingReasoningFunc != nil {
return parseStreamingChatResponse(ctx, r, payload)
}
// Parse response
Expand Down Expand Up @@ -493,9 +497,10 @@ func combineStreamingChatResponse(
}
choice := streamResponse.Choices[0]
chunk := []byte(choice.Delta.Content)
reasoningChunk := []byte(choice.Delta.ReasoningContent) // TODO: not sure if there will be any reasoning related to function call later, so just pass it here
response.Choices[0].Message.Content += choice.Delta.Content
response.Choices[0].FinishReason = choice.FinishReason
response.Choices[0].Message.ReasoningContent = choice.Delta.ReasoningContent
response.Choices[0].Message.ReasoningContent += choice.Delta.ReasoningContent

if choice.Delta.FunctionCall != nil {
chunk = updateFunctionCall(response.Choices[0].Message, choice.Delta.FunctionCall)
Expand All @@ -512,6 +517,12 @@ func combineStreamingChatResponse(
return nil, fmt.Errorf("streaming func returned an error: %w", err)
}
}
if payload.StreamingReasoningFunc != nil {
err := payload.StreamingReasoningFunc(ctx, reasoningChunk, chunk)
if err != nil {
return nil, fmt.Errorf("streaming reasoning func returned an error: %w", err)
}
}
}
return &response, nil
}
Expand Down
27 changes: 27 additions & 0 deletions llms/openai/internal/openaiclient/chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,33 @@ func TestParseStreamingChatResponse_ReasoningContent(t *testing.T) {
assert.Equal(t, FinishReason("stop"), resp.Choices[0].FinishReason)
}

func TestParseStreamingChatResponse_ReasoningFunc(t *testing.T) {
t.Parallel()
mockBody := `
data: {"id":"fa7e4fc5-a05d-4e7b-9a66-a2dd89e91a4e","object":"chat.completion.chunk","created":1738492867,"model":"deepseek-reasoner","system_fingerprint":"fp_7e73fd9a08","choices":[{"index":0,"delta":{"content":null,"reasoning_content":"Okay"},"logprobs":null,"finish_reason":null}]}
`
r := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(mockBody)),
}

req := &ChatRequest{
StreamingReasoningFunc: func(_ context.Context, reasoningChunk, chunk []byte) error {
t.Logf("reasoningChunk: %s", string(reasoningChunk))
t.Logf("chunk: %s", string(chunk))
return nil
},
}

resp, err := parseStreamingChatResponse(context.Background(), r, req)

require.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, "", resp.Choices[0].Message.Content)
assert.Equal(t, "Okay", resp.Choices[0].Message.ReasoningContent)
assert.Equal(t, FinishReason(""), resp.Choices[0].FinishReason)
}

func TestChatMessage_MarshalUnmarshal(t *testing.T) {
t.Parallel()
msg := ChatMessage{
Expand Down
17 changes: 9 additions & 8 deletions llms/openai/openaillm.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,15 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten
chatMsgs = append(chatMsgs, msg)
}
req := &openaiclient.ChatRequest{
Model: opts.Model,
StopWords: opts.StopWords,
Messages: chatMsgs,
StreamingFunc: opts.StreamingFunc,
Temperature: opts.Temperature,
N: opts.N,
FrequencyPenalty: opts.FrequencyPenalty,
PresencePenalty: opts.PresencePenalty,
Model: opts.Model,
StopWords: opts.StopWords,
Messages: chatMsgs,
StreamingFunc: opts.StreamingFunc,
StreamingReasoningFunc: opts.StreamingReasoningFunc,
Temperature: opts.Temperature,
N: opts.N,
FrequencyPenalty: opts.FrequencyPenalty,
PresencePenalty: opts.PresencePenalty,

MaxCompletionTokens: opts.MaxTokens,

Expand Down
10 changes: 10 additions & 0 deletions llms/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ type CallOptions struct {
// StreamingFunc is a function to be called for each chunk of a streaming response.
// Return an error to stop streaming early.
StreamingFunc func(ctx context.Context, chunk []byte) error `json:"-"`
// StreamingReasoningFunc is a function to be called for each chunk of a streaming response.
// Return an error to stop streaming early.
StreamingReasoningFunc func(ctx context.Context, reasoningChunk, chunk []byte) error `json:"-"`
// TopK is the number of tokens to consider for top-k sampling.
TopK int `json:"top_k"`
// TopP is the cumulative probability for top-p sampling.
Expand Down Expand Up @@ -162,6 +165,13 @@ func WithStreamingFunc(streamingFunc func(ctx context.Context, chunk []byte) err
}
}

// WithStreamingReasoningFunc specifies the streaming reasoning function to use.
func WithStreamingReasoningFunc(streamingReasoningFunc func(ctx context.Context, reasoningChunk, chunk []byte) error) CallOption {
return func(o *CallOptions) {
o.StreamingReasoningFunc = streamingReasoningFunc
}
}

// WithTopK will add an option to use top-k sampling.
func WithTopK(topK int) CallOption {
return func(o *CallOptions) {
Expand Down