From 78079b701a444a16e35f9c45c391c5cb3feb9fb2 Mon Sep 17 00:00:00 2001 From: Leo Douglas Date: Sun, 2 Feb 2025 19:01:48 +0800 Subject: [PATCH] feat: implement StreamingReasoningFunc for resoning models --- llms/openai/internal/openaiclient/chat.go | 17 +++++++++--- .../openai/internal/openaiclient/chat_test.go | 27 +++++++++++++++++++ llms/openai/openaillm.go | 17 ++++++------ llms/options.go | 10 +++++++ 4 files changed, 60 insertions(+), 11 deletions(-) diff --git a/llms/openai/internal/openaiclient/chat.go b/llms/openai/internal/openaiclient/chat.go index 74ca0d844..e26c2b087 100644 --- a/llms/openai/internal/openaiclient/chat.go +++ b/llms/openai/internal/openaiclient/chat.go @@ -67,6 +67,10 @@ type ChatRequest struct { // Return an error to stop streaming early. StreamingFunc func(ctx context.Context, chunk []byte) error `json:"-"` + // StreamingReasoningFunc is a function to be called for each reasoning and content chunk of a streaming response. + // Return an error to stop streaming early. + StreamingReasoningFunc func(ctx context.Context, reasoningChunk, chunk []byte) error `json:"-"` + // Deprecated: use Tools instead. Functions []FunctionDefinition `json:"functions,omitempty"` // Deprecated: use ToolChoice instead. @@ -380,7 +384,7 @@ type FunctionCall struct { } func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatCompletionResponse, error) { - if payload.StreamingFunc != nil { + if payload.StreamingFunc != nil || payload.StreamingReasoningFunc != nil { payload.Stream = true if payload.StreamOptions == nil { payload.StreamOptions = &StreamOptions{IncludeUsage: true} @@ -421,7 +425,7 @@ func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatCom return nil, fmt.Errorf("%s: %s", msg, errResp.Error.Message) // nolint:goerr113 } - if payload.StreamingFunc != nil { + if payload.StreamingFunc != nil || payload.StreamingReasoningFunc != nil { return parseStreamingChatResponse(ctx, r, payload) } // Parse response @@ -493,9 +497,10 @@ func combineStreamingChatResponse( } choice := streamResponse.Choices[0] chunk := []byte(choice.Delta.Content) + reasoningChunk := []byte(choice.Delta.ReasoningContent) // TODO: not sure if there will be any reasoning related to function call later, so just pass it here response.Choices[0].Message.Content += choice.Delta.Content response.Choices[0].FinishReason = choice.FinishReason - response.Choices[0].Message.ReasoningContent = choice.Delta.ReasoningContent + response.Choices[0].Message.ReasoningContent += choice.Delta.ReasoningContent if choice.Delta.FunctionCall != nil { chunk = updateFunctionCall(response.Choices[0].Message, choice.Delta.FunctionCall) @@ -512,6 +517,12 @@ func combineStreamingChatResponse( return nil, fmt.Errorf("streaming func returned an error: %w", err) } } + if payload.StreamingReasoningFunc != nil { + err := payload.StreamingReasoningFunc(ctx, reasoningChunk, chunk) + if err != nil { + return nil, fmt.Errorf("streaming reasoning func returned an error: %w", err) + } + } } return &response, nil } diff --git a/llms/openai/internal/openaiclient/chat_test.go b/llms/openai/internal/openaiclient/chat_test.go index 7260af894..f4150bafc 100644 --- a/llms/openai/internal/openaiclient/chat_test.go +++ b/llms/openai/internal/openaiclient/chat_test.go @@ -56,6 +56,33 @@ func TestParseStreamingChatResponse_ReasoningContent(t *testing.T) { assert.Equal(t, FinishReason("stop"), resp.Choices[0].FinishReason) } +func TestParseStreamingChatResponse_ReasoningFunc(t *testing.T) { + t.Parallel() + mockBody := ` +data: {"id":"fa7e4fc5-a05d-4e7b-9a66-a2dd89e91a4e","object":"chat.completion.chunk","created":1738492867,"model":"deepseek-reasoner","system_fingerprint":"fp_7e73fd9a08","choices":[{"index":0,"delta":{"content":null,"reasoning_content":"Okay"},"logprobs":null,"finish_reason":null}]} +` + r := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(mockBody)), + } + + req := &ChatRequest{ + StreamingReasoningFunc: func(_ context.Context, reasoningChunk, chunk []byte) error { + t.Logf("reasoningChunk: %s", string(reasoningChunk)) + t.Logf("chunk: %s", string(chunk)) + return nil + }, + } + + resp, err := parseStreamingChatResponse(context.Background(), r, req) + + require.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, "", resp.Choices[0].Message.Content) + assert.Equal(t, "Okay", resp.Choices[0].Message.ReasoningContent) + assert.Equal(t, FinishReason(""), resp.Choices[0].FinishReason) +} + func TestChatMessage_MarshalUnmarshal(t *testing.T) { t.Parallel() msg := ChatMessage{ diff --git a/llms/openai/openaillm.go b/llms/openai/openaillm.go index 78f8334d2..5710b6fc5 100644 --- a/llms/openai/openaillm.go +++ b/llms/openai/openaillm.go @@ -96,14 +96,15 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten chatMsgs = append(chatMsgs, msg) } req := &openaiclient.ChatRequest{ - Model: opts.Model, - StopWords: opts.StopWords, - Messages: chatMsgs, - StreamingFunc: opts.StreamingFunc, - Temperature: opts.Temperature, - N: opts.N, - FrequencyPenalty: opts.FrequencyPenalty, - PresencePenalty: opts.PresencePenalty, + Model: opts.Model, + StopWords: opts.StopWords, + Messages: chatMsgs, + StreamingFunc: opts.StreamingFunc, + StreamingReasoningFunc: opts.StreamingReasoningFunc, + Temperature: opts.Temperature, + N: opts.N, + FrequencyPenalty: opts.FrequencyPenalty, + PresencePenalty: opts.PresencePenalty, MaxCompletionTokens: opts.MaxTokens, diff --git a/llms/options.go b/llms/options.go index b6b595290..0c0a4afc8 100644 --- a/llms/options.go +++ b/llms/options.go @@ -21,6 +21,9 @@ type CallOptions struct { // StreamingFunc is a function to be called for each chunk of a streaming response. // Return an error to stop streaming early. StreamingFunc func(ctx context.Context, chunk []byte) error `json:"-"` + // StreamingReasoningFunc is a function to be called for each chunk of a streaming response. + // Return an error to stop streaming early. + StreamingReasoningFunc func(ctx context.Context, reasoningChunk, chunk []byte) error `json:"-"` // TopK is the number of tokens to consider for top-k sampling. TopK int `json:"top_k"` // TopP is the cumulative probability for top-p sampling. @@ -162,6 +165,13 @@ func WithStreamingFunc(streamingFunc func(ctx context.Context, chunk []byte) err } } +// WithStreamingReasoningFunc specifies the streaming reasoning function to use. +func WithStreamingReasoningFunc(streamingReasoningFunc func(ctx context.Context, reasoningChunk, chunk []byte) error) CallOption { + return func(o *CallOptions) { + o.StreamingReasoningFunc = streamingReasoningFunc + } +} + // WithTopK will add an option to use top-k sampling. func WithTopK(topK int) CallOption { return func(o *CallOptions) {