mirror of
https://github.com/duanhf2012/origin.git
synced 2026-02-04 06:54:45 +08:00
add Tx struct and some method for tx
This commit is contained in:
@@ -34,6 +34,12 @@ type DBModule struct {
|
||||
syncCoroutineNum int
|
||||
}
|
||||
|
||||
// Tx ...
|
||||
type Tx struct {
|
||||
tx *sql.Tx
|
||||
PrintTime time.Duration
|
||||
}
|
||||
|
||||
// DBResult ...
|
||||
type DBResult struct {
|
||||
Err error
|
||||
@@ -476,7 +482,7 @@ func (slf *DBModule) QueryEx(query string, args ...interface{}) (*DataSetList, e
|
||||
hasRet := rows.NextResultSet()
|
||||
|
||||
if hasRet == false {
|
||||
if rows.Err()!= nil {
|
||||
if rows.Err() != nil {
|
||||
service.GetLogger().Printf(service.LEVER_ERROR, "Query:%s(%+v)", query, rows)
|
||||
}
|
||||
break
|
||||
@@ -728,3 +734,203 @@ func (slf *DataSetList) rowData2interface(rowIdx int, m map[string][]interface{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (slf *DBModule) GetTx() (*Tx, error) {
|
||||
var txDBMoudule Tx
|
||||
txdb, err := slf.db.Begin()
|
||||
if err != nil {
|
||||
return &txDBMoudule, err
|
||||
}
|
||||
txDBMoudule.tx = txdb
|
||||
return &txDBMoudule, nil
|
||||
}
|
||||
|
||||
func (slf *Tx) Rollback() error {
|
||||
return slf.tx.Rollback()
|
||||
}
|
||||
|
||||
func (slf *Tx) Commit() error {
|
||||
return slf.tx.Commit()
|
||||
}
|
||||
|
||||
func (slf *Tx) CheckArgs(args ...interface{}) error {
|
||||
for _, val := range args {
|
||||
if reflect.TypeOf(val).Kind() == reflect.String {
|
||||
retVal := val.(string)
|
||||
if strings.Contains(retVal, "-") == true {
|
||||
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
||||
}
|
||||
if strings.Contains(retVal, "#") == true {
|
||||
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
||||
}
|
||||
if strings.Contains(retVal, "&") == true {
|
||||
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
||||
}
|
||||
if strings.Contains(retVal, "=") == true {
|
||||
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
||||
}
|
||||
if strings.Contains(retVal, "%") == true {
|
||||
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
||||
}
|
||||
if strings.Contains(retVal, "'") == true {
|
||||
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(retVal), "delete ") == true {
|
||||
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(retVal), "truncate ") == true {
|
||||
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(retVal), " or ") == true {
|
||||
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(retVal), "from ") == true {
|
||||
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(retVal), "set ") == true {
|
||||
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (slf *Tx) Query(query string, args ...interface{}) DBResult {
|
||||
if slf.CheckArgs(args) != nil {
|
||||
ret := DBResult{}
|
||||
service.GetLogger().Printf(service.LEVER_ERROR, "CheckArgs is error :%s", query)
|
||||
ret.Err = fmt.Errorf("CheckArgs is error!")
|
||||
return ret
|
||||
}
|
||||
|
||||
if slf.tx == nil {
|
||||
ret := DBResult{}
|
||||
service.GetLogger().Printf(service.LEVER_ERROR, "cannot connect database:%s", query)
|
||||
ret.Err = fmt.Errorf("cannot connect database!")
|
||||
return ret
|
||||
}
|
||||
|
||||
rows, err := slf.tx.Query(query, args...)
|
||||
if err != nil {
|
||||
service.GetLogger().Printf(service.LEVER_ERROR, "Query:%s(%v)", query, err)
|
||||
}
|
||||
|
||||
return DBResult{
|
||||
Err: err,
|
||||
res: rows,
|
||||
tag: "json",
|
||||
blur: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (slf *Tx) IsPrintTimeLog(Time time.Duration) bool {
|
||||
if slf.PrintTime != 0 && Time >= slf.PrintTime {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (slf *Tx) QueryEx(query string, args ...interface{}) (*DataSetList, error) {
|
||||
datasetList := DataSetList{}
|
||||
datasetList.tag = "json"
|
||||
datasetList.blur = true
|
||||
|
||||
if slf.CheckArgs(args) != nil {
|
||||
service.GetLogger().Printf(service.LEVER_ERROR, "CheckArgs is error :%s", query)
|
||||
return &datasetList, fmt.Errorf("CheckArgs is error!")
|
||||
}
|
||||
|
||||
if slf.tx == nil {
|
||||
service.GetLogger().Printf(service.LEVER_ERROR, "cannot connect database:%s", query)
|
||||
return &datasetList, fmt.Errorf("cannot connect database!")
|
||||
}
|
||||
|
||||
TimeFuncStart := time.Now()
|
||||
rows, err := slf.tx.Query(query, args...)
|
||||
TimeFuncPass := time.Since(TimeFuncStart)
|
||||
|
||||
if slf.IsPrintTimeLog(TimeFuncPass) {
|
||||
service.GetLogger().Printf(service.LEVER_INFO, "DBModule Tx QueryEx Time %s , Query :%s , args :%+v", TimeFuncPass, query, args)
|
||||
}
|
||||
if err != nil {
|
||||
service.GetLogger().Printf(service.LEVER_ERROR, "Tx Query:%s(%v)", query, err)
|
||||
if rows != nil {
|
||||
rows.Close()
|
||||
}
|
||||
return &datasetList, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for {
|
||||
dbResult := DBResultEx{}
|
||||
//取出当前结果集所有行
|
||||
for rows.Next() {
|
||||
if dbResult.RowInfo == nil {
|
||||
dbResult.RowInfo = make(map[string][]interface{})
|
||||
}
|
||||
//RowInfo map[string][][]sql.NullString //map[fieldname][row][column]sql.NullString
|
||||
colField, err := rows.Columns()
|
||||
if err != nil {
|
||||
return &datasetList, err
|
||||
}
|
||||
count := len(colField)
|
||||
valuePtrs := make([]interface{}, count)
|
||||
for i := 0; i < count; i++ {
|
||||
valuePtrs[i] = &sql.NullString{}
|
||||
}
|
||||
rows.Scan(valuePtrs...)
|
||||
|
||||
for idx, fieldname := range colField {
|
||||
fieldRowData := dbResult.RowInfo[strings.ToLower(fieldname)]
|
||||
fieldRowData = append(fieldRowData, valuePtrs[idx])
|
||||
dbResult.RowInfo[strings.ToLower(fieldname)] = fieldRowData
|
||||
}
|
||||
dbResult.rowNum += 1
|
||||
}
|
||||
|
||||
datasetList.dataSetList = append(datasetList.dataSetList, dbResult)
|
||||
//取下一个结果集
|
||||
hasRet := rows.NextResultSet()
|
||||
|
||||
if hasRet == false {
|
||||
if rows.Err() != nil {
|
||||
service.GetLogger().Printf(service.LEVER_ERROR, "Query:%s(%+v)", query, rows)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return &datasetList, nil
|
||||
}
|
||||
|
||||
// Exec ...
|
||||
func (slf *Tx) Exec(query string, args ...interface{}) (*DBResultEx, error) {
|
||||
ret := &DBResultEx{}
|
||||
if slf.tx == nil {
|
||||
service.GetLogger().Printf(service.LEVER_ERROR, "cannot connect database:%s", query)
|
||||
return ret, fmt.Errorf("cannot connect database!")
|
||||
}
|
||||
|
||||
if slf.CheckArgs(args) != nil {
|
||||
service.GetLogger().Printf(service.LEVER_ERROR, "CheckArgs is error :%s", query)
|
||||
//return ret, fmt.Errorf("cannot connect database!")
|
||||
return ret, fmt.Errorf("CheckArgs is error!")
|
||||
}
|
||||
|
||||
TimeFuncStart := time.Now()
|
||||
res, err := slf.tx.Exec(query, args...)
|
||||
TimeFuncPass := time.Since(TimeFuncStart)
|
||||
if slf.IsPrintTimeLog(TimeFuncPass) {
|
||||
service.GetLogger().Printf(service.LEVER_INFO, "DBModule QueryEx Time %s , Query :%s , args :%+v", TimeFuncPass, query, args)
|
||||
}
|
||||
if err != nil {
|
||||
service.GetLogger().Printf(service.LEVER_ERROR, "Exec:%s(%v)", query, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret.LastInsertID, _ = res.LastInsertId()
|
||||
ret.RowsAffected, _ = res.RowsAffected()
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user