diff --git a/example.config.yaml b/example.config.yaml index f5030804..2524ac7e 100644 --- a/example.config.yaml +++ b/example.config.yaml @@ -19,7 +19,7 @@ domain: "http://127.0.0.1:8080" # 内调llm,用于绘图时文本转tags llm: - baseUrl: "http://127.0.0.1:8081" + base-url: "http://127.0.0.1:8081" model: "bing" token: "xxx" @@ -31,7 +31,7 @@ lmsys: [49 , 109] # 参数替换:{{prompt}}、{{negative_prompt}}、{{sampler}}、{{style}}、{{seed}} hf: # animagine-xl-3.1: -# baseUrl: https://boboiazumi-animagine-xl-3-1.hf.space +# base-url: https://boboiazumi-animagine-xl-3-1.hf.space # fn: [5, 61] # data: '[ # "{{prompt}}", @@ -54,9 +54,9 @@ hf: # 0.65 # ]' # dalle-4k: -# baseUrl: https://mukaist-dalle-4k.hf.space +# base-url: https://mukaist-dalle-4k.hf.space # dalle-3-xl: -# baseUrl: https://ehristoforu-dalle-3-xl-lora-v2.hf.space +# base-url: https://ehristoforu-dalle-3-xl-lora-v2.hf.space # gemini 自定义安全设置 google: @@ -65,7 +65,7 @@ google: # threshold: BLOCK_NONE bing: - baseUrl: "https://edgeservices.bing.com/edgesvc" + base-url: "https://edgeservices.bing.com/edgesvc" claude: pad: 0 @@ -113,12 +113,14 @@ coze: # validate: xxx@gmail.com interpreter: - baseUrl: http://127.0.0.1:8000 - echoCode: false + base-url: http://127.0.0.1:8000 + echo-code: false ws: true custom-llm: - baseUrl: http://127.0.0.1:8080 + - base-url: http://127.0.0.1:8080/v1 + prefix: custom + use-proxies: false # toolCall 默认配置化; 在 flags 关闭时也可用 toolCall: diff --git a/internal/gin.handler/basic.go b/internal/gin.handler/basic.go index 05e54dd3..68b5add8 100644 --- a/internal/gin.handler/basic.go +++ b/internal/gin.handler/basic.go @@ -48,6 +48,7 @@ func Bind(port int, version, proxies string) { route.POST("/v1/object/completions", completions) route.POST("/proxies/v1/chat/completions", completions) route.POST("/v1/embeddings", embedding) + route.POST("/proxies/v1/embeddings", embedding) route.POST("v1/images/generations", generations) route.POST("v1/object/generations", generations) route.POST("proxies/v1/images/generations", generations) diff --git a/internal/plugin/hf/adapter.go b/internal/plugin/hf/adapter.go index be70129d..3a90692f 100644 --- a/internal/plugin/hf/adapter.go +++ b/internal/plugin/hf/adapter.go @@ -208,7 +208,7 @@ func completeTagsGenerator(ctx *gin.Context, content string) (string, error) { proxies = ctx.GetString("proxies") model = pkg.Config.GetString("llm.model") cookie = pkg.Config.GetString("llm.token") - baseUrl = pkg.Config.GetString("llm.baseUrl") + baseUrl = pkg.Config.GetString("llm.base-url") ) c := regexp.MustCompile("]+)\"\\s?/>") diff --git a/internal/plugin/hf/fetch.go b/internal/plugin/hf/fetch.go index 8ecc013c..fe174156 100644 --- a/internal/plugin/hf/fetch.go +++ b/internal/plugin/hf/fetch.go @@ -183,7 +183,7 @@ func Ox002(ctx *gin.Context, model, message string) (value string, err error) { baseUrl = "https://mukaist-dalle-4k.hf.space" ) - if u := pkg.Config.GetString("hf.dalle-4k.baseUrl"); u != "" { + if u := pkg.Config.GetString("hf.dalle-4k.base-url"); u != "" { baseUrl = u } @@ -282,7 +282,7 @@ func Ox003(ctx *gin.Context, message string) (value string, err error) { domain = fmt.Sprintf("http://127.0.0.1:%d", ctx.GetInt("port")) } - if u := pkg.Config.GetString("hf.dalle-3-xl.baseUrl"); u != "" { + if u := pkg.Config.GetString("hf.dalle-3-xl.base-url"); u != "" { baseUrl = u } @@ -409,7 +409,7 @@ func Ox004(ctx *gin.Context, model, samples, message string) (value string, err domain = fmt.Sprintf("http://127.0.0.1:%d", ctx.GetInt("port")) } - if u := pkg.Config.GetString("hf.animagine-xl-3.1.baseUrl"); u != "" { + if u := pkg.Config.GetString("hf.animagine-xl-3.1.base-url"); u != "" { baseUrl = u } diff --git a/internal/plugin/llm/bing/adapter.go b/internal/plugin/llm/bing/adapter.go index f2d77390..e576dbaa 100644 --- a/internal/plugin/llm/bing/adapter.go +++ b/internal/plugin/llm/bing/adapter.go @@ -68,7 +68,7 @@ func (API) Completion(ctx *gin.Context) { completion = common.GetGinCompletion(ctx) matchers = common.GetGinMatchers(ctx) - baseUrl = pkg.Config.GetString("bing.baseUrl") + baseUrl = pkg.Config.GetString("bing.base-url") ) if cookie == "xxx" { diff --git a/internal/plugin/llm/bing/toolcall.go b/internal/plugin/llm/bing/toolcall.go index aeaf9696..b1dd9630 100644 --- a/internal/plugin/llm/bing/toolcall.go +++ b/internal/plugin/llm/bing/toolcall.go @@ -17,7 +17,7 @@ func completeToolCalls(ctx *gin.Context, cookie, proxies string, completion pkg. logger.Infof("completeTools ...") var ( - baseUrl = pkg.Config.GetString("bing.baseUrl") + baseUrl = pkg.Config.GetString("bing.base-url") echo = ctx.GetBool(vars.GinEcho) ) diff --git a/internal/plugin/llm/coze/websdk.go b/internal/plugin/llm/coze/websdk.go index 8f9bb536..9fa8b4ac 100644 --- a/internal/plugin/llm/coze/websdk.go +++ b/internal/plugin/llm/coze/websdk.go @@ -131,7 +131,7 @@ func runTasks(opts ...map[string]interface{}) { // 重置任务函数 func loopTasks() { s5 := 5 * time.Second - baseUrl := pkg.Config.GetString("serverless.baseUrl") + baseUrl := pkg.Config.GetString("serverless.base-url") if baseUrl == "" { baseUrl = "http://127.0.0.1:" + pkg.Config.GetString("you.helper") } @@ -208,7 +208,7 @@ func loopTasks() { // 初始任务函数 func initTasks(opts ...*obj) (exec bool) { time.Sleep(6 * time.Second) // 等待程序启动就绪 - baseUrl := pkg.Config.GetString("serverless.baseUrl") + baseUrl := pkg.Config.GetString("serverless.base-url") if baseUrl == "" { baseUrl = "http://127.0.0.1:" + pkg.Config.GetString("you.helper") } diff --git a/internal/plugin/llm/interpreter/fetch.go b/internal/plugin/llm/interpreter/fetch.go index fef44168..c6b4ebbd 100644 --- a/internal/plugin/llm/interpreter/fetch.go +++ b/internal/plugin/llm/interpreter/fetch.go @@ -15,7 +15,7 @@ import ( func fetch(ctx *gin.Context, proxies string, completion pkg.ChatCompletion) (response *http.Response, tokens int, err error) { var ( - baseUrl = pkg.Config.GetString("interpreter.baseUrl") + baseUrl = pkg.Config.GetString("interpreter.base-url") ) tokens, message, err := mergeMessages(ctx, proxies, baseUrl, completion) diff --git a/internal/plugin/llm/interpreter/socket.io.go b/internal/plugin/llm/interpreter/socket.io.go index e7c64f92..4b1e392d 100644 --- a/internal/plugin/llm/interpreter/socket.io.go +++ b/internal/plugin/llm/interpreter/socket.io.go @@ -16,7 +16,7 @@ func completionWS(ctx *gin.Context) { } var ( - baseUrl = pkg.Config.GetString("interpreter.baseUrl") + baseUrl = pkg.Config.GetString("interpreter.base-url") proxies = ctx.GetString("proxies") completion = common.GetGinCompletion(ctx) matchers = common.GetGinMatchers(ctx) diff --git a/internal/plugin/llm/v1/adapter.go b/internal/plugin/llm/v1/adapter.go index 0151f973..96a7a6bf 100644 --- a/internal/plugin/llm/v1/adapter.go +++ b/internal/plugin/llm/v1/adapter.go @@ -6,7 +6,6 @@ import ( "chatgpt-adapter/internal/plugin" "chatgpt-adapter/logger" "chatgpt-adapter/pkg" - "io" "net/http" "strings" @@ -17,14 +16,41 @@ import ( var ( Adapter = API{} Model = "custom" + schema = make([]map[string]interface{}, 0) + key = "__custom-url__" + upKey = "__custom-proxies__" + modKey = "__custom-model__" ) type API struct { plugin.BaseAdapter } -func (API) Match(_ *gin.Context, model string) bool { - return strings.HasPrefix(model, "custom/") +func init() { + common.AddInitialized(func() { + llm := pkg.Config.Get("custom-llm") + if slice, ok := llm.([]interface{}); ok { + for _, it := range slice { + item, o := it.(map[string]interface{}) + if !o { + continue + } + schema = append(schema, item) + } + } + }) +} + +func (API) Match(ctx *gin.Context, model string) bool { + for _, it := range schema { + if prefix, ok := it["prefix"].(string); ok && strings.HasPrefix(model, prefix+"/") { + ctx.Set(key, it["base-url"]) + ctx.Set(upKey, it["use-proxies"] == "true") + ctx.Set(modKey, model[len(prefix)+1:]) + return true + } + } + return false } func (API) Models() []plugin.Model { @@ -33,7 +59,7 @@ func (API) Models() []plugin.Model { Id: "custom", Object: "model", Created: 1686935002, - By: "lmsys-adapter", + By: "custom-adapter", }, } } @@ -76,35 +102,33 @@ label: func (API) Embedding(ctx *gin.Context) { embedding := common.GetGinEmbedding(ctx) - embedding.Model = embedding.Model[7:] + embedding.Model = ctx.GetString(modKey) var ( - token = ctx.GetString("token") - proxies = ctx.GetString("proxies") - baseUrl = pkg.Config.GetString("custom-llm.baseUrl") - useProxy = pkg.Config.GetBool("custom-llm.useProxy") + token = ctx.GetString("token") + proxies = ctx.GetString("proxies") + baseUrl = ctx.GetString(key) ) - if !useProxy { + if !ctx.GetBool(upKey) { proxies = "" } + resp, err := emit.ClientBuilder(plugin.HTTPClient). Proxies(proxies). Context(common.GetGinContext(ctx)). - POST(baseUrl+"/v1/embeddings"). + POST(baseUrl+"/embeddings"). Header("Authorization", "Bearer "+token). JHeader(). - Body(embedding).DoC(emit.Status(http.StatusOK)) + Body(embedding).DoC(emit.Status(http.StatusOK), emit.IsJSON) if err != nil { - ctx.JSON(http.StatusBadGateway, gin.H{ - "error": "can't send request to upstream", - }) + response.Error(ctx, http.StatusBadGateway, err) + return } - ctx.Header("Content-Type", "application/json; charset=utf-8") - content, err := io.ReadAll(resp.Body) + + obj, err := emit.ToMap(resp) if err != nil { - ctx.JSON(http.StatusBadGateway, gin.H{ - "error": "can't read from upstream", - }) + response.Error(ctx, http.StatusBadGateway, err) + return } - ctx.Writer.Write(content) - ctx.Writer.Flush() + + ctx.JSON(http.StatusOK, obj) } diff --git a/internal/plugin/llm/v1/fetch.go b/internal/plugin/llm/v1/fetch.go index 3ca7c0ad..7d695073 100644 --- a/internal/plugin/llm/v1/fetch.go +++ b/internal/plugin/llm/v1/fetch.go @@ -4,6 +4,7 @@ import ( "chatgpt-adapter/internal/common" "chatgpt-adapter/internal/plugin" "chatgpt-adapter/pkg" + "encoding/json" "github.com/bincooo/emit.io" "github.com/gin-gonic/gin" "net/http" @@ -11,11 +12,10 @@ import ( func fetch(ctx *gin.Context, proxies, token string, completion pkg.ChatCompletion) (response *http.Response, err error) { var ( - baseUrl = pkg.Config.GetString("custom-llm.baseUrl") - useProxy = pkg.Config.GetBool("custom-llm.useProxy") + baseUrl = ctx.GetString(key) ) - if !useProxy { + if !ctx.GetBool(upKey) { proxies = "" } @@ -38,13 +38,37 @@ func fetch(ctx *gin.Context, proxies, token string, completion pkg.ChatCompletio ctx.Set(ginTokens, token) completion.Stream = true + completion.Model = ctx.GetString(modKey) + obj, err := toMap(completion) + if err != nil { + return nil, err + } + + if completion.TopK == 0 { + delete(obj, "top_k") + } + response, err = emit.ClientBuilder(plugin.HTTPClient). Proxies(proxies). Context(common.GetGinContext(ctx)). - POST(baseUrl+"/v1/chat/completions"). + POST(baseUrl+"/chat/completions"). Header("Authorization", "Bearer "+token). JHeader(). - Body(completion). + Body(obj). DoC(emit.Status(http.StatusOK), emit.IsSTREAM) return } + +func toMap(obj interface{}) (mo map[string]interface{}, err error) { + if obj == nil { + return + } + + bytes, err := json.Marshal(obj) + if err != nil { + return + } + + err = json.Unmarshal(bytes, &mo) + return +} diff --git a/pkg/model.go b/pkg/model.go index 708c2766..1b085c76 100644 --- a/pkg/model.go +++ b/pkg/model.go @@ -6,16 +6,16 @@ import ( ) type ChatCompletion struct { - System string `json:"system"` + System string `json:"system,omitempty"` Messages []Keyv[interface{}] `json:"messages"` Tools []Keyv[interface{}] `json:"tools"` Model string `json:"model"` MaxTokens int `json:"max_tokens"` StopSequences []string `json:"stop"` Temperature float32 `json:"temperature"` - TopK int `json:"topK"` - TopP float32 `json:"topP"` - Stream bool `json:"stream"` + TopK int `json:"top_k,omitempty"` + TopP float32 `json:"top_p,omitempty"` + Stream bool `json:"stream,omitempty"` ToolChoice interface{} `json:"tool_choice"` }