kefu/knowledge/main.go

430 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/joho/godotenv"
uuid "github.com/satori/go.uuid"
"github.com/tidwall/gjson"
"io/ioutil"
"knowledge/controller"
"knowledge/models"
"knowledge/utils"
"log"
"net/http"
"os"
"path"
"strconv"
"strings"
)
func main() {
err := godotenv.Load(".env")
if err != nil {
log.Fatalf("Error loading .env file: %v", err)
}
// 读取环境变量
utils.QdrantBase = os.Getenv("QDRANT_BASE")
utils.QdrantPort = os.Getenv("QDRANT_PORT")
mysqlServer := os.Getenv("MYSQL_SERVER")
mysqlPort := os.Getenv("MYSQL_PORT")
mysqlDbname := os.Getenv("MYSQL_DATABASE")
mysqlUsername := os.Getenv("MYSQL_USERNAME")
mysqlPassword := os.Getenv("MYSQL_PASSWORD")
chatModel := os.Getenv("CHAT_MODEL")
//初始化mysql
models.Connect(mysqlServer, mysqlPort, mysqlDbname, mysqlUsername, mysqlPassword)
// 创建 Gin 引擎
router := gin.Default()
//启用跨域中间件
router.Use(func(c *gin.Context) {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET,POST,OPTIONS,DELETE,PUT")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Origin")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}
c.Next()
})
router.Static("/pannel", "pannel/")
router.Static("/static", "pannel/static/")
router.LoadHTMLGlob("templates/*")
router.GET("/", func(c *gin.Context) {
c.HTML(http.StatusOK, "welcome.html", gin.H{})
})
//首页
router.GET("/:collectName/index", func(c *gin.Context) {
collectName := c.Param("collectName")
c.HTML(http.StatusOK, "index.html", gin.H{
"collectName": collectName,
})
})
//后台界面
router.GET("/:collectName/qdrant", func(c *gin.Context) {
collectName := c.Param("collectName")
c.HTML(http.StatusOK, "qdrant.html", gin.H{
"collectName": collectName,
})
})
//聊天页
router.GET("/:collectName/chat", func(c *gin.Context) {
collectName := c.Param("collectName")
c.HTML(http.StatusOK, "chat.html", gin.H{
"collectName": collectName,
})
})
//创建集合
router.POST("/collect/:collectName", func(c *gin.Context) {
collectName := c.Param("collectName")
//判断集合是否存在
collectInfo, _ := utils.GetCollection(collectName)
collectInfoStatus := gjson.Get(string(collectInfo), "status").String()
if collectInfoStatus != "ok" {
res, err := utils.PutCollection(collectName)
if err != nil {
c.Writer.Write([]byte(err.Error()))
return
}
c.Writer.Write([]byte(res))
} else {
c.Writer.Write(collectInfo)
}
})
//删除集合
router.GET("/delCollect", func(c *gin.Context) {
collectName := c.Query("collectName")
info, _ := utils.DeleteCollection(collectName)
c.Writer.Write([]byte(info))
})
//查询集合
router.GET("/:collectName/info", func(c *gin.Context) {
collectName := c.Param("collectName")
list, err := utils.GetCollection(collectName)
if err != nil {
c.Writer.Write([]byte(err.Error()))
return
}
c.Writer.Write(list)
})
//集合列表
router.GET("/collects", func(c *gin.Context) {
list, err := utils.GetCollections()
if err != nil {
c.Writer.Write([]byte(err.Error()))
return
}
c.Writer.Write(list)
})
//向量列表
router.GET("/:collectName/points", func(c *gin.Context) {
collectName := c.Param("collectName")
list, err := utils.GetPoints(collectName, 1000, nil)
if err != nil {
c.Writer.Write([]byte(err.Error()))
return
}
c.Writer.Write(list)
})
//向量列表
router.GET("/:collectName/filePoints", func(c *gin.Context) {
collectName := c.Param("collectName")
id := c.Query("id")
points := make([]string, 0)
filePoints := models.FindAiFilePoint(1, 100000, "collect_name = ? and file_id = ?", collectName, id)
for _, item := range filePoints {
points = append(points, item.PointsId)
}
list, err := utils.GetPointsByIds(collectName, points)
if err != nil {
c.Writer.Write([]byte(err.Error()))
return
}
c.Writer.Write(list)
})
//删除向量
router.GET("/:collectName/delPoints", func(c *gin.Context) {
collectName := c.Param("collectName")
id := c.Query("id")
i, err := strconv.Atoi(id)
var points interface{}
if err != nil {
points = []string{id}
} else {
points = []int{i}
}
list, err := utils.DeletePoints(collectName, points)
if err != nil {
c.Writer.Write([]byte(err.Error()))
return
}
c.Writer.Write([]byte(list))
})
//删除文件
router.GET("/:collectName/delFile", func(c *gin.Context) {
collectName := c.Param("collectName")
id := c.Query("id")
points := make([]string, 0)
filePoints := models.FindAiFilePoint(1, 100000, "collect_name = ? and file_id = ?", collectName, id)
for _, item := range filePoints {
points = append(points, item.PointsId)
}
list, err := utils.DeletePoints(collectName, points)
models.DelFile("collect_name = ? and id = ?", collectName, id)
models.DelFilePoints("collect_name = ? and file_id = ?", collectName, id)
if err != nil {
c.Writer.Write([]byte(err.Error()))
return
}
c.Writer.Write([]byte(list))
})
//搜索流式响应
if chatModel == "CHATGLM-6B" {
router.Any("/:collectName/searchStream", controller.PostChatGLM)
} else {
router.Any("/:collectName/searchStream", controller.PostChatGPT)
}
//训练
router.POST("/:collectName/training", func(c *gin.Context) {
openaiKey := os.Getenv("OPENAI_KEY")
openaiUrl := os.Getenv("OPENAI_API_BASE")
//gpt url
gptUrl := c.PostForm("gptUrl")
if gptUrl != "" {
openaiUrl = gptUrl
}
//gpt secret
gptSecret := c.PostForm("gptSecret")
if gptSecret != "" {
openaiKey = gptSecret
}
collectName := c.Param("collectName")
//判断集合是否存在
collectInfo, err := utils.GetCollection(collectName)
collectInfoStatus := gjson.Get(string(collectInfo), "status").String()
if collectInfoStatus != "ok" {
//utils.PutCollection(collectName)
c.Writer.Write([]byte("集合不存在"))
return
}
//判断ID是否传递以及是否为数值型或uuid
id := c.PostForm("id")
var pointId interface{}
if id == "" {
pointId = uuid.NewV4().String()
} else {
i, err := strconv.Atoi(id)
if err != nil {
pointId = id
} else {
pointId = i
}
}
//向量化数据
content := c.PostForm("content")
res, err := Train(openaiUrl, openaiKey, pointId, collectName, content, "")
log.Println(err)
c.Writer.Write([]byte(res))
})
//上传链接
router.POST("/:collectName/uploadUrl", func(c *gin.Context) {
openaiKey := os.Getenv("OPENAI_KEY")
openaiUrl := os.Getenv("OPENAI_API_BASE")
//gpt url
gptUrl := c.PostForm("gptUrl")
if gptUrl != "" {
openaiUrl = gptUrl
}
//gpt secret
gptSecret := c.PostForm("gptSecret")
if gptSecret != "" {
openaiKey = gptSecret
}
collectName := c.Param("collectName")
url := c.PostForm("url")
if url == "" || collectName == "" {
c.JSON(200, gin.H{
"code": 400,
"msg": "参数错误!",
})
return
}
resp := utils.HttpGet(url)
if resp == "" {
c.JSON(200, gin.H{
"code": 400,
"msg": "请求数据错误!",
})
return
}
htmlContent := utils.TrimHtml(resp)
//入库
files := &models.AiFile{
FileName: url,
CollectName: collectName,
}
fileId := files.AddAiFile()
chunks := SplitTextByLength(htmlContent, 200)
for _, chunk := range chunks {
pointId := uuid.NewV4().String()
Train(openaiUrl, openaiKey, pointId, collectName, chunk, url)
//入库
aiFilePoint := &models.AiFilePoints{
FileId: fmt.Sprintf("%d", fileId),
CollectName: collectName,
PointsId: pointId,
}
aiFilePoint.AddAiFilePoint()
}
c.JSON(200, gin.H{
"code": 200,
})
})
//上传doc
router.POST("/:collectName/uploadDoc", func(c *gin.Context) {
openaiKey := os.Getenv("OPENAI_KEY")
openaiUrl := os.Getenv("OPENAI_API_BASE")
collectName := c.Param("collectName")
f, err := c.FormFile("file")
if err != nil {
c.JSON(200, gin.H{
"code": 400,
"msg": "上传失败!" + err.Error(),
})
return
} else {
fileExt := strings.ToLower(path.Ext(f.Filename))
if fileExt != ".docx" && fileExt != ".txt" && fileExt != ".xlsx" && fileExt != ".pdf" {
c.JSON(200, gin.H{
"code": 400,
"msg": "上传失败!只允许txt、pdf、docx或xlsx文件",
})
return
}
fileName := collectName + f.Filename
c.SaveUploadedFile(f, fileName)
text := ""
chunks := make([]string, 0)
if fileExt == ".txt" {
// 打开txt文件
file, _ := os.Open(fileName)
// 一次性读取整个txt文件的内容
txt, _ := ioutil.ReadAll(file)
text = string(txt)
file.Close()
} else if fileExt == ".docx" {
text, err = utils.ReadDocxAll(fileName)
} else if fileExt == ".pdf" {
text, err = utils.ReadPdfAll(fileName)
} else {
chunks, err = utils.ReadExcelAll(fileName)
}
if err != nil {
log.Println(err)
}
removeErr := os.Remove(fileName)
if removeErr != nil {
log.Println("Remove error:", fileName, removeErr)
}
if text == "" && len(chunks) == 0 {
err = errors.New("上传失败!读取数据为空")
}
if err != nil {
c.JSON(200, gin.H{
"code": 400,
"msg": err.Error(),
})
return
}
//入库
files := &models.AiFile{
FileName: f.Filename,
CollectName: collectName,
}
fileId := files.AddAiFile()
if fileExt != ".xlsx" {
chunks = SplitTextByLength(text, 200)
}
for _, chunk := range chunks {
pointId := uuid.NewV4().String()
log.Println("上传数据:" + chunk)
_, err := Train(openaiUrl, openaiKey, pointId, collectName, chunk, f.Filename)
if err == nil {
aiFilePoint := &models.AiFilePoints{
FileId: fmt.Sprintf("%d", fileId),
CollectName: collectName,
PointsId: pointId,
}
aiFilePoint.AddAiFilePoint()
} else {
log.Println(err)
}
}
os.Remove(fileName)
c.JSON(200, gin.H{
"code": 200,
})
}
})
//文件列表
router.GET("/:collectName/fileList", func(c *gin.Context) {
collectName := c.Param("collectName")
list := models.FindFileList(1, 1000, "collect_name = ? ", collectName)
c.JSON(200, gin.H{
"code": 200,
"result": list,
})
})
// 启动服务器
if err := router.Run(":8083"); err != nil {
panic(err)
}
}
//训练功能
func Train(openaiUrl string, openaiKey string, pointId interface{}, collectName, content, url string) (string, error) {
gpt := utils.NewChatGptTool(openaiUrl, openaiKey)
str := strings.TrimSpace(content)
response, err := gpt.GetEmbedding(str, "text-embedding-ada-002")
if err != nil {
return "", err
}
var embeddingResponse utils.EmbeddingResponse
json.Unmarshal([]byte(response), &embeddingResponse)
if len(embeddingResponse.Data) == 0 {
return "", errors.New(response)
}
points := []map[string]interface{}{
{
"id": pointId,
"payload": map[string]interface{}{"text": str, "url": url},
"vector": embeddingResponse.Data[0].Embedding,
},
}
res, err := utils.PutPoints(collectName, points)
return res, err
}
//长文本分块
func SplitTextByLength(text string, length int) []string {
var blocks []string
runes := []rune(text)
for i := 0; i < len(runes); i += length {
j := i + length
if j > len(runes) {
j = len(runes)
}
blocks = append(blocks, string(runes[i:j]))
}
return blocks
}