196 lines
4.6 KiB
Go
196 lines
4.6 KiB
Go
package ai
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"github.com/gin-gonic/gin"
|
|
uuid "github.com/satori/go.uuid"
|
|
"io/ioutil"
|
|
"kefu/common"
|
|
"kefu/models"
|
|
"kefu/tools"
|
|
"log"
|
|
"os"
|
|
"path"
|
|
"strconv"
|
|
"strings"
|
|
"unicode/utf8"
|
|
)
|
|
|
|
// 上传文档
|
|
func PostUploadDoc(c *gin.Context) {
|
|
collectNameInter, _ := c.Get("collect_name")
|
|
collectName := collectNameInter.(string)
|
|
openaiUrl, _ := c.Get("openai_url")
|
|
openaiKey, _ := c.Get("openai_key")
|
|
|
|
f, err := c.FormFile("file")
|
|
if err != nil {
|
|
c.JSON(200, gin.H{
|
|
"code": 400,
|
|
"msg": "上传失败!" + err.Error(),
|
|
})
|
|
return
|
|
} else {
|
|
maxSize := 15 * 1024 * 1024
|
|
uploadMaxSize := models.FindConfig("UploadMaxSize")
|
|
if uploadMaxSize != "" {
|
|
uploadMaxSizeInt, _ := strconv.Atoi(uploadMaxSize)
|
|
maxSize = uploadMaxSizeInt * 1024 * 1024
|
|
}
|
|
if f.Size >= int64(maxSize) {
|
|
c.JSON(200, gin.H{
|
|
"code": 400,
|
|
"msg": "上传失败,文件大小超限!",
|
|
})
|
|
return
|
|
}
|
|
|
|
fileExt := strings.ToLower(path.Ext(f.Filename))
|
|
if fileExt != ".docx" && fileExt != ".txt" && fileExt != ".xlsx" && fileExt != ".pdf" {
|
|
c.JSON(200, gin.H{
|
|
"code": 400,
|
|
"msg": "上传失败!只允许txt、pdf、docx或xlsx文件",
|
|
})
|
|
return
|
|
}
|
|
|
|
fileName := f.Filename
|
|
fildDir := fmt.Sprintf("%sai/%s/", common.Upload, collectName)
|
|
isExist, _ := tools.IsFileExist(fildDir)
|
|
if !isExist {
|
|
err := os.MkdirAll(fildDir, os.ModePerm)
|
|
if err != nil {
|
|
c.JSON(200, gin.H{
|
|
"code": 400,
|
|
"msg": "上传失败!" + err.Error(),
|
|
})
|
|
return
|
|
}
|
|
}
|
|
filepath := fmt.Sprintf("%s%s", fildDir, fileName)
|
|
c.SaveUploadedFile(f, filepath)
|
|
//path := "/" + filepath
|
|
//上传到阿里云oss
|
|
//oss, err := lib.NewOssLib()
|
|
//if err == nil {
|
|
// dstUrl, err := oss.Upload(filepath, filepath)
|
|
// if err == nil {
|
|
// path = dstUrl
|
|
// }
|
|
//}
|
|
|
|
text := ""
|
|
chunks := make([]string, 0)
|
|
if fileExt == ".txt" {
|
|
// 打开txt文件
|
|
file, _ := os.Open(filepath)
|
|
// 一次性读取整个txt文件的内容
|
|
txt, _ := ioutil.ReadAll(file)
|
|
text = string(txt)
|
|
file.Close()
|
|
} else if fileExt == ".docx" {
|
|
text, err = tools.ReadDocxAll(filepath)
|
|
} else if fileExt == ".pdf" {
|
|
text, err = tools.ReadPdfAll(filepath)
|
|
} else {
|
|
chunks, err = tools.ReadExcelAll(filepath)
|
|
}
|
|
if err != nil {
|
|
log.Println(err)
|
|
}
|
|
//removeErr := os.Remove(filepath)
|
|
//if removeErr != nil {
|
|
// log.Println("Remove error:", filepath, removeErr)
|
|
//}
|
|
if text == "" && len(chunks) == 0 {
|
|
err = errors.New("上传失败!读取数据为空")
|
|
}
|
|
if err != nil {
|
|
c.JSON(200, gin.H{
|
|
"code": 400,
|
|
"msg": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
|
|
fileSize := tools.Int2Str(utf8.RuneCountInString(text))
|
|
if fileExt != ".xlsx" {
|
|
chunks = SplitTextByLength(text, 500)
|
|
} else {
|
|
num := 0
|
|
for _, chunk := range chunks {
|
|
num += utf8.RuneCountInString(chunk)
|
|
}
|
|
fileSize = tools.Int2Str(num)
|
|
}
|
|
//入库
|
|
files := &models.AiFile{
|
|
FileName: f.Filename,
|
|
CollectName: collectName,
|
|
FileSize: fileSize,
|
|
}
|
|
fileId := files.AddAiFile()
|
|
for _, chunk := range chunks {
|
|
pointId := uuid.NewV4().String()
|
|
log.Println("上传数据:" + chunk)
|
|
_, err := Train(openaiUrl.(string), openaiKey.(string), pointId, collectName, chunk, tools.Int2Str(fileId), f.Filename, "")
|
|
if err == nil {
|
|
aiFilePoint := &models.AiFilePoints{
|
|
FileId: fmt.Sprintf("%d", fileId),
|
|
CollectName: collectName,
|
|
PointsId: pointId,
|
|
}
|
|
aiFilePoint.AddAiFilePoint()
|
|
} else {
|
|
log.Println(err)
|
|
}
|
|
}
|
|
os.Remove(fileName)
|
|
c.JSON(200, gin.H{
|
|
"code": 200,
|
|
})
|
|
}
|
|
}
|
|
|
|
// 上传文档
|
|
func PostUploadUrl(c *gin.Context) {
|
|
collectNameInter, _ := c.Get("collect_name")
|
|
collectName := collectNameInter.(string)
|
|
openaiUrl, _ := c.Get("openai_url")
|
|
openaiKey, _ := c.Get("openai_key")
|
|
url := c.PostForm("url")
|
|
resp := tools.Get(url)
|
|
if resp == "" {
|
|
c.JSON(200, gin.H{
|
|
"code": 400,
|
|
"msg": "请求数据错误!",
|
|
})
|
|
return
|
|
}
|
|
htmlContent := tools.TrimHtml(resp)
|
|
//入库
|
|
files := &models.AiFile{
|
|
FileName: url,
|
|
CollectName: collectName,
|
|
FileSize: tools.Int2Str(utf8.RuneCountInString(htmlContent)),
|
|
}
|
|
fileId := files.AddAiFile()
|
|
chunks := SplitTextByLength(htmlContent, 200)
|
|
for _, chunk := range chunks {
|
|
pointId := uuid.NewV4().String()
|
|
log.Println("上传网页数据:" + chunk)
|
|
Train(openaiUrl.(string), openaiKey.(string), pointId, collectName, chunk, tools.Int2Str(fileId), "", url)
|
|
//入库
|
|
aiFilePoint := &models.AiFilePoints{
|
|
FileId: fmt.Sprintf("%d", fileId),
|
|
CollectName: collectName,
|
|
PointsId: pointId,
|
|
}
|
|
aiFilePoint.AddAiFilePoint()
|
|
}
|
|
c.JSON(200, gin.H{
|
|
"code": 200,
|
|
})
|
|
}
|