package mysqlmodule import ( "database/sql" "errors" "fmt" "github.com/duanhf2012/origin/v2/log" "net/url" "reflect" "strconv" "strings" "sync" "time" "github.com/duanhf2012/origin/v2/service" _ "github.com/go-sql-driver/mysql" ) type SyncFun func() type DBExecute struct { syncExecuteFun chan SyncFun syncExecuteExit chan bool } type PingExecute struct { tickerPing *time.Ticker pintExit chan bool } type MySQLModule struct { service.Module db *sql.DB url string username string password string dbname string slowDuration time.Duration pingCoroutine PingExecute waitGroup sync.WaitGroup } // Tx ... type Tx struct { tx *sql.Tx slowDuration time.Duration } // DBResult ... type DBResult struct { LastInsertID int64 RowsAffected int64 rowNum int RowInfo map[string][]interface{} //map[fieldName][row]sql.NullString } type DataSetList struct { dataSetList []DBResult currentDataSetIdx int32 tag string blur bool } type dbControl interface { Exec(query string, args ...interface{}) (sql.Result, error) Query(query string, args ...interface{}) (*sql.Rows, error) } func (m *MySQLModule) Init(url string, userName string, password string, dbname string, maxConn int) error { m.url = url m.username = userName m.password = password m.dbname = dbname m.pingCoroutine = PingExecute{tickerPing: time.NewTicker(5 * time.Second), pintExit: make(chan bool, 1)} return m.connect(maxConn) } func (m *MySQLModule) SetQuerySlowTime(slowDuration time.Duration) { m.slowDuration = slowDuration } func (m *MySQLModule) Query(strQuery string, args ...interface{}) (*DataSetList, error) { return query(m.slowDuration, m.db, strQuery, args...) } // Exec ... func (m *MySQLModule) Exec(strSql string, args ...interface{}) (*DBResult, error) { return exec(m.slowDuration, m.db, strSql, args...) } // Begin starts a transaction. func (m *MySQLModule) Begin() (*Tx, error) { var txDBModule Tx txDb, err := m.db.Begin() if err != nil { log.Error("Begin error", log.ErrorField("err",err)) return &txDBModule, err } txDBModule.slowDuration = m.slowDuration txDBModule.tx = txDb return &txDBModule, nil } // Rollback aborts the transaction. func (slf *Tx) Rollback() error { return slf.tx.Rollback() } // Commit commits the transaction. func (slf *Tx) Commit() error { return slf.tx.Commit() } // Query executes a query that return rows. func (slf *Tx) Query(strQuery string, args ...interface{}) (*DataSetList, error) { return query(slf.slowDuration, slf.tx, strQuery, args...) } // Exec executes a query that doesn't return rows. func (slf *Tx) Exec(strSql string, args ...interface{}) (*DBResult, error) { return exec(slf.slowDuration, slf.tx, strSql, args...) } // Connect ... func (m *MySQLModule) connect(maxConn int) error { cmd := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8&parseTime=true&loc=%s&readTimeout=30s&timeout=15s&writeTimeout=30s", m.username, m.password, m.url, m.dbname, url.QueryEscape(time.Local.String())) db, err := sql.Open("mysql", cmd) if err != nil { return err } err = db.Ping() if err != nil { db.Close() return err } m.db = db db.SetMaxOpenConns(maxConn) db.SetMaxIdleConns(maxConn) db.SetConnMaxLifetime(time.Second * 90) go m.runPing() return nil } func (m *MySQLModule) runPing() { for { select { case <-m.pingCoroutine.pintExit: log.Error("RunPing stopping",log.String("url", m.url),log.String("dbname", m.dbname)) return case <-m.pingCoroutine.tickerPing.C: if m.db != nil { m.db.Ping() } } } } func checkArgs(args ...interface{}) error { for _, val := range args { if reflect.TypeOf(val).Kind() == reflect.String { retVal := val.(string) if strings.Contains(retVal, "-") == true { return fmt.Errorf("CheckArgs is error arg is %+v", retVal) } if strings.Contains(retVal, "#") == true { return fmt.Errorf("CheckArgs is error arg is %+v", retVal) } if strings.Contains(retVal, "&") == true { return fmt.Errorf("CheckArgs is error arg is %+v", retVal) } if strings.Contains(retVal, "=") == true { return fmt.Errorf("CheckArgs is error arg is %+v", retVal) } if strings.Contains(retVal, "%") == true { return fmt.Errorf("CheckArgs is error arg is %+v", retVal) } if strings.Contains(retVal, "'") == true { return fmt.Errorf("CheckArgs is error arg is %+v", retVal) } if strings.Contains(strings.ToLower(retVal), "delete ") == true { return fmt.Errorf("CheckArgs is error arg is %+v", retVal) } if strings.Contains(strings.ToLower(retVal), "truncate ") == true { return fmt.Errorf("CheckArgs is error arg is %+v", retVal) } if strings.Contains(strings.ToLower(retVal), " or ") == true { return fmt.Errorf("CheckArgs is error arg is %+v", retVal) } if strings.Contains(strings.ToLower(retVal), "from ") == true { return fmt.Errorf("CheckArgs is error arg is %+v", retVal) } if strings.Contains(strings.ToLower(retVal), "set ") == true { return fmt.Errorf("CheckArgs is error arg is %+v", retVal) } } } return nil } func checkSlow(slowDuration time.Duration, Time time.Duration) bool { if slowDuration != 0 && Time >= slowDuration { return true } return false } func query(slowDuration time.Duration, db dbControl, strQuery string, args ...interface{}) (*DataSetList, error) { datasetList := DataSetList{} datasetList.tag = "json" datasetList.blur = true if checkArgs(args) != nil { log.Error("CheckArgs is error",log.String("sql",strQuery)) return &datasetList, fmt.Errorf("checkArgs is error") } if db == nil { log.Error("cannot connect database",log.String("sql", strQuery)) return &datasetList, fmt.Errorf("cannot connect database") } TimeFuncStart := time.Now() rows, err := db.Query(strQuery, args...) timeFuncPass := time.Since(TimeFuncStart) if checkSlow(slowDuration, timeFuncPass) { log.Error("Query slow",log.Int64("time_ms",timeFuncPass.Milliseconds()),log.String("sql", strQuery), log.Any("args",args)) } if err != nil { log.Error("Query error", log.String("sql",strQuery),log.ErrorField("err",err)) if rows != nil { rows.Close() } return &datasetList, err } defer rows.Close() for { dbResult := DBResult{} //取出当前结果集所有行 for rows.Next() { if dbResult.RowInfo == nil { dbResult.RowInfo = make(map[string][]interface{}) } colField, err := rows.Columns() if err != nil { return &datasetList, err } count := len(colField) valuePtrs := make([]interface{}, count) for i := 0; i < count; i++ { valuePtrs[i] = &sql.NullString{} } rows.Scan(valuePtrs...) for idx, fieldName := range colField { fieldRowData := dbResult.RowInfo[strings.ToLower(fieldName)] fieldRowData = append(fieldRowData, valuePtrs[idx]) dbResult.RowInfo[strings.ToLower(fieldName)] = fieldRowData } dbResult.rowNum += 1 } datasetList.dataSetList = append(datasetList.dataSetList, dbResult) //取下一个结果集 hasRet := rows.NextResultSet() if hasRet == false { if rowErr :=rows.Err();rowErr != nil { log.Error("NextResultSet error", log.String("sql",strQuery), log.ErrorField("err",rowErr)) } break } } return &datasetList, nil } func exec(slowDuration time.Duration, db dbControl, strSql string, args ...interface{}) (*DBResult, error) { ret := &DBResult{} if db == nil { log.Error("cannot connect database", log.String("sql",strSql)) return ret, fmt.Errorf("cannot connect database") } if checkArgs(args) != nil { log.Error("CheckArgs is error", log.String("sql",strSql)) return ret, fmt.Errorf("checkArgs is error") } TimeFuncStart := time.Now() res, err := db.Exec(strSql, args...) timeFuncPass := time.Since(TimeFuncStart) if checkSlow(slowDuration, timeFuncPass) { log.Error("Exec slow",log.Int64("time_ms",timeFuncPass.Milliseconds()),log.String("sql",strSql),log.Any("args",args) ) } if err != nil { log.Error("Exec error",log.String("sql",strSql),log.ErrorField("err", err)) return nil, err } ret.LastInsertID, _ = res.LastInsertId() ret.RowsAffected, _ = res.RowsAffected() return ret, nil } func (ds *DataSetList) UnMarshal(args ...interface{}) error { if len(ds.dataSetList) != len(args) { return errors.New(fmt.Sprintf("Data set len(%d,%d) is not equal to args!", len(ds.dataSetList), len(args))) } for _, out := range args { v := reflect.ValueOf(out) if v.Kind() != reflect.Ptr { return errors.New("interface must be a pointer") } if v.Kind() != reflect.Ptr { return errors.New("interface must be a pointer") } if v.Elem().Kind() == reflect.Struct { err := ds.rowData2interface(0, ds.dataSetList[ds.currentDataSetIdx].RowInfo, v) if err != nil { return err } } if v.Elem().Kind() == reflect.Slice { err := ds.slice2interface(out) if err != nil { return err } } ds.currentDataSetIdx = ds.currentDataSetIdx + 1 } return nil } func (ds *DataSetList) slice2interface(in interface{}) error { length := ds.dataSetList[ds.currentDataSetIdx].rowNum if length == 0 { return nil } v := reflect.ValueOf(in).Elem() newV := reflect.MakeSlice(v.Type(), 0, length) v.Set(newV) v.SetLen(length) for i := 0; i < length; i++ { idxV := v.Index(i) if idxV.Kind() == reflect.Ptr { newObj := reflect.New(idxV.Type().Elem()) v.Index(i).Set(newObj) idxV = newObj } else { idxV = idxV.Addr() } err := ds.rowData2interface(i, ds.dataSetList[ds.currentDataSetIdx].RowInfo, idxV) if err != nil { return err } } return nil } func (ds *DataSetList) rowData2interface(rowIdx int, m map[string][]interface{}, v reflect.Value) error { t := v.Type() val := v.Elem() typ := t.Elem() if !val.IsValid() { return errors.New("incorrect data type") } for i := 0; i < val.NumField(); i++ { value := val.Field(i) kind := value.Kind() tag := typ.Field(i).Tag.Get(ds.tag) if tag == "" { tag = typ.Field(i).Name } if tag != "" && tag != "-" { vTag := strings.ToLower(tag) columnData, ok := m[vTag] if ok == false { if !ds.blur { return fmt.Errorf("cannot find filed name %s", vTag) } continue } if len(columnData) <= rowIdx { return fmt.Errorf("datasource column is error %s", tag) } meta := columnData[rowIdx].(*sql.NullString) if !ok { if !ds.blur { return fmt.Errorf("no corresponding field was found in the result set %s", tag) } continue } if !value.CanSet() { return errors.New("struct fields do not have read or write permissions") } if meta.Valid == false { continue } if len(meta.String) == 0 { continue } switch kind { case reflect.String: value.SetString(meta.String) case reflect.Float32, reflect.Float64: f, err := strconv.ParseFloat(meta.String, 64) if err != nil { return err } value.SetFloat(f) case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int: integer64, err := strconv.ParseInt(meta.String, 10, 64) if err != nil { return err } value.SetInt(integer64) case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: integer64, err := strconv.ParseUint(meta.String, 10, 64) if err != nil { return err } value.SetUint(integer64) case reflect.Bool: b, err := strconv.ParseBool(meta.String) if err != nil { return err } value.SetBool(b) default: return errors.New("the database map has unrecognized data types") } } } return nil }