diff --git a/service/aiproxy/common/consume/consume.go b/service/aiproxy/common/consume/consume.go index 7ba0d3d2ef6..139de8675b1 100644 --- a/service/aiproxy/common/consume/consume.go +++ b/service/aiproxy/common/consume/consume.go @@ -19,7 +19,6 @@ func Wait() { } func AsyncConsume( - ctx context.Context, postGroupConsumer balance.PostGroupConsumer, code int, usage *relaymodel.Usage, @@ -41,7 +40,17 @@ func AsyncConsume( } }() - go Consume(ctx, postGroupConsumer, code, usage, meta, inputPrice, outputPrice, content, requestDetail) + go Consume( + context.Background(), + postGroupConsumer, + code, + usage, + meta, + inputPrice, + outputPrice, + content, + requestDetail, + ) } func Consume( diff --git a/service/aiproxy/common/rpmlimit/rate-limit.go b/service/aiproxy/common/rpmlimit/rate-limit.go index 3c186d707c8..56abc90eced 100644 --- a/service/aiproxy/common/rpmlimit/rate-limit.go +++ b/service/aiproxy/common/rpmlimit/rate-limit.go @@ -16,60 +16,89 @@ const ( groupModelRPMKey = "group_model_rpm:%s:%s" ) -// 1. 使用Redis列表存储请求时间戳 -// 2. 列表长度代表当前窗口内的请求数 -// 3. 如果请求数未达到限制,直接添加新请求并返回成功 -// 4. 如果达到限制,则检查最老的请求是否已经过期 -// 5. 如果最老的请求已过期,最多移除3个过期请求并添加新请求,否则拒绝新请求 -// 6. 通过EXPIRE命令设置键的过期时间,自动清理过期数据 -var luaScript = ` +var pushRequestScript = ` local key = KEYS[1] -local max_requests = tonumber(ARGV[1]) -local window = tonumber(ARGV[2]) -local current_time = tonumber(ARGV[3]) +local window = tonumber(ARGV[1]) +local current_time = tonumber(ARGV[2]) +local cutoff = current_time - window + +local page_size = 100 +local remove_count = 0 -local count = redis.call('LLEN', key) - -if count < max_requests then - redis.call('LPUSH', key, current_time) - redis.call('PEXPIRE', key, window) - return 1 -else - local removed = 0 - for i = 1, 3 do - local oldest = redis.call('LINDEX', key, -1) - if current_time - tonumber(oldest) >= window then - redis.call('RPOP', key) - removed = removed + 1 +while true do + local timestamps = redis.call('LRANGE', key, remove_count, remove_count + page_size - 1) + if #timestamps == 0 then + break + end + + local found_non_expired = false + for i = 1, #timestamps do + local timestamp = tonumber(timestamps[i]) + if timestamp < cutoff then + remove_count = remove_count + 1 else + found_non_expired = true break end end - if removed > 0 then - redis.call('LPUSH', key, current_time) - redis.call('PEXPIRE', key, window) - return 1 - else - return 0 + + if found_non_expired then + break end end + +if remove_count > 0 then + redis.call('LTRIM', key, remove_count, -1) +end + +redis.call('LPUSH', key, current_time) + +redis.call('PEXPIRE', key, window) + +return redis.call('LLEN', key) ` -var getRPMSumLuaScript = ` +var getRequestCountScript = ` local pattern = ARGV[1] local window = tonumber(ARGV[2]) local current_time = tonumber(ARGV[3]) +local cutoff = current_time - window +local page_size = 100 local keys = redis.call('KEYS', pattern) local total = 0 for _, key in ipairs(keys) do - local timestamps = redis.call('LRANGE', key, 0, -1) - for _, ts in ipairs(timestamps) do - if current_time - tonumber(ts) < window then - total = total + 1 + local remove_count = 0 + + while true do + local timestamps = redis.call('LRANGE', key, remove_count, remove_count + page_size - 1) + if #timestamps == 0 then + break + end + + local found_non_expired = false + for i = 1, #timestamps do + local timestamp = tonumber(timestamps[i]) + if timestamp < cutoff then + remove_count = remove_count + 1 + else + found_non_expired = true + break + end + end + + if found_non_expired then + break end end + + if remove_count > 0 then + redis.call('LTRIM', key, remove_count, -1) + end + + local total_count = redis.call('LLEN', key) + total = total + total_count end return total @@ -93,24 +122,35 @@ func GetRPM(ctx context.Context, group, model string) (int64, error) { rdb := common.RDB currentTime := time.Now().UnixMilli() - result, err := rdb.Eval(ctx, getRPMSumLuaScript, []string{}, pattern, time.Minute.Milliseconds(), currentTime).Int64() + result, err := rdb.Eval( + ctx, + getRequestCountScript, + []string{}, + pattern, + time.Minute.Milliseconds(), + currentTime, + ).Int64() if err != nil { return 0, err } - return result, nil } func redisRateLimitRequest(ctx context.Context, group, model string, maxRequestNum int64, duration time.Duration) (bool, error) { rdb := common.RDB - currentTime := time.Now().UnixMilli() - result, err := rdb.Eval(ctx, luaScript, []string{ - fmt.Sprintf(groupModelRPMKey, group, model), - }, maxRequestNum, duration.Milliseconds(), currentTime).Int64() + result, err := rdb.Eval( + ctx, + pushRequestScript, + []string{ + fmt.Sprintf(groupModelRPMKey, group, model), + }, + duration.Milliseconds(), + time.Now().UnixMilli(), + ).Int64() if err != nil { return false, err } - return result == 1, nil + return result <= maxRequestNum, nil } func RateLimit(ctx context.Context, group, model string, maxRequestNum int64, duration time.Duration) (bool, error) { diff --git a/service/aiproxy/middleware/distributor.go b/service/aiproxy/middleware/distributor.go index 5876014e06b..71ee5b30447 100644 --- a/service/aiproxy/middleware/distributor.go +++ b/service/aiproxy/middleware/distributor.go @@ -4,10 +4,13 @@ import ( "fmt" "net/http" "slices" + "strings" "time" "github.com/gin-gonic/gin" + "github.com/labring/sealos/service/aiproxy/common" "github.com/labring/sealos/service/aiproxy/common/config" + "github.com/labring/sealos/service/aiproxy/common/consume" "github.com/labring/sealos/service/aiproxy/common/ctxkey" "github.com/labring/sealos/service/aiproxy/common/rpmlimit" "github.com/labring/sealos/service/aiproxy/model" @@ -15,10 +18,6 @@ import ( log "github.com/sirupsen/logrus" ) -type ModelRequest struct { - Model string `form:"model" json:"model"` -} - func calculateGroupConsumeLevelRpmRatio(usedAmount float64) float64 { v := config.GetGroupConsumeLevelRpmRatio() var maxConsumeLevel float64 = -1 @@ -90,7 +89,13 @@ func checkGroupModelRPMAndTPM(c *gin.Context, group *model.GroupCache, requestMo return nil } -func Distribute(c *gin.Context) { +func NewDistribute(mode int) gin.HandlerFunc { + return func(c *gin.Context) { + distribute(c, mode) + } +} + +func distribute(c *gin.Context, mode int) { if config.GetDisableServe() { abortWithMessage(c, http.StatusServiceUnavailable, "service is under maintenance") return @@ -110,6 +115,8 @@ func Distribute(c *gin.Context) { return } + c.Set(ctxkey.OriginalModel, requestModel) + SetLogModelFields(log.Data, requestModel) mc, ok := GetModelCaches(c).ModelConfigMap[requestModel] @@ -118,7 +125,10 @@ func Distribute(c *gin.Context) { return } + c.Set(ctxkey.ModelConfig, mc) + token := GetToken(c) + if len(token.Models) == 0 || !slices.Contains(token.Models, requestModel) { abortWithMessage(c, http.StatusForbidden, @@ -130,13 +140,21 @@ func Distribute(c *gin.Context) { } if err := checkGroupModelRPMAndTPM(c, group, requestModel, mc.RPM, mc.TPM); err != nil { - abortWithMessage(c, http.StatusTooManyRequests, err.Error()) + errMsg := err.Error() + consume.AsyncConsume( + nil, + http.StatusTooManyRequests, + nil, + NewMetaByContext(c, nil, requestModel, mode), + 0, + 0, + errMsg, + nil, + ) + abortWithMessage(c, http.StatusTooManyRequests, errMsg) return } - c.Set(ctxkey.OriginalModel, requestModel) - c.Set(ctxkey.ModelConfig, mc) - c.Next() } @@ -164,3 +182,26 @@ func NewMetaByContext(c *gin.Context, channel *model.Channel, modelName string, meta.WithEndpoint(c.Request.URL.Path), ) } + +type ModelRequest struct { + Model string `form:"model" json:"model"` +} + +func getRequestModel(c *gin.Context) (string, error) { + path := c.Request.URL.Path + switch { + case strings.HasPrefix(path, "/v1/audio/transcriptions"), + strings.HasPrefix(path, "/v1/audio/translations"): + return c.Request.FormValue("model"), nil + case strings.HasPrefix(path, "/v1/engines") && strings.HasSuffix(path, "/embeddings"): + // /engines/:model/embeddings + return c.Param("model"), nil + default: + var modelRequest ModelRequest + err := common.UnmarshalBodyReusable(c.Request, &modelRequest) + if err != nil { + return "", fmt.Errorf("get request model failed: %w", err) + } + return modelRequest.Model, nil + } +} diff --git a/service/aiproxy/middleware/utils.go b/service/aiproxy/middleware/utils.go index 96e5534679e..b113dbc93be 100644 --- a/service/aiproxy/middleware/utils.go +++ b/service/aiproxy/middleware/utils.go @@ -2,10 +2,8 @@ package middleware import ( "fmt" - "strings" "github.com/gin-gonic/gin" - "github.com/labring/sealos/service/aiproxy/common" "github.com/labring/sealos/service/aiproxy/relay/model" ) @@ -27,21 +25,3 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) { }) c.Abort() } - -func getRequestModel(c *gin.Context) (string, error) { - path := c.Request.URL.Path - switch { - case strings.HasPrefix(path, "/v1/audio/transcriptions"), strings.HasPrefix(path, "/v1/audio/translations"): - return c.Request.FormValue("model"), nil - case strings.HasPrefix(path, "/v1/engines") && strings.HasSuffix(path, "/embeddings"): - // /engines/:model/embeddings - return c.Param("model"), nil - default: - var modelRequest ModelRequest - err := common.UnmarshalBodyReusable(c.Request, &modelRequest) - if err != nil { - return "", fmt.Errorf("get request model failed: %w", err) - } - return modelRequest.Model, nil - } -} diff --git a/service/aiproxy/relay/controller/handle.go b/service/aiproxy/relay/controller/handle.go index 99b90696b0d..e000266dfdc 100644 --- a/service/aiproxy/relay/controller/handle.go +++ b/service/aiproxy/relay/controller/handle.go @@ -1,7 +1,6 @@ package controller import ( - "context" "errors" "fmt" "net/http" @@ -37,7 +36,6 @@ func Handle(meta *meta.Meta, c *gin.Context, preProcess func() (*PreCheckGroupBa log.Errorf("get group (%s) balance failed: %v", meta.Group.ID, err) errMsg := fmt.Sprintf("get group (%s) balance failed", meta.Group.ID) consume.AsyncConsume( - context.Background(), nil, http.StatusInternalServerError, nil, @@ -68,7 +66,6 @@ func Handle(meta *meta.Meta, c *gin.Context, preProcess func() (*PreCheckGroupBa } } consume.AsyncConsume( - context.Background(), nil, http.StatusBadRequest, nil, @@ -104,7 +101,6 @@ func Handle(meta *meta.Meta, c *gin.Context, preProcess func() (*PreCheckGroupBa } consume.AsyncConsume( - context.Background(), postGroupConsumer, respErr.StatusCode, usage, @@ -119,7 +115,6 @@ func Handle(meta *meta.Meta, c *gin.Context, preProcess func() (*PreCheckGroupBa // 6. Post consume consume.AsyncConsume( - context.Background(), postGroupConsumer, http.StatusOK, usage, diff --git a/service/aiproxy/relay/meta/meta.go b/service/aiproxy/relay/meta/meta.go index 4e1ab6b0a0d..5fd49e10891 100644 --- a/service/aiproxy/relay/meta/meta.go +++ b/service/aiproxy/relay/meta/meta.go @@ -81,6 +81,7 @@ func NewMeta( values: make(map[string]any), Mode: mode, OriginModel: modelName, + ActualModel: modelName, RequestAt: time.Now(), ModelConfig: modelConfig, } @@ -89,7 +90,9 @@ func NewMeta( opt(&meta) } - meta.Reset(channel) + if channel != nil { + meta.Reset(channel) + } return &meta } diff --git a/service/aiproxy/router/relay.go b/service/aiproxy/router/relay.go index f33925bc528..513374c4a55 100644 --- a/service/aiproxy/router/relay.go +++ b/service/aiproxy/router/relay.go @@ -24,20 +24,66 @@ func SetRelayRouter(router *gin.Engine) { dashboardRouter.GET("/billing/usage", controller.GetUsage) } relayRouter := v1Router.Group("") - relayRouter.Use(middleware.Distribute) { - relayRouter.POST("/completions", controller.NewRelay(relaymode.Completions)) - relayRouter.POST("/chat/completions", controller.NewRelay(relaymode.ChatCompletions)) - relayRouter.POST("/edits", controller.NewRelay(relaymode.Edits)) - relayRouter.POST("/images/generations", controller.NewRelay(relaymode.ImagesGenerations)) + relayRouter.POST( + "/completions", + middleware.NewDistribute(relaymode.Completions), + controller.NewRelay(relaymode.Completions), + ) + + relayRouter.POST( + "/chat/completions", + middleware.NewDistribute(relaymode.ChatCompletions), + controller.NewRelay(relaymode.ChatCompletions), + ) + relayRouter.POST( + "/edits", + middleware.NewDistribute(relaymode.Edits), + controller.NewRelay(relaymode.Edits), + ) + relayRouter.POST( + "/images/generations", + middleware.NewDistribute(relaymode.ImagesGenerations), + controller.NewRelay(relaymode.ImagesGenerations), + ) + relayRouter.POST( + "/embeddings", + middleware.NewDistribute(relaymode.Embeddings), + controller.NewRelay(relaymode.Embeddings), + ) + relayRouter.POST( + "/engines/:model/embeddings", + middleware.NewDistribute(relaymode.Embeddings), + controller.NewRelay(relaymode.Embeddings), + ) + relayRouter.POST( + "/audio/transcriptions", + middleware.NewDistribute(relaymode.AudioTranscription), + controller.NewRelay(relaymode.AudioTranscription), + ) + relayRouter.POST( + "/audio/translations", + middleware.NewDistribute(relaymode.AudioTranslation), + controller.NewRelay(relaymode.AudioTranslation), + ) + relayRouter.POST( + "/audio/speech", + middleware.NewDistribute(relaymode.AudioSpeech), + controller.NewRelay(relaymode.AudioSpeech), + ) + relayRouter.POST( + "/rerank", + middleware.NewDistribute(relaymode.Rerank), + controller.NewRelay(relaymode.Rerank), + ) + relayRouter.POST( + "/moderations", + middleware.NewDistribute(relaymode.Moderations), + controller.NewRelay(relaymode.Moderations), + ) + relayRouter.POST("/images/edits", controller.RelayNotImplemented) relayRouter.POST("/images/variations", controller.RelayNotImplemented) - relayRouter.POST("/embeddings", controller.NewRelay(relaymode.Embeddings)) - relayRouter.POST("/engines/:model/embeddings", controller.NewRelay(relaymode.Embeddings)) - relayRouter.POST("/audio/transcriptions", controller.NewRelay(relaymode.AudioTranscription)) - relayRouter.POST("/audio/translations", controller.NewRelay(relaymode.AudioTranslation)) - relayRouter.POST("/audio/speech", controller.NewRelay(relaymode.AudioSpeech)) - relayRouter.POST("/rerank", controller.NewRelay(relaymode.Rerank)) relayRouter.GET("/files", controller.RelayNotImplemented) relayRouter.POST("/files", controller.RelayNotImplemented) relayRouter.DELETE("/files/:id", controller.RelayNotImplemented) @@ -49,7 +95,6 @@ func SetRelayRouter(router *gin.Engine) { relayRouter.POST("/fine_tuning/jobs/:id/cancel", controller.RelayNotImplemented) relayRouter.GET("/fine_tuning/jobs/:id/events", controller.RelayNotImplemented) relayRouter.DELETE("/models/:model", controller.RelayNotImplemented) - relayRouter.POST("/moderations", controller.NewRelay(relaymode.Moderations)) relayRouter.POST("/assistants", controller.RelayNotImplemented) relayRouter.GET("/assistants/:id", controller.RelayNotImplemented) relayRouter.POST("/assistants/:id", controller.RelayNotImplemented)