kefu/lib/glmv3.go

258 lines
6.2 KiB
Go
Raw Normal View History

2024-12-10 02:50:12 +00:00
package lib
import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/dgrijalva/jwt-go"
"io/ioutil"
"net/http"
"strings"
"time"
)
//chatGLM v3版本
type ChatGLM struct {
ApiKey string
Token string
BaseUrl string
}
func NewChatGLM(apiKey string) (*ChatGLM, error) {
token, err := generateChatGLMToken(apiKey, 3600*24*365)
if err != nil {
return nil, err
}
glm := &ChatGLM{
ApiKey: apiKey,
Token: token,
BaseUrl: "https://open.bigmodel.cn/api/paas/v3/model-api",
}
return glm, nil
}
func generateChatGLMToken(apiKey string, expSeconds int64) (string, error) {
idSecret := strings.Split(apiKey, ".")
if len(idSecret) != 2 {
return "", fmt.Errorf("invalid apikey")
}
id := idSecret[0]
secret := []byte(idSecret[1])
header := map[string]interface{}{
"alg": "HS256",
"sign_type": "SIGN",
}
payload := jwt.MapClaims{
"api_key": id,
"exp": time.Now().Add(time.Second * time.Duration(expSeconds)).Unix(),
"timestamp": time.Now().Unix(),
}
// Step 3: Generate JWT token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
token.Header = header
tokenString, err := token.SignedString(secret)
if err != nil {
return "", err
}
return tokenString, nil
}
func (this *ChatGLM) ChatGLM6B() (string, error) {
url := this.BaseUrl + "/chatglm_6b/invoke"
// 构建请求参数
params := map[string]interface{}{
"prompt": []map[string]string{{"role": "user", "content": "你好"}},
"temperature": 0.95,
"top_p": 0.7,
"incremental": false,
}
// 创建HTTP请求的body
// 创建HTTP请求的body
jsonParams, err := json.Marshal(params)
requestBody := bytes.NewBuffer(jsonParams)
// 创建POST请求
req, err := http.NewRequest("POST", url, requestBody)
if err != nil {
fmt.Println("创建请求失败:", err)
return "", err
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
//req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Authorization", "Bearer "+this.Token)
// 发送请求
client := http.Client{}
response, err := client.Do(req)
if err != nil {
fmt.Println("发送请求失败:", err)
return "", err
}
defer response.Body.Close()
// 读取响应
responseBody, err := ioutil.ReadAll(response.Body)
if err != nil {
return "", err
}
return string(responseBody), nil
}
/**
prompt
list
调用对话模型时将当前对话信息列表作为提示输入给模型; 按照 {"role": "user", "content": "你好"} 的键值对形式进行传参; 总长度超过模型最长输入限制后会自动截断需按时间由旧到新排序
temperature
float
采样温度控制输出的随机性必须为正数
取值范围是(0.0,1.0]不能等于 0,默认值为 0.95
值越大会使输出更随机更具创造性值越小输出会更加稳定或确定
建议您根据应用场景调整 top_p temperature 参数但不要同时调整两个参数
top_p
float
用温度取样的另一种方法称为核取样
取值范围是(0.0,1.0)开区间不能等于 0 1默认值为 0.7
模型考虑具有 top_p 概率质量的令牌的结果所以 0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取tokens
建议您根据应用场景调整 top_p temperature 参数但不要同时调整两个参数
request_id
string
由用户端传参需保证唯一性用于区分每次请求的唯一标识用户端不传时平台会默认生成
incremental
boolean
SSE接口调用时用于控制每次返回内容方式是增量还是全量不提供此参数时默认为增量返回
- true 为增量返回
- false 为全量返回
*/
func (this *ChatGLM) StreamChatGLM6B(prompt []map[string]string, temperature, topP float64, incremental bool) (*bufio.Reader, error) {
url := this.BaseUrl + "/chatglm_6b/sse-invoke"
// 构建请求参数
params := map[string]interface{}{
"prompt": prompt,
"temperature": temperature,
"top_p": topP,
"incremental": incremental,
}
// 创建HTTP请求的body
// 创建HTTP请求的body
jsonParams, err := json.Marshal(params)
requestBody := bytes.NewBuffer(jsonParams)
// 创建POST请求
req, err := http.NewRequest("POST", url, requestBody)
if err != nil {
fmt.Println("创建请求失败:", err)
return nil, err
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Authorization", "Bearer "+this.Token)
// 发送请求
client := http.Client{}
response, err := client.Do(req)
if err != nil {
fmt.Println("发送请求失败:", err)
return nil, err
}
//defer response.Body.Close()
// 读取响应
// 读取响应体数据
reader := bufio.NewReader(response.Body)
return reader, nil
//res := ""
//meta := ""
//for {
// line, err := reader.ReadString('\n')
// if err != nil {
// break
// }
// // 处理每行数据
// line = strings.TrimSpace(line)
// if line == "" {
// continue
// }
//
// // 根据冒号分割每行数据的键值对
// parts := strings.SplitN(line, ":", 2)
// if len(parts) != 2 {
// break
// }
//
// key := strings.TrimSpace(parts[0])
// value := strings.TrimSpace(parts[1])
//
// // 根据键的不同处理不同的字段
// switch key {
// case "event":
// if value == "finish" {
// break
// }
// case "data":
// // 设置Event的数据
// res = value
// fmt.Println(value)
// case "meta":
// // 解析JSON格式的元数据
// meta = value
// }
//
//}
//return fmt.Sprintf("{\"content\":\"%s\",\"meta\":%s}", res, meta), nil
}
func (this *ChatGLM) StreamRecv(reader *bufio.Reader) (string, error) {
waitForData:
line, err := reader.ReadString('\n')
if err != nil {
return "", err
}
// 处理每行数据
line = strings.TrimSpace(line)
if line == "" {
goto waitForData
}
// 根据冒号分割每行数据的键值对
parts := strings.SplitN(line, ":", 2)
if len(parts) != 2 {
return "", errors.New("数据格式错误")
}
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
// 根据键的不同处理不同的字段
switch key {
case "data":
// 设置Event的数据
return value, nil
//case "meta":
// // 解析JSON格式的元数据
// return value, nil
}
goto waitForData
//return "", errors.New("finish")
}