100 lines
3.1 KiB
Go
100 lines
3.1 KiB
Go
package aiProxy
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/tidwall/gjson"
|
||
"kefu/lib"
|
||
"kefu/models"
|
||
"log"
|
||
"net/http"
|
||
"strings"
|
||
)
|
||
|
||
func PostErnieBotTurbocompletions(c *gin.Context) {
|
||
entId := c.Param("entId")
|
||
// 获取原始请求体数据
|
||
rawData, err := c.GetRawData()
|
||
log.Println("PostErnieBotTurbocompletions ", string(rawData))
|
||
if err != nil {
|
||
log.Println("GetRawData Error", err)
|
||
return
|
||
}
|
||
|
||
// 设置响应头,指定为SSE格式
|
||
c.Header("Content-Type", "text/event-stream")
|
||
c.Header("Cache-Control", "no-cache")
|
||
c.Header("Connection", "keep-alive")
|
||
c.Header("X-Accel-Buffering", "no")
|
||
|
||
// 创建一个只写的响应流
|
||
responseWriter := c.Writer
|
||
flusher, ok := responseWriter.(http.Flusher)
|
||
if !ok {
|
||
log.Println("Streaming not supported")
|
||
return
|
||
}
|
||
|
||
// 将原始数据转换为字符串
|
||
jsonString := string(rawData)
|
||
messages := gjson.Get(jsonString, "messages").String()
|
||
config := models.GetEntConfigsMap(entId, "ERNIEAppId", "ERNIEAPIKey", "ERNIESecretKey")
|
||
|
||
//AppID := "35662533"
|
||
//APIKey := "Iq1FfkOQIGtMtZqRFxOrvq6T"
|
||
//SecretKey := "qbzsoFAUSl8UGt1GkGSDSjENtqsjrOTC"
|
||
m, _ := lib.NewErnieBotTurbo(config["ERNIEAppId"], config["ERNIEAPIKey"], config["ERNIESecretKey"])
|
||
prompts := make([]map[string]string, 0)
|
||
json.Unmarshal([]byte(messages), &prompts)
|
||
if len(prompts) > 1 {
|
||
prompts = prompts[len(prompts)-1:]
|
||
}
|
||
//log.Println(prompts)
|
||
res, _ := m.StreamChat(prompts)
|
||
for {
|
||
str, err := m.StreamRecv(res)
|
||
if err != nil {
|
||
log.Println(err)
|
||
break
|
||
}
|
||
result := strings.Trim(gjson.Get(str, "result").String(), "\n")
|
||
lines := strings.Split(result, "\n")
|
||
|
||
for _, line := range lines {
|
||
if line == "" {
|
||
line = "\\n\\n"
|
||
}
|
||
eventData := fmt.Sprintf("data: {\"id\":\"chatcmpl-7gUI4R3DmEst3sTiK3FtMYFBxV1NR\",\"object\":\"chat.completion.chunk\",\"created\":1690360568,\"model\":\"gpt-3.5-turbo-0613\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"%s\"},\"finish_reason\":null}]}\n\n", line)
|
||
// 构造SSE事件数据
|
||
// 将事件数据写入响应流
|
||
_, err = responseWriter.WriteString(eventData)
|
||
if err != nil {
|
||
break
|
||
}
|
||
// 刷新响应流,确保数据被发送到客户端
|
||
flusher.Flush()
|
||
}
|
||
log.Println(str, err)
|
||
}
|
||
eventData := "data: {\"id\":\"chatcmpl-7gUI4R3DmEst3sTiK3FtMYFBxV1NR\",\"object\":\"chat.completion.chunk\",\"created\":1690360568,\"model\":\"gpt-3.5-turbo-0613\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"stop\"}]}\n\ndata: [DONE]\n\n"
|
||
_, err = responseWriter.WriteString(eventData)
|
||
flusher.Flush()
|
||
}
|
||
func PostErnieBotTurboEmbeddings(c *gin.Context) {
|
||
c.Header("Content-Type", "application/json; charset=utf-8")
|
||
entId := c.Param("entId")
|
||
// 获取原始请求体数据
|
||
rawData, err := c.GetRawData()
|
||
if err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
config := models.GetEntConfigsMap(entId, "ERNIEAppId", "ERNIEAPIKey", "ERNIESecretKey")
|
||
|
||
m, _ := lib.NewErnieBotTurbo(config["ERNIEAppId"], config["ERNIEAPIKey"], config["ERNIESecretKey"])
|
||
prompt := []string{gjson.Get(string(rawData), "input").String()}
|
||
res, err := m.Embedding(prompt)
|
||
c.Writer.Write([]byte(res))
|
||
}
|