52 lines
1.2 KiB
Go
52 lines
1.2 KiB
Go
package ai
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"kefu/lib"
|
|
"log"
|
|
"strings"
|
|
)
|
|
|
|
// 训练功能
|
|
func Train(openaiUrl string, openaiKey string, pointId interface{}, collectName, content, fileId, title, url string) (string, error) {
|
|
|
|
gpt := lib.NewChatGptTool(openaiUrl, openaiKey)
|
|
str := strings.TrimSpace(content)
|
|
response, err := gpt.GetEmbedding(str, "text-embedding-ada-002")
|
|
if err != nil {
|
|
log.Println("向量接口失败:", err)
|
|
return "", err
|
|
}
|
|
var embeddingResponse lib.EmbeddingResponse
|
|
json.Unmarshal([]byte(response), &embeddingResponse)
|
|
if len(embeddingResponse.Data) == 0 {
|
|
log.Println("向量接口失败:", response)
|
|
return "", errors.New(response)
|
|
}
|
|
|
|
points := []map[string]interface{}{
|
|
{
|
|
"id": pointId,
|
|
"payload": map[string]interface{}{"text": str, "title": title, "url": url, "fileid": fileId},
|
|
"vector": embeddingResponse.Data[0].Embedding,
|
|
},
|
|
}
|
|
res, err := lib.PutPoints(collectName, points)
|
|
return res, err
|
|
}
|
|
|
|
// 长文本分块
|
|
func SplitTextByLength(text string, length int) []string {
|
|
var blocks []string
|
|
runes := []rune(text)
|
|
for i := 0; i < len(runes); i += length {
|
|
j := i + length
|
|
if j > len(runes) {
|
|
j = len(runes)
|
|
}
|
|
blocks = append(blocks, string(runes[i:j]))
|
|
}
|
|
return blocks
|
|
}
|