154 lines
4.1 KiB
Go
154 lines
4.1 KiB
Go
package ai
|
|
|
|
import (
|
|
"errors"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/sashabaranov/go-openai"
|
|
"io"
|
|
"kefu/lib"
|
|
"kefu/models"
|
|
"kefu/tools"
|
|
"kefu/types"
|
|
"log"
|
|
"net/http"
|
|
)
|
|
|
|
// 聊天
|
|
func PostChatStream(c *gin.Context) {
|
|
c.Header("Content-Type", "text/html;charset=utf-8;")
|
|
content := c.PostForm("content")
|
|
bigModel := c.PostForm("model")
|
|
entId, _ := c.Get("ent_id")
|
|
kefuName, _ := c.Get("kefu_name")
|
|
collectId := c.PostForm("collect_id")
|
|
|
|
//查询key
|
|
config := models.GetEntConfigsMap(entId.(string), "chatGPTUrl", "chatGPTSecret", "ERNIEAppId", "ERNIEAPIKey", "ERNIESecretKey")
|
|
if config["chatGPTUrl"] == "" || config["chatGPTSecret"] == "" {
|
|
c.Writer.Write([]byte("请先配置大模型URL和KEY"))
|
|
return
|
|
}
|
|
gpt := lib.NewChatGptTool(config["chatGPTUrl"], config["chatGPTSecret"])
|
|
var err error
|
|
gpt3Dot5Message := make([]lib.Gpt3Dot5Message, 0)
|
|
|
|
//历史记录
|
|
if collectId != "" {
|
|
collects := models.FindAigcSessionMessage(0, 100, "kefu_name = ? and ent_id = ? and collect_id = ?", kefuName, entId, collectId)
|
|
for _, sessionMessage := range collects {
|
|
if sessionMessage.MsgType == "ask" {
|
|
item := lib.Gpt3Dot5Message{
|
|
Role: "user",
|
|
Content: sessionMessage.Content,
|
|
}
|
|
gpt3Dot5Message = append(gpt3Dot5Message, item)
|
|
} else {
|
|
item := lib.Gpt3Dot5Message{
|
|
Role: "assistant",
|
|
Content: sessionMessage.Content,
|
|
}
|
|
gpt3Dot5Message = append(gpt3Dot5Message, item)
|
|
}
|
|
}
|
|
}
|
|
|
|
gpt3Dot5Message = append(gpt3Dot5Message, lib.Gpt3Dot5Message{
|
|
Role: "user",
|
|
Content: content,
|
|
})
|
|
//调用openai
|
|
f, _ := c.Writer.(http.Flusher)
|
|
|
|
var stream *openai.ChatCompletionStream
|
|
log.Println(gpt3Dot5Message)
|
|
//if bigModel == "GPT-3.5" {
|
|
stream, err = gpt.ChatGPT3Dot5TurboStream(gpt3Dot5Message, bigModel)
|
|
//} else if bigModel == "GPT-4" {
|
|
// stream, err = gpt.ChatGPT4Stream(gpt3Dot5Message)
|
|
//} else if bigModel == "ERNIE-Bot-turbo" {
|
|
// //AppID := "35662533"
|
|
// //APIKey := "Iq1FfkOQIGtMtZqRFxOrvq6T"
|
|
// //SecretKey := "qbzsoFAUSl8UGt1GkGSDSjENtqsjrOTC"
|
|
// m, _ := lib.NewErnieBotTurbo(config["ERNIEAppId"], config["ERNIEAPIKey"], config["ERNIESecretKey"])
|
|
// prompt := []map[string]string{{"role": "user", "content": content}}
|
|
// res, _ := m.StreamChat(prompt)
|
|
// for {
|
|
// str, err := m.StreamRecv(res)
|
|
// if errors.Is(err, io.EOF) {
|
|
// log.Println("Stream finished", err)
|
|
// break
|
|
// } else if err != nil {
|
|
// c.Writer.Write([]byte("文心千帆大模型错误"))
|
|
// f.Flush()
|
|
// log.Println(err)
|
|
// break
|
|
// }
|
|
//
|
|
// //log.Println(str, err)
|
|
// result := gjson.Get(str, "result").String()
|
|
// c.Writer.Write([]byte(result))
|
|
// f.Flush()
|
|
// }
|
|
// return
|
|
//} else {
|
|
// c.Writer.Write([]byte("模型不存在"))
|
|
// f.Flush()
|
|
// return
|
|
//}
|
|
if err != nil {
|
|
c.Writer.Write([]byte("模型错误"))
|
|
f.Flush()
|
|
log.Println("gpt stream error: ", err)
|
|
return
|
|
}
|
|
|
|
for {
|
|
response, err := stream.Recv()
|
|
if errors.Is(err, io.EOF) {
|
|
log.Println("Stream finished", err)
|
|
break
|
|
} else if err != nil {
|
|
log.Println("Stream error:", err, response)
|
|
break
|
|
} else {
|
|
aiReply := response.Choices[0].Delta.Content
|
|
c.Writer.Write([]byte(aiReply))
|
|
f.Flush()
|
|
}
|
|
}
|
|
stream.Close()
|
|
}
|
|
|
|
// 聊天
|
|
func PostSaveChatStream(c *gin.Context) {
|
|
collectId := c.PostForm("collect_id")
|
|
content := c.PostForm("content")
|
|
kefuAvatar := c.PostForm("kefu_avatar")
|
|
aiAvatar := c.PostForm("ai_avatar")
|
|
entId, _ := c.Get("ent_id")
|
|
kefuName, _ := c.Get("kefu_name")
|
|
msgType := c.PostForm("msg_type")
|
|
collect := models.FindAigcSessionCollect("id = ?", collectId)
|
|
if collect.ID == 0 {
|
|
collectUintId := models.CreateAigcSessionCollect(content, entId.(string), kefuName.(string))
|
|
collectId = tools.Int2Str(collectUintId)
|
|
}
|
|
message := models.AigcSessionMessage{
|
|
EntId: entId.(string),
|
|
KefuName: kefuName.(string),
|
|
Content: content,
|
|
KefuAvatar: kefuAvatar,
|
|
AiAvatar: aiAvatar,
|
|
MsgType: msgType,
|
|
CollectId: tools.Str2Uint(collectId),
|
|
}
|
|
message.CreateAigcSessionMessage()
|
|
c.JSON(200, gin.H{
|
|
"code": types.ApiCode.SUCCESS,
|
|
"msg": types.ApiCode.GetMessage(types.ApiCode.SUCCESS),
|
|
"result": gin.H{
|
|
"collect_id": collectId,
|
|
},
|
|
})
|
|
}
|