353 lines
8.5 KiB
Go
353 lines
8.5 KiB
Go
|
package lib
|
||
|
|
||
|
import (
|
||
|
"bufio"
|
||
|
"bytes"
|
||
|
"context"
|
||
|
"encoding/json"
|
||
|
"fmt"
|
||
|
"github.com/sashabaranov/go-openai"
|
||
|
"io"
|
||
|
"io/ioutil"
|
||
|
"log"
|
||
|
"mime/multipart"
|
||
|
"net/http"
|
||
|
"os"
|
||
|
"strings"
|
||
|
)
|
||
|
|
||
|
type ChatGptTool struct {
|
||
|
Secret string
|
||
|
Url string
|
||
|
Client *openai.Client
|
||
|
}
|
||
|
type Gpt3Dot5Message openai.ChatCompletionMessage
|
||
|
|
||
|
func NewChatGptTool(url, key string) *ChatGptTool {
|
||
|
if strings.Contains(url, "azure") {
|
||
|
config := openai.DefaultAzureConfig(key, url)
|
||
|
client := openai.NewClientWithConfig(config)
|
||
|
return &ChatGptTool{
|
||
|
Secret: key,
|
||
|
Client: client,
|
||
|
Url: url,
|
||
|
}
|
||
|
}
|
||
|
config := openai.DefaultConfig(key)
|
||
|
if url != "" {
|
||
|
url = strings.TrimRight(url, "/")
|
||
|
if !strings.Contains(url, "/v1") {
|
||
|
url = url + "/v1"
|
||
|
}
|
||
|
config.BaseURL = url
|
||
|
}
|
||
|
client := openai.NewClientWithConfig(config)
|
||
|
//client := openai.NewClient(secret)
|
||
|
return &ChatGptTool{
|
||
|
Secret: key,
|
||
|
Client: client,
|
||
|
Url: url,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/*
|
||
|
*
|
||
|
调用gpt3.5接口
|
||
|
*/
|
||
|
func (this *ChatGptTool) ChatGPT3Dot5Turbo(messages []Gpt3Dot5Message, bigModelName string) (string, error) {
|
||
|
if bigModelName == "" {
|
||
|
bigModelName = "gpt-3.5-turbo-16k"
|
||
|
}
|
||
|
reqMessages := make([]openai.ChatCompletionMessage, 0)
|
||
|
for _, row := range messages {
|
||
|
reqMessage := openai.ChatCompletionMessage{
|
||
|
Role: row.Role,
|
||
|
Content: row.Content,
|
||
|
Name: row.Name,
|
||
|
}
|
||
|
reqMessages = append(reqMessages, reqMessage)
|
||
|
}
|
||
|
resp, err := this.Client.CreateChatCompletion(
|
||
|
context.Background(),
|
||
|
openai.ChatCompletionRequest{
|
||
|
Model: bigModelName,
|
||
|
MaxTokens: 1800,
|
||
|
Messages: reqMessages,
|
||
|
Temperature: 0,
|
||
|
},
|
||
|
)
|
||
|
|
||
|
if err != nil {
|
||
|
log.Println("ChatGPT3Dot5Turbo error: ", err)
|
||
|
return "", err
|
||
|
}
|
||
|
|
||
|
return resp.Choices[0].Message.Content, nil
|
||
|
}
|
||
|
|
||
|
/*
|
||
|
*
|
||
|
调用gpt3.5流式接口
|
||
|
*/
|
||
|
func (this *ChatGptTool) ChatGPT3Dot5TurboStream(messages []Gpt3Dot5Message, bigModelName string) (*openai.ChatCompletionStream, error) {
|
||
|
if bigModelName == "" {
|
||
|
bigModelName = "gpt-3.5-turbo-16k"
|
||
|
}
|
||
|
c := this.Client
|
||
|
ctx := context.Background()
|
||
|
reqMessages := make([]openai.ChatCompletionMessage, 0)
|
||
|
for _, row := range messages {
|
||
|
reqMessage := openai.ChatCompletionMessage{
|
||
|
Role: row.Role,
|
||
|
Content: row.Content,
|
||
|
Name: row.Name,
|
||
|
}
|
||
|
reqMessages = append(reqMessages, reqMessage)
|
||
|
}
|
||
|
req := openai.ChatCompletionRequest{
|
||
|
Model: bigModelName,
|
||
|
MaxTokens: 1800,
|
||
|
Messages: reqMessages,
|
||
|
Stream: true,
|
||
|
Temperature: 0,
|
||
|
}
|
||
|
stream, err := c.CreateChatCompletionStream(ctx, req)
|
||
|
if err != nil {
|
||
|
fmt.Printf("ChatCompletionStream error: %v\n", err)
|
||
|
return stream, err
|
||
|
}
|
||
|
|
||
|
//for {
|
||
|
// response, err := stream.Recv()
|
||
|
// if errors.Is(err, io.EOF) {
|
||
|
// log.Println("\nStream finished")
|
||
|
// break
|
||
|
// } else if err != nil {
|
||
|
// log.Printf("\nStream error: %v\n", err)
|
||
|
// break
|
||
|
// } else {
|
||
|
// log.Println(response.Choices[0].Delta.Content, err)
|
||
|
// }
|
||
|
//}
|
||
|
return stream, nil
|
||
|
}
|
||
|
|
||
|
// MyChatGptStream 实现了 ChatGptStream 接口
|
||
|
type MyChatGptStream struct {
|
||
|
scanner *bufio.Scanner
|
||
|
response *http.Response
|
||
|
}
|
||
|
|
||
|
// Close 实现了 ChatGptStream 接口中的 Close 方法
|
||
|
func (s *MyChatGptStream) Close() {
|
||
|
s.response.Body.Close()
|
||
|
}
|
||
|
|
||
|
// Recv 实现了 ChatGptStream 接口中的 Recv 方法
|
||
|
func (s *MyChatGptStream) Recv() (openai.ChatCompletionStreamResponse, error) {
|
||
|
var respJson openai.ChatCompletionStreamResponse
|
||
|
s.scanner.Scan()
|
||
|
line := s.scanner.Text()
|
||
|
//log.Println(line)
|
||
|
if strings.Contains(line, "[DONE]") {
|
||
|
return respJson, io.EOF
|
||
|
}
|
||
|
|
||
|
line = strings.TrimSpace(strings.Replace(line, "data: ", "", 1))
|
||
|
if line == "" {
|
||
|
return respJson, nil
|
||
|
}
|
||
|
|
||
|
err := json.Unmarshal([]byte(line), &respJson)
|
||
|
if err != nil {
|
||
|
return respJson, io.EOF
|
||
|
}
|
||
|
//respContent, ok := respJson["choices"].([]interface{})[0].(map[string]interface{})["delta"].(map[string]interface{})["content"].(string)
|
||
|
//if !ok {
|
||
|
// return "", io.EOF
|
||
|
//}
|
||
|
return respJson, nil
|
||
|
}
|
||
|
|
||
|
// http接口请求openai流式响应
|
||
|
func (this *ChatGptTool) ChatGPTHttpStream(messages []Gpt3Dot5Message, bigModelName, visitorId string) (*MyChatGptStream, error) {
|
||
|
if bigModelName == "" {
|
||
|
bigModelName = "gpt-3.5-turbo-16k"
|
||
|
}
|
||
|
url := this.Url + "/chat/completions"
|
||
|
data := map[string]interface{}{
|
||
|
"stream": true,
|
||
|
"model": bigModelName,
|
||
|
}
|
||
|
if strings.Contains(this.Secret, "fastgpt-") {
|
||
|
data["chatId"] = visitorId
|
||
|
}
|
||
|
data["messages"] = make([]map[string]interface{}, 0)
|
||
|
for _, row := range messages {
|
||
|
reqMessage := map[string]interface{}{
|
||
|
"role": row.Role,
|
||
|
"content": row.Content,
|
||
|
}
|
||
|
data["messages"] = append(data["messages"].([]map[string]interface{}), reqMessage)
|
||
|
}
|
||
|
jsonData, err := json.Marshal(data)
|
||
|
if err != nil {
|
||
|
log.Println("Error marshaling JSON:", err)
|
||
|
return nil, err
|
||
|
}
|
||
|
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||
|
if err != nil {
|
||
|
log.Println("Error creating request:", err)
|
||
|
return nil, err
|
||
|
}
|
||
|
req.Header.Set("Content-Type", "application/json")
|
||
|
req.Header.Set("Authorization", "Bearer "+this.Secret)
|
||
|
client := &http.Client{}
|
||
|
resp, err := client.Do(req)
|
||
|
|
||
|
if err != nil || resp.StatusCode != 200 {
|
||
|
log.Println("Error making request:", err)
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
scanner := bufio.NewScanner(resp.Body)
|
||
|
return &MyChatGptStream{scanner: scanner, response: resp}, nil
|
||
|
}
|
||
|
|
||
|
type EmbeddingRequest struct {
|
||
|
Input string `json:"input"`
|
||
|
Model string `json:"model"`
|
||
|
}
|
||
|
type EmbeddingResponse struct {
|
||
|
Data []struct {
|
||
|
Embedding []float64 `json:"embedding"`
|
||
|
Index int `json:"index"`
|
||
|
Object string `json:"object"`
|
||
|
} `json:"data"`
|
||
|
Model string `json:"model"`
|
||
|
Object string `json:"object"`
|
||
|
Usage struct {
|
||
|
PromptTokens int `json:"prompt_tokens"`
|
||
|
TotalTokens int `json:"total_tokens"`
|
||
|
} `json:"usage"`
|
||
|
}
|
||
|
|
||
|
func (this *ChatGptTool) GetEmbedding(input string, model string) (string, error) {
|
||
|
if strings.Contains(this.Url, "azure") {
|
||
|
resp, err := this.Client.CreateEmbeddings(
|
||
|
context.Background(),
|
||
|
openai.EmbeddingRequest{
|
||
|
Input: []string{input},
|
||
|
Model: openai.AdaEmbeddingV2,
|
||
|
})
|
||
|
|
||
|
if err != nil {
|
||
|
log.Println("CreateEmbeddings error:", err)
|
||
|
return "", err
|
||
|
}
|
||
|
respStr, _ := json.Marshal(resp)
|
||
|
return string(respStr), nil
|
||
|
//vectors := resp.Data[0].Embedding // []float32 with 1536 dimensions
|
||
|
}
|
||
|
// 构建请求体
|
||
|
requestBody := EmbeddingRequest{
|
||
|
Input: input,
|
||
|
Model: model,
|
||
|
}
|
||
|
requestBodyBytes, err := json.Marshal(requestBody)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
|
||
|
// 构建 HTTP 请求
|
||
|
|
||
|
if !strings.Contains(this.Url, "/v1") {
|
||
|
this.Url = this.Url + "/v1"
|
||
|
}
|
||
|
url := this.Url + "/embeddings"
|
||
|
req, err := http.NewRequest("POST", url, bytes.NewReader(requestBodyBytes))
|
||
|
if err != nil {
|
||
|
log.Println("embeddings error:", err)
|
||
|
return "", err
|
||
|
}
|
||
|
req.Header.Set("Content-Type", "application/json")
|
||
|
req.Header.Set("Authorization", "Bearer "+this.Secret)
|
||
|
|
||
|
// 发送请求并获取响应
|
||
|
client := http.Client{}
|
||
|
resp, err := client.Do(req)
|
||
|
if err != nil {
|
||
|
log.Println("embeddings error:", err)
|
||
|
return "", err
|
||
|
}
|
||
|
defer resp.Body.Close()
|
||
|
|
||
|
// 解析响应体
|
||
|
responseBodyBytes, err := ioutil.ReadAll(resp.Body)
|
||
|
if err != nil {
|
||
|
log.Println("embeddings error:", err)
|
||
|
return string(responseBodyBytes), err
|
||
|
}
|
||
|
return string(responseBodyBytes), nil
|
||
|
}
|
||
|
|
||
|
// 语音转文本
|
||
|
func (this *ChatGptTool) GetWhisper(filePath string) (string, error) {
|
||
|
|
||
|
// 构建 HTTP 请求
|
||
|
|
||
|
if !strings.Contains(this.Url, "/v1") {
|
||
|
this.Url = this.Url + "/v1"
|
||
|
}
|
||
|
// Replace 'TOKEN' with your actual API token
|
||
|
token := this.Secret
|
||
|
file := filePath // Replace with the actual file path
|
||
|
|
||
|
// Create a buffer to store the request body
|
||
|
var requestBody bytes.Buffer
|
||
|
writer := multipart.NewWriter(&requestBody)
|
||
|
|
||
|
// Add the 'file' parameter
|
||
|
filePart, err := writer.CreateFormFile("file", file)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
fileToUpload, err := os.Open(file)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
defer fileToUpload.Close()
|
||
|
if _, err = io.Copy(filePart, fileToUpload); err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
|
||
|
// Add the 'model' parameter
|
||
|
writer.WriteField("model", "whisper-1")
|
||
|
|
||
|
// Close the multipart writer
|
||
|
writer.Close()
|
||
|
|
||
|
// Create the HTTP request
|
||
|
req, err := http.NewRequest("POST", this.Url+"/audio/transcriptions", &requestBody)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||
|
|
||
|
// Perform the request
|
||
|
client := &http.Client{}
|
||
|
resp, err := client.Do(req)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
defer resp.Body.Close()
|
||
|
// 解析响应体
|
||
|
responseBodyBytes, err := ioutil.ReadAll(resp.Body)
|
||
|
if err != nil {
|
||
|
log.Println("whisper error:", err)
|
||
|
return string(responseBodyBytes), err
|
||
|
}
|
||
|
return string(responseBodyBytes), nil
|
||
|
}
|