196 lines
4.5 KiB
Go
196 lines
4.5 KiB
Go
|
package controller
|
|||
|
|
|||
|
import (
|
|||
|
"fmt"
|
|||
|
"github.com/gin-gonic/gin"
|
|||
|
"github.com/tealeg/xlsx"
|
|||
|
"github.com/tidwall/gjson"
|
|||
|
"io/ioutil"
|
|||
|
"kefu/lib"
|
|||
|
"kefu/models"
|
|||
|
"kefu/tools"
|
|||
|
"kefu/types"
|
|||
|
"log"
|
|||
|
"net/url"
|
|||
|
"os"
|
|||
|
"path"
|
|||
|
"strings"
|
|||
|
)
|
|||
|
|
|||
|
type QdrantCollect struct {
|
|||
|
Id uint `form:"id" json:"id" uri:"id" xml:"id"`
|
|||
|
CateName string `form:"cate_name" json:"cate_name" uri:"cate_name" xml:"cate_name" binding:"required"`
|
|||
|
}
|
|||
|
|
|||
|
// 集合列表
|
|||
|
func GetCollects(c *gin.Context) {
|
|||
|
api := models.FindConfig("BaseGPTKnowledge")
|
|||
|
if api == "" {
|
|||
|
c.JSON(200, gin.H{
|
|||
|
"code": types.ApiCode.FAILED,
|
|||
|
"msg": types.ApiCode.GetMessage(types.ApiCode.INVALID),
|
|||
|
})
|
|||
|
return
|
|||
|
}
|
|||
|
ret, err := lib.GetCollections()
|
|||
|
if err != nil {
|
|||
|
c.Writer.Write([]byte(err.Error()))
|
|||
|
return
|
|||
|
}
|
|||
|
//c.Writer.Write(list)
|
|||
|
//info := tools.Get(fmt.Sprintf("%s/collects", api))
|
|||
|
list := gjson.Get(string(ret), "result.collections").String()
|
|||
|
c.JSON(200, gin.H{
|
|||
|
"code": types.ApiCode.SUCCESS,
|
|||
|
"msg": types.ApiCode.GetMessage(types.ApiCode.SUCCESS),
|
|||
|
"result": list,
|
|||
|
})
|
|||
|
}
|
|||
|
|
|||
|
// 编辑qdrant
|
|||
|
func PostQdrantCollect(c *gin.Context) {
|
|||
|
var form QdrantCollect
|
|||
|
err := c.Bind(&form)
|
|||
|
//api := models.FindConfig("BaseGPTKnowledge")
|
|||
|
if err != nil {
|
|||
|
c.JSON(200, gin.H{
|
|||
|
"code": types.ApiCode.FAILED,
|
|||
|
"msg": types.ApiCode.GetMessage(types.ApiCode.INVALID),
|
|||
|
"result": err.Error(),
|
|||
|
})
|
|||
|
return
|
|||
|
}
|
|||
|
collectName := form.CateName
|
|||
|
//判断集合是否存在
|
|||
|
collectInfo, _ := lib.GetCollection(collectName)
|
|||
|
collectInfoStatus := gjson.Get(string(collectInfo), "status").String()
|
|||
|
if collectInfoStatus != "ok" {
|
|||
|
res, err := lib.PutCollection(collectName)
|
|||
|
if err != nil {
|
|||
|
c.Writer.Write([]byte(err.Error()))
|
|||
|
return
|
|||
|
}
|
|||
|
c.Writer.Write([]byte(res))
|
|||
|
} else {
|
|||
|
c.Writer.Write(collectInfo)
|
|||
|
}
|
|||
|
//tools.PostForm(fmt.Sprintf("%s/collect/%s", api, form.CateName), nil)
|
|||
|
//c.JSON(200, gin.H{
|
|||
|
// "code": types.ApiCode.SUCCESS,
|
|||
|
// "msg": types.ApiCode.GetMessage(types.ApiCode.SUCCESS),
|
|||
|
//})
|
|||
|
}
|
|||
|
|
|||
|
// 删除集合
|
|||
|
func GetDelCollect(c *gin.Context) {
|
|||
|
collectName := c.Query("collectName")
|
|||
|
api := models.FindConfig("BaseGPTKnowledge")
|
|||
|
if api == "" || collectName == "" {
|
|||
|
c.JSON(200, gin.H{
|
|||
|
"code": types.ApiCode.FAILED,
|
|||
|
"msg": types.ApiCode.GetMessage(types.ApiCode.INVALID),
|
|||
|
})
|
|||
|
return
|
|||
|
}
|
|||
|
info, _ := lib.DeleteCollection(collectName)
|
|||
|
c.Writer.Write([]byte(info))
|
|||
|
|
|||
|
//c.JSON(200, gin.H{
|
|||
|
// "code": types.ApiCode.SUCCESS,
|
|||
|
// "msg": types.ApiCode.GetMessage(types.ApiCode.SUCCESS),
|
|||
|
// "result": info,
|
|||
|
//})
|
|||
|
}
|
|||
|
|
|||
|
// 上传文档
|
|||
|
func PostUploadCollect(c *gin.Context) {
|
|||
|
collectName := c.Query("collect")
|
|||
|
api := models.FindConfig("BaseGPTKnowledge")
|
|||
|
if api == "" || collectName == "" {
|
|||
|
c.JSON(200, gin.H{
|
|||
|
"code": types.ApiCode.FAILED,
|
|||
|
"msg": types.ApiCode.GetMessage(types.ApiCode.INVALID),
|
|||
|
})
|
|||
|
return
|
|||
|
}
|
|||
|
f, err := c.FormFile("file")
|
|||
|
if err != nil {
|
|||
|
c.JSON(200, gin.H{
|
|||
|
"code": 400,
|
|||
|
"msg": "上传失败!" + err.Error(),
|
|||
|
})
|
|||
|
return
|
|||
|
} else {
|
|||
|
|
|||
|
fileExt := strings.ToLower(path.Ext(f.Filename))
|
|||
|
if fileExt != ".docx" && fileExt != ".txt" && fileExt != ".xlsx" {
|
|||
|
c.JSON(200, gin.H{
|
|||
|
"code": 400,
|
|||
|
"msg": "上传失败!只允许txt或xlsx文件",
|
|||
|
})
|
|||
|
return
|
|||
|
}
|
|||
|
fileName := collectName + f.Filename
|
|||
|
c.SaveUploadedFile(f, fileName)
|
|||
|
texts := make([]string, 0)
|
|||
|
if fileExt == ".txt" {
|
|||
|
// 打开txt文件
|
|||
|
file, _ := os.Open(fileName)
|
|||
|
// 一次性读取整个txt文件的内容
|
|||
|
txt, _ := ioutil.ReadAll(file)
|
|||
|
texts = append(texts, string(txt))
|
|||
|
file.Close()
|
|||
|
} else if fileExt == ".xlsx" {
|
|||
|
// 打开 Excel 文件
|
|||
|
xlFile, err := xlsx.OpenFile(fileName)
|
|||
|
if err != nil {
|
|||
|
c.JSON(200, gin.H{
|
|||
|
"code": 400,
|
|||
|
"msg": "读取excel失败:" + err.Error(),
|
|||
|
})
|
|||
|
return
|
|||
|
}
|
|||
|
// 遍历每个 Sheet
|
|||
|
for _, sheet := range xlFile.Sheets {
|
|||
|
// 遍历每行数据
|
|||
|
for _, row := range sheet.Rows {
|
|||
|
line := ""
|
|||
|
// 遍历每个单元格
|
|||
|
for _, cell := range row.Cells {
|
|||
|
// 输出单元格的值
|
|||
|
line += cell.Value
|
|||
|
}
|
|||
|
if line == "" {
|
|||
|
continue
|
|||
|
}
|
|||
|
texts = append(texts, line)
|
|||
|
}
|
|||
|
|
|||
|
}
|
|||
|
}
|
|||
|
err := os.Remove(fileName)
|
|||
|
if err != nil {
|
|||
|
c.JSON(200, gin.H{
|
|||
|
"code": 400,
|
|||
|
"msg": err.Error(),
|
|||
|
})
|
|||
|
return
|
|||
|
}
|
|||
|
os.Remove(fileName)
|
|||
|
for _, text := range texts {
|
|||
|
log.Println(text)
|
|||
|
data := url.Values{}
|
|||
|
data.Set("content", text)
|
|||
|
_, err := tools.PostForm(fmt.Sprintf("%s/%s/training", api, collectName), data)
|
|||
|
if err != nil {
|
|||
|
log.Println(err)
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
c.JSON(200, gin.H{
|
|||
|
"code": types.ApiCode.SUCCESS,
|
|||
|
"msg": types.ApiCode.GetMessage(types.ApiCode.SUCCESS),
|
|||
|
})
|
|||
|
}
|
|||
|
}
|