189 lines
6.2 KiB
Go
189 lines
6.2 KiB
Go
package controller
|
||
|
||
import (
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/tidwall/gjson"
|
||
"io"
|
||
"knowledge/utils"
|
||
"log"
|
||
"net/http"
|
||
"net/url"
|
||
"os"
|
||
"strconv"
|
||
)
|
||
|
||
func PostChatGPT(c *gin.Context) {
|
||
// 将响应头中的Content-Type设置为text/plain,表示响应内容为文本
|
||
c.Header("Content-Type", "text/html;charset=utf-8;")
|
||
// 关闭输出缓冲,使得每次写入的数据能够立即发送给客户端
|
||
f, ok := c.Writer.(http.Flusher)
|
||
|
||
keywords, _ := url.QueryUnescape(c.Query("keywords"))
|
||
if keywords == "" {
|
||
keywords = c.PostForm("keywords")
|
||
}
|
||
system, _ := url.QueryUnescape(c.Query("system"))
|
||
if system == "" {
|
||
system = c.PostForm("system")
|
||
}
|
||
if system == "" {
|
||
system = "你正在扮演知识库AI机器人"
|
||
}
|
||
if keywords == "" {
|
||
c.Writer.Write([]byte("请求参数为空"))
|
||
return
|
||
}
|
||
//历史记录
|
||
history := c.PostForm("history")
|
||
//匹配不到
|
||
prompt := c.PostForm("prompt")
|
||
|
||
if prompt != "" {
|
||
prompt = fmt.Sprintf("对于与知识信息无关的问题,你应拒绝并告知用户“%s”", prompt)
|
||
}
|
||
openaiKey := os.Getenv("OPENAI_KEY")
|
||
openaiUrl := os.Getenv("OPENAI_API_BASE")
|
||
//gpt url
|
||
gptUrl := c.PostForm("gptUrl")
|
||
if gptUrl != "" {
|
||
openaiUrl = gptUrl
|
||
}
|
||
//gpt secret
|
||
gptSecret := c.PostForm("gptSecret")
|
||
if gptSecret != "" {
|
||
openaiKey = gptSecret
|
||
}
|
||
collectName := c.Param("collectName")
|
||
log.Println(openaiUrl, openaiKey)
|
||
gpt := utils.NewChatGptTool(openaiUrl, openaiKey)
|
||
message, err, urls := MakePrompt(openaiUrl, openaiKey, collectName, keywords, system, history, prompt)
|
||
if err != nil {
|
||
c.Writer.Write([]byte(err.Error()))
|
||
return
|
||
}
|
||
log.Printf("请求openai:%+v\n", message)
|
||
stream, err := gpt.ChatGPT3Dot5TurboStream(message)
|
||
if err != nil {
|
||
log.Println("gpt3 stream error: ", err)
|
||
return
|
||
}
|
||
if !ok {
|
||
c.AbortWithStatus(http.StatusInternalServerError)
|
||
return
|
||
}
|
||
for {
|
||
response, err := stream.Recv()
|
||
if errors.Is(err, io.EOF) {
|
||
log.Println("\nStream finished", err)
|
||
break
|
||
} else if err != nil {
|
||
log.Printf("\nStream error: %v,%v\n", err, response)
|
||
break
|
||
} else {
|
||
data := response.Choices[0].Delta.Content
|
||
log.Printf(data)
|
||
c.Writer.Write([]byte(data))
|
||
f.Flush()
|
||
}
|
||
}
|
||
isShowRefer := os.Getenv("IS_SHOW_REFER")
|
||
if len(urls) != 0 && isShowRefer == "yes" {
|
||
log.Println(urls)
|
||
urls = utils.RemoveRepByMap(urls)
|
||
c.Writer.Write([]byte("\r\n\r\n参考信息"))
|
||
f.Flush()
|
||
i := 1
|
||
for _, url := range urls {
|
||
utlHtml := fmt.Sprintf("\n\n%d. %s", i, url)
|
||
c.Writer.Write([]byte(utlHtml))
|
||
f.Flush()
|
||
i++
|
||
}
|
||
}
|
||
|
||
stream.Close()
|
||
}
|
||
|
||
//组织Prompt
|
||
func MakePrompt(openaiUrl, openaiKey, collectionName, keywords, system, history, promptMessage string) ([]utils.Gpt3Dot5Message, error, []string) {
|
||
gpt := utils.NewChatGptTool(openaiUrl, openaiKey)
|
||
urls := make([]string, 0)
|
||
|
||
response, err := gpt.GetEmbedding(keywords, "text-embedding-ada-002")
|
||
if err != nil {
|
||
return nil, err, urls
|
||
}
|
||
var embeddingResponse utils.EmbeddingResponse
|
||
json.Unmarshal([]byte(response), &embeddingResponse)
|
||
|
||
params := map[string]interface{}{"exact": false, "hnsw_ef": 128}
|
||
if len(embeddingResponse.Data) == 0 {
|
||
return nil, errors.New(response), urls
|
||
}
|
||
vector := embeddingResponse.Data[0].Embedding
|
||
limit := 3
|
||
points, _ := utils.SearchPoints(collectionName, params, vector, limit)
|
||
log.Printf("相似搜索结果:%s\n", string(points))
|
||
result := gjson.Get(string(points), "result").Array()
|
||
message := make([]utils.Gpt3Dot5Message, 0)
|
||
content := ""
|
||
line := ""
|
||
referScoreStr := os.Getenv("SHOW_REFER_SCORE")
|
||
referScore := 0.85
|
||
if referScoreStr != "" {
|
||
referScore, _ = strconv.ParseFloat(referScoreStr, 64)
|
||
}
|
||
for key, row := range result {
|
||
key++
|
||
line = fmt.Sprintf("%d. %s\n", key, row.Get("payload.text").String())
|
||
url := row.Get("payload.url").String()
|
||
score := row.Get("score").Float()
|
||
if score >= referScore && url != "" && key <= 3 {
|
||
urls = append(urls, url)
|
||
}
|
||
arr := []rune(line)
|
||
if len(arr) > 500 {
|
||
line = string(arr[:500])
|
||
}
|
||
//message = append(message, utils.Gpt3Dot5Message{
|
||
// Role: "assistant",
|
||
// Content: line,
|
||
//})
|
||
content += line
|
||
}
|
||
if content == "" {
|
||
content = "我是知识库AI机器人\n"
|
||
}
|
||
message = append(message, utils.Gpt3Dot5Message{
|
||
Role: "system",
|
||
//Content: fmt.Sprintf("%s,我会向你提问题,我提供的上下文信息是:\n\n%s\n", system, content),
|
||
Content: fmt.Sprintf("%s,%s\n我提供的知识信息是:\n\n%s\n", system, promptMessage, content),
|
||
//Content: fmt.Sprintf("%s,我会向你提问题,你只能基于知识信息回答问题,我提供的知识信息是:\n\n%s\n\n你必须根据自己的知识信息总结归纳后回答问题,不要写解释,不要写具体代码,对于与知识信息无关的问题或者不理解的问题等,你应拒绝并告知用户“未查询到相关信息,请提供详细的问题信息。”", system, content),
|
||
//Content: fmt.Sprintf("你现在扮演知识库AI机器人。请严格根据提供的参考信息总结归纳后回答问题。对于与参考信息无关的问题或者不理解的问题等,你应拒绝并告知用户“未查询到相关信息,请提供详细的问题信息。”避免引用任何当前或过去的政治人物或事件,以及可能引起争议或分裂的历史人物或事件。",
|
||
})
|
||
var historySlice = make([]utils.Gpt3Dot5Message, 0)
|
||
err = json.Unmarshal([]byte(history), &historySlice)
|
||
if err == nil {
|
||
message = append(message, historySlice...)
|
||
} else {
|
||
//message = append(message, utils.Gpt3Dot5Message{
|
||
// Role: "user",
|
||
// //Content: keywords,
|
||
// Content: fmt.Sprintf("我提供的知识信息是:\n\n%s\n\n你必须根据提供的知识信息总结归纳后回答问题,不要写解释,不要写具体代码,对于与知识信息无关的问题或者不理解的问题等,你应拒绝并告知用户“未查询到相关信息,请提供详细的问题信息。”,我的问题是:%s", content, keywords),
|
||
//})
|
||
}
|
||
//userQuestion := keywords
|
||
//if promptMessage != "" {
|
||
userQuestion := fmt.Sprintf("%s,%s\n我的问题是:'''\n%s\n'''", system, promptMessage, keywords)
|
||
//}
|
||
|
||
message = append(message, utils.Gpt3Dot5Message{
|
||
Role: "user",
|
||
Content: userQuestion,
|
||
})
|
||
return message, nil, urls
|
||
}
|