430 lines
11 KiB
Go
430 lines
11 KiB
Go
package main
|
||
|
||
import (
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/joho/godotenv"
|
||
uuid "github.com/satori/go.uuid"
|
||
"github.com/tidwall/gjson"
|
||
"io/ioutil"
|
||
"knowledge/controller"
|
||
"knowledge/models"
|
||
"knowledge/utils"
|
||
"log"
|
||
"net/http"
|
||
"os"
|
||
"path"
|
||
"strconv"
|
||
"strings"
|
||
)
|
||
|
||
func main() {
|
||
err := godotenv.Load(".env")
|
||
if err != nil {
|
||
log.Fatalf("Error loading .env file: %v", err)
|
||
}
|
||
// 读取环境变量
|
||
utils.QdrantBase = os.Getenv("QDRANT_BASE")
|
||
utils.QdrantPort = os.Getenv("QDRANT_PORT")
|
||
mysqlServer := os.Getenv("MYSQL_SERVER")
|
||
mysqlPort := os.Getenv("MYSQL_PORT")
|
||
mysqlDbname := os.Getenv("MYSQL_DATABASE")
|
||
mysqlUsername := os.Getenv("MYSQL_USERNAME")
|
||
mysqlPassword := os.Getenv("MYSQL_PASSWORD")
|
||
chatModel := os.Getenv("CHAT_MODEL")
|
||
|
||
//初始化mysql
|
||
models.Connect(mysqlServer, mysqlPort, mysqlDbname, mysqlUsername, mysqlPassword)
|
||
|
||
// 创建 Gin 引擎
|
||
router := gin.Default()
|
||
//启用跨域中间件
|
||
router.Use(func(c *gin.Context) {
|
||
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET,POST,OPTIONS,DELETE,PUT")
|
||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Origin")
|
||
if c.Request.Method == "OPTIONS" {
|
||
c.AbortWithStatus(204)
|
||
return
|
||
}
|
||
c.Next()
|
||
})
|
||
router.Static("/pannel", "pannel/")
|
||
router.Static("/static", "pannel/static/")
|
||
router.LoadHTMLGlob("templates/*")
|
||
router.GET("/", func(c *gin.Context) {
|
||
c.HTML(http.StatusOK, "welcome.html", gin.H{})
|
||
})
|
||
//首页
|
||
router.GET("/:collectName/index", func(c *gin.Context) {
|
||
collectName := c.Param("collectName")
|
||
c.HTML(http.StatusOK, "index.html", gin.H{
|
||
"collectName": collectName,
|
||
})
|
||
})
|
||
//后台界面
|
||
router.GET("/:collectName/qdrant", func(c *gin.Context) {
|
||
collectName := c.Param("collectName")
|
||
c.HTML(http.StatusOK, "qdrant.html", gin.H{
|
||
"collectName": collectName,
|
||
})
|
||
})
|
||
//聊天页
|
||
router.GET("/:collectName/chat", func(c *gin.Context) {
|
||
collectName := c.Param("collectName")
|
||
c.HTML(http.StatusOK, "chat.html", gin.H{
|
||
"collectName": collectName,
|
||
})
|
||
})
|
||
//创建集合
|
||
router.POST("/collect/:collectName", func(c *gin.Context) {
|
||
collectName := c.Param("collectName")
|
||
//判断集合是否存在
|
||
collectInfo, _ := utils.GetCollection(collectName)
|
||
collectInfoStatus := gjson.Get(string(collectInfo), "status").String()
|
||
if collectInfoStatus != "ok" {
|
||
res, err := utils.PutCollection(collectName)
|
||
if err != nil {
|
||
c.Writer.Write([]byte(err.Error()))
|
||
return
|
||
}
|
||
c.Writer.Write([]byte(res))
|
||
} else {
|
||
c.Writer.Write(collectInfo)
|
||
}
|
||
|
||
})
|
||
//删除集合
|
||
router.GET("/delCollect", func(c *gin.Context) {
|
||
collectName := c.Query("collectName")
|
||
info, _ := utils.DeleteCollection(collectName)
|
||
c.Writer.Write([]byte(info))
|
||
})
|
||
//查询集合
|
||
router.GET("/:collectName/info", func(c *gin.Context) {
|
||
collectName := c.Param("collectName")
|
||
list, err := utils.GetCollection(collectName)
|
||
if err != nil {
|
||
c.Writer.Write([]byte(err.Error()))
|
||
return
|
||
}
|
||
c.Writer.Write(list)
|
||
})
|
||
//集合列表
|
||
router.GET("/collects", func(c *gin.Context) {
|
||
list, err := utils.GetCollections()
|
||
if err != nil {
|
||
c.Writer.Write([]byte(err.Error()))
|
||
return
|
||
}
|
||
c.Writer.Write(list)
|
||
})
|
||
//向量列表
|
||
router.GET("/:collectName/points", func(c *gin.Context) {
|
||
collectName := c.Param("collectName")
|
||
list, err := utils.GetPoints(collectName, 1000, nil)
|
||
if err != nil {
|
||
c.Writer.Write([]byte(err.Error()))
|
||
return
|
||
}
|
||
c.Writer.Write(list)
|
||
})
|
||
//向量列表
|
||
router.GET("/:collectName/filePoints", func(c *gin.Context) {
|
||
collectName := c.Param("collectName")
|
||
id := c.Query("id")
|
||
points := make([]string, 0)
|
||
filePoints := models.FindAiFilePoint(1, 100000, "collect_name = ? and file_id = ?", collectName, id)
|
||
for _, item := range filePoints {
|
||
points = append(points, item.PointsId)
|
||
}
|
||
|
||
list, err := utils.GetPointsByIds(collectName, points)
|
||
if err != nil {
|
||
c.Writer.Write([]byte(err.Error()))
|
||
return
|
||
}
|
||
c.Writer.Write(list)
|
||
})
|
||
//删除向量
|
||
router.GET("/:collectName/delPoints", func(c *gin.Context) {
|
||
collectName := c.Param("collectName")
|
||
id := c.Query("id")
|
||
i, err := strconv.Atoi(id)
|
||
var points interface{}
|
||
if err != nil {
|
||
points = []string{id}
|
||
} else {
|
||
points = []int{i}
|
||
}
|
||
list, err := utils.DeletePoints(collectName, points)
|
||
if err != nil {
|
||
c.Writer.Write([]byte(err.Error()))
|
||
return
|
||
}
|
||
c.Writer.Write([]byte(list))
|
||
})
|
||
//删除文件
|
||
router.GET("/:collectName/delFile", func(c *gin.Context) {
|
||
collectName := c.Param("collectName")
|
||
id := c.Query("id")
|
||
points := make([]string, 0)
|
||
filePoints := models.FindAiFilePoint(1, 100000, "collect_name = ? and file_id = ?", collectName, id)
|
||
for _, item := range filePoints {
|
||
points = append(points, item.PointsId)
|
||
}
|
||
list, err := utils.DeletePoints(collectName, points)
|
||
models.DelFile("collect_name = ? and id = ?", collectName, id)
|
||
models.DelFilePoints("collect_name = ? and file_id = ?", collectName, id)
|
||
if err != nil {
|
||
c.Writer.Write([]byte(err.Error()))
|
||
return
|
||
}
|
||
c.Writer.Write([]byte(list))
|
||
})
|
||
//搜索流式响应
|
||
if chatModel == "CHATGLM-6B" {
|
||
router.Any("/:collectName/searchStream", controller.PostChatGLM)
|
||
} else {
|
||
router.Any("/:collectName/searchStream", controller.PostChatGPT)
|
||
}
|
||
//训练
|
||
router.POST("/:collectName/training", func(c *gin.Context) {
|
||
openaiKey := os.Getenv("OPENAI_KEY")
|
||
openaiUrl := os.Getenv("OPENAI_API_BASE")
|
||
//gpt url
|
||
gptUrl := c.PostForm("gptUrl")
|
||
if gptUrl != "" {
|
||
openaiUrl = gptUrl
|
||
}
|
||
//gpt secret
|
||
gptSecret := c.PostForm("gptSecret")
|
||
if gptSecret != "" {
|
||
openaiKey = gptSecret
|
||
}
|
||
collectName := c.Param("collectName")
|
||
//判断集合是否存在
|
||
collectInfo, err := utils.GetCollection(collectName)
|
||
collectInfoStatus := gjson.Get(string(collectInfo), "status").String()
|
||
if collectInfoStatus != "ok" {
|
||
//utils.PutCollection(collectName)
|
||
c.Writer.Write([]byte("集合不存在"))
|
||
return
|
||
}
|
||
//判断ID是否传递,以及是否为数值型或uuid
|
||
id := c.PostForm("id")
|
||
var pointId interface{}
|
||
if id == "" {
|
||
pointId = uuid.NewV4().String()
|
||
} else {
|
||
i, err := strconv.Atoi(id)
|
||
if err != nil {
|
||
pointId = id
|
||
} else {
|
||
pointId = i
|
||
}
|
||
}
|
||
//向量化数据
|
||
content := c.PostForm("content")
|
||
res, err := Train(openaiUrl, openaiKey, pointId, collectName, content, "")
|
||
log.Println(err)
|
||
c.Writer.Write([]byte(res))
|
||
})
|
||
//上传链接
|
||
router.POST("/:collectName/uploadUrl", func(c *gin.Context) {
|
||
openaiKey := os.Getenv("OPENAI_KEY")
|
||
openaiUrl := os.Getenv("OPENAI_API_BASE")
|
||
//gpt url
|
||
gptUrl := c.PostForm("gptUrl")
|
||
if gptUrl != "" {
|
||
openaiUrl = gptUrl
|
||
}
|
||
//gpt secret
|
||
gptSecret := c.PostForm("gptSecret")
|
||
if gptSecret != "" {
|
||
openaiKey = gptSecret
|
||
}
|
||
collectName := c.Param("collectName")
|
||
url := c.PostForm("url")
|
||
if url == "" || collectName == "" {
|
||
c.JSON(200, gin.H{
|
||
"code": 400,
|
||
"msg": "参数错误!",
|
||
})
|
||
return
|
||
}
|
||
resp := utils.HttpGet(url)
|
||
if resp == "" {
|
||
c.JSON(200, gin.H{
|
||
"code": 400,
|
||
"msg": "请求数据错误!",
|
||
})
|
||
return
|
||
}
|
||
htmlContent := utils.TrimHtml(resp)
|
||
//入库
|
||
files := &models.AiFile{
|
||
FileName: url,
|
||
CollectName: collectName,
|
||
}
|
||
fileId := files.AddAiFile()
|
||
chunks := SplitTextByLength(htmlContent, 200)
|
||
for _, chunk := range chunks {
|
||
pointId := uuid.NewV4().String()
|
||
Train(openaiUrl, openaiKey, pointId, collectName, chunk, url)
|
||
//入库
|
||
aiFilePoint := &models.AiFilePoints{
|
||
FileId: fmt.Sprintf("%d", fileId),
|
||
CollectName: collectName,
|
||
PointsId: pointId,
|
||
}
|
||
aiFilePoint.AddAiFilePoint()
|
||
}
|
||
c.JSON(200, gin.H{
|
||
"code": 200,
|
||
})
|
||
})
|
||
//上传doc
|
||
router.POST("/:collectName/uploadDoc", func(c *gin.Context) {
|
||
openaiKey := os.Getenv("OPENAI_KEY")
|
||
openaiUrl := os.Getenv("OPENAI_API_BASE")
|
||
collectName := c.Param("collectName")
|
||
f, err := c.FormFile("file")
|
||
if err != nil {
|
||
c.JSON(200, gin.H{
|
||
"code": 400,
|
||
"msg": "上传失败!" + err.Error(),
|
||
})
|
||
return
|
||
} else {
|
||
|
||
fileExt := strings.ToLower(path.Ext(f.Filename))
|
||
if fileExt != ".docx" && fileExt != ".txt" && fileExt != ".xlsx" && fileExt != ".pdf" {
|
||
c.JSON(200, gin.H{
|
||
"code": 400,
|
||
"msg": "上传失败!只允许txt、pdf、docx或xlsx文件",
|
||
})
|
||
return
|
||
}
|
||
fileName := collectName + f.Filename
|
||
c.SaveUploadedFile(f, fileName)
|
||
text := ""
|
||
chunks := make([]string, 0)
|
||
if fileExt == ".txt" {
|
||
// 打开txt文件
|
||
file, _ := os.Open(fileName)
|
||
// 一次性读取整个txt文件的内容
|
||
txt, _ := ioutil.ReadAll(file)
|
||
text = string(txt)
|
||
file.Close()
|
||
} else if fileExt == ".docx" {
|
||
text, err = utils.ReadDocxAll(fileName)
|
||
} else if fileExt == ".pdf" {
|
||
text, err = utils.ReadPdfAll(fileName)
|
||
} else {
|
||
chunks, err = utils.ReadExcelAll(fileName)
|
||
}
|
||
if err != nil {
|
||
log.Println(err)
|
||
}
|
||
removeErr := os.Remove(fileName)
|
||
if removeErr != nil {
|
||
log.Println("Remove error:", fileName, removeErr)
|
||
}
|
||
if text == "" && len(chunks) == 0 {
|
||
err = errors.New("上传失败!读取数据为空")
|
||
}
|
||
if err != nil {
|
||
c.JSON(200, gin.H{
|
||
"code": 400,
|
||
"msg": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
//入库
|
||
files := &models.AiFile{
|
||
FileName: f.Filename,
|
||
CollectName: collectName,
|
||
}
|
||
fileId := files.AddAiFile()
|
||
|
||
if fileExt != ".xlsx" {
|
||
chunks = SplitTextByLength(text, 200)
|
||
}
|
||
for _, chunk := range chunks {
|
||
pointId := uuid.NewV4().String()
|
||
log.Println("上传数据:" + chunk)
|
||
_, err := Train(openaiUrl, openaiKey, pointId, collectName, chunk, f.Filename)
|
||
if err == nil {
|
||
aiFilePoint := &models.AiFilePoints{
|
||
FileId: fmt.Sprintf("%d", fileId),
|
||
CollectName: collectName,
|
||
PointsId: pointId,
|
||
}
|
||
aiFilePoint.AddAiFilePoint()
|
||
} else {
|
||
log.Println(err)
|
||
}
|
||
}
|
||
os.Remove(fileName)
|
||
c.JSON(200, gin.H{
|
||
"code": 200,
|
||
})
|
||
}
|
||
})
|
||
//文件列表
|
||
router.GET("/:collectName/fileList", func(c *gin.Context) {
|
||
collectName := c.Param("collectName")
|
||
list := models.FindFileList(1, 1000, "collect_name = ? ", collectName)
|
||
c.JSON(200, gin.H{
|
||
"code": 200,
|
||
"result": list,
|
||
})
|
||
})
|
||
// 启动服务器
|
||
if err := router.Run(":8083"); err != nil {
|
||
panic(err)
|
||
}
|
||
}
|
||
|
||
//训练功能
|
||
func Train(openaiUrl string, openaiKey string, pointId interface{}, collectName, content, url string) (string, error) {
|
||
gpt := utils.NewChatGptTool(openaiUrl, openaiKey)
|
||
str := strings.TrimSpace(content)
|
||
response, err := gpt.GetEmbedding(str, "text-embedding-ada-002")
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
var embeddingResponse utils.EmbeddingResponse
|
||
json.Unmarshal([]byte(response), &embeddingResponse)
|
||
if len(embeddingResponse.Data) == 0 {
|
||
return "", errors.New(response)
|
||
}
|
||
|
||
points := []map[string]interface{}{
|
||
{
|
||
"id": pointId,
|
||
"payload": map[string]interface{}{"text": str, "url": url},
|
||
"vector": embeddingResponse.Data[0].Embedding,
|
||
},
|
||
}
|
||
res, err := utils.PutPoints(collectName, points)
|
||
return res, err
|
||
}
|
||
|
||
//长文本分块
|
||
func SplitTextByLength(text string, length int) []string {
|
||
var blocks []string
|
||
runes := []rune(text)
|
||
for i := 0; i < len(runes); i += length {
|
||
j := i + length
|
||
if j > len(runes) {
|
||
j = len(runes)
|
||
}
|
||
blocks = append(blocks, string(runes[i:j]))
|
||
}
|
||
return blocks
|
||
}
|