mirror of
https://github.com/duanhf2012/origin.git
synced 2026-02-03 22:45:13 +08:00
规范重命名mysqlmodule模块名称
This commit is contained in:
468
sysmodule/mysqlmodule/mysqlmodule.go
Normal file
468
sysmodule/mysqlmodule/mysqlmodule.go
Normal file
@@ -0,0 +1,468 @@
|
||||
package mysqlmondule
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/duanhf2012/origin/log"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/duanhf2012/origin/service"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
||||
type SyncFun func()
|
||||
|
||||
type DBExecute struct {
|
||||
syncExecuteFun chan SyncFun
|
||||
syncExecuteExit chan bool
|
||||
}
|
||||
|
||||
type PingExecute struct {
|
||||
tickerPing *time.Ticker
|
||||
pintExit chan bool
|
||||
}
|
||||
|
||||
// DBModule ...
|
||||
type MySQLModule struct {
|
||||
service.Module
|
||||
db *sql.DB
|
||||
url string
|
||||
username string
|
||||
password string
|
||||
dbname string
|
||||
slowDuration time.Duration
|
||||
pingCoroutine PingExecute
|
||||
waitGroup sync.WaitGroup
|
||||
}
|
||||
|
||||
// Tx ...
|
||||
type Tx struct {
|
||||
tx *sql.Tx
|
||||
slowDuration time.Duration
|
||||
}
|
||||
|
||||
// DBResult ...
|
||||
type DBResult struct {
|
||||
LastInsertID int64
|
||||
RowsAffected int64
|
||||
|
||||
rowNum int
|
||||
RowInfo map[string][]interface{} //map[fieldname][row]sql.NullString
|
||||
}
|
||||
|
||||
type DataSetList struct {
|
||||
dataSetList []DBResult
|
||||
currentDataSetIdx int32
|
||||
tag string
|
||||
blur bool
|
||||
}
|
||||
|
||||
|
||||
type dbControl interface {
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
|
||||
func (m *MySQLModule) Init( url string, userName string, password string, dbname string,maxConn int) error {
|
||||
m.url = url
|
||||
m.username = userName
|
||||
m.password = password
|
||||
m.dbname = dbname
|
||||
m.pingCoroutine = PingExecute{tickerPing : time.NewTicker(5*time.Second), pintExit : make(chan bool, 1)}
|
||||
|
||||
return m.connect(maxConn)
|
||||
}
|
||||
|
||||
func (m *MySQLModule) SetQuerySlowTime(slowDuration time.Duration) {
|
||||
m.slowDuration = slowDuration
|
||||
}
|
||||
|
||||
func (m *MySQLModule) Query(strQuery string, args ...interface{}) (*DataSetList, error) {
|
||||
return query(m.slowDuration, m.db,strQuery,args...)
|
||||
}
|
||||
|
||||
// Exec ...
|
||||
func (m *MySQLModule) Exec(strSql string, args ...interface{}) (*DBResult, error) {
|
||||
return exec(m.slowDuration, m.db,strSql,args...)
|
||||
}
|
||||
|
||||
// Begin starts a transaction.
|
||||
func (m *MySQLModule) Begin() (*Tx, error) {
|
||||
var txDBModule Tx
|
||||
txDb, err := m.db.Begin()
|
||||
if err != nil {
|
||||
log.Error("Begin error:%s", err.Error())
|
||||
return &txDBModule, err
|
||||
}
|
||||
txDBModule.slowDuration = m.slowDuration
|
||||
txDBModule.tx = txDb
|
||||
return &txDBModule, nil
|
||||
}
|
||||
|
||||
// Rollback aborts the transaction.
|
||||
func (slf *Tx) Rollback() error {
|
||||
return slf.tx.Rollback()
|
||||
}
|
||||
|
||||
// Commit commits the transaction.
|
||||
func (slf *Tx) Commit() error {
|
||||
return slf.tx.Commit()
|
||||
}
|
||||
|
||||
// QueryEx executes a query that return rows.
|
||||
func (slf *Tx) Query(strQuery string, args ...interface{}) (*DataSetList, error) {
|
||||
return query(slf.slowDuration,slf.tx,strQuery,args...)
|
||||
}
|
||||
|
||||
// Exec executes a query that doesn't return rows.
|
||||
func (slf *Tx) Exec(strSql string, args ...interface{}) (*DBResult, error) {
|
||||
return exec(slf.slowDuration,slf.tx,strSql,args...)
|
||||
}
|
||||
|
||||
// Connect ...
|
||||
func (m *MySQLModule) connect(maxConn int) error {
|
||||
cmd := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8&parseTime=true&loc=%s&readTimeout=30s&timeout=15s&writeTimeout=30s",
|
||||
m.username,
|
||||
m.password,
|
||||
m.url,
|
||||
m.dbname,
|
||||
url.QueryEscape(time.Local.String()))
|
||||
|
||||
db, err := sql.Open("mysql", cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = db.Ping()
|
||||
if err != nil {
|
||||
db.Close()
|
||||
return err
|
||||
}
|
||||
m.db = db
|
||||
db.SetMaxOpenConns(maxConn)
|
||||
db.SetMaxIdleConns(maxConn)
|
||||
db.SetConnMaxLifetime(time.Second * 90)
|
||||
|
||||
go m.runPing()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MySQLModule) runPing() {
|
||||
for {
|
||||
select {
|
||||
case <-m.pingCoroutine.pintExit:
|
||||
log.Error("RunPing stopping %s...", fmt.Sprintf("%T", m))
|
||||
return
|
||||
case <-m.pingCoroutine.tickerPing.C:
|
||||
if m.db != nil {
|
||||
m.db.Ping()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func checkArgs(args ...interface{}) error {
|
||||
for _, val := range args {
|
||||
if reflect.TypeOf(val).Kind() == reflect.String {
|
||||
retVal := val.(string)
|
||||
if strings.Contains(retVal, "-") == true {
|
||||
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
||||
}
|
||||
if strings.Contains(retVal, "#") == true {
|
||||
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
||||
}
|
||||
if strings.Contains(retVal, "&") == true {
|
||||
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
||||
}
|
||||
if strings.Contains(retVal, "=") == true {
|
||||
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
||||
}
|
||||
if strings.Contains(retVal, "%") == true {
|
||||
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
||||
}
|
||||
if strings.Contains(retVal, "'") == true {
|
||||
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(retVal), "delete ") == true {
|
||||
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(retVal), "truncate ") == true {
|
||||
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(retVal), " or ") == true {
|
||||
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(retVal), "from ") == true {
|
||||
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(retVal), "set ") == true {
|
||||
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkSlow(slowDuration time.Duration,Time time.Duration) bool {
|
||||
if slowDuration != 0 && Time >=slowDuration {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func query(slowDuration time.Duration,db dbControl,strQuery string, args ...interface{}) (*DataSetList, error) {
|
||||
datasetList := DataSetList{}
|
||||
datasetList.tag = "json"
|
||||
datasetList.blur = true
|
||||
|
||||
if checkArgs(args) != nil {
|
||||
log.Error("CheckArgs is error :%s", strQuery)
|
||||
return &datasetList, fmt.Errorf("CheckArgs is error!")
|
||||
}
|
||||
|
||||
if db == nil {
|
||||
log.Error("cannot connect database:%s", strQuery)
|
||||
return &datasetList, fmt.Errorf("cannot connect database!")
|
||||
}
|
||||
|
||||
TimeFuncStart := time.Now()
|
||||
rows, err := db.Query(strQuery, args...)
|
||||
timeFuncPass := time.Since(TimeFuncStart)
|
||||
|
||||
if checkSlow(slowDuration,timeFuncPass) {
|
||||
log.Error("DBModule QueryEx Time %s , Query :%s , args :%+v", timeFuncPass, strQuery, args)
|
||||
}
|
||||
if err != nil {
|
||||
log.Error("Query:%s(%v)", strQuery, err)
|
||||
if rows != nil {
|
||||
rows.Close()
|
||||
}
|
||||
return &datasetList, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for {
|
||||
dbResult := DBResult{}
|
||||
//取出当前结果集所有行
|
||||
for rows.Next() {
|
||||
if dbResult.RowInfo == nil {
|
||||
dbResult.RowInfo = make(map[string][]interface{})
|
||||
}
|
||||
//RowInfo map[string][][]sql.NullString //map[fieldname][row][column]sql.NullString
|
||||
colField, err := rows.Columns()
|
||||
if err != nil {
|
||||
return &datasetList, err
|
||||
}
|
||||
count := len(colField)
|
||||
valuePtrs := make([]interface{}, count)
|
||||
for i := 0; i < count; i++ {
|
||||
valuePtrs[i] = &sql.NullString{}
|
||||
}
|
||||
rows.Scan(valuePtrs...)
|
||||
|
||||
for idx, fieldname := range colField {
|
||||
fieldRowData := dbResult.RowInfo[strings.ToLower(fieldname)]
|
||||
fieldRowData = append(fieldRowData, valuePtrs[idx])
|
||||
dbResult.RowInfo[strings.ToLower(fieldname)] = fieldRowData
|
||||
}
|
||||
dbResult.rowNum += 1
|
||||
}
|
||||
|
||||
datasetList.dataSetList = append(datasetList.dataSetList, dbResult)
|
||||
//取下一个结果集
|
||||
hasRet := rows.NextResultSet()
|
||||
|
||||
if hasRet == false {
|
||||
if rows.Err() != nil {
|
||||
log.Error( "Query:%s(%+v)", strQuery, rows)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return &datasetList, nil
|
||||
}
|
||||
|
||||
func exec(slowDuration time.Duration,db dbControl,strSql string, args ...interface{}) (*DBResult, error) {
|
||||
ret := &DBResult{}
|
||||
if db == nil {
|
||||
log.Error("cannot connect database:%s", strSql)
|
||||
return ret, fmt.Errorf("cannot connect database!")
|
||||
}
|
||||
|
||||
if checkArgs(args) != nil {
|
||||
log.Error("CheckArgs is error :%s", strSql)
|
||||
return ret, fmt.Errorf("CheckArgs is error!")
|
||||
}
|
||||
|
||||
TimeFuncStart := time.Now()
|
||||
res, err := db.Exec(strSql, args...)
|
||||
timeFuncPass := time.Since(TimeFuncStart)
|
||||
if checkSlow(slowDuration,timeFuncPass) {
|
||||
log.Error("DBModule QueryEx Time %s , Query :%s , args :%+v", timeFuncPass, strSql, args)
|
||||
}
|
||||
if err != nil {
|
||||
log.Error("Exec:%s(%v)", strSql, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret.LastInsertID, _ = res.LastInsertId()
|
||||
ret.RowsAffected, _ = res.RowsAffected()
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (ds *DataSetList) UnMarshal(args ...interface{}) error {
|
||||
if len(ds.dataSetList) != len(args) {
|
||||
return errors.New(fmt.Sprintf("Data set len(%d,%d) is not equal to args!", len(ds.dataSetList), len(args)))
|
||||
}
|
||||
|
||||
for _, out := range args {
|
||||
v := reflect.ValueOf(out)
|
||||
if v.Kind() != reflect.Ptr {
|
||||
return errors.New("interface must be a pointer")
|
||||
}
|
||||
|
||||
if v.Kind() != reflect.Ptr {
|
||||
return errors.New("interface must be a pointer")
|
||||
}
|
||||
|
||||
if v.Elem().Kind() == reflect.Struct {
|
||||
err := ds.rowData2interface(0, ds.dataSetList[ds.currentDataSetIdx].RowInfo, v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if v.Elem().Kind() == reflect.Slice {
|
||||
err := ds.slice2interface(out)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ds.currentDataSetIdx = ds.currentDataSetIdx + 1
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ds *DataSetList) slice2interface(in interface{}) error {
|
||||
length := ds.dataSetList[ds.currentDataSetIdx].rowNum
|
||||
if length == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
v := reflect.ValueOf(in).Elem()
|
||||
newv := reflect.MakeSlice(v.Type(), 0, length)
|
||||
v.Set(newv)
|
||||
v.SetLen(length)
|
||||
|
||||
for i := 0; i < length; i++ {
|
||||
idxv := v.Index(i)
|
||||
if idxv.Kind() == reflect.Ptr {
|
||||
newObj := reflect.New(idxv.Type().Elem())
|
||||
v.Index(i).Set(newObj)
|
||||
idxv = newObj
|
||||
} else {
|
||||
idxv = idxv.Addr()
|
||||
}
|
||||
|
||||
err := ds.rowData2interface(i, ds.dataSetList[ds.currentDataSetIdx].RowInfo, idxv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ds *DataSetList) rowData2interface(rowIdx int, m map[string][]interface{}, v reflect.Value) error {
|
||||
t := v.Type()
|
||||
val := v.Elem()
|
||||
typ := t.Elem()
|
||||
|
||||
if !val.IsValid() {
|
||||
return errors.New("Incorrect data type!")
|
||||
}
|
||||
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
value := val.Field(i)
|
||||
kind := value.Kind()
|
||||
tag := typ.Field(i).Tag.Get(ds.tag)
|
||||
if tag == "" {
|
||||
tag = typ.Field(i).Name
|
||||
}
|
||||
|
||||
if tag != "" && tag != "-" {
|
||||
vtag := strings.ToLower(tag)
|
||||
columnData, ok := m[vtag]
|
||||
if ok == false {
|
||||
if !ds.blur {
|
||||
return fmt.Errorf("Cannot find filed name %s!", vtag)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if len(columnData) <= rowIdx {
|
||||
return fmt.Errorf("datasource column is error %s", tag)
|
||||
}
|
||||
meta := columnData[rowIdx].(*sql.NullString)
|
||||
if !ok {
|
||||
if !ds.blur {
|
||||
return fmt.Errorf("No corresponding field was found in the result set %s!", tag)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if !value.CanSet() {
|
||||
return errors.New("Struct fields do not have read or write permissions!")
|
||||
}
|
||||
|
||||
if meta.Valid == false {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(meta.String) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case reflect.String:
|
||||
value.SetString(meta.String)
|
||||
case reflect.Float32, reflect.Float64:
|
||||
f, err := strconv.ParseFloat(meta.String, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
value.SetFloat(f)
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
|
||||
integer64, err := strconv.ParseInt(meta.String, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
value.SetInt(integer64)
|
||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
||||
integer64, err := strconv.ParseUint(meta.String, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
value.SetUint(integer64)
|
||||
case reflect.Bool:
|
||||
b, err := strconv.ParseBool(meta.String)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
value.SetBool(b)
|
||||
default:
|
||||
return errors.New("The database map has unrecognized data types!")
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
48
sysmodule/mysqlmodule/mysqlmodule_test.go
Normal file
48
sysmodule/mysqlmodule/mysqlmodule_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package mysqlmondule
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_Example(t *testing.T) {
|
||||
//初始化连接
|
||||
test := MySQLModule{}
|
||||
err := test.Init("127.0.0.1:3306","root","123456","dbname",10)
|
||||
if err !=nil {
|
||||
fmt.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
//设置1秒慢查询监控
|
||||
test.SetQuerySlowTime(1*time.Second)
|
||||
|
||||
//查询语句
|
||||
//存储过程可以call sp_querytest(?,?,?)方式
|
||||
result,err := test.Query("select Id,Name from UserInfo where SvrAreaId=?",18)
|
||||
if err!=nil {
|
||||
fmt.Print(err)
|
||||
}else{
|
||||
dbRet := []struct {
|
||||
Id int `json:"Id"`
|
||||
Name string `json:"Name"`
|
||||
}{}
|
||||
|
||||
//从结构集中返序列化数据到结构体切片中,UnMarshal可以支持多个结果集
|
||||
err = result.UnMarshal(&dbRet)
|
||||
if err !=nil {
|
||||
fmt.Print(err)
|
||||
}
|
||||
}
|
||||
|
||||
//exec sql
|
||||
//存储过程可以call sp_updatetest(?,?,?)方式
|
||||
execResult,execErr := test.Exec("update UserInfo set Name=? where Id=?","nickname",1000)
|
||||
if execErr!=nil {
|
||||
fmt.Print(execErr)
|
||||
}else{
|
||||
fmt.Print(*execResult)
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user