Skip to content

Commit

Permalink
feat: v1兼容github model
Browse files Browse the repository at this point in the history
  • Loading branch information
bincooo committed Sep 11, 2024
1 parent 8d6605a commit bb25496
Show file tree
Hide file tree
Showing 12 changed files with 100 additions and 49 deletions.
18 changes: 10 additions & 8 deletions example.config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ domain: "http://127.0.0.1:8080"

# 内调llm,用于绘图时文本转tags
llm:
baseUrl: "http://127.0.0.1:8081"
base-url: "http://127.0.0.1:8081"
model: "bing"
token: "xxx"

Expand All @@ -31,7 +31,7 @@ lmsys: [49 , 109]
# 参数替换:{{prompt}}、{{negative_prompt}}、{{sampler}}、{{style}}、{{seed}}
hf:
# animagine-xl-3.1:
# baseUrl: https://boboiazumi-animagine-xl-3-1.hf.space
# base-url: https://boboiazumi-animagine-xl-3-1.hf.space
# fn: [5, 61]
# data: '[
# "{{prompt}}",
Expand All @@ -54,9 +54,9 @@ hf:
# 0.65
# ]'
# dalle-4k:
# baseUrl: https://mukaist-dalle-4k.hf.space
# base-url: https://mukaist-dalle-4k.hf.space
# dalle-3-xl:
# baseUrl: https://ehristoforu-dalle-3-xl-lora-v2.hf.space
# base-url: https://ehristoforu-dalle-3-xl-lora-v2.hf.space

# gemini 自定义安全设置
google:
Expand All @@ -65,7 +65,7 @@ google:
# threshold: BLOCK_NONE

bing:
baseUrl: "https://edgeservices.bing.com/edgesvc"
base-url: "https://edgeservices.bing.com/edgesvc"

claude:
pad: 0
Expand Down Expand Up @@ -113,12 +113,14 @@ coze:
# validate: [email protected]

interpreter:
baseUrl: http://127.0.0.1:8000
echoCode: false
base-url: http://127.0.0.1:8000
echo-code: false
ws: true

custom-llm:
baseUrl: http://127.0.0.1:8080
- base-url: http://127.0.0.1:8080/v1
prefix: custom
use-proxies: false

