kefu/controller/ai/chatStream.go

154 lines
4.1 KiB
Go

package ai
import (
"errors"
"github.com/gin-gonic/gin"
"github.com/sashabaranov/go-openai"
"io"
"kefu/lib"
"kefu/models"
"kefu/tools"
"kefu/types"
"log"
"net/http"
)
// 聊天
func PostChatStream(c *gin.Context) {
c.Header("Content-Type", "text/html;charset=utf-8;")
content := c.PostForm("content")
bigModel := c.PostForm("model")
entId, _ := c.Get("ent_id")
kefuName, _ := c.Get("kefu_name")
collectId := c.PostForm("collect_id")
//查询key
config := models.GetEntConfigsMap(entId.(string), "chatGPTUrl", "chatGPTSecret", "ERNIEAppId", "ERNIEAPIKey", "ERNIESecretKey")
if config["chatGPTUrl"] == "" || config["chatGPTSecret"] == "" {
c.Writer.Write([]byte("请先配置大模型URL和KEY"))
return
}
gpt := lib.NewChatGptTool(config["chatGPTUrl"], config["chatGPTSecret"])
var err error
gpt3Dot5Message := make([]lib.Gpt3Dot5Message, 0)
//历史记录
if collectId != "" {
collects := models.FindAigcSessionMessage(0, 100, "kefu_name = ? and ent_id = ? and collect_id = ?", kefuName, entId, collectId)
for _, sessionMessage := range collects {
if sessionMessage.MsgType == "ask" {
item := lib.Gpt3Dot5Message{
Role: "user",
Content: sessionMessage.Content,
}
gpt3Dot5Message = append(gpt3Dot5Message, item)
} else {
item := lib.Gpt3Dot5Message{
Role: "assistant",
Content: sessionMessage.Content,
}
gpt3Dot5Message = append(gpt3Dot5Message, item)
}
}
}
gpt3Dot5Message = append(gpt3Dot5Message, lib.Gpt3Dot5Message{
Role: "user",
Content: content,
})
//调用openai
f, _ := c.Writer.(http.Flusher)
var stream *openai.ChatCompletionStream
log.Println(gpt3Dot5Message)
//if bigModel == "GPT-3.5" {
stream, err = gpt.ChatGPT3Dot5TurboStream(gpt3Dot5Message, bigModel)
//} else if bigModel == "GPT-4" {
// stream, err = gpt.ChatGPT4Stream(gpt3Dot5Message)
//} else if bigModel == "ERNIE-Bot-turbo" {
// //AppID := "35662533"
// //APIKey := "Iq1FfkOQIGtMtZqRFxOrvq6T"
// //SecretKey := "qbzsoFAUSl8UGt1GkGSDSjENtqsjrOTC"
// m, _ := lib.NewErnieBotTurbo(config["ERNIEAppId"], config["ERNIEAPIKey"], config["ERNIESecretKey"])
// prompt := []map[string]string{{"role": "user", "content": content}}
// res, _ := m.StreamChat(prompt)
// for {
// str, err := m.StreamRecv(res)
// if errors.Is(err, io.EOF) {
// log.Println("Stream finished", err)
// break
// } else if err != nil {
// c.Writer.Write([]byte("文心千帆大模型错误"))
// f.Flush()
// log.Println(err)
// break
// }
//
// //log.Println(str, err)
// result := gjson.Get(str, "result").String()
// c.Writer.Write([]byte(result))
// f.Flush()
// }
// return
//} else {
// c.Writer.Write([]byte("模型不存在"))
// f.Flush()
// return
//}
if err != nil {
c.Writer.Write([]byte("模型错误"))
f.Flush()
log.Println("gpt stream error: ", err)
return
}
for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
log.Println("Stream finished", err)
break
} else if err != nil {
log.Println("Stream error:", err, response)
break
} else {
aiReply := response.Choices[0].Delta.Content
c.Writer.Write([]byte(aiReply))
f.Flush()
}
}
stream.Close()
}
// 聊天
func PostSaveChatStream(c *gin.Context) {
collectId := c.PostForm("collect_id")
content := c.PostForm("content")
kefuAvatar := c.PostForm("kefu_avatar")
aiAvatar := c.PostForm("ai_avatar")
entId, _ := c.Get("ent_id")
kefuName, _ := c.Get("kefu_name")
msgType := c.PostForm("msg_type")
collect := models.FindAigcSessionCollect("id = ?", collectId)
if collect.ID == 0 {
collectUintId := models.CreateAigcSessionCollect(content, entId.(string), kefuName.(string))
collectId = tools.Int2Str(collectUintId)
}
message := models.AigcSessionMessage{
EntId: entId.(string),
KefuName: kefuName.(string),
Content: content,
KefuAvatar: kefuAvatar,
AiAvatar: aiAvatar,
MsgType: msgType,
CollectId: tools.Str2Uint(collectId),
}
message.CreateAigcSessionMessage()
c.JSON(200, gin.H{
"code": types.ApiCode.SUCCESS,
"msg": types.ApiCode.GetMessage(types.ApiCode.SUCCESS),
"result": gin.H{
"collect_id": collectId,
},
})
}