258 lines
6.2 KiB
Go
258 lines
6.2 KiB
Go
|
package lib
|
|||
|
|
|||
|
import (
|
|||
|
"bufio"
|
|||
|
"bytes"
|
|||
|
"encoding/json"
|
|||
|
"errors"
|
|||
|
"fmt"
|
|||
|
"github.com/dgrijalva/jwt-go"
|
|||
|
"io/ioutil"
|
|||
|
"net/http"
|
|||
|
"strings"
|
|||
|
"time"
|
|||
|
)
|
|||
|
|
|||
|
//chatGLM v3版本
|
|||
|
type ChatGLM struct {
|
|||
|
ApiKey string
|
|||
|
Token string
|
|||
|
BaseUrl string
|
|||
|
}
|
|||
|
|
|||
|
func NewChatGLM(apiKey string) (*ChatGLM, error) {
|
|||
|
|
|||
|
token, err := generateChatGLMToken(apiKey, 3600*24*365)
|
|||
|
if err != nil {
|
|||
|
return nil, err
|
|||
|
}
|
|||
|
glm := &ChatGLM{
|
|||
|
ApiKey: apiKey,
|
|||
|
Token: token,
|
|||
|
BaseUrl: "https://open.bigmodel.cn/api/paas/v3/model-api",
|
|||
|
}
|
|||
|
return glm, nil
|
|||
|
}
|
|||
|
func generateChatGLMToken(apiKey string, expSeconds int64) (string, error) {
|
|||
|
idSecret := strings.Split(apiKey, ".")
|
|||
|
if len(idSecret) != 2 {
|
|||
|
return "", fmt.Errorf("invalid apikey")
|
|||
|
}
|
|||
|
|
|||
|
id := idSecret[0]
|
|||
|
secret := []byte(idSecret[1])
|
|||
|
|
|||
|
header := map[string]interface{}{
|
|||
|
"alg": "HS256",
|
|||
|
"sign_type": "SIGN",
|
|||
|
}
|
|||
|
|
|||
|
payload := jwt.MapClaims{
|
|||
|
"api_key": id,
|
|||
|
"exp": time.Now().Add(time.Second * time.Duration(expSeconds)).Unix(),
|
|||
|
"timestamp": time.Now().Unix(),
|
|||
|
}
|
|||
|
|
|||
|
// Step 3: Generate JWT token
|
|||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
|
|||
|
token.Header = header
|
|||
|
tokenString, err := token.SignedString(secret)
|
|||
|
if err != nil {
|
|||
|
return "", err
|
|||
|
}
|
|||
|
|
|||
|
return tokenString, nil
|
|||
|
}
|
|||
|
func (this *ChatGLM) ChatGLM6B() (string, error) {
|
|||
|
url := this.BaseUrl + "/chatglm_6b/invoke"
|
|||
|
|
|||
|
// 构建请求参数
|
|||
|
params := map[string]interface{}{
|
|||
|
"prompt": []map[string]string{{"role": "user", "content": "你好"}},
|
|||
|
"temperature": 0.95,
|
|||
|
"top_p": 0.7,
|
|||
|
"incremental": false,
|
|||
|
}
|
|||
|
|
|||
|
// 创建HTTP请求的body
|
|||
|
// 创建HTTP请求的body
|
|||
|
jsonParams, err := json.Marshal(params)
|
|||
|
requestBody := bytes.NewBuffer(jsonParams)
|
|||
|
|
|||
|
// 创建POST请求
|
|||
|
req, err := http.NewRequest("POST", url, requestBody)
|
|||
|
if err != nil {
|
|||
|
fmt.Println("创建请求失败:", err)
|
|||
|
return "", err
|
|||
|
}
|
|||
|
|
|||
|
// 设置请求头
|
|||
|
req.Header.Set("Content-Type", "application/json")
|
|||
|
//req.Header.Set("Accept", "text/event-stream")
|
|||
|
req.Header.Set("Authorization", "Bearer "+this.Token)
|
|||
|
// 发送请求
|
|||
|
client := http.Client{}
|
|||
|
response, err := client.Do(req)
|
|||
|
if err != nil {
|
|||
|
fmt.Println("发送请求失败:", err)
|
|||
|
return "", err
|
|||
|
}
|
|||
|
|
|||
|
defer response.Body.Close()
|
|||
|
|
|||
|
// 读取响应
|
|||
|
responseBody, err := ioutil.ReadAll(response.Body)
|
|||
|
if err != nil {
|
|||
|
return "", err
|
|||
|
}
|
|||
|
return string(responseBody), nil
|
|||
|
}
|
|||
|
|
|||
|
/**
|
|||
|
prompt
|
|||
|
list
|
|||
|
调用对话模型时,将当前对话信息列表作为提示输入给模型; 按照 {"role": "user", "content": "你好"} 的键值对形式进行传参; 总长度超过模型最长输入限制后会自动截断,需按时间由旧到新排序
|
|||
|
|
|||
|
temperature
|
|||
|
float
|
|||
|
否
|
|||
|
采样温度,控制输出的随机性,必须为正数
|
|||
|
取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95
|
|||
|
值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定
|
|||
|
建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数
|
|||
|
|
|||
|
top_p
|
|||
|
float
|
|||
|
否
|
|||
|
用温度取样的另一种方法,称为核取样
|
|||
|
取值范围是:(0.0,1.0);开区间,不能等于 0 或 1,默认值为 0.7
|
|||
|
模型考虑具有 top_p 概率质量的令牌的结果。所以 0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取tokens
|
|||
|
建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数
|
|||
|
|
|||
|
request_id
|
|||
|
string
|
|||
|
否
|
|||
|
由用户端传参,需保证唯一性;用于区分每次请求的唯一标识,用户端不传时平台会默认生成。
|
|||
|
|
|||
|
incremental
|
|||
|
boolean
|
|||
|
否
|
|||
|
SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回
|
|||
|
- true 为增量返回
|
|||
|
- false 为全量返回
|
|||
|
*/
|
|||
|
func (this *ChatGLM) StreamChatGLM6B(prompt []map[string]string, temperature, topP float64, incremental bool) (*bufio.Reader, error) {
|
|||
|
url := this.BaseUrl + "/chatglm_6b/sse-invoke"
|
|||
|
|
|||
|
// 构建请求参数
|
|||
|
params := map[string]interface{}{
|
|||
|
"prompt": prompt,
|
|||
|
"temperature": temperature,
|
|||
|
"top_p": topP,
|
|||
|
"incremental": incremental,
|
|||
|
}
|
|||
|
|
|||
|
// 创建HTTP请求的body
|
|||
|
// 创建HTTP请求的body
|
|||
|
jsonParams, err := json.Marshal(params)
|
|||
|
requestBody := bytes.NewBuffer(jsonParams)
|
|||
|
|
|||
|
// 创建POST请求
|
|||
|
req, err := http.NewRequest("POST", url, requestBody)
|
|||
|
if err != nil {
|
|||
|
fmt.Println("创建请求失败:", err)
|
|||
|
return nil, err
|
|||
|
}
|
|||
|
|
|||
|
// 设置请求头
|
|||
|
req.Header.Set("Content-Type", "application/json")
|
|||
|
req.Header.Set("Accept", "text/event-stream")
|
|||
|
req.Header.Set("Authorization", "Bearer "+this.Token)
|
|||
|
// 发送请求
|
|||
|
client := http.Client{}
|
|||
|
response, err := client.Do(req)
|
|||
|
if err != nil {
|
|||
|
fmt.Println("发送请求失败:", err)
|
|||
|
return nil, err
|
|||
|
}
|
|||
|
|
|||
|
//defer response.Body.Close()
|
|||
|
|
|||
|
// 读取响应
|
|||
|
// 读取响应体数据
|
|||
|
reader := bufio.NewReader(response.Body)
|
|||
|
return reader, nil
|
|||
|
//res := ""
|
|||
|
//meta := ""
|
|||
|
//for {
|
|||
|
// line, err := reader.ReadString('\n')
|
|||
|
// if err != nil {
|
|||
|
// break
|
|||
|
// }
|
|||
|
// // 处理每行数据
|
|||
|
// line = strings.TrimSpace(line)
|
|||
|
// if line == "" {
|
|||
|
// continue
|
|||
|
// }
|
|||
|
//
|
|||
|
// // 根据冒号分割每行数据的键值对
|
|||
|
// parts := strings.SplitN(line, ":", 2)
|
|||
|
// if len(parts) != 2 {
|
|||
|
// break
|
|||
|
// }
|
|||
|
//
|
|||
|
// key := strings.TrimSpace(parts[0])
|
|||
|
// value := strings.TrimSpace(parts[1])
|
|||
|
//
|
|||
|
// // 根据键的不同处理不同的字段
|
|||
|
// switch key {
|
|||
|
// case "event":
|
|||
|
// if value == "finish" {
|
|||
|
// break
|
|||
|
// }
|
|||
|
// case "data":
|
|||
|
// // 设置Event的数据
|
|||
|
// res = value
|
|||
|
// fmt.Println(value)
|
|||
|
// case "meta":
|
|||
|
// // 解析JSON格式的元数据
|
|||
|
// meta = value
|
|||
|
// }
|
|||
|
//
|
|||
|
//}
|
|||
|
//return fmt.Sprintf("{\"content\":\"%s\",\"meta\":%s}", res, meta), nil
|
|||
|
}
|
|||
|
func (this *ChatGLM) StreamRecv(reader *bufio.Reader) (string, error) {
|
|||
|
waitForData:
|
|||
|
line, err := reader.ReadString('\n')
|
|||
|
if err != nil {
|
|||
|
return "", err
|
|||
|
}
|
|||
|
// 处理每行数据
|
|||
|
line = strings.TrimSpace(line)
|
|||
|
if line == "" {
|
|||
|
goto waitForData
|
|||
|
}
|
|||
|
|
|||
|
// 根据冒号分割每行数据的键值对
|
|||
|
parts := strings.SplitN(line, ":", 2)
|
|||
|
if len(parts) != 2 {
|
|||
|
return "", errors.New("数据格式错误")
|
|||
|
}
|
|||
|
|
|||
|
key := strings.TrimSpace(parts[0])
|
|||
|
value := strings.TrimSpace(parts[1])
|
|||
|
|
|||
|
// 根据键的不同处理不同的字段
|
|||
|
switch key {
|
|||
|
case "data":
|
|||
|
// 设置Event的数据
|
|||
|
return value, nil
|
|||
|
//case "meta":
|
|||
|
// // 解析JSON格式的元数据
|
|||
|
// return value, nil
|
|||
|
}
|
|||
|
goto waitForData
|
|||
|
//return "", errors.New("finish")
|
|||
|
}
|