kefu/lib/chatgpt.go

353 lines
8.5 KiB
Go

package lib
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"github.com/sashabaranov/go-openai"
"io"
"io/ioutil"
"log"
"mime/multipart"
"net/http"
"os"
"strings"
)
type ChatGptTool struct {
Secret string
Url string
Client *openai.Client
}
type Gpt3Dot5Message openai.ChatCompletionMessage
func NewChatGptTool(url, key string) *ChatGptTool {
if strings.Contains(url, "azure") {
config := openai.DefaultAzureConfig(key, url)
client := openai.NewClientWithConfig(config)
return &ChatGptTool{
Secret: key,
Client: client,
Url: url,
}
}
config := openai.DefaultConfig(key)
if url != "" {
url = strings.TrimRight(url, "/")
if !strings.Contains(url, "/v1") {
url = url + "/v1"
}
config.BaseURL = url
}
client := openai.NewClientWithConfig(config)
//client := openai.NewClient(secret)
return &ChatGptTool{
Secret: key,
Client: client,
Url: url,
}
}
/*
*
调用gpt3.5接口
*/
func (this *ChatGptTool) ChatGPT3Dot5Turbo(messages []Gpt3Dot5Message, bigModelName string) (string, error) {
if bigModelName == "" {
bigModelName = "gpt-3.5-turbo-16k"
}
reqMessages := make([]openai.ChatCompletionMessage, 0)
for _, row := range messages {
reqMessage := openai.ChatCompletionMessage{
Role: row.Role,
Content: row.Content,
Name: row.Name,
}
reqMessages = append(reqMessages, reqMessage)
}
resp, err := this.Client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: bigModelName,
MaxTokens: 1800,
Messages: reqMessages,
Temperature: 0,
},
)
if err != nil {
log.Println("ChatGPT3Dot5Turbo error: ", err)
return "", err
}
return resp.Choices[0].Message.Content, nil
}
/*
*
调用gpt3.5流式接口
*/
func (this *ChatGptTool) ChatGPT3Dot5TurboStream(messages []Gpt3Dot5Message, bigModelName string) (*openai.ChatCompletionStream, error) {
if bigModelName == "" {
bigModelName = "gpt-3.5-turbo-16k"
}
c := this.Client
ctx := context.Background()
reqMessages := make([]openai.ChatCompletionMessage, 0)
for _, row := range messages {
reqMessage := openai.ChatCompletionMessage{
Role: row.Role,
Content: row.Content,
Name: row.Name,
}
reqMessages = append(reqMessages, reqMessage)
}
req := openai.ChatCompletionRequest{
Model: bigModelName,
MaxTokens: 1800,
Messages: reqMessages,
Stream: true,
Temperature: 0,
}
stream, err := c.CreateChatCompletionStream(ctx, req)
if err != nil {
fmt.Printf("ChatCompletionStream error: %v\n", err)
return stream, err
}
//for {
// response, err := stream.Recv()
// if errors.Is(err, io.EOF) {
// log.Println("\nStream finished")
// break
// } else if err != nil {
// log.Printf("\nStream error: %v\n", err)
// break
// } else {
// log.Println(response.Choices[0].Delta.Content, err)
// }
//}
return stream, nil
}
// MyChatGptStream 实现了 ChatGptStream 接口
type MyChatGptStream struct {
scanner *bufio.Scanner
response *http.Response
}
// Close 实现了 ChatGptStream 接口中的 Close 方法
func (s *MyChatGptStream) Close() {
s.response.Body.Close()
}
// Recv 实现了 ChatGptStream 接口中的 Recv 方法
func (s *MyChatGptStream) Recv() (openai.ChatCompletionStreamResponse, error) {
var respJson openai.ChatCompletionStreamResponse
s.scanner.Scan()
line := s.scanner.Text()
//log.Println(line)
if strings.Contains(line, "[DONE]") {
return respJson, io.EOF
}
line = strings.TrimSpace(strings.Replace(line, "data: ", "", 1))
if line == "" {
return respJson, nil
}
err := json.Unmarshal([]byte(line), &respJson)
if err != nil {
return respJson, io.EOF
}
//respContent, ok := respJson["choices"].([]interface{})[0].(map[string]interface{})["delta"].(map[string]interface{})["content"].(string)
//if !ok {
// return "", io.EOF
//}
return respJson, nil
}
// http接口请求openai流式响应
func (this *ChatGptTool) ChatGPTHttpStream(messages []Gpt3Dot5Message, bigModelName, visitorId string) (*MyChatGptStream, error) {
if bigModelName == "" {
bigModelName = "gpt-3.5-turbo-16k"
}
url := this.Url + "/chat/completions"
data := map[string]interface{}{
"stream": true,
"model": bigModelName,
}
if strings.Contains(this.Secret, "fastgpt-") {
data["chatId"] = visitorId
}
data["messages"] = make([]map[string]interface{}, 0)
for _, row := range messages {
reqMessage := map[string]interface{}{
"role": row.Role,
"content": row.Content,
}
data["messages"] = append(data["messages"].([]map[string]interface{}), reqMessage)
}
jsonData, err := json.Marshal(data)
if err != nil {
log.Println("Error marshaling JSON:", err)
return nil, err
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
log.Println("Error creating request:", err)
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+this.Secret)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil || resp.StatusCode != 200 {
log.Println("Error making request:", err)
return nil, err
}
scanner := bufio.NewScanner(resp.Body)
return &MyChatGptStream{scanner: scanner, response: resp}, nil
}
type EmbeddingRequest struct {
Input string `json:"input"`
Model string `json:"model"`
}
type EmbeddingResponse struct {
Data []struct {
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
Object string `json:"object"`
} `json:"data"`
Model string `json:"model"`
Object string `json:"object"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
func (this *ChatGptTool) GetEmbedding(input string, model string) (string, error) {
if strings.Contains(this.Url, "azure") {
resp, err := this.Client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Input: []string{input},
Model: openai.AdaEmbeddingV2,
})
if err != nil {
log.Println("CreateEmbeddings error:", err)
return "", err
}
respStr, _ := json.Marshal(resp)
return string(respStr), nil
//vectors := resp.Data[0].Embedding // []float32 with 1536 dimensions
}
// 构建请求体
requestBody := EmbeddingRequest{
Input: input,
Model: model,
}
requestBodyBytes, err := json.Marshal(requestBody)
if err != nil {
return "", err
}
// 构建 HTTP 请求
if !strings.Contains(this.Url, "/v1") {
this.Url = this.Url + "/v1"
}
url := this.Url + "/embeddings"
req, err := http.NewRequest("POST", url, bytes.NewReader(requestBodyBytes))
if err != nil {
log.Println("embeddings error:", err)
return "", err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+this.Secret)
// 发送请求并获取响应
client := http.Client{}
resp, err := client.Do(req)
if err != nil {
log.Println("embeddings error:", err)
return "", err
}
defer resp.Body.Close()
// 解析响应体
responseBodyBytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Println("embeddings error:", err)
return string(responseBodyBytes), err
}
return string(responseBodyBytes), nil
}
// 语音转文本
func (this *ChatGptTool) GetWhisper(filePath string) (string, error) {
// 构建 HTTP 请求
if !strings.Contains(this.Url, "/v1") {
this.Url = this.Url + "/v1"
}
// Replace 'TOKEN' with your actual API token
token := this.Secret
file := filePath // Replace with the actual file path
// Create a buffer to store the request body
var requestBody bytes.Buffer
writer := multipart.NewWriter(&requestBody)
// Add the 'file' parameter
filePart, err := writer.CreateFormFile("file", file)
if err != nil {
return "", err
}
fileToUpload, err := os.Open(file)
if err != nil {
return "", err
}
defer fileToUpload.Close()
if _, err = io.Copy(filePart, fileToUpload); err != nil {
return "", err
}
// Add the 'model' parameter
writer.WriteField("model", "whisper-1")
// Close the multipart writer
writer.Close()
// Create the HTTP request
req, err := http.NewRequest("POST", this.Url+"/audio/transcriptions", &requestBody)
if err != nil {
return "", err
}
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("Content-Type", writer.FormDataContentType())
// Perform the request
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
// 解析响应体
responseBodyBytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Println("whisper error:", err)
return string(responseBodyBytes), err
}
return string(responseBodyBytes), nil
}