Skip to content

Commit

Permalink
feat: option from env
Browse files Browse the repository at this point in the history
  • Loading branch information
zijiren233 committed Jan 6, 2025
1 parent 07c4cb5 commit 2b0050d
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 87 deletions.
123 changes: 63 additions & 60 deletions service/aiproxy/common/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,55 +4,81 @@ import (
"math"
"os"
"slices"
"strconv"
"sync/atomic"
"time"

"github.com/labring/sealos/service/aiproxy/common/env"
)

var (
DebugEnabled, _ = strconv.ParseBool(os.Getenv("DEBUG"))
DebugSQLEnabled, _ = strconv.ParseBool(os.Getenv("DEBUG_SQL"))
DebugEnabled = env.Bool("DEBUG", false)
DebugSQLEnabled = env.Bool("DEBUG_SQL", false)
)

var (
// 暂停服务
disableServe atomic.Bool
// log detail 存储时间(小时)
DisableAutoMigrateDB = env.Bool("DISABLE_AUTO_MIGRATE_DB", false)
OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false)
AdminKey = os.Getenv("ADMIN_KEY")
)

var (
disableServe atomic.Bool
logDetailStorageHours int64 = 3 * 24
internalToken atomic.Value
)

var (
// 重试次数
retryTimes atomic.Int64
// 是否开启模型错误率自动封禁
retryTimes atomic.Int64
enableModelErrorAutoBan atomic.Bool
// 模型错误率自动封禁
modelErrorAutoBanRate = math.Float64bits(0.5)
// 模型类型超时时间,单位秒
timeoutWithModelType atomic.Value
modelErrorAutoBanRate = math.Float64bits(0.5)
timeoutWithModelType atomic.Value
disableModelConfig atomic.Bool
)

disableModelConfig atomic.Bool
var (
defaultChannelModels atomic.Value
defaultChannelModelMapping atomic.Value
groupMaxTokenNum atomic.Int64
groupConsumeLevelRatio atomic.Value
)

var geminiSafetySetting atomic.Value

var billingEnabled atomic.Bool

func init() {
timeoutWithModelType.Store(make(map[int]int64))
defaultChannelModels.Store(make(map[int][]string))
defaultChannelModelMapping.Store(make(map[int]map[string]string))
groupConsumeLevelRatio.Store(make(map[float64]float64))
geminiSafetySetting.Store("BLOCK_NONE")
billingEnabled.Store(true)
internalToken.Store(os.Getenv("INTERNAL_TOKEN"))
}

func GetDisableModelConfig() bool {
return disableModelConfig.Load()
}

func SetDisableModelConfig(disabled bool) {
disabled = env.Bool("DISABLE_MODEL_CONFIG", disabled)
disableModelConfig.Store(disabled)
}

func GetRetryTimes() int64 {
return retryTimes.Load()
}

func SetRetryTimes(times int64) {
times = env.Int64("RETRY_TIMES", times)
retryTimes.Store(times)
}

func GetEnableModelErrorAutoBan() bool {
return enableModelErrorAutoBan.Load()
}

func SetEnableModelErrorAutoBan(enabled bool) {
enabled = env.Bool("ENABLE_MODEL_ERROR_AUTO_BAN", enabled)
enableModelErrorAutoBan.Store(enabled)
}

Expand All @@ -61,22 +87,16 @@ func GetModelErrorAutoBanRate() float64 {
}

func SetModelErrorAutoBanRate(rate float64) {
rate = env.Float64("MODEL_ERROR_AUTO_BAN_RATE", rate)
atomic.StoreUint64(&modelErrorAutoBanRate, math.Float64bits(rate))
}

func SetRetryTimes(times int64) {
retryTimes.Store(times)
}

func init() {
timeoutWithModelType.Store(make(map[int]int64))
}

func GetTimeoutWithModelType() map[int]int64 {
return timeoutWithModelType.Load().(map[int]int64)
}

func SetTimeoutWithModelType(timeout map[int]int64) {
timeout = env.JSON("TIMEOUT_WITH_MODEL_TYPE", timeout)
timeoutWithModelType.Store(timeout)
}

Expand All @@ -85,6 +105,7 @@ func GetLogDetailStorageHours() int64 {
}

func SetLogDetailStorageHours(hours int64) {
hours = env.Int64("LOG_DETAIL_STORAGE_HOURS", hours)
atomic.StoreInt64(&logDetailStorageHours, hours)
}

Expand All @@ -93,36 +114,16 @@ func GetDisableServe() bool {
}

func SetDisableServe(disabled bool) {
disabled = env.Bool("DISABLE_SERVE", disabled)
disableServe.Store(disabled)
}

var DisableAutoMigrateDB = os.Getenv("DISABLE_AUTO_MIGRATE_DB") == "true"

var RateLimitKeyExpirationDuration = 20 * time.Minute

var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false)

var AdminKey = env.String("ADMIN_KEY", "")

var (
defaultChannelModels atomic.Value
defaultChannelModelMapping atomic.Value
groupMaxTokenNum atomic.Int32
// group消费金额对应的rpm/tpm乘数,使用map[float64]float64
groupConsumeLevelRatio atomic.Value
)

func init() {
defaultChannelModels.Store(make(map[int][]string))
defaultChannelModelMapping.Store(make(map[int]map[string]string))
groupConsumeLevelRatio.Store(make(map[float64]float64))
}

func GetDefaultChannelModels() map[int][]string {
return defaultChannelModels.Load().(map[int][]string)
}

