kefu/knowledge/utils/glmv3.go

258 lines
6.2 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package utils
import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/dgrijalva/jwt-go"
"io/ioutil"
"net/http"
"strings"
"time"
)
//chatGLM v3版本
type ChatGLM struct {
ApiKey string
Token string
BaseUrl string
}
func NewChatGLM(apiKey string) (*ChatGLM, error) {
token, err := generateChatGLMToken(apiKey, 3600*24*365)
if err != nil {
return nil, err
}
glm := &ChatGLM{
ApiKey: apiKey,
Token: token,
BaseUrl: "https://open.bigmodel.cn/api/paas/v3/model-api",
}
return glm, nil
}
func generateChatGLMToken(apiKey string, expSeconds int64) (string, error) {
idSecret := strings.Split(apiKey, ".")
if len(idSecret) != 2 {
return "", fmt.Errorf("invalid apikey")
}
id := idSecret[0]
secret := []byte(idSecret[1])
header := map[string]interface{}{
"alg": "HS256",
"sign_type": "SIGN",
}
payload := jwt.MapClaims{
"api_key": id,
"exp": time.Now().Add(time.Second * time.Duration(expSeconds)).Unix(),
"timestamp": time.Now().Unix(),
}
// Step 3: Generate JWT token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
token.Header = header
tokenString, err := token.SignedString(secret)
if err != nil {
return "", err
}
return tokenString, nil
}
func (this *ChatGLM) ChatGLM6B() (string, error) {
url := this.BaseUrl + "/chatglm_6b/invoke"
// 构建请求参数
params := map[string]interface{}{
"prompt": []map[string]string{{"role": "user", "content": "你好"}},
"temperature": 0.95,
"top_p": 0.7,
"incremental": false,
}
// 创建HTTP请求的body
// 创建HTTP请求的body
jsonParams, err := json.Marshal(params)
requestBody := bytes.NewBuffer(jsonParams)
// 创建POST请求
req, err := http.NewRequest("POST", url, requestBody)
if err != nil {
fmt.Println("创建请求失败:", err)
return "", err
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
//req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Authorization", "Bearer "+this.Token)
// 发送请求
client := http.Client{}
response, err := client.Do(req)
if err != nil {
fmt.Println("发送请求失败:", err)
return "", err
}
defer response.Body.Close()
// 读取响应
responseBody, err := ioutil.ReadAll(response.Body)
if err != nil {
return "", err
}
return string(responseBody), nil
}
/**
prompt
list
调用对话模型时,将当前对话信息列表作为提示输入给模型; 按照 {"role": "user", "content": "你好"} 的键值对形式进行传参; 总长度超过模型最长输入限制后会自动截断,需按时间由旧到新排序
temperature
float
采样温度,控制输出的随机性,必须为正数
取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95
值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定
建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数
top_p
float
用温度取样的另一种方法,称为核取样
取值范围是:(0.0,1.0);开区间,不能等于 0 或 1默认值为 0.7
模型考虑具有 top_p 概率质量的令牌的结果。所以 0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取tokens
建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数
request_id
string
由用户端传参,需保证唯一性;用于区分每次请求的唯一标识,用户端不传时平台会默认生成。
incremental
boolean
SSE接口调用时用于控制每次返回内容方式是增量还是全量不提供此参数时默认为增量返回
- true 为增量返回
- false 为全量返回
*/
func (this *ChatGLM) StreamChatGLM6B(prompt []map[string]string, temperature, topP float64, incremental bool) (*bufio.Reader, error) {
url := this.BaseUrl + "/chatglm_6b/sse-invoke"
// 构建请求参数
params := map[string]interface{}{
"prompt": prompt,
"temperature": temperature,
"top_p": topP,
"incremental": incremental,
}
// 创建HTTP请求的body
// 创建HTTP请求的body
jsonParams, err := json.Marshal(params)
requestBody := bytes.NewBuffer(jsonParams)
// 创建POST请求
req, err := http.NewRequest("POST", url, requestBody)
if err != nil {
fmt.Println("创建请求失败:", err)
return nil, err
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Authorization", "Bearer "+this.Token)
// 发送请求
client := http.Client{}
response, err := client.Do(req)
if err != nil {
fmt.Println("发送请求失败:", err)
return nil, err
}
//defer response.Body.Close()
// 读取响应
// 读取响应体数据
reader := bufio.NewReader(response.Body)
return reader, nil
//res := ""
//meta := ""
//for {
// line, err := reader.ReadString('\n')
// if err != nil {
// break
// }
// // 处理每行数据
// line = strings.TrimSpace(line)
// if line == "" {
// continue
// }
//
// // 根据冒号分割每行数据的键值对
// parts := strings.SplitN(line, ":", 2)
// if len(parts) != 2 {
// break
// }
//
// key := strings.TrimSpace(parts[0])
// value := strings.TrimSpace(parts[1])
//
// // 根据键的不同处理不同的字段
// switch key {
// case "event":
// if value == "finish" {
// break
// }
// case "data":
// // 设置Event的数据
// res = value
// fmt.Println(value)
// case "meta":
// // 解析JSON格式的元数据
// meta = value
// }
//
//}
//return fmt.Sprintf("{\"content\":\"%s\",\"meta\":%s}", res, meta), nil
}
func (this *ChatGLM) StreamRecv(reader *bufio.Reader) (string, error) {
waitForData:
line, err := reader.ReadString('\n')
if err != nil {
return "", err
}
// 处理每行数据
line = strings.TrimSpace(line)
if line == "" {
goto waitForData
}
// 根据冒号分割每行数据的键值对
parts := strings.SplitN(line, ":", 2)
if len(parts) != 2 {
return "", errors.New("数据格式错误")
}
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
// 根据键的不同处理不同的字段
switch key {
case "data":
// 设置Event的数据
return value, nil
//case "meta":
// // 解析JSON格式的元数据
// return value, nil
}
goto waitForData
//return "", errors.New("finish")
}