Skip to content

Commit

Permalink
Merge 652cecd into 5a7a8c2
Browse files Browse the repository at this point in the history
  • Loading branch information
Hoshinonyaruko authored Aug 18, 2024
2 parents 5a7a8c2 + 652cecd commit 73b7c07
Show file tree
Hide file tree
Showing 11 changed files with 3,044 additions and 92 deletions.
210 changes: 210 additions & 0 deletions applogic/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,32 @@ func (app *App) EnsureCustomTableExist() error {
return nil
}

func (app *App) EnsureCustomTableExistSP() error {
createTableSQL := `
CREATE TABLE IF NOT EXISTS custom_table (
user_id TEXT PRIMARY KEY,
promptstr TEXT NOT NULL,
promptstr_stat INTEGER,
str1 TEXT,
str2 TEXT,
str3 TEXT,
str4 TEXT,
str5 TEXT,
str6 TEXT,
str7 TEXT,
str8 TEXT,
str9 TEXT,
str10 TEXT
);`

_, err := app.DB.Exec(createTableSQL)
if err != nil {
return fmt.Errorf("error creating custom_table: %w", err)
}

return nil
}

func (app *App) EnsureUserContextTableExists() error {
createTableSQL := `
CREATE TABLE IF NOT EXISTS user_context (
Expand Down Expand Up @@ -237,6 +263,39 @@ func (app *App) EnsureUserContextTableExists() error {
return nil
}

func (app *App) EnsureUserContextTableExistsSP() error {
createTableSQL := `
CREATE TABLE IF NOT EXISTS user_context (
user_id TEXT PRIMARY KEY,
conversation_id TEXT NOT NULL,
parent_message_id TEXT
);`

_, err := app.DB.Exec(createTableSQL)
if err != nil {
return fmt.Errorf("error creating user_context table: %w", err)
}

// 为 conversation_id 创建索引
createConvIDIndexSQL := `CREATE INDEX IF NOT EXISTS idx_user_context_conversation_id ON user_context(conversation_id);`

_, err = app.DB.Exec(createConvIDIndexSQL)
if err != nil {
return fmt.Errorf("error creating index on user_context(conversation_id): %w", err)
}

// 为 parent_message_id 创建索引
// 只有当您需要根据 parent_message_id 进行查询时才添加此索引
createParentMsgIDIndexSQL := `CREATE INDEX IF NOT EXISTS idx_user_context_parent_message_id ON user_context(parent_message_id);`

_, err = app.DB.Exec(createParentMsgIDIndexSQL)
if err != nil {
return fmt.Errorf("error creating index on user_context(parent_message_id): %w", err)
}

return nil
}

func (app *App) handleUserContext(userID int64) (string, string, error) {
var conversationID, parentMessageID string

Expand Down Expand Up @@ -265,6 +324,34 @@ func (app *App) handleUserContext(userID int64) (string, string, error) {
return conversationID, parentMessageID, nil
}

func (app *App) handleUserContextSP(userID string) (string, string, error) {
var conversationID, parentMessageID string

// 检查用户上下文是否存在
query := `SELECT conversation_id, parent_message_id FROM user_context WHERE user_id = ?`
err := app.DB.QueryRow(query, userID).Scan(&conversationID, &parentMessageID)
if err != nil {
if err == sql.ErrNoRows {
// 用户上下文不存在,创建新的
conversationID = utils.GenerateUUID() // 假设generateUUID()是一个生成UUID的函数
parentMessageID = ""

// 插入新的用户上下文
insertQuery := `INSERT INTO user_context (user_id, conversation_id, parent_message_id) VALUES (?, ?, ?)`
_, err = app.DB.Exec(insertQuery, userID, conversationID, parentMessageID)
if err != nil {
return "", "", err
}
} else {
// 查询过程中出现了其他错误
return "", "", err
}
}

// 返回conversationID和parentMessageID
return conversationID, parentMessageID, nil
}

func (app *App) migrateUserToNewContext(userID int64) error {
// 生成新的conversationID
newConversationID := utils.GenerateUUID() // 假设generateUUID()是一个生成UUID的函数
Expand All @@ -279,6 +366,20 @@ func (app *App) migrateUserToNewContext(userID int64) error {
return nil
}

func (app *App) migrateUserToNewContextSP(userID string) error {
// 生成新的conversationID
newConversationID := utils.GenerateUUID() // 假设GenerateUUID()是一个生成UUID的函数

// 更新用户上下文
updateQuery := `UPDATE user_context SET conversation_id = ?, parent_message_id = '' WHERE user_id = ?`
_, err := app.DB.Exec(updateQuery, newConversationID, userID)
if err != nil {
return err
}

return nil
}

func (app *App) updateUserContext(userID int64, parentMessageID string) error {
updateQuery := `UPDATE user_context SET parent_message_id = ? WHERE user_id = ?`
_, err := app.DB.Exec(updateQuery, parentMessageID, userID)
Expand All @@ -288,6 +389,15 @@ func (app *App) updateUserContext(userID int64, parentMessageID string) error {
return nil
}

func (app *App) updateUserContextSP(userID string, parentMessageID string) error {
updateQuery := `UPDATE user_context SET parent_message_id = ? WHERE user_id = ?`
_, err := app.DB.Exec(updateQuery, parentMessageID, userID)
if err != nil {
return err
}
return nil
}

func (app *App) updateUserContextPro(userID int64, conversationID, parentMessageID string) error {
updateQuery := `
UPDATE user_context
Expand All @@ -300,6 +410,18 @@ func (app *App) updateUserContextPro(userID int64, conversationID, parentMessage
return nil
}

func (app *App) updateUserContextProSP(userID string, conversationID, parentMessageID string) error {
updateQuery := `
UPDATE user_context
SET conversation_id = ?, parent_message_id = ?
WHERE user_id = ?;`
_, err := app.DB.Exec(updateQuery, conversationID, parentMessageID, userID)
if err != nil {
return fmt.Errorf("error updating user context: %w", err)
}
return nil
}

func (app *App) getHistory(conversationID, parentMessageID string) ([]structs.Message, error) {
// 如果不开启上下文
if config.GetNoContext() {
Expand Down Expand Up @@ -385,6 +507,20 @@ func (app *App) AddUserMemory(userID int64, conversationID, parentMessageID, con
return app.ensureMemoryLimit(userID)
}

func (app *App) AddUserMemorySP(userID string, conversationID, parentMessageID, conversationTitle string) error {
// 插入新的记忆
insertMemorySQL := `
INSERT INTO user_memories (user_id, conversation_id, parent_message_id, conversation_title)
VALUES (?, ?, ?, ?);`
_, err := app.DB.Exec(insertMemorySQL, userID, conversationID, parentMessageID, conversationTitle)
if err != nil {
return fmt.Errorf("error inserting new memory: %w", err)
}

// 检查并保持记忆数量不超过10条
return app.ensureMemoryLimitSP(userID)
}

func (app *App) updateConversationTitle(userID int64, conversationID, parentMessageID, newTitle string) error {
// 定义SQL更新语句
updateQuery := `
Expand All @@ -401,6 +537,22 @@ func (app *App) updateConversationTitle(userID int64, conversationID, parentMess
return nil
}

func (app *App) updateConversationTitleSP(userID string, conversationID, parentMessageID, newTitle string) error {
// 定义SQL更新语句
updateQuery := `
UPDATE user_memories
SET conversation_title = ?
WHERE user_id = ? AND conversation_id = ? AND parent_message_id = ?;`

// 执行SQL更新操作
_, err := app.DB.Exec(updateQuery, newTitle, userID, conversationID, parentMessageID)
if err != nil {
return fmt.Errorf("error updating conversation title: %w", err)
}

return nil
}

func (app *App) ensureMemoryLimit(userID int64) error {
// 查询当前记忆总数
countQuerySQL := `SELECT COUNT(*) FROM user_memories WHERE user_id = ?;`
Expand Down Expand Up @@ -430,6 +582,35 @@ func (app *App) ensureMemoryLimit(userID int64) error {
return nil
}

func (app *App) ensureMemoryLimitSP(userID string) error {
// 查询当前记忆总数
countQuerySQL := `SELECT COUNT(*) FROM user_memories WHERE user_id = ?;`
var count int
row := app.DB.QueryRow(countQuerySQL, userID)
err := row.Scan(&count)
if err != nil {
return fmt.Errorf("error counting memories: %w", err)
}

// 如果记忆超过5条,则删除最旧的记忆
if count > 5 {
deleteOldestMemorySQL := `
DELETE FROM user_memories
WHERE memory_id IN (
SELECT memory_id FROM user_memories
WHERE user_id = ?
ORDER BY memory_id ASC
LIMIT ?
);`
_, err := app.DB.Exec(deleteOldestMemorySQL, userID, count-5)
if err != nil {
return fmt.Errorf("error deleting old memories: %w", err)
}
}

return nil
}

func (app *App) GetUserMemories(userID int64) ([]structs.Memory, error) {
// 定义查询SQL,获取所有相关的记忆
querySQL := `
Expand Down Expand Up @@ -458,3 +639,32 @@ func (app *App) GetUserMemories(userID int64) ([]structs.Memory, error) {

return memories, nil
}

func (app *App) GetUserMemoriesSP(userID string) ([]structs.Memory, error) {
// 定义查询SQL,获取所有相关的记忆
querySQL := `
SELECT conversation_id, parent_message_id, conversation_title
FROM user_memories
WHERE user_id = ?;
`
rows, err := app.DB.Query(querySQL, userID)
if err != nil {
return nil, fmt.Errorf("error querying user memories: %w", err)
}
defer rows.Close() // 确保关闭rows以释放数据库资源

var memories []structs.Memory
for rows.Next() {
var m structs.Memory
if err := rows.Scan(&m.ConversationID, &m.ParentMessageID, &m.ConversationTitle); err != nil {
return nil, fmt.Errorf("error scanning memory: %w", err)
}
memories = append(memories, m)
}

if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error during rows iteration: %w", err)
}

return memories, nil
}
Loading

0 comments on commit 73b7c07

Please sign in to comment.