From ed78b8985daa4b3b39301e47d543a2ab6069367f Mon Sep 17 00:00:00 2001 From: zhengguanghao Date: Tue, 17 Sep 2019 00:44:04 +0800 Subject: [PATCH 1/3] add Tx struct and some method for tx --- sysmodule/DBModule.go | 208 ++++++++++++++++++++++++++++++++++++- sysmodule/DBModule_test.go | 83 +++++++++++---- 2 files changed, 268 insertions(+), 23 deletions(-) diff --git a/sysmodule/DBModule.go b/sysmodule/DBModule.go index 8d2d0fc..07381ba 100644 --- a/sysmodule/DBModule.go +++ b/sysmodule/DBModule.go @@ -34,6 +34,12 @@ type DBModule struct { syncCoroutineNum int } +// Tx ... +type Tx struct { + tx *sql.Tx + PrintTime time.Duration +} + // DBResult ... type DBResult struct { Err error @@ -476,7 +482,7 @@ func (slf *DBModule) QueryEx(query string, args ...interface{}) (*DataSetList, e hasRet := rows.NextResultSet() if hasRet == false { - if rows.Err()!= nil { + if rows.Err() != nil { service.GetLogger().Printf(service.LEVER_ERROR, "Query:%s(%+v)", query, rows) } break @@ -728,3 +734,203 @@ func (slf *DataSetList) rowData2interface(rowIdx int, m map[string][]interface{} } return nil } + +func (slf *DBModule) GetTx() (*Tx, error) { + var txDBMoudule Tx + txdb, err := slf.db.Begin() + if err != nil { + return &txDBMoudule, err + } + txDBMoudule.tx = txdb + return &txDBMoudule, nil +} + +func (slf *Tx) Rollback() error { + return slf.tx.Rollback() +} + +func (slf *Tx) Commit() error { + return slf.tx.Commit() +} + +func (slf *Tx) CheckArgs(args ...interface{}) error { + for _, val := range args { + if reflect.TypeOf(val).Kind() == reflect.String { + retVal := val.(string) + if strings.Contains(retVal, "-") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(retVal, "#") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(retVal, "&") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(retVal, "=") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(retVal, "%") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(retVal, "'") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(strings.ToLower(retVal), "delete ") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(strings.ToLower(retVal), "truncate ") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(strings.ToLower(retVal), " or ") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(strings.ToLower(retVal), "from ") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(strings.ToLower(retVal), "set ") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + } + } + + return nil +} + +func (slf *Tx) Query(query string, args ...interface{}) DBResult { + if slf.CheckArgs(args) != nil { + ret := DBResult{} + service.GetLogger().Printf(service.LEVER_ERROR, "CheckArgs is error :%s", query) + ret.Err = fmt.Errorf("CheckArgs is error!") + return ret + } + + if slf.tx == nil { + ret := DBResult{} + service.GetLogger().Printf(service.LEVER_ERROR, "cannot connect database:%s", query) + ret.Err = fmt.Errorf("cannot connect database!") + return ret + } + + rows, err := slf.tx.Query(query, args...) + if err != nil { + service.GetLogger().Printf(service.LEVER_ERROR, "Query:%s(%v)", query, err) + } + + return DBResult{ + Err: err, + res: rows, + tag: "json", + blur: true, + } +} + +func (slf *Tx) IsPrintTimeLog(Time time.Duration) bool { + if slf.PrintTime != 0 && Time >= slf.PrintTime { + return true + } + return false +} + +func (slf *Tx) QueryEx(query string, args ...interface{}) (*DataSetList, error) { + datasetList := DataSetList{} + datasetList.tag = "json" + datasetList.blur = true + + if slf.CheckArgs(args) != nil { + service.GetLogger().Printf(service.LEVER_ERROR, "CheckArgs is error :%s", query) + return &datasetList, fmt.Errorf("CheckArgs is error!") + } + + if slf.tx == nil { + service.GetLogger().Printf(service.LEVER_ERROR, "cannot connect database:%s", query) + return &datasetList, fmt.Errorf("cannot connect database!") + } + + TimeFuncStart := time.Now() + rows, err := slf.tx.Query(query, args...) + TimeFuncPass := time.Since(TimeFuncStart) + + if slf.IsPrintTimeLog(TimeFuncPass) { + service.GetLogger().Printf(service.LEVER_INFO, "DBModule Tx QueryEx Time %s , Query :%s , args :%+v", TimeFuncPass, query, args) + } + if err != nil { + service.GetLogger().Printf(service.LEVER_ERROR, "Tx Query:%s(%v)", query, err) + if rows != nil { + rows.Close() + } + return &datasetList, err + } + defer rows.Close() + + for { + dbResult := DBResultEx{} + //取出当前结果集所有行 + for rows.Next() { + if dbResult.RowInfo == nil { + dbResult.RowInfo = make(map[string][]interface{}) + } + //RowInfo map[string][][]sql.NullString //map[fieldname][row][column]sql.NullString + colField, err := rows.Columns() + if err != nil { + return &datasetList, err + } + count := len(colField) + valuePtrs := make([]interface{}, count) + for i := 0; i < count; i++ { + valuePtrs[i] = &sql.NullString{} + } + rows.Scan(valuePtrs...) + + for idx, fieldname := range colField { + fieldRowData := dbResult.RowInfo[strings.ToLower(fieldname)] + fieldRowData = append(fieldRowData, valuePtrs[idx]) + dbResult.RowInfo[strings.ToLower(fieldname)] = fieldRowData + } + dbResult.rowNum += 1 + } + + datasetList.dataSetList = append(datasetList.dataSetList, dbResult) + //取下一个结果集 + hasRet := rows.NextResultSet() + + if hasRet == false { + if rows.Err() != nil { + service.GetLogger().Printf(service.LEVER_ERROR, "Query:%s(%+v)", query, rows) + } + break + } + } + + return &datasetList, nil +} + +// Exec ... +func (slf *Tx) Exec(query string, args ...interface{}) (*DBResultEx, error) { + ret := &DBResultEx{} + if slf.tx == nil { + service.GetLogger().Printf(service.LEVER_ERROR, "cannot connect database:%s", query) + return ret, fmt.Errorf("cannot connect database!") + } + + if slf.CheckArgs(args) != nil { + service.GetLogger().Printf(service.LEVER_ERROR, "CheckArgs is error :%s", query) + //return ret, fmt.Errorf("cannot connect database!") + return ret, fmt.Errorf("CheckArgs is error!") + } + + TimeFuncStart := time.Now() + res, err := slf.tx.Exec(query, args...) + TimeFuncPass := time.Since(TimeFuncStart) + if slf.IsPrintTimeLog(TimeFuncPass) { + service.GetLogger().Printf(service.LEVER_INFO, "DBModule QueryEx Time %s , Query :%s , args :%+v", TimeFuncPass, query, args) + } + if err != nil { + service.GetLogger().Printf(service.LEVER_ERROR, "Exec:%s(%v)", query, err) + return nil, err + } + + ret.LastInsertID, _ = res.LastInsertId() + ret.RowsAffected, _ = res.RowsAffected() + + return ret, nil +} diff --git a/sysmodule/DBModule_test.go b/sysmodule/DBModule_test.go index ada9031..1eba7f1 100644 --- a/sysmodule/DBModule_test.go +++ b/sysmodule/DBModule_test.go @@ -1,6 +1,7 @@ package sysmodule_test import ( + "fmt" "sync" "testing" @@ -13,38 +14,76 @@ func TestDBModule(t *testing.T) { db.ExitChan = make(chan bool) db.WaitGroup = new(sync.WaitGroup) - db.Init(100, "192.168.0.5:3306", "root", "Root!!2018", "QuantFundsDB") + // db.Init(100, "192.168.0.5:3306", "root", "Root!!2018", "QuantFundsDB") + db.Init(100, "127.0.0.1:3306", "root", "zgh50221", "rebort_message") db.OnInit() - res, err := db.QueryEx("select * from tbl_fun_heelthrow where id >= 1") + tx, err := db.GetTx() if err != nil { - t.Error(err) + fmt.Println("err 1", err) + return + } + res, err := tx.QueryEx("select id as Id, info_type as InfoType, info_type_Name as InfoTypeName from tbl_info_type where id >= 1") + if err != nil { + fmt.Println("err 2", err) + tx.Rollback() + return } out := []struct { - Addtime int64 `json:"addtime"` - Tname string `json:"tname"` - Uuid string `json:"uuid,omitempty"` - AAAA string `json:"xxx"` + Id int64 + InfoType string + InfoTypeName string }{} err = res.UnMarshal(&out) if err != nil { - t.Error(err) + fmt.Println("err 3", err) + tx.Rollback() + return } - - sres := db.SyncQuery("select * from tbl_fun_heelthrow where id >= 1") - res, err = sres.Get(2000) + fmt.Println(out) + _, err = tx.Exec("insert into tbl_info_type(info_type, info_type_name) VALUES (?, ?)", "4", "weibo") if err != nil { - t.Error(err) + fmt.Println("err 4", err) + tx.Rollback() + return } - - out2 := []struct { - Addtime int64 `json:"addtime"` - Tname string `json:"tname"` - Uuid string `json:"uuid,omitempty"` - AAAA string `json:"xxx"` - }{} - - err = res.UnMarshal(&out2) + _, err = tx.Exec("update tbl_info_type set info_types = ? Where id = ?", "5", 0) if err != nil { - t.Error(err) + fmt.Println("err 4", err) + tx.Rollback() + return } + + tx.Commit() + // res, err := db.QueryEx("select * from tbl_fun_heelthrow where id >= 1") + // if err != nil { + // t.Error(err) + // } + // out := []struct { + // Addtime int64 `json:"addtime"` + // Tname string `json:"tname"` + // Uuid string `json:"uuid,omitempty"` + // AAAA string `json:"xxx"` + // }{} + // err = res.UnMarshal(&out) + // if err != nil { + // t.Error(err) + // } + + // sres := db.SyncQuery("select * from tbl_fun_heelthrow where id >= 1") + // res, err = sres.Get(2000) + // if err != nil { + // t.Error(err) + // } + + // out2 := []struct { + // Addtime int64 `json:"addtime"` + // Tname string `json:"tname"` + // Uuid string `json:"uuid,omitempty"` + // AAAA string `json:"xxx"` + // }{} + + // err = res.UnMarshal(&out2) + // if err != nil { + // t.Error(err) + // } } From 834f3ff062787205bdfd94169c2b803000b78436 Mon Sep 17 00:00:00 2001 From: zhengguanghao Date: Tue, 17 Sep 2019 10:06:33 +0800 Subject: [PATCH 2/3] add comment --- sysmodule/DBModule.go | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/sysmodule/DBModule.go b/sysmodule/DBModule.go index 07381ba..450d7d3 100644 --- a/sysmodule/DBModule.go +++ b/sysmodule/DBModule.go @@ -735,24 +735,29 @@ func (slf *DataSetList) rowData2interface(rowIdx int, m map[string][]interface{} return nil } -func (slf *DBModule) GetTx() (*Tx, error) { +// Begin starts a transaction. +func (slf *DBModule) Begin() (*Tx, error) { var txDBMoudule Tx txdb, err := slf.db.Begin() if err != nil { + service.GetLogger().Printf(service.LEVER_ERROR, "Begin error:%s", err.Error()) return &txDBMoudule, err } txDBMoudule.tx = txdb return &txDBMoudule, nil } +// Rollback aborts the transaction. func (slf *Tx) Rollback() error { return slf.tx.Rollback() } +// Commit commits the transaction. func (slf *Tx) Commit() error { return slf.tx.Commit() } +// CheckArgs... func (slf *Tx) CheckArgs(args ...interface{}) error { for _, val := range args { if reflect.TypeOf(val).Kind() == reflect.String { @@ -796,6 +801,7 @@ func (slf *Tx) CheckArgs(args ...interface{}) error { return nil } +// Query executes a query that returns rows, typically a SELECT. func (slf *Tx) Query(query string, args ...interface{}) DBResult { if slf.CheckArgs(args) != nil { ret := DBResult{} @@ -813,7 +819,7 @@ func (slf *Tx) Query(query string, args ...interface{}) DBResult { rows, err := slf.tx.Query(query, args...) if err != nil { - service.GetLogger().Printf(service.LEVER_ERROR, "Query:%s(%v)", query, err) + service.GetLogger().Printf(service.LEVER_ERROR, "Tx Query:%s(%v)", query, err) } return DBResult{ @@ -824,6 +830,7 @@ func (slf *Tx) Query(query string, args ...interface{}) DBResult { } } +// IsPrintTimeLog... func (slf *Tx) IsPrintTimeLog(Time time.Duration) bool { if slf.PrintTime != 0 && Time >= slf.PrintTime { return true @@ -831,6 +838,7 @@ func (slf *Tx) IsPrintTimeLog(Time time.Duration) bool { return false } +// QueryEx executes a query that return rows. func (slf *Tx) QueryEx(query string, args ...interface{}) (*DataSetList, error) { datasetList := DataSetList{} datasetList.tag = "json" @@ -851,7 +859,7 @@ func (slf *Tx) QueryEx(query string, args ...interface{}) (*DataSetList, error) TimeFuncPass := time.Since(TimeFuncStart) if slf.IsPrintTimeLog(TimeFuncPass) { - service.GetLogger().Printf(service.LEVER_INFO, "DBModule Tx QueryEx Time %s , Query :%s , args :%+v", TimeFuncPass, query, args) + service.GetLogger().Printf(service.LEVER_INFO, "Tx QueryEx Time %s , Query :%s , args :%+v", TimeFuncPass, query, args) } if err != nil { service.GetLogger().Printf(service.LEVER_ERROR, "Tx Query:%s(%v)", query, err) @@ -895,7 +903,7 @@ func (slf *Tx) QueryEx(query string, args ...interface{}) (*DataSetList, error) if hasRet == false { if rows.Err() != nil { - service.GetLogger().Printf(service.LEVER_ERROR, "Query:%s(%+v)", query, rows) + service.GetLogger().Printf(service.LEVER_ERROR, "Tx Query:%s(%+v)", query, rows) } break } @@ -904,7 +912,7 @@ func (slf *Tx) QueryEx(query string, args ...interface{}) (*DataSetList, error) return &datasetList, nil } -// Exec ... +// Exec executes a query that doesn't return rows. func (slf *Tx) Exec(query string, args ...interface{}) (*DBResultEx, error) { ret := &DBResultEx{} if slf.tx == nil { @@ -922,10 +930,10 @@ func (slf *Tx) Exec(query string, args ...interface{}) (*DBResultEx, error) { res, err := slf.tx.Exec(query, args...) TimeFuncPass := time.Since(TimeFuncStart) if slf.IsPrintTimeLog(TimeFuncPass) { - service.GetLogger().Printf(service.LEVER_INFO, "DBModule QueryEx Time %s , Query :%s , args :%+v", TimeFuncPass, query, args) + service.GetLogger().Printf(service.LEVER_INFO, "Tx QueryEx Time %s , Query :%s , args :%+v", TimeFuncPass, query, args) } if err != nil { - service.GetLogger().Printf(service.LEVER_ERROR, "Exec:%s(%v)", query, err) + service.GetLogger().Printf(service.LEVER_ERROR, "Tx Exec:%s(%v)", query, err) return nil, err } From 8bcb2af8b00db4cf880c74146a4dbec9302b01ac Mon Sep 17 00:00:00 2001 From: Ally Dale Date: Tue, 17 Sep 2019 13:45:06 +0800 Subject: [PATCH 3/3] add auto balancing info --- cluster/cluster.go | 45 ++++++++++++++++++++++++++++++++++++++++-- cluster/config.go | 38 +++++++++++++++++++++++++++++++---- sysmodule/LogModule.go | 13 +++++++++--- 3 files changed, 87 insertions(+), 9 deletions(-) diff --git a/cluster/cluster.go b/cluster/cluster.go index ee98ced..d028e95 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -5,6 +5,7 @@ import ( "math/rand" "net" "os" + "sort" "strings" "time" @@ -274,7 +275,7 @@ func (slf *CCluster) GetNodeList(NodeServiceMethod string, rpcServerMethod *stri servicename = servicename[1:] nodeidList = append(nodeidList, GetNodeId()) } else { - nodeidList = slf.cfg.GetIdByService(servicename) + nodeidList = slf.cfg.GetIdByService(servicename, "") } } else { nodeidList = slf.GetIdByNodeService(nodename, servicename) @@ -293,7 +294,7 @@ func (slf *CCluster) GetNodeList(NodeServiceMethod string, rpcServerMethod *stri //GetNodeIdByServiceName 根据服务名查找nodeid servicename服务名 bOnline是否需要查找在线服务 func (slf *CCluster) GetNodeIdByServiceName(servicename string, bOnline bool) []int { - nodeIDList := slf.cfg.GetIdByService(servicename) + nodeIDList := slf.cfg.GetIdByService(servicename, "") if bOnline { ret := make([]int, 0, len(nodeIDList)) @@ -308,6 +309,40 @@ func (slf *CCluster) GetNodeIdByServiceName(servicename string, bOnline bool) [] return nodeIDList } +//根据Service获取负载均衡信息 +//负载均衡的策略是从配置获取所有配置了该服务的NodeId 并按NodeId排序 每个node负责处理数组index所在的那一部分 +func (slf *CCluster) GetBalancingInfo(currentNodeId int, servicename string, inSubNet bool) (*BalancingInfo, error) { + subNetName := "" + if inSubNet { + if node, ok := slf.cfg.mapIdNode[currentNodeId]; ok { + subNetName = node.SubNetName + } else { + return nil, fmt.Errorf("[cluster.GetBalancingInfo] cannot find node %d", currentNodeId) + } + } + lst := slf.cfg.GetIdByService(servicename, subNetName) + // if len(lst) <= 0 { + // return nil, fmt.Errorf("[cluster.GetBalancingInfo] cannot find service %s in any node", servicename) + // } + sort.Ints(lst) + ret := &BalancingInfo{ + NodeId: currentNodeId, + ServiceName: servicename, + TotalNum: len(lst), + MyIndex: -1, + NodeList: lst, + } + if _, ok := slf.cfg.mapIdNode[currentNodeId]; ok { + for i, v := range lst { + if v == currentNodeId { + ret.MyIndex = i + break + } + } + } + return ret, nil +} + func (slf *CCluster) CheckNodeIsConnectedByID(nodeid int) bool { if nodeid == GetNodeId() { return true @@ -522,6 +557,12 @@ func GetNodeIdByServiceName(serviceName string, bOnline bool) []int { return InstanceClusterMgr().GetNodeIdByServiceName(serviceName, bOnline) } +//获取服务的负载均衡信息 +//负载均衡的策略是从配置获取所有配置了该服务的NodeId 并按NodeId排序 每个node负责处理数组index所在的那一部分 +func GetBalancingInfo(currentNodeId int, servicename string, inSubNet bool) (*BalancingInfo, error) { + return InstanceClusterMgr().GetBalancingInfo(currentNodeId, servicename, inSubNet) +} + //随机选择在线的node发送 func CallRandomService(NodeServiceMethod string, args interface{}, reply interface{}) error { return InstanceClusterMgr().CallRandomService(NodeServiceMethod, args, reply) diff --git a/cluster/config.go b/cluster/config.go index d51a56a..1b66701 100644 --- a/cluster/config.go +++ b/cluster/config.go @@ -8,6 +8,34 @@ import ( "strings" ) +//负载均衡信息 +type BalancingInfo struct { + NodeId int //我的nodeId + ServiceName string //负载均衡的ServiceName + + TotalNum int //总共有多少个协同Node + MyIndex int //负责的index [0, TotalNum) + NodeList []int //所有协同的node列表 按NodeId升序排列 +} + +//判断hash后的Id是否命中我的NodeId +func (slf *BalancingInfo) Hit(hashId int) bool { + if hashId >= 0 && slf.TotalNum > 0 && slf.MyIndex >= 0 { + return hashId%slf.TotalNum == slf.MyIndex + } + return false +} + +//判断命中的NodeId,-1表示无法取得 +func (slf *BalancingInfo) GetHitNodeId(hashId int) int { + if hashId >= 0 && slf.TotalNum > 0 { + if idx := hashId % slf.TotalNum; idx >= 0 && idx < len(slf.NodeList) { + return slf.NodeList[idx] + } + } + return -1 +} + type CNodeCfg struct { NodeID int NodeName string @@ -148,14 +176,16 @@ func ReadCfg(path string, nodeid int, mapNodeData map[int]NodeData) (*ClusterCon return clsCfg, nil } -func (slf *ClusterConfig) GetIdByService(serviceName string) []int { - var nodeidlist []int - nodeidlist = make([]int, 0) +func (slf *ClusterConfig) GetIdByService(serviceName, subNetName string) []int { + var nodeidlist = []int{} nodeList, ok := slf.mapClusterServiceNode[serviceName] if ok == true { + nodeidlist = make([]int, 0, len(nodeList)) for _, v := range nodeList { - nodeidlist = append(nodeidlist, v.NodeID) + if subNetName == "" || subNetName == v.SubNetName { + nodeidlist = append(nodeidlist, v.NodeID) + } } } diff --git a/sysmodule/LogModule.go b/sysmodule/LogModule.go index ba9ff6e..ff6cc14 100644 --- a/sysmodule/LogModule.go +++ b/sysmodule/LogModule.go @@ -55,7 +55,13 @@ func (slf *LogModule) GetCurrentFileName() string { now := time.Now() fpath := filepath.Join("logs") os.MkdirAll(fpath, os.ModePerm) - fname := slf.logfilename + "-" + now.Format("20060102-150405") + ".log" + y, m, d := now.Date() + h := now.Hour() + mm := now.Minute() + mm -= mm % 15 //15分钟内使用同一个日志文件 + dt := y*10000 + int(m)*100 + d + tm := h*100 + mm + fname := fmt.Sprintf("%s-%d-%d.log", slf.logfilename, dt, tm) ret := filepath.Join(fpath, fname) return ret } @@ -84,9 +90,10 @@ func (slf *LogModule) CheckAndGenFile(fileline string) (newFile bool) { } var err error - slf.logFile, err = os.OpenFile(slf.GetCurrentFileName(), os.O_RDWR|os.O_CREATE|os.O_APPEND, os.ModePerm) + filename := slf.GetCurrentFileName() + slf.logFile, err = os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_APPEND, os.ModePerm) if err != nil { - fmt.Printf("create log file %+v error!", slf.GetCurrentFileName()) + fmt.Printf("create log file %+v error!", filename) slf.locker.Unlock() return false }