180 lines
4.1 KiB
Go
180 lines
4.1 KiB
Go
package lib
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/tidwall/gjson"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"strings"
|
|
)
|
|
|
|
//百度文心一言ERNIE-Bot-turbo
|
|
type ErnieBotTurbo struct {
|
|
AppId string
|
|
ApiKey string
|
|
SecretKey string
|
|
AccessToken string
|
|
}
|
|
|
|
func NewErnieBotTurbo(appId, apiKey, secretKey string) (*ErnieBotTurbo, error) {
|
|
|
|
m := &ErnieBotTurbo{
|
|
AppId: appId,
|
|
ApiKey: apiKey,
|
|
SecretKey: secretKey,
|
|
}
|
|
var err error
|
|
m.AccessToken, err = m.GenerateAccessToken()
|
|
if err != nil {
|
|
return m, err
|
|
}
|
|
return m, nil
|
|
}
|
|
|
|
//获取access_token
|
|
func (this *ErnieBotTurbo) GenerateAccessToken() (string, error) {
|
|
url := fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
|
|
this.ApiKey,
|
|
this.SecretKey)
|
|
// 创建POST请求
|
|
req, err := http.NewRequest("GET", url, nil)
|
|
if err != nil {
|
|
fmt.Println("创建请求失败:", err)
|
|
return "", err
|
|
}
|
|
// 发送请求
|
|
client := http.Client{}
|
|
response, err := client.Do(req)
|
|
if err != nil {
|
|
fmt.Println("发送请求失败:", err)
|
|
return "", err
|
|
}
|
|
defer response.Body.Close()
|
|
|
|
// 读取响应
|
|
responseBody, err := ioutil.ReadAll(response.Body)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
accessToken := gjson.Get(string(responseBody), "access_token").String()
|
|
if accessToken == "" {
|
|
return "", errors.New("获取access_token失败")
|
|
}
|
|
this.AccessToken = accessToken
|
|
return accessToken, nil
|
|
}
|
|
|
|
//流式请求接口
|
|
func (this *ErnieBotTurbo) StreamChat(messages []map[string]string) (*bufio.Reader, error) {
|
|
url := fmt.Sprintf("https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token=" + this.AccessToken)
|
|
|
|
// 构建请求参数
|
|
params := map[string]interface{}{
|
|
"messages": messages,
|
|
"stream": true,
|
|
}
|
|
|
|
// 创建HTTP请求的body
|
|
jsonParams, err := json.Marshal(params)
|
|
requestBody := bytes.NewBuffer(jsonParams)
|
|
|
|
// 创建POST请求
|
|
req, err := http.NewRequest("POST", url, requestBody)
|
|
if err != nil {
|
|
fmt.Println("创建请求失败:", err)
|
|
return nil, err
|
|
}
|
|
// 设置请求头
|
|
//req.Header.Set("Access", "text/event-stream")
|
|
// 发送请求
|
|
client := http.Client{}
|
|
response, err := client.Do(req)
|
|
if err != nil {
|
|
fmt.Println("发送请求失败:", err)
|
|
return nil, err
|
|
}
|
|
|
|
//defer response.Body.Close()
|
|
|
|
// 读取响应
|
|
// 读取响应体数据
|
|
reader := bufio.NewReader(response.Body)
|
|
return reader, nil
|
|
}
|
|
func (this *ErnieBotTurbo) StreamRecv(reader *bufio.Reader) (string, error) {
|
|
waitForData:
|
|
line, err := reader.ReadString('\n')
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
// 处理每行数据
|
|
line = strings.TrimSpace(line)
|
|
if line == "" {
|
|
goto waitForData
|
|
}
|
|
|
|
// 根据冒号分割每行数据的键值对
|
|
parts := strings.SplitN(line, ":", 2)
|
|
if len(parts) != 2 {
|
|
return "", errors.New("数据格式错误")
|
|
}
|
|
|
|
key := strings.TrimSpace(parts[0])
|
|
value := strings.TrimSpace(parts[1])
|
|
|
|
// 根据键的不同处理不同的字段
|
|
switch key {
|
|
case "data":
|
|
// 设置Event的数据
|
|
return value, nil
|
|
//case "meta":
|
|
// // 解析JSON格式的元数据
|
|
// return value, nil
|
|
}
|
|
goto waitForData
|
|
//return "", errors.New("finish")
|
|
}
|
|
|
|
//流式请求接口
|
|
func (this *ErnieBotTurbo) Embedding(input []string) (string, error) {
|
|
url := fmt.Sprintf("https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1?access_token=" + this.AccessToken)
|
|
|
|
// 构建请求参数
|
|
params := map[string]interface{}{
|
|
"input": input,
|
|
}
|
|
|
|
// 创建HTTP请求的body
|
|
jsonParams, err := json.Marshal(params)
|
|
requestBody := bytes.NewBuffer(jsonParams)
|
|
|
|
// 创建POST请求
|
|
req, err := http.NewRequest("POST", url, requestBody)
|
|
if err != nil {
|
|
fmt.Println("创建请求失败:", err)
|
|
return "", err
|
|
}
|
|
// 设置请求头
|
|
//req.Header.Set("Access", "text/event-stream")
|
|
// 发送请求
|
|
client := http.Client{}
|
|
response, err := client.Do(req)
|
|
if err != nil {
|
|
fmt.Println("发送请求失败:", err)
|
|
return "", err
|
|
}
|
|
|
|
defer response.Body.Close()
|
|
|
|
// 读取响应
|
|
responseBody, err := ioutil.ReadAll(response.Body)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return string(responseBody), nil
|
|
}
|