From ed78b8985daa4b3b39301e47d543a2ab6069367f Mon Sep 17 00:00:00 2001 From: zhengguanghao Date: Tue, 17 Sep 2019 00:44:04 +0800 Subject: [PATCH 1/2] add Tx struct and some method for tx --- sysmodule/DBModule.go | 208 ++++++++++++++++++++++++++++++++++++- sysmodule/DBModule_test.go | 83 +++++++++++---- 2 files changed, 268 insertions(+), 23 deletions(-) diff --git a/sysmodule/DBModule.go b/sysmodule/DBModule.go index 8d2d0fc..07381ba 100644 --- a/sysmodule/DBModule.go +++ b/sysmodule/DBModule.go @@ -34,6 +34,12 @@ type DBModule struct { syncCoroutineNum int } +// Tx ... +type Tx struct { + tx *sql.Tx + PrintTime time.Duration +} + // DBResult ... type DBResult struct { Err error @@ -476,7 +482,7 @@ func (slf *DBModule) QueryEx(query string, args ...interface{}) (*DataSetList, e hasRet := rows.NextResultSet() if hasRet == false { - if rows.Err()!= nil { + if rows.Err() != nil { service.GetLogger().Printf(service.LEVER_ERROR, "Query:%s(%+v)", query, rows) } break @@ -728,3 +734,203 @@ func (slf *DataSetList) rowData2interface(rowIdx int, m map[string][]interface{} } return nil } + +func (slf *DBModule) GetTx() (*Tx, error) { + var txDBMoudule Tx + txdb, err := slf.db.Begin() + if err != nil { + return &txDBMoudule, err + } + txDBMoudule.tx = txdb + return &txDBMoudule, nil +} + +func (slf *Tx) Rollback() error { + return slf.tx.Rollback() +} + +func (slf *Tx) Commit() error { + return slf.tx.Commit() +} + +func (slf *Tx) CheckArgs(args ...interface{}) error { + for _, val := range args { + if reflect.TypeOf(val).Kind() == reflect.String { + retVal := val.(string) + if strings.Contains(retVal, "-") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(retVal, "#") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(retVal, "&") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(retVal, "=") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(retVal, "%") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(retVal, "'") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(strings.ToLower(retVal), "delete ") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(strings.ToLower(retVal), "truncate ") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(strings.ToLower(retVal), " or ") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(strings.ToLower(retVal), "from ") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(strings.ToLower(retVal), "set ") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + } + } + + return nil +} + +func (slf *Tx) Query(query string, args ...interface{}) DBResult { + if slf.CheckArgs(args) != nil { + ret := DBResult{} + service.GetLogger().Printf(service.LEVER_ERROR, "CheckArgs is error :%s", query) + ret.Err = fmt.Errorf("CheckArgs is error!") + return ret + } + + if slf.tx == nil { + ret := DBResult{} + service.GetLogger().Printf(service.LEVER_ERROR, "cannot connect database:%s", query) + ret.Err = fmt.Errorf("cannot connect database!") + return ret + } + + rows, err := slf.tx.Query(query, args...) + if err != nil { + service.GetLogger().Printf(service.LEVER_ERROR, "Query:%s(%v)", query, err) + } + + return DBResult{ + Err: err, + res: rows, + tag: "json", + blur: true, + } +} + +func (slf *Tx) IsPrintTimeLog(Time time.Duration) bool { + if slf.PrintTime != 0 && Time >= slf.PrintTime { + return true + } + return false +} + +func (slf *Tx) QueryEx(query string, args ...interface{}) (*DataSetList, error) { + datasetList := DataSetList{} + datasetList.tag = "json" + datasetList.blur = true + + if slf.CheckArgs(args) != nil { + service.GetLogger().Printf(service.LEVER_ERROR, "CheckArgs is error :%s", query) + return &datasetList, fmt.Errorf("CheckArgs is error!") + } + + if slf.tx == nil { + service.GetLogger().Printf(service.LEVER_ERROR, "cannot connect database:%s", query) + return &datasetList, fmt.Errorf("cannot connect database!") + } + + TimeFuncStart := time.Now() + rows, err := slf.tx.Query(query, args...) + TimeFuncPass := time.Since(TimeFuncStart) + + if slf.IsPrintTimeLog(TimeFuncPass) { + service.GetLogger().Printf(service.LEVER_INFO, "DBModule Tx QueryEx Time %s , Query :%s , args :%+v", TimeFuncPass, query, args) + } + if err != nil { + service.GetLogger().Printf(service.LEVER_ERROR, "Tx Query:%s(%v)", query, err) + if rows != nil { + rows.Close() + } + return &datasetList, err + } + defer rows.Close() + + for { + dbResult := DBResultEx{} + //取出当前结果集所有行 + for rows.Next() { + if dbResult.RowInfo == nil { + dbResult.RowInfo = make(map[string][]interface{}) + } + //RowInfo map[string][][]sql.NullString //map[fieldname][row][column]sql.NullString + colField, err := rows.Columns() + if err != nil { + return &datasetList, err + } + count := len(colField) + valuePtrs := make([]interface{}, count) + for i := 0; i < count; i++ { + valuePtrs[i] = &sql.NullString{} + } + rows.Scan(valuePtrs...) + + for idx, fieldname := range colField { + fieldRowData := dbResult.RowInfo[strings.ToLower(fieldname)] + fieldRowData = append(fieldRowData, valuePtrs[idx]) + dbResult.RowInfo[strings.ToLower(fieldname)] = fieldRowData + } + dbResult.rowNum += 1 + } + + datasetList.dataSetList = append(datasetList.dataSetList, dbResult) + //取下一个结果集 + hasRet := rows.NextResultSet() + + if hasRet == false { + if rows.Err() != nil { + service.GetLogger().Printf(service.LEVER_ERROR, "Query:%s(%+v)", query, rows) + } + break + } + } + + return &datasetList, nil +} + +// Exec ... +func (slf *Tx) Exec(query string, args ...interface{}) (*DBResultEx, error) { + ret := &DBResultEx{} + if slf.tx == nil { + service.GetLogger().Printf(service.LEVER_ERROR, "cannot connect database:%s", query) + return ret, fmt.Errorf("cannot connect database!") + } + + if slf.CheckArgs(args) != nil { + service.GetLogger().Printf(service.LEVER_ERROR, "CheckArgs is error :%s", query) + //return ret, fmt.Errorf("cannot connect database!") + return ret, fmt.Errorf("CheckArgs is error!") + } + + TimeFuncStart := time.Now() + res, err := slf.tx.Exec(query, args...) + TimeFuncPass := time.Since(TimeFuncStart) + if slf.IsPrintTimeLog(TimeFuncPass) { + service.GetLogger().Printf(service.LEVER_INFO, "DBModule QueryEx Time %s , Query :%s , args :%+v", TimeFuncPass, query, args) + } + if err != nil { + service.GetLogger().Printf(service.LEVER_ERROR, "Exec:%s(%v)", query, err) + return nil, err + } + + ret.LastInsertID, _ = res.LastInsertId() + ret.RowsAffected, _ = res.RowsAffected() + + return ret, nil +} diff --git a/sysmodule/DBModule_test.go b/sysmodule/DBModule_test.go index ada9031..1eba7f1 100644 --- a/sysmodule/DBModule_test.go +++ b/sysmodule/DBModule_test.go @@ -1,6 +1,7 @@ package sysmodule_test import ( + "fmt" "sync" "testing" @@ -13,38 +14,76 @@ func TestDBModule(t *testing.T) { db.ExitChan = make(chan bool) db.WaitGroup = new(sync.WaitGroup) - db.Init(100, "192.168.0.5:3306", "root", "Root!!2018", "QuantFundsDB") + // db.Init(100, "192.168.0.5:3306", "root", "Root!!2018", "QuantFundsDB") + db.Init(100, "127.0.0.1:3306", "root", "zgh50221", "rebort_message") db.OnInit() - res, err := db.QueryEx("select * from tbl_fun_heelthrow where id >= 1") + tx, err := db.GetTx() if err != nil { - t.Error(err) + fmt.Println("err 1", err) + return + } + res, err := tx.QueryEx("select id as Id, info_type as InfoType, info_type_Name as InfoTypeName from tbl_info_type where id >= 1") + if err != nil { + fmt.Println("err 2", err) + tx.Rollback() + return } out := []struct { - Addtime int64 `json:"addtime"` - Tname string `json:"tname"` - Uuid string `json:"uuid,omitempty"` - AAAA string `json:"xxx"` + Id int64 + InfoType string + InfoTypeName string }{} err = res.UnMarshal(&out) if err != nil { - t.Error(err) + fmt.Println("err 3", err) + tx.Rollback() + return } - - sres := db.SyncQuery("select * from tbl_fun_heelthrow where id >= 1") - res, err = sres.Get(2000) + fmt.Println(out) + _, err = tx.Exec("insert into tbl_info_type(info_type, info_type_name) VALUES (?, ?)", "4", "weibo") if err != nil { - t.Error(err) + fmt.Println("err 4", err) + tx.Rollback() + return } - - out2 := []struct { - Addtime int64 `json:"addtime"` - Tname string `json:"tname"` - Uuid string `json:"uuid,omitempty"` - AAAA string `json:"xxx"` - }{} - - err = res.UnMarshal(&out2) + _, err = tx.Exec("update tbl_info_type set info_types = ? Where id = ?", "5", 0) if err != nil { - t.Error(err) + fmt.Println("err 4", err) + tx.Rollback() + return } + + tx.Commit() + // res, err := db.QueryEx("select * from tbl_fun_heelthrow where id >= 1") + // if err != nil { + // t.Error(err) + // } + // out := []struct { + // Addtime int64 `json:"addtime"` + // Tname string `json:"tname"` + // Uuid string `json:"uuid,omitempty"` + // AAAA string `json:"xxx"` + // }{} + // err = res.UnMarshal(&out) + // if err != nil { + // t.Error(err) + // } + + // sres := db.SyncQuery("select * from tbl_fun_heelthrow where id >= 1") + // res, err = sres.Get(2000) + // if err != nil { + // t.Error(err) + // } + + // out2 := []struct { + // Addtime int64 `json:"addtime"` + // Tname string `json:"tname"` + // Uuid string `json:"uuid,omitempty"` + // AAAA string `json:"xxx"` + // }{} + + // err = res.UnMarshal(&out2) + // if err != nil { + // t.Error(err) + // } } From 834f3ff062787205bdfd94169c2b803000b78436 Mon Sep 17 00:00:00 2001 From: zhengguanghao Date: Tue, 17 Sep 2019 10:06:33 +0800 Subject: [PATCH 2/2] add comment --- sysmodule/DBModule.go | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/sysmodule/DBModule.go b/sysmodule/DBModule.go index 07381ba..450d7d3 100644 --- a/sysmodule/DBModule.go +++ b/sysmodule/DBModule.go @@ -735,24 +735,29 @@ func (slf *DataSetList) rowData2interface(rowIdx int, m map[string][]interface{} return nil } -func (slf *DBModule) GetTx() (*Tx, error) { +// Begin starts a transaction. +func (slf *DBModule) Begin() (*Tx, error) { var txDBMoudule Tx txdb, err := slf.db.Begin() if err != nil { + service.GetLogger().Printf(service.LEVER_ERROR, "Begin error:%s", err.Error()) return &txDBMoudule, err } txDBMoudule.tx = txdb return &txDBMoudule, nil } +// Rollback aborts the transaction. func (slf *Tx) Rollback() error { return slf.tx.Rollback() } +// Commit commits the transaction. func (slf *Tx) Commit() error { return slf.tx.Commit() } +// CheckArgs... func (slf *Tx) CheckArgs(args ...interface{}) error { for _, val := range args { if reflect.TypeOf(val).Kind() == reflect.String { @@ -796,6 +801,7 @@ func (slf *Tx) CheckArgs(args ...interface{}) error { return nil } +// Query executes a query that returns rows, typically a SELECT. func (slf *Tx) Query(query string, args ...interface{}) DBResult { if slf.CheckArgs(args) != nil { ret := DBResult{} @@ -813,7 +819,7 @@ func (slf *Tx) Query(query string, args ...interface{}) DBResult { rows, err := slf.tx.Query(query, args...) if err != nil { - service.GetLogger().Printf(service.LEVER_ERROR, "Query:%s(%v)", query, err) + service.GetLogger().Printf(service.LEVER_ERROR, "Tx Query:%s(%v)", query, err) } return DBResult{ @@ -824,6 +830,7 @@ func (slf *Tx) Query(query string, args ...interface{}) DBResult { } } +// IsPrintTimeLog... func (slf *Tx) IsPrintTimeLog(Time time.Duration) bool { if slf.PrintTime != 0 && Time >= slf.PrintTime { return true @@ -831,6 +838,7 @@ func (slf *Tx) IsPrintTimeLog(Time time.Duration) bool { return false } +// QueryEx executes a query that return rows. func (slf *Tx) QueryEx(query string, args ...interface{}) (*DataSetList, error) { datasetList := DataSetList{} datasetList.tag = "json" @@ -851,7 +859,7 @@ func (slf *Tx) QueryEx(query string, args ...interface{}) (*DataSetList, error) TimeFuncPass := time.Since(TimeFuncStart) if slf.IsPrintTimeLog(TimeFuncPass) { - service.GetLogger().Printf(service.LEVER_INFO, "DBModule Tx QueryEx Time %s , Query :%s , args :%+v", TimeFuncPass, query, args) + service.GetLogger().Printf(service.LEVER_INFO, "Tx QueryEx Time %s , Query :%s , args :%+v", TimeFuncPass, query, args) } if err != nil { service.GetLogger().Printf(service.LEVER_ERROR, "Tx Query:%s(%v)", query, err) @@ -895,7 +903,7 @@ func (slf *Tx) QueryEx(query string, args ...interface{}) (*DataSetList, error) if hasRet == false { if rows.Err() != nil { - service.GetLogger().Printf(service.LEVER_ERROR, "Query:%s(%+v)", query, rows) + service.GetLogger().Printf(service.LEVER_ERROR, "Tx Query:%s(%+v)", query, rows) } break } @@ -904,7 +912,7 @@ func (slf *Tx) QueryEx(query string, args ...interface{}) (*DataSetList, error) return &datasetList, nil } -// Exec ... +// Exec executes a query that doesn't return rows. func (slf *Tx) Exec(query string, args ...interface{}) (*DBResultEx, error) { ret := &DBResultEx{} if slf.tx == nil { @@ -922,10 +930,10 @@ func (slf *Tx) Exec(query string, args ...interface{}) (*DBResultEx, error) { res, err := slf.tx.Exec(query, args...) TimeFuncPass := time.Since(TimeFuncStart) if slf.IsPrintTimeLog(TimeFuncPass) { - service.GetLogger().Printf(service.LEVER_INFO, "DBModule QueryEx Time %s , Query :%s , args :%+v", TimeFuncPass, query, args) + service.GetLogger().Printf(service.LEVER_INFO, "Tx QueryEx Time %s , Query :%s , args :%+v", TimeFuncPass, query, args) } if err != nil { - service.GetLogger().Printf(service.LEVER_ERROR, "Exec:%s(%v)", query, err) + service.GetLogger().Printf(service.LEVER_ERROR, "Tx Exec:%s(%v)", query, err) return nil, err }