diff --git a/sysmodule/DBModule.go b/sysmodule/DBModule.go index 8d2d0fc..07381ba 100644 --- a/sysmodule/DBModule.go +++ b/sysmodule/DBModule.go @@ -34,6 +34,12 @@ type DBModule struct { syncCoroutineNum int } +// Tx ... +type Tx struct { + tx *sql.Tx + PrintTime time.Duration +} + // DBResult ... type DBResult struct { Err error @@ -476,7 +482,7 @@ func (slf *DBModule) QueryEx(query string, args ...interface{}) (*DataSetList, e hasRet := rows.NextResultSet() if hasRet == false { - if rows.Err()!= nil { + if rows.Err() != nil { service.GetLogger().Printf(service.LEVER_ERROR, "Query:%s(%+v)", query, rows) } break @@ -728,3 +734,203 @@ func (slf *DataSetList) rowData2interface(rowIdx int, m map[string][]interface{} } return nil } + +func (slf *DBModule) GetTx() (*Tx, error) { + var txDBMoudule Tx + txdb, err := slf.db.Begin() + if err != nil { + return &txDBMoudule, err + } + txDBMoudule.tx = txdb + return &txDBMoudule, nil +} + +func (slf *Tx) Rollback() error { + return slf.tx.Rollback() +} + +func (slf *Tx) Commit() error { + return slf.tx.Commit() +} + +func (slf *Tx) CheckArgs(args ...interface{}) error { + for _, val := range args { + if reflect.TypeOf(val).Kind() == reflect.String { + retVal := val.(string) + if strings.Contains(retVal, "-") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(retVal, "#") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(retVal, "&") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(retVal, "=") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(retVal, "%") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(retVal, "'") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(strings.ToLower(retVal), "delete ") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(strings.ToLower(retVal), "truncate ") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(strings.ToLower(retVal), " or ") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(strings.ToLower(retVal), "from ") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(strings.ToLower(retVal), "set ") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + } + } + + return nil +} + +func (slf *Tx) Query(query string, args ...interface{}) DBResult { + if slf.CheckArgs(args) != nil { + ret := DBResult{} + service.GetLogger().Printf(service.LEVER_ERROR, "CheckArgs is error :%s", query) + ret.Err = fmt.Errorf("CheckArgs is error!") + return ret + } + + if slf.tx == nil { + ret := DBResult{} + service.GetLogger().Printf(service.LEVER_ERROR, "cannot connect database:%s", query) + ret.Err = fmt.Errorf("cannot connect database!") + return ret + } + + rows, err := slf.tx.Query(query, args...) + if err != nil { + service.GetLogger().Printf(service.LEVER_ERROR, "Query:%s(%v)", query, err) + } + + return DBResult{ + Err: err, + res: rows, + tag: "json", + blur: true, + } +} + +func (slf *Tx) IsPrintTimeLog(Time time.Duration) bool { + if slf.PrintTime != 0 && Time >= slf.PrintTime { + return true + } + return false +} + +func (slf *Tx) QueryEx(query string, args ...interface{}) (*DataSetList, error) { + datasetList := DataSetList{} + datasetList.tag = "json" + datasetList.blur = true + + if slf.CheckArgs(args) != nil { + service.GetLogger().Printf(service.LEVER_ERROR, "CheckArgs is error :%s", query) + return &datasetList, fmt.Errorf("CheckArgs is error!") + } + + if slf.tx == nil { + service.GetLogger().Printf(service.LEVER_ERROR, "cannot connect database:%s", query) + return &datasetList, fmt.Errorf("cannot connect database!") + } + + TimeFuncStart := time.Now() + rows, err := slf.tx.Query(query, args...) + TimeFuncPass := time.Since(TimeFuncStart) + + if slf.IsPrintTimeLog(TimeFuncPass) { + service.GetLogger().Printf(service.LEVER_INFO, "DBModule Tx QueryEx Time %s , Query :%s , args :%+v", TimeFuncPass, query, args) + } + if err != nil { + service.GetLogger().Printf(service.LEVER_ERROR, "Tx Query:%s(%v)", query, err) + if rows != nil { + rows.Close() + } + return &datasetList, err + } + defer rows.Close() + + for { + dbResult := DBResultEx{} + //取出当前结果集所有行 + for rows.Next() { + if dbResult.RowInfo == nil { + dbResult.RowInfo = make(map[string][]interface{}) + } + //RowInfo map[string][][]sql.NullString //map[fieldname][row][column]sql.NullString + colField, err := rows.Columns() + if err != nil { + return &datasetList, err + } + count := len(colField) + valuePtrs := make([]interface{}, count) + for i := 0; i < count; i++ { + valuePtrs[i] = &sql.NullString{} + } + rows.Scan(valuePtrs...) + + for idx, fieldname := range colField { + fieldRowData := dbResult.RowInfo[strings.ToLower(fieldname)] + fieldRowData = append(fieldRowData, valuePtrs[idx]) + dbResult.RowInfo[strings.ToLower(fieldname)] = fieldRowData + } + dbResult.rowNum += 1 + } + + datasetList.dataSetList = append(datasetList.dataSetList, dbResult) + //取下一个结果集 + hasRet := rows.NextResultSet() + + if hasRet == false { + if rows.Err() != nil { + service.GetLogger().Printf(service.LEVER_ERROR, "Query:%s(%+v)", query, rows) + } + break + } + } + + return &datasetList, nil +} + +// Exec ... +func (slf *Tx) Exec(query string, args ...interface{}) (*DBResultEx, error) { + ret := &DBResultEx{} + if slf.tx == nil { + service.GetLogger().Printf(service.LEVER_ERROR, "cannot connect database:%s", query) + return ret, fmt.Errorf("cannot connect database!") + } + + if slf.CheckArgs(args) != nil { + service.GetLogger().Printf(service.LEVER_ERROR, "CheckArgs is error :%s", query) + //return ret, fmt.Errorf("cannot connect database!") + return ret, fmt.Errorf("CheckArgs is error!") + } + + TimeFuncStart := time.Now() + res, err := slf.tx.Exec(query, args...) + TimeFuncPass := time.Since(TimeFuncStart) + if slf.IsPrintTimeLog(TimeFuncPass) { + service.GetLogger().Printf(service.LEVER_INFO, "DBModule QueryEx Time %s , Query :%s , args :%+v", TimeFuncPass, query, args) + } + if err != nil { + service.GetLogger().Printf(service.LEVER_ERROR, "Exec:%s(%v)", query, err) + return nil, err + } + + ret.LastInsertID, _ = res.LastInsertId() + ret.RowsAffected, _ = res.RowsAffected() + + return ret, nil +} diff --git a/sysmodule/DBModule_test.go b/sysmodule/DBModule_test.go index ada9031..1eba7f1 100644 --- a/sysmodule/DBModule_test.go +++ b/sysmodule/DBModule_test.go @@ -1,6 +1,7 @@ package sysmodule_test import ( + "fmt" "sync" "testing" @@ -13,38 +14,76 @@ func TestDBModule(t *testing.T) { db.ExitChan = make(chan bool) db.WaitGroup = new(sync.WaitGroup) - db.Init(100, "192.168.0.5:3306", "root", "Root!!2018", "QuantFundsDB") + // db.Init(100, "192.168.0.5:3306", "root", "Root!!2018", "QuantFundsDB") + db.Init(100, "127.0.0.1:3306", "root", "zgh50221", "rebort_message") db.OnInit() - res, err := db.QueryEx("select * from tbl_fun_heelthrow where id >= 1") + tx, err := db.GetTx() if err != nil { - t.Error(err) + fmt.Println("err 1", err) + return + } + res, err := tx.QueryEx("select id as Id, info_type as InfoType, info_type_Name as InfoTypeName from tbl_info_type where id >= 1") + if err != nil { + fmt.Println("err 2", err) + tx.Rollback() + return } out := []struct { - Addtime int64 `json:"addtime"` - Tname string `json:"tname"` - Uuid string `json:"uuid,omitempty"` - AAAA string `json:"xxx"` + Id int64 + InfoType string + InfoTypeName string }{} err = res.UnMarshal(&out) if err != nil { - t.Error(err) + fmt.Println("err 3", err) + tx.Rollback() + return } - - sres := db.SyncQuery("select * from tbl_fun_heelthrow where id >= 1") - res, err = sres.Get(2000) + fmt.Println(out) + _, err = tx.Exec("insert into tbl_info_type(info_type, info_type_name) VALUES (?, ?)", "4", "weibo") if err != nil { - t.Error(err) + fmt.Println("err 4", err) + tx.Rollback() + return } - - out2 := []struct { - Addtime int64 `json:"addtime"` - Tname string `json:"tname"` - Uuid string `json:"uuid,omitempty"` - AAAA string `json:"xxx"` - }{} - - err = res.UnMarshal(&out2) + _, err = tx.Exec("update tbl_info_type set info_types = ? Where id = ?", "5", 0) if err != nil { - t.Error(err) + fmt.Println("err 4", err) + tx.Rollback() + return } + + tx.Commit() + // res, err := db.QueryEx("select * from tbl_fun_heelthrow where id >= 1") + // if err != nil { + // t.Error(err) + // } + // out := []struct { + // Addtime int64 `json:"addtime"` + // Tname string `json:"tname"` + // Uuid string `json:"uuid,omitempty"` + // AAAA string `json:"xxx"` + // }{} + // err = res.UnMarshal(&out) + // if err != nil { + // t.Error(err) + // } + + // sres := db.SyncQuery("select * from tbl_fun_heelthrow where id >= 1") + // res, err = sres.Get(2000) + // if err != nil { + // t.Error(err) + // } + + // out2 := []struct { + // Addtime int64 `json:"addtime"` + // Tname string `json:"tname"` + // Uuid string `json:"uuid,omitempty"` + // AAAA string `json:"xxx"` + // }{} + + // err = res.UnMarshal(&out2) + // if err != nil { + // t.Error(err) + // } }