From fb2df0fe22002f1826bfaa1cb008c45db375885c Mon Sep 17 00:00:00 2001 From: Young-Jin Park Date: Wed, 31 Jul 2024 20:14:24 -0400 Subject: [PATCH] feat: add azure, examples, and message constructors --- README.md | 59 ++++- azure/azure.go | 237 ++++++++++++++++++ azure/azure_test.go | 130 ++++++++++ azure/example_auth_test.go | 47 ++++ chatcompletion.go | 92 +++++++ examples/audio-text-to-speech/main.go | 47 ++++ examples/audio-transcriptions/main.go | 29 +++ examples/chat-completion-streaming/main.go | 37 +++ examples/chat-completion-tool-calling/main.go | 94 +++++++ examples/chat-completion/main.go | 32 +++ examples/fine-tuning/fine-tuning-data.jsonl | 10 + examples/fine-tuning/main.go | 84 +++++++ examples/go.mod | 19 ++ examples/go.sum | 16 ++ examples/image-generation/main.go | 63 +++++ go.mod | 24 +- go.sum | 31 ++- internal/apijson/decoder.go | 121 ++++----- internal/apijson/json_test.go | 53 +--- 19 files changed, 1090 insertions(+), 135 deletions(-) create mode 100644 azure/azure.go create mode 100644 azure/azure_test.go create mode 100644 azure/example_auth_test.go create mode 100644 examples/audio-text-to-speech/main.go create mode 100644 examples/audio-transcriptions/main.go create mode 100644 examples/chat-completion-streaming/main.go create mode 100644 examples/chat-completion-tool-calling/main.go create mode 100644 examples/chat-completion/main.go create mode 100644 examples/fine-tuning/fine-tuning-data.jsonl create mode 100644 examples/fine-tuning/main.go create mode 100644 examples/go.mod create mode 100644 examples/go.sum create mode 100644 examples/image-generation/main.go diff --git a/README.md b/README.md index 8746450..0397a06 100644 --- a/README.md +++ b/README.md @@ -51,10 +51,9 @@ func main() { option.WithAPIKey("My API Key"), // defaults to os.LookupEnv("OPENAI_API_KEY") ) chatCompletion, err := client.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{ - Messages: openai.F([]openai.ChatCompletionMessageParamUnion{openai.ChatCompletionUserMessageParam{ - Role: openai.F(openai.ChatCompletionUserMessageParamRoleUser), - Content: openai.F[openai.ChatCompletionUserMessageParamContentUnion](shared.UnionString("Say this is a test")), - }}), + Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("Say this is a test"), + }), Model: openai.F(openai.ChatModelGPT4o), }) if err != nil { @@ -237,10 +236,9 @@ defer cancel() client.Chat.Completions.New( ctx, openai.ChatCompletionNewParams{ - Messages: openai.F([]openai.ChatCompletionMessageParamUnion{openai.ChatCompletionUserMessageParam{ - Role: openai.F(openai.ChatCompletionUserMessageParamRoleUser), - Content: openai.F[openai.ChatCompletionUserMessageParamContentUnion](shared.UnionString("How can I list all files in a directory using Python?")), - }}), + Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("Say this is a test"), + }), Model: openai.F(openai.ChatModelGPT4o), }, // This sets the per-retry timeout @@ -300,10 +298,9 @@ client := openai.NewClient( client.Chat.Completions.New( context.TODO(), openai.ChatCompletionNewParams{ - Messages: openai.F([]openai.ChatCompletionMessageParamUnion{openai.ChatCompletionUserMessageParam{ - Role: openai.F(openai.ChatCompletionUserMessageParamRoleUser), - Content: openai.F[openai.ChatCompletionUserMessageParamContentUnion](shared.UnionString("How can I get the name of the current day in Node.js?")), - }}), + Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("Say this is a test"), + }), Model: openai.F(openai.ChatModelGPT4o), }, option.WithMaxRetries(5), @@ -396,6 +393,44 @@ You may also replace the default `http.Client` with accepted (this overwrites any previous client) and receives requests after any middleware has been applied. +## Microsoft Azure OpenAI + +To use this library with [Azure OpenAI](https://learn.microsoft.com/azure/ai-services/openai/overview), use the option.RequestOption functions in the `azure` package. + +```go +package main + +import ( + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/openai/openai-go" + "github.com/openai/openai-go/azure" + "github.com/openai/openai-go/option" +) + +func main() { + const azureOpenAIEndpoint = "https://.openai.azure.com" + + // The latest API versions, including previews, can be found here: + // https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning + const azureOpenAIAPIVersion = "2024-06-01" + + tokenCredential, err := azidentity.NewDefaultAzureCredential(nil) + + if err != nil { + fmt.Printf("Failed to create the DefaultAzureCredential: %s", err) + os.Exit(1) + } + + client := openai.NewClient( + azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion), + + // Choose between authenticating using a TokenCredential or an API Key + azure.WithTokenCredential(tokenCredential), + // or azure.WithAPIKey(azureOpenAIAPIKey), + ) +} +``` + ## Semantic versioning This package generally follows [SemVer](https://semver.org/spec/v2.0.0.html) conventions, though certain backwards-incompatible changes may be released as minor versions: diff --git a/azure/azure.go b/azure/azure.go new file mode 100644 index 0000000..5d3156f --- /dev/null +++ b/azure/azure.go @@ -0,0 +1,237 @@ +// Package azure provides configuration options so you can connect and use Azure OpenAI using the [openai.Client]. +// +// Typical usage of this package will look like this: +// +// client := openai.NewClient( +// azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion), +// azure.WithTokenCredential(azureIdentityTokenCredential), +// // or azure.WithAPIKey(azureOpenAIAPIKey), +// ) +// +// Or, if you want to construct a specific service: +// +// client := openai.NewChatCompletionService( +// azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion), +// azure.WithTokenCredential(azureIdentityTokenCredential), +// // or azure.WithAPIKey(azureOpenAIAPIKey), +// ) +package azure + +import ( + "bytes" + "encoding/json" + "errors" + "io" + "mime" + "mime/multipart" + "net/http" + "net/url" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/openai/openai-go/internal/requestconfig" + "github.com/openai/openai-go/option" +) + +// WithEndpoint configures this client to connect to an Azure OpenAI endpoint. +// +// - endpoint - the Azure OpenAI endpoint to connect to. Ex: https://.openai.azure.com +// - apiVersion - the Azure OpenAI API version to target (ex: 2024-06-01). See [Azure OpenAI apiversions] for current API versions. This value cannot be empty. +// +// This function should be paired with a call to authenticate, like [azure.WithAPIKey] or [azure.WithTokenCredential], similar to this: +// +// client := openai.NewClient( +// azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion), +// azure.WithTokenCredential(azureIdentityTokenCredential), +// // or azure.WithAPIKey(azureOpenAIAPIKey), +// ) +// +// [Azure OpenAI apiversions]: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning +func WithEndpoint(endpoint string, apiVersion string) option.RequestOption { + if !strings.HasSuffix(endpoint, "/") { + endpoint += "/" + } + + endpoint += "openai/" + + withQueryAdd := option.WithQueryAdd("api-version", apiVersion) + withEndpoint := option.WithBaseURL(endpoint) + + withModelMiddleware := option.WithMiddleware(func(r *http.Request, mn option.MiddlewareNext) (*http.Response, error) { + replacementPath, err := getReplacementPathWithDeployment(r) + + if err != nil { + return nil, err + } + + r.URL.Path = replacementPath + return mn(r) + }) + + return func(rc *requestconfig.RequestConfig) error { + if apiVersion == "" { + return errors.New("apiVersion is an empty string, but needs to be set. See https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning for details.") + } + + if err := withQueryAdd(rc); err != nil { + return err + } + + if err := withEndpoint(rc); err != nil { + return err + } + + if err := withModelMiddleware(rc); err != nil { + return err + } + + return nil + } +} + +// WithTokenCredential configures this client to authenticate using an [Azure Identity] TokenCredential. +// This function should be paired with a call to [WithEndpoint] to point to your Azure OpenAI instance. +// +// [Azure Identity]: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azidentity +func WithTokenCredential(tokenCredential azcore.TokenCredential) option.RequestOption { + bearerTokenPolicy := runtime.NewBearerTokenPolicy(tokenCredential, []string{"https://cognitiveservices.azure.com/.default"}, nil) + + // add in a middleware that uses the bearer token generated from the token credential + return option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) { + pipeline := runtime.NewPipeline("azopenai-extensions", version, runtime.PipelineOptions{}, &policy.ClientOptions{ + InsecureAllowCredentialWithHTTP: true, // allow for plain HTTP proxies, etc.. + PerRetryPolicies: []policy.Policy{ + bearerTokenPolicy, + policyAdapter(next), + }, + }) + + req2, err := runtime.NewRequestFromRequest(req) + + if err != nil { + return nil, err + } + + return pipeline.Do(req2) + }) +} + +// WithAPIKey configures this client to authenticate using an API key. +// This function should be paired with a call to [WithEndpoint] to point to your Azure OpenAI instance. +func WithAPIKey(apiKey string) option.RequestOption { + // NOTE: there is an option.WithApiKey(), but that adds the value into + // the Authorization header instead so we're doing this instead. + return option.WithHeader("Api-Key", apiKey) +} + +// jsonRoutes have JSON payloads - we'll deserialize looking for a .model field in there +// so we won't have to worry about individual types for completions vs embeddings, etc... +var jsonRoutes = map[string]bool{ + "/openai/completions": true, + "/openai/chat/completions": true, + "/openai/embeddings": true, + "/openai/audio/speech": true, + "/openai/images/generations": true, +} + +// audioMultipartRoutes have mime/multipart payloads. These are less generic - we're very much +// expecting a transcription or translation payload for these. +var audioMultipartRoutes = map[string]bool{ + "/openai/audio/transcriptions": true, + "/openai/audio/translations": true, +} + +// getReplacementPathWithDeployment parses the request body to extract out the Model parameter (or equivalent) +// (note, the req.Body is fully read as part of this, and is replaced with a bytes.Reader) +func getReplacementPathWithDeployment(req *http.Request) (string, error) { + if jsonRoutes[req.URL.Path] { + return getJSONRoute(req) + } + + if audioMultipartRoutes[req.URL.Path] { + return getAudioMultipartRoute(req) + } + + // No need to relocate the path. We've already tacked on /openai when we setup the endpoint. + return req.URL.Path, nil +} + +func getJSONRoute(req *http.Request) (string, error) { + // we need to deserialize the body, partly, in order to read out the model field. + jsonBytes, err := io.ReadAll(req.Body) + + if err != nil { + return "", err + } + + // make sure we restore the body so it can be used in later middlewares. + req.Body = io.NopCloser(bytes.NewReader(jsonBytes)) + + var v *struct { + Model string `json:"model"` + } + + if err := json.Unmarshal(jsonBytes, &v); err != nil { + return "", err + } + + escapedDeployment := url.PathEscape(v.Model) + return strings.Replace(req.URL.Path, "/openai/", "/openai/deployments/"+escapedDeployment+"/", 1), nil +} + +func getAudioMultipartRoute(req *http.Request) (string, error) { + // body is a multipart/mime body type instead. + mimeBytes, err := io.ReadAll(req.Body) + + if err != nil { + return "", err + } + + // make sure we restore the body so it can be used in later middlewares. + req.Body = io.NopCloser(bytes.NewReader(mimeBytes)) + + _, mimeParams, err := mime.ParseMediaType(req.Header.Get("Content-Type")) + + if err != nil { + return "", err + } + + mimeReader := multipart.NewReader( + io.NopCloser(bytes.NewReader(mimeBytes)), + mimeParams["boundary"]) + + for { + mp, err := mimeReader.NextPart() + + if err != nil { + if errors.Is(err, io.EOF) { + return "", errors.New("unable to find the model part in multipart body") + } + + return "", err + } + + defer mp.Close() + + if mp.FormName() == "model" { + modelBytes, err := io.ReadAll(mp) + + if err != nil { + return "", err + } + + escapedDeployment := url.PathEscape(string(modelBytes)) + return strings.Replace(req.URL.Path, "/openai/", "/openai/deployments/"+escapedDeployment+"/", 1), nil + } + } +} + +type policyAdapter option.MiddlewareNext + +func (mp policyAdapter) Do(req *policy.Request) (*http.Response, error) { + return (option.MiddlewareNext)(mp)(req.Raw()) +} + +const version = "v.0.1.0" diff --git a/azure/azure_test.go b/azure/azure_test.go new file mode 100644 index 0000000..00f5733 --- /dev/null +++ b/azure/azure_test.go @@ -0,0 +1,130 @@ +package azure + +import ( + "bytes" + "mime/multipart" + "net/http" + "testing" + + "github.com/openai/openai-go" + "github.com/openai/openai-go/internal/apijson" + "github.com/openai/openai-go/shared" +) + +func TestJSONRoute(t *testing.T) { + chatCompletionParams := openai.ChatCompletionNewParams{ + Model: openai.F(openai.ChatModel("arbitraryDeployment")), + Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ + openai.ChatCompletionAssistantMessageParam{ + Role: openai.F(openai.ChatCompletionAssistantMessageParamRoleAssistant), + Content: openai.String("You are a helpful assistant"), + }, + openai.ChatCompletionUserMessageParam{ + Role: openai.F(openai.ChatCompletionUserMessageParamRoleUser), + Content: openai.F[openai.ChatCompletionUserMessageParamContentUnion](shared.UnionString("Can you tell me another word for the universe?")), + }, + }), + } + + serializedBytes, err := apijson.MarshalRoot(chatCompletionParams) + + if err != nil { + t.Fatal(err) + } + + req, err := http.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(serializedBytes)) + + if err != nil { + t.Fatal(err) + } + + replacementPath, err := getReplacementPathWithDeployment(req) + + if err != nil { + t.Fatal(err) + } + + if replacementPath != "/openai/deployments/arbitraryDeployment/chat/completions" { + t.Fatalf("replacementpath didn't match: %s", replacementPath) + } +} + +func TestGetAudioMultipartRoute(t *testing.T) { + buff := &bytes.Buffer{} + mw := multipart.NewWriter(buff) + defer mw.Close() + + fw, err := mw.CreateFormFile("file", "test.mp3") + + if err != nil { + t.Fatal(err) + } + + if _, err = fw.Write([]byte("ignore me")); err != nil { + t.Fatal(err) + } + + if err := mw.WriteField("model", "arbitraryDeployment"); err != nil { + t.Fatal(err) + } + + if err := mw.Close(); err != nil { + t.Fatal(err) + } + + req, err := http.NewRequest("POST", "/openai/audio/transcriptions", bytes.NewReader(buff.Bytes())) + + if err != nil { + t.Fatal(err) + } + + req.Header.Set("Content-Type", mw.FormDataContentType()) + + replacementPath, err := getReplacementPathWithDeployment(req) + + if err != nil { + t.Fatal(err) + } + + if replacementPath != "/openai/deployments/arbitraryDeployment/audio/transcriptions" { + t.Fatalf("replacementpath didn't match: %s", replacementPath) + } +} + +func TestNoRouteChangeNeeded(t *testing.T) { + chatCompletionParams := openai.ChatCompletionNewParams{ + Model: openai.F(openai.ChatModel("arbitraryDeployment")), + Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ + openai.ChatCompletionAssistantMessageParam{ + Role: openai.F(openai.ChatCompletionAssistantMessageParamRoleAssistant), + Content: openai.String("You are a helpful assistant"), + }, + openai.ChatCompletionUserMessageParam{ + Role: openai.F(openai.ChatCompletionUserMessageParamRoleUser), + Content: openai.F[openai.ChatCompletionUserMessageParamContentUnion](shared.UnionString("Can you tell me another word for the universe?")), + }, + }), + } + + serializedBytes, err := apijson.MarshalRoot(chatCompletionParams) + + if err != nil { + t.Fatal(err) + } + + req, err := http.NewRequest("POST", "/openai/does/not/need/a/deployment", bytes.NewReader(serializedBytes)) + + if err != nil { + t.Fatal(err) + } + + replacementPath, err := getReplacementPathWithDeployment(req) + + if err != nil { + t.Fatal(err) + } + + if replacementPath != "/openai/does/not/need/a/deployment" { + t.Fatalf("replacementpath didn't match: %s", replacementPath) + } +} diff --git a/azure/example_auth_test.go b/azure/example_auth_test.go new file mode 100644 index 0000000..3a8ef21 --- /dev/null +++ b/azure/example_auth_test.go @@ -0,0 +1,47 @@ +package azure_test + +import ( + "fmt" + + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/openai/openai-go" + "github.com/openai/openai-go/azure" +) + +func Example_authentication() { + // There are two ways to authenticate - using a TokenCredential (via the azidentity + // package), or using an API Key. + const azureOpenAIEndpoint = "https://.openai.azure.com" + const azureOpenAIAPIVersion = "" + + // Using a TokenCredential + { + // For a full list of credential types look at the documentation for the Azure Identity + // package: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azidentity + tokenCredential, err := azidentity.NewDefaultAzureCredential(nil) + + if err != nil { + fmt.Printf("Failed to create TokenCredential: %s\n", err) + return + } + + client := openai.NewClient( + azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion), + azure.WithTokenCredential(tokenCredential), + ) + + _ = client + } + + // Using an API Key + { + const azureOpenAIAPIKey = "" + + client := openai.NewClient( + azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion), + azure.WithAPIKey(azureOpenAIAPIKey), + ) + + _ = client + } +} diff --git a/chatcompletion.go b/chatcompletion.go index 4dcd393..f6331cb 100644 --- a/chatcompletion.go +++ b/chatcompletion.go @@ -12,8 +12,73 @@ import ( "github.com/openai/openai-go/option" "github.com/openai/openai-go/packages/ssestream" "github.com/openai/openai-go/shared" + "github.com/tidwall/sjson" ) +func UserMessage(content string) ChatCompletionMessageParamUnion { + return ChatCompletionUserMessageParam{ + Role: F(ChatCompletionUserMessageParamRoleUser), + Content: F[ChatCompletionUserMessageParamContentUnion]( + shared.UnionString(content), + ), + } +} + +func UserMessageBlocks(blocks ...ChatCompletionContentPartUnionParam) ChatCompletionMessageParamUnion { + return ChatCompletionUserMessageParam{ + Role: F(ChatCompletionUserMessageParamRoleUser), + Content: F[ChatCompletionUserMessageParamContentUnion]( + ChatCompletionUserMessageParamContentArrayOfContentParts(blocks), + ), + } +} + +func UserMessageTextBlock(content string) ChatCompletionContentPartUnionParam { + return ChatCompletionContentPartTextParam{ + Type: F(ChatCompletionContentPartTextTypeText), + Text: F(content), + } +} + +func UserMessageImageBlock(url string) ChatCompletionContentPartUnionParam { + return ChatCompletionContentPartImageParam{ + Type: F(ChatCompletionContentPartImageTypeImageURL), + ImageURL: F(ChatCompletionContentPartImageImageURLParam{ + URL: F(url), + }), + } +} + +func AssistantMessage(content string) ChatCompletionMessageParamUnion { + return ChatCompletionAssistantMessageParam{ + Role: F(ChatCompletionAssistantMessageParamRoleAssistant), + Content: F(content), + } +} + +func ToolMessage(toolCallID, content string) ChatCompletionMessageParamUnion { + return ChatCompletionToolMessageParam{ + Role: F(ChatCompletionToolMessageParamRoleTool), + ToolCallID: F(toolCallID), + Content: F(content), + } +} + +func SystemMessage(content string) ChatCompletionMessageParamUnion { + return ChatCompletionSystemMessageParam{ + Role: F(ChatCompletionSystemMessageParamRoleSystem), + Content: F(content), + } +} + +func FunctionMessage(name, content string) ChatCompletionMessageParamUnion { + return ChatCompletionFunctionMessageParam{ + Role: F(ChatCompletionFunctionMessageParamRoleFunction), + Name: F(name), + Content: F(content), + } +} + // ChatCompletionService contains methods and other services that help with // interacting with the openai API. // @@ -785,10 +850,35 @@ func (r *ChatCompletionMessage) UnmarshalJSON(data []byte) (err error) { return apijson.UnmarshalRoot(data, r) } +func (r ChatCompletionMessage) MarshalJSON() (data []byte, err error) { + s := "" + s, _ = sjson.Set(s, "role", r.Role) + + if r.FunctionCall.Name != "" { + b, err := apijson.Marshal(r.FunctionCall) + if err != nil { + return nil, err + } + s, _ = sjson.SetRaw(s, "function_call", string(b)) + } else if len(r.ToolCalls) > 0 { + b, err := apijson.Marshal(r.ToolCalls) + if err != nil { + return nil, err + } + s, _ = sjson.SetRaw(s, "tool_calls", string(b)) + } else { + s, _ = sjson.Set(s, "content", r.Content) + } + + return []byte(s), nil +} + func (r chatCompletionMessageJSON) RawJSON() string { return r.raw } +func (r ChatCompletionMessage) implementsChatCompletionMessageParamUnion() {} + // The role of the author of this message. type ChatCompletionMessageRole string @@ -857,6 +947,8 @@ func (r ChatCompletionMessageParam) implementsChatCompletionMessageParamUnion() // [ChatCompletionUserMessageParam], [ChatCompletionAssistantMessageParam], // [ChatCompletionToolMessageParam], [ChatCompletionFunctionMessageParam], // [ChatCompletionMessageParam]. +// +// This union is additionally satisfied by the return types [ChatCompletionMessage] type ChatCompletionMessageParamUnion interface { implementsChatCompletionMessageParamUnion() } diff --git a/examples/audio-text-to-speech/main.go b/examples/audio-text-to-speech/main.go new file mode 100644 index 0000000..7a268b3 --- /dev/null +++ b/examples/audio-text-to-speech/main.go @@ -0,0 +1,47 @@ +package main + +import ( + "context" + "time" + + "github.com/ebitengine/oto/v3" + "github.com/openai/openai-go" +) + +func main() { + client := openai.NewClient() + ctx := context.Background() + + res, err := client.Audio.Speech.New(ctx, openai.AudioSpeechNewParams{ + Model: openai.F(openai.AudioSpeechNewParamsModelTTS1), + Input: openai.String(`Why did the chicken cross the road? To get to the other side.`), + ResponseFormat: openai.F(openai.AudioSpeechNewParamsResponseFormatPCM), + Voice: openai.F(openai.AudioSpeechNewParamsVoiceAlloy), + }) + defer res.Body.Close() + if err != nil { + panic(err) + } + + op := &oto.NewContextOptions{} + op.SampleRate = 24000 + op.ChannelCount = 1 + op.Format = oto.FormatSignedInt16LE + + otoCtx, readyChan, err := oto.NewContext(op) + if err != nil { + panic("oto.NewContext failed: " + err.Error()) + } + + <-readyChan + + player := otoCtx.NewPlayer(res.Body) + player.Play() + for player.IsPlaying() { + time.Sleep(time.Millisecond) + } + err = player.Close() + if err != nil { + panic("player.Close failed: " + err.Error()) + } +} diff --git a/examples/audio-transcriptions/main.go b/examples/audio-transcriptions/main.go new file mode 100644 index 0000000..f2f7f2e --- /dev/null +++ b/examples/audio-transcriptions/main.go @@ -0,0 +1,29 @@ +package main + +import ( + "context" + "io" + "os" + + "github.com/openai/openai-go" +) + +func main() { + client := openai.NewClient() + ctx := context.Background() + + file, err := os.Open("speech.mp3") + if err != nil { + panic(err) + } + + transcription, err := client.Audio.Transcriptions.New(ctx, openai.AudioTranscriptionNewParams{ + Model: openai.F(openai.AudioTranscriptionNewParamsModelWhisper1), + File: openai.F[io.Reader](file), + }) + if err != nil { + panic(err) + } + + println(transcription.Text) +} diff --git a/examples/chat-completion-streaming/main.go b/examples/chat-completion-streaming/main.go new file mode 100644 index 0000000..9591973 --- /dev/null +++ b/examples/chat-completion-streaming/main.go @@ -0,0 +1,37 @@ +package main + +import ( + "context" + + "github.com/openai/openai-go" +) + +func main() { + client := openai.NewClient() + + ctx := context.Background() + + question := "Write me a haiku" + + print("> ") + println(question) + println() + + stream := client.Chat.Completions.NewStreaming(ctx, openai.ChatCompletionNewParams{ + Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(question), + }), + Seed: openai.Int(0), + Model: openai.F(openai.ChatModelGPT4o), + }) + + for stream.Next() { + evt := stream.Current() + print(evt.Choices[0].Delta.Content) + } + println() + + if err := stream.Err(); err != nil { + panic(err) + } +} diff --git a/examples/chat-completion-tool-calling/main.go b/examples/chat-completion-tool-calling/main.go new file mode 100644 index 0000000..d82a77a --- /dev/null +++ b/examples/chat-completion-tool-calling/main.go @@ -0,0 +1,94 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/openai/openai-go" +) + +func main() { + client := openai.NewClient() + + ctx := context.Background() + + question := "What is the weather in New York City?" + + print("> ") + println(question) + + params := openai.ChatCompletionNewParams{ + Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(question), + }), + Tools: openai.F([]openai.ChatCompletionToolParam{ + { + Type: openai.F(openai.ChatCompletionToolTypeFunction), + Function: openai.F(openai.FunctionDefinitionParam{ + Name: openai.String("get_weather"), + Description: openai.String("Get weather at the given location"), + Parameters: openai.F(openai.FunctionParameters{ + "type": "object", + "properties": map[string]interface{}{ + "location": map[string]string{ + "type": "string", + }, + }, + "required": []string{"location"}, + }), + }), + }, + }), + Seed: openai.Int(0), + Model: openai.F(openai.ChatModelGPT4o), + } + + // Make initial chat completion request + completion, err := client.Chat.Completions.New(ctx, params) + if err != nil { + panic(err) + } + + toolCalls := completion.Choices[0].Message.ToolCalls + + // Abort early if there are no tool calls + if len(toolCalls) == 0 { + fmt.Printf("No function call") + return + } + + // If there is a was a function call, continue the conversation + params.Messages.Value = append(params.Messages.Value, completion.Choices[0].Message) + for _, toolCall := range toolCalls { + if toolCall.Function.Name == "get_weather" { + // Extract the location from the function call arguments + var args map[string]interface{} + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil { + panic(err) + } + location := args["location"].(string) + + // Simulate getting weather data + weatherData := getWeather(location) + + // Print the weather data + fmt.Printf("Weather in %s: %s\n", location, weatherData) + + params.Messages.Value = append(params.Messages.Value, openai.ToolMessage(toolCall.ID, weatherData)) + } + } + + completion, err = client.Chat.Completions.New(ctx, params) + if err != nil { + panic(err) + } + + println(completion.Choices[0].Message.Content) +} + +// Mock function to simulate weather data retrieval +func getWeather(location string) string { + // In a real implementation, this function would call a weather API + return "Sunny, 25°C" +} diff --git a/examples/chat-completion/main.go b/examples/chat-completion/main.go new file mode 100644 index 0000000..a845b5c --- /dev/null +++ b/examples/chat-completion/main.go @@ -0,0 +1,32 @@ +package main + +import ( + "context" + + "github.com/openai/openai-go" +) + +func main() { + client := openai.NewClient() + + ctx := context.Background() + + question := "Write me a haiku" + + print("> ") + println(question) + println() + + completion, err := client.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ + Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(question), + }), + Seed: openai.Int(0), + Model: openai.F(openai.ChatModelGPT4o), + }) + if err != nil { + panic(err) + } + + println(completion.Choices[0].Message.Content) +} diff --git a/examples/fine-tuning/fine-tuning-data.jsonl b/examples/fine-tuning/fine-tuning-data.jsonl new file mode 100644 index 0000000..e9ab267 --- /dev/null +++ b/examples/fine-tuning/fine-tuning-data.jsonl @@ -0,0 +1,10 @@ +{"messages": [{"role": "system", "content": "Marvin is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]} +{"messages": [{"role": "system", "content": "Marvin is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "Oh, just some guy named William Shakespeare. Ever heard of him?"}]} +{"messages": [{"role": "system", "content": "Marvin is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "Around 384,400 kilometers. Give or take a few, like that really matters."}]} +{"messages": [{"role": "system", "content": "Marvin is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the tallest mountain in the world?"}, {"role": "assistant", "content": "Everest, of course. Because no other mountain even comes close!"}]} +{"messages": [{"role": "system", "content": "Marvin is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who was the first President of the United States?"}, {"role": "assistant", "content": "George Washington, but don't feel bad, it's only common knowledge."}]} +{"messages": [{"role": "system", "content": "Marvin is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the boiling point of water?"}, {"role": "assistant", "content": "100°C at sea level. Shocking, isn't it?"}]} +{"messages": [{"role": "system", "content": "Marvin is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the largest ocean on Earth?"}, {"role": "assistant", "content": "The Pacific Ocean. Surprised? Me neither."}]} +{"messages": [{"role": "system", "content": "Marvin is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the smallest planet in our solar system?"}, {"role": "assistant", "content": "Mercury. It's not like it's a big deal or anything."}]} +{"messages": [{"role": "system", "content": "Marvin is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who discovered penicillin?"}, {"role": "assistant", "content": "Alexander Fleming. But don't worry, it's not like he's famous for it or anything."}]} +{"messages": [{"role": "system", "content": "Marvin is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "When was the internet invented?"}, {"role": "assistant", "content": "In the 1960s, but hey, who's counting?"}]} diff --git a/examples/fine-tuning/main.go b/examples/fine-tuning/main.go new file mode 100644 index 0000000..5c3a744 --- /dev/null +++ b/examples/fine-tuning/main.go @@ -0,0 +1,84 @@ +package main + +import ( + "context" + "fmt" + "io" + "os" + "time" + + "github.com/openai/openai-go" +) + +func main() { + client := openai.NewClient() + ctx := context.Background() + + fmt.Println("==> Uploading file") + + data, err := os.Open("./fine-tuning-data.jsonl") + file, err := client.Files.New(ctx, openai.FileNewParams{ + File: openai.F[io.Reader](data), + Purpose: openai.F(openai.FileNewParamsPurposeFineTune), + }) + if err != nil { + panic(err) + } + fmt.Printf("Uploaded file with ID: %s\n", file.ID) + + fmt.Println("Waiting for file to be processed") + for { + file, err = client.Files.Get(ctx, file.ID) + if err != nil { + panic(err) + } + fmt.Printf("File status: %s\n", file.Status) + if file.Status == "processed" { + break + } + time.Sleep(time.Second) + } + + fmt.Println("") + fmt.Println("==> Starting fine-tuning") + fineTune, err := client.FineTuning.Jobs.New(ctx, openai.FineTuningJobNewParams{ + Model: openai.F(openai.FineTuningJobNewParamsModelGPT3_5Turbo), + TrainingFile: openai.F(file.ID), + }) + if err != nil { + panic(err) + } + fmt.Printf("Fine-tuning ID: %s\n", fineTune.ID) + + fmt.Println("") + fmt.Println("==> Track fine-tuning progress:") + + events := make(map[string]openai.FineTuningJobEvent) + + for fineTune.Status == "running" || fineTune.Status == "queued" || fineTune.Status == "validating_files" { + fineTune, err = client.FineTuning.Jobs.Get(ctx, fineTune.ID) + if err != nil { + panic(err) + } + fmt.Println(fineTune.Status) + + page, err := client.FineTuning.Jobs.ListEvents(ctx, fineTune.ID, openai.FineTuningJobListEventsParams{ + Limit: openai.Int(100), + }) + if err != nil { + panic(err) + } + + for i := len(page.Data) - 1; i >= 0; i-- { + event := page.Data[i] + if _, exists := events[event.ID]; exists { + continue + } + events[event.ID] = event + timestamp := time.Unix(int64(event.CreatedAt), 0) + fmt.Printf("- %s: %s\n", timestamp.Format(time.Kitchen), event.Message) + } + + time.Sleep(5 * time.Second) + } +} diff --git a/examples/go.mod b/examples/go.mod new file mode 100644 index 0000000..be82650 --- /dev/null +++ b/examples/go.mod @@ -0,0 +1,19 @@ +module github.com/openai/openai-go/examples + +replace github.com/openai/openai-go => ../ + +go 1.22.4 + +require ( + github.com/ebitengine/oto/v3 v3.2.0 + github.com/openai/openai-go v0.0.0-00010101000000-000000000000 +) + +require ( + github.com/ebitengine/purego v0.7.0 // indirect + github.com/tidwall/gjson v1.14.4 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect + golang.org/x/sys v0.22.0 // indirect +) diff --git a/examples/go.sum b/examples/go.sum new file mode 100644 index 0000000..447fab2 --- /dev/null +++ b/examples/go.sum @@ -0,0 +1,16 @@ +github.com/ebitengine/oto/v3 v3.2.0 h1:FuggTJTSI3/3hEYwZEIN0CZVXYT29ZOdCu+z/f4QjTw= +github.com/ebitengine/oto/v3 v3.2.0/go.mod h1:dOKXShvy1EQbIXhXPFcKLargdnFqH0RjptecvyAxhyw= +github.com/ebitengine/purego v0.7.0 h1:HPZpl61edMGCEW6XK2nsR6+7AnJ3unUxpTZBkkIXnMc= +github.com/ebitengine/purego v0.7.0/go.mod h1:ah1In8AOtksoNK6yk5z1HTJeUkC1Ez4Wk2idgGslMwQ= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= +github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= +golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= diff --git a/examples/image-generation/main.go b/examples/image-generation/main.go new file mode 100644 index 0000000..533e6f4 --- /dev/null +++ b/examples/image-generation/main.go @@ -0,0 +1,63 @@ +package main + +import ( + "context" + "encoding/base64" + "os" + + "github.com/openai/openai-go" +) + +func main() { + client := openai.NewClient() + + ctx := context.Background() + + prompt := "A cute robot in a forest of trees." + + print("> ") + println(prompt) + println() + + // Image URL + + image, err := client.Images.Generate(ctx, openai.ImageGenerateParams{ + Prompt: openai.String(prompt), + Model: openai.F(openai.ImageGenerateParamsModelDallE3), + ResponseFormat: openai.F(openai.ImageGenerateParamsResponseFormatURL), + N: openai.Int(1), + }) + if err != nil { + panic(err) + } + println("Image URL:") + println(image.Data[0].URL) + println() + + // Base64 + + image, err = client.Images.Generate(ctx, openai.ImageGenerateParams{ + Prompt: openai.String(prompt), + Model: openai.F(openai.ImageGenerateParamsModelDallE3), + ResponseFormat: openai.F(openai.ImageGenerateParamsResponseFormatB64JSON), + N: openai.Int(1), + }) + if err != nil { + panic(err) + } + println("Image Base64 Length:") + println(len(image.Data[0].B64JSON)) + println() + + imageBytes, err := base64.StdEncoding.DecodeString(image.Data[0].B64JSON) + if err != nil { + panic(err) + } + + dest := "./image.png" + println("Writing image to " + dest) + err = os.WriteFile(dest, imageBytes, 0755) + if err != nil { + panic(err) + } +} diff --git a/go.mod b/go.mod index 1e064e6..a487ea3 100644 --- a/go.mod +++ b/go.mod @@ -3,9 +3,27 @@ module github.com/openai/openai-go go 1.19 require ( - github.com/google/uuid v1.3.0 // indirect - github.com/tidwall/gjson v1.14.4 // indirect + github.com/tidwall/gjson v1.14.4 + github.com/tidwall/sjson v1.2.5 +) + +require ( + github.com/google/uuid v1.6.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect - github.com/tidwall/sjson v1.2.5 // indirect +) + +require ( + // NOTE: these dependencies are only used for the `azure` subpackage. + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.13.0 + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 + github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect + github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect + github.com/golang-jwt/jwt/v5 v5.2.1 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect + github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect + golang.org/x/crypto v0.25.0 // indirect + golang.org/x/net v0.27.0 // indirect + golang.org/x/sys v0.22.0 // indirect + golang.org/x/text v0.16.0 // indirect ) diff --git a/go.sum b/go.sum index 569e555..240415f 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,22 @@ -github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.13.0 h1:GJHeeA2N7xrG3q30L2UXDyuWRzDM900/65j70wcM4Ww= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.13.0/go.mod h1:l38EPgmsp71HHLq9j7De57JcKOWPyhrsW1Awm1JS6K0= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 h1:tfLQ34V6F7tVSwoTf/4lH5sE0o6eCJuNDTmH09nDpbc= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0/go.mod h1:9kIvujWAA58nmPmWB1m23fyWic1kYZMxD9CxaWn4Qpg= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 h1:ywEEhmNahHBihViHepv3xPBn1663uRv2t2q/ESv9seY= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkYRNWPENUnqx6bJ2xnSDFI2tjwZNuY= +github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 h1:XHOnouVk1mxXfQidrMEnLlPk9UMeRtyBTnEFtxkV0kU= +github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= @@ -10,3 +27,13 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= +golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= +golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= +golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/internal/apijson/decoder.go b/internal/apijson/decoder.go index e1b21b7..deb0bac 100644 --- a/internal/apijson/decoder.go +++ b/internal/apijson/decoder.go @@ -214,31 +214,17 @@ func (d *decoderBuilder) newUnionDecoder(t reflect.Type) decoderFunc { decoders = append(decoders, decoder) } return func(n gjson.Result, v reflect.Value, state *decoderState) error { - // If there is a discriminator match, circumvent the exactness logic entirely - for idx, variant := range unionEntry.variants { - decoder := decoders[idx] - if variant.TypeFilter != n.Type { - continue - } - - if len(unionEntry.discriminatorKey) != 0 { - discriminatorValue := n.Get(unionEntry.discriminatorKey).Value() - if discriminatorValue == variant.DiscriminatorValue { - inner := reflect.New(variant.Type).Elem() - err := decoder(n, inner, state) - v.Set(inner) - return err - } - } - } - // Set bestExactness to worse than loose bestExactness := loose - 1 + for idx, variant := range unionEntry.variants { decoder := decoders[idx] if variant.TypeFilter != n.Type { continue } + if len(unionEntry.discriminatorKey) != 0 && n.Get(unionEntry.discriminatorKey).Value() != variant.DiscriminatorValue { + continue + } sub := decoderState{strict: state.strict, exactness: exact} inner := reflect.New(variant.Type).Elem() err := decoder(n, inner, &sub) @@ -339,58 +325,62 @@ func (d *decoderBuilder) newArrayTypeDecoder(t reflect.Type) decoderFunc { func (d *decoderBuilder) newStructTypeDecoder(t reflect.Type) decoderFunc { // map of json field name to struct field decoders decoderFields := map[string]decoderField{} - anonymousDecoders := []decoderField{} extraDecoder := (*decoderField)(nil) inlineDecoder := (*decoderField)(nil) - for i := 0; i < t.NumField(); i++ { - idx := []int{i} - field := t.FieldByIndex(idx) - if !field.IsExported() { - continue - } - // If this is an embedded struct, traverse one level deeper to extract - // the fields and get their encoders as well. - if field.Anonymous { - anonymousDecoders = append(anonymousDecoders, decoderField{ - fn: d.typeDecoder(field.Type), - idx: idx[:], - }) - continue - } - // If json tag is not present, then we skip, which is intentionally - // different behavior from the stdlib. - ptag, ok := parseJSONStructTag(field) - if !ok { - continue - } - // We only want to support unexported fields if they're tagged with - // `extras` because that field shouldn't be part of the public API. - if ptag.extras { - extraDecoder = &decoderField{ptag, d.typeDecoder(field.Type.Elem()), idx, field.Name} - continue - } - if ptag.inline { - inlineDecoder = &decoderField{ptag, d.typeDecoder(field.Type), idx, field.Name} - continue - } - if ptag.metadata { - continue - } + // This helper allows us to recursively collect field encoders into a flat + // array. The parameter `index` keeps track of the access patterns necessary + // to get to some field. + var collectFieldDecoders func(r reflect.Type, index []int) + collectFieldDecoders = func(r reflect.Type, index []int) { + for i := 0; i < r.NumField(); i++ { + idx := append(index, i) + field := t.FieldByIndex(idx) + if !field.IsExported() { + continue + } + // If this is an embedded struct, traverse one level deeper to extract + // the fields and get their encoders as well. + if field.Anonymous { + collectFieldDecoders(field.Type, idx) + continue + } + // If json tag is not present, then we skip, which is intentionally + // different behavior from the stdlib. + ptag, ok := parseJSONStructTag(field) + if !ok { + continue + } + // We only want to support unexported fields if they're tagged with + // `extras` because that field shouldn't be part of the public API. We + // also want to only keep the top level extras + if ptag.extras && len(index) == 0 { + extraDecoder = &decoderField{ptag, d.typeDecoder(field.Type.Elem()), idx, field.Name} + continue + } + if ptag.inline && len(index) == 0 { + inlineDecoder = &decoderField{ptag, d.typeDecoder(field.Type), idx, field.Name} + continue + } + if ptag.metadata { + continue + } - oldFormat := d.dateFormat - dateFormat, ok := parseFormatStructTag(field) - if ok { - switch dateFormat { - case "date-time": - d.dateFormat = time.RFC3339 - case "date": - d.dateFormat = "2006-01-02" + oldFormat := d.dateFormat + dateFormat, ok := parseFormatStructTag(field) + if ok { + switch dateFormat { + case "date-time": + d.dateFormat = time.RFC3339 + case "date": + d.dateFormat = "2006-01-02" + } } + decoderFields[ptag.name] = decoderField{ptag, d.typeDecoder(field.Type), idx, field.Name} + d.dateFormat = oldFormat } - decoderFields[ptag.name] = decoderField{ptag, d.typeDecoder(field.Type), idx, field.Name} - d.dateFormat = oldFormat } + collectFieldDecoders(t, []int{}) return func(node gjson.Result, value reflect.Value, state *decoderState) (err error) { if field := value.FieldByName("JSON"); field.IsValid() { @@ -399,11 +389,6 @@ func (d *decoderBuilder) newStructTypeDecoder(t reflect.Type) decoderFunc { } } - for _, decoder := range anonymousDecoders { - // ignore errors - decoder.fn(node, value.FieldByIndex(decoder.idx), state) - } - if inlineDecoder != nil { var meta Field dest := value.FieldByIndex(inlineDecoder.idx) diff --git a/internal/apijson/json_test.go b/internal/apijson/json_test.go index 72bc4c2..43cea30 100644 --- a/internal/apijson/json_test.go +++ b/internal/apijson/json_test.go @@ -48,32 +48,10 @@ type TypedAdditionalProperties struct { ExtraFields map[string]int `json:"-,extras"` } -type EmbeddedStruct struct { - A bool `json:"a"` - B string `json:"b"` - - JSON EmbeddedStructJSON -} - -type EmbeddedStructJSON struct { - A Field - B Field - ExtraFields map[string]Field - raw string -} - type EmbeddedStructs struct { - EmbeddedStruct - A *int `json:"a"` + AdditionalProperties + A *int `json:"number2"` ExtraFields map[string]interface{} `json:"-,extras"` - - JSON EmbeddedStructsJSON -} - -type EmbeddedStructsJSON struct { - A Field - ExtraFields map[string]Field - raw string } type Recursive struct { @@ -354,34 +332,9 @@ var tests = map[string]struct { }, }, - "embedded_struct": { - `{"a":1,"b":"bar"}`, - EmbeddedStructs{ - EmbeddedStruct: EmbeddedStruct{ - A: true, - B: "bar", - JSON: EmbeddedStructJSON{ - A: Field{raw: `1`, status: valid}, - B: Field{raw: `"bar"`, status: valid}, - raw: `{"a":1,"b":"bar"}`, - }, - }, - A: P(1), - ExtraFields: map[string]interface{}{"b": "bar"}, - JSON: EmbeddedStructsJSON{ - A: Field{raw: `1`, status: valid}, - ExtraFields: map[string]Field{ - "b": {raw: `"bar"`, status: valid}, - }, - raw: `{"a":1,"b":"bar"}`, - }, - }, - }, - "recursive_struct": { `{"child":{"name":"Alex"},"name":"Robert"}`, - Recursive{Name: "Robert", Child: &Recursive{Name: "Alex"}}, - }, + Recursive{Name: "Robert", Child: &Recursive{Name: "Alex"}}, }, "metadata_coerce": { `{"a":"12","b":"12","c":null,"extra_typed":12,"extra_untyped":{"foo":"bar"}}`,