package main import ( "encoding/json" "errors" "fmt" "github.com/gin-gonic/gin" "github.com/joho/godotenv" uuid "github.com/satori/go.uuid" "github.com/tidwall/gjson" "io/ioutil" "knowledge/controller" "knowledge/models" "knowledge/utils" "log" "net/http" "os" "path" "strconv" "strings" ) func main() { err := godotenv.Load(".env") if err != nil { log.Fatalf("Error loading .env file: %v", err) } // 读取环境变量 utils.QdrantBase = os.Getenv("QDRANT_BASE") utils.QdrantPort = os.Getenv("QDRANT_PORT") mysqlServer := os.Getenv("MYSQL_SERVER") mysqlPort := os.Getenv("MYSQL_PORT") mysqlDbname := os.Getenv("MYSQL_DATABASE") mysqlUsername := os.Getenv("MYSQL_USERNAME") mysqlPassword := os.Getenv("MYSQL_PASSWORD") chatModel := os.Getenv("CHAT_MODEL") //初始化mysql models.Connect(mysqlServer, mysqlPort, mysqlDbname, mysqlUsername, mysqlPassword) // 创建 Gin 引擎 router := gin.Default() //启用跨域中间件 router.Use(func(c *gin.Context) { c.Writer.Header().Set("Access-Control-Allow-Origin", "*") c.Writer.Header().Set("Access-Control-Allow-Methods", "GET,POST,OPTIONS,DELETE,PUT") c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Origin") if c.Request.Method == "OPTIONS" { c.AbortWithStatus(204) return } c.Next() }) router.Static("/pannel", "pannel/") router.Static("/static", "pannel/static/") router.LoadHTMLGlob("templates/*") router.GET("/", func(c *gin.Context) { c.HTML(http.StatusOK, "welcome.html", gin.H{}) }) //首页 router.GET("/:collectName/index", func(c *gin.Context) { collectName := c.Param("collectName") c.HTML(http.StatusOK, "index.html", gin.H{ "collectName": collectName, }) }) //后台界面 router.GET("/:collectName/qdrant", func(c *gin.Context) { collectName := c.Param("collectName") c.HTML(http.StatusOK, "qdrant.html", gin.H{ "collectName": collectName, }) }) //聊天页 router.GET("/:collectName/chat", func(c *gin.Context) { collectName := c.Param("collectName") c.HTML(http.StatusOK, "chat.html", gin.H{ "collectName": collectName, }) }) //创建集合 router.POST("/collect/:collectName", func(c *gin.Context) { collectName := c.Param("collectName") //判断集合是否存在 collectInfo, _ := utils.GetCollection(collectName) collectInfoStatus := gjson.Get(string(collectInfo), "status").String() if collectInfoStatus != "ok" { res, err := utils.PutCollection(collectName) if err != nil { c.Writer.Write([]byte(err.Error())) return } c.Writer.Write([]byte(res)) } else { c.Writer.Write(collectInfo) } }) //删除集合 router.GET("/delCollect", func(c *gin.Context) { collectName := c.Query("collectName") info, _ := utils.DeleteCollection(collectName) c.Writer.Write([]byte(info)) }) //查询集合 router.GET("/:collectName/info", func(c *gin.Context) { collectName := c.Param("collectName") list, err := utils.GetCollection(collectName) if err != nil { c.Writer.Write([]byte(err.Error())) return } c.Writer.Write(list) }) //集合列表 router.GET("/collects", func(c *gin.Context) { list, err := utils.GetCollections() if err != nil { c.Writer.Write([]byte(err.Error())) return } c.Writer.Write(list) }) //向量列表 router.GET("/:collectName/points", func(c *gin.Context) { collectName := c.Param("collectName") list, err := utils.GetPoints(collectName, 1000, nil) if err != nil { c.Writer.Write([]byte(err.Error())) return } c.Writer.Write(list) }) //向量列表 router.GET("/:collectName/filePoints", func(c *gin.Context) { collectName := c.Param("collectName") id := c.Query("id") points := make([]string, 0) filePoints := models.FindAiFilePoint(1, 100000, "collect_name = ? and file_id = ?", collectName, id) for _, item := range filePoints { points = append(points, item.PointsId) } list, err := utils.GetPointsByIds(collectName, points) if err != nil { c.Writer.Write([]byte(err.Error())) return } c.Writer.Write(list) }) //删除向量 router.GET("/:collectName/delPoints", func(c *gin.Context) { collectName := c.Param("collectName") id := c.Query("id") i, err := strconv.Atoi(id) var points interface{} if err != nil { points = []string{id} } else { points = []int{i} } list, err := utils.DeletePoints(collectName, points) if err != nil { c.Writer.Write([]byte(err.Error())) return } c.Writer.Write([]byte(list)) }) //删除文件 router.GET("/:collectName/delFile", func(c *gin.Context) { collectName := c.Param("collectName") id := c.Query("id") points := make([]string, 0) filePoints := models.FindAiFilePoint(1, 100000, "collect_name = ? and file_id = ?", collectName, id) for _, item := range filePoints { points = append(points, item.PointsId) } list, err := utils.DeletePoints(collectName, points) models.DelFile("collect_name = ? and id = ?", collectName, id) models.DelFilePoints("collect_name = ? and file_id = ?", collectName, id) if err != nil { c.Writer.Write([]byte(err.Error())) return } c.Writer.Write([]byte(list)) }) //搜索流式响应 if chatModel == "CHATGLM-6B" { router.Any("/:collectName/searchStream", controller.PostChatGLM) } else { router.Any("/:collectName/searchStream", controller.PostChatGPT) } //训练 router.POST("/:collectName/training", func(c *gin.Context) { openaiKey := os.Getenv("OPENAI_KEY") openaiUrl := os.Getenv("OPENAI_API_BASE") //gpt url gptUrl := c.PostForm("gptUrl") if gptUrl != "" { openaiUrl = gptUrl } //gpt secret gptSecret := c.PostForm("gptSecret") if gptSecret != "" { openaiKey = gptSecret } collectName := c.Param("collectName") //判断集合是否存在 collectInfo, err := utils.GetCollection(collectName) collectInfoStatus := gjson.Get(string(collectInfo), "status").String() if collectInfoStatus != "ok" { //utils.PutCollection(collectName) c.Writer.Write([]byte("集合不存在")) return } //判断ID是否传递,以及是否为数值型或uuid id := c.PostForm("id") var pointId interface{} if id == "" { pointId = uuid.NewV4().String() } else { i, err := strconv.Atoi(id) if err != nil { pointId = id } else { pointId = i } } //向量化数据 content := c.PostForm("content") res, err := Train(openaiUrl, openaiKey, pointId, collectName, content, "") log.Println(err) c.Writer.Write([]byte(res)) }) //上传链接 router.POST("/:collectName/uploadUrl", func(c *gin.Context) { openaiKey := os.Getenv("OPENAI_KEY") openaiUrl := os.Getenv("OPENAI_API_BASE") //gpt url gptUrl := c.PostForm("gptUrl") if gptUrl != "" { openaiUrl = gptUrl } //gpt secret gptSecret := c.PostForm("gptSecret") if gptSecret != "" { openaiKey = gptSecret } collectName := c.Param("collectName") url := c.PostForm("url") if url == "" || collectName == "" { c.JSON(200, gin.H{ "code": 400, "msg": "参数错误!", }) return } resp := utils.HttpGet(url) if resp == "" { c.JSON(200, gin.H{ "code": 400, "msg": "请求数据错误!", }) return } htmlContent := utils.TrimHtml(resp) //入库 files := &models.AiFile{ FileName: url, CollectName: collectName, } fileId := files.AddAiFile() chunks := SplitTextByLength(htmlContent, 200) for _, chunk := range chunks { pointId := uuid.NewV4().String() Train(openaiUrl, openaiKey, pointId, collectName, chunk, url) //入库 aiFilePoint := &models.AiFilePoints{ FileId: fmt.Sprintf("%d", fileId), CollectName: collectName, PointsId: pointId, } aiFilePoint.AddAiFilePoint() } c.JSON(200, gin.H{ "code": 200, }) }) //上传doc router.POST("/:collectName/uploadDoc", func(c *gin.Context) { openaiKey := os.Getenv("OPENAI_KEY") openaiUrl := os.Getenv("OPENAI_API_BASE") collectName := c.Param("collectName") f, err := c.FormFile("file") if err != nil { c.JSON(200, gin.H{ "code": 400, "msg": "上传失败!" + err.Error(), }) return } else { fileExt := strings.ToLower(path.Ext(f.Filename)) if fileExt != ".docx" && fileExt != ".txt" && fileExt != ".xlsx" && fileExt != ".pdf" { c.JSON(200, gin.H{ "code": 400, "msg": "上传失败!只允许txt、pdf、docx或xlsx文件", }) return } fileName := collectName + f.Filename c.SaveUploadedFile(f, fileName) text := "" chunks := make([]string, 0) if fileExt == ".txt" { // 打开txt文件 file, _ := os.Open(fileName) // 一次性读取整个txt文件的内容 txt, _ := ioutil.ReadAll(file) text = string(txt) file.Close() } else if fileExt == ".docx" { text, err = utils.ReadDocxAll(fileName) } else if fileExt == ".pdf" { text, err = utils.ReadPdfAll(fileName) } else { chunks, err = utils.ReadExcelAll(fileName) } if err != nil { log.Println(err) } removeErr := os.Remove(fileName) if removeErr != nil { log.Println("Remove error:", fileName, removeErr) } if text == "" && len(chunks) == 0 { err = errors.New("上传失败!读取数据为空") } if err != nil { c.JSON(200, gin.H{ "code": 400, "msg": err.Error(), }) return } //入库 files := &models.AiFile{ FileName: f.Filename, CollectName: collectName, } fileId := files.AddAiFile() if fileExt != ".xlsx" { chunks = SplitTextByLength(text, 200) } for _, chunk := range chunks { pointId := uuid.NewV4().String() log.Println("上传数据:" + chunk) _, err := Train(openaiUrl, openaiKey, pointId, collectName, chunk, f.Filename) if err == nil { aiFilePoint := &models.AiFilePoints{ FileId: fmt.Sprintf("%d", fileId), CollectName: collectName, PointsId: pointId, } aiFilePoint.AddAiFilePoint() } else { log.Println(err) } } os.Remove(fileName) c.JSON(200, gin.H{ "code": 200, }) } }) //文件列表 router.GET("/:collectName/fileList", func(c *gin.Context) { collectName := c.Param("collectName") list := models.FindFileList(1, 1000, "collect_name = ? ", collectName) c.JSON(200, gin.H{ "code": 200, "result": list, }) }) // 启动服务器 if err := router.Run(":8083"); err != nil { panic(err) } } //训练功能 func Train(openaiUrl string, openaiKey string, pointId interface{}, collectName, content, url string) (string, error) { gpt := utils.NewChatGptTool(openaiUrl, openaiKey) str := strings.TrimSpace(content) response, err := gpt.GetEmbedding(str, "text-embedding-ada-002") if err != nil { return "", err } var embeddingResponse utils.EmbeddingResponse json.Unmarshal([]byte(response), &embeddingResponse) if len(embeddingResponse.Data) == 0 { return "", errors.New(response) } points := []map[string]interface{}{ { "id": pointId, "payload": map[string]interface{}{"text": str, "url": url}, "vector": embeddingResponse.Data[0].Embedding, }, } res, err := utils.PutPoints(collectName, points) return res, err } //长文本分块 func SplitTextByLength(text string, length int) []string { var blocks []string runes := []rune(text) for i := 0; i < len(runes); i += length { j := i + length if j > len(runes) { j = len(runes) } blocks = append(blocks, string(runes[i:j])) } return blocks }