kefu/knowledge/utils/import.go

120 lines
2.7 KiB
Go

package utils
import (
"database/sql"
"errors"
"fmt"
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/mysql"
"io/ioutil"
"log"
"os"
"strings"
"time"
)
type DBTool struct {
SqlPath string
Username, Password, Server, Port, Database string
}
func (this *DBTool) connect() (*gorm.DB, error) {
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", this.Username, this.Password, this.Server, this.Port, this.Database)
db, err := gorm.Open("mysql", dsn)
if err != nil {
log.Println("数据库连接失败:", err)
//panic("数据库连接失败!")
return nil, err
}
db.SingularTable(true)
db.LogMode(true)
db.DB().SetMaxIdleConns(10)
db.DB().SetMaxOpenConns(100)
db.DB().SetConnMaxLifetime(59 * time.Second)
return db, nil
}
//执行查询语句
func (this *DBTool) QuerySql(sqlStr string) ([]map[string]interface{}, error) {
db, err := this.connect()
var result []map[string]interface{}
if err != nil {
return result, err
}
sqlStr = strings.TrimSpace(sqlStr)
if sqlStr == "" {
return result, errors.New("sql语句为空")
}
rows, err := db.DB().Query(sqlStr)
if err != nil {
return result, err
}
//获取列名
columns, _ := rows.Columns()
//定义一个切片,长度是字段的个数,切片里面的元素类型是sql.RawBytes
values := make([]sql.RawBytes, len(columns))
//定义一个切片,元素类型是interface{} 接口
scanArgs := make([]interface{}, len(values))
for i := range values {
//把sql.RawBytes类型的地址存进去了
scanArgs[i] = &values[i]
}
//获取字段值
for rows.Next() {
res := make(map[string]interface{})
rows.Scan(scanArgs...)
for i, col := range values {
res[columns[i]] = string(col)
}
result = append(result, res)
}
return result, nil
}
//执行增删改
func (this *DBTool) ExecuteSql(sql string) error {
db, err := this.connect()
sql = strings.TrimSpace(sql)
if sql == "" {
return errors.New("sql语句为空")
}
err = db.Exec(sql).Error
if err != nil {
log.Println("执行失败:" + err.Error())
return err
} else {
log.Println(sql, "\t success!")
}
return nil
}
//执行sql文件
func (this *DBTool) ImportSql() error {
_, err := os.Stat(this.SqlPath)
if os.IsNotExist(err) {
log.Println("数据库SQL文件不存在:", err)
return err
}
db, err := this.connect()
sqls, _ := ioutil.ReadFile(this.SqlPath)
sqlArr := strings.Split(string(sqls), ";")
for _, sql := range sqlArr {
sql = strings.TrimSpace(sql)
if sql == "" {
continue
}
err := db.Exec(sql).Error
if err != nil {
log.Println("数据库导入失败:" + err.Error())
return err
} else {
log.Println(sql, "\t success!")
}
}
return nil
}