345 lines
10 KiB
Go
345 lines
10 KiB
Go
package ws
|
||
|
||
import (
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"github.com/tidwall/gjson"
|
||
"kefu/lib"
|
||
"kefu/models"
|
||
"kefu/tools"
|
||
"log"
|
||
)
|
||
|
||
// 反向提问
|
||
func GptRevQuestion(entId, visitorId string, kefuInfo models.User, question string) string {
|
||
config := models.GetEntConfigsMap(entId, "chatGPTUrl", "chatGPTSecret", "chatGPTRevQuestion", "QdrantHistoryNum", "BigModelName")
|
||
if config["chatGPTRevQuestion"] == "" || config["chatGPTUrl"] == "" || config["chatGPTSecret"] == "" {
|
||
return ""
|
||
}
|
||
//大模型名称
|
||
bigModelName := config["BigModelName"]
|
||
if bigModelName == "" {
|
||
bigModelName = "gpt-3.5-turbo-16k"
|
||
}
|
||
//是否传递历史记录
|
||
gptMessages := make([]lib.Gpt3Dot5Message, 0)
|
||
gptMessages = append(gptMessages, lib.Gpt3Dot5Message{
|
||
Role: "system",
|
||
Content: config["chatGPTRevQuestion"],
|
||
})
|
||
num := 5
|
||
if config["QdrantHistoryNum"] != "" {
|
||
num = tools.Str2Int(config["QdrantHistoryNum"])
|
||
}
|
||
|
||
messages := models.FindMessageByQueryPage(1, uint(num), "visitor_id = ?", visitorId)
|
||
for i := len(messages) - 1; i >= 1; i-- {
|
||
reqContent := messages[i].Content
|
||
if messages[i].MesType == "visitor" {
|
||
gptMessages = append(gptMessages, lib.Gpt3Dot5Message{
|
||
Role: "user",
|
||
Content: reqContent,
|
||
})
|
||
} else {
|
||
gptMessages = append(gptMessages, lib.Gpt3Dot5Message{
|
||
Role: "assistant",
|
||
Content: reqContent,
|
||
})
|
||
}
|
||
}
|
||
gptMessages = append(gptMessages, lib.Gpt3Dot5Message{
|
||
Role: "user",
|
||
Content: config["chatGPTRevQuestion"] + question,
|
||
})
|
||
openaiUrl := config["chatGPTUrl"]
|
||
openaiKey := config["chatGPTSecret"]
|
||
gpt := lib.NewChatGptTool(openaiUrl, openaiKey)
|
||
log.Println("反向提问调用GPT请求参数:", gptMessages)
|
||
content, _ := gpt.ChatGPT3Dot5Turbo(gptMessages, bigModelName)
|
||
if content == "" {
|
||
return ""
|
||
}
|
||
VisitorMessage(visitorId, content, kefuInfo)
|
||
go KefuMessage(visitorId, content, kefuInfo)
|
||
models.CreateMessage(kefuInfo.Name, visitorId, content, "kefu", entId, "read")
|
||
return content
|
||
}
|
||
|
||
// 调用GPT-3.5流式响应
|
||
func Gpt3Knowledge(entId, visitorId string, kefuInfo models.User, content string) string {
|
||
config := models.GetEntConfigsMap(entId, "chatGPTUrl", "chatGPTSecret", "QdrantAIStatus", "QdrantAICollect", "chatGPTSystem", "chatGPTPrompt",
|
||
"RobotName", "RobotAvator", "QdrantScore", "QdrantHistoryNum",
|
||
"BaiduCheckApiKey", "BaiduCheckSecretKey", "BaiduCheckDisableStream",
|
||
"BigModelName", "EmbeddingModelName", "SearchEmptyInterrupt",
|
||
"chatGPTHistory")
|
||
lib.QdrantBase = models.FindConfig("QdrantBase")
|
||
lib.QdrantPort = models.FindConfig("QdrantPort")
|
||
_, ok := ClientList[visitorId]
|
||
messageContent := ""
|
||
//没有开启AI回复
|
||
if config["QdrantAIStatus"] != "true" || config["chatGPTUrl"] == "" || config["chatGPTSecret"] == "" {
|
||
return ""
|
||
}
|
||
//大模型名称
|
||
bigModelName := config["BigModelName"]
|
||
if bigModelName == "" {
|
||
bigModelName = "gpt-3.5-turbo-16k"
|
||
}
|
||
|
||
//调用百度文本审核用户问题
|
||
isBaiduCheck := false
|
||
if config["BaiduCheckApiKey"] != "" && config["BaiduCheckSecretKey"] != "" {
|
||
isBaiduCheck = true
|
||
b := lib.BaiduCheck{
|
||
API_KEY: config["BaiduCheckApiKey"],
|
||
SECRET_KEY: config["BaiduCheckSecretKey"],
|
||
}
|
||
r, err := b.Check(content)
|
||
log.Println("百度审核结果:", content, r)
|
||
ok := gjson.Get(r, "conclusion").String()
|
||
if err == nil && ok != "合规" {
|
||
messageContent = "对不起,内容包含敏感信息!"
|
||
VisitorMessage(visitorId, messageContent, kefuInfo)
|
||
return messageContent
|
||
}
|
||
}
|
||
|
||
kefuInfo.Nickname = tools.Ifelse(config["RobotName"] != "", config["RobotName"], kefuInfo.Nickname).(string)
|
||
kefuInfo.Avator = tools.Ifelse(config["RobotAvator"] != "", config["RobotAvator"], kefuInfo.Avator).(string)
|
||
openaiUrl := config["chatGPTUrl"]
|
||
openaiKey := config["chatGPTSecret"]
|
||
collectName := config["QdrantAICollect"]
|
||
embeddingModelName := config["EmbeddingModelName"]
|
||
promptMessage := config["chatGPTPrompt"]
|
||
if promptMessage != "" {
|
||
content = promptMessage + content
|
||
}
|
||
|
||
system := ""
|
||
if config["chatGPTSystem"] != "" {
|
||
system = config["chatGPTSystem"]
|
||
}
|
||
system = fmt.Sprintf("%s\n当前时间:%s", system, tools.GetNowTime())
|
||
|
||
//是否传递历史记录
|
||
gptMessages := make([]lib.Gpt3Dot5Message, 0)
|
||
history := make([]lib.CozeCompletionMessage, 0)
|
||
num := 5
|
||
if config["QdrantHistoryNum"] != "" {
|
||
num = tools.Str2Int(config["QdrantHistoryNum"])
|
||
}
|
||
messages := models.FindMessageByQueryPage(1, uint(num), "visitor_id = ?", visitorId)
|
||
for i := len(messages) - 1; i >= 1; i-- {
|
||
reqContent := messages[i].Content
|
||
if messages[i].MesType == "visitor" {
|
||
gptMessages = append(gptMessages, lib.Gpt3Dot5Message{
|
||
Role: "user",
|
||
Content: reqContent,
|
||
})
|
||
//扣子智能体
|
||
if bigModelName == "coze" {
|
||
history = append(history, lib.CozeCompletionMessage{
|
||
Role: "user",
|
||
Content: reqContent,
|
||
ContentType: "text",
|
||
})
|
||
}
|
||
} else {
|
||
gptMessages = append(gptMessages, lib.Gpt3Dot5Message{
|
||
Role: "assistant",
|
||
Content: reqContent,
|
||
})
|
||
//扣子智能体
|
||
if bigModelName == "coze" {
|
||
history = append(history, lib.CozeCompletionMessage{
|
||
Role: "assistant",
|
||
Content: reqContent,
|
||
ContentType: "text",
|
||
Type: "answer",
|
||
})
|
||
}
|
||
}
|
||
}
|
||
//扣子智能体
|
||
if bigModelName == "coze" {
|
||
coze := &lib.Coze{
|
||
BOT_ID: config["chatGPTUrl"],
|
||
API_KEY: config["chatGPTSecret"],
|
||
}
|
||
messageContent, _ := coze.ChatCoze(visitorId, visitorId, content, history)
|
||
if ok && messageContent != "" {
|
||
VisitorMessage(visitorId, messageContent, kefuInfo)
|
||
go KefuMessage(visitorId, messageContent, kefuInfo)
|
||
models.CreateMessage(kefuInfo.Name, visitorId, messageContent, "kefu", entId, "read")
|
||
}
|
||
return messageContent
|
||
}
|
||
gpt := lib.NewChatGptTool(openaiUrl, openaiKey)
|
||
score := 0.78
|
||
|
||
if config["QdrantScore"] != "" {
|
||
score = tools.Str2Float64(config["QdrantScore"])
|
||
}
|
||
searchEmptyInterrupt := false
|
||
if config["SearchEmptyInterrupt"] == "true" {
|
||
searchEmptyInterrupt = true
|
||
}
|
||
message, err, _ := GptEmbedding(openaiUrl, openaiKey, embeddingModelName, collectName, content, system, score, searchEmptyInterrupt, gptMessages)
|
||
if err != nil {
|
||
log.Println(err)
|
||
return ""
|
||
}
|
||
log.Println("调用GPT请求参数:", message)
|
||
stream, err := gpt.ChatGPTHttpStream(message, bigModelName, visitorId)
|
||
if err != nil {
|
||
log.Println(err)
|
||
return ""
|
||
}
|
||
msgId := tools.Now()
|
||
for {
|
||
response, err := stream.Recv()
|
||
if err != nil {
|
||
log.Println("调用GPT响应流:", err)
|
||
break
|
||
}
|
||
if len(response.Choices) < 1 {
|
||
continue
|
||
}
|
||
|
||
data := response.Choices[0].Delta.Content
|
||
if data == "" {
|
||
continue
|
||
}
|
||
log.Println(data)
|
||
messageContent += data
|
||
tempContent := messageContent
|
||
//百度文本审核时屏蔽流式输出
|
||
if isBaiduCheck && config["BaiduCheckDisableStream"] != "" {
|
||
tempContent = "正在回复内容..."
|
||
}
|
||
VisitorMessageSameMsgId(uint(msgId), visitorId, tempContent, kefuInfo)
|
||
}
|
||
stream.Close()
|
||
//AI反向提问
|
||
//go GptRevQuestion(entId, visitorId, kefuInfo, content)
|
||
//记录日志
|
||
go llmLog(entId, kefuInfo.Name, bigModelName, messageContent, message)
|
||
if ok && messageContent != "" {
|
||
|
||
//调用百度文本审核AI回复结果
|
||
if config["BaiduCheckApiKey"] != "" && config["BaiduCheckSecretKey"] != "" {
|
||
b := lib.BaiduCheck{
|
||
API_KEY: config["BaiduCheckApiKey"],
|
||
SECRET_KEY: config["BaiduCheckSecretKey"],
|
||
}
|
||
r, err := b.Check(messageContent)
|
||
log.Println("百度审核结果:", messageContent, r)
|
||
ok := gjson.Get(r, "conclusion").String()
|
||
if err == nil && ok != "合规" {
|
||
messageContent = "对不起,AI回复内容包含敏感信息!"
|
||
}
|
||
VisitorMessageSameMsgId(uint(msgId), visitorId, messageContent, kefuInfo)
|
||
}
|
||
|
||
go KefuMessage(visitorId, messageContent, kefuInfo)
|
||
models.CreateMessage(kefuInfo.Name, visitorId, messageContent, "kefu", entId, "read")
|
||
return messageContent
|
||
}
|
||
return messageContent
|
||
}
|
||
|
||
// 调用GPT embedding接口并查询qdrant集合数据
|
||
func GptEmbedding(openaiUrl, openaiKey, embeddingModelName, collectionName, keywords, system string, score float64, searchEmptyInterrupt bool, history []lib.Gpt3Dot5Message) ([]lib.Gpt3Dot5Message, error, []string) {
|
||
message := make([]lib.Gpt3Dot5Message, 0)
|
||
urls := make([]string, 0)
|
||
if collectionName == "" {
|
||
message = append(message, lib.Gpt3Dot5Message{
|
||
Role: "system",
|
||
Content: system,
|
||
})
|
||
message = append(message, history...)
|
||
message = append(message, lib.Gpt3Dot5Message{
|
||
Role: "user",
|
||
Content: keywords,
|
||
})
|
||
return message, nil, urls
|
||
}
|
||
gpt := lib.NewChatGptTool(openaiUrl, openaiKey)
|
||
//调用GPT text-embedding-ada-002
|
||
if embeddingModelName == "" {
|
||
embeddingModelName = "text-embedding-ada-002"
|
||
}
|
||
response, err := gpt.GetEmbedding(keywords, embeddingModelName)
|
||
if err != nil {
|
||
return nil, err, urls
|
||
}
|
||
var embeddingResponse lib.EmbeddingResponse
|
||
json.Unmarshal([]byte(response), &embeddingResponse)
|
||
|
||
params := map[string]interface{}{"exact": false, "hnsw_ef": 128}
|
||
if len(embeddingResponse.Data) == 0 {
|
||
return nil, errors.New(response), urls
|
||
}
|
||
vector := embeddingResponse.Data[0].Embedding
|
||
limit := 10
|
||
points, err := lib.SearchPoints(collectionName, params, vector, limit, score)
|
||
if err != nil {
|
||
log.Println(err)
|
||
}
|
||
log.Println("qdrant相似搜索结果:", string(points))
|
||
result := gjson.Get(string(points), "result").Array()
|
||
if len(result) == 0 && searchEmptyInterrupt {
|
||
return message, errors.New("搜索知识库为空,中断AI回复"), urls
|
||
}
|
||
content := ""
|
||
line := ""
|
||
|
||
for key, row := range result {
|
||
key++
|
||
line = fmt.Sprintf("%d. %s\n", key, row.Get("payload.text").String())
|
||
url := row.Get("payload.url").String()
|
||
if url != "" && key <= 3 {
|
||
urls = append(urls, url)
|
||
}
|
||
//arr := []rune(line)
|
||
//if len(arr) > 500 {
|
||
// line = string(arr[:500])
|
||
//}
|
||
content += line
|
||
}
|
||
//if content == "" {
|
||
// content = "我是知识库AI机器人\n"
|
||
//}
|
||
message = append(message, lib.Gpt3Dot5Message{
|
||
Role: "system",
|
||
Content: fmt.Sprintf("%s\n%s\n", system, content),
|
||
})
|
||
|
||
message = append(message, history...)
|
||
//userQuestion := fmt.Sprintf("%s%s", promptMessage, keywords)
|
||
//userQuestion := keywords
|
||
message = append(message, lib.Gpt3Dot5Message{
|
||
Role: "user",
|
||
Content: keywords,
|
||
})
|
||
urls = tools.RemoveDuplicateStrings(urls)
|
||
return message, nil, urls
|
||
}
|
||
|
||
// 记录LLM日志
|
||
func llmLog(entId, kefuName, modelName, answer string, questions []lib.Gpt3Dot5Message) {
|
||
q := ""
|
||
for _, question := range questions {
|
||
q += question.Role + ":" + question.Content + "\n"
|
||
}
|
||
llm := models.LlmLog{
|
||
EntID: entId,
|
||
KefuName: kefuName,
|
||
ModelName: modelName,
|
||
Question: q,
|
||
Answer: answer,
|
||
}
|
||
llm.AddLlmLog()
|
||
}
|