Skip to content

Commit

Permalink
feat: group model list adjusted tpm rpm
Browse files Browse the repository at this point in the history
  • Loading branch information
zijiren233 committed Jan 6, 2025
1 parent 00cf94d commit 3e7dad5
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 51 deletions.
14 changes: 7 additions & 7 deletions service/aiproxy/common/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,14 @@ var (
defaultChannelModels atomic.Value
defaultChannelModelMapping atomic.Value
groupMaxTokenNum atomic.Int32
// group消费金额对应的rpm乘数,使用map[float64]float64
groupConsumeLevelRpmRatio atomic.Value
// group消费金额对应的rpm/tpm乘数,使用map[float64]float64
groupConsumeLevelRatio atomic.Value
)

func init() {
defaultChannelModels.Store(make(map[int][]string))
defaultChannelModelMapping.Store(make(map[int]map[string]string))
groupConsumeLevelRpmRatio.Store(make(map[float64]float64))
groupConsumeLevelRatio.Store(make(map[float64]float64))
}

func GetDefaultChannelModels() map[int][]string {
Expand All @@ -128,12 +128,12 @@ func SetDefaultChannelModelMapping(mapping map[int]map[string]string) {
defaultChannelModelMapping.Store(mapping)
}

func GetGroupConsumeLevelRpmRatio() map[float64]float64 {
return groupConsumeLevelRpmRatio.Load().(map[float64]float64)
func GetGroupConsumeLevelRatio() map[float64]float64 {
return groupConsumeLevelRatio.Load().(map[float64]float64)
}

func SetGroupConsumeLevelRpmRatio(ratio map[float64]float64) {
groupConsumeLevelRpmRatio.Store(ratio)
func SetGroupConsumeLevelRatio(ratio map[float64]float64) {
groupConsumeLevelRatio.Store(ratio)
}

// 那个group最多可创建的token数量,0表示不限制
Expand Down
20 changes: 20 additions & 0 deletions service/aiproxy/controller/dashboard.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,23 @@ func GetGroupDashboard(c *gin.Context) {

middleware.SuccessResponse(c, dashboards)
}

func GetGroupDashboardModels(c *gin.Context) {
group := c.Param("group")
if group == "" {
middleware.ErrorResponse(c, http.StatusOK, "invalid parameter")
return
}
groupCache, err := model.CacheGetGroup(group)
if err != nil {
middleware.ErrorResponse(c, http.StatusOK, "failed to get group")
return
}

enabledModelConfigs := model.LoadModelCaches().EnabledModelConfigs
newEnabledModelConfigs := make([]*model.ModelConfig, len(enabledModelConfigs))
for i, mc := range enabledModelConfigs {
newEnabledModelConfigs[i] = middleware.GetGroupAdjustedModelConfig(groupCache, mc)
}
middleware.SuccessResponse(c, newEnabledModelConfigs)
}
96 changes: 56 additions & 40 deletions service/aiproxy/middleware/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,71 +18,87 @@ import (
log "github.com/sirupsen/logrus"
)

func calculateGroupConsumeLevelRpmRatio(usedAmount float64) float64 {
v := config.GetGroupConsumeLevelRpmRatio()
func calculateGroupConsumeLevelRatio(usedAmount float64) float64 {
v := config.GetGroupConsumeLevelRatio()
if len(v) == 0 {
return 1
}
var maxConsumeLevel float64 = -1
var groupConsumeLevelRpmRatio float64
var groupConsumeLevelRatio float64
for consumeLevel, ratio := range v {
if usedAmount < consumeLevel {
continue
}
if consumeLevel > maxConsumeLevel {
maxConsumeLevel = consumeLevel
groupConsumeLevelRpmRatio = ratio
groupConsumeLevelRatio = ratio
}
}
if groupConsumeLevelRpmRatio <= 0 {
groupConsumeLevelRpmRatio = 1
if groupConsumeLevelRatio <= 0 {
groupConsumeLevelRatio = 1
}
return groupConsumeLevelRpmRatio
return groupConsumeLevelRatio
}

func getGroupRPMRatio(group *model.GroupCache) float64 {
func getGroupPMRatio(group *model.GroupCache) (float64, float64) {
groupRPMRatio := group.RPMRatio
if groupRPMRatio <= 0 {
groupRPMRatio = 1
}
return groupRPMRatio
}

func checkGroupModelRPMAndTPM(c *gin.Context, group *model.GroupCache, requestModel string, modelRPM int64, modelTPM int64) error {
if group.RPM != nil && group.RPM[requestModel] > 0 {
modelRPM = group.RPM[requestModel]
}
if group.TPM != nil && group.TPM[requestModel] > 0 {
modelTPM = group.TPM[requestModel]
groupTPMRatio := group.TPMRatio
if groupTPMRatio <= 0 {
groupTPMRatio = 1
}
return groupRPMRatio, groupTPMRatio
}

if modelRPM <= 0 && modelTPM <= 0 {
return nil
}

groupConsumeLevelRpmRatio := calculateGroupConsumeLevelRpmRatio(group.UsedAmount)
groupRPMRatio := getGroupRPMRatio(group)

adjustedModelRPM := int64(float64(modelRPM) * groupRPMRatio * groupConsumeLevelRpmRatio)
func GetGroupAdjustedModelConfig(group *model.GroupCache, mc *model.ModelConfig) *model.ModelConfig {
rpm := mc.RPM
tpm := mc.TPM
if group.RPM != nil && group.RPM[mc.Model] > 0 {
rpm = group.RPM[mc.Model]
}
if group.TPM != nil && group.TPM[mc.Model] > 0 {
tpm = group.TPM[mc.Model]
}
rpmRatio, tpmRatio := getGroupPMRatio(group)
groupConsumeLevelRatio := calculateGroupConsumeLevelRatio(group.UsedAmount)
rpm = int64(float64(rpm) * rpmRatio * groupConsumeLevelRatio)
tpm = int64(float64(tpm) * tpmRatio * groupConsumeLevelRatio)
if rpm != mc.RPM || tpm != mc.TPM {
newMc := *mc
newMc.RPM = rpm
newMc.TPM = tpm
return &newMc
}
return mc
}

ok := rpmlimit.ForceRateLimit(
c.Request.Context(),
group.ID,
requestModel,
adjustedModelRPM,
time.Minute,
)
func checkGroupModelRPMAndTPM(c *gin.Context, group *model.GroupCache, mc *model.ModelConfig) error {
adjustedModelConfig := GetGroupAdjustedModelConfig(group, mc)

if !ok {
return fmt.Errorf("group (%s) is requesting too frequently", group.ID)
if adjustedModelConfig.RPM > 0 {
ok := rpmlimit.ForceRateLimit(
c.Request.Context(),
group.ID,
mc.Model,
adjustedModelConfig.RPM,
time.Minute,
)
if !ok {
return fmt.Errorf("group (%s) is requesting too frequently", group.ID)
}
}

if modelTPM > 0 {
tpm, err := model.CacheGetGroupModelTPM(group.ID, requestModel)
if adjustedModelConfig.TPM > 0 {
tpm, err := model.CacheGetGroupModelTPM(group.ID, mc.Model)
if err != nil {
log.Errorf("get group model tpm (%s:%s) error: %s", group.ID, requestModel, err.Error())
log.Errorf("get group model tpm (%s:%s) error: %s", group.ID, mc.Model, err.Error())
// ignore error
return nil
}

if tpm >= modelTPM {
if tpm >= adjustedModelConfig.TPM {
return fmt.Errorf("group (%s) tpm is too high", group.ID)
}
}
Expand Down Expand Up @@ -139,13 +155,13 @@ func distribute(c *gin.Context, mode int) {
return
}

if err := checkGroupModelRPMAndTPM(c, group, requestModel, mc.RPM, mc.TPM); err != nil {
if err := checkGroupModelRPMAndTPM(c, group, mc); err != nil {
errMsg := err.Error()
consume.AsyncConsume(
nil,
http.StatusTooManyRequests,
nil,
NewMetaByContext(c, nil, requestModel, mode),
NewMetaByContext(c, nil, mc.Model, mode),
0,
0,
errMsg,
Expand Down
8 changes: 4 additions & 4 deletions service/aiproxy/model/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ func initOptionMap() error {
optionMap["DefaultChannelModelMapping"] = conv.BytesToString(defaultChannelModelMappingJSON)
optionMap["GeminiSafetySetting"] = config.GetGeminiSafetySetting()
optionMap["GroupMaxTokenNum"] = strconv.FormatInt(int64(config.GetGroupMaxTokenNum()), 10)
groupConsumeLevelRpmRatioJSON, err := json.Marshal(config.GetGroupConsumeLevelRpmRatio())
groupConsumeLevelRatioJSON, err := json.Marshal(config.GetGroupConsumeLevelRatio())
if err != nil {
return err
}
optionMap["GroupConsumeLevelRpmRatio"] = conv.BytesToString(groupConsumeLevelRpmRatioJSON)
optionMap["GroupConsumeLevelRatio"] = conv.BytesToString(groupConsumeLevelRatioJSON)

optionKeys = make([]string, 0, len(optionMap))
for key := range optionMap {
Expand Down Expand Up @@ -284,7 +284,7 @@ func updateOption(key string, value string, isInit bool) (err error) {
}
}
config.SetTimeoutWithModelType(newTimeoutWithModelType)
case "GroupConsumeLevelRpmRatio":
case "GroupConsumeLevelRatio":
var newGroupRpmRatio map[float64]float64
err := json.Unmarshal(conv.StringToBytes(value), &newGroupRpmRatio)
if err != nil {
Expand All @@ -298,7 +298,7 @@ func updateOption(key string, value string, isInit bool) (err error) {
return errors.New("rpm ratio must be greater than 0")
}
}
config.SetGroupConsumeLevelRpmRatio(newGroupRpmRatio)
config.SetGroupConsumeLevelRatio(newGroupRpmRatio)
default:
return ErrUnknownOptionKey
}
Expand Down
1 change: 1 addition & 0 deletions service/aiproxy/router/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func SetAPIRouter(router *gin.Engine) {
{
dashboardRoute.GET("/", controller.GetDashboard)
dashboardRoute.GET("/:group", controller.GetGroupDashboard)
dashboardRoute.GET("/:group/models", controller.GetGroupDashboardModels)
}

groupsRoute := apiRouter.Group("/groups")
Expand Down

0 comments on commit 3e7dad5

Please sign in to comment.