Skip to content

Commit

Permalink
Merge b27f585 into 3b6c2ed
Browse files Browse the repository at this point in the history
  • Loading branch information
Hoshinonyaruko authored Apr 25, 2024
2 parents 3b6c2ed + b27f585 commit 9bae728
Show file tree
Hide file tree
Showing 7 changed files with 381 additions and 6 deletions.
26 changes: 26 additions & 0 deletions applogic/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,32 @@ func (app *App) EnsureQATableExist() error {
return nil
}

func (app *App) EnsureCustomTableExist() error {
createTableSQL := `
CREATE TABLE IF NOT EXISTS custom_table (
user_id INTEGER PRIMARY KEY,
promptstr TEXT NOT NULL,
promptstr_stat INTEGER,
str1 TEXT,
str2 TEXT,
str3 TEXT,
str4 TEXT,
str5 TEXT,
str6 TEXT,
str7 TEXT,
str8 TEXT,
str9 TEXT,
str10 TEXT
);`

_, err := app.DB.Exec(createTableSQL)
if err != nil {
return fmt.Errorf("error creating custom_table: %w", err)
}

return nil
}

func (app *App) EnsureUserContextTableExists() error {
createTableSQL := `
CREATE TABLE IF NOT EXISTS user_context (
Expand Down
73 changes: 67 additions & 6 deletions applogic/gensokyo.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,68 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) {
return
}

// 读取URL参数 "prompt"
promptstr := r.URL.Query().Get("prompt")
if promptstr != "" {
// 使用 prompt 变量进行后续处理
fmt.Printf("收到prompt参数: %s\n", promptstr)
// 从数据库读取用户的剧情存档
CustomRecord, err := app.FetchCustomRecord(message.UserID)
if err != nil {
fmt.Printf("app.FetchCustomRecord 出错: %s\n", err)
}

var promptstr string
if CustomRecord != nil {
// 提示词参数
if CustomRecord.PromptStr == "" {
// 读取URL参数 "prompt"
promptstr = r.URL.Query().Get("prompt")
if promptstr != "" {
// 使用 prompt 变量进行后续处理
fmt.Printf("收到prompt参数: %s\n", promptstr)
}
} else {
promptstr = CustomRecord.PromptStr
fmt.Printf("刷新prompt参数: %s,newPromptStrStat:%d\n", promptstr, CustomRecord.PromptStrStat-1)
newPromptStrStat := CustomRecord.PromptStrStat - 1
err = app.InsertCustomTableRecord(message.UserID, promptstr, newPromptStrStat)
if err != nil {
fmt.Printf("app.InsertCustomTableRecord 出错: %s\n", err)
}
}

// 提示词之间流转 达到信号量
markType := config.GetPromptMarkType(promptstr)
if (markType == 0 || markType == 1) && (CustomRecord.PromptStrStat-1 <= 0) {
PromptMarks := config.GetPromptMarks(promptstr)
if len(PromptMarks) != 0 {
randomIndex := rand.Intn(len(PromptMarks))
newPromptStr := PromptMarks[randomIndex]

// 如果 markType 是 1,提取 "aaa" 部分
if markType == 1 {
parts := strings.Split(newPromptStr, ":")
if len(parts) > 0 {
newPromptStr = parts[0] // 取冒号前的部分作为新的提示词
}
}

// 刷新新的提示词给用户目前的状态
// 获取新的信号长度
PromptMarksLength := config.GetPromptMarksLength(newPromptStr)

app.InsertCustomTableRecord(message.UserID, newPromptStr, PromptMarksLength)
fmt.Printf("流转prompt参数: %s,newPromptStrStat:%d\n", newPromptStr, PromptMarksLength)
}
}
} else {
// 读取URL参数 "prompt"
promptstr = r.URL.Query().Get("prompt")
if promptstr != "" {
// 使用 prompt 变量进行后续处理
fmt.Printf("收到prompt参数: %s\n", promptstr)
}
PromptMarksLength := config.GetPromptMarksLength(promptstr)
err = app.InsertCustomTableRecord(message.UserID, promptstr, PromptMarksLength)
if err != nil {
fmt.Printf("app.InsertCustomTableRecord 出错: %s\n", err)
}
}

// 读取URL参数 "selfid"
Expand Down Expand Up @@ -168,6 +225,8 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) {
} else {
utils.SendGroupMessage(message.GroupID, message.UserID, RestoreResponse, selfid)
}
// 处理故事情节的重置
app.deleteCustomRecord(message.UserID)
return
}

Expand Down Expand Up @@ -519,7 +578,9 @@ func (app *App) GensokyoHandler(w http.ResponseWriter, r *http.Request) {
}
}
}
//清空之前加入缓存
// 处理故事模式
app.ProcessAnswer(message.UserID, response, promptstr)
// 清空之前加入缓存
// 缓存省钱部分 这里默认不被覆盖,如果主配置开了缓存,始终缓存.
if config.GetUseCache() {
if response != "" {
Expand Down
164 changes: 164 additions & 0 deletions applogic/promptstr.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
package applogic

import (
"database/sql"
"fmt"
"strconv"
"strings"

"github.com/hoshinonyaruko/gensokyo-llm/config"
"github.com/hoshinonyaruko/gensokyo-llm/structs"
)

func (app *App) InsertCustomTableRecord(userID int64, promptStr string, promptStrStat int, strs ...string) error {
// 构建 SQL 语句,使用 UPSERT 逻辑
sqlStr := `
INSERT INTO custom_table (user_id, promptstr, promptstr_stat, str1, str2, str3, str4, str5, str6, str7, str8, str9, str10)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(user_id) DO UPDATE SET
promptstr = excluded.promptstr,
promptstr_stat = excluded.promptstr_stat`

// 为每个非nil str构建更新部分
updateParts := make([]string, 10)
params := make([]interface{}, 13)
params[0] = userID
params[1] = promptStr
params[2] = promptStrStat

for i, str := range strs {
if i < 10 {
params[i+3] = str
if str != "" { // 只更新非空的str字段
fieldName := fmt.Sprintf("str%d", i+1)
updateParts[i] = fmt.Sprintf("%s = excluded.%s", fieldName, fieldName)
}
}
}

// 添加非空更新字段到SQL语句
nonEmptyUpdates := []string{}
for _, part := range updateParts {
if part != "" {
nonEmptyUpdates = append(nonEmptyUpdates, part)
}
}
if len(nonEmptyUpdates) > 0 {
sqlStr += ", " + strings.Join(nonEmptyUpdates, ", ")
}

sqlStr += ";" // 结束 SQL 语句

// 填充剩余的nil值
for j := len(strs) + 3; j < 13; j++ {
params[j] = nil
}

// 执行 SQL 操作
_, err := app.DB.Exec(sqlStr, params...)
if err != nil {
return fmt.Errorf("error inserting or updating record in custom_table: %w", err)
}

return nil
}

func (app *App) FetchCustomRecord(userID int64, fields ...string) (*structs.CustomRecord, error) {
// Default fields now include promptstr_stat
queryFields := "user_id, promptstr, promptstr_stat"
if len(fields) > 0 {
queryFields += ", " + strings.Join(fields, ", ")
}

// Construct the SQL query string
queryStr := fmt.Sprintf("SELECT %s FROM custom_table WHERE user_id = ?", queryFields)

row := app.DB.QueryRow(queryStr, userID)
var record structs.CustomRecord
// Initialize scan parameters including the new promptstr_stat
scanArgs := []interface{}{&record.UserID, &record.PromptStr, &record.PromptStrStat}
for i := 0; i < len(fields); i++ {
idx := fieldIndex(fields[i])
if idx >= 0 {
scanArgs = append(scanArgs, &record.Strs[idx])
}
}

err := row.Scan(scanArgs...)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil // No record found
}
return nil, fmt.Errorf("error scanning custom_table record: %w", err)
}

return &record, nil
}

func (app *App) deleteCustomRecord(userID int64) error {
deleteSQL := `DELETE FROM custom_table WHERE user_id = ?;`

_, err := app.DB.Exec(deleteSQL, userID)
if err != nil {
return fmt.Errorf("error deleting record from custom_table: %w", err)
}

return nil
}

// Helper function to get index from field name
func fieldIndex(field string) int {
if strings.HasPrefix(field, "str") && len(field) > 3 {
if idx, err := strconv.Atoi(field[3:]); err == nil && idx >= 1 && idx <= 10 {
return idx - 1
}
}
return -1
}

func (app *App) ProcessAnswer(userID int64, answer string, promptStr string) {
// 根据 promptStr 获取 PromptMarkType
markType := config.GetPromptMarkType(promptStr)

// 如果 markType 是 0,则不执行任何操作
if markType == 0 {
return
}

// 如果 markType 是 1,执行以下操作
if markType == 1 {
// 获取 PromptMarks
PromptMarks := config.GetPromptMarks(promptStr)

for _, mark := range PromptMarks {
// 提取冒号右侧的文本,并转换为数组
parts := strings.Split(mark, ":")
if len(parts) < 2 {
continue
}
codes := strings.Split(parts[1], "-")

// 检查 answer 是否包含数组中的任意一个成员
for _, code := range codes {
if strings.Contains(answer, code) {
// 当找到匹配时,构建新的 promptStr
newPromptStr := parts[0]

// 获取 PromptMarksLength
PromptMarksLength := config.GetPromptMarksLength(newPromptStr)

// 插入记录到自定义表
err := app.InsertCustomTableRecord(userID, newPromptStr, PromptMarksLength)
if err != nil {
fmt.Println("Error inserting custom table record:", err)
return
}

// 输出结果
fmt.Printf("type1=流转prompt参数: %s, newPromptStrStat: %d\n", newPromptStr, PromptMarksLength)
return // 停止循环
}
}
}
}
}
102 changes: 102 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -1402,3 +1402,105 @@ func getStandardGptApiInternal(options ...string) bool {

return standardGptApi
}

// 获取 PromptMarkType
func GetPromptMarkType(options ...string) int {
mu.Lock()
defer mu.Unlock()
return getPromptMarkTypeInternal(options...)
}

// 内部逻辑执行函数,不处理锁,可以安全地递归调用
func getPromptMarkTypeInternal(options ...string) int {
// 检查是否有参数传递进来,以及是否为空字符串
if len(options) == 0 || options[0] == "" {
if instance != nil {
return instance.Settings.PromptMarkType
}
return 0 // 默认返回 0 或一个合理的默认值
}

// 使用传入的 basename
basename := options[0]
promptMarkTypeInterface, err := prompt.GetSettingFromFilename(basename, "PromptMarkType")
if err != nil {
log.Println("Error retrieving PromptMarkType:", err)
return getPromptMarkTypeInternal() // 递归调用内部函数,不传递任何参数
}

promptMarkType, ok := promptMarkTypeInterface.(int)
if !ok { // 检查是否断言失败
fmt.Println("Type assertion failed for PromptMarkType, fetching default")
return getPromptMarkTypeInternal() // 递归调用内部函数,不传递任何参数
}

return promptMarkType
}

// 获取 PromptMarksLength
func GetPromptMarksLength(options ...string) int {
mu.Lock()
defer mu.Unlock()
return getPromptMarksLengthInternal(options...)
}

// 内部逻辑执行函数,不处理锁,可以安全地递归调用
func getPromptMarksLengthInternal(options ...string) int {
// 检查是否有参数传递进来,以及是否为空字符串
if len(options) == 0 || options[0] == "" {
if instance != nil {
return instance.Settings.PromptMarksLength
}
return 0 // 默认返回 0 或一个合理的默认值
}

// 使用传入的 basename
basename := options[0]
promptMarksLengthInterface, err := prompt.GetSettingFromFilename(basename, "PromptMarksLength")
if err != nil {
log.Println("Error retrieving PromptMarksLength:", err)
return getPromptMarksLengthInternal() // 递归调用内部函数,不传递任何参数
}

promptMarksLength, ok := promptMarksLengthInterface.(int)
if !ok { // 检查是否断言失败
fmt.Println("Type assertion failed for PromptMarksLength, fetching default")
return getPromptMarksLengthInternal() // 递归调用内部函数,不传递任何参数
}

return promptMarksLength
}

// 获取 PromptMarks
func GetPromptMarks(options ...string) []string {
mu.Lock()
defer mu.Unlock()
return getPromptMarksInternal(options...)
}

// 内部逻辑执行函数,不处理锁,可以安全地递归调用
func getPromptMarksInternal(options ...string) []string {
// 检查是否有参数传递进来,以及是否为空字符串
if len(options) == 0 || options[0] == "" {
if instance != nil {
return instance.Settings.PromptMarks
}
return nil // 如果实例或设置未定义,返回nil
}

// 使用传入的 basename
basename := options[0]
promptMarksInterface, err := prompt.GetSettingFromFilename(basename, "PromptMarks")
if err != nil {
log.Println("Error retrieving PromptMarks:", err)
return getPromptMarksInternal() // 递归调用内部函数,不传递任何参数
}

promptMarks, ok := promptMarksInterface.([]string)
if !ok { // 检查是否断言失败
fmt.Println("Type assertion failed for PromptMarks, fetching default")
return getPromptMarksInternal() // 递归调用内部函数,不传递任何参数
}

return promptMarks
}
Loading

0 comments on commit 9bae728

Please sign in to comment.