package utils import ( "bytes" "context" "encoding/json" "github.com/sashabaranov/go-openai" "io/ioutil" "log" "net/http" "strings" ) type ChatGptTool struct { Secret string Url string Client *openai.Client } type Gpt3Dot5Message openai.ChatCompletionMessage func NewChatGptTool(url, key string) *ChatGptTool { if strings.Contains(url, "azure") { config := openai.DefaultAzureConfig(key, url) client := openai.NewClientWithConfig(config) return &ChatGptTool{ Secret: key, Client: client, Url: url, } } config := openai.DefaultConfig(key) if url != "" { config.BaseURL = url + "/v1" } client := openai.NewClientWithConfig(config) //client := openai.NewClient(secret) return &ChatGptTool{ Secret: key, Client: client, Url: url, } } /** 调用gpt3.5接口 */ func (this *ChatGptTool) ChatGPT3Dot5Turbo(messages []Gpt3Dot5Message) (string, error) { reqMessages := make([]openai.ChatCompletionMessage, 0) for _, row := range messages { reqMessage := openai.ChatCompletionMessage{ Role: row.Role, Content: row.Content, Name: row.Name, } reqMessages = append(reqMessages, reqMessage) } resp, err := this.Client.CreateChatCompletion( context.Background(), openai.ChatCompletionRequest{ Model: openai.GPT3Dot5Turbo, Messages: reqMessages, Temperature: 0, }, ) if err != nil { log.Println("ChatGPT3Dot5Turbo error: ", err) return "", err } return resp.Choices[0].Message.Content, nil } /** 调用gpt3.5流式接口 */ func (this *ChatGptTool) ChatGPT3Dot5TurboStream(messages []Gpt3Dot5Message) (*openai.ChatCompletionStream, error) { c := this.Client ctx := context.Background() reqMessages := make([]openai.ChatCompletionMessage, 0) for _, row := range messages { reqMessage := openai.ChatCompletionMessage{ Role: row.Role, Content: row.Content, Name: row.Name, } reqMessages = append(reqMessages, reqMessage) } req := openai.ChatCompletionRequest{ Model: openai.GPT3Dot5Turbo16K, MaxTokens: 1800, Messages: reqMessages, Stream: true, Temperature: 0, } stream, err := c.CreateChatCompletionStream(ctx, req) if err != nil { log.Println("ChatCompletionStream error: ", err) return stream, err } //for { // response, err := stream.Recv() // if errors.Is(err, io.EOF) { // log.Println("\nStream finished") // break // } else if err != nil { // log.Printf("\nStream error: %v\n", err) // break // } else { // log.Println(response.Choices[0].Delta.Content, err) // } //} return stream, nil } type EmbeddingRequest struct { Input string `json:"input"` Model string `json:"model"` } type EmbeddingResponse struct { Data []struct { Embedding []float64 `json:"embedding"` Index int `json:"index"` Object string `json:"object"` } `json:"data"` Model string `json:"model"` Object string `json:"object"` Usage struct { PromptTokens int `json:"prompt_tokens"` TotalTokens int `json:"total_tokens"` } `json:"usage"` } func (this *ChatGptTool) GetEmbedding(input string, model string) (string, error) { if strings.Contains(this.Url, "azure") { resp, err := this.Client.CreateEmbeddings( context.Background(), openai.EmbeddingRequest{ Input: []string{input}, Model: openai.AdaEmbeddingV2, }) if err != nil { log.Println("CreateEmbeddings error:", err) return "", err } respStr, _ := json.Marshal(resp) return string(respStr), nil //vectors := resp.Data[0].Embedding // []float32 with 1536 dimensions } // 构建请求体 requestBody := EmbeddingRequest{ Input: input, Model: model, } requestBodyBytes, err := json.Marshal(requestBody) if err != nil { return "", err } // 构建 HTTP 请求 url := this.Url + "/v1/embeddings" req, err := http.NewRequest("POST", url, bytes.NewReader(requestBodyBytes)) if err != nil { log.Println("embeddings error:", err) return "", err } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+this.Secret) // 发送请求并获取响应 client := http.Client{} resp, err := client.Do(req) if err != nil { log.Println("embeddings error:", err) return "", err } defer resp.Body.Close() // 解析响应体 responseBodyBytes, err := ioutil.ReadAll(resp.Body) if err != nil { log.Println("embeddings error:", err) return string(responseBodyBytes), err } return string(responseBodyBytes), nil }