Files
origin/sysmodule/mysqlmodule/mysqlmodule.go
2024-04-09 16:40:13 +08:00

469 lines
11 KiB
Go

package mysqlmodule
import (
"database/sql"
"errors"
"fmt"
"github.com/duanhf2012/origin/v2/log"
"net/url"
"reflect"
"strconv"
"strings"
"sync"
"time"
"github.com/duanhf2012/origin/v2/service"
_ "github.com/go-sql-driver/mysql"
)
type SyncFun func()
type DBExecute struct {
syncExecuteFun chan SyncFun
syncExecuteExit chan bool
}
type PingExecute struct {
tickerPing *time.Ticker
pintExit chan bool
}
// DBModule ...
type MySQLModule struct {
service.Module
db *sql.DB
url string
username string
password string
dbname string
slowDuration time.Duration
pingCoroutine PingExecute
waitGroup sync.WaitGroup
}
// Tx ...
type Tx struct {
tx *sql.Tx
slowDuration time.Duration
}
// DBResult ...
type DBResult struct {
LastInsertID int64
RowsAffected int64
rowNum int
RowInfo map[string][]interface{} //map[fieldname][row]sql.NullString
}
type DataSetList struct {
dataSetList []DBResult
currentDataSetIdx int32
tag string
blur bool
}
type dbControl interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
}
func (m *MySQLModule) Init( url string, userName string, password string, dbname string,maxConn int) error {
m.url = url
m.username = userName
m.password = password
m.dbname = dbname
m.pingCoroutine = PingExecute{tickerPing : time.NewTicker(5*time.Second), pintExit : make(chan bool, 1)}
return m.connect(maxConn)
}
func (m *MySQLModule) SetQuerySlowTime(slowDuration time.Duration) {
m.slowDuration = slowDuration
}
func (m *MySQLModule) Query(strQuery string, args ...interface{}) (*DataSetList, error) {
return query(m.slowDuration, m.db,strQuery,args...)
}
// Exec ...
func (m *MySQLModule) Exec(strSql string, args ...interface{}) (*DBResult, error) {
return exec(m.slowDuration, m.db,strSql,args...)
}
// Begin starts a transaction.
func (m *MySQLModule) Begin() (*Tx, error) {
var txDBModule Tx
txDb, err := m.db.Begin()
if err != nil {
log.Error("Begin error:%s", err.Error())
return &txDBModule, err
}
txDBModule.slowDuration = m.slowDuration
txDBModule.tx = txDb
return &txDBModule, nil
}
// Rollback aborts the transaction.
func (slf *Tx) Rollback() error {
return slf.tx.Rollback()
}
// Commit commits the transaction.
func (slf *Tx) Commit() error {
return slf.tx.Commit()
}
// QueryEx executes a query that return rows.
func (slf *Tx) Query(strQuery string, args ...interface{}) (*DataSetList, error) {
return query(slf.slowDuration,slf.tx,strQuery,args...)
}
// Exec executes a query that doesn't return rows.
func (slf *Tx) Exec(strSql string, args ...interface{}) (*DBResult, error) {
return exec(slf.slowDuration,slf.tx,strSql,args...)
}
// Connect ...
func (m *MySQLModule) connect(maxConn int) error {
cmd := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8&parseTime=true&loc=%s&readTimeout=30s&timeout=15s&writeTimeout=30s",
m.username,
m.password,
m.url,
m.dbname,
url.QueryEscape(time.Local.String()))
db, err := sql.Open("mysql", cmd)
if err != nil {
return err
}
err = db.Ping()
if err != nil {
db.Close()
return err
}
m.db = db
db.SetMaxOpenConns(maxConn)
db.SetMaxIdleConns(maxConn)
db.SetConnMaxLifetime(time.Second * 90)
go m.runPing()
return nil
}
func (m *MySQLModule) runPing() {
for {
select {
case <-m.pingCoroutine.pintExit:
log.Error("RunPing stopping %s...", fmt.Sprintf("%T", m))
return
case <-m.pingCoroutine.tickerPing.C:
if m.db != nil {
m.db.Ping()
}
}
}
}
func checkArgs(args ...interface{}) error {
for _, val := range args {
if reflect.TypeOf(val).Kind() == reflect.String {
retVal := val.(string)
if strings.Contains(retVal, "-") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
if strings.Contains(retVal, "#") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
if strings.Contains(retVal, "&") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
if strings.Contains(retVal, "=") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
if strings.Contains(retVal, "%") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
if strings.Contains(retVal, "'") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
if strings.Contains(strings.ToLower(retVal), "delete ") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
if strings.Contains(strings.ToLower(retVal), "truncate ") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
if strings.Contains(strings.ToLower(retVal), " or ") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
if strings.Contains(strings.ToLower(retVal), "from ") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
if strings.Contains(strings.ToLower(retVal), "set ") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
}
}
return nil
}
func checkSlow(slowDuration time.Duration,Time time.Duration) bool {
if slowDuration != 0 && Time >=slowDuration {
return true
}
return false
}
func query(slowDuration time.Duration,db dbControl,strQuery string, args ...interface{}) (*DataSetList, error) {
datasetList := DataSetList{}
datasetList.tag = "json"
datasetList.blur = true
if checkArgs(args) != nil {
log.Error("CheckArgs is error :%s", strQuery)
return &datasetList, fmt.Errorf("CheckArgs is error!")
}
if db == nil {
log.Error("cannot connect database:%s", strQuery)
return &datasetList, fmt.Errorf("cannot connect database!")
}
TimeFuncStart := time.Now()
rows, err := db.Query(strQuery, args...)
timeFuncPass := time.Since(TimeFuncStart)
if checkSlow(slowDuration,timeFuncPass) {
log.Error("DBModule QueryEx Time %s , Query :%s , args :%+v", timeFuncPass, strQuery, args)
}
if err != nil {
log.Error("Query:%s(%v)", strQuery, err)
if rows != nil {
rows.Close()
}
return &datasetList, err
}
defer rows.Close()
for {
dbResult := DBResult{}
//取出当前结果集所有行
for rows.Next() {
if dbResult.RowInfo == nil {
dbResult.RowInfo = make(map[string][]interface{})
}
//RowInfo map[string][][]sql.NullString //map[fieldname][row][column]sql.NullString
colField, err := rows.Columns()
if err != nil {
return &datasetList, err
}
count := len(colField)
valuePtrs := make([]interface{}, count)
for i := 0; i < count; i++ {
valuePtrs[i] = &sql.NullString{}
}
rows.Scan(valuePtrs...)
for idx, fieldname := range colField {
fieldRowData := dbResult.RowInfo[strings.ToLower(fieldname)]
fieldRowData = append(fieldRowData, valuePtrs[idx])
dbResult.RowInfo[strings.ToLower(fieldname)] = fieldRowData
}
dbResult.rowNum += 1
}
datasetList.dataSetList = append(datasetList.dataSetList, dbResult)
//取下一个结果集
hasRet := rows.NextResultSet()
if hasRet == false {
if rows.Err() != nil {
log.Error( "Query:%s(%+v)", strQuery, rows)
}
break
}
}
return &datasetList, nil
}
func exec(slowDuration time.Duration,db dbControl,strSql string, args ...interface{}) (*DBResult, error) {
ret := &DBResult{}
if db == nil {
log.Error("cannot connect database:%s", strSql)
return ret, fmt.Errorf("cannot connect database!")
}
if checkArgs(args) != nil {
log.Error("CheckArgs is error :%s", strSql)
return ret, fmt.Errorf("CheckArgs is error!")
}
TimeFuncStart := time.Now()
res, err := db.Exec(strSql, args...)
timeFuncPass := time.Since(TimeFuncStart)
if checkSlow(slowDuration,timeFuncPass) {
log.Error("DBModule QueryEx Time %s , Query :%s , args :%+v", timeFuncPass, strSql, args)
}
if err != nil {
log.Error("Exec:%s(%v)", strSql, err)
return nil, err
}
ret.LastInsertID, _ = res.LastInsertId()
ret.RowsAffected, _ = res.RowsAffected()
return ret, nil
}
func (ds *DataSetList) UnMarshal(args ...interface{}) error {
if len(ds.dataSetList) != len(args) {
return errors.New(fmt.Sprintf("Data set len(%d,%d) is not equal to args!", len(ds.dataSetList), len(args)))
}
for _, out := range args {
v := reflect.ValueOf(out)
if v.Kind() != reflect.Ptr {
return errors.New("interface must be a pointer")
}
if v.Kind() != reflect.Ptr {
return errors.New("interface must be a pointer")
}
if v.Elem().Kind() == reflect.Struct {
err := ds.rowData2interface(0, ds.dataSetList[ds.currentDataSetIdx].RowInfo, v)
if err != nil {
return err
}
}
if v.Elem().Kind() == reflect.Slice {
err := ds.slice2interface(out)
if err != nil {
return err
}
}
ds.currentDataSetIdx = ds.currentDataSetIdx + 1
}
return nil
}
func (ds *DataSetList) slice2interface(in interface{}) error {
length := ds.dataSetList[ds.currentDataSetIdx].rowNum
if length == 0 {
return nil
}
v := reflect.ValueOf(in).Elem()
newv := reflect.MakeSlice(v.Type(), 0, length)
v.Set(newv)
v.SetLen(length)
for i := 0; i < length; i++ {
idxv := v.Index(i)
if idxv.Kind() == reflect.Ptr {
newObj := reflect.New(idxv.Type().Elem())
v.Index(i).Set(newObj)
idxv = newObj
} else {
idxv = idxv.Addr()
}
err := ds.rowData2interface(i, ds.dataSetList[ds.currentDataSetIdx].RowInfo, idxv)
if err != nil {
return err
}
}
return nil
}
func (ds *DataSetList) rowData2interface(rowIdx int, m map[string][]interface{}, v reflect.Value) error {
t := v.Type()
val := v.Elem()
typ := t.Elem()
if !val.IsValid() {
return errors.New("Incorrect data type!")
}
for i := 0; i < val.NumField(); i++ {
value := val.Field(i)
kind := value.Kind()
tag := typ.Field(i).Tag.Get(ds.tag)
if tag == "" {
tag = typ.Field(i).Name
}
if tag != "" && tag != "-" {
vtag := strings.ToLower(tag)
columnData, ok := m[vtag]
if ok == false {
if !ds.blur {
return fmt.Errorf("Cannot find filed name %s!", vtag)
}
continue
}
if len(columnData) <= rowIdx {
return fmt.Errorf("datasource column is error %s", tag)
}
meta := columnData[rowIdx].(*sql.NullString)
if !ok {
if !ds.blur {
return fmt.Errorf("No corresponding field was found in the result set %s!", tag)
}
continue
}
if !value.CanSet() {
return errors.New("Struct fields do not have read or write permissions!")
}
if meta.Valid == false {
continue
}
if len(meta.String) == 0 {
continue
}
switch kind {
case reflect.String:
value.SetString(meta.String)
case reflect.Float32, reflect.Float64:
f, err := strconv.ParseFloat(meta.String, 64)
if err != nil {
return err
}
value.SetFloat(f)
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
integer64, err := strconv.ParseInt(meta.String, 10, 64)
if err != nil {
return err
}
value.SetInt(integer64)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
integer64, err := strconv.ParseUint(meta.String, 10, 64)
if err != nil {
return err
}
value.SetUint(integer64)
case reflect.Bool:
b, err := strconv.ParseBool(meta.String)
if err != nil {
return err
}
value.SetBool(b)
default:
return errors.New("The database map has unrecognized data types!")
}
}
}
return nil
}