diff --git a/applogic/app.go b/applogic/app.go index 4593e70..c17a8d1 100644 --- a/applogic/app.go +++ b/applogic/app.go @@ -178,6 +178,32 @@ func (app *App) EnsureQATableExist() error { return nil } +func (app *App) EnsureCustomTableExist() error { + createTableSQL := ` + CREATE TABLE IF NOT EXISTS custom_table ( + user_id INTEGER PRIMARY KEY, + promptstr TEXT NOT NULL, + promptstr_stat INTEGER, + str1 TEXT, + str2 TEXT, + str3 TEXT, + str4 TEXT, + str5 TEXT, + str6 TEXT, + str7 TEXT, + str8 TEXT, + str9 TEXT, + str10 TEXT + );` + + _, err := app.DB.Exec(createTableSQL) + if err != nil { + return fmt.Errorf("error creating custom_table: %w", err) + } + + return nil +} + func (app *App) EnsureUserContextTableExists() error { createTableSQL := ` CREATE TABLE IF NOT EXISTS user_context ( diff --git a/applogic/gensokyo.go b/applogic/gensokyo.go index a76193e..8230dce 100644 --- a/applogic/gensokyo.go +++ b/applogic/gensokyo.go @@ -92,11 +92,68 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { return } - // 读取URL参数 "prompt" - promptstr := r.URL.Query().Get("prompt") - if promptstr != "" { - // 使用 prompt 变量进行后续处理 - fmt.Printf("收到prompt参数: %s\n", promptstr) + // 从数据库读取用户的剧情存档 + CustomRecord, err := app.FetchCustomRecord(message.UserID) + if err != nil { + fmt.Printf("app.FetchCustomRecord 出错: %s\n", err) + } + + var promptstr string + if CustomRecord != nil { + // 提示词参数 + if CustomRecord.PromptStr == "" { + // 读取URL参数 "prompt" + promptstr = r.URL.Query().Get("prompt") + if promptstr != "" { + // 使用 prompt 变量进行后续处理 + fmt.Printf("收到prompt参数: %s\n", promptstr) + } + } else { + promptstr = CustomRecord.PromptStr + fmt.Printf("刷新prompt参数: %s,newPromptStrStat:%d\n", promptstr, CustomRecord.PromptStrStat-1) + newPromptStrStat := CustomRecord.PromptStrStat - 1 + err = app.InsertCustomTableRecord(message.UserID, promptstr, newPromptStrStat) + if err != nil { + fmt.Printf("app.InsertCustomTableRecord 出错: %s\n", err) + } + } + + // 提示词之间流转 达到信号量 + markType := config.GetPromptMarkType(promptstr) + if (markType == 0 || markType == 1) && (CustomRecord.PromptStrStat-1 <= 0) { + PromptMarks := config.GetPromptMarks(promptstr) + if len(PromptMarks) != 0 { + randomIndex := rand.Intn(len(PromptMarks)) + newPromptStr := PromptMarks[randomIndex] + + // 如果 markType 是 1,提取 "aaa" 部分 + if markType == 1 { + parts := strings.Split(newPromptStr, ":") + if len(parts) > 0 { + newPromptStr = parts[0] // 取冒号前的部分作为新的提示词 + } + } + + // 刷新新的提示词给用户目前的状态 + // 获取新的信号长度 + PromptMarksLength := config.GetPromptMarksLength(newPromptStr) + + app.InsertCustomTableRecord(message.UserID, newPromptStr, PromptMarksLength) + fmt.Printf("流转prompt参数: %s,newPromptStrStat:%d\n", newPromptStr, PromptMarksLength) + } + } + } else { + // 读取URL参数 "prompt" + promptstr = r.URL.Query().Get("prompt") + if promptstr != "" { + // 使用 prompt 变量进行后续处理 + fmt.Printf("收到prompt参数: %s\n", promptstr) + } + PromptMarksLength := config.GetPromptMarksLength(promptstr) + err = app.InsertCustomTableRecord(message.UserID, promptstr, PromptMarksLength) + if err != nil { + fmt.Printf("app.InsertCustomTableRecord 出错: %s\n", err) + } } // 读取URL参数 "selfid" @@ -168,6 +225,8 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { } else { utils.SendGroupMessage(message.GroupID, message.UserID, RestoreResponse, selfid) } + // 处理故事情节的重置 + app.deleteCustomRecord(message.UserID) return } @@ -519,7 +578,9 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) { } } } - //清空之前加入缓存 + // 处理故事模式 + app.ProcessAnswer(message.UserID, response, promptstr) + // 清空之前加入缓存 // 缓存省钱部分 这里默认不被覆盖,如果主配置开了缓存,始终缓存. if config.GetUseCache() { if response != "" { diff --git a/applogic/promptstr.go b/applogic/promptstr.go new file mode 100644 index 0000000..8832df0 --- /dev/null +++ b/applogic/promptstr.go @@ -0,0 +1,164 @@ +package applogic + +import ( + "database/sql" + "fmt" + "strconv" + "strings" + + "github.com/hoshinonyaruko/gensokyo-llm/config" + "github.com/hoshinonyaruko/gensokyo-llm/structs" +) + +func (app *App) InsertCustomTableRecord(userID int64, promptStr string, promptStrStat int, strs ...string) error { + // 构建 SQL 语句,使用 UPSERT 逻辑 + sqlStr := ` + INSERT INTO custom_table (user_id, promptstr, promptstr_stat, str1, str2, str3, str4, str5, str6, str7, str8, str9, str10) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(user_id) DO UPDATE SET + promptstr = excluded.promptstr, + promptstr_stat = excluded.promptstr_stat` + + // 为每个非nil str构建更新部分 + updateParts := make([]string, 10) + params := make([]interface{}, 13) + params[0] = userID + params[1] = promptStr + params[2] = promptStrStat + + for i, str := range strs { + if i < 10 { + params[i+3] = str + if str != "" { // 只更新非空的str字段 + fieldName := fmt.Sprintf("str%d", i+1) + updateParts[i] = fmt.Sprintf("%s = excluded.%s", fieldName, fieldName) + } + } + } + + // 添加非空更新字段到SQL语句 + nonEmptyUpdates := []string{} + for _, part := range updateParts { + if part != "" { + nonEmptyUpdates = append(nonEmptyUpdates, part) + } + } + if len(nonEmptyUpdates) > 0 { + sqlStr += ", " + strings.Join(nonEmptyUpdates, ", ") + } + + sqlStr += ";" // 结束 SQL 语句 + + // 填充剩余的nil值 + for j := len(strs) + 3; j < 13; j++ { + params[j] = nil + } + + // 执行 SQL 操作 + _, err := app.DB.Exec(sqlStr, params...) + if err != nil { + return fmt.Errorf("error inserting or updating record in custom_table: %w", err) + } + + return nil +} + +func (app *App) FetchCustomRecord(userID int64, fields ...string) (*structs.CustomRecord, error) { + // Default fields now include promptstr_stat + queryFields := "user_id, promptstr, promptstr_stat" + if len(fields) > 0 { + queryFields += ", " + strings.Join(fields, ", ") + } + + // Construct the SQL query string + queryStr := fmt.Sprintf("SELECT %s FROM custom_table WHERE user_id = ?", queryFields) + + row := app.DB.QueryRow(queryStr, userID) + var record structs.CustomRecord + // Initialize scan parameters including the new promptstr_stat + scanArgs := []interface{}{&record.UserID, &record.PromptStr, &record.PromptStrStat} + for i := 0; i < len(fields); i++ { + idx := fieldIndex(fields[i]) + if idx >= 0 { + scanArgs = append(scanArgs, &record.Strs[idx]) + } + } + + err := row.Scan(scanArgs...) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil // No record found + } + return nil, fmt.Errorf("error scanning custom_table record: %w", err) + } + + return &record, nil +} + +func (app *App) deleteCustomRecord(userID int64) error { + deleteSQL := `DELETE FROM custom_table WHERE user_id = ?;` + + _, err := app.DB.Exec(deleteSQL, userID) + if err != nil { + return fmt.Errorf("error deleting record from custom_table: %w", err) + } + + return nil +} + +// Helper function to get index from field name +func fieldIndex(field string) int { + if strings.HasPrefix(field, "str") && len(field) > 3 { + if idx, err := strconv.Atoi(field[3:]); err == nil && idx >= 1 && idx <= 10 { + return idx - 1 + } + } + return -1 +} + +func (app *App) ProcessAnswer(userID int64, answer string, promptStr string) { + // 根据 promptStr 获取 PromptMarkType + markType := config.GetPromptMarkType(promptStr) + + // 如果 markType 是 0,则不执行任何操作 + if markType == 0 { + return + } + + // 如果 markType 是 1,执行以下操作 + if markType == 1 { + // 获取 PromptMarks + PromptMarks := config.GetPromptMarks(promptStr) + + for _, mark := range PromptMarks { + // 提取冒号右侧的文本,并转换为数组 + parts := strings.Split(mark, ":") + if len(parts) < 2 { + continue + } + codes := strings.Split(parts[1], "-") + + // 检查 answer 是否包含数组中的任意一个成员 + for _, code := range codes { + if strings.Contains(answer, code) { + // 当找到匹配时,构建新的 promptStr + newPromptStr := parts[0] + + // 获取 PromptMarksLength + PromptMarksLength := config.GetPromptMarksLength(newPromptStr) + + // 插入记录到自定义表 + err := app.InsertCustomTableRecord(userID, newPromptStr, PromptMarksLength) + if err != nil { + fmt.Println("Error inserting custom table record:", err) + return + } + + // 输出结果 + fmt.Printf("type1=流转prompt参数: %s, newPromptStrStat: %d\n", newPromptStr, PromptMarksLength) + return // 停止循环 + } + } + } + } +} diff --git a/config/config.go b/config/config.go index 50983e2..a77a06b 100644 --- a/config/config.go +++ b/config/config.go @@ -1402,3 +1402,105 @@ func getStandardGptApiInternal(options ...string) bool { return standardGptApi } + +// 获取 PromptMarkType +func GetPromptMarkType(options ...string) int { + mu.Lock() + defer mu.Unlock() + return getPromptMarkTypeInternal(options...) +} + +// 内部逻辑执行函数,不处理锁,可以安全地递归调用 +func getPromptMarkTypeInternal(options ...string) int { + // 检查是否有参数传递进来,以及是否为空字符串 + if len(options) == 0 || options[0] == "" { + if instance != nil { + return instance.Settings.PromptMarkType + } + return 0 // 默认返回 0 或一个合理的默认值 + } + + // 使用传入的 basename + basename := options[0] + promptMarkTypeInterface, err := prompt.GetSettingFromFilename(basename, "PromptMarkType") + if err != nil { + log.Println("Error retrieving PromptMarkType:", err) + return getPromptMarkTypeInternal() // 递归调用内部函数,不传递任何参数 + } + + promptMarkType, ok := promptMarkTypeInterface.(int) + if !ok { // 检查是否断言失败 + fmt.Println("Type assertion failed for PromptMarkType, fetching default") + return getPromptMarkTypeInternal() // 递归调用内部函数,不传递任何参数 + } + + return promptMarkType +} + +// 获取 PromptMarksLength +func GetPromptMarksLength(options ...string) int { + mu.Lock() + defer mu.Unlock() + return getPromptMarksLengthInternal(options...) +} + +// 内部逻辑执行函数,不处理锁,可以安全地递归调用 +func getPromptMarksLengthInternal(options ...string) int { + // 检查是否有参数传递进来,以及是否为空字符串 + if len(options) == 0 || options[0] == "" { + if instance != nil { + return instance.Settings.PromptMarksLength + } + return 0 // 默认返回 0 或一个合理的默认值 + } + + // 使用传入的 basename + basename := options[0] + promptMarksLengthInterface, err := prompt.GetSettingFromFilename(basename, "PromptMarksLength") + if err != nil { + log.Println("Error retrieving PromptMarksLength:", err) + return getPromptMarksLengthInternal() // 递归调用内部函数,不传递任何参数 + } + + promptMarksLength, ok := promptMarksLengthInterface.(int) + if !ok { // 检查是否断言失败 + fmt.Println("Type assertion failed for PromptMarksLength, fetching default") + return getPromptMarksLengthInternal() // 递归调用内部函数,不传递任何参数 + } + + return promptMarksLength +} + +// 获取 PromptMarks +func GetPromptMarks(options ...string) []string { + mu.Lock() + defer mu.Unlock() + return getPromptMarksInternal(options...) +} + +// 内部逻辑执行函数,不处理锁,可以安全地递归调用 +func getPromptMarksInternal(options ...string) []string { + // 检查是否有参数传递进来,以及是否为空字符串 + if len(options) == 0 || options[0] == "" { + if instance != nil { + return instance.Settings.PromptMarks + } + return nil // 如果实例或设置未定义,返回nil + } + + // 使用传入的 basename + basename := options[0] + promptMarksInterface, err := prompt.GetSettingFromFilename(basename, "PromptMarks") + if err != nil { + log.Println("Error retrieving PromptMarks:", err) + return getPromptMarksInternal() // 递归调用内部函数,不传递任何参数 + } + + promptMarks, ok := promptMarksInterface.([]string) + if !ok { // 检查是否断言失败 + fmt.Println("Type assertion failed for PromptMarks, fetching default") + return getPromptMarksInternal() // 递归调用内部函数,不传递任何参数 + } + + return promptMarks +} diff --git a/main.go b/main.go index 918beec..2cdc5da 100644 --- a/main.go +++ b/main.go @@ -125,6 +125,12 @@ func main() { log.Fatalf("Failed to ensure SensitiveWordsTable table exists: %v", err) } + // 故事模式存档 + err = app.EnsureCustomTableExist() + if err != nil { + log.Fatalf("Failed to ensure CustomTableExist table exists: %v", err) + } + // 加载 拦截词 err = app.ProcessSensitiveWords() if err != nil { diff --git a/structs/struct.go b/structs/struct.go index b406550..acc8aca 100644 --- a/structs/struct.go +++ b/structs/struct.go @@ -314,6 +314,10 @@ type Settings struct { WSServerToken string `yaml:"wsServerToken"` WSPath string `yaml:"wsPath"` + + PromptMarkType int `yaml:"promptMarkType"` + PromptMarksLength int `yaml:"promptMarksLength"` + PromptMarks []string `yaml:"promptMarks"` } type MetaEvent struct { @@ -370,3 +374,10 @@ type OnebotActionMessage struct { Params map[string]interface{} `json:"params"` Echo interface{} `json:"echo,omitempty"` } + +type CustomRecord struct { + UserID int64 + PromptStr string + PromptStrStat int // New integer field for storing promptstr_stat + Strs [10]string // Array to store str1 to str10 +} diff --git a/template/config_template.go b/template/config_template.go index 1478689..c2d6b17 100644 --- a/template/config_template.go +++ b/template/config_template.go @@ -78,6 +78,11 @@ settings: vectorSensitiveFilter : false #是否开启向量拦截词,请放在同目录下的vector_sensitive.txt中 一行一个,可以是句子。 命令行参数 -test 会用test.exe中的内容跑测试脚本。 vertorSensitiveThreshold : 200 #汉明距离,满足距离代表向量含义相近,可给出拦截. + #多配置覆盖,切换条件等设置 + promptMarkType : 0 #0=多个里随机一个,promptMarksLength达到时触发 1=按条件触发,promptMarksLength达到时也触发.条件格式aaaa:xxx-xxx-xxxx-xxx,aaa是promptmark中的yml,xxx是标记,识别到用户和模型说出标记就会触发这个支线(需要自行写好提示词,让llm能根据条件说出.) + promptMarksLength : 2 #promptMarkType=0时,多少轮开始切换上下文. + promptMarks : [] #prompts文件夹内的文件,一个代表一个配置文件,当promptMarkType为0是,直接是prompts文件夹内的yml名字,当为1时,格式在上面. + #混元配置项 secretId : "" #腾讯云账号(右上角)-访问管理-访问密钥,生成获取 secretKey : ""