kefu/ws/chatgpt.go

345 lines
10 KiB
Go
Raw Normal View History

2024-12-10 02:50:12 +00:00
package ws
import (
"encoding/json"
"errors"
"fmt"
"github.com/tidwall/gjson"
"kefu/lib"
"kefu/models"
"kefu/tools"
"log"
)
// 反向提问
func GptRevQuestion(entId, visitorId string, kefuInfo models.User, question string) string {
config := models.GetEntConfigsMap(entId, "chatGPTUrl", "chatGPTSecret", "chatGPTRevQuestion", "QdrantHistoryNum", "BigModelName")
if config["chatGPTRevQuestion"] == "" || config["chatGPTUrl"] == "" || config["chatGPTSecret"] == "" {
return ""
}
//大模型名称
bigModelName := config["BigModelName"]
if bigModelName == "" {
bigModelName = "gpt-3.5-turbo-16k"
}
//是否传递历史记录
gptMessages := make([]lib.Gpt3Dot5Message, 0)
gptMessages = append(gptMessages, lib.Gpt3Dot5Message{
Role: "system",
Content: config["chatGPTRevQuestion"],
})
num := 5
if config["QdrantHistoryNum"] != "" {
num = tools.Str2Int(config["QdrantHistoryNum"])
}
messages := models.FindMessageByQueryPage(1, uint(num), "visitor_id = ?", visitorId)
for i := len(messages) - 1; i >= 1; i-- {
reqContent := messages[i].Content
if messages[i].MesType == "visitor" {
gptMessages = append(gptMessages, lib.Gpt3Dot5Message{
Role: "user",
Content: reqContent,
})
} else {
gptMessages = append(gptMessages, lib.Gpt3Dot5Message{
Role: "assistant",
Content: reqContent,
})
}
}
gptMessages = append(gptMessages, lib.Gpt3Dot5Message{
Role: "user",
Content: config["chatGPTRevQuestion"] + question,
})
openaiUrl := config["chatGPTUrl"]
openaiKey := config["chatGPTSecret"]
gpt := lib.NewChatGptTool(openaiUrl, openaiKey)
log.Println("反向提问调用GPT请求参数", gptMessages)
content, _ := gpt.ChatGPT3Dot5Turbo(gptMessages, bigModelName)
if content == "" {
return ""
}
VisitorMessage(visitorId, content, kefuInfo)
go KefuMessage(visitorId, content, kefuInfo)
models.CreateMessage(kefuInfo.Name, visitorId, content, "kefu", entId, "read")
return content
}
// 调用GPT-3.5流式响应
func Gpt3Knowledge(entId, visitorId string, kefuInfo models.User, content string) string {
config := models.GetEntConfigsMap(entId, "chatGPTUrl", "chatGPTSecret", "QdrantAIStatus", "QdrantAICollect", "chatGPTSystem", "chatGPTPrompt",
"RobotName", "RobotAvator", "QdrantScore", "QdrantHistoryNum",
"BaiduCheckApiKey", "BaiduCheckSecretKey", "BaiduCheckDisableStream",
"BigModelName", "EmbeddingModelName", "SearchEmptyInterrupt",
"chatGPTHistory")
lib.QdrantBase = models.FindConfig("QdrantBase")
lib.QdrantPort = models.FindConfig("QdrantPort")
_, ok := ClientList[visitorId]
messageContent := ""
//没有开启AI回复
if config["QdrantAIStatus"] != "true" || config["chatGPTUrl"] == "" || config["chatGPTSecret"] == "" {
return ""
}
//大模型名称
bigModelName := config["BigModelName"]
if bigModelName == "" {
bigModelName = "gpt-3.5-turbo-16k"
}
//调用百度文本审核用户问题
isBaiduCheck := false
if config["BaiduCheckApiKey"] != "" && config["BaiduCheckSecretKey"] != "" {
isBaiduCheck = true
b := lib.BaiduCheck{
API_KEY: config["BaiduCheckApiKey"],
SECRET_KEY: config["BaiduCheckSecretKey"],
}
r, err := b.Check(content)
log.Println("百度审核结果:", content, r)
ok := gjson.Get(r, "conclusion").String()
if err == nil && ok != "合规" {
messageContent = "对不起,内容包含敏感信息!"
VisitorMessage(visitorId, messageContent, kefuInfo)
return messageContent
}
}
kefuInfo.Nickname = tools.Ifelse(config["RobotName"] != "", config["RobotName"], kefuInfo.Nickname).(string)
kefuInfo.Avator = tools.Ifelse(config["RobotAvator"] != "", config["RobotAvator"], kefuInfo.Avator).(string)
openaiUrl := config["chatGPTUrl"]
openaiKey := config["chatGPTSecret"]
collectName := config["QdrantAICollect"]
embeddingModelName := config["EmbeddingModelName"]
promptMessage := config["chatGPTPrompt"]
if promptMessage != "" {
content = promptMessage + content
}
system := ""
if config["chatGPTSystem"] != "" {
system = config["chatGPTSystem"]
}
system = fmt.Sprintf("%s\n当前时间%s", system, tools.GetNowTime())
//是否传递历史记录
gptMessages := make([]lib.Gpt3Dot5Message, 0)
history := make([]lib.CozeCompletionMessage, 0)
num := 5
if config["QdrantHistoryNum"] != "" {
num = tools.Str2Int(config["QdrantHistoryNum"])
}
messages := models.FindMessageByQueryPage(1, uint(num), "visitor_id = ?", visitorId)
for i := len(messages) - 1; i >= 1; i-- {
reqContent := messages[i].Content
if messages[i].MesType == "visitor" {
gptMessages = append(gptMessages, lib.Gpt3Dot5Message{
Role: "user",
Content: reqContent,
})
//扣子智能体
if bigModelName == "coze" {
history = append(history, lib.CozeCompletionMessage{
Role: "user",
Content: reqContent,
ContentType: "text",
})
}
} else {
gptMessages = append(gptMessages, lib.Gpt3Dot5Message{
Role: "assistant",
Content: reqContent,
})
//扣子智能体
if bigModelName == "coze" {
history = append(history, lib.CozeCompletionMessage{
Role: "assistant",
Content: reqContent,
ContentType: "text",
Type: "answer",
})
}
}
}
//扣子智能体
if bigModelName == "coze" {
coze := &lib.Coze{
BOT_ID: config["chatGPTUrl"],
API_KEY: config["chatGPTSecret"],
}
messageContent, _ := coze.ChatCoze(visitorId, visitorId, content, history)
if ok && messageContent != "" {
VisitorMessage(visitorId, messageContent, kefuInfo)
go KefuMessage(visitorId, messageContent, kefuInfo)
models.CreateMessage(kefuInfo.Name, visitorId, messageContent, "kefu", entId, "read")
}
return messageContent
}
gpt := lib.NewChatGptTool(openaiUrl, openaiKey)
score := 0.78
if config["QdrantScore"] != "" {
score = tools.Str2Float64(config["QdrantScore"])
}
searchEmptyInterrupt := false
if config["SearchEmptyInterrupt"] == "true" {
searchEmptyInterrupt = true
}
message, err, _ := GptEmbedding(openaiUrl, openaiKey, embeddingModelName, collectName, content, system, score, searchEmptyInterrupt, gptMessages)
if err != nil {
log.Println(err)
return ""
}
log.Println("调用GPT请求参数", message)
stream, err := gpt.ChatGPTHttpStream(message, bigModelName, visitorId)
if err != nil {
log.Println(err)
return ""
}
msgId := tools.Now()
for {
response, err := stream.Recv()
if err != nil {
log.Println("调用GPT响应流", err)
break
}
if len(response.Choices) < 1 {
continue
}
data := response.Choices[0].Delta.Content
if data == "" {
continue
}
log.Println(data)
messageContent += data
tempContent := messageContent
//百度文本审核时屏蔽流式输出
if isBaiduCheck && config["BaiduCheckDisableStream"] != "" {
tempContent = "正在回复内容..."
}
VisitorMessageSameMsgId(uint(msgId), visitorId, tempContent, kefuInfo)
}
stream.Close()
//AI反向提问
//go GptRevQuestion(entId, visitorId, kefuInfo, content)
//记录日志
go llmLog(entId, kefuInfo.Name, bigModelName, messageContent, message)
if ok && messageContent != "" {
//调用百度文本审核AI回复结果
if config["BaiduCheckApiKey"] != "" && config["BaiduCheckSecretKey"] != "" {
b := lib.BaiduCheck{
API_KEY: config["BaiduCheckApiKey"],
SECRET_KEY: config["BaiduCheckSecretKey"],
}
r, err := b.Check(messageContent)
log.Println("百度审核结果:", messageContent, r)
ok := gjson.Get(r, "conclusion").String()
if err == nil && ok != "合规" {
messageContent = "对不起AI回复内容包含敏感信息"
}
VisitorMessageSameMsgId(uint(msgId), visitorId, messageContent, kefuInfo)
}
go KefuMessage(visitorId, messageContent, kefuInfo)
models.CreateMessage(kefuInfo.Name, visitorId, messageContent, "kefu", entId, "read")
return messageContent
}
return messageContent
}
// 调用GPT embedding接口并查询qdrant集合数据
func GptEmbedding(openaiUrl, openaiKey, embeddingModelName, collectionName, keywords, system string, score float64, searchEmptyInterrupt bool, history []lib.Gpt3Dot5Message) ([]lib.Gpt3Dot5Message, error, []string) {
message := make([]lib.Gpt3Dot5Message, 0)
urls := make([]string, 0)
if collectionName == "" {
message = append(message, lib.Gpt3Dot5Message{
Role: "system",
Content: system,
})
message = append(message, history...)
message = append(message, lib.Gpt3Dot5Message{
Role: "user",
Content: keywords,
})
return message, nil, urls
}
gpt := lib.NewChatGptTool(openaiUrl, openaiKey)
//调用GPT text-embedding-ada-002
if embeddingModelName == "" {
embeddingModelName = "text-embedding-ada-002"
}
response, err := gpt.GetEmbedding(keywords, embeddingModelName)
if err != nil {
return nil, err, urls
}
var embeddingResponse lib.EmbeddingResponse
json.Unmarshal([]byte(response), &embeddingResponse)
params := map[string]interface{}{"exact": false, "hnsw_ef": 128}
if len(embeddingResponse.Data) == 0 {
return nil, errors.New(response), urls
}
vector := embeddingResponse.Data[0].Embedding
limit := 10
points, err := lib.SearchPoints(collectionName, params, vector, limit, score)
if err != nil {
log.Println(err)
}
log.Println("qdrant相似搜索结果", string(points))
result := gjson.Get(string(points), "result").Array()
if len(result) == 0 && searchEmptyInterrupt {
return message, errors.New("搜索知识库为空中断AI回复"), urls
}
content := ""
line := ""
for key, row := range result {
key++
line = fmt.Sprintf("%d. %s\n", key, row.Get("payload.text").String())
url := row.Get("payload.url").String()
if url != "" && key <= 3 {
urls = append(urls, url)
}
//arr := []rune(line)
//if len(arr) > 500 {
// line = string(arr[:500])
//}
content += line
}
//if content == "" {
// content = "我是知识库AI机器人\n"
//}
message = append(message, lib.Gpt3Dot5Message{
Role: "system",
Content: fmt.Sprintf("%s\n%s\n", system, content),
})
message = append(message, history...)
//userQuestion := fmt.Sprintf("%s%s", promptMessage, keywords)
//userQuestion := keywords
message = append(message, lib.Gpt3Dot5Message{
Role: "user",
Content: keywords,
})
urls = tools.RemoveDuplicateStrings(urls)
return message, nil, urls
}
// 记录LLM日志
func llmLog(entId, kefuName, modelName, answer string, questions []lib.Gpt3Dot5Message) {
q := ""
for _, question := range questions {
q += question.Role + ":" + question.Content + "\n"
}
llm := models.LlmLog{
EntID: entId,
KefuName: kefuName,
ModelName: modelName,
Question: q,
Answer: answer,
}
llm.AddLlmLog()
}