258 lines
6.2 KiB
Go
258 lines
6.2 KiB
Go
package utils
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"github.com/dgrijalva/jwt-go"
|
||
"io/ioutil"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
//chatGLM v3版本
|
||
type ChatGLM struct {
|
||
ApiKey string
|
||
Token string
|
||
BaseUrl string
|
||
}
|
||
|
||
func NewChatGLM(apiKey string) (*ChatGLM, error) {
|
||
|
||
token, err := generateChatGLMToken(apiKey, 3600*24*365)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
glm := &ChatGLM{
|
||
ApiKey: apiKey,
|
||
Token: token,
|
||
BaseUrl: "https://open.bigmodel.cn/api/paas/v3/model-api",
|
||
}
|
||
return glm, nil
|
||
}
|
||
func generateChatGLMToken(apiKey string, expSeconds int64) (string, error) {
|
||
idSecret := strings.Split(apiKey, ".")
|
||
if len(idSecret) != 2 {
|
||
return "", fmt.Errorf("invalid apikey")
|
||
}
|
||
|
||
id := idSecret[0]
|
||
secret := []byte(idSecret[1])
|
||
|
||
header := map[string]interface{}{
|
||
"alg": "HS256",
|
||
"sign_type": "SIGN",
|
||
}
|
||
|
||
payload := jwt.MapClaims{
|
||
"api_key": id,
|
||
"exp": time.Now().Add(time.Second * time.Duration(expSeconds)).Unix(),
|
||
"timestamp": time.Now().Unix(),
|
||
}
|
||
|
||
// Step 3: Generate JWT token
|
||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
|
||
token.Header = header
|
||
tokenString, err := token.SignedString(secret)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
return tokenString, nil
|
||
}
|
||
func (this *ChatGLM) ChatGLM6B() (string, error) {
|
||
url := this.BaseUrl + "/chatglm_6b/invoke"
|
||
|
||
// 构建请求参数
|
||
params := map[string]interface{}{
|
||
"prompt": []map[string]string{{"role": "user", "content": "你好"}},
|
||
"temperature": 0.95,
|
||
"top_p": 0.7,
|
||
"incremental": false,
|
||
}
|
||
|
||
// 创建HTTP请求的body
|
||
// 创建HTTP请求的body
|
||
jsonParams, err := json.Marshal(params)
|
||
requestBody := bytes.NewBuffer(jsonParams)
|
||
|
||
// 创建POST请求
|
||
req, err := http.NewRequest("POST", url, requestBody)
|
||
if err != nil {
|
||
fmt.Println("创建请求失败:", err)
|
||
return "", err
|
||
}
|
||
|
||
// 设置请求头
|
||
req.Header.Set("Content-Type", "application/json")
|
||
//req.Header.Set("Accept", "text/event-stream")
|
||
req.Header.Set("Authorization", "Bearer "+this.Token)
|
||
// 发送请求
|
||
client := http.Client{}
|
||
response, err := client.Do(req)
|
||
if err != nil {
|
||
fmt.Println("发送请求失败:", err)
|
||
return "", err
|
||
}
|
||
|
||
defer response.Body.Close()
|
||
|
||
// 读取响应
|
||
responseBody, err := ioutil.ReadAll(response.Body)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
return string(responseBody), nil
|
||
}
|
||
|
||
/**
|
||
prompt
|
||
list
|
||
调用对话模型时,将当前对话信息列表作为提示输入给模型; 按照 {"role": "user", "content": "你好"} 的键值对形式进行传参; 总长度超过模型最长输入限制后会自动截断,需按时间由旧到新排序
|
||
|
||
temperature
|
||
float
|
||
否
|
||
采样温度,控制输出的随机性,必须为正数
|
||
取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95
|
||
值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定
|
||
建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数
|
||
|
||
top_p
|
||
float
|
||
否
|
||
用温度取样的另一种方法,称为核取样
|
||
取值范围是:(0.0,1.0);开区间,不能等于 0 或 1,默认值为 0.7
|
||
模型考虑具有 top_p 概率质量的令牌的结果。所以 0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取tokens
|
||
建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数
|
||
|
||
request_id
|
||
string
|
||
否
|
||
由用户端传参,需保证唯一性;用于区分每次请求的唯一标识,用户端不传时平台会默认生成。
|
||
|
||
incremental
|
||
boolean
|
||
否
|
||
SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回
|
||
- true 为增量返回
|
||
- false 为全量返回
|
||
*/
|
||
func (this *ChatGLM) StreamChatGLM6B(prompt []map[string]string, temperature, topP float64, incremental bool) (*bufio.Reader, error) {
|
||
url := this.BaseUrl + "/chatglm_6b/sse-invoke"
|
||
|
||
// 构建请求参数
|
||
params := map[string]interface{}{
|
||
"prompt": prompt,
|
||
"temperature": temperature,
|
||
"top_p": topP,
|
||
"incremental": incremental,
|
||
}
|
||
|
||
// 创建HTTP请求的body
|
||
// 创建HTTP请求的body
|
||
jsonParams, err := json.Marshal(params)
|
||
requestBody := bytes.NewBuffer(jsonParams)
|
||
|
||
// 创建POST请求
|
||
req, err := http.NewRequest("POST", url, requestBody)
|
||
if err != nil {
|
||
fmt.Println("创建请求失败:", err)
|
||
return nil, err
|
||
}
|
||
|
||
// 设置请求头
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("Accept", "text/event-stream")
|
||
req.Header.Set("Authorization", "Bearer "+this.Token)
|
||
// 发送请求
|
||
client := http.Client{}
|
||
response, err := client.Do(req)
|
||
if err != nil {
|
||
fmt.Println("发送请求失败:", err)
|
||
return nil, err
|
||
}
|
||
|
||
//defer response.Body.Close()
|
||
|
||
// 读取响应
|
||
// 读取响应体数据
|
||
reader := bufio.NewReader(response.Body)
|
||
return reader, nil
|
||
//res := ""
|
||
//meta := ""
|
||
//for {
|
||
// line, err := reader.ReadString('\n')
|
||
// if err != nil {
|
||
// break
|
||
// }
|
||
// // 处理每行数据
|
||
// line = strings.TrimSpace(line)
|
||
// if line == "" {
|
||
// continue
|
||
// }
|
||
//
|
||
// // 根据冒号分割每行数据的键值对
|
||
// parts := strings.SplitN(line, ":", 2)
|
||
// if len(parts) != 2 {
|
||
// break
|
||
// }
|
||
//
|
||
// key := strings.TrimSpace(parts[0])
|
||
// value := strings.TrimSpace(parts[1])
|
||
//
|
||
// // 根据键的不同处理不同的字段
|
||
// switch key {
|
||
// case "event":
|
||
// if value == "finish" {
|
||
// break
|
||
// }
|
||
// case "data":
|
||
// // 设置Event的数据
|
||
// res = value
|
||
// fmt.Println(value)
|
||
// case "meta":
|
||
// // 解析JSON格式的元数据
|
||
// meta = value
|
||
// }
|
||
//
|
||
//}
|
||
//return fmt.Sprintf("{\"content\":\"%s\",\"meta\":%s}", res, meta), nil
|
||
}
|
||
func (this *ChatGLM) StreamRecv(reader *bufio.Reader) (string, error) {
|
||
waitForData:
|
||
line, err := reader.ReadString('\n')
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
// 处理每行数据
|
||
line = strings.TrimSpace(line)
|
||
if line == "" {
|
||
goto waitForData
|
||
}
|
||
|
||
// 根据冒号分割每行数据的键值对
|
||
parts := strings.SplitN(line, ":", 2)
|
||
if len(parts) != 2 {
|
||
return "", errors.New("数据格式错误")
|
||
}
|
||
|
||
key := strings.TrimSpace(parts[0])
|
||
value := strings.TrimSpace(parts[1])
|
||
|
||
// 根据键的不同处理不同的字段
|
||
switch key {
|
||
case "data":
|
||
// 设置Event的数据
|
||
return value, nil
|
||
//case "meta":
|
||
// // 解析JSON格式的元数据
|
||
// return value, nil
|
||
}
|
||
goto waitForData
|
||
//return "", errors.New("finish")
|
||
}
|