kefu/tools/rsa.go

118 lines
3.1 KiB
Go

package tools
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"strings"
)
type Rsa struct {
privateKey string
publicKey string
rsaPrivateKey *rsa.PrivateKey
rsaPublicKey *rsa.PublicKey
}
func NewRsa(publicKey, privateKey string) *Rsa {
rsaObj := &Rsa{
privateKey: privateKey,
publicKey: publicKey,
}
rsaObj.init() //初始化,如果存在公钥私钥,将其解析
return rsaObj
}
//初始化
func (r *Rsa) init() {
if r.privateKey != "" {
//将私钥解码
block, _ := pem.Decode([]byte(r.privateKey))
//pkcs1 //判断是否包含 BEGIN RSA 字符串,这个是由下面生成的时候定义的
if strings.Index(r.privateKey, "BEGIN RSA") > 0 {
//解析私钥
r.rsaPrivateKey, _ = x509.ParsePKCS1PrivateKey(block.Bytes)
} else { //pkcs8
//解析私钥
privateKey, _ := x509.ParsePKCS8PrivateKey(block.Bytes)
//转换格式 类型断言
r.rsaPrivateKey = privateKey.(*rsa.PrivateKey)
}
}
if r.publicKey != "" {
//将公钥解码 解析 转换格式
block, _ := pem.Decode([]byte(r.publicKey))
publicKey, _ := x509.ParsePKIXPublicKey(block.Bytes)
r.rsaPublicKey = publicKey.(*rsa.PublicKey)
}
}
//Encrypt 加密
func (r *Rsa) Encrypt(data []byte) ([]byte, error) {
// blockLength = 密钥长度 = 一次能加密的明文长度
// "/8" 将bit转为bytes
// "-11" 为 PKCS#1 建议的 padding 占用了 11 个字节
blockLength := r.rsaPublicKey.N.BitLen()/8 - 11
//如果明文长度不大于密钥长度,可以直接加密
if len(data) <= blockLength {
//对明文进行加密
return rsa.EncryptPKCS1v15(rand.Reader, r.rsaPublicKey, []byte(data))
}
//否则分段加密
//创建一个新的缓冲区
buffer := bytes.NewBufferString("")
pages := len(data) / blockLength //切分为多少块
//循环加密
for i := 0; i <= pages; i++ {
start := i * blockLength
end := (i + 1) * blockLength
if i == pages { //最后一页的判断
if start == len(data) {
continue
}
end = len(data)
}
//分段加密
chunk, err := rsa.EncryptPKCS1v15(rand.Reader, r.rsaPublicKey, data[start:end])
if err != nil {
return nil, err
}
//写入缓冲区
buffer.Write(chunk)
}
//读取缓冲区内容并返回,即返回加密结果
return buffer.Bytes(), nil
}
//Decrypt 解密
func (r *Rsa) Decrypt(data []byte) ([]byte, error) {
//加密后的密文长度=密钥长度。如果密文长度大于密钥长度,说明密文非一次加密形成
//1、获取密钥长度
blockLength := r.rsaPrivateKey.N.BitLen() / 8
if len(data) <= blockLength { //一次形成的密文直接解密
return rsa.DecryptPKCS1v15(rand.Reader, r.rsaPrivateKey, data)
}
buffer := bytes.NewBufferString("")
pages := len(data) / blockLength
for i := 0; i <= pages; i++ { //循环解密
start := i * blockLength
end := (i + 1) * blockLength
if i == pages {
if start == len(data) {
continue
}
end = len(data)
}
chunk, err := rsa.DecryptPKCS1v15(rand.Reader, r.rsaPrivateKey, data[start:end])
if err != nil {
return nil, err
}
buffer.Write(chunk)
}
return buffer.Bytes(), nil
}