kefu/knowledge/controller/gpt.go

189 lines
6.2 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package controller
import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"io"
"knowledge/utils"
"log"
"net/http"
"net/url"
"os"
"strconv"
)
func PostChatGPT(c *gin.Context) {
// 将响应头中的Content-Type设置为text/plain表示响应内容为文本
c.Header("Content-Type", "text/html;charset=utf-8;")
// 关闭输出缓冲,使得每次写入的数据能够立即发送给客户端
f, ok := c.Writer.(http.Flusher)
keywords, _ := url.QueryUnescape(c.Query("keywords"))
if keywords == "" {
keywords = c.PostForm("keywords")
}
system, _ := url.QueryUnescape(c.Query("system"))
if system == "" {
system = c.PostForm("system")
}
if system == "" {
system = "你正在扮演知识库AI机器人"
}
if keywords == "" {
c.Writer.Write([]byte("请求参数为空"))
return
}
//历史记录
history := c.PostForm("history")
//匹配不到
prompt := c.PostForm("prompt")
if prompt != "" {
prompt = fmt.Sprintf("对于与知识信息无关的问题,你应拒绝并告知用户“%s”", prompt)
}
openaiKey := os.Getenv("OPENAI_KEY")
openaiUrl := os.Getenv("OPENAI_API_BASE")
//gpt url
gptUrl := c.PostForm("gptUrl")
if gptUrl != "" {
openaiUrl = gptUrl
}
//gpt secret
gptSecret := c.PostForm("gptSecret")
if gptSecret != "" {
openaiKey = gptSecret
}
collectName := c.Param("collectName")
log.Println(openaiUrl, openaiKey)
gpt := utils.NewChatGptTool(openaiUrl, openaiKey)
message, err, urls := MakePrompt(openaiUrl, openaiKey, collectName, keywords, system, history, prompt)
if err != nil {
c.Writer.Write([]byte(err.Error()))
return
}
log.Printf("请求openai%+v\n", message)
stream, err := gpt.ChatGPT3Dot5TurboStream(message)
if err != nil {
log.Println("gpt3 stream error: ", err)
return
}
if !ok {
c.AbortWithStatus(http.StatusInternalServerError)
return
}
for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
log.Println("\nStream finished", err)
break
} else if err != nil {
log.Printf("\nStream error: %v,%v\n", err, response)
break
} else {
data := response.Choices[0].Delta.Content
log.Printf(data)
c.Writer.Write([]byte(data))
f.Flush()
}
}
isShowRefer := os.Getenv("IS_SHOW_REFER")
if len(urls) != 0 && isShowRefer == "yes" {
log.Println(urls)
urls = utils.RemoveRepByMap(urls)
c.Writer.Write([]byte("\r\n\r\n参考信息"))
f.Flush()
i := 1
for _, url := range urls {
utlHtml := fmt.Sprintf("\n\n%d. %s", i, url)
c.Writer.Write([]byte(utlHtml))
f.Flush()
i++
}
}
stream.Close()
}
//组织Prompt
func MakePrompt(openaiUrl, openaiKey, collectionName, keywords, system, history, promptMessage string) ([]utils.Gpt3Dot5Message, error, []string) {
gpt := utils.NewChatGptTool(openaiUrl, openaiKey)
urls := make([]string, 0)
response, err := gpt.GetEmbedding(keywords, "text-embedding-ada-002")
if err != nil {
return nil, err, urls
}
var embeddingResponse utils.EmbeddingResponse
json.Unmarshal([]byte(response), &embeddingResponse)
params := map[string]interface{}{"exact": false, "hnsw_ef": 128}
if len(embeddingResponse.Data) == 0 {
return nil, errors.New(response), urls
}
vector := embeddingResponse.Data[0].Embedding
limit := 3
points, _ := utils.SearchPoints(collectionName, params, vector, limit)
log.Printf("相似搜索结果:%s\n", string(points))
result := gjson.Get(string(points), "result").Array()
message := make([]utils.Gpt3Dot5Message, 0)
content := ""
line := ""
referScoreStr := os.Getenv("SHOW_REFER_SCORE")
referScore := 0.85
if referScoreStr != "" {
referScore, _ = strconv.ParseFloat(referScoreStr, 64)
}
for key, row := range result {
key++
line = fmt.Sprintf("%d. %s\n", key, row.Get("payload.text").String())
url := row.Get("payload.url").String()
score := row.Get("score").Float()
if score >= referScore && url != "" && key <= 3 {
urls = append(urls, url)
}
arr := []rune(line)
if len(arr) > 500 {
line = string(arr[:500])
}
//message = append(message, utils.Gpt3Dot5Message{
// Role: "assistant",
// Content: line,
//})
content += line
}
if content == "" {
content = "我是知识库AI机器人\n"
}
message = append(message, utils.Gpt3Dot5Message{
Role: "system",
//Content: fmt.Sprintf("%s我会向你提问题我提供的上下文信息是\n\n%s\n", system, content),
Content: fmt.Sprintf("%s%s\n我提供的知识信息是\n\n%s\n", system, promptMessage, content),
//Content: fmt.Sprintf("%s我会向你提问题你只能基于知识信息回答问题我提供的知识信息是\n\n%s\n\n你必须根据自己的知识信息总结归纳后回答问题不要写解释不要写具体代码对于与知识信息无关的问题或者不理解的问题等你应拒绝并告知用户“未查询到相关信息请提供详细的问题信息。”", system, content),
//Content: fmt.Sprintf("你现在扮演知识库AI机器人。请严格根据提供的参考信息总结归纳后回答问题。对于与参考信息无关的问题或者不理解的问题等你应拒绝并告知用户“未查询到相关信息请提供详细的问题信息。”避免引用任何当前或过去的政治人物或事件以及可能引起争议或分裂的历史人物或事件。",
})
var historySlice = make([]utils.Gpt3Dot5Message, 0)
err = json.Unmarshal([]byte(history), &historySlice)
if err == nil {
message = append(message, historySlice...)
} else {
//message = append(message, utils.Gpt3Dot5Message{
// Role: "user",
// //Content: keywords,
// Content: fmt.Sprintf("我提供的知识信息是:\n\n%s\n\n你必须根据提供的知识信息总结归纳后回答问题不要写解释不要写具体代码对于与知识信息无关的问题或者不理解的问题等你应拒绝并告知用户“未查询到相关信息请提供详细的问题信息。”我的问题是%s", content, keywords),
//})
}
//userQuestion := keywords
//if promptMessage != "" {
userQuestion := fmt.Sprintf("%s%s\n我的问题是'''\n%s\n'''", system, promptMessage, keywords)
//}
message = append(message, utils.Gpt3Dot5Message{
Role: "user",
Content: userQuestion,
})
return message, nil, urls
}