From 7001d142131b1fe82a94435e7b720be56e0b16ff Mon Sep 17 00:00:00 2001 From: Spike Lu Date: Tue, 6 Feb 2024 10:48:52 -0800 Subject: [PATCH] integrate event driven architecture --- cmd/bricksllm/main.go | 26 +- cmd/tool/main.go | 269 ------------------ internal/server/web/proxy/anthropic.go | 97 +++---- .../server/web/proxy/azure_chat_completion.go | 56 ++-- internal/server/web/proxy/azure_embedding.go | 27 +- internal/server/web/proxy/custom_provider.go | 26 +- internal/server/web/proxy/middleware.go | 200 ++++++------- internal/server/web/proxy/proxy.go | 86 +++--- internal/server/web/proxy/route.go | 20 +- 9 files changed, 286 insertions(+), 521 deletions(-) delete mode 100644 cmd/tool/main.go diff --git a/cmd/bricksllm/main.go b/cmd/bricksllm/main.go index c8dba0d..69d3354 100644 --- a/cmd/bricksllm/main.go +++ b/cmd/bricksllm/main.go @@ -14,6 +14,7 @@ import ( "github.com/bricks-cloud/bricksllm/internal/config" "github.com/bricks-cloud/bricksllm/internal/logger/zap" "github.com/bricks-cloud/bricksllm/internal/manager" + "github.com/bricks-cloud/bricksllm/internal/message" "github.com/bricks-cloud/bricksllm/internal/provider/anthropic" "github.com/bricks-cloud/bricksllm/internal/provider/azure" "github.com/bricks-cloud/bricksllm/internal/provider/custom" @@ -171,10 +172,23 @@ func main() { log.Sugar().Fatalf("error connecting to api redis cache: %v", err) } + accessRedisCache := redis.NewClient(&redis.Options{ + Addr: fmt.Sprintf("%s:%s", cfg.RedisHosts, cfg.RedisPort), + Password: cfg.RedisPassword, + DB: 4, + }) + + ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if err := apiRedisCache.Ping(ctx).Err(); err != nil { + log.Sugar().Fatalf("error connecting to api redis cache: %v", err) + } + rateLimitCache := redisStorage.NewCache(rateLimitRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) costLimitCache := redisStorage.NewCache(costLimitRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) costStorage := redisStorage.NewStore(costRedisStorage, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) apiCache := redisStorage.NewCache(apiRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) + accessCache := redisStorage.NewAccessCache(accessRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) m := manager.NewManager(store) krm := manager.NewReportingManager(costStorage, store, store) @@ -209,7 +223,16 @@ func main() { c := cache.NewCache(apiCache) - ps, err := proxy.NewProxyServer(log, *modePtr, *privacyPtr, c, m, rm, a, psm, cpm, store, memStore, ce, ace, aoe, v, rec, rlm, cfg.ProxyTimeout) + messageBus := message.NewMessageBus() + eventMessageChan := make(chan message.Message) + messageBus.Subscribe("event", eventMessageChan) + + handler := message.NewHandler(rec, log, ace, ce, aoe, v, m, rlm, accessCache) + + eventConsumer := message.NewConsumer(eventMessageChan, log, 4, handler.HandleEventWithRequestAndResponse) + eventConsumer.StartEventMessageConsumers() + + ps, err := proxy.NewProxyServer(log, *modePtr, *privacyPtr, c, m, rm, a, psm, cpm, store, memStore, ce, ace, aoe, v, rec, messageBus, rlm, cfg.ProxyTimeout, accessCache) if err != nil { log.Sugar().Fatalf("error creating proxy http server: %v", err) } @@ -220,6 +243,7 @@ func main() { signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) <-quit + eventConsumer.Stop() memStore.Stop() psMemStore.Stop() cpMemStore.Stop() diff --git a/cmd/tool/main.go b/cmd/tool/main.go deleted file mode 100644 index 11158b8..0000000 --- a/cmd/tool/main.go +++ /dev/null @@ -1,269 +0,0 @@ -package main - -import ( - "context" - "flag" - "fmt" - "os" - "os/signal" - "syscall" - "time" - - auth "github.com/bricks-cloud/bricksllm/internal/authenticator" - "github.com/bricks-cloud/bricksllm/internal/cache" - "github.com/bricks-cloud/bricksllm/internal/config" - logger "github.com/bricks-cloud/bricksllm/internal/logger/zap" - "github.com/bricks-cloud/bricksllm/internal/manager" - "github.com/bricks-cloud/bricksllm/internal/provider/anthropic" - "github.com/bricks-cloud/bricksllm/internal/provider/azure" - "github.com/bricks-cloud/bricksllm/internal/provider/custom" - "github.com/bricks-cloud/bricksllm/internal/provider/openai" - "github.com/bricks-cloud/bricksllm/internal/recorder" - "github.com/bricks-cloud/bricksllm/internal/server/web/admin" - "github.com/bricks-cloud/bricksllm/internal/server/web/proxy" - "github.com/bricks-cloud/bricksllm/internal/stats" - "github.com/bricks-cloud/bricksllm/internal/storage/memdb" - "github.com/bricks-cloud/bricksllm/internal/storage/postgresql" - redisStorage "github.com/bricks-cloud/bricksllm/internal/storage/redis" - "github.com/bricks-cloud/bricksllm/internal/validator" - "github.com/gin-gonic/gin" - "github.com/redis/go-redis/v9" -) - -func main() { - modePtr := flag.String("m", "dev", "select the mode that bricksllm runs in") - privacyPtr := flag.String("p", "strict", "select the privacy mode that bricksllm runs in") - flag.Parse() - - log := logger.NewZapLogger(*modePtr) - - gin.SetMode(gin.ReleaseMode) - - cfg, err := config.ParseEnvVariables() - if err != nil { - log.Sugar().Fatalf("cannot parse environment variables: %v", err) - } - - err = stats.InitializeClient(cfg.StatsProvider) - if err != nil { - log.Sugar().Fatalf("cannot connect to telemetry provider: %v", err) - } - - store, err := postgresql.NewStore( - fmt.Sprintf("postgresql:///%s?sslmode=%s&user=%s&password=%s&host=%s&port=%s", cfg.PostgresqlDbName, cfg.PostgresqlSslMode, cfg.PostgresqlUsername, cfg.PostgresqlPassword, cfg.PostgresqlHosts, cfg.PostgresqlPort), - cfg.PostgresqlWriteTimeout, - cfg.PostgresqlReadTimeout, - ) - - if err != nil { - log.Sugar().Fatalf("cannot connect to postgresql: %v", err) - } - - err = store.CreateCustomProvidersTable() - if err != nil { - log.Sugar().Fatalf("error creating custom providers table: %v", err) - } - - err = store.CreateRoutesTable() - if err != nil { - log.Sugar().Fatalf("error creating routes table: %v", err) - } - - err = store.CreateKeysTable() - if err != nil { - log.Sugar().Fatalf("error creating keys table: %v", err) - } - - err = store.AlterKeysTable() - if err != nil { - log.Sugar().Fatalf("error altering keys table: %v", err) - } - - err = store.CreateEventsTable() - if err != nil { - log.Sugar().Fatalf("error creating events table: %v", err) - } - - err = store.AlterEventsTable() - if err != nil { - log.Sugar().Fatalf("error altering events table: %v", err) - } - - err = store.CreateProviderSettingsTable() - if err != nil { - log.Sugar().Fatalf("error creating provider settings table: %v", err) - } - - err = store.AlterProviderSettingsTable() - if err != nil { - log.Sugar().Fatalf("error altering provider settings table: %v", err) - } - - memStore, err := memdb.NewMemDb(store, log, cfg.InMemoryDbUpdateInterval) - if err != nil { - log.Sugar().Fatalf("cannot initialize memdb: %v", err) - } - memStore.Listen() - - psMemStore, err := memdb.NewProviderSettingsMemDb(store, log, cfg.InMemoryDbUpdateInterval) - if err != nil { - log.Sugar().Fatalf("cannot initialize provider settings memdb: %v", err) - } - psMemStore.Listen() - - cpMemStore, err := memdb.NewCustomProvidersMemDb(store, log, cfg.InMemoryDbUpdateInterval) - if err != nil { - log.Sugar().Fatalf("cannot initialize custom providers memdb: %v", err) - } - cpMemStore.Listen() - - rMemStore, err := memdb.NewRoutesMemDb(store, log, cfg.InMemoryDbUpdateInterval) - if err != nil { - log.Sugar().Fatalf("cannot initialize routes memdb: %v", err) - } - rMemStore.Listen() - - rateLimitRedisCache := redis.NewClient(&redis.Options{ - Addr: fmt.Sprintf("%s:%s", cfg.RedisHosts, cfg.RedisPort), - Password: cfg.RedisPassword, - DB: 0, - }) - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - if err := rateLimitRedisCache.Ping(ctx).Err(); err != nil { - log.Sugar().Fatalf("error connecting to rate limit redis cache: %v", err) - } - - costLimitRedisCache := redis.NewClient(&redis.Options{ - Addr: fmt.Sprintf("%s:%s", cfg.RedisHosts, cfg.RedisPort), - Password: cfg.RedisPassword, - DB: 1, - }) - - ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - if err := costLimitRedisCache.Ping(ctx).Err(); err != nil { - log.Sugar().Fatalf("error connecting to cost limit redis cache: %v", err) - } - - costRedisStorage := redis.NewClient(&redis.Options{ - Addr: fmt.Sprintf("%s:%s", cfg.RedisHosts, cfg.RedisPort), - Password: cfg.RedisPassword, - DB: 2, - }) - - ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - if err := costRedisStorage.Ping(ctx).Err(); err != nil { - log.Sugar().Fatalf("error connecting to cost limit redis storage: %v", err) - } - - apiRedisCache := redis.NewClient(&redis.Options{ - Addr: fmt.Sprintf("%s:%s", cfg.RedisHosts, cfg.RedisPort), - Password: cfg.RedisPassword, - DB: 3, - }) - - ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - if err := apiRedisCache.Ping(ctx).Err(); err != nil { - log.Sugar().Fatalf("error connecting to api redis cache: %v", err) - } - - rateLimitCache := redisStorage.NewCache(rateLimitRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) - costLimitCache := redisStorage.NewCache(costLimitRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) - costStorage := redisStorage.NewStore(costRedisStorage, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) - apiCache := redisStorage.NewCache(apiRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) - - m := manager.NewManager(store) - krm := manager.NewReportingManager(costStorage, store, store) - psm := manager.NewProviderSettingsManager(store, psMemStore) - cpm := manager.NewCustomProvidersManager(store, cpMemStore) - rm := manager.NewRouteManager(store, store, rMemStore, psMemStore) - - as, err := admin.NewAdminServer(log, *modePtr, m, krm, psm, cpm, rm, cfg.AdminPass) - if err != nil { - log.Sugar().Fatalf("error creating admin http server: %v", err) - } - as.Run() - - tc := openai.NewTokenCounter() - custom.NewTokenCounter() - atc, err := anthropic.NewTokenCounter() - if err != nil { - log.Sugar().Fatalf("error creating anthropic token counter: %v", err) - } - - ae := anthropic.NewCostEstimator(atc) - - ce := openai.NewCostEstimator(openai.OpenAiPerThousandTokenCost, tc) - v := validator.NewValidator(costLimitCache, rateLimitCache, costStorage) - rec := recorder.NewRecorder(costStorage, costLimitCache, ce, store) - rlm := manager.NewRateLimitManager(rateLimitCache) - a := auth.NewAuthenticator(psm, memStore, rm) - - c := cache.NewCache(apiCache) - - aoe := azure.NewCostEstimator() - - ps, err := proxy.NewProxyServer(log, *modePtr, *privacyPtr, c, m, rm, a, psm, cpm, store, memStore, ce, ae, aoe, v, rec, rlm, cfg.ProxyTimeout) - if err != nil { - log.Sugar().Fatalf("error creating proxy http server: %v", err) - } - - ps.Run() - - quit := make(chan os.Signal) - signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) - <-quit - - memStore.Stop() - psMemStore.Stop() - cpMemStore.Stop() - - log.Sugar().Info("shutting down server...") - - ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - if err := as.Shutdown(ctx); err != nil { - log.Sugar().Debugf("admin server shutdown: %v", err) - } - - ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - if err := ps.Shutdown(ctx); err != nil { - log.Sugar().Debugf("proxy server shutdown: %v", err) - } - - select { - case <-ctx.Done(): - log.Sugar().Infof("timeout of 5 seconds") - } - - err = store.DropKeysTable() - if err != nil { - log.Sugar().Fatalf("error dropping keys table: %v", err) - } - - err = store.DropEventsTable() - if err != nil { - log.Sugar().Fatalf("error dropping events table: %v", err) - } - - err = store.DropCustomProvidersTable() - if err != nil { - log.Sugar().Fatalf("error dropping custom providers table: %v", err) - } - - err = store.DropProviderSettingsTable() - if err != nil { - log.Sugar().Fatalf("error dropping provider settings table: %v", err) - } - - err = store.DropRoutesTable() - if err != nil { - log.Sugar().Fatalf("error dropping routes table: %v", err) - } - - log.Sugar().Infof("server exited") -} diff --git a/internal/server/web/proxy/anthropic.go b/internal/server/web/proxy/anthropic.go index a9aab8a..7f6a607 100644 --- a/internal/server/web/proxy/anthropic.go +++ b/internal/server/web/proxy/anthropic.go @@ -11,7 +11,6 @@ import ( "strings" "time" - "github.com/bricks-cloud/bricksllm/internal/key" "github.com/bricks-cloud/bricksllm/internal/provider/anthropic" "github.com/bricks-cloud/bricksllm/internal/stats" "github.com/gin-gonic/gin" @@ -61,13 +60,13 @@ func getCompletionHandler(r recorder, prod, private bool, client http.Client, km return } - raw, exists := c.Get("key") - kc, ok := raw.(*key.ResponseKey) - if !exists || !ok { - stats.Incr("bricksllm.proxy.get_completion_handler.api_key_not_registered", nil, 1) - JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered") - return - } + // raw, exists := c.Get("key") + // kc, ok := raw.(*key.ResponseKey) + // if !exists || !ok { + // stats.Incr("bricksllm.proxy.get_completion_handler.api_key_not_registered", nil, 1) + // JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered") + // return + // } copyHttpHeaders(c.Request, req) @@ -96,7 +95,7 @@ func getCompletionHandler(r recorder, prod, private bool, client http.Client, km } } - model := c.GetString("model") + // model := c.GetString("model") if !isStreaming && res.StatusCode == http.StatusOK { dur := time.Now().Sub(start) @@ -109,8 +108,8 @@ func getCompletionHandler(r recorder, prod, private bool, client http.Client, km return } - var cost float64 = 0 - completionTokens := 0 + // var cost float64 = 0 + // completionTokens := 0 completionRes := &anthropic.CompletionResponse{} stats.Incr("bricksllm.proxy.get_completion_handler.success", nil, 1) stats.Timing("bricksllm.proxy.get_completion_handler.success_latency", dur, nil, 1) @@ -120,27 +119,29 @@ func getCompletionHandler(r recorder, prod, private bool, client http.Client, km logError(log, "error when unmarshalling anthropic http completion response body", prod, cid, err) } - if err == nil { - logCompletionResponse(log, bytes, prod, private, cid) - completionTokens = e.Count(completionRes.Completion) - completionTokens += anthropicCompletionMagicNum - promptTokens := c.GetInt("promptTokenCount") - cost, err = e.EstimateTotalCost(model, promptTokens, completionTokens) - if err != nil { - stats.Incr("bricksllm.proxy.get_completion_handler.estimate_total_cost_error", nil, 1) - logError(log, "error when estimating anthropic cost", prod, cid, err) - } - - micros := int64(cost * 1000000) - err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit) - if err != nil { - stats.Incr("bricksllm.proxy.get_completion_handler.record_key_spend_error", nil, 1) - logError(log, "error when recording anthropic spend", prod, cid, err) - } - } - - c.Set("costInUsd", cost) - c.Set("completionTokenCount", completionTokens) + c.Set("content", completionRes.Completion) + + // if err == nil { + // logCompletionResponse(log, bytes, prod, private, cid) + // completionTokens = e.Count(completionRes.Completion) + // completionTokens += anthropicCompletionMagicNum + // promptTokens := c.GetInt("promptTokenCount") + // cost, err = e.EstimateTotalCost(model, promptTokens, completionTokens) + // if err != nil { + // stats.Incr("bricksllm.proxy.get_completion_handler.estimate_total_cost_error", nil, 1) + // logError(log, "error when estimating anthropic cost", prod, cid, err) + // } + + // micros := int64(cost * 1000000) + // err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit) + // if err != nil { + // stats.Incr("bricksllm.proxy.get_completion_handler.record_key_spend_error", nil, 1) + // logError(log, "error when recording anthropic spend", prod, cid, err) + // } + // } + + // c.Set("costInUsd", cost) + // c.Set("completionTokenCount", completionTokens) c.Data(res.StatusCode, "application/json", bytes) return @@ -163,24 +164,24 @@ func getCompletionHandler(r recorder, prod, private bool, client http.Client, km } buffer := bufio.NewReader(res.Body) - var totalCost float64 = 0 + // var totalCost float64 = 0 content := "" - defer func() { - tks := e.Count(content) - model := c.GetString("model") - cost, err := e.EstimateCompletionCost(model, tks) - if err != nil { - stats.Incr("bricksllm.proxy.get_completion_handler.estimate_completion_cost_error", nil, 1) - logError(log, "error when estimating anthropic completion stream cost", prod, cid, err) - } - - estimatedPromptCost := c.GetFloat64("estimatedPromptCostInUsd") - totalCost = cost + estimatedPromptCost - - c.Set("costInUsd", totalCost) - c.Set("completionTokenCount", tks+anthropicCompletionMagicNum) - }() + // defer func() { + // tks := e.Count(content) + // model := c.GetString("model") + // cost, err := e.EstimateCompletionCost(model, tks) + // if err != nil { + // stats.Incr("bricksllm.proxy.get_completion_handler.estimate_completion_cost_error", nil, 1) + // logError(log, "error when estimating anthropic completion stream cost", prod, cid, err) + // } + + // estimatedPromptCost := c.GetFloat64("estimatedPromptCostInUsd") + // totalCost = cost + estimatedPromptCost + + // c.Set("costInUsd", totalCost) + // c.Set("completionTokenCount", tks+anthropicCompletionMagicNum) + // }() stats.Incr("bricksllm.proxy.get_completion_handler.streaming_requests", nil, 1) diff --git a/internal/server/web/proxy/azure_chat_completion.go b/internal/server/web/proxy/azure_chat_completion.go index 970237a..2807071 100644 --- a/internal/server/web/proxy/azure_chat_completion.go +++ b/internal/server/web/proxy/azure_chat_completion.go @@ -11,7 +11,6 @@ import ( "net/http" "time" - "github.com/bricks-cloud/bricksllm/internal/key" "github.com/bricks-cloud/bricksllm/internal/stats" "github.com/gin-gonic/gin" goopenai "github.com/sashabaranov/go-openai" @@ -36,13 +35,6 @@ func getAzureChatCompletionHandler(r recorder, prod, private bool, psm ProviderS } cid := c.GetString(correlationId) - raw, exists := c.Get("key") - kc, ok := raw.(*key.ResponseKey) - if !exists || !ok { - stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.api_key_not_registered", nil, 1) - JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered") - return - } ctx, cancel := context.WithTimeout(context.Background(), timeOut) defer cancel() @@ -111,12 +103,12 @@ func getAzureChatCompletionHandler(r recorder, prod, private bool, psm ProviderS logError(log, "error when estimating azure openai cost", prod, cid, err) } - micros := int64(cost * 1000000) - err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit) - if err != nil { - stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.record_key_spend_error", nil, 1) - logError(log, "error when recording azure openai spend", prod, cid, err) - } + // micros := int64(cost * 1000000) + // err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit) + // if err != nil { + // stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.record_key_spend_error", nil, 1) + // logError(log, "error when recording azure openai spend", prod, cid, err) + // } } c.Set("costInUsd", cost) @@ -145,8 +137,8 @@ func getAzureChatCompletionHandler(r recorder, prod, private bool, psm ProviderS } buffer := bufio.NewReader(res.Body) - var totalCost float64 = 0 - var totalTokens int = 0 + // var totalCost float64 = 0 + // var totalTokens int = 0 content := "" model := "" @@ -155,24 +147,26 @@ func getAzureChatCompletionHandler(r recorder, prod, private bool, psm ProviderS c.Set("model", model) } - tks, cost, err := aoe.EstimateChatCompletionStreamCostWithTokenCounts(model, content) - if err != nil { - stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.estimate_chat_completion_cost_and_tokens_error", nil, 1) - logError(log, "error when estimating azure openai chat completion stream cost with token counts", prod, cid, err) - } + c.Set("content", content) - estimatedPromptTokenCounts := c.GetInt("promptTokenCount") - promptCost, err := aoe.EstimatePromptCost(model, estimatedPromptTokenCounts) - if err != nil { - stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.estimate_chat_completion_cost_and_tokens_error", nil, 1) - logError(log, "error when estimating azure openai chat completion stream cost with token counts", prod, cid, err) - } + // tks, cost, err := aoe.EstimateChatCompletionStreamCostWithTokenCounts(model, content) + // if err != nil { + // stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.estimate_chat_completion_cost_and_tokens_error", nil, 1) + // logError(log, "error when estimating azure openai chat completion stream cost with token counts", prod, cid, err) + // } + + // estimatedPromptTokenCounts := c.GetInt("promptTokenCount") + // promptCost, err := aoe.EstimatePromptCost(model, estimatedPromptTokenCounts) + // if err != nil { + // stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.estimate_chat_completion_cost_and_tokens_error", nil, 1) + // logError(log, "error when estimating azure openai chat completion stream cost with token counts", prod, cid, err) + // } - totalCost = cost + promptCost - totalTokens += tks + // totalCost = cost + promptCost + // totalTokens += tks - c.Set("costInUsd", totalCost) - c.Set("completionTokenCount", totalTokens) + // c.Set("costInUsd", totalCost) + // c.Set("completionTokenCount", totalTokens) }() stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.streaming_requests", nil, 1) diff --git a/internal/server/web/proxy/azure_embedding.go b/internal/server/web/proxy/azure_embedding.go index c3c2830..f9ac0de 100644 --- a/internal/server/web/proxy/azure_embedding.go +++ b/internal/server/web/proxy/azure_embedding.go @@ -7,7 +7,6 @@ import ( "net/http" "time" - "github.com/bricks-cloud/bricksllm/internal/key" "github.com/bricks-cloud/bricksllm/internal/stats" "github.com/gin-gonic/gin" goopenai "github.com/sashabaranov/go-openai" @@ -23,13 +22,13 @@ func getAzureEmbeddingsHandler(r recorder, prod, private bool, psm ProviderSetti } cid := c.GetString(correlationId) - raw, exists := c.Get("key") - kc, ok := raw.(*key.ResponseKey) - if !exists || !ok { - stats.Incr("bricksllm.proxy.get_azure_embeddings_handler.api_key_not_registered", nil, 1) - JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered") - return - } + // raw, exists := c.Get("key") + // kc, ok := raw.(*key.ResponseKey) + // if !exists || !ok { + // stats.Incr("bricksllm.proxy.get_azure_embeddings_handler.api_key_not_registered", nil, 1) + // JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered") + // return + // } ctx, cancel := context.WithTimeout(context.Background(), timeOut) defer cancel() @@ -111,12 +110,12 @@ func getAzureEmbeddingsHandler(r recorder, prod, private bool, psm ProviderSetti logError(log, "error when estimating azure openai cost for embedding", prod, cid, err) } - micros := int64(cost * 1000000) - err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit) - if err != nil { - stats.Incr("bricksllm.proxy.get_azure_embeddings_handler.record_key_spend_error", nil, 1) - logError(log, "error when recording azure openai spend for embedding", prod, cid, err) - } + // micros := int64(cost * 1000000) + // err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit) + // if err != nil { + // stats.Incr("bricksllm.proxy.get_azure_embeddings_handler.record_key_spend_error", nil, 1) + // logError(log, "error when recording azure openai spend for embedding", prod, cid, err) + // } } } diff --git a/internal/server/web/proxy/custom_provider.go b/internal/server/web/proxy/custom_provider.go index bbc8291..1f9007f 100644 --- a/internal/server/web/proxy/custom_provider.go +++ b/internal/server/web/proxy/custom_provider.go @@ -120,12 +120,14 @@ func getCustomProviderHandler(prod, private bool, psm ProviderSettingsManager, c return } - tks, err := countTokensFromJson(bytes, rc.ResponseCompletionLocation) - if err != nil { - logError(log, "error when counting tokens for custom provider completion response", prod, cid, err) - } + c.Set("response", bytes) + + // tks, err := countTokensFromJson(bytes, rc.ResponseCompletionLocation) + // if err != nil { + // logError(log, "error when counting tokens for custom provider completion response", prod, cid, err) + // } - c.Set("completionTokenCount", tks) + // c.Set("completionTokenCount", tks) c.Data(res.StatusCode, "application/json", bytes) return } @@ -149,13 +151,15 @@ func getCustomProviderHandler(prod, private bool, psm ProviderSettingsManager, c buffer := bufio.NewReader(res.Body) aggregated := "" defer func() { - tks, err := custom.Count(aggregated) - if err != nil { - stats.Incr("bricksllm.proxy.get_custom_provider_handler.count_error", nil, 1) - logError(log, "error when counting tokens for custom provider streaming response", prod, cid, err) - } + c.Set("content", aggregated) + + // tks, err := custom.Count(aggregated) + // if err != nil { + // stats.Incr("bricksllm.proxy.get_custom_provider_handler.count_error", nil, 1) + // logError(log, "error when counting tokens for custom provider streaming response", prod, cid, err) + // } - c.Set("completionTokenCount", tks) + // c.Set("completionTokenCount", tks) }() stats.Incr("bricksllm.proxy.get_custom_provider_handler.streaming_requests", nil, 1) diff --git a/internal/server/web/proxy/middleware.go b/internal/server/web/proxy/middleware.go index fa5916d..e8b268d 100644 --- a/internal/server/web/proxy/middleware.go +++ b/internal/server/web/proxy/middleware.go @@ -11,6 +11,7 @@ import ( "github.com/bricks-cloud/bricksllm/internal/event" "github.com/bricks-cloud/bricksllm/internal/key" + "github.com/bricks-cloud/bricksllm/internal/message" "github.com/bricks-cloud/bricksllm/internal/provider" "github.com/bricks-cloud/bricksllm/internal/provider/anthropic" "github.com/bricks-cloud/bricksllm/internal/route" @@ -72,6 +73,10 @@ type rateLimitManager interface { Increment(keyId string, timeUnit key.TimeUnit) error } +type accessCache interface { + GetAccessStatus(key string) bool +} + type encrypter interface { Encrypt(secret string) string } @@ -93,7 +98,40 @@ type notFoundError interface { NotFound() } -func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManager, a authenticator, prod, private bool, e estimator, ae anthropicEstimator, aoe azureEstimator, v validator, ks keyStorage, log *zap.Logger, rlm rateLimitManager, r recorder, prefix string) gin.HandlerFunc { +type publisher interface { + Publish(message.Message) +} + +func getProvider(c *gin.Context) string { + existing := c.GetString("provider") + if len(existing) != 0 { + return existing + } + + parts := strings.Split(c.FullPath(), "/") + + spaceRemoved := []string{} + + for _, part := range parts { + if len(part) != 0 { + spaceRemoved = append(spaceRemoved, part) + } + } + + if strings.HasPrefix(c.FullPath(), "/api/providers/") { + if len(spaceRemoved) >= 3 { + return spaceRemoved[2] + } + } + + if strings.HasPrefix(c.FullPath(), "/api/custom/providers/") { + return c.Param("provider") + } + + return "" +} + +func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManager, a authenticator, prod, private bool, e estimator, ae anthropicEstimator, aoe azureEstimator, v validator, ks keyStorage, log *zap.Logger, rlm rateLimitManager, pub publisher, prefix string, ac accessCache) gin.HandlerFunc { return func(c *gin.Context) { if c == nil || c.Request == nil { JSON(c, http.StatusInternalServerError, "[BricksLLM] request is empty") @@ -110,17 +148,12 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage c.Set(correlationId, cid) start := time.Now() - selectedProvider := "openai" + enrichedEvent := &event.EventWithRequestAndContent{} customId := c.Request.Header.Get("X-CUSTOM-EVENT-ID") defer func() { dur := time.Now().Sub(start) latency := int(dur.Milliseconds()) - raw, exists := c.Get("key") - var kc *key.ResponseKey - if exists { - kc = raw.(*key.ResponseKey) - } if !prod { log.Sugar().Infof("%s | %d | %s | %s | %dms", prefix, c.Writer.Status(), c.Request.Method, c.FullPath(), latency) @@ -129,16 +162,14 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage keyId := "" tags := []string{} - if kc != nil { - keyId = kc.KeyId - tags = kc.Tags + if enrichedEvent.Key != nil { + keyId = enrichedEvent.Key.KeyId + tags = enrichedEvent.Key.Tags } stats.Timing("bricksllm.proxy.get_middleware.proxy_latency_in_ms", dur, nil, 1) - if len(c.GetString("provider")) != 0 { - selectedProvider = c.GetString("provider") - } + selectedProvider := getProvider(c) if prod { log.Info("response to proxy", @@ -173,12 +204,16 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage CustomId: customId, } - err := r.RecordEvent(evt) - if err != nil { - stats.Incr("bricksllm.proxy.get_middleware.record_event_error", nil, 1) - - logError(log, "error when recording openai event", prod, cid, err) + enrichedEvent.Event = evt + content := c.GetString("content") + if len(content) != 0 { + enrichedEvent.Content = content } + + pub.Publish(message.Message{ + Type: "event", + Data: enrichedEvent, + }) }() if len(c.FullPath()) == 0 { @@ -189,6 +224,7 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage } kc, settings, err := a.AuthenticateHttpRequest(c.Request) + enrichedEvent.Key = kc _, ok := err.(notAuthorizedError) if ok { stats.Incr("bricksllm.proxy.get_middleware.authentication_error", nil, 1) @@ -236,12 +272,11 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage c.Request.Body = io.NopCloser(bytes.NewReader(body)) } - var cost float64 = 0 + // var cost float64 = 0 if c.FullPath() == "/api/providers/anthropic/v1/complete" { logCompletionRequest(log, body, prod, private, cid) - selectedProvider = "anthropic" cr := &anthropic.CompletionRequest{} err = json.Unmarshal(body, cr) if err != nil { @@ -249,19 +284,21 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage return } - tks := ae.Count(cr.Prompt) - tks += anthropicPromptMagicNum - c.Set("promptTokenCount", tks) + enrichedEvent.Request = cr - model := cr.Model - cost, err = ae.EstimatePromptCost(model, tks) - if err != nil { - logError(log, "error when estimating anthropic completion prompt cost", prod, cid, err) - } + // tks := ae.Count(cr.Prompt) + // tks += anthropicPromptMagicNum + // c.Set("promptTokenCount", tks) + + // model := cr.Model + // cost, err = ae.EstimatePromptCost(model, tks) + // if err != nil { + // logError(log, "error when estimating anthropic completion prompt cost", prod, cid, err) + // } if cr.Stream { c.Set("stream", cr.Stream) - c.Set("estimatedPromptCostInUsd", cost) + // c.Set("estimatedPromptCostInUsd", cost) } if len(cr.Model) != 0 { @@ -288,17 +325,22 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage return } - selectedProvider = cp.Provider - c.Set("provider", cp) c.Set("route_config", rc) - tks, err := countTokensFromJson(body, rc.RequestPromptLocation) - if err != nil { - logError(log, "error when counting tokens for custom provider request", prod, cid, err) + enrichedEvent.Request = body + + customResponse, ok := c.Get("response") + if ok { + enrichedEvent.Response = customResponse } - c.Set("promptTokenCount", tks) + // tks, err := countTokensFromJson(body, rc.RequestPromptLocation) + // if err != nil { + // logError(log, "error when counting tokens for custom provider request", prod, cid, err) + // } + + // c.Set("promptTokenCount", tks) result := gjson.Get(string(body), rc.StreamLocation) @@ -344,12 +386,15 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage if !rc.ShouldRunEmbeddings() { ccr := &goopenai.ChatCompletionRequest{} + err = json.Unmarshal(body, ccr) if err != nil { logError(log, "error when unmarshalling route chat completion request", prod, cid, err) return } + enrichedEvent.Request = ccr + logRequest(log, prod, private, cid, ccr) if ccr.Stream { @@ -366,8 +411,6 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage } if c.FullPath() == "/api/providers/azure/openai/deployments/:deployment_id/chat/completions" { - selectedProvider = "azure" - ccr := &goopenai.ChatCompletionRequest{} err = json.Unmarshal(body, ccr) if err != nil { @@ -375,23 +418,23 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage return } + enrichedEvent.Request = ccr + logRequest(log, prod, private, cid, ccr) - tks, err := e.EstimateChatCompletionPromptTokenCounts("gpt-3.5-turbo", ccr) - if err != nil { - stats.Incr("bricksllm.proxy.get_middleware.estimate_chat_completion_prompt_token_counts_error", nil, 1) - logError(log, "error when estimating prompt cost", prod, cid, err) - } + // tks, err := e.EstimateChatCompletionPromptTokenCounts("gpt-3.5-turbo", ccr) + // if err != nil { + // stats.Incr("bricksllm.proxy.get_middleware.estimate_chat_completion_prompt_token_counts_error", nil, 1) + // logError(log, "error when estimating prompt cost", prod, cid, err) + // } if ccr.Stream { c.Set("stream", true) - c.Set("promptTokenCount", tks) + // c.Set("promptTokenCount", tks) } } if c.FullPath() == "/api/providers/azure/openai/deployments/:deployment_id/embeddings" { - selectedProvider = "azure" - er := &goopenai.EmbeddingRequest{} err = json.Unmarshal(body, er) if err != nil { @@ -404,11 +447,11 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage logEmbeddingRequest(log, prod, private, cid, er) - cost, err = aoe.EstimateEmbeddingsCost(er) - if err != nil { - stats.Incr("bricksllm.proxy.get_middleware.estimate_azure_openai_embeddings_cost_error", nil, 1) - logError(log, "error when estimating azure openai embeddings cost", prod, cid, err) - } + // cost, err = aoe.EstimateEmbeddingsCost(er) + // if err != nil { + // stats.Incr("bricksllm.proxy.get_middleware.estimate_azure_openai_embeddings_cost_error", nil, 1) + // logError(log, "error when estimating azure openai embeddings cost", prod, cid, err) + // } } if c.FullPath() == "/api/providers/openai/v1/chat/completions" { @@ -419,6 +462,8 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage return } + enrichedEvent.Request = ccr + c.Set("model", ccr.Model) logRequest(log, prod, private, cid, ccr) @@ -445,16 +490,16 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage return } - c.Set("model", er.Model.String()) + c.Set("model", string(er.Model)) c.Set("encoding_format", string(er.EncodingFormat)) logEmbeddingRequest(log, prod, private, cid, er) - cost, err = e.EstimateEmbeddingsCost(er) - if err != nil { - stats.Incr("bricksllm.proxy.get_middleware.estimate_embeddings_cost_error", nil, 1) - logError(log, "error when estimating embeddings cost", prod, cid, err) - } + // cost, err = e.EstimateEmbeddingsCost(er) + // if err != nil { + // stats.Incr("bricksllm.proxy.get_middleware.estimate_embeddings_cost_error", nil, 1) + // logError(log, "error when estimating embeddings cost", prod, cid, err) + // } } if c.FullPath() == "/api/providers/openai/v1/images/generations" && c.Request.Method == http.MethodPost { @@ -735,48 +780,13 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage logRetrieveFileContentRequest(log, body, prod, cid, fid) } - err = v.Validate(kc, cost) - if err != nil { - stats.Incr("bricksllm.proxy.get_middleware.validation_error", nil, 1) - - if _, ok := err.(expirationError); ok { - stats.Incr("bricksllm.proxy.get_middleware.key_expired", nil, 1) - - truePtr := true - _, err = ks.UpdateKey(kc.KeyId, &key.UpdateKey{ - Revoked: &truePtr, - RevokedReason: "Key has expired or exceeded set spend limit", - }) - - if err != nil { - stats.Incr("bricksllm.proxy.get_middleware.update_key_error", nil, 1) - log.Sugar().Debugf("error when updating revoking the api key %s: %v", kc.KeyId, err) - } - - JSON(c, http.StatusUnauthorized, "[BricksLLM] key has expired") - c.Abort() - return - } - - if _, ok := err.(rateLimitError); ok { - stats.Incr("bricksllm.proxy.get_middleware.rate_limited", nil, 1) - JSON(c, http.StatusTooManyRequests, "[BricksLLM] too many requests") - c.Abort() - return - } - - logError(log, "error when validating api key", prod, cid, err) + if ac.GetAccessStatus(kc.KeyId) { + stats.Incr("bricksllm.proxy.get_middleware.rate_limited", nil, 1) + JSON(c, http.StatusTooManyRequests, "[BricksLLM] too many requests") + c.Abort() return } - if len(kc.RateLimitUnit) != 0 { - if err := rlm.Increment(kc.KeyId, kc.RateLimitUnit); err != nil { - stats.Incr("bricksllm.proxy.get_middleware.rate_limit_increment_error", nil, 1) - - logError(log, "error when incrementing rate limit counter", prod, cid, err) - } - } - c.Next() } } diff --git a/internal/server/web/proxy/proxy.go b/internal/server/web/proxy/proxy.go index dbcb5eb..221dac8 100644 --- a/internal/server/web/proxy/proxy.go +++ b/internal/server/web/proxy/proxy.go @@ -39,7 +39,7 @@ type ProxyServer struct { } type recorder interface { - RecordKeySpend(keyId string, micros int64, costLimitUnit key.TimeUnit) error + // RecordKeySpend(keyId string, micros int64, costLimitUnit key.TimeUnit) error RecordEvent(e *event.Event) error } @@ -55,12 +55,12 @@ type CustomProvidersManager interface { GetCustomProviderFromMem(name string) *custom.Provider } -func NewProxyServer(log *zap.Logger, mode, privacyMode string, c cache, m KeyManager, rm routeManager, a authenticator, psm ProviderSettingsManager, cpm CustomProvidersManager, ks keyStorage, kms keyMemStorage, e estimator, ae anthropicEstimator, aoe azureEstimator, v validator, r recorder, rlm rateLimitManager, timeOut time.Duration) (*ProxyServer, error) { +func NewProxyServer(log *zap.Logger, mode, privacyMode string, c cache, m KeyManager, rm routeManager, a authenticator, psm ProviderSettingsManager, cpm CustomProvidersManager, ks keyStorage, kms keyMemStorage, e estimator, ae anthropicEstimator, aoe azureEstimator, v validator, r recorder, pub publisher, rlm rateLimitManager, timeOut time.Duration, ac accessCache) (*ProxyServer, error) { router := gin.New() prod := mode == "production" private := privacyMode == "strict" - router.Use(getMiddleware(kms, cpm, rm, a, prod, private, e, ae, aoe, v, ks, log, rlm, r, "proxy")) + router.Use(getMiddleware(kms, cpm, rm, a, prod, private, e, ae, aoe, v, ks, log, rlm, pub, "proxy", ac)) client := http.Client{} @@ -942,13 +942,13 @@ func getEmbeddingHandler(r recorder, prod, private bool, psm ProviderSettingsMan return } - raw, exists := c.Get("key") - kc, ok := raw.(*key.ResponseKey) - if !exists || !ok { - stats.Incr("bricksllm.proxy.get_embedding_handler.api_key_not_registered", nil, 1) - JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered") - return - } + // raw, exists := c.Get("key") + // kc, ok := raw.(*key.ResponseKey) + // if !exists || !ok { + // stats.Incr("bricksllm.proxy.get_embedding_handler.api_key_not_registered", nil, 1) + // JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered") + // return + // } id := c.GetString(correlationId) @@ -1032,12 +1032,12 @@ func getEmbeddingHandler(r recorder, prod, private bool, psm ProviderSettingsMan logError(log, "error when estimating openai cost for embedding", prod, id, err) } - micros := int64(cost * 1000000) - err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit) - if err != nil { - stats.Incr("bricksllm.proxy.get_embedding_handler.record_key_spend_error", nil, 1) - logError(log, "error when recording openai spend for embedding", prod, id, err) - } + // micros := int64(cost * 1000000) + // err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit) + // if err != nil { + // stats.Incr("bricksllm.proxy.get_embedding_handler.record_key_spend_error", nil, 1) + // logError(log, "error when recording openai spend for embedding", prod, id, err) + // } } } @@ -1085,13 +1085,13 @@ func getChatCompletionHandler(r recorder, prod, private bool, psm ProviderSettin } cid := c.GetString(correlationId) - raw, exists := c.Get("key") - kc, ok := raw.(*key.ResponseKey) - if !exists || !ok { - stats.Incr("bricksllm.proxy.get_chat_completion_handler.api_key_not_registered", nil, 1) - JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered") - return - } + // raw, exists := c.Get("key") + // kc, ok := raw.(*key.ResponseKey) + // if !exists || !ok { + // stats.Incr("bricksllm.proxy.get_chat_completion_handler.api_key_not_registered", nil, 1) + // JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered") + // return + // } ctx, cancel := context.WithTimeout(context.Background(), timeOut) defer cancel() @@ -1161,12 +1161,12 @@ func getChatCompletionHandler(r recorder, prod, private bool, psm ProviderSettin logError(log, "error when estimating openai cost", prod, cid, err) } - micros := int64(cost * 1000000) - err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit) - if err != nil { - stats.Incr("bricksllm.proxy.get_chat_completion_handler.record_key_spend_error", nil, 1) - logError(log, "error when recording openai spend", prod, cid, err) - } + // micros := int64(cost * 1000000) + // err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit) + // if err != nil { + // stats.Incr("bricksllm.proxy.get_chat_completion_handler.record_key_spend_error", nil, 1) + // logError(log, "error when recording openai spend", prod, cid, err) + // } } c.Set("costInUsd", cost) @@ -1195,22 +1195,24 @@ func getChatCompletionHandler(r recorder, prod, private bool, psm ProviderSettin } buffer := bufio.NewReader(res.Body) - var totalCost float64 = 0 - var totalTokens int = 0 + // var totalCost float64 = 0 + // var totalTokens int = 0 content := "" defer func() { - tks, cost, err := e.EstimateChatCompletionStreamCostWithTokenCounts(model, content) - if err != nil { - stats.Incr("bricksllm.proxy.get_chat_completion_handler.estimate_chat_completion_cost_and_tokens_error", nil, 1) - logError(log, "error when estimating chat completion stream cost with token counts", prod, cid, err) - } + c.Set("content", content) + + // tks, cost, err := e.EstimateChatCompletionStreamCostWithTokenCounts(model, content) + // if err != nil { + // stats.Incr("bricksllm.proxy.get_chat_completion_handler.estimate_chat_completion_cost_and_tokens_error", nil, 1) + // logError(log, "error when estimating chat completion stream cost with token counts", prod, cid, err) + // } - estimatedPromptCost := c.GetFloat64("estimatedPromptCostInUsd") - totalCost = cost + estimatedPromptCost - totalTokens += tks + // estimatedPromptCost := c.GetFloat64("estimatedPromptCostInUsd") + // totalCost = cost + estimatedPromptCost + // totalTokens += tks - c.Set("costInUsd", totalCost) - c.Set("completionTokenCount", totalTokens) + // c.Set("costInUsd", totalCost) + // c.Set("completionTokenCount", totalTokens) }() stats.Incr("bricksllm.proxy.get_chat_completion_handler.streaming_requests", nil, 1) @@ -1522,7 +1524,7 @@ func logEmbeddingRequest(log *zap.Logger, prod, private bool, id string, r *goop if prod { fields := []zapcore.Field{ zap.String(correlationId, id), - zap.String("model", r.Model.String()), + zap.String("model", string(r.Model)), zap.String("encoding_format", string(r.EncodingFormat)), zap.String("user", r.User), } diff --git a/internal/server/web/proxy/route.go b/internal/server/web/proxy/route.go index 13357a9..90c46af 100644 --- a/internal/server/web/proxy/route.go +++ b/internal/server/web/proxy/route.go @@ -231,12 +231,12 @@ func parseResult(c *gin.Context, ca cache, kc *key.ResponseKey, runEmbeddings bo cost = ecost } - micros := int64(cost * 1000000) + // micros := int64(cost * 1000000) - err := r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit) - if err != nil { - return err - } + // err := r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit) + // if err != nil { + // return err + // } } if !runEmbeddings { @@ -262,11 +262,11 @@ func parseResult(c *gin.Context, ca cache, kc *key.ResponseKey, runEmbeddings bo } } - micros := int64(cost * 1000000) - err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit) - if err != nil { - return err - } + // micros := int64(cost * 1000000) + // err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit) + // if err != nil { + // return err + // } } return nil