add Tx struct and some method for tx

This commit is contained in:
zhengguanghao
2019-09-17 00:44:04 +08:00
parent 1aa5a780f3
commit ed78b8985d
2 changed files with 268 additions and 23 deletions

View File

@@ -34,6 +34,12 @@ type DBModule struct {
syncCoroutineNum int
}
// Tx ...
type Tx struct {
tx *sql.Tx
PrintTime time.Duration
}
// DBResult ...
type DBResult struct {
Err error
@@ -476,7 +482,7 @@ func (slf *DBModule) QueryEx(query string, args ...interface{}) (*DataSetList, e
hasRet := rows.NextResultSet()
if hasRet == false {
if rows.Err()!= nil {
if rows.Err() != nil {
service.GetLogger().Printf(service.LEVER_ERROR, "Query:%s(%+v)", query, rows)
}
break
@@ -728,3 +734,203 @@ func (slf *DataSetList) rowData2interface(rowIdx int, m map[string][]interface{}
}
return nil
}
func (slf *DBModule) GetTx() (*Tx, error) {
var txDBMoudule Tx
txdb, err := slf.db.Begin()
if err != nil {
return &txDBMoudule, err
}
txDBMoudule.tx = txdb
return &txDBMoudule, nil
}
func (slf *Tx) Rollback() error {
return slf.tx.Rollback()
}
func (slf *Tx) Commit() error {
return slf.tx.Commit()
}
func (slf *Tx) CheckArgs(args ...interface{}) error {
for _, val := range args {
if reflect.TypeOf(val).Kind() == reflect.String {
retVal := val.(string)
if strings.Contains(retVal, "-") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
if strings.Contains(retVal, "#") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
if strings.Contains(retVal, "&") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
if strings.Contains(retVal, "=") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
if strings.Contains(retVal, "%") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
if strings.Contains(retVal, "'") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
if strings.Contains(strings.ToLower(retVal), "delete ") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
if strings.Contains(strings.ToLower(retVal), "truncate ") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
if strings.Contains(strings.ToLower(retVal), " or ") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
if strings.Contains(strings.ToLower(retVal), "from ") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
if strings.Contains(strings.ToLower(retVal), "set ") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
}
}
return nil
}
func (slf *Tx) Query(query string, args ...interface{}) DBResult {
if slf.CheckArgs(args) != nil {
ret := DBResult{}
service.GetLogger().Printf(service.LEVER_ERROR, "CheckArgs is error :%s", query)
ret.Err = fmt.Errorf("CheckArgs is error!")
return ret
}
if slf.tx == nil {
ret := DBResult{}
service.GetLogger().Printf(service.LEVER_ERROR, "cannot connect database:%s", query)
ret.Err = fmt.Errorf("cannot connect database!")
return ret
}
rows, err := slf.tx.Query(query, args...)
if err != nil {
service.GetLogger().Printf(service.LEVER_ERROR, "Query:%s(%v)", query, err)
}
return DBResult{
Err: err,
res: rows,
tag: "json",
blur: true,
}
}
func (slf *Tx) IsPrintTimeLog(Time time.Duration) bool {
if slf.PrintTime != 0 && Time >= slf.PrintTime {
return true
}
return false
}
func (slf *Tx) QueryEx(query string, args ...interface{}) (*DataSetList, error) {
datasetList := DataSetList{}
datasetList.tag = "json"
datasetList.blur = true
if slf.CheckArgs(args) != nil {
service.GetLogger().Printf(service.LEVER_ERROR, "CheckArgs is error :%s", query)
return &datasetList, fmt.Errorf("CheckArgs is error!")
}
if slf.tx == nil {
service.GetLogger().Printf(service.LEVER_ERROR, "cannot connect database:%s", query)
return &datasetList, fmt.Errorf("cannot connect database!")
}
TimeFuncStart := time.Now()
rows, err := slf.tx.Query(query, args...)
TimeFuncPass := time.Since(TimeFuncStart)
if slf.IsPrintTimeLog(TimeFuncPass) {
service.GetLogger().Printf(service.LEVER_INFO, "DBModule Tx QueryEx Time %s , Query :%s , args :%+v", TimeFuncPass, query, args)
}
if err != nil {
service.GetLogger().Printf(service.LEVER_ERROR, "Tx Query:%s(%v)", query, err)
if rows != nil {
rows.Close()
}
return &datasetList, err
}
defer rows.Close()
for {
dbResult := DBResultEx{}
//取出当前结果集所有行
for rows.Next() {
if dbResult.RowInfo == nil {
dbResult.RowInfo = make(map[string][]interface{})
}
//RowInfo map[string][][]sql.NullString //map[fieldname][row][column]sql.NullString
colField, err := rows.Columns()
if err != nil {
return &datasetList, err
}
count := len(colField)
valuePtrs := make([]interface{}, count)
for i := 0; i < count; i++ {
valuePtrs[i] = &sql.NullString{}
}
rows.Scan(valuePtrs...)
for idx, fieldname := range colField {
fieldRowData := dbResult.RowInfo[strings.ToLower(fieldname)]
fieldRowData = append(fieldRowData, valuePtrs[idx])
dbResult.RowInfo[strings.ToLower(fieldname)] = fieldRowData
}
dbResult.rowNum += 1
}
datasetList.dataSetList = append(datasetList.dataSetList, dbResult)
//取下一个结果集
hasRet := rows.NextResultSet()
if hasRet == false {
if rows.Err() != nil {
service.GetLogger().Printf(service.LEVER_ERROR, "Query:%s(%+v)", query, rows)
}
break
}
}
return &datasetList, nil
}
// Exec ...
func (slf *Tx) Exec(query string, args ...interface{}) (*DBResultEx, error) {
ret := &DBResultEx{}
if slf.tx == nil {
service.GetLogger().Printf(service.LEVER_ERROR, "cannot connect database:%s", query)
return ret, fmt.Errorf("cannot connect database!")
}
if slf.CheckArgs(args) != nil {
service.GetLogger().Printf(service.LEVER_ERROR, "CheckArgs is error :%s", query)
//return ret, fmt.Errorf("cannot connect database!")
return ret, fmt.Errorf("CheckArgs is error!")
}
TimeFuncStart := time.Now()
res, err := slf.tx.Exec(query, args...)
TimeFuncPass := time.Since(TimeFuncStart)
if slf.IsPrintTimeLog(TimeFuncPass) {
service.GetLogger().Printf(service.LEVER_INFO, "DBModule QueryEx Time %s , Query :%s , args :%+v", TimeFuncPass, query, args)
}
if err != nil {
service.GetLogger().Printf(service.LEVER_ERROR, "Exec:%s(%v)", query, err)
return nil, err
}
ret.LastInsertID, _ = res.LastInsertId()
ret.RowsAffected, _ = res.RowsAffected()
return ret, nil
}

View File

@@ -1,6 +1,7 @@
package sysmodule_test
import (
"fmt"
"sync"
"testing"
@@ -13,38 +14,76 @@ func TestDBModule(t *testing.T) {
db.ExitChan = make(chan bool)
db.WaitGroup = new(sync.WaitGroup)
db.Init(100, "192.168.0.5:3306", "root", "Root!!2018", "QuantFundsDB")
// db.Init(100, "192.168.0.5:3306", "root", "Root!!2018", "QuantFundsDB")
db.Init(100, "127.0.0.1:3306", "root", "zgh50221", "rebort_message")
db.OnInit()
res, err := db.QueryEx("select * from tbl_fun_heelthrow where id >= 1")
tx, err := db.GetTx()
if err != nil {
t.Error(err)
fmt.Println("err 1", err)
return
}
res, err := tx.QueryEx("select id as Id, info_type as InfoType, info_type_Name as InfoTypeName from tbl_info_type where id >= 1")
if err != nil {
fmt.Println("err 2", err)
tx.Rollback()
return
}
out := []struct {
Addtime int64 `json:"addtime"`
Tname string `json:"tname"`
Uuid string `json:"uuid,omitempty"`
AAAA string `json:"xxx"`
Id int64
InfoType string
InfoTypeName string
}{}
err = res.UnMarshal(&out)
if err != nil {
t.Error(err)
fmt.Println("err 3", err)
tx.Rollback()
return
}
sres := db.SyncQuery("select * from tbl_fun_heelthrow where id >= 1")
res, err = sres.Get(2000)
fmt.Println(out)
_, err = tx.Exec("insert into tbl_info_type(info_type, info_type_name) VALUES (?, ?)", "4", "weibo")
if err != nil {
t.Error(err)
fmt.Println("err 4", err)
tx.Rollback()
return
}
out2 := []struct {
Addtime int64 `json:"addtime"`
Tname string `json:"tname"`
Uuid string `json:"uuid,omitempty"`
AAAA string `json:"xxx"`
}{}
err = res.UnMarshal(&out2)
_, err = tx.Exec("update tbl_info_type set info_types = ? Where id = ?", "5", 0)
if err != nil {
t.Error(err)
fmt.Println("err 4", err)
tx.Rollback()
return
}
tx.Commit()
// res, err := db.QueryEx("select * from tbl_fun_heelthrow where id >= 1")
// if err != nil {
// t.Error(err)
// }
// out := []struct {
// Addtime int64 `json:"addtime"`
// Tname string `json:"tname"`
// Uuid string `json:"uuid,omitempty"`
// AAAA string `json:"xxx"`
// }{}
// err = res.UnMarshal(&out)
// if err != nil {
// t.Error(err)
// }
// sres := db.SyncQuery("select * from tbl_fun_heelthrow where id >= 1")
// res, err = sres.Get(2000)
// if err != nil {
// t.Error(err)
// }
// out2 := []struct {
// Addtime int64 `json:"addtime"`
// Tname string `json:"tname"`
// Uuid string `json:"uuid,omitempty"`
// AAAA string `json:"xxx"`
// }{}
// err = res.UnMarshal(&out2)
// if err != nil {
// t.Error(err)
// }
}