diff --git a/rpc/client.go b/rpc/client.go index 9a6137f..7865491 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -347,9 +347,9 @@ func (client *Client) Call(serviceMethod string, args interface{}, reply interfa select { case call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1), false).Done: return call.Error - case <-time.After(15 * time.Second): + case <-time.After(30 * time.Second): } //call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done - return errors.New(fmt.Sprintf("Call RPC %s is time out 10s", serviceMethod)) + return errors.New(fmt.Sprintf("Call RPC %s is time out 30s", serviceMethod)) } diff --git a/sysmodule/DBModule.go b/sysmodule/DBModule.go index 2e9f04a..94b75b9 100644 --- a/sysmodule/DBModule.go +++ b/sysmodule/DBModule.go @@ -23,13 +23,13 @@ const ( // DBModule ... type DBModule struct { service.BaseModule - db *sql.DB - url string - username string - password string - dbname string - maxconn int - PrintTime time.Duration + db *sql.DB + url string + username string + password string + dbname string + maxconn int + PrintTime time.Duration syncExecuteFun chan SyncFun syncCoroutineNum int } @@ -241,14 +241,12 @@ func (slf *DBResult) mapSingle2interface(m map[string]string, v reflect.Value) e return nil } - -func (slf *DBModule) SetQuerySlowTime(Time time.Duration){ +func (slf *DBModule) SetQuerySlowTime(Time time.Duration) { slf.PrintTime = Time } - -func (slf *DBModule) IsPrintTimeLog(Time time.Duration)bool{ - if slf.PrintTime != 0 && Time >= slf.PrintTime{ +func (slf *DBModule) IsPrintTimeLog(Time time.Duration) bool { + if slf.PrintTime != 0 && Time >= slf.PrintTime { return true } return false @@ -337,8 +335,58 @@ func (slf *SyncQueryDBResultEx) Get(timeoutMs int) (*DataSetList, error) { return nil, fmt.Errorf("Getting the return result timeout [%d]ms", timeoutMs) } +func (slf *DBModule) CheckArgs(args ...interface{}) error { + for _, val := range args { + if reflect.TypeOf(val).Kind() == reflect.String { + retVal := val.(string) + if strings.Contains(retVal, "-") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(retVal, "#") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(retVal, "&") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(retVal, "=") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(retVal, "%") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(retVal, "'") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(strings.ToLower(retVal), "delete ") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(strings.ToLower(retVal), "truncate ") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(strings.ToLower(retVal), " or ") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(strings.ToLower(retVal), "from ") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(strings.ToLower(retVal), "set ") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + } + } + + return nil +} + // Query ... func (slf *DBModule) Query(query string, args ...interface{}) DBResult { + if slf.CheckArgs(args) != nil { + ret := DBResult{} + service.GetLogger().Printf(service.LEVER_ERROR, "CheckArgs is error :%s", query) + ret.Err = fmt.Errorf("CheckArgs is error!") + return ret + } + if slf.db == nil { ret := DBResult{} service.GetLogger().Printf(service.LEVER_ERROR, "cannot connect database:%s", query) @@ -363,6 +411,11 @@ func (slf *DBModule) QueryEx(query string, args ...interface{}) (*DataSetList, e datasetList.tag = "json" datasetList.blur = true + if slf.CheckArgs(args) != nil { + service.GetLogger().Printf(service.LEVER_ERROR, "CheckArgs is error :%s", query) + return &datasetList, fmt.Errorf("CheckArgs is error!") + } + if slf.db == nil { service.GetLogger().Printf(service.LEVER_ERROR, "cannot connect database:%s", query) return &datasetList, fmt.Errorf("cannot connect database!") @@ -372,8 +425,8 @@ func (slf *DBModule) QueryEx(query string, args ...interface{}) (*DataSetList, e rows, err := slf.db.Query(query, args...) TimeFuncPass := time.Since(TimeFuncStart) - if slf.IsPrintTimeLog(TimeFuncPass) { - service.GetLogger().Printf(service.LEVER_INFO, "DBModule QueryEx Time %s , Query :%s , args :%+v",TimeFuncPass,query,args) + if slf.IsPrintTimeLog(TimeFuncPass) { + service.GetLogger().Printf(service.LEVER_INFO, "DBModule QueryEx Time %s , Query :%s , args :%+v", TimeFuncPass, query, args) } if err != nil { service.GetLogger().Printf(service.LEVER_ERROR, "Query:%s(%v)", query, err) @@ -454,11 +507,17 @@ func (slf *DBModule) Exec(query string, args ...interface{}) (*DBResultEx, error return ret, fmt.Errorf("cannot connect database!") } + if slf.CheckArgs(args) != nil { + service.GetLogger().Printf(service.LEVER_ERROR, "CheckArgs is error :%s", query) + //return ret, fmt.Errorf("cannot connect database!") + return ret, fmt.Errorf("CheckArgs is error!") + } + TimeFuncStart := time.Now() res, err := slf.db.Exec(query, args...) TimeFuncPass := time.Since(TimeFuncStart) if slf.IsPrintTimeLog(TimeFuncPass) { - service.GetLogger().Printf(service.LEVER_INFO, "DBModule QueryEx Time %s , Query :%s , args :%+v",TimeFuncPass,query,args) + service.GetLogger().Printf(service.LEVER_INFO, "DBModule QueryEx Time %s , Query :%s , args :%+v", TimeFuncPass, query, args) } if err != nil { service.GetLogger().Printf(service.LEVER_ERROR, "Exec:%s(%v)", query, err) @@ -514,7 +573,7 @@ func (slf *DBModule) RunExecuteDBCoroutine() { func (slf *DataSetList) UnMarshal(args ...interface{}) error { if len(slf.dataSetList) != len(args) { - return errors.New("Data set len(%d) is not equal to args!") + return errors.New(fmt.Sprintf("Data set len(%d,%d) is not equal to args!", len(slf.dataSetList), len(args))) } for _, out := range args { diff --git a/sysservice/httpserverervice.go b/sysservice/httpserverervice.go index d07eaa1..2725a05 100644 --- a/sysservice/httpserverervice.go +++ b/sysservice/httpserverervice.go @@ -130,23 +130,6 @@ func (slf *HttpServerService) staticServer(w http.ResponseWriter, r *http.Reques w.Write([]byte(msg)) } - // 在这儿处理例外路由接口 - var errRet error - for _, filter := range slf.httpfiltrateList { - ret := filter(r.URL.Path, w, r) - if ret == nil { - errRet = nil - break - } else { - errRet = ret - } - } - - if errRet != nil { - w.Write([]byte(errRet.Error())) - return - } - nowpath, _ := os.Getwd() upath := r.URL.Path destLocalPath := nowpath + upath @@ -163,6 +146,24 @@ func (slf *HttpServerService) staticServer(w http.ResponseWriter, r *http.Reques } //上传资源 case "POST": + + // 在这儿处理例外路由接口 + var errRet error + for _, filter := range slf.httpfiltrateList { + ret := filter(r.URL.Path, w, r) + if ret == nil { + errRet = nil + break + } else { + errRet = ret + } + } + + if errRet != nil { + w.Write([]byte(errRet.Error())) + return + } + r.ParseMultipartForm(32 << 20) // max memory is set to 32MB resourceFile, resourceFileHeader, err := r.FormFile("file") if err != nil { @@ -190,6 +191,7 @@ func (slf *HttpServerService) staticServer(w http.ResponseWriter, r *http.Reques defer localfd.Close() io.Copy(localfd, resourceFile) + writeResp(http.StatusOK, upath+fileName) }