189 lines
6.2 KiB
Go
189 lines
6.2 KiB
Go
|
package controller
|
|||
|
|
|||
|
import (
|
|||
|
"encoding/json"
|
|||
|
"errors"
|
|||
|
"fmt"
|
|||
|
"github.com/gin-gonic/gin"
|
|||
|
"github.com/tidwall/gjson"
|
|||
|
"io"
|
|||
|
"knowledge/utils"
|
|||
|
"log"
|
|||
|
"net/http"
|
|||
|
"net/url"
|
|||
|
"os"
|
|||
|
"strconv"
|
|||
|
)
|
|||
|
|
|||
|
func PostChatGPT(c *gin.Context) {
|
|||
|
// 将响应头中的Content-Type设置为text/plain,表示响应内容为文本
|
|||
|
c.Header("Content-Type", "text/html;charset=utf-8;")
|
|||
|
// 关闭输出缓冲,使得每次写入的数据能够立即发送给客户端
|
|||
|
f, ok := c.Writer.(http.Flusher)
|
|||
|
|
|||
|
keywords, _ := url.QueryUnescape(c.Query("keywords"))
|
|||
|
if keywords == "" {
|
|||
|
keywords = c.PostForm("keywords")
|
|||
|
}
|
|||
|
system, _ := url.QueryUnescape(c.Query("system"))
|
|||
|
if system == "" {
|
|||
|
system = c.PostForm("system")
|
|||
|
}
|
|||
|
if system == "" {
|
|||
|
system = "你正在扮演知识库AI机器人"
|
|||
|
}
|
|||
|
if keywords == "" {
|
|||
|
c.Writer.Write([]byte("请求参数为空"))
|
|||
|
return
|
|||
|
}
|
|||
|
//历史记录
|
|||
|
history := c.PostForm("history")
|
|||
|
//匹配不到
|
|||
|
prompt := c.PostForm("prompt")
|
|||
|
|
|||
|
if prompt != "" {
|
|||
|
prompt = fmt.Sprintf("对于与知识信息无关的问题,你应拒绝并告知用户“%s”", prompt)
|
|||
|
}
|
|||
|
openaiKey := os.Getenv("OPENAI_KEY")
|
|||
|
openaiUrl := os.Getenv("OPENAI_API_BASE")
|
|||
|
//gpt url
|
|||
|
gptUrl := c.PostForm("gptUrl")
|
|||
|
if gptUrl != "" {
|
|||
|
openaiUrl = gptUrl
|
|||
|
}
|
|||
|
//gpt secret
|
|||
|
gptSecret := c.PostForm("gptSecret")
|
|||
|
if gptSecret != "" {
|
|||
|
openaiKey = gptSecret
|
|||
|
}
|
|||
|
collectName := c.Param("collectName")
|
|||
|
log.Println(openaiUrl, openaiKey)
|
|||
|
gpt := utils.NewChatGptTool(openaiUrl, openaiKey)
|
|||
|
message, err, urls := MakePrompt(openaiUrl, openaiKey, collectName, keywords, system, history, prompt)
|
|||
|
if err != nil {
|
|||
|
c.Writer.Write([]byte(err.Error()))
|
|||
|
return
|
|||
|
}
|
|||
|
log.Printf("请求openai:%+v\n", message)
|
|||
|
stream, err := gpt.ChatGPT3Dot5TurboStream(message)
|
|||
|
if err != nil {
|
|||
|
log.Println("gpt3 stream error: ", err)
|
|||
|
return
|
|||
|
}
|
|||
|
if !ok {
|
|||
|
c.AbortWithStatus(http.StatusInternalServerError)
|
|||
|
return
|
|||
|
}
|
|||
|
for {
|
|||
|
response, err := stream.Recv()
|
|||
|
if errors.Is(err, io.EOF) {
|
|||
|
log.Println("\nStream finished", err)
|
|||
|
break
|
|||
|
} else if err != nil {
|
|||
|
log.Printf("\nStream error: %v,%v\n", err, response)
|
|||
|
break
|
|||
|
} else {
|
|||
|
data := response.Choices[0].Delta.Content
|
|||
|
log.Printf(data)
|
|||
|
c.Writer.Write([]byte(data))
|
|||
|
f.Flush()
|
|||
|
}
|
|||
|
}
|
|||
|
isShowRefer := os.Getenv("IS_SHOW_REFER")
|
|||
|
if len(urls) != 0 && isShowRefer == "yes" {
|
|||
|
log.Println(urls)
|
|||
|
urls = utils.RemoveRepByMap(urls)
|
|||
|
c.Writer.Write([]byte("\r\n\r\n参考信息"))
|
|||
|
f.Flush()
|
|||
|
i := 1
|
|||
|
for _, url := range urls {
|
|||
|
utlHtml := fmt.Sprintf("\n\n%d. %s", i, url)
|
|||
|
c.Writer.Write([]byte(utlHtml))
|
|||
|
f.Flush()
|
|||
|
i++
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
stream.Close()
|
|||
|
}
|
|||
|
|
|||
|
//组织Prompt
|
|||
|
func MakePrompt(openaiUrl, openaiKey, collectionName, keywords, system, history, promptMessage string) ([]utils.Gpt3Dot5Message, error, []string) {
|
|||
|
gpt := utils.NewChatGptTool(openaiUrl, openaiKey)
|
|||
|
urls := make([]string, 0)
|
|||
|
|
|||
|
response, err := gpt.GetEmbedding(keywords, "text-embedding-ada-002")
|
|||
|
if err != nil {
|
|||
|
return nil, err, urls
|
|||
|
}
|
|||
|
var embeddingResponse utils.EmbeddingResponse
|
|||
|
json.Unmarshal([]byte(response), &embeddingResponse)
|
|||
|
|
|||
|
params := map[string]interface{}{"exact": false, "hnsw_ef": 128}
|
|||
|
if len(embeddingResponse.Data) == 0 {
|
|||
|
return nil, errors.New(response), urls
|
|||
|
}
|
|||
|
vector := embeddingResponse.Data[0].Embedding
|
|||
|
limit := 3
|
|||
|
points, _ := utils.SearchPoints(collectionName, params, vector, limit)
|
|||
|
log.Printf("相似搜索结果:%s\n", string(points))
|
|||
|
result := gjson.Get(string(points), "result").Array()
|
|||
|
message := make([]utils.Gpt3Dot5Message, 0)
|
|||
|
content := ""
|
|||
|
line := ""
|
|||
|
referScoreStr := os.Getenv("SHOW_REFER_SCORE")
|
|||
|
referScore := 0.85
|
|||
|
if referScoreStr != "" {
|
|||
|
referScore, _ = strconv.ParseFloat(referScoreStr, 64)
|
|||
|
}
|
|||
|
for key, row := range result {
|
|||
|
key++
|
|||
|
line = fmt.Sprintf("%d. %s\n", key, row.Get("payload.text").String())
|
|||
|
url := row.Get("payload.url").String()
|
|||
|
score := row.Get("score").Float()
|
|||
|
if score >= referScore && url != "" && key <= 3 {
|
|||
|
urls = append(urls, url)
|
|||
|
}
|
|||
|
arr := []rune(line)
|
|||
|
if len(arr) > 500 {
|
|||
|
line = string(arr[:500])
|
|||
|
}
|
|||
|
//message = append(message, utils.Gpt3Dot5Message{
|
|||
|
// Role: "assistant",
|
|||
|
// Content: line,
|
|||
|
//})
|
|||
|
content += line
|
|||
|
}
|
|||
|
if content == "" {
|
|||
|
content = "我是知识库AI机器人\n"
|
|||
|
}
|
|||
|
message = append(message, utils.Gpt3Dot5Message{
|
|||
|
Role: "system",
|
|||
|
//Content: fmt.Sprintf("%s,我会向你提问题,我提供的上下文信息是:\n\n%s\n", system, content),
|
|||
|
Content: fmt.Sprintf("%s,%s\n我提供的知识信息是:\n\n%s\n", system, promptMessage, content),
|
|||
|
//Content: fmt.Sprintf("%s,我会向你提问题,你只能基于知识信息回答问题,我提供的知识信息是:\n\n%s\n\n你必须根据自己的知识信息总结归纳后回答问题,不要写解释,不要写具体代码,对于与知识信息无关的问题或者不理解的问题等,你应拒绝并告知用户“未查询到相关信息,请提供详细的问题信息。”", system, content),
|
|||
|
//Content: fmt.Sprintf("你现在扮演知识库AI机器人。请严格根据提供的参考信息总结归纳后回答问题。对于与参考信息无关的问题或者不理解的问题等,你应拒绝并告知用户“未查询到相关信息,请提供详细的问题信息。”避免引用任何当前或过去的政治人物或事件,以及可能引起争议或分裂的历史人物或事件。",
|
|||
|
})
|
|||
|
var historySlice = make([]utils.Gpt3Dot5Message, 0)
|
|||
|
err = json.Unmarshal([]byte(history), &historySlice)
|
|||
|
if err == nil {
|
|||
|
message = append(message, historySlice...)
|
|||
|
} else {
|
|||
|
//message = append(message, utils.Gpt3Dot5Message{
|
|||
|
// Role: "user",
|
|||
|
// //Content: keywords,
|
|||
|
// Content: fmt.Sprintf("我提供的知识信息是:\n\n%s\n\n你必须根据提供的知识信息总结归纳后回答问题,不要写解释,不要写具体代码,对于与知识信息无关的问题或者不理解的问题等,你应拒绝并告知用户“未查询到相关信息,请提供详细的问题信息。”,我的问题是:%s", content, keywords),
|
|||
|
//})
|
|||
|
}
|
|||
|
//userQuestion := keywords
|
|||
|
//if promptMessage != "" {
|
|||
|
userQuestion := fmt.Sprintf("%s,%s\n我的问题是:'''\n%s\n'''", system, promptMessage, keywords)
|
|||
|
//}
|
|||
|
|
|||
|
message = append(message, utils.Gpt3Dot5Message{
|
|||
|
Role: "user",
|
|||
|
Content: userQuestion,
|
|||
|
})
|
|||
|
return message, nil, urls
|
|||
|
}
|