# toolCall 默认配置化; 在 flags 关闭时也可用
toolCall:
Expand Down
1 change: 1 addition & 0 deletions internal/gin.handler/basic.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ func Bind(port int, version, proxies string) {
route.POST("/v1/object/completions", completions)
route.POST("/proxies/v1/chat/completions", completions)
route.POST("/v1/embeddings", embedding)
route.POST("/proxies/v1/embeddings", embedding)
route.POST("v1/images/generations", generations)
route.POST("v1/object/generations", generations)
route.POST("proxies/v1/images/generations", generations)
Expand Down
2 changes: 1 addition & 1 deletion internal/plugin/hf/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ func completeTagsGenerator(ctx *gin.Context, content string) (string, error) {
proxies = ctx.GetString("proxies")
model = pkg.Config.GetString("llm.model")
cookie = pkg.Config.GetString("llm.token")
baseUrl = pkg.Config.GetString("llm.baseUrl")
baseUrl = pkg.Config.GetString("llm.base-url")
)

c := regexp.MustCompile("<tag content=\"([^>]+)\"\\s?/>")
Expand Down
6 changes: 3 additions & 3 deletions internal/plugin/hf/fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func Ox002(ctx *gin.Context, model, message string) (value string, err error) {
baseUrl = "https://mukaist-dalle-4k.hf.space"
)

if u := pkg.Config.GetString("hf.dalle-4k.baseUrl"); u != "" {
if u := pkg.Config.GetString("hf.dalle-4k.base-url"); u != "" {
baseUrl = u
}

Expand Down Expand Up @@ -282,7 +282,7 @@ func Ox003(ctx *gin.Context, message string) (value string, err error) {
domain = fmt.Sprintf("http://127.0.0.1:%d", ctx.GetInt("port"))
}

if u := pkg.Config.GetString("hf.dalle-3-xl.baseUrl"); u != "" {
if u := pkg.Config.GetString("hf.dalle-3-xl.base-url"); u != "" {
baseUrl = u
}

Expand Down Expand Up @@ -409,7 +409,7 @@ func Ox004(ctx *gin.Context, model, samples, message string) (value string, err
domain = fmt.Sprintf("http://127.0.0.1:%d", ctx.GetInt("port"))
}

if u := pkg.Config.GetString("hf.animagine-xl-3.1.baseUrl"); u != "" {
if u := pkg.Config.GetString("hf.animagine-xl-3.1.base-url"); u != "" {
baseUrl = u
}

Expand Down
2 changes: 1 addition & 1 deletion internal/plugin/llm/bing/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (API) Completion(ctx *gin.Context) {
completion = common.GetGinCompletion(ctx)
matchers = common.GetGinMatchers(ctx)

baseUrl = pkg.Config.GetString("bing.baseUrl")
baseUrl = pkg.Config.GetString("bing.base-url")
)

if cookie == "xxx" {
Expand Down
2 changes: 1 addition & 1 deletion internal/plugin/llm/bing/toolcall.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func completeToolCalls(ctx *gin.Context, cookie, proxies string, completion pkg.
logger.Infof("completeTools ...")

var (
baseUrl = pkg.Config.GetString("bing.baseUrl")
baseUrl = pkg.Config.GetString("bing.base-url")
echo = ctx.GetBool(vars.GinEcho)
)

Expand Down
4 changes: 2 additions & 2 deletions internal/plugin/llm/coze/websdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func runTasks(opts ...map[string]interface{}) {
// 重置任务函数
func loopTasks() {
s5 := 5 * time.Second
baseUrl := pkg.Config.GetString("serverless.baseUrl")
baseUrl := pkg.Config.GetString("serverless.base-url")
if baseUrl == "" {
baseUrl = "http://127.0.0.1:" + pkg.Config.GetString("you.helper")
}
Expand Down Expand Up @@ -208,7 +208,7 @@ func loopTasks() {
// 初始任务函数
func initTasks(opts ...*obj) (exec bool) {
time.Sleep(6 * time.Second) // 等待程序启动就绪
baseUrl := pkg.Config.GetString("serverless.baseUrl")
baseUrl := pkg.Config.GetString("serverless.base-url")
if baseUrl == "" {
baseUrl = "http://127.0.0.1:" + pkg.Config.GetString("you.helper")
}
Expand Down
2 changes: 1 addition & 1 deletion internal/plugin/llm/interpreter/fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (

func fetch(ctx *gin.Context, proxies string, completion pkg.ChatCompletion) (response *http.Response, tokens int, err error) {
var (
baseUrl = pkg.Config.GetString("interpreter.baseUrl")
baseUrl = pkg.Config.GetString("interpreter.base-url")
)

tokens, message, err := mergeMessages(ctx, proxies, baseUrl, completion)
Expand Down
2 changes: 1 addition & 1 deletion internal/plugin/llm/interpreter/socket.io.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func completionWS(ctx *gin.Context) {
}

var (
baseUrl = pkg.Config.GetString("interpreter.baseUrl")
baseUrl = pkg.Config.GetString("interpreter.base-url")
proxies = ctx.GetString("proxies")
completion = common.GetGinCompletion(ctx)
matchers = common.GetGinMatchers(ctx)
Expand Down
68 changes: 46 additions & 22 deletions internal/plugin/llm/v1/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"chatgpt-adapter/internal/plugin"
"chatgpt-adapter/logger"
"chatgpt-adapter/pkg"
"io"
"net/http"
"strings"

Expand All @@ -17,14 +16,41 @@ import (
var (
Adapter = API{}
Model = "custom"
schema = make([]map[string]interface{}, 0)
key = "__custom-url__"
upKey = "__custom-proxies__"
modKey = "__custom-model__"
)

type API struct {
plugin.BaseAdapter
}

func (API) Match(_ *gin.Context, model string) bool {
return strings.HasPrefix(model, "custom/")
func init() {
common.AddInitialized(func() {
llm := pkg.Config.Get("custom-llm")
if slice, ok := llm.([]interface{}); ok {
for _, it := range slice {
item, o := it.(map[string]interface{})
if !o {
continue
}
schema = append(schema, item)
}
}
})
}

func (API) Match(ctx *gin.Context, model string) bool {
for _, it := range schema {
if prefix, ok := it["prefix"].(string); ok && strings.HasPrefix(model, prefix+"/") {
ctx.Set(key, it["base-url"])
ctx.Set(upKey, it["use-proxies"] == "true")
ctx.Set(modKey, model[len(prefix)+1:])
return true
}
}
return false
}

func (API) Models() []plugin.Model {
Expand All @@ -33,7 +59,7 @@ func (API) Models() []plugin.Model {
Id: "custom",
Object: "model",
Created: 1686935002,
By: "lmsys-adapter",
By: "custom-adapter",
},
}
}
Expand Down Expand Up @@ -76,35 +102,33 @@ label:

func (API) Embedding(ctx *gin.Context) {
embedding := common.GetGinEmbedding(ctx)
embedding.Model = embedding.Model[7:]
embedding.Model = ctx.GetString(modKey)
var (
token = ctx.GetString("token")
proxies = ctx.GetString("proxies")
baseUrl = pkg.Config.GetString("custom-llm.baseUrl")
useProxy = pkg.Config.GetBool("custom-llm.useProxy")
token = ctx.GetString("token")
proxies = ctx.GetString("proxies")
baseUrl = ctx.GetString(key)
)
if !useProxy {
if !ctx.GetBool(upKey) {
proxies = ""
}

resp, err := emit.ClientBuilder(plugin.HTTPClient).
Proxies(proxies).
Context(common.GetGinContext(ctx)).
POST(baseUrl+"/v1/embeddings").
POST(baseUrl+"/embeddings").
Header("Authorization", "Bearer "+token).
JHeader().
Body(embedding).DoC(emit.Status(http.StatusOK))
Body(embedding).DoC(emit.Status(http.StatusOK), emit.IsJSON)
if err != nil {
ctx.JSON(http.StatusBadGateway, gin.H{
"error": "can't send request to upstream",
})
response.Error(ctx, http.StatusBadGateway, err)
return
}
ctx.Header("Content-Type", "application/json; charset=utf-8")
content, err := io.ReadAll(resp.Body)

obj, err := emit.ToMap(resp)
if err != nil {
ctx.JSON(http.StatusBadGateway, gin.H{
"error": "can't read from upstream",
})
response.Error(ctx, http.StatusBadGateway, err)
return
}
ctx.Writer.Write(content)
ctx.Writer.Flush()

ctx.JSON(http.StatusOK, obj)
}
34 changes: 29 additions & 5 deletions internal/plugin/llm/v1/fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@ import (
"chatgpt-adapter/internal/common"
"chatgpt-adapter/internal/plugin"
"chatgpt-adapter/pkg"
"encoding/json"
"github.com/bincooo/emit.io"
"github.com/gin-gonic/gin"
"net/http"
)

func fetch(ctx *gin.Context, proxies, token string, completion pkg.ChatCompletion) (response *http.Response, err error) {
var (
baseUrl = pkg.Config.GetString("custom-llm.baseUrl")
useProxy = pkg.Config.GetBool("custom-llm.useProxy")
baseUrl = ctx.GetString(key)
)

if !useProxy {
if !ctx.GetBool(upKey) {
proxies = ""
}

Expand All @@ -38,13 +38,37 @@ func fetch(ctx *gin.Context, proxies, token string, completion pkg.ChatCompletio
ctx.Set(ginTokens, token)

completion.Stream = true
completion.Model = ctx.GetString(modKey)
obj, err := toMap(completion)
if err != nil {
return nil, err
}

if completion.TopK == 0 {
delete(obj, "top_k")
}

response, err = emit.ClientBuilder(plugin.HTTPClient).
Proxies(proxies).
Context(common.GetGinContext(ctx)).
POST(baseUrl+"/v1/chat/completions").
POST(baseUrl+"/chat/completions").
Header("Authorization", "Bearer "+token).
JHeader().
Body(completion).
Body(obj).
DoC(emit.Status(http.StatusOK), emit.IsSTREAM)
return
}

func toMap(obj interface{}) (mo map[string]interface{}, err error) {
if obj == nil {
return
}

bytes, err := json.Marshal(obj)
if err != nil {
return
}

err = json.Unmarshal(bytes, &mo)
return
}
8 changes: 4 additions & 4 deletions pkg/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@ import (
)

type ChatCompletion struct {
System string `json:"system"`
System string `json:"system,omitempty"`
Messages []Keyv[interface{}] `json:"messages"`
Tools []Keyv[interface{}] `json:"tools"`
Model string `json:"model"`
MaxTokens int `json:"max_tokens"`
StopSequences []string `json:"stop"`
Temperature float32 `json:"temperature"`
TopK int `json:"topK"`
TopP float32 `json:"topP"`
Stream bool `json:"stream"`
TopK int `json:"top_k,omitempty"`
TopP float32 `json:"top_p,omitempty"`
Stream bool `json:"stream,omitempty"`
ToolChoice interface{} `json:"tool_choice"`
}

Expand Down

0 comments on commit bb25496

Please sign in to comment.