kefu/controller/qdrant.go

196 lines
4.5 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package controller
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/tealeg/xlsx"
"github.com/tidwall/gjson"
"io/ioutil"
"kefu/lib"
"kefu/models"
"kefu/tools"
"kefu/types"
"log"
"net/url"
"os"
"path"
"strings"
)
type QdrantCollect struct {
Id uint `form:"id" json:"id" uri:"id" xml:"id"`
CateName string `form:"cate_name" json:"cate_name" uri:"cate_name" xml:"cate_name" binding:"required"`
}
// 集合列表
func GetCollects(c *gin.Context) {
api := models.FindConfig("BaseGPTKnowledge")
if api == "" {
c.JSON(200, gin.H{
"code": types.ApiCode.FAILED,
"msg": types.ApiCode.GetMessage(types.ApiCode.INVALID),
})
return
}
ret, err := lib.GetCollections()
if err != nil {
c.Writer.Write([]byte(err.Error()))
return
}
//c.Writer.Write(list)
//info := tools.Get(fmt.Sprintf("%s/collects", api))
list := gjson.Get(string(ret), "result.collections").String()
c.JSON(200, gin.H{
"code": types.ApiCode.SUCCESS,
"msg": types.ApiCode.GetMessage(types.ApiCode.SUCCESS),
"result": list,
})
}
// 编辑qdrant
func PostQdrantCollect(c *gin.Context) {
var form QdrantCollect
err := c.Bind(&form)
//api := models.FindConfig("BaseGPTKnowledge")
if err != nil {
c.JSON(200, gin.H{
"code": types.ApiCode.FAILED,
"msg": types.ApiCode.GetMessage(types.ApiCode.INVALID),
"result": err.Error(),
})
return
}
collectName := form.CateName
//判断集合是否存在
collectInfo, _ := lib.GetCollection(collectName)
collectInfoStatus := gjson.Get(string(collectInfo), "status").String()
if collectInfoStatus != "ok" {
res, err := lib.PutCollection(collectName)
if err != nil {
c.Writer.Write([]byte(err.Error()))
return
}
c.Writer.Write([]byte(res))
} else {
c.Writer.Write(collectInfo)
}
//tools.PostForm(fmt.Sprintf("%s/collect/%s", api, form.CateName), nil)
//c.JSON(200, gin.H{
// "code": types.ApiCode.SUCCESS,
// "msg": types.ApiCode.GetMessage(types.ApiCode.SUCCESS),
//})
}
// 删除集合
func GetDelCollect(c *gin.Context) {
collectName := c.Query("collectName")
api := models.FindConfig("BaseGPTKnowledge")
if api == "" || collectName == "" {
c.JSON(200, gin.H{
"code": types.ApiCode.FAILED,
"msg": types.ApiCode.GetMessage(types.ApiCode.INVALID),
})
return
}
info, _ := lib.DeleteCollection(collectName)
c.Writer.Write([]byte(info))
//c.JSON(200, gin.H{
// "code": types.ApiCode.SUCCESS,
// "msg": types.ApiCode.GetMessage(types.ApiCode.SUCCESS),
// "result": info,
//})
}
// 上传文档
func PostUploadCollect(c *gin.Context) {
collectName := c.Query("collect")
api := models.FindConfig("BaseGPTKnowledge")
if api == "" || collectName == "" {
c.JSON(200, gin.H{
"code": types.ApiCode.FAILED,
"msg": types.ApiCode.GetMessage(types.ApiCode.INVALID),
})
return
}
f, err := c.FormFile("file")
if err != nil {
c.JSON(200, gin.H{
"code": 400,
"msg": "上传失败!" + err.Error(),
})
return
} else {
fileExt := strings.ToLower(path.Ext(f.Filename))
if fileExt != ".docx" && fileExt != ".txt" && fileExt != ".xlsx" {
c.JSON(200, gin.H{
"code": 400,
"msg": "上传失败!只允许txt或xlsx文件",
})
return
}
fileName := collectName + f.Filename
c.SaveUploadedFile(f, fileName)
texts := make([]string, 0)
if fileExt == ".txt" {
// 打开txt文件
file, _ := os.Open(fileName)
// 一次性读取整个txt文件的内容
txt, _ := ioutil.ReadAll(file)
texts = append(texts, string(txt))
file.Close()
} else if fileExt == ".xlsx" {
// 打开 Excel 文件
xlFile, err := xlsx.OpenFile(fileName)
if err != nil {
c.JSON(200, gin.H{
"code": 400,
"msg": "读取excel失败" + err.Error(),
})
return
}
// 遍历每个 Sheet
for _, sheet := range xlFile.Sheets {
// 遍历每行数据
for _, row := range sheet.Rows {
line := ""
// 遍历每个单元格
for _, cell := range row.Cells {
// 输出单元格的值
line += cell.Value
}
if line == "" {
continue
}
texts = append(texts, line)
}
}
}
err := os.Remove(fileName)
if err != nil {
c.JSON(200, gin.H{
"code": 400,
"msg": err.Error(),
})
return
}
os.Remove(fileName)
for _, text := range texts {
log.Println(text)
data := url.Values{}
data.Set("content", text)
_, err := tools.PostForm(fmt.Sprintf("%s/%s/training", api, collectName), data)
if err != nil {
log.Println(err)
}
}
c.JSON(200, gin.H{
"code": types.ApiCode.SUCCESS,
"msg": types.ApiCode.GetMessage(types.ApiCode.SUCCESS),
})
}
}