kefu/lib/ERNIE-Bot-turbo.go

180 lines
4.1 KiB
Go

package lib
import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/tidwall/gjson"
"io/ioutil"
"net/http"
"strings"
)
//百度文心一言ERNIE-Bot-turbo
type ErnieBotTurbo struct {
AppId string
ApiKey string
SecretKey string
AccessToken string
}
func NewErnieBotTurbo(appId, apiKey, secretKey string) (*ErnieBotTurbo, error) {
m := &ErnieBotTurbo{
AppId: appId,
ApiKey: apiKey,
SecretKey: secretKey,
}
var err error
m.AccessToken, err = m.GenerateAccessToken()
if err != nil {
return m, err
}
return m, nil
}
//获取access_token
func (this *ErnieBotTurbo) GenerateAccessToken() (string, error) {
url := fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
this.ApiKey,
this.SecretKey)
// 创建POST请求
req, err := http.NewRequest("GET", url, nil)
if err != nil {
fmt.Println("创建请求失败:", err)
return "", err
}
// 发送请求
client := http.Client{}
response, err := client.Do(req)
if err != nil {
fmt.Println("发送请求失败:", err)
return "", err
}
defer response.Body.Close()
// 读取响应
responseBody, err := ioutil.ReadAll(response.Body)
if err != nil {
return "", err
}
accessToken := gjson.Get(string(responseBody), "access_token").String()
if accessToken == "" {
return "", errors.New("获取access_token失败")
}
this.AccessToken = accessToken
return accessToken, nil
}
//流式请求接口
func (this *ErnieBotTurbo) StreamChat(messages []map[string]string) (*bufio.Reader, error) {
url := fmt.Sprintf("https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token=" + this.AccessToken)
// 构建请求参数
params := map[string]interface{}{
"messages": messages,
"stream": true,
}
// 创建HTTP请求的body
jsonParams, err := json.Marshal(params)
requestBody := bytes.NewBuffer(jsonParams)
// 创建POST请求
req, err := http.NewRequest("POST", url, requestBody)
if err != nil {
fmt.Println("创建请求失败:", err)
return nil, err
}
// 设置请求头
//req.Header.Set("Access", "text/event-stream")
// 发送请求
client := http.Client{}
response, err := client.Do(req)
if err != nil {
fmt.Println("发送请求失败:", err)
return nil, err
}
//defer response.Body.Close()
// 读取响应
// 读取响应体数据
reader := bufio.NewReader(response.Body)
return reader, nil
}
func (this *ErnieBotTurbo) StreamRecv(reader *bufio.Reader) (string, error) {
waitForData:
line, err := reader.ReadString('\n')
if err != nil {
return "", err
}
// 处理每行数据
line = strings.TrimSpace(line)
if line == "" {
goto waitForData
}
// 根据冒号分割每行数据的键值对
parts := strings.SplitN(line, ":", 2)
if len(parts) != 2 {
return "", errors.New("数据格式错误")
}
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
// 根据键的不同处理不同的字段
switch key {
case "data":
// 设置Event的数据
return value, nil
//case "meta":
// // 解析JSON格式的元数据
// return value, nil
}
goto waitForData
//return "", errors.New("finish")
}
//流式请求接口
func (this *ErnieBotTurbo) Embedding(input []string) (string, error) {
url := fmt.Sprintf("https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1?access_token=" + this.AccessToken)
// 构建请求参数
params := map[string]interface{}{
"input": input,
}
// 创建HTTP请求的body
jsonParams, err := json.Marshal(params)
requestBody := bytes.NewBuffer(jsonParams)
// 创建POST请求
req, err := http.NewRequest("POST", url, requestBody)
if err != nil {
fmt.Println("创建请求失败:", err)
return "", err
}
// 设置请求头
//req.Header.Set("Access", "text/event-stream")
// 发送请求
client := http.Client{}
response, err := client.Do(req)
if err != nil {
fmt.Println("发送请求失败:", err)
return "", err
}
defer response.Body.Close()
// 读取响应
responseBody, err := ioutil.ReadAll(response.Body)
if err != nil {
return "", err
}
return string(responseBody), nil
}