Skip to content

Commit

Permalink
beta119
Browse files Browse the repository at this point in the history
  • Loading branch information
Hoshinonyaruko committed May 17, 2024
1 parent a0d9d93 commit 81c5f60
Show file tree
Hide file tree
Showing 10 changed files with 1,257 additions and 156 deletions.
266 changes: 246 additions & 20 deletions applogic/hunyuan.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,22 @@ func (app *App) ChatHandlerHunyuan(w http.ResponseWriter, r *http.Request) {
}
}
} else {
history, err = prompt.GetMessagesFromFilename(promptstr)
if err != nil {
fmtf.Printf("prompt.GetMessagesFromFilename error: %v\n", err)
// 默认执行 正常提示词顺序
if !config.GetEnhancedQA(promptstr) {
history, err = prompt.GetMessagesFromFilename(promptstr)
if err != nil {
fmtf.Printf("prompt.GetMessagesFromFilename error: %v\n", err)
}
} else {
// 只获取系统提示词
systemMessage, err := prompt.GetFirstSystemMessageStruct(promptstr)
if err != nil {
fmt.Println("Error:", err)
} else {
// 如果找到system消息,将其添加到历史数组中
history = append(history, systemMessage)
fmt.Println("Added system message back to history.")
}
}
}

Expand All @@ -137,31 +150,38 @@ func (app *App) ChatHandlerHunyuan(w http.ResponseWriter, r *http.Request) {
// 获取系统级预埋的系统自定义QA对
systemHistory, err := prompt.GetMessagesExcludingSystem(promptstr)
if err != nil {
fmtf.Printf("Error getting system history: %v\n", err)
fmtf.Printf("Error getting system history: %v,promptstr[%v]\n", err, promptstr)
return
}

// 处理增强QA逻辑
if config.GetEnhancedQA(promptstr) {
// 确保系统历史与用户或助手历史数量一致,如果不足,则补足空的历史记录
// 因为最后一个成员让给当前QA,所以-1
if len(systemHistory)-2 > len(userHistory) {
difference := len(systemHistory) - len(userHistory)
systemHistory, err := prompt.GetMessagesExcludingSystem(promptstr)
if err != nil {
fmt.Printf("Error getting system history: %v\n", err)
return
}

// 计算需要补足的历史记录数量
neededHistoryCount := len(systemHistory) - 2 // 最后两条留给当前QA处理
if neededHistoryCount > len(userHistory) {
// 补足用户或助手历史
difference := neededHistoryCount - len(userHistory)
for i := 0; i < difference; i++ {
userHistory = append(userHistory, structs.Message{Text: "", Role: "user"})
userHistory = append(userHistory, structs.Message{Text: "", Role: "assistant"})
if i%2 != 0 {
userHistory = append(userHistory, structs.Message{Text: "", Role: "user"})
} else {
userHistory = append(userHistory, structs.Message{Text: "", Role: "assistant"})
}
}
}

// 如果系统历史中只有一个成员,跳过覆盖逻辑,留给后续处理
if len(systemHistory) > 1 {
// 将系统历史(除最后2个成员外)附加到相应的用户或助手历史上,采用倒序方式处理最近的记录
for i := 0; i < len(systemHistory)-2; i++ {
sysMsg := systemHistory[i]
index := len(userHistory) - len(systemHistory) + i
if index >= 0 && index < len(userHistory) && (userHistory[index].Role == "user" || userHistory[index].Role == "assistant") {
userHistory[index].Text += fmt.Sprintf(" (%s)", sysMsg.Text)
}
// 附加系统历史到用户或助手历史,除了最后两条
for i := 0; i < len(systemHistory)-2; i++ {
sysMsg := systemHistory[i]
index := len(userHistory) - neededHistoryCount + i
if index >= 0 && index < len(userHistory) {
userHistory[index].Text += fmt.Sprintf(" (%s)", sysMsg.Text)
}
}
} else {
Expand All @@ -181,6 +201,16 @@ func (app *App) ChatHandlerHunyuan(w http.ResponseWriter, r *http.Request) {
if config.GetHunyuanType() == 0 {
// 构建 hunyuan 请求
request := hunyuan.NewChatProRequest()
// 配置块
request.StreamModeration = new(bool)
*request.StreamModeration = config.GetHunyuanStreamModeration(promptstr)
request.Stream = new(bool)
*request.Stream = config.GetHunyuanStreamModeration(promptstr)
request.TopP = new(float64)
*request.TopP = config.GetTopPHunyuan(promptstr)
request.Temperature = new(float64)
*request.Temperature = config.GetTemperatureHunyuan(promptstr)

// 添加历史信息
for _, hMsg := range history {
content := hMsg.Text // 创建新变量
Expand Down Expand Up @@ -335,9 +365,20 @@ func (app *App) ChatHandlerHunyuan(w http.ResponseWriter, r *http.Request) {
fmtf.Fprintf(w, "data: %s\n\n", string(finalResponseJSON))
flusher.Flush()
}
} else {
} else if config.GetHunyuanType() == 1 {
// 构建 hunyuan 标准版请求
request := hunyuan.NewChatStdRequest()

// 配置块
request.StreamModeration = new(bool)
*request.StreamModeration = config.GetHunyuanStreamModeration(promptstr)
request.Stream = new(bool)
*request.Stream = config.GetHunyuanStreamModeration(promptstr)
request.TopP = new(float64)
*request.TopP = config.GetTopPHunyuan(promptstr)
request.Temperature = new(float64)
*request.Temperature = config.GetTemperatureHunyuan(promptstr)

// 添加历史信息
for _, hMsg := range history {
content := hMsg.Text // 创建新变量
Expand Down Expand Up @@ -493,6 +534,191 @@ func (app *App) ChatHandlerHunyuan(w http.ResponseWriter, r *http.Request) {
flusher.Flush()

}
} else if config.GetHunyuanType() == 2 || config.GetHunyuanType() == 3 || config.GetHunyuanType() == 4 || config.GetHunyuanType() == 5 {
// 构建 hunyuan 请求
request := hunyuan.NewChatCompletionsRequest()
// 添加历史信息
for _, hMsg := range history {
content := hMsg.Text // 创建新变量
role := hMsg.Role // 创建新变量
hunyuanMsg := hunyuan.Message{
Content: &content, // 引用新变量的地址
Role: &role, // 引用新变量的地址
}
request.Messages = append(request.Messages, &hunyuanMsg)
}

// 添加当前用户消息
currentUserContent := msg.Text // 创建新变量
currentUserRole := msg.Role // 创建新变量
currentUserMsg := hunyuan.Message{
Content: &currentUserContent, // 引用新变量的地址
Role: &currentUserRole, // 引用新变量的地址
}
request.Messages = append(request.Messages, &currentUserMsg)

// 获取HunyuanType并设置对应的Model
switch config.GetHunyuanType() {
case 2:
request.Model = new(string)
*request.Model = "hunyuan-lite"
case 3:
request.Model = new(string)
*request.Model = "hunyuan-standard"
case 4:
request.Model = new(string)
*request.Model = "hunyuan-standard-256K"
case 5:
request.Model = new(string)
*request.Model = "hunyuan-pro"
default:
request.Model = new(string)
*request.Model = "default-value"
}
fmtf.Printf("请求的混元模型类型:%v", *request.Model)
request.StreamModeration = new(bool)
*request.StreamModeration = config.GetHunyuanStreamModeration(promptstr)
request.Stream = new(bool)
*request.Stream = config.GetHunyuanStreamModeration(promptstr)
request.TopP = new(float64)
*request.TopP = config.GetTopPHunyuan(promptstr)
request.Temperature = new(float64)
*request.Temperature = config.GetTemperatureHunyuan(promptstr)

// 打印请求以进行调试
utils.PrintChatCompletionsRequest(request)

// 发送请求并获取响应
response, err := app.Client.ChatCompletions(request)
if err != nil {
http.Error(w, fmtf.Sprintf("hunyuanapi返回错误: %v", err), http.StatusInternalServerError)
return
}
if !config.GetuseSse(promptstr) {
// 解析响应
var responseTextBuilder strings.Builder
var totalUsage structs.UsageInfo
for event := range response.BaseSSEResponse.Events {
if event.Err != nil {
http.Error(w, fmtf.Sprintf("接收事件时发生错误: %v", event.Err), http.StatusInternalServerError)
return
}

// 解析事件数据
var eventData map[string]interface{}
if err := json.Unmarshal(event.Data, &eventData); err != nil {
http.Error(w, fmtf.Sprintf("解析事件数据出错: %v", err), http.StatusInternalServerError)
return
}

// 使用extractEventDetails函数提取信息
responseText, usageInfo := utils.ExtractEventDetails(eventData)
responseTextBuilder.WriteString(responseText)
totalUsage.PromptTokens += usageInfo.PromptTokens
totalUsage.CompletionTokens += usageInfo.CompletionTokens
}
// 现在responseTextBuilder中的内容是所有AI助手回复的组合
responseText := responseTextBuilder.String()

assistantMessageID, err := app.addMessage(structs.Message{
ConversationID: msg.ConversationID,
ParentMessageID: userMessageID,
Text: responseText,
Role: "assistant",
})

if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

// 构造响应
responseMap := map[string]interface{}{
"response": responseText,
"conversationId": msg.ConversationID,
"messageId": assistantMessageID,
"details": map[string]interface{}{
"usage": totalUsage,
},
}

json.NewEncoder(w).Encode(responseMap)
} else {
// 设置SSE相关的响应头部
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")

flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "Streaming unsupported!", http.StatusInternalServerError)
return
}

var responseTextBuilder strings.Builder
var totalUsage structs.UsageInfo

for event := range response.BaseSSEResponse.Events {
if event.Err != nil {
fmtf.Fprintf(w, "data: %s\n\n", fmtf.Sprintf("接收事件时发生错误: %v", event.Err))
flusher.Flush()
continue
}

// 解析事件数据和提取信息
var eventData map[string]interface{}
if err := json.Unmarshal(event.Data, &eventData); err != nil {
fmtf.Fprintf(w, "data: %s\n\n", fmtf.Sprintf("解析事件数据出错: %v", err))
flusher.Flush()
continue
}

responseText, usageInfo := utils.ExtractEventDetails(eventData)
responseTextBuilder.WriteString(responseText)
totalUsage.PromptTokens += usageInfo.PromptTokens
totalUsage.CompletionTokens += usageInfo.CompletionTokens

// 发送当前事件的响应数据,但不包含assistantMessageID
//fmtf.Printf("发送当前事件的响应数据,但不包含assistantMessageID\n")
tempResponseMap := map[string]interface{}{
"response": responseText,
"conversationId": msg.ConversationID,
"details": map[string]interface{}{
"usage": usageInfo,
},
}
tempResponseJSON, _ := json.Marshal(tempResponseMap)
fmtf.Fprintf(w, "data: %s\n\n", string(tempResponseJSON))
flusher.Flush()
}

// 处理完所有事件后,生成并发送包含assistantMessageID的最终响应
responseText := responseTextBuilder.String()
fmtf.Printf("处理完所有事件后,生成并发送包含assistantMessageID的最终响应:%v\n", responseText)
assistantMessageID, err := app.addMessage(structs.Message{
ConversationID: msg.ConversationID,
ParentMessageID: userMessageID,
Text: responseText,
Role: "assistant",
})

if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

finalResponseMap := map[string]interface{}{
"response": responseText,
"conversationId": msg.ConversationID,
"messageId": assistantMessageID,
"details": map[string]interface{}{
"usage": totalUsage,
},
}
finalResponseJSON, _ := json.Marshal(finalResponseMap)
fmtf.Fprintf(w, "data: %s\n\n", string(finalResponseJSON))
flusher.Flush()
}
}

}
Expand Down
39 changes: 23 additions & 16 deletions applogic/rwkv.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func (app *App) ChatHandlerRwkv(w http.ResponseWriter, r *http.Request) {
}
} else {
// 只获取系统提示词
systemMessage, err := prompt.FindFirstSystemMessage(history)
systemMessage, err := prompt.GetFirstSystemMessageStruct(promptstr)
if err != nil {
fmt.Println("Error:", err)
} else {
Expand Down Expand Up @@ -162,25 +162,32 @@ func (app *App) ChatHandlerRwkv(w http.ResponseWriter, r *http.Request) {

// 处理增强QA逻辑
if config.GetEnhancedQA(promptstr) {
// 确保系统历史与用户或助手历史数量一致,如果不足,则补足空的历史记录
// 因为最后一个成员让给当前QA,所以-1
if len(systemHistory)-2 > len(userHistory) {
difference := len(systemHistory) - len(userHistory)
systemHistory, err := prompt.GetMessagesExcludingSystem(promptstr)
if err != nil {
fmt.Printf("Error getting system history: %v\n", err)
return
}

// 计算需要补足的历史记录数量
neededHistoryCount := len(systemHistory) - 2 // 最后两条留给当前QA处理
if neededHistoryCount > len(userHistory) {
// 补足用户或助手历史
difference := neededHistoryCount - len(userHistory)
for i := 0; i < difference; i++ {
userHistory = append(userHistory, structs.Message{Text: "", Role: "user"})
userHistory = append(userHistory, structs.Message{Text: "", Role: "assistant"})
if i%2 != 0 {
userHistory = append(userHistory, structs.Message{Text: "", Role: "user"})
} else {
userHistory = append(userHistory, structs.Message{Text: "", Role: "assistant"})
}
}
}

// 如果系统历史中只有一个成员,跳过覆盖逻辑,留给后续处理
if len(systemHistory) > 1 {
// 将系统历史(除最后2个成员外)附加到相应的用户或助手历史上,采用倒序方式处理最近的记录
for i := 0; i < len(systemHistory)-2; i++ {
sysMsg := systemHistory[i]
index := len(userHistory) - len(systemHistory) + i
if index >= 0 && index < len(userHistory) && (userHistory[index].Role == "user" || userHistory[index].Role == "assistant") {
userHistory[index].Text += fmt.Sprintf(" (%s)", sysMsg.Text)
}
// 附加系统历史到用户或助手历史,除了最后两条
for i := 0; i < len(systemHistory)-2; i++ {
sysMsg := systemHistory[i]
index := len(userHistory) - neededHistoryCount + i
if index >= 0 && index < len(userHistory) {
userHistory[index].Text += fmt.Sprintf(" (%s)", sysMsg.Text)
}
}
} else {
Expand Down
4 changes: 3 additions & 1 deletion applogic/tongyiqianwen.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func (app *App) ChatHandlerTyqw(w http.ResponseWriter, r *http.Request) {
}
}
}
// TODO: msgid是空的开始第一句也要处理 插入

// 获取历史信息
if msg.ParentMessageID != "" {
userhistory, err := app.getHistory(msg.ConversationID, msg.ParentMessageID)
Expand All @@ -159,6 +159,8 @@ func (app *App) ChatHandlerTyqw(w http.ResponseWriter, r *http.Request) {
fmtf.Printf("Error getting system history: %v,promptstr[%v]\n", err, promptstr)
return
}

// 处理增强QA逻辑
if config.GetEnhancedQA(promptstr) {
systemHistory, err := prompt.GetMessagesExcludingSystem(promptstr)
if err != nil {
Expand Down
Loading

0 comments on commit 81c5f60

Please sign in to comment.