diff --git a/.gitignore b/.gitignore index d4db4b7..8ef422f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ # specific -config.yml +*.yml *.sqlite *.txt diff --git a/applogic/gensokyo.go b/applogic/gensokyo.go index 64ee051..0e52fa4 100644 --- a/applogic/gensokyo.go +++ b/applogic/gensokyo.go @@ -100,7 +100,7 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { switch msg := message.Message.(type) { case string: // message.Message是一个string - fmtf.Printf("Received string message: %s\n", msg) + fmtf.Printf("userid:[%v]Received string message: %s\n", message.UserID, msg) //是否过滤群信息 if !config.GetGroupmessage() { @@ -125,6 +125,11 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { } } + if utils.BlacklistIntercept(message) { + fmtf.Printf("userid:[%v]这位用户在黑名单中,被拦截", message.UserID) + return + } + //处理重置指令 if isResetCommand { fmtf.Println("处理重置操作") @@ -134,55 +139,7 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { if !config.GetUsePrivateSSE() { utils.SendPrivateMessage(message.UserID, RestoreResponse) } else { - - // 将字符串转换为rune切片,以正确处理多字节字符 - runes := []rune(RestoreResponse) - - // 计算每部分应该包含的rune数量 - partLength := len(runes) / 3 - - // 初始化用于存储分割结果的切片 - parts := make([]string, 3) - - // 按字符分割字符串 - for i := 0; i < 3; i++ { - if i < 2 { // 前两部分 - start := i * partLength - end := start + partLength - parts[i] = string(runes[start:end]) - } else { // 最后一部分,包含所有剩余的字符 - start := i * partLength - parts[i] = string(runes[start:]) - } - } - - // 开头 - messageSSE := structs.InterfaceBody{ - Content: parts[0], - State: 1, - } - - utils.SendPrivateMessageSSE(message.UserID, messageSSE) - - //中间 - messageSSE = structs.InterfaceBody{ - Content: parts[1], - State: 11, - } - utils.SendPrivateMessageSSE(message.UserID, messageSSE) - - // 从配置中获取promptkeyboard - promptkeyboard := config.GetPromptkeyboard() - - // 创建InterfaceBody结构体实例 - messageSSE = structs.InterfaceBody{ - Content: parts[2], // 假设空格字符串是期望的内容 - State: 20, // 假设的状态码 - PromptKeyboard: promptkeyboard, // 使用更新后的promptkeyboard - } - - // 发送SSE私人消息 - utils.SendPrivateMessageSSE(message.UserID, messageSSE) + utils.SendSSEPrivateRestoreMessage(message.UserID, RestoreResponse) } } else { utils.SendGroupMessage(message.GroupID, RestoreResponse) @@ -282,6 +239,10 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { responseText, err := app.GetRandomAnswer(similarTexts[0]) if err == nil { fmtf.Printf("缓存命中,Q:%v,A:%v\n", newmsg, responseText) + //加入上下文 + if app.AddSingleContext(message, responseText) { + fmtf.Printf("缓存加入上下文成功") + } // 发送响应消息 if message.RealMessageType == "group_private" || message.MessageType == "private" { if !config.GetUsePrivateSSE() { diff --git a/applogic/singlecontext.go b/applogic/singlecontext.go new file mode 100644 index 0000000..7064aed --- /dev/null +++ b/applogic/singlecontext.go @@ -0,0 +1,50 @@ +package applogic + +import ( + "time" + + "github.com/hoshinonyaruko/gensokyo-llm/fmtf" + "github.com/hoshinonyaruko/gensokyo-llm/structs" +) + +// 直接根据缓存来储存上下文 +// 其实向量缓存是一个单轮的QA缓存,因为这个项目很初步,很显然无法应对上下文场景的缓存 +// 通过这种方式,将每次缓存的内容也加入上下文,可能会有一个初步的效果提升. +func (app *App) AddSingleContext(message structs.OnebotGroupMessage, responseText string) bool { + // 请求conversation api 增加当前用户上下文 + conversationID, parentMessageID, err := app.handleUserContext(message.UserID) + if err != nil { + fmtf.Printf("error in AddSingleContext app.handleUserContex :%v", err) + return false + } + + // 构造用户消息并添加到上下文 + userMessage := structs.Message{ + ConversationID: conversationID, + ParentMessageID: parentMessageID, + Text: message.Message.(string), + Role: "user", + CreatedAt: time.Now().Format(time.RFC3339), + } + userMessageID, err := app.addMessage(userMessage) + if err != nil { + fmtf.Printf("error in AddSingleContext app.addMessage(userMessage) :%v", err) + return false + } + + // 构造助理消息并添加到上下文 + assistantMessage := structs.Message{ + ConversationID: conversationID, + ParentMessageID: userMessageID, + Text: responseText, + Role: "assistant", + CreatedAt: time.Now().Format(time.RFC3339), + } + _, err = app.addMessage(assistantMessage) + if err != nil { + fmtf.Printf("error in AddSingleContext app.addMessage(assistantMessage) :%v", err) + return false + } + + return true +} diff --git a/config/config.go b/config/config.go index b0a16c2..134439b 100644 --- a/config/config.go +++ b/config/config.go @@ -85,6 +85,7 @@ type Settings struct { LanguagesResponseMessages []string `yaml:"langResponseMessages"` QuestionMaxLenth int `yaml:"questionMaxLenth"` QmlResponseMessages []string `yaml:"qmlResponseMessages"` + BlacklistResponseMessages []string `yaml:"blacklistResponseMessages"` } // LoadConfig 从文件中加载配置并初始化单例配置 @@ -850,3 +851,19 @@ func GetQmlResponseMessages() string { } return "" // 如果列表为空,返回空字符串 } + +// BlacklistResponseMessages 返回语言拦截响应消息列表 +func GetBlacklistResponseMessages() string { + mu.Lock() + defer mu.Unlock() + if instance != nil && len(instance.Settings.BlacklistResponseMessages) > 0 { + // 如果列表中只有一个消息,直接返回这个消息 + if len(instance.Settings.BlacklistResponseMessages) == 1 { + return instance.Settings.BlacklistResponseMessages[0] + } + // 如果有多个消息,随机选择一个返回 + index := rand.Intn(len(instance.Settings.BlacklistResponseMessages)) + return instance.Settings.BlacklistResponseMessages[index] + } + return "" // 如果列表为空,返回空字符串 +} diff --git a/go.mod b/go.mod index 32c4542..97f89f9 100644 --- a/go.mod +++ b/go.mod @@ -10,3 +10,8 @@ require ( ) require github.com/abadojack/whatlanggo v1.0.1 + +require ( + github.com/fsnotify/fsnotify v1.7.0 // indirect + golang.org/x/sys v0.4.0 // indirect +) diff --git a/go.sum b/go.sum index 646cdb2..49a2f5e 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +1,15 @@ github.com/abadojack/whatlanggo v1.0.1 h1:19N6YogDnf71CTHm3Mp2qhYfkRdyvbgwWdd2EPxJRG4= github.com/abadojack/whatlanggo v1.0.1/go.mod h1:66WiQbSbJBIlOZMsvbKe5m6pzQovxCH9B/K8tQB2uoc= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU= github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/mattn/go-sqlite3 v1.14.19 h1:fhGleo2h1p8tVChob4I9HpmVFIAkKGpiukdrgQbWfGI= github.com/mattn/go-sqlite3 v1.14.19/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.839 h1:VGVFNQDaUpDsPkJrh8I9qOxHZ1yj5sJmg9ngsUvTAHM= github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.839/go.mod h1:r5r4xbfxSaeR04b166HGsBa/R4U3SueirEUpXGuw+Q0= +golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= +golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/main.go b/main.go index 50fc0a9..21077dc 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "log" "net/http" "os" + "path/filepath" _ "github.com/mattn/go-sqlite3" // 只导入,作为驱动 @@ -19,25 +20,41 @@ import ( ) func main() { - testFlag := flag.Bool("test", false, "Run the test script,test.txt中的是虚拟信息,一行一条") + testFlag := flag.Bool("test", false, "Run the test script, test.txt中的是虚拟信息,一行一条") + ymlPath := flag.String("yml", "", "指定config.yml的路径") flag.Parse() - if _, err := os.Stat("config.yml"); os.IsNotExist(err) { + // 如果用户指定了-yml参数 + configFilePath := "config.yml" // 默认配置文件路径 + if *ymlPath != "" { + configFilePath = *ymlPath + } - // 将修改后的配置写入 config.yml - err = os.WriteFile("config.yml", []byte(template.ConfigTemplate), 0644) - if err != nil { - fmtf.Println("Error writing config.yml:", err) + // 检查配置文件是否存在 + if _, err := os.Stat(configFilePath); os.IsNotExist(err) { + if *ymlPath == "" { + // 用户没有指定-yml参数,按照默认行为处理 + err = os.WriteFile(configFilePath, []byte(template.ConfigTemplate), 0644) + if err != nil { + fmtf.Println("Error writing config.yml:", err) + return + } + fmtf.Println("请配置config.yml然后再次运行.") + fmtf.Print("按下 Enter 继续...") + bufio.NewReader(os.Stdin).ReadBytes('\n') + os.Exit(0) + } else { + // 用户指定了-yml参数,但指定的文件不存在 + fmtf.Println("指定的配置文件不存在:", *ymlPath) return } - - fmtf.Println("请配置config.yml然后再次运行.") - fmtf.Print("按下 Enter 继续...") - bufio.NewReader(os.Stdin).ReadBytes('\n') - os.Exit(0) + } else { + if *ymlPath != "" { + fmtf.Println("载入成功:", *ymlPath) + } } // 加载配置 - conf, err := config.LoadConfig("config.yml") + conf, err := config.LoadConfig(configFilePath) if err != nil { log.Fatalf("error: %v", err) } @@ -125,6 +142,21 @@ func main() { log.Printf("Unknown API type: %d", apiType) } + exePath, err := os.Executable() + if err != nil { + log.Fatal(err) + } + exeDir := filepath.Dir(exePath) + blacklistPath := filepath.Join(exeDir, "blacklist.txt") + + // 载入黑名单 + if err := utils.LoadBlacklist(blacklistPath); err != nil { + log.Fatalf("Failed to load blacklist: %v", err) + } + + // 启动黑名单文件变动监听 + go utils.WatchBlacklist(blacklistPath) + http.HandleFunc("/gensokyo", app.GensokyoHandler) port := config.GetPort() portStr := fmtf.Sprintf(":%d", port) diff --git a/template/config_template.go b/template/config_template.go index 01b4be0..c1d27c7 100644 --- a/template/config_template.go +++ b/template/config_template.go @@ -34,10 +34,11 @@ settings: promptkeyboard : [""] #临时的promptkeyboard超过3个则随机,后期会增加一个ai生成的方式,也会是ai-agent savelogs : false #本地落地日志. #语言过滤 - allowedLanguages : ["Cmn"] #根据自身安全实力,酌情过滤,cmn代表中文,小写字母,[]空数组代表不限制. + allowedLanguages : ["cmn"] #根据自身安全实力,酌情过滤,cmn代表中文,小写字母,[]空数组代表不限制. langResponseMessages : ["抱歉,我不会**这个语言呢","我不会**这门语言,请使用中文和我对话吧"] #定型文,**会自动替换为检测到的语言 questionMaxLenth : 100 #最大问题字数. 0代表不限制 qmlResponseMessages : ["问题太长了,缩短问题试试吧"] #最大问题长度回复. + blacklistResponseMessages : ["目前正在维护中...请稍候再试吧"] #黑名单回复,将userid丢入blacklist.txt 一行一个 #向量缓存(省钱-酌情调整参数)(进阶!!)需要有一定的调试能力,数据库调优能力,计算和数据测试能力. #不同种类的向量,维度和模型不同,所以请一开始决定好使用的向量,或者自行将数据库备份\对应,不同种类向量没有互相检索的能力。 diff --git a/utils/blacklist.go b/utils/blacklist.go new file mode 100644 index 0000000..b1f7423 --- /dev/null +++ b/utils/blacklist.go @@ -0,0 +1,117 @@ +package utils + +import ( + "bufio" + "fmt" + "log" + "os" + "strconv" + "sync" + + "github.com/fsnotify/fsnotify" + "github.com/hoshinonyaruko/gensokyo-llm/config" + "github.com/hoshinonyaruko/gensokyo-llm/structs" +) + +var blacklist = make(map[string]bool) +var mu sync.RWMutex + +// LoadBlacklist 从给定的文件路径载入黑名单ID。 +// 如果文件不存在,则创建该文件。 +func LoadBlacklist(filePath string) error { + file, err := os.Open(filePath) + if err != nil { + if os.IsNotExist(err) { + // 如果文件不存在,则创建一个新文件 + file, err = os.Create(filePath) + if err != nil { + return err // 创建文件失败,返回错误 + } + } else { + return err // 打开文件失败,且原因不是文件不存在 + } + } + defer file.Close() + + scanner := bufio.NewScanner(file) + mu.Lock() + defer mu.Unlock() + blacklist = make(map[string]bool) // 重置黑名单 + + for scanner.Scan() { + blacklist[scanner.Text()] = true + } + + return scanner.Err() +} + +// isInBlacklist 检查给定的ID是否在黑名单中。 +func IsInBlacklist(id string) bool { + mu.RLock() + defer mu.RUnlock() + _, exists := blacklist[id] + return exists +} + +// watchBlacklist 监控黑名单文件的变动并动态更新。 +func WatchBlacklist(filePath string) { + watcher, err := fsnotify.NewWatcher() + if err != nil { + log.Fatal("Error creating watcher:", err) + } + defer watcher.Close() + + done := make(chan bool) + go func() { + for { + select { + case event, ok := <-watcher.Events: + if !ok { + return + } + if event.Op&fsnotify.Write == fsnotify.Write { + fmt.Println("Detected update to blacklist, reloading...") + err := LoadBlacklist(filePath) + if err != nil { + log.Printf("Error reloading blacklist: %v", err) + } + } + case err, ok := <-watcher.Errors: + if !ok { + return + } + log.Println("Watcher error:", err) + } + } + }() + + err = watcher.Add(filePath) + if err != nil { + log.Fatal("Error adding watcher to file:", err) + } + <-done // Keep the watcher alive +} + +// BlacklistIntercept 检查用户ID是否在黑名单中,如果在,则发送预设消息 +func BlacklistIntercept(message structs.OnebotGroupMessage) bool { + // 检查用户ID是否在黑名单中 + if IsInBlacklist(strconv.FormatInt(message.UserID, 10)) { + // 获取黑名单响应消息 + responseMessage := config.GetBlacklistResponseMessages() + + // 根据消息类型发送响应 + if message.RealMessageType == "group_private" || message.MessageType == "private" { + if !config.GetUsePrivateSSE() { + SendPrivateMessage(message.UserID, responseMessage) + } else { + SendSSEPrivateMessage(message.UserID, responseMessage) + } + } else { + SendGroupMessage(message.GroupID, responseMessage) + } + + fmt.Printf("userid:[%v]这位用户在黑名单中,被拦截\n", message.UserID) + return true // 拦截 + } + return false // 用户ID不在黑名单中,不拦截 +} diff --git a/utils/utils.go b/utils/utils.go index f5c66c3..ce14969 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -11,6 +11,7 @@ import ( "os" "regexp" "strings" + "time" "github.com/abadojack/whatlanggo" "github.com/google/uuid" @@ -303,8 +304,12 @@ func PostSensitiveMessages() error { results = append(results, string(responseBody)) } - // 将HTTP响应结果保存到test_result.txt文件中 - return os.WriteFile("test_result.txt", []byte(strings.Join(results, "\n")), 0644) + // 使用当前时间戳生成文件名 + currentTime := time.Now() + fileName := "test_result_" + currentTime.Format("20060102_150405") + ".txt" + + // 将HTTP响应结果保存到指定的文件中 + return os.WriteFile(fileName, []byte(strings.Join(results, "\n")), 0644) } // SendSSEPrivateMessage 分割并发送消息的核心逻辑,直接遍历字符串 @@ -430,6 +435,58 @@ func SendSSEPrivateSafeMessage(userID int64, saveresponse string) { SendPrivateMessageSSE(userID, messageSSE) } +// SendSSEPrivateRestoreMessage 分割并发送重置消息的核心逻辑,直接遍历字符串 +func SendSSEPrivateRestoreMessage(userID int64, RestoreResponse string) { + // 将字符串转换为rune切片,以正确处理多字节字符 + runes := []rune(RestoreResponse) + + // 计算每部分应该包含的rune数量 + partLength := len(runes) / 3 + + // 初始化用于存储分割结果的切片 + parts := make([]string, 3) + + // 按字符分割字符串 + for i := 0; i < 3; i++ { + if i < 2 { // 前两部分 + start := i * partLength + end := start + partLength + parts[i] = string(runes[start:end]) + } else { // 最后一部分,包含所有剩余的字符 + start := i * partLength + parts[i] = string(runes[start:]) + } + } + + // 开头 + messageSSE := structs.InterfaceBody{ + Content: parts[0], + State: 1, + } + + SendPrivateMessageSSE(userID, messageSSE) + + //中间 + messageSSE = structs.InterfaceBody{ + Content: parts[1], + State: 11, + } + SendPrivateMessageSSE(userID, messageSSE) + + // 从配置中获取promptkeyboard + promptkeyboard := config.GetPromptkeyboard() + + // 创建InterfaceBody结构体实例 + messageSSE = structs.InterfaceBody{ + Content: parts[2], // 假设空格字符串是期望的内容 + State: 20, // 假设的状态码 + PromptKeyboard: promptkeyboard, // 使用更新后的promptkeyboard + } + + // 发送SSE私人消息 + SendPrivateMessageSSE(userID, messageSSE) +} + // LanguageIntercept 检查文本语言,如果不在允许列表中,则返回 true 并发送消息 func LanguageIntercept(text string, message structs.OnebotGroupMessage) bool { info := whatlanggo.Detect(text)