kefu/controller/qdrant.go

196 lines
4.5 KiB
Go
Raw Normal View History

2024-12-10 02:50:12 +00:00
package controller
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/tealeg/xlsx"
"github.com/tidwall/gjson"
"io/ioutil"
"kefu/lib"
"kefu/models"
"kefu/tools"
"kefu/types"
"log"
"net/url"
"os"
"path"
"strings"
)
type QdrantCollect struct {
Id uint `form:"id" json:"id" uri:"id" xml:"id"`
CateName string `form:"cate_name" json:"cate_name" uri:"cate_name" xml:"cate_name" binding:"required"`
}
// 集合列表
func GetCollects(c *gin.Context) {
api := models.FindConfig("BaseGPTKnowledge")
if api == "" {
c.JSON(200, gin.H{
"code": types.ApiCode.FAILED,
"msg": types.ApiCode.GetMessage(types.ApiCode.INVALID),
})
return
}
ret, err := lib.GetCollections()
if err != nil {
c.Writer.Write([]byte(err.Error()))
return
}
//c.Writer.Write(list)
//info := tools.Get(fmt.Sprintf("%s/collects", api))
list := gjson.Get(string(ret), "result.collections").String()
c.JSON(200, gin.H{
"code": types.ApiCode.SUCCESS,
"msg": types.ApiCode.GetMessage(types.ApiCode.SUCCESS),
"result": list,
})
}
// 编辑qdrant
func PostQdrantCollect(c *gin.Context) {
var form QdrantCollect
err := c.Bind(&form)
//api := models.FindConfig("BaseGPTKnowledge")
if err != nil {
c.JSON(200, gin.H{
"code": types.ApiCode.FAILED,
"msg": types.ApiCode.GetMessage(types.ApiCode.INVALID),
"result": err.Error(),
})
return
}
collectName := form.CateName
//判断集合是否存在
collectInfo, _ := lib.GetCollection(collectName)
collectInfoStatus := gjson.Get(string(collectInfo), "status").String()
if collectInfoStatus != "ok" {
res, err := lib.PutCollection(collectName)
if err != nil {
c.Writer.Write([]byte(err.Error()))
return
}
c.Writer.Write([]byte(res))
} else {
c.Writer.Write(collectInfo)
}
//tools.PostForm(fmt.Sprintf("%s/collect/%s", api, form.CateName), nil)
//c.JSON(200, gin.H{
// "code": types.ApiCode.SUCCESS,
// "msg": types.ApiCode.GetMessage(types.ApiCode.SUCCESS),
//})
}
// 删除集合
func GetDelCollect(c *gin.Context) {
collectName := c.Query("collectName")
api := models.FindConfig("BaseGPTKnowledge")
if api == "" || collectName == "" {
c.JSON(200, gin.H{
"code": types.ApiCode.FAILED,
"msg": types.ApiCode.GetMessage(types.ApiCode.INVALID),
})
return
}
info, _ := lib.DeleteCollection(collectName)
c.Writer.Write([]byte(info))
//c.JSON(200, gin.H{
// "code": types.ApiCode.SUCCESS,
// "msg": types.ApiCode.GetMessage(types.ApiCode.SUCCESS),
// "result": info,
//})
}
// 上传文档
func PostUploadCollect(c *gin.Context) {
collectName := c.Query("collect")
api := models.FindConfig("BaseGPTKnowledge")
if api == "" || collectName == "" {
c.JSON(200, gin.H{
"code": types.ApiCode.FAILED,
"msg": types.ApiCode.GetMessage(types.ApiCode.INVALID),
})
return
}
f, err := c.FormFile("file")
if err != nil {
c.JSON(200, gin.H{
"code": 400,
"msg": "上传失败!" + err.Error(),
})
return
} else {
fileExt := strings.ToLower(path.Ext(f.Filename))
if fileExt != ".docx" && fileExt != ".txt" && fileExt != ".xlsx" {
c.JSON(200, gin.H{
"code": 400,
"msg": "上传失败!只允许txt或xlsx文件",
})
return
}
fileName := collectName + f.Filename
c.SaveUploadedFile(f, fileName)
texts := make([]string, 0)
if fileExt == ".txt" {
// 打开txt文件
file, _ := os.Open(fileName)
// 一次性读取整个txt文件的内容
txt, _ := ioutil.ReadAll(file)
texts = append(texts, string(txt))
file.Close()
} else if fileExt == ".xlsx" {
// 打开 Excel 文件
xlFile, err := xlsx.OpenFile(fileName)
if err != nil {
c.JSON(200, gin.H{
"code": 400,
"msg": "读取excel失败" + err.Error(),
})
return
}
// 遍历每个 Sheet
for _, sheet := range xlFile.Sheets {
// 遍历每行数据
for _, row := range sheet.Rows {
line := ""
// 遍历每个单元格
for _, cell := range row.Cells {
// 输出单元格的值
line += cell.Value
}
if line == "" {
continue
}
texts = append(texts, line)
}
}
}
err := os.Remove(fileName)
if err != nil {
c.JSON(200, gin.H{
"code": 400,
"msg": err.Error(),
})
return
}
os.Remove(fileName)
for _, text := range texts {
log.Println(text)
data := url.Values{}
data.Set("content", text)
_, err := tools.PostForm(fmt.Sprintf("%s/%s/training", api, collectName), data)
if err != nil {
log.Println(err)
}
}
c.JSON(200, gin.H{
"code": types.ApiCode.SUCCESS,
"msg": types.ApiCode.GetMessage(types.ApiCode.SUCCESS),
})
}
}