Skip to content

Commit

Permalink
estimate images. fix speech
Browse files Browse the repository at this point in the history
  • Loading branch information
sergei-bronnikov committed Aug 7, 2024
1 parent a2c7d57 commit 7c1196f
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 2 deletions.
52 changes: 52 additions & 0 deletions internal/message/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ type estimator interface {
EstimateTotalCost(model string, promptTks, completionTks int) (float64, error)
EstimateEmbeddingsInputCost(model string, tks int) (float64, error)
EstimateChatCompletionPromptTokenCounts(model string, r *goopenai.ChatCompletionRequest) (int, error)
EstimateImagesCost(model, quality, resolution string) (float64, error)
}

type azureEstimator interface {
Expand Down Expand Up @@ -412,6 +413,57 @@ func (h *Handler) decorateEvent(m Message) error {
return errors.New("message data cannot be parsed as event with request and response")
}

if e.Event.Path == "/api/providers/openai/v1/images/generations" {
gir, ok := e.Request.(*goopenai.ImageRequest)
if !ok {
telemetry.Incr("bricksllm.message.handler.decorate_event.event_request_parsing_error", nil, 1)
h.log.Debug("event contains request that cannot be converted to openai image request", zap.Any("data", m.Data))
return errors.New("event request data cannot be parsed as openai image request")
}
if e.Event.Status == http.StatusOK {
cost, err := h.e.EstimateImagesCost(string(gir.Model), string(gir.Quality), string(gir.Size))
if err != nil {
telemetry.Incr("bricksllm.message.handler.decorate_event.estimate_completion_cost_error", nil, 1)
return err
}
e.Event.CostInUsd = cost
}
}

if e.Event.Path == "/api/providers/openai/v1/images/edits" {
eir, ok := e.Request.(*goopenai.ImageEditRequest)
if !ok {
telemetry.Incr("bricksllm.message.handler.decorate_event.event_request_parsing_error", nil, 1)
h.log.Debug("event contains request that cannot be converted to openai image edit request", zap.Any("data", m.Data))
return errors.New("event request data cannot be parsed as openai image edit request")
}
if e.Event.Status == http.StatusOK {
cost, err := h.e.EstimateImagesCost(string(eir.Model), "", string(eir.Size))
if err != nil {
telemetry.Incr("bricksllm.message.handler.decorate_event.estimate_completion_cost_error", nil, 1)
return err
}
e.Event.CostInUsd = cost
}
}

if e.Event.Path == "/api/providers/openai/v1/images/variations" {
vir, ok := e.Request.(*goopenai.ImageVariRequest)
if !ok {
telemetry.Incr("bricksllm.message.handler.decorate_event.event_request_parsing_error", nil, 1)
h.log.Debug("event contains request that cannot be converted to openai image variation request", zap.Any("data", m.Data))
return errors.New("event request data cannot be parsed as openai image variation request")
}
if e.Event.Status == http.StatusOK {
cost, err := h.e.EstimateImagesCost(string(vir.Model), "", string(vir.Size))
if err != nil {
telemetry.Incr("bricksllm.message.handler.decorate_event.estimate_completion_cost_error", nil, 1)
return err
}
e.Event.CostInUsd = cost
}
}

if e.Event.Path == "/api/providers/openai/v1/audio/speech" {
csr, ok := e.Request.(*goopenai.CreateSpeechRequest)
if !ok {
Expand Down
105 changes: 105 additions & 0 deletions internal/provider/openai/cost.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"math"
"slices"
"strings"

"github.com/bricks-cloud/bricksllm/internal/util"
Expand Down Expand Up @@ -127,6 +128,18 @@ var OpenAiPerThousandTokenCost = map[string]map[string]float64{
"finetune-babbage-002": 0.0016,
"finetune-davinci-002": 0.012,
},
"images": {
"dall-e-2": 0.02,
"dall-e-2-256": 0.016,
"dall-e-2-512": 0.018,
"dall-e-2-1024": 0.02,

"dall-e-3": 0.04,
"dall-e-3-1024-standart": 0.04,
"dall-e-3-1792-standart": 0.08,
"dall-e-3-1024-hd": 0.08,
"dall-e-3-1792-hd": 0.12,
},
}

type tokenCounter interface {
Expand Down Expand Up @@ -292,6 +305,98 @@ func (ce *CostEstimator) EstimateCompletionsStreamCostWithTokenCounts(model stri
return tks, cost, nil
}

func (ce *CostEstimator) EstimateImagesCost(model, quality, resolution string) (float64, error) {
simpleRes, err := convertResToSimple(resolution)
if err != nil {
return 0, err
}
var normalizedModel string
switch model {
case "dall-e-2":
normalizedModel, err = prepareDallE2Model(simpleRes, model)
if err != nil {
return 0, err
}
case "dall-e-3":
normalizedModel, err = prepareDallE3Model(quality, simpleRes, model)
if err != nil {
return 0, err
}
default:
return 0, errors.New("model is not present in the images cost map")
}

costMap, ok := ce.tokenCostMap["images"]
if !ok {
return 0, errors.New("images cost map is not provided")
}
cost, ok := costMap[normalizedModel]
if !ok {
return 0, errors.New("model is not present in the images cost map")
}
return cost, nil
}

var allowedDallE2Resolutions = []string{"256", "512", "1024"}
var allowedDallE3Resolutions = []string{"1024", "1792"}
var allowedDallE3Qualities = []string{"standart", "hd"}

func convertResToSimple(resolution string) (string, error) {
if resolution == "" {
return "", nil
}
if strings.Contains(resolution, "1792") {
return "1792", nil
}
if strings.Contains(resolution, "1024") {
return "1024", nil
}
if strings.Contains(resolution, "512") {
return "512", nil
}
if strings.Contains(resolution, "256") {
return "256", nil
}
return "", errors.New("resolution is not valid")
}

func prepareDallE2Model(resolution, model string) (string, error) {
if resolution == "" {
return model, nil
}
if slices.Contains(allowedDallE2Resolutions, resolution) {
return fmt.Sprintf("%s-%s", model, resolution), nil
}
return "", errors.New("resolution is not valid")
}

func prepareDallE3Model(quality, resolution, model string) (string, error) {
preparedQuality, err := prepareDallE3Quality(quality)
if err != nil {
return "", err
}
if resolution == "" && quality == "" {
return model, nil
}
if resolution == "" {
return fmt.Sprintf("%s-%s-%s", model, "1024", preparedQuality), nil
}
if slices.Contains(allowedDallE3Resolutions, resolution) {
return fmt.Sprintf("%s-%s-%s", model, resolution, preparedQuality), nil
}
return "", errors.New("resolution is not valid")
}

func prepareDallE3Quality(quality string) (string, error) {
if quality != "" && !slices.Contains(allowedDallE3Qualities, quality) {
return "", errors.New("quality is not valid")
}
if quality == "" {
return "standart", nil
}
return quality, nil
}

func (ce *CostEstimator) EstimateTranscriptionCost(secs float64, model string) (float64, error) {
costMap, ok := ce.tokenCostMap["audio"]
if !ok {
Expand Down
17 changes: 17 additions & 0 deletions internal/server/web/proxy/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,7 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag
logError(logWithCid, "error when unmarshalling create image request", prod, err)
return
}
enrichedEvent.Request = ir

c.Set("model", ir.Model)

Expand All @@ -759,6 +760,14 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag
}

if c.FullPath() == "/api/providers/openai/v1/images/edits" && c.Request.Method == http.MethodPost {
ier := &goopenai.ImageEditRequest{}
err := json.Unmarshal(body, ier)
if err != nil {
logError(logWithCid, "error when unmarshalling edit image request", prod, err)
return
}
enrichedEvent.Request = ier

prompt := c.PostForm("model")
model := c.PostForm("model")
size := c.PostForm("size")
Expand All @@ -779,6 +788,14 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag
}

if c.FullPath() == "/api/providers/openai/v1/images/variations" && c.Request.Method == http.MethodPost {
ir := &goopenai.ImageVariRequest{}
err := json.Unmarshal(body, ir)
if err != nil {
logError(logWithCid, "error when unmarshalling image variations request", prod, err)
return
}
enrichedEvent.Request = ir

model := c.PostForm("model")
size := c.PostForm("size")
user := c.PostForm("user")
Expand Down
4 changes: 2 additions & 2 deletions internal/storage/postgresql/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -946,8 +946,8 @@ func (s *Store) GetEventsV2(req *event.EventRequest) (*event.EventResponse, erro
}

func isJSON(str string) bool {
var js json.RawMessage
return json.Unmarshal([]byte(str), &js) == nil
var js json.RawMessage
return json.Unmarshal([]byte(str), &js) == nil
}

func (s *Store) InsertEvent(e *event.Event) error {
Expand Down

0 comments on commit 7c1196f

Please sign in to comment.