196 lines
4.5 KiB
Go
196 lines
4.5 KiB
Go
package controller
|
||
|
||
import (
|
||
"fmt"
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/tealeg/xlsx"
|
||
"github.com/tidwall/gjson"
|
||
"io/ioutil"
|
||
"kefu/lib"
|
||
"kefu/models"
|
||
"kefu/tools"
|
||
"kefu/types"
|
||
"log"
|
||
"net/url"
|
||
"os"
|
||
"path"
|
||
"strings"
|
||
)
|
||
|
||
type QdrantCollect struct {
|
||
Id uint `form:"id" json:"id" uri:"id" xml:"id"`
|
||
CateName string `form:"cate_name" json:"cate_name" uri:"cate_name" xml:"cate_name" binding:"required"`
|
||
}
|
||
|
||
// 集合列表
|
||
func GetCollects(c *gin.Context) {
|
||
api := models.FindConfig("BaseGPTKnowledge")
|
||
if api == "" {
|
||
c.JSON(200, gin.H{
|
||
"code": types.ApiCode.FAILED,
|
||
"msg": types.ApiCode.GetMessage(types.ApiCode.INVALID),
|
||
})
|
||
return
|
||
}
|
||
ret, err := lib.GetCollections()
|
||
if err != nil {
|
||
c.Writer.Write([]byte(err.Error()))
|
||
return
|
||
}
|
||
//c.Writer.Write(list)
|
||
//info := tools.Get(fmt.Sprintf("%s/collects", api))
|
||
list := gjson.Get(string(ret), "result.collections").String()
|
||
c.JSON(200, gin.H{
|
||
"code": types.ApiCode.SUCCESS,
|
||
"msg": types.ApiCode.GetMessage(types.ApiCode.SUCCESS),
|
||
"result": list,
|
||
})
|
||
}
|
||
|
||
// 编辑qdrant
|
||
func PostQdrantCollect(c *gin.Context) {
|
||
var form QdrantCollect
|
||
err := c.Bind(&form)
|
||
//api := models.FindConfig("BaseGPTKnowledge")
|
||
if err != nil {
|
||
c.JSON(200, gin.H{
|
||
"code": types.ApiCode.FAILED,
|
||
"msg": types.ApiCode.GetMessage(types.ApiCode.INVALID),
|
||
"result": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
collectName := form.CateName
|
||
//判断集合是否存在
|
||
collectInfo, _ := lib.GetCollection(collectName)
|
||
collectInfoStatus := gjson.Get(string(collectInfo), "status").String()
|
||
if collectInfoStatus != "ok" {
|
||
res, err := lib.PutCollection(collectName)
|
||
if err != nil {
|
||
c.Writer.Write([]byte(err.Error()))
|
||
return
|
||
}
|
||
c.Writer.Write([]byte(res))
|
||
} else {
|
||
c.Writer.Write(collectInfo)
|
||
}
|
||
//tools.PostForm(fmt.Sprintf("%s/collect/%s", api, form.CateName), nil)
|
||
//c.JSON(200, gin.H{
|
||
// "code": types.ApiCode.SUCCESS,
|
||
// "msg": types.ApiCode.GetMessage(types.ApiCode.SUCCESS),
|
||
//})
|
||
}
|
||
|
||
// 删除集合
|
||
func GetDelCollect(c *gin.Context) {
|
||
collectName := c.Query("collectName")
|
||
api := models.FindConfig("BaseGPTKnowledge")
|
||
if api == "" || collectName == "" {
|
||
c.JSON(200, gin.H{
|
||
"code": types.ApiCode.FAILED,
|
||
"msg": types.ApiCode.GetMessage(types.ApiCode.INVALID),
|
||
})
|
||
return
|
||
}
|
||
info, _ := lib.DeleteCollection(collectName)
|
||
c.Writer.Write([]byte(info))
|
||
|
||
//c.JSON(200, gin.H{
|
||
// "code": types.ApiCode.SUCCESS,
|
||
// "msg": types.ApiCode.GetMessage(types.ApiCode.SUCCESS),
|
||
// "result": info,
|
||
//})
|
||
}
|
||
|
||
// 上传文档
|
||
func PostUploadCollect(c *gin.Context) {
|
||
collectName := c.Query("collect")
|
||
api := models.FindConfig("BaseGPTKnowledge")
|
||
if api == "" || collectName == "" {
|
||
c.JSON(200, gin.H{
|
||
"code": types.ApiCode.FAILED,
|
||
"msg": types.ApiCode.GetMessage(types.ApiCode.INVALID),
|
||
})
|
||
return
|
||
}
|
||
f, err := c.FormFile("file")
|
||
if err != nil {
|
||
c.JSON(200, gin.H{
|
||
"code": 400,
|
||
"msg": "上传失败!" + err.Error(),
|
||
})
|
||
return
|
||
} else {
|
||
|
||
fileExt := strings.ToLower(path.Ext(f.Filename))
|
||
if fileExt != ".docx" && fileExt != ".txt" && fileExt != ".xlsx" {
|
||
c.JSON(200, gin.H{
|
||
"code": 400,
|
||
"msg": "上传失败!只允许txt或xlsx文件",
|
||
})
|
||
return
|
||
}
|
||
fileName := collectName + f.Filename
|
||
c.SaveUploadedFile(f, fileName)
|
||
texts := make([]string, 0)
|
||
if fileExt == ".txt" {
|
||
// 打开txt文件
|
||
file, _ := os.Open(fileName)
|
||
// 一次性读取整个txt文件的内容
|
||
txt, _ := ioutil.ReadAll(file)
|
||
texts = append(texts, string(txt))
|
||
file.Close()
|
||
} else if fileExt == ".xlsx" {
|
||
// 打开 Excel 文件
|
||
xlFile, err := xlsx.OpenFile(fileName)
|
||
if err != nil {
|
||
c.JSON(200, gin.H{
|
||
"code": 400,
|
||
"msg": "读取excel失败:" + err.Error(),
|
||
})
|
||
return
|
||
}
|
||
// 遍历每个 Sheet
|
||
for _, sheet := range xlFile.Sheets {
|
||
// 遍历每行数据
|
||
for _, row := range sheet.Rows {
|
||
line := ""
|
||
// 遍历每个单元格
|
||
for _, cell := range row.Cells {
|
||
// 输出单元格的值
|
||
line += cell.Value
|
||
}
|
||
if line == "" {
|
||
continue
|
||
}
|
||
texts = append(texts, line)
|
||
}
|
||
|
||
}
|
||
}
|
||
err := os.Remove(fileName)
|
||
if err != nil {
|
||
c.JSON(200, gin.H{
|
||
"code": 400,
|
||
"msg": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
os.Remove(fileName)
|
||
for _, text := range texts {
|
||
log.Println(text)
|
||
data := url.Values{}
|
||
data.Set("content", text)
|
||
_, err := tools.PostForm(fmt.Sprintf("%s/%s/training", api, collectName), data)
|
||
if err != nil {
|
||
log.Println(err)
|
||
}
|
||
}
|
||
|
||
c.JSON(200, gin.H{
|
||
"code": types.ApiCode.SUCCESS,
|
||
"msg": types.ApiCode.GetMessage(types.ApiCode.SUCCESS),
|
||
})
|
||
}
|
||
}
|