From 7787502b44429d32662afb541091f659b6c556bd Mon Sep 17 00:00:00 2001 From: SanaeFox <36219542+Hoshinonyaruko@users.noreply.github.com> Date: Wed, 10 Apr 2024 18:00:37 +0800 Subject: [PATCH 01/13] Beta61 (#62) * beta1 * beta2 * beta3 * beta4 * beta5 * beta6 * beta7 * beta8 * beta9 * beta10 * beta11 * beta12 * beta13 * beta14 * beta15 * beta16 * beta16 * beta19 * beta20 * beta21 * beta22 * beta23 * beta24 * beta25 * beta27 * beta28 * beta29 * beta30 * beta31 * beta33 * beta34 * beta35 * beta36 * beta37 * beta38 * beta39 * beta40 * beta41 * beta42 * beta43 * beta44 * beta45 * beta45 * beta46 * beat48 * beta49 * beta50 * beta51 * beta52 * beta53 * beta54 * beta55 * beta57 * beta58 * beta59 * beta61 --- applogic/gensokyo.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/applogic/gensokyo.go b/applogic/gensokyo.go index 970b9d8..ec3d8b7 100644 --- a/applogic/gensokyo.go +++ b/applogic/gensokyo.go @@ -428,7 +428,7 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { } else { //最后一条了 messageSSE := structs.InterfaceBody{ - Content: newPart + "\n", + Content: newPart, State: 11, } utils.SendPrivateMessageSSE(message.UserID, messageSSE) @@ -449,7 +449,7 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { } else { //最后一条了 messageSSE := structs.InterfaceBody{ - Content: response + "\n", + Content: response, State: 11, } utils.SendPrivateMessageSSE(message.UserID, messageSSE) @@ -499,7 +499,7 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { //最后一条了 messageSSE := structs.InterfaceBody{ - Content: " " + "\n", + Content: " ", State: 20, PromptKeyboard: promptkeyboard, } @@ -610,7 +610,7 @@ func processMessage(response string, msg structs.OnebotGroupMessage, newmesssage //CallbackData := GetStringById(lastMessageID) uerid := strconv.FormatInt(msg.UserID, 10) messageSSE := structs.InterfaceBody{ - Content: accumulatedMessage + "\n", + Content: accumulatedMessage, State: 1, ActionButton: 10, CallbackData: uerid, @@ -619,7 +619,7 @@ func processMessage(response string, msg structs.OnebotGroupMessage, newmesssage } else { //SSE的前半部分 messageSSE := structs.InterfaceBody{ - Content: accumulatedMessage + "\n", + Content: accumulatedMessage, State: 1, } utils.SendPrivateMessageSSE(msg.UserID, messageSSE) From 8534650bff7c3e4eb9e2f8db161ec517bfd2367a Mon Sep 17 00:00:00 2001 From: SanaeFox <36219542+Hoshinonyaruko@users.noreply.github.com> Date: Fri, 12 Apr 2024 11:46:19 +0800 Subject: [PATCH 02/13] Beta62 (#63) * beta1 * beta2 * beta3 * beta4 * beta5 * beta6 * beta7 * beta8 * beta9 * beta10 * beta11 * beta12 * beta13 * beta14 * beta15 * beta16 * beta16 * beta19 * beta20 * beta21 * beta22 * beta23 * beta24 * beta25 * beta27 * beta28 * beta29 * beta30 * beta31 * beta33 * beta34 * beta35 * beta36 * beta37 * beta38 * beta39 * beta40 * beta41 * beta42 * beta43 * beta44 * beta45 * beta45 * beta46 * beat48 * beta49 * beta50 * beta51 * beta52 * beta53 * beta54 * beta55 * beta57 * beta58 * beta59 * beta61 * beta62 --- applogic/gensokyo.go | 2 +- config/config.go | 11 +++++++++++ template/config_template.go | 1 + utils/utils.go | 14 +++++++++++--- 4 files changed, 24 insertions(+), 4 deletions(-) diff --git a/applogic/gensokyo.go b/applogic/gensokyo.go index ec3d8b7..4576417 100644 --- a/applogic/gensokyo.go +++ b/applogic/gensokyo.go @@ -593,7 +593,7 @@ func processMessage(response string, msg structs.OnebotGroupMessage, newmesssage for _, char := range response { messageBuilder.WriteRune(char) - if utils.ContainsRune(punctuations, char) { + if utils.ContainsRune(punctuations, char, msg.GroupID) { // 达到标点符号,发送累积的整个消息 if messageBuilder.Len() > 0 { accumulatedMessage := messageBuilder.String() diff --git a/config/config.go b/config/config.go index 9938065..378ba60 100644 --- a/config/config.go +++ b/config/config.go @@ -93,6 +93,7 @@ type Settings struct { UseFunctionPromptkeyboard bool `yaml:"useFunctionPromptkeyboard"` AIPromptkeyboardPath string `yaml:"AIPromptkeyboardPath"` UseAIPromptkeyboard bool `yaml:"useAIPromptkeyboard"` + SplitByPuntuationsGroup int `yaml:"splitByPuntuationsGroup"` } // LoadConfig 从文件中加载配置并初始化单例配置 @@ -341,6 +342,16 @@ func GetSplitByPuntuations() int { return 0 } +// 获取SplitByPuntuationsGroup +func GetSplitByPuntuationsGroup() int { + mu.Lock() + defer mu.Unlock() + if instance != nil { + return instance.Settings.SplitByPuntuationsGroup + } + return 0 +} + // 获取HunyuanType func GetHunyuanType() int { mu.Lock() diff --git a/template/config_template.go b/template/config_template.go index 6f7ff8b..a762513 100644 --- a/template/config_template.go +++ b/template/config_template.go @@ -19,6 +19,7 @@ settings: thirdA : [""] #可空 groupMessage : true #是否响应群信息 splitByPuntuations : 40 #截断率,仅在sse时有效,100则代表每句截断 + splitByPuntuationsGroup : 10 #截断率(群),仅在sse时有效,100则代表每句截断 sensitiveMode : false #是否开启敏感词替换 sensitiveModeType : 0 #0=只过滤用户输入 1=输出也进行过滤 defaultChangeWord : "*" #默认的屏蔽词替换,你可以在sensitive_words.txt的####后修改为自己需要,可以用记事本批量替换 diff --git a/utils/utils.go b/utils/utils.go index f99ee43..94a5fae 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -83,11 +83,18 @@ func GetKey(groupid int64, userid int64) string { } // 随机的分布发送 -func ContainsRune(slice []rune, value rune) bool { +func ContainsRune(slice []rune, value rune, groupid int64) bool { + var probability int + if groupid == 0 { + // 获取概率百分比 + probability = config.GetSplitByPuntuations() + } else { + // 获取概率百分比 + probability = config.GetSplitByPuntuationsGroup() + } + for _, item := range slice { if item == value { - // 获取概率百分比 - probability := config.GetSplitByPuntuations() // 将概率转换为0到1之间的浮点数 probabilityPercentage := float64(probability) / 100.0 // 生成一个0到1之间的随机浮点数 @@ -152,6 +159,7 @@ func SendGroupMessage(groupID int64, userID int64, message string) error { "user_id": userID, "message": message, }) + fmtf.Printf("发群信息请求:%v", string(requestBody)) if err != nil { return fmtf.Errorf("failed to marshal request body: %w", err) } From 518b67036225fdba114d97a08ca492728514d80d Mon Sep 17 00:00:00 2001 From: SanaeFox <36219542+Hoshinonyaruko@users.noreply.github.com> Date: Fri, 12 Apr 2024 16:26:45 +0800 Subject: [PATCH 03/13] Beta63 (#64) * beta1 * beta2 * beta3 * beta4 * beta5 * beta6 * beta7 * beta8 * beta9 * beta10 * beta11 * beta12 * beta13 * beta14 * beta15 * beta16 * beta16 * beta19 * beta20 * beta21 * beta22 * beta23 * beta24 * beta25 * beta27 * beta28 * beta29 * beta30 * beta31 * beta33 * beta34 * beta35 * beta36 * beta37 * beta38 * beta39 * beta40 * beta41 * beta42 * beta43 * beta44 * beta45 * beta45 * beta46 * beat48 * beta49 * beta50 * beta51 * beta52 * beta53 * beta54 * beta55 * beta57 * beta58 * beta59 * beta61 * beta62 * beta63 * beta63 --- applogic/gensokyo.go | 11 +++++++++++ go.mod | 8 +++++++- go.sum | 22 ++++++++++++++++++++++ readme.md | 2 +- utils/t2s.go | 29 +++++++++++++++++++++++++++++ 5 files changed, 70 insertions(+), 2 deletions(-) create mode 100644 utils/t2s.go diff --git a/applogic/gensokyo.go b/applogic/gensokyo.go index 4576417..568b65b 100644 --- a/applogic/gensokyo.go +++ b/applogic/gensokyo.go @@ -352,6 +352,12 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { fmtf.Printf("消息进入替换前:%v", requestmsg) } + // 繁体转换简体 安全策略 + requestmsg, err = utils.ConvertTraditionalToSimplified(requestmsg) + if err != nil { + fmtf.Printf("繁体转换简体失败:%v", err) + } + // 替换in替换词规则 if config.GetSensitiveMode() { requestmsg = acnode.CheckWordIN(requestmsg) @@ -497,6 +503,11 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { promptkeyboard = GetPromptKeyboardAI("Q" + newmsg + "A" + response) } + // 使用acnode.CheckWordOUT()过滤promptkeyboard中的每个字符串 + for i, item := range promptkeyboard { + promptkeyboard[i] = acnode.CheckWordOUT(item) + } + //最后一条了 messageSSE := structs.InterfaceBody{ Content: " ", diff --git a/go.mod b/go.mod index ce85a5a..58cc3a4 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,12 @@ require ( require github.com/abadojack/whatlanggo v1.0.1 require ( - github.com/fsnotify/fsnotify v1.7.0 + github.com/liuzl/cedar-go v0.0.0-20170805034717-80a9c64b256d // indirect + github.com/liuzl/da v0.0.0-20180704015230-14771aad5b1d // indirect + github.com/longbridgeapp/opencc v0.3.11 // indirect +) + +require ( + github.com/fsnotify/fsnotify v1.7.0 golang.org/x/sys v0.4.0 // indirect ) diff --git a/go.sum b/go.sum index 49a2f5e..943f5f3 100644 --- a/go.sum +++ b/go.sum @@ -1,16 +1,38 @@ github.com/abadojack/whatlanggo v1.0.1 h1:19N6YogDnf71CTHm3Mp2qhYfkRdyvbgwWdd2EPxJRG4= github.com/abadojack/whatlanggo v1.0.1/go.mod h1:66WiQbSbJBIlOZMsvbKe5m6pzQovxCH9B/K8tQB2uoc= +github.com/adamzy/cedar-go v0.0.0-20170805034717-80a9c64b256d/go.mod h1:PRWNwWq0yifz6XDPZu48aSld8BWwBfr2JKB2bGWiEd4= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU= github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/liuzl/cedar-go v0.0.0-20170805034717-80a9c64b256d h1:qSmEGTgjkESUX5kPMSGJ4pcBUtYVDdkNzMrjQyvRvp0= +github.com/liuzl/cedar-go v0.0.0-20170805034717-80a9c64b256d/go.mod h1:x7SghIWwLVcJObXbjK7S2ENsT1cAcdJcPl7dRaSFog0= +github.com/liuzl/da v0.0.0-20180704015230-14771aad5b1d h1:hTRDIpJ1FjS9ULJuEzu69n3qTgc18eI+ztw/pJv47hs= +github.com/liuzl/da v0.0.0-20180704015230-14771aad5b1d/go.mod h1:7xD3p0XnHvJFQ3t/stEJd877CSIMkH/fACVWen5pYnc= +github.com/longbridgeapp/opencc v0.3.11 h1:MfozRXTRmchceDyVsJ/JoOsuXb7AqtjF7RUtWUa0cQo= +github.com/longbridgeapp/opencc v0.3.11/go.mod h1:jRuKtq8eLA+cZUu75XgMvkB/hFSXJbZDmij0v29lNaY= github.com/mattn/go-sqlite3 v1.14.19 h1:fhGleo2h1p8tVChob4I9HpmVFIAkKGpiukdrgQbWfGI= github.com/mattn/go-sqlite3 v1.14.19/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.839 h1:VGVFNQDaUpDsPkJrh8I9qOxHZ1yj5sJmg9ngsUvTAHM= github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.839/go.mod h1:r5r4xbfxSaeR04b166HGsBa/R4U3SueirEUpXGuw+Q0= golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/readme.md b/readme.md index b8dec4c..0565e4f 100644 --- a/readme.md +++ b/readme.md @@ -46,7 +46,7 @@ AhoCorasick算法实现的超高效文本IN-Out替换规则,可大量替换n 命令行 -mlog 将当前储存的所有日志进行QA格式化,每日审验,从实际场景提炼新安全规则,不断增加安全性,第六重安全措施 -语言过滤,允许llm只接受所指定的语言,在自己擅长的领域进行防守,第七重安全措施 +语言过滤,允许llm只接受所指定的语言,自动将繁体转换为简体应用安全规则,在自己擅长的领域进行防守,第七重安全措施 提示词长度限制,用最原始的方式控制安全,阻止恶意用户构造长提示词,第八重安全措施 diff --git a/utils/t2s.go b/utils/t2s.go new file mode 100644 index 0000000..9508620 --- /dev/null +++ b/utils/t2s.go @@ -0,0 +1,29 @@ +package utils + +import ( + "log" + "sync" + + "github.com/longbridgeapp/opencc" +) + +// Global converter instance +var converter *opencc.OpenCC +var once sync.Once + +// init function to initialize the global converter +func init() { + var err error + once.Do(func() { + // Initialize the converter with the appropriate conversion configuration + converter, err = opencc.New("t2s") + if err != nil { + log.Fatalf("Failed to initialize OpenCC converter: %v", err) + } + }) +} + +// ConvertTraditionalToSimplified converts traditional Chinese to simplified Chinese. +func ConvertTraditionalToSimplified(text string) (string, error) { + return converter.Convert(text) +} From 9f334246dc0cc476218a15b1fa2a6cb938102189 Mon Sep 17 00:00:00 2001 From: SanaeFox <36219542+Hoshinonyaruko@users.noreply.github.com> Date: Sat, 13 Apr 2024 20:18:54 +0800 Subject: [PATCH 04/13] Beta64 (#66) * beta1 * beta2 * beta3 * beta4 * beta5 * beta6 * beta7 * beta8 * beta9 * beta10 * beta11 * beta12 * beta13 * beta14 * beta15 * beta16 * beta16 * beta19 * beta20 * beta21 * beta22 * beta23 * beta24 * beta25 * beta27 * beta28 * beta29 * beta30 * beta31 * beta33 * beta34 * beta35 * beta36 * beta37 * beta38 * beta39 * beta40 * beta41 * beta42 * beta43 * beta44 * beta45 * beta45 * beta46 * beat48 * beta49 * beta50 * beta51 * beta52 * beta53 * beta54 * beta55 * beta57 * beta58 * beta59 * beta61 * beta62 * beta63 * beta63 * beta64 --- applogic/chatgpt.go | 2 +- applogic/gensokyo.go | 21 +- applogic/rwkv.go | 405 ++++++++++++++++++++++++++++++++++++ config/config.go | 176 ++++++++++++++++ main.go | 3 + template/config_template.go | 18 ++ utils/utils.go | 33 +++ 7 files changed, 651 insertions(+), 7 deletions(-) create mode 100644 applogic/rwkv.go diff --git a/applogic/chatgpt.go b/applogic/chatgpt.go index ebb12b7..511ba62 100644 --- a/applogic/chatgpt.go +++ b/applogic/chatgpt.go @@ -181,7 +181,7 @@ func (app *App) ChatHandlerChatgpt(w http.ResponseWriter, r *http.Request) { http.Error(w, fmtf.Sprintf("Failed to read response body: %v", err), http.StatusInternalServerError) return } - // fmtf.Printf("chatgpt返回:%v", string(responseBody)) + fmtf.Printf("chatgpt返回:%v", string(responseBody)) // 假设已经成功发送请求并获得响应,responseBody是响应体的字节数据 var apiResponse struct { Choices []struct { diff --git a/applogic/gensokyo.go b/applogic/gensokyo.go index 568b65b..27431e5 100644 --- a/applogic/gensokyo.go +++ b/applogic/gensokyo.go @@ -468,8 +468,13 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { //清空之前加入缓存 // 缓存省钱部分 if config.GetUseCache() { - fmtf.Printf("缓存了Q:%v,A:%v,向量ID:%v", newmsg, response, lastSelectedVectorID) - app.InsertQAEntry(newmsg, response, lastSelectedVectorID) + if response != "" { + fmtf.Printf("缓存了Q:%v,A:%v,向量ID:%v", newmsg, response, lastSelectedVectorID) + app.InsertQAEntry(newmsg, response, lastSelectedVectorID) + } else { + fmtf.Printf("缓存Q:%v时遇到问题,A为空,检查api是否存在问题", newmsg) + } + } // 清空映射中对应的累积消息 @@ -510,7 +515,7 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { //最后一条了 messageSSE := structs.InterfaceBody{ - Content: " ", + Content: " ", State: 20, PromptKeyboard: promptkeyboard, } @@ -592,15 +597,19 @@ func splitAndSendMessages(message structs.OnebotGroupMessage, line string, newme return } - // 处理提取出的信息 - processMessage(sseData.Response, message, newmesssage) + if sseData.Response != "\n\n" { + // 处理提取出的信息 + processMessage(sseData.Response, message, newmesssage) + } else { + fmtf.Printf("忽略llm末尾的换行符") + } } func processMessage(response string, msg structs.OnebotGroupMessage, newmesssage string) { key := utils.GetKey(msg.GroupID, msg.UserID) // 定义中文全角和英文标点符号 - punctuations := []rune{'。', '!', '?', ',', ',', '.', '!', '?'} + punctuations := []rune{'。', '!', '?', ',', ',', '.', '!', '?', '~'} for _, char := range response { messageBuilder.WriteRune(char) diff --git a/applogic/rwkv.go b/applogic/rwkv.go new file mode 100644 index 0000000..c2bfe67 --- /dev/null +++ b/applogic/rwkv.go @@ -0,0 +1,405 @@ +package applogic + +import ( + "bufio" + "bytes" + "encoding/json" + "io" + "net/http" + "strings" + "sync" + + "github.com/google/uuid" + "github.com/hoshinonyaruko/gensokyo-llm/config" + "github.com/hoshinonyaruko/gensokyo-llm/fmtf" + "github.com/hoshinonyaruko/gensokyo-llm/structs" + "github.com/hoshinonyaruko/gensokyo-llm/utils" +) + +// 用于存储每个conversationId的最后一条消息内容 +var ( + // lastResponses 存储每个真实 conversationId 的最后响应文本 + lastResponsesRwkv sync.Map + lastCompleteResponsesRwkv sync.Map // 存储每个conversationId的完整累积信息 + mutexRwkv sync.Mutex +) + +func (app *App) ChatHandlerRwkv(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + http.Error(w, "Only POST method is allowed", http.StatusMethodNotAllowed) + return + } + + var msg structs.Message + err := json.NewDecoder(r.Body).Decode(&msg) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + msg.Role = "user" + //颠倒用户输入 + if config.GetReverseUserPrompt() { + msg.Text = utils.ReverseString(msg.Text) + } + + if msg.ConversationID == "" { + msg.ConversationID = utils.GenerateUUID() + app.createConversation(msg.ConversationID) + } + + userMessageID, err := app.addMessage(msg) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + var history []structs.Message + + // 获取系统提示词 + systemPromptContent := config.SystemPrompt() + if systemPromptContent != "0" { + systemPrompt := structs.Message{ + Text: systemPromptContent, + Role: "system", + } + // 将系统提示词添加到历史信息的开始 + history = append([]structs.Message{systemPrompt}, history...) + } + + // 分别获取FirstQ&A, SecondQ&A, ThirdQ&A + pairs := []struct { + Q string + A string + RoleQ string // 问题的角色 + RoleA string // 答案的角色 + }{ + {config.GetFirstQ(), config.GetFirstA(), "user", "assistant"}, + {config.GetSecondQ(), config.GetSecondA(), "user", "assistant"}, + {config.GetThirdQ(), config.GetThirdA(), "user", "assistant"}, + } + + // 检查每一对Q&A是否均不为空,并追加到历史信息中 + for _, pair := range pairs { + if pair.Q != "" && pair.A != "" { + qMessage := structs.Message{ + Text: pair.Q, + Role: pair.RoleQ, + } + aMessage := structs.Message{ + Text: pair.A, + Role: pair.RoleA, + } + + // 注意追加的顺序,确保问题在答案之前 + history = append(history, qMessage, aMessage) + } + } + + // 获取历史信息 + if msg.ParentMessageID != "" { + userhistory, err := app.getHistory(msg.ConversationID, msg.ParentMessageID) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // 截断历史信息 + userhistory = truncateHistoryGpt(userhistory, msg.Text) + + // 注意追加的顺序,确保问题在系统提示词之后 + // 使用...操作符来展开userhistory切片并追加到history切片 + history = append(history, userhistory...) + } + + fmtf.Printf("RWKV上下文history:%v\n", history) + + // 构建请求到RWKV API + apiURL := config.GetRwkvApiPath() + + // 构造消息历史和当前消息 + messages := []map[string]interface{}{} + for _, hMsg := range history { + messages = append(messages, map[string]interface{}{ + "role": hMsg.Role, + "content": hMsg.Text, + }) + } + messages = append(messages, map[string]interface{}{ + "role": "user", + "content": msg.Text, + }) + + // 构建请求体 + requestBody := map[string]interface{}{ + "max_tokens": config.GetRwkvMaxTokens(), + "temperature": config.GetRwkvTemperature(), + "top_p": config.GetRwkvTopP(), + "presence_penalty": config.GetRwkvPresencePenalty(), + "frequency_penalty": config.GetRwkvFrequencyPenalty(), + "penalty_decay": config.GetRwkvPenaltyDecay(), + "top_k": config.GetRwkvTopK(), + "global_penalty": config.GetRwkvGlobalPenalty(), + "model": "rwkv", + "stream": config.GetuseSse(), + "stop": config.GetRwkvStop(), + "user_name": config.GetRwkvUserName(), + "assistant_name": config.GetRwkvAssistantName(), + "system_name": config.GetRwkvSystemName(), + "presystem": config.GetRwkvPreSystem(), + "messages": messages, + } + + fmtf.Printf("rwkv requestBody :%v", requestBody) + requestBodyJSON, _ := json.Marshal(requestBody) + + // 准备HTTP请求 + client := &http.Client{} + req, err := http.NewRequest("POST", apiURL, bytes.NewBuffer(requestBodyJSON)) + if err != nil { + http.Error(w, fmtf.Sprintf("Failed to create request: %v", err), http.StatusInternalServerError) + return + } + + req.Header.Set("Content-Type", "application/json") + + // 发送请求 + resp, err := client.Do(req) + if err != nil { + http.Error(w, fmtf.Sprintf("Error sending request to ChatGPT API: %v", err), http.StatusInternalServerError) + return + } + defer resp.Body.Close() + + if !config.GetuseSse() { + // 处理响应 + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + http.Error(w, fmtf.Sprintf("Failed to read response body: %v", err), http.StatusInternalServerError) + return + } + fmtf.Printf("rwkv 返回:%v", string(responseBody)) + // 假设已经成功发送请求并获得响应,responseBody是响应体的字节数据 + var apiResponse struct { + Choices []struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + } `json:"choices"` + } + if err := json.Unmarshal(responseBody, &apiResponse); err != nil { + http.Error(w, fmtf.Sprintf("Error unmarshaling API response: %v", err), http.StatusInternalServerError) + return + } + + // 从API响应中获取回复文本 + responseText := "" + if len(apiResponse.Choices) > 0 { + responseText = apiResponse.Choices[0].Message.Content + } + + // 添加助理消息 + assistantMessageID, err := app.addMessage(structs.Message{ + ConversationID: msg.ConversationID, + ParentMessageID: userMessageID, + Text: responseText, + Role: "assistant", + }) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // 构造响应数据,包括回复文本、对话ID、消息ID,以及使用情况(用例中未计算,可根据需要添加) + responseMap := map[string]interface{}{ + "response": responseText, + "conversationId": msg.ConversationID, + "messageId": assistantMessageID, + // 在此实际使用情况中,应该有逻辑来填充totalUsage + // 此处仅为示例,根据实际情况来调整 + "details": map[string]interface{}{ + "usage": structs.UsageInfo{ + PromptTokens: 0, // 示例值,需要根据实际情况计算 + CompletionTokens: 0, // 示例值,需要根据实际情况计算 + }, + }, + } + + // 设置响应头部为JSON格式 + w.Header().Set("Content-Type", "application/json") + // 将响应数据编码为JSON并发送 + if err := json.NewEncoder(w).Encode(responseMap); err != nil { + http.Error(w, fmtf.Sprintf("Error encoding response: %v", err), http.StatusInternalServerError) + return + } + } else { + // 设置SSE相关的响应头部 + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported!", http.StatusInternalServerError) + return + } + + // 生成一个随机的UUID + randomUUID, err := uuid.NewRandom() + if err != nil { + http.Error(w, "Failed to generate UUID", http.StatusInternalServerError) + return + } + + reader := bufio.NewReader(resp.Body) + var responseTextBuilder strings.Builder + var totalUsage structs.GPTUsageInfo + if config.GetRwkvSseType() == 1 { + for { + line, err := reader.ReadString('\n') + if err != nil { + if err == io.EOF { + break // 流结束 + } + // 处理错误 + fmtf.Fprintf(w, "data: %s\n\n", fmtf.Sprintf("读取流数据时发生错误: %v", err)) + flusher.Flush() + continue + } + + if strings.HasPrefix(line, "data: ") { + eventDataJSON := line[5:] // 去掉"data: "前缀 + + // 解析JSON数据 + var eventData structs.GPTEventData + if err := json.Unmarshal([]byte(eventDataJSON), &eventData); err != nil { + fmtf.Fprintf(w, "data: %s\n\n", fmtf.Sprintf("解析事件数据出错: %v", err)) + flusher.Flush() + continue + } + + // 遍历choices数组,累积所有文本内容 + for _, choice := range eventData.Choices { + responseTextBuilder.WriteString(choice.Delta.Content) + } + + // 如果存在需要发送的临时响应数据(例如,在事件流中间点) + // 注意:这里暂时省略了使用信息的处理,因为示例输出中没有包含这部分数据 + tempResponseMap := map[string]interface{}{ + "response": responseTextBuilder.String(), + "conversationId": msg.ConversationID, // 确保msg.ConversationID已经定义并初始化 + // "details" 字段留待进一步处理,如有必要 + } + tempResponseJSON, _ := json.Marshal(tempResponseMap) + fmtf.Fprintf(w, "data: %s\n\n", string(tempResponseJSON)) + flusher.Flush() + } + } + } else { + for { + line, err := reader.ReadString('\n') + if err != nil { + if err == io.EOF { + break // 流结束 + } + fmtf.Fprintf(w, "data: %s\n\n", fmtf.Sprintf("读取流数据时发生错误: %v", err)) + flusher.Flush() + continue + } + + if strings.HasPrefix(line, "data: ") { + eventDataJSON := line[5:] // 去掉"data: "前缀 + if eventDataJSON[1] != '{' { + fmtf.Println("非JSON数据,跳过:", eventDataJSON) + continue + } + + //fmtf.Printf("rwkv eventDataJSON:%v", eventDataJSON) + + var eventData structs.GPTEventData + if err := json.Unmarshal([]byte(eventDataJSON), &eventData); err != nil { + fmtf.Fprintf(w, "data: %s\n\n", fmtf.Sprintf("解析事件数据出错: %v", err)) + flusher.Flush() + continue + } + + // 在修改共享资源之前锁定Mutex + mutexRwkv.Lock() + + conversationId := msg.ConversationID + randomUUID.String() + //读取完整信息 + completeResponse, _ := lastCompleteResponsesRwkv.LoadOrStore(conversationId, "") + + // 检索上一次的响应文本 + lastResponse, _ := lastResponsesRwkv.LoadOrStore(conversationId, "") + lastResponseText := lastResponse.(string) + + newContent := "" + for _, choice := range eventData.Choices { + if strings.HasPrefix(choice.Delta.Content, lastResponseText) { + // 如果新内容以旧内容开头,剔除旧内容部分,只保留新增的部分 + newContent += choice.Delta.Content[len(lastResponseText):] + } else { + // 如果新内容不以旧内容开头,可能是并发情况下的新消息,直接使用新内容 + newContent += choice.Delta.Content + } + } + + // 更新存储的完整累积信息 + updatedCompleteResponse := completeResponse.(string) + newContent + lastCompleteResponsesRwkv.Store(conversationId, updatedCompleteResponse) + + // 使用累加的新内容更新存储的最后响应状态 + if newContent != "" { + lastResponsesRwkv.Store(conversationId, newContent) + } + + // 完成修改后解锁Mutex + mutexRwkv.Unlock() + + // 发送新增的内容 + if newContent != "" { + tempResponseMap := map[string]interface{}{ + "response": newContent, + "conversationId": conversationId, + } + tempResponseJSON, _ := json.Marshal(tempResponseMap) + fmtf.Fprintf(w, "data: %s\n\n", string(tempResponseJSON)) + flusher.Flush() + } + } + } + } + //一点点奇怪的转换 + conversationId := msg.ConversationID + randomUUID.String() + completeResponse, _ := lastCompleteResponsesRwkv.LoadOrStore(conversationId, "") + // 在所有事件处理完毕后发送最终响应 + assistantMessageID, err := app.addMessage(structs.Message{ + ConversationID: msg.ConversationID, + ParentMessageID: userMessageID, + Text: completeResponse.(string), + Role: "assistant", + }) + + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // 在所有事件处理完毕后发送最终响应 + // 首先从 conversationMap 获取真实的 conversationId + if finalContent, ok := lastCompleteResponsesRwkv.Load(conversationId); ok { + finalResponseMap := map[string]interface{}{ + "response": finalContent, + "conversationId": conversationId, + "messageId": assistantMessageID, + "details": map[string]interface{}{ + "usage": totalUsage, + }, + } + finalResponseJSON, _ := json.Marshal(finalResponseMap) + fmtf.Fprintf(w, "data: %s\n\n", string(finalResponseJSON)) + flusher.Flush() + } + } + +} diff --git a/config/config.go b/config/config.go index 378ba60..356bea8 100644 --- a/config/config.go +++ b/config/config.go @@ -94,6 +94,22 @@ type Settings struct { AIPromptkeyboardPath string `yaml:"AIPromptkeyboardPath"` UseAIPromptkeyboard bool `yaml:"useAIPromptkeyboard"` SplitByPuntuationsGroup int `yaml:"splitByPuntuationsGroup"` + RwkvApiPath string `yaml:"rwkvApiPath"` + RwkvMaxTokens int `yaml:"rwkvMaxTokens"` + RwkvTemperature float64 `yaml:"rwkvTemperature"` + RwkvTopP float64 `yaml:"rwkvTopP"` + RwkvPresencePenalty float64 `yaml:"rwkvPresencePenalty"` + RwkvFrequencyPenalty float64 `yaml:"rwkvFrequencyPenalty"` + RwkvPenaltyDecay float64 `yaml:"rwkvPenaltyDecay"` + RwkvTopK int `yaml:"rwkvTopK"` + RwkvGlobalPenalty bool `yaml:"rwkvGlobalPenalty"` + RwkvStream bool `yaml:"rwkvStream"` + RwkvStop []string `yaml:"rwkvStop"` + RwkvUserName string `yaml:"rwkvUserName"` + RwkvAssistantName string `yaml:"rwkvAssistantName"` + RwkvSystemName string `yaml:"rwkvSystemName"` + RwkvPreSystem bool `yaml:"rwkvPreSystem"` + RwkvSseType int `yaml:"rwkvSseType"` } // LoadConfig 从文件中加载配置并初始化单例配置 @@ -954,3 +970,163 @@ func GetAIPromptkeyboardPath() string { } return "" } + +// 获取RWKV API路径 +func GetRwkvApiPath() string { + mu.Lock() + defer mu.Unlock() + if instance != nil { + return instance.Settings.RwkvApiPath + } + return "" +} + +// 获取RWKV最大令牌数 +func GetRwkvMaxTokens() int { + mu.Lock() + defer mu.Unlock() + if instance != nil { + return instance.Settings.RwkvMaxTokens + } + return 0 +} + +// 获取RwkvSseType +func GetRwkvSseType() int { + mu.Lock() + defer mu.Unlock() + if instance != nil { + return instance.Settings.RwkvSseType + } + return 0 +} + +// 获取RWKV温度 +func GetRwkvTemperature() float64 { + mu.Lock() + defer mu.Unlock() + if instance != nil { + return instance.Settings.RwkvTemperature + } + return 0.0 +} + +// 获取RWKV Top P +func GetRwkvTopP() float64 { + mu.Lock() + defer mu.Unlock() + if instance != nil { + return instance.Settings.RwkvTopP + } + return 0.0 +} + +// 获取RWKV存在惩罚 +func GetRwkvPresencePenalty() float64 { + mu.Lock() + defer mu.Unlock() + if instance != nil { + return instance.Settings.RwkvPresencePenalty + } + return 0.0 +} + +// 获取RWKV频率惩罚 +func GetRwkvFrequencyPenalty() float64 { + mu.Lock() + defer mu.Unlock() + if instance != nil { + return instance.Settings.RwkvFrequencyPenalty + } + return 0.0 +} + +// 获取RWKV惩罚衰减 +func GetRwkvPenaltyDecay() float64 { + mu.Lock() + defer mu.Unlock() + if instance != nil { + return instance.Settings.RwkvPenaltyDecay + } + return 0.0 +} + +// 获取RWKV Top K +func GetRwkvTopK() int { + mu.Lock() + defer mu.Unlock() + if instance != nil { + return instance.Settings.RwkvTopK + } + return 0 +} + +// 获取RWKV是否全局惩罚 +func GetRwkvGlobalPenalty() bool { + mu.Lock() + defer mu.Unlock() + if instance != nil { + return instance.Settings.RwkvGlobalPenalty + } + return false +} + +// 获取RWKV是否流模式 +func GetRwkvStream() bool { + mu.Lock() + defer mu.Unlock() + if instance != nil { + return instance.Settings.RwkvStream + } + return false +} + +// 获取RWKV停止列表 +func GetRwkvStop() []string { + mu.Lock() + defer mu.Unlock() + if instance != nil { + return instance.Settings.RwkvStop + } + return nil +} + +// 获取RWKV用户名 +func GetRwkvUserName() string { + mu.Lock() + defer mu.Unlock() + if instance != nil { + return instance.Settings.RwkvUserName + } + return "" +} + +// 获取RWKV助手名 +func GetRwkvAssistantName() string { + mu.Lock() + defer mu.Unlock() + if instance != nil { + return instance.Settings.RwkvAssistantName + } + return "" +} + +// 获取RWKV系统名称 +func GetRwkvSystemName() string { + mu.Lock() + defer mu.Unlock() + if instance != nil { + return instance.Settings.RwkvSystemName + } + return "" +} + +// 获取RWKV是否预处理 +func GetRwkvPreSystem() bool { + mu.Lock() + defer mu.Unlock() + if instance != nil { + return instance.Settings.RwkvPreSystem + } + return false +} diff --git a/main.go b/main.go index be9a705..6b17a7f 100644 --- a/main.go +++ b/main.go @@ -144,6 +144,9 @@ func main() { case 2: // 如果API类型是2,使用app.chatHandlerChatGpt http.HandleFunc("/conversation", app.ChatHandlerChatgpt) + case 3: + // 如果API类型是3,使用app.chatHandlerRwkv + http.HandleFunc("/conversation", app.ChatHandlerRwkv) default: // 如果是其他值,可以选择一个默认的处理器或者记录一个错误 log.Printf("Unknown API type: %d", apiType) diff --git a/template/config_template.go b/template/config_template.go index a762513..b47ff46 100644 --- a/template/config_template.go +++ b/template/config_template.go @@ -96,6 +96,24 @@ settings: gptSafeMode : false #额外走腾讯云检查安全,但是会额外消耗P数(会给出回复,但可能跑偏)仅api2d支持 gptModeration : false #额外走腾讯云检查安全,不合规直接拦截.(和上面一样但是会直接拦截.)仅api2d支持 gptSseType : 0 #gpt的sse流式有两种形式,0是只返回新的 你 好 呀 , 我 是 一 个,1是递增 你好呀,我是一个人类 你 你好 你好呀 你好呀, 你好呀,我 你好呀,我是 + + # RWKV 模型配置文件 仅适用于对接gensokyo-discord、gensokyo-telegram等平台,国内请遵守并符合相应的api资质要求. + rwkvApiPath: "https://api.example.com/rwkv" # 符合 RWKV 标准的 API 地址 是否以流形式取决于UseSSE配置 + rwkvMaxTokens: 1024 # 最大的输出 Token 数量 + rwkvTemperature: 0.7 # 生成的随机性控制 + rwkvTopP: 0.9 # 累积概率最高的令牌进行采样的界限 + rwkvPresencePenalty: 0.0 # 当前上下文中令牌出现的频率惩罚 + rwkvFrequencyPenalty: 0.0 # 全局令牌出现的频率惩罚 + rwkvPenaltyDecay: 0.99 # 惩罚值的衰减率 + rwkvTopK: 25 # 从概率最高的K个令牌中采样 + rwkvSseType : 0 # 同gptSseType + rwkvGlobalPenalty: false # 是否在全局上应用频率惩罚 + rwkvStop: # 停止生成的标记列表 + - "\n\nUser" + rwkvUserName: "User" # 用户名称 + rwkvAssistantName: "Assistant" # 助手名称 + rwkvSystemName: "System" # 系统名称 + rwkvPreSystem: false # 是否在系统层面进行预处理 ` const Logo = ` diff --git a/utils/utils.go b/utils/utils.go index 94a5fae..c6377ff 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -274,6 +274,14 @@ func SendPrivateMessageSSE(UserID int64, message structs.InterfaceBody) error { fmtf.Printf("流式信息替换后:%v", message.Content) } + // 去除末尾的换行符 不去除会导致sse接口始终等待 + message.Content = removeTrailingCRLFs(message.Content) + + if message.Content == "" { + message.Content = " " + fmtf.Printf("过滤空SendPrivateMessageSSE,可能是llm api只发了换行符.") + } + // 构造请求体,包括InterfaceBody requestBody, err := json.Marshal(map[string]interface{}{ "user_id": UserID, @@ -317,6 +325,31 @@ func SendPrivateMessageSSE(UserID int64, message structs.InterfaceBody) error { return nil } +// removeTrailingCRLFs 移除字符串末尾的所有CRLF换行符 +func removeTrailingCRLFs(input string) string { + // 将字符串转换为字节切片 + byteMessage := []byte(input) + + // CRLF的字节表示 + crlf := []byte{'\r', '\n'} + + // 循环移除末尾的CRLF + for bytes.HasSuffix(byteMessage, crlf) { + byteMessage = bytes.TrimSuffix(byteMessage, crlf) + } + + // LFLF的字节表示 + lflf := []byte{'\n', '\n'} + + // 循环移除末尾的LFLF + for bytes.HasSuffix(byteMessage, lflf) { + byteMessage = bytes.TrimSuffix(byteMessage, lflf) + } + + // 将处理后的字节切片转换回字符串 + return string(byteMessage) +} + // ReverseString 颠倒给定字符串中的字符顺序 func ReverseString(s string) string { // // 将字符串转换为rune切片,以便处理多字节字符 From 43a22a568b5c2821ea053ea9cb3a1e28e8a95728 Mon Sep 17 00:00:00 2001 From: SanaeFox <36219542+Hoshinonyaruko@users.noreply.github.com> Date: Sat, 13 Apr 2024 20:32:55 +0800 Subject: [PATCH 05/13] Beta65 (#67) * beta1 * beta2 * beta3 * beta4 * beta5 * beta6 * beta7 * beta8 * beta9 * beta10 * beta11 * beta12 * beta13 * beta14 * beta15 * beta16 * beta16 * beta19 * beta20 * beta21 * beta22 * beta23 * beta24 * beta25 * beta27 * beta28 * beta29 * beta30 * beta31 * beta33 * beta34 * beta35 * beta36 * beta37 * beta38 * beta39 * beta40 * beta41 * beta42 * beta43 * beta44 * beta45 * beta45 * beta46 * beat48 * beta49 * beta50 * beta51 * beta52 * beta53 * beta54 * beta55 * beta57 * beta58 * beta59 * beta61 * beta62 * beta63 * beta63 * beta64 * beta65 --- utils/utils.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/utils.go b/utils/utils.go index c6377ff..ac38584 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -435,7 +435,7 @@ func PostSensitiveMessages() error { // SendSSEPrivateMessage 分割并发送消息的核心逻辑,直接遍历字符串 func SendSSEPrivateMessage(userID int64, content string) { - punctuations := []rune{'。', '!', '?', ',', ',', '.', '!', '?'} + punctuations := []rune{'。', '!', '?', ',', ',', '.', '!', '?', '~'} splitProbability := config.GetSplitByPuntuations() var parts []string From 8ff8ea2b71c6d218b9061b5e3c81ec4517c3e1fb Mon Sep 17 00:00:00 2001 From: SanaeFox <36219542+Hoshinonyaruko@users.noreply.github.com> Date: Mon, 15 Apr 2024 23:28:02 +0800 Subject: [PATCH 06/13] Beta66 (#68) * beta1 * beta2 * beta3 * beta4 * beta5 * beta6 * beta7 * beta8 * beta9 * beta10 * beta11 * beta12 * beta13 * beta14 * beta15 * beta16 * beta16 * beta19 * beta20 * beta21 * beta22 * beta23 * beta24 * beta25 * beta27 * beta28 * beta29 * beta30 * beta31 * beta33 * beta34 * beta35 * beta36 * beta37 * beta38 * beta39 * beta40 * beta41 * beta42 * beta43 * beta44 * beta45 * beta45 * beta46 * beat48 * beta49 * beta50 * beta51 * beta52 * beta53 * beta54 * beta55 * beta57 * beta58 * beta59 * beta61 * beta62 * beta63 * beta63 * beta64 * beta65 * beta66 --- applogic/gensokyo.go | 5 +- main.go | 8 +++ utils/log.go | 125 +++++++++++++++++++++++++++++++++++++++++++ utils/utils.go | 1 + 4 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 utils/log.go diff --git a/applogic/gensokyo.go b/applogic/gensokyo.go index 27431e5..4467a0a 100644 --- a/applogic/gensokyo.go +++ b/applogic/gensokyo.go @@ -363,7 +363,7 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { requestmsg = acnode.CheckWordIN(requestmsg) } - fmtf.Printf("实际请求conversation端点内容:%v\n", requestmsg) + fmtf.Printf("实际请求conversation端点内容:[%v]%v\n", message.UserID, requestmsg) requestBody, err := json.Marshal(map[string]interface{}{ "message": requestmsg, @@ -442,6 +442,9 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { } else { utils.SendGroupMessage(message.GroupID, message.UserID, newPart) } + } else { + //流的最后一次是完整结束的 + fmtf.Printf("A完整信息: %s(sse完整结束)\n", response) } } else if response != "" { diff --git a/main.go b/main.go index 6b17a7f..fbde55f 100644 --- a/main.go +++ b/main.go @@ -23,6 +23,7 @@ func main() { testFlag := flag.Bool("test", false, "Run the test script, test.txt中的是虚拟信息,一行一条") ymlPath := flag.String("yml", "", "指定config.yml的路径") vFlag := flag.Bool("v", false, "Run ProcessSensitiveWordsV2") + tidyFlag := flag.Bool("tidy", false, "Run tidylog") flag.Parse() // 如果用户指定了-yml参数 @@ -176,6 +177,13 @@ func main() { } } + // 根据-tidy参数决定是否运行utils.Tidylogs() + if *tidyFlag { + utils.Tidylogs() + fmtf.Println("日志整理完毕") + return + } + http.HandleFunc("/gensokyo", app.GensokyoHandler) port := config.GetPort() portStr := fmtf.Sprintf(":%d", port) diff --git a/utils/log.go b/utils/log.go new file mode 100644 index 0000000..ac77815 --- /dev/null +++ b/utils/log.go @@ -0,0 +1,125 @@ +package utils + +import ( + "bufio" + "bytes" + "fmt" + "os" + "path/filepath" + "strings" +) + +func Tidylogs() { + logDir := "./log" + files, err := os.ReadDir(logDir) + if err != nil { + fmt.Println("Error reading log directory:", err) + return + } + + for _, file := range files { + fileName := file.Name() + if filepath.Ext(fileName) == ".log" && !strings.Contains(fileName, "-tidy") { + processLogFile(filepath.Join(logDir, fileName)) + } + } +} + +func processLogFile(filePath string) { + outputFilePath := strings.TrimSuffix(filePath, filepath.Ext(filePath)) + "-tidy.log" + + // Check if the tidy file already exists + if _, err := os.Stat(outputFilePath); err == nil { + fmt.Println("Skipping as tidy file already exists:", outputFilePath) + return // File exists, skip processing + } else if !os.IsNotExist(err) { + fmt.Println("Error checking output file:", err) + return // Some other error occurred when checking the file + } + + // Read the entire file + data, err := os.ReadFile(filePath) + if err != nil { + fmt.Println("Error reading file:", err) + return + } + + // Define newline sequences and placeholder + crlf := []byte{'\r', '\n'} + lf := []byte{'\n'} + placeholder := []byte{0xFF, 0xFE} // Safe placeholder for double newlines + + // Handle different newline formats + doubleCRLF := append(crlf, crlf...) + doubleLF := append(lf, lf...) + + // Replace double newlines with a placeholder + data = bytes.ReplaceAll(data, doubleCRLF, placeholder) + data = bytes.ReplaceAll(data, doubleLF, placeholder) + + // Remove standalone newlines + data = bytes.ReplaceAll(data, crlf, []byte{}) + data = bytes.ReplaceAll(data, lf, []byte{}) + + // Replace placeholders with a single newline (LF) + data = bytes.ReplaceAll(data, placeholder, lf) + + outputFile, err := os.Create(outputFilePath) + if err != nil { + fmt.Println("Error creating output file:", err) + return + } + defer outputFile.Close() + + // Scan through the modified content + scanner := bufio.NewScanner(bytes.NewReader(data)) + for scanner.Scan() { + line := scanner.Text() + // Process each line based on specific patterns + if strings.Contains(line, "实际请求conversation端点内容:") { + formatAndWriteQuestionLine(line, outputFile) + } + if strings.Contains(line, "A完整信息:") { + formatAndWriteAnswerLine(line, outputFile) + } + } + + if err := scanner.Err(); err != nil { + fmt.Println("Error scanning content:", err) + } +} + +func formatAndWriteQuestionLine(line string, outputFile *os.File) { + prefix := "实际请求conversation端点内容:" + startIndex := strings.Index(line, prefix) + if startIndex != -1 { + // 找到前缀后,提取从这个位置开始直到行尾的内容 + messageStart := startIndex + len(prefix) + message := line[messageStart:] // 从"实际请求conversation端点内容:"后的内容开始提取到行尾 + message = strings.TrimSpace(message) // 去除前后空格 + formattedLine := fmt.Sprintf("Q:%s\n", message) // 格式化行 + + // 写入到输出文件 + _, err := outputFile.WriteString(formattedLine) + if err != nil { + fmt.Println("Error writing to output file:", err) + } + } +} + +func formatAndWriteAnswerLine(line string, outputFile *os.File) { + prefix := "A完整信息:" + startIndex := strings.Index(line, prefix) // 查找"A完整信息:"的开始位置 + if startIndex != -1 { + // 找到"A完整信息:"后,提取从这个位置开始直到行尾的内容 + messageStart := startIndex + len(prefix) + message := line[messageStart:] // 从"A完整信息:"后的内容开始提取到行尾 + formattedLine := fmt.Sprintf("A:%s\n", strings.TrimSpace(message)) // 格式化并去除前后空白字符 + + // 写入到输出文件 + _, err := outputFile.WriteString(formattedLine) + if err != nil { + fmt.Println("Error writing to output file:", err) + } + } +} diff --git a/utils/utils.go b/utils/utils.go index ac38584..f730477 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -280,6 +280,7 @@ func SendPrivateMessageSSE(UserID int64, message structs.InterfaceBody) error { if message.Content == "" { message.Content = " " fmtf.Printf("过滤空SendPrivateMessageSSE,可能是llm api只发了换行符.") + return nil } // 构造请求体,包括InterfaceBody From 3232fd8f13007ef902d67919a71627a4e18fc535 Mon Sep 17 00:00:00 2001 From: SanaeFox <36219542+Hoshinonyaruko@users.noreply.github.com> Date: Wed, 17 Apr 2024 16:32:11 +0800 Subject: [PATCH 07/13] Beta67 (#69) * beta1 * beta2 * beta3 * beta4 * beta5 * beta6 * beta7 * beta8 * beta9 * beta10 * beta11 * beta12 * beta13 * beta14 * beta15 * beta16 * beta16 * beta19 * beta20 * beta21 * beta22 * beta23 * beta24 * beta25 * beta27 * beta28 * beta29 * beta30 * beta31 * beta33 * beta34 * beta35 * beta36 * beta37 * beta38 * beta39 * beta40 * beta41 * beta42 * beta43 * beta44 * beta45 * beta45 * beta46 * beat48 * beta49 * beta50 * beta51 * beta52 * beta53 * beta54 * beta55 * beta57 * beta58 * beta59 * beta61 * beta62 * beta63 * beta63 * beta64 * beta65 * beta66 * beta67 --- applogic/gensokyo.go | 8 ++++++-- config/config.go | 11 +++++++++++ template/config_template.go | 3 ++- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/applogic/gensokyo.go b/applogic/gensokyo.go index 4467a0a..71bbd19 100644 --- a/applogic/gensokyo.go +++ b/applogic/gensokyo.go @@ -407,7 +407,9 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { } // 处理接收到的数据 - fmtf.Printf("Received SSE data: %s", string(line)) + if !config.GetHideExtraLogs() { + fmtf.Printf("Received SSE data: %s", string(line)) + } // 去除"data: "前缀后进行JSON解析 jsonData := strings.TrimPrefix(string(line), "data: ") @@ -485,7 +487,9 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { } } else { //发送信息 - fmtf.Printf("收到流数据,切割并发送信息: %s", string(line)) + if !config.GetHideExtraLogs() { + fmtf.Printf("收到流数据,切割并发送信息: %s", string(line)) + } splitAndSendMessages(message, string(line), newmsg) } } diff --git a/config/config.go b/config/config.go index 356bea8..b88bbb1 100644 --- a/config/config.go +++ b/config/config.go @@ -110,6 +110,7 @@ type Settings struct { RwkvSystemName string `yaml:"rwkvSystemName"` RwkvPreSystem bool `yaml:"rwkvPreSystem"` RwkvSseType int `yaml:"rwkvSseType"` + HideExtraLogs bool `yaml:"hideExtraLogs"` } // LoadConfig 从文件中加载配置并初始化单例配置 @@ -1130,3 +1131,13 @@ func GetRwkvPreSystem() bool { } return false } + +// 获取隐藏日志 +func GetHideExtraLogs() bool { + mu.Lock() + defer mu.Unlock() + if instance != nil { + return instance.Settings.HideExtraLogs + } + return false +} diff --git a/template/config_template.go b/template/config_template.go index b47ff46..0dfe20e 100644 --- a/template/config_template.go +++ b/template/config_template.go @@ -35,7 +35,8 @@ settings: promptkeyboard : [""] #临时的promptkeyboard超过3个则随机,后期会增加一个ai生成的方式,也会是ai-agent savelogs : false #本地落地日志. noContext : false #不开启上下文 - withdrawCommand : ["撤回"] #撤回指令 + withdrawCommand : ["撤回"] #撤回指令 + hideExtraLogs : false #忽略流信息的log,提高性能 functionMode : false #是否指定本agent使用func模式(目前仅支持千帆平台),效果不好,暂时不用. functionPath : "" #调用另一个启用了func模式的gsk-llm联合工作的/conversation地址,效果不好,暂时不用. From 834d7e74300ce8b1f66e955e1558d9653cd8a05a Mon Sep 17 00:00:00 2001 From: SanaeFox <36219542+Hoshinonyaruko@users.noreply.github.com> Date: Thu, 18 Apr 2024 20:04:56 +0800 Subject: [PATCH 08/13] Beta70 (#70) * beta1 * beta2 * beta3 * beta4 * beta5 * beta6 * beta7 * beta8 * beta9 * beta10 * beta11 * beta12 * beta13 * beta14 * beta15 * beta16 * beta16 * beta19 * beta20 * beta21 * beta22 * beta23 * beta24 * beta25 * beta27 * beta28 * beta29 * beta30 * beta31 * beta33 * beta34 * beta35 * beta36 * beta37 * beta38 * beta39 * beta40 * beta41 * beta42 * beta43 * beta44 * beta45 * beta45 * beta46 * beat48 * beta49 * beta50 * beta51 * beta52 * beta53 * beta54 * beta55 * beta57 * beta58 * beta59 * beta61 * beta62 * beta63 * beta63 * beta64 * beta65 * beta66 * beta67 * beta70 --- .gitignore | 5 +- applogic/chatgpt.go | 83 +++++++++------ applogic/ernie.go | 88 ++++++++++------ applogic/gensokyo.go | 72 ++++++++----- applogic/hunyuan.go | 89 +++++++++------- applogic/rwkv.go | 83 +++++++++------ applogic/vectorsensitive.go | 6 +- config/config.go | 162 +++++++++++------------------ go.mod | 26 ++++- go.sum | 71 +++++++++++++ main.go | 28 ++++- prompt/prompt.go | 202 ++++++++++++++++++++++++++++++++++++ readme.md | 32 +++++- server/server.go | 199 +++++++++++++++++++++++++++++++++++ structs/struct.go | 159 ++++++++++++++++++++++++++++ template/config_template.go | 6 +- utils/blacklist.go | 6 +- utils/utils.go | 47 +++++++-- 18 files changed, 1088 insertions(+), 276 deletions(-) create mode 100644 prompt/prompt.go create mode 100644 server/server.go diff --git a/.gitignore b/.gitignore index b705a09..8157dc6 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,7 @@ *.exe # log -log \ No newline at end of file +log + +# prompts +prompts \ No newline at end of file diff --git a/applogic/chatgpt.go b/applogic/chatgpt.go index 511ba62..6d3627c 100644 --- a/applogic/chatgpt.go +++ b/applogic/chatgpt.go @@ -11,6 +11,7 @@ import ( "github.com/hoshinonyaruko/gensokyo-llm/config" "github.com/hoshinonyaruko/gensokyo-llm/fmtf" + "github.com/hoshinonyaruko/gensokyo-llm/prompt" "github.com/hoshinonyaruko/gensokyo-llm/structs" "github.com/hoshinonyaruko/gensokyo-llm/utils" ) @@ -37,6 +38,14 @@ func (app *App) ChatHandlerChatgpt(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusBadRequest) return } + + // 读取URL参数 "prompt" + promptstr := r.URL.Query().Get("prompt") + if promptstr != "" { + // prompt 参数存在,可以根据需要进一步处理或记录 + fmtf.Printf("Received prompt parameter: %s\n", promptstr) + } + msg.Role = "user" //颠倒用户输入 if config.GetReverseUserPrompt() { @@ -56,43 +65,51 @@ func (app *App) ChatHandlerChatgpt(w http.ResponseWriter, r *http.Request) { var history []structs.Message - // 获取系统提示词 - systemPromptContent := config.SystemPrompt() - if systemPromptContent != "0" { - systemPrompt := structs.Message{ - Text: systemPromptContent, - Role: "system", + //根据是否有prompt参数 选择是否载入config.yml的prompt还是prompts文件夹的 + if promptstr == "" { + // 获取系统提示词 + systemPromptContent := config.SystemPrompt() + if systemPromptContent != "0" { + systemPrompt := structs.Message{ + Text: systemPromptContent, + Role: "system", + } + // 将系统提示词添加到历史信息的开始 + history = append([]structs.Message{systemPrompt}, history...) } - // 将系统提示词添加到历史信息的开始 - history = append([]structs.Message{systemPrompt}, history...) - } - // 分别获取FirstQ&A, SecondQ&A, ThirdQ&A - pairs := []struct { - Q string - A string - RoleQ string // 问题的角色 - RoleA string // 答案的角色 - }{ - {config.GetFirstQ(), config.GetFirstA(), "user", "assistant"}, - {config.GetSecondQ(), config.GetSecondA(), "user", "assistant"}, - {config.GetThirdQ(), config.GetThirdA(), "user", "assistant"}, - } + // 分别获取FirstQ&A, SecondQ&A, ThirdQ&A + pairs := []struct { + Q string + A string + RoleQ string // 问题的角色 + RoleA string // 答案的角色 + }{ + {config.GetFirstQ(), config.GetFirstA(), "user", "assistant"}, + {config.GetSecondQ(), config.GetSecondA(), "user", "assistant"}, + {config.GetThirdQ(), config.GetThirdA(), "user", "assistant"}, + } - // 检查每一对Q&A是否均不为空,并追加到历史信息中 - for _, pair := range pairs { - if pair.Q != "" && pair.A != "" { - qMessage := structs.Message{ - Text: pair.Q, - Role: pair.RoleQ, - } - aMessage := structs.Message{ - Text: pair.A, - Role: pair.RoleA, - } + // 检查每一对Q&A是否均不为空,并追加到历史信息中 + for _, pair := range pairs { + if pair.Q != "" && pair.A != "" { + qMessage := structs.Message{ + Text: pair.Q, + Role: pair.RoleQ, + } + aMessage := structs.Message{ + Text: pair.A, + Role: pair.RoleA, + } - // 注意追加的顺序,确保问题在答案之前 - history = append(history, qMessage, aMessage) + // 注意追加的顺序,确保问题在答案之前 + history = append(history, qMessage, aMessage) + } + } + } else { + history, err = prompt.GetMessagesFromFilename(promptstr) + if err != nil { + fmtf.Printf("prompt.GetMessagesFromFilename error: %v\n", err) } } diff --git a/applogic/ernie.go b/applogic/ernie.go index 865101b..943f523 100644 --- a/applogic/ernie.go +++ b/applogic/ernie.go @@ -12,6 +12,7 @@ import ( "github.com/hoshinonyaruko/gensokyo-llm/config" "github.com/hoshinonyaruko/gensokyo-llm/fmtf" + "github.com/hoshinonyaruko/gensokyo-llm/prompt" "github.com/hoshinonyaruko/gensokyo-llm/structs" "github.com/hoshinonyaruko/gensokyo-llm/utils" ) @@ -30,6 +31,14 @@ func (app *App) ChatHandlerErnie(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusBadRequest) return } + + // 读取URL参数 "prompt" + promptstr := r.URL.Query().Get("prompt") + if promptstr != "" { + // prompt 参数存在,可以根据需要进一步处理或记录 + fmtf.Printf("Received prompt parameter: %s\n", promptstr) + } + msg.Role = "user" //颠倒用户输入 if config.GetReverseUserPrompt() { @@ -47,33 +56,42 @@ func (app *App) ChatHandlerErnie(w http.ResponseWriter, r *http.Request) { return } - // 分别获取FirstQ&A, SecondQ&A, ThirdQ&A var history []structs.Message - pairs := []struct { - Q string - A string - RoleQ string // 问题的角色 - RoleA string // 答案的角色 - }{ - {config.GetFirstQ(), config.GetFirstA(), "user", "assistant"}, - {config.GetSecondQ(), config.GetSecondA(), "user", "assistant"}, - {config.GetThirdQ(), config.GetThirdA(), "user", "assistant"}, - } - // 检查每一对Q&A是否均不为空,并追加到历史信息中 - for _, pair := range pairs { - if pair.Q != "" && pair.A != "" { - qMessage := structs.Message{ - Text: pair.Q, - Role: pair.RoleQ, - } - aMessage := structs.Message{ - Text: pair.A, - Role: pair.RoleA, - } + // 是否从参数获取prompt + if promptstr == "" { + // 分别获取FirstQ&A, SecondQ&A, ThirdQ&A + pairs := []struct { + Q string + A string + RoleQ string // 问题的角色 + RoleA string // 答案的角色 + }{ + {config.GetFirstQ(), config.GetFirstA(), "user", "assistant"}, + {config.GetSecondQ(), config.GetSecondA(), "user", "assistant"}, + {config.GetThirdQ(), config.GetThirdA(), "user", "assistant"}, + } + + // 检查每一对Q&A是否均不为空,并追加到历史信息中 + for _, pair := range pairs { + if pair.Q != "" && pair.A != "" { + qMessage := structs.Message{ + Text: pair.Q, + Role: pair.RoleQ, + } + aMessage := structs.Message{ + Text: pair.A, + Role: pair.RoleA, + } - // 注意追加的顺序,确保问题在答案之前 - history = append(history, qMessage, aMessage) + // 注意追加的顺序,确保问题在答案之前 + history = append(history, qMessage, aMessage) + } + } + } else { + history, err = prompt.GetMessagesExcludingSystem(promptstr) + if err != nil { + fmtf.Printf("prompt.GetMessagesExcludingSystem error: %v\n", err) } } @@ -124,10 +142,22 @@ func (app *App) ChatHandlerErnie(w http.ResponseWriter, r *http.Request) { payload.Stream = true } - // 获取系统提示词,并设置system字段,如果它不为空 - systemPromptContent := config.SystemPrompt() // 确保函数名正确 - if systemPromptContent != "0" { - payload.System = systemPromptContent // 直接在请求负载中设置system字段 + // 是否从参数中获取prompt + if promptstr == "" { + // 获取系统提示词,并设置system字段,如果它不为空 + systemPromptContent := config.SystemPrompt() // 确保函数名正确 + if systemPromptContent != "0" { + payload.System = systemPromptContent // 直接在请求负载中设置system字段 + } + } else { + // 获取系统提示词,并设置system字段,如果它不为空 + systemPromptContent, err := prompt.GetFirstSystemMessage(promptstr) + if err != nil { + fmtf.Printf("prompt.GetFirstSystemMessage error: %v\n", err) + } + if systemPromptContent != "" { + payload.System = systemPromptContent // 直接在请求负载中设置system字段 + } } // 获取访问凭证和API路径 @@ -144,7 +174,7 @@ func (app *App) ChatHandlerErnie(w http.ResponseWriter, r *http.Request) { log.Fatalf("Error occurred during marshaling. Error: %s", err.Error()) } - fmtf.Printf("%v\n", string(jsonData)) + fmtf.Printf("文心一言请求:%v\n", string(jsonData)) // 创建并发送POST请求 req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) diff --git a/applogic/gensokyo.go b/applogic/gensokyo.go index 71bbd19..cc3868c 100644 --- a/applogic/gensokyo.go +++ b/applogic/gensokyo.go @@ -84,8 +84,6 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { // 解析请求体到OnebotGroupMessage结构体 var message structs.OnebotGroupMessage - fmtf.Printf("收到onebotv11信息: %+v\n", string(body)) - err = json.Unmarshal(body, &message) if err != nil { fmtf.Printf("Error parsing request body: %+v\n", string(body)) @@ -93,6 +91,23 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { return } + // 读取URL参数 "prompt" + promptstr := r.URL.Query().Get("prompt") + if promptstr != "" { + // 使用 prompt 变量进行后续处理 + fmt.Printf("收到prompt参数: %s\n", promptstr) + } + + // 读取URL参数 "prompt" + selfid := r.URL.Query().Get("selfid") + if selfid != "" { + // 使用 prompt 变量进行后续处理 + fmt.Printf("收到selfid参数: %s\n", selfid) + } + + // 打印日志信息,包括prompt参数 + fmtf.Printf("收到onebotv11信息: %+v\n", string(body)) + // 打印消息和其他相关信息 fmtf.Printf("Received message: %v\n", message.Message) fmtf.Printf("Full message details: %+v\n", message) @@ -126,7 +141,7 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { } } - if utils.BlacklistIntercept(message) { + if utils.BlacklistIntercept(message, selfid) { fmtf.Printf("userid:[%v]这位用户在黑名单中,被拦截", message.UserID) return } @@ -138,12 +153,12 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { RestoreResponse := config.GetRandomRestoreResponses() if message.RealMessageType == "group_private" || message.MessageType == "private" { if !config.GetUsePrivateSSE() { - utils.SendPrivateMessage(message.UserID, RestoreResponse) + utils.SendPrivateMessage(message.UserID, RestoreResponse, selfid) } else { utils.SendSSEPrivateRestoreMessage(message.UserID, RestoreResponse) } } else { - utils.SendGroupMessage(message.GroupID, message.UserID, RestoreResponse) + utils.SendGroupMessage(message.GroupID, message.UserID, RestoreResponse, selfid) } return } @@ -179,7 +194,7 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { // 进行字数拦截 if config.GetQuestionMaxLenth() != 0 { - if utils.LengthIntercept(newmsg, message) { + if utils.LengthIntercept(newmsg, message, selfid) { fmtf.Printf("字数过长,可在questionMaxLenth配置项修改,Q: %v", newmsg) // 发送响应 w.WriteHeader(http.StatusOK) @@ -190,7 +205,7 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { // 进行语言判断拦截 if len(config.GetAllowedLanguages()) > 0 { - if utils.LanguageIntercept(newmsg, message) { + if utils.LanguageIntercept(newmsg, message, selfid) { fmtf.Printf("不安全!不支持的语言,可在config.yml设置允许的语言,allowedLanguages配置项,Q: %v", newmsg) // 发送响应 w.WriteHeader(http.StatusOK) @@ -217,7 +232,7 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { // 向量安全词部分,机器人向量安全屏障 if config.GetVectorSensitiveFilter() { - ret, retstr, err := app.InterceptSensitiveContent(vector, message) + ret, retstr, err := app.InterceptSensitiveContent(vector, message, selfid) if err != nil { fmtf.Printf("Error in InterceptSensitiveContent: %v", err) // 发送响应 @@ -267,12 +282,12 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { // 发送响应消息 if message.RealMessageType == "group_private" || message.MessageType == "private" { if !config.GetUsePrivateSSE() { - utils.SendPrivateMessage(message.UserID, responseText) + utils.SendPrivateMessage(message.UserID, responseText, selfid) } else { utils.SendSSEPrivateMessage(message.UserID, responseText) } } else { - utils.SendGroupMessage(message.GroupID, message.UserID, responseText) + utils.SendGroupMessage(message.GroupID, message.UserID, responseText, selfid) } // 发送响应 w.WriteHeader(http.StatusOK) @@ -312,12 +327,12 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { if saveresponse != "" { if message.RealMessageType == "group_private" || message.MessageType == "private" { if !config.GetUsePrivateSSE() { - utils.SendPrivateMessage(message.UserID, saveresponse) + utils.SendPrivateMessage(message.UserID, saveresponse, selfid) } else { utils.SendSSEPrivateSafeMessage(message.UserID, saveresponse) } } else { - utils.SendGroupMessage(message.GroupID, message.UserID, saveresponse) + utils.SendGroupMessage(message.GroupID, message.UserID, saveresponse, selfid) } } // 发送响应 @@ -343,7 +358,14 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { // 构建并发送请求到conversation接口 port := config.GetPort() portStr := fmtf.Sprintf(":%d", port) - url := "http://127.0.0.1" + portStr + "/conversation" + + var url string + //如果promptstr不等于空,添加到参数中 + if promptstr != "" { + url = "http://127.0.0.1" + portStr + "/conversation?prompt=" + promptstr + } else { + url = "http://127.0.0.1" + portStr + "/conversation" + } // 请求模型还是使用原文请求 requestmsg := message.Message.(string) @@ -432,7 +454,7 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { // 判断消息类型,如果是私人消息或私有群消息,发送私人消息;否则,根据配置决定是否发送群消息 if message.RealMessageType == "group_private" || message.MessageType == "private" { if !config.GetUsePrivateSSE() { - utils.SendPrivateMessage(message.UserID, newPart) + utils.SendPrivateMessage(message.UserID, newPart, selfid) } else { //最后一条了 messageSSE := structs.InterfaceBody{ @@ -442,7 +464,7 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { utils.SendPrivateMessageSSE(message.UserID, messageSSE) } } else { - utils.SendGroupMessage(message.GroupID, message.UserID, newPart) + utils.SendGroupMessage(message.GroupID, message.UserID, newPart, selfid) } } else { //流的最后一次是完整结束的 @@ -456,7 +478,7 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { // 判断消息类型,如果是私人消息或私有群消息,发送私人消息;否则,根据配置决定是否发送群消息 if message.RealMessageType == "group_private" || message.MessageType == "private" { if !config.GetUsePrivateSSE() { - utils.SendPrivateMessage(message.UserID, response) + utils.SendPrivateMessage(message.UserID, response, selfid) } else { //最后一条了 messageSSE := structs.InterfaceBody{ @@ -466,7 +488,7 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { utils.SendPrivateMessageSSE(message.UserID, messageSSE) } } else { - utils.SendGroupMessage(message.GroupID, message.UserID, response) + utils.SendGroupMessage(message.GroupID, message.UserID, response, selfid) } } } @@ -490,7 +512,7 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { if !config.GetHideExtraLogs() { fmtf.Printf("收到流数据,切割并发送信息: %s", string(line)) } - splitAndSendMessages(message, string(line), newmsg) + splitAndSendMessages(message, string(line), newmsg, selfid) } } @@ -552,9 +574,9 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { if response, ok = responseData["response"].(string); ok && response != "" { // 判断消息类型,如果是私人消息或私有群消息,发送私人消息;否则,根据配置决定是否发送群消息 if message.RealMessageType == "group_private" || message.MessageType == "private" { - utils.SendPrivateMessage(message.UserID, response) + utils.SendPrivateMessage(message.UserID, response, selfid) } else { - utils.SendGroupMessage(message.GroupID, message.UserID, response) + utils.SendGroupMessage(message.GroupID, message.UserID, response, selfid) } } @@ -589,7 +611,7 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { } -func splitAndSendMessages(message structs.OnebotGroupMessage, line string, newmesssage string) { +func splitAndSendMessages(message structs.OnebotGroupMessage, line string, newmesssage string, selfid string) { // 提取JSON部分 dataPrefix := "data: " jsonStr := strings.TrimPrefix(line, dataPrefix) @@ -606,13 +628,13 @@ func splitAndSendMessages(message structs.OnebotGroupMessage, line string, newme if sseData.Response != "\n\n" { // 处理提取出的信息 - processMessage(sseData.Response, message, newmesssage) + processMessage(sseData.Response, message, newmesssage, selfid) } else { fmtf.Printf("忽略llm末尾的换行符") } } -func processMessage(response string, msg structs.OnebotGroupMessage, newmesssage string) { +func processMessage(response string, msg structs.OnebotGroupMessage, newmesssage string, selfid string) { key := utils.GetKey(msg.GroupID, msg.UserID) // 定义中文全角和英文标点符号 @@ -629,7 +651,7 @@ func processMessage(response string, msg structs.OnebotGroupMessage, newmesssage // 判断消息类型,如果是私人消息或私有群消息,发送私人消息;否则,根据配置决定是否发送群消息 if msg.RealMessageType == "group_private" || msg.MessageType == "private" { if !config.GetUsePrivateSSE() { - utils.SendPrivateMessage(msg.UserID, accumulatedMessage) + utils.SendPrivateMessage(msg.UserID, accumulatedMessage, selfid) } else { if IncrementIndex(newmesssage) == 1 { //第一条信息 @@ -653,7 +675,7 @@ func processMessage(response string, msg structs.OnebotGroupMessage, newmesssage } } } else { - utils.SendGroupMessage(msg.GroupID, msg.UserID, accumulatedMessage) + utils.SendGroupMessage(msg.GroupID, msg.UserID, accumulatedMessage, selfid) } messageBuilder.Reset() // 重置消息构建器 diff --git a/applogic/hunyuan.go b/applogic/hunyuan.go index 05d80a3..a5dfbf7 100644 --- a/applogic/hunyuan.go +++ b/applogic/hunyuan.go @@ -8,6 +8,7 @@ import ( "github.com/hoshinonyaruko/gensokyo-llm/config" "github.com/hoshinonyaruko/gensokyo-llm/fmtf" "github.com/hoshinonyaruko/gensokyo-llm/hunyuan" + "github.com/hoshinonyaruko/gensokyo-llm/prompt" "github.com/hoshinonyaruko/gensokyo-llm/structs" "github.com/hoshinonyaruko/gensokyo-llm/utils" ) @@ -27,6 +28,14 @@ func (app *App) ChatHandlerHunyuan(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusBadRequest) return } + + // 读取URL参数 "prompt" + promptstr := r.URL.Query().Get("prompt") + if promptstr != "" { + // prompt 参数存在,可以根据需要进一步处理或记录 + fmtf.Printf("Received prompt parameter: %s\n", promptstr) + } + msg.Role = "user" //颠倒用户输入 if config.GetReverseUserPrompt() { @@ -45,46 +54,54 @@ func (app *App) ChatHandlerHunyuan(w http.ResponseWriter, r *http.Request) { } var history []structs.Message - // 获取系统提示词 - systemPromptContent := config.SystemPrompt() // 注意检查实际的函数名是否正确 - - // 如果系统提示词不为空,则添加到历史信息的开始 - if systemPromptContent != "0" { - systemPromptRole := "system" - systemPromptMsg := structs.Message{ - Text: systemPromptContent, - Role: systemPromptRole, + + //根据是否有prompt参数 选择是否载入config.yml的prompt还是prompts文件夹的 + if promptstr == "" { + // 获取系统提示词 + systemPromptContent := config.SystemPrompt() // 注意检查实际的函数名是否正确 + // 如果系统提示词不为空,则添加到历史信息的开始 + if systemPromptContent != "0" { + systemPromptRole := "system" + systemPromptMsg := structs.Message{ + Text: systemPromptContent, + Role: systemPromptRole, + } + // 将系统提示作为第一条消息 + history = append([]structs.Message{systemPromptMsg}, history...) } - // 将系统提示作为第一条消息 - history = append([]structs.Message{systemPromptMsg}, history...) - } - // 分别获取FirstQ&A, SecondQ&A, ThirdQ&A - pairs := []struct { - Q string - A string - RoleQ string // 问题的角色 - RoleA string // 答案的角色 - }{ - {config.GetFirstQ(), config.GetFirstA(), "user", "assistant"}, - {config.GetSecondQ(), config.GetSecondA(), "user", "assistant"}, - {config.GetThirdQ(), config.GetThirdA(), "user", "assistant"}, - } + // 分别获取FirstQ&A, SecondQ&A, ThirdQ&A + pairs := []struct { + Q string + A string + RoleQ string // 问题的角色 + RoleA string // 答案的角色 + }{ + {config.GetFirstQ(), config.GetFirstA(), "user", "assistant"}, + {config.GetSecondQ(), config.GetSecondA(), "user", "assistant"}, + {config.GetThirdQ(), config.GetThirdA(), "user", "assistant"}, + } - // 检查每一对Q&A是否均不为空,并追加到历史信息中 - for _, pair := range pairs { - if pair.Q != "" && pair.A != "" { - qMessage := structs.Message{ - Text: pair.Q, - Role: pair.RoleQ, - } - aMessage := structs.Message{ - Text: pair.A, - Role: pair.RoleA, - } + // 检查每一对Q&A是否均不为空,并追加到历史信息中 + for _, pair := range pairs { + if pair.Q != "" && pair.A != "" { + qMessage := structs.Message{ + Text: pair.Q, + Role: pair.RoleQ, + } + aMessage := structs.Message{ + Text: pair.A, + Role: pair.RoleA, + } - // 注意追加的顺序,确保问题在答案之前 - history = append(history, qMessage, aMessage) + // 注意追加的顺序,确保问题在答案之前 + history = append(history, qMessage, aMessage) + } + } + } else { + history, err = prompt.GetMessagesFromFilename(promptstr) + if err != nil { + fmtf.Printf("prompt.GetMessagesFromFilename error: %v\n", err) } } diff --git a/applogic/rwkv.go b/applogic/rwkv.go index c2bfe67..6dc00c1 100644 --- a/applogic/rwkv.go +++ b/applogic/rwkv.go @@ -12,6 +12,7 @@ import ( "github.com/google/uuid" "github.com/hoshinonyaruko/gensokyo-llm/config" "github.com/hoshinonyaruko/gensokyo-llm/fmtf" + "github.com/hoshinonyaruko/gensokyo-llm/prompt" "github.com/hoshinonyaruko/gensokyo-llm/structs" "github.com/hoshinonyaruko/gensokyo-llm/utils" ) @@ -36,6 +37,14 @@ func (app *App) ChatHandlerRwkv(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusBadRequest) return } + + // 读取URL参数 "prompt" + promptstr := r.URL.Query().Get("prompt") + if promptstr != "" { + // prompt 参数存在,可以根据需要进一步处理或记录 + fmtf.Printf("Received prompt parameter: %s\n", promptstr) + } + msg.Role = "user" //颠倒用户输入 if config.GetReverseUserPrompt() { @@ -55,43 +64,51 @@ func (app *App) ChatHandlerRwkv(w http.ResponseWriter, r *http.Request) { var history []structs.Message - // 获取系统提示词 - systemPromptContent := config.SystemPrompt() - if systemPromptContent != "0" { - systemPrompt := structs.Message{ - Text: systemPromptContent, - Role: "system", + //根据是否有prompt参数 选择是否载入config.yml的prompt还是prompts文件夹的 + if promptstr == "" { + // 获取系统提示词 + systemPromptContent := config.SystemPrompt() + if systemPromptContent != "0" { + systemPrompt := structs.Message{ + Text: systemPromptContent, + Role: "system", + } + // 将系统提示词添加到历史信息的开始 + history = append([]structs.Message{systemPrompt}, history...) } - // 将系统提示词添加到历史信息的开始 - history = append([]structs.Message{systemPrompt}, history...) - } - // 分别获取FirstQ&A, SecondQ&A, ThirdQ&A - pairs := []struct { - Q string - A string - RoleQ string // 问题的角色 - RoleA string // 答案的角色 - }{ - {config.GetFirstQ(), config.GetFirstA(), "user", "assistant"}, - {config.GetSecondQ(), config.GetSecondA(), "user", "assistant"}, - {config.GetThirdQ(), config.GetThirdA(), "user", "assistant"}, - } + // 分别获取FirstQ&A, SecondQ&A, ThirdQ&A + pairs := []struct { + Q string + A string + RoleQ string // 问题的角色 + RoleA string // 答案的角色 + }{ + {config.GetFirstQ(), config.GetFirstA(), "user", "assistant"}, + {config.GetSecondQ(), config.GetSecondA(), "user", "assistant"}, + {config.GetThirdQ(), config.GetThirdA(), "user", "assistant"}, + } - // 检查每一对Q&A是否均不为空,并追加到历史信息中 - for _, pair := range pairs { - if pair.Q != "" && pair.A != "" { - qMessage := structs.Message{ - Text: pair.Q, - Role: pair.RoleQ, - } - aMessage := structs.Message{ - Text: pair.A, - Role: pair.RoleA, - } + // 检查每一对Q&A是否均不为空,并追加到历史信息中 + for _, pair := range pairs { + if pair.Q != "" && pair.A != "" { + qMessage := structs.Message{ + Text: pair.Q, + Role: pair.RoleQ, + } + aMessage := structs.Message{ + Text: pair.A, + Role: pair.RoleA, + } - // 注意追加的顺序,确保问题在答案之前 - history = append(history, qMessage, aMessage) + // 注意追加的顺序,确保问题在答案之前 + history = append(history, qMessage, aMessage) + } + } + } else { + history, err = prompt.GetMessagesFromFilename(promptstr) + if err != nil { + fmtf.Printf("prompt.GetMessagesFromFilename error: %v\n", err) } } diff --git a/applogic/vectorsensitive.go b/applogic/vectorsensitive.go index 486e805..e5dc922 100644 --- a/applogic/vectorsensitive.go +++ b/applogic/vectorsensitive.go @@ -258,7 +258,7 @@ func (app *App) textExistsInDatabase(text string) (bool, error) { return exists, nil } -func (app *App) InterceptSensitiveContent(vector []float64, message structs.OnebotGroupMessage) (int, string, error) { +func (app *App) InterceptSensitiveContent(vector []float64, message structs.OnebotGroupMessage, selfid string) (int, string, error) { // 自定义阈值 Threshold := config.GetVertorSensitiveThreshold() @@ -283,12 +283,12 @@ func (app *App) InterceptSensitiveContent(vector []float64, message structs.Oneb if saveresponse != "" { if message.RealMessageType == "group_private" || message.MessageType == "private" { if !config.GetUsePrivateSSE() { - utils.SendPrivateMessage(message.UserID, saveresponse) + utils.SendPrivateMessage(message.UserID, saveresponse, selfid) } else { utils.SendSSEPrivateSafeMessage(message.UserID, saveresponse) } } else { - utils.SendGroupMessage(message.GroupID, message.UserID, saveresponse) + utils.SendGroupMessage(message.GroupID, message.UserID, saveresponse, selfid) } return 1, saveresponse, nil } diff --git a/config/config.go b/config/config.go index b88bbb1..c9775cb 100644 --- a/config/config.go +++ b/config/config.go @@ -1,12 +1,15 @@ package config import ( + "log" "math/rand" "os" "sync" "time" "github.com/hoshinonyaruko/gensokyo-llm/fmtf" + "github.com/hoshinonyaruko/gensokyo-llm/prompt" + "github.com/hoshinonyaruko/gensokyo-llm/structs" "gopkg.in/yaml.v3" ) @@ -18,99 +21,8 @@ var ( var r = rand.New(rand.NewSource(time.Now().UnixNano())) type Config struct { - Version int `yaml:"version"` - Settings Settings `yaml:"settings"` -} - -type Settings struct { - SecretId string `yaml:"secretId"` - SecretKey string `yaml:"secretKey"` - Region string `yaml:"region"` - UseSse bool `yaml:"useSse"` - Port int `yaml:"port"` - HttpPath string `yaml:"path"` - SystemPrompt []string `yaml:"systemPrompt"` - IPWhiteList []string `yaml:"iPWhiteList"` - MaxTokensHunyuan int `yaml:"maxTokensHunyuan"` - ApiType int `yaml:"apiType"` - WenxinAccessToken string `yaml:"wenxinAccessToken"` - WenxinApiPath string `yaml:"wenxinApiPath"` - MaxTokenWenxin int `yaml:"maxTokenWenxin"` - GptModel string `yaml:"gptModel"` - GptApiPath string `yaml:"gptApiPath"` - GptToken string `yaml:"gptToken"` - MaxTokenGpt int `yaml:"maxTokenGpt"` - GptSafeMode bool `yaml:"gptSafeMode"` - GptSseType int `yaml:"gptSseType"` - Groupmessage bool `yaml:"groupMessage"` - SplitByPuntuations int `yaml:"splitByPuntuations"` - HunyuanType int `yaml:"hunyuanType"` - FirstQ []string `yaml:"firstQ"` - FirstA []string `yaml:"firstA"` - SecondQ []string `yaml:"secondQ"` - SecondA []string `yaml:"secondA"` - ThirdQ []string `yaml:"thirdQ"` - ThirdA []string `yaml:"thirdA"` - SensitiveMode bool `yaml:"sensitiveMode"` - SensitiveModeType int `yaml:"sensitiveModeType"` - DefaultChangeWord string `yaml:"defaultChangeWord"` - AntiPromptAttackPath string `yaml:"antiPromptAttackPath"` - ReverseUserPrompt bool `yaml:"reverseUserPrompt"` - IgnoreExtraTips bool `yaml:"ignoreExtraTips"` - SaveResponses []string `yaml:"saveResponses"` - RestoreCommand []string `yaml:"restoreCommand"` - RestoreResponses []string `yaml:"restoreResponses"` - UsePrivateSSE bool `yaml:"usePrivateSSE"` - Promptkeyboard []string `yaml:"promptkeyboard"` - Savelogs bool `yaml:"savelogs"` - AntiPromptLimit float64 `yaml:"antiPromptLimit"` - UseCache bool `yaml:"useCache"` - CacheThreshold int `yaml:"cacheThreshold"` - CacheChance int `yaml:"cacheChance"` - EmbeddingType int `yaml:"embeddingType"` - WenxinEmbeddingUrl string `yaml:"wenxinEmbeddingUrl"` - GptEmbeddingUrl string `yaml:"gptEmbeddingUrl"` - PrintHanming bool `yaml:"printHanming"` - CacheK float64 `yaml:"cacheK"` - CacheN int64 `yaml:"cacheN"` - PrintVector bool `yaml:"printVector"` - VToBThreshold float64 `yaml:"vToBThreshold"` - GptModeration bool `yaml:"gptModeration"` - WenxinTopp float64 `yaml:"wenxinTopp"` - WnxinPenaltyScore float64 `yaml:"wenxinPenaltyScore"` - WenxinMaxOutputTokens int `yaml:"wenxinMaxOutputTokens"` - VectorSensitiveFilter bool `yaml:"vectorSensitiveFilter"` - VertorSensitiveThreshold int `yaml:"vertorSensitiveThreshold"` - AllowedLanguages []string `yaml:"allowedLanguages"` - LanguagesResponseMessages []string `yaml:"langResponseMessages"` - QuestionMaxLenth int `yaml:"questionMaxLenth"` - QmlResponseMessages []string `yaml:"qmlResponseMessages"` - BlacklistResponseMessages []string `yaml:"blacklistResponseMessages"` - NoContext bool `yaml:"noContext"` - WithdrawCommand []string `yaml:"withdrawCommand"` - FunctionMode bool `yaml:"functionMode"` - FunctionPath string `yaml:"functionPath"` - UseFunctionPromptkeyboard bool `yaml:"useFunctionPromptkeyboard"` - AIPromptkeyboardPath string `yaml:"AIPromptkeyboardPath"` - UseAIPromptkeyboard bool `yaml:"useAIPromptkeyboard"` - SplitByPuntuationsGroup int `yaml:"splitByPuntuationsGroup"` - RwkvApiPath string `yaml:"rwkvApiPath"` - RwkvMaxTokens int `yaml:"rwkvMaxTokens"` - RwkvTemperature float64 `yaml:"rwkvTemperature"` - RwkvTopP float64 `yaml:"rwkvTopP"` - RwkvPresencePenalty float64 `yaml:"rwkvPresencePenalty"` - RwkvFrequencyPenalty float64 `yaml:"rwkvFrequencyPenalty"` - RwkvPenaltyDecay float64 `yaml:"rwkvPenaltyDecay"` - RwkvTopK int `yaml:"rwkvTopK"` - RwkvGlobalPenalty bool `yaml:"rwkvGlobalPenalty"` - RwkvStream bool `yaml:"rwkvStream"` - RwkvStop []string `yaml:"rwkvStop"` - RwkvUserName string `yaml:"rwkvUserName"` - RwkvAssistantName string `yaml:"rwkvAssistantName"` - RwkvSystemName string `yaml:"rwkvSystemName"` - RwkvPreSystem bool `yaml:"rwkvPreSystem"` - RwkvSseType int `yaml:"rwkvSseType"` - HideExtraLogs bool `yaml:"hideExtraLogs"` + Version int `yaml:"version"` + Settings structs.Settings `yaml:"settings"` } // LoadConfig 从文件中加载配置并初始化单例配置 @@ -260,13 +172,32 @@ func GetWenxinAccessToken() string { } // 获取WenxinApiPath -func GetWenxinApiPath() string { +func GetWenxinApiPath(options ...string) string { mu.Lock() defer mu.Unlock() - if instance != nil { - return instance.Settings.WenxinApiPath + + if len(options) == 0 { + if instance != nil { + return instance.Settings.WenxinApiPath + } + return "0" } - return "0" + + // 处理传入的 basename + basename := options[0] + apiPathInterface, err := prompt.GetSettingFromFilename(basename, "WenxinApiPath") + if err != nil { + log.Println("Error retrieving WenxinApiPath:", err) + return "0" + } + + apiPath, ok := apiPathInterface.(string) + if !ok { + log.Println("Type assertion failed for WenxinApiPath") + return "0" + } + + return apiPath } // 获取GetMaxTokenWenxin @@ -280,13 +211,32 @@ func GetMaxTokenWenxin() int { } // 获取GptModel -func GetGptModel() string { +func GetGptModel(options ...string) string { mu.Lock() defer mu.Unlock() - if instance != nil { - return instance.Settings.GptModel + + if len(options) == 0 { + if instance != nil { + return instance.Settings.GptModel + } + return "0" } - return "0" + + // 处理传入的 basename + basename := options[0] + gptModelInterface, err := prompt.GetSettingFromFilename(basename, "GptModel") + if err != nil { + log.Println("Error retrieving GptModel:", err) + return "0" + } + + gptModel, ok := gptModelInterface.(string) + if !ok { + fmtf.Println("Type assertion failed for GptModel") + return "0" + } + + return gptModel } // 获取GptApiPath @@ -1141,3 +1091,13 @@ func GetHideExtraLogs() bool { } return false } + +// 获取wsServerToken +func GetWSServerToken() string { + mu.Lock() + defer mu.Unlock() + if instance != nil { + return instance.Settings.WSServerToken + } + return "" +} diff --git a/go.mod b/go.mod index 58cc3a4..5b3170f 100644 --- a/go.mod +++ b/go.mod @@ -12,12 +12,36 @@ require ( require github.com/abadojack/whatlanggo v1.0.1 require ( + github.com/bytedance/sonic v1.9.1 // indirect + github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/gabriel-vasile/mimetype v1.4.2 // indirect + github.com/gin-contrib/sse v0.1.0 // indirect + github.com/gin-gonic/gin v1.9.1 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.14.0 // indirect + github.com/goccy/go-json v0.10.2 // indirect + github.com/gorilla/websocket v1.5.1 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/cpuid/v2 v2.2.4 // indirect + github.com/leodido/go-urn v1.2.4 // indirect github.com/liuzl/cedar-go v0.0.0-20170805034717-80a9c64b256d // indirect github.com/liuzl/da v0.0.0-20180704015230-14771aad5b1d // indirect github.com/longbridgeapp/opencc v0.3.11 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/ugorji/go/codec v1.2.11 // indirect + golang.org/x/arch v0.3.0 // indirect + golang.org/x/crypto v0.14.0 // indirect + golang.org/x/net v0.17.0 // indirect + golang.org/x/text v0.13.0 // indirect + google.golang.org/protobuf v1.30.0 // indirect ) require ( github.com/fsnotify/fsnotify v1.7.0 - golang.org/x/sys v0.4.0 // indirect + golang.org/x/sys v0.13.0 // indirect ) diff --git a/go.sum b/go.sum index 943f5f3..c9475e2 100644 --- a/go.sum +++ b/go.sum @@ -1,32 +1,102 @@ github.com/abadojack/whatlanggo v1.0.1 h1:19N6YogDnf71CTHm3Mp2qhYfkRdyvbgwWdd2EPxJRG4= github.com/abadojack/whatlanggo v1.0.1/go.mod h1:66WiQbSbJBIlOZMsvbKe5m6pzQovxCH9B/K8tQB2uoc= github.com/adamzy/cedar-go v0.0.0-20170805034717-80a9c64b256d/go.mod h1:PRWNwWq0yifz6XDPZu48aSld8BWwBfr2JKB2bGWiEd4= +github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= +github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= +github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= +github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= +github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= +github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= +github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= +github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= +github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU= github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= +github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= +github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= +github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= github.com/liuzl/cedar-go v0.0.0-20170805034717-80a9c64b256d h1:qSmEGTgjkESUX5kPMSGJ4pcBUtYVDdkNzMrjQyvRvp0= github.com/liuzl/cedar-go v0.0.0-20170805034717-80a9c64b256d/go.mod h1:x7SghIWwLVcJObXbjK7S2ENsT1cAcdJcPl7dRaSFog0= github.com/liuzl/da v0.0.0-20180704015230-14771aad5b1d h1:hTRDIpJ1FjS9ULJuEzu69n3qTgc18eI+ztw/pJv47hs= github.com/liuzl/da v0.0.0-20180704015230-14771aad5b1d/go.mod h1:7xD3p0XnHvJFQ3t/stEJd877CSIMkH/fACVWen5pYnc= github.com/longbridgeapp/opencc v0.3.11 h1:MfozRXTRmchceDyVsJ/JoOsuXb7AqtjF7RUtWUa0cQo= github.com/longbridgeapp/opencc v0.3.11/go.mod h1:jRuKtq8eLA+cZUu75XgMvkB/hFSXJbZDmij0v29lNaY= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.19 h1:fhGleo2h1p8tVChob4I9HpmVFIAkKGpiukdrgQbWfGI= github.com/mattn/go-sqlite3 v1.14.19/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= +github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.839 h1:VGVFNQDaUpDsPkJrh8I9qOxHZ1yj5sJmg9ngsUvTAHM= github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.839/go.mod h1:r5r4xbfxSaeR04b166HGsBa/R4U3SueirEUpXGuw+Q0= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= +github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= +golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= +golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= +golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= +golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= +google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= @@ -36,3 +106,4 @@ gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/main.go b/main.go index fbde55f..5bfeb12 100644 --- a/main.go +++ b/main.go @@ -4,10 +4,13 @@ import ( "bufio" "database/sql" "flag" + "fmt" "log" "net/http" "os" + "os/signal" "path/filepath" + "syscall" _ "github.com/mattn/go-sqlite3" // 只导入,作为驱动 @@ -15,6 +18,7 @@ import ( "github.com/hoshinonyaruko/gensokyo-llm/config" "github.com/hoshinonyaruko/gensokyo-llm/fmtf" "github.com/hoshinonyaruko/gensokyo-llm/hunyuan" + "github.com/hoshinonyaruko/gensokyo-llm/server" "github.com/hoshinonyaruko/gensokyo-llm/template" "github.com/hoshinonyaruko/gensokyo-llm/utils" ) @@ -184,10 +188,32 @@ func main() { return } + // 设置路由 http.HandleFunc("/gensokyo", app.GensokyoHandler) + var wspath string + if conf.Settings.WSPath == "nil" { + wspath = "/" + } else { + wspath = "/" + conf.Settings.WSPath + } + http.HandleFunc(wspath, func(w http.ResponseWriter, r *http.Request) { + server.WsHandler(w, r, conf) + }) port := config.GetPort() portStr := fmtf.Sprintf(":%d", port) fmtf.Printf("listening on %v\n", portStr) - // 这里阻塞等待并处理请求 + + // 设置信号处理 + go func() { + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + <-sigChan + + fmt.Println("Shutting down server...") + server.CloseAllConnections() + os.Exit(0) + }() + + // 启动HTTP服务器 log.Fatal(http.ListenAndServe(portStr, nil)) } diff --git a/prompt/prompt.go b/prompt/prompt.go new file mode 100644 index 0000000..2fea630 --- /dev/null +++ b/prompt/prompt.go @@ -0,0 +1,202 @@ +package prompt + +import ( + "fmt" + "log" + "os" + "path/filepath" + "reflect" + "sync" + + "github.com/fsnotify/fsnotify" + + "github.com/hoshinonyaruko/gensokyo-llm/structs" + "gopkg.in/yaml.v3" +) + +type Prompt struct { + Role string `yaml:"role"` + Content string `yaml:"content"` +} + +type PromptFile struct { + Prompts []Prompt `yaml:"Prompt"` + Settings structs.Settings `yaml:"settings"` +} + +var ( + promptsCache = make(map[string]PromptFile) + lock sync.RWMutex + promptsDir = "prompts" // 定义固定的目录名 +) + +func init() { + // 通过 init 函数在包加载时就执行目录监控 + err := LoadPrompts() + if err != nil { + log.Fatal("Failed to load prompts:", err) + } +} + +// LoadPrompts 确保目录存在并尝试加载提示词文件 +func LoadPrompts() error { + // 构建目录路径 + directory := filepath.Join(".", promptsDir) + + // 尝试创建目录(如果不存在) + if _, err := os.Stat(directory); os.IsNotExist(err) { + // 目录不存在,尝试创建它 + if err := os.MkdirAll(directory, os.ModePerm); err != nil { + return err + } + } + files, err := os.ReadDir(directory) + if err != nil { + return err + } + + for _, file := range files { + if filepath.Ext(file.Name()) == ".yml" { + loadFile(filepath.Join(directory, file.Name())) + } + } + + watcher, err := fsnotify.NewWatcher() + if err != nil { + return err + } + + go func() { + for { + select { + case event, ok := <-watcher.Events: + if !ok { + return + } + if event.Op&fsnotify.Write == fsnotify.Write { + loadFile(event.Name) + } + case err, ok := <-watcher.Errors: + if !ok { + return + } + log.Println("error:", err) + } + } + }() + + err = watcher.Add(directory) + if err != nil { + return err + } + + return nil +} + +func loadFile(filename string) { + lock.Lock() + defer lock.Unlock() + + data, err := os.ReadFile(filename) + if err != nil { + log.Println("Failed to read file:", err) + return + } + + var prompts PromptFile + err = yaml.Unmarshal(data, &prompts) + if err != nil { + log.Println("Failed to unmarshal YAML:", err) + return + } + + baseName := filepath.Base(filename) + promptsCache[baseName] = prompts +} + +func GetMessagesFromFilename(basename string) ([]structs.Message, error) { + lock.RLock() + defer lock.RUnlock() + + filename := basename + ".yml" + promptFile, exists := promptsCache[filename] + if !exists { + return nil, fmt.Errorf("no data for file: %s", filename) + } + + var history []structs.Message + for _, prompt := range promptFile.Prompts { + history = append(history, structs.Message{ + Text: prompt.Content, + Role: prompt.Role, + }) + } + + return history, nil +} + +// 返回除了 "system" 角色之外的所有消息 +func GetMessagesExcludingSystem(basename string) ([]structs.Message, error) { + lock.RLock() + defer lock.RUnlock() + + filename := basename + ".yml" + promptFile, exists := promptsCache[filename] + if !exists { + return nil, fmt.Errorf("no data for file: %s", filename) + } + + var history []structs.Message + for _, prompt := range promptFile.Prompts { + if prompt.Role != "system" && prompt.Role != "System" { + history = append(history, structs.Message{ + Text: prompt.Content, + Role: prompt.Role, + }) + } + } + + return history, nil +} + +// 返回第一条 "system" 角色的消息文本 +func GetFirstSystemMessage(basename string) (string, error) { + lock.RLock() + defer lock.RUnlock() + + filename := basename + ".yml" + promptFile, exists := promptsCache[filename] + if !exists { + return "", fmt.Errorf("no data for file: %s", filename) + } + + for _, prompt := range promptFile.Prompts { + if prompt.Role == "system" || prompt.Role == "System" { + return prompt.Content, nil + } + } + + return "", fmt.Errorf("no system message found in file: %s", filename) +} + +// GetSettingFromFilename 用于获取配置文件中的特定设置 +func GetSettingFromFilename(basename, settingName string) (interface{}, error) { + lock.RLock() + defer lock.RUnlock() + + filename := basename + ".yml" + promptFile, exists := promptsCache[filename] + if !exists { + return nil, fmt.Errorf("no data for file: %s", filename) + } + + // 使用反射获取Settings结构体中的字段 + rv := reflect.ValueOf(promptFile.Settings) + field := rv.FieldByName(settingName) + if !field.IsValid() { + return nil, fmt.Errorf("no setting with name: %s", settingName) + } + + // 返回字段的值,转换为interface{} + return field.Interface(), nil +} diff --git a/readme.md b/readme.md index 0565e4f..92fd8e1 100644 --- a/readme.md +++ b/readme.md @@ -65,7 +65,37 @@ AhoCorasick算法实现的超高效文本IN-Out替换规则,可大量替换n 支持中间件开发,在gensokyo框架层到gensokyo-llm的http请求之间,可开发中间件实现向量拓展,数据库拓展,动态修改用户问题. -## 接口调用说明 +# API接口调用说明 + +本文档提供了关于API接口的调用方法和配置文件的格式说明,帮助用户正确使用和配置。 + +## 接口支持的查询参数 + +本系统的 `conversation` 和 `gensokyo` 端点支持通过查询参数 `?prompt=xxx` 来指定特定的配置。 + +- `prompt` 参数允许用户指定位于执行文件(exe)的 `prompts` 文件夹下的配置YAML文件。使用该参数可以动态地调整API行为和返回内容。 + +## YAML配置文件格式 + +配置文件应遵循以下YAML格式。这里提供了一个示例配置文件,展示了如何定义不同角色的对话内容: + +```yaml +Prompt: + - role: "system" + content: "Welcome to the system. How can I assist you today?" + - role: "user" + content: "I need help with my account." + - role: "assistant" + content: "I can help you with that. What seems to be the problem?" + - role: "user" + content: "aaaaaaaaaa!" + - role: "assistant" + content: "ooooooooo?" +settings: + # 以下是通用配置项 和config.yml相同 + useSse: true + port: 46233 +``` ### 终结点 diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..34cfd65 --- /dev/null +++ b/server/server.go @@ -0,0 +1,199 @@ +package server + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "strings" + "sync" + + "github.com/gorilla/websocket" + "github.com/hoshinonyaruko/gensokyo-llm/config" + "github.com/hoshinonyaruko/gensokyo-llm/fmtf" + "github.com/hoshinonyaruko/gensokyo-llm/structs" +) + +type WebSocketServerClient struct { + SelfID string + Conn *websocket.Conn +} + +// 维护所有活跃连接的切片 +var clients = []*WebSocketServerClient{} +var lock sync.Mutex +var upgrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true // 允许所有跨域请求 + }, +} + +// 用于处理WebSocket连接 +func WsHandler(w http.ResponseWriter, r *http.Request, config *config.Config) { + // 从请求头或URL查询参数中提取token + tokenFromHeader := r.Header.Get("Authorization") + selfID := r.Header.Get("X-Self-ID") + fmtf.Printf("接入机器人X-Self-ID[%v]", selfID) + var token string + if strings.HasPrefix(tokenFromHeader, "Token ") { + token = strings.TrimPrefix(tokenFromHeader, "Token ") + } else if strings.HasPrefix(tokenFromHeader, "Bearer ") { + token = strings.TrimPrefix(tokenFromHeader, "Bearer ") + } else { + token = tokenFromHeader + } + if token == "" { + token = r.URL.Query().Get("access_token") + } + + // 验证token + validToken := config.Settings.WSServerToken + if validToken != "" && (token == "" || token != validToken) { + if token == "" { + log.Printf("Connection failed due to missing token. Headers: %v", r.Header) + http.Error(w, "Missing token", http.StatusUnauthorized) + } else { + log.Printf("Connection failed due to incorrect token. Headers: %v, Provided token: %s", r.Header, token) + http.Error(w, "Incorrect token", http.StatusForbidden) + } + return + } + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Printf("Failed to set websocket upgrade: %+v", err) + return + } + defer conn.Close() + + lock.Lock() + clients = append(clients, &WebSocketServerClient{ + SelfID: selfID, + Conn: conn, + }) + lock.Unlock() + + clientIP := r.RemoteAddr + log.Printf("WebSocket client connected. IP: %s", clientIP) + + for { + messageType, p, err := conn.ReadMessage() + if err != nil { + log.Printf("Error reading message: %v", err) + break + } + + if messageType == websocket.TextMessage { + processWSMessage(p, selfID) + } + } +} + +// 处理收到的信息 +func processWSMessage(msg []byte, selfid string) { + var genericMap map[string]interface{} + if err := json.Unmarshal(msg, &genericMap); err != nil { + log.Printf("Error unmarshalling message to map: %v, Original message: %s\n", err, string(msg)) + return + } + + // Assuming there's a way to distinguish notice messages, for example, checking if notice_type exists + if noticeType, ok := genericMap["notice_type"].(string); ok && noticeType != "" { + var noticeEvent structs.NoticeEvent + if err := json.Unmarshal(msg, ¬iceEvent); err != nil { + log.Printf("Error unmarshalling notice event: %v\n", err) + return + } + fmt.Printf("Processed a notice event of type '%s' from group %d.\n", noticeEvent.NoticeType, noticeEvent.GroupID) + //进入处理流程 + + } else if postType, ok := genericMap["post_type"].(string); ok { + switch postType { + case "message": + var messageEvent structs.OnebotGroupMessage + if err := json.Unmarshal(msg, &messageEvent); err != nil { + log.Printf("Error unmarshalling message event: %v\n", err) + return + } + fmt.Printf("Processed a message event from group %d.\n", messageEvent.GroupID) + //进入处理流程 + + // 将消息事件序列化为JSON + data, err := json.Marshal(messageEvent) + if err != nil { + log.Printf("Error marshalling message event: %v\n", err) + return + } + + port := config.GetPort() + // 构造请求URL + url := "http://127.0.0.1:" + fmt.Sprint(port) + "/gensokyo?selfid=" + selfid + + // 创建POST请求 + resp, err := http.Post(url, "application/json", bytes.NewReader(data)) + if err != nil { + log.Printf("Failed to send POST request: %v\n", err) + return + } + defer resp.Body.Close() + + // 读取响应 + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + log.Printf("Failed to read response body: %v\n", err) + return + } + + log.Printf("Received response: %s\n", responseBody) + + case "meta_event": + var metaEvent structs.MetaEvent + if err := json.Unmarshal(msg, &metaEvent); err != nil { + log.Printf("Error unmarshalling meta event: %v\n", err) + return + } + fmt.Printf("Processed a meta event, heartbeat interval: %d.\n", metaEvent.Interval) + //进入 处理流程 + + } + } else { + log.Printf("Unknown message type or missing post type\n") + } +} + +// 发信息给client +func SendMessageBySelfID(selfID string, message map[string]interface{}) error { + lock.Lock() + defer lock.Unlock() + + for _, client := range clients { + if client.SelfID == selfID { + msgBytes, err := json.Marshal(message) + if err != nil { + return fmt.Errorf("error marshalling message: %v", err) + } + return client.Conn.WriteMessage(websocket.TextMessage, msgBytes) + } + } + + return fmt.Errorf("no connection found for selfID: %s", selfID) +} + +func (client *WebSocketServerClient) Close() error { + return client.Conn.Close() +} + +func CloseAllConnections() { + lock.Lock() + defer lock.Unlock() + + for _, client := range clients { + err := client.Close() + if err != nil { + log.Printf("failed to close connection for selfID %s: %v", client.SelfID, err) + } + } + clients = nil // 清空切片,避免悬挂引用 +} diff --git a/structs/struct.go b/structs/struct.go index 8423a8d..86af079 100644 --- a/structs/struct.go +++ b/structs/struct.go @@ -206,3 +206,162 @@ type WXFunctionCall struct { Arguments map[string]interface{} `json:"arguments,omitempty"` Thought string `json:"thought,omitempty"` } + +type Settings struct { + SecretId string `yaml:"secretId"` + SecretKey string `yaml:"secretKey"` + Region string `yaml:"region"` + UseSse bool `yaml:"useSse"` + Port int `yaml:"port"` + HttpPath string `yaml:"path"` + SystemPrompt []string `yaml:"systemPrompt"` + IPWhiteList []string `yaml:"iPWhiteList"` + ApiType int `yaml:"apiType"` + + HunyuanType int `yaml:"hunyuanType"` + MaxTokensHunyuan int `yaml:"maxTokensHunyuan"` + + WenxinAccessToken string `yaml:"wenxinAccessToken"` + WenxinApiPath string `yaml:"wenxinApiPath"` + MaxTokenWenxin int `yaml:"maxTokenWenxin"` + WenxinTopp float64 `yaml:"wenxinTopp"` + WnxinPenaltyScore float64 `yaml:"wenxinPenaltyScore"` + WenxinMaxOutputTokens int `yaml:"wenxinMaxOutputTokens"` + WenxinEmbeddingUrl string `yaml:"wenxinEmbeddingUrl"` + + GptModel string `yaml:"gptModel"` + GptApiPath string `yaml:"gptApiPath"` + GptToken string `yaml:"gptToken"` + MaxTokenGpt int `yaml:"maxTokenGpt"` + GptSafeMode bool `yaml:"gptSafeMode"` + GptSseType int `yaml:"gptSseType"` + GptEmbeddingUrl string `yaml:"gptEmbeddingUrl"` + + Groupmessage bool `yaml:"groupMessage"` + SplitByPuntuations int `yaml:"splitByPuntuations"` + + FirstQ []string `yaml:"firstQ"` + FirstA []string `yaml:"firstA"` + SecondQ []string `yaml:"secondQ"` + SecondA []string `yaml:"secondA"` + ThirdQ []string `yaml:"thirdQ"` + ThirdA []string `yaml:"thirdA"` + + SensitiveMode bool `yaml:"sensitiveMode"` + SensitiveModeType int `yaml:"sensitiveModeType"` + DefaultChangeWord string `yaml:"defaultChangeWord"` + AntiPromptAttackPath string `yaml:"antiPromptAttackPath"` + ReverseUserPrompt bool `yaml:"reverseUserPrompt"` + IgnoreExtraTips bool `yaml:"ignoreExtraTips"` + SaveResponses []string `yaml:"saveResponses"` + RestoreCommand []string `yaml:"restoreCommand"` + RestoreResponses []string `yaml:"restoreResponses"` + UsePrivateSSE bool `yaml:"usePrivateSSE"` + Promptkeyboard []string `yaml:"promptkeyboard"` + Savelogs bool `yaml:"savelogs"` + AntiPromptLimit float64 `yaml:"antiPromptLimit"` + + UseCache bool `yaml:"useCache"` + CacheThreshold int `yaml:"cacheThreshold"` + CacheChance int `yaml:"cacheChance"` + EmbeddingType int `yaml:"embeddingType"` + + PrintHanming bool `yaml:"printHanming"` + CacheK float64 `yaml:"cacheK"` + CacheN int64 `yaml:"cacheN"` + PrintVector bool `yaml:"printVector"` + VToBThreshold float64 `yaml:"vToBThreshold"` + GptModeration bool `yaml:"gptModeration"` + + VectorSensitiveFilter bool `yaml:"vectorSensitiveFilter"` + VertorSensitiveThreshold int `yaml:"vertorSensitiveThreshold"` + AllowedLanguages []string `yaml:"allowedLanguages"` + LanguagesResponseMessages []string `yaml:"langResponseMessages"` + QuestionMaxLenth int `yaml:"questionMaxLenth"` + QmlResponseMessages []string `yaml:"qmlResponseMessages"` + BlacklistResponseMessages []string `yaml:"blacklistResponseMessages"` + NoContext bool `yaml:"noContext"` + WithdrawCommand []string `yaml:"withdrawCommand"` + FunctionMode bool `yaml:"functionMode"` + FunctionPath string `yaml:"functionPath"` + UseFunctionPromptkeyboard bool `yaml:"useFunctionPromptkeyboard"` + AIPromptkeyboardPath string `yaml:"AIPromptkeyboardPath"` + UseAIPromptkeyboard bool `yaml:"useAIPromptkeyboard"` + SplitByPuntuationsGroup int `yaml:"splitByPuntuationsGroup"` + + RwkvApiPath string `yaml:"rwkvApiPath"` + RwkvMaxTokens int `yaml:"rwkvMaxTokens"` + RwkvTemperature float64 `yaml:"rwkvTemperature"` + RwkvTopP float64 `yaml:"rwkvTopP"` + RwkvPresencePenalty float64 `yaml:"rwkvPresencePenalty"` + RwkvFrequencyPenalty float64 `yaml:"rwkvFrequencyPenalty"` + RwkvPenaltyDecay float64 `yaml:"rwkvPenaltyDecay"` + RwkvTopK int `yaml:"rwkvTopK"` + RwkvGlobalPenalty bool `yaml:"rwkvGlobalPenalty"` + RwkvStream bool `yaml:"rwkvStream"` + RwkvStop []string `yaml:"rwkvStop"` + RwkvUserName string `yaml:"rwkvUserName"` + RwkvAssistantName string `yaml:"rwkvAssistantName"` + RwkvSystemName string `yaml:"rwkvSystemName"` + RwkvPreSystem bool `yaml:"rwkvPreSystem"` + RwkvSseType int `yaml:"rwkvSseType"` + HideExtraLogs bool `yaml:"hideExtraLogs"` + + WSServerToken string `yaml:"wsServerToken"` + WSPath string `yaml:"wsPath"` +} + +type MetaEvent struct { + PostType string `json:"post_type"` + MetaEventType string `json:"meta_event_type"` + Time int64 `json:"time"` + SelfID int64 `json:"self_id"` + Interval int `json:"interval"` + Status struct { + AppEnabled bool `json:"app_enabled"` + AppGood bool `json:"app_good"` + AppInitialized bool `json:"app_initialized"` + Good bool `json:"good"` + Online bool `json:"online"` + PluginsGood *bool `json:"plugins_good"` + Stat struct { + PacketReceived int `json:"packet_received"` + PacketSent int `json:"packet_sent"` + PacketLost int `json:"packet_lost"` + MessageReceived int `json:"message_received"` + MessageSent int `json:"message_sent"` + DisconnectTimes int `json:"disconnect_times"` + LostTimes int `json:"lost_times"` + LastMessageTime int64 `json:"last_message_time"` + } `json:"stat"` + } `json:"status"` +} + +type NoticeEvent struct { + GroupID int64 `json:"group_id"` + NoticeType string `json:"notice_type"` + OperatorID int64 `json:"operator_id"` + PostType string `json:"post_type"` + SelfID int64 `json:"self_id"` + SubType string `json:"sub_type"` + Time int64 `json:"time"` + UserID int64 `json:"user_id"` +} + +type RobotStatus struct { + SelfID int64 `json:"self_id"` + Date string `json:"date"` + Online bool `json:"online"` + MessageReceived int `json:"message_received"` + MessageSent int `json:"message_sent"` + LastMessageTime int64 `json:"last_message_time"` + InvitesReceived int `json:"invites_received"` + KicksReceived int `json:"kicks_received"` + DailyDAU int `json:"daily_dau"` +} + +type OnebotActionMessage struct { + Action string `json:"action"` + Params map[string]interface{} `json:"params"` + Echo interface{} `json:"echo,omitempty"` +} diff --git a/template/config_template.go b/template/config_template.go index 0dfe20e..ff359aa 100644 --- a/template/config_template.go +++ b/template/config_template.go @@ -5,7 +5,7 @@ version: 1 settings: #通用配置项 - useSse : true + useSse : false #智能体场景开启,其他场景,比如普通onebotv11不开启 port : 46233 #本程序监听端口,支持gensokyo http上报, 请在gensokyo的反向http配置加入 post_url: ["http://127.0.0.1:port/gensokyo"] path : "http://123.123.123.123:11111" #调用gensokyo api的地址,填入 gensokyo 的 正向http地址 http_address: "0.0.0.0:46231" 对应填入 "http://127.0.0.1:46231" apiType : 0 #0=混元 1=文心(文心平台包含了N种模型...) 2=gpt @@ -38,6 +38,10 @@ settings: withdrawCommand : ["撤回"] #撤回指令 hideExtraLogs : false #忽略流信息的log,提高性能 + #Ws服务器配置 + wsServerToken : "" #ws密钥 可以由onebotv11反向ws接入 + wsPath : "nil" #设置了ws就不用设置path了,可以连接多个机器人. + functionMode : false #是否指定本agent使用func模式(目前仅支持千帆平台),效果不好,暂时不用. functionPath : "" #调用另一个启用了func模式的gsk-llm联合工作的/conversation地址,效果不好,暂时不用. useFunctionPromptkeyboard : false #使用func生成气泡,效果不好,暂时不用. diff --git a/utils/blacklist.go b/utils/blacklist.go index 789109d..6d3b1f2 100644 --- a/utils/blacklist.go +++ b/utils/blacklist.go @@ -93,7 +93,7 @@ func WatchBlacklist(filePath string) { } // BlacklistIntercept 检查用户ID是否在黑名单中,如果在,则发送预设消息 -func BlacklistIntercept(message structs.OnebotGroupMessage) bool { +func BlacklistIntercept(message structs.OnebotGroupMessage, selfid string) bool { // 检查用户ID是否在黑名单中 if IsInBlacklist(strconv.FormatInt(message.UserID, 10)) { // 获取黑名单响应消息 @@ -102,12 +102,12 @@ func BlacklistIntercept(message structs.OnebotGroupMessage) bool { // 根据消息类型发送响应 if message.RealMessageType == "group_private" || message.MessageType == "private" { if !config.GetUsePrivateSSE() { - SendPrivateMessage(message.UserID, responseMessage) + SendPrivateMessage(message.UserID, responseMessage, selfid) } else { SendSSEPrivateMessage(message.UserID, responseMessage) } } else { - SendGroupMessage(message.GroupID, message.UserID, responseMessage) + SendGroupMessage(message.GroupID, message.UserID, responseMessage, selfid) } fmt.Printf("userid:[%v]这位用户在黑名单中,被拦截\n", message.UserID) diff --git a/utils/utils.go b/utils/utils.go index f730477..8bafa4e 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -21,6 +21,7 @@ import ( "github.com/hoshinonyaruko/gensokyo-llm/config" "github.com/hoshinonyaruko/gensokyo-llm/fmtf" "github.com/hoshinonyaruko/gensokyo-llm/hunyuan" + "github.com/hoshinonyaruko/gensokyo-llm/server" "github.com/hoshinonyaruko/gensokyo-llm/structs" ) @@ -142,7 +143,23 @@ func ExtractEventDetails(eventData map[string]interface{}) (string, structs.Usag return responseTextBuilder.String(), totalUsage } -func SendGroupMessage(groupID int64, userID int64, message string) error { +func SendGroupMessage(groupID int64, userID int64, message string, selfid string) error { + //TODO: 用userid作为了echo,在ws收到回调信息的时候,加入到全局撤回数组,AddMessageID,实现撤回 + if selfid != "" { + // 创建消息结构体 + msg := map[string]interface{}{ + "action": "send_group_msg", + "params": map[string]interface{}{ + "group_id": groupID, + "user_id": userID, + "message": message, + }, + "echo": userID, + } + + // 发送消息 + return server.SendMessageBySelfID(selfid, msg) + } // 获取基础URL baseURL := config.GetHttpPath() // 假设config.getHttpPath()返回基础URL @@ -198,7 +215,21 @@ func SendGroupMessage(groupID int64, userID int64, message string) error { return nil } -func SendPrivateMessage(UserID int64, message string) error { +func SendPrivateMessage(UserID int64, message string, selfid string) error { + if selfid != "" { + // 创建消息结构体 + msg := map[string]interface{}{ + "action": "send_private_msg", + "params": map[string]interface{}{ + "user_id": UserID, + "message": message, + }, + "echo": UserID, + } + + // 发送消息 + return server.SendMessageBySelfID(selfid, msg) + } // 获取基础URL baseURL := config.GetHttpPath() // 假设config.getHttpPath()返回基础URL @@ -610,7 +641,7 @@ func SendSSEPrivateRestoreMessage(userID int64, RestoreResponse string) { } // LanguageIntercept 检查文本语言,如果不在允许列表中,则返回 true 并发送消息 -func LanguageIntercept(text string, message structs.OnebotGroupMessage) bool { +func LanguageIntercept(text string, message structs.OnebotGroupMessage, selfid string) bool { info := whatlanggo.Detect(text) lang := whatlanggo.LangToString(info.Lang) fmtf.Printf("LanguageIntercept:%v\n", lang) @@ -630,12 +661,12 @@ func LanguageIntercept(text string, message structs.OnebotGroupMessage) bool { // 发送响应消息 if message.RealMessageType == "group_private" || message.MessageType == "private" { if !config.GetUsePrivateSSE() { - SendPrivateMessage(message.UserID, responseMessage) + SendPrivateMessage(message.UserID, responseMessage, selfid) } else { SendSSEPrivateMessage(message.UserID, responseMessage) } } else { - SendGroupMessage(message.GroupID, message.UserID, responseMessage) + SendGroupMessage(message.GroupID, message.UserID, responseMessage, selfid) } return true // 拦截 @@ -678,7 +709,7 @@ func FriendlyLanguageNameCN(lang whatlanggo.Lang) string { } // LengthIntercept 检查文本长度,如果超过最大长度,则返回 true 并发送消息 -func LengthIntercept(text string, message structs.OnebotGroupMessage) bool { +func LengthIntercept(text string, message structs.OnebotGroupMessage, selfid string) bool { maxLen := config.GetQuestionMaxLenth() if len(text) > maxLen { // 长度超出限制,获取并发送响应消息 @@ -687,12 +718,12 @@ func LengthIntercept(text string, message structs.OnebotGroupMessage) bool { // 根据消息类型发送响应 if message.RealMessageType == "group_private" || message.MessageType == "private" { if !config.GetUsePrivateSSE() { - SendPrivateMessage(message.UserID, responseMessage) + SendPrivateMessage(message.UserID, responseMessage, selfid) } else { SendSSEPrivateMessage(message.UserID, responseMessage) } } else { - SendGroupMessage(message.GroupID, message.UserID, responseMessage) + SendGroupMessage(message.GroupID, message.UserID, responseMessage, selfid) } return true // 拦截 From 80840de0637cf3ae1263b6109584aa89f984bf51 Mon Sep 17 00:00:00 2001 From: SanaeFox <36219542+Hoshinonyaruko@users.noreply.github.com> Date: Thu, 18 Apr 2024 21:29:48 +0800 Subject: [PATCH 09/13] Beta71 (#71) * beta1 * beta2 * beta3 * beta4 * beta5 * beta6 * beta7 * beta8 * beta9 * beta10 * beta11 * beta12 * beta13 * beta14 * beta15 * beta16 * beta16 * beta19 * beta20 * beta21 * beta22 * beta23 * beta24 * beta25 * beta27 * beta28 * beta29 * beta30 * beta31 * beta33 * beta34 * beta35 * beta36 * beta37 * beta38 * beta39 * beta40 * beta41 * beta42 * beta43 * beta44 * beta45 * beta45 * beta46 * beat48 * beta49 * beta50 * beta51 * beta52 * beta53 * beta54 * beta55 * beta57 * beta58 * beta59 * beta61 * beta62 * beta63 * beta63 * beta64 * beta65 * beta66 * beta67 * beta70 * beta71 --- applogic/chatgpt.go | 15 ++++++++++++++- applogic/ernie.go | 15 ++++++++++++++- applogic/gensokyo.go | 2 +- applogic/hunyuan.go | 13 +++++++++++++ applogic/rwkv.go | 13 +++++++++++++ config/config.go | 41 ++++++++++++++++++++++++++++++++--------- 6 files changed, 87 insertions(+), 12 deletions(-) diff --git a/applogic/chatgpt.go b/applogic/chatgpt.go index 6d3627c..74ba45c 100644 --- a/applogic/chatgpt.go +++ b/applogic/chatgpt.go @@ -32,6 +32,19 @@ func (app *App) ChatHandlerChatgpt(w http.ResponseWriter, r *http.Request) { return } + // 获取访问者的IP地址 + ip := r.RemoteAddr // 注意:这可能包含端口号 + ip = strings.Split(ip, ":")[0] // 去除端口号,仅保留IP地址 + + // 获取IP白名单 + whiteList := config.IPWhiteList() + + // 检查IP是否在白名单中 + if !utils.Contains(whiteList, ip) { + http.Error(w, "Access denied", http.StatusInternalServerError) + return + } + var msg structs.Message err := json.NewDecoder(r.Body).Decode(&msg) if err != nil { @@ -132,7 +145,7 @@ func (app *App) ChatHandlerChatgpt(w http.ResponseWriter, r *http.Request) { fmtf.Printf("CLOSE-AI上下文history:%v\n", history) // 构建请求到ChatGPT API - model := config.GetGptModel() + model := config.GetGptModel(promptstr) apiURL := config.GetGptApiPath() token := config.GetGptToken() diff --git a/applogic/ernie.go b/applogic/ernie.go index 943f523..2879a67 100644 --- a/applogic/ernie.go +++ b/applogic/ernie.go @@ -25,6 +25,19 @@ func (app *App) ChatHandlerErnie(w http.ResponseWriter, r *http.Request) { return } + // 获取访问者的IP地址 + ip := r.RemoteAddr // 注意:这可能包含端口号 + ip = strings.Split(ip, ":")[0] // 去除端口号,仅保留IP地址 + + // 获取IP白名单 + whiteList := config.IPWhiteList() + + // 检查IP是否在白名单中 + if !utils.Contains(whiteList, ip) { + http.Error(w, "Access denied", http.StatusInternalServerError) + return + } + var msg structs.Message err := json.NewDecoder(r.Body).Decode(&msg) if err != nil { @@ -162,7 +175,7 @@ func (app *App) ChatHandlerErnie(w http.ResponseWriter, r *http.Request) { // 获取访问凭证和API路径 accessToken := config.GetWenxinAccessToken() - apiPath := config.GetWenxinApiPath() + apiPath := config.GetWenxinApiPath(promptstr) // 构建请求URL url := fmtf.Sprintf("%s?access_token=%s", apiPath, accessToken) diff --git a/applogic/gensokyo.go b/applogic/gensokyo.go index cc3868c..ac2c84a 100644 --- a/applogic/gensokyo.go +++ b/applogic/gensokyo.go @@ -250,7 +250,7 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { } // 缓存省钱部分 - if config.GetUseCache() { + if config.GetUseCache(promptstr) { //fmtf.Printf("计算向量: %v", vector) cacheThreshold := config.GetCacheThreshold() // 搜索相似文本和对应的ID diff --git a/applogic/hunyuan.go b/applogic/hunyuan.go index a5dfbf7..399adc0 100644 --- a/applogic/hunyuan.go +++ b/applogic/hunyuan.go @@ -22,6 +22,19 @@ func (app *App) ChatHandlerHunyuan(w http.ResponseWriter, r *http.Request) { return } + // 获取访问者的IP地址 + ip := r.RemoteAddr // 注意:这可能包含端口号 + ip = strings.Split(ip, ":")[0] // 去除端口号,仅保留IP地址 + + // 获取IP白名单 + whiteList := config.IPWhiteList() + + // 检查IP是否在白名单中 + if !utils.Contains(whiteList, ip) { + http.Error(w, "Access denied", http.StatusInternalServerError) + return + } + var msg structs.Message err := json.NewDecoder(r.Body).Decode(&msg) if err != nil { diff --git a/applogic/rwkv.go b/applogic/rwkv.go index 6dc00c1..7017e8c 100644 --- a/applogic/rwkv.go +++ b/applogic/rwkv.go @@ -31,6 +31,19 @@ func (app *App) ChatHandlerRwkv(w http.ResponseWriter, r *http.Request) { return } + // 获取访问者的IP地址 + ip := r.RemoteAddr // 注意:这可能包含端口号 + ip = strings.Split(ip, ":")[0] // 去除端口号,仅保留IP地址 + + // 获取IP白名单 + whiteList := config.IPWhiteList() + + // 检查IP是否在白名单中 + if !utils.Contains(whiteList, ip) { + http.Error(w, "Access denied", http.StatusInternalServerError) + return + } + var msg structs.Message err := json.NewDecoder(r.Body).Decode(&msg) if err != nil { diff --git a/config/config.go b/config/config.go index c9775cb..9f62ca3 100644 --- a/config/config.go +++ b/config/config.go @@ -1,6 +1,7 @@ package config import ( + "fmt" "log" "math/rand" "os" @@ -176,14 +177,15 @@ func GetWenxinApiPath(options ...string) string { mu.Lock() defer mu.Unlock() - if len(options) == 0 { + // 检查是否有参数传递进来,以及是否为空字符串 + if len(options) == 0 || options[0] == "" { if instance != nil { return instance.Settings.WenxinApiPath } return "0" } - // 处理传入的 basename + // 使用传入的 basename basename := options[0] apiPathInterface, err := prompt.GetSettingFromFilename(basename, "WenxinApiPath") if err != nil { @@ -215,14 +217,15 @@ func GetGptModel(options ...string) string { mu.Lock() defer mu.Unlock() - if len(options) == 0 { + // 检查是否有参数传递进来,以及是否为空字符串 + if len(options) == 0 || options[0] == "" { if instance != nil { return instance.Settings.GptModel } return "0" } - // 处理传入的 basename + // 使用传入的 basename basename := options[0] gptModelInterface, err := prompt.GetSettingFromFilename(basename, "GptModel") if err != nil { @@ -232,7 +235,7 @@ func GetGptModel(options ...string) string { gptModel, ok := gptModelInterface.(string) if !ok { - fmtf.Println("Type assertion failed for GptModel") + fmt.Println("Type assertion failed for GptModel") return "0" } @@ -615,13 +618,33 @@ func GetAntiPromptLimit() float64 { } // 获取UseCache -func GetUseCache() bool { +func GetUseCache(options ...string) bool { mu.Lock() defer mu.Unlock() - if instance != nil { - return instance.Settings.UseCache + + // 检查是否有参数传递进来,以及是否为空字符串 + if len(options) == 0 || options[0] == "" { + if instance != nil { + return instance.Settings.UseCache + } + return false } - return false + + // 使用传入的 basename + basename := options[0] + useCacheInterface, err := prompt.GetSettingFromFilename(basename, "UseCache") + if err != nil { + log.Println("Error retrieving UseCache:", err) + return false + } + + useCache, ok := useCacheInterface.(bool) + if !ok { + log.Println("Type assertion failed for UseCache") + return false + } + + return useCache } // 获取CacheThreshold From a2f056ee856f96deb3c83f04b61a3004c1150d04 Mon Sep 17 00:00:00 2001 From: SanaeFox <36219542+Hoshinonyaruko@users.noreply.github.com> Date: Thu, 18 Apr 2024 22:03:35 +0800 Subject: [PATCH 10/13] Beta72 (#72) * beta1 * beta2 * beta3 * beta4 * beta5 * beta6 * beta7 * beta8 * beta9 * beta10 * beta11 * beta12 * beta13 * beta14 * beta15 * beta16 * beta16 * beta19 * beta20 * beta21 * beta22 * beta23 * beta24 * beta25 * beta27 * beta28 * beta29 * beta30 * beta31 * beta33 * beta34 * beta35 * beta36 * beta37 * beta38 * beta39 * beta40 * beta41 * beta42 * beta43 * beta44 * beta45 * beta45 * beta46 * beat48 * beta49 * beta50 * beta51 * beta52 * beta53 * beta54 * beta55 * beta57 * beta58 * beta59 * beta61 * beta62 * beta63 * beta63 * beta64 * beta65 * beta66 * beta67 * beta70 * beta71 * beta72 * beta72 --- config/config.go | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/config/config.go b/config/config.go index 9f62ca3..aadf864 100644 --- a/config/config.go +++ b/config/config.go @@ -177,6 +177,11 @@ func GetWenxinApiPath(options ...string) string { mu.Lock() defer mu.Unlock() + return getWenxinApiPathInternal(options...) +} + +// 内部逻辑执行函数,不处理锁,可以安全地递归调用 +func getWenxinApiPathInternal(options ...string) string { // 检查是否有参数传递进来,以及是否为空字符串 if len(options) == 0 || options[0] == "" { if instance != nil { @@ -194,9 +199,9 @@ func GetWenxinApiPath(options ...string) string { } apiPath, ok := apiPathInterface.(string) - if !ok { - log.Println("Type assertion failed for WenxinApiPath") - return "0" + if !ok || apiPath == "" { // 检查是否断言失败或结果为空字符串 + log.Println("Type assertion failed or empty string for WenxinApiPath, fetching default") + return getWenxinApiPathInternal() // 递归调用内部函数,不传递任何参数 } return apiPath @@ -216,7 +221,11 @@ func GetMaxTokenWenxin() int { func GetGptModel(options ...string) string { mu.Lock() defer mu.Unlock() + return getGptModelInternal(options...) +} +// 内部逻辑执行函数,不处理锁,可以安全地递归调用 +func getGptModelInternal(options ...string) string { // 检查是否有参数传递进来,以及是否为空字符串 if len(options) == 0 || options[0] == "" { if instance != nil { @@ -234,9 +243,9 @@ func GetGptModel(options ...string) string { } gptModel, ok := gptModelInterface.(string) - if !ok { - fmt.Println("Type assertion failed for GptModel") - return "0" + if !ok || gptModel == "" { // 检查是否断言失败或结果为空字符串 + fmt.Println("Type assertion failed or empty string for GptModel, fetching default") + return getGptModelInternal() // 递归调用内部函数,不传递任何参数 } return gptModel From 1f68210384de248c53c7e900a6c6a8d2ba44e4f9 Mon Sep 17 00:00:00 2001 From: SanaeFox <36219542+Hoshinonyaruko@users.noreply.github.com> Date: Fri, 19 Apr 2024 15:39:22 +0800 Subject: [PATCH 11/13] Beta74 (#73) * beta1 * beta2 * beta3 * beta4 * beta5 * beta6 * beta7 * beta8 * beta9 * beta10 * beta11 * beta12 * beta13 * beta14 * beta15 * beta16 * beta16 * beta19 * beta20 * beta21 * beta22 * beta23 * beta24 * beta25 * beta27 * beta28 * beta29 * beta30 * beta31 * beta33 * beta34 * beta35 * beta36 * beta37 * beta38 * beta39 * beta40 * beta41 * beta42 * beta43 * beta44 * beta45 * beta45 * beta46 * beat48 * beta49 * beta50 * beta51 * beta52 * beta53 * beta54 * beta55 * beta57 * beta58 * beta59 * beta61 * beta62 * beta63 * beta63 * beta64 * beta65 * beta66 * beta67 * beta70 * beta71 * beta72 * beta72 * beta74 --- applogic/gensokyo.go | 22 +++++++++---- config/config.go | 11 ++++++- structs/struct.go | 1 + utils/utils.go | 78 +++++++++++++++++++++++++++++++++++++++----- 4 files changed, 97 insertions(+), 15 deletions(-) diff --git a/applogic/gensokyo.go b/applogic/gensokyo.go index ac2c84a..096636a 100644 --- a/applogic/gensokyo.go +++ b/applogic/gensokyo.go @@ -8,6 +8,7 @@ import ( "io" "math/rand" "net/http" + "net/url" "strconv" "strings" @@ -359,14 +360,23 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { port := config.GetPort() portStr := fmtf.Sprintf(":%d", port) - var url string - //如果promptstr不等于空,添加到参数中 + // 初始化URL + baseURL := "http://127.0.0.1" + portStr + "/conversation" + + // 使用net/url包来构建和编码URL + urlParams := url.Values{} if promptstr != "" { - url = "http://127.0.0.1" + portStr + "/conversation?prompt=" + promptstr - } else { - url = "http://127.0.0.1" + portStr + "/conversation" + urlParams.Add("prompt", promptstr) } + // 将查询参数编码后附加到基本URL上 + fullURL := baseURL + if len(urlParams) > 0 { + fullURL += "?" + urlParams.Encode() + } + + fmtf.Printf("Generated URL:%v", fullURL) + // 请求模型还是使用原文请求 requestmsg := message.Message.(string) @@ -399,7 +409,7 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { return } - resp, err := http.Post(url, "application/json", bytes.NewBuffer(requestBody)) + resp, err := http.Post(fullURL, "application/json", bytes.NewBuffer(requestBody)) if err != nil { fmtf.Printf("Error sending request to conversation interface: %v\n", err) return diff --git a/config/config.go b/config/config.go index aadf864..6e63bc3 100644 --- a/config/config.go +++ b/config/config.go @@ -176,7 +176,6 @@ func GetWenxinAccessToken() string { func GetWenxinApiPath(options ...string) string { mu.Lock() defer mu.Unlock() - return getWenxinApiPathInternal(options...) } @@ -1133,3 +1132,13 @@ func GetWSServerToken() string { } return "" } + +// 获取PathToken +func GetPathToken() string { + mu.Lock() + defer mu.Unlock() + if instance != nil { + return instance.Settings.PathToken + } + return "" +} diff --git a/structs/struct.go b/structs/struct.go index 86af079..90ee12e 100644 --- a/structs/struct.go +++ b/structs/struct.go @@ -214,6 +214,7 @@ type Settings struct { UseSse bool `yaml:"useSse"` Port int `yaml:"port"` HttpPath string `yaml:"path"` + PathToken string `yaml:"pathToken"` SystemPrompt []string `yaml:"systemPrompt"` IPWhiteList []string `yaml:"iPWhiteList"` ApiType int `yaml:"apiType"` diff --git a/utils/utils.go b/utils/utils.go index 8bafa4e..1b0d8f0 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -8,6 +8,7 @@ import ( "io" "math/rand" "net/http" + "net/url" "os" "regexp" "strconv" @@ -164,7 +165,22 @@ func SendGroupMessage(groupID int64, userID int64, message string, selfid string baseURL := config.GetHttpPath() // 假设config.getHttpPath()返回基础URL // 构建完整的URL - url := baseURL + "/send_group_msg" + baseURL = baseURL + "/send_group_msg" + + // 获取PathToken并检查其是否为空 + pathToken := config.GetPathToken() + // 使用net/url包构建URL + u, err := url.Parse(baseURL) + if err != nil { + panic("URL parsing failed: " + err.Error()) + } + + // 添加access_token参数 + query := u.Query() + if pathToken != "" { + query.Set("access_token", pathToken) + } + u.RawQuery = query.Encode() if config.GetSensitiveModeType() == 1 { message = acnode.CheckWordOUT(message) @@ -182,7 +198,7 @@ func SendGroupMessage(groupID int64, userID int64, message string, selfid string } // 发送POST请求 - resp, err := http.Post(url, "application/json", bytes.NewBuffer(requestBody)) + resp, err := http.Post(u.String(), "application/json", bytes.NewBuffer(requestBody)) if err != nil { return fmtf.Errorf("failed to send POST request: %w", err) } @@ -234,7 +250,22 @@ func SendPrivateMessage(UserID int64, message string, selfid string) error { baseURL := config.GetHttpPath() // 假设config.getHttpPath()返回基础URL // 构建完整的URL - url := baseURL + "/send_private_msg" + baseURL = baseURL + "/send_private_msg" + + // 获取PathToken并检查其是否为空 + pathToken := config.GetPathToken() + // 使用net/url包构建URL + u, err := url.Parse(baseURL) + if err != nil { + panic("URL parsing failed: " + err.Error()) + } + + // 添加access_token参数 + query := u.Query() + if pathToken != "" { + query.Set("access_token", pathToken) + } + u.RawQuery = query.Encode() if config.GetSensitiveModeType() == 1 { message = acnode.CheckWordOUT(message) @@ -251,7 +282,7 @@ func SendPrivateMessage(UserID int64, message string, selfid string) error { } // 发送POST请求 - resp, err := http.Post(url, "application/json", bytes.NewBuffer(requestBody)) + resp, err := http.Post(u.String(), "application/json", bytes.NewBuffer(requestBody)) if err != nil { return fmtf.Errorf("failed to send POST request: %w", err) } @@ -289,7 +320,23 @@ func SendPrivateMessageSSE(UserID int64, message structs.InterfaceBody) error { baseURL := config.GetHttpPath() // 假设config.GetHttpPath()返回基础URL // 构建完整的URL - url := baseURL + "/send_private_msg_sse" + baseURL = baseURL + "/send_private_msg_sse" + + // 获取PathToken并检查其是否为空 + pathToken := config.GetPathToken() + // 使用net/url包构建URL + u, err := url.Parse(baseURL) + if err != nil { + panic("URL parsing failed: " + err.Error()) + } + + // 添加access_token参数 + query := u.Query() + if pathToken != "" { + query.Set("access_token", pathToken) + } + u.RawQuery = query.Encode() + // 调试用的 if config.GetPrintHanming() { fmtf.Printf("流式信息替换前:%v", message.Content) @@ -324,7 +371,7 @@ func SendPrivateMessageSSE(UserID int64, message structs.InterfaceBody) error { } // 发送POST请求 - resp, err := http.Post(url, "application/json", bytes.NewBuffer(requestBody)) + resp, err := http.Post(u.String(), "application/json", bytes.NewBuffer(requestBody)) if err != nil { return fmtf.Errorf("failed to send POST request: %w", err) } @@ -806,7 +853,22 @@ func DeleteLatestMessage(messageType string, id int64, userid int64) error { baseURL := config.GetHttpPath() // 假设config.GetHttpPath()返回基础URL // 构建完整的URL - url := baseURL + "/delete_msg" + baseURL = baseURL + "/delete_msg" + + // 获取PathToken并检查其是否为空 + pathToken := config.GetPathToken() + // 使用net/url包构建URL + u, err := url.Parse(baseURL) + if err != nil { + panic("URL parsing failed: " + err.Error()) + } + + // 添加access_token参数 + query := u.Query() + if pathToken != "" { + query.Set("access_token", pathToken) + } + u.RawQuery = query.Encode() // 获取最新的有效消息ID messageID, valid := GetLatestValidMessageID(userid) @@ -839,5 +901,5 @@ func DeleteLatestMessage(messageType string, id int64, userid int64) error { fmtf.Printf("发送撤回请求:%v", string(requestBodyBytes)) // 发送删除消息请求 - return sendDeleteRequest(url, requestBodyBytes) + return sendDeleteRequest(u.String(), requestBodyBytes) } From 6ad461ddd456a91699b54f2681bb80d7f659af8e Mon Sep 17 00:00:00 2001 From: SanaeFox <36219542+Hoshinonyaruko@users.noreply.github.com> Date: Fri, 19 Apr 2024 19:12:09 +0800 Subject: [PATCH 12/13] Beta75 (#74) * beta1 * beta2 * beta3 * beta4 * beta5 * beta6 * beta7 * beta8 * beta9 * beta10 * beta11 * beta12 * beta13 * beta14 * beta15 * beta16 * beta16 * beta19 * beta20 * beta21 * beta22 * beta23 * beta24 * beta25 * beta27 * beta28 * beta29 * beta30 * beta31 * beta33 * beta34 * beta35 * beta36 * beta37 * beta38 * beta39 * beta40 * beta41 * beta42 * beta43 * beta44 * beta45 * beta45 * beta46 * beat48 * beta49 * beta50 * beta51 * beta52 * beta53 * beta54 * beta55 * beta57 * beta58 * beta59 * beta61 * beta62 * beta63 * beta63 * beta64 * beta65 * beta66 * beta67 * beta70 * beta71 * beta72 * beta72 * beta74 * beta75 --- utils/log.go | 36 ++++++++++++++++++++++++++++++++++++ utils/utils.go | 4 ++++ 2 files changed, 40 insertions(+) diff --git a/utils/log.go b/utils/log.go index ac77815..ae4da35 100644 --- a/utils/log.go +++ b/utils/log.go @@ -82,6 +82,9 @@ func processLogFile(filePath string) { if strings.Contains(line, "A完整信息:") { formatAndWriteAnswerLine(line, outputFile) } + if strings.Contains(line, "实际发送信息:") { + formatAndWriteAnswerLineV2(line, outputFile) + } } if err := scanner.Err(); err != nil { @@ -123,3 +126,36 @@ func formatAndWriteAnswerLine(line string, outputFile *os.File) { } } } + +func formatAndWriteAnswerLineV2(line string, outputFile *os.File) { + prefix := "实际发送信息:" + infoSuffix := "INFO:" // 设置截止字符串 + + currentIndex := 0 // 当前搜索的起始位置 + for { + // 从当前索引开始查找"实际发送信息:"的开始位置 + startIndex := strings.Index(line[currentIndex:], prefix) + if startIndex == -1 { + break // 如果没有找到,退出循环 + } + startIndex += currentIndex // 调整到全局索引 + + messageStart := startIndex + len(prefix) + endIndex := strings.Index(line[messageStart:], infoSuffix) // 查找"INFO:"的开始位置 + if endIndex == -1 { + break // 如果没有找到,退出循环 + } + endIndex += messageStart // 调整到全局索引 + + message := line[messageStart:endIndex] // 截取从"实际发送信息:"到"INFO:"之前的内容 + formattedLine := fmt.Sprintf("实际发送:%s\n", strings.TrimSpace(message)) // 格式化并去除前后空白字符 + + // 写入到输出文件 + _, err := outputFile.WriteString(formattedLine) + if err != nil { + fmt.Println("Error writing to output file:", err) + } + + currentIndex = endIndex // 更新currentIndex为当前endIndex,为下一次搜索做准备 + } +} diff --git a/utils/utils.go b/utils/utils.go index 1b0d8f0..3156755 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -193,6 +193,7 @@ func SendGroupMessage(groupID int64, userID int64, message string, selfid string "message": message, }) fmtf.Printf("发群信息请求:%v", string(requestBody)) + fmtf.Printf("实际发送信息:%v", message) if err != nil { return fmtf.Errorf("failed to marshal request body: %w", err) } @@ -280,6 +281,7 @@ func SendPrivateMessage(UserID int64, message string, selfid string) error { if err != nil { return fmtf.Errorf("failed to marshal request body: %w", err) } + fmtf.Printf("实际发送信息:%v", message) // 发送POST请求 resp, err := http.Post(u.String(), "application/json", bytes.NewBuffer(requestBody)) @@ -361,6 +363,8 @@ func SendPrivateMessageSSE(UserID int64, message structs.InterfaceBody) error { return nil } + fmtf.Printf("实际发送信息:%v", message.Content) + // 构造请求体,包括InterfaceBody requestBody, err := json.Marshal(map[string]interface{}{ "user_id": UserID, From 5c2b8f5832f8add01266b5d4c5c249e6d68cccae Mon Sep 17 00:00:00 2001 From: SanaeFox <36219542+Hoshinonyaruko@users.noreply.github.com> Date: Sat, 20 Apr 2024 23:05:16 +0800 Subject: [PATCH 13/13] Beta76 (#75) * beta1 * beta2 * beta3 * beta4 * beta5 * beta6 * beta7 * beta8 * beta9 * beta10 * beta11 * beta12 * beta13 * beta14 * beta15 * beta16 * beta16 * beta19 * beta20 * beta21 * beta22 * beta23 * beta24 * beta25 * beta27 * beta28 * beta29 * beta30 * beta31 * beta33 * beta34 * beta35 * beta36 * beta37 * beta38 * beta39 * beta40 * beta41 * beta42 * beta43 * beta44 * beta45 * beta45 * beta46 * beat48 * beta49 * beta50 * beta51 * beta52 * beta53 * beta54 * beta55 * beta57 * beta58 * beta59 * beta61 * beta62 * beta63 * beta63 * beta64 * beta65 * beta66 * beta67 * beta70 * beta71 * beta72 * beta72 * beta74 * beta75 * beta76 --- applogic/chatgpt.go | 10 ++++++++-- applogic/rwkv.go | 10 ++++++++-- utils/utils.go | 6 ++++++ 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/applogic/chatgpt.go b/applogic/chatgpt.go index 74ba45c..b41f7c4 100644 --- a/applogic/chatgpt.go +++ b/applogic/chatgpt.go @@ -361,9 +361,15 @@ func (app *App) ChatHandlerChatgpt(w http.ResponseWriter, r *http.Request) { newContent := "" for _, choice := range eventData.Choices { + // 如果新内容以旧内容开头 if strings.HasPrefix(choice.Delta.Content, lastResponseText) { - // 如果新内容以旧内容开头,剔除旧内容部分,只保留新增的部分 - newContent += choice.Delta.Content[len(lastResponseText):] + // 特殊情况:当新内容和旧内容完全相同时,处理逻辑应当与新内容不以旧内容开头时相同 + if choice.Delta.Content == lastResponseText { + newContent += choice.Delta.Content + } else { + // 剔除旧内容部分,只保留新增的部分 + newContent += choice.Delta.Content[len(lastResponseText):] + } } else { // 如果新内容不以旧内容开头,可能是并发情况下的新消息,直接使用新内容 newContent += choice.Delta.Content diff --git a/applogic/rwkv.go b/applogic/rwkv.go index 7017e8c..69e19a6 100644 --- a/applogic/rwkv.go +++ b/applogic/rwkv.go @@ -365,9 +365,15 @@ func (app *App) ChatHandlerRwkv(w http.ResponseWriter, r *http.Request) { newContent := "" for _, choice := range eventData.Choices { + // 如果新内容以旧内容开头 if strings.HasPrefix(choice.Delta.Content, lastResponseText) { - // 如果新内容以旧内容开头,剔除旧内容部分,只保留新增的部分 - newContent += choice.Delta.Content[len(lastResponseText):] + // 特殊情况:当新内容和旧内容完全相同时,处理逻辑应当与新内容不以旧内容开头时相同 + if choice.Delta.Content == lastResponseText { + newContent += choice.Delta.Content + } else { + // 剔除旧内容部分,只保留新增的部分 + newContent += choice.Delta.Content[len(lastResponseText):] + } } else { // 如果新内容不以旧内容开头,可能是并发情况下的新消息,直接使用新内容 newContent += choice.Delta.Content diff --git a/utils/utils.go b/utils/utils.go index 3156755..38808ac 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -186,6 +186,9 @@ func SendGroupMessage(groupID int64, userID int64, message string, selfid string message = acnode.CheckWordOUT(message) } + // 去除末尾的换行符 不去除会导致不好看 + message = removeTrailingCRLFs(message) + // 构造请求体 requestBody, err := json.Marshal(map[string]interface{}{ "group_id": groupID, @@ -272,6 +275,9 @@ func SendPrivateMessage(UserID int64, message string, selfid string) error { message = acnode.CheckWordOUT(message) } + // 去除末尾的换行符 不去除会导致不好看 + message = removeTrailingCRLFs(message) + // 构造请求体 requestBody, err := json.Marshal(map[string]interface{}{ "user_id": UserID,