105 lines
2.4 KiB
Go
105 lines
2.4 KiB
Go
package kefuWework
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/patrickmn/go-cache"
|
||
gogpt "github.com/sashabaranov/go-openai"
|
||
)
|
||
|
||
// openai key
|
||
var apiKey = "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"
|
||
|
||
// 这是一个可以自定义的 id,用默认值不会有问题
|
||
var userId = "orgId"
|
||
|
||
// 企业微信 token 缓存,请求频次过高可能有一些额外的问题
|
||
var conversationCache = cache.New(5*time.Minute, 5*time.Minute)
|
||
|
||
type ChatGPT struct {
|
||
client *gogpt.Client
|
||
ctx context.Context
|
||
userId string
|
||
}
|
||
|
||
func Chat(c *gin.Context) {
|
||
question := c.Query("question")
|
||
conversationId := c.Query("conversationId")
|
||
apiKey := c.Query("api_key")
|
||
ret, err := AskOnConversation(apiKey, question, conversationId, weworkConversationSize)
|
||
if err != nil {
|
||
c.JSON(500, err.Error())
|
||
return
|
||
}
|
||
c.JSON(200, ret)
|
||
}
|
||
|
||
func AskOnConversation(apiKey, question, conversationId string, size int) (string, error) {
|
||
var messages = []gogpt.ChatCompletionMessage{}
|
||
key := fmt.Sprintf("cache:conversation:%s", conversationId)
|
||
data, found := conversationCache.Get(key)
|
||
if found {
|
||
messages = data.([]gogpt.ChatCompletionMessage)
|
||
}
|
||
messages = append(messages, gogpt.ChatCompletionMessage{
|
||
Role: "system",
|
||
Content: question,
|
||
})
|
||
fmt.Println(messages)
|
||
pivot := size
|
||
if pivot > len(messages) {
|
||
pivot = len(messages)
|
||
}
|
||
messages = messages[len(messages)-pivot:]
|
||
conversationCache.Set(key, messages, 12*time.Hour)
|
||
k, userId := apiKey, userId
|
||
chat := NewGPT(k, userId)
|
||
defer chat.Close()
|
||
answer, err := chat.Chat(messages)
|
||
if err != nil {
|
||
fmt.Print(err.Error())
|
||
}
|
||
return answer, err
|
||
}
|
||
|
||
func (c *ChatGPT) Chat(messages []gogpt.ChatCompletionMessage) (answer string, err error) {
|
||
var msg = gogpt.ChatCompletionMessage{}
|
||
msg.Role = "system"
|
||
req := gogpt.ChatCompletionRequest{
|
||
Model: gogpt.GPT3Dot5Turbo,
|
||
Messages: messages,
|
||
}
|
||
resp, err := c.client.CreateChatCompletion(c.ctx, req)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
answer = resp.Choices[0].Message.Content
|
||
for len(answer) > 0 {
|
||
if answer[0] == '\n' {
|
||
answer = answer[1:]
|
||
} else {
|
||
break
|
||
}
|
||
}
|
||
return answer, err
|
||
}
|
||
|
||
func NewGPT(ApiKey, UserId string) *ChatGPT {
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
go func() {
|
||
<-ctx.Done()
|
||
cancel()
|
||
}()
|
||
return &ChatGPT{
|
||
client: gogpt.NewClient(ApiKey),
|
||
ctx: ctx,
|
||
userId: UserId,
|
||
}
|
||
}
|
||
func (c *ChatGPT) Close() {
|
||
c.ctx.Done()
|
||
}
|