kefu/knowledge/main.go

430 lines
11 KiB
Go
Raw Normal View History

2024-12-10 02:50:12 +00:00
package main
import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/joho/godotenv"
uuid "github.com/satori/go.uuid"
"github.com/tidwall/gjson"
"io/ioutil"
"knowledge/controller"
"knowledge/models"
"knowledge/utils"
"log"
"net/http"
"os"
"path"
"strconv"
"strings"
)
func main() {
err := godotenv.Load(".env")
if err != nil {
log.Fatalf("Error loading .env file: %v", err)
}
// 读取环境变量
utils.QdrantBase = os.Getenv("QDRANT_BASE")
utils.QdrantPort = os.Getenv("QDRANT_PORT")
mysqlServer := os.Getenv("MYSQL_SERVER")
mysqlPort := os.Getenv("MYSQL_PORT")
mysqlDbname := os.Getenv("MYSQL_DATABASE")
mysqlUsername := os.Getenv("MYSQL_USERNAME")
mysqlPassword := os.Getenv("MYSQL_PASSWORD")
chatModel := os.Getenv("CHAT_MODEL")
//初始化mysql
models.Connect(mysqlServer, mysqlPort, mysqlDbname, mysqlUsername, mysqlPassword)
// 创建 Gin 引擎
router := gin.Default()
//启用跨域中间件
router.Use(func(c *gin.Context) {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET,POST,OPTIONS,DELETE,PUT")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Origin")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}
c.Next()
})
router.Static("/pannel", "pannel/")
router.Static("/static", "pannel/static/")
router.LoadHTMLGlob("templates/*")
router.GET("/", func(c *gin.Context) {
c.HTML(http.StatusOK, "welcome.html", gin.H{})
})
//首页
router.GET("/:collectName/index", func(c *gin.Context) {
collectName := c.Param("collectName")
c.HTML(http.StatusOK, "index.html", gin.H{
"collectName": collectName,
})
})
//后台界面
router.GET("/:collectName/qdrant", func(c *gin.Context) {
collectName := c.Param("collectName")
c.HTML(http.StatusOK, "qdrant.html", gin.H{
"collectName": collectName,
})
})
//聊天页
router.GET("/:collectName/chat", func(c *gin.Context) {
collectName := c.Param("collectName")
c.HTML(http.StatusOK, "chat.html", gin.H{
"collectName": collectName,
})
})
//创建集合
router.POST("/collect/:collectName", func(c *gin.Context) {
collectName := c.Param("collectName")
//判断集合是否存在
collectInfo, _ := utils.GetCollection(collectName)
collectInfoStatus := gjson.Get(string(collectInfo), "status").String()
if collectInfoStatus != "ok" {
res, err := utils.PutCollection(collectName)
if err != nil {
c.Writer.Write([]byte(err.Error()))
return
}
c.Writer.Write([]byte(res))
} else {
c.Writer.Write(collectInfo)
}
})
//删除集合
router.GET("/delCollect", func(c *gin.Context) {
collectName := c.Query("collectName")
info, _ := utils.DeleteCollection(collectName)
c.Writer.Write([]byte(info))
})
//查询集合
router.GET("/:collectName/info", func(c *gin.Context) {
collectName := c.Param("collectName")
list, err := utils.GetCollection(collectName)
if err != nil {
c.Writer.Write([]byte(err.Error()))
return
}
c.Writer.Write(list)
})
//集合列表
router.GET("/collects", func(c *gin.Context) {
list, err := utils.GetCollections()
if err != nil {
c.Writer.Write([]byte(err.Error()))
return
}
c.Writer.Write(list)
})
//向量列表
router.GET("/:collectName/points", func(c *gin.Context) {
collectName := c.Param("collectName")
list, err := utils.GetPoints(collectName, 1000, nil)
if err != nil {
c.Writer.Write([]byte(err.Error()))
return
}
c.Writer.Write(list)
})
//向量列表
router.GET("/:collectName/filePoints", func(c *gin.Context) {
collectName := c.Param("collectName")
id := c.Query("id")
points := make([]string, 0)
filePoints := models.FindAiFilePoint(1, 100000, "collect_name = ? and file_id = ?", collectName, id)
for _, item := range filePoints {
points = append(points, item.PointsId)
}
list, err := utils.GetPointsByIds(collectName, points)
if err != nil {
c.Writer.Write([]byte(err.Error()))
return
}
c.Writer.Write(list)
})
//删除向量
router.GET("/:collectName/delPoints", func(c *gin.Context) {
collectName := c.Param("collectName")
id := c.Query("id")
i, err := strconv.Atoi(id)
var points interface{}
if err != nil {
points = []string{id}
} else {
points = []int{i}
}
list, err := utils.DeletePoints(collectName, points)
if err != nil {
c.Writer.Write([]byte(err.Error()))
return
}
c.Writer.Write([]byte(list))
})
//删除文件
router.GET("/:collectName/delFile", func(c *gin.Context) {
collectName := c.Param("collectName")
id := c.Query("id")
points := make([]string, 0)
filePoints := models.FindAiFilePoint(1, 100000, "collect_name = ? and file_id = ?", collectName, id)
for _, item := range filePoints {
points = append(points, item.PointsId)
}
list, err := utils.DeletePoints(collectName, points)
models.DelFile("collect_name = ? and id = ?", collectName, id)
models.DelFilePoints("collect_name = ? and file_id = ?", collectName, id)
if err != nil {
c.Writer.Write([]byte(err.Error()))
return
}
c.Writer.Write([]byte(list))
})
//搜索流式响应
if chatModel == "CHATGLM-6B" {
router.Any("/:collectName/searchStream", controller.PostChatGLM)
} else {
router.Any("/:collectName/searchStream", controller.PostChatGPT)
}
//训练
router.POST("/:collectName/training", func(c *gin.Context) {
openaiKey := os.Getenv("OPENAI_KEY")
openaiUrl := os.Getenv("OPENAI_API_BASE")
//gpt url
gptUrl := c.PostForm("gptUrl")
if gptUrl != "" {
openaiUrl = gptUrl
}
//gpt secret
gptSecret := c.PostForm("gptSecret")
if gptSecret != "" {
openaiKey = gptSecret
}
collectName := c.Param("collectName")
//判断集合是否存在
collectInfo, err := utils.GetCollection(collectName)
collectInfoStatus := gjson.Get(string(collectInfo), "status").String()
if collectInfoStatus != "ok" {
//utils.PutCollection(collectName)
c.Writer.Write([]byte("集合不存在"))
return
}
//判断ID是否传递以及是否为数值型或uuid
id := c.PostForm("id")
var pointId interface{}
if id == "" {
pointId = uuid.NewV4().String()
} else {
i, err := strconv.Atoi(id)
if err != nil {
pointId = id
} else {
pointId = i
}
}
//向量化数据
content := c.PostForm("content")
res, err := Train(openaiUrl, openaiKey, pointId, collectName, content, "")
log.Println(err)
c.Writer.Write([]byte(res))
})
//上传链接
router.POST("/:collectName/uploadUrl", func(c *gin.Context) {
openaiKey := os.Getenv("OPENAI_KEY")
openaiUrl := os.Getenv("OPENAI_API_BASE")
//gpt url
gptUrl := c.PostForm("gptUrl")
if gptUrl != "" {
openaiUrl = gptUrl
}
//gpt secret
gptSecret := c.PostForm("gptSecret")
if gptSecret != "" {
openaiKey = gptSecret
}
collectName := c.Param("collectName")
url := c.PostForm("url")
if url == "" || collectName == "" {
c.JSON(200, gin.H{
"code": 400,
"msg": "参数错误!",
})
return
}
resp := utils.HttpGet(url)
if resp == "" {
c.JSON(200, gin.H{
"code": 400,
"msg": "请求数据错误!",
})
return
}
htmlContent := utils.TrimHtml(resp)
//入库
files := &models.AiFile{
FileName: url,
CollectName: collectName,
}
fileId := files.AddAiFile()
chunks := SplitTextByLength(htmlContent, 200)
for _, chunk := range chunks {
pointId := uuid.NewV4().String()
Train(openaiUrl, openaiKey, pointId, collectName, chunk, url)
//入库
aiFilePoint := &models.AiFilePoints{
FileId: fmt.Sprintf("%d", fileId),
CollectName: collectName,
PointsId: pointId,
}
aiFilePoint.AddAiFilePoint()
}
c.JSON(200, gin.H{
"code": 200,
})
})
//上传doc
router.POST("/:collectName/uploadDoc", func(c *gin.Context) {
openaiKey := os.Getenv("OPENAI_KEY")
openaiUrl := os.Getenv("OPENAI_API_BASE")
collectName := c.Param("collectName")
f, err := c.FormFile("file")
if err != nil {
c.JSON(200, gin.H{
"code": 400,
"msg": "上传失败!" + err.Error(),
})
return
} else {
fileExt := strings.ToLower(path.Ext(f.Filename))
if fileExt != ".docx" && fileExt != ".txt" && fileExt != ".xlsx" && fileExt != ".pdf" {
c.JSON(200, gin.H{
"code": 400,
"msg": "上传失败!只允许txt、pdf、docx或xlsx文件",
})
return
}
fileName := collectName + f.Filename
c.SaveUploadedFile(f, fileName)
text := ""
chunks := make([]string, 0)
if fileExt == ".txt" {
// 打开txt文件
file, _ := os.Open(fileName)
// 一次性读取整个txt文件的内容
txt, _ := ioutil.ReadAll(file)
text = string(txt)
file.Close()
} else if fileExt == ".docx" {
text, err = utils.ReadDocxAll(fileName)
} else if fileExt == ".pdf" {
text, err = utils.ReadPdfAll(fileName)
} else {
chunks, err = utils.ReadExcelAll(fileName)
}
if err != nil {
log.Println(err)
}
removeErr := os.Remove(fileName)
if removeErr != nil {
log.Println("Remove error:", fileName, removeErr)
}
if text == "" && len(chunks) == 0 {
err = errors.New("上传失败!读取数据为空")
}
if err != nil {
c.JSON(200, gin.H{
"code": 400,
"msg": err.Error(),
})
return
}
//入库
files := &models.AiFile{
FileName: f.Filename,
CollectName: collectName,
}
fileId := files.AddAiFile()
if fileExt != ".xlsx" {
chunks = SplitTextByLength(text, 200)
}
for _, chunk := range chunks {
pointId := uuid.NewV4().String()
log.Println("上传数据:" + chunk)
_, err := Train(openaiUrl, openaiKey, pointId, collectName, chunk, f.Filename)
if err == nil {
aiFilePoint := &models.AiFilePoints{
FileId: fmt.Sprintf("%d", fileId),
CollectName: collectName,
PointsId: pointId,
}
aiFilePoint.AddAiFilePoint()
} else {
log.Println(err)
}
}
os.Remove(fileName)
c.JSON(200, gin.H{
"code": 200,
})
}
})
//文件列表
router.GET("/:collectName/fileList", func(c *gin.Context) {
collectName := c.Param("collectName")
list := models.FindFileList(1, 1000, "collect_name = ? ", collectName)
c.JSON(200, gin.H{
"code": 200,
"result": list,
})
})
// 启动服务器
if err := router.Run(":8083"); err != nil {
panic(err)
}
}
//训练功能
func Train(openaiUrl string, openaiKey string, pointId interface{}, collectName, content, url string) (string, error) {
gpt := utils.NewChatGptTool(openaiUrl, openaiKey)
str := strings.TrimSpace(content)
response, err := gpt.GetEmbedding(str, "text-embedding-ada-002")
if err != nil {
return "", err
}
var embeddingResponse utils.EmbeddingResponse
json.Unmarshal([]byte(response), &embeddingResponse)
if len(embeddingResponse.Data) == 0 {
return "", errors.New(response)
}
points := []map[string]interface{}{
{
"id": pointId,
"payload": map[string]interface{}{"text": str, "url": url},
"vector": embeddingResponse.Data[0].Embedding,
},
}
res, err := utils.PutPoints(collectName, points)
return res, err
}
//长文本分块
func SplitTextByLength(text string, length int) []string {
var blocks []string
runes := []rune(text)
for i := 0; i < len(runes); i += length {
j := i + length
if j > len(runes) {
j = len(runes)
}
blocks = append(blocks, string(runes[i:j]))
}
return blocks
}