Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore: Add openai o1 support #366

Merged
merged 6 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions deepseek-model-provider/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ go 1.23.4
replace github.com/obot-platform/tools/openai-model-provider => ../openai-model-provider

require github.com/obot-platform/tools/openai-model-provider v0.0.0

require github.com/gptscript-ai/chat-completion-client v0.0.0-20250123123106-c86554320789 // indirect
2 changes: 2 additions & 0 deletions deepseek-model-provider/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
github.com/gptscript-ai/chat-completion-client v0.0.0-20250123123106-c86554320789 h1:rfriXe+FFqZ5fZ+wGzLUivrq7Fyj2xfRdZjDsHf6Ps0=
github.com/gptscript-ai/chat-completion-client v0.0.0-20250123123106-c86554320789/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
2 changes: 1 addition & 1 deletion excel/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ go 1.23.1

require (
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.16.0
github.com/getkin/kin-openapi v0.124.0
github.com/gptscript-ai/go-gptscript v0.9.5
github.com/microsoft/kiota-abstractions-go v1.7.0
github.com/microsoftgraph/msgraph-sdk-go v1.51.0
Expand All @@ -13,7 +14,6 @@ require (
github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect
github.com/cjlapao/common-go v0.0.41 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/getkin/kin-openapi v0.124.0 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-openapi/jsonpointer v0.20.2 // indirect
Expand Down
2 changes: 2 additions & 0 deletions groq-model-provider/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ go 1.23.4
replace github.com/obot-platform/tools/openai-model-provider => ../openai-model-provider

require github.com/obot-platform/tools/openai-model-provider v0.0.0

require github.com/gptscript-ai/chat-completion-client v0.0.0-20250123123106-c86554320789 // indirect
2 changes: 2 additions & 0 deletions groq-model-provider/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
github.com/gptscript-ai/chat-completion-client v0.0.0-20250123123106-c86554320789 h1:rfriXe+FFqZ5fZ+wGzLUivrq7Fyj2xfRdZjDsHf6Ps0=
github.com/gptscript-ai/chat-completion-client v0.0.0-20250123123106-c86554320789/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
2 changes: 2 additions & 0 deletions ollama-model-provider/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ go 1.23.4
replace github.com/obot-platform/tools/openai-model-provider => ../openai-model-provider

require github.com/obot-platform/tools/openai-model-provider v0.0.0

require github.com/gptscript-ai/chat-completion-client v0.0.0-20250123123106-c86554320789 // indirect
2 changes: 2 additions & 0 deletions ollama-model-provider/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
github.com/gptscript-ai/chat-completion-client v0.0.0-20250123123106-c86554320789 h1:rfriXe+FFqZ5fZ+wGzLUivrq7Fyj2xfRdZjDsHf6Ps0=
github.com/gptscript-ai/chat-completion-client v0.0.0-20250123123106-c86554320789/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
2 changes: 2 additions & 0 deletions openai-model-provider/go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
module github.com/obot-platform/tools/openai-model-provider

go 1.23.4

require github.com/gptscript-ai/chat-completion-client v0.0.0-20250123123106-c86554320789
2 changes: 2 additions & 0 deletions openai-model-provider/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
github.com/gptscript-ai/chat-completion-client v0.0.0-20250123123106-c86554320789 h1:rfriXe+FFqZ5fZ+wGzLUivrq7Fyj2xfRdZjDsHf6Ps0=
github.com/gptscript-ai/chat-completion-client v0.0.0-20250123123106-c86554320789/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
21 changes: 16 additions & 5 deletions openai-model-provider/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package main

import (
"fmt"
"net/http"
"net/http/httputil"
"os"

"github.com/obot-platform/tools/openai-model-provider/openaiproxy"
"github.com/obot-platform/tools/openai-model-provider/proxy"
)

Expand All @@ -20,13 +23,21 @@ func main() {
}

cfg := &proxy.Config{
APIKey: apiKey,
ListenPort: port,
BaseURL: "https://api.openai.com/v1",
RewriteModelsFn: proxy.DefaultRewriteModelsResponse,
Name: "OpenAI",
APIKey: apiKey,
ListenPort: port,
BaseURL: "https://api.openai.com/v1",
RewriteModelsFn: proxy.DefaultRewriteModelsResponse,
Name: "OpenAI",
CustomPathHandleFuncs: map[string]http.HandlerFunc{},
}

openaiProxy := openaiproxy.NewServer(cfg)
reverseProxy := &httputil.ReverseProxy{
Director: openaiProxy.Openaiv1ProxyRedirect,
ModifyResponse: openaiProxy.ModifyResponse,
}
cfg.CustomPathHandleFuncs["/v1/"] = reverseProxy.ServeHTTP

if len(os.Args) > 1 && os.Args[1] == "validate" {
if err := cfg.Validate("/tools/openai-model-provider/validate"); err != nil {
os.Exit(1)
Expand Down
135 changes: 135 additions & 0 deletions openai-model-provider/openaiproxy/proxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package openaiproxy

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"

openai "github.com/gptscript-ai/chat-completion-client"
"github.com/obot-platform/tools/openai-model-provider/proxy"
)

type Server struct {
cfg *proxy.Config
}

func NewServer(cfg *proxy.Config) *Server {
return &Server{cfg: cfg}
}

func (s *Server) Openaiv1ProxyRedirect(req *http.Request) {
req.URL.Scheme = s.cfg.URL.Scheme
req.URL.Host = s.cfg.URL.Host
req.URL.Path = s.cfg.URL.JoinPath(strings.TrimPrefix(req.URL.Path, "/v1")).Path // join baseURL with request path - /v1 must be part of baseURL if it's needed
req.Host = req.URL.Host

req.Header.Set("Authorization", "Bearer "+s.cfg.APIKey)

if req.Body == nil || s.cfg.URL.Host != proxy.OpenaiBaseHostName || req.URL.Path != proxy.ChatCompletionsPath {
return
}

bodyBytes, err := io.ReadAll(req.Body)
if err != nil {
fmt.Println("failed to read request body, error: ", err.Error())
return
}

var reqBody openai.ChatCompletionRequest
if err := json.Unmarshal(bodyBytes, &reqBody); err == nil && isModelO1(reqBody.Model) {
if err := modifyRequestBodyForO1(req, &reqBody); err != nil {
fmt.Println("failed to modify request body for o1, error: ", err.Error())
req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
}
} else {
req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
}
}

func modifyRequestBodyForO1(req *http.Request, reqBody *openai.ChatCompletionRequest) error {
reqBody.Stream = false
reqBody.Temperature = nil
for i, msg := range reqBody.Messages {
if msg.Role == "system" {
reqBody.Messages[i].Role = "developer"
}
}
modifiedBodyBytes, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("failed to marshal request body after modification: %w", err)
}
req.Body = io.NopCloser(bytes.NewBuffer(modifiedBodyBytes))
req.ContentLength = int64(len(modifiedBodyBytes))
req.Header.Set("Accept", "application/json")
req.Header.Set("Accept-Encoding", "")
req.Header.Set("Content-Type", "application/json")
return nil
}

func (s *Server) ModifyResponse(resp *http.Response) error {
if resp.StatusCode != http.StatusOK || resp.Request.URL.Path != proxy.ChatCompletionsPath || resp.Request.URL.Host != proxy.OpenaiBaseHostName {
return nil
}

if resp.Header.Get("Content-Type") == "application/json" {
rawBody, err := io.ReadAll(resp.Body)
if err != nil {
resp.Body.Close()
return fmt.Errorf("failed to read response body: %w", err)
}
resp.Body.Close()
var respBody openai.ChatCompletionResponse
if err := json.Unmarshal(rawBody, &respBody); err == nil && isModelO1(respBody.Model) {
// Convert non-streaming response to a single SSE for o1 model
streamResponse := openai.ChatCompletionStreamResponse{
ID: respBody.ID,
Object: respBody.Object,
Created: respBody.Created,
Model: respBody.Model,
Usage: respBody.Usage,
Choices: func() []openai.ChatCompletionStreamChoice {
var choices []openai.ChatCompletionStreamChoice
for _, choice := range respBody.Choices {
choices = append(choices, openai.ChatCompletionStreamChoice{
Index: choice.Index,
Delta: openai.ChatCompletionStreamChoiceDelta{
Content: choice.Message.Content,
Role: choice.Message.Role,
FunctionCall: choice.Message.FunctionCall,
ToolCalls: choice.Message.ToolCalls,
},
FinishReason: choice.FinishReason,
})
}
return choices
}(),
}

sseData, err := json.Marshal(streamResponse)
if err != nil {
return fmt.Errorf("failed to marshal stream response: %w", err)
}

sseFormattedData := fmt.Sprintf("data: %s\n\nevent: close\ndata: [DONE]\n\n", sseData)

resp.Header.Set("Content-Type", "text/event-stream")
resp.Header.Set("Cache-Control", "no-cache")
resp.Header.Set("Connection", "keep-alive")
resp.Body = io.NopCloser(bytes.NewBufferString(sseFormattedData))
} else {
resp.Body = io.NopCloser(bytes.NewBuffer(rawBody))
}
}

return nil
}

func isModelO1(model string) bool {
if model == "o1" {
return true
}
return strings.HasPrefix(model, "o1-") && !strings.HasPrefix(model, "o1-mini") && !strings.HasPrefix(model, "o1-preview")
}
38 changes: 25 additions & 13 deletions openai-model-provider/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,14 @@ import (
"strings"
)

var (
OpenaiBaseHostName = "api.openai.com"

ChatCompletionsPath = "/v1/chat/completions"
)

type Config struct {
url *url.URL
URL *url.URL

// ListenPort is the port the proxy server listens on
ListenPort string
Expand Down Expand Up @@ -39,7 +45,7 @@ type server struct {
}

func (cfg *Config) ensureURL() error {
if cfg.url != nil {
if cfg.URL != nil {
return nil
}

Expand All @@ -58,7 +64,7 @@ func (cfg *Config) ensureURL() error {
}
}

cfg.url = u
cfg.URL = u
return nil
}

Expand All @@ -81,25 +87,31 @@ func Run(cfg *Config) error {

mux := http.NewServeMux()

// Register custom path handlers first
for path, handler := range cfg.CustomPathHandleFuncs {
mux.HandleFunc(path, handler)
}

// Register default handlers only if they are not already registered
if _, exists := cfg.CustomPathHandleFuncs["/{$}"]; !exists {
mux.HandleFunc("/{$}", s.healthz)
}
if _, exists := cfg.CustomPathHandleFuncs["/v1/models"]; !exists {
if handler, exists := cfg.CustomPathHandleFuncs["/v1/models"]; !exists {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@iwilltry42 can you review this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!
I don't like all those if-else statements, but have no cleaner solution right now (unfortunately, mux doesn't take care of the path sorting and neither allows overrides nor looking up what paths are already handled.

mux.Handle("/v1/models", &httputil.ReverseProxy{
Director: s.proxyDirector,
ModifyResponse: cfg.RewriteModelsFn,
})
} else {
mux.HandleFunc("/v1/models", handler)
}
if _, exists := cfg.CustomPathHandleFuncs["/v1/"]; !exists {
if handler, exists := cfg.CustomPathHandleFuncs["/v1/"]; !exists {
mux.Handle("/v1/", &httputil.ReverseProxy{
Director: s.proxyDirector,
})
} else {
mux.HandleFunc("/v1/", handler)
}

for path, handler := range cfg.CustomPathHandleFuncs {
if path == "/v1/models" || path == "/v1/" {
continue
}
mux.HandleFunc(path, handler)
}

httpServer := &http.Server{
Expand All @@ -119,9 +131,9 @@ func (s *server) healthz(w http.ResponseWriter, _ *http.Request) {
}

func (s *server) proxyDirector(req *http.Request) {
req.URL.Scheme = s.cfg.url.Scheme
req.URL.Host = s.cfg.url.Host
req.URL.Path = s.cfg.url.JoinPath(strings.TrimPrefix(req.URL.Path, "/v1")).Path // join baseURL with request path - /v1 must be part of baseURL if it's needed
req.URL.Scheme = s.cfg.URL.Scheme
req.URL.Host = s.cfg.URL.Host
req.URL.Path = s.cfg.URL.JoinPath(strings.TrimPrefix(req.URL.Path, "/v1")).Path // join baseURL with request path - /v1 must be part of baseURL if it's needed
req.Host = req.URL.Host

req.Header.Set("Authorization", "Bearer "+s.cfg.APIKey)
Expand Down
2 changes: 1 addition & 1 deletion openai-model-provider/proxy/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func (cfg *Config) Validate(toolPath string) error {
return fmt.Errorf("failed to ensure URL: %w", err)
}

url := cfg.url.JoinPath("/models")
url := cfg.URL.JoinPath("/models")

req, err := http.NewRequest("GET", url.String(), nil)
if err != nil {
Expand Down
13 changes: 0 additions & 13 deletions outlook/calendar/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,15 @@ require (
github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect
github.com/cjlapao/common-go v0.0.41 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/getkin/kin-openapi v0.128.0 // indirect
github.com/glebarez/go-sqlite v1.21.2 // indirect
github.com/glebarez/sqlite v1.11.0 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-openapi/jsonpointer v0.21.0 // indirect
github.com/go-openapi/swag v0.23.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/invopop/yaml v0.3.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.9 // indirect
github.com/microsoft/kiota-authentication-azure-go v1.1.0 // indirect
github.com/microsoft/kiota-http-go v1.4.5 // indirect
Expand All @@ -45,19 +39,12 @@ require (
github.com/olekukonko/tablewriter v0.0.5 // indirect
github.com/perimeterx/marshmallow v1.1.5 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf // indirect
github.com/std-uritemplate/std-uritemplate/go v1.0.6 // indirect
go.opentelemetry.io/otel v1.31.0 // indirect
go.opentelemetry.io/otel/metric v1.31.0 // indirect
go.opentelemetry.io/otel/trace v1.31.0 // indirect
golang.org/x/net v0.30.0 // indirect
golang.org/x/sys v0.26.0 // indirect
golang.org/x/text v0.19.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
gorm.io/gorm v1.25.7 // indirect
modernc.org/libc v1.22.5 // indirect
modernc.org/mathutil v1.5.0 // indirect
modernc.org/memory v1.5.0 // indirect
modernc.org/sqlite v1.23.1 // indirect
)
Loading