mirror of
https://github.com/duanhf2012/origin.git
synced 2026-02-03 22:45:13 +08:00
466 lines
11 KiB
Go
466 lines
11 KiB
Go
package mysqlmodule
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/duanhf2012/origin/v2/log"
|
|
"net/url"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/duanhf2012/origin/v2/service"
|
|
_ "github.com/go-sql-driver/mysql"
|
|
)
|
|
|
|
type SyncFun func()
|
|
|
|
type DBExecute struct {
|
|
syncExecuteFun chan SyncFun
|
|
syncExecuteExit chan bool
|
|
}
|
|
|
|
type PingExecute struct {
|
|
tickerPing *time.Ticker
|
|
pintExit chan bool
|
|
}
|
|
|
|
type MySQLModule struct {
|
|
service.Module
|
|
db *sql.DB
|
|
url string
|
|
username string
|
|
password string
|
|
dbname string
|
|
slowDuration time.Duration
|
|
pingCoroutine PingExecute
|
|
waitGroup sync.WaitGroup
|
|
}
|
|
|
|
// Tx ...
|
|
type Tx struct {
|
|
tx *sql.Tx
|
|
slowDuration time.Duration
|
|
}
|
|
|
|
// DBResult ...
|
|
type DBResult struct {
|
|
LastInsertID int64
|
|
RowsAffected int64
|
|
|
|
rowNum int
|
|
RowInfo map[string][]interface{} //map[fieldName][row]sql.NullString
|
|
}
|
|
|
|
type DataSetList struct {
|
|
dataSetList []DBResult
|
|
currentDataSetIdx int32
|
|
tag string
|
|
blur bool
|
|
}
|
|
|
|
type dbControl interface {
|
|
Exec(query string, args ...interface{}) (sql.Result, error)
|
|
Query(query string, args ...interface{}) (*sql.Rows, error)
|
|
}
|
|
|
|
func (m *MySQLModule) Init(url string, userName string, password string, dbname string, maxConn int) error {
|
|
m.url = url
|
|
m.username = userName
|
|
m.password = password
|
|
m.dbname = dbname
|
|
m.pingCoroutine = PingExecute{tickerPing: time.NewTicker(5 * time.Second), pintExit: make(chan bool, 1)}
|
|
|
|
return m.connect(maxConn)
|
|
}
|
|
|
|
func (m *MySQLModule) SetQuerySlowTime(slowDuration time.Duration) {
|
|
m.slowDuration = slowDuration
|
|
}
|
|
|
|
func (m *MySQLModule) Query(strQuery string, args ...interface{}) (*DataSetList, error) {
|
|
return query(m.slowDuration, m.db, strQuery, args...)
|
|
}
|
|
|
|
// Exec ...
|
|
func (m *MySQLModule) Exec(strSql string, args ...interface{}) (*DBResult, error) {
|
|
return exec(m.slowDuration, m.db, strSql, args...)
|
|
}
|
|
|
|
// Begin starts a transaction.
|
|
func (m *MySQLModule) Begin() (*Tx, error) {
|
|
var txDBModule Tx
|
|
txDb, err := m.db.Begin()
|
|
if err != nil {
|
|
log.Error("Begin error", log.ErrorField("err",err))
|
|
return &txDBModule, err
|
|
}
|
|
txDBModule.slowDuration = m.slowDuration
|
|
txDBModule.tx = txDb
|
|
return &txDBModule, nil
|
|
}
|
|
|
|
// Rollback aborts the transaction.
|
|
func (slf *Tx) Rollback() error {
|
|
return slf.tx.Rollback()
|
|
}
|
|
|
|
// Commit commits the transaction.
|
|
func (slf *Tx) Commit() error {
|
|
return slf.tx.Commit()
|
|
}
|
|
|
|
// Query executes a query that return rows.
|
|
func (slf *Tx) Query(strQuery string, args ...interface{}) (*DataSetList, error) {
|
|
return query(slf.slowDuration, slf.tx, strQuery, args...)
|
|
}
|
|
|
|
// Exec executes a query that doesn't return rows.
|
|
func (slf *Tx) Exec(strSql string, args ...interface{}) (*DBResult, error) {
|
|
return exec(slf.slowDuration, slf.tx, strSql, args...)
|
|
}
|
|
|
|
// Connect ...
|
|
func (m *MySQLModule) connect(maxConn int) error {
|
|
cmd := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8&parseTime=true&loc=%s&readTimeout=30s&timeout=15s&writeTimeout=30s",
|
|
m.username,
|
|
m.password,
|
|
m.url,
|
|
m.dbname,
|
|
url.QueryEscape(time.Local.String()))
|
|
|
|
db, err := sql.Open("mysql", cmd)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = db.Ping()
|
|
if err != nil {
|
|
db.Close()
|
|
return err
|
|
}
|
|
m.db = db
|
|
db.SetMaxOpenConns(maxConn)
|
|
db.SetMaxIdleConns(maxConn)
|
|
db.SetConnMaxLifetime(time.Second * 90)
|
|
|
|
go m.runPing()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *MySQLModule) runPing() {
|
|
for {
|
|
select {
|
|
case <-m.pingCoroutine.pintExit:
|
|
log.Error("RunPing stopping",log.String("url", m.url),log.String("dbname", m.dbname))
|
|
return
|
|
case <-m.pingCoroutine.tickerPing.C:
|
|
if m.db != nil {
|
|
m.db.Ping()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func checkArgs(args ...interface{}) error {
|
|
for _, val := range args {
|
|
if reflect.TypeOf(val).Kind() == reflect.String {
|
|
retVal := val.(string)
|
|
if strings.Contains(retVal, "-") == true {
|
|
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
|
}
|
|
if strings.Contains(retVal, "#") == true {
|
|
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
|
}
|
|
if strings.Contains(retVal, "&") == true {
|
|
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
|
}
|
|
if strings.Contains(retVal, "=") == true {
|
|
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
|
}
|
|
if strings.Contains(retVal, "%") == true {
|
|
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
|
}
|
|
if strings.Contains(retVal, "'") == true {
|
|
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
|
}
|
|
if strings.Contains(strings.ToLower(retVal), "delete ") == true {
|
|
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
|
}
|
|
if strings.Contains(strings.ToLower(retVal), "truncate ") == true {
|
|
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
|
}
|
|
if strings.Contains(strings.ToLower(retVal), " or ") == true {
|
|
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
|
}
|
|
if strings.Contains(strings.ToLower(retVal), "from ") == true {
|
|
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
|
}
|
|
if strings.Contains(strings.ToLower(retVal), "set ") == true {
|
|
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func checkSlow(slowDuration time.Duration, Time time.Duration) bool {
|
|
if slowDuration != 0 && Time >= slowDuration {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func query(slowDuration time.Duration, db dbControl, strQuery string, args ...interface{}) (*DataSetList, error) {
|
|
datasetList := DataSetList{}
|
|
datasetList.tag = "json"
|
|
datasetList.blur = true
|
|
|
|
if checkArgs(args) != nil {
|
|
log.Error("CheckArgs is error",log.String("sql",strQuery))
|
|
return &datasetList, fmt.Errorf("checkArgs is error")
|
|
}
|
|
|
|
if db == nil {
|
|
log.Error("cannot connect database",log.String("sql", strQuery))
|
|
return &datasetList, fmt.Errorf("cannot connect database")
|
|
}
|
|
|
|
TimeFuncStart := time.Now()
|
|
rows, err := db.Query(strQuery, args...)
|
|
timeFuncPass := time.Since(TimeFuncStart)
|
|
|
|
if checkSlow(slowDuration, timeFuncPass) {
|
|
log.Error("Query slow",log.Int64("time_ms",timeFuncPass.Milliseconds()),log.String("sql", strQuery), log.Any("args",args))
|
|
}
|
|
if err != nil {
|
|
log.Error("Query error", log.String("sql",strQuery),log.ErrorField("err",err))
|
|
if rows != nil {
|
|
rows.Close()
|
|
}
|
|
return &datasetList, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for {
|
|
dbResult := DBResult{}
|
|
//取出当前结果集所有行
|
|
for rows.Next() {
|
|
if dbResult.RowInfo == nil {
|
|
dbResult.RowInfo = make(map[string][]interface{})
|
|
}
|
|
|
|
colField, err := rows.Columns()
|
|
if err != nil {
|
|
return &datasetList, err
|
|
}
|
|
count := len(colField)
|
|
valuePtrs := make([]interface{}, count)
|
|
for i := 0; i < count; i++ {
|
|
valuePtrs[i] = &sql.NullString{}
|
|
}
|
|
rows.Scan(valuePtrs...)
|
|
|
|
for idx, fieldName := range colField {
|
|
fieldRowData := dbResult.RowInfo[strings.ToLower(fieldName)]
|
|
fieldRowData = append(fieldRowData, valuePtrs[idx])
|
|
dbResult.RowInfo[strings.ToLower(fieldName)] = fieldRowData
|
|
}
|
|
dbResult.rowNum += 1
|
|
}
|
|
|
|
datasetList.dataSetList = append(datasetList.dataSetList, dbResult)
|
|
//取下一个结果集
|
|
hasRet := rows.NextResultSet()
|
|
|
|
if hasRet == false {
|
|
if rowErr :=rows.Err();rowErr != nil {
|
|
log.Error("NextResultSet error", log.String("sql",strQuery), log.ErrorField("err",rowErr))
|
|
}
|
|
break
|
|
}
|
|
}
|
|
|
|
return &datasetList, nil
|
|
}
|
|
|
|
func exec(slowDuration time.Duration, db dbControl, strSql string, args ...interface{}) (*DBResult, error) {
|
|
ret := &DBResult{}
|
|
if db == nil {
|
|
log.Error("cannot connect database", log.String("sql",strSql))
|
|
return ret, fmt.Errorf("cannot connect database")
|
|
}
|
|
|
|
if checkArgs(args) != nil {
|
|
log.Error("CheckArgs is error", log.String("sql",strSql))
|
|
return ret, fmt.Errorf("checkArgs is error")
|
|
}
|
|
|
|
TimeFuncStart := time.Now()
|
|
res, err := db.Exec(strSql, args...)
|
|
timeFuncPass := time.Since(TimeFuncStart)
|
|
if checkSlow(slowDuration, timeFuncPass) {
|
|
log.Error("Exec slow",log.Int64("time_ms",timeFuncPass.Milliseconds()),log.String("sql",strSql),log.Any("args",args) )
|
|
}
|
|
if err != nil {
|
|
log.Error("Exec error",log.String("sql",strSql),log.ErrorField("err", err))
|
|
return nil, err
|
|
}
|
|
|
|
ret.LastInsertID, _ = res.LastInsertId()
|
|
ret.RowsAffected, _ = res.RowsAffected()
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func (ds *DataSetList) UnMarshal(args ...interface{}) error {
|
|
if len(ds.dataSetList) != len(args) {
|
|
return errors.New(fmt.Sprintf("Data set len(%d,%d) is not equal to args!", len(ds.dataSetList), len(args)))
|
|
}
|
|
|
|
for _, out := range args {
|
|
v := reflect.ValueOf(out)
|
|
if v.Kind() != reflect.Ptr {
|
|
return errors.New("interface must be a pointer")
|
|
}
|
|
|
|
if v.Kind() != reflect.Ptr {
|
|
return errors.New("interface must be a pointer")
|
|
}
|
|
|
|
if v.Elem().Kind() == reflect.Struct {
|
|
err := ds.rowData2interface(0, ds.dataSetList[ds.currentDataSetIdx].RowInfo, v)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if v.Elem().Kind() == reflect.Slice {
|
|
err := ds.slice2interface(out)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
ds.currentDataSetIdx = ds.currentDataSetIdx + 1
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (ds *DataSetList) slice2interface(in interface{}) error {
|
|
length := ds.dataSetList[ds.currentDataSetIdx].rowNum
|
|
if length == 0 {
|
|
return nil
|
|
}
|
|
|
|
v := reflect.ValueOf(in).Elem()
|
|
newV := reflect.MakeSlice(v.Type(), 0, length)
|
|
v.Set(newV)
|
|
v.SetLen(length)
|
|
|
|
for i := 0; i < length; i++ {
|
|
idxV := v.Index(i)
|
|
if idxV.Kind() == reflect.Ptr {
|
|
newObj := reflect.New(idxV.Type().Elem())
|
|
v.Index(i).Set(newObj)
|
|
idxV = newObj
|
|
} else {
|
|
idxV = idxV.Addr()
|
|
}
|
|
|
|
err := ds.rowData2interface(i, ds.dataSetList[ds.currentDataSetIdx].RowInfo, idxV)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (ds *DataSetList) rowData2interface(rowIdx int, m map[string][]interface{}, v reflect.Value) error {
|
|
t := v.Type()
|
|
val := v.Elem()
|
|
typ := t.Elem()
|
|
|
|
if !val.IsValid() {
|
|
return errors.New("incorrect data type")
|
|
}
|
|
|
|
for i := 0; i < val.NumField(); i++ {
|
|
value := val.Field(i)
|
|
kind := value.Kind()
|
|
tag := typ.Field(i).Tag.Get(ds.tag)
|
|
if tag == "" {
|
|
tag = typ.Field(i).Name
|
|
}
|
|
|
|
if tag != "" && tag != "-" {
|
|
vTag := strings.ToLower(tag)
|
|
columnData, ok := m[vTag]
|
|
if ok == false {
|
|
if !ds.blur {
|
|
return fmt.Errorf("cannot find filed name %s", vTag)
|
|
}
|
|
continue
|
|
}
|
|
if len(columnData) <= rowIdx {
|
|
return fmt.Errorf("datasource column is error %s", tag)
|
|
}
|
|
meta := columnData[rowIdx].(*sql.NullString)
|
|
if !ok {
|
|
if !ds.blur {
|
|
return fmt.Errorf("no corresponding field was found in the result set %s", tag)
|
|
}
|
|
continue
|
|
}
|
|
if !value.CanSet() {
|
|
return errors.New("struct fields do not have read or write permissions")
|
|
}
|
|
|
|
if meta.Valid == false {
|
|
continue
|
|
}
|
|
|
|
if len(meta.String) == 0 {
|
|
continue
|
|
}
|
|
|
|
switch kind {
|
|
case reflect.String:
|
|
value.SetString(meta.String)
|
|
case reflect.Float32, reflect.Float64:
|
|
f, err := strconv.ParseFloat(meta.String, 64)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
value.SetFloat(f)
|
|
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
|
|
integer64, err := strconv.ParseInt(meta.String, 10, 64)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
value.SetInt(integer64)
|
|
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
|
integer64, err := strconv.ParseUint(meta.String, 10, 64)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
value.SetUint(integer64)
|
|
case reflect.Bool:
|
|
b, err := strconv.ParseBool(meta.String)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
value.SetBool(b)
|
|
default:
|
|
return errors.New("the database map has unrecognized data types")
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|