func SetDefaultChannelModels(models map[int][]string) {
models = env.JSON("DEFAULT_CHANNEL_MODELS", models)
for key, ms := range models {
slices.Sort(ms)
models[key] = slices.Compact(ms)
Expand All @@ -135,6 +136,7 @@ func GetDefaultChannelModelMapping() map[int]map[string]string {
}

func SetDefaultChannelModelMapping(mapping map[int]map[string]string) {
mapping = env.JSON("DEFAULT_CHANNEL_MODEL_MAPPING", mapping)
defaultChannelModelMapping.Store(mapping)
}

Expand All @@ -143,42 +145,43 @@ func GetGroupConsumeLevelRatio() map[float64]float64 {
}

func SetGroupConsumeLevelRatio(ratio map[float64]float64) {
ratio = env.JSON("GROUP_CONSUME_LEVEL_RATIO", ratio)
groupConsumeLevelRatio.Store(ratio)
}

// 那个group最多可创建的token数量,0表示不限制
func GetGroupMaxTokenNum() int32 {
// GetGroupMaxTokenNum returns max number of tokens per group, 0 means unlimited
func GetGroupMaxTokenNum() int64 {
return groupMaxTokenNum.Load()
}

func SetGroupMaxTokenNum(num int32) {
func SetGroupMaxTokenNum(num int64) {
num = env.Int64("GROUP_MAX_TOKEN_NUM", num)
groupMaxTokenNum.Store(num)
}

var geminiSafetySetting atomic.Value

func init() {
geminiSafetySetting.Store("BLOCK_NONE")
}

func GetGeminiSafetySetting() string {
return geminiSafetySetting.Load().(string)
}

func SetGeminiSafetySetting(setting string) {
setting = env.String("GEMINI_SAFETY_SETTING", setting)
geminiSafetySetting.Store(setting)
}

var billingEnabled atomic.Bool

func init() {
billingEnabled.Store(true)
}

func GetBillingEnabled() bool {
return billingEnabled.Load()
}

func SetBillingEnabled(enabled bool) {
enabled = env.Bool("BILLING_ENABLED", enabled)
billingEnabled.Store(enabled)
}

func GetInternalToken() string {
return internalToken.Load().(string)
}

func SetInternalToken(token string) {
token = env.String("INTERNAL_TOKEN", token)
internalToken.Store(token)
}
4 changes: 2 additions & 2 deletions service/aiproxy/common/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ var (
)

var (
SQLitePath = "aiproxy.db"
SQLiteBusyTimeout = env.Int("SQLITE_BUSY_TIMEOUT", 3000)
SQLitePath = env.String("SQLITE_PATH", "aiproxy.db")
SQLiteBusyTimeout = env.Int64("SQLITE_BUSY_TIMEOUT", 3000)
)
61 changes: 52 additions & 9 deletions service/aiproxy/common/env/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,83 @@ package env
import (
"os"
"strconv"

json "github.com/json-iterator/go"
"github.com/labring/sealos/service/aiproxy/common/conv"
log "github.com/sirupsen/logrus"
)

func Bool(env string, defaultValue bool) bool {
if env == "" || os.Getenv(env) == "" {
if env == "" {
return defaultValue
}
e := os.Getenv(env)
if e == "" {
return defaultValue
}
p, err := strconv.ParseBool(e)
if err != nil {
log.Errorf("invalid %s: %s", env, e)
return defaultValue
}
return os.Getenv(env) == "true"
return p
}

func Int(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
func Int64(env string, defaultValue int64) int64 {
if env == "" {
return defaultValue
}
e := os.Getenv(env)
if e == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
num, err := strconv.ParseInt(e, 10, 64)
if err != nil {
log.Errorf("invalid %s: %s", env, e)
return defaultValue
}
return num
}

func Float64(env string, defaultValue float64) float64 {
if env == "" || os.Getenv(env) == "" {
if env == "" {
return defaultValue
}
e := os.Getenv(env)
if e == "" {
return defaultValue
}
num, err := strconv.ParseFloat(os.Getenv(env), 64)
num, err := strconv.ParseFloat(e, 64)
if err != nil {
log.Errorf("invalid %s: %s", env, e)
return defaultValue
}
return num
}

func String(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
if env == "" {
return defaultValue
}
e := os.Getenv(env)
if e == "" {
return defaultValue
}
return e
}

func JSON[T any](env string, defaultValue T) T {
if env == "" {
return defaultValue
}
e := os.Getenv(env)
if e == "" {
return defaultValue
}
var t T
if err := json.Unmarshal(conv.StringToBytes(e), &t); err != nil {
log.Errorf("invalid %s: %s", env, e)
return defaultValue
}
return os.Getenv(env)
return t
}
3 changes: 0 additions & 3 deletions service/aiproxy/common/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ var (
func Init() {
flag.Parse()

if os.Getenv("SQLITE_PATH") != "" {
SQLitePath = os.Getenv("SQLITE_PATH")
}
if *LogDir != "" {
var err error
*LogDir, err = filepath.Abs(*LogDir)
Expand Down
3 changes: 1 addition & 2 deletions service/aiproxy/common/rpmlimit/rate-limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"time"

"github.com/labring/sealos/service/aiproxy/common"
"github.com/labring/sealos/service/aiproxy/common/config"
log "github.com/sirupsen/logrus"
)

Expand Down Expand Up @@ -187,6 +186,6 @@ func ForceRateLimit(ctx context.Context, group, model string, maxRequestNum int6

func MemoryRateLimit(_ context.Context, group, model string, maxRequestNum int64, duration time.Duration) bool {
// It's safe to call multi times.
inMemoryRateLimiter.Init(config.RateLimitKeyExpirationDuration)
inMemoryRateLimiter.Init(3 * time.Minute)
return inMemoryRateLimiter.Request(fmt.Sprintf(groupModelRPMKey, group, model), int(maxRequestNum), duration)
}
Loading

0 comments on commit 2b0050d

Please sign in to comment.