kefu/controller/aiProxy/ernieBotTurbo.go

100 lines
3.1 KiB
Go
Raw Normal View History

2024-12-10 02:50:12 +00:00
package aiProxy
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"kefu/lib"
"kefu/models"
"log"
"net/http"
"strings"
)
func PostErnieBotTurbocompletions(c *gin.Context) {
entId := c.Param("entId")
// 获取原始请求体数据
rawData, err := c.GetRawData()
log.Println("PostErnieBotTurbocompletions ", string(rawData))
if err != nil {
log.Println("GetRawData Error", err)
return
}
// 设置响应头指定为SSE格式
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
// 创建一个只写的响应流
responseWriter := c.Writer
flusher, ok := responseWriter.(http.Flusher)
if !ok {
log.Println("Streaming not supported")
return
}
// 将原始数据转换为字符串
jsonString := string(rawData)
messages := gjson.Get(jsonString, "messages").String()
config := models.GetEntConfigsMap(entId, "ERNIEAppId", "ERNIEAPIKey", "ERNIESecretKey")
//AppID := "35662533"
//APIKey := "Iq1FfkOQIGtMtZqRFxOrvq6T"
//SecretKey := "qbzsoFAUSl8UGt1GkGSDSjENtqsjrOTC"
m, _ := lib.NewErnieBotTurbo(config["ERNIEAppId"], config["ERNIEAPIKey"], config["ERNIESecretKey"])
prompts := make([]map[string]string, 0)
json.Unmarshal([]byte(messages), &prompts)
if len(prompts) > 1 {
prompts = prompts[len(prompts)-1:]
}
//log.Println(prompts)
res, _ := m.StreamChat(prompts)
for {
str, err := m.StreamRecv(res)
if err != nil {
log.Println(err)
break
}
result := strings.Trim(gjson.Get(str, "result").String(), "\n")
lines := strings.Split(result, "\n")
for _, line := range lines {
if line == "" {
line = "\\n\\n"
}
eventData := fmt.Sprintf("data: {\"id\":\"chatcmpl-7gUI4R3DmEst3sTiK3FtMYFBxV1NR\",\"object\":\"chat.completion.chunk\",\"created\":1690360568,\"model\":\"gpt-3.5-turbo-0613\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"%s\"},\"finish_reason\":null}]}\n\n", line)
// 构造SSE事件数据
// 将事件数据写入响应流
_, err = responseWriter.WriteString(eventData)
if err != nil {
break
}
// 刷新响应流,确保数据被发送到客户端
flusher.Flush()
}
log.Println(str, err)
}
eventData := "data: {\"id\":\"chatcmpl-7gUI4R3DmEst3sTiK3FtMYFBxV1NR\",\"object\":\"chat.completion.chunk\",\"created\":1690360568,\"model\":\"gpt-3.5-turbo-0613\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"stop\"}]}\n\ndata: [DONE]\n\n"
_, err = responseWriter.WriteString(eventData)
flusher.Flush()
}
func PostErnieBotTurboEmbeddings(c *gin.Context) {
c.Header("Content-Type", "application/json; charset=utf-8")
entId := c.Param("entId")
// 获取原始请求体数据
rawData, err := c.GetRawData()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
config := models.GetEntConfigsMap(entId, "ERNIEAppId", "ERNIEAPIKey", "ERNIESecretKey")
m, _ := lib.NewErnieBotTurbo(config["ERNIEAppId"], config["ERNIEAPIKey"], config["ERNIESecretKey"])
prompt := []string{gjson.Get(string(rawData), "input").String()}
res, err := m.Embedding(prompt)
c.Writer.Write([]byte(res))
}