430 lines
11 KiB
Go
430 lines
11 KiB
Go
|
package main
|
|||
|
|
|||
|
import (
|
|||
|
"encoding/json"
|
|||
|
"errors"
|
|||
|
"fmt"
|
|||
|
"github.com/gin-gonic/gin"
|
|||
|
"github.com/joho/godotenv"
|
|||
|
uuid "github.com/satori/go.uuid"
|
|||
|
"github.com/tidwall/gjson"
|
|||
|
"io/ioutil"
|
|||
|
"knowledge/controller"
|
|||
|
"knowledge/models"
|
|||
|
"knowledge/utils"
|
|||
|
"log"
|
|||
|
"net/http"
|
|||
|
"os"
|
|||
|
"path"
|
|||
|
"strconv"
|
|||
|
"strings"
|
|||
|
)
|
|||
|
|
|||
|
func main() {
|
|||
|
err := godotenv.Load(".env")
|
|||
|
if err != nil {
|
|||
|
log.Fatalf("Error loading .env file: %v", err)
|
|||
|
}
|
|||
|
// 读取环境变量
|
|||
|
utils.QdrantBase = os.Getenv("QDRANT_BASE")
|
|||
|
utils.QdrantPort = os.Getenv("QDRANT_PORT")
|
|||
|
mysqlServer := os.Getenv("MYSQL_SERVER")
|
|||
|
mysqlPort := os.Getenv("MYSQL_PORT")
|
|||
|
mysqlDbname := os.Getenv("MYSQL_DATABASE")
|
|||
|
mysqlUsername := os.Getenv("MYSQL_USERNAME")
|
|||
|
mysqlPassword := os.Getenv("MYSQL_PASSWORD")
|
|||
|
chatModel := os.Getenv("CHAT_MODEL")
|
|||
|
|
|||
|
//初始化mysql
|
|||
|
models.Connect(mysqlServer, mysqlPort, mysqlDbname, mysqlUsername, mysqlPassword)
|
|||
|
|
|||
|
// 创建 Gin 引擎
|
|||
|
router := gin.Default()
|
|||
|
//启用跨域中间件
|
|||
|
router.Use(func(c *gin.Context) {
|
|||
|
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
|
|||
|
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET,POST,OPTIONS,DELETE,PUT")
|
|||
|
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Origin")
|
|||
|
if c.Request.Method == "OPTIONS" {
|
|||
|
c.AbortWithStatus(204)
|
|||
|
return
|
|||
|
}
|
|||
|
c.Next()
|
|||
|
})
|
|||
|
router.Static("/pannel", "pannel/")
|
|||
|
router.Static("/static", "pannel/static/")
|
|||
|
router.LoadHTMLGlob("templates/*")
|
|||
|
router.GET("/", func(c *gin.Context) {
|
|||
|
c.HTML(http.StatusOK, "welcome.html", gin.H{})
|
|||
|
})
|
|||
|
//首页
|
|||
|
router.GET("/:collectName/index", func(c *gin.Context) {
|
|||
|
collectName := c.Param("collectName")
|
|||
|
c.HTML(http.StatusOK, "index.html", gin.H{
|
|||
|
"collectName": collectName,
|
|||
|
})
|
|||
|
})
|
|||
|
//后台界面
|
|||
|
router.GET("/:collectName/qdrant", func(c *gin.Context) {
|
|||
|
collectName := c.Param("collectName")
|
|||
|
c.HTML(http.StatusOK, "qdrant.html", gin.H{
|
|||
|
"collectName": collectName,
|
|||
|
})
|
|||
|
})
|
|||
|
//聊天页
|
|||
|
router.GET("/:collectName/chat", func(c *gin.Context) {
|
|||
|
collectName := c.Param("collectName")
|
|||
|
c.HTML(http.StatusOK, "chat.html", gin.H{
|
|||
|
"collectName": collectName,
|
|||
|
})
|
|||
|
})
|
|||
|
//创建集合
|
|||
|
router.POST("/collect/:collectName", func(c *gin.Context) {
|
|||
|
collectName := c.Param("collectName")
|
|||
|
//判断集合是否存在
|
|||
|
collectInfo, _ := utils.GetCollection(collectName)
|
|||
|
collectInfoStatus := gjson.Get(string(collectInfo), "status").String()
|
|||
|
if collectInfoStatus != "ok" {
|
|||
|
res, err := utils.PutCollection(collectName)
|
|||
|
if err != nil {
|
|||
|
c.Writer.Write([]byte(err.Error()))
|
|||
|
return
|
|||
|
}
|
|||
|
c.Writer.Write([]byte(res))
|
|||
|
} else {
|
|||
|
c.Writer.Write(collectInfo)
|
|||
|
}
|
|||
|
|
|||
|
})
|
|||
|
//删除集合
|
|||
|
router.GET("/delCollect", func(c *gin.Context) {
|
|||
|
collectName := c.Query("collectName")
|
|||
|
info, _ := utils.DeleteCollection(collectName)
|
|||
|
c.Writer.Write([]byte(info))
|
|||
|
})
|
|||
|
//查询集合
|
|||
|
router.GET("/:collectName/info", func(c *gin.Context) {
|
|||
|
collectName := c.Param("collectName")
|
|||
|
list, err := utils.GetCollection(collectName)
|
|||
|
if err != nil {
|
|||
|
c.Writer.Write([]byte(err.Error()))
|
|||
|
return
|
|||
|
}
|
|||
|
c.Writer.Write(list)
|
|||
|
})
|
|||
|
//集合列表
|
|||
|
router.GET("/collects", func(c *gin.Context) {
|
|||
|
list, err := utils.GetCollections()
|
|||
|
if err != nil {
|
|||
|
c.Writer.Write([]byte(err.Error()))
|
|||
|
return
|
|||
|
}
|
|||
|
c.Writer.Write(list)
|
|||
|
})
|
|||
|
//向量列表
|
|||
|
router.GET("/:collectName/points", func(c *gin.Context) {
|
|||
|
collectName := c.Param("collectName")
|
|||
|
list, err := utils.GetPoints(collectName, 1000, nil)
|
|||
|
if err != nil {
|
|||
|
c.Writer.Write([]byte(err.Error()))
|
|||
|
return
|
|||
|
}
|
|||
|
c.Writer.Write(list)
|
|||
|
})
|
|||
|
//向量列表
|
|||
|
router.GET("/:collectName/filePoints", func(c *gin.Context) {
|
|||
|
collectName := c.Param("collectName")
|
|||
|
id := c.Query("id")
|
|||
|
points := make([]string, 0)
|
|||
|
filePoints := models.FindAiFilePoint(1, 100000, "collect_name = ? and file_id = ?", collectName, id)
|
|||
|
for _, item := range filePoints {
|
|||
|
points = append(points, item.PointsId)
|
|||
|
}
|
|||
|
|
|||
|
list, err := utils.GetPointsByIds(collectName, points)
|
|||
|
if err != nil {
|
|||
|
c.Writer.Write([]byte(err.Error()))
|
|||
|
return
|
|||
|
}
|
|||
|
c.Writer.Write(list)
|
|||
|
})
|
|||
|
//删除向量
|
|||
|
router.GET("/:collectName/delPoints", func(c *gin.Context) {
|
|||
|
collectName := c.Param("collectName")
|
|||
|
id := c.Query("id")
|
|||
|
i, err := strconv.Atoi(id)
|
|||
|
var points interface{}
|
|||
|
if err != nil {
|
|||
|
points = []string{id}
|
|||
|
} else {
|
|||
|
points = []int{i}
|
|||
|
}
|
|||
|
list, err := utils.DeletePoints(collectName, points)
|
|||
|
if err != nil {
|
|||
|
c.Writer.Write([]byte(err.Error()))
|
|||
|
return
|
|||
|
}
|
|||
|
c.Writer.Write([]byte(list))
|
|||
|
})
|
|||
|
//删除文件
|
|||
|
router.GET("/:collectName/delFile", func(c *gin.Context) {
|
|||
|
collectName := c.Param("collectName")
|
|||
|
id := c.Query("id")
|
|||
|
points := make([]string, 0)
|
|||
|
filePoints := models.FindAiFilePoint(1, 100000, "collect_name = ? and file_id = ?", collectName, id)
|
|||
|
for _, item := range filePoints {
|
|||
|
points = append(points, item.PointsId)
|
|||
|
}
|
|||
|
list, err := utils.DeletePoints(collectName, points)
|
|||
|
models.DelFile("collect_name = ? and id = ?", collectName, id)
|
|||
|
models.DelFilePoints("collect_name = ? and file_id = ?", collectName, id)
|
|||
|
if err != nil {
|
|||
|
c.Writer.Write([]byte(err.Error()))
|
|||
|
return
|
|||
|
}
|
|||
|
c.Writer.Write([]byte(list))
|
|||
|
})
|
|||
|
//搜索流式响应
|
|||
|
if chatModel == "CHATGLM-6B" {
|
|||
|
router.Any("/:collectName/searchStream", controller.PostChatGLM)
|
|||
|
} else {
|
|||
|
router.Any("/:collectName/searchStream", controller.PostChatGPT)
|
|||
|
}
|
|||
|
//训练
|
|||
|
router.POST("/:collectName/training", func(c *gin.Context) {
|
|||
|
openaiKey := os.Getenv("OPENAI_KEY")
|
|||
|
openaiUrl := os.Getenv("OPENAI_API_BASE")
|
|||
|
//gpt url
|
|||
|
gptUrl := c.PostForm("gptUrl")
|
|||
|
if gptUrl != "" {
|
|||
|
openaiUrl = gptUrl
|
|||
|
}
|
|||
|
//gpt secret
|
|||
|
gptSecret := c.PostForm("gptSecret")
|
|||
|
if gptSecret != "" {
|
|||
|
openaiKey = gptSecret
|
|||
|
}
|
|||
|
collectName := c.Param("collectName")
|
|||
|
//判断集合是否存在
|
|||
|
collectInfo, err := utils.GetCollection(collectName)
|
|||
|
collectInfoStatus := gjson.Get(string(collectInfo), "status").String()
|
|||
|
if collectInfoStatus != "ok" {
|
|||
|
//utils.PutCollection(collectName)
|
|||
|
c.Writer.Write([]byte("集合不存在"))
|
|||
|
return
|
|||
|
}
|
|||
|
//判断ID是否传递,以及是否为数值型或uuid
|
|||
|
id := c.PostForm("id")
|
|||
|
var pointId interface{}
|
|||
|
if id == "" {
|
|||
|
pointId = uuid.NewV4().String()
|
|||
|
} else {
|
|||
|
i, err := strconv.Atoi(id)
|
|||
|
if err != nil {
|
|||
|
pointId = id
|
|||
|
} else {
|
|||
|
pointId = i
|
|||
|
}
|
|||
|
}
|
|||
|
//向量化数据
|
|||
|
content := c.PostForm("content")
|
|||
|
res, err := Train(openaiUrl, openaiKey, pointId, collectName, content, "")
|
|||
|
log.Println(err)
|
|||
|
c.Writer.Write([]byte(res))
|
|||
|
})
|
|||
|
//上传链接
|
|||
|
router.POST("/:collectName/uploadUrl", func(c *gin.Context) {
|
|||
|
openaiKey := os.Getenv("OPENAI_KEY")
|
|||
|
openaiUrl := os.Getenv("OPENAI_API_BASE")
|
|||
|
//gpt url
|
|||
|
gptUrl := c.PostForm("gptUrl")
|
|||
|
if gptUrl != "" {
|
|||
|
openaiUrl = gptUrl
|
|||
|
}
|
|||
|
//gpt secret
|
|||
|
gptSecret := c.PostForm("gptSecret")
|
|||
|
if gptSecret != "" {
|
|||
|
openaiKey = gptSecret
|
|||
|
}
|
|||
|
collectName := c.Param("collectName")
|
|||
|
url := c.PostForm("url")
|
|||
|
if url == "" || collectName == "" {
|
|||
|
c.JSON(200, gin.H{
|
|||
|
"code": 400,
|
|||
|
"msg": "参数错误!",
|
|||
|
})
|
|||
|
return
|
|||
|
}
|
|||
|
resp := utils.HttpGet(url)
|
|||
|
if resp == "" {
|
|||
|
c.JSON(200, gin.H{
|
|||
|
"code": 400,
|
|||
|
"msg": "请求数据错误!",
|
|||
|
})
|
|||
|
return
|
|||
|
}
|
|||
|
htmlContent := utils.TrimHtml(resp)
|
|||
|
//入库
|
|||
|
files := &models.AiFile{
|
|||
|
FileName: url,
|
|||
|
CollectName: collectName,
|
|||
|
}
|
|||
|
fileId := files.AddAiFile()
|
|||
|
chunks := SplitTextByLength(htmlContent, 200)
|
|||
|
for _, chunk := range chunks {
|
|||
|
pointId := uuid.NewV4().String()
|
|||
|
Train(openaiUrl, openaiKey, pointId, collectName, chunk, url)
|
|||
|
//入库
|
|||
|
aiFilePoint := &models.AiFilePoints{
|
|||
|
FileId: fmt.Sprintf("%d", fileId),
|
|||
|
CollectName: collectName,
|
|||
|
PointsId: pointId,
|
|||
|
}
|
|||
|
aiFilePoint.AddAiFilePoint()
|
|||
|
}
|
|||
|
c.JSON(200, gin.H{
|
|||
|
"code": 200,
|
|||
|
})
|
|||
|
})
|
|||
|
//上传doc
|
|||
|
router.POST("/:collectName/uploadDoc", func(c *gin.Context) {
|
|||
|
openaiKey := os.Getenv("OPENAI_KEY")
|
|||
|
openaiUrl := os.Getenv("OPENAI_API_BASE")
|
|||
|
collectName := c.Param("collectName")
|
|||
|
f, err := c.FormFile("file")
|
|||
|
if err != nil {
|
|||
|
c.JSON(200, gin.H{
|
|||
|
"code": 400,
|
|||
|
"msg": "上传失败!" + err.Error(),
|
|||
|
})
|
|||
|
return
|
|||
|
} else {
|
|||
|
|
|||
|
fileExt := strings.ToLower(path.Ext(f.Filename))
|
|||
|
if fileExt != ".docx" && fileExt != ".txt" && fileExt != ".xlsx" && fileExt != ".pdf" {
|
|||
|
c.JSON(200, gin.H{
|
|||
|
"code": 400,
|
|||
|
"msg": "上传失败!只允许txt、pdf、docx或xlsx文件",
|
|||
|
})
|
|||
|
return
|
|||
|
}
|
|||
|
fileName := collectName + f.Filename
|
|||
|
c.SaveUploadedFile(f, fileName)
|
|||
|
text := ""
|
|||
|
chunks := make([]string, 0)
|
|||
|
if fileExt == ".txt" {
|
|||
|
// 打开txt文件
|
|||
|
file, _ := os.Open(fileName)
|
|||
|
// 一次性读取整个txt文件的内容
|
|||
|
txt, _ := ioutil.ReadAll(file)
|
|||
|
text = string(txt)
|
|||
|
file.Close()
|
|||
|
} else if fileExt == ".docx" {
|
|||
|
text, err = utils.ReadDocxAll(fileName)
|
|||
|
} else if fileExt == ".pdf" {
|
|||
|
text, err = utils.ReadPdfAll(fileName)
|
|||
|
} else {
|
|||
|
chunks, err = utils.ReadExcelAll(fileName)
|
|||
|
}
|
|||
|
if err != nil {
|
|||
|
log.Println(err)
|
|||
|
}
|
|||
|
removeErr := os.Remove(fileName)
|
|||
|
if removeErr != nil {
|
|||
|
log.Println("Remove error:", fileName, removeErr)
|
|||
|
}
|
|||
|
if text == "" && len(chunks) == 0 {
|
|||
|
err = errors.New("上传失败!读取数据为空")
|
|||
|
}
|
|||
|
if err != nil {
|
|||
|
c.JSON(200, gin.H{
|
|||
|
"code": 400,
|
|||
|
"msg": err.Error(),
|
|||
|
})
|
|||
|
return
|
|||
|
}
|
|||
|
//入库
|
|||
|
files := &models.AiFile{
|
|||
|
FileName: f.Filename,
|
|||
|
CollectName: collectName,
|
|||
|
}
|
|||
|
fileId := files.AddAiFile()
|
|||
|
|
|||
|
if fileExt != ".xlsx" {
|
|||
|
chunks = SplitTextByLength(text, 200)
|
|||
|
}
|
|||
|
for _, chunk := range chunks {
|
|||
|
pointId := uuid.NewV4().String()
|
|||
|
log.Println("上传数据:" + chunk)
|
|||
|
_, err := Train(openaiUrl, openaiKey, pointId, collectName, chunk, f.Filename)
|
|||
|
if err == nil {
|
|||
|
aiFilePoint := &models.AiFilePoints{
|
|||
|
FileId: fmt.Sprintf("%d", fileId),
|
|||
|
CollectName: collectName,
|
|||
|
PointsId: pointId,
|
|||
|
}
|
|||
|
aiFilePoint.AddAiFilePoint()
|
|||
|
} else {
|
|||
|
log.Println(err)
|
|||
|
}
|
|||
|
}
|
|||
|
os.Remove(fileName)
|
|||
|
c.JSON(200, gin.H{
|
|||
|
"code": 200,
|
|||
|
})
|
|||
|
}
|
|||
|
})
|
|||
|
//文件列表
|
|||
|
router.GET("/:collectName/fileList", func(c *gin.Context) {
|
|||
|
collectName := c.Param("collectName")
|
|||
|
list := models.FindFileList(1, 1000, "collect_name = ? ", collectName)
|
|||
|
c.JSON(200, gin.H{
|
|||
|
"code": 200,
|
|||
|
"result": list,
|
|||
|
})
|
|||
|
})
|
|||
|
// 启动服务器
|
|||
|
if err := router.Run(":8083"); err != nil {
|
|||
|
panic(err)
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
//训练功能
|
|||
|
func Train(openaiUrl string, openaiKey string, pointId interface{}, collectName, content, url string) (string, error) {
|
|||
|
gpt := utils.NewChatGptTool(openaiUrl, openaiKey)
|
|||
|
str := strings.TrimSpace(content)
|
|||
|
response, err := gpt.GetEmbedding(str, "text-embedding-ada-002")
|
|||
|
if err != nil {
|
|||
|
return "", err
|
|||
|
}
|
|||
|
var embeddingResponse utils.EmbeddingResponse
|
|||
|
json.Unmarshal([]byte(response), &embeddingResponse)
|
|||
|
if len(embeddingResponse.Data) == 0 {
|
|||
|
return "", errors.New(response)
|
|||
|
}
|
|||
|
|
|||
|
points := []map[string]interface{}{
|
|||
|
{
|
|||
|
"id": pointId,
|
|||
|
"payload": map[string]interface{}{"text": str, "url": url},
|
|||
|
"vector": embeddingResponse.Data[0].Embedding,
|
|||
|
},
|
|||
|
}
|
|||
|
res, err := utils.PutPoints(collectName, points)
|
|||
|
return res, err
|
|||
|
}
|
|||
|
|
|||
|
//长文本分块
|
|||
|
func SplitTextByLength(text string, length int) []string {
|
|||
|
var blocks []string
|
|||
|
runes := []rune(text)
|
|||
|
for i := 0; i < len(runes); i += length {
|
|||
|
j := i + length
|
|||
|
if j > len(runes) {
|
|||
|
j = len(runes)
|
|||
|
}
|
|||
|
blocks = append(blocks, string(runes[i:j]))
|
|||
|
}
|
|||
|
return blocks
|
|||
|
}
|