diff --git a/client/client.go b/client/client.go index 60ba447..ddccf0b 100644 --- a/client/client.go +++ b/client/client.go @@ -3,10 +3,7 @@ package client import ( - bytes "bytes" context "context" - json "encoding/json" - errors "errors" fmt "fmt" v2 "github.com/cohere-ai/cohere-go/v2" connectors "github.com/cohere-ai/cohere-go/v2/connectors" @@ -14,17 +11,17 @@ import ( datasets "github.com/cohere-ai/cohere-go/v2/datasets" embedjobs "github.com/cohere-ai/cohere-go/v2/embedjobs" finetuningclient "github.com/cohere-ai/cohere-go/v2/finetuning/client" + internal "github.com/cohere-ai/cohere-go/v2/internal" models "github.com/cohere-ai/cohere-go/v2/models" option "github.com/cohere-ai/cohere-go/v2/option" v2v2 "github.com/cohere-ai/cohere-go/v2/v2" - io "io" http "net/http" os "os" ) type Client struct { baseURL string - caller *core.Caller + caller *internal.Caller header http.Header V2 *v2v2.Client @@ -42,8 +39,8 @@ func NewClient(opts ...option.RequestOption) *Client { } return &Client{ baseURL: options.BaseURL, - caller: core.NewCaller( - &core.CallerParams{ + caller: internal.NewCaller( + &internal.CallerParams{ Client: options.HTTPClient, MaxAttempts: options.MaxAttempts, }, @@ -58,7 +55,8 @@ func NewClient(opts ...option.RequestOption) *Client { } } -// Generates a text response to a user message. +// Generates a streamed text response to a user message. +// // To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/docs/chat-api). func (c *Client) ChatStream( ctx context.Context, @@ -66,123 +64,96 @@ func (c *Client) ChatStream( opts ...option.RequestOption, ) (*core.Stream[v2.StreamedChatResponse], error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) endpointURL := baseURL + "/v1/chat" - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) if request.Accepts != nil { headers.Add("Accepts", fmt.Sprintf("%v", request.Accepts)) } - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + headers.Set("Content-Type", "application/json") + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } - streamer := core.NewStreamer[v2.StreamedChatResponse](c.caller) + streamer := internal.NewStreamer[v2.StreamedChatResponse](c.caller) return streamer.Stream( ctx, - &core.StreamParams{ + &internal.StreamParams{ URL: endpointURL, Method: http.MethodPost, + Headers: headers, MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, - Headers: headers, Client: options.HTTPClient, Request: request, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ) } @@ -195,124 +166,97 @@ func (c *Client) Chat( opts ...option.RequestOption, ) (*v2.NonStreamedChatResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) endpointURL := baseURL + "/v1/chat" - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) if request.Accepts != nil { headers.Add("Accepts", fmt.Sprintf("%v", request.Accepts)) } - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + headers.Set("Content-Type", "application/json") + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } var response *v2.NonStreamedChatResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodPost, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Request: request, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err @@ -330,120 +274,93 @@ func (c *Client) GenerateStream( opts ...option.RequestOption, ) (*core.Stream[v2.GenerateStreamedResponse], error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) endpointURL := baseURL + "/v1/generate" - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + headers.Set("Content-Type", "application/json") + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } - streamer := core.NewStreamer[v2.GenerateStreamedResponse](c.caller) + streamer := internal.NewStreamer[v2.GenerateStreamedResponse](c.caller) return streamer.Stream( ctx, - &core.StreamParams{ + &internal.StreamParams{ URL: endpointURL, Method: http.MethodPost, + Headers: headers, MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, - Headers: headers, Client: options.HTTPClient, Request: request, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ) } @@ -458,121 +375,94 @@ func (c *Client) Generate( opts ...option.RequestOption, ) (*v2.Generation, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) endpointURL := baseURL + "/v1/generate" - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + headers.Set("Content-Type", "application/json") + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } var response *v2.Generation if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodPost, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Request: request, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err @@ -591,121 +481,94 @@ func (c *Client) Embed( opts ...option.RequestOption, ) (*v2.EmbedResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) endpointURL := baseURL + "/v1/embed" - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + headers.Set("Content-Type", "application/json") + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } var response *v2.EmbedResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodPost, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Request: request, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err @@ -720,121 +583,94 @@ func (c *Client) Rerank( opts ...option.RequestOption, ) (*v2.RerankResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) endpointURL := baseURL + "/v1/rerank" - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + headers.Set("Content-Type", "application/json") + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } var response *v2.RerankResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodPost, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Request: request, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err @@ -850,121 +686,94 @@ func (c *Client) Classify( opts ...option.RequestOption, ) (*v2.ClassifyResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) endpointURL := baseURL + "/v1/classify" - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + headers.Set("Content-Type", "application/json") + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } var response *v2.ClassifyResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodPost, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Request: request, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err @@ -982,121 +791,94 @@ func (c *Client) Summarize( opts ...option.RequestOption, ) (*v2.SummarizeResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) endpointURL := baseURL + "/v1/summarize" - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + headers.Set("Content-Type", "application/json") + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } var response *v2.SummarizeResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodPost, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Request: request, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err @@ -1111,121 +893,94 @@ func (c *Client) Tokenize( opts ...option.RequestOption, ) (*v2.TokenizeResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) endpointURL := baseURL + "/v1/tokenize" - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + headers.Set("Content-Type", "application/json") + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } var response *v2.TokenizeResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodPost, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Request: request, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err @@ -1240,121 +995,94 @@ func (c *Client) Detokenize( opts ...option.RequestOption, ) (*v2.DetokenizeResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) endpointURL := baseURL + "/v1/detokenize" - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + headers.Set("Content-Type", "application/json") + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } var response *v2.DetokenizeResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodPost, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Request: request, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err @@ -1368,120 +1096,92 @@ func (c *Client) CheckApiKey( opts ...option.RequestOption, ) (*v2.CheckApiKeyResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) endpointURL := baseURL + "/v1/check-api-key" - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } var response *v2.CheckApiKeyResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodPost, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err diff --git a/connectors.go b/connectors.go index c7423c4..25e3dbc 100644 --- a/connectors.go +++ b/connectors.go @@ -2,6 +2,13 @@ package api +import ( + json "encoding/json" + fmt "fmt" + internal "github.com/cohere-ai/cohere-go/v2/internal" + time "time" +) + type CreateConnectorRequest struct { // A human-readable name for the connector. Name string `json:"name" url:"-"` @@ -33,6 +40,699 @@ type ConnectorsOAuthAuthorizeRequest struct { AfterTokenRedirect *string `json:"-" url:"after_token_redirect,omitempty"` } +// The token_type specifies the way the token is passed in the Authorization header. Valid values are "bearer", "basic", and "noscheme". +type AuthTokenType string + +const ( + AuthTokenTypeBearer AuthTokenType = "bearer" + AuthTokenTypeBasic AuthTokenType = "basic" + AuthTokenTypeNoscheme AuthTokenType = "noscheme" +) + +func NewAuthTokenTypeFromString(s string) (AuthTokenType, error) { + switch s { + case "bearer": + return AuthTokenTypeBearer, nil + case "basic": + return AuthTokenTypeBasic, nil + case "noscheme": + return AuthTokenTypeNoscheme, nil + } + var t AuthTokenType + return "", fmt.Errorf("%s is not a valid %T", s, t) +} + +func (a AuthTokenType) Ptr() *AuthTokenType { + return &a +} + +// A connector allows you to integrate data sources with the '/chat' endpoint to create grounded generations with citations to the data source. +// documents to help answer users. +type Connector struct { + // The unique identifier of the connector (used in both `/connectors` & `/chat` endpoints). + // This is automatically created from the name of the connector upon registration. + Id string `json:"id" url:"id"` + // The organization to which this connector belongs. This is automatically set to + // the organization of the user who created the connector. + OrganizationId *string `json:"organization_id,omitempty" url:"organization_id,omitempty"` + // A human-readable name for the connector. + Name string `json:"name" url:"name"` + // A description of the connector. + Description *string `json:"description,omitempty" url:"description,omitempty"` + // The URL of the connector that will be used to search for documents. + Url *string `json:"url,omitempty" url:"url,omitempty"` + // The UTC time at which the connector was created. + CreatedAt time.Time `json:"created_at" url:"created_at"` + // The UTC time at which the connector was last updated. + UpdatedAt time.Time `json:"updated_at" url:"updated_at"` + // A list of fields to exclude from the prompt (fields remain in the document). + Excludes []string `json:"excludes,omitempty" url:"excludes,omitempty"` + // The type of authentication/authorization used by the connector. Possible values: [oauth, service_auth] + AuthType *string `json:"auth_type,omitempty" url:"auth_type,omitempty"` + // The OAuth 2.0 configuration for the connector. + Oauth *ConnectorOAuth `json:"oauth,omitempty" url:"oauth,omitempty"` + // The OAuth status for the user making the request. One of ["valid", "expired", ""]. Empty string (field is omitted) means the user has not authorized the connector yet. + AuthStatus *ConnectorAuthStatus `json:"auth_status,omitempty" url:"auth_status,omitempty"` + // Whether the connector is active or not. + Active *bool `json:"active,omitempty" url:"active,omitempty"` + // Whether a chat request should continue or not if the request to this connector fails. + ContinueOnFailure *bool `json:"continue_on_failure,omitempty" url:"continue_on_failure,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *Connector) GetId() string { + if c == nil { + return "" + } + return c.Id +} + +func (c *Connector) GetOrganizationId() *string { + if c == nil { + return nil + } + return c.OrganizationId +} + +func (c *Connector) GetName() string { + if c == nil { + return "" + } + return c.Name +} + +func (c *Connector) GetDescription() *string { + if c == nil { + return nil + } + return c.Description +} + +func (c *Connector) GetUrl() *string { + if c == nil { + return nil + } + return c.Url +} + +func (c *Connector) GetCreatedAt() time.Time { + if c == nil { + return time.Time{} + } + return c.CreatedAt +} + +func (c *Connector) GetUpdatedAt() time.Time { + if c == nil { + return time.Time{} + } + return c.UpdatedAt +} + +func (c *Connector) GetExcludes() []string { + if c == nil { + return nil + } + return c.Excludes +} + +func (c *Connector) GetAuthType() *string { + if c == nil { + return nil + } + return c.AuthType +} + +func (c *Connector) GetOauth() *ConnectorOAuth { + if c == nil { + return nil + } + return c.Oauth +} + +func (c *Connector) GetAuthStatus() *ConnectorAuthStatus { + if c == nil { + return nil + } + return c.AuthStatus +} + +func (c *Connector) GetActive() *bool { + if c == nil { + return nil + } + return c.Active +} + +func (c *Connector) GetContinueOnFailure() *bool { + if c == nil { + return nil + } + return c.ContinueOnFailure +} + +func (c *Connector) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *Connector) UnmarshalJSON(data []byte) error { + type embed Connector + var unmarshaler = struct { + embed + CreatedAt *internal.DateTime `json:"created_at"` + UpdatedAt *internal.DateTime `json:"updated_at"` + }{ + embed: embed(*c), + } + if err := json.Unmarshal(data, &unmarshaler); err != nil { + return err + } + *c = Connector(unmarshaler.embed) + c.CreatedAt = unmarshaler.CreatedAt.Time() + c.UpdatedAt = unmarshaler.UpdatedAt.Time() + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *Connector) MarshalJSON() ([]byte, error) { + type embed Connector + var marshaler = struct { + embed + CreatedAt *internal.DateTime `json:"created_at"` + UpdatedAt *internal.DateTime `json:"updated_at"` + }{ + embed: embed(*c), + CreatedAt: internal.NewDateTime(c.CreatedAt), + UpdatedAt: internal.NewDateTime(c.UpdatedAt), + } + return json.Marshal(marshaler) +} + +func (c *Connector) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +// The OAuth status for the user making the request. One of ["valid", "expired", ""]. Empty string (field is omitted) means the user has not authorized the connector yet. +type ConnectorAuthStatus string + +const ( + ConnectorAuthStatusValid ConnectorAuthStatus = "valid" + ConnectorAuthStatusExpired ConnectorAuthStatus = "expired" +) + +func NewConnectorAuthStatusFromString(s string) (ConnectorAuthStatus, error) { + switch s { + case "valid": + return ConnectorAuthStatusValid, nil + case "expired": + return ConnectorAuthStatusExpired, nil + } + var t ConnectorAuthStatus + return "", fmt.Errorf("%s is not a valid %T", s, t) +} + +func (c ConnectorAuthStatus) Ptr() *ConnectorAuthStatus { + return &c +} + +type ConnectorOAuth struct { + // The OAuth 2.0 client ID. This field is encrypted at rest. + ClientId *string `json:"client_id,omitempty" url:"client_id,omitempty"` + // The OAuth 2.0 client Secret. This field is encrypted at rest and never returned in a response. + ClientSecret *string `json:"client_secret,omitempty" url:"client_secret,omitempty"` + // The OAuth 2.0 /authorize endpoint to use when users authorize the connector. + AuthorizeUrl string `json:"authorize_url" url:"authorize_url"` + // The OAuth 2.0 /token endpoint to use when users authorize the connector. + TokenUrl string `json:"token_url" url:"token_url"` + // The OAuth scopes to request when users authorize the connector. + Scope *string `json:"scope,omitempty" url:"scope,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ConnectorOAuth) GetClientId() *string { + if c == nil { + return nil + } + return c.ClientId +} + +func (c *ConnectorOAuth) GetClientSecret() *string { + if c == nil { + return nil + } + return c.ClientSecret +} + +func (c *ConnectorOAuth) GetAuthorizeUrl() string { + if c == nil { + return "" + } + return c.AuthorizeUrl +} + +func (c *ConnectorOAuth) GetTokenUrl() string { + if c == nil { + return "" + } + return c.TokenUrl +} + +func (c *ConnectorOAuth) GetScope() *string { + if c == nil { + return nil + } + return c.Scope +} + +func (c *ConnectorOAuth) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ConnectorOAuth) UnmarshalJSON(data []byte) error { + type unmarshaler ConnectorOAuth + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ConnectorOAuth(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ConnectorOAuth) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +type CreateConnectorOAuth struct { + // The OAuth 2.0 client ID. This fields is encrypted at rest. + ClientId *string `json:"client_id,omitempty" url:"client_id,omitempty"` + // The OAuth 2.0 client Secret. This field is encrypted at rest and never returned in a response. + ClientSecret *string `json:"client_secret,omitempty" url:"client_secret,omitempty"` + // The OAuth 2.0 /authorize endpoint to use when users authorize the connector. + AuthorizeUrl *string `json:"authorize_url,omitempty" url:"authorize_url,omitempty"` + // The OAuth 2.0 /token endpoint to use when users authorize the connector. + TokenUrl *string `json:"token_url,omitempty" url:"token_url,omitempty"` + // The OAuth scopes to request when users authorize the connector. + Scope *string `json:"scope,omitempty" url:"scope,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *CreateConnectorOAuth) GetClientId() *string { + if c == nil { + return nil + } + return c.ClientId +} + +func (c *CreateConnectorOAuth) GetClientSecret() *string { + if c == nil { + return nil + } + return c.ClientSecret +} + +func (c *CreateConnectorOAuth) GetAuthorizeUrl() *string { + if c == nil { + return nil + } + return c.AuthorizeUrl +} + +func (c *CreateConnectorOAuth) GetTokenUrl() *string { + if c == nil { + return nil + } + return c.TokenUrl +} + +func (c *CreateConnectorOAuth) GetScope() *string { + if c == nil { + return nil + } + return c.Scope +} + +func (c *CreateConnectorOAuth) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *CreateConnectorOAuth) UnmarshalJSON(data []byte) error { + type unmarshaler CreateConnectorOAuth + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = CreateConnectorOAuth(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *CreateConnectorOAuth) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +type CreateConnectorResponse struct { + Connector *Connector `json:"connector,omitempty" url:"connector,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *CreateConnectorResponse) GetConnector() *Connector { + if c == nil { + return nil + } + return c.Connector +} + +func (c *CreateConnectorResponse) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *CreateConnectorResponse) UnmarshalJSON(data []byte) error { + type unmarshaler CreateConnectorResponse + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = CreateConnectorResponse(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *CreateConnectorResponse) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +type CreateConnectorServiceAuth struct { + Type AuthTokenType `json:"type" url:"type"` + // The token that will be used in the HTTP Authorization header when making requests to the connector. This field is encrypted at rest and never returned in a response. + Token string `json:"token" url:"token"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *CreateConnectorServiceAuth) GetType() AuthTokenType { + if c == nil { + return "" + } + return c.Type +} + +func (c *CreateConnectorServiceAuth) GetToken() string { + if c == nil { + return "" + } + return c.Token +} + +func (c *CreateConnectorServiceAuth) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *CreateConnectorServiceAuth) UnmarshalJSON(data []byte) error { + type unmarshaler CreateConnectorServiceAuth + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = CreateConnectorServiceAuth(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *CreateConnectorServiceAuth) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +type DeleteConnectorResponse = map[string]interface{} + +type GetConnectorResponse struct { + Connector *Connector `json:"connector,omitempty" url:"connector,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (g *GetConnectorResponse) GetConnector() *Connector { + if g == nil { + return nil + } + return g.Connector +} + +func (g *GetConnectorResponse) GetExtraProperties() map[string]interface{} { + return g.extraProperties +} + +func (g *GetConnectorResponse) UnmarshalJSON(data []byte) error { + type unmarshaler GetConnectorResponse + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *g = GetConnectorResponse(value) + extraProperties, err := internal.ExtractExtraProperties(data, *g) + if err != nil { + return err + } + g.extraProperties = extraProperties + g.rawJSON = json.RawMessage(data) + return nil +} + +func (g *GetConnectorResponse) String() string { + if len(g.rawJSON) > 0 { + if value, err := internal.StringifyJSON(g.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(g); err == nil { + return value + } + return fmt.Sprintf("%#v", g) +} + +type ListConnectorsResponse struct { + Connectors []*Connector `json:"connectors,omitempty" url:"connectors,omitempty"` + // Total number of connectors. + TotalCount *float64 `json:"total_count,omitempty" url:"total_count,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (l *ListConnectorsResponse) GetConnectors() []*Connector { + if l == nil { + return nil + } + return l.Connectors +} + +func (l *ListConnectorsResponse) GetTotalCount() *float64 { + if l == nil { + return nil + } + return l.TotalCount +} + +func (l *ListConnectorsResponse) GetExtraProperties() map[string]interface{} { + return l.extraProperties +} + +func (l *ListConnectorsResponse) UnmarshalJSON(data []byte) error { + type unmarshaler ListConnectorsResponse + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *l = ListConnectorsResponse(value) + extraProperties, err := internal.ExtractExtraProperties(data, *l) + if err != nil { + return err + } + l.extraProperties = extraProperties + l.rawJSON = json.RawMessage(data) + return nil +} + +func (l *ListConnectorsResponse) String() string { + if len(l.rawJSON) > 0 { + if value, err := internal.StringifyJSON(l.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(l); err == nil { + return value + } + return fmt.Sprintf("%#v", l) +} + +type OAuthAuthorizeResponse struct { + // The OAuth 2.0 redirect url. Redirect the user to this url to authorize the connector. + RedirectUrl *string `json:"redirect_url,omitempty" url:"redirect_url,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (o *OAuthAuthorizeResponse) GetRedirectUrl() *string { + if o == nil { + return nil + } + return o.RedirectUrl +} + +func (o *OAuthAuthorizeResponse) GetExtraProperties() map[string]interface{} { + return o.extraProperties +} + +func (o *OAuthAuthorizeResponse) UnmarshalJSON(data []byte) error { + type unmarshaler OAuthAuthorizeResponse + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *o = OAuthAuthorizeResponse(value) + extraProperties, err := internal.ExtractExtraProperties(data, *o) + if err != nil { + return err + } + o.extraProperties = extraProperties + o.rawJSON = json.RawMessage(data) + return nil +} + +func (o *OAuthAuthorizeResponse) String() string { + if len(o.rawJSON) > 0 { + if value, err := internal.StringifyJSON(o.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(o); err == nil { + return value + } + return fmt.Sprintf("%#v", o) +} + +type UpdateConnectorResponse struct { + Connector *Connector `json:"connector,omitempty" url:"connector,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (u *UpdateConnectorResponse) GetConnector() *Connector { + if u == nil { + return nil + } + return u.Connector +} + +func (u *UpdateConnectorResponse) GetExtraProperties() map[string]interface{} { + return u.extraProperties +} + +func (u *UpdateConnectorResponse) UnmarshalJSON(data []byte) error { + type unmarshaler UpdateConnectorResponse + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *u = UpdateConnectorResponse(value) + extraProperties, err := internal.ExtractExtraProperties(data, *u) + if err != nil { + return err + } + u.extraProperties = extraProperties + u.rawJSON = json.RawMessage(data) + return nil +} + +func (u *UpdateConnectorResponse) String() string { + if len(u.rawJSON) > 0 { + if value, err := internal.StringifyJSON(u.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(u); err == nil { + return value + } + return fmt.Sprintf("%#v", u) +} + type UpdateConnectorRequest struct { // A human-readable name for the connector. Name *string `json:"name,omitempty" url:"-"` diff --git a/connectors/client.go b/connectors/client.go index 42c42e9..d1c20d0 100644 --- a/connectors/client.go +++ b/connectors/client.go @@ -3,21 +3,18 @@ package connectors import ( - bytes "bytes" context "context" - json "encoding/json" - errors "errors" v2 "github.com/cohere-ai/cohere-go/v2" core "github.com/cohere-ai/cohere-go/v2/core" + internal "github.com/cohere-ai/cohere-go/v2/internal" option "github.com/cohere-ai/cohere-go/v2/option" - io "io" http "net/http" os "os" ) type Client struct { baseURL string - caller *core.Caller + caller *internal.Caller header http.Header } @@ -28,8 +25,8 @@ func NewClient(opts ...option.RequestOption) *Client { } return &Client{ baseURL: options.BaseURL, - caller: core.NewCaller( - &core.CallerParams{ + caller: internal.NewCaller( + &internal.CallerParams{ Client: options.HTTPClient, MaxAttempts: options.MaxAttempts, }, @@ -45,128 +42,99 @@ func (c *Client) List( opts ...option.RequestOption, ) (*v2.ListConnectorsResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) endpointURL := baseURL + "/v1/connectors" - - queryParams, err := core.QueryValues(request) + queryParams, err := internal.QueryValues(request) if err != nil { return nil, err } if len(queryParams) > 0 { endpointURL += "?" + queryParams.Encode() } - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } var response *v2.ListConnectorsResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodGet, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err @@ -181,121 +149,94 @@ func (c *Client) Create( opts ...option.RequestOption, ) (*v2.CreateConnectorResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) endpointURL := baseURL + "/v1/connectors" - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + headers.Set("Content-Type", "application/json") + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } var response *v2.CreateConnectorResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodPost, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Request: request, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err @@ -311,120 +252,95 @@ func (c *Client) Get( opts ...option.RequestOption, ) (*v2.GetConnectorResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } - endpointURL := core.EncodeURL(baseURL+"/v1/connectors/%v", id) - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) + endpointURL := internal.EncodeURL( + baseURL+"/v1/connectors/%v", + id, + ) + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } var response *v2.GetConnectorResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodGet, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err @@ -440,120 +356,95 @@ func (c *Client) Delete( opts ...option.RequestOption, ) (v2.DeleteConnectorResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } - endpointURL := core.EncodeURL(baseURL+"/v1/connectors/%v", id) - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) + endpointURL := internal.EncodeURL( + baseURL+"/v1/connectors/%v", + id, + ) + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } var response v2.DeleteConnectorResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodDelete, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err @@ -570,121 +461,97 @@ func (c *Client) Update( opts ...option.RequestOption, ) (*v2.UpdateConnectorResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } - endpointURL := core.EncodeURL(baseURL+"/v1/connectors/%v", id) - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) + endpointURL := internal.EncodeURL( + baseURL+"/v1/connectors/%v", + id, + ) + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + headers.Set("Content-Type", "application/json") + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } var response *v2.UpdateConnectorResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodPatch, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Request: request, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err @@ -692,7 +559,7 @@ func (c *Client) Update( return response, nil } -// Authorize the connector with the given ID for the connector oauth app. See ['Connector Authentication'](https://docs.cohere.com/docs/connector-authentication) for more information. +// Authorize the connector with the given ID for the connector oauth app. See ['Connector Authentication'](https://docs.cohere.com/docs/connector-authentication) for more information. func (c *Client) OAuthAuthorize( ctx context.Context, // The ID of the connector to authorize. @@ -701,128 +568,102 @@ func (c *Client) OAuthAuthorize( opts ...option.RequestOption, ) (*v2.OAuthAuthorizeResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } - endpointURL := core.EncodeURL(baseURL+"/v1/connectors/%v/oauth/authorize", id) - - queryParams, err := core.QueryValues(request) + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) + endpointURL := internal.EncodeURL( + baseURL+"/v1/connectors/%v/oauth/authorize", + id, + ) + queryParams, err := internal.QueryValues(request) if err != nil { return nil, err } if len(queryParams) > 0 { endpointURL += "?" + queryParams.Encode() } - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } var response *v2.OAuthAuthorizeResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodPost, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err diff --git a/core/api_error.go b/core/api_error.go new file mode 100644 index 0000000..dc4190c --- /dev/null +++ b/core/api_error.go @@ -0,0 +1,42 @@ +package core + +import "fmt" + +// APIError is a lightweight wrapper around the standard error +// interface that preserves the status code from the RPC, if any. +type APIError struct { + err error + + StatusCode int `json:"-"` +} + +// NewAPIError constructs a new API error. +func NewAPIError(statusCode int, err error) *APIError { + return &APIError{ + err: err, + StatusCode: statusCode, + } +} + +// Unwrap returns the underlying error. This also makes the error compatible +// with errors.As and errors.Is. +func (a *APIError) Unwrap() error { + if a == nil { + return nil + } + return a.err +} + +// Error returns the API error's message. +func (a *APIError) Error() string { + if a == nil || (a.err == nil && a.StatusCode == 0) { + return "" + } + if a.err == nil { + return fmt.Sprintf("%d", a.StatusCode) + } + if a.StatusCode == 0 { + return a.err.Error() + } + return fmt.Sprintf("%d: %s", a.StatusCode, a.err.Error()) +} diff --git a/core/http.go b/core/http.go new file mode 100644 index 0000000..b553350 --- /dev/null +++ b/core/http.go @@ -0,0 +1,8 @@ +package core + +import "net/http" + +// HTTPClient is an interface for a subset of the *http.Client. +type HTTPClient interface { + Do(*http.Request) (*http.Response, error) +} diff --git a/core/request_option.go b/core/request_option.go index 4ce0dc1..b434073 100644 --- a/core/request_option.go +++ b/core/request_option.go @@ -61,7 +61,8 @@ func (r *RequestOptions) cloneHeader() http.Header { headers := r.HTTPHeader.Clone() headers.Set("X-Fern-Language", "Go") headers.Set("X-Fern-SDK-Name", "github.com/cohere-ai/cohere-go/v2") - headers.Set("X-Fern-SDK-Version", "v2.12.3") + headers.Set("X-Fern-SDK-Version", "v2.12.4") + headers.Set("User-Agent", "github.com/cohere-ai/cohere-go/2.12.4") return headers } diff --git a/core/stream.go b/core/stream.go index 30e374d..f92c629 100644 --- a/core/stream.go +++ b/core/stream.go @@ -2,121 +2,14 @@ package core import ( "bufio" - "context" "encoding/json" "io" "net/http" - "net/url" "strings" ) -const ( - // DefaultDataPrefix is the default prefix used for SSE streaming. - DefaultSSEDataPrefix = "data: " - - // DefaultTerminator is the default terminator used for SSE streaming. - DefaultSSETerminator = "[DONE]" - - // The default stream delimiter used to split messages. - defaultStreamDelimiter = '\n' -) - -// Streamer calls APIs and streams responses using a *Stream. -type Streamer[T any] struct { - client HTTPClient - retrier *Retrier -} - -// NewStreamer returns a new *Streamer backed by the given caller's HTTP client. -func NewStreamer[T any](caller *Caller) *Streamer[T] { - return &Streamer[T]{ - client: caller.client, - retrier: caller.retrier, - } -} - -// StreamParams represents the parameters used to issue an API streaming call. -type StreamParams struct { - URL string - Method string - Prefix string - Delimiter string - Terminator string - MaxAttempts uint - Headers http.Header - BodyProperties map[string]interface{} - QueryParameters url.Values - Client HTTPClient - Request interface{} - ErrorDecoder ErrorDecoder -} - -// Stream issues an API streaming call according to the given stream parameters. -func (s *Streamer[T]) Stream(ctx context.Context, params *StreamParams) (*Stream[T], error) { - url := buildURL(params.URL, params.QueryParameters) - req, err := newRequest( - ctx, - url, - params.Method, - params.Headers, - params.Request, - params.BodyProperties, - ) - if err != nil { - return nil, err - } - - // If the call has been cancelled, don't issue the request. - if err := ctx.Err(); err != nil { - return nil, err - } - - client := s.client - if params.Client != nil { - // Use the HTTP client scoped to the request. - client = params.Client - } - - var retryOptions []RetryOption - if params.MaxAttempts > 0 { - retryOptions = append(retryOptions, WithMaxAttempts(params.MaxAttempts)) - } - - resp, err := s.retrier.Run( - client.Do, - req, - params.ErrorDecoder, - retryOptions..., - ) - if err != nil { - return nil, err - } - - // Check if the call was cancelled before we return the error - // associated with the call and/or unmarshal the response data. - if err := ctx.Err(); err != nil { - defer resp.Body.Close() - return nil, err - } - - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - defer resp.Body.Close() - return nil, decodeError(resp, params.ErrorDecoder) - } - - var opts []StreamOption - if params.Delimiter != "" { - opts = append(opts, WithDelimiter(params.Delimiter)) - } - if params.Prefix != "" { - opts = append(opts, WithPrefix(params.Prefix)) - } - if params.Terminator != "" { - opts = append(opts, WithTerminator(params.Terminator)) - } - - return NewStream[T](resp, opts...), nil -} +// defaultStreamDelimiter is the default stream delimiter used to split messages. +const defaultStreamDelimiter = '\n' // Stream represents a stream of messages sent from a server. type Stream[T any] struct { @@ -262,18 +155,21 @@ func (s *scannerStreamReader) ReadFromStream() ([]byte, error) { } func (s *scannerStreamReader) parse(bytes []byte) (int, []byte, error) { - var start int + var startIndex int if s.options != nil && s.options.prefix != "" { if i := strings.Index(string(bytes), s.options.prefix); i >= 0 { - start = i + len(s.options.prefix) + startIndex = i + len(s.options.prefix) } } - data := bytes[start:] - if i := strings.Index(string(data), s.options.delimiter); i >= 0 { - data = data[:i+len(s.options.delimiter)] + data := bytes[startIndex:] + delimIndex := strings.Index(string(data), s.options.delimiter) + if delimIndex < 0 { + return startIndex + len(data), data, nil } - n := start + len(data) + len(s.options.delimiter) - return n, data, nil + endIndex := delimIndex + len(s.options.delimiter) + parsedData := data[:endIndex] + n := startIndex + endIndex + return n, parsedData, nil } func (s *scannerStreamReader) isTerminated(bytes []byte) bool { diff --git a/datasets.go b/datasets.go index 7e55fbe..22b2ef2 100644 --- a/datasets.go +++ b/datasets.go @@ -5,7 +5,7 @@ package api import ( json "encoding/json" fmt "fmt" - core "github.com/cohere-ai/cohere-go/v2/core" + internal "github.com/cohere-ai/cohere-go/v2/internal" time "time" ) @@ -45,12 +45,940 @@ type DatasetsListRequest struct { ValidationStatus *DatasetValidationStatus `json:"-" url:"validationStatus,omitempty"` } +type ChatDataMetrics struct { + // The sum of all turns of valid train examples. + NumTrainTurns *int64 `json:"num_train_turns,omitempty" url:"num_train_turns,omitempty"` + // The sum of all turns of valid eval examples. + NumEvalTurns *int64 `json:"num_eval_turns,omitempty" url:"num_eval_turns,omitempty"` + // The preamble of this dataset. + Preamble *string `json:"preamble,omitempty" url:"preamble,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatDataMetrics) GetNumTrainTurns() *int64 { + if c == nil { + return nil + } + return c.NumTrainTurns +} + +func (c *ChatDataMetrics) GetNumEvalTurns() *int64 { + if c == nil { + return nil + } + return c.NumEvalTurns +} + +func (c *ChatDataMetrics) GetPreamble() *string { + if c == nil { + return nil + } + return c.Preamble +} + +func (c *ChatDataMetrics) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatDataMetrics) UnmarshalJSON(data []byte) error { + type unmarshaler ChatDataMetrics + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatDataMetrics(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatDataMetrics) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +type ClassifyDataMetrics struct { + LabelMetrics []*LabelMetric `json:"label_metrics,omitempty" url:"label_metrics,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ClassifyDataMetrics) GetLabelMetrics() []*LabelMetric { + if c == nil { + return nil + } + return c.LabelMetrics +} + +func (c *ClassifyDataMetrics) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ClassifyDataMetrics) UnmarshalJSON(data []byte) error { + type unmarshaler ClassifyDataMetrics + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ClassifyDataMetrics(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ClassifyDataMetrics) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +type Dataset struct { + // The dataset ID + Id string `json:"id" url:"id"` + // The name of the dataset + Name string `json:"name" url:"name"` + // The creation date + CreatedAt time.Time `json:"created_at" url:"created_at"` + // The last update date + UpdatedAt time.Time `json:"updated_at" url:"updated_at"` + DatasetType DatasetType `json:"dataset_type" url:"dataset_type"` + ValidationStatus DatasetValidationStatus `json:"validation_status" url:"validation_status"` + // Errors found during validation + ValidationError *string `json:"validation_error,omitempty" url:"validation_error,omitempty"` + // the avro schema of the dataset + Schema *string `json:"schema,omitempty" url:"schema,omitempty"` + RequiredFields []string `json:"required_fields,omitempty" url:"required_fields,omitempty"` + PreserveFields []string `json:"preserve_fields,omitempty" url:"preserve_fields,omitempty"` + // the underlying files that make up the dataset + DatasetParts []*DatasetPart `json:"dataset_parts,omitempty" url:"dataset_parts,omitempty"` + // warnings found during validation + ValidationWarnings []string `json:"validation_warnings,omitempty" url:"validation_warnings,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (d *Dataset) GetId() string { + if d == nil { + return "" + } + return d.Id +} + +func (d *Dataset) GetName() string { + if d == nil { + return "" + } + return d.Name +} + +func (d *Dataset) GetCreatedAt() time.Time { + if d == nil { + return time.Time{} + } + return d.CreatedAt +} + +func (d *Dataset) GetUpdatedAt() time.Time { + if d == nil { + return time.Time{} + } + return d.UpdatedAt +} + +func (d *Dataset) GetDatasetType() DatasetType { + if d == nil { + return "" + } + return d.DatasetType +} + +func (d *Dataset) GetValidationStatus() DatasetValidationStatus { + if d == nil { + return "" + } + return d.ValidationStatus +} + +func (d *Dataset) GetValidationError() *string { + if d == nil { + return nil + } + return d.ValidationError +} + +func (d *Dataset) GetSchema() *string { + if d == nil { + return nil + } + return d.Schema +} + +func (d *Dataset) GetRequiredFields() []string { + if d == nil { + return nil + } + return d.RequiredFields +} + +func (d *Dataset) GetPreserveFields() []string { + if d == nil { + return nil + } + return d.PreserveFields +} + +func (d *Dataset) GetDatasetParts() []*DatasetPart { + if d == nil { + return nil + } + return d.DatasetParts +} + +func (d *Dataset) GetValidationWarnings() []string { + if d == nil { + return nil + } + return d.ValidationWarnings +} + +func (d *Dataset) GetExtraProperties() map[string]interface{} { + return d.extraProperties +} + +func (d *Dataset) UnmarshalJSON(data []byte) error { + type embed Dataset + var unmarshaler = struct { + embed + CreatedAt *internal.DateTime `json:"created_at"` + UpdatedAt *internal.DateTime `json:"updated_at"` + }{ + embed: embed(*d), + } + if err := json.Unmarshal(data, &unmarshaler); err != nil { + return err + } + *d = Dataset(unmarshaler.embed) + d.CreatedAt = unmarshaler.CreatedAt.Time() + d.UpdatedAt = unmarshaler.UpdatedAt.Time() + extraProperties, err := internal.ExtractExtraProperties(data, *d) + if err != nil { + return err + } + d.extraProperties = extraProperties + d.rawJSON = json.RawMessage(data) + return nil +} + +func (d *Dataset) MarshalJSON() ([]byte, error) { + type embed Dataset + var marshaler = struct { + embed + CreatedAt *internal.DateTime `json:"created_at"` + UpdatedAt *internal.DateTime `json:"updated_at"` + }{ + embed: embed(*d), + CreatedAt: internal.NewDateTime(d.CreatedAt), + UpdatedAt: internal.NewDateTime(d.UpdatedAt), + } + return json.Marshal(marshaler) +} + +func (d *Dataset) String() string { + if len(d.rawJSON) > 0 { + if value, err := internal.StringifyJSON(d.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(d); err == nil { + return value + } + return fmt.Sprintf("%#v", d) +} + +type DatasetPart struct { + // The dataset part ID + Id string `json:"id" url:"id"` + // The name of the dataset part + Name string `json:"name" url:"name"` + // The download url of the file + Url *string `json:"url,omitempty" url:"url,omitempty"` + // The index of the file + Index *int `json:"index,omitempty" url:"index,omitempty"` + // The size of the file in bytes + SizeBytes *int `json:"size_bytes,omitempty" url:"size_bytes,omitempty"` + // The number of rows in the file + NumRows *int `json:"num_rows,omitempty" url:"num_rows,omitempty"` + // The download url of the original file + OriginalUrl *string `json:"original_url,omitempty" url:"original_url,omitempty"` + // The first few rows of the parsed file + Samples []string `json:"samples,omitempty" url:"samples,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (d *DatasetPart) GetId() string { + if d == nil { + return "" + } + return d.Id +} + +func (d *DatasetPart) GetName() string { + if d == nil { + return "" + } + return d.Name +} + +func (d *DatasetPart) GetUrl() *string { + if d == nil { + return nil + } + return d.Url +} + +func (d *DatasetPart) GetIndex() *int { + if d == nil { + return nil + } + return d.Index +} + +func (d *DatasetPart) GetSizeBytes() *int { + if d == nil { + return nil + } + return d.SizeBytes +} + +func (d *DatasetPart) GetNumRows() *int { + if d == nil { + return nil + } + return d.NumRows +} + +func (d *DatasetPart) GetOriginalUrl() *string { + if d == nil { + return nil + } + return d.OriginalUrl +} + +func (d *DatasetPart) GetSamples() []string { + if d == nil { + return nil + } + return d.Samples +} + +func (d *DatasetPart) GetExtraProperties() map[string]interface{} { + return d.extraProperties +} + +func (d *DatasetPart) UnmarshalJSON(data []byte) error { + type unmarshaler DatasetPart + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *d = DatasetPart(value) + extraProperties, err := internal.ExtractExtraProperties(data, *d) + if err != nil { + return err + } + d.extraProperties = extraProperties + d.rawJSON = json.RawMessage(data) + return nil +} + +func (d *DatasetPart) String() string { + if len(d.rawJSON) > 0 { + if value, err := internal.StringifyJSON(d.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(d); err == nil { + return value + } + return fmt.Sprintf("%#v", d) +} + +// The type of the dataset +type DatasetType string + +const ( + DatasetTypeEmbedInput DatasetType = "embed-input" + DatasetTypeEmbedResult DatasetType = "embed-result" + DatasetTypeClusterResult DatasetType = "cluster-result" + DatasetTypeClusterOutliers DatasetType = "cluster-outliers" + DatasetTypeRerankerFinetuneInput DatasetType = "reranker-finetune-input" + DatasetTypeSingleLabelClassificationFinetuneInput DatasetType = "single-label-classification-finetune-input" + DatasetTypeChatFinetuneInput DatasetType = "chat-finetune-input" + DatasetTypeMultiLabelClassificationFinetuneInput DatasetType = "multi-label-classification-finetune-input" +) + +func NewDatasetTypeFromString(s string) (DatasetType, error) { + switch s { + case "embed-input": + return DatasetTypeEmbedInput, nil + case "embed-result": + return DatasetTypeEmbedResult, nil + case "cluster-result": + return DatasetTypeClusterResult, nil + case "cluster-outliers": + return DatasetTypeClusterOutliers, nil + case "reranker-finetune-input": + return DatasetTypeRerankerFinetuneInput, nil + case "single-label-classification-finetune-input": + return DatasetTypeSingleLabelClassificationFinetuneInput, nil + case "chat-finetune-input": + return DatasetTypeChatFinetuneInput, nil + case "multi-label-classification-finetune-input": + return DatasetTypeMultiLabelClassificationFinetuneInput, nil + } + var t DatasetType + return "", fmt.Errorf("%s is not a valid %T", s, t) +} + +func (d DatasetType) Ptr() *DatasetType { + return &d +} + +// The validation status of the dataset +type DatasetValidationStatus string + +const ( + DatasetValidationStatusUnknown DatasetValidationStatus = "unknown" + DatasetValidationStatusQueued DatasetValidationStatus = "queued" + DatasetValidationStatusProcessing DatasetValidationStatus = "processing" + DatasetValidationStatusFailed DatasetValidationStatus = "failed" + DatasetValidationStatusValidated DatasetValidationStatus = "validated" + DatasetValidationStatusSkipped DatasetValidationStatus = "skipped" +) + +func NewDatasetValidationStatusFromString(s string) (DatasetValidationStatus, error) { + switch s { + case "unknown": + return DatasetValidationStatusUnknown, nil + case "queued": + return DatasetValidationStatusQueued, nil + case "processing": + return DatasetValidationStatusProcessing, nil + case "failed": + return DatasetValidationStatusFailed, nil + case "validated": + return DatasetValidationStatusValidated, nil + case "skipped": + return DatasetValidationStatusSkipped, nil + } + var t DatasetValidationStatus + return "", fmt.Errorf("%s is not a valid %T", s, t) +} + +func (d DatasetValidationStatus) Ptr() *DatasetValidationStatus { + return &d +} + +type FinetuneDatasetMetrics struct { + // The number of tokens of valid examples that can be used for training. + TrainableTokenCount *int64 `json:"trainable_token_count,omitempty" url:"trainable_token_count,omitempty"` + // The overall number of examples. + TotalExamples *int64 `json:"total_examples,omitempty" url:"total_examples,omitempty"` + // The number of training examples. + TrainExamples *int64 `json:"train_examples,omitempty" url:"train_examples,omitempty"` + // The size in bytes of all training examples. + TrainSizeBytes *int64 `json:"train_size_bytes,omitempty" url:"train_size_bytes,omitempty"` + // Number of evaluation examples. + EvalExamples *int64 `json:"eval_examples,omitempty" url:"eval_examples,omitempty"` + // The size in bytes of all eval examples. + EvalSizeBytes *int64 `json:"eval_size_bytes,omitempty" url:"eval_size_bytes,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (f *FinetuneDatasetMetrics) GetTrainableTokenCount() *int64 { + if f == nil { + return nil + } + return f.TrainableTokenCount +} + +func (f *FinetuneDatasetMetrics) GetTotalExamples() *int64 { + if f == nil { + return nil + } + return f.TotalExamples +} + +func (f *FinetuneDatasetMetrics) GetTrainExamples() *int64 { + if f == nil { + return nil + } + return f.TrainExamples +} + +func (f *FinetuneDatasetMetrics) GetTrainSizeBytes() *int64 { + if f == nil { + return nil + } + return f.TrainSizeBytes +} + +func (f *FinetuneDatasetMetrics) GetEvalExamples() *int64 { + if f == nil { + return nil + } + return f.EvalExamples +} + +func (f *FinetuneDatasetMetrics) GetEvalSizeBytes() *int64 { + if f == nil { + return nil + } + return f.EvalSizeBytes +} + +func (f *FinetuneDatasetMetrics) GetExtraProperties() map[string]interface{} { + return f.extraProperties +} + +func (f *FinetuneDatasetMetrics) UnmarshalJSON(data []byte) error { + type unmarshaler FinetuneDatasetMetrics + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *f = FinetuneDatasetMetrics(value) + extraProperties, err := internal.ExtractExtraProperties(data, *f) + if err != nil { + return err + } + f.extraProperties = extraProperties + f.rawJSON = json.RawMessage(data) + return nil +} + +func (f *FinetuneDatasetMetrics) String() string { + if len(f.rawJSON) > 0 { + if value, err := internal.StringifyJSON(f.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(f); err == nil { + return value + } + return fmt.Sprintf("%#v", f) +} + +type LabelMetric struct { + // Total number of examples for this label + TotalExamples *int64 `json:"total_examples,omitempty" url:"total_examples,omitempty"` + // value of the label + Label *string `json:"label,omitempty" url:"label,omitempty"` + // samples for this label + Samples []string `json:"samples,omitempty" url:"samples,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (l *LabelMetric) GetTotalExamples() *int64 { + if l == nil { + return nil + } + return l.TotalExamples +} + +func (l *LabelMetric) GetLabel() *string { + if l == nil { + return nil + } + return l.Label +} + +func (l *LabelMetric) GetSamples() []string { + if l == nil { + return nil + } + return l.Samples +} + +func (l *LabelMetric) GetExtraProperties() map[string]interface{} { + return l.extraProperties +} + +func (l *LabelMetric) UnmarshalJSON(data []byte) error { + type unmarshaler LabelMetric + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *l = LabelMetric(value) + extraProperties, err := internal.ExtractExtraProperties(data, *l) + if err != nil { + return err + } + l.extraProperties = extraProperties + l.rawJSON = json.RawMessage(data) + return nil +} + +func (l *LabelMetric) String() string { + if len(l.rawJSON) > 0 { + if value, err := internal.StringifyJSON(l.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(l); err == nil { + return value + } + return fmt.Sprintf("%#v", l) +} + +type Metrics struct { + FinetuneDatasetMetrics *FinetuneDatasetMetrics `json:"finetune_dataset_metrics,omitempty" url:"finetune_dataset_metrics,omitempty"` + EmbedData *MetricsEmbedData `json:"embed_data,omitempty" url:"embed_data,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (m *Metrics) GetFinetuneDatasetMetrics() *FinetuneDatasetMetrics { + if m == nil { + return nil + } + return m.FinetuneDatasetMetrics +} + +func (m *Metrics) GetEmbedData() *MetricsEmbedData { + if m == nil { + return nil + } + return m.EmbedData +} + +func (m *Metrics) GetExtraProperties() map[string]interface{} { + return m.extraProperties +} + +func (m *Metrics) UnmarshalJSON(data []byte) error { + type unmarshaler Metrics + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *m = Metrics(value) + extraProperties, err := internal.ExtractExtraProperties(data, *m) + if err != nil { + return err + } + m.extraProperties = extraProperties + m.rawJSON = json.RawMessage(data) + return nil +} + +func (m *Metrics) String() string { + if len(m.rawJSON) > 0 { + if value, err := internal.StringifyJSON(m.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(m); err == nil { + return value + } + return fmt.Sprintf("%#v", m) +} + +type MetricsEmbedData struct { + // the fields in the dataset + Fields []*MetricsEmbedDataFieldsItem `json:"fields,omitempty" url:"fields,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (m *MetricsEmbedData) GetFields() []*MetricsEmbedDataFieldsItem { + if m == nil { + return nil + } + return m.Fields +} + +func (m *MetricsEmbedData) GetExtraProperties() map[string]interface{} { + return m.extraProperties +} + +func (m *MetricsEmbedData) UnmarshalJSON(data []byte) error { + type unmarshaler MetricsEmbedData + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *m = MetricsEmbedData(value) + extraProperties, err := internal.ExtractExtraProperties(data, *m) + if err != nil { + return err + } + m.extraProperties = extraProperties + m.rawJSON = json.RawMessage(data) + return nil +} + +func (m *MetricsEmbedData) String() string { + if len(m.rawJSON) > 0 { + if value, err := internal.StringifyJSON(m.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(m); err == nil { + return value + } + return fmt.Sprintf("%#v", m) +} + +type MetricsEmbedDataFieldsItem struct { + // the name of the field + Name *string `json:"name,omitempty" url:"name,omitempty"` + // the number of times the field appears in the dataset + Count *float64 `json:"count,omitempty" url:"count,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (m *MetricsEmbedDataFieldsItem) GetName() *string { + if m == nil { + return nil + } + return m.Name +} + +func (m *MetricsEmbedDataFieldsItem) GetCount() *float64 { + if m == nil { + return nil + } + return m.Count +} + +func (m *MetricsEmbedDataFieldsItem) GetExtraProperties() map[string]interface{} { + return m.extraProperties +} + +func (m *MetricsEmbedDataFieldsItem) UnmarshalJSON(data []byte) error { + type unmarshaler MetricsEmbedDataFieldsItem + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *m = MetricsEmbedDataFieldsItem(value) + extraProperties, err := internal.ExtractExtraProperties(data, *m) + if err != nil { + return err + } + m.extraProperties = extraProperties + m.rawJSON = json.RawMessage(data) + return nil +} + +func (m *MetricsEmbedDataFieldsItem) String() string { + if len(m.rawJSON) > 0 { + if value, err := internal.StringifyJSON(m.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(m); err == nil { + return value + } + return fmt.Sprintf("%#v", m) +} + +type ParseInfo struct { + Separator *string `json:"separator,omitempty" url:"separator,omitempty"` + Delimiter *string `json:"delimiter,omitempty" url:"delimiter,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (p *ParseInfo) GetSeparator() *string { + if p == nil { + return nil + } + return p.Separator +} + +func (p *ParseInfo) GetDelimiter() *string { + if p == nil { + return nil + } + return p.Delimiter +} + +func (p *ParseInfo) GetExtraProperties() map[string]interface{} { + return p.extraProperties +} + +func (p *ParseInfo) UnmarshalJSON(data []byte) error { + type unmarshaler ParseInfo + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *p = ParseInfo(value) + extraProperties, err := internal.ExtractExtraProperties(data, *p) + if err != nil { + return err + } + p.extraProperties = extraProperties + p.rawJSON = json.RawMessage(data) + return nil +} + +func (p *ParseInfo) String() string { + if len(p.rawJSON) > 0 { + if value, err := internal.StringifyJSON(p.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(p); err == nil { + return value + } + return fmt.Sprintf("%#v", p) +} + +type RerankerDataMetrics struct { + // The number of training queries. + NumTrainQueries *int64 `json:"num_train_queries,omitempty" url:"num_train_queries,omitempty"` + // The sum of all relevant passages of valid training examples. + NumTrainRelevantPassages *int64 `json:"num_train_relevant_passages,omitempty" url:"num_train_relevant_passages,omitempty"` + // The sum of all hard negatives of valid training examples. + NumTrainHardNegatives *int64 `json:"num_train_hard_negatives,omitempty" url:"num_train_hard_negatives,omitempty"` + // The number of evaluation queries. + NumEvalQueries *int64 `json:"num_eval_queries,omitempty" url:"num_eval_queries,omitempty"` + // The sum of all relevant passages of valid eval examples. + NumEvalRelevantPassages *int64 `json:"num_eval_relevant_passages,omitempty" url:"num_eval_relevant_passages,omitempty"` + // The sum of all hard negatives of valid eval examples. + NumEvalHardNegatives *int64 `json:"num_eval_hard_negatives,omitempty" url:"num_eval_hard_negatives,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (r *RerankerDataMetrics) GetNumTrainQueries() *int64 { + if r == nil { + return nil + } + return r.NumTrainQueries +} + +func (r *RerankerDataMetrics) GetNumTrainRelevantPassages() *int64 { + if r == nil { + return nil + } + return r.NumTrainRelevantPassages +} + +func (r *RerankerDataMetrics) GetNumTrainHardNegatives() *int64 { + if r == nil { + return nil + } + return r.NumTrainHardNegatives +} + +func (r *RerankerDataMetrics) GetNumEvalQueries() *int64 { + if r == nil { + return nil + } + return r.NumEvalQueries +} + +func (r *RerankerDataMetrics) GetNumEvalRelevantPassages() *int64 { + if r == nil { + return nil + } + return r.NumEvalRelevantPassages +} + +func (r *RerankerDataMetrics) GetNumEvalHardNegatives() *int64 { + if r == nil { + return nil + } + return r.NumEvalHardNegatives +} + +func (r *RerankerDataMetrics) GetExtraProperties() map[string]interface{} { + return r.extraProperties +} + +func (r *RerankerDataMetrics) UnmarshalJSON(data []byte) error { + type unmarshaler RerankerDataMetrics + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *r = RerankerDataMetrics(value) + extraProperties, err := internal.ExtractExtraProperties(data, *r) + if err != nil { + return err + } + r.extraProperties = extraProperties + r.rawJSON = json.RawMessage(data) + return nil +} + +func (r *RerankerDataMetrics) String() string { + if len(r.rawJSON) > 0 { + if value, err := internal.StringifyJSON(r.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(r); err == nil { + return value + } + return fmt.Sprintf("%#v", r) +} + type DatasetsCreateResponse struct { // The dataset ID Id *string `json:"id,omitempty" url:"id,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (d *DatasetsCreateResponse) GetId() *string { + if d == nil { + return nil + } + return d.Id } func (d *DatasetsCreateResponse) GetExtraProperties() map[string]interface{} { @@ -64,24 +992,96 @@ func (d *DatasetsCreateResponse) UnmarshalJSON(data []byte) error { return err } *d = DatasetsCreateResponse(value) - - extraProperties, err := core.ExtractExtraProperties(data, *d) + extraProperties, err := internal.ExtractExtraProperties(data, *d) if err != nil { return err } d.extraProperties = extraProperties - - d._rawJSON = json.RawMessage(data) + d.rawJSON = json.RawMessage(data) return nil } func (d *DatasetsCreateResponse) String() string { - if len(d._rawJSON) > 0 { - if value, err := core.StringifyJSON(d._rawJSON); err == nil { + if len(d.rawJSON) > 0 { + if value, err := internal.StringifyJSON(d.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(d); err == nil { + if value, err := internal.StringifyJSON(d); err == nil { + return value + } + return fmt.Sprintf("%#v", d) +} + +// the underlying files that make up the dataset +type DatasetsCreateResponseDatasetPartsItem struct { + // the name of the dataset part + Name *string `json:"name,omitempty" url:"name,omitempty"` + // the number of rows in the dataset part + NumRows *float64 `json:"num_rows,omitempty" url:"num_rows,omitempty"` + Samples []string `json:"samples,omitempty" url:"samples,omitempty"` + // the kind of dataset part + PartKind *string `json:"part_kind,omitempty" url:"part_kind,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (d *DatasetsCreateResponseDatasetPartsItem) GetName() *string { + if d == nil { + return nil + } + return d.Name +} + +func (d *DatasetsCreateResponseDatasetPartsItem) GetNumRows() *float64 { + if d == nil { + return nil + } + return d.NumRows +} + +func (d *DatasetsCreateResponseDatasetPartsItem) GetSamples() []string { + if d == nil { + return nil + } + return d.Samples +} + +func (d *DatasetsCreateResponseDatasetPartsItem) GetPartKind() *string { + if d == nil { + return nil + } + return d.PartKind +} + +func (d *DatasetsCreateResponseDatasetPartsItem) GetExtraProperties() map[string]interface{} { + return d.extraProperties +} + +func (d *DatasetsCreateResponseDatasetPartsItem) UnmarshalJSON(data []byte) error { + type unmarshaler DatasetsCreateResponseDatasetPartsItem + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *d = DatasetsCreateResponseDatasetPartsItem(value) + extraProperties, err := internal.ExtractExtraProperties(data, *d) + if err != nil { + return err + } + d.extraProperties = extraProperties + d.rawJSON = json.RawMessage(data) + return nil +} + +func (d *DatasetsCreateResponseDatasetPartsItem) String() string { + if len(d.rawJSON) > 0 { + if value, err := internal.StringifyJSON(d.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(d); err == nil { return value } return fmt.Sprintf("%#v", d) @@ -91,7 +1091,14 @@ type DatasetsGetResponse struct { Dataset *Dataset `json:"dataset,omitempty" url:"dataset,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (d *DatasetsGetResponse) GetDataset() *Dataset { + if d == nil { + return nil + } + return d.Dataset } func (d *DatasetsGetResponse) GetExtraProperties() map[string]interface{} { @@ -105,24 +1112,22 @@ func (d *DatasetsGetResponse) UnmarshalJSON(data []byte) error { return err } *d = DatasetsGetResponse(value) - - extraProperties, err := core.ExtractExtraProperties(data, *d) + extraProperties, err := internal.ExtractExtraProperties(data, *d) if err != nil { return err } d.extraProperties = extraProperties - - d._rawJSON = json.RawMessage(data) + d.rawJSON = json.RawMessage(data) return nil } func (d *DatasetsGetResponse) String() string { - if len(d._rawJSON) > 0 { - if value, err := core.StringifyJSON(d._rawJSON); err == nil { + if len(d.rawJSON) > 0 { + if value, err := internal.StringifyJSON(d.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(d); err == nil { + if value, err := internal.StringifyJSON(d); err == nil { return value } return fmt.Sprintf("%#v", d) @@ -133,7 +1138,14 @@ type DatasetsGetUsageResponse struct { OrganizationUsage *int64 `json:"organization_usage,omitempty" url:"organization_usage,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (d *DatasetsGetUsageResponse) GetOrganizationUsage() *int64 { + if d == nil { + return nil + } + return d.OrganizationUsage } func (d *DatasetsGetUsageResponse) GetExtraProperties() map[string]interface{} { @@ -147,24 +1159,22 @@ func (d *DatasetsGetUsageResponse) UnmarshalJSON(data []byte) error { return err } *d = DatasetsGetUsageResponse(value) - - extraProperties, err := core.ExtractExtraProperties(data, *d) + extraProperties, err := internal.ExtractExtraProperties(data, *d) if err != nil { return err } d.extraProperties = extraProperties - - d._rawJSON = json.RawMessage(data) + d.rawJSON = json.RawMessage(data) return nil } func (d *DatasetsGetUsageResponse) String() string { - if len(d._rawJSON) > 0 { - if value, err := core.StringifyJSON(d._rawJSON); err == nil { + if len(d.rawJSON) > 0 { + if value, err := internal.StringifyJSON(d.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(d); err == nil { + if value, err := internal.StringifyJSON(d); err == nil { return value } return fmt.Sprintf("%#v", d) @@ -174,7 +1184,14 @@ type DatasetsListResponse struct { Datasets []*Dataset `json:"datasets,omitempty" url:"datasets,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (d *DatasetsListResponse) GetDatasets() []*Dataset { + if d == nil { + return nil + } + return d.Datasets } func (d *DatasetsListResponse) GetExtraProperties() map[string]interface{} { @@ -188,24 +1205,22 @@ func (d *DatasetsListResponse) UnmarshalJSON(data []byte) error { return err } *d = DatasetsListResponse(value) - - extraProperties, err := core.ExtractExtraProperties(data, *d) + extraProperties, err := internal.ExtractExtraProperties(data, *d) if err != nil { return err } d.extraProperties = extraProperties - - d._rawJSON = json.RawMessage(data) + d.rawJSON = json.RawMessage(data) return nil } func (d *DatasetsListResponse) String() string { - if len(d._rawJSON) > 0 { - if value, err := core.StringifyJSON(d._rawJSON); err == nil { + if len(d.rawJSON) > 0 { + if value, err := internal.StringifyJSON(d.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(d); err == nil { + if value, err := internal.StringifyJSON(d); err == nil { return value } return fmt.Sprintf("%#v", d) diff --git a/datasets/client.go b/datasets/client.go index c07d3d2..fe702cf 100644 --- a/datasets/client.go +++ b/datasets/client.go @@ -3,22 +3,19 @@ package datasets import ( - bytes "bytes" context "context" - json "encoding/json" - errors "errors" v2 "github.com/cohere-ai/cohere-go/v2" core "github.com/cohere-ai/cohere-go/v2/core" + internal "github.com/cohere-ai/cohere-go/v2/internal" option "github.com/cohere-ai/cohere-go/v2/option" io "io" - multipart "mime/multipart" http "net/http" os "os" ) type Client struct { baseURL string - caller *core.Caller + caller *internal.Caller header http.Header } @@ -29,8 +26,8 @@ func NewClient(opts ...option.RequestOption) *Client { } return &Client{ baseURL: options.BaseURL, - caller: core.NewCaller( - &core.CallerParams{ + caller: internal.NewCaller( + &internal.CallerParams{ Client: options.HTTPClient, MaxAttempts: options.MaxAttempts, }, @@ -46,128 +43,99 @@ func (c *Client) List( opts ...option.RequestOption, ) (*v2.DatasetsListResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) endpointURL := baseURL + "/v1/datasets" - - queryParams, err := core.QueryValues(request) + queryParams, err := internal.QueryValues(request) if err != nil { return nil, err } if len(queryParams) > 0 { endpointURL += "?" + queryParams.Encode() } - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } var response *v2.DatasetsListResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodGet, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err @@ -184,160 +152,113 @@ func (c *Client) Create( opts ...option.RequestOption, ) (*v2.DatasetsCreateResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) endpointURL := baseURL + "/v1/datasets" - - queryParams, err := core.QueryValues(request) + queryParams, err := internal.QueryValues(request) if err != nil { return nil, err } if len(queryParams) > 0 { endpointURL += "?" + queryParams.Encode() } - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError - } - - var response *v2.DatasetsCreateResponse - requestBuffer := bytes.NewBuffer(nil) - writer := multipart.NewWriter(requestBuffer) - dataFilename := "data_filename" - if named, ok := data.(interface{ Name() string }); ok { - dataFilename = named.Name() - } - dataPart, err := writer.CreateFormFile("data", dataFilename) - if err != nil { - return nil, err + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } - if _, err := io.Copy(dataPart, data); err != nil { + writer := internal.NewMultipartWriter() + if err := writer.WriteFile("data", data); err != nil { return nil, err } if evalData != nil { - evalDataFilename := "evalData_filename" - if named, ok := evalData.(interface{ Name() string }); ok { - evalDataFilename = named.Name() - } - evalDataPart, err := writer.CreateFormFile("eval_data", evalDataFilename) - if err != nil { - return nil, err - } - if _, err := io.Copy(evalDataPart, evalData); err != nil { + if err := writer.WriteFile("eval_data", evalData); err != nil { return nil, err } } if err := writer.Close(); err != nil { return nil, err } - headers.Set("Content-Type", writer.FormDataContentType()) + headers.Set("Content-Type", writer.ContentType()) + var response *v2.DatasetsCreateResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodPost, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, - Request: requestBuffer, + Request: writer.Buffer(), Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err @@ -351,120 +272,92 @@ func (c *Client) GetUsage( opts ...option.RequestOption, ) (*v2.DatasetsGetUsageResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) endpointURL := baseURL + "/v1/datasets/usage" - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } var response *v2.DatasetsGetUsageResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodGet, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err @@ -479,120 +372,95 @@ func (c *Client) Get( opts ...option.RequestOption, ) (*v2.DatasetsGetResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } - endpointURL := core.EncodeURL(baseURL+"/v1/datasets/%v", id) - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) + endpointURL := internal.EncodeURL( + baseURL+"/v1/datasets/%v", + id, + ) + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } var response *v2.DatasetsGetResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodGet, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err @@ -607,120 +475,95 @@ func (c *Client) Delete( opts ...option.RequestOption, ) (map[string]interface{}, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } - endpointURL := core.EncodeURL(baseURL+"/v1/datasets/%v", id) - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) + endpointURL := internal.EncodeURL( + baseURL+"/v1/datasets/%v", + id, + ) + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } var response map[string]interface{} if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodDelete, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err diff --git a/embed_jobs.go b/embed_jobs.go index 297e6d7..13104cd 100644 --- a/embed_jobs.go +++ b/embed_jobs.go @@ -3,7 +3,10 @@ package api import ( + json "encoding/json" fmt "fmt" + internal "github.com/cohere-ai/cohere-go/v2/internal" + time "time" ) type CreateEmbedJobRequest struct { @@ -35,6 +38,298 @@ type CreateEmbedJobRequest struct { Truncate *CreateEmbedJobRequestTruncate `json:"truncate,omitempty" url:"-"` } +// Response from creating an embed job. +type CreateEmbedJobResponse struct { + JobId string `json:"job_id" url:"job_id"` + Meta *ApiMeta `json:"meta,omitempty" url:"meta,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *CreateEmbedJobResponse) GetJobId() string { + if c == nil { + return "" + } + return c.JobId +} + +func (c *CreateEmbedJobResponse) GetMeta() *ApiMeta { + if c == nil { + return nil + } + return c.Meta +} + +func (c *CreateEmbedJobResponse) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *CreateEmbedJobResponse) UnmarshalJSON(data []byte) error { + type unmarshaler CreateEmbedJobResponse + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = CreateEmbedJobResponse(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *CreateEmbedJobResponse) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +type EmbedJob struct { + // ID of the embed job + JobId string `json:"job_id" url:"job_id"` + // The name of the embed job + Name *string `json:"name,omitempty" url:"name,omitempty"` + // The status of the embed job + Status EmbedJobStatus `json:"status" url:"status"` + // The creation date of the embed job + CreatedAt time.Time `json:"created_at" url:"created_at"` + // ID of the input dataset + InputDatasetId string `json:"input_dataset_id" url:"input_dataset_id"` + // ID of the resulting output dataset + OutputDatasetId *string `json:"output_dataset_id,omitempty" url:"output_dataset_id,omitempty"` + // ID of the model used to embed + Model string `json:"model" url:"model"` + // The truncation option used + Truncate EmbedJobTruncate `json:"truncate" url:"truncate"` + Meta *ApiMeta `json:"meta,omitempty" url:"meta,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (e *EmbedJob) GetJobId() string { + if e == nil { + return "" + } + return e.JobId +} + +func (e *EmbedJob) GetName() *string { + if e == nil { + return nil + } + return e.Name +} + +func (e *EmbedJob) GetStatus() EmbedJobStatus { + if e == nil { + return "" + } + return e.Status +} + +func (e *EmbedJob) GetCreatedAt() time.Time { + if e == nil { + return time.Time{} + } + return e.CreatedAt +} + +func (e *EmbedJob) GetInputDatasetId() string { + if e == nil { + return "" + } + return e.InputDatasetId +} + +func (e *EmbedJob) GetOutputDatasetId() *string { + if e == nil { + return nil + } + return e.OutputDatasetId +} + +func (e *EmbedJob) GetModel() string { + if e == nil { + return "" + } + return e.Model +} + +func (e *EmbedJob) GetTruncate() EmbedJobTruncate { + if e == nil { + return "" + } + return e.Truncate +} + +func (e *EmbedJob) GetMeta() *ApiMeta { + if e == nil { + return nil + } + return e.Meta +} + +func (e *EmbedJob) GetExtraProperties() map[string]interface{} { + return e.extraProperties +} + +func (e *EmbedJob) UnmarshalJSON(data []byte) error { + type embed EmbedJob + var unmarshaler = struct { + embed + CreatedAt *internal.DateTime `json:"created_at"` + }{ + embed: embed(*e), + } + if err := json.Unmarshal(data, &unmarshaler); err != nil { + return err + } + *e = EmbedJob(unmarshaler.embed) + e.CreatedAt = unmarshaler.CreatedAt.Time() + extraProperties, err := internal.ExtractExtraProperties(data, *e) + if err != nil { + return err + } + e.extraProperties = extraProperties + e.rawJSON = json.RawMessage(data) + return nil +} + +func (e *EmbedJob) MarshalJSON() ([]byte, error) { + type embed EmbedJob + var marshaler = struct { + embed + CreatedAt *internal.DateTime `json:"created_at"` + }{ + embed: embed(*e), + CreatedAt: internal.NewDateTime(e.CreatedAt), + } + return json.Marshal(marshaler) +} + +func (e *EmbedJob) String() string { + if len(e.rawJSON) > 0 { + if value, err := internal.StringifyJSON(e.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(e); err == nil { + return value + } + return fmt.Sprintf("%#v", e) +} + +// The status of the embed job +type EmbedJobStatus string + +const ( + EmbedJobStatusProcessing EmbedJobStatus = "processing" + EmbedJobStatusComplete EmbedJobStatus = "complete" + EmbedJobStatusCancelling EmbedJobStatus = "cancelling" + EmbedJobStatusCancelled EmbedJobStatus = "cancelled" + EmbedJobStatusFailed EmbedJobStatus = "failed" +) + +func NewEmbedJobStatusFromString(s string) (EmbedJobStatus, error) { + switch s { + case "processing": + return EmbedJobStatusProcessing, nil + case "complete": + return EmbedJobStatusComplete, nil + case "cancelling": + return EmbedJobStatusCancelling, nil + case "cancelled": + return EmbedJobStatusCancelled, nil + case "failed": + return EmbedJobStatusFailed, nil + } + var t EmbedJobStatus + return "", fmt.Errorf("%s is not a valid %T", s, t) +} + +func (e EmbedJobStatus) Ptr() *EmbedJobStatus { + return &e +} + +// The truncation option used +type EmbedJobTruncate string + +const ( + EmbedJobTruncateStart EmbedJobTruncate = "START" + EmbedJobTruncateEnd EmbedJobTruncate = "END" +) + +func NewEmbedJobTruncateFromString(s string) (EmbedJobTruncate, error) { + switch s { + case "START": + return EmbedJobTruncateStart, nil + case "END": + return EmbedJobTruncateEnd, nil + } + var t EmbedJobTruncate + return "", fmt.Errorf("%s is not a valid %T", s, t) +} + +func (e EmbedJobTruncate) Ptr() *EmbedJobTruncate { + return &e +} + +type ListEmbedJobResponse struct { + EmbedJobs []*EmbedJob `json:"embed_jobs,omitempty" url:"embed_jobs,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (l *ListEmbedJobResponse) GetEmbedJobs() []*EmbedJob { + if l == nil { + return nil + } + return l.EmbedJobs +} + +func (l *ListEmbedJobResponse) GetExtraProperties() map[string]interface{} { + return l.extraProperties +} + +func (l *ListEmbedJobResponse) UnmarshalJSON(data []byte) error { + type unmarshaler ListEmbedJobResponse + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *l = ListEmbedJobResponse(value) + extraProperties, err := internal.ExtractExtraProperties(data, *l) + if err != nil { + return err + } + l.extraProperties = extraProperties + l.rawJSON = json.RawMessage(data) + return nil +} + +func (l *ListEmbedJobResponse) String() string { + if len(l.rawJSON) > 0 { + if value, err := internal.StringifyJSON(l.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(l); err == nil { + return value + } + return fmt.Sprintf("%#v", l) +} + // One of `START|END` to specify how the API will handle inputs longer than the maximum token length. // // Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. diff --git a/embedjobs/client.go b/embedjobs/client.go index cb513db..7f7a459 100644 --- a/embedjobs/client.go +++ b/embedjobs/client.go @@ -3,21 +3,18 @@ package embedjobs import ( - bytes "bytes" context "context" - json "encoding/json" - errors "errors" v2 "github.com/cohere-ai/cohere-go/v2" core "github.com/cohere-ai/cohere-go/v2/core" + internal "github.com/cohere-ai/cohere-go/v2/internal" option "github.com/cohere-ai/cohere-go/v2/option" - io "io" http "net/http" os "os" ) type Client struct { baseURL string - caller *core.Caller + caller *internal.Caller header http.Header } @@ -28,8 +25,8 @@ func NewClient(opts ...option.RequestOption) *Client { } return &Client{ baseURL: options.BaseURL, - caller: core.NewCaller( - &core.CallerParams{ + caller: internal.NewCaller( + &internal.CallerParams{ Client: options.HTTPClient, MaxAttempts: options.MaxAttempts, }, @@ -44,120 +41,92 @@ func (c *Client) List( opts ...option.RequestOption, ) (*v2.ListEmbedJobResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) endpointURL := baseURL + "/v1/embed-jobs" - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } var response *v2.ListEmbedJobResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodGet, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err @@ -172,121 +141,94 @@ func (c *Client) Create( opts ...option.RequestOption, ) (*v2.CreateEmbedJobResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) endpointURL := baseURL + "/v1/embed-jobs" - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + headers.Set("Content-Type", "application/json") + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } var response *v2.CreateEmbedJobResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodPost, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Request: request, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err @@ -302,120 +244,95 @@ func (c *Client) Get( opts ...option.RequestOption, ) (*v2.EmbedJob, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } - endpointURL := core.EncodeURL(baseURL+"/v1/embed-jobs/%v", id) - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) + endpointURL := internal.EncodeURL( + baseURL+"/v1/embed-jobs/%v", + id, + ) + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } var response *v2.EmbedJob if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodGet, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err @@ -431,118 +348,93 @@ func (c *Client) Cancel( opts ...option.RequestOption, ) error { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } - endpointURL := core.EncodeURL(baseURL+"/v1/embed-jobs/%v/cancel", id) - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) + endpointURL := internal.EncodeURL( + baseURL+"/v1/embed-jobs/%v/cancel", + id, + ) + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodPost, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return err diff --git a/errors.go b/errors.go index 42e07bd..b68d277 100644 --- a/errors.go +++ b/errors.go @@ -7,6 +7,10 @@ import ( core "github.com/cohere-ai/cohere-go/v2/core" ) +// This error is returned when the request is not well formed. This could be because: +// - JSON is invalid +// - The request is missing required fields +// - The request contains an invalid combination of fields type BadRequestError struct { *core.APIError Body interface{} @@ -33,11 +37,11 @@ func (b *BadRequestError) Unwrap() error { // This error is returned when a request is cancelled by the user. type ClientClosedRequestError struct { *core.APIError - Body *ClientClosedRequestErrorBody + Body interface{} } func (c *ClientClosedRequestError) UnmarshalJSON(data []byte) error { - var body *ClientClosedRequestErrorBody + var body interface{} if err := json.Unmarshal(data, &body); err != nil { return err } @@ -54,6 +58,9 @@ func (c *ClientClosedRequestError) Unwrap() error { return c.APIError } +// This error indicates that the operation attempted to be performed is not allowed. This could be because: +// - The api token is invalid +// - The user does not have the necessary permissions type ForbiddenError struct { *core.APIError Body interface{} @@ -81,11 +88,11 @@ func (f *ForbiddenError) Unwrap() error { // - An internal services taking too long to respond type GatewayTimeoutError struct { *core.APIError - Body *GatewayTimeoutErrorBody + Body interface{} } func (g *GatewayTimeoutError) UnmarshalJSON(data []byte) error { - var body *GatewayTimeoutErrorBody + var body interface{} if err := json.Unmarshal(data, &body); err != nil { return err } @@ -102,6 +109,7 @@ func (g *GatewayTimeoutError) Unwrap() error { return g.APIError } +// This error is returned when an uncategorised internal server error occurs. type InternalServerError struct { *core.APIError Body interface{} @@ -125,6 +133,33 @@ func (i *InternalServerError) Unwrap() error { return i.APIError } +// This error is returned when a request or response contains a deny-listed token. +type InvalidTokenError struct { + *core.APIError + Body interface{} +} + +func (i *InvalidTokenError) UnmarshalJSON(data []byte) error { + var body interface{} + if err := json.Unmarshal(data, &body); err != nil { + return err + } + i.StatusCode = 498 + i.Body = body + return nil +} + +func (i *InvalidTokenError) MarshalJSON() ([]byte, error) { + return json.Marshal(i.Body) +} + +func (i *InvalidTokenError) Unwrap() error { + return i.APIError +} + +// This error is returned when a resource is not found. This could be because: +// - The endpoint does not exist +// - The resource does not exist eg model id, dataset id type NotFoundError struct { *core.APIError Body interface{} @@ -151,11 +186,11 @@ func (n *NotFoundError) Unwrap() error { // This error is returned when the requested feature is not implemented. type NotImplementedError struct { *core.APIError - Body *NotImplementedErrorBody + Body interface{} } func (n *NotImplementedError) UnmarshalJSON(data []byte) error { - var body *NotImplementedErrorBody + var body interface{} if err := json.Unmarshal(data, &body); err != nil { return err } @@ -172,6 +207,8 @@ func (n *NotImplementedError) Unwrap() error { return n.APIError } +// This error is returned when the service is unavailable. This could be due to: +// - Too many users trying to access the service at the same time type ServiceUnavailableError struct { *core.APIError Body interface{} @@ -198,11 +235,11 @@ func (s *ServiceUnavailableError) Unwrap() error { // Too many requests type TooManyRequestsError struct { *core.APIError - Body *TooManyRequestsErrorBody + Body interface{} } func (t *TooManyRequestsError) UnmarshalJSON(data []byte) error { - var body *TooManyRequestsErrorBody + var body interface{} if err := json.Unmarshal(data, &body); err != nil { return err } @@ -219,6 +256,9 @@ func (t *TooManyRequestsError) Unwrap() error { return t.APIError } +// This error indicates that the operation attempted to be performed is not allowed. This could be because: +// - The api token is invalid +// - The user does not have the necessary permissions type UnauthorizedError struct { *core.APIError Body interface{} @@ -248,11 +288,11 @@ func (u *UnauthorizedError) Unwrap() error { // - The request contains an invalid combination of fields type UnprocessableEntityError struct { *core.APIError - Body *UnprocessableEntityErrorBody + Body interface{} } func (u *UnprocessableEntityError) UnmarshalJSON(data []byte) error { - var body *UnprocessableEntityErrorBody + var body interface{} if err := json.Unmarshal(data, &body); err != nil { return err } diff --git a/file_param.go b/file_param.go new file mode 100644 index 0000000..737abb9 --- /dev/null +++ b/file_param.go @@ -0,0 +1,41 @@ +package api + +import ( + "io" +) + +// FileParam is a file type suitable for multipart/form-data uploads. +type FileParam struct { + io.Reader + filename string + contentType string +} + +// FileParamOption adapts the behavior of the FileParam. No options are +// implemented yet, but this interface allows for future extensibility. +type FileParamOption interface { + apply() +} + +// NewFileParam returns a *FileParam type suitable for multipart/form-data uploads. All file +// upload endpoints accept a simple io.Reader, which is usually created by opening a file +// via os.Open. +// +// However, some endpoints require additional metadata about the file such as a specific +// Content-Type or custom filename. FileParam makes it easier to create the correct type +// signature for these endpoints. +func NewFileParam( + reader io.Reader, + filename string, + contentType string, + opts ...FileParamOption, +) *FileParam { + return &FileParam{ + Reader: reader, + filename: filename, + contentType: contentType, + } +} + +func (f *FileParam) Name() string { return f.filename } +func (f *FileParam) ContentType() string { return f.contentType } diff --git a/finetuning.go b/finetuning.go index 368ba4c..8331c18 100644 --- a/finetuning.go +++ b/finetuning.go @@ -4,13 +4,14 @@ package api import ( json "encoding/json" - core "github.com/cohere-ai/cohere-go/v2/core" finetuning "github.com/cohere-ai/cohere-go/v2/finetuning" + internal "github.com/cohere-ai/cohere-go/v2/internal" time "time" ) type FinetuningListEventsRequest struct { - // Maximum number of results to be returned by the server. If 0, defaults to 50. + // Maximum number of results to be returned by the server. If 0, defaults to + // 50. PageSize *int `json:"-" url:"page_size,omitempty"` // Request a specific page of the list results. PageToken *string `json:"-" url:"page_token,omitempty"` @@ -19,13 +20,13 @@ type FinetuningListEventsRequest struct { // " desc" to the field name. For example: "created_at desc,name". // // Supported sorting fields: - // - // - created_at (default) + // - created_at (default) OrderBy *string `json:"-" url:"order_by,omitempty"` } type FinetuningListFinetunedModelsRequest struct { - // Maximum number of results to be returned by the server. If 0, defaults to 50. + // Maximum number of results to be returned by the server. If 0, defaults to + // 50. PageSize *int `json:"-" url:"page_size,omitempty"` // Request a specific page of the list results. PageToken *string `json:"-" url:"page_token,omitempty"` @@ -34,13 +35,13 @@ type FinetuningListFinetunedModelsRequest struct { // " desc" to the field name. For example: "created_at desc,name". // // Supported sorting fields: - // - // - created_at (default) + // - created_at (default) OrderBy *string `json:"-" url:"order_by,omitempty"` } type FinetuningListTrainingStepMetricsRequest struct { - // Maximum number of results to be returned by the server. If 0, defaults to 50. + // Maximum number of results to be returned by the server. If 0, defaults to + // 50. PageSize *int `json:"-" url:"page_size,omitempty"` // Request a specific page of the list results. PageToken *string `json:"-" url:"page_token,omitempty"` @@ -81,16 +82,16 @@ func (f *FinetuningUpdateFinetunedModelRequest) MarshalJSON() ([]byte, error) { type embed FinetuningUpdateFinetunedModelRequest var marshaler = struct { embed - CreatedAt *core.DateTime `json:"created_at,omitempty"` - UpdatedAt *core.DateTime `json:"updated_at,omitempty"` - CompletedAt *core.DateTime `json:"completed_at,omitempty"` - LastUsed *core.DateTime `json:"last_used,omitempty"` + CreatedAt *internal.DateTime `json:"created_at,omitempty"` + UpdatedAt *internal.DateTime `json:"updated_at,omitempty"` + CompletedAt *internal.DateTime `json:"completed_at,omitempty"` + LastUsed *internal.DateTime `json:"last_used,omitempty"` }{ embed: embed(*f), - CreatedAt: core.NewOptionalDateTime(f.CreatedAt), - UpdatedAt: core.NewOptionalDateTime(f.UpdatedAt), - CompletedAt: core.NewOptionalDateTime(f.CompletedAt), - LastUsed: core.NewOptionalDateTime(f.LastUsed), + CreatedAt: internal.NewOptionalDateTime(f.CreatedAt), + UpdatedAt: internal.NewOptionalDateTime(f.UpdatedAt), + CompletedAt: internal.NewOptionalDateTime(f.CompletedAt), + LastUsed: internal.NewOptionalDateTime(f.LastUsed), } return json.Marshal(marshaler) } diff --git a/finetuning/client/client.go b/finetuning/client/client.go index 1da4cbe..c74ad7c 100644 --- a/finetuning/client/client.go +++ b/finetuning/client/client.go @@ -3,22 +3,19 @@ package client import ( - bytes "bytes" context "context" - json "encoding/json" - errors "errors" v2 "github.com/cohere-ai/cohere-go/v2" core "github.com/cohere-ai/cohere-go/v2/core" finetuning "github.com/cohere-ai/cohere-go/v2/finetuning" + internal "github.com/cohere-ai/cohere-go/v2/internal" option "github.com/cohere-ai/cohere-go/v2/option" - io "io" http "net/http" os "os" ) type Client struct { baseURL string - caller *core.Caller + caller *internal.Caller header http.Header } @@ -29,8 +26,8 @@ func NewClient(opts ...option.RequestOption) *Client { } return &Client{ baseURL: options.BaseURL, - caller: core.NewCaller( - &core.CallerParams{ + caller: internal.NewCaller( + &internal.CallerParams{ Client: options.HTTPClient, MaxAttempts: options.MaxAttempts, }, @@ -45,93 +42,69 @@ func (c *Client) ListFinetunedModels( opts ...option.RequestOption, ) (*finetuning.ListFinetunedModelsResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) endpointURL := baseURL + "/v1/finetuning/finetuned-models" - - queryParams, err := core.QueryValues(request) + queryParams, err := internal.QueryValues(request) if err != nil { return nil, err } if len(queryParams) > 0 { endpointURL += "?" + queryParams.Encode() } - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, } var response *finetuning.ListFinetunedModelsResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodGet, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err @@ -145,86 +118,64 @@ func (c *Client) CreateFinetunedModel( opts ...option.RequestOption, ) (*finetuning.CreateFinetunedModelResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) endpointURL := baseURL + "/v1/finetuning/finetuned-models" - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + headers.Set("Content-Type", "application/json") + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, } var response *finetuning.CreateFinetunedModelResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodPost, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Request: request, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err @@ -239,85 +190,65 @@ func (c *Client) GetFinetunedModel( opts ...option.RequestOption, ) (*finetuning.GetFinetunedModelResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } - endpointURL := core.EncodeURL(baseURL+"/v1/finetuning/finetuned-models/%v", id) - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) + endpointURL := internal.EncodeURL( + baseURL+"/v1/finetuning/finetuned-models/%v", + id, + ) + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, } var response *finetuning.GetFinetunedModelResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodGet, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err @@ -332,85 +263,65 @@ func (c *Client) DeleteFinetunedModel( opts ...option.RequestOption, ) (finetuning.DeleteFinetunedModelResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } - endpointURL := core.EncodeURL(baseURL+"/v1/finetuning/finetuned-models/%v", id) - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) + endpointURL := internal.EncodeURL( + baseURL+"/v1/finetuning/finetuned-models/%v", + id, + ) + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, } var response finetuning.DeleteFinetunedModelResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodDelete, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err @@ -426,86 +337,67 @@ func (c *Client) UpdateFinetunedModel( opts ...option.RequestOption, ) (*finetuning.UpdateFinetunedModelResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } - endpointURL := core.EncodeURL(baseURL+"/v1/finetuning/finetuned-models/%v", id) - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) + endpointURL := internal.EncodeURL( + baseURL+"/v1/finetuning/finetuned-models/%v", + id, + ) + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + headers.Set("Content-Type", "application/json") + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, } var response *finetuning.UpdateFinetunedModelResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodPatch, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Request: request, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err @@ -521,93 +413,72 @@ func (c *Client) ListEvents( opts ...option.RequestOption, ) (*finetuning.ListEventsResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } - endpointURL := core.EncodeURL(baseURL+"/v1/finetuning/finetuned-models/%v/events", finetunedModelId) - - queryParams, err := core.QueryValues(request) + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) + endpointURL := internal.EncodeURL( + baseURL+"/v1/finetuning/finetuned-models/%v/events", + finetunedModelId, + ) + queryParams, err := internal.QueryValues(request) if err != nil { return nil, err } if len(queryParams) > 0 { endpointURL += "?" + queryParams.Encode() } - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, } var response *finetuning.ListEventsResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodGet, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err @@ -623,93 +494,72 @@ func (c *Client) ListTrainingStepMetrics( opts ...option.RequestOption, ) (*finetuning.ListTrainingStepMetricsResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } - endpointURL := core.EncodeURL(baseURL+"/v1/finetuning/finetuned-models/%v/training-step-metrics", finetunedModelId) - - queryParams, err := core.QueryValues(request) + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) + endpointURL := internal.EncodeURL( + baseURL+"/v1/finetuning/finetuned-models/%v/training-step-metrics", + finetunedModelId, + ) + queryParams, err := internal.QueryValues(request) if err != nil { return nil, err } if len(queryParams) > 0 { endpointURL += "?" + queryParams.Encode() } - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, } var response *finetuning.ListTrainingStepMetricsResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodGet, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err diff --git a/finetuning/types.go b/finetuning/finetuning.go similarity index 64% rename from finetuning/types.go rename to finetuning/finetuning.go index 3b828c0..59f0c10 100644 --- a/finetuning/types.go +++ b/finetuning/finetuning.go @@ -5,7 +5,7 @@ package finetuning import ( json "encoding/json" fmt "fmt" - core "github.com/cohere-ai/cohere-go/v2/core" + internal "github.com/cohere-ai/cohere-go/v2/internal" time "time" ) @@ -21,7 +21,35 @@ type BaseModel struct { Strategy *Strategy `json:"strategy,omitempty" url:"strategy,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (b *BaseModel) GetName() *string { + if b == nil { + return nil + } + return b.Name +} + +func (b *BaseModel) GetVersion() *string { + if b == nil { + return nil + } + return b.Version +} + +func (b *BaseModel) GetBaseType() BaseType { + if b == nil { + return "" + } + return b.BaseType +} + +func (b *BaseModel) GetStrategy() *Strategy { + if b == nil { + return nil + } + return b.Strategy } func (b *BaseModel) GetExtraProperties() map[string]interface{} { @@ -35,24 +63,22 @@ func (b *BaseModel) UnmarshalJSON(data []byte) error { return err } *b = BaseModel(value) - - extraProperties, err := core.ExtractExtraProperties(data, *b) + extraProperties, err := internal.ExtractExtraProperties(data, *b) if err != nil { return err } b.extraProperties = extraProperties - - b._rawJSON = json.RawMessage(data) + b.rawJSON = json.RawMessage(data) return nil } func (b *BaseModel) String() string { - if len(b._rawJSON) > 0 { - if value, err := core.StringifyJSON(b._rawJSON); err == nil { + if len(b.rawJSON) > 0 { + if value, err := internal.StringifyJSON(b.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(b); err == nil { + if value, err := internal.StringifyJSON(b); err == nil { return value } return fmt.Sprintf("%#v", b) @@ -60,11 +86,11 @@ func (b *BaseModel) String() string { // The possible types of fine-tuned models. // -// - BASE_TYPE_UNSPECIFIED: Unspecified model. -// - BASE_TYPE_GENERATIVE: Deprecated: Generative model. -// - BASE_TYPE_CLASSIFICATION: Classification model. -// - BASE_TYPE_RERANK: Rerank model. -// - BASE_TYPE_CHAT: Chat model. +// - BASE_TYPE_UNSPECIFIED: Unspecified model. +// - BASE_TYPE_GENERATIVE: Deprecated: Generative model. +// - BASE_TYPE_CLASSIFICATION: Classification model. +// - BASE_TYPE_RERANK: Rerank model. +// - BASE_TYPE_CHAT: Chat model. type BaseType string const ( @@ -102,7 +128,14 @@ type CreateFinetunedModelResponse struct { FinetunedModel *FinetunedModel `json:"finetuned_model,omitempty" url:"finetuned_model,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (c *CreateFinetunedModelResponse) GetFinetunedModel() *FinetunedModel { + if c == nil { + return nil + } + return c.FinetunedModel } func (c *CreateFinetunedModelResponse) GetExtraProperties() map[string]interface{} { @@ -116,24 +149,22 @@ func (c *CreateFinetunedModelResponse) UnmarshalJSON(data []byte) error { return err } *c = CreateFinetunedModelResponse(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + extraProperties, err := internal.ExtractExtraProperties(data, *c) if err != nil { return err } c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + c.rawJSON = json.RawMessage(data) return nil } func (c *CreateFinetunedModelResponse) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(c); err == nil { return value } return fmt.Sprintf("%#v", c) @@ -152,7 +183,28 @@ type Event struct { CreatedAt *time.Time `json:"created_at,omitempty" url:"created_at,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (e *Event) GetUserId() *string { + if e == nil { + return nil + } + return e.UserId +} + +func (e *Event) GetStatus() *Status { + if e == nil { + return nil + } + return e.Status +} + +func (e *Event) GetCreatedAt() *time.Time { + if e == nil { + return nil + } + return e.CreatedAt } func (e *Event) GetExtraProperties() map[string]interface{} { @@ -163,7 +215,7 @@ func (e *Event) UnmarshalJSON(data []byte) error { type embed Event var unmarshaler = struct { embed - CreatedAt *core.DateTime `json:"created_at,omitempty"` + CreatedAt *internal.DateTime `json:"created_at,omitempty"` }{ embed: embed(*e), } @@ -172,14 +224,12 @@ func (e *Event) UnmarshalJSON(data []byte) error { } *e = Event(unmarshaler.embed) e.CreatedAt = unmarshaler.CreatedAt.TimePtr() - - extraProperties, err := core.ExtractExtraProperties(data, *e) + extraProperties, err := internal.ExtractExtraProperties(data, *e) if err != nil { return err } e.extraProperties = extraProperties - - e._rawJSON = json.RawMessage(data) + e.rawJSON = json.RawMessage(data) return nil } @@ -187,21 +237,21 @@ func (e *Event) MarshalJSON() ([]byte, error) { type embed Event var marshaler = struct { embed - CreatedAt *core.DateTime `json:"created_at,omitempty"` + CreatedAt *internal.DateTime `json:"created_at,omitempty"` }{ embed: embed(*e), - CreatedAt: core.NewOptionalDateTime(e.CreatedAt), + CreatedAt: internal.NewOptionalDateTime(e.CreatedAt), } return json.Marshal(marshaler) } func (e *Event) String() string { - if len(e._rawJSON) > 0 { - if value, err := core.StringifyJSON(e._rawJSON); err == nil { + if len(e.rawJSON) > 0 { + if value, err := internal.StringifyJSON(e.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(e); err == nil { + if value, err := internal.StringifyJSON(e); err == nil { return value } return fmt.Sprintf("%#v", e) @@ -231,7 +281,77 @@ type FinetunedModel struct { LastUsed *time.Time `json:"last_used,omitempty" url:"last_used,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (f *FinetunedModel) GetId() *string { + if f == nil { + return nil + } + return f.Id +} + +func (f *FinetunedModel) GetName() string { + if f == nil { + return "" + } + return f.Name +} + +func (f *FinetunedModel) GetCreatorId() *string { + if f == nil { + return nil + } + return f.CreatorId +} + +func (f *FinetunedModel) GetOrganizationId() *string { + if f == nil { + return nil + } + return f.OrganizationId +} + +func (f *FinetunedModel) GetSettings() *Settings { + if f == nil { + return nil + } + return f.Settings +} + +func (f *FinetunedModel) GetStatus() *Status { + if f == nil { + return nil + } + return f.Status +} + +func (f *FinetunedModel) GetCreatedAt() *time.Time { + if f == nil { + return nil + } + return f.CreatedAt +} + +func (f *FinetunedModel) GetUpdatedAt() *time.Time { + if f == nil { + return nil + } + return f.UpdatedAt +} + +func (f *FinetunedModel) GetCompletedAt() *time.Time { + if f == nil { + return nil + } + return f.CompletedAt +} + +func (f *FinetunedModel) GetLastUsed() *time.Time { + if f == nil { + return nil + } + return f.LastUsed } func (f *FinetunedModel) GetExtraProperties() map[string]interface{} { @@ -242,10 +362,10 @@ func (f *FinetunedModel) UnmarshalJSON(data []byte) error { type embed FinetunedModel var unmarshaler = struct { embed - CreatedAt *core.DateTime `json:"created_at,omitempty"` - UpdatedAt *core.DateTime `json:"updated_at,omitempty"` - CompletedAt *core.DateTime `json:"completed_at,omitempty"` - LastUsed *core.DateTime `json:"last_used,omitempty"` + CreatedAt *internal.DateTime `json:"created_at,omitempty"` + UpdatedAt *internal.DateTime `json:"updated_at,omitempty"` + CompletedAt *internal.DateTime `json:"completed_at,omitempty"` + LastUsed *internal.DateTime `json:"last_used,omitempty"` }{ embed: embed(*f), } @@ -257,14 +377,12 @@ func (f *FinetunedModel) UnmarshalJSON(data []byte) error { f.UpdatedAt = unmarshaler.UpdatedAt.TimePtr() f.CompletedAt = unmarshaler.CompletedAt.TimePtr() f.LastUsed = unmarshaler.LastUsed.TimePtr() - - extraProperties, err := core.ExtractExtraProperties(data, *f) + extraProperties, err := internal.ExtractExtraProperties(data, *f) if err != nil { return err } f.extraProperties = extraProperties - - f._rawJSON = json.RawMessage(data) + f.rawJSON = json.RawMessage(data) return nil } @@ -272,27 +390,27 @@ func (f *FinetunedModel) MarshalJSON() ([]byte, error) { type embed FinetunedModel var marshaler = struct { embed - CreatedAt *core.DateTime `json:"created_at,omitempty"` - UpdatedAt *core.DateTime `json:"updated_at,omitempty"` - CompletedAt *core.DateTime `json:"completed_at,omitempty"` - LastUsed *core.DateTime `json:"last_used,omitempty"` + CreatedAt *internal.DateTime `json:"created_at,omitempty"` + UpdatedAt *internal.DateTime `json:"updated_at,omitempty"` + CompletedAt *internal.DateTime `json:"completed_at,omitempty"` + LastUsed *internal.DateTime `json:"last_used,omitempty"` }{ embed: embed(*f), - CreatedAt: core.NewOptionalDateTime(f.CreatedAt), - UpdatedAt: core.NewOptionalDateTime(f.UpdatedAt), - CompletedAt: core.NewOptionalDateTime(f.CompletedAt), - LastUsed: core.NewOptionalDateTime(f.LastUsed), + CreatedAt: internal.NewOptionalDateTime(f.CreatedAt), + UpdatedAt: internal.NewOptionalDateTime(f.UpdatedAt), + CompletedAt: internal.NewOptionalDateTime(f.CompletedAt), + LastUsed: internal.NewOptionalDateTime(f.LastUsed), } return json.Marshal(marshaler) } func (f *FinetunedModel) String() string { - if len(f._rawJSON) > 0 { - if value, err := core.StringifyJSON(f._rawJSON); err == nil { + if len(f.rawJSON) > 0 { + if value, err := internal.StringifyJSON(f.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(f); err == nil { + if value, err := internal.StringifyJSON(f); err == nil { return value } return fmt.Sprintf("%#v", f) @@ -304,7 +422,14 @@ type GetFinetunedModelResponse struct { FinetunedModel *FinetunedModel `json:"finetuned_model,omitempty" url:"finetuned_model,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (g *GetFinetunedModelResponse) GetFinetunedModel() *FinetunedModel { + if g == nil { + return nil + } + return g.FinetunedModel } func (g *GetFinetunedModelResponse) GetExtraProperties() map[string]interface{} { @@ -318,24 +443,22 @@ func (g *GetFinetunedModelResponse) UnmarshalJSON(data []byte) error { return err } *g = GetFinetunedModelResponse(value) - - extraProperties, err := core.ExtractExtraProperties(data, *g) + extraProperties, err := internal.ExtractExtraProperties(data, *g) if err != nil { return err } g.extraProperties = extraProperties - - g._rawJSON = json.RawMessage(data) + g.rawJSON = json.RawMessage(data) return nil } func (g *GetFinetunedModelResponse) String() string { - if len(g._rawJSON) > 0 { - if value, err := core.StringifyJSON(g._rawJSON); err == nil { + if len(g.rawJSON) > 0 { + if value, err := internal.StringifyJSON(g.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(g); err == nil { + if value, err := internal.StringifyJSON(g); err == nil { return value } return fmt.Sprintf("%#v", g) @@ -365,7 +488,63 @@ type Hyperparameters struct { LoraTargetModules *LoraTargetModules `json:"lora_target_modules,omitempty" url:"lora_target_modules,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (h *Hyperparameters) GetEarlyStoppingPatience() *int { + if h == nil { + return nil + } + return h.EarlyStoppingPatience +} + +func (h *Hyperparameters) GetEarlyStoppingThreshold() *float64 { + if h == nil { + return nil + } + return h.EarlyStoppingThreshold +} + +func (h *Hyperparameters) GetTrainBatchSize() *int { + if h == nil { + return nil + } + return h.TrainBatchSize +} + +func (h *Hyperparameters) GetTrainEpochs() *int { + if h == nil { + return nil + } + return h.TrainEpochs +} + +func (h *Hyperparameters) GetLearningRate() *float64 { + if h == nil { + return nil + } + return h.LearningRate +} + +func (h *Hyperparameters) GetLoraAlpha() *int { + if h == nil { + return nil + } + return h.LoraAlpha +} + +func (h *Hyperparameters) GetLoraRank() *int { + if h == nil { + return nil + } + return h.LoraRank +} + +func (h *Hyperparameters) GetLoraTargetModules() *LoraTargetModules { + if h == nil { + return nil + } + return h.LoraTargetModules } func (h *Hyperparameters) GetExtraProperties() map[string]interface{} { @@ -379,24 +558,22 @@ func (h *Hyperparameters) UnmarshalJSON(data []byte) error { return err } *h = Hyperparameters(value) - - extraProperties, err := core.ExtractExtraProperties(data, *h) + extraProperties, err := internal.ExtractExtraProperties(data, *h) if err != nil { return err } h.extraProperties = extraProperties - - h._rawJSON = json.RawMessage(data) + h.rawJSON = json.RawMessage(data) return nil } func (h *Hyperparameters) String() string { - if len(h._rawJSON) > 0 { - if value, err := core.StringifyJSON(h._rawJSON); err == nil { + if len(h.rawJSON) > 0 { + if value, err := internal.StringifyJSON(h.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(h); err == nil { + if value, err := internal.StringifyJSON(h); err == nil { return value } return fmt.Sprintf("%#v", h) @@ -413,7 +590,28 @@ type ListEventsResponse struct { TotalSize *int `json:"total_size,omitempty" url:"total_size,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (l *ListEventsResponse) GetEvents() []*Event { + if l == nil { + return nil + } + return l.Events +} + +func (l *ListEventsResponse) GetNextPageToken() *string { + if l == nil { + return nil + } + return l.NextPageToken +} + +func (l *ListEventsResponse) GetTotalSize() *int { + if l == nil { + return nil + } + return l.TotalSize } func (l *ListEventsResponse) GetExtraProperties() map[string]interface{} { @@ -427,24 +625,22 @@ func (l *ListEventsResponse) UnmarshalJSON(data []byte) error { return err } *l = ListEventsResponse(value) - - extraProperties, err := core.ExtractExtraProperties(data, *l) + extraProperties, err := internal.ExtractExtraProperties(data, *l) if err != nil { return err } l.extraProperties = extraProperties - - l._rawJSON = json.RawMessage(data) + l.rawJSON = json.RawMessage(data) return nil } func (l *ListEventsResponse) String() string { - if len(l._rawJSON) > 0 { - if value, err := core.StringifyJSON(l._rawJSON); err == nil { + if len(l.rawJSON) > 0 { + if value, err := internal.StringifyJSON(l.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(l); err == nil { + if value, err := internal.StringifyJSON(l); err == nil { return value } return fmt.Sprintf("%#v", l) @@ -461,7 +657,28 @@ type ListFinetunedModelsResponse struct { TotalSize *int `json:"total_size,omitempty" url:"total_size,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (l *ListFinetunedModelsResponse) GetFinetunedModels() []*FinetunedModel { + if l == nil { + return nil + } + return l.FinetunedModels +} + +func (l *ListFinetunedModelsResponse) GetNextPageToken() *string { + if l == nil { + return nil + } + return l.NextPageToken +} + +func (l *ListFinetunedModelsResponse) GetTotalSize() *int { + if l == nil { + return nil + } + return l.TotalSize } func (l *ListFinetunedModelsResponse) GetExtraProperties() map[string]interface{} { @@ -475,24 +692,22 @@ func (l *ListFinetunedModelsResponse) UnmarshalJSON(data []byte) error { return err } *l = ListFinetunedModelsResponse(value) - - extraProperties, err := core.ExtractExtraProperties(data, *l) + extraProperties, err := internal.ExtractExtraProperties(data, *l) if err != nil { return err } l.extraProperties = extraProperties - - l._rawJSON = json.RawMessage(data) + l.rawJSON = json.RawMessage(data) return nil } func (l *ListFinetunedModelsResponse) String() string { - if len(l._rawJSON) > 0 { - if value, err := core.StringifyJSON(l._rawJSON); err == nil { + if len(l.rawJSON) > 0 { + if value, err := internal.StringifyJSON(l.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(l); err == nil { + if value, err := internal.StringifyJSON(l); err == nil { return value } return fmt.Sprintf("%#v", l) @@ -507,7 +722,21 @@ type ListTrainingStepMetricsResponse struct { NextPageToken *string `json:"next_page_token,omitempty" url:"next_page_token,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (l *ListTrainingStepMetricsResponse) GetStepMetrics() []*TrainingStepMetrics { + if l == nil { + return nil + } + return l.StepMetrics +} + +func (l *ListTrainingStepMetricsResponse) GetNextPageToken() *string { + if l == nil { + return nil + } + return l.NextPageToken } func (l *ListTrainingStepMetricsResponse) GetExtraProperties() map[string]interface{} { @@ -521,24 +750,22 @@ func (l *ListTrainingStepMetricsResponse) UnmarshalJSON(data []byte) error { return err } *l = ListTrainingStepMetricsResponse(value) - - extraProperties, err := core.ExtractExtraProperties(data, *l) + extraProperties, err := internal.ExtractExtraProperties(data, *l) if err != nil { return err } l.extraProperties = extraProperties - - l._rawJSON = json.RawMessage(data) + l.rawJSON = json.RawMessage(data) return nil } func (l *ListTrainingStepMetricsResponse) String() string { - if len(l._rawJSON) > 0 { - if value, err := core.StringifyJSON(l._rawJSON); err == nil { + if len(l.rawJSON) > 0 { + if value, err := internal.StringifyJSON(l.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(l); err == nil { + if value, err := internal.StringifyJSON(l); err == nil { return value } return fmt.Sprintf("%#v", l) @@ -546,10 +773,10 @@ func (l *ListTrainingStepMetricsResponse) String() string { // The possible combinations of LoRA modules to target. // -// - LORA_TARGET_MODULES_UNSPECIFIED: Unspecified LoRA target modules. -// - LORA_TARGET_MODULES_QV: LoRA adapts the query and value matrices in transformer attention layers. -// - LORA_TARGET_MODULES_QKVO: LoRA adapts query, key, value, and output matrices in attention layers. -// - LORA_TARGET_MODULES_QKVO_FFN: LoRA adapts attention projection matrices and feed-forward networks (FFN). +// - LORA_TARGET_MODULES_UNSPECIFIED: Unspecified LoRA target modules. +// - LORA_TARGET_MODULES_QV: LoRA adapts the query and value matrices in transformer attention layers. +// - LORA_TARGET_MODULES_QKVO: LoRA adapts query, key, value, and output matrices in attention layers. +// - LORA_TARGET_MODULES_QKVO_FFN: LoRA adapts attention projection matrices and feed-forward networks (FFN). type LoraTargetModules string const ( @@ -592,7 +819,42 @@ type Settings struct { Wandb *WandbConfig `json:"wandb,omitempty" url:"wandb,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (s *Settings) GetBaseModel() *BaseModel { + if s == nil { + return nil + } + return s.BaseModel +} + +func (s *Settings) GetDatasetId() string { + if s == nil { + return "" + } + return s.DatasetId +} + +func (s *Settings) GetHyperparameters() *Hyperparameters { + if s == nil { + return nil + } + return s.Hyperparameters +} + +func (s *Settings) GetMultiLabel() *bool { + if s == nil { + return nil + } + return s.MultiLabel +} + +func (s *Settings) GetWandb() *WandbConfig { + if s == nil { + return nil + } + return s.Wandb } func (s *Settings) GetExtraProperties() map[string]interface{} { @@ -606,24 +868,22 @@ func (s *Settings) UnmarshalJSON(data []byte) error { return err } *s = Settings(value) - - extraProperties, err := core.ExtractExtraProperties(data, *s) + extraProperties, err := internal.ExtractExtraProperties(data, *s) if err != nil { return err } s.extraProperties = extraProperties - - s._rawJSON = json.RawMessage(data) + s.rawJSON = json.RawMessage(data) return nil } func (s *Settings) String() string { - if len(s._rawJSON) > 0 { - if value, err := core.StringifyJSON(s._rawJSON); err == nil { + if len(s.rawJSON) > 0 { + if value, err := internal.StringifyJSON(s.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(s); err == nil { + if value, err := internal.StringifyJSON(s); err == nil { return value } return fmt.Sprintf("%#v", s) @@ -631,15 +891,15 @@ func (s *Settings) String() string { // The possible stages of a fine-tuned model life-cycle. // -// - STATUS_UNSPECIFIED: Unspecified status. -// - STATUS_FINETUNING: The fine-tuned model is being fine-tuned. -// - STATUS_DEPLOYING_API: Deprecated: The fine-tuned model is being deployed. -// - STATUS_READY: The fine-tuned model is ready to receive requests. -// - STATUS_FAILED: The fine-tuned model failed. -// - STATUS_DELETED: The fine-tuned model was deleted. -// - STATUS_TEMPORARILY_OFFLINE: Deprecated: The fine-tuned model is temporarily unavailable. -// - STATUS_PAUSED: Deprecated: The fine-tuned model is paused (Vanilla only). -// - STATUS_QUEUED: The fine-tuned model is queued for training. +// - STATUS_UNSPECIFIED: Unspecified status. +// - STATUS_FINETUNING: The fine-tuned model is being fine-tuned. +// - STATUS_DEPLOYING_API: Deprecated: The fine-tuned model is being deployed. +// - STATUS_READY: The fine-tuned model is ready to receive requests. +// - STATUS_FAILED: The fine-tuned model failed. +// - STATUS_DELETED: The fine-tuned model was deleted. +// - STATUS_TEMPORARILY_OFFLINE: Deprecated: The fine-tuned model is temporarily unavailable. +// - STATUS_PAUSED: Deprecated: The fine-tuned model is paused (Vanilla only). +// - STATUS_QUEUED: The fine-tuned model is queued for training. type Status string const ( @@ -685,9 +945,9 @@ func (s Status) Ptr() *Status { // The possible strategy used to serve a fine-tuned models. // -// - STRATEGY_UNSPECIFIED: Unspecified strategy. -// - STRATEGY_VANILLA: Deprecated: Serve the fine-tuned model on a dedicated GPU. -// - STRATEGY_TFEW: Deprecated: Serve the fine-tuned model on a shared GPU. +// - STRATEGY_UNSPECIFIED: Unspecified strategy. +// - STRATEGY_VANILLA: Deprecated: Serve the fine-tuned model on a dedicated GPU. +// - STRATEGY_TFEW: Deprecated: Serve the fine-tuned model on a shared GPU. type Strategy string const ( @@ -723,7 +983,28 @@ type TrainingStepMetrics struct { Metrics map[string]float64 `json:"metrics,omitempty" url:"metrics,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (t *TrainingStepMetrics) GetCreatedAt() *time.Time { + if t == nil { + return nil + } + return t.CreatedAt +} + +func (t *TrainingStepMetrics) GetStepNumber() *int { + if t == nil { + return nil + } + return t.StepNumber +} + +func (t *TrainingStepMetrics) GetMetrics() map[string]float64 { + if t == nil { + return nil + } + return t.Metrics } func (t *TrainingStepMetrics) GetExtraProperties() map[string]interface{} { @@ -734,7 +1015,7 @@ func (t *TrainingStepMetrics) UnmarshalJSON(data []byte) error { type embed TrainingStepMetrics var unmarshaler = struct { embed - CreatedAt *core.DateTime `json:"created_at,omitempty"` + CreatedAt *internal.DateTime `json:"created_at,omitempty"` }{ embed: embed(*t), } @@ -743,14 +1024,12 @@ func (t *TrainingStepMetrics) UnmarshalJSON(data []byte) error { } *t = TrainingStepMetrics(unmarshaler.embed) t.CreatedAt = unmarshaler.CreatedAt.TimePtr() - - extraProperties, err := core.ExtractExtraProperties(data, *t) + extraProperties, err := internal.ExtractExtraProperties(data, *t) if err != nil { return err } t.extraProperties = extraProperties - - t._rawJSON = json.RawMessage(data) + t.rawJSON = json.RawMessage(data) return nil } @@ -758,21 +1037,21 @@ func (t *TrainingStepMetrics) MarshalJSON() ([]byte, error) { type embed TrainingStepMetrics var marshaler = struct { embed - CreatedAt *core.DateTime `json:"created_at,omitempty"` + CreatedAt *internal.DateTime `json:"created_at,omitempty"` }{ embed: embed(*t), - CreatedAt: core.NewOptionalDateTime(t.CreatedAt), + CreatedAt: internal.NewOptionalDateTime(t.CreatedAt), } return json.Marshal(marshaler) } func (t *TrainingStepMetrics) String() string { - if len(t._rawJSON) > 0 { - if value, err := core.StringifyJSON(t._rawJSON); err == nil { + if len(t.rawJSON) > 0 { + if value, err := internal.StringifyJSON(t.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(t); err == nil { + if value, err := internal.StringifyJSON(t); err == nil { return value } return fmt.Sprintf("%#v", t) @@ -784,7 +1063,14 @@ type UpdateFinetunedModelResponse struct { FinetunedModel *FinetunedModel `json:"finetuned_model,omitempty" url:"finetuned_model,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (u *UpdateFinetunedModelResponse) GetFinetunedModel() *FinetunedModel { + if u == nil { + return nil + } + return u.FinetunedModel } func (u *UpdateFinetunedModelResponse) GetExtraProperties() map[string]interface{} { @@ -798,24 +1084,22 @@ func (u *UpdateFinetunedModelResponse) UnmarshalJSON(data []byte) error { return err } *u = UpdateFinetunedModelResponse(value) - - extraProperties, err := core.ExtractExtraProperties(data, *u) + extraProperties, err := internal.ExtractExtraProperties(data, *u) if err != nil { return err } u.extraProperties = extraProperties - - u._rawJSON = json.RawMessage(data) + u.rawJSON = json.RawMessage(data) return nil } func (u *UpdateFinetunedModelResponse) String() string { - if len(u._rawJSON) > 0 { - if value, err := core.StringifyJSON(u._rawJSON); err == nil { + if len(u.rawJSON) > 0 { + if value, err := internal.StringifyJSON(u.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(u); err == nil { + if value, err := internal.StringifyJSON(u); err == nil { return value } return fmt.Sprintf("%#v", u) @@ -831,7 +1115,28 @@ type WandbConfig struct { Entity *string `json:"entity,omitempty" url:"entity,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (w *WandbConfig) GetProject() string { + if w == nil { + return "" + } + return w.Project +} + +func (w *WandbConfig) GetApiKey() string { + if w == nil { + return "" + } + return w.ApiKey +} + +func (w *WandbConfig) GetEntity() *string { + if w == nil { + return nil + } + return w.Entity } func (w *WandbConfig) GetExtraProperties() map[string]interface{} { @@ -845,24 +1150,22 @@ func (w *WandbConfig) UnmarshalJSON(data []byte) error { return err } *w = WandbConfig(value) - - extraProperties, err := core.ExtractExtraProperties(data, *w) + extraProperties, err := internal.ExtractExtraProperties(data, *w) if err != nil { return err } w.extraProperties = extraProperties - - w._rawJSON = json.RawMessage(data) + w.rawJSON = json.RawMessage(data) return nil } func (w *WandbConfig) String() string { - if len(w._rawJSON) > 0 { - if value, err := core.StringifyJSON(w._rawJSON); err == nil { + if len(w.rawJSON) > 0 { + if value, err := internal.StringifyJSON(w.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(w); err == nil { + if value, err := internal.StringifyJSON(w); err == nil { return value } return fmt.Sprintf("%#v", w) diff --git a/core/core.go b/internal/caller.go similarity index 69% rename from core/core.go rename to internal/caller.go index 6b5a8f3..2f3e914 100644 --- a/core/core.go +++ b/internal/caller.go @@ -1,4 +1,4 @@ -package core +package internal import ( "bytes" @@ -7,11 +7,12 @@ import ( "errors" "fmt" "io" - "mime/multipart" "net/http" "net/url" "reflect" "strings" + + "github.com/cohere-ai/cohere-go/v2/core" ) const ( @@ -20,105 +21,21 @@ const ( contentTypeHeader = "Content-Type" ) -// HTTPClient is an interface for a subset of the *http.Client. -type HTTPClient interface { - Do(*http.Request) (*http.Response, error) -} - -// EncodeURL encodes the given arguments into the URL, escaping -// values as needed. -func EncodeURL(urlFormat string, args ...interface{}) string { - escapedArgs := make([]interface{}, 0, len(args)) - for _, arg := range args { - escapedArgs = append(escapedArgs, url.PathEscape(fmt.Sprintf("%v", arg))) - } - return fmt.Sprintf(urlFormat, escapedArgs...) -} - -// MergeHeaders merges the given headers together, where the right -// takes precedence over the left. -func MergeHeaders(left, right http.Header) http.Header { - for key, values := range right { - if len(values) > 1 { - left[key] = values - continue - } - if value := right.Get(key); value != "" { - left.Set(key, value) - } - } - return left -} - -// WriteMultipartJSON writes the given value as a JSON part. -// This is used to serialize non-primitive multipart properties -// (i.e. lists, objects, etc). -func WriteMultipartJSON(writer *multipart.Writer, field string, value interface{}) error { - bytes, err := json.Marshal(value) - if err != nil { - return err - } - return writer.WriteField(field, string(bytes)) -} - -// APIError is a lightweight wrapper around the standard error -// interface that preserves the status code from the RPC, if any. -type APIError struct { - err error - - StatusCode int `json:"-"` -} - -// NewAPIError constructs a new API error. -func NewAPIError(statusCode int, err error) *APIError { - return &APIError{ - err: err, - StatusCode: statusCode, - } -} - -// Unwrap returns the underlying error. This also makes the error compatible -// with errors.As and errors.Is. -func (a *APIError) Unwrap() error { - if a == nil { - return nil - } - return a.err -} - -// Error returns the API error's message. -func (a *APIError) Error() string { - if a == nil || (a.err == nil && a.StatusCode == 0) { - return "" - } - if a.err == nil { - return fmt.Sprintf("%d", a.StatusCode) - } - if a.StatusCode == 0 { - return a.err.Error() - } - return fmt.Sprintf("%d: %s", a.StatusCode, a.err.Error()) -} - -// ErrorDecoder decodes *http.Response errors and returns a -// typed API error (e.g. *APIError). -type ErrorDecoder func(statusCode int, body io.Reader) error - // Caller calls APIs and deserializes their response, if any. type Caller struct { - client HTTPClient + client core.HTTPClient retrier *Retrier } // CallerParams represents the parameters used to constrcut a new *Caller. type CallerParams struct { - Client HTTPClient + Client core.HTTPClient MaxAttempts uint } // NewCaller returns a new *Caller backed by the given parameters. func NewCaller(params *CallerParams) *Caller { - var httpClient HTTPClient = http.DefaultClient + var httpClient core.HTTPClient = http.DefaultClient if params.Client != nil { httpClient = params.Client } @@ -140,7 +57,7 @@ type CallParams struct { Headers http.Header BodyProperties map[string]interface{} QueryParameters url.Values - Client HTTPClient + Client core.HTTPClient Request interface{} Response interface{} ResponseIsOptional bool @@ -309,9 +226,9 @@ func decodeError(response *http.Response, errorDecoder ErrorDecoder) error { // The error didn't have a response body, // so all we can do is return an error // with the status code. - return NewAPIError(response.StatusCode, nil) + return core.NewAPIError(response.StatusCode, nil) } - return NewAPIError(response.StatusCode, errors.New(string(bytes))) + return core.NewAPIError(response.StatusCode, errors.New(string(bytes))) } // isNil is used to determine if the request value is equal to nil (i.e. an interface diff --git a/core/core_test.go b/internal/caller_test.go similarity index 97% rename from core/core_test.go rename to internal/caller_test.go index e6eaef3..2f012cc 100644 --- a/core/core_test.go +++ b/internal/caller_test.go @@ -1,4 +1,4 @@ -package core +package internal import ( "bytes" @@ -13,6 +13,7 @@ import ( "strconv" "testing" + "github.com/cohere-ai/cohere-go/v2/core" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -50,7 +51,7 @@ type Response struct { // NotFoundError represents a 404. type NotFoundError struct { - *APIError + *core.APIError Message string `json:"message"` } @@ -98,7 +99,7 @@ func TestCall(t *testing.T) { }, giveErrorDecoder: newTestErrorDecoder(t), wantError: &NotFoundError{ - APIError: NewAPIError( + APIError: core.NewAPIError( http.StatusNotFound, errors.New(`{"message":"ID \"404\" not found"}`), ), @@ -111,7 +112,7 @@ func TestCall(t *testing.T) { "X-API-Status": []string{"fail"}, }, giveRequest: nil, - wantError: NewAPIError( + wantError: core.NewAPIError( http.StatusBadRequest, errors.New("invalid request"), ), @@ -136,7 +137,7 @@ func TestCall(t *testing.T) { giveRequest: &Request{ Id: strconv.Itoa(http.StatusInternalServerError), }, - wantError: NewAPIError( + wantError: core.NewAPIError( http.StatusInternalServerError, errors.New("failed to process request"), ), @@ -324,7 +325,7 @@ func newTestServer(t *testing.T, tc *TestCase) *httptest.Server { switch request.Id { case strconv.Itoa(http.StatusNotFound): notFoundError := &NotFoundError{ - APIError: &APIError{ + APIError: &core.APIError{ StatusCode: http.StatusNotFound, }, Message: fmt.Sprintf("ID %q not found", request.Id), @@ -375,7 +376,7 @@ func newTestErrorDecoder(t *testing.T) func(int, io.Reader) error { require.NoError(t, err) var ( - apiError = NewAPIError(statusCode, errors.New(string(raw))) + apiError = core.NewAPIError(statusCode, errors.New(string(raw))) decoder = json.NewDecoder(bytes.NewReader(raw)) ) if statusCode == http.StatusNotFound { diff --git a/internal/error_decoder.go b/internal/error_decoder.go new file mode 100644 index 0000000..16bf36d --- /dev/null +++ b/internal/error_decoder.go @@ -0,0 +1,45 @@ +package internal + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + + "github.com/cohere-ai/cohere-go/v2/core" +) + +// ErrorDecoder decodes *http.Response errors and returns a +// typed API error (e.g. *core.APIError). +type ErrorDecoder func(statusCode int, body io.Reader) error + +// ErrorCodes maps HTTP status codes to error constructors. +type ErrorCodes map[int]func(*core.APIError) error + +// NewErrorDecoder returns a new ErrorDecoder backed by the given error codes. +func NewErrorDecoder(errorCodes ErrorCodes) ErrorDecoder { + return func(statusCode int, body io.Reader) error { + raw, err := io.ReadAll(body) + if err != nil { + return fmt.Errorf("failed to read error from response body: %w", err) + } + apiError := core.NewAPIError( + statusCode, + errors.New(string(raw)), + ) + newErrorFunc, ok := errorCodes[statusCode] + if !ok { + // This status code isn't recognized, so we return + // the API error as-is. + return apiError + } + customError := newErrorFunc(apiError) + if err := json.NewDecoder(bytes.NewReader(raw)).Decode(customError); err != nil { + // If we fail to decode the error, we return the + // API error as-is. + return apiError + } + return customError + } +} diff --git a/internal/error_decoder_test.go b/internal/error_decoder_test.go new file mode 100644 index 0000000..43eab7b --- /dev/null +++ b/internal/error_decoder_test.go @@ -0,0 +1,55 @@ +package internal + +import ( + "bytes" + "errors" + "net/http" + "testing" + + "github.com/cohere-ai/cohere-go/v2/core" + "github.com/stretchr/testify/assert" +) + +func TestErrorDecoder(t *testing.T) { + decoder := NewErrorDecoder( + ErrorCodes{ + http.StatusNotFound: func(apiError *core.APIError) error { + return &NotFoundError{APIError: apiError} + }, + }) + + tests := []struct { + description string + giveStatusCode int + giveBody string + wantError error + }{ + { + description: "unrecognized status code", + giveStatusCode: http.StatusInternalServerError, + giveBody: "Internal Server Error", + wantError: core.NewAPIError(http.StatusInternalServerError, errors.New("Internal Server Error")), + }, + { + description: "not found with valid JSON", + giveStatusCode: http.StatusNotFound, + giveBody: `{"message": "Resource not found"}`, + wantError: &NotFoundError{ + APIError: core.NewAPIError(http.StatusNotFound, errors.New(`{"message": "Resource not found"}`)), + Message: "Resource not found", + }, + }, + { + description: "not found with invalid JSON", + giveStatusCode: http.StatusNotFound, + giveBody: `Resource not found`, + wantError: core.NewAPIError(http.StatusNotFound, errors.New("Resource not found")), + }, + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + assert.Equal(t, tt.wantError, decoder(tt.giveStatusCode, bytes.NewReader([]byte(tt.giveBody)))) + }) + } +} diff --git a/core/extra_properties.go b/internal/extra_properties.go similarity index 99% rename from core/extra_properties.go rename to internal/extra_properties.go index a6af3e1..540c3fd 100644 --- a/core/extra_properties.go +++ b/internal/extra_properties.go @@ -1,4 +1,4 @@ -package core +package internal import ( "bytes" diff --git a/core/extra_properties_test.go b/internal/extra_properties_test.go similarity index 99% rename from core/extra_properties_test.go rename to internal/extra_properties_test.go index dc66fcc..aa2510e 100644 --- a/core/extra_properties_test.go +++ b/internal/extra_properties_test.go @@ -1,4 +1,4 @@ -package core +package internal import ( "encoding/json" diff --git a/internal/http.go b/internal/http.go new file mode 100644 index 0000000..768968b --- /dev/null +++ b/internal/http.go @@ -0,0 +1,48 @@ +package internal + +import ( + "fmt" + "net/http" + "net/url" +) + +// HTTPClient is an interface for a subset of the *http.Client. +type HTTPClient interface { + Do(*http.Request) (*http.Response, error) +} + +// ResolveBaseURL resolves the base URL from the given arguments, +// preferring the first non-empty value. +func ResolveBaseURL(values ...string) string { + for _, value := range values { + if value != "" { + return value + } + } + return "" +} + +// EncodeURL encodes the given arguments into the URL, escaping +// values as needed. +func EncodeURL(urlFormat string, args ...interface{}) string { + escapedArgs := make([]interface{}, 0, len(args)) + for _, arg := range args { + escapedArgs = append(escapedArgs, url.PathEscape(fmt.Sprintf("%v", arg))) + } + return fmt.Sprintf(urlFormat, escapedArgs...) +} + +// MergeHeaders merges the given headers together, where the right +// takes precedence over the left. +func MergeHeaders(left, right http.Header) http.Header { + for key, values := range right { + if len(values) > 1 { + left[key] = values + continue + } + if value := right.Get(key); value != "" { + left.Set(key, value) + } + } + return left +} diff --git a/internal/multipart.go b/internal/multipart.go new file mode 100644 index 0000000..67a2229 --- /dev/null +++ b/internal/multipart.go @@ -0,0 +1,202 @@ +package internal + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/textproto" + "strings" +) + +// Named is implemented by types that define a name. +type Named interface { + Name() string +} + +// ContentTyped is implemented by types that define a Content-Type. +type ContentTyped interface { + ContentType() string +} + +// WriteMultipartOption adapts the behavior of the multipart writer. +type WriteMultipartOption func(*writeMultipartOptions) + +// WithDefaultContentType sets the default Content-Type for the part +// written to the MultipartWriter. +// +// Note that if the part is a FileParam, the file's Content-Type takes +// precedence over the value provided here. +func WithDefaultContentType(contentType string) WriteMultipartOption { + return func(options *writeMultipartOptions) { + options.defaultContentType = contentType + } +} + +// MultipartWriter writes multipart/form-data requests. +type MultipartWriter struct { + buffer *bytes.Buffer + writer *multipart.Writer +} + +// NewMultipartWriter creates a new multipart writer. +func NewMultipartWriter() *MultipartWriter { + buffer := bytes.NewBuffer(nil) + return &MultipartWriter{ + buffer: buffer, + writer: multipart.NewWriter(buffer), + } +} + +// Buffer returns the underlying buffer. +func (w *MultipartWriter) Buffer() *bytes.Buffer { + return w.buffer +} + +// ContentType returns the Content-Type for an HTTP multipart/form-data. +func (w *MultipartWriter) ContentType() string { + return w.writer.FormDataContentType() +} + +// WriteFile writes the given file part. +func (w *MultipartWriter) WriteFile( + field string, + file io.Reader, + opts ...WriteMultipartOption, +) error { + options := newWriteMultipartOptions(opts...) + return w.writeFile(field, file, options.defaultContentType) +} + +// WriteField writes the given value as a form field. +func (w *MultipartWriter) WriteField( + field string, + value string, + opts ...WriteMultipartOption, +) error { + options := newWriteMultipartOptions(opts...) + return w.writeField(field, value, options.defaultContentType) +} + +// WriteJSON writes the given value as a JSON form field. +func (w *MultipartWriter) WriteJSON( + field string, + value interface{}, + opts ...WriteMultipartOption, +) error { + bytes, err := json.Marshal(value) + if err != nil { + return err + } + return w.WriteField(field, string(bytes), opts...) +} + +// Close closes the writer. +func (w *MultipartWriter) Close() error { + return w.writer.Close() +} + +func (w *MultipartWriter) writeField( + field string, + value string, + contentType string, +) error { + part, err := w.newFormField(field, contentType) + if err != nil { + return err + } + _, err = part.Write([]byte(value)) + return err +} + +func (w *MultipartWriter) writeFile( + field string, + file io.Reader, + defaultContentType string, +) error { + var ( + filename = getFilename(file) + contentType = getContentType(file) + ) + if contentType == "" { + contentType = defaultContentType + } + part, err := w.newFormPart(field, filename, contentType) + if err != nil { + return err + } + _, err = io.Copy(part, file) + return err +} + +// newFormField creates a new form field. +func (w *MultipartWriter) newFormField( + field string, + contentType string, +) (io.Writer, error) { + return w.newFormPart(field, "" /* filename */, contentType) +} + +// newFormPart creates a new form data part. +func (w *MultipartWriter) newFormPart( + field string, + filename string, + contentType string, +) (io.Writer, error) { + h := make(textproto.MIMEHeader) + h.Set("Content-Disposition", getContentDispositionHeaderValue(field, filename)) + if contentType != "" { + h.Set("Content-Type", contentType) + } + return w.writer.CreatePart(h) +} + +// writeMultipartOptions are options used to adapt the behavior of the multipart writer. +type writeMultipartOptions struct { + defaultContentType string +} + +// newWriteMultipartOptions returns a new write multipart options. +func newWriteMultipartOptions(opts ...WriteMultipartOption) *writeMultipartOptions { + options := new(writeMultipartOptions) + for _, opt := range opts { + opt(options) + } + return options +} + +// getContentType returns the Content-Type for the given file, if any. +func getContentType(file io.Reader) string { + if v, ok := file.(ContentTyped); ok { + return v.ContentType() + } + return "" +} + +// getFilename returns the name for the given file, if any. +func getFilename(file io.Reader) string { + if v, ok := file.(Named); ok { + return v.Name() + } + return "" +} + +// getContentDispositionHeaderValue returns the value for the Content-Disposition header. +func getContentDispositionHeaderValue(field string, filename string) string { + contentDisposition := fmt.Sprintf("form-data; name=%q", field) + if filename != "" { + contentDisposition += fmt.Sprintf(`; filename=%q`, escapeQuotes(filename)) + } + return contentDisposition +} + +// https://cs.opensource.google/go/go/+/refs/tags/go1.23.2:src/mime/multipart/writer.go;l=132 +var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") + +// escapeQuotes is directly referenced from the standard library. +// +// https://cs.opensource.google/go/go/+/refs/tags/go1.23.2:src/mime/multipart/writer.go;l=134 +func escapeQuotes(s string) string { + return quoteEscaper.Replace(s) +} diff --git a/internal/multipart_test.go b/internal/multipart_test.go new file mode 100644 index 0000000..07008dd --- /dev/null +++ b/internal/multipart_test.go @@ -0,0 +1,260 @@ +package internal + +import ( + "encoding/json" + "io" + "mime/multipart" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const maxFormMemory = 32 << 20 // 32MB + +type mockFile struct { + name string + content string + contentType string + + reader io.Reader +} + +func (f *mockFile) Read(p []byte) (n int, err error) { + if f.reader == nil { + f.reader = strings.NewReader(f.content) + } + return f.reader.Read(p) +} + +func (f *mockFile) Name() string { + return f.name +} + +func (f *mockFile) ContentType() string { + return f.contentType +} + +func TestMultipartWriter(t *testing.T) { + t.Run("empty", func(t *testing.T) { + w := NewMultipartWriter() + assert.NotNil(t, w.Buffer()) + assert.Contains(t, w.ContentType(), "multipart/form-data; boundary=") + require.NoError(t, w.Close()) + }) + + t.Run("write field", func(t *testing.T) { + tests := []struct { + desc string + giveField string + giveValue string + giveContentType string + }{ + { + desc: "empty field", + giveField: "empty", + giveValue: "", + }, + { + desc: "simple field", + giveField: "greeting", + giveValue: "hello world", + }, + { + desc: "field with content type", + giveField: "message", + giveValue: "hello", + giveContentType: "text/plain", + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + w := NewMultipartWriter() + + var opts []WriteMultipartOption + if tt.giveContentType != "" { + opts = append(opts, WithDefaultContentType(tt.giveContentType)) + } + + require.NoError(t, w.WriteField(tt.giveField, tt.giveValue, opts...)) + require.NoError(t, w.Close()) + + reader := multipart.NewReader(w.Buffer(), w.writer.Boundary()) + form, err := reader.ReadForm(maxFormMemory) + require.NoError(t, err) + + assert.Equal(t, []string{tt.giveValue}, form.Value[tt.giveField]) + require.NoError(t, form.RemoveAll()) + }) + } + }) + + t.Run("write file", func(t *testing.T) { + tests := []struct { + desc string + giveField string + giveFile *mockFile + giveContentType string + }{ + { + desc: "simple file", + giveField: "file", + giveFile: &mockFile{ + name: "test.txt", + content: "hello world", + contentType: "text/plain", + }, + }, + { + desc: "file content type takes precedence", + giveField: "file", + giveFile: &mockFile{ + name: "test.txt", + content: "hello world", + contentType: "text/plain", + }, + giveContentType: "application/octet-stream", + }, + { + desc: "default content type", + giveField: "file", + giveFile: &mockFile{ + name: "test.txt", + content: "hello world", + }, + giveContentType: "application/octet-stream", + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + w := NewMultipartWriter() + + var opts []WriteMultipartOption + if tt.giveContentType != "" { + opts = append(opts, WithDefaultContentType(tt.giveContentType)) + } + + require.NoError(t, w.WriteFile(tt.giveField, tt.giveFile, opts...)) + require.NoError(t, w.Close()) + + reader := multipart.NewReader(w.Buffer(), w.writer.Boundary()) + form, err := reader.ReadForm(maxFormMemory) + require.NoError(t, err) + defer func() { + require.NoError(t, form.RemoveAll()) + }() + + files := form.File[tt.giveField] + require.Len(t, files, 1) + + file := files[0] + assert.Equal(t, tt.giveFile.name, file.Filename) + + f, err := file.Open() + require.NoError(t, err) + defer func() { + require.NoError(t, f.Close()) + }() + + content, err := io.ReadAll(f) + require.NoError(t, err) + assert.Equal(t, tt.giveFile.content, string(content)) + + expectedContentType := tt.giveFile.contentType + if expectedContentType == "" { + expectedContentType = tt.giveContentType + } + if expectedContentType != "" { + assert.Equal(t, expectedContentType, file.Header.Get("Content-Type")) + } + }) + } + }) + + t.Run("write JSON", func(t *testing.T) { + type testStruct struct { + Name string `json:"name"` + Value int `json:"value"` + } + + tests := []struct { + desc string + giveField string + giveValue interface{} + }{ + { + desc: "struct", + giveField: "data", + giveValue: testStruct{Name: "test", Value: 123}, + }, + { + desc: "map", + giveField: "data", + giveValue: map[string]string{"key": "value"}, + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + w := NewMultipartWriter() + + require.NoError(t, w.WriteJSON(tt.giveField, tt.giveValue)) + require.NoError(t, w.Close()) + + reader := multipart.NewReader(w.Buffer(), w.writer.Boundary()) + form, err := reader.ReadForm(maxFormMemory) + require.NoError(t, err) + defer func() { + require.NoError(t, form.RemoveAll()) + }() + + expected, err := json.Marshal(tt.giveValue) + require.NoError(t, err) + assert.Equal(t, []string{string(expected)}, form.Value[tt.giveField]) + }) + } + }) + + t.Run("complex", func(t *testing.T) { + w := NewMultipartWriter() + + // Add multiple fields and files + require.NoError(t, w.WriteField("foo", "bar")) + require.NoError(t, w.WriteField("baz", "qux")) + + hello := mockFile{name: "file.txt", content: "Hello, world!", contentType: "text/plain"} + require.NoError(t, w.WriteFile("file", &hello)) + require.NoError(t, w.WriteJSON("data", map[string]string{"key": "value"})) + require.NoError(t, w.Close()) + + reader := multipart.NewReader(w.Buffer(), w.writer.Boundary()) + form, err := reader.ReadForm(maxFormMemory) + require.NoError(t, err) + defer func() { + require.NoError(t, form.RemoveAll()) + }() + + assert.Equal(t, []string{"bar"}, form.Value["foo"]) + assert.Equal(t, []string{"qux"}, form.Value["baz"]) + assert.Equal(t, []string{`{"key":"value"}`}, form.Value["data"]) + + files := form.File["file"] + require.Len(t, files, 1) + + file := files[0] + assert.Equal(t, "file.txt", file.Filename) + + f, err := file.Open() + require.NoError(t, err) + defer func() { + require.NoError(t, f.Close()) + }() + + content, err := io.ReadAll(f) + require.NoError(t, err) + assert.Equal(t, "Hello, world!", string(content)) + }) +} diff --git a/core/query.go b/internal/query.go similarity index 99% rename from core/query.go rename to internal/query.go index 2670ff7..6129e71 100644 --- a/core/query.go +++ b/internal/query.go @@ -1,4 +1,4 @@ -package core +package internal import ( "encoding/base64" diff --git a/core/query_test.go b/internal/query_test.go similarity index 99% rename from core/query_test.go rename to internal/query_test.go index 5498fa9..2e58cca 100644 --- a/core/query_test.go +++ b/internal/query_test.go @@ -1,4 +1,4 @@ -package core +package internal import ( "testing" diff --git a/core/retrier.go b/internal/retrier.go similarity index 98% rename from core/retrier.go rename to internal/retrier.go index ea24916..6040147 100644 --- a/core/retrier.go +++ b/internal/retrier.go @@ -1,4 +1,4 @@ -package core +package internal import ( "crypto/rand" @@ -130,7 +130,6 @@ func (r *Retrier) run( func (r *Retrier) shouldRetry(response *http.Response) bool { return response.StatusCode == http.StatusTooManyRequests || response.StatusCode == http.StatusRequestTimeout || - response.StatusCode == http.StatusConflict || response.StatusCode >= http.StatusInternalServerError } diff --git a/internal/retrier_test.go b/internal/retrier_test.go new file mode 100644 index 0000000..ac0d623 --- /dev/null +++ b/internal/retrier_test.go @@ -0,0 +1,211 @@ +package internal + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/cohere-ai/cohere-go/v2/core" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type RetryTestCase struct { + description string + + giveAttempts uint + giveStatusCodes []int + giveResponse *Response + + wantResponse *Response + wantError *core.APIError +} + +func TestRetrier(t *testing.T) { + tests := []*RetryTestCase{ + { + description: "retry request succeeds after multiple failures", + giveAttempts: 3, + giveStatusCodes: []int{ + http.StatusServiceUnavailable, + http.StatusServiceUnavailable, + http.StatusOK, + }, + giveResponse: &Response{ + Id: "1", + }, + wantResponse: &Response{ + Id: "1", + }, + }, + { + description: "retry request fails if MaxAttempts is exceeded", + giveAttempts: 3, + giveStatusCodes: []int{ + http.StatusRequestTimeout, + http.StatusRequestTimeout, + http.StatusRequestTimeout, + http.StatusOK, + }, + wantError: &core.APIError{ + StatusCode: http.StatusRequestTimeout, + }, + }, + { + description: "retry durations increase exponentially and stay within the min and max delay values", + giveAttempts: 4, + giveStatusCodes: []int{ + http.StatusServiceUnavailable, + http.StatusServiceUnavailable, + http.StatusServiceUnavailable, + http.StatusOK, + }, + }, + { + description: "retry does not occur on status code 404", + giveAttempts: 2, + giveStatusCodes: []int{http.StatusNotFound, http.StatusOK}, + wantError: &core.APIError{ + StatusCode: http.StatusNotFound, + }, + }, + { + description: "retries occur on status code 429", + giveAttempts: 2, + giveStatusCodes: []int{http.StatusTooManyRequests, http.StatusOK}, + }, + { + description: "retries occur on status code 408", + giveAttempts: 2, + giveStatusCodes: []int{http.StatusRequestTimeout, http.StatusOK}, + }, + { + description: "retries occur on status code 500", + giveAttempts: 2, + giveStatusCodes: []int{http.StatusInternalServerError, http.StatusOK}, + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + var ( + test = tc + server = newTestRetryServer(t, test) + client = server.Client() + ) + + t.Parallel() + + caller := NewCaller( + &CallerParams{ + Client: client, + }, + ) + + var response *Response + err := caller.Call( + context.Background(), + &CallParams{ + URL: server.URL, + Method: http.MethodGet, + Request: &Request{}, + Response: &response, + MaxAttempts: test.giveAttempts, + ResponseIsOptional: true, + }, + ) + + if test.wantError != nil { + require.IsType(t, err, &core.APIError{}) + expectedErrorCode := test.wantError.StatusCode + actualErrorCode := err.(*core.APIError).StatusCode + assert.Equal(t, expectedErrorCode, actualErrorCode) + return + } + + require.NoError(t, err) + assert.Equal(t, test.wantResponse, response) + }) + } +} + +// newTestRetryServer returns a new *httptest.Server configured with the +// given test parameters, suitable for testing retries. +func newTestRetryServer(t *testing.T, tc *RetryTestCase) *httptest.Server { + var index int + timestamps := make([]time.Time, 0, len(tc.giveStatusCodes)) + + return httptest.NewServer( + http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + timestamps = append(timestamps, time.Now()) + if index > 0 && index < len(expectedRetryDurations) { + // Ensure that the duration between retries increases exponentially, + // and that it is within the minimum and maximum retry delay values. + actualDuration := timestamps[index].Sub(timestamps[index-1]) + expectedDurationMin := expectedRetryDurations[index-1] * 75 / 100 + expectedDurationMax := expectedRetryDurations[index-1] * 125 / 100 + assert.True( + t, + actualDuration >= expectedDurationMin && actualDuration <= expectedDurationMax, + "expected duration to be in range [%v, %v], got %v", + expectedDurationMin, + expectedDurationMax, + actualDuration, + ) + assert.LessOrEqual( + t, + actualDuration, + maxRetryDelay, + "expected duration to be less than the maxRetryDelay (%v), got %v", + maxRetryDelay, + actualDuration, + ) + assert.GreaterOrEqual( + t, + actualDuration, + minRetryDelay, + "expected duration to be greater than the minRetryDelay (%v), got %v", + minRetryDelay, + actualDuration, + ) + } + + request := new(Request) + bytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(bytes, request)) + require.LessOrEqual(t, index, len(tc.giveStatusCodes)) + + statusCode := tc.giveStatusCodes[index] + w.WriteHeader(statusCode) + + if tc.giveResponse != nil && statusCode == http.StatusOK { + bytes, err = json.Marshal(tc.giveResponse) + require.NoError(t, err) + _, err = w.Write(bytes) + require.NoError(t, err) + } + + index++ + }, + ), + ) +} + +// expectedRetryDurations holds an array of calculated retry durations, +// where the index of the array should correspond to the retry attempt. +// +// Values are calculated based off of `minRetryDelay + minRetryDelay*i*i`, with +// a max and min value of 5000ms and 500ms respectively. +var expectedRetryDurations = []time.Duration{ + 500 * time.Millisecond, + 1000 * time.Millisecond, + 2500 * time.Millisecond, + 5000 * time.Millisecond, + 5000 * time.Millisecond, +} diff --git a/internal/streamer.go b/internal/streamer.go new file mode 100644 index 0000000..0d8047f --- /dev/null +++ b/internal/streamer.go @@ -0,0 +1,114 @@ +package internal + +import ( + "context" + "net/http" + "net/url" + + "github.com/cohere-ai/cohere-go/v2/core" +) + +const ( + // DefaultDataPrefix is the default prefix used for SSE streaming. + DefaultSSEDataPrefix = "data: " + + // DefaultTerminator is the default terminator used for SSE streaming. + DefaultSSETerminator = "[DONE]" +) + +// Streamer calls APIs and streams responses using a *Stream. +type Streamer[T any] struct { + client HTTPClient + retrier *Retrier +} + +// NewStreamer returns a new *Streamer backed by the given caller's HTTP client. +func NewStreamer[T any](caller *Caller) *Streamer[T] { + return &Streamer[T]{ + client: caller.client, + retrier: caller.retrier, + } +} + +// StreamParams represents the parameters used to issue an API streaming call. +type StreamParams struct { + URL string + Method string + Prefix string + Delimiter string + Terminator string + MaxAttempts uint + Headers http.Header + BodyProperties map[string]interface{} + QueryParameters url.Values + Client HTTPClient + Request interface{} + ErrorDecoder ErrorDecoder +} + +// Stream issues an API streaming call according to the given stream parameters. +func (s *Streamer[T]) Stream(ctx context.Context, params *StreamParams) (*core.Stream[T], error) { + url := buildURL(params.URL, params.QueryParameters) + req, err := newRequest( + ctx, + url, + params.Method, + params.Headers, + params.Request, + params.BodyProperties, + ) + if err != nil { + return nil, err + } + + // If the call has been cancelled, don't issue the request. + if err := ctx.Err(); err != nil { + return nil, err + } + + client := s.client + if params.Client != nil { + // Use the HTTP client scoped to the request. + client = params.Client + } + + var retryOptions []RetryOption + if params.MaxAttempts > 0 { + retryOptions = append(retryOptions, WithMaxAttempts(params.MaxAttempts)) + } + + resp, err := s.retrier.Run( + client.Do, + req, + params.ErrorDecoder, + retryOptions..., + ) + if err != nil { + return nil, err + } + + // Check if the call was cancelled before we return the error + // associated with the call and/or unmarshal the response data. + if err := ctx.Err(); err != nil { + defer resp.Body.Close() + return nil, err + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + defer resp.Body.Close() + return nil, decodeError(resp, params.ErrorDecoder) + } + + var opts []core.StreamOption + if params.Delimiter != "" { + opts = append(opts, core.WithDelimiter(params.Delimiter)) + } + if params.Prefix != "" { + opts = append(opts, core.WithPrefix(params.Prefix)) + } + if params.Terminator != "" { + opts = append(opts, core.WithTerminator(params.Terminator)) + } + + return core.NewStream[T](resp, opts...), nil +} diff --git a/core/stringer.go b/internal/stringer.go similarity index 94% rename from core/stringer.go rename to internal/stringer.go index 000cf44..3128018 100644 --- a/core/stringer.go +++ b/internal/stringer.go @@ -1,4 +1,4 @@ -package core +package internal import "encoding/json" diff --git a/core/time.go b/internal/time.go similarity index 99% rename from core/time.go rename to internal/time.go index d009ab3..ab0e269 100644 --- a/core/time.go +++ b/internal/time.go @@ -1,4 +1,4 @@ -package core +package internal import ( "encoding/json" diff --git a/models.go b/models.go index 18c539f..e98ba66 100644 --- a/models.go +++ b/models.go @@ -2,6 +2,12 @@ package api +import ( + json "encoding/json" + fmt "fmt" + internal "github.com/cohere-ai/cohere-go/v2/internal" +) + type ModelsListRequest struct { // Maximum number of models to include in a page // Defaults to `20`, min value of `1`, max value of `1000`. @@ -13,3 +19,189 @@ type ModelsListRequest struct { // When provided, filters the list of models to only the default model to the endpoint. This parameter is only valid when `endpoint` is provided. DefaultOnly *bool `json:"-" url:"default_only,omitempty"` } + +// One of the Cohere API endpoints that the model can be used with. +type CompatibleEndpoint string + +const ( + CompatibleEndpointChat CompatibleEndpoint = "chat" + CompatibleEndpointEmbed CompatibleEndpoint = "embed" + CompatibleEndpointClassify CompatibleEndpoint = "classify" + CompatibleEndpointSummarize CompatibleEndpoint = "summarize" + CompatibleEndpointRerank CompatibleEndpoint = "rerank" + CompatibleEndpointRate CompatibleEndpoint = "rate" + CompatibleEndpointGenerate CompatibleEndpoint = "generate" +) + +func NewCompatibleEndpointFromString(s string) (CompatibleEndpoint, error) { + switch s { + case "chat": + return CompatibleEndpointChat, nil + case "embed": + return CompatibleEndpointEmbed, nil + case "classify": + return CompatibleEndpointClassify, nil + case "summarize": + return CompatibleEndpointSummarize, nil + case "rerank": + return CompatibleEndpointRerank, nil + case "rate": + return CompatibleEndpointRate, nil + case "generate": + return CompatibleEndpointGenerate, nil + } + var t CompatibleEndpoint + return "", fmt.Errorf("%s is not a valid %T", s, t) +} + +func (c CompatibleEndpoint) Ptr() *CompatibleEndpoint { + return &c +} + +// Contains information about the model and which API endpoints it can be used with. +type GetModelResponse struct { + // Specify this name in the `model` parameter of API requests to use your chosen model. + Name *string `json:"name,omitempty" url:"name,omitempty"` + // The API endpoints that the model is compatible with. + Endpoints []CompatibleEndpoint `json:"endpoints,omitempty" url:"endpoints,omitempty"` + // Whether the model has been fine-tuned or not. + Finetuned *bool `json:"finetuned,omitempty" url:"finetuned,omitempty"` + // The maximum number of tokens that the model can process in a single request. Note that not all of these tokens are always available due to special tokens and preambles that Cohere has added by default. + ContextLength *float64 `json:"context_length,omitempty" url:"context_length,omitempty"` + // Public URL to the tokenizer's configuration file. + TokenizerUrl *string `json:"tokenizer_url,omitempty" url:"tokenizer_url,omitempty"` + // The API endpoints that the model is default to. + DefaultEndpoints []CompatibleEndpoint `json:"default_endpoints,omitempty" url:"default_endpoints,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (g *GetModelResponse) GetName() *string { + if g == nil { + return nil + } + return g.Name +} + +func (g *GetModelResponse) GetEndpoints() []CompatibleEndpoint { + if g == nil { + return nil + } + return g.Endpoints +} + +func (g *GetModelResponse) GetFinetuned() *bool { + if g == nil { + return nil + } + return g.Finetuned +} + +func (g *GetModelResponse) GetContextLength() *float64 { + if g == nil { + return nil + } + return g.ContextLength +} + +func (g *GetModelResponse) GetTokenizerUrl() *string { + if g == nil { + return nil + } + return g.TokenizerUrl +} + +func (g *GetModelResponse) GetDefaultEndpoints() []CompatibleEndpoint { + if g == nil { + return nil + } + return g.DefaultEndpoints +} + +func (g *GetModelResponse) GetExtraProperties() map[string]interface{} { + return g.extraProperties +} + +func (g *GetModelResponse) UnmarshalJSON(data []byte) error { + type unmarshaler GetModelResponse + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *g = GetModelResponse(value) + extraProperties, err := internal.ExtractExtraProperties(data, *g) + if err != nil { + return err + } + g.extraProperties = extraProperties + g.rawJSON = json.RawMessage(data) + return nil +} + +func (g *GetModelResponse) String() string { + if len(g.rawJSON) > 0 { + if value, err := internal.StringifyJSON(g.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(g); err == nil { + return value + } + return fmt.Sprintf("%#v", g) +} + +type ListModelsResponse struct { + Models []*GetModelResponse `json:"models,omitempty" url:"models,omitempty"` + // A token to retrieve the next page of results. Provide in the page_token parameter of the next request. + NextPageToken *string `json:"next_page_token,omitempty" url:"next_page_token,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (l *ListModelsResponse) GetModels() []*GetModelResponse { + if l == nil { + return nil + } + return l.Models +} + +func (l *ListModelsResponse) GetNextPageToken() *string { + if l == nil { + return nil + } + return l.NextPageToken +} + +func (l *ListModelsResponse) GetExtraProperties() map[string]interface{} { + return l.extraProperties +} + +func (l *ListModelsResponse) UnmarshalJSON(data []byte) error { + type unmarshaler ListModelsResponse + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *l = ListModelsResponse(value) + extraProperties, err := internal.ExtractExtraProperties(data, *l) + if err != nil { + return err + } + l.extraProperties = extraProperties + l.rawJSON = json.RawMessage(data) + return nil +} + +func (l *ListModelsResponse) String() string { + if len(l.rawJSON) > 0 { + if value, err := internal.StringifyJSON(l.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(l); err == nil { + return value + } + return fmt.Sprintf("%#v", l) +} diff --git a/models/client.go b/models/client.go index 7932f9f..9053186 100644 --- a/models/client.go +++ b/models/client.go @@ -3,21 +3,18 @@ package models import ( - bytes "bytes" context "context" - json "encoding/json" - errors "errors" v2 "github.com/cohere-ai/cohere-go/v2" core "github.com/cohere-ai/cohere-go/v2/core" + internal "github.com/cohere-ai/cohere-go/v2/internal" option "github.com/cohere-ai/cohere-go/v2/option" - io "io" http "net/http" os "os" ) type Client struct { baseURL string - caller *core.Caller + caller *internal.Caller header http.Header } @@ -28,8 +25,8 @@ func NewClient(opts ...option.RequestOption) *Client { } return &Client{ baseURL: options.BaseURL, - caller: core.NewCaller( - &core.CallerParams{ + caller: internal.NewCaller( + &internal.CallerParams{ Client: options.HTTPClient, MaxAttempts: options.MaxAttempts, }, @@ -45,120 +42,95 @@ func (c *Client) Get( opts ...option.RequestOption, ) (*v2.GetModelResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } - endpointURL := core.EncodeURL(baseURL+"/v1/models/%v", model) - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) + endpointURL := internal.EncodeURL( + baseURL+"/v1/models/%v", + model, + ) + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, } - return value - } - return apiError + }, } var response *v2.GetModelResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodGet, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err @@ -173,128 +145,99 @@ func (c *Client) List( opts ...option.RequestOption, ) (*v2.ListModelsResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) endpointURL := baseURL + "/v1/models" - - queryParams, err := core.QueryValues(request) + queryParams, err := internal.QueryValues(request) if err != nil { return nil, err } if len(queryParams) > 0 { endpointURL += "?" + queryParams.Encode() } - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, } - return value - } - return apiError + }, } var response *v2.ListModelsResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodGet, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err diff --git a/types.go b/types.go index ee700b1..b4a0c5b 100644 --- a/types.go +++ b/types.go @@ -5,8 +5,7 @@ package api import ( json "encoding/json" fmt "fmt" - core "github.com/cohere-ai/cohere-go/v2/core" - time "time" + internal "github.com/cohere-ai/cohere-go/v2/internal" ) type ChatRequest struct { @@ -787,7 +786,35 @@ type ApiMeta struct { Warnings []string `json:"warnings,omitempty" url:"warnings,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (a *ApiMeta) GetApiVersion() *ApiMetaApiVersion { + if a == nil { + return nil + } + return a.ApiVersion +} + +func (a *ApiMeta) GetBilledUnits() *ApiMetaBilledUnits { + if a == nil { + return nil + } + return a.BilledUnits +} + +func (a *ApiMeta) GetTokens() *ApiMetaTokens { + if a == nil { + return nil + } + return a.Tokens +} + +func (a *ApiMeta) GetWarnings() []string { + if a == nil { + return nil + } + return a.Warnings } func (a *ApiMeta) GetExtraProperties() map[string]interface{} { @@ -801,24 +828,22 @@ func (a *ApiMeta) UnmarshalJSON(data []byte) error { return err } *a = ApiMeta(value) - - extraProperties, err := core.ExtractExtraProperties(data, *a) + extraProperties, err := internal.ExtractExtraProperties(data, *a) if err != nil { return err } a.extraProperties = extraProperties - - a._rawJSON = json.RawMessage(data) + a.rawJSON = json.RawMessage(data) return nil } func (a *ApiMeta) String() string { - if len(a._rawJSON) > 0 { - if value, err := core.StringifyJSON(a._rawJSON); err == nil { + if len(a.rawJSON) > 0 { + if value, err := internal.StringifyJSON(a.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(a); err == nil { + if value, err := internal.StringifyJSON(a); err == nil { return value } return fmt.Sprintf("%#v", a) @@ -830,7 +855,28 @@ type ApiMetaApiVersion struct { IsExperimental *bool `json:"is_experimental,omitempty" url:"is_experimental,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (a *ApiMetaApiVersion) GetVersion() string { + if a == nil { + return "" + } + return a.Version +} + +func (a *ApiMetaApiVersion) GetIsDeprecated() *bool { + if a == nil { + return nil + } + return a.IsDeprecated +} + +func (a *ApiMetaApiVersion) GetIsExperimental() *bool { + if a == nil { + return nil + } + return a.IsExperimental } func (a *ApiMetaApiVersion) GetExtraProperties() map[string]interface{} { @@ -844,24 +890,22 @@ func (a *ApiMetaApiVersion) UnmarshalJSON(data []byte) error { return err } *a = ApiMetaApiVersion(value) - - extraProperties, err := core.ExtractExtraProperties(data, *a) + extraProperties, err := internal.ExtractExtraProperties(data, *a) if err != nil { return err } a.extraProperties = extraProperties - - a._rawJSON = json.RawMessage(data) + a.rawJSON = json.RawMessage(data) return nil } func (a *ApiMetaApiVersion) String() string { - if len(a._rawJSON) > 0 { - if value, err := core.StringifyJSON(a._rawJSON); err == nil { + if len(a.rawJSON) > 0 { + if value, err := internal.StringifyJSON(a.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(a); err == nil { + if value, err := internal.StringifyJSON(a); err == nil { return value } return fmt.Sprintf("%#v", a) @@ -880,7 +924,42 @@ type ApiMetaBilledUnits struct { Classifications *float64 `json:"classifications,omitempty" url:"classifications,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (a *ApiMetaBilledUnits) GetImages() *float64 { + if a == nil { + return nil + } + return a.Images +} + +func (a *ApiMetaBilledUnits) GetInputTokens() *float64 { + if a == nil { + return nil + } + return a.InputTokens +} + +func (a *ApiMetaBilledUnits) GetOutputTokens() *float64 { + if a == nil { + return nil + } + return a.OutputTokens +} + +func (a *ApiMetaBilledUnits) GetSearchUnits() *float64 { + if a == nil { + return nil + } + return a.SearchUnits +} + +func (a *ApiMetaBilledUnits) GetClassifications() *float64 { + if a == nil { + return nil + } + return a.Classifications } func (a *ApiMetaBilledUnits) GetExtraProperties() map[string]interface{} { @@ -894,24 +973,22 @@ func (a *ApiMetaBilledUnits) UnmarshalJSON(data []byte) error { return err } *a = ApiMetaBilledUnits(value) - - extraProperties, err := core.ExtractExtraProperties(data, *a) + extraProperties, err := internal.ExtractExtraProperties(data, *a) if err != nil { return err } a.extraProperties = extraProperties - - a._rawJSON = json.RawMessage(data) + a.rawJSON = json.RawMessage(data) return nil } func (a *ApiMetaBilledUnits) String() string { - if len(a._rawJSON) > 0 { - if value, err := core.StringifyJSON(a._rawJSON); err == nil { + if len(a.rawJSON) > 0 { + if value, err := internal.StringifyJSON(a.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(a); err == nil { + if value, err := internal.StringifyJSON(a); err == nil { return value } return fmt.Sprintf("%#v", a) @@ -924,334 +1001,96 @@ type ApiMetaTokens struct { OutputTokens *float64 `json:"output_tokens,omitempty" url:"output_tokens,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (a *ApiMetaTokens) GetExtraProperties() map[string]interface{} { - return a.extraProperties + rawJSON json.RawMessage } -func (a *ApiMetaTokens) UnmarshalJSON(data []byte) error { - type unmarshaler ApiMetaTokens - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *a = ApiMetaTokens(value) - - extraProperties, err := core.ExtractExtraProperties(data, *a) - if err != nil { - return err +func (a *ApiMetaTokens) GetInputTokens() *float64 { + if a == nil { + return nil } - a.extraProperties = extraProperties - - a._rawJSON = json.RawMessage(data) - return nil + return a.InputTokens } -func (a *ApiMetaTokens) String() string { - if len(a._rawJSON) > 0 { - if value, err := core.StringifyJSON(a._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(a); err == nil { - return value +func (a *ApiMetaTokens) GetOutputTokens() *float64 { + if a == nil { + return nil } - return fmt.Sprintf("%#v", a) -} - -// A message from the assistant role can contain text and tool call information. -type AssistantMessage struct { - ToolCalls []*ToolCallV2 `json:"tool_calls,omitempty" url:"tool_calls,omitempty"` - // A chain-of-thought style reflection and plan that the model generates when working with Tools. - ToolPlan *string `json:"tool_plan,omitempty" url:"tool_plan,omitempty"` - Content *AssistantMessageContent `json:"content,omitempty" url:"content,omitempty"` - Citations []*Citation `json:"citations,omitempty" url:"citations,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage + return a.OutputTokens } -func (a *AssistantMessage) GetExtraProperties() map[string]interface{} { +func (a *ApiMetaTokens) GetExtraProperties() map[string]interface{} { return a.extraProperties } -func (a *AssistantMessage) UnmarshalJSON(data []byte) error { - type unmarshaler AssistantMessage +func (a *ApiMetaTokens) UnmarshalJSON(data []byte) error { + type unmarshaler ApiMetaTokens var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *a = AssistantMessage(value) - - extraProperties, err := core.ExtractExtraProperties(data, *a) + *a = ApiMetaTokens(value) + extraProperties, err := internal.ExtractExtraProperties(data, *a) if err != nil { return err } a.extraProperties = extraProperties - - a._rawJSON = json.RawMessage(data) + a.rawJSON = json.RawMessage(data) return nil } -func (a *AssistantMessage) String() string { - if len(a._rawJSON) > 0 { - if value, err := core.StringifyJSON(a._rawJSON); err == nil { +func (a *ApiMetaTokens) String() string { + if len(a.rawJSON) > 0 { + if value, err := internal.StringifyJSON(a.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(a); err == nil { + if value, err := internal.StringifyJSON(a); err == nil { return value } return fmt.Sprintf("%#v", a) } -type AssistantMessageContent struct { - String string - AssistantMessageContentItemList []*AssistantMessageContentItem -} - -func (a *AssistantMessageContent) UnmarshalJSON(data []byte) error { - var valueString string - if err := json.Unmarshal(data, &valueString); err == nil { - a.String = valueString - return nil - } - var valueAssistantMessageContentItemList []*AssistantMessageContentItem - if err := json.Unmarshal(data, &valueAssistantMessageContentItemList); err == nil { - a.AssistantMessageContentItemList = valueAssistantMessageContentItemList - return nil - } - return fmt.Errorf("%s cannot be deserialized as a %T", data, a) -} - -func (a AssistantMessageContent) MarshalJSON() ([]byte, error) { - if a.String != "" { - return json.Marshal(a.String) - } - if a.AssistantMessageContentItemList != nil { - return json.Marshal(a.AssistantMessageContentItemList) - } - return nil, fmt.Errorf("type %T does not include a non-empty union type", a) -} - -type AssistantMessageContentVisitor interface { - VisitString(string) error - VisitAssistantMessageContentItemList([]*AssistantMessageContentItem) error -} - -func (a *AssistantMessageContent) Accept(visitor AssistantMessageContentVisitor) error { - if a.String != "" { - return visitor.VisitString(a.String) - } - if a.AssistantMessageContentItemList != nil { - return visitor.VisitAssistantMessageContentItemList(a.AssistantMessageContentItemList) - } - return fmt.Errorf("type %T does not include a non-empty union type", a) -} - -type AssistantMessageContentItem struct { - Type string - Text *TextContent -} - -func (a *AssistantMessageContentItem) UnmarshalJSON(data []byte) error { - var unmarshaler struct { - Type string `json:"type"` - } - if err := json.Unmarshal(data, &unmarshaler); err != nil { - return err - } - a.Type = unmarshaler.Type - if unmarshaler.Type == "" { - return fmt.Errorf("%T did not include discriminant type", a) - } - switch unmarshaler.Type { - case "text": - value := new(TextContent) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - a.Text = value - } - return nil -} - -func (a AssistantMessageContentItem) MarshalJSON() ([]byte, error) { - if a.Text != nil { - return core.MarshalJSONWithExtraProperty(a.Text, "type", "text") - } - return nil, fmt.Errorf("type %T does not define a non-empty union type", a) -} - -type AssistantMessageContentItemVisitor interface { - VisitText(*TextContent) error -} - -func (a *AssistantMessageContentItem) Accept(visitor AssistantMessageContentItemVisitor) error { - if a.Text != nil { - return visitor.VisitText(a.Text) - } - return fmt.Errorf("type %T does not define a non-empty union type", a) -} - -// A message from the assistant role can contain text and tool call information. -type AssistantMessageResponse struct { - ToolCalls []*ToolCallV2 `json:"tool_calls,omitempty" url:"tool_calls,omitempty"` - // A chain-of-thought style reflection and plan that the model generates when working with Tools. - ToolPlan *string `json:"tool_plan,omitempty" url:"tool_plan,omitempty"` - Content []*AssistantMessageResponseContentItem `json:"content,omitempty" url:"content,omitempty"` - Citations []*Citation `json:"citations,omitempty" url:"citations,omitempty"` - role string +// A section of the generated reply which cites external knowledge. +type ChatCitation struct { + // The index of text that the citation starts at, counting from zero. For example, a generation of `Hello, world!` with a citation on `world` would have a start value of `7`. This is because the citation starts at `w`, which is the seventh character. + Start int `json:"start" url:"start"` + // The index of text that the citation ends after, counting from zero. For example, a generation of `Hello, world!` with a citation on `world` would have an end value of `11`. This is because the citation ends after `d`, which is the eleventh character. + End int `json:"end" url:"end"` + // The text of the citation. For example, a generation of `Hello, world!` with a citation of `world` would have a text value of `world`. + Text string `json:"text" url:"text"` + // Identifiers of documents cited by this section of the generated reply. + DocumentIds []string `json:"document_ids,omitempty" url:"document_ids,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (a *AssistantMessageResponse) GetExtraProperties() map[string]interface{} { - return a.extraProperties -} - -func (a *AssistantMessageResponse) Role() string { - return a.role -} - -func (a *AssistantMessageResponse) UnmarshalJSON(data []byte) error { - type embed AssistantMessageResponse - var unmarshaler = struct { - embed - Role string `json:"role"` - }{ - embed: embed(*a), - } - if err := json.Unmarshal(data, &unmarshaler); err != nil { - return err - } - *a = AssistantMessageResponse(unmarshaler.embed) - if unmarshaler.Role != "assistant" { - return fmt.Errorf("unexpected value for literal on type %T; expected %v got %v", a, "assistant", unmarshaler.Role) - } - a.role = unmarshaler.Role - - extraProperties, err := core.ExtractExtraProperties(data, *a, "role") - if err != nil { - return err - } - a.extraProperties = extraProperties - - a._rawJSON = json.RawMessage(data) - return nil -} - -func (a *AssistantMessageResponse) MarshalJSON() ([]byte, error) { - type embed AssistantMessageResponse - var marshaler = struct { - embed - Role string `json:"role"` - }{ - embed: embed(*a), - Role: "assistant", - } - return json.Marshal(marshaler) + rawJSON json.RawMessage } -func (a *AssistantMessageResponse) String() string { - if len(a._rawJSON) > 0 { - if value, err := core.StringifyJSON(a._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(a); err == nil { - return value +func (c *ChatCitation) GetStart() int { + if c == nil { + return 0 } - return fmt.Sprintf("%#v", a) -} - -type AssistantMessageResponseContentItem struct { - Type string - Text *TextContent + return c.Start } -func (a *AssistantMessageResponseContentItem) UnmarshalJSON(data []byte) error { - var unmarshaler struct { - Type string `json:"type"` - } - if err := json.Unmarshal(data, &unmarshaler); err != nil { - return err - } - a.Type = unmarshaler.Type - if unmarshaler.Type == "" { - return fmt.Errorf("%T did not include discriminant type", a) - } - switch unmarshaler.Type { - case "text": - value := new(TextContent) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - a.Text = value +func (c *ChatCitation) GetEnd() int { + if c == nil { + return 0 } - return nil + return c.End } -func (a AssistantMessageResponseContentItem) MarshalJSON() ([]byte, error) { - if a.Text != nil { - return core.MarshalJSONWithExtraProperty(a.Text, "type", "text") +func (c *ChatCitation) GetText() string { + if c == nil { + return "" } - return nil, fmt.Errorf("type %T does not define a non-empty union type", a) -} - -type AssistantMessageResponseContentItemVisitor interface { - VisitText(*TextContent) error + return c.Text } -func (a *AssistantMessageResponseContentItem) Accept(visitor AssistantMessageResponseContentItemVisitor) error { - if a.Text != nil { - return visitor.VisitText(a.Text) +func (c *ChatCitation) GetDocumentIds() []string { + if c == nil { + return nil } - return fmt.Errorf("type %T does not define a non-empty union type", a) -} - -// The token_type specifies the way the token is passed in the Authorization header. Valid values are "bearer", "basic", and "noscheme". -type AuthTokenType string - -const ( - AuthTokenTypeBearer AuthTokenType = "bearer" - AuthTokenTypeBasic AuthTokenType = "basic" - AuthTokenTypeNoscheme AuthTokenType = "noscheme" -) - -func NewAuthTokenTypeFromString(s string) (AuthTokenType, error) { - switch s { - case "bearer": - return AuthTokenTypeBearer, nil - case "basic": - return AuthTokenTypeBasic, nil - case "noscheme": - return AuthTokenTypeNoscheme, nil - } - var t AuthTokenType - return "", fmt.Errorf("%s is not a valid %T", s, t) -} - -func (a AuthTokenType) Ptr() *AuthTokenType { - return &a -} - -// A section of the generated reply which cites external knowledge. -type ChatCitation struct { - // The index of text that the citation starts at, counting from zero. For example, a generation of `Hello, world!` with a citation on `world` would have a start value of `7`. This is because the citation starts at `w`, which is the seventh character. - Start int `json:"start" url:"start"` - // The index of text that the citation ends after, counting from zero. For example, a generation of `Hello, world!` with a citation on `world` would have an end value of `11`. This is because the citation ends after `d`, which is the eleventh character. - End int `json:"end" url:"end"` - // The text of the citation. For example, a generation of `Hello, world!` with a citation of `world` would have a text value of `world`. - Text string `json:"text" url:"text"` - // Identifiers of documents cited by this section of the generated reply. - DocumentIds []string `json:"document_ids,omitempty" url:"document_ids,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage + return c.DocumentIds } func (c *ChatCitation) GetExtraProperties() map[string]interface{} { @@ -1265,24 +1104,22 @@ func (c *ChatCitation) UnmarshalJSON(data []byte) error { return err } *c = ChatCitation(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + extraProperties, err := internal.ExtractExtraProperties(data, *c) if err != nil { return err } c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + c.rawJSON = json.RawMessage(data) return nil } func (c *ChatCitation) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(c); err == nil { return value } return fmt.Sprintf("%#v", c) @@ -1293,7 +1130,14 @@ type ChatCitationGenerationEvent struct { Citations []*ChatCitation `json:"citations,omitempty" url:"citations,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (c *ChatCitationGenerationEvent) GetCitations() []*ChatCitation { + if c == nil { + return nil + } + return c.Citations } func (c *ChatCitationGenerationEvent) GetExtraProperties() map[string]interface{} { @@ -1307,24 +1151,22 @@ func (c *ChatCitationGenerationEvent) UnmarshalJSON(data []byte) error { return err } *c = ChatCitationGenerationEvent(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + extraProperties, err := internal.ExtractExtraProperties(data, *c) if err != nil { return err } c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + c.rawJSON = json.RawMessage(data) return nil } func (c *ChatCitationGenerationEvent) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(c); err == nil { return value } return fmt.Sprintf("%#v", c) @@ -1346,7 +1188,35 @@ type ChatConnector struct { Options map[string]interface{} `json:"options,omitempty" url:"options,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (c *ChatConnector) GetId() string { + if c == nil { + return "" + } + return c.Id +} + +func (c *ChatConnector) GetUserAccessToken() *string { + if c == nil { + return nil + } + return c.UserAccessToken +} + +func (c *ChatConnector) GetContinueOnFailure() *bool { + if c == nil { + return nil + } + return c.ContinueOnFailure +} + +func (c *ChatConnector) GetOptions() map[string]interface{} { + if c == nil { + return nil + } + return c.Options } func (c *ChatConnector) GetExtraProperties() map[string]interface{} { @@ -1360,8125 +1230,4563 @@ func (c *ChatConnector) UnmarshalJSON(data []byte) error { return err } *c = ChatConnector(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + extraProperties, err := internal.ExtractExtraProperties(data, *c) if err != nil { return err } c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + c.rawJSON = json.RawMessage(data) return nil } func (c *ChatConnector) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(c); err == nil { return value } return fmt.Sprintf("%#v", c) } -// A streamed delta event which contains a delta of chat text content. -type ChatContentDeltaEvent struct { - Index *int `json:"index,omitempty" url:"index,omitempty"` - Delta *ChatContentDeltaEventDelta `json:"delta,omitempty" url:"delta,omitempty"` - Logprobs *LogprobItem `json:"logprobs,omitempty" url:"logprobs,omitempty"` +type ChatDebugEvent struct { + Prompt *string `json:"prompt,omitempty" url:"prompt,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage } -func (c *ChatContentDeltaEvent) GetExtraProperties() map[string]interface{} { - return c.extraProperties +func (c *ChatDebugEvent) GetPrompt() *string { + if c == nil { + return nil + } + return c.Prompt } -func (c *ChatContentDeltaEvent) UnmarshalJSON(data []byte) error { - type unmarshaler ChatContentDeltaEvent - var value unmarshaler +func (c *ChatDebugEvent) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatDebugEvent) UnmarshalJSON(data []byte) error { + type unmarshaler ChatDebugEvent + var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ChatContentDeltaEvent(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *c = ChatDebugEvent(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) if err != nil { return err } c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + c.rawJSON = json.RawMessage(data) return nil } -func (c *ChatContentDeltaEvent) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (c *ChatDebugEvent) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(c); err == nil { return value } return fmt.Sprintf("%#v", c) } -type ChatContentDeltaEventDelta struct { - Message *ChatContentDeltaEventDeltaMessage `json:"message,omitempty" url:"message,omitempty"` +// Relevant information that could be used by the model to generate a more accurate reply. +// The contents of each document are generally short (under 300 words), and are passed in the form of a +// dictionary of strings. Some suggested keys are "text", "author", "date". Both the key name and the value will be +// passed to the model. +type ChatDocument = map[string]string + +// Represents a single message in the chat history, excluding the current user turn. It has two properties: `role` and `message`. The `role` identifies the sender (`CHATBOT`, `SYSTEM`, or `USER`), while the `message` contains the text content. +// +// The chat_history parameter should not be used for `SYSTEM` messages in most cases. Instead, to add a `SYSTEM` role message at the beginning of a conversation, the `preamble` parameter should be used. +type ChatMessage struct { + // Contents of the chat message. + Message string `json:"message" url:"message"` + ToolCalls []*ToolCall `json:"tool_calls,omitempty" url:"tool_calls,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage } -func (c *ChatContentDeltaEventDelta) GetExtraProperties() map[string]interface{} { +func (c *ChatMessage) GetMessage() string { + if c == nil { + return "" + } + return c.Message +} + +func (c *ChatMessage) GetToolCalls() []*ToolCall { + if c == nil { + return nil + } + return c.ToolCalls +} + +func (c *ChatMessage) GetExtraProperties() map[string]interface{} { return c.extraProperties } -func (c *ChatContentDeltaEventDelta) UnmarshalJSON(data []byte) error { - type unmarshaler ChatContentDeltaEventDelta +func (c *ChatMessage) UnmarshalJSON(data []byte) error { + type unmarshaler ChatMessage var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ChatContentDeltaEventDelta(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *c = ChatMessage(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) if err != nil { return err } c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + c.rawJSON = json.RawMessage(data) return nil } -func (c *ChatContentDeltaEventDelta) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (c *ChatMessage) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(c); err == nil { return value } return fmt.Sprintf("%#v", c) } -type ChatContentDeltaEventDeltaMessage struct { - Content *ChatContentDeltaEventDeltaMessageContent `json:"content,omitempty" url:"content,omitempty"` +// Defaults to `"accurate"`. +// +// Dictates the approach taken to generating citations as part of the RAG flow by allowing the user to specify whether they want `"accurate"` results, `"fast"` results or no results. +// +// Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments +type ChatRequestCitationQuality string + +const ( + ChatRequestCitationQualityFast ChatRequestCitationQuality = "fast" + ChatRequestCitationQualityAccurate ChatRequestCitationQuality = "accurate" + ChatRequestCitationQualityOff ChatRequestCitationQuality = "off" +) + +func NewChatRequestCitationQualityFromString(s string) (ChatRequestCitationQuality, error) { + switch s { + case "fast": + return ChatRequestCitationQualityFast, nil + case "accurate": + return ChatRequestCitationQualityAccurate, nil + case "off": + return ChatRequestCitationQualityOff, nil + } + var t ChatRequestCitationQuality + return "", fmt.Errorf("%s is not a valid %T", s, t) +} + +func (c ChatRequestCitationQuality) Ptr() *ChatRequestCitationQuality { + return &c +} + +// (internal) Sets inference and model options for RAG search query and tool use generations. Defaults are used when options are not specified here, meaning that other parameters outside of connectors_search_options are ignored (such as model= or temperature=). +type ChatRequestConnectorsSearchOptions struct { + // If specified, the backend will make a best effort to sample tokens + // deterministically, such that repeated requests with the same + // seed and parameters should return the same result. However, + // determinism cannot be totally guaranteed. + // + // Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments + Seed *int `json:"seed,omitempty" url:"seed,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (c *ChatRequestConnectorsSearchOptions) GetSeed() *int { + if c == nil { + return nil + } + return c.Seed } -func (c *ChatContentDeltaEventDeltaMessage) GetExtraProperties() map[string]interface{} { +func (c *ChatRequestConnectorsSearchOptions) GetExtraProperties() map[string]interface{} { return c.extraProperties } -func (c *ChatContentDeltaEventDeltaMessage) UnmarshalJSON(data []byte) error { - type unmarshaler ChatContentDeltaEventDeltaMessage +func (c *ChatRequestConnectorsSearchOptions) UnmarshalJSON(data []byte) error { + type unmarshaler ChatRequestConnectorsSearchOptions var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ChatContentDeltaEventDeltaMessage(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *c = ChatRequestConnectorsSearchOptions(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) if err != nil { return err } c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + c.rawJSON = json.RawMessage(data) return nil } -func (c *ChatContentDeltaEventDeltaMessage) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (c *ChatRequestConnectorsSearchOptions) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(c); err == nil { return value } return fmt.Sprintf("%#v", c) } -type ChatContentDeltaEventDeltaMessageContent struct { - Text *string `json:"text,omitempty" url:"text,omitempty"` +// Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases. +// +// Dictates how the prompt will be constructed. +// +// With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be changed and ranked by relevance. +// +// With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API. +// +// With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be returned. +// +// Compatible Deployments: +// - AUTO: Cohere Platform Only +// - AUTO_PRESERVE_ORDER: Azure, AWS Sagemaker/Bedrock, Private Deployments +type ChatRequestPromptTruncation string - extraProperties map[string]interface{} - _rawJSON json.RawMessage +const ( + ChatRequestPromptTruncationOff ChatRequestPromptTruncation = "OFF" + ChatRequestPromptTruncationAuto ChatRequestPromptTruncation = "AUTO" + ChatRequestPromptTruncationAutoPreserveOrder ChatRequestPromptTruncation = "AUTO_PRESERVE_ORDER" +) + +func NewChatRequestPromptTruncationFromString(s string) (ChatRequestPromptTruncation, error) { + switch s { + case "OFF": + return ChatRequestPromptTruncationOff, nil + case "AUTO": + return ChatRequestPromptTruncationAuto, nil + case "AUTO_PRESERVE_ORDER": + return ChatRequestPromptTruncationAutoPreserveOrder, nil + } + var t ChatRequestPromptTruncation + return "", fmt.Errorf("%s is not a valid %T", s, t) } -func (c *ChatContentDeltaEventDeltaMessageContent) GetExtraProperties() map[string]interface{} { - return c.extraProperties +func (c ChatRequestPromptTruncation) Ptr() *ChatRequestPromptTruncation { + return &c } -func (c *ChatContentDeltaEventDeltaMessageContent) UnmarshalJSON(data []byte) error { - type unmarshaler ChatContentDeltaEventDeltaMessageContent - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *c = ChatContentDeltaEventDeltaMessageContent(value) +// Used to select the [safety instruction](https://docs.cohere.com/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`. +// When `NONE` is specified, the safety instruction will be omitted. +// +// Safety modes are not yet configurable in combination with `tools`, `tool_results` and `documents` parameters. +// +// **Note**: This parameter is only compatible with models [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release), [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release) and newer. +// +// **Note**: `command-r7b-12-2024` only supports `"CONTEXTUAL"` and `"STRICT"` modes. +// +// Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments +type ChatRequestSafetyMode string - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err - } - c.extraProperties = extraProperties +const ( + ChatRequestSafetyModeContextual ChatRequestSafetyMode = "CONTEXTUAL" + ChatRequestSafetyModeStrict ChatRequestSafetyMode = "STRICT" + ChatRequestSafetyModeNone ChatRequestSafetyMode = "NONE" +) - c._rawJSON = json.RawMessage(data) - return nil +func NewChatRequestSafetyModeFromString(s string) (ChatRequestSafetyMode, error) { + switch s { + case "CONTEXTUAL": + return ChatRequestSafetyModeContextual, nil + case "STRICT": + return ChatRequestSafetyModeStrict, nil + case "NONE": + return ChatRequestSafetyModeNone, nil + } + var t ChatRequestSafetyMode + return "", fmt.Errorf("%s is not a valid %T", s, t) } -func (c *ChatContentDeltaEventDeltaMessageContent) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(c); err == nil { - return value - } - return fmt.Sprintf("%#v", c) +func (c ChatRequestSafetyMode) Ptr() *ChatRequestSafetyMode { + return &c } -// A streamed delta event which signifies that the content block has ended. -type ChatContentEndEvent struct { - Index *int `json:"index,omitempty" url:"index,omitempty"` +type ChatSearchQueriesGenerationEvent struct { + // Generated search queries, meant to be used as part of the RAG flow. + SearchQueries []*ChatSearchQuery `json:"search_queries,omitempty" url:"search_queries,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (c *ChatSearchQueriesGenerationEvent) GetSearchQueries() []*ChatSearchQuery { + if c == nil { + return nil + } + return c.SearchQueries } -func (c *ChatContentEndEvent) GetExtraProperties() map[string]interface{} { +func (c *ChatSearchQueriesGenerationEvent) GetExtraProperties() map[string]interface{} { return c.extraProperties } -func (c *ChatContentEndEvent) UnmarshalJSON(data []byte) error { - type unmarshaler ChatContentEndEvent +func (c *ChatSearchQueriesGenerationEvent) UnmarshalJSON(data []byte) error { + type unmarshaler ChatSearchQueriesGenerationEvent var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ChatContentEndEvent(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *c = ChatSearchQueriesGenerationEvent(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) if err != nil { return err } c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + c.rawJSON = json.RawMessage(data) return nil } -func (c *ChatContentEndEvent) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (c *ChatSearchQueriesGenerationEvent) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(c); err == nil { return value } return fmt.Sprintf("%#v", c) } -// A streamed delta event which signifies that a new content block has started. -type ChatContentStartEvent struct { - Index *int `json:"index,omitempty" url:"index,omitempty"` - Delta *ChatContentStartEventDelta `json:"delta,omitempty" url:"delta,omitempty"` +// The generated search query. Contains the text of the query and a unique identifier for the query. +type ChatSearchQuery struct { + // The text of the search query. + Text string `json:"text" url:"text"` + // Unique identifier for the generated search query. Useful for submitting feedback. + GenerationId string `json:"generation_id" url:"generation_id"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (c *ChatSearchQuery) GetText() string { + if c == nil { + return "" + } + return c.Text +} + +func (c *ChatSearchQuery) GetGenerationId() string { + if c == nil { + return "" + } + return c.GenerationId } -func (c *ChatContentStartEvent) GetExtraProperties() map[string]interface{} { +func (c *ChatSearchQuery) GetExtraProperties() map[string]interface{} { return c.extraProperties } -func (c *ChatContentStartEvent) UnmarshalJSON(data []byte) error { - type unmarshaler ChatContentStartEvent +func (c *ChatSearchQuery) UnmarshalJSON(data []byte) error { + type unmarshaler ChatSearchQuery var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ChatContentStartEvent(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *c = ChatSearchQuery(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) if err != nil { return err } c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + c.rawJSON = json.RawMessage(data) return nil } -func (c *ChatContentStartEvent) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (c *ChatSearchQuery) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(c); err == nil { return value } return fmt.Sprintf("%#v", c) } -type ChatContentStartEventDelta struct { - Message *ChatContentStartEventDeltaMessage `json:"message,omitempty" url:"message,omitempty"` +type ChatSearchResult struct { + SearchQuery *ChatSearchQuery `json:"search_query,omitempty" url:"search_query,omitempty"` + // The connector from which this result comes from. + Connector *ChatSearchResultConnector `json:"connector,omitempty" url:"connector,omitempty"` + // Identifiers of documents found by this search query. + DocumentIds []string `json:"document_ids,omitempty" url:"document_ids,omitempty"` + // An error message if the search failed. + ErrorMessage *string `json:"error_message,omitempty" url:"error_message,omitempty"` + // Whether a chat request should continue or not if the request to this connector fails. + ContinueOnFailure *bool `json:"continue_on_failure,omitempty" url:"continue_on_failure,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage } -func (c *ChatContentStartEventDelta) GetExtraProperties() map[string]interface{} { - return c.extraProperties +func (c *ChatSearchResult) GetSearchQuery() *ChatSearchQuery { + if c == nil { + return nil + } + return c.SearchQuery } -func (c *ChatContentStartEventDelta) UnmarshalJSON(data []byte) error { - type unmarshaler ChatContentStartEventDelta - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err +func (c *ChatSearchResult) GetConnector() *ChatSearchResultConnector { + if c == nil { + return nil } - *c = ChatContentStartEventDelta(value) + return c.Connector +} - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err +func (c *ChatSearchResult) GetDocumentIds() []string { + if c == nil { + return nil } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) - return nil + return c.DocumentIds } -func (c *ChatContentStartEventDelta) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(c); err == nil { - return value +func (c *ChatSearchResult) GetErrorMessage() *string { + if c == nil { + return nil } - return fmt.Sprintf("%#v", c) + return c.ErrorMessage } -type ChatContentStartEventDeltaMessage struct { - Content *ChatContentStartEventDeltaMessageContent `json:"content,omitempty" url:"content,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage +func (c *ChatSearchResult) GetContinueOnFailure() *bool { + if c == nil { + return nil + } + return c.ContinueOnFailure } -func (c *ChatContentStartEventDeltaMessage) GetExtraProperties() map[string]interface{} { +func (c *ChatSearchResult) GetExtraProperties() map[string]interface{} { return c.extraProperties } -func (c *ChatContentStartEventDeltaMessage) UnmarshalJSON(data []byte) error { - type unmarshaler ChatContentStartEventDeltaMessage +func (c *ChatSearchResult) UnmarshalJSON(data []byte) error { + type unmarshaler ChatSearchResult var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ChatContentStartEventDeltaMessage(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *c = ChatSearchResult(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) if err != nil { return err } c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + c.rawJSON = json.RawMessage(data) return nil } -func (c *ChatContentStartEventDeltaMessage) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (c *ChatSearchResult) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(c); err == nil { return value } return fmt.Sprintf("%#v", c) } -type ChatContentStartEventDeltaMessageContent struct { - Text *string `json:"text,omitempty" url:"text,omitempty"` - Type *string `json:"type,omitempty" url:"type,omitempty"` +// The connector used for fetching documents. +type ChatSearchResultConnector struct { + // The identifier of the connector. + Id string `json:"id" url:"id"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (c *ChatSearchResultConnector) GetId() string { + if c == nil { + return "" + } + return c.Id } -func (c *ChatContentStartEventDeltaMessageContent) GetExtraProperties() map[string]interface{} { +func (c *ChatSearchResultConnector) GetExtraProperties() map[string]interface{} { return c.extraProperties } -func (c *ChatContentStartEventDeltaMessageContent) UnmarshalJSON(data []byte) error { - type unmarshaler ChatContentStartEventDeltaMessageContent +func (c *ChatSearchResultConnector) UnmarshalJSON(data []byte) error { + type unmarshaler ChatSearchResultConnector var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ChatContentStartEventDeltaMessageContent(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *c = ChatSearchResultConnector(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) if err != nil { return err } c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + c.rawJSON = json.RawMessage(data) return nil } -func (c *ChatContentStartEventDeltaMessageContent) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (c *ChatSearchResultConnector) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(c); err == nil { return value } return fmt.Sprintf("%#v", c) } -type ChatDataMetrics struct { - // The sum of all turns of valid train examples. - NumTrainTurns *int64 `json:"num_train_turns,omitempty" url:"num_train_turns,omitempty"` - // The sum of all turns of valid eval examples. - NumEvalTurns *int64 `json:"num_eval_turns,omitempty" url:"num_eval_turns,omitempty"` - // The preamble of this dataset. - Preamble *string `json:"preamble,omitempty" url:"preamble,omitempty"` - +type ChatSearchResultsEvent struct { + // Conducted searches and the ids of documents retrieved from each of them. + SearchResults []*ChatSearchResult `json:"search_results,omitempty" url:"search_results,omitempty"` + // Documents fetched from searches or provided by the user. + Documents []ChatDocument `json:"documents,omitempty" url:"documents,omitempty"` + extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (c *ChatSearchResultsEvent) GetSearchResults() []*ChatSearchResult { + if c == nil { + return nil + } + return c.SearchResults +} + +func (c *ChatSearchResultsEvent) GetDocuments() []ChatDocument { + if c == nil { + return nil + } + return c.Documents } -func (c *ChatDataMetrics) GetExtraProperties() map[string]interface{} { +func (c *ChatSearchResultsEvent) GetExtraProperties() map[string]interface{} { return c.extraProperties } -func (c *ChatDataMetrics) UnmarshalJSON(data []byte) error { - type unmarshaler ChatDataMetrics +func (c *ChatSearchResultsEvent) UnmarshalJSON(data []byte) error { + type unmarshaler ChatSearchResultsEvent var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ChatDataMetrics(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *c = ChatSearchResultsEvent(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) if err != nil { return err } c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + c.rawJSON = json.RawMessage(data) return nil } -func (c *ChatDataMetrics) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (c *ChatSearchResultsEvent) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(c); err == nil { return value } return fmt.Sprintf("%#v", c) } -type ChatDebugEvent struct { - Prompt *string `json:"prompt,omitempty" url:"prompt,omitempty"` +type ChatStreamEndEvent struct { + // - `COMPLETE` - the model sent back a finished reply + // - `ERROR_LIMIT` - the reply was cut off because the model reached the maximum number of tokens for its context length + // - `MAX_TOKENS` - the reply was cut off because the model reached the maximum number of tokens specified by the max_tokens parameter + // - `ERROR` - something went wrong when generating the reply + // - `ERROR_TOXIC` - the model generated a reply that was deemed toxic + FinishReason ChatStreamEndEventFinishReason `json:"finish_reason" url:"finish_reason"` + // The consolidated response from the model. Contains the generated reply and all the other information streamed back in the previous events. + Response *NonStreamedChatResponse `json:"response,omitempty" url:"response,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage } -func (c *ChatDebugEvent) GetExtraProperties() map[string]interface{} { +func (c *ChatStreamEndEvent) GetFinishReason() ChatStreamEndEventFinishReason { + if c == nil { + return "" + } + return c.FinishReason +} + +func (c *ChatStreamEndEvent) GetResponse() *NonStreamedChatResponse { + if c == nil { + return nil + } + return c.Response +} + +func (c *ChatStreamEndEvent) GetExtraProperties() map[string]interface{} { return c.extraProperties } -func (c *ChatDebugEvent) UnmarshalJSON(data []byte) error { - type unmarshaler ChatDebugEvent +func (c *ChatStreamEndEvent) UnmarshalJSON(data []byte) error { + type unmarshaler ChatStreamEndEvent var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ChatDebugEvent(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *c = ChatStreamEndEvent(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) if err != nil { return err } c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + c.rawJSON = json.RawMessage(data) return nil } -func (c *ChatDebugEvent) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (c *ChatStreamEndEvent) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(c); err == nil { return value } return fmt.Sprintf("%#v", c) } -// Relevant information that could be used by the model to generate a more accurate reply. -// The contents of each document are generally short (under 300 words), and are passed in the form of a -// dictionary of strings. Some suggested keys are "text", "author", "date". Both the key name and the value will be -// passed to the model. -type ChatDocument = map[string]string - -// The reason a chat request has finished. -// -// - **complete**: The model finished sending a complete message. -// - **max_tokens**: The number of generated tokens exceeded the model's context length or the value specified via the `max_tokens` parameter. -// - **stop_sequence**: One of the provided `stop_sequence` entries was reached in the model's generation. -// - **tool_call**: The model generated a Tool Call and is expecting a Tool Message in return -// - **error**: The generation failed due to an internal error -type ChatFinishReason string +// - `COMPLETE` - the model sent back a finished reply +// - `ERROR_LIMIT` - the reply was cut off because the model reached the maximum number of tokens for its context length +// - `MAX_TOKENS` - the reply was cut off because the model reached the maximum number of tokens specified by the max_tokens parameter +// - `ERROR` - something went wrong when generating the reply +// - `ERROR_TOXIC` - the model generated a reply that was deemed toxic +type ChatStreamEndEventFinishReason string const ( - ChatFinishReasonComplete ChatFinishReason = "COMPLETE" - ChatFinishReasonStopSequence ChatFinishReason = "STOP_SEQUENCE" - ChatFinishReasonMaxTokens ChatFinishReason = "MAX_TOKENS" - ChatFinishReasonToolCall ChatFinishReason = "TOOL_CALL" - ChatFinishReasonError ChatFinishReason = "ERROR" + ChatStreamEndEventFinishReasonComplete ChatStreamEndEventFinishReason = "COMPLETE" + ChatStreamEndEventFinishReasonErrorLimit ChatStreamEndEventFinishReason = "ERROR_LIMIT" + ChatStreamEndEventFinishReasonMaxTokens ChatStreamEndEventFinishReason = "MAX_TOKENS" + ChatStreamEndEventFinishReasonError ChatStreamEndEventFinishReason = "ERROR" + ChatStreamEndEventFinishReasonErrorToxic ChatStreamEndEventFinishReason = "ERROR_TOXIC" ) -func NewChatFinishReasonFromString(s string) (ChatFinishReason, error) { +func NewChatStreamEndEventFinishReasonFromString(s string) (ChatStreamEndEventFinishReason, error) { switch s { case "COMPLETE": - return ChatFinishReasonComplete, nil - case "STOP_SEQUENCE": - return ChatFinishReasonStopSequence, nil + return ChatStreamEndEventFinishReasonComplete, nil + case "ERROR_LIMIT": + return ChatStreamEndEventFinishReasonErrorLimit, nil case "MAX_TOKENS": - return ChatFinishReasonMaxTokens, nil - case "TOOL_CALL": - return ChatFinishReasonToolCall, nil + return ChatStreamEndEventFinishReasonMaxTokens, nil case "ERROR": - return ChatFinishReasonError, nil + return ChatStreamEndEventFinishReasonError, nil + case "ERROR_TOXIC": + return ChatStreamEndEventFinishReasonErrorToxic, nil } - var t ChatFinishReason + var t ChatStreamEndEventFinishReason return "", fmt.Errorf("%s is not a valid %T", s, t) } -func (c ChatFinishReason) Ptr() *ChatFinishReason { +func (c ChatStreamEndEventFinishReason) Ptr() *ChatStreamEndEventFinishReason { return &c } -// Represents a single message in the chat history, excluding the current user turn. It has two properties: `role` and `message`. The `role` identifies the sender (`CHATBOT`, `SYSTEM`, or `USER`), while the `message` contains the text content. -// -// The chat_history parameter should not be used for `SYSTEM` messages in most cases. Instead, to add a `SYSTEM` role message at the beginning of a conversation, the `preamble` parameter should be used. -type ChatMessage struct { - // Contents of the chat message. - Message string `json:"message" url:"message"` - ToolCalls []*ToolCall `json:"tool_calls,omitempty" url:"tool_calls,omitempty"` - +type ChatStreamEvent struct { extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage } -func (c *ChatMessage) GetExtraProperties() map[string]interface{} { +func (c *ChatStreamEvent) GetExtraProperties() map[string]interface{} { return c.extraProperties } -func (c *ChatMessage) UnmarshalJSON(data []byte) error { - type unmarshaler ChatMessage +func (c *ChatStreamEvent) UnmarshalJSON(data []byte) error { + type unmarshaler ChatStreamEvent var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ChatMessage(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *c = ChatStreamEvent(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) if err != nil { return err } c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + c.rawJSON = json.RawMessage(data) return nil } -func (c *ChatMessage) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (c *ChatStreamEvent) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(c); err == nil { return value } return fmt.Sprintf("%#v", c) } -// A streamed event which signifies that the chat message has ended. -type ChatMessageEndEvent struct { - Id *string `json:"id,omitempty" url:"id,omitempty"` - Delta *ChatMessageEndEventDelta `json:"delta,omitempty" url:"delta,omitempty"` +// Defaults to `"accurate"`. +// +// Dictates the approach taken to generating citations as part of the RAG flow by allowing the user to specify whether they want `"accurate"` results, `"fast"` results or no results. +// +// Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments +type ChatStreamRequestCitationQuality string + +const ( + ChatStreamRequestCitationQualityFast ChatStreamRequestCitationQuality = "fast" + ChatStreamRequestCitationQualityAccurate ChatStreamRequestCitationQuality = "accurate" + ChatStreamRequestCitationQualityOff ChatStreamRequestCitationQuality = "off" +) + +func NewChatStreamRequestCitationQualityFromString(s string) (ChatStreamRequestCitationQuality, error) { + switch s { + case "fast": + return ChatStreamRequestCitationQualityFast, nil + case "accurate": + return ChatStreamRequestCitationQualityAccurate, nil + case "off": + return ChatStreamRequestCitationQualityOff, nil + } + var t ChatStreamRequestCitationQuality + return "", fmt.Errorf("%s is not a valid %T", s, t) +} + +func (c ChatStreamRequestCitationQuality) Ptr() *ChatStreamRequestCitationQuality { + return &c +} + +// (internal) Sets inference and model options for RAG search query and tool use generations. Defaults are used when options are not specified here, meaning that other parameters outside of connectors_search_options are ignored (such as model= or temperature=). +type ChatStreamRequestConnectorsSearchOptions struct { + // If specified, the backend will make a best effort to sample tokens + // deterministically, such that repeated requests with the same + // seed and parameters should return the same result. However, + // determinism cannot be totally guaranteed. + // + // Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments + Seed *int `json:"seed,omitempty" url:"seed,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (c *ChatStreamRequestConnectorsSearchOptions) GetSeed() *int { + if c == nil { + return nil + } + return c.Seed } -func (c *ChatMessageEndEvent) GetExtraProperties() map[string]interface{} { +func (c *ChatStreamRequestConnectorsSearchOptions) GetExtraProperties() map[string]interface{} { return c.extraProperties } -func (c *ChatMessageEndEvent) UnmarshalJSON(data []byte) error { - type unmarshaler ChatMessageEndEvent +func (c *ChatStreamRequestConnectorsSearchOptions) UnmarshalJSON(data []byte) error { + type unmarshaler ChatStreamRequestConnectorsSearchOptions var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ChatMessageEndEvent(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *c = ChatStreamRequestConnectorsSearchOptions(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) if err != nil { return err } c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + c.rawJSON = json.RawMessage(data) return nil } -func (c *ChatMessageEndEvent) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (c *ChatStreamRequestConnectorsSearchOptions) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(c); err == nil { return value } return fmt.Sprintf("%#v", c) } -type ChatMessageEndEventDelta struct { - FinishReason *ChatFinishReason `json:"finish_reason,omitempty" url:"finish_reason,omitempty"` - Usage *Usage `json:"usage,omitempty" url:"usage,omitempty"` +// Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases. +// +// Dictates how the prompt will be constructed. +// +// With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be changed and ranked by relevance. +// +// With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API. +// +// With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be returned. +// +// Compatible Deployments: +// - AUTO: Cohere Platform Only +// - AUTO_PRESERVE_ORDER: Azure, AWS Sagemaker/Bedrock, Private Deployments +type ChatStreamRequestPromptTruncation string + +const ( + ChatStreamRequestPromptTruncationOff ChatStreamRequestPromptTruncation = "OFF" + ChatStreamRequestPromptTruncationAuto ChatStreamRequestPromptTruncation = "AUTO" + ChatStreamRequestPromptTruncationAutoPreserveOrder ChatStreamRequestPromptTruncation = "AUTO_PRESERVE_ORDER" +) - extraProperties map[string]interface{} - _rawJSON json.RawMessage +func NewChatStreamRequestPromptTruncationFromString(s string) (ChatStreamRequestPromptTruncation, error) { + switch s { + case "OFF": + return ChatStreamRequestPromptTruncationOff, nil + case "AUTO": + return ChatStreamRequestPromptTruncationAuto, nil + case "AUTO_PRESERVE_ORDER": + return ChatStreamRequestPromptTruncationAutoPreserveOrder, nil + } + var t ChatStreamRequestPromptTruncation + return "", fmt.Errorf("%s is not a valid %T", s, t) } -func (c *ChatMessageEndEventDelta) GetExtraProperties() map[string]interface{} { - return c.extraProperties +func (c ChatStreamRequestPromptTruncation) Ptr() *ChatStreamRequestPromptTruncation { + return &c } -func (c *ChatMessageEndEventDelta) UnmarshalJSON(data []byte) error { - type unmarshaler ChatMessageEndEventDelta - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *c = ChatMessageEndEventDelta(value) +// Used to select the [safety instruction](https://docs.cohere.com/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`. +// When `NONE` is specified, the safety instruction will be omitted. +// +// Safety modes are not yet configurable in combination with `tools`, `tool_results` and `documents` parameters. +// +// **Note**: This parameter is only compatible with models [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release), [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release) and newer. +// +// **Note**: `command-r7b-12-2024` only supports `"CONTEXTUAL"` and `"STRICT"` modes. +// +// Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments +type ChatStreamRequestSafetyMode string - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err - } - c.extraProperties = extraProperties +const ( + ChatStreamRequestSafetyModeContextual ChatStreamRequestSafetyMode = "CONTEXTUAL" + ChatStreamRequestSafetyModeStrict ChatStreamRequestSafetyMode = "STRICT" + ChatStreamRequestSafetyModeNone ChatStreamRequestSafetyMode = "NONE" +) - c._rawJSON = json.RawMessage(data) - return nil +func NewChatStreamRequestSafetyModeFromString(s string) (ChatStreamRequestSafetyMode, error) { + switch s { + case "CONTEXTUAL": + return ChatStreamRequestSafetyModeContextual, nil + case "STRICT": + return ChatStreamRequestSafetyModeStrict, nil + case "NONE": + return ChatStreamRequestSafetyModeNone, nil + } + var t ChatStreamRequestSafetyMode + return "", fmt.Errorf("%s is not a valid %T", s, t) } -func (c *ChatMessageEndEventDelta) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(c); err == nil { - return value - } - return fmt.Sprintf("%#v", c) +func (c ChatStreamRequestSafetyMode) Ptr() *ChatStreamRequestSafetyMode { + return &c } -// A streamed event which signifies that a stream has started. -type ChatMessageStartEvent struct { - // Unique identifier for the generated reply. - Id *string `json:"id,omitempty" url:"id,omitempty"` - Delta *ChatMessageStartEventDelta `json:"delta,omitempty" url:"delta,omitempty"` +type ChatStreamStartEvent struct { + // Unique identifier for the generated reply. Useful for submitting feedback. + GenerationId string `json:"generation_id" url:"generation_id"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (c *ChatStreamStartEvent) GetGenerationId() string { + if c == nil { + return "" + } + return c.GenerationId } -func (c *ChatMessageStartEvent) GetExtraProperties() map[string]interface{} { +func (c *ChatStreamStartEvent) GetExtraProperties() map[string]interface{} { return c.extraProperties } -func (c *ChatMessageStartEvent) UnmarshalJSON(data []byte) error { - type unmarshaler ChatMessageStartEvent +func (c *ChatStreamStartEvent) UnmarshalJSON(data []byte) error { + type unmarshaler ChatStreamStartEvent var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ChatMessageStartEvent(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *c = ChatStreamStartEvent(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) if err != nil { return err } c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + c.rawJSON = json.RawMessage(data) return nil } -func (c *ChatMessageStartEvent) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (c *ChatStreamStartEvent) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(c); err == nil { return value } return fmt.Sprintf("%#v", c) } -type ChatMessageStartEventDelta struct { - Message *ChatMessageStartEventDeltaMessage `json:"message,omitempty" url:"message,omitempty"` +type ChatTextGenerationEvent struct { + // The next batch of text generated by the model. + Text string `json:"text" url:"text"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage } -func (c *ChatMessageStartEventDelta) GetExtraProperties() map[string]interface{} { +func (c *ChatTextGenerationEvent) GetText() string { + if c == nil { + return "" + } + return c.Text +} + +func (c *ChatTextGenerationEvent) GetExtraProperties() map[string]interface{} { return c.extraProperties } -func (c *ChatMessageStartEventDelta) UnmarshalJSON(data []byte) error { - type unmarshaler ChatMessageStartEventDelta +func (c *ChatTextGenerationEvent) UnmarshalJSON(data []byte) error { + type unmarshaler ChatTextGenerationEvent var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ChatMessageStartEventDelta(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *c = ChatTextGenerationEvent(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) if err != nil { return err } c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + c.rawJSON = json.RawMessage(data) return nil } -func (c *ChatMessageStartEventDelta) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (c *ChatTextGenerationEvent) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(c); err == nil { return value } return fmt.Sprintf("%#v", c) } -type ChatMessageStartEventDeltaMessage struct { - // The role of the message. - Role *string `json:"role,omitempty" url:"role,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage +type ChatToolCallsChunkEvent struct { + ToolCallDelta *ToolCallDelta `json:"tool_call_delta,omitempty" url:"tool_call_delta,omitempty"` + Text *string `json:"text,omitempty" url:"text,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatToolCallsChunkEvent) GetToolCallDelta() *ToolCallDelta { + if c == nil { + return nil + } + return c.ToolCallDelta +} + +func (c *ChatToolCallsChunkEvent) GetText() *string { + if c == nil { + return nil + } + return c.Text } -func (c *ChatMessageStartEventDeltaMessage) GetExtraProperties() map[string]interface{} { +func (c *ChatToolCallsChunkEvent) GetExtraProperties() map[string]interface{} { return c.extraProperties } -func (c *ChatMessageStartEventDeltaMessage) UnmarshalJSON(data []byte) error { - type unmarshaler ChatMessageStartEventDeltaMessage +func (c *ChatToolCallsChunkEvent) UnmarshalJSON(data []byte) error { + type unmarshaler ChatToolCallsChunkEvent var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ChatMessageStartEventDeltaMessage(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *c = ChatToolCallsChunkEvent(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) if err != nil { return err } c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + c.rawJSON = json.RawMessage(data) return nil } -func (c *ChatMessageStartEventDeltaMessage) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (c *ChatToolCallsChunkEvent) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(c); err == nil { return value } return fmt.Sprintf("%#v", c) } -// Represents a single message in the chat history from a given role. -type ChatMessageV2 struct { - Role string - User *UserMessage - Assistant *AssistantMessage - System *SystemMessage - Tool *ToolMessageV2 -} - -func (c *ChatMessageV2) UnmarshalJSON(data []byte) error { - var unmarshaler struct { - Role string `json:"role"` - } - if err := json.Unmarshal(data, &unmarshaler); err != nil { - return err - } - c.Role = unmarshaler.Role - if unmarshaler.Role == "" { - return fmt.Errorf("%T did not include discriminant role", c) - } - switch unmarshaler.Role { - case "user": - value := new(UserMessage) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - c.User = value - case "assistant": - value := new(AssistantMessage) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - c.Assistant = value - case "system": - value := new(SystemMessage) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - c.System = value - case "tool": - value := new(ToolMessageV2) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - c.Tool = value - } - return nil -} - -func (c ChatMessageV2) MarshalJSON() ([]byte, error) { - if c.User != nil { - return core.MarshalJSONWithExtraProperty(c.User, "role", "user") - } - if c.Assistant != nil { - return core.MarshalJSONWithExtraProperty(c.Assistant, "role", "assistant") - } - if c.System != nil { - return core.MarshalJSONWithExtraProperty(c.System, "role", "system") - } - if c.Tool != nil { - return core.MarshalJSONWithExtraProperty(c.Tool, "role", "tool") - } - return nil, fmt.Errorf("type %T does not define a non-empty union type", c) -} +type ChatToolCallsGenerationEvent struct { + // The text generated related to the tool calls generated + Text *string `json:"text,omitempty" url:"text,omitempty"` + ToolCalls []*ToolCall `json:"tool_calls,omitempty" url:"tool_calls,omitempty"` -type ChatMessageV2Visitor interface { - VisitUser(*UserMessage) error - VisitAssistant(*AssistantMessage) error - VisitSystem(*SystemMessage) error - VisitTool(*ToolMessageV2) error + extraProperties map[string]interface{} + rawJSON json.RawMessage } -func (c *ChatMessageV2) Accept(visitor ChatMessageV2Visitor) error { - if c.User != nil { - return visitor.VisitUser(c.User) - } - if c.Assistant != nil { - return visitor.VisitAssistant(c.Assistant) - } - if c.System != nil { - return visitor.VisitSystem(c.System) - } - if c.Tool != nil { - return visitor.VisitTool(c.Tool) +func (c *ChatToolCallsGenerationEvent) GetText() *string { + if c == nil { + return nil } - return fmt.Errorf("type %T does not define a non-empty union type", c) + return c.Text } -// A list of chat messages in chronological order, representing a conversation between the user and the model. -// -// Messages can be from `User`, `Assistant`, `Tool` and `System` roles. Learn more about messages and roles in [the Chat API guide](https://docs.cohere.com/v2/docs/chat-api). -type ChatMessages = []*ChatMessageV2 - -// Defaults to `"accurate"`. -// -// Dictates the approach taken to generating citations as part of the RAG flow by allowing the user to specify whether they want `"accurate"` results, `"fast"` results or no results. -// -// Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments -type ChatRequestCitationQuality string - -const ( - ChatRequestCitationQualityFast ChatRequestCitationQuality = "fast" - ChatRequestCitationQualityAccurate ChatRequestCitationQuality = "accurate" - ChatRequestCitationQualityOff ChatRequestCitationQuality = "off" -) - -func NewChatRequestCitationQualityFromString(s string) (ChatRequestCitationQuality, error) { - switch s { - case "fast": - return ChatRequestCitationQualityFast, nil - case "accurate": - return ChatRequestCitationQualityAccurate, nil - case "off": - return ChatRequestCitationQualityOff, nil +func (c *ChatToolCallsGenerationEvent) GetToolCalls() []*ToolCall { + if c == nil { + return nil } - var t ChatRequestCitationQuality - return "", fmt.Errorf("%s is not a valid %T", s, t) -} - -func (c ChatRequestCitationQuality) Ptr() *ChatRequestCitationQuality { - return &c -} - -// (internal) Sets inference and model options for RAG search query and tool use generations. Defaults are used when options are not specified here, meaning that other parameters outside of connectors_search_options are ignored (such as model= or temperature=). -type ChatRequestConnectorsSearchOptions struct { - // If specified, the backend will make a best effort to sample tokens - // deterministically, such that repeated requests with the same - // seed and parameters should return the same result. However, - // determinism cannot be totally guaranteed. - // - // Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments - Seed *int `json:"seed,omitempty" url:"seed,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage + return c.ToolCalls } -func (c *ChatRequestConnectorsSearchOptions) GetExtraProperties() map[string]interface{} { +func (c *ChatToolCallsGenerationEvent) GetExtraProperties() map[string]interface{} { return c.extraProperties } -func (c *ChatRequestConnectorsSearchOptions) UnmarshalJSON(data []byte) error { - type unmarshaler ChatRequestConnectorsSearchOptions +func (c *ChatToolCallsGenerationEvent) UnmarshalJSON(data []byte) error { + type unmarshaler ChatToolCallsGenerationEvent var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ChatRequestConnectorsSearchOptions(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *c = ChatToolCallsGenerationEvent(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) if err != nil { return err } c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + c.rawJSON = json.RawMessage(data) return nil } -func (c *ChatRequestConnectorsSearchOptions) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (c *ChatToolCallsGenerationEvent) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(c); err == nil { return value } return fmt.Sprintf("%#v", c) } -// Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases. -// -// Dictates how the prompt will be constructed. -// -// With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be changed and ranked by relevance. -// -// With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API. -// -// With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be returned. -// -// Compatible Deployments: -// -// - AUTO: Cohere Platform Only -// - AUTO_PRESERVE_ORDER: Azure, AWS Sagemaker/Bedrock, Private Deployments -type ChatRequestPromptTruncation string - -const ( - ChatRequestPromptTruncationOff ChatRequestPromptTruncation = "OFF" - ChatRequestPromptTruncationAuto ChatRequestPromptTruncation = "AUTO" - ChatRequestPromptTruncationAutoPreserveOrder ChatRequestPromptTruncation = "AUTO_PRESERVE_ORDER" -) - -func NewChatRequestPromptTruncationFromString(s string) (ChatRequestPromptTruncation, error) { - switch s { - case "OFF": - return ChatRequestPromptTruncationOff, nil - case "AUTO": - return ChatRequestPromptTruncationAuto, nil - case "AUTO_PRESERVE_ORDER": - return ChatRequestPromptTruncationAutoPreserveOrder, nil - } - var t ChatRequestPromptTruncation - return "", fmt.Errorf("%s is not a valid %T", s, t) -} +type CheckApiKeyResponse struct { + Valid bool `json:"valid" url:"valid"` + OrganizationId *string `json:"organization_id,omitempty" url:"organization_id,omitempty"` + OwnerId *string `json:"owner_id,omitempty" url:"owner_id,omitempty"` -func (c ChatRequestPromptTruncation) Ptr() *ChatRequestPromptTruncation { - return &c + extraProperties map[string]interface{} + rawJSON json.RawMessage } -// Used to select the [safety instruction](https://docs.cohere.com/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`. -// When `NONE` is specified, the safety instruction will be omitted. -// -// Safety modes are not yet configurable in combination with `tools`, `tool_results` and `documents` parameters. -// -// **Note**: This parameter is only compatible with models [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release), [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release) and newer. -// -// **Note**: `command-r7b-12-2024` only supports `"CONTEXTUAL"` and `"STRICT"` modes. -// -// Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments -type ChatRequestSafetyMode string - -const ( - ChatRequestSafetyModeContextual ChatRequestSafetyMode = "CONTEXTUAL" - ChatRequestSafetyModeStrict ChatRequestSafetyMode = "STRICT" - ChatRequestSafetyModeNone ChatRequestSafetyMode = "NONE" -) - -func NewChatRequestSafetyModeFromString(s string) (ChatRequestSafetyMode, error) { - switch s { - case "CONTEXTUAL": - return ChatRequestSafetyModeContextual, nil - case "STRICT": - return ChatRequestSafetyModeStrict, nil - case "NONE": - return ChatRequestSafetyModeNone, nil +func (c *CheckApiKeyResponse) GetValid() bool { + if c == nil { + return false } - var t ChatRequestSafetyMode - return "", fmt.Errorf("%s is not a valid %T", s, t) + return c.Valid } -func (c ChatRequestSafetyMode) Ptr() *ChatRequestSafetyMode { - return &c +func (c *CheckApiKeyResponse) GetOrganizationId() *string { + if c == nil { + return nil + } + return c.OrganizationId } -type ChatResponse struct { - // Unique identifier for the generated reply. Useful for submitting feedback. - Id string `json:"id" url:"id"` - FinishReason ChatFinishReason `json:"finish_reason" url:"finish_reason"` - // The prompt that was used. Only present when `return_prompt` in the request is set to true. - Prompt *string `json:"prompt,omitempty" url:"prompt,omitempty"` - Message *AssistantMessageResponse `json:"message,omitempty" url:"message,omitempty"` - Usage *Usage `json:"usage,omitempty" url:"usage,omitempty"` - Logprobs []*LogprobItem `json:"logprobs,omitempty" url:"logprobs,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage +func (c *CheckApiKeyResponse) GetOwnerId() *string { + if c == nil { + return nil + } + return c.OwnerId } -func (c *ChatResponse) GetExtraProperties() map[string]interface{} { +func (c *CheckApiKeyResponse) GetExtraProperties() map[string]interface{} { return c.extraProperties } -func (c *ChatResponse) UnmarshalJSON(data []byte) error { - type unmarshaler ChatResponse +func (c *CheckApiKeyResponse) UnmarshalJSON(data []byte) error { + type unmarshaler CheckApiKeyResponse var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ChatResponse(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *c = CheckApiKeyResponse(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) if err != nil { return err } c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + c.rawJSON = json.RawMessage(data) return nil } -func (c *ChatResponse) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (c *CheckApiKeyResponse) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(c); err == nil { return value } return fmt.Sprintf("%#v", c) } -type ChatSearchQueriesGenerationEvent struct { - // Generated search queries, meant to be used as part of the RAG flow. - SearchQueries []*ChatSearchQuery `json:"search_queries,omitempty" url:"search_queries,omitempty"` +type ClassifyExample struct { + Text *string `json:"text,omitempty" url:"text,omitempty"` + Label *string `json:"label,omitempty" url:"label,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (c *ChatSearchQueriesGenerationEvent) GetExtraProperties() map[string]interface{} { - return c.extraProperties + rawJSON json.RawMessage } -func (c *ChatSearchQueriesGenerationEvent) UnmarshalJSON(data []byte) error { - type unmarshaler ChatSearchQueriesGenerationEvent - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *c = ChatSearchQueriesGenerationEvent(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err +func (c *ClassifyExample) GetText() *string { + if c == nil { + return nil } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) - return nil + return c.Text } -func (c *ChatSearchQueriesGenerationEvent) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(c); err == nil { - return value +func (c *ClassifyExample) GetLabel() *string { + if c == nil { + return nil } - return fmt.Sprintf("%#v", c) -} - -// The generated search query. Contains the text of the query and a unique identifier for the query. -type ChatSearchQuery struct { - // The text of the search query. - Text string `json:"text" url:"text"` - // Unique identifier for the generated search query. Useful for submitting feedback. - GenerationId string `json:"generation_id" url:"generation_id"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage + return c.Label } -func (c *ChatSearchQuery) GetExtraProperties() map[string]interface{} { +func (c *ClassifyExample) GetExtraProperties() map[string]interface{} { return c.extraProperties } -func (c *ChatSearchQuery) UnmarshalJSON(data []byte) error { - type unmarshaler ChatSearchQuery +func (c *ClassifyExample) UnmarshalJSON(data []byte) error { + type unmarshaler ClassifyExample var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ChatSearchQuery(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *c = ClassifyExample(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) if err != nil { return err } c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + c.rawJSON = json.RawMessage(data) return nil } -func (c *ChatSearchQuery) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (c *ClassifyExample) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(c); err == nil { return value } return fmt.Sprintf("%#v", c) } -type ChatSearchResult struct { - SearchQuery *ChatSearchQuery `json:"search_query,omitempty" url:"search_query,omitempty"` - // The connector from which this result comes from. - Connector *ChatSearchResultConnector `json:"connector,omitempty" url:"connector,omitempty"` - // Identifiers of documents found by this search query. - DocumentIds []string `json:"document_ids,omitempty" url:"document_ids,omitempty"` - // An error message if the search failed. - ErrorMessage *string `json:"error_message,omitempty" url:"error_message,omitempty"` - // Whether a chat request should continue or not if the request to this connector fails. - ContinueOnFailure *bool `json:"continue_on_failure,omitempty" url:"continue_on_failure,omitempty"` +// One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. +// Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. +// If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned. +type ClassifyRequestTruncate string - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} +const ( + ClassifyRequestTruncateNone ClassifyRequestTruncate = "NONE" + ClassifyRequestTruncateStart ClassifyRequestTruncate = "START" + ClassifyRequestTruncateEnd ClassifyRequestTruncate = "END" +) -func (c *ChatSearchResult) GetExtraProperties() map[string]interface{} { - return c.extraProperties +func NewClassifyRequestTruncateFromString(s string) (ClassifyRequestTruncate, error) { + switch s { + case "NONE": + return ClassifyRequestTruncateNone, nil + case "START": + return ClassifyRequestTruncateStart, nil + case "END": + return ClassifyRequestTruncateEnd, nil + } + var t ClassifyRequestTruncate + return "", fmt.Errorf("%s is not a valid %T", s, t) } -func (c *ChatSearchResult) UnmarshalJSON(data []byte) error { - type unmarshaler ChatSearchResult - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *c = ChatSearchResult(value) +func (c ClassifyRequestTruncate) Ptr() *ClassifyRequestTruncate { + return &c +} - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err - } - c.extraProperties = extraProperties +type ClassifyResponse struct { + Id string `json:"id" url:"id"` + Classifications []*ClassifyResponseClassificationsItem `json:"classifications,omitempty" url:"classifications,omitempty"` + Meta *ApiMeta `json:"meta,omitempty" url:"meta,omitempty"` - c._rawJSON = json.RawMessage(data) - return nil + extraProperties map[string]interface{} + rawJSON json.RawMessage } -func (c *ChatSearchResult) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(c); err == nil { - return value +func (c *ClassifyResponse) GetId() string { + if c == nil { + return "" } - return fmt.Sprintf("%#v", c) + return c.Id } -// The connector used for fetching documents. -type ChatSearchResultConnector struct { - // The identifier of the connector. - Id string `json:"id" url:"id"` +func (c *ClassifyResponse) GetClassifications() []*ClassifyResponseClassificationsItem { + if c == nil { + return nil + } + return c.Classifications +} - extraProperties map[string]interface{} - _rawJSON json.RawMessage +func (c *ClassifyResponse) GetMeta() *ApiMeta { + if c == nil { + return nil + } + return c.Meta } -func (c *ChatSearchResultConnector) GetExtraProperties() map[string]interface{} { +func (c *ClassifyResponse) GetExtraProperties() map[string]interface{} { return c.extraProperties } -func (c *ChatSearchResultConnector) UnmarshalJSON(data []byte) error { - type unmarshaler ChatSearchResultConnector +func (c *ClassifyResponse) UnmarshalJSON(data []byte) error { + type unmarshaler ClassifyResponse var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ChatSearchResultConnector(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *c = ClassifyResponse(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) if err != nil { return err } c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + c.rawJSON = json.RawMessage(data) return nil } -func (c *ChatSearchResultConnector) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (c *ClassifyResponse) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(c); err == nil { return value } return fmt.Sprintf("%#v", c) } -type ChatSearchResultsEvent struct { - // Conducted searches and the ids of documents retrieved from each of them. - SearchResults []*ChatSearchResult `json:"search_results,omitempty" url:"search_results,omitempty"` - // Documents fetched from searches or provided by the user. - Documents []ChatDocument `json:"documents,omitempty" url:"documents,omitempty"` +type ClassifyResponseClassificationsItem struct { + Id string `json:"id" url:"id"` + // The input text that was classified + Input *string `json:"input,omitempty" url:"input,omitempty"` + // The predicted label for the associated query (only filled for single-label models) + Prediction *string `json:"prediction,omitempty" url:"prediction,omitempty"` + // An array containing the predicted labels for the associated query (only filled for single-label classification) + Predictions []string `json:"predictions,omitempty" url:"predictions,omitempty"` + // The confidence score for the top predicted class (only filled for single-label classification) + Confidence *float64 `json:"confidence,omitempty" url:"confidence,omitempty"` + // An array containing the confidence scores of all the predictions in the same order + Confidences []float64 `json:"confidences,omitempty" url:"confidences,omitempty"` + // A map containing each label and its confidence score according to the classifier. All the confidence scores add up to 1 for single-label classification. For multi-label classification the label confidences are independent of each other, so they don't have to sum up to 1. + Labels map[string]*ClassifyResponseClassificationsItemLabelsValue `json:"labels,omitempty" url:"labels,omitempty"` + // The type of classification performed + ClassificationType ClassifyResponseClassificationsItemClassificationType `json:"classification_type" url:"classification_type"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage } -func (c *ChatSearchResultsEvent) GetExtraProperties() map[string]interface{} { - return c.extraProperties +func (c *ClassifyResponseClassificationsItem) GetId() string { + if c == nil { + return "" + } + return c.Id } -func (c *ChatSearchResultsEvent) UnmarshalJSON(data []byte) error { - type unmarshaler ChatSearchResultsEvent - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err +func (c *ClassifyResponseClassificationsItem) GetInput() *string { + if c == nil { + return nil } - *c = ChatSearchResultsEvent(value) + return c.Input +} - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err +func (c *ClassifyResponseClassificationsItem) GetPrediction() *string { + if c == nil { + return nil } - c.extraProperties = extraProperties + return c.Prediction +} - c._rawJSON = json.RawMessage(data) - return nil +func (c *ClassifyResponseClassificationsItem) GetPredictions() []string { + if c == nil { + return nil + } + return c.Predictions } -func (c *ChatSearchResultsEvent) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } +func (c *ClassifyResponseClassificationsItem) GetConfidence() *float64 { + if c == nil { + return nil } - if value, err := core.StringifyJSON(c); err == nil { - return value + return c.Confidence +} + +func (c *ClassifyResponseClassificationsItem) GetConfidences() []float64 { + if c == nil { + return nil } - return fmt.Sprintf("%#v", c) + return c.Confidences } -type ChatStreamEndEvent struct { - // - `COMPLETE` - the model sent back a finished reply - // - `ERROR_LIMIT` - the reply was cut off because the model reached the maximum number of tokens for its context length - // - `MAX_TOKENS` - the reply was cut off because the model reached the maximum number of tokens specified by the max_tokens parameter - // - `ERROR` - something went wrong when generating the reply - // - `ERROR_TOXIC` - the model generated a reply that was deemed toxic - FinishReason ChatStreamEndEventFinishReason `json:"finish_reason" url:"finish_reason"` - // The consolidated response from the model. Contains the generated reply and all the other information streamed back in the previous events. - Response *NonStreamedChatResponse `json:"response,omitempty" url:"response,omitempty"` +func (c *ClassifyResponseClassificationsItem) GetLabels() map[string]*ClassifyResponseClassificationsItemLabelsValue { + if c == nil { + return nil + } + return c.Labels +} - extraProperties map[string]interface{} - _rawJSON json.RawMessage +func (c *ClassifyResponseClassificationsItem) GetClassificationType() ClassifyResponseClassificationsItemClassificationType { + if c == nil { + return "" + } + return c.ClassificationType } -func (c *ChatStreamEndEvent) GetExtraProperties() map[string]interface{} { +func (c *ClassifyResponseClassificationsItem) GetExtraProperties() map[string]interface{} { return c.extraProperties } -func (c *ChatStreamEndEvent) UnmarshalJSON(data []byte) error { - type unmarshaler ChatStreamEndEvent +func (c *ClassifyResponseClassificationsItem) UnmarshalJSON(data []byte) error { + type unmarshaler ClassifyResponseClassificationsItem var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ChatStreamEndEvent(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *c = ClassifyResponseClassificationsItem(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) if err != nil { return err } c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + c.rawJSON = json.RawMessage(data) return nil } -func (c *ChatStreamEndEvent) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (c *ClassifyResponseClassificationsItem) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(c); err == nil { return value } return fmt.Sprintf("%#v", c) } -// - `COMPLETE` - the model sent back a finished reply -// - `ERROR_LIMIT` - the reply was cut off because the model reached the maximum number of tokens for its context length -// - `MAX_TOKENS` - the reply was cut off because the model reached the maximum number of tokens specified by the max_tokens parameter -// - `ERROR` - something went wrong when generating the reply -// - `ERROR_TOXIC` - the model generated a reply that was deemed toxic -type ChatStreamEndEventFinishReason string +// The type of classification performed +type ClassifyResponseClassificationsItemClassificationType string const ( - ChatStreamEndEventFinishReasonComplete ChatStreamEndEventFinishReason = "COMPLETE" - ChatStreamEndEventFinishReasonErrorLimit ChatStreamEndEventFinishReason = "ERROR_LIMIT" - ChatStreamEndEventFinishReasonMaxTokens ChatStreamEndEventFinishReason = "MAX_TOKENS" - ChatStreamEndEventFinishReasonError ChatStreamEndEventFinishReason = "ERROR" - ChatStreamEndEventFinishReasonErrorToxic ChatStreamEndEventFinishReason = "ERROR_TOXIC" + ClassifyResponseClassificationsItemClassificationTypeSingleLabel ClassifyResponseClassificationsItemClassificationType = "single-label" + ClassifyResponseClassificationsItemClassificationTypeMultiLabel ClassifyResponseClassificationsItemClassificationType = "multi-label" ) -func NewChatStreamEndEventFinishReasonFromString(s string) (ChatStreamEndEventFinishReason, error) { +func NewClassifyResponseClassificationsItemClassificationTypeFromString(s string) (ClassifyResponseClassificationsItemClassificationType, error) { switch s { - case "COMPLETE": - return ChatStreamEndEventFinishReasonComplete, nil - case "ERROR_LIMIT": - return ChatStreamEndEventFinishReasonErrorLimit, nil - case "MAX_TOKENS": - return ChatStreamEndEventFinishReasonMaxTokens, nil - case "ERROR": - return ChatStreamEndEventFinishReasonError, nil - case "ERROR_TOXIC": - return ChatStreamEndEventFinishReasonErrorToxic, nil + case "single-label": + return ClassifyResponseClassificationsItemClassificationTypeSingleLabel, nil + case "multi-label": + return ClassifyResponseClassificationsItemClassificationTypeMultiLabel, nil } - var t ChatStreamEndEventFinishReason + var t ClassifyResponseClassificationsItemClassificationType return "", fmt.Errorf("%s is not a valid %T", s, t) } -func (c ChatStreamEndEventFinishReason) Ptr() *ChatStreamEndEventFinishReason { +func (c ClassifyResponseClassificationsItemClassificationType) Ptr() *ClassifyResponseClassificationsItemClassificationType { return &c } -type ChatStreamEvent struct { +type ClassifyResponseClassificationsItemLabelsValue struct { + Confidence *float64 `json:"confidence,omitempty" url:"confidence,omitempty"` + extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage } -func (c *ChatStreamEvent) GetExtraProperties() map[string]interface{} { +func (c *ClassifyResponseClassificationsItemLabelsValue) GetConfidence() *float64 { + if c == nil { + return nil + } + return c.Confidence +} + +func (c *ClassifyResponseClassificationsItemLabelsValue) GetExtraProperties() map[string]interface{} { return c.extraProperties } -func (c *ChatStreamEvent) UnmarshalJSON(data []byte) error { - type unmarshaler ChatStreamEvent +func (c *ClassifyResponseClassificationsItemLabelsValue) UnmarshalJSON(data []byte) error { + type unmarshaler ClassifyResponseClassificationsItemLabelsValue var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ChatStreamEvent(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *c = ClassifyResponseClassificationsItemLabelsValue(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) if err != nil { return err } c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + c.rawJSON = json.RawMessage(data) return nil } -func (c *ChatStreamEvent) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (c *ClassifyResponseClassificationsItemLabelsValue) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(c); err == nil { return value } return fmt.Sprintf("%#v", c) } -// The streamed event types -type ChatStreamEventType struct { +type DetokenizeResponse struct { + // A string representing the list of tokens. + Text string `json:"text" url:"text"` + Meta *ApiMeta `json:"meta,omitempty" url:"meta,omitempty"` + extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage } -func (c *ChatStreamEventType) GetExtraProperties() map[string]interface{} { - return c.extraProperties +func (d *DetokenizeResponse) GetText() string { + if d == nil { + return "" + } + return d.Text +} + +func (d *DetokenizeResponse) GetMeta() *ApiMeta { + if d == nil { + return nil + } + return d.Meta +} + +func (d *DetokenizeResponse) GetExtraProperties() map[string]interface{} { + return d.extraProperties } -func (c *ChatStreamEventType) UnmarshalJSON(data []byte) error { - type unmarshaler ChatStreamEventType +func (d *DetokenizeResponse) UnmarshalJSON(data []byte) error { + type unmarshaler DetokenizeResponse var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ChatStreamEventType(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *d = DetokenizeResponse(value) + extraProperties, err := internal.ExtractExtraProperties(data, *d) if err != nil { return err } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + d.extraProperties = extraProperties + d.rawJSON = json.RawMessage(data) return nil } -func (c *ChatStreamEventType) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (d *DetokenizeResponse) String() string { + if len(d.rawJSON) > 0 { + if value, err := internal.StringifyJSON(d.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(d); err == nil { return value } - return fmt.Sprintf("%#v", c) + return fmt.Sprintf("%#v", d) } -// Defaults to `"accurate"`. -// -// Dictates the approach taken to generating citations as part of the RAG flow by allowing the user to specify whether they want `"accurate"` results, `"fast"` results or no results. -// -// Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments -type ChatStreamRequestCitationQuality string +type EmbedByTypeResponse struct { + Id string `json:"id" url:"id"` + // An object with different embedding types. The length of each embedding type array will be the same as the length of the original `texts` array. + Embeddings *EmbedByTypeResponseEmbeddings `json:"embeddings,omitempty" url:"embeddings,omitempty"` + // The text entries for which embeddings were returned. + Texts []string `json:"texts,omitempty" url:"texts,omitempty"` + // The image entries for which embeddings were returned. + Images []*Image `json:"images,omitempty" url:"images,omitempty"` + Meta *ApiMeta `json:"meta,omitempty" url:"meta,omitempty"` -const ( - ChatStreamRequestCitationQualityFast ChatStreamRequestCitationQuality = "fast" - ChatStreamRequestCitationQualityAccurate ChatStreamRequestCitationQuality = "accurate" - ChatStreamRequestCitationQualityOff ChatStreamRequestCitationQuality = "off" -) + extraProperties map[string]interface{} + rawJSON json.RawMessage +} -func NewChatStreamRequestCitationQualityFromString(s string) (ChatStreamRequestCitationQuality, error) { - switch s { - case "fast": - return ChatStreamRequestCitationQualityFast, nil - case "accurate": - return ChatStreamRequestCitationQualityAccurate, nil - case "off": - return ChatStreamRequestCitationQualityOff, nil +func (e *EmbedByTypeResponse) GetId() string { + if e == nil { + return "" } - var t ChatStreamRequestCitationQuality - return "", fmt.Errorf("%s is not a valid %T", s, t) -} - -func (c ChatStreamRequestCitationQuality) Ptr() *ChatStreamRequestCitationQuality { - return &c -} - -// (internal) Sets inference and model options for RAG search query and tool use generations. Defaults are used when options are not specified here, meaning that other parameters outside of connectors_search_options are ignored (such as model= or temperature=). -type ChatStreamRequestConnectorsSearchOptions struct { - // If specified, the backend will make a best effort to sample tokens - // deterministically, such that repeated requests with the same - // seed and parameters should return the same result. However, - // determinism cannot be totally guaranteed. - // - // Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments - Seed *int `json:"seed,omitempty" url:"seed,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (c *ChatStreamRequestConnectorsSearchOptions) GetExtraProperties() map[string]interface{} { - return c.extraProperties + return e.Id } -func (c *ChatStreamRequestConnectorsSearchOptions) UnmarshalJSON(data []byte) error { - type unmarshaler ChatStreamRequestConnectorsSearchOptions - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *c = ChatStreamRequestConnectorsSearchOptions(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err +func (e *EmbedByTypeResponse) GetEmbeddings() *EmbedByTypeResponseEmbeddings { + if e == nil { + return nil } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) - return nil + return e.Embeddings } -func (c *ChatStreamRequestConnectorsSearchOptions) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(c); err == nil { - return value +func (e *EmbedByTypeResponse) GetTexts() []string { + if e == nil { + return nil } - return fmt.Sprintf("%#v", c) + return e.Texts } -// Defaults to `AUTO` when `connectors` are specified and `OFF` in all other cases. -// -// Dictates how the prompt will be constructed. -// -// With `prompt_truncation` set to "AUTO", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be changed and ranked by relevance. -// -// With `prompt_truncation` set to "AUTO_PRESERVE_ORDER", some elements from `chat_history` and `documents` will be dropped in an attempt to construct a prompt that fits within the model's context length limit. During this process the order of the documents and chat history will be preserved as they are inputted into the API. -// -// With `prompt_truncation` set to "OFF", no elements will be dropped. If the sum of the inputs exceeds the model's context length limit, a `TooManyTokens` error will be returned. -// -// Compatible Deployments: -// -// - AUTO: Cohere Platform Only -// - AUTO_PRESERVE_ORDER: Azure, AWS Sagemaker/Bedrock, Private Deployments -type ChatStreamRequestPromptTruncation string - -const ( - ChatStreamRequestPromptTruncationOff ChatStreamRequestPromptTruncation = "OFF" - ChatStreamRequestPromptTruncationAuto ChatStreamRequestPromptTruncation = "AUTO" - ChatStreamRequestPromptTruncationAutoPreserveOrder ChatStreamRequestPromptTruncation = "AUTO_PRESERVE_ORDER" -) - -func NewChatStreamRequestPromptTruncationFromString(s string) (ChatStreamRequestPromptTruncation, error) { - switch s { - case "OFF": - return ChatStreamRequestPromptTruncationOff, nil - case "AUTO": - return ChatStreamRequestPromptTruncationAuto, nil - case "AUTO_PRESERVE_ORDER": - return ChatStreamRequestPromptTruncationAutoPreserveOrder, nil +func (e *EmbedByTypeResponse) GetImages() []*Image { + if e == nil { + return nil } - var t ChatStreamRequestPromptTruncation - return "", fmt.Errorf("%s is not a valid %T", s, t) -} - -func (c ChatStreamRequestPromptTruncation) Ptr() *ChatStreamRequestPromptTruncation { - return &c + return e.Images } -// Used to select the [safety instruction](https://docs.cohere.com/docs/safety-modes) inserted into the prompt. Defaults to `CONTEXTUAL`. -// When `NONE` is specified, the safety instruction will be omitted. -// -// Safety modes are not yet configurable in combination with `tools`, `tool_results` and `documents` parameters. -// -// **Note**: This parameter is only compatible with models [Command R 08-2024](https://docs.cohere.com/docs/command-r#august-2024-release), [Command R+ 08-2024](https://docs.cohere.com/docs/command-r-plus#august-2024-release) and newer. -// -// **Note**: `command-r7b-12-2024` only supports `"CONTEXTUAL"` and `"STRICT"` modes. -// -// Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments -type ChatStreamRequestSafetyMode string - -const ( - ChatStreamRequestSafetyModeContextual ChatStreamRequestSafetyMode = "CONTEXTUAL" - ChatStreamRequestSafetyModeStrict ChatStreamRequestSafetyMode = "STRICT" - ChatStreamRequestSafetyModeNone ChatStreamRequestSafetyMode = "NONE" -) - -func NewChatStreamRequestSafetyModeFromString(s string) (ChatStreamRequestSafetyMode, error) { - switch s { - case "CONTEXTUAL": - return ChatStreamRequestSafetyModeContextual, nil - case "STRICT": - return ChatStreamRequestSafetyModeStrict, nil - case "NONE": - return ChatStreamRequestSafetyModeNone, nil +func (e *EmbedByTypeResponse) GetMeta() *ApiMeta { + if e == nil { + return nil } - var t ChatStreamRequestSafetyMode - return "", fmt.Errorf("%s is not a valid %T", s, t) -} - -func (c ChatStreamRequestSafetyMode) Ptr() *ChatStreamRequestSafetyMode { - return &c + return e.Meta } -type ChatStreamStartEvent struct { - // Unique identifier for the generated reply. Useful for submitting feedback. - GenerationId string `json:"generation_id" url:"generation_id"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (c *ChatStreamStartEvent) GetExtraProperties() map[string]interface{} { - return c.extraProperties +func (e *EmbedByTypeResponse) GetExtraProperties() map[string]interface{} { + return e.extraProperties } -func (c *ChatStreamStartEvent) UnmarshalJSON(data []byte) error { - type unmarshaler ChatStreamStartEvent +func (e *EmbedByTypeResponse) UnmarshalJSON(data []byte) error { + type unmarshaler EmbedByTypeResponse var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ChatStreamStartEvent(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *e = EmbedByTypeResponse(value) + extraProperties, err := internal.ExtractExtraProperties(data, *e) if err != nil { return err } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + e.extraProperties = extraProperties + e.rawJSON = json.RawMessage(data) return nil } -func (c *ChatStreamStartEvent) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (e *EmbedByTypeResponse) String() string { + if len(e.rawJSON) > 0 { + if value, err := internal.StringifyJSON(e.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(e); err == nil { return value } - return fmt.Sprintf("%#v", c) + return fmt.Sprintf("%#v", e) } -type ChatTextGenerationEvent struct { - // The next batch of text generated by the model. - Text string `json:"text" url:"text"` +// An object with different embedding types. The length of each embedding type array will be the same as the length of the original `texts` array. +type EmbedByTypeResponseEmbeddings struct { + // An array of float embeddings. + Float [][]float64 `json:"float,omitempty" url:"float,omitempty"` + // An array of signed int8 embeddings. Each value is between -128 and 127. + Int8 [][]int `json:"int8,omitempty" url:"int8,omitempty"` + // An array of unsigned int8 embeddings. Each value is between 0 and 255. + Uint8 [][]int `json:"uint8,omitempty" url:"uint8,omitempty"` + // An array of packed signed binary embeddings. The length of each binary embedding is 1/8 the length of the float embeddings of the provided model. Each value is between -128 and 127. + Binary [][]int `json:"binary,omitempty" url:"binary,omitempty"` + // An array of packed unsigned binary embeddings. The length of each binary embedding is 1/8 the length of the float embeddings of the provided model. Each value is between 0 and 255. + Ubinary [][]int `json:"ubinary,omitempty" url:"ubinary,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage } -func (c *ChatTextGenerationEvent) GetExtraProperties() map[string]interface{} { - return c.extraProperties +func (e *EmbedByTypeResponseEmbeddings) GetFloat() [][]float64 { + if e == nil { + return nil + } + return e.Float } -func (c *ChatTextGenerationEvent) UnmarshalJSON(data []byte) error { - type unmarshaler ChatTextGenerationEvent - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err +func (e *EmbedByTypeResponseEmbeddings) GetInt8() [][]int { + if e == nil { + return nil } - *c = ChatTextGenerationEvent(value) + return e.Int8 +} - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err +func (e *EmbedByTypeResponseEmbeddings) GetUint8() [][]int { + if e == nil { + return nil } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) - return nil + return e.Uint8 } -func (c *ChatTextGenerationEvent) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(c); err == nil { - return value +func (e *EmbedByTypeResponseEmbeddings) GetBinary() [][]int { + if e == nil { + return nil } - return fmt.Sprintf("%#v", c) + return e.Binary } -// A streamed event delta which signifies a delta in tool call arguments. -type ChatToolCallDeltaEvent struct { - Index *int `json:"index,omitempty" url:"index,omitempty"` - Delta *ChatToolCallDeltaEventDelta `json:"delta,omitempty" url:"delta,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage +func (e *EmbedByTypeResponseEmbeddings) GetUbinary() [][]int { + if e == nil { + return nil + } + return e.Ubinary } -func (c *ChatToolCallDeltaEvent) GetExtraProperties() map[string]interface{} { - return c.extraProperties +func (e *EmbedByTypeResponseEmbeddings) GetExtraProperties() map[string]interface{} { + return e.extraProperties } -func (c *ChatToolCallDeltaEvent) UnmarshalJSON(data []byte) error { - type unmarshaler ChatToolCallDeltaEvent +func (e *EmbedByTypeResponseEmbeddings) UnmarshalJSON(data []byte) error { + type unmarshaler EmbedByTypeResponseEmbeddings var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ChatToolCallDeltaEvent(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *e = EmbedByTypeResponseEmbeddings(value) + extraProperties, err := internal.ExtractExtraProperties(data, *e) if err != nil { return err } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + e.extraProperties = extraProperties + e.rawJSON = json.RawMessage(data) return nil } -func (c *ChatToolCallDeltaEvent) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (e *EmbedByTypeResponseEmbeddings) String() string { + if len(e.rawJSON) > 0 { + if value, err := internal.StringifyJSON(e.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(e); err == nil { return value } - return fmt.Sprintf("%#v", c) + return fmt.Sprintf("%#v", e) } -type ChatToolCallDeltaEventDelta struct { - Message *ChatToolCallDeltaEventDeltaMessage `json:"message,omitempty" url:"message,omitempty"` +type EmbedFloatsResponse struct { + Id string `json:"id" url:"id"` + // An array of embeddings, where each embedding is an array of floats. The length of the `embeddings` array will be the same as the length of the original `texts` array. + Embeddings [][]float64 `json:"embeddings,omitempty" url:"embeddings,omitempty"` + // The text entries for which embeddings were returned. + Texts []string `json:"texts,omitempty" url:"texts,omitempty"` + // The image entries for which embeddings were returned. + Images []*Image `json:"images,omitempty" url:"images,omitempty"` + Meta *ApiMeta `json:"meta,omitempty" url:"meta,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage } -func (c *ChatToolCallDeltaEventDelta) GetExtraProperties() map[string]interface{} { - return c.extraProperties +func (e *EmbedFloatsResponse) GetId() string { + if e == nil { + return "" + } + return e.Id } -func (c *ChatToolCallDeltaEventDelta) UnmarshalJSON(data []byte) error { - type unmarshaler ChatToolCallDeltaEventDelta - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err +func (e *EmbedFloatsResponse) GetEmbeddings() [][]float64 { + if e == nil { + return nil } - *c = ChatToolCallDeltaEventDelta(value) + return e.Embeddings +} - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err +func (e *EmbedFloatsResponse) GetTexts() []string { + if e == nil { + return nil } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) - return nil + return e.Texts } -func (c *ChatToolCallDeltaEventDelta) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(c); err == nil { - return value +func (e *EmbedFloatsResponse) GetImages() []*Image { + if e == nil { + return nil } - return fmt.Sprintf("%#v", c) + return e.Images } -type ChatToolCallDeltaEventDeltaMessage struct { - ToolCalls *ChatToolCallDeltaEventDeltaMessageToolCalls `json:"tool_calls,omitempty" url:"tool_calls,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage +func (e *EmbedFloatsResponse) GetMeta() *ApiMeta { + if e == nil { + return nil + } + return e.Meta } -func (c *ChatToolCallDeltaEventDeltaMessage) GetExtraProperties() map[string]interface{} { - return c.extraProperties +func (e *EmbedFloatsResponse) GetExtraProperties() map[string]interface{} { + return e.extraProperties } -func (c *ChatToolCallDeltaEventDeltaMessage) UnmarshalJSON(data []byte) error { - type unmarshaler ChatToolCallDeltaEventDeltaMessage +func (e *EmbedFloatsResponse) UnmarshalJSON(data []byte) error { + type unmarshaler EmbedFloatsResponse var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ChatToolCallDeltaEventDeltaMessage(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *e = EmbedFloatsResponse(value) + extraProperties, err := internal.ExtractExtraProperties(data, *e) if err != nil { return err } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + e.extraProperties = extraProperties + e.rawJSON = json.RawMessage(data) return nil } -func (c *ChatToolCallDeltaEventDeltaMessage) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (e *EmbedFloatsResponse) String() string { + if len(e.rawJSON) > 0 { + if value, err := internal.StringifyJSON(e.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(e); err == nil { return value } - return fmt.Sprintf("%#v", c) + return fmt.Sprintf("%#v", e) } -type ChatToolCallDeltaEventDeltaMessageToolCalls struct { - Function *ChatToolCallDeltaEventDeltaMessageToolCallsFunction `json:"function,omitempty" url:"function,omitempty"` +// Specifies the type of input passed to the model. Required for embedding models v3 and higher. +// +// - `"search_document"`: Used for embeddings stored in a vector database for search use-cases. +// - `"search_query"`: Used for embeddings of search queries run against a vector DB to find relevant documents. +// - `"classification"`: Used for embeddings passed through a text classifier. +// - `"clustering"`: Used for the embeddings run through a clustering algorithm. +// - `"image"`: Used for embeddings with image input. +type EmbedInputType string - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} +const ( + EmbedInputTypeSearchDocument EmbedInputType = "search_document" + EmbedInputTypeSearchQuery EmbedInputType = "search_query" + EmbedInputTypeClassification EmbedInputType = "classification" + EmbedInputTypeClustering EmbedInputType = "clustering" + EmbedInputTypeImage EmbedInputType = "image" +) -func (c *ChatToolCallDeltaEventDeltaMessageToolCalls) GetExtraProperties() map[string]interface{} { - return c.extraProperties +func NewEmbedInputTypeFromString(s string) (EmbedInputType, error) { + switch s { + case "search_document": + return EmbedInputTypeSearchDocument, nil + case "search_query": + return EmbedInputTypeSearchQuery, nil + case "classification": + return EmbedInputTypeClassification, nil + case "clustering": + return EmbedInputTypeClustering, nil + case "image": + return EmbedInputTypeImage, nil + } + var t EmbedInputType + return "", fmt.Errorf("%s is not a valid %T", s, t) } -func (c *ChatToolCallDeltaEventDeltaMessageToolCalls) UnmarshalJSON(data []byte) error { - type unmarshaler ChatToolCallDeltaEventDeltaMessageToolCalls - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *c = ChatToolCallDeltaEventDeltaMessageToolCalls(value) +func (e EmbedInputType) Ptr() *EmbedInputType { + return &e +} - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err - } - c.extraProperties = extraProperties +// One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. +// +// Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. +// +// If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned. +type EmbedRequestTruncate string - c._rawJSON = json.RawMessage(data) - return nil -} +const ( + EmbedRequestTruncateNone EmbedRequestTruncate = "NONE" + EmbedRequestTruncateStart EmbedRequestTruncate = "START" + EmbedRequestTruncateEnd EmbedRequestTruncate = "END" +) -func (c *ChatToolCallDeltaEventDeltaMessageToolCalls) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(c); err == nil { - return value +func NewEmbedRequestTruncateFromString(s string) (EmbedRequestTruncate, error) { + switch s { + case "NONE": + return EmbedRequestTruncateNone, nil + case "START": + return EmbedRequestTruncateStart, nil + case "END": + return EmbedRequestTruncateEnd, nil } - return fmt.Sprintf("%#v", c) + var t EmbedRequestTruncate + return "", fmt.Errorf("%s is not a valid %T", s, t) } -type ChatToolCallDeltaEventDeltaMessageToolCallsFunction struct { - Arguments *string `json:"arguments,omitempty" url:"arguments,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage +func (e EmbedRequestTruncate) Ptr() *EmbedRequestTruncate { + return &e } -func (c *ChatToolCallDeltaEventDeltaMessageToolCallsFunction) GetExtraProperties() map[string]interface{} { - return c.extraProperties +type EmbedResponse struct { + ResponseType string + EmbeddingsFloats *EmbedFloatsResponse + EmbeddingsByType *EmbedByTypeResponse } -func (c *ChatToolCallDeltaEventDeltaMessageToolCallsFunction) UnmarshalJSON(data []byte) error { - type unmarshaler ChatToolCallDeltaEventDeltaMessageToolCallsFunction - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *c = ChatToolCallDeltaEventDeltaMessageToolCallsFunction(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err +func (e *EmbedResponse) GetResponseType() string { + if e == nil { + return "" } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) - return nil + return e.ResponseType } -func (c *ChatToolCallDeltaEventDeltaMessageToolCallsFunction) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(c); err == nil { - return value +func (e *EmbedResponse) GetEmbeddingsFloats() *EmbedFloatsResponse { + if e == nil { + return nil } - return fmt.Sprintf("%#v", c) + return e.EmbeddingsFloats } -// A streamed event delta which signifies a tool call has finished streaming. -type ChatToolCallEndEvent struct { - Index *int `json:"index,omitempty" url:"index,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (c *ChatToolCallEndEvent) GetExtraProperties() map[string]interface{} { - return c.extraProperties +func (e *EmbedResponse) GetEmbeddingsByType() *EmbedByTypeResponse { + if e == nil { + return nil + } + return e.EmbeddingsByType } -func (c *ChatToolCallEndEvent) UnmarshalJSON(data []byte) error { - type unmarshaler ChatToolCallEndEvent - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err +func (e *EmbedResponse) UnmarshalJSON(data []byte) error { + var unmarshaler struct { + ResponseType string `json:"response_type"` } - *c = ChatToolCallEndEvent(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { + if err := json.Unmarshal(data, &unmarshaler); err != nil { return err } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + e.ResponseType = unmarshaler.ResponseType + if unmarshaler.ResponseType == "" { + return fmt.Errorf("%T did not include discriminant response_type", e) + } + switch unmarshaler.ResponseType { + case "embeddings_floats": + value := new(EmbedFloatsResponse) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + e.EmbeddingsFloats = value + case "embeddings_by_type": + value := new(EmbedByTypeResponse) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + e.EmbeddingsByType = value + } return nil } -func (c *ChatToolCallEndEvent) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } +func (e EmbedResponse) MarshalJSON() ([]byte, error) { + if err := e.validate(); err != nil { + return nil, err } - if value, err := core.StringifyJSON(c); err == nil { - return value + if e.EmbeddingsFloats != nil { + return internal.MarshalJSONWithExtraProperty(e.EmbeddingsFloats, "response_type", "embeddings_floats") } - return fmt.Sprintf("%#v", c) -} - -// A streamed event delta which signifies a tool call has started streaming. -type ChatToolCallStartEvent struct { - Index *int `json:"index,omitempty" url:"index,omitempty"` - Delta *ChatToolCallStartEventDelta `json:"delta,omitempty" url:"delta,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage + if e.EmbeddingsByType != nil { + return internal.MarshalJSONWithExtraProperty(e.EmbeddingsByType, "response_type", "embeddings_by_type") + } + return nil, fmt.Errorf("type %T does not define a non-empty union type", e) } -func (c *ChatToolCallStartEvent) GetExtraProperties() map[string]interface{} { - return c.extraProperties +type EmbedResponseVisitor interface { + VisitEmbeddingsFloats(*EmbedFloatsResponse) error + VisitEmbeddingsByType(*EmbedByTypeResponse) error } -func (c *ChatToolCallStartEvent) UnmarshalJSON(data []byte) error { - type unmarshaler ChatToolCallStartEvent - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err +func (e *EmbedResponse) Accept(visitor EmbedResponseVisitor) error { + if e.EmbeddingsFloats != nil { + return visitor.VisitEmbeddingsFloats(e.EmbeddingsFloats) } - *c = ChatToolCallStartEvent(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err + if e.EmbeddingsByType != nil { + return visitor.VisitEmbeddingsByType(e.EmbeddingsByType) } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) - return nil + return fmt.Errorf("type %T does not define a non-empty union type", e) } -func (c *ChatToolCallStartEvent) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value +func (e *EmbedResponse) validate() error { + if e == nil { + return fmt.Errorf("type %T is nil", e) + } + var fields []string + if e.EmbeddingsFloats != nil { + fields = append(fields, "embeddings_floats") + } + if e.EmbeddingsByType != nil { + fields = append(fields, "embeddings_by_type") + } + if len(fields) == 0 { + if e.ResponseType != "" { + return fmt.Errorf("type %T defines a discriminant set to %q but the field is not set", e, e.ResponseType) } + return fmt.Errorf("type %T is empty", e) } - if value, err := core.StringifyJSON(c); err == nil { - return value + if len(fields) > 1 { + return fmt.Errorf("type %T defines values for %s, but only one value is allowed", e, fields) } - return fmt.Sprintf("%#v", c) -} - -type ChatToolCallStartEventDelta struct { - Message *ChatToolCallStartEventDeltaMessage `json:"message,omitempty" url:"message,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage + if e.ResponseType != "" { + field := fields[0] + if e.ResponseType != field { + return fmt.Errorf( + "type %T defines a discriminant set to %q, but it does not match the %T field; either remove or update the discriminant to match", + e, + e.ResponseType, + e, + ) + } + } + return nil } -func (c *ChatToolCallStartEventDelta) GetExtraProperties() map[string]interface{} { - return c.extraProperties -} +type EmbeddingType string -func (c *ChatToolCallStartEventDelta) UnmarshalJSON(data []byte) error { - type unmarshaler ChatToolCallStartEventDelta - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *c = ChatToolCallStartEventDelta(value) +const ( + EmbeddingTypeFloat EmbeddingType = "float" + EmbeddingTypeInt8 EmbeddingType = "int8" + EmbeddingTypeUint8 EmbeddingType = "uint8" + EmbeddingTypeBinary EmbeddingType = "binary" + EmbeddingTypeUbinary EmbeddingType = "ubinary" +) - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err +func NewEmbeddingTypeFromString(s string) (EmbeddingType, error) { + switch s { + case "float": + return EmbeddingTypeFloat, nil + case "int8": + return EmbeddingTypeInt8, nil + case "uint8": + return EmbeddingTypeUint8, nil + case "binary": + return EmbeddingTypeBinary, nil + case "ubinary": + return EmbeddingTypeUbinary, nil } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) - return nil + var t EmbeddingType + return "", fmt.Errorf("%s is not a valid %T", s, t) } -func (c *ChatToolCallStartEventDelta) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(c); err == nil { - return value - } - return fmt.Sprintf("%#v", c) +func (e EmbeddingType) Ptr() *EmbeddingType { + return &e } -type ChatToolCallStartEventDeltaMessage struct { - ToolCalls *ToolCallV2 `json:"tool_calls,omitempty" url:"tool_calls,omitempty"` +type FinishReason string - extraProperties map[string]interface{} - _rawJSON json.RawMessage +const ( + FinishReasonComplete FinishReason = "COMPLETE" + FinishReasonStopSequence FinishReason = "STOP_SEQUENCE" + FinishReasonError FinishReason = "ERROR" + FinishReasonErrorToxic FinishReason = "ERROR_TOXIC" + FinishReasonErrorLimit FinishReason = "ERROR_LIMIT" + FinishReasonUserCancel FinishReason = "USER_CANCEL" + FinishReasonMaxTokens FinishReason = "MAX_TOKENS" +) + +func NewFinishReasonFromString(s string) (FinishReason, error) { + switch s { + case "COMPLETE": + return FinishReasonComplete, nil + case "STOP_SEQUENCE": + return FinishReasonStopSequence, nil + case "ERROR": + return FinishReasonError, nil + case "ERROR_TOXIC": + return FinishReasonErrorToxic, nil + case "ERROR_LIMIT": + return FinishReasonErrorLimit, nil + case "USER_CANCEL": + return FinishReasonUserCancel, nil + case "MAX_TOKENS": + return FinishReasonMaxTokens, nil + } + var t FinishReason + return "", fmt.Errorf("%s is not a valid %T", s, t) } -func (c *ChatToolCallStartEventDeltaMessage) GetExtraProperties() map[string]interface{} { - return c.extraProperties +func (f FinishReason) Ptr() *FinishReason { + return &f } -func (c *ChatToolCallStartEventDeltaMessage) UnmarshalJSON(data []byte) error { - type unmarshaler ChatToolCallStartEventDeltaMessage - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *c = ChatToolCallStartEventDeltaMessage(value) +// One of `GENERATION|NONE` to specify how and if the token likelihoods are returned with the response. Defaults to `NONE`. +// +// If `GENERATION` is selected, the token likelihoods will only be provided for generated text. +// +// WARNING: `ALL` is deprecated, and will be removed in a future release. +type GenerateRequestReturnLikelihoods string - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err - } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) - return nil -} +const ( + GenerateRequestReturnLikelihoodsGeneration GenerateRequestReturnLikelihoods = "GENERATION" + GenerateRequestReturnLikelihoodsAll GenerateRequestReturnLikelihoods = "ALL" + GenerateRequestReturnLikelihoodsNone GenerateRequestReturnLikelihoods = "NONE" +) -func (c *ChatToolCallStartEventDeltaMessage) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(c); err == nil { - return value +func NewGenerateRequestReturnLikelihoodsFromString(s string) (GenerateRequestReturnLikelihoods, error) { + switch s { + case "GENERATION": + return GenerateRequestReturnLikelihoodsGeneration, nil + case "ALL": + return GenerateRequestReturnLikelihoodsAll, nil + case "NONE": + return GenerateRequestReturnLikelihoodsNone, nil } - return fmt.Sprintf("%#v", c) + var t GenerateRequestReturnLikelihoods + return "", fmt.Errorf("%s is not a valid %T", s, t) } -type ChatToolCallsChunkEvent struct { - ToolCallDelta *ToolCallDelta `json:"tool_call_delta,omitempty" url:"tool_call_delta,omitempty"` - Text *string `json:"text,omitempty" url:"text,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage +func (g GenerateRequestReturnLikelihoods) Ptr() *GenerateRequestReturnLikelihoods { + return &g } -func (c *ChatToolCallsChunkEvent) GetExtraProperties() map[string]interface{} { - return c.extraProperties -} +// One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. +// +// Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. +// +// If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned. +type GenerateRequestTruncate string -func (c *ChatToolCallsChunkEvent) UnmarshalJSON(data []byte) error { - type unmarshaler ChatToolCallsChunkEvent - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *c = ChatToolCallsChunkEvent(value) +const ( + GenerateRequestTruncateNone GenerateRequestTruncate = "NONE" + GenerateRequestTruncateStart GenerateRequestTruncate = "START" + GenerateRequestTruncateEnd GenerateRequestTruncate = "END" +) - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err +func NewGenerateRequestTruncateFromString(s string) (GenerateRequestTruncate, error) { + switch s { + case "NONE": + return GenerateRequestTruncateNone, nil + case "START": + return GenerateRequestTruncateStart, nil + case "END": + return GenerateRequestTruncateEnd, nil } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) - return nil + var t GenerateRequestTruncate + return "", fmt.Errorf("%s is not a valid %T", s, t) } -func (c *ChatToolCallsChunkEvent) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(c); err == nil { - return value - } - return fmt.Sprintf("%#v", c) +func (g GenerateRequestTruncate) Ptr() *GenerateRequestTruncate { + return &g } -type ChatToolCallsGenerationEvent struct { - // The text generated related to the tool calls generated - Text *string `json:"text,omitempty" url:"text,omitempty"` - ToolCalls []*ToolCall `json:"tool_calls,omitempty" url:"tool_calls,omitempty"` +type GenerateStreamEnd struct { + IsFinished bool `json:"is_finished" url:"is_finished"` + FinishReason *FinishReason `json:"finish_reason,omitempty" url:"finish_reason,omitempty"` + Response *GenerateStreamEndResponse `json:"response,omitempty" url:"response,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (c *ChatToolCallsGenerationEvent) GetExtraProperties() map[string]interface{} { - return c.extraProperties + rawJSON json.RawMessage } -func (c *ChatToolCallsGenerationEvent) UnmarshalJSON(data []byte) error { - type unmarshaler ChatToolCallsGenerationEvent - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *c = ChatToolCallsGenerationEvent(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err +func (g *GenerateStreamEnd) GetIsFinished() bool { + if g == nil { + return false } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) - return nil + return g.IsFinished } -func (c *ChatToolCallsGenerationEvent) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(c); err == nil { - return value +func (g *GenerateStreamEnd) GetFinishReason() *FinishReason { + if g == nil { + return nil } - return fmt.Sprintf("%#v", c) + return g.FinishReason } -// A streamed event which contains a delta of tool plan text. -type ChatToolPlanDeltaEvent struct { - Delta *ChatToolPlanDeltaEventDelta `json:"delta,omitempty" url:"delta,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage +func (g *GenerateStreamEnd) GetResponse() *GenerateStreamEndResponse { + if g == nil { + return nil + } + return g.Response } -func (c *ChatToolPlanDeltaEvent) GetExtraProperties() map[string]interface{} { - return c.extraProperties +func (g *GenerateStreamEnd) GetExtraProperties() map[string]interface{} { + return g.extraProperties } -func (c *ChatToolPlanDeltaEvent) UnmarshalJSON(data []byte) error { - type unmarshaler ChatToolPlanDeltaEvent +func (g *GenerateStreamEnd) UnmarshalJSON(data []byte) error { + type unmarshaler GenerateStreamEnd var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ChatToolPlanDeltaEvent(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *g = GenerateStreamEnd(value) + extraProperties, err := internal.ExtractExtraProperties(data, *g) if err != nil { return err } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + g.extraProperties = extraProperties + g.rawJSON = json.RawMessage(data) return nil } -func (c *ChatToolPlanDeltaEvent) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (g *GenerateStreamEnd) String() string { + if len(g.rawJSON) > 0 { + if value, err := internal.StringifyJSON(g.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(g); err == nil { return value } - return fmt.Sprintf("%#v", c) + return fmt.Sprintf("%#v", g) } -type ChatToolPlanDeltaEventDelta struct { - Message *ChatToolPlanDeltaEventDeltaMessage `json:"message,omitempty" url:"message,omitempty"` +type GenerateStreamEndResponse struct { + Id string `json:"id" url:"id"` + Prompt *string `json:"prompt,omitempty" url:"prompt,omitempty"` + Generations []*SingleGenerationInStream `json:"generations,omitempty" url:"generations,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage } -func (c *ChatToolPlanDeltaEventDelta) GetExtraProperties() map[string]interface{} { - return c.extraProperties -} - -func (c *ChatToolPlanDeltaEventDelta) UnmarshalJSON(data []byte) error { - type unmarshaler ChatToolPlanDeltaEventDelta - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *c = ChatToolPlanDeltaEventDelta(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err +func (g *GenerateStreamEndResponse) GetId() string { + if g == nil { + return "" } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) - return nil + return g.Id } -func (c *ChatToolPlanDeltaEventDelta) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(c); err == nil { - return value +func (g *GenerateStreamEndResponse) GetPrompt() *string { + if g == nil { + return nil } - return fmt.Sprintf("%#v", c) + return g.Prompt } -type ChatToolPlanDeltaEventDeltaMessage struct { - ToolPlan *string `json:"tool_plan,omitempty" url:"tool_plan,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage +func (g *GenerateStreamEndResponse) GetGenerations() []*SingleGenerationInStream { + if g == nil { + return nil + } + return g.Generations } -func (c *ChatToolPlanDeltaEventDeltaMessage) GetExtraProperties() map[string]interface{} { - return c.extraProperties +func (g *GenerateStreamEndResponse) GetExtraProperties() map[string]interface{} { + return g.extraProperties } -func (c *ChatToolPlanDeltaEventDeltaMessage) UnmarshalJSON(data []byte) error { - type unmarshaler ChatToolPlanDeltaEventDeltaMessage +func (g *GenerateStreamEndResponse) UnmarshalJSON(data []byte) error { + type unmarshaler GenerateStreamEndResponse var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ChatToolPlanDeltaEventDeltaMessage(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *g = GenerateStreamEndResponse(value) + extraProperties, err := internal.ExtractExtraProperties(data, *g) if err != nil { return err } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + g.extraProperties = extraProperties + g.rawJSON = json.RawMessage(data) return nil } -func (c *ChatToolPlanDeltaEventDeltaMessage) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (g *GenerateStreamEndResponse) String() string { + if len(g.rawJSON) > 0 { + if value, err := internal.StringifyJSON(g.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(g); err == nil { return value } - return fmt.Sprintf("%#v", c) + return fmt.Sprintf("%#v", g) } -type CheckApiKeyResponse struct { - Valid bool `json:"valid" url:"valid"` - OrganizationId *string `json:"organization_id,omitempty" url:"organization_id,omitempty"` - OwnerId *string `json:"owner_id,omitempty" url:"owner_id,omitempty"` +type GenerateStreamError struct { + // Refers to the nth generation. Only present when `num_generations` is greater than zero. + Index *int `json:"index,omitempty" url:"index,omitempty"` + IsFinished bool `json:"is_finished" url:"is_finished"` + FinishReason FinishReason `json:"finish_reason" url:"finish_reason"` + // Error message + Err string `json:"err" url:"err"` extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (c *CheckApiKeyResponse) GetExtraProperties() map[string]interface{} { - return c.extraProperties + rawJSON json.RawMessage } -func (c *CheckApiKeyResponse) UnmarshalJSON(data []byte) error { - type unmarshaler CheckApiKeyResponse - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err +func (g *GenerateStreamError) GetIndex() *int { + if g == nil { + return nil } - *c = CheckApiKeyResponse(value) + return g.Index +} - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err +func (g *GenerateStreamError) GetIsFinished() bool { + if g == nil { + return false } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) - return nil + return g.IsFinished } -func (c *CheckApiKeyResponse) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(c); err == nil { - return value +func (g *GenerateStreamError) GetFinishReason() FinishReason { + if g == nil { + return "" } - return fmt.Sprintf("%#v", c) + return g.FinishReason } -// Citation information containing sources and the text cited. -type Citation struct { - // Start index of the cited snippet in the original source text. - Start *int `json:"start,omitempty" url:"start,omitempty"` - // End index of the cited snippet in the original source text. - End *int `json:"end,omitempty" url:"end,omitempty"` - // Text snippet that is being cited. - Text *string `json:"text,omitempty" url:"text,omitempty"` - Sources []*Source `json:"sources,omitempty" url:"sources,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage +func (g *GenerateStreamError) GetErr() string { + if g == nil { + return "" + } + return g.Err } -func (c *Citation) GetExtraProperties() map[string]interface{} { - return c.extraProperties +func (g *GenerateStreamError) GetExtraProperties() map[string]interface{} { + return g.extraProperties } -func (c *Citation) UnmarshalJSON(data []byte) error { - type unmarshaler Citation +func (g *GenerateStreamError) UnmarshalJSON(data []byte) error { + type unmarshaler GenerateStreamError var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = Citation(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *g = GenerateStreamError(value) + extraProperties, err := internal.ExtractExtraProperties(data, *g) if err != nil { return err } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + g.extraProperties = extraProperties + g.rawJSON = json.RawMessage(data) return nil } -func (c *Citation) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (g *GenerateStreamError) String() string { + if len(g.rawJSON) > 0 { + if value, err := internal.StringifyJSON(g.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(g); err == nil { return value } - return fmt.Sprintf("%#v", c) + return fmt.Sprintf("%#v", g) } -// A streamed event which signifies a citation has finished streaming. -type CitationEndEvent struct { - Index *int `json:"index,omitempty" url:"index,omitempty"` - +type GenerateStreamEvent struct { extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage } -func (c *CitationEndEvent) GetExtraProperties() map[string]interface{} { - return c.extraProperties +func (g *GenerateStreamEvent) GetExtraProperties() map[string]interface{} { + return g.extraProperties } -func (c *CitationEndEvent) UnmarshalJSON(data []byte) error { - type unmarshaler CitationEndEvent +func (g *GenerateStreamEvent) UnmarshalJSON(data []byte) error { + type unmarshaler GenerateStreamEvent var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = CitationEndEvent(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *g = GenerateStreamEvent(value) + extraProperties, err := internal.ExtractExtraProperties(data, *g) if err != nil { return err } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + g.extraProperties = extraProperties + g.rawJSON = json.RawMessage(data) return nil } -func (c *CitationEndEvent) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (g *GenerateStreamEvent) String() string { + if len(g.rawJSON) > 0 { + if value, err := internal.StringifyJSON(g.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(g); err == nil { return value } - return fmt.Sprintf("%#v", c) + return fmt.Sprintf("%#v", g) } -// Options for controlling citation generation. -type CitationOptions struct { - // Defaults to `"accurate"`. - // Dictates the approach taken to generating citations as part of the RAG flow by allowing the user to specify whether they want `"accurate"` results, `"fast"` results or no results. - // - // **Note**: `command-r7b-12-2024` only supports `"fast"` and `"off"` modes. Its default is `"fast"`. - Mode *CitationOptionsMode `json:"mode,omitempty" url:"mode,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} +// One of `GENERATION|NONE` to specify how and if the token likelihoods are returned with the response. Defaults to `NONE`. +// +// If `GENERATION` is selected, the token likelihoods will only be provided for generated text. +// +// WARNING: `ALL` is deprecated, and will be removed in a future release. +type GenerateStreamRequestReturnLikelihoods string -func (c *CitationOptions) GetExtraProperties() map[string]interface{} { - return c.extraProperties -} +const ( + GenerateStreamRequestReturnLikelihoodsGeneration GenerateStreamRequestReturnLikelihoods = "GENERATION" + GenerateStreamRequestReturnLikelihoodsAll GenerateStreamRequestReturnLikelihoods = "ALL" + GenerateStreamRequestReturnLikelihoodsNone GenerateStreamRequestReturnLikelihoods = "NONE" +) -func (c *CitationOptions) UnmarshalJSON(data []byte) error { - type unmarshaler CitationOptions - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *c = CitationOptions(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err +func NewGenerateStreamRequestReturnLikelihoodsFromString(s string) (GenerateStreamRequestReturnLikelihoods, error) { + switch s { + case "GENERATION": + return GenerateStreamRequestReturnLikelihoodsGeneration, nil + case "ALL": + return GenerateStreamRequestReturnLikelihoodsAll, nil + case "NONE": + return GenerateStreamRequestReturnLikelihoodsNone, nil } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) - return nil + var t GenerateStreamRequestReturnLikelihoods + return "", fmt.Errorf("%s is not a valid %T", s, t) } -func (c *CitationOptions) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(c); err == nil { - return value - } - return fmt.Sprintf("%#v", c) +func (g GenerateStreamRequestReturnLikelihoods) Ptr() *GenerateStreamRequestReturnLikelihoods { + return &g } -// Defaults to `"accurate"`. -// Dictates the approach taken to generating citations as part of the RAG flow by allowing the user to specify whether they want `"accurate"` results, `"fast"` results or no results. +// One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. +// +// Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. // -// **Note**: `command-r7b-12-2024` only supports `"fast"` and `"off"` modes. Its default is `"fast"`. -type CitationOptionsMode string +// If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned. +type GenerateStreamRequestTruncate string const ( - CitationOptionsModeFast CitationOptionsMode = "FAST" - CitationOptionsModeAccurate CitationOptionsMode = "ACCURATE" - CitationOptionsModeOff CitationOptionsMode = "OFF" + GenerateStreamRequestTruncateNone GenerateStreamRequestTruncate = "NONE" + GenerateStreamRequestTruncateStart GenerateStreamRequestTruncate = "START" + GenerateStreamRequestTruncateEnd GenerateStreamRequestTruncate = "END" ) -func NewCitationOptionsModeFromString(s string) (CitationOptionsMode, error) { +func NewGenerateStreamRequestTruncateFromString(s string) (GenerateStreamRequestTruncate, error) { switch s { - case "FAST": - return CitationOptionsModeFast, nil - case "ACCURATE": - return CitationOptionsModeAccurate, nil - case "OFF": - return CitationOptionsModeOff, nil + case "NONE": + return GenerateStreamRequestTruncateNone, nil + case "START": + return GenerateStreamRequestTruncateStart, nil + case "END": + return GenerateStreamRequestTruncateEnd, nil } - var t CitationOptionsMode + var t GenerateStreamRequestTruncate return "", fmt.Errorf("%s is not a valid %T", s, t) } -func (c CitationOptionsMode) Ptr() *CitationOptionsMode { - return &c +func (g GenerateStreamRequestTruncate) Ptr() *GenerateStreamRequestTruncate { + return &g } -// A streamed event which signifies a citation has been created. -type CitationStartEvent struct { - Index *int `json:"index,omitempty" url:"index,omitempty"` - Delta *CitationStartEventDelta `json:"delta,omitempty" url:"delta,omitempty"` +type GenerateStreamText struct { + // A segment of text of the generation. + Text string `json:"text" url:"text"` + // Refers to the nth generation. Only present when `num_generations` is greater than zero, and only when text responses are being streamed. + Index *int `json:"index,omitempty" url:"index,omitempty"` + IsFinished bool `json:"is_finished" url:"is_finished"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage } -func (c *CitationStartEvent) GetExtraProperties() map[string]interface{} { - return c.extraProperties +func (g *GenerateStreamText) GetText() string { + if g == nil { + return "" + } + return g.Text +} + +func (g *GenerateStreamText) GetIndex() *int { + if g == nil { + return nil + } + return g.Index +} + +func (g *GenerateStreamText) GetIsFinished() bool { + if g == nil { + return false + } + return g.IsFinished +} + +func (g *GenerateStreamText) GetExtraProperties() map[string]interface{} { + return g.extraProperties } -func (c *CitationStartEvent) UnmarshalJSON(data []byte) error { - type unmarshaler CitationStartEvent +func (g *GenerateStreamText) UnmarshalJSON(data []byte) error { + type unmarshaler GenerateStreamText var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = CitationStartEvent(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *g = GenerateStreamText(value) + extraProperties, err := internal.ExtractExtraProperties(data, *g) if err != nil { return err } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + g.extraProperties = extraProperties + g.rawJSON = json.RawMessage(data) return nil } -func (c *CitationStartEvent) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (g *GenerateStreamText) String() string { + if len(g.rawJSON) > 0 { + if value, err := internal.StringifyJSON(g.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(g); err == nil { return value } - return fmt.Sprintf("%#v", c) + return fmt.Sprintf("%#v", g) } -type CitationStartEventDelta struct { - Message *CitationStartEventDeltaMessage `json:"message,omitempty" url:"message,omitempty"` +// Response in content type stream when `stream` is `true` in the request parameters. Generation tokens are streamed with the GenerationStream response. The final response is of type GenerationFinalResponse. +type GenerateStreamedResponse struct { + EventType string + TextGeneration *GenerateStreamText + StreamEnd *GenerateStreamEnd + StreamError *GenerateStreamError +} - extraProperties map[string]interface{} - _rawJSON json.RawMessage +func (g *GenerateStreamedResponse) GetEventType() string { + if g == nil { + return "" + } + return g.EventType } -func (c *CitationStartEventDelta) GetExtraProperties() map[string]interface{} { - return c.extraProperties +func (g *GenerateStreamedResponse) GetTextGeneration() *GenerateStreamText { + if g == nil { + return nil + } + return g.TextGeneration } -func (c *CitationStartEventDelta) UnmarshalJSON(data []byte) error { - type unmarshaler CitationStartEventDelta - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err +func (g *GenerateStreamedResponse) GetStreamEnd() *GenerateStreamEnd { + if g == nil { + return nil } - *c = CitationStartEventDelta(value) + return g.StreamEnd +} - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err +func (g *GenerateStreamedResponse) GetStreamError() *GenerateStreamError { + if g == nil { + return nil } - c.extraProperties = extraProperties + return g.StreamError +} - c._rawJSON = json.RawMessage(data) +func (g *GenerateStreamedResponse) UnmarshalJSON(data []byte) error { + var unmarshaler struct { + EventType string `json:"event_type"` + } + if err := json.Unmarshal(data, &unmarshaler); err != nil { + return err + } + g.EventType = unmarshaler.EventType + if unmarshaler.EventType == "" { + return fmt.Errorf("%T did not include discriminant event_type", g) + } + switch unmarshaler.EventType { + case "text-generation": + value := new(GenerateStreamText) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + g.TextGeneration = value + case "stream-end": + value := new(GenerateStreamEnd) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + g.StreamEnd = value + case "stream-error": + value := new(GenerateStreamError) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + g.StreamError = value + } return nil } -func (c *CitationStartEventDelta) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } +func (g GenerateStreamedResponse) MarshalJSON() ([]byte, error) { + if err := g.validate(); err != nil { + return nil, err } - if value, err := core.StringifyJSON(c); err == nil { - return value + if g.TextGeneration != nil { + return internal.MarshalJSONWithExtraProperty(g.TextGeneration, "event_type", "text-generation") } - return fmt.Sprintf("%#v", c) + if g.StreamEnd != nil { + return internal.MarshalJSONWithExtraProperty(g.StreamEnd, "event_type", "stream-end") + } + if g.StreamError != nil { + return internal.MarshalJSONWithExtraProperty(g.StreamError, "event_type", "stream-error") + } + return nil, fmt.Errorf("type %T does not define a non-empty union type", g) } -type CitationStartEventDeltaMessage struct { - Citations *Citation `json:"citations,omitempty" url:"citations,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage +type GenerateStreamedResponseVisitor interface { + VisitTextGeneration(*GenerateStreamText) error + VisitStreamEnd(*GenerateStreamEnd) error + VisitStreamError(*GenerateStreamError) error } -func (c *CitationStartEventDeltaMessage) GetExtraProperties() map[string]interface{} { - return c.extraProperties +func (g *GenerateStreamedResponse) Accept(visitor GenerateStreamedResponseVisitor) error { + if g.TextGeneration != nil { + return visitor.VisitTextGeneration(g.TextGeneration) + } + if g.StreamEnd != nil { + return visitor.VisitStreamEnd(g.StreamEnd) + } + if g.StreamError != nil { + return visitor.VisitStreamError(g.StreamError) + } + return fmt.Errorf("type %T does not define a non-empty union type", g) } -func (c *CitationStartEventDeltaMessage) UnmarshalJSON(data []byte) error { - type unmarshaler CitationStartEventDeltaMessage - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err +func (g *GenerateStreamedResponse) validate() error { + if g == nil { + return fmt.Errorf("type %T is nil", g) } - *c = CitationStartEventDeltaMessage(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err + var fields []string + if g.TextGeneration != nil { + fields = append(fields, "text-generation") + } + if g.StreamEnd != nil { + fields = append(fields, "stream-end") + } + if g.StreamError != nil { + fields = append(fields, "stream-error") + } + if len(fields) == 0 { + if g.EventType != "" { + return fmt.Errorf("type %T defines a discriminant set to %q but the field is not set", g, g.EventType) + } + return fmt.Errorf("type %T is empty", g) + } + if len(fields) > 1 { + return fmt.Errorf("type %T defines values for %s, but only one value is allowed", g, fields) + } + if g.EventType != "" { + field := fields[0] + if g.EventType != field { + return fmt.Errorf( + "type %T defines a discriminant set to %q, but it does not match the %T field; either remove or update the discriminant to match", + g, + g.EventType, + g, + ) + } } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) return nil } -func (c *CitationStartEventDeltaMessage) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } +type Generation struct { + Id string `json:"id" url:"id"` + // Prompt used for generations. + Prompt *string `json:"prompt,omitempty" url:"prompt,omitempty"` + // List of generated results + Generations []*SingleGeneration `json:"generations,omitempty" url:"generations,omitempty"` + Meta *ApiMeta `json:"meta,omitempty" url:"meta,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (g *Generation) GetId() string { + if g == nil { + return "" } - if value, err := core.StringifyJSON(c); err == nil { - return value + return g.Id +} + +func (g *Generation) GetPrompt() *string { + if g == nil { + return nil } - return fmt.Sprintf("%#v", c) + return g.Prompt } -type ClassifyDataMetrics struct { - LabelMetrics []*LabelMetric `json:"label_metrics,omitempty" url:"label_metrics,omitempty"` +func (g *Generation) GetGenerations() []*SingleGeneration { + if g == nil { + return nil + } + return g.Generations +} - extraProperties map[string]interface{} - _rawJSON json.RawMessage +func (g *Generation) GetMeta() *ApiMeta { + if g == nil { + return nil + } + return g.Meta } -func (c *ClassifyDataMetrics) GetExtraProperties() map[string]interface{} { - return c.extraProperties +func (g *Generation) GetExtraProperties() map[string]interface{} { + return g.extraProperties } -func (c *ClassifyDataMetrics) UnmarshalJSON(data []byte) error { - type unmarshaler ClassifyDataMetrics +func (g *Generation) UnmarshalJSON(data []byte) error { + type unmarshaler Generation var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ClassifyDataMetrics(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *g = Generation(value) + extraProperties, err := internal.ExtractExtraProperties(data, *g) if err != nil { return err } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + g.extraProperties = extraProperties + g.rawJSON = json.RawMessage(data) return nil } -func (c *ClassifyDataMetrics) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (g *Generation) String() string { + if len(g.rawJSON) > 0 { + if value, err := internal.StringifyJSON(g.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(g); err == nil { return value } - return fmt.Sprintf("%#v", c) + return fmt.Sprintf("%#v", g) } -type ClassifyExample struct { - Text *string `json:"text,omitempty" url:"text,omitempty"` - Label *string `json:"label,omitempty" url:"label,omitempty"` +type Image struct { + // Width of the image in pixels + Width int64 `json:"width" url:"width"` + // Height of the image in pixels + Height int64 `json:"height" url:"height"` + // Format of the image + Format string `json:"format" url:"format"` + // Bit depth of the image + BitDepth int64 `json:"bit_depth" url:"bit_depth"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage } -func (c *ClassifyExample) GetExtraProperties() map[string]interface{} { - return c.extraProperties +func (i *Image) GetWidth() int64 { + if i == nil { + return 0 + } + return i.Width } -func (c *ClassifyExample) UnmarshalJSON(data []byte) error { - type unmarshaler ClassifyExample +func (i *Image) GetHeight() int64 { + if i == nil { + return 0 + } + return i.Height +} + +func (i *Image) GetFormat() string { + if i == nil { + return "" + } + return i.Format +} + +func (i *Image) GetBitDepth() int64 { + if i == nil { + return 0 + } + return i.BitDepth +} + +func (i *Image) GetExtraProperties() map[string]interface{} { + return i.extraProperties +} + +func (i *Image) UnmarshalJSON(data []byte) error { + type unmarshaler Image var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *c = ClassifyExample(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) + *i = Image(value) + extraProperties, err := internal.ExtractExtraProperties(data, *i) if err != nil { return err } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) + i.extraProperties = extraProperties + i.rawJSON = json.RawMessage(data) return nil } -func (c *ClassifyExample) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { +func (i *Image) String() string { + if len(i.rawJSON) > 0 { + if value, err := internal.StringifyJSON(i.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(c); err == nil { + if value, err := internal.StringifyJSON(i); err == nil { return value } - return fmt.Sprintf("%#v", c) + return fmt.Sprintf("%#v", i) } -// One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. -// Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. -// If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned. -type ClassifyRequestTruncate string - -const ( - ClassifyRequestTruncateNone ClassifyRequestTruncate = "NONE" - ClassifyRequestTruncateStart ClassifyRequestTruncate = "START" - ClassifyRequestTruncateEnd ClassifyRequestTruncate = "END" -) - -func NewClassifyRequestTruncateFromString(s string) (ClassifyRequestTruncate, error) { - switch s { - case "NONE": - return ClassifyRequestTruncateNone, nil - case "START": - return ClassifyRequestTruncateStart, nil - case "END": - return ClassifyRequestTruncateEnd, nil - } - var t ClassifyRequestTruncate - return "", fmt.Errorf("%s is not a valid %T", s, t) -} - -func (c ClassifyRequestTruncate) Ptr() *ClassifyRequestTruncate { - return &c -} - -type ClassifyResponse struct { - Id string `json:"id" url:"id"` - Classifications []*ClassifyResponseClassificationsItem `json:"classifications,omitempty" url:"classifications,omitempty"` - Meta *ApiMeta `json:"meta,omitempty" url:"meta,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (c *ClassifyResponse) GetExtraProperties() map[string]interface{} { - return c.extraProperties -} - -func (c *ClassifyResponse) UnmarshalJSON(data []byte) error { - type unmarshaler ClassifyResponse - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *c = ClassifyResponse(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err - } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) - return nil -} - -func (c *ClassifyResponse) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(c); err == nil { - return value - } - return fmt.Sprintf("%#v", c) -} - -type ClassifyResponseClassificationsItem struct { - Id string `json:"id" url:"id"` - // The input text that was classified - Input *string `json:"input,omitempty" url:"input,omitempty"` - // The predicted label for the associated query (only filled for single-label models) - Prediction *string `json:"prediction,omitempty" url:"prediction,omitempty"` - // An array containing the predicted labels for the associated query (only filled for single-label classification) - Predictions []string `json:"predictions,omitempty" url:"predictions,omitempty"` - // The confidence score for the top predicted class (only filled for single-label classification) - Confidence *float64 `json:"confidence,omitempty" url:"confidence,omitempty"` - // An array containing the confidence scores of all the predictions in the same order - Confidences []float64 `json:"confidences,omitempty" url:"confidences,omitempty"` - // A map containing each label and its confidence score according to the classifier. All the confidence scores add up to 1 for single-label classification. For multi-label classification the label confidences are independent of each other, so they don't have to sum up to 1. - Labels map[string]*ClassifyResponseClassificationsItemLabelsValue `json:"labels,omitempty" url:"labels,omitempty"` - // The type of classification performed - ClassificationType ClassifyResponseClassificationsItemClassificationType `json:"classification_type" url:"classification_type"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (c *ClassifyResponseClassificationsItem) GetExtraProperties() map[string]interface{} { - return c.extraProperties -} - -func (c *ClassifyResponseClassificationsItem) UnmarshalJSON(data []byte) error { - type unmarshaler ClassifyResponseClassificationsItem - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *c = ClassifyResponseClassificationsItem(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err - } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) - return nil -} - -func (c *ClassifyResponseClassificationsItem) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(c); err == nil { - return value - } - return fmt.Sprintf("%#v", c) -} - -// The type of classification performed -type ClassifyResponseClassificationsItemClassificationType string - -const ( - ClassifyResponseClassificationsItemClassificationTypeSingleLabel ClassifyResponseClassificationsItemClassificationType = "single-label" - ClassifyResponseClassificationsItemClassificationTypeMultiLabel ClassifyResponseClassificationsItemClassificationType = "multi-label" -) - -func NewClassifyResponseClassificationsItemClassificationTypeFromString(s string) (ClassifyResponseClassificationsItemClassificationType, error) { - switch s { - case "single-label": - return ClassifyResponseClassificationsItemClassificationTypeSingleLabel, nil - case "multi-label": - return ClassifyResponseClassificationsItemClassificationTypeMultiLabel, nil - } - var t ClassifyResponseClassificationsItemClassificationType - return "", fmt.Errorf("%s is not a valid %T", s, t) -} - -func (c ClassifyResponseClassificationsItemClassificationType) Ptr() *ClassifyResponseClassificationsItemClassificationType { - return &c -} - -type ClassifyResponseClassificationsItemLabelsValue struct { - Confidence *float64 `json:"confidence,omitempty" url:"confidence,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (c *ClassifyResponseClassificationsItemLabelsValue) GetExtraProperties() map[string]interface{} { - return c.extraProperties -} - -func (c *ClassifyResponseClassificationsItemLabelsValue) UnmarshalJSON(data []byte) error { - type unmarshaler ClassifyResponseClassificationsItemLabelsValue - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *c = ClassifyResponseClassificationsItemLabelsValue(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err - } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) - return nil -} - -func (c *ClassifyResponseClassificationsItemLabelsValue) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(c); err == nil { - return value - } - return fmt.Sprintf("%#v", c) -} - -type ClientClosedRequestErrorBody struct { - Data *string `json:"data,omitempty" url:"data,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (c *ClientClosedRequestErrorBody) GetExtraProperties() map[string]interface{} { - return c.extraProperties -} - -func (c *ClientClosedRequestErrorBody) UnmarshalJSON(data []byte) error { - type unmarshaler ClientClosedRequestErrorBody - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *c = ClientClosedRequestErrorBody(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err - } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) - return nil -} - -func (c *ClientClosedRequestErrorBody) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(c); err == nil { - return value - } - return fmt.Sprintf("%#v", c) -} - -// One of the Cohere API endpoints that the model can be used with. -type CompatibleEndpoint string - -const ( - CompatibleEndpointChat CompatibleEndpoint = "chat" - CompatibleEndpointEmbed CompatibleEndpoint = "embed" - CompatibleEndpointClassify CompatibleEndpoint = "classify" - CompatibleEndpointSummarize CompatibleEndpoint = "summarize" - CompatibleEndpointRerank CompatibleEndpoint = "rerank" - CompatibleEndpointRate CompatibleEndpoint = "rate" - CompatibleEndpointGenerate CompatibleEndpoint = "generate" -) - -func NewCompatibleEndpointFromString(s string) (CompatibleEndpoint, error) { - switch s { - case "chat": - return CompatibleEndpointChat, nil - case "embed": - return CompatibleEndpointEmbed, nil - case "classify": - return CompatibleEndpointClassify, nil - case "summarize": - return CompatibleEndpointSummarize, nil - case "rerank": - return CompatibleEndpointRerank, nil - case "rate": - return CompatibleEndpointRate, nil - case "generate": - return CompatibleEndpointGenerate, nil - } - var t CompatibleEndpoint - return "", fmt.Errorf("%s is not a valid %T", s, t) -} - -func (c CompatibleEndpoint) Ptr() *CompatibleEndpoint { - return &c -} - -// A connector allows you to integrate data sources with the '/chat' endpoint to create grounded generations with citations to the data source. -// documents to help answer users. -type Connector struct { - // The unique identifier of the connector (used in both `/connectors` & `/chat` endpoints). - // This is automatically created from the name of the connector upon registration. - Id string `json:"id" url:"id"` - // The organization to which this connector belongs. This is automatically set to - // the organization of the user who created the connector. - OrganizationId *string `json:"organization_id,omitempty" url:"organization_id,omitempty"` - // A human-readable name for the connector. - Name string `json:"name" url:"name"` - // A description of the connector. - Description *string `json:"description,omitempty" url:"description,omitempty"` - // The URL of the connector that will be used to search for documents. - Url *string `json:"url,omitempty" url:"url,omitempty"` - // The UTC time at which the connector was created. - CreatedAt time.Time `json:"created_at" url:"created_at"` - // The UTC time at which the connector was last updated. - UpdatedAt time.Time `json:"updated_at" url:"updated_at"` - // A list of fields to exclude from the prompt (fields remain in the document). - Excludes []string `json:"excludes,omitempty" url:"excludes,omitempty"` - // The type of authentication/authorization used by the connector. Possible values: [oauth, service_auth] - AuthType *string `json:"auth_type,omitempty" url:"auth_type,omitempty"` - // The OAuth 2.0 configuration for the connector. - Oauth *ConnectorOAuth `json:"oauth,omitempty" url:"oauth,omitempty"` - // The OAuth status for the user making the request. One of ["valid", "expired", ""]. Empty string (field is omitted) means the user has not authorized the connector yet. - AuthStatus *ConnectorAuthStatus `json:"auth_status,omitempty" url:"auth_status,omitempty"` - // Whether the connector is active or not. - Active *bool `json:"active,omitempty" url:"active,omitempty"` - // Whether a chat request should continue or not if the request to this connector fails. - ContinueOnFailure *bool `json:"continue_on_failure,omitempty" url:"continue_on_failure,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (c *Connector) GetExtraProperties() map[string]interface{} { - return c.extraProperties -} - -func (c *Connector) UnmarshalJSON(data []byte) error { - type embed Connector - var unmarshaler = struct { - embed - CreatedAt *core.DateTime `json:"created_at"` - UpdatedAt *core.DateTime `json:"updated_at"` - }{ - embed: embed(*c), - } - if err := json.Unmarshal(data, &unmarshaler); err != nil { - return err - } - *c = Connector(unmarshaler.embed) - c.CreatedAt = unmarshaler.CreatedAt.Time() - c.UpdatedAt = unmarshaler.UpdatedAt.Time() - - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err - } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) - return nil -} - -func (c *Connector) MarshalJSON() ([]byte, error) { - type embed Connector - var marshaler = struct { - embed - CreatedAt *core.DateTime `json:"created_at"` - UpdatedAt *core.DateTime `json:"updated_at"` - }{ - embed: embed(*c), - CreatedAt: core.NewDateTime(c.CreatedAt), - UpdatedAt: core.NewDateTime(c.UpdatedAt), - } - return json.Marshal(marshaler) -} - -func (c *Connector) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(c); err == nil { - return value - } - return fmt.Sprintf("%#v", c) -} - -// The OAuth status for the user making the request. One of ["valid", "expired", ""]. Empty string (field is omitted) means the user has not authorized the connector yet. -type ConnectorAuthStatus string - -const ( - ConnectorAuthStatusValid ConnectorAuthStatus = "valid" - ConnectorAuthStatusExpired ConnectorAuthStatus = "expired" -) - -func NewConnectorAuthStatusFromString(s string) (ConnectorAuthStatus, error) { - switch s { - case "valid": - return ConnectorAuthStatusValid, nil - case "expired": - return ConnectorAuthStatusExpired, nil - } - var t ConnectorAuthStatus - return "", fmt.Errorf("%s is not a valid %T", s, t) -} - -func (c ConnectorAuthStatus) Ptr() *ConnectorAuthStatus { - return &c -} - -type ConnectorOAuth struct { - // The OAuth 2.0 client ID. This field is encrypted at rest. - ClientId *string `json:"client_id,omitempty" url:"client_id,omitempty"` - // The OAuth 2.0 client Secret. This field is encrypted at rest and never returned in a response. - ClientSecret *string `json:"client_secret,omitempty" url:"client_secret,omitempty"` - // The OAuth 2.0 /authorize endpoint to use when users authorize the connector. - AuthorizeUrl string `json:"authorize_url" url:"authorize_url"` - // The OAuth 2.0 /token endpoint to use when users authorize the connector. - TokenUrl string `json:"token_url" url:"token_url"` - // The OAuth scopes to request when users authorize the connector. - Scope *string `json:"scope,omitempty" url:"scope,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (c *ConnectorOAuth) GetExtraProperties() map[string]interface{} { - return c.extraProperties -} - -func (c *ConnectorOAuth) UnmarshalJSON(data []byte) error { - type unmarshaler ConnectorOAuth - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *c = ConnectorOAuth(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err - } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) - return nil -} - -func (c *ConnectorOAuth) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(c); err == nil { - return value - } - return fmt.Sprintf("%#v", c) -} - -// A Content block which contains information about the content type and the content itself. -type Content struct { - Type string - Text *TextContent -} - -func (c *Content) UnmarshalJSON(data []byte) error { - var unmarshaler struct { - Type string `json:"type"` - } - if err := json.Unmarshal(data, &unmarshaler); err != nil { - return err - } - c.Type = unmarshaler.Type - if unmarshaler.Type == "" { - return fmt.Errorf("%T did not include discriminant type", c) - } - switch unmarshaler.Type { - case "text": - value := new(TextContent) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - c.Text = value - } - return nil -} - -func (c Content) MarshalJSON() ([]byte, error) { - if c.Text != nil { - return core.MarshalJSONWithExtraProperty(c.Text, "type", "text") - } - return nil, fmt.Errorf("type %T does not define a non-empty union type", c) -} - -type ContentVisitor interface { - VisitText(*TextContent) error -} - -func (c *Content) Accept(visitor ContentVisitor) error { - if c.Text != nil { - return visitor.VisitText(c.Text) - } - return fmt.Errorf("type %T does not define a non-empty union type", c) -} - -type CreateConnectorOAuth struct { - // The OAuth 2.0 client ID. This fields is encrypted at rest. - ClientId *string `json:"client_id,omitempty" url:"client_id,omitempty"` - // The OAuth 2.0 client Secret. This field is encrypted at rest and never returned in a response. - ClientSecret *string `json:"client_secret,omitempty" url:"client_secret,omitempty"` - // The OAuth 2.0 /authorize endpoint to use when users authorize the connector. - AuthorizeUrl *string `json:"authorize_url,omitempty" url:"authorize_url,omitempty"` - // The OAuth 2.0 /token endpoint to use when users authorize the connector. - TokenUrl *string `json:"token_url,omitempty" url:"token_url,omitempty"` - // The OAuth scopes to request when users authorize the connector. - Scope *string `json:"scope,omitempty" url:"scope,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (c *CreateConnectorOAuth) GetExtraProperties() map[string]interface{} { - return c.extraProperties -} - -func (c *CreateConnectorOAuth) UnmarshalJSON(data []byte) error { - type unmarshaler CreateConnectorOAuth - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *c = CreateConnectorOAuth(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err - } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) - return nil -} - -func (c *CreateConnectorOAuth) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(c); err == nil { - return value - } - return fmt.Sprintf("%#v", c) -} - -type CreateConnectorResponse struct { - Connector *Connector `json:"connector,omitempty" url:"connector,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (c *CreateConnectorResponse) GetExtraProperties() map[string]interface{} { - return c.extraProperties -} - -func (c *CreateConnectorResponse) UnmarshalJSON(data []byte) error { - type unmarshaler CreateConnectorResponse - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *c = CreateConnectorResponse(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err - } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) - return nil -} - -func (c *CreateConnectorResponse) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(c); err == nil { - return value - } - return fmt.Sprintf("%#v", c) -} - -type CreateConnectorServiceAuth struct { - Type AuthTokenType `json:"type" url:"type"` - // The token that will be used in the HTTP Authorization header when making requests to the connector. This field is encrypted at rest and never returned in a response. - Token string `json:"token" url:"token"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (c *CreateConnectorServiceAuth) GetExtraProperties() map[string]interface{} { - return c.extraProperties -} - -func (c *CreateConnectorServiceAuth) UnmarshalJSON(data []byte) error { - type unmarshaler CreateConnectorServiceAuth - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *c = CreateConnectorServiceAuth(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err - } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) - return nil -} - -func (c *CreateConnectorServiceAuth) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(c); err == nil { - return value - } - return fmt.Sprintf("%#v", c) -} - -// Response from creating an embed job. -type CreateEmbedJobResponse struct { - JobId string `json:"job_id" url:"job_id"` - Meta *ApiMeta `json:"meta,omitempty" url:"meta,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (c *CreateEmbedJobResponse) GetExtraProperties() map[string]interface{} { - return c.extraProperties -} - -func (c *CreateEmbedJobResponse) UnmarshalJSON(data []byte) error { - type unmarshaler CreateEmbedJobResponse - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *c = CreateEmbedJobResponse(value) - - extraProperties, err := core.ExtractExtraProperties(data, *c) - if err != nil { - return err - } - c.extraProperties = extraProperties - - c._rawJSON = json.RawMessage(data) - return nil -} - -func (c *CreateEmbedJobResponse) String() string { - if len(c._rawJSON) > 0 { - if value, err := core.StringifyJSON(c._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(c); err == nil { - return value - } - return fmt.Sprintf("%#v", c) -} - -type Dataset struct { - // The dataset ID - Id string `json:"id" url:"id"` - // The name of the dataset - Name string `json:"name" url:"name"` - // The creation date - CreatedAt time.Time `json:"created_at" url:"created_at"` - // The last update date - UpdatedAt time.Time `json:"updated_at" url:"updated_at"` - DatasetType DatasetType `json:"dataset_type" url:"dataset_type"` - ValidationStatus DatasetValidationStatus `json:"validation_status" url:"validation_status"` - // Errors found during validation - ValidationError *string `json:"validation_error,omitempty" url:"validation_error,omitempty"` - // the avro schema of the dataset - Schema *string `json:"schema,omitempty" url:"schema,omitempty"` - RequiredFields []string `json:"required_fields,omitempty" url:"required_fields,omitempty"` - PreserveFields []string `json:"preserve_fields,omitempty" url:"preserve_fields,omitempty"` - // the underlying files that make up the dataset - DatasetParts []*DatasetPart `json:"dataset_parts,omitempty" url:"dataset_parts,omitempty"` - // warnings found during validation - ValidationWarnings []string `json:"validation_warnings,omitempty" url:"validation_warnings,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (d *Dataset) GetExtraProperties() map[string]interface{} { - return d.extraProperties -} - -func (d *Dataset) UnmarshalJSON(data []byte) error { - type embed Dataset - var unmarshaler = struct { - embed - CreatedAt *core.DateTime `json:"created_at"` - UpdatedAt *core.DateTime `json:"updated_at"` - }{ - embed: embed(*d), - } - if err := json.Unmarshal(data, &unmarshaler); err != nil { - return err - } - *d = Dataset(unmarshaler.embed) - d.CreatedAt = unmarshaler.CreatedAt.Time() - d.UpdatedAt = unmarshaler.UpdatedAt.Time() - - extraProperties, err := core.ExtractExtraProperties(data, *d) - if err != nil { - return err - } - d.extraProperties = extraProperties - - d._rawJSON = json.RawMessage(data) - return nil -} - -func (d *Dataset) MarshalJSON() ([]byte, error) { - type embed Dataset - var marshaler = struct { - embed - CreatedAt *core.DateTime `json:"created_at"` - UpdatedAt *core.DateTime `json:"updated_at"` - }{ - embed: embed(*d), - CreatedAt: core.NewDateTime(d.CreatedAt), - UpdatedAt: core.NewDateTime(d.UpdatedAt), - } - return json.Marshal(marshaler) -} - -func (d *Dataset) String() string { - if len(d._rawJSON) > 0 { - if value, err := core.StringifyJSON(d._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(d); err == nil { - return value - } - return fmt.Sprintf("%#v", d) -} - -type DatasetPart struct { - // The dataset part ID - Id string `json:"id" url:"id"` - // The name of the dataset part - Name string `json:"name" url:"name"` - // The download url of the file - Url *string `json:"url,omitempty" url:"url,omitempty"` - // The index of the file - Index *int `json:"index,omitempty" url:"index,omitempty"` - // The size of the file in bytes - SizeBytes *int `json:"size_bytes,omitempty" url:"size_bytes,omitempty"` - // The number of rows in the file - NumRows *int `json:"num_rows,omitempty" url:"num_rows,omitempty"` - // The download url of the original file - OriginalUrl *string `json:"original_url,omitempty" url:"original_url,omitempty"` - // The first few rows of the parsed file - Samples []string `json:"samples,omitempty" url:"samples,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (d *DatasetPart) GetExtraProperties() map[string]interface{} { - return d.extraProperties -} - -func (d *DatasetPart) UnmarshalJSON(data []byte) error { - type unmarshaler DatasetPart - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *d = DatasetPart(value) - - extraProperties, err := core.ExtractExtraProperties(data, *d) - if err != nil { - return err - } - d.extraProperties = extraProperties - - d._rawJSON = json.RawMessage(data) - return nil -} - -func (d *DatasetPart) String() string { - if len(d._rawJSON) > 0 { - if value, err := core.StringifyJSON(d._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(d); err == nil { - return value - } - return fmt.Sprintf("%#v", d) -} - -// The type of the dataset -type DatasetType string - -const ( - DatasetTypeEmbedInput DatasetType = "embed-input" - DatasetTypeEmbedResult DatasetType = "embed-result" - DatasetTypeClusterResult DatasetType = "cluster-result" - DatasetTypeClusterOutliers DatasetType = "cluster-outliers" - DatasetTypeRerankerFinetuneInput DatasetType = "reranker-finetune-input" - DatasetTypeSingleLabelClassificationFinetuneInput DatasetType = "single-label-classification-finetune-input" - DatasetTypeChatFinetuneInput DatasetType = "chat-finetune-input" - DatasetTypeMultiLabelClassificationFinetuneInput DatasetType = "multi-label-classification-finetune-input" -) - -func NewDatasetTypeFromString(s string) (DatasetType, error) { - switch s { - case "embed-input": - return DatasetTypeEmbedInput, nil - case "embed-result": - return DatasetTypeEmbedResult, nil - case "cluster-result": - return DatasetTypeClusterResult, nil - case "cluster-outliers": - return DatasetTypeClusterOutliers, nil - case "reranker-finetune-input": - return DatasetTypeRerankerFinetuneInput, nil - case "single-label-classification-finetune-input": - return DatasetTypeSingleLabelClassificationFinetuneInput, nil - case "chat-finetune-input": - return DatasetTypeChatFinetuneInput, nil - case "multi-label-classification-finetune-input": - return DatasetTypeMultiLabelClassificationFinetuneInput, nil - } - var t DatasetType - return "", fmt.Errorf("%s is not a valid %T", s, t) -} - -func (d DatasetType) Ptr() *DatasetType { - return &d -} - -// The validation status of the dataset -type DatasetValidationStatus string - -const ( - DatasetValidationStatusUnknown DatasetValidationStatus = "unknown" - DatasetValidationStatusQueued DatasetValidationStatus = "queued" - DatasetValidationStatusProcessing DatasetValidationStatus = "processing" - DatasetValidationStatusFailed DatasetValidationStatus = "failed" - DatasetValidationStatusValidated DatasetValidationStatus = "validated" - DatasetValidationStatusSkipped DatasetValidationStatus = "skipped" -) - -func NewDatasetValidationStatusFromString(s string) (DatasetValidationStatus, error) { - switch s { - case "unknown": - return DatasetValidationStatusUnknown, nil - case "queued": - return DatasetValidationStatusQueued, nil - case "processing": - return DatasetValidationStatusProcessing, nil - case "failed": - return DatasetValidationStatusFailed, nil - case "validated": - return DatasetValidationStatusValidated, nil - case "skipped": - return DatasetValidationStatusSkipped, nil - } - var t DatasetValidationStatus - return "", fmt.Errorf("%s is not a valid %T", s, t) -} - -func (d DatasetValidationStatus) Ptr() *DatasetValidationStatus { - return &d -} - -type DeleteConnectorResponse = map[string]interface{} - -type DetokenizeResponse struct { - // A string representing the list of tokens. - Text string `json:"text" url:"text"` - Meta *ApiMeta `json:"meta,omitempty" url:"meta,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (d *DetokenizeResponse) GetExtraProperties() map[string]interface{} { - return d.extraProperties -} - -func (d *DetokenizeResponse) UnmarshalJSON(data []byte) error { - type unmarshaler DetokenizeResponse - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *d = DetokenizeResponse(value) - - extraProperties, err := core.ExtractExtraProperties(data, *d) - if err != nil { - return err - } - d.extraProperties = extraProperties - - d._rawJSON = json.RawMessage(data) - return nil -} - -func (d *DetokenizeResponse) String() string { - if len(d._rawJSON) > 0 { - if value, err := core.StringifyJSON(d._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(d); err == nil { - return value - } - return fmt.Sprintf("%#v", d) -} - -// Relevant information that could be used by the model to generate a more accurate reply. -// The content of each document are generally short (should be under 300 words). Metadata should be used to provide additional information, both the key name and the value will be -// passed to the model. -type Document struct { - // A relevant documents that the model can cite to generate a more accurate reply. Each document is a string-string dictionary. - Data map[string]string `json:"data,omitempty" url:"data,omitempty"` - // Unique identifier for this document which will be referenced in citations. If not provided an ID will be automatically generated. - Id *string `json:"id,omitempty" url:"id,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (d *Document) GetExtraProperties() map[string]interface{} { - return d.extraProperties -} - -func (d *Document) UnmarshalJSON(data []byte) error { - type unmarshaler Document - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *d = Document(value) - - extraProperties, err := core.ExtractExtraProperties(data, *d) - if err != nil { - return err - } - d.extraProperties = extraProperties - - d._rawJSON = json.RawMessage(data) - return nil -} - -func (d *Document) String() string { - if len(d._rawJSON) > 0 { - if value, err := core.StringifyJSON(d._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(d); err == nil { - return value - } - return fmt.Sprintf("%#v", d) -} - -// Document content. -type DocumentContent struct { - Document *Document `json:"document,omitempty" url:"document,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (d *DocumentContent) GetExtraProperties() map[string]interface{} { - return d.extraProperties -} - -func (d *DocumentContent) UnmarshalJSON(data []byte) error { - type unmarshaler DocumentContent - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *d = DocumentContent(value) - - extraProperties, err := core.ExtractExtraProperties(data, *d) - if err != nil { - return err - } - d.extraProperties = extraProperties - - d._rawJSON = json.RawMessage(data) - return nil -} - -func (d *DocumentContent) String() string { - if len(d._rawJSON) > 0 { - if value, err := core.StringifyJSON(d._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(d); err == nil { - return value - } - return fmt.Sprintf("%#v", d) -} - -// A document source object containing the unique identifier of the document and the document itself. -type DocumentSource struct { - // The unique identifier of the document - Id *string `json:"id,omitempty" url:"id,omitempty"` - Document map[string]interface{} `json:"document,omitempty" url:"document,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (d *DocumentSource) GetExtraProperties() map[string]interface{} { - return d.extraProperties -} - -func (d *DocumentSource) UnmarshalJSON(data []byte) error { - type unmarshaler DocumentSource - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *d = DocumentSource(value) - - extraProperties, err := core.ExtractExtraProperties(data, *d) - if err != nil { - return err - } - d.extraProperties = extraProperties - - d._rawJSON = json.RawMessage(data) - return nil -} - -func (d *DocumentSource) String() string { - if len(d._rawJSON) > 0 { - if value, err := core.StringifyJSON(d._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(d); err == nil { - return value - } - return fmt.Sprintf("%#v", d) -} - -type EmbedByTypeResponse struct { - Id string `json:"id" url:"id"` - // An object with different embedding types. The length of each embedding type array will be the same as the length of the original `texts` array. - Embeddings *EmbedByTypeResponseEmbeddings `json:"embeddings,omitempty" url:"embeddings,omitempty"` - // The text entries for which embeddings were returned. - Texts []string `json:"texts,omitempty" url:"texts,omitempty"` - // The image entries for which embeddings were returned. - Images []*Image `json:"images,omitempty" url:"images,omitempty"` - Meta *ApiMeta `json:"meta,omitempty" url:"meta,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (e *EmbedByTypeResponse) GetExtraProperties() map[string]interface{} { - return e.extraProperties -} - -func (e *EmbedByTypeResponse) UnmarshalJSON(data []byte) error { - type unmarshaler EmbedByTypeResponse - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *e = EmbedByTypeResponse(value) - - extraProperties, err := core.ExtractExtraProperties(data, *e) - if err != nil { - return err - } - e.extraProperties = extraProperties - - e._rawJSON = json.RawMessage(data) - return nil -} - -func (e *EmbedByTypeResponse) String() string { - if len(e._rawJSON) > 0 { - if value, err := core.StringifyJSON(e._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(e); err == nil { - return value - } - return fmt.Sprintf("%#v", e) -} - -// An object with different embedding types. The length of each embedding type array will be the same as the length of the original `texts` array. -type EmbedByTypeResponseEmbeddings struct { - // An array of float embeddings. - Float [][]float64 `json:"float,omitempty" url:"float,omitempty"` - // An array of signed int8 embeddings. Each value is between -128 and 127. - Int8 [][]int `json:"int8,omitempty" url:"int8,omitempty"` - // An array of unsigned int8 embeddings. Each value is between 0 and 255. - Uint8 [][]int `json:"uint8,omitempty" url:"uint8,omitempty"` - // An array of packed signed binary embeddings. The length of each binary embedding is 1/8 the length of the float embeddings of the provided model. Each value is between -128 and 127. - Binary [][]int `json:"binary,omitempty" url:"binary,omitempty"` - // An array of packed unsigned binary embeddings. The length of each binary embedding is 1/8 the length of the float embeddings of the provided model. Each value is between 0 and 255. - Ubinary [][]int `json:"ubinary,omitempty" url:"ubinary,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (e *EmbedByTypeResponseEmbeddings) GetExtraProperties() map[string]interface{} { - return e.extraProperties -} - -func (e *EmbedByTypeResponseEmbeddings) UnmarshalJSON(data []byte) error { - type unmarshaler EmbedByTypeResponseEmbeddings - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *e = EmbedByTypeResponseEmbeddings(value) - - extraProperties, err := core.ExtractExtraProperties(data, *e) - if err != nil { - return err - } - e.extraProperties = extraProperties - - e._rawJSON = json.RawMessage(data) - return nil -} - -func (e *EmbedByTypeResponseEmbeddings) String() string { - if len(e._rawJSON) > 0 { - if value, err := core.StringifyJSON(e._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(e); err == nil { - return value - } - return fmt.Sprintf("%#v", e) -} - -type EmbedFloatsResponse struct { - Id string `json:"id" url:"id"` - // An array of embeddings, where each embedding is an array of floats. The length of the `embeddings` array will be the same as the length of the original `texts` array. - Embeddings [][]float64 `json:"embeddings,omitempty" url:"embeddings,omitempty"` - // The text entries for which embeddings were returned. - Texts []string `json:"texts,omitempty" url:"texts,omitempty"` - // The image entries for which embeddings were returned. - Images []*Image `json:"images,omitempty" url:"images,omitempty"` - Meta *ApiMeta `json:"meta,omitempty" url:"meta,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (e *EmbedFloatsResponse) GetExtraProperties() map[string]interface{} { - return e.extraProperties -} - -func (e *EmbedFloatsResponse) UnmarshalJSON(data []byte) error { - type unmarshaler EmbedFloatsResponse - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *e = EmbedFloatsResponse(value) - - extraProperties, err := core.ExtractExtraProperties(data, *e) - if err != nil { - return err - } - e.extraProperties = extraProperties - - e._rawJSON = json.RawMessage(data) - return nil -} - -func (e *EmbedFloatsResponse) String() string { - if len(e._rawJSON) > 0 { - if value, err := core.StringifyJSON(e._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(e); err == nil { - return value - } - return fmt.Sprintf("%#v", e) -} - -// Specifies the type of input passed to the model. Required for embedding models v3 and higher. -// -// - `"search_document"`: Used for embeddings stored in a vector database for search use-cases. -// - `"search_query"`: Used for embeddings of search queries run against a vector DB to find relevant documents. -// - `"classification"`: Used for embeddings passed through a text classifier. -// - `"clustering"`: Used for the embeddings run through a clustering algorithm. -// - `"image"`: Used for embeddings with image input. -type EmbedInputType string - -const ( - EmbedInputTypeSearchDocument EmbedInputType = "search_document" - EmbedInputTypeSearchQuery EmbedInputType = "search_query" - EmbedInputTypeClassification EmbedInputType = "classification" - EmbedInputTypeClustering EmbedInputType = "clustering" - EmbedInputTypeImage EmbedInputType = "image" -) - -func NewEmbedInputTypeFromString(s string) (EmbedInputType, error) { - switch s { - case "search_document": - return EmbedInputTypeSearchDocument, nil - case "search_query": - return EmbedInputTypeSearchQuery, nil - case "classification": - return EmbedInputTypeClassification, nil - case "clustering": - return EmbedInputTypeClustering, nil - case "image": - return EmbedInputTypeImage, nil - } - var t EmbedInputType - return "", fmt.Errorf("%s is not a valid %T", s, t) -} - -func (e EmbedInputType) Ptr() *EmbedInputType { - return &e -} - -type EmbedJob struct { - // ID of the embed job - JobId string `json:"job_id" url:"job_id"` - // The name of the embed job - Name *string `json:"name,omitempty" url:"name,omitempty"` - // The status of the embed job - Status EmbedJobStatus `json:"status" url:"status"` - // The creation date of the embed job - CreatedAt time.Time `json:"created_at" url:"created_at"` - // ID of the input dataset - InputDatasetId string `json:"input_dataset_id" url:"input_dataset_id"` - // ID of the resulting output dataset - OutputDatasetId *string `json:"output_dataset_id,omitempty" url:"output_dataset_id,omitempty"` - // ID of the model used to embed - Model string `json:"model" url:"model"` - // The truncation option used - Truncate EmbedJobTruncate `json:"truncate" url:"truncate"` - Meta *ApiMeta `json:"meta,omitempty" url:"meta,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (e *EmbedJob) GetExtraProperties() map[string]interface{} { - return e.extraProperties -} - -func (e *EmbedJob) UnmarshalJSON(data []byte) error { - type embed EmbedJob - var unmarshaler = struct { - embed - CreatedAt *core.DateTime `json:"created_at"` - }{ - embed: embed(*e), - } - if err := json.Unmarshal(data, &unmarshaler); err != nil { - return err - } - *e = EmbedJob(unmarshaler.embed) - e.CreatedAt = unmarshaler.CreatedAt.Time() - - extraProperties, err := core.ExtractExtraProperties(data, *e) - if err != nil { - return err - } - e.extraProperties = extraProperties - - e._rawJSON = json.RawMessage(data) - return nil -} - -func (e *EmbedJob) MarshalJSON() ([]byte, error) { - type embed EmbedJob - var marshaler = struct { - embed - CreatedAt *core.DateTime `json:"created_at"` - }{ - embed: embed(*e), - CreatedAt: core.NewDateTime(e.CreatedAt), - } - return json.Marshal(marshaler) -} - -func (e *EmbedJob) String() string { - if len(e._rawJSON) > 0 { - if value, err := core.StringifyJSON(e._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(e); err == nil { - return value - } - return fmt.Sprintf("%#v", e) -} - -// The status of the embed job -type EmbedJobStatus string - -const ( - EmbedJobStatusProcessing EmbedJobStatus = "processing" - EmbedJobStatusComplete EmbedJobStatus = "complete" - EmbedJobStatusCancelling EmbedJobStatus = "cancelling" - EmbedJobStatusCancelled EmbedJobStatus = "cancelled" - EmbedJobStatusFailed EmbedJobStatus = "failed" -) - -func NewEmbedJobStatusFromString(s string) (EmbedJobStatus, error) { - switch s { - case "processing": - return EmbedJobStatusProcessing, nil - case "complete": - return EmbedJobStatusComplete, nil - case "cancelling": - return EmbedJobStatusCancelling, nil - case "cancelled": - return EmbedJobStatusCancelled, nil - case "failed": - return EmbedJobStatusFailed, nil - } - var t EmbedJobStatus - return "", fmt.Errorf("%s is not a valid %T", s, t) -} - -func (e EmbedJobStatus) Ptr() *EmbedJobStatus { - return &e -} - -// The truncation option used -type EmbedJobTruncate string - -const ( - EmbedJobTruncateStart EmbedJobTruncate = "START" - EmbedJobTruncateEnd EmbedJobTruncate = "END" -) - -func NewEmbedJobTruncateFromString(s string) (EmbedJobTruncate, error) { - switch s { - case "START": - return EmbedJobTruncateStart, nil - case "END": - return EmbedJobTruncateEnd, nil - } - var t EmbedJobTruncate - return "", fmt.Errorf("%s is not a valid %T", s, t) -} - -func (e EmbedJobTruncate) Ptr() *EmbedJobTruncate { - return &e -} - -// One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. -// -// Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. -// -// If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned. -type EmbedRequestTruncate string - -const ( - EmbedRequestTruncateNone EmbedRequestTruncate = "NONE" - EmbedRequestTruncateStart EmbedRequestTruncate = "START" - EmbedRequestTruncateEnd EmbedRequestTruncate = "END" -) - -func NewEmbedRequestTruncateFromString(s string) (EmbedRequestTruncate, error) { - switch s { - case "NONE": - return EmbedRequestTruncateNone, nil - case "START": - return EmbedRequestTruncateStart, nil - case "END": - return EmbedRequestTruncateEnd, nil - } - var t EmbedRequestTruncate - return "", fmt.Errorf("%s is not a valid %T", s, t) -} - -func (e EmbedRequestTruncate) Ptr() *EmbedRequestTruncate { - return &e -} - -type EmbedResponse struct { - ResponseType string - EmbeddingsFloats *EmbedFloatsResponse - EmbeddingsByType *EmbedByTypeResponse -} - -func (e *EmbedResponse) UnmarshalJSON(data []byte) error { - var unmarshaler struct { - ResponseType string `json:"response_type"` - } - if err := json.Unmarshal(data, &unmarshaler); err != nil { - return err - } - e.ResponseType = unmarshaler.ResponseType - if unmarshaler.ResponseType == "" { - return fmt.Errorf("%T did not include discriminant response_type", e) - } - switch unmarshaler.ResponseType { - case "embeddings_floats": - value := new(EmbedFloatsResponse) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - e.EmbeddingsFloats = value - case "embeddings_by_type": - value := new(EmbedByTypeResponse) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - e.EmbeddingsByType = value - } - return nil -} - -func (e EmbedResponse) MarshalJSON() ([]byte, error) { - if e.EmbeddingsFloats != nil { - return core.MarshalJSONWithExtraProperty(e.EmbeddingsFloats, "response_type", "embeddings_floats") - } - if e.EmbeddingsByType != nil { - return core.MarshalJSONWithExtraProperty(e.EmbeddingsByType, "response_type", "embeddings_by_type") - } - return nil, fmt.Errorf("type %T does not define a non-empty union type", e) -} - -type EmbedResponseVisitor interface { - VisitEmbeddingsFloats(*EmbedFloatsResponse) error - VisitEmbeddingsByType(*EmbedByTypeResponse) error -} - -func (e *EmbedResponse) Accept(visitor EmbedResponseVisitor) error { - if e.EmbeddingsFloats != nil { - return visitor.VisitEmbeddingsFloats(e.EmbeddingsFloats) - } - if e.EmbeddingsByType != nil { - return visitor.VisitEmbeddingsByType(e.EmbeddingsByType) - } - return fmt.Errorf("type %T does not define a non-empty union type", e) -} - -type EmbeddingType string - -const ( - EmbeddingTypeFloat EmbeddingType = "float" - EmbeddingTypeInt8 EmbeddingType = "int8" - EmbeddingTypeUint8 EmbeddingType = "uint8" - EmbeddingTypeBinary EmbeddingType = "binary" - EmbeddingTypeUbinary EmbeddingType = "ubinary" -) - -func NewEmbeddingTypeFromString(s string) (EmbeddingType, error) { - switch s { - case "float": - return EmbeddingTypeFloat, nil - case "int8": - return EmbeddingTypeInt8, nil - case "uint8": - return EmbeddingTypeUint8, nil - case "binary": - return EmbeddingTypeBinary, nil - case "ubinary": - return EmbeddingTypeUbinary, nil - } - var t EmbeddingType - return "", fmt.Errorf("%s is not a valid %T", s, t) -} - -func (e EmbeddingType) Ptr() *EmbeddingType { - return &e -} - -type FinetuneDatasetMetrics struct { - // The number of tokens of valid examples that can be used for training. - TrainableTokenCount *int64 `json:"trainable_token_count,omitempty" url:"trainable_token_count,omitempty"` - // The overall number of examples. - TotalExamples *int64 `json:"total_examples,omitempty" url:"total_examples,omitempty"` - // The number of training examples. - TrainExamples *int64 `json:"train_examples,omitempty" url:"train_examples,omitempty"` - // The size in bytes of all training examples. - TrainSizeBytes *int64 `json:"train_size_bytes,omitempty" url:"train_size_bytes,omitempty"` - // Number of evaluation examples. - EvalExamples *int64 `json:"eval_examples,omitempty" url:"eval_examples,omitempty"` - // The size in bytes of all eval examples. - EvalSizeBytes *int64 `json:"eval_size_bytes,omitempty" url:"eval_size_bytes,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (f *FinetuneDatasetMetrics) GetExtraProperties() map[string]interface{} { - return f.extraProperties -} - -func (f *FinetuneDatasetMetrics) UnmarshalJSON(data []byte) error { - type unmarshaler FinetuneDatasetMetrics - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *f = FinetuneDatasetMetrics(value) - - extraProperties, err := core.ExtractExtraProperties(data, *f) - if err != nil { - return err - } - f.extraProperties = extraProperties - - f._rawJSON = json.RawMessage(data) - return nil -} - -func (f *FinetuneDatasetMetrics) String() string { - if len(f._rawJSON) > 0 { - if value, err := core.StringifyJSON(f._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(f); err == nil { - return value - } - return fmt.Sprintf("%#v", f) -} - -type FinishReason string - -const ( - FinishReasonComplete FinishReason = "COMPLETE" - FinishReasonStopSequence FinishReason = "STOP_SEQUENCE" - FinishReasonError FinishReason = "ERROR" - FinishReasonErrorToxic FinishReason = "ERROR_TOXIC" - FinishReasonErrorLimit FinishReason = "ERROR_LIMIT" - FinishReasonUserCancel FinishReason = "USER_CANCEL" - FinishReasonMaxTokens FinishReason = "MAX_TOKENS" -) - -func NewFinishReasonFromString(s string) (FinishReason, error) { - switch s { - case "COMPLETE": - return FinishReasonComplete, nil - case "STOP_SEQUENCE": - return FinishReasonStopSequence, nil - case "ERROR": - return FinishReasonError, nil - case "ERROR_TOXIC": - return FinishReasonErrorToxic, nil - case "ERROR_LIMIT": - return FinishReasonErrorLimit, nil - case "USER_CANCEL": - return FinishReasonUserCancel, nil - case "MAX_TOKENS": - return FinishReasonMaxTokens, nil - } - var t FinishReason - return "", fmt.Errorf("%s is not a valid %T", s, t) -} - -func (f FinishReason) Ptr() *FinishReason { - return &f -} - -type GatewayTimeoutErrorBody struct { - Data *string `json:"data,omitempty" url:"data,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (g *GatewayTimeoutErrorBody) GetExtraProperties() map[string]interface{} { - return g.extraProperties -} - -func (g *GatewayTimeoutErrorBody) UnmarshalJSON(data []byte) error { - type unmarshaler GatewayTimeoutErrorBody - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *g = GatewayTimeoutErrorBody(value) - - extraProperties, err := core.ExtractExtraProperties(data, *g) - if err != nil { - return err - } - g.extraProperties = extraProperties - - g._rawJSON = json.RawMessage(data) - return nil -} - -func (g *GatewayTimeoutErrorBody) String() string { - if len(g._rawJSON) > 0 { - if value, err := core.StringifyJSON(g._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(g); err == nil { - return value - } - return fmt.Sprintf("%#v", g) -} - -// One of `GENERATION|NONE` to specify how and if the token likelihoods are returned with the response. Defaults to `NONE`. -// -// If `GENERATION` is selected, the token likelihoods will only be provided for generated text. -// -// WARNING: `ALL` is deprecated, and will be removed in a future release. -type GenerateRequestReturnLikelihoods string - -const ( - GenerateRequestReturnLikelihoodsGeneration GenerateRequestReturnLikelihoods = "GENERATION" - GenerateRequestReturnLikelihoodsAll GenerateRequestReturnLikelihoods = "ALL" - GenerateRequestReturnLikelihoodsNone GenerateRequestReturnLikelihoods = "NONE" -) - -func NewGenerateRequestReturnLikelihoodsFromString(s string) (GenerateRequestReturnLikelihoods, error) { - switch s { - case "GENERATION": - return GenerateRequestReturnLikelihoodsGeneration, nil - case "ALL": - return GenerateRequestReturnLikelihoodsAll, nil - case "NONE": - return GenerateRequestReturnLikelihoodsNone, nil - } - var t GenerateRequestReturnLikelihoods - return "", fmt.Errorf("%s is not a valid %T", s, t) -} - -func (g GenerateRequestReturnLikelihoods) Ptr() *GenerateRequestReturnLikelihoods { - return &g -} - -// One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. -// -// Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. -// -// If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned. -type GenerateRequestTruncate string - -const ( - GenerateRequestTruncateNone GenerateRequestTruncate = "NONE" - GenerateRequestTruncateStart GenerateRequestTruncate = "START" - GenerateRequestTruncateEnd GenerateRequestTruncate = "END" -) - -func NewGenerateRequestTruncateFromString(s string) (GenerateRequestTruncate, error) { - switch s { - case "NONE": - return GenerateRequestTruncateNone, nil - case "START": - return GenerateRequestTruncateStart, nil - case "END": - return GenerateRequestTruncateEnd, nil - } - var t GenerateRequestTruncate - return "", fmt.Errorf("%s is not a valid %T", s, t) -} - -func (g GenerateRequestTruncate) Ptr() *GenerateRequestTruncate { - return &g -} - -type GenerateStreamEnd struct { - IsFinished bool `json:"is_finished" url:"is_finished"` - FinishReason *FinishReason `json:"finish_reason,omitempty" url:"finish_reason,omitempty"` - Response *GenerateStreamEndResponse `json:"response,omitempty" url:"response,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (g *GenerateStreamEnd) GetExtraProperties() map[string]interface{} { - return g.extraProperties -} - -func (g *GenerateStreamEnd) UnmarshalJSON(data []byte) error { - type unmarshaler GenerateStreamEnd - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *g = GenerateStreamEnd(value) - - extraProperties, err := core.ExtractExtraProperties(data, *g) - if err != nil { - return err - } - g.extraProperties = extraProperties - - g._rawJSON = json.RawMessage(data) - return nil -} - -func (g *GenerateStreamEnd) String() string { - if len(g._rawJSON) > 0 { - if value, err := core.StringifyJSON(g._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(g); err == nil { - return value - } - return fmt.Sprintf("%#v", g) -} - -type GenerateStreamEndResponse struct { - Id string `json:"id" url:"id"` - Prompt *string `json:"prompt,omitempty" url:"prompt,omitempty"` - Generations []*SingleGenerationInStream `json:"generations,omitempty" url:"generations,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (g *GenerateStreamEndResponse) GetExtraProperties() map[string]interface{} { - return g.extraProperties -} - -func (g *GenerateStreamEndResponse) UnmarshalJSON(data []byte) error { - type unmarshaler GenerateStreamEndResponse - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *g = GenerateStreamEndResponse(value) - - extraProperties, err := core.ExtractExtraProperties(data, *g) - if err != nil { - return err - } - g.extraProperties = extraProperties - - g._rawJSON = json.RawMessage(data) - return nil -} - -func (g *GenerateStreamEndResponse) String() string { - if len(g._rawJSON) > 0 { - if value, err := core.StringifyJSON(g._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(g); err == nil { - return value - } - return fmt.Sprintf("%#v", g) -} - -type GenerateStreamError struct { - // Refers to the nth generation. Only present when `num_generations` is greater than zero. - Index *int `json:"index,omitempty" url:"index,omitempty"` - IsFinished bool `json:"is_finished" url:"is_finished"` - FinishReason FinishReason `json:"finish_reason" url:"finish_reason"` - // Error message - Err string `json:"err" url:"err"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (g *GenerateStreamError) GetExtraProperties() map[string]interface{} { - return g.extraProperties -} - -func (g *GenerateStreamError) UnmarshalJSON(data []byte) error { - type unmarshaler GenerateStreamError - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *g = GenerateStreamError(value) - - extraProperties, err := core.ExtractExtraProperties(data, *g) - if err != nil { - return err - } - g.extraProperties = extraProperties - - g._rawJSON = json.RawMessage(data) - return nil -} - -func (g *GenerateStreamError) String() string { - if len(g._rawJSON) > 0 { - if value, err := core.StringifyJSON(g._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(g); err == nil { - return value - } - return fmt.Sprintf("%#v", g) -} - -type GenerateStreamEvent struct { - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (g *GenerateStreamEvent) GetExtraProperties() map[string]interface{} { - return g.extraProperties -} - -func (g *GenerateStreamEvent) UnmarshalJSON(data []byte) error { - type unmarshaler GenerateStreamEvent - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *g = GenerateStreamEvent(value) - - extraProperties, err := core.ExtractExtraProperties(data, *g) - if err != nil { - return err - } - g.extraProperties = extraProperties - - g._rawJSON = json.RawMessage(data) - return nil -} - -func (g *GenerateStreamEvent) String() string { - if len(g._rawJSON) > 0 { - if value, err := core.StringifyJSON(g._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(g); err == nil { - return value - } - return fmt.Sprintf("%#v", g) -} - -// One of `GENERATION|NONE` to specify how and if the token likelihoods are returned with the response. Defaults to `NONE`. -// -// If `GENERATION` is selected, the token likelihoods will only be provided for generated text. -// -// WARNING: `ALL` is deprecated, and will be removed in a future release. -type GenerateStreamRequestReturnLikelihoods string - -const ( - GenerateStreamRequestReturnLikelihoodsGeneration GenerateStreamRequestReturnLikelihoods = "GENERATION" - GenerateStreamRequestReturnLikelihoodsAll GenerateStreamRequestReturnLikelihoods = "ALL" - GenerateStreamRequestReturnLikelihoodsNone GenerateStreamRequestReturnLikelihoods = "NONE" -) - -func NewGenerateStreamRequestReturnLikelihoodsFromString(s string) (GenerateStreamRequestReturnLikelihoods, error) { - switch s { - case "GENERATION": - return GenerateStreamRequestReturnLikelihoodsGeneration, nil - case "ALL": - return GenerateStreamRequestReturnLikelihoodsAll, nil - case "NONE": - return GenerateStreamRequestReturnLikelihoodsNone, nil - } - var t GenerateStreamRequestReturnLikelihoods - return "", fmt.Errorf("%s is not a valid %T", s, t) -} - -func (g GenerateStreamRequestReturnLikelihoods) Ptr() *GenerateStreamRequestReturnLikelihoods { - return &g -} - -// One of `NONE|START|END` to specify how the API will handle inputs longer than the maximum token length. -// -// Passing `START` will discard the start of the input. `END` will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. -// -// If `NONE` is selected, when the input exceeds the maximum input token length an error will be returned. -type GenerateStreamRequestTruncate string - -const ( - GenerateStreamRequestTruncateNone GenerateStreamRequestTruncate = "NONE" - GenerateStreamRequestTruncateStart GenerateStreamRequestTruncate = "START" - GenerateStreamRequestTruncateEnd GenerateStreamRequestTruncate = "END" -) - -func NewGenerateStreamRequestTruncateFromString(s string) (GenerateStreamRequestTruncate, error) { - switch s { - case "NONE": - return GenerateStreamRequestTruncateNone, nil - case "START": - return GenerateStreamRequestTruncateStart, nil - case "END": - return GenerateStreamRequestTruncateEnd, nil - } - var t GenerateStreamRequestTruncate - return "", fmt.Errorf("%s is not a valid %T", s, t) -} - -func (g GenerateStreamRequestTruncate) Ptr() *GenerateStreamRequestTruncate { - return &g -} - -type GenerateStreamText struct { - // A segment of text of the generation. - Text string `json:"text" url:"text"` - // Refers to the nth generation. Only present when `num_generations` is greater than zero, and only when text responses are being streamed. - Index *int `json:"index,omitempty" url:"index,omitempty"` - IsFinished bool `json:"is_finished" url:"is_finished"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (g *GenerateStreamText) GetExtraProperties() map[string]interface{} { - return g.extraProperties -} - -func (g *GenerateStreamText) UnmarshalJSON(data []byte) error { - type unmarshaler GenerateStreamText - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *g = GenerateStreamText(value) - - extraProperties, err := core.ExtractExtraProperties(data, *g) - if err != nil { - return err - } - g.extraProperties = extraProperties - - g._rawJSON = json.RawMessage(data) - return nil -} - -func (g *GenerateStreamText) String() string { - if len(g._rawJSON) > 0 { - if value, err := core.StringifyJSON(g._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(g); err == nil { - return value - } - return fmt.Sprintf("%#v", g) -} - -// Response in content type stream when `stream` is `true` in the request parameters. Generation tokens are streamed with the GenerationStream response. The final response is of type GenerationFinalResponse. -type GenerateStreamedResponse struct { - EventType string - TextGeneration *GenerateStreamText - StreamEnd *GenerateStreamEnd - StreamError *GenerateStreamError -} - -func (g *GenerateStreamedResponse) UnmarshalJSON(data []byte) error { - var unmarshaler struct { - EventType string `json:"event_type"` - } - if err := json.Unmarshal(data, &unmarshaler); err != nil { - return err - } - g.EventType = unmarshaler.EventType - if unmarshaler.EventType == "" { - return fmt.Errorf("%T did not include discriminant event_type", g) - } - switch unmarshaler.EventType { - case "text-generation": - value := new(GenerateStreamText) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - g.TextGeneration = value - case "stream-end": - value := new(GenerateStreamEnd) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - g.StreamEnd = value - case "stream-error": - value := new(GenerateStreamError) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - g.StreamError = value - } - return nil -} - -func (g GenerateStreamedResponse) MarshalJSON() ([]byte, error) { - if g.TextGeneration != nil { - return core.MarshalJSONWithExtraProperty(g.TextGeneration, "event_type", "text-generation") - } - if g.StreamEnd != nil { - return core.MarshalJSONWithExtraProperty(g.StreamEnd, "event_type", "stream-end") - } - if g.StreamError != nil { - return core.MarshalJSONWithExtraProperty(g.StreamError, "event_type", "stream-error") - } - return nil, fmt.Errorf("type %T does not define a non-empty union type", g) -} - -type GenerateStreamedResponseVisitor interface { - VisitTextGeneration(*GenerateStreamText) error - VisitStreamEnd(*GenerateStreamEnd) error - VisitStreamError(*GenerateStreamError) error -} - -func (g *GenerateStreamedResponse) Accept(visitor GenerateStreamedResponseVisitor) error { - if g.TextGeneration != nil { - return visitor.VisitTextGeneration(g.TextGeneration) - } - if g.StreamEnd != nil { - return visitor.VisitStreamEnd(g.StreamEnd) - } - if g.StreamError != nil { - return visitor.VisitStreamError(g.StreamError) - } - return fmt.Errorf("type %T does not define a non-empty union type", g) -} - -type Generation struct { - Id string `json:"id" url:"id"` - // Prompt used for generations. - Prompt *string `json:"prompt,omitempty" url:"prompt,omitempty"` - // List of generated results - Generations []*SingleGeneration `json:"generations,omitempty" url:"generations,omitempty"` - Meta *ApiMeta `json:"meta,omitempty" url:"meta,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (g *Generation) GetExtraProperties() map[string]interface{} { - return g.extraProperties -} - -func (g *Generation) UnmarshalJSON(data []byte) error { - type unmarshaler Generation - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *g = Generation(value) - - extraProperties, err := core.ExtractExtraProperties(data, *g) - if err != nil { - return err - } - g.extraProperties = extraProperties - - g._rawJSON = json.RawMessage(data) - return nil -} - -func (g *Generation) String() string { - if len(g._rawJSON) > 0 { - if value, err := core.StringifyJSON(g._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(g); err == nil { - return value - } - return fmt.Sprintf("%#v", g) -} - -type GetConnectorResponse struct { - Connector *Connector `json:"connector,omitempty" url:"connector,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (g *GetConnectorResponse) GetExtraProperties() map[string]interface{} { - return g.extraProperties -} - -func (g *GetConnectorResponse) UnmarshalJSON(data []byte) error { - type unmarshaler GetConnectorResponse - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *g = GetConnectorResponse(value) - - extraProperties, err := core.ExtractExtraProperties(data, *g) - if err != nil { - return err - } - g.extraProperties = extraProperties - - g._rawJSON = json.RawMessage(data) - return nil -} - -func (g *GetConnectorResponse) String() string { - if len(g._rawJSON) > 0 { - if value, err := core.StringifyJSON(g._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(g); err == nil { - return value - } - return fmt.Sprintf("%#v", g) -} - -// Contains information about the model and which API endpoints it can be used with. -type GetModelResponse struct { - // Specify this name in the `model` parameter of API requests to use your chosen model. - Name *string `json:"name,omitempty" url:"name,omitempty"` - // The API endpoints that the model is compatible with. - Endpoints []CompatibleEndpoint `json:"endpoints,omitempty" url:"endpoints,omitempty"` - // Whether the model has been fine-tuned or not. - Finetuned *bool `json:"finetuned,omitempty" url:"finetuned,omitempty"` - // The maximum number of tokens that the model can process in a single request. Note that not all of these tokens are always available due to special tokens and preambles that Cohere has added by default. - ContextLength *float64 `json:"context_length,omitempty" url:"context_length,omitempty"` - // Public URL to the tokenizer's configuration file. - TokenizerUrl *string `json:"tokenizer_url,omitempty" url:"tokenizer_url,omitempty"` - // The API endpoints that the model is default to. - DefaultEndpoints []CompatibleEndpoint `json:"default_endpoints,omitempty" url:"default_endpoints,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (g *GetModelResponse) GetExtraProperties() map[string]interface{} { - return g.extraProperties -} - -func (g *GetModelResponse) UnmarshalJSON(data []byte) error { - type unmarshaler GetModelResponse - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *g = GetModelResponse(value) - - extraProperties, err := core.ExtractExtraProperties(data, *g) - if err != nil { - return err - } - g.extraProperties = extraProperties - - g._rawJSON = json.RawMessage(data) - return nil -} - -func (g *GetModelResponse) String() string { - if len(g._rawJSON) > 0 { - if value, err := core.StringifyJSON(g._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(g); err == nil { - return value - } - return fmt.Sprintf("%#v", g) -} - -type Image struct { - // Width of the image in pixels - Width int64 `json:"width" url:"width"` - // Height of the image in pixels - Height int64 `json:"height" url:"height"` - // Format of the image - Format string `json:"format" url:"format"` - // Bit depth of the image - BitDepth int64 `json:"bit_depth" url:"bit_depth"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (i *Image) GetExtraProperties() map[string]interface{} { - return i.extraProperties -} - -func (i *Image) UnmarshalJSON(data []byte) error { - type unmarshaler Image - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *i = Image(value) - - extraProperties, err := core.ExtractExtraProperties(data, *i) - if err != nil { - return err - } - i.extraProperties = extraProperties - - i._rawJSON = json.RawMessage(data) - return nil -} - -func (i *Image) String() string { - if len(i._rawJSON) > 0 { - if value, err := core.StringifyJSON(i._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(i); err == nil { - return value - } - return fmt.Sprintf("%#v", i) -} - -type JsonResponseFormat struct { - // A JSON schema object that the output will adhere to. There are some restrictions we have on the schema, refer to [our guide](https://docs.cohere.com/docs/structured-outputs-json#schema-constraints) for more information. - // Example (required name and age object): - // - // ```json - // - // { - // "type": "object", - // "properties": { - // "name": { "type": "string" }, - // "age": { "type": "integer" } - // }, - // "required": ["name", "age"] - // } - // - // ``` - // - // **Note**: This field must not be specified when the `type` is set to `"text"`. - Schema map[string]interface{} `json:"schema,omitempty" url:"schema,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (j *JsonResponseFormat) GetExtraProperties() map[string]interface{} { - return j.extraProperties -} - -func (j *JsonResponseFormat) UnmarshalJSON(data []byte) error { - type unmarshaler JsonResponseFormat - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *j = JsonResponseFormat(value) - - extraProperties, err := core.ExtractExtraProperties(data, *j) - if err != nil { - return err - } - j.extraProperties = extraProperties - - j._rawJSON = json.RawMessage(data) - return nil -} - -func (j *JsonResponseFormat) String() string { - if len(j._rawJSON) > 0 { - if value, err := core.StringifyJSON(j._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(j); err == nil { - return value - } - return fmt.Sprintf("%#v", j) -} - -type JsonResponseFormatV2 struct { - // A [JSON schema](https://json-schema.org/overview/what-is-jsonschema) object that the output will adhere to. There are some restrictions we have on the schema, refer to [our guide](https://docs.cohere.com/docs/structured-outputs-json#schema-constraints) for more information. - // Example (required name and age object): - // - // ```json - // - // { - // "type": "object", - // "properties": { - // "name": { "type": "string" }, - // "age": { "type": "integer" } - // }, - // "required": ["name", "age"] - // } - // - // ``` - // - // **Note**: This field must not be specified when the `type` is set to `"text"`. - JsonSchema map[string]interface{} `json:"json_schema,omitempty" url:"json_schema,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (j *JsonResponseFormatV2) GetExtraProperties() map[string]interface{} { - return j.extraProperties -} - -func (j *JsonResponseFormatV2) UnmarshalJSON(data []byte) error { - type unmarshaler JsonResponseFormatV2 - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *j = JsonResponseFormatV2(value) - - extraProperties, err := core.ExtractExtraProperties(data, *j) - if err != nil { - return err - } - j.extraProperties = extraProperties - - j._rawJSON = json.RawMessage(data) - return nil -} - -func (j *JsonResponseFormatV2) String() string { - if len(j._rawJSON) > 0 { - if value, err := core.StringifyJSON(j._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(j); err == nil { - return value - } - return fmt.Sprintf("%#v", j) -} - -type LabelMetric struct { - // Total number of examples for this label - TotalExamples *int64 `json:"total_examples,omitempty" url:"total_examples,omitempty"` - // value of the label - Label *string `json:"label,omitempty" url:"label,omitempty"` - // samples for this label - Samples []string `json:"samples,omitempty" url:"samples,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (l *LabelMetric) GetExtraProperties() map[string]interface{} { - return l.extraProperties -} - -func (l *LabelMetric) UnmarshalJSON(data []byte) error { - type unmarshaler LabelMetric - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *l = LabelMetric(value) - - extraProperties, err := core.ExtractExtraProperties(data, *l) - if err != nil { - return err - } - l.extraProperties = extraProperties - - l._rawJSON = json.RawMessage(data) - return nil -} - -func (l *LabelMetric) String() string { - if len(l._rawJSON) > 0 { - if value, err := core.StringifyJSON(l._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(l); err == nil { - return value - } - return fmt.Sprintf("%#v", l) -} - -type ListConnectorsResponse struct { - Connectors []*Connector `json:"connectors,omitempty" url:"connectors,omitempty"` - // Total number of connectors. - TotalCount *float64 `json:"total_count,omitempty" url:"total_count,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (l *ListConnectorsResponse) GetExtraProperties() map[string]interface{} { - return l.extraProperties -} - -func (l *ListConnectorsResponse) UnmarshalJSON(data []byte) error { - type unmarshaler ListConnectorsResponse - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *l = ListConnectorsResponse(value) - - extraProperties, err := core.ExtractExtraProperties(data, *l) - if err != nil { - return err - } - l.extraProperties = extraProperties - - l._rawJSON = json.RawMessage(data) - return nil -} - -func (l *ListConnectorsResponse) String() string { - if len(l._rawJSON) > 0 { - if value, err := core.StringifyJSON(l._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(l); err == nil { - return value - } - return fmt.Sprintf("%#v", l) -} - -type ListEmbedJobResponse struct { - EmbedJobs []*EmbedJob `json:"embed_jobs,omitempty" url:"embed_jobs,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (l *ListEmbedJobResponse) GetExtraProperties() map[string]interface{} { - return l.extraProperties -} - -func (l *ListEmbedJobResponse) UnmarshalJSON(data []byte) error { - type unmarshaler ListEmbedJobResponse - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *l = ListEmbedJobResponse(value) - - extraProperties, err := core.ExtractExtraProperties(data, *l) - if err != nil { - return err - } - l.extraProperties = extraProperties - - l._rawJSON = json.RawMessage(data) - return nil -} - -func (l *ListEmbedJobResponse) String() string { - if len(l._rawJSON) > 0 { - if value, err := core.StringifyJSON(l._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(l); err == nil { - return value - } - return fmt.Sprintf("%#v", l) -} - -type ListModelsResponse struct { - Models []*GetModelResponse `json:"models,omitempty" url:"models,omitempty"` - // A token to retrieve the next page of results. Provide in the page_token parameter of the next request. - NextPageToken *string `json:"next_page_token,omitempty" url:"next_page_token,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (l *ListModelsResponse) GetExtraProperties() map[string]interface{} { - return l.extraProperties -} - -func (l *ListModelsResponse) UnmarshalJSON(data []byte) error { - type unmarshaler ListModelsResponse - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *l = ListModelsResponse(value) - - extraProperties, err := core.ExtractExtraProperties(data, *l) - if err != nil { - return err - } - l.extraProperties = extraProperties - - l._rawJSON = json.RawMessage(data) - return nil -} - -func (l *ListModelsResponse) String() string { - if len(l._rawJSON) > 0 { - if value, err := core.StringifyJSON(l._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(l); err == nil { - return value - } - return fmt.Sprintf("%#v", l) -} - -type LogprobItem struct { - // The text chunk for which the log probabilities was calculated. - Text *string `json:"text,omitempty" url:"text,omitempty"` - // The token ids of each token used to construct the text chunk. - TokenIds []int `json:"token_ids,omitempty" url:"token_ids,omitempty"` - // The log probability of each token used to construct the text chunk. - Logprobs []float64 `json:"logprobs,omitempty" url:"logprobs,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (l *LogprobItem) GetExtraProperties() map[string]interface{} { - return l.extraProperties -} - -func (l *LogprobItem) UnmarshalJSON(data []byte) error { - type unmarshaler LogprobItem - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *l = LogprobItem(value) - - extraProperties, err := core.ExtractExtraProperties(data, *l) - if err != nil { - return err - } - l.extraProperties = extraProperties - - l._rawJSON = json.RawMessage(data) - return nil -} - -func (l *LogprobItem) String() string { - if len(l._rawJSON) > 0 { - if value, err := core.StringifyJSON(l._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(l); err == nil { - return value - } - return fmt.Sprintf("%#v", l) -} - -type Message struct { - Role string - Chatbot *ChatMessage - System *ChatMessage - User *ChatMessage - Tool *ToolMessage -} - -func (m *Message) UnmarshalJSON(data []byte) error { - var unmarshaler struct { - Role string `json:"role"` - } - if err := json.Unmarshal(data, &unmarshaler); err != nil { - return err - } - m.Role = unmarshaler.Role - if unmarshaler.Role == "" { - return fmt.Errorf("%T did not include discriminant role", m) - } - switch unmarshaler.Role { - case "CHATBOT": - value := new(ChatMessage) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - m.Chatbot = value - case "SYSTEM": - value := new(ChatMessage) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - m.System = value - case "USER": - value := new(ChatMessage) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - m.User = value - case "TOOL": - value := new(ToolMessage) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - m.Tool = value - } - return nil -} - -func (m Message) MarshalJSON() ([]byte, error) { - if m.Chatbot != nil { - return core.MarshalJSONWithExtraProperty(m.Chatbot, "role", "CHATBOT") - } - if m.System != nil { - return core.MarshalJSONWithExtraProperty(m.System, "role", "SYSTEM") - } - if m.User != nil { - return core.MarshalJSONWithExtraProperty(m.User, "role", "USER") - } - if m.Tool != nil { - return core.MarshalJSONWithExtraProperty(m.Tool, "role", "TOOL") - } - return nil, fmt.Errorf("type %T does not define a non-empty union type", m) -} - -type MessageVisitor interface { - VisitChatbot(*ChatMessage) error - VisitSystem(*ChatMessage) error - VisitUser(*ChatMessage) error - VisitTool(*ToolMessage) error -} - -func (m *Message) Accept(visitor MessageVisitor) error { - if m.Chatbot != nil { - return visitor.VisitChatbot(m.Chatbot) - } - if m.System != nil { - return visitor.VisitSystem(m.System) - } - if m.User != nil { - return visitor.VisitUser(m.User) - } - if m.Tool != nil { - return visitor.VisitTool(m.Tool) - } - return fmt.Errorf("type %T does not define a non-empty union type", m) -} - -type Metrics struct { - FinetuneDatasetMetrics *FinetuneDatasetMetrics `json:"finetune_dataset_metrics,omitempty" url:"finetune_dataset_metrics,omitempty"` - EmbedData *MetricsEmbedData `json:"embed_data,omitempty" url:"embed_data,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (m *Metrics) GetExtraProperties() map[string]interface{} { - return m.extraProperties -} - -func (m *Metrics) UnmarshalJSON(data []byte) error { - type unmarshaler Metrics - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *m = Metrics(value) - - extraProperties, err := core.ExtractExtraProperties(data, *m) - if err != nil { - return err - } - m.extraProperties = extraProperties - - m._rawJSON = json.RawMessage(data) - return nil -} - -func (m *Metrics) String() string { - if len(m._rawJSON) > 0 { - if value, err := core.StringifyJSON(m._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(m); err == nil { - return value - } - return fmt.Sprintf("%#v", m) -} - -type MetricsEmbedData struct { - // the fields in the dataset - Fields []*MetricsEmbedDataFieldsItem `json:"fields,omitempty" url:"fields,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (m *MetricsEmbedData) GetExtraProperties() map[string]interface{} { - return m.extraProperties -} - -func (m *MetricsEmbedData) UnmarshalJSON(data []byte) error { - type unmarshaler MetricsEmbedData - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *m = MetricsEmbedData(value) - - extraProperties, err := core.ExtractExtraProperties(data, *m) - if err != nil { - return err - } - m.extraProperties = extraProperties - - m._rawJSON = json.RawMessage(data) - return nil -} - -func (m *MetricsEmbedData) String() string { - if len(m._rawJSON) > 0 { - if value, err := core.StringifyJSON(m._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(m); err == nil { - return value - } - return fmt.Sprintf("%#v", m) -} - -type MetricsEmbedDataFieldsItem struct { - // the name of the field - Name *string `json:"name,omitempty" url:"name,omitempty"` - // the number of times the field appears in the dataset - Count *float64 `json:"count,omitempty" url:"count,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (m *MetricsEmbedDataFieldsItem) GetExtraProperties() map[string]interface{} { - return m.extraProperties -} - -func (m *MetricsEmbedDataFieldsItem) UnmarshalJSON(data []byte) error { - type unmarshaler MetricsEmbedDataFieldsItem - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *m = MetricsEmbedDataFieldsItem(value) - - extraProperties, err := core.ExtractExtraProperties(data, *m) - if err != nil { - return err - } - m.extraProperties = extraProperties - - m._rawJSON = json.RawMessage(data) - return nil -} - -func (m *MetricsEmbedDataFieldsItem) String() string { - if len(m._rawJSON) > 0 { - if value, err := core.StringifyJSON(m._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(m); err == nil { - return value - } - return fmt.Sprintf("%#v", m) -} - -type NonStreamedChatResponse struct { - // Contents of the reply generated by the model. - Text string `json:"text" url:"text"` - // Unique identifier for the generated reply. Useful for submitting feedback. - GenerationId *string `json:"generation_id,omitempty" url:"generation_id,omitempty"` - // Unique identifier for the response. - ResponseId *string `json:"response_id,omitempty" url:"response_id,omitempty"` - // Inline citations for the generated reply. - Citations []*ChatCitation `json:"citations,omitempty" url:"citations,omitempty"` - // Documents seen by the model when generating the reply. - Documents []ChatDocument `json:"documents,omitempty" url:"documents,omitempty"` - // Denotes that a search for documents is required during the RAG flow. - IsSearchRequired *bool `json:"is_search_required,omitempty" url:"is_search_required,omitempty"` - // Generated search queries, meant to be used as part of the RAG flow. - SearchQueries []*ChatSearchQuery `json:"search_queries,omitempty" url:"search_queries,omitempty"` - // Documents retrieved from each of the conducted searches. - SearchResults []*ChatSearchResult `json:"search_results,omitempty" url:"search_results,omitempty"` - FinishReason *FinishReason `json:"finish_reason,omitempty" url:"finish_reason,omitempty"` - ToolCalls []*ToolCall `json:"tool_calls,omitempty" url:"tool_calls,omitempty"` - // A list of previous messages between the user and the model, meant to give the model conversational context for responding to the user's `message`. - ChatHistory []*Message `json:"chat_history,omitempty" url:"chat_history,omitempty"` - // The prompt that was used. Only present when `return_prompt` in the request is set to true. - Prompt *string `json:"prompt,omitempty" url:"prompt,omitempty"` - Meta *ApiMeta `json:"meta,omitempty" url:"meta,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (n *NonStreamedChatResponse) GetExtraProperties() map[string]interface{} { - return n.extraProperties -} - -func (n *NonStreamedChatResponse) UnmarshalJSON(data []byte) error { - type unmarshaler NonStreamedChatResponse - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *n = NonStreamedChatResponse(value) - - extraProperties, err := core.ExtractExtraProperties(data, *n) - if err != nil { - return err - } - n.extraProperties = extraProperties - - n._rawJSON = json.RawMessage(data) - return nil -} - -func (n *NonStreamedChatResponse) String() string { - if len(n._rawJSON) > 0 { - if value, err := core.StringifyJSON(n._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(n); err == nil { - return value - } - return fmt.Sprintf("%#v", n) -} - -type NotImplementedErrorBody struct { - Data *string `json:"data,omitempty" url:"data,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (n *NotImplementedErrorBody) GetExtraProperties() map[string]interface{} { - return n.extraProperties -} - -func (n *NotImplementedErrorBody) UnmarshalJSON(data []byte) error { - type unmarshaler NotImplementedErrorBody - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *n = NotImplementedErrorBody(value) - - extraProperties, err := core.ExtractExtraProperties(data, *n) - if err != nil { - return err - } - n.extraProperties = extraProperties - - n._rawJSON = json.RawMessage(data) - return nil -} - -func (n *NotImplementedErrorBody) String() string { - if len(n._rawJSON) > 0 { - if value, err := core.StringifyJSON(n._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(n); err == nil { - return value - } - return fmt.Sprintf("%#v", n) -} - -type OAuthAuthorizeResponse struct { - // The OAuth 2.0 redirect url. Redirect the user to this url to authorize the connector. - RedirectUrl *string `json:"redirect_url,omitempty" url:"redirect_url,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (o *OAuthAuthorizeResponse) GetExtraProperties() map[string]interface{} { - return o.extraProperties -} - -func (o *OAuthAuthorizeResponse) UnmarshalJSON(data []byte) error { - type unmarshaler OAuthAuthorizeResponse - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *o = OAuthAuthorizeResponse(value) - - extraProperties, err := core.ExtractExtraProperties(data, *o) - if err != nil { - return err - } - o.extraProperties = extraProperties - - o._rawJSON = json.RawMessage(data) - return nil -} - -func (o *OAuthAuthorizeResponse) String() string { - if len(o._rawJSON) > 0 { - if value, err := core.StringifyJSON(o._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(o); err == nil { - return value - } - return fmt.Sprintf("%#v", o) -} - -type ParseInfo struct { - Separator *string `json:"separator,omitempty" url:"separator,omitempty"` - Delimiter *string `json:"delimiter,omitempty" url:"delimiter,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (p *ParseInfo) GetExtraProperties() map[string]interface{} { - return p.extraProperties -} - -func (p *ParseInfo) UnmarshalJSON(data []byte) error { - type unmarshaler ParseInfo - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *p = ParseInfo(value) - - extraProperties, err := core.ExtractExtraProperties(data, *p) - if err != nil { - return err - } - p.extraProperties = extraProperties - - p._rawJSON = json.RawMessage(data) - return nil -} - -func (p *ParseInfo) String() string { - if len(p._rawJSON) > 0 { - if value, err := core.StringifyJSON(p._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(p); err == nil { - return value - } - return fmt.Sprintf("%#v", p) -} - -type RerankDocument = map[string]string - -type RerankRequestDocumentsItem struct { - String string - RerankDocument RerankDocument -} - -func (r *RerankRequestDocumentsItem) UnmarshalJSON(data []byte) error { - var valueString string - if err := json.Unmarshal(data, &valueString); err == nil { - r.String = valueString - return nil - } - var valueRerankDocument RerankDocument - if err := json.Unmarshal(data, &valueRerankDocument); err == nil { - r.RerankDocument = valueRerankDocument - return nil - } - return fmt.Errorf("%s cannot be deserialized as a %T", data, r) -} - -func (r RerankRequestDocumentsItem) MarshalJSON() ([]byte, error) { - if r.String != "" { - return json.Marshal(r.String) - } - if r.RerankDocument != nil { - return json.Marshal(r.RerankDocument) - } - return nil, fmt.Errorf("type %T does not include a non-empty union type", r) -} - -type RerankRequestDocumentsItemVisitor interface { - VisitString(string) error - VisitRerankDocument(RerankDocument) error -} - -func (r *RerankRequestDocumentsItem) Accept(visitor RerankRequestDocumentsItemVisitor) error { - if r.String != "" { - return visitor.VisitString(r.String) - } - if r.RerankDocument != nil { - return visitor.VisitRerankDocument(r.RerankDocument) - } - return fmt.Errorf("type %T does not include a non-empty union type", r) -} - -type RerankResponse struct { - Id *string `json:"id,omitempty" url:"id,omitempty"` - // An ordered list of ranked documents - Results []*RerankResponseResultsItem `json:"results,omitempty" url:"results,omitempty"` - Meta *ApiMeta `json:"meta,omitempty" url:"meta,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (r *RerankResponse) GetExtraProperties() map[string]interface{} { - return r.extraProperties -} - -func (r *RerankResponse) UnmarshalJSON(data []byte) error { - type unmarshaler RerankResponse - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *r = RerankResponse(value) - - extraProperties, err := core.ExtractExtraProperties(data, *r) - if err != nil { - return err - } - r.extraProperties = extraProperties - - r._rawJSON = json.RawMessage(data) - return nil -} - -func (r *RerankResponse) String() string { - if len(r._rawJSON) > 0 { - if value, err := core.StringifyJSON(r._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(r); err == nil { - return value - } - return fmt.Sprintf("%#v", r) -} - -type RerankResponseResultsItem struct { - // If `return_documents` is set as `false` this will return none, if `true` it will return the documents passed in - Document *RerankResponseResultsItemDocument `json:"document,omitempty" url:"document,omitempty"` - // Corresponds to the index in the original list of documents to which the ranked document belongs. (i.e. if the first value in the `results` object has an `index` value of 3, it means in the list of documents passed in, the document at `index=3` had the highest relevance) - Index int `json:"index" url:"index"` - // Relevance scores are normalized to be in the range `[0, 1]`. Scores close to `1` indicate a high relevance to the query, and scores closer to `0` indicate low relevance. It is not accurate to assume a score of 0.9 means the document is 2x more relevant than a document with a score of 0.45 - RelevanceScore float64 `json:"relevance_score" url:"relevance_score"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (r *RerankResponseResultsItem) GetExtraProperties() map[string]interface{} { - return r.extraProperties -} - -func (r *RerankResponseResultsItem) UnmarshalJSON(data []byte) error { - type unmarshaler RerankResponseResultsItem - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *r = RerankResponseResultsItem(value) - - extraProperties, err := core.ExtractExtraProperties(data, *r) - if err != nil { - return err - } - r.extraProperties = extraProperties - - r._rawJSON = json.RawMessage(data) - return nil -} - -func (r *RerankResponseResultsItem) String() string { - if len(r._rawJSON) > 0 { - if value, err := core.StringifyJSON(r._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(r); err == nil { - return value - } - return fmt.Sprintf("%#v", r) -} - -// If `return_documents` is set as `false` this will return none, if `true` it will return the documents passed in -type RerankResponseResultsItemDocument struct { - // The text of the document to rerank - Text string `json:"text" url:"text"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (r *RerankResponseResultsItemDocument) GetExtraProperties() map[string]interface{} { - return r.extraProperties -} - -func (r *RerankResponseResultsItemDocument) UnmarshalJSON(data []byte) error { - type unmarshaler RerankResponseResultsItemDocument - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *r = RerankResponseResultsItemDocument(value) - - extraProperties, err := core.ExtractExtraProperties(data, *r) - if err != nil { - return err - } - r.extraProperties = extraProperties - - r._rawJSON = json.RawMessage(data) - return nil -} - -func (r *RerankResponseResultsItemDocument) String() string { - if len(r._rawJSON) > 0 { - if value, err := core.StringifyJSON(r._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(r); err == nil { - return value - } - return fmt.Sprintf("%#v", r) -} - -type RerankerDataMetrics struct { - // The number of training queries. - NumTrainQueries *int64 `json:"num_train_queries,omitempty" url:"num_train_queries,omitempty"` - // The sum of all relevant passages of valid training examples. - NumTrainRelevantPassages *int64 `json:"num_train_relevant_passages,omitempty" url:"num_train_relevant_passages,omitempty"` - // The sum of all hard negatives of valid training examples. - NumTrainHardNegatives *int64 `json:"num_train_hard_negatives,omitempty" url:"num_train_hard_negatives,omitempty"` - // The number of evaluation queries. - NumEvalQueries *int64 `json:"num_eval_queries,omitempty" url:"num_eval_queries,omitempty"` - // The sum of all relevant passages of valid eval examples. - NumEvalRelevantPassages *int64 `json:"num_eval_relevant_passages,omitempty" url:"num_eval_relevant_passages,omitempty"` - // The sum of all hard negatives of valid eval examples. - NumEvalHardNegatives *int64 `json:"num_eval_hard_negatives,omitempty" url:"num_eval_hard_negatives,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (r *RerankerDataMetrics) GetExtraProperties() map[string]interface{} { - return r.extraProperties -} - -func (r *RerankerDataMetrics) UnmarshalJSON(data []byte) error { - type unmarshaler RerankerDataMetrics - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *r = RerankerDataMetrics(value) - - extraProperties, err := core.ExtractExtraProperties(data, *r) - if err != nil { - return err - } - r.extraProperties = extraProperties - - r._rawJSON = json.RawMessage(data) - return nil -} - -func (r *RerankerDataMetrics) String() string { - if len(r._rawJSON) > 0 { - if value, err := core.StringifyJSON(r._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(r); err == nil { - return value - } - return fmt.Sprintf("%#v", r) -} - -// Configuration for forcing the model output to adhere to the specified format. Supported on [Command R 03-2024](https://docs.cohere.com/docs/command-r), [Command R+ 04-2024](https://docs.cohere.com/docs/command-r-plus) and newer models. -// -// The model can be forced into outputting JSON objects (with up to 5 levels of nesting) by setting `{ "type": "json_object" }`. -// -// A [JSON Schema](https://json-schema.org/) can optionally be provided, to ensure a specific structure. -// -// **Note**: When using `{ "type": "json_object" }` your `message` should always explicitly instruct the model to generate a JSON (eg: _"Generate a JSON ..."_) . Otherwise the model may end up getting stuck generating an infinite stream of characters and eventually run out of context length. -// **Limitation**: The parameter is not supported in RAG mode (when any of `connectors`, `documents`, `tools`, `tool_results` are provided). -type ResponseFormat struct { - Type string - Text *TextResponseFormat - JsonObject *JsonResponseFormat -} - -func (r *ResponseFormat) UnmarshalJSON(data []byte) error { - var unmarshaler struct { - Type string `json:"type"` - } - if err := json.Unmarshal(data, &unmarshaler); err != nil { - return err - } - r.Type = unmarshaler.Type - if unmarshaler.Type == "" { - return fmt.Errorf("%T did not include discriminant type", r) - } - switch unmarshaler.Type { - case "text": - value := new(TextResponseFormat) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - r.Text = value - case "json_object": - value := new(JsonResponseFormat) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - r.JsonObject = value - } - return nil -} - -func (r ResponseFormat) MarshalJSON() ([]byte, error) { - if r.Text != nil { - return core.MarshalJSONWithExtraProperty(r.Text, "type", "text") - } - if r.JsonObject != nil { - return core.MarshalJSONWithExtraProperty(r.JsonObject, "type", "json_object") - } - return nil, fmt.Errorf("type %T does not define a non-empty union type", r) -} - -type ResponseFormatVisitor interface { - VisitText(*TextResponseFormat) error - VisitJsonObject(*JsonResponseFormat) error -} - -func (r *ResponseFormat) Accept(visitor ResponseFormatVisitor) error { - if r.Text != nil { - return visitor.VisitText(r.Text) - } - if r.JsonObject != nil { - return visitor.VisitJsonObject(r.JsonObject) - } - return fmt.Errorf("type %T does not define a non-empty union type", r) -} - -// Configuration for forcing the model output to adhere to the specified format. Supported on [Command R](https://docs.cohere.com/v2/docs/command-r), [Command R+](https://docs.cohere.com/v2/docs/command-r-plus) and newer models. -// -// The model can be forced into outputting JSON objects by setting `{ "type": "json_object" }`. -// -// A [JSON Schema](https://json-schema.org/) can optionally be provided, to ensure a specific structure. -// -// **Note**: When using `{ "type": "json_object" }` your `message` should always explicitly instruct the model to generate a JSON (eg: _"Generate a JSON ..."_) . Otherwise the model may end up getting stuck generating an infinite stream of characters and eventually run out of context length. -// -// **Note**: When `json_schema` is not specified, the generated object can have up to 5 layers of nesting. -// -// **Limitation**: The parameter is not supported when used in combinations with the `documents` or `tools` parameters. -type ResponseFormatV2 struct { - Type string - Text *TextResponseFormatV2 - JsonObject *JsonResponseFormatV2 -} - -func (r *ResponseFormatV2) UnmarshalJSON(data []byte) error { - var unmarshaler struct { - Type string `json:"type"` - } - if err := json.Unmarshal(data, &unmarshaler); err != nil { - return err - } - r.Type = unmarshaler.Type - if unmarshaler.Type == "" { - return fmt.Errorf("%T did not include discriminant type", r) - } - switch unmarshaler.Type { - case "text": - value := new(TextResponseFormatV2) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - r.Text = value - case "json_object": - value := new(JsonResponseFormatV2) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - r.JsonObject = value - } - return nil -} - -func (r ResponseFormatV2) MarshalJSON() ([]byte, error) { - if r.Text != nil { - return core.MarshalJSONWithExtraProperty(r.Text, "type", "text") - } - if r.JsonObject != nil { - return core.MarshalJSONWithExtraProperty(r.JsonObject, "type", "json_object") - } - return nil, fmt.Errorf("type %T does not define a non-empty union type", r) -} - -type ResponseFormatV2Visitor interface { - VisitText(*TextResponseFormatV2) error - VisitJsonObject(*JsonResponseFormatV2) error -} - -func (r *ResponseFormatV2) Accept(visitor ResponseFormatV2Visitor) error { - if r.Text != nil { - return visitor.VisitText(r.Text) - } - if r.JsonObject != nil { - return visitor.VisitJsonObject(r.JsonObject) - } - return fmt.Errorf("type %T does not define a non-empty union type", r) -} - -type SingleGeneration struct { - Id string `json:"id" url:"id"` - Text string `json:"text" url:"text"` - // Refers to the nth generation. Only present when `num_generations` is greater than zero. - Index *int `json:"index,omitempty" url:"index,omitempty"` - Likelihood *float64 `json:"likelihood,omitempty" url:"likelihood,omitempty"` - // Only returned if `return_likelihoods` is set to `GENERATION` or `ALL`. The likelihood refers to the average log-likelihood of the entire specified string, which is useful for [evaluating the performance of your model](likelihood-eval), especially if you've created a [custom model](https://docs.cohere.com/docs/training-custom-models). Individual token likelihoods provide the log-likelihood of each token. The first token will not have a likelihood. - TokenLikelihoods []*SingleGenerationTokenLikelihoodsItem `json:"token_likelihoods,omitempty" url:"token_likelihoods,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (s *SingleGeneration) GetExtraProperties() map[string]interface{} { - return s.extraProperties -} - -func (s *SingleGeneration) UnmarshalJSON(data []byte) error { - type unmarshaler SingleGeneration - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *s = SingleGeneration(value) - - extraProperties, err := core.ExtractExtraProperties(data, *s) - if err != nil { - return err - } - s.extraProperties = extraProperties - - s._rawJSON = json.RawMessage(data) - return nil -} - -func (s *SingleGeneration) String() string { - if len(s._rawJSON) > 0 { - if value, err := core.StringifyJSON(s._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(s); err == nil { - return value - } - return fmt.Sprintf("%#v", s) -} - -type SingleGenerationInStream struct { - Id string `json:"id" url:"id"` - // Full text of the generation. - Text string `json:"text" url:"text"` - // Refers to the nth generation. Only present when `num_generations` is greater than zero. - Index *int `json:"index,omitempty" url:"index,omitempty"` - FinishReason FinishReason `json:"finish_reason" url:"finish_reason"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (s *SingleGenerationInStream) GetExtraProperties() map[string]interface{} { - return s.extraProperties -} - -func (s *SingleGenerationInStream) UnmarshalJSON(data []byte) error { - type unmarshaler SingleGenerationInStream - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *s = SingleGenerationInStream(value) - - extraProperties, err := core.ExtractExtraProperties(data, *s) - if err != nil { - return err - } - s.extraProperties = extraProperties +type JsonResponseFormat struct { + // A JSON schema object that the output will adhere to. There are some restrictions we have on the schema, refer to [our guide](https://docs.cohere.com/docs/structured-outputs-json#schema-constraints) for more information. + // Example (required name and age object): + // ```json + // + // { + // "type": "object", + // "properties": { + // "name": {"type": "string"}, + // "age": {"type": "integer"} + // }, + // "required": ["name", "age"] + // } + // + // ``` + // + // **Note**: This field must not be specified when the `type` is set to `"text"`. + Schema map[string]interface{} `json:"schema,omitempty" url:"schema,omitempty"` - s._rawJSON = json.RawMessage(data) - return nil + extraProperties map[string]interface{} + rawJSON json.RawMessage } -func (s *SingleGenerationInStream) String() string { - if len(s._rawJSON) > 0 { - if value, err := core.StringifyJSON(s._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(s); err == nil { - return value +func (j *JsonResponseFormat) GetSchema() map[string]interface{} { + if j == nil { + return nil } - return fmt.Sprintf("%#v", s) -} - -type SingleGenerationTokenLikelihoodsItem struct { - Token string `json:"token" url:"token"` - Likelihood float64 `json:"likelihood" url:"likelihood"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage + return j.Schema } -func (s *SingleGenerationTokenLikelihoodsItem) GetExtraProperties() map[string]interface{} { - return s.extraProperties +func (j *JsonResponseFormat) GetExtraProperties() map[string]interface{} { + return j.extraProperties } -func (s *SingleGenerationTokenLikelihoodsItem) UnmarshalJSON(data []byte) error { - type unmarshaler SingleGenerationTokenLikelihoodsItem +func (j *JsonResponseFormat) UnmarshalJSON(data []byte) error { + type unmarshaler JsonResponseFormat var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *s = SingleGenerationTokenLikelihoodsItem(value) - - extraProperties, err := core.ExtractExtraProperties(data, *s) + *j = JsonResponseFormat(value) + extraProperties, err := internal.ExtractExtraProperties(data, *j) if err != nil { return err } - s.extraProperties = extraProperties - - s._rawJSON = json.RawMessage(data) + j.extraProperties = extraProperties + j.rawJSON = json.RawMessage(data) return nil } -func (s *SingleGenerationTokenLikelihoodsItem) String() string { - if len(s._rawJSON) > 0 { - if value, err := core.StringifyJSON(s._rawJSON); err == nil { +func (j *JsonResponseFormat) String() string { + if len(j.rawJSON) > 0 { + if value, err := internal.StringifyJSON(j.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(s); err == nil { + if value, err := internal.StringifyJSON(j); err == nil { return value } - return fmt.Sprintf("%#v", s) -} - -// A source object containing information about the source of the data cited. -type Source struct { - Type string - Tool *ToolSource - Document *DocumentSource + return fmt.Sprintf("%#v", j) } -func (s *Source) UnmarshalJSON(data []byte) error { - var unmarshaler struct { - Type string `json:"type"` - } - if err := json.Unmarshal(data, &unmarshaler); err != nil { - return err - } - s.Type = unmarshaler.Type - if unmarshaler.Type == "" { - return fmt.Errorf("%T did not include discriminant type", s) - } - switch unmarshaler.Type { - case "tool": - value := new(ToolSource) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - s.Tool = value - case "document": - value := new(DocumentSource) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - s.Document = value - } - return nil +type Message struct { + Role string + Chatbot *ChatMessage + System *ChatMessage + User *ChatMessage + Tool *ToolMessage } -func (s Source) MarshalJSON() ([]byte, error) { - if s.Tool != nil { - return core.MarshalJSONWithExtraProperty(s.Tool, "type", "tool") +func (m *Message) GetRole() string { + if m == nil { + return "" } - if s.Document != nil { - return core.MarshalJSONWithExtraProperty(s.Document, "type", "document") - } - return nil, fmt.Errorf("type %T does not define a non-empty union type", s) -} - -type SourceVisitor interface { - VisitTool(*ToolSource) error - VisitDocument(*DocumentSource) error + return m.Role } -func (s *Source) Accept(visitor SourceVisitor) error { - if s.Tool != nil { - return visitor.VisitTool(s.Tool) - } - if s.Document != nil { - return visitor.VisitDocument(s.Document) +func (m *Message) GetChatbot() *ChatMessage { + if m == nil { + return nil } - return fmt.Errorf("type %T does not define a non-empty union type", s) -} - -// StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request). -type StreamedChatResponse struct { - EventType string - StreamStart *ChatStreamStartEvent - SearchQueriesGeneration *ChatSearchQueriesGenerationEvent - SearchResults *ChatSearchResultsEvent - TextGeneration *ChatTextGenerationEvent - CitationGeneration *ChatCitationGenerationEvent - ToolCallsGeneration *ChatToolCallsGenerationEvent - StreamEnd *ChatStreamEndEvent - ToolCallsChunk *ChatToolCallsChunkEvent - Debug *ChatDebugEvent + return m.Chatbot } -func (s *StreamedChatResponse) UnmarshalJSON(data []byte) error { - var unmarshaler struct { - EventType string `json:"event_type"` - } - if err := json.Unmarshal(data, &unmarshaler); err != nil { - return err - } - s.EventType = unmarshaler.EventType - if unmarshaler.EventType == "" { - return fmt.Errorf("%T did not include discriminant event_type", s) - } - switch unmarshaler.EventType { - case "stream-start": - value := new(ChatStreamStartEvent) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - s.StreamStart = value - case "search-queries-generation": - value := new(ChatSearchQueriesGenerationEvent) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - s.SearchQueriesGeneration = value - case "search-results": - value := new(ChatSearchResultsEvent) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - s.SearchResults = value - case "text-generation": - value := new(ChatTextGenerationEvent) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - s.TextGeneration = value - case "citation-generation": - value := new(ChatCitationGenerationEvent) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - s.CitationGeneration = value - case "tool-calls-generation": - value := new(ChatToolCallsGenerationEvent) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - s.ToolCallsGeneration = value - case "stream-end": - value := new(ChatStreamEndEvent) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - s.StreamEnd = value - case "tool-calls-chunk": - value := new(ChatToolCallsChunkEvent) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - s.ToolCallsChunk = value - case "debug": - value := new(ChatDebugEvent) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - s.Debug = value +func (m *Message) GetSystem() *ChatMessage { + if m == nil { + return nil } - return nil + return m.System } -func (s StreamedChatResponse) MarshalJSON() ([]byte, error) { - if s.StreamStart != nil { - return core.MarshalJSONWithExtraProperty(s.StreamStart, "event_type", "stream-start") - } - if s.SearchQueriesGeneration != nil { - return core.MarshalJSONWithExtraProperty(s.SearchQueriesGeneration, "event_type", "search-queries-generation") - } - if s.SearchResults != nil { - return core.MarshalJSONWithExtraProperty(s.SearchResults, "event_type", "search-results") - } - if s.TextGeneration != nil { - return core.MarshalJSONWithExtraProperty(s.TextGeneration, "event_type", "text-generation") - } - if s.CitationGeneration != nil { - return core.MarshalJSONWithExtraProperty(s.CitationGeneration, "event_type", "citation-generation") - } - if s.ToolCallsGeneration != nil { - return core.MarshalJSONWithExtraProperty(s.ToolCallsGeneration, "event_type", "tool-calls-generation") - } - if s.StreamEnd != nil { - return core.MarshalJSONWithExtraProperty(s.StreamEnd, "event_type", "stream-end") - } - if s.ToolCallsChunk != nil { - return core.MarshalJSONWithExtraProperty(s.ToolCallsChunk, "event_type", "tool-calls-chunk") - } - if s.Debug != nil { - return core.MarshalJSONWithExtraProperty(s.Debug, "event_type", "debug") +func (m *Message) GetUser() *ChatMessage { + if m == nil { + return nil } - return nil, fmt.Errorf("type %T does not define a non-empty union type", s) -} - -type StreamedChatResponseVisitor interface { - VisitStreamStart(*ChatStreamStartEvent) error - VisitSearchQueriesGeneration(*ChatSearchQueriesGenerationEvent) error - VisitSearchResults(*ChatSearchResultsEvent) error - VisitTextGeneration(*ChatTextGenerationEvent) error - VisitCitationGeneration(*ChatCitationGenerationEvent) error - VisitToolCallsGeneration(*ChatToolCallsGenerationEvent) error - VisitStreamEnd(*ChatStreamEndEvent) error - VisitToolCallsChunk(*ChatToolCallsChunkEvent) error - VisitDebug(*ChatDebugEvent) error + return m.User } -func (s *StreamedChatResponse) Accept(visitor StreamedChatResponseVisitor) error { - if s.StreamStart != nil { - return visitor.VisitStreamStart(s.StreamStart) - } - if s.SearchQueriesGeneration != nil { - return visitor.VisitSearchQueriesGeneration(s.SearchQueriesGeneration) - } - if s.SearchResults != nil { - return visitor.VisitSearchResults(s.SearchResults) - } - if s.TextGeneration != nil { - return visitor.VisitTextGeneration(s.TextGeneration) - } - if s.CitationGeneration != nil { - return visitor.VisitCitationGeneration(s.CitationGeneration) - } - if s.ToolCallsGeneration != nil { - return visitor.VisitToolCallsGeneration(s.ToolCallsGeneration) - } - if s.StreamEnd != nil { - return visitor.VisitStreamEnd(s.StreamEnd) - } - if s.ToolCallsChunk != nil { - return visitor.VisitToolCallsChunk(s.ToolCallsChunk) - } - if s.Debug != nil { - return visitor.VisitDebug(s.Debug) +func (m *Message) GetTool() *ToolMessage { + if m == nil { + return nil } - return fmt.Errorf("type %T does not define a non-empty union type", s) + return m.Tool } -// StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request). -type StreamedChatResponseV2 struct { - Type string - MessageStart *ChatMessageStartEvent - ContentStart *ChatContentStartEvent - ContentDelta *ChatContentDeltaEvent - ContentEnd *ChatContentEndEvent - ToolPlanDelta *ChatToolPlanDeltaEvent - ToolCallStart *ChatToolCallStartEvent - ToolCallDelta *ChatToolCallDeltaEvent - ToolCallEnd *ChatToolCallEndEvent - CitationStart *CitationStartEvent - CitationEnd *CitationEndEvent - MessageEnd *ChatMessageEndEvent - Debug *ChatDebugEvent -} - -func (s *StreamedChatResponseV2) UnmarshalJSON(data []byte) error { +func (m *Message) UnmarshalJSON(data []byte) error { var unmarshaler struct { - Type string `json:"type"` - } - if err := json.Unmarshal(data, &unmarshaler); err != nil { - return err - } - s.Type = unmarshaler.Type - if unmarshaler.Type == "" { - return fmt.Errorf("%T did not include discriminant type", s) + Role string `json:"role"` } - switch unmarshaler.Type { - case "message-start": - value := new(ChatMessageStartEvent) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - s.MessageStart = value - case "content-start": - value := new(ChatContentStartEvent) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - s.ContentStart = value - case "content-delta": - value := new(ChatContentDeltaEvent) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - s.ContentDelta = value - case "content-end": - value := new(ChatContentEndEvent) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - s.ContentEnd = value - case "tool-plan-delta": - value := new(ChatToolPlanDeltaEvent) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - s.ToolPlanDelta = value - case "tool-call-start": - value := new(ChatToolCallStartEvent) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - s.ToolCallStart = value - case "tool-call-delta": - value := new(ChatToolCallDeltaEvent) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - s.ToolCallDelta = value - case "tool-call-end": - value := new(ChatToolCallEndEvent) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - s.ToolCallEnd = value - case "citation-start": - value := new(CitationStartEvent) + if err := json.Unmarshal(data, &unmarshaler); err != nil { + return err + } + m.Role = unmarshaler.Role + if unmarshaler.Role == "" { + return fmt.Errorf("%T did not include discriminant role", m) + } + switch unmarshaler.Role { + case "CHATBOT": + value := new(ChatMessage) if err := json.Unmarshal(data, &value); err != nil { return err } - s.CitationStart = value - case "citation-end": - value := new(CitationEndEvent) + m.Chatbot = value + case "SYSTEM": + value := new(ChatMessage) if err := json.Unmarshal(data, &value); err != nil { return err } - s.CitationEnd = value - case "message-end": - value := new(ChatMessageEndEvent) + m.System = value + case "USER": + value := new(ChatMessage) if err := json.Unmarshal(data, &value); err != nil { return err } - s.MessageEnd = value - case "debug": - value := new(ChatDebugEvent) + m.User = value + case "TOOL": + value := new(ToolMessage) if err := json.Unmarshal(data, &value); err != nil { return err } - s.Debug = value + m.Tool = value } return nil } -func (s StreamedChatResponseV2) MarshalJSON() ([]byte, error) { - if s.MessageStart != nil { - return core.MarshalJSONWithExtraProperty(s.MessageStart, "type", "message-start") - } - if s.ContentStart != nil { - return core.MarshalJSONWithExtraProperty(s.ContentStart, "type", "content-start") - } - if s.ContentDelta != nil { - return core.MarshalJSONWithExtraProperty(s.ContentDelta, "type", "content-delta") - } - if s.ContentEnd != nil { - return core.MarshalJSONWithExtraProperty(s.ContentEnd, "type", "content-end") - } - if s.ToolPlanDelta != nil { - return core.MarshalJSONWithExtraProperty(s.ToolPlanDelta, "type", "tool-plan-delta") - } - if s.ToolCallStart != nil { - return core.MarshalJSONWithExtraProperty(s.ToolCallStart, "type", "tool-call-start") - } - if s.ToolCallDelta != nil { - return core.MarshalJSONWithExtraProperty(s.ToolCallDelta, "type", "tool-call-delta") - } - if s.ToolCallEnd != nil { - return core.MarshalJSONWithExtraProperty(s.ToolCallEnd, "type", "tool-call-end") +func (m Message) MarshalJSON() ([]byte, error) { + if err := m.validate(); err != nil { + return nil, err } - if s.CitationStart != nil { - return core.MarshalJSONWithExtraProperty(s.CitationStart, "type", "citation-start") + if m.Chatbot != nil { + return internal.MarshalJSONWithExtraProperty(m.Chatbot, "role", "CHATBOT") } - if s.CitationEnd != nil { - return core.MarshalJSONWithExtraProperty(s.CitationEnd, "type", "citation-end") + if m.System != nil { + return internal.MarshalJSONWithExtraProperty(m.System, "role", "SYSTEM") } - if s.MessageEnd != nil { - return core.MarshalJSONWithExtraProperty(s.MessageEnd, "type", "message-end") + if m.User != nil { + return internal.MarshalJSONWithExtraProperty(m.User, "role", "USER") } - if s.Debug != nil { - return core.MarshalJSONWithExtraProperty(s.Debug, "type", "debug") + if m.Tool != nil { + return internal.MarshalJSONWithExtraProperty(m.Tool, "role", "TOOL") } - return nil, fmt.Errorf("type %T does not define a non-empty union type", s) + return nil, fmt.Errorf("type %T does not define a non-empty union type", m) } -type StreamedChatResponseV2Visitor interface { - VisitMessageStart(*ChatMessageStartEvent) error - VisitContentStart(*ChatContentStartEvent) error - VisitContentDelta(*ChatContentDeltaEvent) error - VisitContentEnd(*ChatContentEndEvent) error - VisitToolPlanDelta(*ChatToolPlanDeltaEvent) error - VisitToolCallStart(*ChatToolCallStartEvent) error - VisitToolCallDelta(*ChatToolCallDeltaEvent) error - VisitToolCallEnd(*ChatToolCallEndEvent) error - VisitCitationStart(*CitationStartEvent) error - VisitCitationEnd(*CitationEndEvent) error - VisitMessageEnd(*ChatMessageEndEvent) error - VisitDebug(*ChatDebugEvent) error +type MessageVisitor interface { + VisitChatbot(*ChatMessage) error + VisitSystem(*ChatMessage) error + VisitUser(*ChatMessage) error + VisitTool(*ToolMessage) error } -func (s *StreamedChatResponseV2) Accept(visitor StreamedChatResponseV2Visitor) error { - if s.MessageStart != nil { - return visitor.VisitMessageStart(s.MessageStart) +func (m *Message) Accept(visitor MessageVisitor) error { + if m.Chatbot != nil { + return visitor.VisitChatbot(m.Chatbot) } - if s.ContentStart != nil { - return visitor.VisitContentStart(s.ContentStart) + if m.System != nil { + return visitor.VisitSystem(m.System) } - if s.ContentDelta != nil { - return visitor.VisitContentDelta(s.ContentDelta) + if m.User != nil { + return visitor.VisitUser(m.User) } - if s.ContentEnd != nil { - return visitor.VisitContentEnd(s.ContentEnd) + if m.Tool != nil { + return visitor.VisitTool(m.Tool) } - if s.ToolPlanDelta != nil { - return visitor.VisitToolPlanDelta(s.ToolPlanDelta) + return fmt.Errorf("type %T does not define a non-empty union type", m) +} + +func (m *Message) validate() error { + if m == nil { + return fmt.Errorf("type %T is nil", m) } - if s.ToolCallStart != nil { - return visitor.VisitToolCallStart(s.ToolCallStart) + var fields []string + if m.Chatbot != nil { + fields = append(fields, "CHATBOT") } - if s.ToolCallDelta != nil { - return visitor.VisitToolCallDelta(s.ToolCallDelta) + if m.System != nil { + fields = append(fields, "SYSTEM") } - if s.ToolCallEnd != nil { - return visitor.VisitToolCallEnd(s.ToolCallEnd) + if m.User != nil { + fields = append(fields, "USER") } - if s.CitationStart != nil { - return visitor.VisitCitationStart(s.CitationStart) + if m.Tool != nil { + fields = append(fields, "TOOL") } - if s.CitationEnd != nil { - return visitor.VisitCitationEnd(s.CitationEnd) + if len(fields) == 0 { + if m.Role != "" { + return fmt.Errorf("type %T defines a discriminant set to %q but the field is not set", m, m.Role) + } + return fmt.Errorf("type %T is empty", m) } - if s.MessageEnd != nil { - return visitor.VisitMessageEnd(s.MessageEnd) + if len(fields) > 1 { + return fmt.Errorf("type %T defines values for %s, but only one value is allowed", m, fields) } - if s.Debug != nil { - return visitor.VisitDebug(s.Debug) + if m.Role != "" { + field := fields[0] + if m.Role != field { + return fmt.Errorf( + "type %T defines a discriminant set to %q, but it does not match the %T field; either remove or update the discriminant to match", + m, + m.Role, + m, + ) + } } - return fmt.Errorf("type %T does not define a non-empty union type", s) + return nil } -// One of `low`, `medium`, `high`, or `auto`, defaults to `auto`. Controls how close to the original text the summary is. `high` extractiveness summaries will lean towards reusing sentences verbatim, while `low` extractiveness summaries will tend to paraphrase more. If `auto` is selected, the best option will be picked based on the input text. -type SummarizeRequestExtractiveness string +type NonStreamedChatResponse struct { + // Contents of the reply generated by the model. + Text string `json:"text" url:"text"` + // Unique identifier for the generated reply. Useful for submitting feedback. + GenerationId *string `json:"generation_id,omitempty" url:"generation_id,omitempty"` + // Unique identifier for the response. + ResponseId *string `json:"response_id,omitempty" url:"response_id,omitempty"` + // Inline citations for the generated reply. + Citations []*ChatCitation `json:"citations,omitempty" url:"citations,omitempty"` + // Documents seen by the model when generating the reply. + Documents []ChatDocument `json:"documents,omitempty" url:"documents,omitempty"` + // Denotes that a search for documents is required during the RAG flow. + IsSearchRequired *bool `json:"is_search_required,omitempty" url:"is_search_required,omitempty"` + // Generated search queries, meant to be used as part of the RAG flow. + SearchQueries []*ChatSearchQuery `json:"search_queries,omitempty" url:"search_queries,omitempty"` + // Documents retrieved from each of the conducted searches. + SearchResults []*ChatSearchResult `json:"search_results,omitempty" url:"search_results,omitempty"` + FinishReason *FinishReason `json:"finish_reason,omitempty" url:"finish_reason,omitempty"` + ToolCalls []*ToolCall `json:"tool_calls,omitempty" url:"tool_calls,omitempty"` + // A list of previous messages between the user and the model, meant to give the model conversational context for responding to the user's `message`. + ChatHistory []*Message `json:"chat_history,omitempty" url:"chat_history,omitempty"` + // The prompt that was used. Only present when `return_prompt` in the request is set to true. + Prompt *string `json:"prompt,omitempty" url:"prompt,omitempty"` + Meta *ApiMeta `json:"meta,omitempty" url:"meta,omitempty"` -const ( - SummarizeRequestExtractivenessLow SummarizeRequestExtractiveness = "low" - SummarizeRequestExtractivenessMedium SummarizeRequestExtractiveness = "medium" - SummarizeRequestExtractivenessHigh SummarizeRequestExtractiveness = "high" -) + extraProperties map[string]interface{} + rawJSON json.RawMessage +} -func NewSummarizeRequestExtractivenessFromString(s string) (SummarizeRequestExtractiveness, error) { - switch s { - case "low": - return SummarizeRequestExtractivenessLow, nil - case "medium": - return SummarizeRequestExtractivenessMedium, nil - case "high": - return SummarizeRequestExtractivenessHigh, nil +func (n *NonStreamedChatResponse) GetText() string { + if n == nil { + return "" } - var t SummarizeRequestExtractiveness - return "", fmt.Errorf("%s is not a valid %T", s, t) + return n.Text } -func (s SummarizeRequestExtractiveness) Ptr() *SummarizeRequestExtractiveness { - return &s +func (n *NonStreamedChatResponse) GetGenerationId() *string { + if n == nil { + return nil + } + return n.GenerationId } -// One of `paragraph`, `bullets`, or `auto`, defaults to `auto`. Indicates the style in which the summary will be delivered - in a free form paragraph or in bullet points. If `auto` is selected, the best option will be picked based on the input text. -type SummarizeRequestFormat string +func (n *NonStreamedChatResponse) GetResponseId() *string { + if n == nil { + return nil + } + return n.ResponseId +} -const ( - SummarizeRequestFormatParagraph SummarizeRequestFormat = "paragraph" - SummarizeRequestFormatBullets SummarizeRequestFormat = "bullets" -) +func (n *NonStreamedChatResponse) GetCitations() []*ChatCitation { + if n == nil { + return nil + } + return n.Citations +} -func NewSummarizeRequestFormatFromString(s string) (SummarizeRequestFormat, error) { - switch s { - case "paragraph": - return SummarizeRequestFormatParagraph, nil - case "bullets": - return SummarizeRequestFormatBullets, nil +func (n *NonStreamedChatResponse) GetDocuments() []ChatDocument { + if n == nil { + return nil } - var t SummarizeRequestFormat - return "", fmt.Errorf("%s is not a valid %T", s, t) + return n.Documents } -func (s SummarizeRequestFormat) Ptr() *SummarizeRequestFormat { - return &s +func (n *NonStreamedChatResponse) GetIsSearchRequired() *bool { + if n == nil { + return nil + } + return n.IsSearchRequired } -// One of `short`, `medium`, `long`, or `auto` defaults to `auto`. Indicates the approximate length of the summary. If `auto` is selected, the best option will be picked based on the input text. -type SummarizeRequestLength string +func (n *NonStreamedChatResponse) GetSearchQueries() []*ChatSearchQuery { + if n == nil { + return nil + } + return n.SearchQueries +} -const ( - SummarizeRequestLengthShort SummarizeRequestLength = "short" - SummarizeRequestLengthMedium SummarizeRequestLength = "medium" - SummarizeRequestLengthLong SummarizeRequestLength = "long" -) +func (n *NonStreamedChatResponse) GetSearchResults() []*ChatSearchResult { + if n == nil { + return nil + } + return n.SearchResults +} -func NewSummarizeRequestLengthFromString(s string) (SummarizeRequestLength, error) { - switch s { - case "short": - return SummarizeRequestLengthShort, nil - case "medium": - return SummarizeRequestLengthMedium, nil - case "long": - return SummarizeRequestLengthLong, nil +func (n *NonStreamedChatResponse) GetFinishReason() *FinishReason { + if n == nil { + return nil } - var t SummarizeRequestLength - return "", fmt.Errorf("%s is not a valid %T", s, t) + return n.FinishReason } -func (s SummarizeRequestLength) Ptr() *SummarizeRequestLength { - return &s +func (n *NonStreamedChatResponse) GetToolCalls() []*ToolCall { + if n == nil { + return nil + } + return n.ToolCalls } -type SummarizeResponse struct { - // Generated ID for the summary - Id *string `json:"id,omitempty" url:"id,omitempty"` - // Generated summary for the text - Summary *string `json:"summary,omitempty" url:"summary,omitempty"` - Meta *ApiMeta `json:"meta,omitempty" url:"meta,omitempty"` +func (n *NonStreamedChatResponse) GetChatHistory() []*Message { + if n == nil { + return nil + } + return n.ChatHistory +} - extraProperties map[string]interface{} - _rawJSON json.RawMessage +func (n *NonStreamedChatResponse) GetPrompt() *string { + if n == nil { + return nil + } + return n.Prompt } -func (s *SummarizeResponse) GetExtraProperties() map[string]interface{} { - return s.extraProperties +func (n *NonStreamedChatResponse) GetMeta() *ApiMeta { + if n == nil { + return nil + } + return n.Meta } -func (s *SummarizeResponse) UnmarshalJSON(data []byte) error { - type unmarshaler SummarizeResponse +func (n *NonStreamedChatResponse) GetExtraProperties() map[string]interface{} { + return n.extraProperties +} + +func (n *NonStreamedChatResponse) UnmarshalJSON(data []byte) error { + type unmarshaler NonStreamedChatResponse var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *s = SummarizeResponse(value) - - extraProperties, err := core.ExtractExtraProperties(data, *s) + *n = NonStreamedChatResponse(value) + extraProperties, err := internal.ExtractExtraProperties(data, *n) if err != nil { return err } - s.extraProperties = extraProperties - - s._rawJSON = json.RawMessage(data) + n.extraProperties = extraProperties + n.rawJSON = json.RawMessage(data) return nil } -func (s *SummarizeResponse) String() string { - if len(s._rawJSON) > 0 { - if value, err := core.StringifyJSON(s._rawJSON); err == nil { +func (n *NonStreamedChatResponse) String() string { + if len(n.rawJSON) > 0 { + if value, err := internal.StringifyJSON(n.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(s); err == nil { + if value, err := internal.StringifyJSON(n); err == nil { return value } - return fmt.Sprintf("%#v", s) + return fmt.Sprintf("%#v", n) } -// A message from the system. -type SystemMessage struct { - Content *SystemMessageContent `json:"content,omitempty" url:"content,omitempty"` +type RerankDocument = map[string]string - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} +type RerankRequestDocumentsItem struct { + String string + RerankDocument RerankDocument -func (s *SystemMessage) GetExtraProperties() map[string]interface{} { - return s.extraProperties + typ string } -func (s *SystemMessage) UnmarshalJSON(data []byte) error { - type unmarshaler SystemMessage - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err +func (r *RerankRequestDocumentsItem) GetString() string { + if r == nil { + return "" } - *s = SystemMessage(value) + return r.String +} - extraProperties, err := core.ExtractExtraProperties(data, *s) - if err != nil { - return err +func (r *RerankRequestDocumentsItem) GetRerankDocument() RerankDocument { + if r == nil { + return nil } - s.extraProperties = extraProperties + return r.RerankDocument +} - s._rawJSON = json.RawMessage(data) - return nil +func (r *RerankRequestDocumentsItem) UnmarshalJSON(data []byte) error { + var valueString string + if err := json.Unmarshal(data, &valueString); err == nil { + r.typ = "String" + r.String = valueString + return nil + } + var valueRerankDocument RerankDocument + if err := json.Unmarshal(data, &valueRerankDocument); err == nil { + r.typ = "RerankDocument" + r.RerankDocument = valueRerankDocument + return nil + } + return fmt.Errorf("%s cannot be deserialized as a %T", data, r) } -func (s *SystemMessage) String() string { - if len(s._rawJSON) > 0 { - if value, err := core.StringifyJSON(s._rawJSON); err == nil { - return value - } +func (r RerankRequestDocumentsItem) MarshalJSON() ([]byte, error) { + if r.typ == "String" || r.String != "" { + return json.Marshal(r.String) } - if value, err := core.StringifyJSON(s); err == nil { - return value + if r.typ == "RerankDocument" || r.RerankDocument != nil { + return json.Marshal(r.RerankDocument) } - return fmt.Sprintf("%#v", s) + return nil, fmt.Errorf("type %T does not include a non-empty union type", r) } -type SystemMessageContent struct { - String string - SystemMessageContentItemList []*SystemMessageContentItem +type RerankRequestDocumentsItemVisitor interface { + VisitString(string) error + VisitRerankDocument(RerankDocument) error } -func (s *SystemMessageContent) UnmarshalJSON(data []byte) error { - var valueString string - if err := json.Unmarshal(data, &valueString); err == nil { - s.String = valueString - return nil +func (r *RerankRequestDocumentsItem) Accept(visitor RerankRequestDocumentsItemVisitor) error { + if r.typ == "String" || r.String != "" { + return visitor.VisitString(r.String) } - var valueSystemMessageContentItemList []*SystemMessageContentItem - if err := json.Unmarshal(data, &valueSystemMessageContentItemList); err == nil { - s.SystemMessageContentItemList = valueSystemMessageContentItemList - return nil + if r.typ == "RerankDocument" || r.RerankDocument != nil { + return visitor.VisitRerankDocument(r.RerankDocument) } - return fmt.Errorf("%s cannot be deserialized as a %T", data, s) + return fmt.Errorf("type %T does not include a non-empty union type", r) +} + +type RerankResponse struct { + Id *string `json:"id,omitempty" url:"id,omitempty"` + // An ordered list of ranked documents + Results []*RerankResponseResultsItem `json:"results,omitempty" url:"results,omitempty"` + Meta *ApiMeta `json:"meta,omitempty" url:"meta,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage } -func (s SystemMessageContent) MarshalJSON() ([]byte, error) { - if s.String != "" { - return json.Marshal(s.String) - } - if s.SystemMessageContentItemList != nil { - return json.Marshal(s.SystemMessageContentItemList) +func (r *RerankResponse) GetId() *string { + if r == nil { + return nil } - return nil, fmt.Errorf("type %T does not include a non-empty union type", s) + return r.Id } -type SystemMessageContentVisitor interface { - VisitString(string) error - VisitSystemMessageContentItemList([]*SystemMessageContentItem) error +func (r *RerankResponse) GetResults() []*RerankResponseResultsItem { + if r == nil { + return nil + } + return r.Results } -func (s *SystemMessageContent) Accept(visitor SystemMessageContentVisitor) error { - if s.String != "" { - return visitor.VisitString(s.String) - } - if s.SystemMessageContentItemList != nil { - return visitor.VisitSystemMessageContentItemList(s.SystemMessageContentItemList) +func (r *RerankResponse) GetMeta() *ApiMeta { + if r == nil { + return nil } - return fmt.Errorf("type %T does not include a non-empty union type", s) + return r.Meta } -type SystemMessageContentItem struct { - Type string - Text *TextContent +func (r *RerankResponse) GetExtraProperties() map[string]interface{} { + return r.extraProperties } -func (s *SystemMessageContentItem) UnmarshalJSON(data []byte) error { - var unmarshaler struct { - Type string `json:"type"` - } - if err := json.Unmarshal(data, &unmarshaler); err != nil { +func (r *RerankResponse) UnmarshalJSON(data []byte) error { + type unmarshaler RerankResponse + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { return err } - s.Type = unmarshaler.Type - if unmarshaler.Type == "" { - return fmt.Errorf("%T did not include discriminant type", s) - } - switch unmarshaler.Type { - case "text": - value := new(TextContent) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - s.Text = value + *r = RerankResponse(value) + extraProperties, err := internal.ExtractExtraProperties(data, *r) + if err != nil { + return err } + r.extraProperties = extraProperties + r.rawJSON = json.RawMessage(data) return nil } -func (s SystemMessageContentItem) MarshalJSON() ([]byte, error) { - if s.Text != nil { - return core.MarshalJSONWithExtraProperty(s.Text, "type", "text") +func (r *RerankResponse) String() string { + if len(r.rawJSON) > 0 { + if value, err := internal.StringifyJSON(r.rawJSON); err == nil { + return value + } } - return nil, fmt.Errorf("type %T does not define a non-empty union type", s) + if value, err := internal.StringifyJSON(r); err == nil { + return value + } + return fmt.Sprintf("%#v", r) } -type SystemMessageContentItemVisitor interface { - VisitText(*TextContent) error +type RerankResponseResultsItem struct { + // If `return_documents` is set as `false` this will return none, if `true` it will return the documents passed in + Document *RerankResponseResultsItemDocument `json:"document,omitempty" url:"document,omitempty"` + // Corresponds to the index in the original list of documents to which the ranked document belongs. (i.e. if the first value in the `results` object has an `index` value of 3, it means in the list of documents passed in, the document at `index=3` had the highest relevance) + Index int `json:"index" url:"index"` + // Relevance scores are normalized to be in the range `[0, 1]`. Scores close to `1` indicate a high relevance to the query, and scores closer to `0` indicate low relevance. It is not accurate to assume a score of 0.9 means the document is 2x more relevant than a document with a score of 0.45 + RelevanceScore float64 `json:"relevance_score" url:"relevance_score"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage } -func (s *SystemMessageContentItem) Accept(visitor SystemMessageContentItemVisitor) error { - if s.Text != nil { - return visitor.VisitText(s.Text) +func (r *RerankResponseResultsItem) GetDocument() *RerankResponseResultsItemDocument { + if r == nil { + return nil } - return fmt.Errorf("type %T does not define a non-empty union type", s) + return r.Document } -// Text content of the message. -type TextContent struct { - Text string `json:"text" url:"text"` +func (r *RerankResponseResultsItem) GetIndex() int { + if r == nil { + return 0 + } + return r.Index +} - extraProperties map[string]interface{} - _rawJSON json.RawMessage +func (r *RerankResponseResultsItem) GetRelevanceScore() float64 { + if r == nil { + return 0 + } + return r.RelevanceScore } -func (t *TextContent) GetExtraProperties() map[string]interface{} { - return t.extraProperties +func (r *RerankResponseResultsItem) GetExtraProperties() map[string]interface{} { + return r.extraProperties } -func (t *TextContent) UnmarshalJSON(data []byte) error { - type unmarshaler TextContent +func (r *RerankResponseResultsItem) UnmarshalJSON(data []byte) error { + type unmarshaler RerankResponseResultsItem var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *t = TextContent(value) - - extraProperties, err := core.ExtractExtraProperties(data, *t) + *r = RerankResponseResultsItem(value) + extraProperties, err := internal.ExtractExtraProperties(data, *r) if err != nil { return err } - t.extraProperties = extraProperties - - t._rawJSON = json.RawMessage(data) + r.extraProperties = extraProperties + r.rawJSON = json.RawMessage(data) return nil } -func (t *TextContent) String() string { - if len(t._rawJSON) > 0 { - if value, err := core.StringifyJSON(t._rawJSON); err == nil { +func (r *RerankResponseResultsItem) String() string { + if len(r.rawJSON) > 0 { + if value, err := internal.StringifyJSON(r.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(t); err == nil { + if value, err := internal.StringifyJSON(r); err == nil { return value } - return fmt.Sprintf("%#v", t) + return fmt.Sprintf("%#v", r) } -type TextResponseFormat struct { +// If `return_documents` is set as `false` this will return none, if `true` it will return the documents passed in +type RerankResponseResultsItemDocument struct { + // The text of the document to rerank + Text string `json:"text" url:"text"` + extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage } -func (t *TextResponseFormat) GetExtraProperties() map[string]interface{} { - return t.extraProperties +func (r *RerankResponseResultsItemDocument) GetText() string { + if r == nil { + return "" + } + return r.Text } -func (t *TextResponseFormat) UnmarshalJSON(data []byte) error { - type unmarshaler TextResponseFormat +func (r *RerankResponseResultsItemDocument) GetExtraProperties() map[string]interface{} { + return r.extraProperties +} + +func (r *RerankResponseResultsItemDocument) UnmarshalJSON(data []byte) error { + type unmarshaler RerankResponseResultsItemDocument var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *t = TextResponseFormat(value) - - extraProperties, err := core.ExtractExtraProperties(data, *t) + *r = RerankResponseResultsItemDocument(value) + extraProperties, err := internal.ExtractExtraProperties(data, *r) if err != nil { return err } - t.extraProperties = extraProperties - - t._rawJSON = json.RawMessage(data) + r.extraProperties = extraProperties + r.rawJSON = json.RawMessage(data) return nil } -func (t *TextResponseFormat) String() string { - if len(t._rawJSON) > 0 { - if value, err := core.StringifyJSON(t._rawJSON); err == nil { +func (r *RerankResponseResultsItemDocument) String() string { + if len(r.rawJSON) > 0 { + if value, err := internal.StringifyJSON(r.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(t); err == nil { + if value, err := internal.StringifyJSON(r); err == nil { return value } - return fmt.Sprintf("%#v", t) + return fmt.Sprintf("%#v", r) } -type TextResponseFormatV2 struct { - extraProperties map[string]interface{} - _rawJSON json.RawMessage +// Configuration for forcing the model output to adhere to the specified format. Supported on [Command R 03-2024](https://docs.cohere.com/docs/command-r), [Command R+ 04-2024](https://docs.cohere.com/docs/command-r-plus) and newer models. +// +// The model can be forced into outputting JSON objects (with up to 5 levels of nesting) by setting `{ "type": "json_object" }`. +// +// A [JSON Schema](https://json-schema.org/) can optionally be provided, to ensure a specific structure. +// +// **Note**: When using `{ "type": "json_object" }` your `message` should always explicitly instruct the model to generate a JSON (eg: _"Generate a JSON ..."_) . Otherwise the model may end up getting stuck generating an infinite stream of characters and eventually run out of context length. +// **Limitation**: The parameter is not supported in RAG mode (when any of `connectors`, `documents`, `tools`, `tool_results` are provided). +type ResponseFormat struct { + Type string + Text *TextResponseFormat + JsonObject *JsonResponseFormat } -func (t *TextResponseFormatV2) GetExtraProperties() map[string]interface{} { - return t.extraProperties +func (r *ResponseFormat) GetType() string { + if r == nil { + return "" + } + return r.Type } -func (t *TextResponseFormatV2) UnmarshalJSON(data []byte) error { - type unmarshaler TextResponseFormatV2 - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err +func (r *ResponseFormat) GetText() *TextResponseFormat { + if r == nil { + return nil } - *t = TextResponseFormatV2(value) + return r.Text +} - extraProperties, err := core.ExtractExtraProperties(data, *t) - if err != nil { - return err +func (r *ResponseFormat) GetJsonObject() *JsonResponseFormat { + if r == nil { + return nil } - t.extraProperties = extraProperties - - t._rawJSON = json.RawMessage(data) - return nil + return r.JsonObject } -func (t *TextResponseFormatV2) String() string { - if len(t._rawJSON) > 0 { - if value, err := core.StringifyJSON(t._rawJSON); err == nil { - return value - } +func (r *ResponseFormat) UnmarshalJSON(data []byte) error { + var unmarshaler struct { + Type string `json:"type"` } - if value, err := core.StringifyJSON(t); err == nil { - return value + if err := json.Unmarshal(data, &unmarshaler); err != nil { + return err } - return fmt.Sprintf("%#v", t) + r.Type = unmarshaler.Type + if unmarshaler.Type == "" { + return fmt.Errorf("%T did not include discriminant type", r) + } + switch unmarshaler.Type { + case "text": + value := new(TextResponseFormat) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + r.Text = value + case "json_object": + value := new(JsonResponseFormat) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + r.JsonObject = value + } + return nil } -type TokenizeResponse struct { - // An array of tokens, where each token is an integer. - Tokens []int `json:"tokens,omitempty" url:"tokens,omitempty"` - TokenStrings []string `json:"token_strings,omitempty" url:"token_strings,omitempty"` - Meta *ApiMeta `json:"meta,omitempty" url:"meta,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage +func (r ResponseFormat) MarshalJSON() ([]byte, error) { + if err := r.validate(); err != nil { + return nil, err + } + if r.Text != nil { + return internal.MarshalJSONWithExtraProperty(r.Text, "type", "text") + } + if r.JsonObject != nil { + return internal.MarshalJSONWithExtraProperty(r.JsonObject, "type", "json_object") + } + return nil, fmt.Errorf("type %T does not define a non-empty union type", r) } -func (t *TokenizeResponse) GetExtraProperties() map[string]interface{} { - return t.extraProperties +type ResponseFormatVisitor interface { + VisitText(*TextResponseFormat) error + VisitJsonObject(*JsonResponseFormat) error } -func (t *TokenizeResponse) UnmarshalJSON(data []byte) error { - type unmarshaler TokenizeResponse - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err +func (r *ResponseFormat) Accept(visitor ResponseFormatVisitor) error { + if r.Text != nil { + return visitor.VisitText(r.Text) } - *t = TokenizeResponse(value) - - extraProperties, err := core.ExtractExtraProperties(data, *t) - if err != nil { - return err + if r.JsonObject != nil { + return visitor.VisitJsonObject(r.JsonObject) } - t.extraProperties = extraProperties - - t._rawJSON = json.RawMessage(data) - return nil + return fmt.Errorf("type %T does not define a non-empty union type", r) } -func (t *TokenizeResponse) String() string { - if len(t._rawJSON) > 0 { - if value, err := core.StringifyJSON(t._rawJSON); err == nil { - return value +func (r *ResponseFormat) validate() error { + if r == nil { + return fmt.Errorf("type %T is nil", r) + } + var fields []string + if r.Text != nil { + fields = append(fields, "text") + } + if r.JsonObject != nil { + fields = append(fields, "json_object") + } + if len(fields) == 0 { + if r.Type != "" { + return fmt.Errorf("type %T defines a discriminant set to %q but the field is not set", r, r.Type) } + return fmt.Errorf("type %T is empty", r) } - if value, err := core.StringifyJSON(t); err == nil { - return value + if len(fields) > 1 { + return fmt.Errorf("type %T defines values for %s, but only one value is allowed", r, fields) } - return fmt.Sprintf("%#v", t) + if r.Type != "" { + field := fields[0] + if r.Type != field { + return fmt.Errorf( + "type %T defines a discriminant set to %q, but it does not match the %T field; either remove or update the discriminant to match", + r, + r.Type, + r, + ) + } + } + return nil } -type TooManyRequestsErrorBody struct { - Data *string `json:"data,omitempty" url:"data,omitempty"` +type SingleGeneration struct { + Id string `json:"id" url:"id"` + Text string `json:"text" url:"text"` + // Refers to the nth generation. Only present when `num_generations` is greater than zero. + Index *int `json:"index,omitempty" url:"index,omitempty"` + Likelihood *float64 `json:"likelihood,omitempty" url:"likelihood,omitempty"` + // Only returned if `return_likelihoods` is set to `GENERATION` or `ALL`. The likelihood refers to the average log-likelihood of the entire specified string, which is useful for [evaluating the performance of your model](likelihood-eval), especially if you've created a [custom model](https://docs.cohere.com/docs/training-custom-models). Individual token likelihoods provide the log-likelihood of each token. The first token will not have a likelihood. + TokenLikelihoods []*SingleGenerationTokenLikelihoodsItem `json:"token_likelihoods,omitempty" url:"token_likelihoods,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage } -func (t *TooManyRequestsErrorBody) GetExtraProperties() map[string]interface{} { - return t.extraProperties +func (s *SingleGeneration) GetId() string { + if s == nil { + return "" + } + return s.Id } -func (t *TooManyRequestsErrorBody) UnmarshalJSON(data []byte) error { - type unmarshaler TooManyRequestsErrorBody - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err +func (s *SingleGeneration) GetText() string { + if s == nil { + return "" } - *t = TooManyRequestsErrorBody(value) + return s.Text +} - extraProperties, err := core.ExtractExtraProperties(data, *t) - if err != nil { - return err +func (s *SingleGeneration) GetIndex() *int { + if s == nil { + return nil } - t.extraProperties = extraProperties - - t._rawJSON = json.RawMessage(data) - return nil + return s.Index } -func (t *TooManyRequestsErrorBody) String() string { - if len(t._rawJSON) > 0 { - if value, err := core.StringifyJSON(t._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(t); err == nil { - return value +func (s *SingleGeneration) GetLikelihood() *float64 { + if s == nil { + return nil } - return fmt.Sprintf("%#v", t) + return s.Likelihood } -type Tool struct { - // The name of the tool to be called. Valid names contain only the characters `a-z`, `A-Z`, `0-9`, `_` and must not begin with a digit. - Name string `json:"name" url:"name"` - // The description of what the tool does, the model uses the description to choose when and how to call the function. - Description string `json:"description" url:"description"` - // The input parameters of the tool. Accepts a dictionary where the key is the name of the parameter and the value is the parameter spec. Valid parameter names contain only the characters `a-z`, `A-Z`, `0-9`, `_` and must not begin with a digit. - // - // ``` - // - // { - // "my_param": { - // "description": , - // "type": , // any python data type, such as 'str', 'bool' - // "required": - // } - // } - // - // ``` - ParameterDefinitions map[string]*ToolParameterDefinitionsValue `json:"parameter_definitions,omitempty" url:"parameter_definitions,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage +func (s *SingleGeneration) GetTokenLikelihoods() []*SingleGenerationTokenLikelihoodsItem { + if s == nil { + return nil + } + return s.TokenLikelihoods } -func (t *Tool) GetExtraProperties() map[string]interface{} { - return t.extraProperties +func (s *SingleGeneration) GetExtraProperties() map[string]interface{} { + return s.extraProperties } -func (t *Tool) UnmarshalJSON(data []byte) error { - type unmarshaler Tool +func (s *SingleGeneration) UnmarshalJSON(data []byte) error { + type unmarshaler SingleGeneration var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *t = Tool(value) - - extraProperties, err := core.ExtractExtraProperties(data, *t) + *s = SingleGeneration(value) + extraProperties, err := internal.ExtractExtraProperties(data, *s) if err != nil { return err } - t.extraProperties = extraProperties - - t._rawJSON = json.RawMessage(data) + s.extraProperties = extraProperties + s.rawJSON = json.RawMessage(data) return nil } -func (t *Tool) String() string { - if len(t._rawJSON) > 0 { - if value, err := core.StringifyJSON(t._rawJSON); err == nil { +func (s *SingleGeneration) String() string { + if len(s.rawJSON) > 0 { + if value, err := internal.StringifyJSON(s.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(t); err == nil { + if value, err := internal.StringifyJSON(s); err == nil { return value } - return fmt.Sprintf("%#v", t) + return fmt.Sprintf("%#v", s) } -// Contains the tool calls generated by the model. Use it to invoke your tools. -type ToolCall struct { - // Name of the tool to call. - Name string `json:"name" url:"name"` - // The name and value of the parameters to use when invoking a tool. - Parameters map[string]interface{} `json:"parameters,omitempty" url:"parameters,omitempty"` +type SingleGenerationInStream struct { + Id string `json:"id" url:"id"` + // Full text of the generation. + Text string `json:"text" url:"text"` + // Refers to the nth generation. Only present when `num_generations` is greater than zero. + Index *int `json:"index,omitempty" url:"index,omitempty"` + FinishReason FinishReason `json:"finish_reason" url:"finish_reason"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage } -func (t *ToolCall) GetExtraProperties() map[string]interface{} { - return t.extraProperties +func (s *SingleGenerationInStream) GetId() string { + if s == nil { + return "" + } + return s.Id } -func (t *ToolCall) UnmarshalJSON(data []byte) error { - type unmarshaler ToolCall +func (s *SingleGenerationInStream) GetText() string { + if s == nil { + return "" + } + return s.Text +} + +func (s *SingleGenerationInStream) GetIndex() *int { + if s == nil { + return nil + } + return s.Index +} + +func (s *SingleGenerationInStream) GetFinishReason() FinishReason { + if s == nil { + return "" + } + return s.FinishReason +} + +func (s *SingleGenerationInStream) GetExtraProperties() map[string]interface{} { + return s.extraProperties +} + +func (s *SingleGenerationInStream) UnmarshalJSON(data []byte) error { + type unmarshaler SingleGenerationInStream var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *t = ToolCall(value) - - extraProperties, err := core.ExtractExtraProperties(data, *t) + *s = SingleGenerationInStream(value) + extraProperties, err := internal.ExtractExtraProperties(data, *s) if err != nil { return err } - t.extraProperties = extraProperties - - t._rawJSON = json.RawMessage(data) + s.extraProperties = extraProperties + s.rawJSON = json.RawMessage(data) return nil } -func (t *ToolCall) String() string { - if len(t._rawJSON) > 0 { - if value, err := core.StringifyJSON(t._rawJSON); err == nil { +func (s *SingleGenerationInStream) String() string { + if len(s.rawJSON) > 0 { + if value, err := internal.StringifyJSON(s.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(t); err == nil { + if value, err := internal.StringifyJSON(s); err == nil { return value } - return fmt.Sprintf("%#v", t) + return fmt.Sprintf("%#v", s) } -// Contains the chunk of the tool call generation in the stream. -type ToolCallDelta struct { - // Name of the tool call - Name *string `json:"name,omitempty" url:"name,omitempty"` - // Index of the tool call generated - Index *float64 `json:"index,omitempty" url:"index,omitempty"` - // Chunk of the tool parameters - Parameters *string `json:"parameters,omitempty" url:"parameters,omitempty"` - // Chunk of the tool plan text - Text *string `json:"text,omitempty" url:"text,omitempty"` +type SingleGenerationTokenLikelihoodsItem struct { + Token string `json:"token" url:"token"` + Likelihood float64 `json:"likelihood" url:"likelihood"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage } -func (t *ToolCallDelta) GetExtraProperties() map[string]interface{} { - return t.extraProperties +func (s *SingleGenerationTokenLikelihoodsItem) GetToken() string { + if s == nil { + return "" + } + return s.Token } -func (t *ToolCallDelta) UnmarshalJSON(data []byte) error { - type unmarshaler ToolCallDelta +func (s *SingleGenerationTokenLikelihoodsItem) GetLikelihood() float64 { + if s == nil { + return 0 + } + return s.Likelihood +} + +func (s *SingleGenerationTokenLikelihoodsItem) GetExtraProperties() map[string]interface{} { + return s.extraProperties +} + +func (s *SingleGenerationTokenLikelihoodsItem) UnmarshalJSON(data []byte) error { + type unmarshaler SingleGenerationTokenLikelihoodsItem var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *t = ToolCallDelta(value) - - extraProperties, err := core.ExtractExtraProperties(data, *t) + *s = SingleGenerationTokenLikelihoodsItem(value) + extraProperties, err := internal.ExtractExtraProperties(data, *s) if err != nil { return err } - t.extraProperties = extraProperties - - t._rawJSON = json.RawMessage(data) + s.extraProperties = extraProperties + s.rawJSON = json.RawMessage(data) return nil } -func (t *ToolCallDelta) String() string { - if len(t._rawJSON) > 0 { - if value, err := core.StringifyJSON(t._rawJSON); err == nil { +func (s *SingleGenerationTokenLikelihoodsItem) String() string { + if len(s.rawJSON) > 0 { + if value, err := internal.StringifyJSON(s.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(t); err == nil { + if value, err := internal.StringifyJSON(s); err == nil { return value } - return fmt.Sprintf("%#v", t) + return fmt.Sprintf("%#v", s) } -// An array of tool calls to be made. -type ToolCallV2 struct { - Id *string `json:"id,omitempty" url:"id,omitempty"` - Type *string `json:"type,omitempty" url:"type,omitempty"` - Function *ToolCallV2Function `json:"function,omitempty" url:"function,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage +// StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request). +type StreamedChatResponse struct { + EventType string + StreamStart *ChatStreamStartEvent + SearchQueriesGeneration *ChatSearchQueriesGenerationEvent + SearchResults *ChatSearchResultsEvent + TextGeneration *ChatTextGenerationEvent + CitationGeneration *ChatCitationGenerationEvent + ToolCallsGeneration *ChatToolCallsGenerationEvent + StreamEnd *ChatStreamEndEvent + ToolCallsChunk *ChatToolCallsChunkEvent + Debug *ChatDebugEvent } -func (t *ToolCallV2) GetExtraProperties() map[string]interface{} { - return t.extraProperties +func (s *StreamedChatResponse) GetEventType() string { + if s == nil { + return "" + } + return s.EventType } -func (t *ToolCallV2) UnmarshalJSON(data []byte) error { - type unmarshaler ToolCallV2 - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err +func (s *StreamedChatResponse) GetStreamStart() *ChatStreamStartEvent { + if s == nil { + return nil } - *t = ToolCallV2(value) + return s.StreamStart +} - extraProperties, err := core.ExtractExtraProperties(data, *t) - if err != nil { - return err +func (s *StreamedChatResponse) GetSearchQueriesGeneration() *ChatSearchQueriesGenerationEvent { + if s == nil { + return nil } - t.extraProperties = extraProperties + return s.SearchQueriesGeneration +} - t._rawJSON = json.RawMessage(data) - return nil +func (s *StreamedChatResponse) GetSearchResults() *ChatSearchResultsEvent { + if s == nil { + return nil + } + return s.SearchResults } -func (t *ToolCallV2) String() string { - if len(t._rawJSON) > 0 { - if value, err := core.StringifyJSON(t._rawJSON); err == nil { - return value - } +func (s *StreamedChatResponse) GetTextGeneration() *ChatTextGenerationEvent { + if s == nil { + return nil } - if value, err := core.StringifyJSON(t); err == nil { - return value + return s.TextGeneration +} + +func (s *StreamedChatResponse) GetCitationGeneration() *ChatCitationGenerationEvent { + if s == nil { + return nil } - return fmt.Sprintf("%#v", t) + return s.CitationGeneration } -type ToolCallV2Function struct { - Name *string `json:"name,omitempty" url:"name,omitempty"` - Arguments *string `json:"arguments,omitempty" url:"arguments,omitempty"` +func (s *StreamedChatResponse) GetToolCallsGeneration() *ChatToolCallsGenerationEvent { + if s == nil { + return nil + } + return s.ToolCallsGeneration +} - extraProperties map[string]interface{} - _rawJSON json.RawMessage +func (s *StreamedChatResponse) GetStreamEnd() *ChatStreamEndEvent { + if s == nil { + return nil + } + return s.StreamEnd } -func (t *ToolCallV2Function) GetExtraProperties() map[string]interface{} { - return t.extraProperties +func (s *StreamedChatResponse) GetToolCallsChunk() *ChatToolCallsChunkEvent { + if s == nil { + return nil + } + return s.ToolCallsChunk } -func (t *ToolCallV2Function) UnmarshalJSON(data []byte) error { - type unmarshaler ToolCallV2Function - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err +func (s *StreamedChatResponse) GetDebug() *ChatDebugEvent { + if s == nil { + return nil } - *t = ToolCallV2Function(value) + return s.Debug +} - extraProperties, err := core.ExtractExtraProperties(data, *t) - if err != nil { +func (s *StreamedChatResponse) UnmarshalJSON(data []byte) error { + var unmarshaler struct { + EventType string `json:"event_type"` + } + if err := json.Unmarshal(data, &unmarshaler); err != nil { return err } - t.extraProperties = extraProperties - - t._rawJSON = json.RawMessage(data) + s.EventType = unmarshaler.EventType + if unmarshaler.EventType == "" { + return fmt.Errorf("%T did not include discriminant event_type", s) + } + switch unmarshaler.EventType { + case "stream-start": + value := new(ChatStreamStartEvent) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + s.StreamStart = value + case "search-queries-generation": + value := new(ChatSearchQueriesGenerationEvent) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + s.SearchQueriesGeneration = value + case "search-results": + value := new(ChatSearchResultsEvent) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + s.SearchResults = value + case "text-generation": + value := new(ChatTextGenerationEvent) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + s.TextGeneration = value + case "citation-generation": + value := new(ChatCitationGenerationEvent) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + s.CitationGeneration = value + case "tool-calls-generation": + value := new(ChatToolCallsGenerationEvent) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + s.ToolCallsGeneration = value + case "stream-end": + value := new(ChatStreamEndEvent) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + s.StreamEnd = value + case "tool-calls-chunk": + value := new(ChatToolCallsChunkEvent) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + s.ToolCallsChunk = value + case "debug": + value := new(ChatDebugEvent) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + s.Debug = value + } return nil } -func (t *ToolCallV2Function) String() string { - if len(t._rawJSON) > 0 { - if value, err := core.StringifyJSON(t._rawJSON); err == nil { - return value - } +func (s StreamedChatResponse) MarshalJSON() ([]byte, error) { + if err := s.validate(); err != nil { + return nil, err } - if value, err := core.StringifyJSON(t); err == nil { - return value + if s.StreamStart != nil { + return internal.MarshalJSONWithExtraProperty(s.StreamStart, "event_type", "stream-start") } - return fmt.Sprintf("%#v", t) + if s.SearchQueriesGeneration != nil { + return internal.MarshalJSONWithExtraProperty(s.SearchQueriesGeneration, "event_type", "search-queries-generation") + } + if s.SearchResults != nil { + return internal.MarshalJSONWithExtraProperty(s.SearchResults, "event_type", "search-results") + } + if s.TextGeneration != nil { + return internal.MarshalJSONWithExtraProperty(s.TextGeneration, "event_type", "text-generation") + } + if s.CitationGeneration != nil { + return internal.MarshalJSONWithExtraProperty(s.CitationGeneration, "event_type", "citation-generation") + } + if s.ToolCallsGeneration != nil { + return internal.MarshalJSONWithExtraProperty(s.ToolCallsGeneration, "event_type", "tool-calls-generation") + } + if s.StreamEnd != nil { + return internal.MarshalJSONWithExtraProperty(s.StreamEnd, "event_type", "stream-end") + } + if s.ToolCallsChunk != nil { + return internal.MarshalJSONWithExtraProperty(s.ToolCallsChunk, "event_type", "tool-calls-chunk") + } + if s.Debug != nil { + return internal.MarshalJSONWithExtraProperty(s.Debug, "event_type", "debug") + } + return nil, fmt.Errorf("type %T does not define a non-empty union type", s) } -// A content block which contains information about the content of a tool result -type ToolContent struct { - Type string - Text *TextContent - Document *DocumentContent +type StreamedChatResponseVisitor interface { + VisitStreamStart(*ChatStreamStartEvent) error + VisitSearchQueriesGeneration(*ChatSearchQueriesGenerationEvent) error + VisitSearchResults(*ChatSearchResultsEvent) error + VisitTextGeneration(*ChatTextGenerationEvent) error + VisitCitationGeneration(*ChatCitationGenerationEvent) error + VisitToolCallsGeneration(*ChatToolCallsGenerationEvent) error + VisitStreamEnd(*ChatStreamEndEvent) error + VisitToolCallsChunk(*ChatToolCallsChunkEvent) error + VisitDebug(*ChatDebugEvent) error } -func (t *ToolContent) UnmarshalJSON(data []byte) error { - var unmarshaler struct { - Type string `json:"type"` +func (s *StreamedChatResponse) Accept(visitor StreamedChatResponseVisitor) error { + if s.StreamStart != nil { + return visitor.VisitStreamStart(s.StreamStart) + } + if s.SearchQueriesGeneration != nil { + return visitor.VisitSearchQueriesGeneration(s.SearchQueriesGeneration) + } + if s.SearchResults != nil { + return visitor.VisitSearchResults(s.SearchResults) + } + if s.TextGeneration != nil { + return visitor.VisitTextGeneration(s.TextGeneration) + } + if s.CitationGeneration != nil { + return visitor.VisitCitationGeneration(s.CitationGeneration) + } + if s.ToolCallsGeneration != nil { + return visitor.VisitToolCallsGeneration(s.ToolCallsGeneration) } - if err := json.Unmarshal(data, &unmarshaler); err != nil { - return err + if s.StreamEnd != nil { + return visitor.VisitStreamEnd(s.StreamEnd) } - t.Type = unmarshaler.Type - if unmarshaler.Type == "" { - return fmt.Errorf("%T did not include discriminant type", t) + if s.ToolCallsChunk != nil { + return visitor.VisitToolCallsChunk(s.ToolCallsChunk) } - switch unmarshaler.Type { - case "text": - value := new(TextContent) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - t.Text = value - case "document": - value := new(DocumentContent) - if err := json.Unmarshal(data, &value); err != nil { - return err - } - t.Document = value + if s.Debug != nil { + return visitor.VisitDebug(s.Debug) } - return nil + return fmt.Errorf("type %T does not define a non-empty union type", s) } -func (t ToolContent) MarshalJSON() ([]byte, error) { - if t.Text != nil { - return core.MarshalJSONWithExtraProperty(t.Text, "type", "text") +func (s *StreamedChatResponse) validate() error { + if s == nil { + return fmt.Errorf("type %T is nil", s) } - if t.Document != nil { - return core.MarshalJSONWithExtraProperty(t.Document, "type", "document") + var fields []string + if s.StreamStart != nil { + fields = append(fields, "stream-start") } - return nil, fmt.Errorf("type %T does not define a non-empty union type", t) -} - -type ToolContentVisitor interface { - VisitText(*TextContent) error - VisitDocument(*DocumentContent) error -} - -func (t *ToolContent) Accept(visitor ToolContentVisitor) error { - if t.Text != nil { - return visitor.VisitText(t.Text) + if s.SearchQueriesGeneration != nil { + fields = append(fields, "search-queries-generation") } - if t.Document != nil { - return visitor.VisitDocument(t.Document) + if s.SearchResults != nil { + fields = append(fields, "search-results") } - return fmt.Errorf("type %T does not define a non-empty union type", t) -} - -// Represents tool result in the chat history. -type ToolMessage struct { - ToolResults []*ToolResult `json:"tool_results,omitempty" url:"tool_results,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (t *ToolMessage) GetExtraProperties() map[string]interface{} { - return t.extraProperties -} - -func (t *ToolMessage) UnmarshalJSON(data []byte) error { - type unmarshaler ToolMessage - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err + if s.TextGeneration != nil { + fields = append(fields, "text-generation") } - *t = ToolMessage(value) - - extraProperties, err := core.ExtractExtraProperties(data, *t) - if err != nil { - return err + if s.CitationGeneration != nil { + fields = append(fields, "citation-generation") } - t.extraProperties = extraProperties - - t._rawJSON = json.RawMessage(data) - return nil -} - -func (t *ToolMessage) String() string { - if len(t._rawJSON) > 0 { - if value, err := core.StringifyJSON(t._rawJSON); err == nil { - return value - } + if s.ToolCallsGeneration != nil { + fields = append(fields, "tool-calls-generation") } - if value, err := core.StringifyJSON(t); err == nil { - return value + if s.StreamEnd != nil { + fields = append(fields, "stream-end") } - return fmt.Sprintf("%#v", t) -} - -// A message with Tool outputs. -type ToolMessageV2 struct { - // The id of the associated tool call that has provided the given content - ToolCallId string `json:"tool_call_id" url:"tool_call_id"` - // Outputs from a tool. The content should formatted as a JSON object string, or a list of tool content blocks - Content *ToolMessageV2Content `json:"content,omitempty" url:"content,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (t *ToolMessageV2) GetExtraProperties() map[string]interface{} { - return t.extraProperties -} - -func (t *ToolMessageV2) UnmarshalJSON(data []byte) error { - type unmarshaler ToolMessageV2 - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err + if s.ToolCallsChunk != nil { + fields = append(fields, "tool-calls-chunk") } - *t = ToolMessageV2(value) - - extraProperties, err := core.ExtractExtraProperties(data, *t) - if err != nil { - return err + if s.Debug != nil { + fields = append(fields, "debug") } - t.extraProperties = extraProperties - - t._rawJSON = json.RawMessage(data) - return nil -} - -func (t *ToolMessageV2) String() string { - if len(t._rawJSON) > 0 { - if value, err := core.StringifyJSON(t._rawJSON); err == nil { - return value + if len(fields) == 0 { + if s.EventType != "" { + return fmt.Errorf("type %T defines a discriminant set to %q but the field is not set", s, s.EventType) } + return fmt.Errorf("type %T is empty", s) } - if value, err := core.StringifyJSON(t); err == nil { - return value - } - return fmt.Sprintf("%#v", t) -} - -// Outputs from a tool. The content should formatted as a JSON object string, or a list of tool content blocks -type ToolMessageV2Content struct { - String string - ToolContentList []*ToolContent -} - -func (t *ToolMessageV2Content) UnmarshalJSON(data []byte) error { - var valueString string - if err := json.Unmarshal(data, &valueString); err == nil { - t.String = valueString - return nil + if len(fields) > 1 { + return fmt.Errorf("type %T defines values for %s, but only one value is allowed", s, fields) } - var valueToolContentList []*ToolContent - if err := json.Unmarshal(data, &valueToolContentList); err == nil { - t.ToolContentList = valueToolContentList - return nil + if s.EventType != "" { + field := fields[0] + if s.EventType != field { + return fmt.Errorf( + "type %T defines a discriminant set to %q, but it does not match the %T field; either remove or update the discriminant to match", + s, + s.EventType, + s, + ) + } } - return fmt.Errorf("%s cannot be deserialized as a %T", data, t) + return nil } -func (t ToolMessageV2Content) MarshalJSON() ([]byte, error) { - if t.String != "" { - return json.Marshal(t.String) - } - if t.ToolContentList != nil { - return json.Marshal(t.ToolContentList) - } - return nil, fmt.Errorf("type %T does not include a non-empty union type", t) -} +// One of `low`, `medium`, `high`, or `auto`, defaults to `auto`. Controls how close to the original text the summary is. `high` extractiveness summaries will lean towards reusing sentences verbatim, while `low` extractiveness summaries will tend to paraphrase more. If `auto` is selected, the best option will be picked based on the input text. +type SummarizeRequestExtractiveness string -type ToolMessageV2ContentVisitor interface { - VisitString(string) error - VisitToolContentList([]*ToolContent) error -} +const ( + SummarizeRequestExtractivenessLow SummarizeRequestExtractiveness = "low" + SummarizeRequestExtractivenessMedium SummarizeRequestExtractiveness = "medium" + SummarizeRequestExtractivenessHigh SummarizeRequestExtractiveness = "high" +) -func (t *ToolMessageV2Content) Accept(visitor ToolMessageV2ContentVisitor) error { - if t.String != "" { - return visitor.VisitString(t.String) - } - if t.ToolContentList != nil { - return visitor.VisitToolContentList(t.ToolContentList) +func NewSummarizeRequestExtractivenessFromString(s string) (SummarizeRequestExtractiveness, error) { + switch s { + case "low": + return SummarizeRequestExtractivenessLow, nil + case "medium": + return SummarizeRequestExtractivenessMedium, nil + case "high": + return SummarizeRequestExtractivenessHigh, nil } - return fmt.Errorf("type %T does not include a non-empty union type", t) -} - -type ToolParameterDefinitionsValue struct { - // The description of the parameter. - Description *string `json:"description,omitempty" url:"description,omitempty"` - // The type of the parameter. Must be a valid Python type. - Type string `json:"type" url:"type"` - // Denotes whether the parameter is always present (required) or not. Defaults to not required. - Required *bool `json:"required,omitempty" url:"required,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage + var t SummarizeRequestExtractiveness + return "", fmt.Errorf("%s is not a valid %T", s, t) } -func (t *ToolParameterDefinitionsValue) GetExtraProperties() map[string]interface{} { - return t.extraProperties +func (s SummarizeRequestExtractiveness) Ptr() *SummarizeRequestExtractiveness { + return &s } -func (t *ToolParameterDefinitionsValue) UnmarshalJSON(data []byte) error { - type unmarshaler ToolParameterDefinitionsValue - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *t = ToolParameterDefinitionsValue(value) - - extraProperties, err := core.ExtractExtraProperties(data, *t) - if err != nil { - return err - } - t.extraProperties = extraProperties +// One of `paragraph`, `bullets`, or `auto`, defaults to `auto`. Indicates the style in which the summary will be delivered - in a free form paragraph or in bullet points. If `auto` is selected, the best option will be picked based on the input text. +type SummarizeRequestFormat string - t._rawJSON = json.RawMessage(data) - return nil -} +const ( + SummarizeRequestFormatParagraph SummarizeRequestFormat = "paragraph" + SummarizeRequestFormatBullets SummarizeRequestFormat = "bullets" +) -func (t *ToolParameterDefinitionsValue) String() string { - if len(t._rawJSON) > 0 { - if value, err := core.StringifyJSON(t._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(t); err == nil { - return value +func NewSummarizeRequestFormatFromString(s string) (SummarizeRequestFormat, error) { + switch s { + case "paragraph": + return SummarizeRequestFormatParagraph, nil + case "bullets": + return SummarizeRequestFormatBullets, nil } - return fmt.Sprintf("%#v", t) + var t SummarizeRequestFormat + return "", fmt.Errorf("%s is not a valid %T", s, t) } -type ToolResult struct { - Call *ToolCall `json:"call,omitempty" url:"call,omitempty"` - Outputs []map[string]interface{} `json:"outputs,omitempty" url:"outputs,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage +func (s SummarizeRequestFormat) Ptr() *SummarizeRequestFormat { + return &s } -func (t *ToolResult) GetExtraProperties() map[string]interface{} { - return t.extraProperties -} +// One of `short`, `medium`, `long`, or `auto` defaults to `auto`. Indicates the approximate length of the summary. If `auto` is selected, the best option will be picked based on the input text. +type SummarizeRequestLength string -func (t *ToolResult) UnmarshalJSON(data []byte) error { - type unmarshaler ToolResult - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *t = ToolResult(value) +const ( + SummarizeRequestLengthShort SummarizeRequestLength = "short" + SummarizeRequestLengthMedium SummarizeRequestLength = "medium" + SummarizeRequestLengthLong SummarizeRequestLength = "long" +) - extraProperties, err := core.ExtractExtraProperties(data, *t) - if err != nil { - return err +func NewSummarizeRequestLengthFromString(s string) (SummarizeRequestLength, error) { + switch s { + case "short": + return SummarizeRequestLengthShort, nil + case "medium": + return SummarizeRequestLengthMedium, nil + case "long": + return SummarizeRequestLengthLong, nil } - t.extraProperties = extraProperties - - t._rawJSON = json.RawMessage(data) - return nil + var t SummarizeRequestLength + return "", fmt.Errorf("%s is not a valid %T", s, t) } -func (t *ToolResult) String() string { - if len(t._rawJSON) > 0 { - if value, err := core.StringifyJSON(t._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(t); err == nil { - return value - } - return fmt.Sprintf("%#v", t) +func (s SummarizeRequestLength) Ptr() *SummarizeRequestLength { + return &s } -type ToolSource struct { - // The unique identifier of the document - Id *string `json:"id,omitempty" url:"id,omitempty"` - ToolOutput map[string]interface{} `json:"tool_output,omitempty" url:"tool_output,omitempty"` +type SummarizeResponse struct { + // Generated ID for the summary + Id *string `json:"id,omitempty" url:"id,omitempty"` + // Generated summary for the text + Summary *string `json:"summary,omitempty" url:"summary,omitempty"` + Meta *ApiMeta `json:"meta,omitempty" url:"meta,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (t *ToolSource) GetExtraProperties() map[string]interface{} { - return t.extraProperties + rawJSON json.RawMessage } -func (t *ToolSource) UnmarshalJSON(data []byte) error { - type unmarshaler ToolSource - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *t = ToolSource(value) - - extraProperties, err := core.ExtractExtraProperties(data, *t) - if err != nil { - return err +func (s *SummarizeResponse) GetId() *string { + if s == nil { + return nil } - t.extraProperties = extraProperties - - t._rawJSON = json.RawMessage(data) - return nil + return s.Id } -func (t *ToolSource) String() string { - if len(t._rawJSON) > 0 { - if value, err := core.StringifyJSON(t._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(t); err == nil { - return value +func (s *SummarizeResponse) GetSummary() *string { + if s == nil { + return nil } - return fmt.Sprintf("%#v", t) + return s.Summary } -type ToolV2 struct { - Type *string `json:"type,omitempty" url:"type,omitempty"` - // The function to be executed. - Function *ToolV2Function `json:"function,omitempty" url:"function,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage +func (s *SummarizeResponse) GetMeta() *ApiMeta { + if s == nil { + return nil + } + return s.Meta } -func (t *ToolV2) GetExtraProperties() map[string]interface{} { - return t.extraProperties +func (s *SummarizeResponse) GetExtraProperties() map[string]interface{} { + return s.extraProperties } -func (t *ToolV2) UnmarshalJSON(data []byte) error { - type unmarshaler ToolV2 +func (s *SummarizeResponse) UnmarshalJSON(data []byte) error { + type unmarshaler SummarizeResponse var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *t = ToolV2(value) - - extraProperties, err := core.ExtractExtraProperties(data, *t) + *s = SummarizeResponse(value) + extraProperties, err := internal.ExtractExtraProperties(data, *s) if err != nil { return err } - t.extraProperties = extraProperties - - t._rawJSON = json.RawMessage(data) + s.extraProperties = extraProperties + s.rawJSON = json.RawMessage(data) return nil } -func (t *ToolV2) String() string { - if len(t._rawJSON) > 0 { - if value, err := core.StringifyJSON(t._rawJSON); err == nil { +func (s *SummarizeResponse) String() string { + if len(s.rawJSON) > 0 { + if value, err := internal.StringifyJSON(s.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(t); err == nil { + if value, err := internal.StringifyJSON(s); err == nil { return value } - return fmt.Sprintf("%#v", t) + return fmt.Sprintf("%#v", s) } -// The function to be executed. -type ToolV2Function struct { - // The name of the function. - Name *string `json:"name,omitempty" url:"name,omitempty"` - // The description of the function. - Description *string `json:"description,omitempty" url:"description,omitempty"` - // The parameters of the function as a JSON schema. - Parameters map[string]interface{} `json:"parameters,omitempty" url:"parameters,omitempty"` - +type TextResponseFormat struct { extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage } -func (t *ToolV2Function) GetExtraProperties() map[string]interface{} { +func (t *TextResponseFormat) GetExtraProperties() map[string]interface{} { return t.extraProperties } -func (t *ToolV2Function) UnmarshalJSON(data []byte) error { - type unmarshaler ToolV2Function +func (t *TextResponseFormat) UnmarshalJSON(data []byte) error { + type unmarshaler TextResponseFormat var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *t = ToolV2Function(value) - - extraProperties, err := core.ExtractExtraProperties(data, *t) + *t = TextResponseFormat(value) + extraProperties, err := internal.ExtractExtraProperties(data, *t) if err != nil { return err } t.extraProperties = extraProperties - - t._rawJSON = json.RawMessage(data) + t.rawJSON = json.RawMessage(data) return nil } -func (t *ToolV2Function) String() string { - if len(t._rawJSON) > 0 { - if value, err := core.StringifyJSON(t._rawJSON); err == nil { +func (t *TextResponseFormat) String() string { + if len(t.rawJSON) > 0 { + if value, err := internal.StringifyJSON(t.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(t); err == nil { + if value, err := internal.StringifyJSON(t); err == nil { return value } return fmt.Sprintf("%#v", t) } -type UnprocessableEntityErrorBody struct { - Data *string `json:"data,omitempty" url:"data,omitempty"` +type TokenizeResponse struct { + // An array of tokens, where each token is an integer. + Tokens []int `json:"tokens,omitempty" url:"tokens,omitempty"` + TokenStrings []string `json:"token_strings,omitempty" url:"token_strings,omitempty"` + Meta *ApiMeta `json:"meta,omitempty" url:"meta,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (t *TokenizeResponse) GetTokens() []int { + if t == nil { + return nil + } + return t.Tokens +} + +func (t *TokenizeResponse) GetTokenStrings() []string { + if t == nil { + return nil + } + return t.TokenStrings } -func (u *UnprocessableEntityErrorBody) GetExtraProperties() map[string]interface{} { - return u.extraProperties +func (t *TokenizeResponse) GetMeta() *ApiMeta { + if t == nil { + return nil + } + return t.Meta +} + +func (t *TokenizeResponse) GetExtraProperties() map[string]interface{} { + return t.extraProperties } -func (u *UnprocessableEntityErrorBody) UnmarshalJSON(data []byte) error { - type unmarshaler UnprocessableEntityErrorBody +func (t *TokenizeResponse) UnmarshalJSON(data []byte) error { + type unmarshaler TokenizeResponse var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *u = UnprocessableEntityErrorBody(value) - - extraProperties, err := core.ExtractExtraProperties(data, *u) + *t = TokenizeResponse(value) + extraProperties, err := internal.ExtractExtraProperties(data, *t) if err != nil { return err } - u.extraProperties = extraProperties - - u._rawJSON = json.RawMessage(data) + t.extraProperties = extraProperties + t.rawJSON = json.RawMessage(data) return nil } -func (u *UnprocessableEntityErrorBody) String() string { - if len(u._rawJSON) > 0 { - if value, err := core.StringifyJSON(u._rawJSON); err == nil { +func (t *TokenizeResponse) String() string { + if len(t.rawJSON) > 0 { + if value, err := internal.StringifyJSON(t.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(u); err == nil { + if value, err := internal.StringifyJSON(t); err == nil { return value } - return fmt.Sprintf("%#v", u) + return fmt.Sprintf("%#v", t) } -type UpdateConnectorResponse struct { - Connector *Connector `json:"connector,omitempty" url:"connector,omitempty"` +type Tool struct { + // The name of the tool to be called. Valid names contain only the characters `a-z`, `A-Z`, `0-9`, `_` and must not begin with a digit. + Name string `json:"name" url:"name"` + // The description of what the tool does, the model uses the description to choose when and how to call the function. + Description string `json:"description" url:"description"` + // The input parameters of the tool. Accepts a dictionary where the key is the name of the parameter and the value is the parameter spec. Valid parameter names contain only the characters `a-z`, `A-Z`, `0-9`, `_` and must not begin with a digit. + // ``` + // + // { + // "my_param": { + // "description": , + // "type": , // any python data type, such as 'str', 'bool' + // "required": + // } + // } + // + // ``` + ParameterDefinitions map[string]*ToolParameterDefinitionsValue `json:"parameter_definitions,omitempty" url:"parameter_definitions,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage } -func (u *UpdateConnectorResponse) GetExtraProperties() map[string]interface{} { - return u.extraProperties +func (t *Tool) GetName() string { + if t == nil { + return "" + } + return t.Name +} + +func (t *Tool) GetDescription() string { + if t == nil { + return "" + } + return t.Description +} + +func (t *Tool) GetParameterDefinitions() map[string]*ToolParameterDefinitionsValue { + if t == nil { + return nil + } + return t.ParameterDefinitions } -func (u *UpdateConnectorResponse) UnmarshalJSON(data []byte) error { - type unmarshaler UpdateConnectorResponse +func (t *Tool) GetExtraProperties() map[string]interface{} { + return t.extraProperties +} + +func (t *Tool) UnmarshalJSON(data []byte) error { + type unmarshaler Tool var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *u = UpdateConnectorResponse(value) - - extraProperties, err := core.ExtractExtraProperties(data, *u) + *t = Tool(value) + extraProperties, err := internal.ExtractExtraProperties(data, *t) if err != nil { return err } - u.extraProperties = extraProperties - - u._rawJSON = json.RawMessage(data) + t.extraProperties = extraProperties + t.rawJSON = json.RawMessage(data) return nil } -func (u *UpdateConnectorResponse) String() string { - if len(u._rawJSON) > 0 { - if value, err := core.StringifyJSON(u._rawJSON); err == nil { +func (t *Tool) String() string { + if len(t.rawJSON) > 0 { + if value, err := internal.StringifyJSON(t.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(u); err == nil { + if value, err := internal.StringifyJSON(t); err == nil { return value } - return fmt.Sprintf("%#v", u) + return fmt.Sprintf("%#v", t) } -type Usage struct { - BilledUnits *UsageBilledUnits `json:"billed_units,omitempty" url:"billed_units,omitempty"` - Tokens *UsageTokens `json:"tokens,omitempty" url:"tokens,omitempty"` +// Contains the tool calls generated by the model. Use it to invoke your tools. +type ToolCall struct { + // Name of the tool to call. + Name string `json:"name" url:"name"` + // The name and value of the parameters to use when invoking a tool. + Parameters map[string]interface{} `json:"parameters,omitempty" url:"parameters,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (t *ToolCall) GetName() string { + if t == nil { + return "" + } + return t.Name +} + +func (t *ToolCall) GetParameters() map[string]interface{} { + if t == nil { + return nil + } + return t.Parameters } -func (u *Usage) GetExtraProperties() map[string]interface{} { - return u.extraProperties +func (t *ToolCall) GetExtraProperties() map[string]interface{} { + return t.extraProperties } -func (u *Usage) UnmarshalJSON(data []byte) error { - type unmarshaler Usage +func (t *ToolCall) UnmarshalJSON(data []byte) error { + type unmarshaler ToolCall var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *u = Usage(value) - - extraProperties, err := core.ExtractExtraProperties(data, *u) + *t = ToolCall(value) + extraProperties, err := internal.ExtractExtraProperties(data, *t) if err != nil { return err } - u.extraProperties = extraProperties - - u._rawJSON = json.RawMessage(data) + t.extraProperties = extraProperties + t.rawJSON = json.RawMessage(data) return nil } -func (u *Usage) String() string { - if len(u._rawJSON) > 0 { - if value, err := core.StringifyJSON(u._rawJSON); err == nil { +func (t *ToolCall) String() string { + if len(t.rawJSON) > 0 { + if value, err := internal.StringifyJSON(t.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(u); err == nil { + if value, err := internal.StringifyJSON(t); err == nil { return value } - return fmt.Sprintf("%#v", u) + return fmt.Sprintf("%#v", t) } -type UsageBilledUnits struct { - // The number of billed input tokens. - InputTokens *float64 `json:"input_tokens,omitempty" url:"input_tokens,omitempty"` - // The number of billed output tokens. - OutputTokens *float64 `json:"output_tokens,omitempty" url:"output_tokens,omitempty"` - // The number of billed search units. - SearchUnits *float64 `json:"search_units,omitempty" url:"search_units,omitempty"` - // The number of billed classifications units. - Classifications *float64 `json:"classifications,omitempty" url:"classifications,omitempty"` +// Contains the chunk of the tool call generation in the stream. +type ToolCallDelta struct { + // Name of the tool call + Name *string `json:"name,omitempty" url:"name,omitempty"` + // Index of the tool call generated + Index *float64 `json:"index,omitempty" url:"index,omitempty"` + // Chunk of the tool parameters + Parameters *string `json:"parameters,omitempty" url:"parameters,omitempty"` + // Chunk of the tool plan text + Text *string `json:"text,omitempty" url:"text,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (u *UsageBilledUnits) GetExtraProperties() map[string]interface{} { - return u.extraProperties + rawJSON json.RawMessage } -func (u *UsageBilledUnits) UnmarshalJSON(data []byte) error { - type unmarshaler UsageBilledUnits - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err +func (t *ToolCallDelta) GetName() *string { + if t == nil { + return nil } - *u = UsageBilledUnits(value) + return t.Name +} - extraProperties, err := core.ExtractExtraProperties(data, *u) - if err != nil { - return err +func (t *ToolCallDelta) GetIndex() *float64 { + if t == nil { + return nil } - u.extraProperties = extraProperties - - u._rawJSON = json.RawMessage(data) - return nil + return t.Index } -func (u *UsageBilledUnits) String() string { - if len(u._rawJSON) > 0 { - if value, err := core.StringifyJSON(u._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(u); err == nil { - return value +func (t *ToolCallDelta) GetParameters() *string { + if t == nil { + return nil } - return fmt.Sprintf("%#v", u) + return t.Parameters } -type UsageTokens struct { - // The number of tokens used as input to the model. - InputTokens *float64 `json:"input_tokens,omitempty" url:"input_tokens,omitempty"` - // The number of tokens produced by the model. - OutputTokens *float64 `json:"output_tokens,omitempty" url:"output_tokens,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage +func (t *ToolCallDelta) GetText() *string { + if t == nil { + return nil + } + return t.Text } -func (u *UsageTokens) GetExtraProperties() map[string]interface{} { - return u.extraProperties +func (t *ToolCallDelta) GetExtraProperties() map[string]interface{} { + return t.extraProperties } -func (u *UsageTokens) UnmarshalJSON(data []byte) error { - type unmarshaler UsageTokens +func (t *ToolCallDelta) UnmarshalJSON(data []byte) error { + type unmarshaler ToolCallDelta var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *u = UsageTokens(value) - - extraProperties, err := core.ExtractExtraProperties(data, *u) + *t = ToolCallDelta(value) + extraProperties, err := internal.ExtractExtraProperties(data, *t) if err != nil { return err } - u.extraProperties = extraProperties - - u._rawJSON = json.RawMessage(data) + t.extraProperties = extraProperties + t.rawJSON = json.RawMessage(data) return nil } -func (u *UsageTokens) String() string { - if len(u._rawJSON) > 0 { - if value, err := core.StringifyJSON(u._rawJSON); err == nil { +func (t *ToolCallDelta) String() string { + if len(t.rawJSON) > 0 { + if value, err := internal.StringifyJSON(t.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(u); err == nil { + if value, err := internal.StringifyJSON(t); err == nil { return value } - return fmt.Sprintf("%#v", u) + return fmt.Sprintf("%#v", t) } -// A message from the user. -type UserMessage struct { - // The content of the message. This can be a string or a list of content blocks. - // If a string is provided, it will be treated as a text content block. - Content *UserMessageContent `json:"content,omitempty" url:"content,omitempty"` +// Represents tool result in the chat history. +type ToolMessage struct { + ToolResults []*ToolResult `json:"tool_results,omitempty" url:"tool_results,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (t *ToolMessage) GetToolResults() []*ToolResult { + if t == nil { + return nil + } + return t.ToolResults } -func (u *UserMessage) GetExtraProperties() map[string]interface{} { - return u.extraProperties +func (t *ToolMessage) GetExtraProperties() map[string]interface{} { + return t.extraProperties } -func (u *UserMessage) UnmarshalJSON(data []byte) error { - type unmarshaler UserMessage +func (t *ToolMessage) UnmarshalJSON(data []byte) error { + type unmarshaler ToolMessage var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *u = UserMessage(value) - - extraProperties, err := core.ExtractExtraProperties(data, *u) + *t = ToolMessage(value) + extraProperties, err := internal.ExtractExtraProperties(data, *t) if err != nil { return err } - u.extraProperties = extraProperties - - u._rawJSON = json.RawMessage(data) + t.extraProperties = extraProperties + t.rawJSON = json.RawMessage(data) return nil } -func (u *UserMessage) String() string { - if len(u._rawJSON) > 0 { - if value, err := core.StringifyJSON(u._rawJSON); err == nil { +func (t *ToolMessage) String() string { + if len(t.rawJSON) > 0 { + if value, err := internal.StringifyJSON(t.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(u); err == nil { + if value, err := internal.StringifyJSON(t); err == nil { return value } - return fmt.Sprintf("%#v", u) + return fmt.Sprintf("%#v", t) } -// The content of the message. This can be a string or a list of content blocks. -// If a string is provided, it will be treated as a text content block. -type UserMessageContent struct { - String string - ContentList []*Content +type ToolParameterDefinitionsValue struct { + // The description of the parameter. + Description *string `json:"description,omitempty" url:"description,omitempty"` + // The type of the parameter. Must be a valid Python type. + Type string `json:"type" url:"type"` + // Denotes whether the parameter is always present (required) or not. Defaults to not required. + Required *bool `json:"required,omitempty" url:"required,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage } -func (u *UserMessageContent) UnmarshalJSON(data []byte) error { - var valueString string - if err := json.Unmarshal(data, &valueString); err == nil { - u.String = valueString +func (t *ToolParameterDefinitionsValue) GetDescription() *string { + if t == nil { return nil } - var valueContentList []*Content - if err := json.Unmarshal(data, &valueContentList); err == nil { - u.ContentList = valueContentList - return nil - } - return fmt.Errorf("%s cannot be deserialized as a %T", data, u) + return t.Description } -func (u UserMessageContent) MarshalJSON() ([]byte, error) { - if u.String != "" { - return json.Marshal(u.String) +func (t *ToolParameterDefinitionsValue) GetType() string { + if t == nil { + return "" } - if u.ContentList != nil { - return json.Marshal(u.ContentList) - } - return nil, fmt.Errorf("type %T does not include a non-empty union type", u) -} - -type UserMessageContentVisitor interface { - VisitString(string) error - VisitContentList([]*Content) error + return t.Type } -func (u *UserMessageContent) Accept(visitor UserMessageContentVisitor) error { - if u.String != "" { - return visitor.VisitString(u.String) - } - if u.ContentList != nil { - return visitor.VisitContentList(u.ContentList) +func (t *ToolParameterDefinitionsValue) GetRequired() *bool { + if t == nil { + return nil } - return fmt.Errorf("type %T does not include a non-empty union type", u) -} - -// the underlying files that make up the dataset -type DatasetsCreateResponseDatasetPartsItem struct { - // the name of the dataset part - Name *string `json:"name,omitempty" url:"name,omitempty"` - // the number of rows in the dataset part - NumRows *float64 `json:"num_rows,omitempty" url:"num_rows,omitempty"` - Samples []string `json:"samples,omitempty" url:"samples,omitempty"` - // the kind of dataset part - PartKind *string `json:"part_kind,omitempty" url:"part_kind,omitempty"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage + return t.Required } -func (d *DatasetsCreateResponseDatasetPartsItem) GetExtraProperties() map[string]interface{} { - return d.extraProperties +func (t *ToolParameterDefinitionsValue) GetExtraProperties() map[string]interface{} { + return t.extraProperties } -func (d *DatasetsCreateResponseDatasetPartsItem) UnmarshalJSON(data []byte) error { - type unmarshaler DatasetsCreateResponseDatasetPartsItem +func (t *ToolParameterDefinitionsValue) UnmarshalJSON(data []byte) error { + type unmarshaler ToolParameterDefinitionsValue var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *d = DatasetsCreateResponseDatasetPartsItem(value) - - extraProperties, err := core.ExtractExtraProperties(data, *d) + *t = ToolParameterDefinitionsValue(value) + extraProperties, err := internal.ExtractExtraProperties(data, *t) if err != nil { return err } - d.extraProperties = extraProperties - - d._rawJSON = json.RawMessage(data) + t.extraProperties = extraProperties + t.rawJSON = json.RawMessage(data) return nil } -func (d *DatasetsCreateResponseDatasetPartsItem) String() string { - if len(d._rawJSON) > 0 { - if value, err := core.StringifyJSON(d._rawJSON); err == nil { +func (t *ToolParameterDefinitionsValue) String() string { + if len(t.rawJSON) > 0 { + if value, err := internal.StringifyJSON(t.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(d); err == nil { + if value, err := internal.StringifyJSON(t); err == nil { return value } - return fmt.Sprintf("%#v", d) + return fmt.Sprintf("%#v", t) } -type V2RerankResponseResultsItem struct { - // If `return_documents` is set as `false` this will return none, if `true` it will return the documents passed in - Document *V2RerankResponseResultsItemDocument `json:"document,omitempty" url:"document,omitempty"` - // Corresponds to the index in the original list of documents to which the ranked document belongs. (i.e. if the first value in the `results` object has an `index` value of 3, it means in the list of documents passed in, the document at `index=3` had the highest relevance) - Index int `json:"index" url:"index"` - // Relevance scores are normalized to be in the range `[0, 1]`. Scores close to `1` indicate a high relevance to the query, and scores closer to `0` indicate low relevance. It is not accurate to assume a score of 0.9 means the document is 2x more relevant than a document with a score of 0.45 - RelevanceScore float64 `json:"relevance_score" url:"relevance_score"` +type ToolResult struct { + Call *ToolCall `json:"call,omitempty" url:"call,omitempty"` + Outputs []map[string]interface{} `json:"outputs,omitempty" url:"outputs,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage -} - -func (v *V2RerankResponseResultsItem) GetExtraProperties() map[string]interface{} { - return v.extraProperties + rawJSON json.RawMessage } -func (v *V2RerankResponseResultsItem) UnmarshalJSON(data []byte) error { - type unmarshaler V2RerankResponseResultsItem - var value unmarshaler - if err := json.Unmarshal(data, &value); err != nil { - return err - } - *v = V2RerankResponseResultsItem(value) - - extraProperties, err := core.ExtractExtraProperties(data, *v) - if err != nil { - return err +func (t *ToolResult) GetCall() *ToolCall { + if t == nil { + return nil } - v.extraProperties = extraProperties - - v._rawJSON = json.RawMessage(data) - return nil + return t.Call } -func (v *V2RerankResponseResultsItem) String() string { - if len(v._rawJSON) > 0 { - if value, err := core.StringifyJSON(v._rawJSON); err == nil { - return value - } - } - if value, err := core.StringifyJSON(v); err == nil { - return value +func (t *ToolResult) GetOutputs() []map[string]interface{} { + if t == nil { + return nil } - return fmt.Sprintf("%#v", v) -} - -// If `return_documents` is set as `false` this will return none, if `true` it will return the documents passed in -type V2RerankResponseResultsItemDocument struct { - // The text of the document to rerank - Text string `json:"text" url:"text"` - - extraProperties map[string]interface{} - _rawJSON json.RawMessage + return t.Outputs } -func (v *V2RerankResponseResultsItemDocument) GetExtraProperties() map[string]interface{} { - return v.extraProperties +func (t *ToolResult) GetExtraProperties() map[string]interface{} { + return t.extraProperties } -func (v *V2RerankResponseResultsItemDocument) UnmarshalJSON(data []byte) error { - type unmarshaler V2RerankResponseResultsItemDocument +func (t *ToolResult) UnmarshalJSON(data []byte) error { + type unmarshaler ToolResult var value unmarshaler if err := json.Unmarshal(data, &value); err != nil { return err } - *v = V2RerankResponseResultsItemDocument(value) - - extraProperties, err := core.ExtractExtraProperties(data, *v) + *t = ToolResult(value) + extraProperties, err := internal.ExtractExtraProperties(data, *t) if err != nil { return err } - v.extraProperties = extraProperties - - v._rawJSON = json.RawMessage(data) + t.extraProperties = extraProperties + t.rawJSON = json.RawMessage(data) return nil } -func (v *V2RerankResponseResultsItemDocument) String() string { - if len(v._rawJSON) > 0 { - if value, err := core.StringifyJSON(v._rawJSON); err == nil { +func (t *ToolResult) String() string { + if len(t.rawJSON) > 0 { + if value, err := internal.StringifyJSON(t.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(v); err == nil { + if value, err := internal.StringifyJSON(t); err == nil { return value } - return fmt.Sprintf("%#v", v) + return fmt.Sprintf("%#v", t) } diff --git a/v2/client.go b/v2/client.go index 9a6b7e0..f7a78d5 100644 --- a/v2/client.go +++ b/v2/client.go @@ -3,21 +3,18 @@ package v2 import ( - bytes "bytes" context "context" - json "encoding/json" - errors "errors" v2 "github.com/cohere-ai/cohere-go/v2" core "github.com/cohere-ai/cohere-go/v2/core" + internal "github.com/cohere-ai/cohere-go/v2/internal" option "github.com/cohere-ai/cohere-go/v2/option" - io "io" http "net/http" os "os" ) type Client struct { baseURL string - caller *core.Caller + caller *internal.Caller header http.Header } @@ -28,8 +25,8 @@ func NewClient(opts ...option.RequestOption) *Client { } return &Client{ baseURL: options.BaseURL, - caller: core.NewCaller( - &core.CallerParams{ + caller: internal.NewCaller( + &internal.CallerParams{ Client: options.HTTPClient, MaxAttempts: options.MaxAttempts, }, @@ -38,7 +35,7 @@ func NewClient(opts ...option.RequestOption) *Client { } } -// Generates a text response to a user message and streams it down, token by token. To learn how to use the Chat API with streaming follow our [Text Generation guides](https://docs.cohere.com/v2/docs/chat-api). +// Generates a text response to a user message. To learn how to use the Chat API and RAG follow our [Text Generation guides](https://docs.cohere.com/v2/docs/chat-api). // // Follow the [Migration Guide](https://docs.cohere.com/v2/docs/migrating-v1-to-v2) for instructions on moving from API v1 to API v2. func (c *Client) ChatStream( @@ -47,123 +44,96 @@ func (c *Client) ChatStream( opts ...option.RequestOption, ) (*core.Stream[v2.StreamedChatResponseV2], error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) endpointURL := baseURL + "/v2/chat" - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) headers.Set("Accept", "text/event-stream") - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + headers.Set("Content-Type", "application/json") + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } - streamer := core.NewStreamer[v2.StreamedChatResponseV2](c.caller) + streamer := internal.NewStreamer[v2.StreamedChatResponseV2](c.caller) return streamer.Stream( ctx, - &core.StreamParams{ + &internal.StreamParams{ URL: endpointURL, Method: http.MethodPost, - Prefix: core.DefaultSSEDataPrefix, - Terminator: core.DefaultSSETerminator, + Headers: headers, + Prefix: internal.DefaultSSEDataPrefix, + Terminator: internal.DefaultSSETerminator, MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, - Headers: headers, Client: options.HTTPClient, Request: request, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ) } @@ -177,121 +147,94 @@ func (c *Client) Chat( opts ...option.RequestOption, ) (*v2.ChatResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) endpointURL := baseURL + "/v2/chat" - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + headers.Set("Content-Type", "application/json") + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } var response *v2.ChatResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodPost, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Request: request, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err @@ -310,121 +253,94 @@ func (c *Client) Embed( opts ...option.RequestOption, ) (*v2.EmbedByTypeResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) endpointURL := baseURL + "/v2/embed" - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + headers.Set("Content-Type", "application/json") + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } var response *v2.EmbedByTypeResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodPost, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Request: request, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err @@ -439,121 +355,94 @@ func (c *Client) Rerank( opts ...option.RequestOption, ) (*v2.V2RerankResponse, error) { options := core.NewRequestOptions(opts...) - - baseURL := "https://api.cohere.com" - if c.baseURL != "" { - baseURL = c.baseURL - } - if options.BaseURL != "" { - baseURL = options.BaseURL - } + baseURL := internal.ResolveBaseURL( + options.BaseURL, + c.baseURL, + "https://api.cohere.com", + ) endpointURL := baseURL + "/v2/rerank" - - headers := core.MergeHeaders(c.header.Clone(), options.ToHeader()) - - errorDecoder := func(statusCode int, body io.Reader) error { - raw, err := io.ReadAll(body) - if err != nil { - return err - } - apiError := core.NewAPIError(statusCode, errors.New(string(raw))) - decoder := json.NewDecoder(bytes.NewReader(raw)) - switch statusCode { - case 400: - value := new(v2.BadRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 401: - value := new(v2.UnauthorizedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 403: - value := new(v2.ForbiddenError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 404: - value := new(v2.NotFoundError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 422: - value := new(v2.UnprocessableEntityError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 429: - value := new(v2.TooManyRequestsError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 499: - value := new(v2.ClientClosedRequestError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 500: - value := new(v2.InternalServerError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 501: - value := new(v2.NotImplementedError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 503: - value := new(v2.ServiceUnavailableError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - case 504: - value := new(v2.GatewayTimeoutError) - value.APIError = apiError - if err := decoder.Decode(value); err != nil { - return apiError - } - return value - } - return apiError + headers := internal.MergeHeaders( + c.header.Clone(), + options.ToHeader(), + ) + headers.Set("Content-Type", "application/json") + errorCodes := internal.ErrorCodes{ + 400: func(apiError *core.APIError) error { + return &v2.BadRequestError{ + APIError: apiError, + } + }, + 401: func(apiError *core.APIError) error { + return &v2.UnauthorizedError{ + APIError: apiError, + } + }, + 403: func(apiError *core.APIError) error { + return &v2.ForbiddenError{ + APIError: apiError, + } + }, + 404: func(apiError *core.APIError) error { + return &v2.NotFoundError{ + APIError: apiError, + } + }, + 422: func(apiError *core.APIError) error { + return &v2.UnprocessableEntityError{ + APIError: apiError, + } + }, + 429: func(apiError *core.APIError) error { + return &v2.TooManyRequestsError{ + APIError: apiError, + } + }, + 498: func(apiError *core.APIError) error { + return &v2.InvalidTokenError{ + APIError: apiError, + } + }, + 499: func(apiError *core.APIError) error { + return &v2.ClientClosedRequestError{ + APIError: apiError, + } + }, + 500: func(apiError *core.APIError) error { + return &v2.InternalServerError{ + APIError: apiError, + } + }, + 501: func(apiError *core.APIError) error { + return &v2.NotImplementedError{ + APIError: apiError, + } + }, + 503: func(apiError *core.APIError) error { + return &v2.ServiceUnavailableError{ + APIError: apiError, + } + }, + 504: func(apiError *core.APIError) error { + return &v2.GatewayTimeoutError{ + APIError: apiError, + } + }, } var response *v2.V2RerankResponse if err := c.caller.Call( ctx, - &core.CallParams{ + &internal.CallParams{ URL: endpointURL, Method: http.MethodPost, - MaxAttempts: options.MaxAttempts, Headers: headers, + MaxAttempts: options.MaxAttempts, BodyProperties: options.BodyProperties, QueryParameters: options.QueryParameters, Client: options.HTTPClient, Request: request, Response: &response, - ErrorDecoder: errorDecoder, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), }, ); err != nil { return nil, err diff --git a/v_2.go b/v_2.go index 1d2f02f..1dea41a 100644 --- a/v_2.go +++ b/v_2.go @@ -5,7 +5,7 @@ package api import ( json "encoding/json" fmt "fmt" - core "github.com/cohere-ai/cohere-go/v2/core" + internal "github.com/cohere-ai/cohere-go/v2/internal" ) type V2ChatRequest struct { @@ -260,19 +260,4490 @@ type V2RerankRequest struct { MaxTokensPerDoc *int `json:"max_tokens_per_doc,omitempty" url:"-"` } +// A message from the assistant role can contain text and tool call information. +type AssistantMessage struct { + ToolCalls []*ToolCallV2 `json:"tool_calls,omitempty" url:"tool_calls,omitempty"` + // A chain-of-thought style reflection and plan that the model generates when working with Tools. + ToolPlan *string `json:"tool_plan,omitempty" url:"tool_plan,omitempty"` + Content *AssistantMessageContent `json:"content,omitempty" url:"content,omitempty"` + Citations []*Citation `json:"citations,omitempty" url:"citations,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (a *AssistantMessage) GetToolCalls() []*ToolCallV2 { + if a == nil { + return nil + } + return a.ToolCalls +} + +func (a *AssistantMessage) GetToolPlan() *string { + if a == nil { + return nil + } + return a.ToolPlan +} + +func (a *AssistantMessage) GetContent() *AssistantMessageContent { + if a == nil { + return nil + } + return a.Content +} + +func (a *AssistantMessage) GetCitations() []*Citation { + if a == nil { + return nil + } + return a.Citations +} + +func (a *AssistantMessage) GetExtraProperties() map[string]interface{} { + return a.extraProperties +} + +func (a *AssistantMessage) UnmarshalJSON(data []byte) error { + type unmarshaler AssistantMessage + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *a = AssistantMessage(value) + extraProperties, err := internal.ExtractExtraProperties(data, *a) + if err != nil { + return err + } + a.extraProperties = extraProperties + a.rawJSON = json.RawMessage(data) + return nil +} + +func (a *AssistantMessage) String() string { + if len(a.rawJSON) > 0 { + if value, err := internal.StringifyJSON(a.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(a); err == nil { + return value + } + return fmt.Sprintf("%#v", a) +} + +type AssistantMessageContent struct { + String string + AssistantMessageContentItemList []*AssistantMessageContentItem + + typ string +} + +func (a *AssistantMessageContent) GetString() string { + if a == nil { + return "" + } + return a.String +} + +func (a *AssistantMessageContent) GetAssistantMessageContentItemList() []*AssistantMessageContentItem { + if a == nil { + return nil + } + return a.AssistantMessageContentItemList +} + +func (a *AssistantMessageContent) UnmarshalJSON(data []byte) error { + var valueString string + if err := json.Unmarshal(data, &valueString); err == nil { + a.typ = "String" + a.String = valueString + return nil + } + var valueAssistantMessageContentItemList []*AssistantMessageContentItem + if err := json.Unmarshal(data, &valueAssistantMessageContentItemList); err == nil { + a.typ = "AssistantMessageContentItemList" + a.AssistantMessageContentItemList = valueAssistantMessageContentItemList + return nil + } + return fmt.Errorf("%s cannot be deserialized as a %T", data, a) +} + +func (a AssistantMessageContent) MarshalJSON() ([]byte, error) { + if a.typ == "String" || a.String != "" { + return json.Marshal(a.String) + } + if a.typ == "AssistantMessageContentItemList" || a.AssistantMessageContentItemList != nil { + return json.Marshal(a.AssistantMessageContentItemList) + } + return nil, fmt.Errorf("type %T does not include a non-empty union type", a) +} + +type AssistantMessageContentVisitor interface { + VisitString(string) error + VisitAssistantMessageContentItemList([]*AssistantMessageContentItem) error +} + +func (a *AssistantMessageContent) Accept(visitor AssistantMessageContentVisitor) error { + if a.typ == "String" || a.String != "" { + return visitor.VisitString(a.String) + } + if a.typ == "AssistantMessageContentItemList" || a.AssistantMessageContentItemList != nil { + return visitor.VisitAssistantMessageContentItemList(a.AssistantMessageContentItemList) + } + return fmt.Errorf("type %T does not include a non-empty union type", a) +} + +type AssistantMessageContentItem struct { + Type string + Text *TextContent +} + +func (a *AssistantMessageContentItem) GetType() string { + if a == nil { + return "" + } + return a.Type +} + +func (a *AssistantMessageContentItem) GetText() *TextContent { + if a == nil { + return nil + } + return a.Text +} + +func (a *AssistantMessageContentItem) UnmarshalJSON(data []byte) error { + var unmarshaler struct { + Type string `json:"type"` + } + if err := json.Unmarshal(data, &unmarshaler); err != nil { + return err + } + a.Type = unmarshaler.Type + if unmarshaler.Type == "" { + return fmt.Errorf("%T did not include discriminant type", a) + } + switch unmarshaler.Type { + case "text": + value := new(TextContent) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + a.Text = value + } + return nil +} + +func (a AssistantMessageContentItem) MarshalJSON() ([]byte, error) { + if err := a.validate(); err != nil { + return nil, err + } + if a.Text != nil { + return internal.MarshalJSONWithExtraProperty(a.Text, "type", "text") + } + return nil, fmt.Errorf("type %T does not define a non-empty union type", a) +} + +type AssistantMessageContentItemVisitor interface { + VisitText(*TextContent) error +} + +func (a *AssistantMessageContentItem) Accept(visitor AssistantMessageContentItemVisitor) error { + if a.Text != nil { + return visitor.VisitText(a.Text) + } + return fmt.Errorf("type %T does not define a non-empty union type", a) +} + +func (a *AssistantMessageContentItem) validate() error { + if a == nil { + return fmt.Errorf("type %T is nil", a) + } + var fields []string + if a.Text != nil { + fields = append(fields, "text") + } + if len(fields) == 0 { + if a.Type != "" { + return fmt.Errorf("type %T defines a discriminant set to %q but the field is not set", a, a.Type) + } + return fmt.Errorf("type %T is empty", a) + } + if len(fields) > 1 { + return fmt.Errorf("type %T defines values for %s, but only one value is allowed", a, fields) + } + if a.Type != "" { + field := fields[0] + if a.Type != field { + return fmt.Errorf( + "type %T defines a discriminant set to %q, but it does not match the %T field; either remove or update the discriminant to match", + a, + a.Type, + a, + ) + } + } + return nil +} + +// A message from the assistant role can contain text and tool call information. +type AssistantMessageResponse struct { + ToolCalls []*ToolCallV2 `json:"tool_calls,omitempty" url:"tool_calls,omitempty"` + // A chain-of-thought style reflection and plan that the model generates when working with Tools. + ToolPlan *string `json:"tool_plan,omitempty" url:"tool_plan,omitempty"` + Content []*AssistantMessageResponseContentItem `json:"content,omitempty" url:"content,omitempty"` + Citations []*Citation `json:"citations,omitempty" url:"citations,omitempty"` + role string + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (a *AssistantMessageResponse) GetToolCalls() []*ToolCallV2 { + if a == nil { + return nil + } + return a.ToolCalls +} + +func (a *AssistantMessageResponse) GetToolPlan() *string { + if a == nil { + return nil + } + return a.ToolPlan +} + +func (a *AssistantMessageResponse) GetContent() []*AssistantMessageResponseContentItem { + if a == nil { + return nil + } + return a.Content +} + +func (a *AssistantMessageResponse) GetCitations() []*Citation { + if a == nil { + return nil + } + return a.Citations +} + +func (a *AssistantMessageResponse) Role() string { + return a.role +} + +func (a *AssistantMessageResponse) GetExtraProperties() map[string]interface{} { + return a.extraProperties +} + +func (a *AssistantMessageResponse) UnmarshalJSON(data []byte) error { + type embed AssistantMessageResponse + var unmarshaler = struct { + embed + Role string `json:"role"` + }{ + embed: embed(*a), + } + if err := json.Unmarshal(data, &unmarshaler); err != nil { + return err + } + *a = AssistantMessageResponse(unmarshaler.embed) + if unmarshaler.Role != "assistant" { + return fmt.Errorf("unexpected value for literal on type %T; expected %v got %v", a, "assistant", unmarshaler.Role) + } + a.role = unmarshaler.Role + extraProperties, err := internal.ExtractExtraProperties(data, *a, "role") + if err != nil { + return err + } + a.extraProperties = extraProperties + a.rawJSON = json.RawMessage(data) + return nil +} + +func (a *AssistantMessageResponse) MarshalJSON() ([]byte, error) { + type embed AssistantMessageResponse + var marshaler = struct { + embed + Role string `json:"role"` + }{ + embed: embed(*a), + Role: "assistant", + } + return json.Marshal(marshaler) +} + +func (a *AssistantMessageResponse) String() string { + if len(a.rawJSON) > 0 { + if value, err := internal.StringifyJSON(a.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(a); err == nil { + return value + } + return fmt.Sprintf("%#v", a) +} + +type AssistantMessageResponseContentItem struct { + Type string + Text *TextContent +} + +func (a *AssistantMessageResponseContentItem) GetType() string { + if a == nil { + return "" + } + return a.Type +} + +func (a *AssistantMessageResponseContentItem) GetText() *TextContent { + if a == nil { + return nil + } + return a.Text +} + +func (a *AssistantMessageResponseContentItem) UnmarshalJSON(data []byte) error { + var unmarshaler struct { + Type string `json:"type"` + } + if err := json.Unmarshal(data, &unmarshaler); err != nil { + return err + } + a.Type = unmarshaler.Type + if unmarshaler.Type == "" { + return fmt.Errorf("%T did not include discriminant type", a) + } + switch unmarshaler.Type { + case "text": + value := new(TextContent) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + a.Text = value + } + return nil +} + +func (a AssistantMessageResponseContentItem) MarshalJSON() ([]byte, error) { + if err := a.validate(); err != nil { + return nil, err + } + if a.Text != nil { + return internal.MarshalJSONWithExtraProperty(a.Text, "type", "text") + } + return nil, fmt.Errorf("type %T does not define a non-empty union type", a) +} + +type AssistantMessageResponseContentItemVisitor interface { + VisitText(*TextContent) error +} + +func (a *AssistantMessageResponseContentItem) Accept(visitor AssistantMessageResponseContentItemVisitor) error { + if a.Text != nil { + return visitor.VisitText(a.Text) + } + return fmt.Errorf("type %T does not define a non-empty union type", a) +} + +func (a *AssistantMessageResponseContentItem) validate() error { + if a == nil { + return fmt.Errorf("type %T is nil", a) + } + var fields []string + if a.Text != nil { + fields = append(fields, "text") + } + if len(fields) == 0 { + if a.Type != "" { + return fmt.Errorf("type %T defines a discriminant set to %q but the field is not set", a, a.Type) + } + return fmt.Errorf("type %T is empty", a) + } + if len(fields) > 1 { + return fmt.Errorf("type %T defines values for %s, but only one value is allowed", a, fields) + } + if a.Type != "" { + field := fields[0] + if a.Type != field { + return fmt.Errorf( + "type %T defines a discriminant set to %q, but it does not match the %T field; either remove or update the discriminant to match", + a, + a.Type, + a, + ) + } + } + return nil +} + +// A streamed delta event which contains a delta of chat text content. +type ChatContentDeltaEvent struct { + Index *int `json:"index,omitempty" url:"index,omitempty"` + Delta *ChatContentDeltaEventDelta `json:"delta,omitempty" url:"delta,omitempty"` + Logprobs *LogprobItem `json:"logprobs,omitempty" url:"logprobs,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatContentDeltaEvent) GetIndex() *int { + if c == nil { + return nil + } + return c.Index +} + +func (c *ChatContentDeltaEvent) GetDelta() *ChatContentDeltaEventDelta { + if c == nil { + return nil + } + return c.Delta +} + +func (c *ChatContentDeltaEvent) GetLogprobs() *LogprobItem { + if c == nil { + return nil + } + return c.Logprobs +} + +func (c *ChatContentDeltaEvent) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatContentDeltaEvent) UnmarshalJSON(data []byte) error { + type unmarshaler ChatContentDeltaEvent + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatContentDeltaEvent(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatContentDeltaEvent) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +type ChatContentDeltaEventDelta struct { + Message *ChatContentDeltaEventDeltaMessage `json:"message,omitempty" url:"message,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatContentDeltaEventDelta) GetMessage() *ChatContentDeltaEventDeltaMessage { + if c == nil { + return nil + } + return c.Message +} + +func (c *ChatContentDeltaEventDelta) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatContentDeltaEventDelta) UnmarshalJSON(data []byte) error { + type unmarshaler ChatContentDeltaEventDelta + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatContentDeltaEventDelta(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatContentDeltaEventDelta) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +type ChatContentDeltaEventDeltaMessage struct { + Content *ChatContentDeltaEventDeltaMessageContent `json:"content,omitempty" url:"content,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatContentDeltaEventDeltaMessage) GetContent() *ChatContentDeltaEventDeltaMessageContent { + if c == nil { + return nil + } + return c.Content +} + +func (c *ChatContentDeltaEventDeltaMessage) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatContentDeltaEventDeltaMessage) UnmarshalJSON(data []byte) error { + type unmarshaler ChatContentDeltaEventDeltaMessage + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatContentDeltaEventDeltaMessage(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatContentDeltaEventDeltaMessage) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +type ChatContentDeltaEventDeltaMessageContent struct { + Text *string `json:"text,omitempty" url:"text,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatContentDeltaEventDeltaMessageContent) GetText() *string { + if c == nil { + return nil + } + return c.Text +} + +func (c *ChatContentDeltaEventDeltaMessageContent) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatContentDeltaEventDeltaMessageContent) UnmarshalJSON(data []byte) error { + type unmarshaler ChatContentDeltaEventDeltaMessageContent + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatContentDeltaEventDeltaMessageContent(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatContentDeltaEventDeltaMessageContent) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +// A streamed delta event which signifies that the content block has ended. +type ChatContentEndEvent struct { + Index *int `json:"index,omitempty" url:"index,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatContentEndEvent) GetIndex() *int { + if c == nil { + return nil + } + return c.Index +} + +func (c *ChatContentEndEvent) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatContentEndEvent) UnmarshalJSON(data []byte) error { + type unmarshaler ChatContentEndEvent + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatContentEndEvent(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatContentEndEvent) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +// A streamed delta event which signifies that a new content block has started. +type ChatContentStartEvent struct { + Index *int `json:"index,omitempty" url:"index,omitempty"` + Delta *ChatContentStartEventDelta `json:"delta,omitempty" url:"delta,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatContentStartEvent) GetIndex() *int { + if c == nil { + return nil + } + return c.Index +} + +func (c *ChatContentStartEvent) GetDelta() *ChatContentStartEventDelta { + if c == nil { + return nil + } + return c.Delta +} + +func (c *ChatContentStartEvent) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatContentStartEvent) UnmarshalJSON(data []byte) error { + type unmarshaler ChatContentStartEvent + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatContentStartEvent(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatContentStartEvent) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +type ChatContentStartEventDelta struct { + Message *ChatContentStartEventDeltaMessage `json:"message,omitempty" url:"message,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatContentStartEventDelta) GetMessage() *ChatContentStartEventDeltaMessage { + if c == nil { + return nil + } + return c.Message +} + +func (c *ChatContentStartEventDelta) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatContentStartEventDelta) UnmarshalJSON(data []byte) error { + type unmarshaler ChatContentStartEventDelta + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatContentStartEventDelta(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatContentStartEventDelta) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +type ChatContentStartEventDeltaMessage struct { + Content *ChatContentStartEventDeltaMessageContent `json:"content,omitempty" url:"content,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatContentStartEventDeltaMessage) GetContent() *ChatContentStartEventDeltaMessageContent { + if c == nil { + return nil + } + return c.Content +} + +func (c *ChatContentStartEventDeltaMessage) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatContentStartEventDeltaMessage) UnmarshalJSON(data []byte) error { + type unmarshaler ChatContentStartEventDeltaMessage + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatContentStartEventDeltaMessage(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatContentStartEventDeltaMessage) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +type ChatContentStartEventDeltaMessageContent struct { + Text *string `json:"text,omitempty" url:"text,omitempty"` + Type *string `json:"type,omitempty" url:"type,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatContentStartEventDeltaMessageContent) GetText() *string { + if c == nil { + return nil + } + return c.Text +} + +func (c *ChatContentStartEventDeltaMessageContent) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatContentStartEventDeltaMessageContent) UnmarshalJSON(data []byte) error { + type unmarshaler ChatContentStartEventDeltaMessageContent + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatContentStartEventDeltaMessageContent(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatContentStartEventDeltaMessageContent) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +// The reason a chat request has finished. +// +// - **complete**: The model finished sending a complete message. +// - **max_tokens**: The number of generated tokens exceeded the model's context length or the value specified via the `max_tokens` parameter. +// - **stop_sequence**: One of the provided `stop_sequence` entries was reached in the model's generation. +// - **tool_call**: The model generated a Tool Call and is expecting a Tool Message in return +// - **error**: The generation failed due to an internal error +type ChatFinishReason string + +const ( + ChatFinishReasonComplete ChatFinishReason = "COMPLETE" + ChatFinishReasonStopSequence ChatFinishReason = "STOP_SEQUENCE" + ChatFinishReasonMaxTokens ChatFinishReason = "MAX_TOKENS" + ChatFinishReasonToolCall ChatFinishReason = "TOOL_CALL" + ChatFinishReasonError ChatFinishReason = "ERROR" +) + +func NewChatFinishReasonFromString(s string) (ChatFinishReason, error) { + switch s { + case "COMPLETE": + return ChatFinishReasonComplete, nil + case "STOP_SEQUENCE": + return ChatFinishReasonStopSequence, nil + case "MAX_TOKENS": + return ChatFinishReasonMaxTokens, nil + case "TOOL_CALL": + return ChatFinishReasonToolCall, nil + case "ERROR": + return ChatFinishReasonError, nil + } + var t ChatFinishReason + return "", fmt.Errorf("%s is not a valid %T", s, t) +} + +func (c ChatFinishReason) Ptr() *ChatFinishReason { + return &c +} + +// A streamed event which signifies that the chat message has ended. +type ChatMessageEndEvent struct { + Id *string `json:"id,omitempty" url:"id,omitempty"` + Delta *ChatMessageEndEventDelta `json:"delta,omitempty" url:"delta,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatMessageEndEvent) GetId() *string { + if c == nil { + return nil + } + return c.Id +} + +func (c *ChatMessageEndEvent) GetDelta() *ChatMessageEndEventDelta { + if c == nil { + return nil + } + return c.Delta +} + +func (c *ChatMessageEndEvent) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatMessageEndEvent) UnmarshalJSON(data []byte) error { + type unmarshaler ChatMessageEndEvent + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatMessageEndEvent(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatMessageEndEvent) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +type ChatMessageEndEventDelta struct { + FinishReason *ChatFinishReason `json:"finish_reason,omitempty" url:"finish_reason,omitempty"` + Usage *Usage `json:"usage,omitempty" url:"usage,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatMessageEndEventDelta) GetFinishReason() *ChatFinishReason { + if c == nil { + return nil + } + return c.FinishReason +} + +func (c *ChatMessageEndEventDelta) GetUsage() *Usage { + if c == nil { + return nil + } + return c.Usage +} + +func (c *ChatMessageEndEventDelta) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatMessageEndEventDelta) UnmarshalJSON(data []byte) error { + type unmarshaler ChatMessageEndEventDelta + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatMessageEndEventDelta(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatMessageEndEventDelta) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +// A streamed event which signifies that a stream has started. +type ChatMessageStartEvent struct { + // Unique identifier for the generated reply. + Id *string `json:"id,omitempty" url:"id,omitempty"` + Delta *ChatMessageStartEventDelta `json:"delta,omitempty" url:"delta,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatMessageStartEvent) GetId() *string { + if c == nil { + return nil + } + return c.Id +} + +func (c *ChatMessageStartEvent) GetDelta() *ChatMessageStartEventDelta { + if c == nil { + return nil + } + return c.Delta +} + +func (c *ChatMessageStartEvent) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatMessageStartEvent) UnmarshalJSON(data []byte) error { + type unmarshaler ChatMessageStartEvent + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatMessageStartEvent(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatMessageStartEvent) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +type ChatMessageStartEventDelta struct { + Message *ChatMessageStartEventDeltaMessage `json:"message,omitempty" url:"message,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatMessageStartEventDelta) GetMessage() *ChatMessageStartEventDeltaMessage { + if c == nil { + return nil + } + return c.Message +} + +func (c *ChatMessageStartEventDelta) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatMessageStartEventDelta) UnmarshalJSON(data []byte) error { + type unmarshaler ChatMessageStartEventDelta + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatMessageStartEventDelta(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatMessageStartEventDelta) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +type ChatMessageStartEventDeltaMessage struct { + // The role of the message. + Role *string `json:"role,omitempty" url:"role,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatMessageStartEventDeltaMessage) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatMessageStartEventDeltaMessage) UnmarshalJSON(data []byte) error { + type unmarshaler ChatMessageStartEventDeltaMessage + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatMessageStartEventDeltaMessage(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatMessageStartEventDeltaMessage) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +// Represents a single message in the chat history from a given role. +type ChatMessageV2 struct { + Role string + User *UserMessage + Assistant *AssistantMessage + System *SystemMessage + Tool *ToolMessageV2 +} + +func (c *ChatMessageV2) GetRole() string { + if c == nil { + return "" + } + return c.Role +} + +func (c *ChatMessageV2) GetUser() *UserMessage { + if c == nil { + return nil + } + return c.User +} + +func (c *ChatMessageV2) GetAssistant() *AssistantMessage { + if c == nil { + return nil + } + return c.Assistant +} + +func (c *ChatMessageV2) GetSystem() *SystemMessage { + if c == nil { + return nil + } + return c.System +} + +func (c *ChatMessageV2) GetTool() *ToolMessageV2 { + if c == nil { + return nil + } + return c.Tool +} + +func (c *ChatMessageV2) UnmarshalJSON(data []byte) error { + var unmarshaler struct { + Role string `json:"role"` + } + if err := json.Unmarshal(data, &unmarshaler); err != nil { + return err + } + c.Role = unmarshaler.Role + if unmarshaler.Role == "" { + return fmt.Errorf("%T did not include discriminant role", c) + } + switch unmarshaler.Role { + case "user": + value := new(UserMessage) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + c.User = value + case "assistant": + value := new(AssistantMessage) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + c.Assistant = value + case "system": + value := new(SystemMessage) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + c.System = value + case "tool": + value := new(ToolMessageV2) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + c.Tool = value + } + return nil +} + +func (c ChatMessageV2) MarshalJSON() ([]byte, error) { + if err := c.validate(); err != nil { + return nil, err + } + if c.User != nil { + return internal.MarshalJSONWithExtraProperty(c.User, "role", "user") + } + if c.Assistant != nil { + return internal.MarshalJSONWithExtraProperty(c.Assistant, "role", "assistant") + } + if c.System != nil { + return internal.MarshalJSONWithExtraProperty(c.System, "role", "system") + } + if c.Tool != nil { + return internal.MarshalJSONWithExtraProperty(c.Tool, "role", "tool") + } + return nil, fmt.Errorf("type %T does not define a non-empty union type", c) +} + +type ChatMessageV2Visitor interface { + VisitUser(*UserMessage) error + VisitAssistant(*AssistantMessage) error + VisitSystem(*SystemMessage) error + VisitTool(*ToolMessageV2) error +} + +func (c *ChatMessageV2) Accept(visitor ChatMessageV2Visitor) error { + if c.User != nil { + return visitor.VisitUser(c.User) + } + if c.Assistant != nil { + return visitor.VisitAssistant(c.Assistant) + } + if c.System != nil { + return visitor.VisitSystem(c.System) + } + if c.Tool != nil { + return visitor.VisitTool(c.Tool) + } + return fmt.Errorf("type %T does not define a non-empty union type", c) +} + +func (c *ChatMessageV2) validate() error { + if c == nil { + return fmt.Errorf("type %T is nil", c) + } + var fields []string + if c.User != nil { + fields = append(fields, "user") + } + if c.Assistant != nil { + fields = append(fields, "assistant") + } + if c.System != nil { + fields = append(fields, "system") + } + if c.Tool != nil { + fields = append(fields, "tool") + } + if len(fields) == 0 { + if c.Role != "" { + return fmt.Errorf("type %T defines a discriminant set to %q but the field is not set", c, c.Role) + } + return fmt.Errorf("type %T is empty", c) + } + if len(fields) > 1 { + return fmt.Errorf("type %T defines values for %s, but only one value is allowed", c, fields) + } + if c.Role != "" { + field := fields[0] + if c.Role != field { + return fmt.Errorf( + "type %T defines a discriminant set to %q, but it does not match the %T field; either remove or update the discriminant to match", + c, + c.Role, + c, + ) + } + } + return nil +} + +// A list of chat messages in chronological order, representing a conversation between the user and the model. +// +// Messages can be from `User`, `Assistant`, `Tool` and `System` roles. Learn more about messages and roles in [the Chat API guide](https://docs.cohere.com/v2/docs/chat-api). +type ChatMessages = []*ChatMessageV2 + +type ChatResponse struct { + // Unique identifier for the generated reply. Useful for submitting feedback. + Id string `json:"id" url:"id"` + FinishReason ChatFinishReason `json:"finish_reason" url:"finish_reason"` + // The prompt that was used. Only present when `return_prompt` in the request is set to true. + Prompt *string `json:"prompt,omitempty" url:"prompt,omitempty"` + Message *AssistantMessageResponse `json:"message,omitempty" url:"message,omitempty"` + Usage *Usage `json:"usage,omitempty" url:"usage,omitempty"` + Logprobs []*LogprobItem `json:"logprobs,omitempty" url:"logprobs,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatResponse) GetId() string { + if c == nil { + return "" + } + return c.Id +} + +func (c *ChatResponse) GetFinishReason() ChatFinishReason { + if c == nil { + return "" + } + return c.FinishReason +} + +func (c *ChatResponse) GetPrompt() *string { + if c == nil { + return nil + } + return c.Prompt +} + +func (c *ChatResponse) GetMessage() *AssistantMessageResponse { + if c == nil { + return nil + } + return c.Message +} + +func (c *ChatResponse) GetUsage() *Usage { + if c == nil { + return nil + } + return c.Usage +} + +func (c *ChatResponse) GetLogprobs() []*LogprobItem { + if c == nil { + return nil + } + return c.Logprobs +} + +func (c *ChatResponse) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatResponse) UnmarshalJSON(data []byte) error { + type unmarshaler ChatResponse + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatResponse(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatResponse) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +// The streamed event types +type ChatStreamEventType struct { + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatStreamEventType) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatStreamEventType) UnmarshalJSON(data []byte) error { + type unmarshaler ChatStreamEventType + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatStreamEventType(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatStreamEventType) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +// A streamed event delta which signifies a delta in tool call arguments. +type ChatToolCallDeltaEvent struct { + Index *int `json:"index,omitempty" url:"index,omitempty"` + Delta *ChatToolCallDeltaEventDelta `json:"delta,omitempty" url:"delta,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatToolCallDeltaEvent) GetIndex() *int { + if c == nil { + return nil + } + return c.Index +} + +func (c *ChatToolCallDeltaEvent) GetDelta() *ChatToolCallDeltaEventDelta { + if c == nil { + return nil + } + return c.Delta +} + +func (c *ChatToolCallDeltaEvent) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatToolCallDeltaEvent) UnmarshalJSON(data []byte) error { + type unmarshaler ChatToolCallDeltaEvent + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatToolCallDeltaEvent(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatToolCallDeltaEvent) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +type ChatToolCallDeltaEventDelta struct { + Message *ChatToolCallDeltaEventDeltaMessage `json:"message,omitempty" url:"message,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatToolCallDeltaEventDelta) GetMessage() *ChatToolCallDeltaEventDeltaMessage { + if c == nil { + return nil + } + return c.Message +} + +func (c *ChatToolCallDeltaEventDelta) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatToolCallDeltaEventDelta) UnmarshalJSON(data []byte) error { + type unmarshaler ChatToolCallDeltaEventDelta + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatToolCallDeltaEventDelta(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatToolCallDeltaEventDelta) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +type ChatToolCallDeltaEventDeltaMessage struct { + ToolCalls *ChatToolCallDeltaEventDeltaMessageToolCalls `json:"tool_calls,omitempty" url:"tool_calls,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatToolCallDeltaEventDeltaMessage) GetToolCalls() *ChatToolCallDeltaEventDeltaMessageToolCalls { + if c == nil { + return nil + } + return c.ToolCalls +} + +func (c *ChatToolCallDeltaEventDeltaMessage) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatToolCallDeltaEventDeltaMessage) UnmarshalJSON(data []byte) error { + type unmarshaler ChatToolCallDeltaEventDeltaMessage + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatToolCallDeltaEventDeltaMessage(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatToolCallDeltaEventDeltaMessage) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +type ChatToolCallDeltaEventDeltaMessageToolCalls struct { + Function *ChatToolCallDeltaEventDeltaMessageToolCallsFunction `json:"function,omitempty" url:"function,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatToolCallDeltaEventDeltaMessageToolCalls) GetFunction() *ChatToolCallDeltaEventDeltaMessageToolCallsFunction { + if c == nil { + return nil + } + return c.Function +} + +func (c *ChatToolCallDeltaEventDeltaMessageToolCalls) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatToolCallDeltaEventDeltaMessageToolCalls) UnmarshalJSON(data []byte) error { + type unmarshaler ChatToolCallDeltaEventDeltaMessageToolCalls + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatToolCallDeltaEventDeltaMessageToolCalls(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatToolCallDeltaEventDeltaMessageToolCalls) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +type ChatToolCallDeltaEventDeltaMessageToolCallsFunction struct { + Arguments *string `json:"arguments,omitempty" url:"arguments,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatToolCallDeltaEventDeltaMessageToolCallsFunction) GetArguments() *string { + if c == nil { + return nil + } + return c.Arguments +} + +func (c *ChatToolCallDeltaEventDeltaMessageToolCallsFunction) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatToolCallDeltaEventDeltaMessageToolCallsFunction) UnmarshalJSON(data []byte) error { + type unmarshaler ChatToolCallDeltaEventDeltaMessageToolCallsFunction + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatToolCallDeltaEventDeltaMessageToolCallsFunction(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatToolCallDeltaEventDeltaMessageToolCallsFunction) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +// A streamed event delta which signifies a tool call has finished streaming. +type ChatToolCallEndEvent struct { + Index *int `json:"index,omitempty" url:"index,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatToolCallEndEvent) GetIndex() *int { + if c == nil { + return nil + } + return c.Index +} + +func (c *ChatToolCallEndEvent) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatToolCallEndEvent) UnmarshalJSON(data []byte) error { + type unmarshaler ChatToolCallEndEvent + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatToolCallEndEvent(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatToolCallEndEvent) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +// A streamed event delta which signifies a tool call has started streaming. +type ChatToolCallStartEvent struct { + Index *int `json:"index,omitempty" url:"index,omitempty"` + Delta *ChatToolCallStartEventDelta `json:"delta,omitempty" url:"delta,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatToolCallStartEvent) GetIndex() *int { + if c == nil { + return nil + } + return c.Index +} + +func (c *ChatToolCallStartEvent) GetDelta() *ChatToolCallStartEventDelta { + if c == nil { + return nil + } + return c.Delta +} + +func (c *ChatToolCallStartEvent) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatToolCallStartEvent) UnmarshalJSON(data []byte) error { + type unmarshaler ChatToolCallStartEvent + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatToolCallStartEvent(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatToolCallStartEvent) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +type ChatToolCallStartEventDelta struct { + Message *ChatToolCallStartEventDeltaMessage `json:"message,omitempty" url:"message,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatToolCallStartEventDelta) GetMessage() *ChatToolCallStartEventDeltaMessage { + if c == nil { + return nil + } + return c.Message +} + +func (c *ChatToolCallStartEventDelta) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatToolCallStartEventDelta) UnmarshalJSON(data []byte) error { + type unmarshaler ChatToolCallStartEventDelta + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatToolCallStartEventDelta(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatToolCallStartEventDelta) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +type ChatToolCallStartEventDeltaMessage struct { + ToolCalls *ToolCallV2 `json:"tool_calls,omitempty" url:"tool_calls,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatToolCallStartEventDeltaMessage) GetToolCalls() *ToolCallV2 { + if c == nil { + return nil + } + return c.ToolCalls +} + +func (c *ChatToolCallStartEventDeltaMessage) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatToolCallStartEventDeltaMessage) UnmarshalJSON(data []byte) error { + type unmarshaler ChatToolCallStartEventDeltaMessage + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatToolCallStartEventDeltaMessage(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatToolCallStartEventDeltaMessage) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +// A streamed event which contains a delta of tool plan text. +type ChatToolPlanDeltaEvent struct { + Delta *ChatToolPlanDeltaEventDelta `json:"delta,omitempty" url:"delta,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatToolPlanDeltaEvent) GetDelta() *ChatToolPlanDeltaEventDelta { + if c == nil { + return nil + } + return c.Delta +} + +func (c *ChatToolPlanDeltaEvent) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatToolPlanDeltaEvent) UnmarshalJSON(data []byte) error { + type unmarshaler ChatToolPlanDeltaEvent + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatToolPlanDeltaEvent(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatToolPlanDeltaEvent) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +type ChatToolPlanDeltaEventDelta struct { + Message *ChatToolPlanDeltaEventDeltaMessage `json:"message,omitempty" url:"message,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatToolPlanDeltaEventDelta) GetMessage() *ChatToolPlanDeltaEventDeltaMessage { + if c == nil { + return nil + } + return c.Message +} + +func (c *ChatToolPlanDeltaEventDelta) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatToolPlanDeltaEventDelta) UnmarshalJSON(data []byte) error { + type unmarshaler ChatToolPlanDeltaEventDelta + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatToolPlanDeltaEventDelta(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatToolPlanDeltaEventDelta) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +type ChatToolPlanDeltaEventDeltaMessage struct { + ToolPlan *string `json:"tool_plan,omitempty" url:"tool_plan,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *ChatToolPlanDeltaEventDeltaMessage) GetToolPlan() *string { + if c == nil { + return nil + } + return c.ToolPlan +} + +func (c *ChatToolPlanDeltaEventDeltaMessage) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *ChatToolPlanDeltaEventDeltaMessage) UnmarshalJSON(data []byte) error { + type unmarshaler ChatToolPlanDeltaEventDeltaMessage + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = ChatToolPlanDeltaEventDeltaMessage(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *ChatToolPlanDeltaEventDeltaMessage) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +// Citation information containing sources and the text cited. +type Citation struct { + // Start index of the cited snippet in the original source text. + Start *int `json:"start,omitempty" url:"start,omitempty"` + // End index of the cited snippet in the original source text. + End *int `json:"end,omitempty" url:"end,omitempty"` + // Text snippet that is being cited. + Text *string `json:"text,omitempty" url:"text,omitempty"` + Sources []*Source `json:"sources,omitempty" url:"sources,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *Citation) GetStart() *int { + if c == nil { + return nil + } + return c.Start +} + +func (c *Citation) GetEnd() *int { + if c == nil { + return nil + } + return c.End +} + +func (c *Citation) GetText() *string { + if c == nil { + return nil + } + return c.Text +} + +func (c *Citation) GetSources() []*Source { + if c == nil { + return nil + } + return c.Sources +} + +func (c *Citation) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *Citation) UnmarshalJSON(data []byte) error { + type unmarshaler Citation + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = Citation(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *Citation) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +// A streamed event which signifies a citation has finished streaming. +type CitationEndEvent struct { + Index *int `json:"index,omitempty" url:"index,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *CitationEndEvent) GetIndex() *int { + if c == nil { + return nil + } + return c.Index +} + +func (c *CitationEndEvent) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *CitationEndEvent) UnmarshalJSON(data []byte) error { + type unmarshaler CitationEndEvent + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = CitationEndEvent(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *CitationEndEvent) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +// Options for controlling citation generation. +type CitationOptions struct { + // Defaults to `"accurate"`. + // Dictates the approach taken to generating citations as part of the RAG flow by allowing the user to specify whether they want `"accurate"` results, `"fast"` results or no results. + // + // **Note**: `command-r7b-12-2024` only supports `"fast"` and `"off"` modes. Its default is `"fast"`. + Mode *CitationOptionsMode `json:"mode,omitempty" url:"mode,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *CitationOptions) GetMode() *CitationOptionsMode { + if c == nil { + return nil + } + return c.Mode +} + +func (c *CitationOptions) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *CitationOptions) UnmarshalJSON(data []byte) error { + type unmarshaler CitationOptions + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = CitationOptions(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *CitationOptions) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +// Defaults to `"accurate"`. +// Dictates the approach taken to generating citations as part of the RAG flow by allowing the user to specify whether they want `"accurate"` results, `"fast"` results or no results. +// +// **Note**: `command-r7b-12-2024` only supports `"fast"` and `"off"` modes. Its default is `"fast"`. +type CitationOptionsMode string + +const ( + CitationOptionsModeFast CitationOptionsMode = "FAST" + CitationOptionsModeAccurate CitationOptionsMode = "ACCURATE" + CitationOptionsModeOff CitationOptionsMode = "OFF" +) + +func NewCitationOptionsModeFromString(s string) (CitationOptionsMode, error) { + switch s { + case "FAST": + return CitationOptionsModeFast, nil + case "ACCURATE": + return CitationOptionsModeAccurate, nil + case "OFF": + return CitationOptionsModeOff, nil + } + var t CitationOptionsMode + return "", fmt.Errorf("%s is not a valid %T", s, t) +} + +func (c CitationOptionsMode) Ptr() *CitationOptionsMode { + return &c +} + +// A streamed event which signifies a citation has been created. +type CitationStartEvent struct { + Index *int `json:"index,omitempty" url:"index,omitempty"` + Delta *CitationStartEventDelta `json:"delta,omitempty" url:"delta,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *CitationStartEvent) GetIndex() *int { + if c == nil { + return nil + } + return c.Index +} + +func (c *CitationStartEvent) GetDelta() *CitationStartEventDelta { + if c == nil { + return nil + } + return c.Delta +} + +func (c *CitationStartEvent) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *CitationStartEvent) UnmarshalJSON(data []byte) error { + type unmarshaler CitationStartEvent + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = CitationStartEvent(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *CitationStartEvent) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +type CitationStartEventDelta struct { + Message *CitationStartEventDeltaMessage `json:"message,omitempty" url:"message,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *CitationStartEventDelta) GetMessage() *CitationStartEventDeltaMessage { + if c == nil { + return nil + } + return c.Message +} + +func (c *CitationStartEventDelta) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *CitationStartEventDelta) UnmarshalJSON(data []byte) error { + type unmarshaler CitationStartEventDelta + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = CitationStartEventDelta(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *CitationStartEventDelta) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +type CitationStartEventDeltaMessage struct { + Citations *Citation `json:"citations,omitempty" url:"citations,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (c *CitationStartEventDeltaMessage) GetCitations() *Citation { + if c == nil { + return nil + } + return c.Citations +} + +func (c *CitationStartEventDeltaMessage) GetExtraProperties() map[string]interface{} { + return c.extraProperties +} + +func (c *CitationStartEventDeltaMessage) UnmarshalJSON(data []byte) error { + type unmarshaler CitationStartEventDeltaMessage + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *c = CitationStartEventDeltaMessage(value) + extraProperties, err := internal.ExtractExtraProperties(data, *c) + if err != nil { + return err + } + c.extraProperties = extraProperties + c.rawJSON = json.RawMessage(data) + return nil +} + +func (c *CitationStartEventDeltaMessage) String() string { + if len(c.rawJSON) > 0 { + if value, err := internal.StringifyJSON(c.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(c); err == nil { + return value + } + return fmt.Sprintf("%#v", c) +} + +// A Content block which contains information about the content type and the content itself. +type Content struct { + Type string + Text *TextContent +} + +func (c *Content) GetType() string { + if c == nil { + return "" + } + return c.Type +} + +func (c *Content) GetText() *TextContent { + if c == nil { + return nil + } + return c.Text +} + +func (c *Content) UnmarshalJSON(data []byte) error { + var unmarshaler struct { + Type string `json:"type"` + } + if err := json.Unmarshal(data, &unmarshaler); err != nil { + return err + } + c.Type = unmarshaler.Type + if unmarshaler.Type == "" { + return fmt.Errorf("%T did not include discriminant type", c) + } + switch unmarshaler.Type { + case "text": + value := new(TextContent) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + c.Text = value + } + return nil +} + +func (c Content) MarshalJSON() ([]byte, error) { + if err := c.validate(); err != nil { + return nil, err + } + if c.Text != nil { + return internal.MarshalJSONWithExtraProperty(c.Text, "type", "text") + } + return nil, fmt.Errorf("type %T does not define a non-empty union type", c) +} + +type ContentVisitor interface { + VisitText(*TextContent) error +} + +func (c *Content) Accept(visitor ContentVisitor) error { + if c.Text != nil { + return visitor.VisitText(c.Text) + } + return fmt.Errorf("type %T does not define a non-empty union type", c) +} + +func (c *Content) validate() error { + if c == nil { + return fmt.Errorf("type %T is nil", c) + } + var fields []string + if c.Text != nil { + fields = append(fields, "text") + } + if len(fields) == 0 { + if c.Type != "" { + return fmt.Errorf("type %T defines a discriminant set to %q but the field is not set", c, c.Type) + } + return fmt.Errorf("type %T is empty", c) + } + if len(fields) > 1 { + return fmt.Errorf("type %T defines values for %s, but only one value is allowed", c, fields) + } + if c.Type != "" { + field := fields[0] + if c.Type != field { + return fmt.Errorf( + "type %T defines a discriminant set to %q, but it does not match the %T field; either remove or update the discriminant to match", + c, + c.Type, + c, + ) + } + } + return nil +} + +// Relevant information that could be used by the model to generate a more accurate reply. +// The content of each document are generally short (should be under 300 words). Metadata should be used to provide additional information, both the key name and the value will be +// passed to the model. +type Document struct { + // A relevant documents that the model can cite to generate a more accurate reply. Each document is a string-string dictionary. + Data map[string]string `json:"data,omitempty" url:"data,omitempty"` + // Unique identifier for this document which will be referenced in citations. If not provided an ID will be automatically generated. + Id *string `json:"id,omitempty" url:"id,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (d *Document) GetData() map[string]string { + if d == nil { + return nil + } + return d.Data +} + +func (d *Document) GetId() *string { + if d == nil { + return nil + } + return d.Id +} + +func (d *Document) GetExtraProperties() map[string]interface{} { + return d.extraProperties +} + +func (d *Document) UnmarshalJSON(data []byte) error { + type unmarshaler Document + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *d = Document(value) + extraProperties, err := internal.ExtractExtraProperties(data, *d) + if err != nil { + return err + } + d.extraProperties = extraProperties + d.rawJSON = json.RawMessage(data) + return nil +} + +func (d *Document) String() string { + if len(d.rawJSON) > 0 { + if value, err := internal.StringifyJSON(d.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(d); err == nil { + return value + } + return fmt.Sprintf("%#v", d) +} + +// Document content. +type DocumentContent struct { + Document *Document `json:"document,omitempty" url:"document,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (d *DocumentContent) GetDocument() *Document { + if d == nil { + return nil + } + return d.Document +} + +func (d *DocumentContent) GetExtraProperties() map[string]interface{} { + return d.extraProperties +} + +func (d *DocumentContent) UnmarshalJSON(data []byte) error { + type unmarshaler DocumentContent + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *d = DocumentContent(value) + extraProperties, err := internal.ExtractExtraProperties(data, *d) + if err != nil { + return err + } + d.extraProperties = extraProperties + d.rawJSON = json.RawMessage(data) + return nil +} + +func (d *DocumentContent) String() string { + if len(d.rawJSON) > 0 { + if value, err := internal.StringifyJSON(d.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(d); err == nil { + return value + } + return fmt.Sprintf("%#v", d) +} + +// A document source object containing the unique identifier of the document and the document itself. +type DocumentSource struct { + // The unique identifier of the document + Id *string `json:"id,omitempty" url:"id,omitempty"` + Document map[string]interface{} `json:"document,omitempty" url:"document,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (d *DocumentSource) GetId() *string { + if d == nil { + return nil + } + return d.Id +} + +func (d *DocumentSource) GetDocument() map[string]interface{} { + if d == nil { + return nil + } + return d.Document +} + +func (d *DocumentSource) GetExtraProperties() map[string]interface{} { + return d.extraProperties +} + +func (d *DocumentSource) UnmarshalJSON(data []byte) error { + type unmarshaler DocumentSource + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *d = DocumentSource(value) + extraProperties, err := internal.ExtractExtraProperties(data, *d) + if err != nil { + return err + } + d.extraProperties = extraProperties + d.rawJSON = json.RawMessage(data) + return nil +} + +func (d *DocumentSource) String() string { + if len(d.rawJSON) > 0 { + if value, err := internal.StringifyJSON(d.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(d); err == nil { + return value + } + return fmt.Sprintf("%#v", d) +} + +type JsonResponseFormatV2 struct { + // A [JSON schema](https://json-schema.org/overview/what-is-jsonschema) object that the output will adhere to. There are some restrictions we have on the schema, refer to [our guide](https://docs.cohere.com/docs/structured-outputs-json#schema-constraints) for more information. + // Example (required name and age object): + // ```json + // + // { + // "type": "object", + // "properties": { + // "name": {"type": "string"}, + // "age": {"type": "integer"} + // }, + // "required": ["name", "age"] + // } + // + // ``` + // + // **Note**: This field must not be specified when the `type` is set to `"text"`. + JsonSchema map[string]interface{} `json:"json_schema,omitempty" url:"json_schema,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (j *JsonResponseFormatV2) GetJsonSchema() map[string]interface{} { + if j == nil { + return nil + } + return j.JsonSchema +} + +func (j *JsonResponseFormatV2) GetExtraProperties() map[string]interface{} { + return j.extraProperties +} + +func (j *JsonResponseFormatV2) UnmarshalJSON(data []byte) error { + type unmarshaler JsonResponseFormatV2 + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *j = JsonResponseFormatV2(value) + extraProperties, err := internal.ExtractExtraProperties(data, *j) + if err != nil { + return err + } + j.extraProperties = extraProperties + j.rawJSON = json.RawMessage(data) + return nil +} + +func (j *JsonResponseFormatV2) String() string { + if len(j.rawJSON) > 0 { + if value, err := internal.StringifyJSON(j.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(j); err == nil { + return value + } + return fmt.Sprintf("%#v", j) +} + +type LogprobItem struct { + // The text chunk for which the log probabilities was calculated. + Text *string `json:"text,omitempty" url:"text,omitempty"` + // The token ids of each token used to construct the text chunk. + TokenIds []int `json:"token_ids,omitempty" url:"token_ids,omitempty"` + // The log probability of each token used to construct the text chunk. + Logprobs []float64 `json:"logprobs,omitempty" url:"logprobs,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (l *LogprobItem) GetText() *string { + if l == nil { + return nil + } + return l.Text +} + +func (l *LogprobItem) GetTokenIds() []int { + if l == nil { + return nil + } + return l.TokenIds +} + +func (l *LogprobItem) GetLogprobs() []float64 { + if l == nil { + return nil + } + return l.Logprobs +} + +func (l *LogprobItem) GetExtraProperties() map[string]interface{} { + return l.extraProperties +} + +func (l *LogprobItem) UnmarshalJSON(data []byte) error { + type unmarshaler LogprobItem + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *l = LogprobItem(value) + extraProperties, err := internal.ExtractExtraProperties(data, *l) + if err != nil { + return err + } + l.extraProperties = extraProperties + l.rawJSON = json.RawMessage(data) + return nil +} + +func (l *LogprobItem) String() string { + if len(l.rawJSON) > 0 { + if value, err := internal.StringifyJSON(l.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(l); err == nil { + return value + } + return fmt.Sprintf("%#v", l) +} + +// Configuration for forcing the model output to adhere to the specified format. Supported on [Command R](https://docs.cohere.com/v2/docs/command-r), [Command R+](https://docs.cohere.com/v2/docs/command-r-plus) and newer models. +// +// The model can be forced into outputting JSON objects by setting `{ "type": "json_object" }`. +// +// A [JSON Schema](https://json-schema.org/) can optionally be provided, to ensure a specific structure. +// +// **Note**: When using `{ "type": "json_object" }` your `message` should always explicitly instruct the model to generate a JSON (eg: _"Generate a JSON ..."_) . Otherwise the model may end up getting stuck generating an infinite stream of characters and eventually run out of context length. +// +// **Note**: When `json_schema` is not specified, the generated object can have up to 5 layers of nesting. +// +// **Limitation**: The parameter is not supported when used in combinations with the `documents` or `tools` parameters. +type ResponseFormatV2 struct { + Type string + Text *TextResponseFormatV2 + JsonObject *JsonResponseFormatV2 +} + +func (r *ResponseFormatV2) GetType() string { + if r == nil { + return "" + } + return r.Type +} + +func (r *ResponseFormatV2) GetText() *TextResponseFormatV2 { + if r == nil { + return nil + } + return r.Text +} + +func (r *ResponseFormatV2) GetJsonObject() *JsonResponseFormatV2 { + if r == nil { + return nil + } + return r.JsonObject +} + +func (r *ResponseFormatV2) UnmarshalJSON(data []byte) error { + var unmarshaler struct { + Type string `json:"type"` + } + if err := json.Unmarshal(data, &unmarshaler); err != nil { + return err + } + r.Type = unmarshaler.Type + if unmarshaler.Type == "" { + return fmt.Errorf("%T did not include discriminant type", r) + } + switch unmarshaler.Type { + case "text": + value := new(TextResponseFormatV2) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + r.Text = value + case "json_object": + value := new(JsonResponseFormatV2) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + r.JsonObject = value + } + return nil +} + +func (r ResponseFormatV2) MarshalJSON() ([]byte, error) { + if err := r.validate(); err != nil { + return nil, err + } + if r.Text != nil { + return internal.MarshalJSONWithExtraProperty(r.Text, "type", "text") + } + if r.JsonObject != nil { + return internal.MarshalJSONWithExtraProperty(r.JsonObject, "type", "json_object") + } + return nil, fmt.Errorf("type %T does not define a non-empty union type", r) +} + +type ResponseFormatV2Visitor interface { + VisitText(*TextResponseFormatV2) error + VisitJsonObject(*JsonResponseFormatV2) error +} + +func (r *ResponseFormatV2) Accept(visitor ResponseFormatV2Visitor) error { + if r.Text != nil { + return visitor.VisitText(r.Text) + } + if r.JsonObject != nil { + return visitor.VisitJsonObject(r.JsonObject) + } + return fmt.Errorf("type %T does not define a non-empty union type", r) +} + +func (r *ResponseFormatV2) validate() error { + if r == nil { + return fmt.Errorf("type %T is nil", r) + } + var fields []string + if r.Text != nil { + fields = append(fields, "text") + } + if r.JsonObject != nil { + fields = append(fields, "json_object") + } + if len(fields) == 0 { + if r.Type != "" { + return fmt.Errorf("type %T defines a discriminant set to %q but the field is not set", r, r.Type) + } + return fmt.Errorf("type %T is empty", r) + } + if len(fields) > 1 { + return fmt.Errorf("type %T defines values for %s, but only one value is allowed", r, fields) + } + if r.Type != "" { + field := fields[0] + if r.Type != field { + return fmt.Errorf( + "type %T defines a discriminant set to %q, but it does not match the %T field; either remove or update the discriminant to match", + r, + r.Type, + r, + ) + } + } + return nil +} + +// A source object containing information about the source of the data cited. +type Source struct { + Type string + Tool *ToolSource + Document *DocumentSource +} + +func (s *Source) GetType() string { + if s == nil { + return "" + } + return s.Type +} + +func (s *Source) GetTool() *ToolSource { + if s == nil { + return nil + } + return s.Tool +} + +func (s *Source) GetDocument() *DocumentSource { + if s == nil { + return nil + } + return s.Document +} + +func (s *Source) UnmarshalJSON(data []byte) error { + var unmarshaler struct { + Type string `json:"type"` + } + if err := json.Unmarshal(data, &unmarshaler); err != nil { + return err + } + s.Type = unmarshaler.Type + if unmarshaler.Type == "" { + return fmt.Errorf("%T did not include discriminant type", s) + } + switch unmarshaler.Type { + case "tool": + value := new(ToolSource) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + s.Tool = value + case "document": + value := new(DocumentSource) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + s.Document = value + } + return nil +} + +func (s Source) MarshalJSON() ([]byte, error) { + if err := s.validate(); err != nil { + return nil, err + } + if s.Tool != nil { + return internal.MarshalJSONWithExtraProperty(s.Tool, "type", "tool") + } + if s.Document != nil { + return internal.MarshalJSONWithExtraProperty(s.Document, "type", "document") + } + return nil, fmt.Errorf("type %T does not define a non-empty union type", s) +} + +type SourceVisitor interface { + VisitTool(*ToolSource) error + VisitDocument(*DocumentSource) error +} + +func (s *Source) Accept(visitor SourceVisitor) error { + if s.Tool != nil { + return visitor.VisitTool(s.Tool) + } + if s.Document != nil { + return visitor.VisitDocument(s.Document) + } + return fmt.Errorf("type %T does not define a non-empty union type", s) +} + +func (s *Source) validate() error { + if s == nil { + return fmt.Errorf("type %T is nil", s) + } + var fields []string + if s.Tool != nil { + fields = append(fields, "tool") + } + if s.Document != nil { + fields = append(fields, "document") + } + if len(fields) == 0 { + if s.Type != "" { + return fmt.Errorf("type %T defines a discriminant set to %q but the field is not set", s, s.Type) + } + return fmt.Errorf("type %T is empty", s) + } + if len(fields) > 1 { + return fmt.Errorf("type %T defines values for %s, but only one value is allowed", s, fields) + } + if s.Type != "" { + field := fields[0] + if s.Type != field { + return fmt.Errorf( + "type %T defines a discriminant set to %q, but it does not match the %T field; either remove or update the discriminant to match", + s, + s.Type, + s, + ) + } + } + return nil +} + +// StreamedChatResponse is returned in streaming mode (specified with `stream=True` in the request). +type StreamedChatResponseV2 struct { + Type string + MessageStart *ChatMessageStartEvent + ContentStart *ChatContentStartEvent + ContentDelta *ChatContentDeltaEvent + ContentEnd *ChatContentEndEvent + ToolPlanDelta *ChatToolPlanDeltaEvent + ToolCallStart *ChatToolCallStartEvent + ToolCallDelta *ChatToolCallDeltaEvent + ToolCallEnd *ChatToolCallEndEvent + CitationStart *CitationStartEvent + CitationEnd *CitationEndEvent + MessageEnd *ChatMessageEndEvent + Debug *ChatDebugEvent +} + +func (s *StreamedChatResponseV2) GetType() string { + if s == nil { + return "" + } + return s.Type +} + +func (s *StreamedChatResponseV2) GetMessageStart() *ChatMessageStartEvent { + if s == nil { + return nil + } + return s.MessageStart +} + +func (s *StreamedChatResponseV2) GetContentStart() *ChatContentStartEvent { + if s == nil { + return nil + } + return s.ContentStart +} + +func (s *StreamedChatResponseV2) GetContentDelta() *ChatContentDeltaEvent { + if s == nil { + return nil + } + return s.ContentDelta +} + +func (s *StreamedChatResponseV2) GetContentEnd() *ChatContentEndEvent { + if s == nil { + return nil + } + return s.ContentEnd +} + +func (s *StreamedChatResponseV2) GetToolPlanDelta() *ChatToolPlanDeltaEvent { + if s == nil { + return nil + } + return s.ToolPlanDelta +} + +func (s *StreamedChatResponseV2) GetToolCallStart() *ChatToolCallStartEvent { + if s == nil { + return nil + } + return s.ToolCallStart +} + +func (s *StreamedChatResponseV2) GetToolCallDelta() *ChatToolCallDeltaEvent { + if s == nil { + return nil + } + return s.ToolCallDelta +} + +func (s *StreamedChatResponseV2) GetToolCallEnd() *ChatToolCallEndEvent { + if s == nil { + return nil + } + return s.ToolCallEnd +} + +func (s *StreamedChatResponseV2) GetCitationStart() *CitationStartEvent { + if s == nil { + return nil + } + return s.CitationStart +} + +func (s *StreamedChatResponseV2) GetCitationEnd() *CitationEndEvent { + if s == nil { + return nil + } + return s.CitationEnd +} + +func (s *StreamedChatResponseV2) GetMessageEnd() *ChatMessageEndEvent { + if s == nil { + return nil + } + return s.MessageEnd +} + +func (s *StreamedChatResponseV2) GetDebug() *ChatDebugEvent { + if s == nil { + return nil + } + return s.Debug +} + +func (s *StreamedChatResponseV2) UnmarshalJSON(data []byte) error { + var unmarshaler struct { + Type string `json:"type"` + } + if err := json.Unmarshal(data, &unmarshaler); err != nil { + return err + } + s.Type = unmarshaler.Type + if unmarshaler.Type == "" { + return fmt.Errorf("%T did not include discriminant type", s) + } + switch unmarshaler.Type { + case "message-start": + value := new(ChatMessageStartEvent) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + s.MessageStart = value + case "content-start": + value := new(ChatContentStartEvent) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + s.ContentStart = value + case "content-delta": + value := new(ChatContentDeltaEvent) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + s.ContentDelta = value + case "content-end": + value := new(ChatContentEndEvent) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + s.ContentEnd = value + case "tool-plan-delta": + value := new(ChatToolPlanDeltaEvent) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + s.ToolPlanDelta = value + case "tool-call-start": + value := new(ChatToolCallStartEvent) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + s.ToolCallStart = value + case "tool-call-delta": + value := new(ChatToolCallDeltaEvent) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + s.ToolCallDelta = value + case "tool-call-end": + value := new(ChatToolCallEndEvent) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + s.ToolCallEnd = value + case "citation-start": + value := new(CitationStartEvent) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + s.CitationStart = value + case "citation-end": + value := new(CitationEndEvent) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + s.CitationEnd = value + case "message-end": + value := new(ChatMessageEndEvent) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + s.MessageEnd = value + case "debug": + value := new(ChatDebugEvent) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + s.Debug = value + } + return nil +} + +func (s StreamedChatResponseV2) MarshalJSON() ([]byte, error) { + if err := s.validate(); err != nil { + return nil, err + } + if s.MessageStart != nil { + return internal.MarshalJSONWithExtraProperty(s.MessageStart, "type", "message-start") + } + if s.ContentStart != nil { + return internal.MarshalJSONWithExtraProperty(s.ContentStart, "type", "content-start") + } + if s.ContentDelta != nil { + return internal.MarshalJSONWithExtraProperty(s.ContentDelta, "type", "content-delta") + } + if s.ContentEnd != nil { + return internal.MarshalJSONWithExtraProperty(s.ContentEnd, "type", "content-end") + } + if s.ToolPlanDelta != nil { + return internal.MarshalJSONWithExtraProperty(s.ToolPlanDelta, "type", "tool-plan-delta") + } + if s.ToolCallStart != nil { + return internal.MarshalJSONWithExtraProperty(s.ToolCallStart, "type", "tool-call-start") + } + if s.ToolCallDelta != nil { + return internal.MarshalJSONWithExtraProperty(s.ToolCallDelta, "type", "tool-call-delta") + } + if s.ToolCallEnd != nil { + return internal.MarshalJSONWithExtraProperty(s.ToolCallEnd, "type", "tool-call-end") + } + if s.CitationStart != nil { + return internal.MarshalJSONWithExtraProperty(s.CitationStart, "type", "citation-start") + } + if s.CitationEnd != nil { + return internal.MarshalJSONWithExtraProperty(s.CitationEnd, "type", "citation-end") + } + if s.MessageEnd != nil { + return internal.MarshalJSONWithExtraProperty(s.MessageEnd, "type", "message-end") + } + if s.Debug != nil { + return internal.MarshalJSONWithExtraProperty(s.Debug, "type", "debug") + } + return nil, fmt.Errorf("type %T does not define a non-empty union type", s) +} + +type StreamedChatResponseV2Visitor interface { + VisitMessageStart(*ChatMessageStartEvent) error + VisitContentStart(*ChatContentStartEvent) error + VisitContentDelta(*ChatContentDeltaEvent) error + VisitContentEnd(*ChatContentEndEvent) error + VisitToolPlanDelta(*ChatToolPlanDeltaEvent) error + VisitToolCallStart(*ChatToolCallStartEvent) error + VisitToolCallDelta(*ChatToolCallDeltaEvent) error + VisitToolCallEnd(*ChatToolCallEndEvent) error + VisitCitationStart(*CitationStartEvent) error + VisitCitationEnd(*CitationEndEvent) error + VisitMessageEnd(*ChatMessageEndEvent) error + VisitDebug(*ChatDebugEvent) error +} + +func (s *StreamedChatResponseV2) Accept(visitor StreamedChatResponseV2Visitor) error { + if s.MessageStart != nil { + return visitor.VisitMessageStart(s.MessageStart) + } + if s.ContentStart != nil { + return visitor.VisitContentStart(s.ContentStart) + } + if s.ContentDelta != nil { + return visitor.VisitContentDelta(s.ContentDelta) + } + if s.ContentEnd != nil { + return visitor.VisitContentEnd(s.ContentEnd) + } + if s.ToolPlanDelta != nil { + return visitor.VisitToolPlanDelta(s.ToolPlanDelta) + } + if s.ToolCallStart != nil { + return visitor.VisitToolCallStart(s.ToolCallStart) + } + if s.ToolCallDelta != nil { + return visitor.VisitToolCallDelta(s.ToolCallDelta) + } + if s.ToolCallEnd != nil { + return visitor.VisitToolCallEnd(s.ToolCallEnd) + } + if s.CitationStart != nil { + return visitor.VisitCitationStart(s.CitationStart) + } + if s.CitationEnd != nil { + return visitor.VisitCitationEnd(s.CitationEnd) + } + if s.MessageEnd != nil { + return visitor.VisitMessageEnd(s.MessageEnd) + } + if s.Debug != nil { + return visitor.VisitDebug(s.Debug) + } + return fmt.Errorf("type %T does not define a non-empty union type", s) +} + +func (s *StreamedChatResponseV2) validate() error { + if s == nil { + return fmt.Errorf("type %T is nil", s) + } + var fields []string + if s.MessageStart != nil { + fields = append(fields, "message-start") + } + if s.ContentStart != nil { + fields = append(fields, "content-start") + } + if s.ContentDelta != nil { + fields = append(fields, "content-delta") + } + if s.ContentEnd != nil { + fields = append(fields, "content-end") + } + if s.ToolPlanDelta != nil { + fields = append(fields, "tool-plan-delta") + } + if s.ToolCallStart != nil { + fields = append(fields, "tool-call-start") + } + if s.ToolCallDelta != nil { + fields = append(fields, "tool-call-delta") + } + if s.ToolCallEnd != nil { + fields = append(fields, "tool-call-end") + } + if s.CitationStart != nil { + fields = append(fields, "citation-start") + } + if s.CitationEnd != nil { + fields = append(fields, "citation-end") + } + if s.MessageEnd != nil { + fields = append(fields, "message-end") + } + if s.Debug != nil { + fields = append(fields, "debug") + } + if len(fields) == 0 { + if s.Type != "" { + return fmt.Errorf("type %T defines a discriminant set to %q but the field is not set", s, s.Type) + } + return fmt.Errorf("type %T is empty", s) + } + if len(fields) > 1 { + return fmt.Errorf("type %T defines values for %s, but only one value is allowed", s, fields) + } + if s.Type != "" { + field := fields[0] + if s.Type != field { + return fmt.Errorf( + "type %T defines a discriminant set to %q, but it does not match the %T field; either remove or update the discriminant to match", + s, + s.Type, + s, + ) + } + } + return nil +} + +// A message from the system. +type SystemMessage struct { + Content *SystemMessageContent `json:"content,omitempty" url:"content,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (s *SystemMessage) GetContent() *SystemMessageContent { + if s == nil { + return nil + } + return s.Content +} + +func (s *SystemMessage) GetExtraProperties() map[string]interface{} { + return s.extraProperties +} + +func (s *SystemMessage) UnmarshalJSON(data []byte) error { + type unmarshaler SystemMessage + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *s = SystemMessage(value) + extraProperties, err := internal.ExtractExtraProperties(data, *s) + if err != nil { + return err + } + s.extraProperties = extraProperties + s.rawJSON = json.RawMessage(data) + return nil +} + +func (s *SystemMessage) String() string { + if len(s.rawJSON) > 0 { + if value, err := internal.StringifyJSON(s.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(s); err == nil { + return value + } + return fmt.Sprintf("%#v", s) +} + +type SystemMessageContent struct { + String string + SystemMessageContentItemList []*SystemMessageContentItem + + typ string +} + +func (s *SystemMessageContent) GetString() string { + if s == nil { + return "" + } + return s.String +} + +func (s *SystemMessageContent) GetSystemMessageContentItemList() []*SystemMessageContentItem { + if s == nil { + return nil + } + return s.SystemMessageContentItemList +} + +func (s *SystemMessageContent) UnmarshalJSON(data []byte) error { + var valueString string + if err := json.Unmarshal(data, &valueString); err == nil { + s.typ = "String" + s.String = valueString + return nil + } + var valueSystemMessageContentItemList []*SystemMessageContentItem + if err := json.Unmarshal(data, &valueSystemMessageContentItemList); err == nil { + s.typ = "SystemMessageContentItemList" + s.SystemMessageContentItemList = valueSystemMessageContentItemList + return nil + } + return fmt.Errorf("%s cannot be deserialized as a %T", data, s) +} + +func (s SystemMessageContent) MarshalJSON() ([]byte, error) { + if s.typ == "String" || s.String != "" { + return json.Marshal(s.String) + } + if s.typ == "SystemMessageContentItemList" || s.SystemMessageContentItemList != nil { + return json.Marshal(s.SystemMessageContentItemList) + } + return nil, fmt.Errorf("type %T does not include a non-empty union type", s) +} + +type SystemMessageContentVisitor interface { + VisitString(string) error + VisitSystemMessageContentItemList([]*SystemMessageContentItem) error +} + +func (s *SystemMessageContent) Accept(visitor SystemMessageContentVisitor) error { + if s.typ == "String" || s.String != "" { + return visitor.VisitString(s.String) + } + if s.typ == "SystemMessageContentItemList" || s.SystemMessageContentItemList != nil { + return visitor.VisitSystemMessageContentItemList(s.SystemMessageContentItemList) + } + return fmt.Errorf("type %T does not include a non-empty union type", s) +} + +type SystemMessageContentItem struct { + Type string + Text *TextContent +} + +func (s *SystemMessageContentItem) GetType() string { + if s == nil { + return "" + } + return s.Type +} + +func (s *SystemMessageContentItem) GetText() *TextContent { + if s == nil { + return nil + } + return s.Text +} + +func (s *SystemMessageContentItem) UnmarshalJSON(data []byte) error { + var unmarshaler struct { + Type string `json:"type"` + } + if err := json.Unmarshal(data, &unmarshaler); err != nil { + return err + } + s.Type = unmarshaler.Type + if unmarshaler.Type == "" { + return fmt.Errorf("%T did not include discriminant type", s) + } + switch unmarshaler.Type { + case "text": + value := new(TextContent) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + s.Text = value + } + return nil +} + +func (s SystemMessageContentItem) MarshalJSON() ([]byte, error) { + if err := s.validate(); err != nil { + return nil, err + } + if s.Text != nil { + return internal.MarshalJSONWithExtraProperty(s.Text, "type", "text") + } + return nil, fmt.Errorf("type %T does not define a non-empty union type", s) +} + +type SystemMessageContentItemVisitor interface { + VisitText(*TextContent) error +} + +func (s *SystemMessageContentItem) Accept(visitor SystemMessageContentItemVisitor) error { + if s.Text != nil { + return visitor.VisitText(s.Text) + } + return fmt.Errorf("type %T does not define a non-empty union type", s) +} + +func (s *SystemMessageContentItem) validate() error { + if s == nil { + return fmt.Errorf("type %T is nil", s) + } + var fields []string + if s.Text != nil { + fields = append(fields, "text") + } + if len(fields) == 0 { + if s.Type != "" { + return fmt.Errorf("type %T defines a discriminant set to %q but the field is not set", s, s.Type) + } + return fmt.Errorf("type %T is empty", s) + } + if len(fields) > 1 { + return fmt.Errorf("type %T defines values for %s, but only one value is allowed", s, fields) + } + if s.Type != "" { + field := fields[0] + if s.Type != field { + return fmt.Errorf( + "type %T defines a discriminant set to %q, but it does not match the %T field; either remove or update the discriminant to match", + s, + s.Type, + s, + ) + } + } + return nil +} + +// Text content of the message. +type TextContent struct { + Text string `json:"text" url:"text"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (t *TextContent) GetText() string { + if t == nil { + return "" + } + return t.Text +} + +func (t *TextContent) GetExtraProperties() map[string]interface{} { + return t.extraProperties +} + +func (t *TextContent) UnmarshalJSON(data []byte) error { + type unmarshaler TextContent + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *t = TextContent(value) + extraProperties, err := internal.ExtractExtraProperties(data, *t) + if err != nil { + return err + } + t.extraProperties = extraProperties + t.rawJSON = json.RawMessage(data) + return nil +} + +func (t *TextContent) String() string { + if len(t.rawJSON) > 0 { + if value, err := internal.StringifyJSON(t.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(t); err == nil { + return value + } + return fmt.Sprintf("%#v", t) +} + +type TextResponseFormatV2 struct { + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (t *TextResponseFormatV2) GetExtraProperties() map[string]interface{} { + return t.extraProperties +} + +func (t *TextResponseFormatV2) UnmarshalJSON(data []byte) error { + type unmarshaler TextResponseFormatV2 + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *t = TextResponseFormatV2(value) + extraProperties, err := internal.ExtractExtraProperties(data, *t) + if err != nil { + return err + } + t.extraProperties = extraProperties + t.rawJSON = json.RawMessage(data) + return nil +} + +func (t *TextResponseFormatV2) String() string { + if len(t.rawJSON) > 0 { + if value, err := internal.StringifyJSON(t.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(t); err == nil { + return value + } + return fmt.Sprintf("%#v", t) +} + +// An array of tool calls to be made. +type ToolCallV2 struct { + Id *string `json:"id,omitempty" url:"id,omitempty"` + Type *string `json:"type,omitempty" url:"type,omitempty"` + Function *ToolCallV2Function `json:"function,omitempty" url:"function,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (t *ToolCallV2) GetId() *string { + if t == nil { + return nil + } + return t.Id +} + +func (t *ToolCallV2) GetFunction() *ToolCallV2Function { + if t == nil { + return nil + } + return t.Function +} + +func (t *ToolCallV2) GetExtraProperties() map[string]interface{} { + return t.extraProperties +} + +func (t *ToolCallV2) UnmarshalJSON(data []byte) error { + type unmarshaler ToolCallV2 + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *t = ToolCallV2(value) + extraProperties, err := internal.ExtractExtraProperties(data, *t) + if err != nil { + return err + } + t.extraProperties = extraProperties + t.rawJSON = json.RawMessage(data) + return nil +} + +func (t *ToolCallV2) String() string { + if len(t.rawJSON) > 0 { + if value, err := internal.StringifyJSON(t.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(t); err == nil { + return value + } + return fmt.Sprintf("%#v", t) +} + +type ToolCallV2Function struct { + Name *string `json:"name,omitempty" url:"name,omitempty"` + Arguments *string `json:"arguments,omitempty" url:"arguments,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (t *ToolCallV2Function) GetName() *string { + if t == nil { + return nil + } + return t.Name +} + +func (t *ToolCallV2Function) GetArguments() *string { + if t == nil { + return nil + } + return t.Arguments +} + +func (t *ToolCallV2Function) GetExtraProperties() map[string]interface{} { + return t.extraProperties +} + +func (t *ToolCallV2Function) UnmarshalJSON(data []byte) error { + type unmarshaler ToolCallV2Function + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *t = ToolCallV2Function(value) + extraProperties, err := internal.ExtractExtraProperties(data, *t) + if err != nil { + return err + } + t.extraProperties = extraProperties + t.rawJSON = json.RawMessage(data) + return nil +} + +func (t *ToolCallV2Function) String() string { + if len(t.rawJSON) > 0 { + if value, err := internal.StringifyJSON(t.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(t); err == nil { + return value + } + return fmt.Sprintf("%#v", t) +} + +// A content block which contains information about the content of a tool result +type ToolContent struct { + Type string + Text *TextContent + Document *DocumentContent +} + +func (t *ToolContent) GetType() string { + if t == nil { + return "" + } + return t.Type +} + +func (t *ToolContent) GetText() *TextContent { + if t == nil { + return nil + } + return t.Text +} + +func (t *ToolContent) GetDocument() *DocumentContent { + if t == nil { + return nil + } + return t.Document +} + +func (t *ToolContent) UnmarshalJSON(data []byte) error { + var unmarshaler struct { + Type string `json:"type"` + } + if err := json.Unmarshal(data, &unmarshaler); err != nil { + return err + } + t.Type = unmarshaler.Type + if unmarshaler.Type == "" { + return fmt.Errorf("%T did not include discriminant type", t) + } + switch unmarshaler.Type { + case "text": + value := new(TextContent) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + t.Text = value + case "document": + value := new(DocumentContent) + if err := json.Unmarshal(data, &value); err != nil { + return err + } + t.Document = value + } + return nil +} + +func (t ToolContent) MarshalJSON() ([]byte, error) { + if err := t.validate(); err != nil { + return nil, err + } + if t.Text != nil { + return internal.MarshalJSONWithExtraProperty(t.Text, "type", "text") + } + if t.Document != nil { + return internal.MarshalJSONWithExtraProperty(t.Document, "type", "document") + } + return nil, fmt.Errorf("type %T does not define a non-empty union type", t) +} + +type ToolContentVisitor interface { + VisitText(*TextContent) error + VisitDocument(*DocumentContent) error +} + +func (t *ToolContent) Accept(visitor ToolContentVisitor) error { + if t.Text != nil { + return visitor.VisitText(t.Text) + } + if t.Document != nil { + return visitor.VisitDocument(t.Document) + } + return fmt.Errorf("type %T does not define a non-empty union type", t) +} + +func (t *ToolContent) validate() error { + if t == nil { + return fmt.Errorf("type %T is nil", t) + } + var fields []string + if t.Text != nil { + fields = append(fields, "text") + } + if t.Document != nil { + fields = append(fields, "document") + } + if len(fields) == 0 { + if t.Type != "" { + return fmt.Errorf("type %T defines a discriminant set to %q but the field is not set", t, t.Type) + } + return fmt.Errorf("type %T is empty", t) + } + if len(fields) > 1 { + return fmt.Errorf("type %T defines values for %s, but only one value is allowed", t, fields) + } + if t.Type != "" { + field := fields[0] + if t.Type != field { + return fmt.Errorf( + "type %T defines a discriminant set to %q, but it does not match the %T field; either remove or update the discriminant to match", + t, + t.Type, + t, + ) + } + } + return nil +} + +// A message with Tool outputs. +type ToolMessageV2 struct { + // The id of the associated tool call that has provided the given content + ToolCallId string `json:"tool_call_id" url:"tool_call_id"` + // Outputs from a tool. The content should formatted as a JSON object string, or a list of tool content blocks + Content *ToolMessageV2Content `json:"content,omitempty" url:"content,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (t *ToolMessageV2) GetToolCallId() string { + if t == nil { + return "" + } + return t.ToolCallId +} + +func (t *ToolMessageV2) GetContent() *ToolMessageV2Content { + if t == nil { + return nil + } + return t.Content +} + +func (t *ToolMessageV2) GetExtraProperties() map[string]interface{} { + return t.extraProperties +} + +func (t *ToolMessageV2) UnmarshalJSON(data []byte) error { + type unmarshaler ToolMessageV2 + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *t = ToolMessageV2(value) + extraProperties, err := internal.ExtractExtraProperties(data, *t) + if err != nil { + return err + } + t.extraProperties = extraProperties + t.rawJSON = json.RawMessage(data) + return nil +} + +func (t *ToolMessageV2) String() string { + if len(t.rawJSON) > 0 { + if value, err := internal.StringifyJSON(t.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(t); err == nil { + return value + } + return fmt.Sprintf("%#v", t) +} + +// Outputs from a tool. The content should formatted as a JSON object string, or a list of tool content blocks +type ToolMessageV2Content struct { + String string + ToolContentList []*ToolContent + + typ string +} + +func (t *ToolMessageV2Content) GetString() string { + if t == nil { + return "" + } + return t.String +} + +func (t *ToolMessageV2Content) GetToolContentList() []*ToolContent { + if t == nil { + return nil + } + return t.ToolContentList +} + +func (t *ToolMessageV2Content) UnmarshalJSON(data []byte) error { + var valueString string + if err := json.Unmarshal(data, &valueString); err == nil { + t.typ = "String" + t.String = valueString + return nil + } + var valueToolContentList []*ToolContent + if err := json.Unmarshal(data, &valueToolContentList); err == nil { + t.typ = "ToolContentList" + t.ToolContentList = valueToolContentList + return nil + } + return fmt.Errorf("%s cannot be deserialized as a %T", data, t) +} + +func (t ToolMessageV2Content) MarshalJSON() ([]byte, error) { + if t.typ == "String" || t.String != "" { + return json.Marshal(t.String) + } + if t.typ == "ToolContentList" || t.ToolContentList != nil { + return json.Marshal(t.ToolContentList) + } + return nil, fmt.Errorf("type %T does not include a non-empty union type", t) +} + +type ToolMessageV2ContentVisitor interface { + VisitString(string) error + VisitToolContentList([]*ToolContent) error +} + +func (t *ToolMessageV2Content) Accept(visitor ToolMessageV2ContentVisitor) error { + if t.typ == "String" || t.String != "" { + return visitor.VisitString(t.String) + } + if t.typ == "ToolContentList" || t.ToolContentList != nil { + return visitor.VisitToolContentList(t.ToolContentList) + } + return fmt.Errorf("type %T does not include a non-empty union type", t) +} + +type ToolSource struct { + // The unique identifier of the document + Id *string `json:"id,omitempty" url:"id,omitempty"` + ToolOutput map[string]interface{} `json:"tool_output,omitempty" url:"tool_output,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (t *ToolSource) GetId() *string { + if t == nil { + return nil + } + return t.Id +} + +func (t *ToolSource) GetToolOutput() map[string]interface{} { + if t == nil { + return nil + } + return t.ToolOutput +} + +func (t *ToolSource) GetExtraProperties() map[string]interface{} { + return t.extraProperties +} + +func (t *ToolSource) UnmarshalJSON(data []byte) error { + type unmarshaler ToolSource + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *t = ToolSource(value) + extraProperties, err := internal.ExtractExtraProperties(data, *t) + if err != nil { + return err + } + t.extraProperties = extraProperties + t.rawJSON = json.RawMessage(data) + return nil +} + +func (t *ToolSource) String() string { + if len(t.rawJSON) > 0 { + if value, err := internal.StringifyJSON(t.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(t); err == nil { + return value + } + return fmt.Sprintf("%#v", t) +} + +type ToolV2 struct { + Type *string `json:"type,omitempty" url:"type,omitempty"` + // The function to be executed. + Function *ToolV2Function `json:"function,omitempty" url:"function,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (t *ToolV2) GetFunction() *ToolV2Function { + if t == nil { + return nil + } + return t.Function +} + +func (t *ToolV2) GetExtraProperties() map[string]interface{} { + return t.extraProperties +} + +func (t *ToolV2) UnmarshalJSON(data []byte) error { + type unmarshaler ToolV2 + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *t = ToolV2(value) + extraProperties, err := internal.ExtractExtraProperties(data, *t) + if err != nil { + return err + } + t.extraProperties = extraProperties + t.rawJSON = json.RawMessage(data) + return nil +} + +func (t *ToolV2) String() string { + if len(t.rawJSON) > 0 { + if value, err := internal.StringifyJSON(t.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(t); err == nil { + return value + } + return fmt.Sprintf("%#v", t) +} + +// The function to be executed. +type ToolV2Function struct { + // The name of the function. + Name *string `json:"name,omitempty" url:"name,omitempty"` + // The description of the function. + Description *string `json:"description,omitempty" url:"description,omitempty"` + // The parameters of the function as a JSON schema. + Parameters map[string]interface{} `json:"parameters,omitempty" url:"parameters,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (t *ToolV2Function) GetName() *string { + if t == nil { + return nil + } + return t.Name +} + +func (t *ToolV2Function) GetDescription() *string { + if t == nil { + return nil + } + return t.Description +} + +func (t *ToolV2Function) GetParameters() map[string]interface{} { + if t == nil { + return nil + } + return t.Parameters +} + +func (t *ToolV2Function) GetExtraProperties() map[string]interface{} { + return t.extraProperties +} + +func (t *ToolV2Function) UnmarshalJSON(data []byte) error { + type unmarshaler ToolV2Function + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *t = ToolV2Function(value) + extraProperties, err := internal.ExtractExtraProperties(data, *t) + if err != nil { + return err + } + t.extraProperties = extraProperties + t.rawJSON = json.RawMessage(data) + return nil +} + +func (t *ToolV2Function) String() string { + if len(t.rawJSON) > 0 { + if value, err := internal.StringifyJSON(t.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(t); err == nil { + return value + } + return fmt.Sprintf("%#v", t) +} + +type Usage struct { + BilledUnits *UsageBilledUnits `json:"billed_units,omitempty" url:"billed_units,omitempty"` + Tokens *UsageTokens `json:"tokens,omitempty" url:"tokens,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (u *Usage) GetBilledUnits() *UsageBilledUnits { + if u == nil { + return nil + } + return u.BilledUnits +} + +func (u *Usage) GetTokens() *UsageTokens { + if u == nil { + return nil + } + return u.Tokens +} + +func (u *Usage) GetExtraProperties() map[string]interface{} { + return u.extraProperties +} + +func (u *Usage) UnmarshalJSON(data []byte) error { + type unmarshaler Usage + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *u = Usage(value) + extraProperties, err := internal.ExtractExtraProperties(data, *u) + if err != nil { + return err + } + u.extraProperties = extraProperties + u.rawJSON = json.RawMessage(data) + return nil +} + +func (u *Usage) String() string { + if len(u.rawJSON) > 0 { + if value, err := internal.StringifyJSON(u.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(u); err == nil { + return value + } + return fmt.Sprintf("%#v", u) +} + +type UsageBilledUnits struct { + // The number of billed input tokens. + InputTokens *float64 `json:"input_tokens,omitempty" url:"input_tokens,omitempty"` + // The number of billed output tokens. + OutputTokens *float64 `json:"output_tokens,omitempty" url:"output_tokens,omitempty"` + // The number of billed search units. + SearchUnits *float64 `json:"search_units,omitempty" url:"search_units,omitempty"` + // The number of billed classifications units. + Classifications *float64 `json:"classifications,omitempty" url:"classifications,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (u *UsageBilledUnits) GetInputTokens() *float64 { + if u == nil { + return nil + } + return u.InputTokens +} + +func (u *UsageBilledUnits) GetOutputTokens() *float64 { + if u == nil { + return nil + } + return u.OutputTokens +} + +func (u *UsageBilledUnits) GetSearchUnits() *float64 { + if u == nil { + return nil + } + return u.SearchUnits +} + +func (u *UsageBilledUnits) GetClassifications() *float64 { + if u == nil { + return nil + } + return u.Classifications +} + +func (u *UsageBilledUnits) GetExtraProperties() map[string]interface{} { + return u.extraProperties +} + +func (u *UsageBilledUnits) UnmarshalJSON(data []byte) error { + type unmarshaler UsageBilledUnits + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *u = UsageBilledUnits(value) + extraProperties, err := internal.ExtractExtraProperties(data, *u) + if err != nil { + return err + } + u.extraProperties = extraProperties + u.rawJSON = json.RawMessage(data) + return nil +} + +func (u *UsageBilledUnits) String() string { + if len(u.rawJSON) > 0 { + if value, err := internal.StringifyJSON(u.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(u); err == nil { + return value + } + return fmt.Sprintf("%#v", u) +} + +type UsageTokens struct { + // The number of tokens used as input to the model. + InputTokens *float64 `json:"input_tokens,omitempty" url:"input_tokens,omitempty"` + // The number of tokens produced by the model. + OutputTokens *float64 `json:"output_tokens,omitempty" url:"output_tokens,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (u *UsageTokens) GetInputTokens() *float64 { + if u == nil { + return nil + } + return u.InputTokens +} + +func (u *UsageTokens) GetOutputTokens() *float64 { + if u == nil { + return nil + } + return u.OutputTokens +} + +func (u *UsageTokens) GetExtraProperties() map[string]interface{} { + return u.extraProperties +} + +func (u *UsageTokens) UnmarshalJSON(data []byte) error { + type unmarshaler UsageTokens + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *u = UsageTokens(value) + extraProperties, err := internal.ExtractExtraProperties(data, *u) + if err != nil { + return err + } + u.extraProperties = extraProperties + u.rawJSON = json.RawMessage(data) + return nil +} + +func (u *UsageTokens) String() string { + if len(u.rawJSON) > 0 { + if value, err := internal.StringifyJSON(u.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(u); err == nil { + return value + } + return fmt.Sprintf("%#v", u) +} + +// A message from the user. +type UserMessage struct { + // The content of the message. This can be a string or a list of content blocks. + // If a string is provided, it will be treated as a text content block. + Content *UserMessageContent `json:"content,omitempty" url:"content,omitempty"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (u *UserMessage) GetContent() *UserMessageContent { + if u == nil { + return nil + } + return u.Content +} + +func (u *UserMessage) GetExtraProperties() map[string]interface{} { + return u.extraProperties +} + +func (u *UserMessage) UnmarshalJSON(data []byte) error { + type unmarshaler UserMessage + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *u = UserMessage(value) + extraProperties, err := internal.ExtractExtraProperties(data, *u) + if err != nil { + return err + } + u.extraProperties = extraProperties + u.rawJSON = json.RawMessage(data) + return nil +} + +func (u *UserMessage) String() string { + if len(u.rawJSON) > 0 { + if value, err := internal.StringifyJSON(u.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(u); err == nil { + return value + } + return fmt.Sprintf("%#v", u) +} + +// The content of the message. This can be a string or a list of content blocks. +// If a string is provided, it will be treated as a text content block. +type UserMessageContent struct { + String string + ContentList []*Content + + typ string +} + +func (u *UserMessageContent) GetString() string { + if u == nil { + return "" + } + return u.String +} + +func (u *UserMessageContent) GetContentList() []*Content { + if u == nil { + return nil + } + return u.ContentList +} + +func (u *UserMessageContent) UnmarshalJSON(data []byte) error { + var valueString string + if err := json.Unmarshal(data, &valueString); err == nil { + u.typ = "String" + u.String = valueString + return nil + } + var valueContentList []*Content + if err := json.Unmarshal(data, &valueContentList); err == nil { + u.typ = "ContentList" + u.ContentList = valueContentList + return nil + } + return fmt.Errorf("%s cannot be deserialized as a %T", data, u) +} + +func (u UserMessageContent) MarshalJSON() ([]byte, error) { + if u.typ == "String" || u.String != "" { + return json.Marshal(u.String) + } + if u.typ == "ContentList" || u.ContentList != nil { + return json.Marshal(u.ContentList) + } + return nil, fmt.Errorf("type %T does not include a non-empty union type", u) +} + +type UserMessageContentVisitor interface { + VisitString(string) error + VisitContentList([]*Content) error +} + +func (u *UserMessageContent) Accept(visitor UserMessageContentVisitor) error { + if u.typ == "String" || u.String != "" { + return visitor.VisitString(u.String) + } + if u.typ == "ContentList" || u.ContentList != nil { + return visitor.VisitContentList(u.ContentList) + } + return fmt.Errorf("type %T does not include a non-empty union type", u) +} + type V2ChatRequestDocumentsItem struct { String string Document *Document + + typ string +} + +func (v *V2ChatRequestDocumentsItem) GetString() string { + if v == nil { + return "" + } + return v.String +} + +func (v *V2ChatRequestDocumentsItem) GetDocument() *Document { + if v == nil { + return nil + } + return v.Document } func (v *V2ChatRequestDocumentsItem) UnmarshalJSON(data []byte) error { var valueString string if err := json.Unmarshal(data, &valueString); err == nil { + v.typ = "String" v.String = valueString return nil } valueDocument := new(Document) if err := json.Unmarshal(data, &valueDocument); err == nil { + v.typ = "Document" v.Document = valueDocument return nil } @@ -280,10 +4751,10 @@ func (v *V2ChatRequestDocumentsItem) UnmarshalJSON(data []byte) error { } func (v V2ChatRequestDocumentsItem) MarshalJSON() ([]byte, error) { - if v.String != "" { + if v.typ == "String" || v.String != "" { return json.Marshal(v.String) } - if v.Document != nil { + if v.typ == "Document" || v.Document != nil { return json.Marshal(v.Document) } return nil, fmt.Errorf("type %T does not include a non-empty union type", v) @@ -295,10 +4766,10 @@ type V2ChatRequestDocumentsItemVisitor interface { } func (v *V2ChatRequestDocumentsItem) Accept(visitor V2ChatRequestDocumentsItemVisitor) error { - if v.String != "" { + if v.typ == "String" || v.String != "" { return visitor.VisitString(v.String) } - if v.Document != nil { + if v.typ == "Document" || v.Document != nil { return visitor.VisitDocument(v.Document) } return fmt.Errorf("type %T does not include a non-empty union type", v) @@ -340,16 +4811,34 @@ func (v V2ChatRequestSafetyMode) Ptr() *V2ChatRequestSafetyMode { type V2ChatStreamRequestDocumentsItem struct { String string Document *Document + + typ string +} + +func (v *V2ChatStreamRequestDocumentsItem) GetString() string { + if v == nil { + return "" + } + return v.String +} + +func (v *V2ChatStreamRequestDocumentsItem) GetDocument() *Document { + if v == nil { + return nil + } + return v.Document } func (v *V2ChatStreamRequestDocumentsItem) UnmarshalJSON(data []byte) error { var valueString string if err := json.Unmarshal(data, &valueString); err == nil { + v.typ = "String" v.String = valueString return nil } valueDocument := new(Document) if err := json.Unmarshal(data, &valueDocument); err == nil { + v.typ = "Document" v.Document = valueDocument return nil } @@ -357,10 +4846,10 @@ func (v *V2ChatStreamRequestDocumentsItem) UnmarshalJSON(data []byte) error { } func (v V2ChatStreamRequestDocumentsItem) MarshalJSON() ([]byte, error) { - if v.String != "" { + if v.typ == "String" || v.String != "" { return json.Marshal(v.String) } - if v.Document != nil { + if v.typ == "Document" || v.Document != nil { return json.Marshal(v.Document) } return nil, fmt.Errorf("type %T does not include a non-empty union type", v) @@ -372,10 +4861,10 @@ type V2ChatStreamRequestDocumentsItemVisitor interface { } func (v *V2ChatStreamRequestDocumentsItem) Accept(visitor V2ChatStreamRequestDocumentsItemVisitor) error { - if v.String != "" { + if v.typ == "String" || v.String != "" { return visitor.VisitString(v.String) } - if v.Document != nil { + if v.typ == "Document" || v.Document != nil { return visitor.VisitDocument(v.Document) } return fmt.Errorf("type %T does not include a non-empty union type", v) @@ -451,7 +4940,28 @@ type V2RerankResponse struct { Meta *ApiMeta `json:"meta,omitempty" url:"meta,omitempty"` extraProperties map[string]interface{} - _rawJSON json.RawMessage + rawJSON json.RawMessage +} + +func (v *V2RerankResponse) GetId() *string { + if v == nil { + return nil + } + return v.Id +} + +func (v *V2RerankResponse) GetResults() []*V2RerankResponseResultsItem { + if v == nil { + return nil + } + return v.Results +} + +func (v *V2RerankResponse) GetMeta() *ApiMeta { + if v == nil { + return nil + } + return v.Meta } func (v *V2RerankResponse) GetExtraProperties() map[string]interface{} { @@ -465,24 +4975,135 @@ func (v *V2RerankResponse) UnmarshalJSON(data []byte) error { return err } *v = V2RerankResponse(value) + extraProperties, err := internal.ExtractExtraProperties(data, *v) + if err != nil { + return err + } + v.extraProperties = extraProperties + v.rawJSON = json.RawMessage(data) + return nil +} + +func (v *V2RerankResponse) String() string { + if len(v.rawJSON) > 0 { + if value, err := internal.StringifyJSON(v.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(v); err == nil { + return value + } + return fmt.Sprintf("%#v", v) +} + +type V2RerankResponseResultsItem struct { + // If `return_documents` is set as `false` this will return none, if `true` it will return the documents passed in + Document *V2RerankResponseResultsItemDocument `json:"document,omitempty" url:"document,omitempty"` + // Corresponds to the index in the original list of documents to which the ranked document belongs. (i.e. if the first value in the `results` object has an `index` value of 3, it means in the list of documents passed in, the document at `index=3` had the highest relevance) + Index int `json:"index" url:"index"` + // Relevance scores are normalized to be in the range `[0, 1]`. Scores close to `1` indicate a high relevance to the query, and scores closer to `0` indicate low relevance. It is not accurate to assume a score of 0.9 means the document is 2x more relevant than a document with a score of 0.45 + RelevanceScore float64 `json:"relevance_score" url:"relevance_score"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (v *V2RerankResponseResultsItem) GetDocument() *V2RerankResponseResultsItemDocument { + if v == nil { + return nil + } + return v.Document +} + +func (v *V2RerankResponseResultsItem) GetIndex() int { + if v == nil { + return 0 + } + return v.Index +} + +func (v *V2RerankResponseResultsItem) GetRelevanceScore() float64 { + if v == nil { + return 0 + } + return v.RelevanceScore +} - extraProperties, err := core.ExtractExtraProperties(data, *v) +func (v *V2RerankResponseResultsItem) GetExtraProperties() map[string]interface{} { + return v.extraProperties +} + +func (v *V2RerankResponseResultsItem) UnmarshalJSON(data []byte) error { + type unmarshaler V2RerankResponseResultsItem + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *v = V2RerankResponseResultsItem(value) + extraProperties, err := internal.ExtractExtraProperties(data, *v) if err != nil { return err } v.extraProperties = extraProperties + v.rawJSON = json.RawMessage(data) + return nil +} + +func (v *V2RerankResponseResultsItem) String() string { + if len(v.rawJSON) > 0 { + if value, err := internal.StringifyJSON(v.rawJSON); err == nil { + return value + } + } + if value, err := internal.StringifyJSON(v); err == nil { + return value + } + return fmt.Sprintf("%#v", v) +} + +// If `return_documents` is set as `false` this will return none, if `true` it will return the documents passed in +type V2RerankResponseResultsItemDocument struct { + // The text of the document to rerank + Text string `json:"text" url:"text"` + + extraProperties map[string]interface{} + rawJSON json.RawMessage +} + +func (v *V2RerankResponseResultsItemDocument) GetText() string { + if v == nil { + return "" + } + return v.Text +} - v._rawJSON = json.RawMessage(data) +func (v *V2RerankResponseResultsItemDocument) GetExtraProperties() map[string]interface{} { + return v.extraProperties +} + +func (v *V2RerankResponseResultsItemDocument) UnmarshalJSON(data []byte) error { + type unmarshaler V2RerankResponseResultsItemDocument + var value unmarshaler + if err := json.Unmarshal(data, &value); err != nil { + return err + } + *v = V2RerankResponseResultsItemDocument(value) + extraProperties, err := internal.ExtractExtraProperties(data, *v) + if err != nil { + return err + } + v.extraProperties = extraProperties + v.rawJSON = json.RawMessage(data) return nil } -func (v *V2RerankResponse) String() string { - if len(v._rawJSON) > 0 { - if value, err := core.StringifyJSON(v._rawJSON); err == nil { +func (v *V2RerankResponseResultsItemDocument) String() string { + if len(v.rawJSON) > 0 { + if value, err := internal.StringifyJSON(v.rawJSON); err == nil { return value } } - if value, err := core.StringifyJSON(v); err == nil { + if value, err := internal.StringifyJSON(v); err == nil { return value } return fmt.Sprintf("%#v", v)