kefu/knowledge/utils/chatgpt.go

189 lines
4.4 KiB
Go

package utils
import (
"bytes"
"context"
"encoding/json"
"github.com/sashabaranov/go-openai"
"io/ioutil"
"log"
"net/http"
"strings"
)
type ChatGptTool struct {
Secret string
Url string
Client *openai.Client
}
type Gpt3Dot5Message openai.ChatCompletionMessage
func NewChatGptTool(url, key string) *ChatGptTool {
if strings.Contains(url, "azure") {
config := openai.DefaultAzureConfig(key, url)
client := openai.NewClientWithConfig(config)
return &ChatGptTool{
Secret: key,
Client: client,
Url: url,
}
}
config := openai.DefaultConfig(key)
if url != "" {
config.BaseURL = url + "/v1"
}
client := openai.NewClientWithConfig(config)
//client := openai.NewClient(secret)
return &ChatGptTool{
Secret: key,
Client: client,
Url: url,
}
}
/**
调用gpt3.5接口
*/
func (this *ChatGptTool) ChatGPT3Dot5Turbo(messages []Gpt3Dot5Message) (string, error) {
reqMessages := make([]openai.ChatCompletionMessage, 0)
for _, row := range messages {
reqMessage := openai.ChatCompletionMessage{
Role: row.Role,
Content: row.Content,
Name: row.Name,
}
reqMessages = append(reqMessages, reqMessage)
}
resp, err := this.Client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
Messages: reqMessages,
Temperature: 0,
},
)
if err != nil {
log.Println("ChatGPT3Dot5Turbo error: ", err)
return "", err
}
return resp.Choices[0].Message.Content, nil
}
/**
调用gpt3.5流式接口
*/
func (this *ChatGptTool) ChatGPT3Dot5TurboStream(messages []Gpt3Dot5Message) (*openai.ChatCompletionStream, error) {
c := this.Client
ctx := context.Background()
reqMessages := make([]openai.ChatCompletionMessage, 0)
for _, row := range messages {
reqMessage := openai.ChatCompletionMessage{
Role: row.Role,
Content: row.Content,
Name: row.Name,
}
reqMessages = append(reqMessages, reqMessage)
}
req := openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo16K,
MaxTokens: 1800,
Messages: reqMessages,
Stream: true,
Temperature: 0,
}
stream, err := c.CreateChatCompletionStream(ctx, req)
if err != nil {
log.Println("ChatCompletionStream error: ", err)
return stream, err
}
//for {
// response, err := stream.Recv()
// if errors.Is(err, io.EOF) {
// log.Println("\nStream finished")
// break
// } else if err != nil {
// log.Printf("\nStream error: %v\n", err)
// break
// } else {
// log.Println(response.Choices[0].Delta.Content, err)
// }
//}
return stream, nil
}
type EmbeddingRequest struct {
Input string `json:"input"`
Model string `json:"model"`
}
type EmbeddingResponse struct {
Data []struct {
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
Object string `json:"object"`
} `json:"data"`
Model string `json:"model"`
Object string `json:"object"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
func (this *ChatGptTool) GetEmbedding(input string, model string) (string, error) {
if strings.Contains(this.Url, "azure") {
resp, err := this.Client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Input: []string{input},
Model: openai.AdaEmbeddingV2,
})
if err != nil {
log.Println("CreateEmbeddings error:", err)
return "", err
}
respStr, _ := json.Marshal(resp)
return string(respStr), nil
//vectors := resp.Data[0].Embedding // []float32 with 1536 dimensions
}
// 构建请求体
requestBody := EmbeddingRequest{
Input: input,
Model: model,
}
requestBodyBytes, err := json.Marshal(requestBody)
if err != nil {
return "", err
}
// 构建 HTTP 请求
url := this.Url + "/v1/embeddings"
req, err := http.NewRequest("POST", url, bytes.NewReader(requestBodyBytes))
if err != nil {
log.Println("embeddings error:", err)
return "", err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+this.Secret)
// 发送请求并获取响应
client := http.Client{}
resp, err := client.Do(req)
if err != nil {
log.Println("embeddings error:", err)
return "", err
}
defer resp.Body.Close()
// 解析响应体
responseBodyBytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Println("embeddings error:", err)
return string(responseBodyBytes), err
}
return string(responseBodyBytes), nil
}