diff --git a/cmd/bricksllm/main.go b/cmd/bricksllm/main.go index c8dba0d..69d3354 100644 --- a/cmd/bricksllm/main.go +++ b/cmd/bricksllm/main.go @@ -14,6 +14,7 @@ import ( "github.com/bricks-cloud/bricksllm/internal/config" "github.com/bricks-cloud/bricksllm/internal/logger/zap" "github.com/bricks-cloud/bricksllm/internal/manager" + "github.com/bricks-cloud/bricksllm/internal/message" "github.com/bricks-cloud/bricksllm/internal/provider/anthropic" "github.com/bricks-cloud/bricksllm/internal/provider/azure" "github.com/bricks-cloud/bricksllm/internal/provider/custom" @@ -171,10 +172,23 @@ func main() { log.Sugar().Fatalf("error connecting to api redis cache: %v", err) } + accessRedisCache := redis.NewClient(&redis.Options{ + Addr: fmt.Sprintf("%s:%s", cfg.RedisHosts, cfg.RedisPort), + Password: cfg.RedisPassword, + DB: 4, + }) + + ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if err := apiRedisCache.Ping(ctx).Err(); err != nil { + log.Sugar().Fatalf("error connecting to api redis cache: %v", err) + } + rateLimitCache := redisStorage.NewCache(rateLimitRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) costLimitCache := redisStorage.NewCache(costLimitRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) costStorage := redisStorage.NewStore(costRedisStorage, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) apiCache := redisStorage.NewCache(apiRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) + accessCache := redisStorage.NewAccessCache(accessRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) m := manager.NewManager(store) krm := manager.NewReportingManager(costStorage, store, store) @@ -209,7 +223,16 @@ func main() { c := cache.NewCache(apiCache) - ps, err := proxy.NewProxyServer(log, *modePtr, *privacyPtr, c, m, rm, a, psm, cpm, store, memStore, ce, ace, aoe, v, rec, rlm, cfg.ProxyTimeout) + messageBus := message.NewMessageBus() + eventMessageChan := make(chan message.Message) + messageBus.Subscribe("event", eventMessageChan) + + handler := message.NewHandler(rec, log, ace, ce, aoe, v, m, rlm, accessCache) + + eventConsumer := message.NewConsumer(eventMessageChan, log, 4, handler.HandleEventWithRequestAndResponse) + eventConsumer.StartEventMessageConsumers() + + ps, err := proxy.NewProxyServer(log, *modePtr, *privacyPtr, c, m, rm, a, psm, cpm, store, memStore, ce, ace, aoe, v, rec, messageBus, rlm, cfg.ProxyTimeout, accessCache) if err != nil { log.Sugar().Fatalf("error creating proxy http server: %v", err) } @@ -220,6 +243,7 @@ func main() { signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) <-quit + eventConsumer.Stop() memStore.Stop() psMemStore.Stop() cpMemStore.Stop() diff --git a/cmd/tool/main.go b/cmd/tool/main.go deleted file mode 100644 index 11158b8..0000000 --- a/cmd/tool/main.go +++ /dev/null @@ -1,269 +0,0 @@ -package main - -import ( - "context" - "flag" - "fmt" - "os" - "os/signal" - "syscall" - "time" - - auth "github.com/bricks-cloud/bricksllm/internal/authenticator" - "github.com/bricks-cloud/bricksllm/internal/cache" - "github.com/bricks-cloud/bricksllm/internal/config" - logger "github.com/bricks-cloud/bricksllm/internal/logger/zap" - "github.com/bricks-cloud/bricksllm/internal/manager" - "github.com/bricks-cloud/bricksllm/internal/provider/anthropic" - "github.com/bricks-cloud/bricksllm/internal/provider/azure" - "github.com/bricks-cloud/bricksllm/internal/provider/custom" - "github.com/bricks-cloud/bricksllm/internal/provider/openai" - "github.com/bricks-cloud/bricksllm/internal/recorder" - "github.com/bricks-cloud/bricksllm/internal/server/web/admin" - "github.com/bricks-cloud/bricksllm/internal/server/web/proxy" - "github.com/bricks-cloud/bricksllm/internal/stats" - "github.com/bricks-cloud/bricksllm/internal/storage/memdb" - "github.com/bricks-cloud/bricksllm/internal/storage/postgresql" - redisStorage "github.com/bricks-cloud/bricksllm/internal/storage/redis" - "github.com/bricks-cloud/bricksllm/internal/validator" - "github.com/gin-gonic/gin" - "github.com/redis/go-redis/v9" -) - -func main() { - modePtr := flag.String("m", "dev", "select the mode that bricksllm runs in") - privacyPtr := flag.String("p", "strict", "select the privacy mode that bricksllm runs in") - flag.Parse() - - log := logger.NewZapLogger(*modePtr) - - gin.SetMode(gin.ReleaseMode) - - cfg, err := config.ParseEnvVariables() - if err != nil { - log.Sugar().Fatalf("cannot parse environment variables: %v", err) - } - - err = stats.InitializeClient(cfg.StatsProvider) - if err != nil { - log.Sugar().Fatalf("cannot connect to telemetry provider: %v", err) - } - - store, err := postgresql.NewStore( - fmt.Sprintf("postgresql:///%s?sslmode=%s&user=%s&password=%s&host=%s&port=%s", cfg.PostgresqlDbName, cfg.PostgresqlSslMode, cfg.PostgresqlUsername, cfg.PostgresqlPassword, cfg.PostgresqlHosts, cfg.PostgresqlPort), - cfg.PostgresqlWriteTimeout, - cfg.PostgresqlReadTimeout, - ) - - if err != nil { - log.Sugar().Fatalf("cannot connect to postgresql: %v", err) - } - - err = store.CreateCustomProvidersTable() - if err != nil { - log.Sugar().Fatalf("error creating custom providers table: %v", err) - } - - err = store.CreateRoutesTable() - if err != nil { - log.Sugar().Fatalf("error creating routes table: %v", err) - } - - err = store.CreateKeysTable() - if err != nil { - log.Sugar().Fatalf("error creating keys table: %v", err) - } - - err = store.AlterKeysTable() - if err != nil { - log.Sugar().Fatalf("error altering keys table: %v", err) - } - - err = store.CreateEventsTable() - if err != nil { - log.Sugar().Fatalf("error creating events table: %v", err) - } - - err = store.AlterEventsTable() - if err != nil { - log.Sugar().Fatalf("error altering events table: %v", err) - } - - err = store.CreateProviderSettingsTable() - if err != nil { - log.Sugar().Fatalf("error creating provider settings table: %v", err) - } - - err = store.AlterProviderSettingsTable() - if err != nil { - log.Sugar().Fatalf("error altering provider settings table: %v", err) - } - - memStore, err := memdb.NewMemDb(store, log, cfg.InMemoryDbUpdateInterval) - if err != nil { - log.Sugar().Fatalf("cannot initialize memdb: %v", err) - } - memStore.Listen() - - psMemStore, err := memdb.NewProviderSettingsMemDb(store, log, cfg.InMemoryDbUpdateInterval) - if err != nil { - log.Sugar().Fatalf("cannot initialize provider settings memdb: %v", err) - } - psMemStore.Listen() - - cpMemStore, err := memdb.NewCustomProvidersMemDb(store, log, cfg.InMemoryDbUpdateInterval) - if err != nil { - log.Sugar().Fatalf("cannot initialize custom providers memdb: %v", err) - } - cpMemStore.Listen() - - rMemStore, err := memdb.NewRoutesMemDb(store, log, cfg.InMemoryDbUpdateInterval) - if err != nil { - log.Sugar().Fatalf("cannot initialize routes memdb: %v", err) - } - rMemStore.Listen() - - rateLimitRedisCache := redis.NewClient(&redis.Options{ - Addr: fmt.Sprintf("%s:%s", cfg.RedisHosts, cfg.RedisPort), - Password: cfg.RedisPassword, - DB: 0, - }) - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - if err := rateLimitRedisCache.Ping(ctx).Err(); err != nil { - log.Sugar().Fatalf("error connecting to rate limit redis cache: %v", err) - } - - costLimitRedisCache := redis.NewClient(&redis.Options{ - Addr: fmt.Sprintf("%s:%s", cfg.RedisHosts, cfg.RedisPort), - Password: cfg.RedisPassword, - DB: 1, - }) - - ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - if err := costLimitRedisCache.Ping(ctx).Err(); err != nil { - log.Sugar().Fatalf("error connecting to cost limit redis cache: %v", err) - } - - costRedisStorage := redis.NewClient(&redis.Options{ - Addr: fmt.Sprintf("%s:%s", cfg.RedisHosts, cfg.RedisPort), - Password: cfg.RedisPassword, - DB: 2, - }) - - ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - if err := costRedisStorage.Ping(ctx).Err(); err != nil { - log.Sugar().Fatalf("error connecting to cost limit redis storage: %v", err) - } - - apiRedisCache := redis.NewClient(&redis.Options{ - Addr: fmt.Sprintf("%s:%s", cfg.RedisHosts, cfg.RedisPort), - Password: cfg.RedisPassword, - DB: 3, - }) - - ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - if err := apiRedisCache.Ping(ctx).Err(); err != nil { - log.Sugar().Fatalf("error connecting to api redis cache: %v", err) - } - - rateLimitCache := redisStorage.NewCache(rateLimitRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) - costLimitCache := redisStorage.NewCache(costLimitRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) - costStorage := redisStorage.NewStore(costRedisStorage, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) - apiCache := redisStorage.NewCache(apiRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout) - - m := manager.NewManager(store) - krm := manager.NewReportingManager(costStorage, store, store) - psm := manager.NewProviderSettingsManager(store, psMemStore) - cpm := manager.NewCustomProvidersManager(store, cpMemStore) - rm := manager.NewRouteManager(store, store, rMemStore, psMemStore) - - as, err := admin.NewAdminServer(log, *modePtr, m, krm, psm, cpm, rm, cfg.AdminPass) - if err != nil { - log.Sugar().Fatalf("error creating admin http server: %v", err) - } - as.Run() - - tc := openai.NewTokenCounter() - custom.NewTokenCounter() - atc, err := anthropic.NewTokenCounter() - if err != nil { - log.Sugar().Fatalf("error creating anthropic token counter: %v", err) - } - - ae := anthropic.NewCostEstimator(atc) - - ce := openai.NewCostEstimator(openai.OpenAiPerThousandTokenCost, tc) - v := validator.NewValidator(costLimitCache, rateLimitCache, costStorage) - rec := recorder.NewRecorder(costStorage, costLimitCache, ce, store) - rlm := manager.NewRateLimitManager(rateLimitCache) - a := auth.NewAuthenticator(psm, memStore, rm) - - c := cache.NewCache(apiCache) - - aoe := azure.NewCostEstimator() - - ps, err := proxy.NewProxyServer(log, *modePtr, *privacyPtr, c, m, rm, a, psm, cpm, store, memStore, ce, ae, aoe, v, rec, rlm, cfg.ProxyTimeout) - if err != nil { - log.Sugar().Fatalf("error creating proxy http server: %v", err) - } - - ps.Run() - - quit := make(chan os.Signal) - signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) - <-quit - - memStore.Stop() - psMemStore.Stop() - cpMemStore.Stop() - - log.Sugar().Info("shutting down server...") - - ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - if err := as.Shutdown(ctx); err != nil { - log.Sugar().Debugf("admin server shutdown: %v", err) - } - - ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - if err := ps.Shutdown(ctx); err != nil { - log.Sugar().Debugf("proxy server shutdown: %v", err) - } - - select { - case <-ctx.Done(): - log.Sugar().Infof("timeout of 5 seconds") - } - - err = store.DropKeysTable() - if err != nil { - log.Sugar().Fatalf("error dropping keys table: %v", err) - } - - err = store.DropEventsTable() - if err != nil { - log.Sugar().Fatalf("error dropping events table: %v", err) - } - - err = store.DropCustomProvidersTable() - if err != nil { - log.Sugar().Fatalf("error dropping custom providers table: %v", err) - } - - err = store.DropProviderSettingsTable() - if err != nil { - log.Sugar().Fatalf("error dropping provider settings table: %v", err) - } - - err = store.DropRoutesTable() - if err != nil { - log.Sugar().Fatalf("error dropping routes table: %v", err) - } - - log.Sugar().Infof("server exited") -} diff --git a/go.mod b/go.mod index 1b4a791..9c500a9 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/mattn/go-colorable v0.1.13 github.com/pkoukk/tiktoken-go-loader v0.0.1 github.com/redis/go-redis/v9 v9.0.5 - github.com/sashabaranov/go-openai v1.17.7 + github.com/sashabaranov/go-openai v1.19.2 github.com/stretchr/testify v1.8.4 go.uber.org/zap v1.24.0 ) diff --git a/go.sum b/go.sum index fc57456..b71b7c1 100644 --- a/go.sum +++ b/go.sum @@ -85,6 +85,8 @@ github.com/sashabaranov/go-openai v1.17.1 h1:tapFKbKE8ep0/qGkKp5Q3TtxWUD7m9VIFe9 github.com/sashabaranov/go-openai v1.17.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/sashabaranov/go-openai v1.17.7 h1:MPcAwlwbeo7ZmhQczoOgZBHtIBY1TfZqsdx6+/ndloM= github.com/sashabaranov/go-openai v1.17.7/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= +github.com/sashabaranov/go-openai v1.19.2 h1:+dkuCADSnwXV02YVJkdphY8XD9AyHLUWwk6V7LB6EL8= +github.com/sashabaranov/go-openai v1.19.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= diff --git a/internal/config/config.go b/internal/config/config.go index 784bc2b..ced3a3b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -7,25 +7,26 @@ import ( ) type Config struct { - PostgresqlHosts string `env:"POSTGRESQL_HOSTS" envSeparator:":" envDefault:"localhost"` - PostgresqlDbName string `env:"POSTGRESQL_DB_NAME"` - PostgresqlUsername string `env:"POSTGRESQL_USERNAME"` - PostgresqlPassword string `env:"POSTGRESQL_PASSWORD"` - PostgresqlSslMode string `env:"POSTGRESQL_SSL_MODE" envDefault:"disable"` - PostgresqlPort string `env:"POSTGRESQL_PORT" envDefault:"5432"` - RedisHosts string `env:"REDIS_HOSTS" envSeparator:":" envDefault:"localhost"` - RedisPort string `env:"REDIS_PORT" envDefault:"6379"` - RedisUsername string `env:"REDIS_USERNAME"` - RedisPassword string `env:"REDIS_PASSWORD"` - RedisReadTimeout time.Duration `env:"REDIS_READ_TIME_OUT" envDefault:"1s"` - RedisWriteTimeout time.Duration `env:"REDIS_WRITE_TIME_OUT" envDefault:"500ms"` - PostgresqlReadTimeout time.Duration `env:"POSTGRESQL_READ_TIME_OUT" envDefault:"2s"` - PostgresqlWriteTimeout time.Duration `env:"POSTGRESQL_WRITE_TIME_OUT" envDefault:"1s"` - InMemoryDbUpdateInterval time.Duration `env:"IN_MEMORY_DB_UPDATE_INTERVAL" envDefault:"5s"` - OpenAiKey string `env:"OPENAI_API_KEY"` - StatsProvider string `env:"STATS_PROVIDER"` - AdminPass string `env:"ADMIN_PASS"` - ProxyTimeout time.Duration `env:"PROXY_TIMEOUT" envDefault:"600s"` + PostgresqlHosts string `env:"POSTGRESQL_HOSTS" envSeparator:":" envDefault:"localhost"` + PostgresqlDbName string `env:"POSTGRESQL_DB_NAME"` + PostgresqlUsername string `env:"POSTGRESQL_USERNAME"` + PostgresqlPassword string `env:"POSTGRESQL_PASSWORD"` + PostgresqlSslMode string `env:"POSTGRESQL_SSL_MODE" envDefault:"disable"` + PostgresqlPort string `env:"POSTGRESQL_PORT" envDefault:"5432"` + RedisHosts string `env:"REDIS_HOSTS" envSeparator:":" envDefault:"localhost"` + RedisPort string `env:"REDIS_PORT" envDefault:"6379"` + RedisUsername string `env:"REDIS_USERNAME"` + RedisPassword string `env:"REDIS_PASSWORD"` + RedisReadTimeout time.Duration `env:"REDIS_READ_TIME_OUT" envDefault:"1s"` + RedisWriteTimeout time.Duration `env:"REDIS_WRITE_TIME_OUT" envDefault:"500ms"` + PostgresqlReadTimeout time.Duration `env:"POSTGRESQL_READ_TIME_OUT" envDefault:"2s"` + PostgresqlWriteTimeout time.Duration `env:"POSTGRESQL_WRITE_TIME_OUT" envDefault:"1s"` + InMemoryDbUpdateInterval time.Duration `env:"IN_MEMORY_DB_UPDATE_INTERVAL" envDefault:"5s"` + OpenAiKey string `env:"OPENAI_API_KEY"` + StatsProvider string `env:"STATS_PROVIDER"` + AdminPass string `env:"ADMIN_PASS"` + ProxyTimeout time.Duration `env:"PROXY_TIMEOUT" envDefault:"600s"` + NumberOfEventMessageConsumers int `env:"NUMBER_OF_EVENT_MESSAGE_CONSUMERS" envDefault:"3"` } func ParseEnvVariables() (*Config, error) { diff --git a/internal/errors/cost_limit_err.go b/internal/errors/cost_limit_err.go new file mode 100644 index 0000000..15c92de --- /dev/null +++ b/internal/errors/cost_limit_err.go @@ -0,0 +1,17 @@ +package errors + +type CostLimitError struct { + message string +} + +func NewCostLimitError(msg string) *CostLimitError { + return &CostLimitError{ + message: msg, + } +} + +func (cle *CostLimitError) Error() string { + return cle.message +} + +func (rle *CostLimitError) CostLimit() {} diff --git a/internal/event/event_with_request_and_response.go b/internal/event/event_with_request_and_response.go new file mode 100644 index 0000000..8e99f0b --- /dev/null +++ b/internal/event/event_with_request_and_response.go @@ -0,0 +1,16 @@ +package event + +import ( + "github.com/bricks-cloud/bricksllm/internal/key" + "github.com/bricks-cloud/bricksllm/internal/provider/custom" +) + +type EventWithRequestAndContent struct { + Event *Event + IsEmbeddingsRequest bool + RouteConfig *custom.RouteConfig + Request interface{} + Content string + Response interface{} + Key *key.ResponseKey +} diff --git a/internal/message/bus.go b/internal/message/bus.go new file mode 100644 index 0000000..7fe0749 --- /dev/null +++ b/internal/message/bus.go @@ -0,0 +1,23 @@ +package message + +type MessageBus struct { + Subscribers map[string][]chan<- Message +} + +func NewMessageBus() *MessageBus { + return &MessageBus{ + Subscribers: make(map[string][]chan<- Message), + } +} + +func (mb *MessageBus) Subscribe(messageType string, subscriber chan<- Message) { + mb.Subscribers[messageType] = append(mb.Subscribers[messageType], subscriber) +} + +func (mb *MessageBus) Publish(ms Message) { + subscribers := mb.Subscribers[ms.Type] + + for _, subscriber := range subscribers { + subscriber <- ms + } +} diff --git a/internal/message/consumer.go b/internal/message/consumer.go new file mode 100644 index 0000000..e8e4e02 --- /dev/null +++ b/internal/message/consumer.go @@ -0,0 +1,58 @@ +package message + +import ( + "github.com/bricks-cloud/bricksllm/internal/event" + "github.com/bricks-cloud/bricksllm/internal/key" + "go.uber.org/zap" +) + +type Consumer struct { + messageChan <-chan Message + done chan bool + log *zap.Logger + numOfEventConsumers int + handle func(Message) error +} + +type recorder interface { + RecordKeySpend(keyId string, micros int64, costLimitUnit key.TimeUnit) error + RecordEvent(e *event.Event) error +} + +func NewConsumer(mc <-chan Message, log *zap.Logger, num int, handle func(Message) error) *Consumer { + return &Consumer{ + messageChan: mc, + done: make(chan bool), + log: log, + numOfEventConsumers: num, + handle: handle, + } +} + +func (c *Consumer) StartEventMessageConsumers() { + for i := 0; i < c.numOfEventConsumers; i++ { + go func() { + for { + select { + case <-c.done: + c.log.Info("event message consumer stoped...") + return + + case m := <-c.messageChan: + err := c.handle(m) + if err != nil { + continue + } + + continue + } + } + }() + } +} + +func (c *Consumer) Stop() { + c.log.Info("shutting down consumer...") + + c.done <- true +} diff --git a/internal/message/handler.go b/internal/message/handler.go new file mode 100644 index 0000000..04c5697 --- /dev/null +++ b/internal/message/handler.go @@ -0,0 +1,422 @@ +package message + +import ( + "errors" + "strings" + "time" + + "github.com/bricks-cloud/bricksllm/internal/event" + "github.com/bricks-cloud/bricksllm/internal/key" + "github.com/bricks-cloud/bricksllm/internal/provider/anthropic" + "github.com/bricks-cloud/bricksllm/internal/provider/custom" + "github.com/bricks-cloud/bricksllm/internal/stats" + "github.com/tidwall/gjson" + "go.uber.org/zap" + + goopenai "github.com/sashabaranov/go-openai" +) + +type anthropicEstimator interface { + EstimateTotalCost(model string, promptTks, completionTks int) (float64, error) + EstimateCompletionCost(model string, tks int) (float64, error) + EstimatePromptCost(model string, tks int) (float64, error) + Count(input string) int +} + +type estimator interface { + EstimateChatCompletionPromptCostWithTokenCounts(r *goopenai.ChatCompletionRequest) (int, float64, error) + EstimateEmbeddingsCost(r *goopenai.EmbeddingRequest) (float64, error) + EstimateChatCompletionStreamCostWithTokenCounts(model, content string) (int, float64, error) + EstimateCompletionCost(model string, tks int) (float64, error) + EstimateTotalCost(model string, promptTks, completionTks int) (float64, error) + EstimateEmbeddingsInputCost(model string, tks int) (float64, error) + EstimateChatCompletionPromptTokenCounts(model string, r *goopenai.ChatCompletionRequest) (int, error) +} + +type azureEstimator interface { + EstimateChatCompletionStreamCostWithTokenCounts(model, content string) (int, float64, error) + EstimateEmbeddingsCost(r *goopenai.EmbeddingRequest) (float64, error) + EstimateCompletionCost(model string, tks int) (float64, error) + EstimatePromptCost(model string, tks int) (float64, error) + EstimateTotalCost(model string, promptTks, completionTks int) (float64, error) + EstimateEmbeddingsInputCost(model string, tks int) (float64, error) +} + +type validator interface { + Validate(k *key.ResponseKey, promptCost float64) error +} + +type keyManager interface { + UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, error) +} + +type rateLimitManager interface { + Increment(keyId string, timeUnit key.TimeUnit) error +} + +type accessCache interface { + Set(key string, timeUnit key.TimeUnit) error +} + +type Handler struct { + recorder recorder + log *zap.Logger + ae anthropicEstimator + e estimator + aze azureEstimator + v validator + km keyManager + rlm rateLimitManager + ac accessCache +} + +func NewHandler(r recorder, log *zap.Logger, ae anthropicEstimator, e estimator, aze azureEstimator, v validator, km keyManager, rlm rateLimitManager, ac accessCache) *Handler { + return &Handler{ + recorder: r, + log: log, + ae: ae, + e: e, + aze: aze, + v: v, + km: km, + rlm: rlm, + ac: ac, + } +} + +func (h *Handler) HandleEvent(m Message) error { + stats.Incr("bricksllm.message.handler.handle_event.requests", nil, 1) + + e, ok := m.Data.(*event.Event) + if !ok { + stats.Incr("bricksllm.message.handler.handle_event.event_parsing_error", nil, 1) + h.log.Info("message contains data that cannot be converted to event format", zap.Any("data", m.Data)) + return errors.New("message data cannot be parsed as event") + } + + start := time.Now() + + err := h.recorder.RecordEvent(e) + if err != nil { + stats.Incr("bricksllm.message.handler.handle_event.record_event_error", nil, 1) + h.log.Sugar().Debugf("error when publishin event: %v", err) + return err + } + + stats.Timing("bricksllm.message.handler.handle_event.record_event_latency", time.Now().Sub(start), nil, 1) + stats.Incr("bricksllm.message.handler.handle_event.success", nil, 1) + + return nil +} + +const ( + anthropicPromptMagicNum int = 1 + anthropicCompletionMagicNum int = 4 +) + +func countTokensFromJson(bytes []byte, contentLoc string) (int, error) { + content := getContentFromJson(bytes, contentLoc) + return custom.Count(content) +} + +func getContentFromJson(bytes []byte, contentLoc string) string { + result := gjson.Get(string(bytes), contentLoc) + content := "" + + if len(result.Str) != 0 { + content += result.Str + } + + if result.IsArray() { + for _, val := range result.Array() { + if len(val.Str) != 0 { + content += val.Str + } + } + } + + return content +} + +type costLimitError interface { + Error() string + CostLimit() +} + +type rateLimitError interface { + Error() string + RateLimit() +} + +type expirationError interface { + Error() string + Reason() string +} + +func (h *Handler) handleValidationResult(kc *key.ResponseKey, cost float64) error { + err := h.v.Validate(kc, cost) + + if err != nil { + stats.Incr("bricksllm.message.handler.handle_validation_result.handle_validation_result", nil, 1) + + // tested + if _, ok := err.(expirationError); ok { + stats.Incr("bricksllm.message.handler.handle_validation_result.expiraton_error", nil, 1) + + truePtr := true + _, err = h.km.UpdateKey(kc.KeyId, &key.UpdateKey{ + Revoked: &truePtr, + RevokedReason: "Key has expired or exceeded set spend limit", + }) + + if err != nil { + stats.Incr("bricksllm.message.handler.handle_validation_result.update_key_error", nil, 1) + return err + } + + return nil + } + + // tested + if _, ok := err.(rateLimitError); ok { + stats.Incr("bricksllm.message.handler.handle_validation_result.rate_limit_error", nil, 1) + + err = h.ac.Set(kc.KeyId, kc.RateLimitUnit) + if err != nil { + stats.Incr("bricksllm.message.handler.handle_validation_result.set_rate_limit_error", nil, 1) + return err + } + + return nil + } + + // tested + if _, ok := err.(costLimitError); ok { + stats.Incr("bricksllm.message.handler.handle_validation_result.cost_limit_error", nil, 1) + + err = h.ac.Set(kc.KeyId, kc.CostLimitInUsdUnit) + if err != nil { + stats.Incr("bricksllm.message.handler.handle_validation_result.set_cost_limit_error", nil, 1) + return err + } + + return nil + } + + return err + } + + return nil +} + +func (h *Handler) HandleEventWithRequestAndResponse(m Message) error { + e, ok := m.Data.(*event.EventWithRequestAndContent) + if !ok { + stats.Incr("bricksllm.message.handler.handle_event_with_request_and_response.message_data_parsing_error", nil, 1) + h.log.Debug("message contains data that cannot be converted to event with request and response format", zap.Any("data", m.Data)) + return errors.New("message data cannot be parsed as event with request and response") + } + + if e.Key != nil && !e.Key.Revoked && e.Event != nil { + err := h.decorateEvent(m) + if err != nil { + stats.Incr("bricksllm.message.handler.handle_event_with_request_and_response.decorate_event_error", nil, 1) + h.log.Debug("error when decorating event", zap.Error(err)) + } + + // tested + if e.Event.CostInUsd != 0 { + micros := int64(e.Event.CostInUsd * 1000000) + err = h.recorder.RecordKeySpend(e.Event.KeyId, micros, e.Key.CostLimitInUsdUnit) + if err != nil { + stats.Incr("bricksllm.message.handler.handle_event_with_request_and_response.record_key_spend_error", nil, 1) + h.log.Debug("error when recording key spend", zap.Error(err)) + } + } + + // tested + if len(e.Key.RateLimitUnit) != 0 { + if err := h.rlm.Increment(e.Key.KeyId, e.Key.RateLimitUnit); err != nil { + stats.Incr("bricksllm.message.handler.handle_event_with_request_and_response.rate_limit_increment_error", nil, 1) + + h.log.Debug("error when incrementing rate limit", zap.Error(err)) + } + } + + // tested + err = h.handleValidationResult(e.Key, e.Event.CostInUsd) + if err != nil { + stats.Incr("bricksllm.message.handler.handle_event_with_request_and_response.handle_validation_result_error", nil, 1) + h.log.Debug("error when handling validation result", zap.Error(err)) + } + + } + + // tested + start := time.Now() + err := h.recorder.RecordEvent(e.Event) + if err != nil { + stats.Incr("bricksllm.message.handler.handle_event_with_request_and_response.record_event_error", nil, 1) + return err + } + + stats.Timing("bricksllm.message.handler.handle_event_with_request_and_response.latency", time.Now().Sub(start), nil, 1) + stats.Incr("bricksllm.message.handler.handle_event_with_request_and_response.success", nil, 1) + + return nil +} + +func (h *Handler) decorateEvent(m Message) error { + stats.Incr("bricksllm.message.handler.decorate_event.request", nil, 1) + + e, ok := m.Data.(*event.EventWithRequestAndContent) + if !ok { + stats.Incr("bricksllm.message.handler.decorate_event.message_data_parsing_error", nil, 1) + h.log.Debug("message contains data that cannot be converted to event with request and response format", zap.Any("data", m.Data)) + return errors.New("message data cannot be parsed as event with request and response") + } + + // tested + if e.Event.Provider == "anthropic" && e.Event.Path == "/api/providers/anthropic/v1/complete" { + cr, ok := e.Request.(*anthropic.CompletionRequest) + if !ok { + stats.Incr("bricksllm.message.handler.decorate_event.event_request_parsing_error", nil, 1) + h.log.Debug("event contains request that cannot be converted to anthropic completion request", zap.Any("data", m.Data)) + return errors.New("event request data cannot be parsed as anthropic completon request") + } + + tks := h.ae.Count(cr.Prompt) + tks += anthropicPromptMagicNum + + model := cr.Model + cost, err := h.ae.EstimatePromptCost(model, tks) + if err != nil { + stats.Incr("bricksllm.message.handler.decorate_event.estimate_prompt_cost", nil, 1) + h.log.Debug("event contains request that cannot be converted to anthropic completion request", zap.Error(err)) + return err + } + + completiontks := h.ae.Count(e.Content) + completiontks += anthropicCompletionMagicNum + + completionCost, err := h.ae.EstimateCompletionCost(model, completiontks) + if err != nil { + stats.Incr("bricksllm.message.handler.decorate_event.estimate_completion_cost_error", nil, 1) + return err + } + + e.Event.PromptTokenCount = tks + e.Event.CompletionTokenCount = completiontks + e.Event.CostInUsd = completionCost + cost + } + + // tested + if e.Event.Provider == "azure" && e.Event.Path == "/api/providers/azure/openai/deployments/:deployment_id/chat/completions" { + ccr, ok := e.Request.(*goopenai.ChatCompletionRequest) + if !ok { + stats.Incr("bricksllm.message.handler.decorate_event.event_request_parsing_error", nil, 1) + h.log.Debug("event contains data that cannot be converted to azure openai completion request", zap.Any("data", m.Data)) + return errors.New("event request data cannot be parsed as azure openai completon request") + } + + if ccr.Stream { + tks, err := h.e.EstimateChatCompletionPromptTokenCounts("gpt-3.5-turbo", ccr) + if err != nil { + stats.Incr("bricksllm.message.decorate_event.estimate_chat_completion_prompt_token_counts_error", nil, 1) + return err + } + + cost, err := h.aze.EstimatePromptCost(e.Event.Model, tks) + if err != nil { + stats.Incr("bricksllm.message.decorate_event.estimate_prompt_cost_error", nil, 1) + return err + } + + completiontks, completionCost, err := h.aze.EstimateChatCompletionStreamCostWithTokenCounts(e.Event.Model, e.Content) + if err != nil { + stats.Incr("bricksllm.message.decorate_event.estimate_chat_completion_stream_cost_with_token_counts_error", nil, 1) + return err + } + + e.Event.PromptTokenCount = tks + e.Event.CompletionTokenCount = completiontks + e.Event.CostInUsd = cost + completionCost + } + } + + // tested + if e.Event.Provider == "openai" && e.Event.Path == "/api/providers/openai/v1/chat/completions" { + ccr, ok := e.Request.(*goopenai.ChatCompletionRequest) + if !ok { + stats.Incr("bricksllm.message.handler.decorate_event.event_request_parsing_error", nil, 1) + h.log.Debug("event contains data that cannot be converted to openai completion request", zap.Any("data", m.Data)) + return errors.New("event request data cannot be parsed as oepnai completon request") + } + + if ccr.Stream { + tks, cost, err := h.e.EstimateChatCompletionPromptCostWithTokenCounts(ccr) + if err != nil { + stats.Incr("bricksllm.message.handler.decorate_event.estimate_chat_completion_prompt_cost_with_token_counts", nil, 1) + return err + } + + completiontks, completionCost, err := h.e.EstimateChatCompletionStreamCostWithTokenCounts(e.Event.Model, e.Content) + if err != nil { + stats.Incr("bricksllm.message.handler.decorate_event.estimate_chat_completion_stream_cost_with_token_counts", nil, 1) + return err + } + + e.Event.PromptTokenCount = tks + e.Event.CompletionTokenCount = completiontks + e.Event.CostInUsd = cost + completionCost + } + } + + if strings.HasPrefix(e.Event.Path, "/api/custom/providers/:provider") && e.RouteConfig != nil { + body, ok := e.Request.([]byte) + if !ok { + stats.Incr("bricksllm.message.handler.decorate_event.event_request_custom_provider_parsing_error", nil, 1) + h.log.Debug("event contains request that cannot be converted to bytes", zap.Any("data", m.Data)) + return errors.New("event request data cannot be parsed as anthropic completon request") + } + + content, ok := e.Response.([]byte) + if !ok { + stats.Incr("bricksllm.message.handler.decorate_event.event_response_custom_provider_parsing_error", nil, 1) + h.log.Debug("event contains response that cannot be converted to bytes", zap.Any("data", m.Data)) + return errors.New("event response data cannot be converted to bytes") + } + + tks, err := countTokensFromJson(body, e.RouteConfig.RequestPromptLocation) + if err != nil { + stats.Incr("bricksllm.message.handler.decorate_event.count_tokens_from_json_error", nil, 1) + + return err + } + + e.Event.PromptTokenCount = tks + + result := gjson.Get(string(body), e.RouteConfig.StreamLocation) + if result.IsBool() { + completiontks, err := custom.Count(e.Content) + if err != nil { + stats.Incr("bricksllm.message.handler.decorate_event.custom_count_error", nil, 1) + return err + } + + e.Event.CompletionTokenCount = completiontks + } + + if !result.IsBool() { + completiontks, err := countTokensFromJson(content, e.RouteConfig.ResponseCompletionLocation) + if err != nil { + stats.Incr("bricksllm.message.handler.decorate_event.count_tokens_from_json_error", nil, 1) + return err + } + + e.Event.CompletionTokenCount = completiontks + } + } + + return nil +} diff --git a/internal/message/message.go b/internal/message/message.go new file mode 100644 index 0000000..a2e97a5 --- /dev/null +++ b/internal/message/message.go @@ -0,0 +1,6 @@ +package message + +type Message struct { + Type string + Data interface{} +} diff --git a/internal/provider/openai/cost.go b/internal/provider/openai/cost.go index 2bec29b..b645b13 100644 --- a/internal/provider/openai/cost.go +++ b/internal/provider/openai/cost.go @@ -198,7 +198,7 @@ func (ce *CostEstimator) EstimateChatCompletionStreamCostWithTokenCounts(model s } func (ce *CostEstimator) EstimateEmbeddingsCost(r *goopenai.EmbeddingRequest) (float64, error) { - if len(r.Model.String()) == 0 { + if len(string(r.Model)) == 0 { return 0, errors.New("model is not provided") } @@ -210,7 +210,7 @@ func (ce *CostEstimator) EstimateEmbeddingsCost(r *goopenai.EmbeddingRequest) (f return 0, errors.New("input is not string") } - tks, err := ce.tc.Count(r.Model.String(), converted) + tks, err := ce.tc.Count(string(r.Model), converted) if err != nil { return 0, err } @@ -218,14 +218,14 @@ func (ce *CostEstimator) EstimateEmbeddingsCost(r *goopenai.EmbeddingRequest) (f total += tks } - return ce.EstimateEmbeddingsInputCost(r.Model.String(), total) + return ce.EstimateEmbeddingsInputCost(string(r.Model), total) } else if input, ok := r.Input.(string); ok { - tks, err := ce.tc.Count(r.Model.String(), input) + tks, err := ce.tc.Count(string(r.Model), input) if err != nil { return 0, err } - return ce.EstimateEmbeddingsInputCost(r.Model.String(), tks) + return ce.EstimateEmbeddingsInputCost(string(r.Model), tks) } return 0, errors.New("input format is not recognized") diff --git a/internal/server/web/proxy/anthropic.go b/internal/server/web/proxy/anthropic.go index a9aab8a..7f6a607 100644 --- a/internal/server/web/proxy/anthropic.go +++ b/internal/server/web/proxy/anthropic.go @@ -11,7 +11,6 @@ import ( "strings" "time" - "github.com/bricks-cloud/bricksllm/internal/key" "github.com/bricks-cloud/bricksllm/internal/provider/anthropic" "github.com/bricks-cloud/bricksllm/internal/stats" "github.com/gin-gonic/gin" @@ -61,13 +60,13 @@ func getCompletionHandler(r recorder, prod, private bool, client http.Client, km return } - raw, exists := c.Get("key") - kc, ok := raw.(*key.ResponseKey) - if !exists || !ok { - stats.Incr("bricksllm.proxy.get_completion_handler.api_key_not_registered", nil, 1) - JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered") - return - } + // raw, exists := c.Get("key") + // kc, ok := raw.(*key.ResponseKey) + // if !exists || !ok { + // stats.Incr("bricksllm.proxy.get_completion_handler.api_key_not_registered", nil, 1) + // JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered") + // return + // } copyHttpHeaders(c.Request, req) @@ -96,7 +95,7 @@ func getCompletionHandler(r recorder, prod, private bool, client http.Client, km } } - model := c.GetString("model") + // model := c.GetString("model") if !isStreaming && res.StatusCode == http.StatusOK { dur := time.Now().Sub(start) @@ -109,8 +108,8 @@ func getCompletionHandler(r recorder, prod, private bool, client http.Client, km return } - var cost float64 = 0 - completionTokens := 0 + // var cost float64 = 0 + // completionTokens := 0 completionRes := &anthropic.CompletionResponse{} stats.Incr("bricksllm.proxy.get_completion_handler.success", nil, 1) stats.Timing("bricksllm.proxy.get_completion_handler.success_latency", dur, nil, 1) @@ -120,27 +119,29 @@ func getCompletionHandler(r recorder, prod, private bool, client http.Client, km logError(log, "error when unmarshalling anthropic http completion response body", prod, cid, err) } - if err == nil { - logCompletionResponse(log, bytes, prod, private, cid) - completionTokens = e.Count(completionRes.Completion) - completionTokens += anthropicCompletionMagicNum - promptTokens := c.GetInt("promptTokenCount") - cost, err = e.EstimateTotalCost(model, promptTokens, completionTokens) - if err != nil { - stats.Incr("bricksllm.proxy.get_completion_handler.estimate_total_cost_error", nil, 1) - logError(log, "error when estimating anthropic cost", prod, cid, err) - } - - micros := int64(cost * 1000000) - err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit) - if err != nil { - stats.Incr("bricksllm.proxy.get_completion_handler.record_key_spend_error", nil, 1) - logError(log, "error when recording anthropic spend", prod, cid, err) - } - } - - c.Set("costInUsd", cost) - c.Set("completionTokenCount", completionTokens) + c.Set("content", completionRes.Completion) + + // if err == nil { + // logCompletionResponse(log, bytes, prod, private, cid) + // completionTokens = e.Count(completionRes.Completion) + // completionTokens += anthropicCompletionMagicNum + // promptTokens := c.GetInt("promptTokenCount") + // cost, err = e.EstimateTotalCost(model, promptTokens, completionTokens) + // if err != nil { + // stats.Incr("bricksllm.proxy.get_completion_handler.estimate_total_cost_error", nil, 1) + // logError(log, "error when estimating anthropic cost", prod, cid, err) + // } + + // micros := int64(cost * 1000000) + // err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit) + // if err != nil { + // stats.Incr("bricksllm.proxy.get_completion_handler.record_key_spend_error", nil, 1) + // logError(log, "error when recording anthropic spend", prod, cid, err) + // } + // } + + // c.Set("costInUsd", cost) + // c.Set("completionTokenCount", completionTokens) c.Data(res.StatusCode, "application/json", bytes) return @@ -163,24 +164,24 @@ func getCompletionHandler(r recorder, prod, private bool, client http.Client, km } buffer := bufio.NewReader(res.Body) - var totalCost float64 = 0 + // var totalCost float64 = 0 content := "" - defer func() { - tks := e.Count(content) - model := c.GetString("model") - cost, err := e.EstimateCompletionCost(model, tks) - if err != nil { - stats.Incr("bricksllm.proxy.get_completion_handler.estimate_completion_cost_error", nil, 1) - logError(log, "error when estimating anthropic completion stream cost", prod, cid, err) - } - - estimatedPromptCost := c.GetFloat64("estimatedPromptCostInUsd") - totalCost = cost + estimatedPromptCost - - c.Set("costInUsd", totalCost) - c.Set("completionTokenCount", tks+anthropicCompletionMagicNum) - }() + // defer func() { + // tks := e.Count(content) + // model := c.GetString("model") + // cost, err := e.EstimateCompletionCost(model, tks) + // if err != nil { + // stats.Incr("bricksllm.proxy.get_completion_handler.estimate_completion_cost_error", nil, 1) + // logError(log, "error when estimating anthropic completion stream cost", prod, cid, err) + // } + + // estimatedPromptCost := c.GetFloat64("estimatedPromptCostInUsd") + // totalCost = cost + estimatedPromptCost + + // c.Set("costInUsd", totalCost) + // c.Set("completionTokenCount", tks+anthropicCompletionMagicNum) + // }() stats.Incr("bricksllm.proxy.get_completion_handler.streaming_requests", nil, 1) diff --git a/internal/server/web/proxy/azure_chat_completion.go b/internal/server/web/proxy/azure_chat_completion.go index 970237a..2807071 100644 --- a/internal/server/web/proxy/azure_chat_completion.go +++ b/internal/server/web/proxy/azure_chat_completion.go @@ -11,7 +11,6 @@ import ( "net/http" "time" - "github.com/bricks-cloud/bricksllm/internal/key" "github.com/bricks-cloud/bricksllm/internal/stats" "github.com/gin-gonic/gin" goopenai "github.com/sashabaranov/go-openai" @@ -36,13 +35,6 @@ func getAzureChatCompletionHandler(r recorder, prod, private bool, psm ProviderS } cid := c.GetString(correlationId) - raw, exists := c.Get("key") - kc, ok := raw.(*key.ResponseKey) - if !exists || !ok { - stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.api_key_not_registered", nil, 1) - JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered") - return - } ctx, cancel := context.WithTimeout(context.Background(), timeOut) defer cancel() @@ -111,12 +103,12 @@ func getAzureChatCompletionHandler(r recorder, prod, private bool, psm ProviderS logError(log, "error when estimating azure openai cost", prod, cid, err) } - micros := int64(cost * 1000000) - err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit) - if err != nil { - stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.record_key_spend_error", nil, 1) - logError(log, "error when recording azure openai spend", prod, cid, err) - } + // micros := int64(cost * 1000000) + // err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit) + // if err != nil { + // stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.record_key_spend_error", nil, 1) + // logError(log, "error when recording azure openai spend", prod, cid, err) + // } } c.Set("costInUsd", cost) @@ -145,8 +137,8 @@ func getAzureChatCompletionHandler(r recorder, prod, private bool, psm ProviderS } buffer := bufio.NewReader(res.Body) - var totalCost float64 = 0 - var totalTokens int = 0 + // var totalCost float64 = 0 + // var totalTokens int = 0 content := "" model := "" @@ -155,24 +147,26 @@ func getAzureChatCompletionHandler(r recorder, prod, private bool, psm ProviderS c.Set("model", model) } - tks, cost, err := aoe.EstimateChatCompletionStreamCostWithTokenCounts(model, content) - if err != nil { - stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.estimate_chat_completion_cost_and_tokens_error", nil, 1) - logError(log, "error when estimating azure openai chat completion stream cost with token counts", prod, cid, err) - } + c.Set("content", content) - estimatedPromptTokenCounts := c.GetInt("promptTokenCount") - promptCost, err := aoe.EstimatePromptCost(model, estimatedPromptTokenCounts) - if err != nil { - stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.estimate_chat_completion_cost_and_tokens_error", nil, 1) - logError(log, "error when estimating azure openai chat completion stream cost with token counts", prod, cid, err) - } + // tks, cost, err := aoe.EstimateChatCompletionStreamCostWithTokenCounts(model, content) + // if err != nil { + // stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.estimate_chat_completion_cost_and_tokens_error", nil, 1) + // logError(log, "error when estimating azure openai chat completion stream cost with token counts", prod, cid, err) + // } + + // estimatedPromptTokenCounts := c.GetInt("promptTokenCount") + // promptCost, err := aoe.EstimatePromptCost(model, estimatedPromptTokenCounts) + // if err != nil { + // stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.estimate_chat_completion_cost_and_tokens_error", nil, 1) + // logError(log, "error when estimating azure openai chat completion stream cost with token counts", prod, cid, err) + // } - totalCost = cost + promptCost - totalTokens += tks + // totalCost = cost + promptCost + // totalTokens += tks - c.Set("costInUsd", totalCost) - c.Set("completionTokenCount", totalTokens) + // c.Set("costInUsd", totalCost) + // c.Set("completionTokenCount", totalTokens) }() stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.streaming_requests", nil, 1) diff --git a/internal/server/web/proxy/azure_embedding.go b/internal/server/web/proxy/azure_embedding.go index c3c2830..f9ac0de 100644 --- a/internal/server/web/proxy/azure_embedding.go +++ b/internal/server/web/proxy/azure_embedding.go @@ -7,7 +7,6 @@ import ( "net/http" "time" - "github.com/bricks-cloud/bricksllm/internal/key" "github.com/bricks-cloud/bricksllm/internal/stats" "github.com/gin-gonic/gin" goopenai "github.com/sashabaranov/go-openai" @@ -23,13 +22,13 @@ func getAzureEmbeddingsHandler(r recorder, prod, private bool, psm ProviderSetti } cid := c.GetString(correlationId) - raw, exists := c.Get("key") - kc, ok := raw.(*key.ResponseKey) - if !exists || !ok { - stats.Incr("bricksllm.proxy.get_azure_embeddings_handler.api_key_not_registered", nil, 1) - JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered") - return - } + // raw, exists := c.Get("key") + // kc, ok := raw.(*key.ResponseKey) + // if !exists || !ok { + // stats.Incr("bricksllm.proxy.get_azure_embeddings_handler.api_key_not_registered", nil, 1) + // JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered") + // return + // } ctx, cancel := context.WithTimeout(context.Background(), timeOut) defer cancel() @@ -111,12 +110,12 @@ func getAzureEmbeddingsHandler(r recorder, prod, private bool, psm ProviderSetti logError(log, "error when estimating azure openai cost for embedding", prod, cid, err) } - micros := int64(cost * 1000000) - err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit) - if err != nil { - stats.Incr("bricksllm.proxy.get_azure_embeddings_handler.record_key_spend_error", nil, 1) - logError(log, "error when recording azure openai spend for embedding", prod, cid, err) - } + // micros := int64(cost * 1000000) + // err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit) + // if err != nil { + // stats.Incr("bricksllm.proxy.get_azure_embeddings_handler.record_key_spend_error", nil, 1) + // logError(log, "error when recording azure openai spend for embedding", prod, cid, err) + // } } } diff --git a/internal/server/web/proxy/custom_provider.go b/internal/server/web/proxy/custom_provider.go index bbc8291..1f9007f 100644 --- a/internal/server/web/proxy/custom_provider.go +++ b/internal/server/web/proxy/custom_provider.go @@ -120,12 +120,14 @@ func getCustomProviderHandler(prod, private bool, psm ProviderSettingsManager, c return } - tks, err := countTokensFromJson(bytes, rc.ResponseCompletionLocation) - if err != nil { - logError(log, "error when counting tokens for custom provider completion response", prod, cid, err) - } + c.Set("response", bytes) + + // tks, err := countTokensFromJson(bytes, rc.ResponseCompletionLocation) + // if err != nil { + // logError(log, "error when counting tokens for custom provider completion response", prod, cid, err) + // } - c.Set("completionTokenCount", tks) + // c.Set("completionTokenCount", tks) c.Data(res.StatusCode, "application/json", bytes) return } @@ -149,13 +151,15 @@ func getCustomProviderHandler(prod, private bool, psm ProviderSettingsManager, c buffer := bufio.NewReader(res.Body) aggregated := "" defer func() { - tks, err := custom.Count(aggregated) - if err != nil { - stats.Incr("bricksllm.proxy.get_custom_provider_handler.count_error", nil, 1) - logError(log, "error when counting tokens for custom provider streaming response", prod, cid, err) - } + c.Set("content", aggregated) + + // tks, err := custom.Count(aggregated) + // if err != nil { + // stats.Incr("bricksllm.proxy.get_custom_provider_handler.count_error", nil, 1) + // logError(log, "error when counting tokens for custom provider streaming response", prod, cid, err) + // } - c.Set("completionTokenCount", tks) + // c.Set("completionTokenCount", tks) }() stats.Incr("bricksllm.proxy.get_custom_provider_handler.streaming_requests", nil, 1) diff --git a/internal/server/web/proxy/middleware.go b/internal/server/web/proxy/middleware.go index fa5916d..e8b268d 100644 --- a/internal/server/web/proxy/middleware.go +++ b/internal/server/web/proxy/middleware.go @@ -11,6 +11,7 @@ import ( "github.com/bricks-cloud/bricksllm/internal/event" "github.com/bricks-cloud/bricksllm/internal/key" + "github.com/bricks-cloud/bricksllm/internal/message" "github.com/bricks-cloud/bricksllm/internal/provider" "github.com/bricks-cloud/bricksllm/internal/provider/anthropic" "github.com/bricks-cloud/bricksllm/internal/route" @@ -72,6 +73,10 @@ type rateLimitManager interface { Increment(keyId string, timeUnit key.TimeUnit) error } +type accessCache interface { + GetAccessStatus(key string) bool +} + type encrypter interface { Encrypt(secret string) string } @@ -93,7 +98,40 @@ type notFoundError interface { NotFound() } -func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManager, a authenticator, prod, private bool, e estimator, ae anthropicEstimator, aoe azureEstimator, v validator, ks keyStorage, log *zap.Logger, rlm rateLimitManager, r recorder, prefix string) gin.HandlerFunc { +type publisher interface { + Publish(message.Message) +} + +func getProvider(c *gin.Context) string { + existing := c.GetString("provider") + if len(existing) != 0 { + return existing + } + + parts := strings.Split(c.FullPath(), "/") + + spaceRemoved := []string{} + + for _, part := range parts { + if len(part) != 0 { + spaceRemoved = append(spaceRemoved, part) + } + } + + if strings.HasPrefix(c.FullPath(), "/api/providers/") { + if len(spaceRemoved) >= 3 { + return spaceRemoved[2] + } + } + + if strings.HasPrefix(c.FullPath(), "/api/custom/providers/") { + return c.Param("provider") + } + + return "" +} + +func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManager, a authenticator, prod, private bool, e estimator, ae anthropicEstimator, aoe azureEstimator, v validator, ks keyStorage, log *zap.Logger, rlm rateLimitManager, pub publisher, prefix string, ac accessCache) gin.HandlerFunc { return func(c *gin.Context) { if c == nil || c.Request == nil { JSON(c, http.StatusInternalServerError, "[BricksLLM] request is empty") @@ -110,17 +148,12 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage c.Set(correlationId, cid) start := time.Now() - selectedProvider := "openai" + enrichedEvent := &event.EventWithRequestAndContent{} customId := c.Request.Header.Get("X-CUSTOM-EVENT-ID") defer func() { dur := time.Now().Sub(start) latency := int(dur.Milliseconds()) - raw, exists := c.Get("key") - var kc *key.ResponseKey - if exists { - kc = raw.(*key.ResponseKey) - } if !prod { log.Sugar().Infof("%s | %d | %s | %s | %dms", prefix, c.Writer.Status(), c.Request.Method, c.FullPath(), latency) @@ -129,16 +162,14 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage keyId := "" tags := []string{} - if kc != nil { - keyId = kc.KeyId - tags = kc.Tags + if enrichedEvent.Key != nil { + keyId = enrichedEvent.Key.KeyId + tags = enrichedEvent.Key.Tags } stats.Timing("bricksllm.proxy.get_middleware.proxy_latency_in_ms", dur, nil, 1) - if len(c.GetString("provider")) != 0 { - selectedProvider = c.GetString("provider") - } + selectedProvider := getProvider(c) if prod { log.Info("response to proxy", @@ -173,12 +204,16 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage CustomId: customId, } - err := r.RecordEvent(evt) - if err != nil { - stats.Incr("bricksllm.proxy.get_middleware.record_event_error", nil, 1) - - logError(log, "error when recording openai event", prod, cid, err) + enrichedEvent.Event = evt + content := c.GetString("content") + if len(content) != 0 { + enrichedEvent.Content = content } + + pub.Publish(message.Message{ + Type: "event", + Data: enrichedEvent, + }) }() if len(c.FullPath()) == 0 { @@ -189,6 +224,7 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage } kc, settings, err := a.AuthenticateHttpRequest(c.Request) + enrichedEvent.Key = kc _, ok := err.(notAuthorizedError) if ok { stats.Incr("bricksllm.proxy.get_middleware.authentication_error", nil, 1) @@ -236,12 +272,11 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage c.Request.Body = io.NopCloser(bytes.NewReader(body)) } - var cost float64 = 0 + // var cost float64 = 0 if c.FullPath() == "/api/providers/anthropic/v1/complete" { logCompletionRequest(log, body, prod, private, cid) - selectedProvider = "anthropic" cr := &anthropic.CompletionRequest{} err = json.Unmarshal(body, cr) if err != nil { @@ -249,19 +284,21 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage return } - tks := ae.Count(cr.Prompt) - tks += anthropicPromptMagicNum - c.Set("promptTokenCount", tks) + enrichedEvent.Request = cr - model := cr.Model - cost, err = ae.EstimatePromptCost(model, tks) - if err != nil { - logError(log, "error when estimating anthropic completion prompt cost", prod, cid, err) - } + // tks := ae.Count(cr.Prompt) + // tks += anthropicPromptMagicNum + // c.Set("promptTokenCount", tks) + + // model := cr.Model + // cost, err = ae.EstimatePromptCost(model, tks) + // if err != nil { + // logError(log, "error when estimating anthropic completion prompt cost", prod, cid, err) + // } if cr.Stream { c.Set("stream", cr.Stream) - c.Set("estimatedPromptCostInUsd", cost) + // c.Set("estimatedPromptCostInUsd", cost) } if len(cr.Model) != 0 { @@ -288,17 +325,22 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage return } - selectedProvider = cp.Provider - c.Set("provider", cp) c.Set("route_config", rc) - tks, err := countTokensFromJson(body, rc.RequestPromptLocation) - if err != nil { - logError(log, "error when counting tokens for custom provider request", prod, cid, err) + enrichedEvent.Request = body + + customResponse, ok := c.Get("response") + if ok { + enrichedEvent.Response = customResponse } - c.Set("promptTokenCount", tks) + // tks, err := countTokensFromJson(body, rc.RequestPromptLocation) + // if err != nil { + // logError(log, "error when counting tokens for custom provider request", prod, cid, err) + // } + + // c.Set("promptTokenCount", tks) result := gjson.Get(string(body), rc.StreamLocation) @@ -344,12 +386,15 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage if !rc.ShouldRunEmbeddings() { ccr := &goopenai.ChatCompletionRequest{} + err = json.Unmarshal(body, ccr) if err != nil { logError(log, "error when unmarshalling route chat completion request", prod, cid, err) return } + enrichedEvent.Request = ccr + logRequest(log, prod, private, cid, ccr) if ccr.Stream { @@ -366,8 +411,6 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage } if c.FullPath() == "/api/providers/azure/openai/deployments/:deployment_id/chat/completions" { - selectedProvider = "azure" - ccr := &goopenai.ChatCompletionRequest{} err = json.Unmarshal(body, ccr) if err != nil { @@ -375,23 +418,23 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage return } + enrichedEvent.Request = ccr + logRequest(log, prod, private, cid, ccr) - tks, err := e.EstimateChatCompletionPromptTokenCounts("gpt-3.5-turbo", ccr) - if err != nil { - stats.Incr("bricksllm.proxy.get_middleware.estimate_chat_completion_prompt_token_counts_error", nil, 1) - logError(log, "error when estimating prompt cost", prod, cid, err) - } + // tks, err := e.EstimateChatCompletionPromptTokenCounts("gpt-3.5-turbo", ccr) + // if err != nil { + // stats.Incr("bricksllm.proxy.get_middleware.estimate_chat_completion_prompt_token_counts_error", nil, 1) + // logError(log, "error when estimating prompt cost", prod, cid, err) + // } if ccr.Stream { c.Set("stream", true) - c.Set("promptTokenCount", tks) + // c.Set("promptTokenCount", tks) } } if c.FullPath() == "/api/providers/azure/openai/deployments/:deployment_id/embeddings" { - selectedProvider = "azure" - er := &goopenai.EmbeddingRequest{} err = json.Unmarshal(body, er) if err != nil { @@ -404,11 +447,11 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage logEmbeddingRequest(log, prod, private, cid, er) - cost, err = aoe.EstimateEmbeddingsCost(er) - if err != nil { - stats.Incr("bricksllm.proxy.get_middleware.estimate_azure_openai_embeddings_cost_error", nil, 1) - logError(log, "error when estimating azure openai embeddings cost", prod, cid, err) - } + // cost, err = aoe.EstimateEmbeddingsCost(er) + // if err != nil { + // stats.Incr("bricksllm.proxy.get_middleware.estimate_azure_openai_embeddings_cost_error", nil, 1) + // logError(log, "error when estimating azure openai embeddings cost", prod, cid, err) + // } } if c.FullPath() == "/api/providers/openai/v1/chat/completions" { @@ -419,6 +462,8 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage return } + enrichedEvent.Request = ccr + c.Set("model", ccr.Model) logRequest(log, prod, private, cid, ccr) @@ -445,16 +490,16 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage return } - c.Set("model", er.Model.String()) + c.Set("model", string(er.Model)) c.Set("encoding_format", string(er.EncodingFormat)) logEmbeddingRequest(log, prod, private, cid, er) - cost, err = e.EstimateEmbeddingsCost(er) - if err != nil { - stats.Incr("bricksllm.proxy.get_middleware.estimate_embeddings_cost_error", nil, 1) - logError(log, "error when estimating embeddings cost", prod, cid, err) - } + // cost, err = e.EstimateEmbeddingsCost(er) + // if err != nil { + // stats.Incr("bricksllm.proxy.get_middleware.estimate_embeddings_cost_error", nil, 1) + // logError(log, "error when estimating embeddings cost", prod, cid, err) + // } } if c.FullPath() == "/api/providers/openai/v1/images/generations" && c.Request.Method == http.MethodPost { @@ -735,48 +780,13 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage logRetrieveFileContentRequest(log, body, prod, cid, fid) } - err = v.Validate(kc, cost) - if err != nil { - stats.Incr("bricksllm.proxy.get_middleware.validation_error", nil, 1) - - if _, ok := err.(expirationError); ok { - stats.Incr("bricksllm.proxy.get_middleware.key_expired", nil, 1) - - truePtr := true - _, err = ks.UpdateKey(kc.KeyId, &key.UpdateKey{ - Revoked: &truePtr, - RevokedReason: "Key has expired or exceeded set spend limit", - }) - - if err != nil { - stats.Incr("bricksllm.proxy.get_middleware.update_key_error", nil, 1) - log.Sugar().Debugf("error when updating revoking the api key %s: %v", kc.KeyId, err) - } - - JSON(c, http.StatusUnauthorized, "[BricksLLM] key has expired") - c.Abort() - return - } - - if _, ok := err.(rateLimitError); ok { - stats.Incr("bricksllm.proxy.get_middleware.rate_limited", nil, 1) - JSON(c, http.StatusTooManyRequests, "[BricksLLM] too many requests") - c.Abort() - return - } - - logError(log, "error when validating api key", prod, cid, err) + if ac.GetAccessStatus(kc.KeyId) { + stats.Incr("bricksllm.proxy.get_middleware.rate_limited", nil, 1) + JSON(c, http.StatusTooManyRequests, "[BricksLLM] too many requests") + c.Abort() return } - if len(kc.RateLimitUnit) != 0 { - if err := rlm.Increment(kc.KeyId, kc.RateLimitUnit); err != nil { - stats.Incr("bricksllm.proxy.get_middleware.rate_limit_increment_error", nil, 1) - - logError(log, "error when incrementing rate limit counter", prod, cid, err) - } - } - c.Next() } } diff --git a/internal/server/web/proxy/proxy.go b/internal/server/web/proxy/proxy.go index dbcb5eb..221dac8 100644 --- a/internal/server/web/proxy/proxy.go +++ b/internal/server/web/proxy/proxy.go @@ -39,7 +39,7 @@ type ProxyServer struct { } type recorder interface { - RecordKeySpend(keyId string, micros int64, costLimitUnit key.TimeUnit) error + // RecordKeySpend(keyId string, micros int64, costLimitUnit key.TimeUnit) error RecordEvent(e *event.Event) error } @@ -55,12 +55,12 @@ type CustomProvidersManager interface { GetCustomProviderFromMem(name string) *custom.Provider } -func NewProxyServer(log *zap.Logger, mode, privacyMode string, c cache, m KeyManager, rm routeManager, a authenticator, psm ProviderSettingsManager, cpm CustomProvidersManager, ks keyStorage, kms keyMemStorage, e estimator, ae anthropicEstimator, aoe azureEstimator, v validator, r recorder, rlm rateLimitManager, timeOut time.Duration) (*ProxyServer, error) { +func NewProxyServer(log *zap.Logger, mode, privacyMode string, c cache, m KeyManager, rm routeManager, a authenticator, psm ProviderSettingsManager, cpm CustomProvidersManager, ks keyStorage, kms keyMemStorage, e estimator, ae anthropicEstimator, aoe azureEstimator, v validator, r recorder, pub publisher, rlm rateLimitManager, timeOut time.Duration, ac accessCache) (*ProxyServer, error) { router := gin.New() prod := mode == "production" private := privacyMode == "strict" - router.Use(getMiddleware(kms, cpm, rm, a, prod, private, e, ae, aoe, v, ks, log, rlm, r, "proxy")) + router.Use(getMiddleware(kms, cpm, rm, a, prod, private, e, ae, aoe, v, ks, log, rlm, pub, "proxy", ac)) client := http.Client{} @@ -942,13 +942,13 @@ func getEmbeddingHandler(r recorder, prod, private bool, psm ProviderSettingsMan return } - raw, exists := c.Get("key") - kc, ok := raw.(*key.ResponseKey) - if !exists || !ok { - stats.Incr("bricksllm.proxy.get_embedding_handler.api_key_not_registered", nil, 1) - JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered") - return - } + // raw, exists := c.Get("key") + // kc, ok := raw.(*key.ResponseKey) + // if !exists || !ok { + // stats.Incr("bricksllm.proxy.get_embedding_handler.api_key_not_registered", nil, 1) + // JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered") + // return + // } id := c.GetString(correlationId) @@ -1032,12 +1032,12 @@ func getEmbeddingHandler(r recorder, prod, private bool, psm ProviderSettingsMan logError(log, "error when estimating openai cost for embedding", prod, id, err) } - micros := int64(cost * 1000000) - err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit) - if err != nil { - stats.Incr("bricksllm.proxy.get_embedding_handler.record_key_spend_error", nil, 1) - logError(log, "error when recording openai spend for embedding", prod, id, err) - } + // micros := int64(cost * 1000000) + // err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit) + // if err != nil { + // stats.Incr("bricksllm.proxy.get_embedding_handler.record_key_spend_error", nil, 1) + // logError(log, "error when recording openai spend for embedding", prod, id, err) + // } } } @@ -1085,13 +1085,13 @@ func getChatCompletionHandler(r recorder, prod, private bool, psm ProviderSettin } cid := c.GetString(correlationId) - raw, exists := c.Get("key") - kc, ok := raw.(*key.ResponseKey) - if !exists || !ok { - stats.Incr("bricksllm.proxy.get_chat_completion_handler.api_key_not_registered", nil, 1) - JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered") - return - } + // raw, exists := c.Get("key") + // kc, ok := raw.(*key.ResponseKey) + // if !exists || !ok { + // stats.Incr("bricksllm.proxy.get_chat_completion_handler.api_key_not_registered", nil, 1) + // JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered") + // return + // } ctx, cancel := context.WithTimeout(context.Background(), timeOut) defer cancel() @@ -1161,12 +1161,12 @@ func getChatCompletionHandler(r recorder, prod, private bool, psm ProviderSettin logError(log, "error when estimating openai cost", prod, cid, err) } - micros := int64(cost * 1000000) - err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit) - if err != nil { - stats.Incr("bricksllm.proxy.get_chat_completion_handler.record_key_spend_error", nil, 1) - logError(log, "error when recording openai spend", prod, cid, err) - } + // micros := int64(cost * 1000000) + // err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit) + // if err != nil { + // stats.Incr("bricksllm.proxy.get_chat_completion_handler.record_key_spend_error", nil, 1) + // logError(log, "error when recording openai spend", prod, cid, err) + // } } c.Set("costInUsd", cost) @@ -1195,22 +1195,24 @@ func getChatCompletionHandler(r recorder, prod, private bool, psm ProviderSettin } buffer := bufio.NewReader(res.Body) - var totalCost float64 = 0 - var totalTokens int = 0 + // var totalCost float64 = 0 + // var totalTokens int = 0 content := "" defer func() { - tks, cost, err := e.EstimateChatCompletionStreamCostWithTokenCounts(model, content) - if err != nil { - stats.Incr("bricksllm.proxy.get_chat_completion_handler.estimate_chat_completion_cost_and_tokens_error", nil, 1) - logError(log, "error when estimating chat completion stream cost with token counts", prod, cid, err) - } + c.Set("content", content) + + // tks, cost, err := e.EstimateChatCompletionStreamCostWithTokenCounts(model, content) + // if err != nil { + // stats.Incr("bricksllm.proxy.get_chat_completion_handler.estimate_chat_completion_cost_and_tokens_error", nil, 1) + // logError(log, "error when estimating chat completion stream cost with token counts", prod, cid, err) + // } - estimatedPromptCost := c.GetFloat64("estimatedPromptCostInUsd") - totalCost = cost + estimatedPromptCost - totalTokens += tks + // estimatedPromptCost := c.GetFloat64("estimatedPromptCostInUsd") + // totalCost = cost + estimatedPromptCost + // totalTokens += tks - c.Set("costInUsd", totalCost) - c.Set("completionTokenCount", totalTokens) + // c.Set("costInUsd", totalCost) + // c.Set("completionTokenCount", totalTokens) }() stats.Incr("bricksllm.proxy.get_chat_completion_handler.streaming_requests", nil, 1) @@ -1522,7 +1524,7 @@ func logEmbeddingRequest(log *zap.Logger, prod, private bool, id string, r *goop if prod { fields := []zapcore.Field{ zap.String(correlationId, id), - zap.String("model", r.Model.String()), + zap.String("model", string(r.Model)), zap.String("encoding_format", string(r.EncodingFormat)), zap.String("user", r.User), } diff --git a/internal/server/web/proxy/route.go b/internal/server/web/proxy/route.go index 13357a9..90c46af 100644 --- a/internal/server/web/proxy/route.go +++ b/internal/server/web/proxy/route.go @@ -231,12 +231,12 @@ func parseResult(c *gin.Context, ca cache, kc *key.ResponseKey, runEmbeddings bo cost = ecost } - micros := int64(cost * 1000000) + // micros := int64(cost * 1000000) - err := r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit) - if err != nil { - return err - } + // err := r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit) + // if err != nil { + // return err + // } } if !runEmbeddings { @@ -262,11 +262,11 @@ func parseResult(c *gin.Context, ca cache, kc *key.ResponseKey, runEmbeddings bo } } - micros := int64(cost * 1000000) - err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit) - if err != nil { - return err - } + // micros := int64(cost * 1000000) + // err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit) + // if err != nil { + // return err + // } } return nil diff --git a/internal/storage/redis/access-cache.go b/internal/storage/redis/access-cache.go new file mode 100644 index 0000000..d8dccef --- /dev/null +++ b/internal/storage/redis/access-cache.go @@ -0,0 +1,48 @@ +package redis + +import ( + "context" + "time" + + "github.com/bricks-cloud/bricksllm/internal/key" + "github.com/redis/go-redis/v9" +) + +type AccessCache struct { + client *redis.Client + wt time.Duration + rt time.Duration +} + +func NewAccessCache(c *redis.Client, wt time.Duration, rt time.Duration) *AccessCache { + return &AccessCache{ + client: c, + wt: wt, + rt: rt, + } +} + +func (ac *AccessCache) Set(key string, timeUnit key.TimeUnit) error { + ttl, err := getCounterTtl(timeUnit) + if err != nil { + return err + } + + ctx, cancel := context.WithTimeout(context.Background(), ac.wt) + defer cancel() + err = ac.client.Set(ctx, key, true, ttl.Sub(time.Now())).Err() + if err != nil { + return err + } + + return nil +} + +func (ac *AccessCache) GetAccessStatus(key string) bool { + ctx, cancel := context.WithTimeout(context.Background(), ac.rt) + defer cancel() + + result := ac.client.Get(ctx, key) + + return result.Err() != redis.Nil +} diff --git a/internal/validator/validator.go b/internal/validator/validator.go index 9551261..d2fd4ba 100644 --- a/internal/validator/validator.go +++ b/internal/validator/validator.go @@ -58,12 +58,12 @@ func (v *Validator) Validate(k *key.ResponseKey, promptCost float64) error { return err } - err = v.validateCostLimitOverTime(k.KeyId, k.CostLimitInUsdOverTime, k.CostLimitInUsdUnit, promptCost) + err = v.validateCostLimitOverTime(k.KeyId, k.CostLimitInUsdOverTime, k.CostLimitInUsdUnit) if err != nil { return err } - err = v.validateCostLimit(k.KeyId, k.CostLimitInUsd, promptCost) + err = v.validateCostLimit(k.KeyId, k.CostLimitInUsd) if err != nil { return err } @@ -96,14 +96,14 @@ func (v *Validator) validateRateLimitOverTime(keyId string, rateLimitOverTime in return errors.New("failed to get rate limit counter") } - if c+1 > int64(rateLimitOverTime) { + if c >= int64(rateLimitOverTime) { return internal_errors.NewRateLimitError(fmt.Sprintf("key exceeded rate limit %d requests per %s", rateLimitOverTime, rateLimitUnit)) } return nil } -func (v *Validator) validateCostLimitOverTime(keyId string, costLimitOverTime float64, costLimitUnit key.TimeUnit, promptCost float64) error { +func (v *Validator) validateCostLimitOverTime(keyId string, costLimitOverTime float64, costLimitUnit key.TimeUnit) error { if costLimitOverTime == 0 { return nil } @@ -113,8 +113,8 @@ func (v *Validator) validateCostLimitOverTime(keyId string, costLimitOverTime fl return errors.New("failed to get cached token cost") } - if convertDollarToMicroDollars(promptCost)+cachedCost > convertDollarToMicroDollars(costLimitOverTime) { - return internal_errors.NewExpirationError(fmt.Sprintf("cost limit: %f has been reached for the current time period: %s", costLimitOverTime, costLimitUnit), internal_errors.CostLimitExpiration) + if cachedCost >= convertDollarToMicroDollars(costLimitOverTime) { + return internal_errors.NewCostLimitError(fmt.Sprintf("cost limit: %f has been reached for the current time period: %s", costLimitOverTime, costLimitUnit)) } return nil @@ -124,7 +124,7 @@ func convertDollarToMicroDollars(dollar float64) int64 { return int64(dollar * 1000000) } -func (v *Validator) validateCostLimit(keyId string, costLimit float64, promptCost float64) error { +func (v *Validator) validateCostLimit(keyId string, costLimit float64) error { if costLimit == 0 { return nil } @@ -134,7 +134,7 @@ func (v *Validator) validateCostLimit(keyId string, costLimit float64, promptCos return errors.New("failed to get total token cost") } - if convertDollarToMicroDollars(promptCost)+existingTotalCost > convertDollarToMicroDollars(costLimit) { + if existingTotalCost >= convertDollarToMicroDollars(costLimit) { return internal_errors.NewExpirationError(fmt.Sprintf("total cost limit: %f has been reached", costLimit), internal_errors.CostLimitExpiration) }