diff --git a/sysmodule/DBModule.go b/sysmodule/DBModule.go index 6224c81..5b9a3ad 100644 --- a/sysmodule/DBModule.go +++ b/sysmodule/DBModule.go @@ -241,14 +241,12 @@ func (slf *DBResult) mapSingle2interface(m map[string]string, v reflect.Value) e return nil } - -func (slf *DBModule) SetQuerySlowTime(Time time.Duration){ +func (slf *DBModule) SetQuerySlowTime(Time time.Duration) { slf.PrintTime = Time } - -func (slf *DBModule) IsPrintTimeLog(Time time.Duration)bool{ - if slf.PrintTime != 0 && Time >= slf.PrintTime{ +func (slf *DBModule) IsPrintTimeLog(Time time.Duration) bool { + if slf.PrintTime != 0 && Time >= slf.PrintTime { return true } return false @@ -337,8 +335,52 @@ func (slf *SyncQueryDBResultEx) Get(timeoutMs int) (*DataSetList, error) { return nil, fmt.Errorf("Getting the return result timeout [%d]ms", timeoutMs) } +func (slf *DBModule) CheckArgs(args ...interface{}) error { + for _, val := range args { + if reflect.TypeOf(val).Kind() == reflect.String { + retVal := val.(string) + if strings.Contains(retVal, "-") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(retVal, "&") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(retVal, "=") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(retVal, "%") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(retVal, "'") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(strings.ToLower(retVal), "delete ") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(strings.ToLower(retVal), "truncate ") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(strings.ToLower(retVal), " or ") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + if strings.Contains(strings.ToLower(retVal), "from ") == true { + return fmt.Errorf("CheckArgs is error arg is %+v", retVal) + } + } + } + + return nil +} + // Query ... func (slf *DBModule) Query(query string, args ...interface{}) DBResult { + if slf.CheckArgs(args) != nil { + ret := DBResult{} + service.GetLogger().Printf(service.LEVER_ERROR, "CheckArgs is error :%s", query) + ret.Err = fmt.Errorf("CheckArgs is error!") + return ret + } + if slf.db == nil { ret := DBResult{} service.GetLogger().Printf(service.LEVER_ERROR, "cannot connect database:%s", query) @@ -363,6 +405,11 @@ func (slf *DBModule) QueryEx(query string, args ...interface{}) (*DataSetList, e datasetList.tag = "json" datasetList.blur = true + if slf.CheckArgs(args) != nil { + service.GetLogger().Printf(service.LEVER_ERROR, "CheckArgs is error :%s", query) + return &datasetList, fmt.Errorf("CheckArgs is error!") + } + if slf.db == nil { service.GetLogger().Printf(service.LEVER_ERROR, "cannot connect database:%s", query) return &datasetList, fmt.Errorf("cannot connect database!") @@ -372,7 +419,7 @@ func (slf *DBModule) QueryEx(query string, args ...interface{}) (*DataSetList, e rows, err := slf.db.Query(query, args...) TimeFuncPass := time.Since(TimeFuncStart) - if slf.IsPrintTimeLog(TimeFuncPass) { + if slf.IsPrintTimeLog(TimeFuncPass) { service.GetLogger().Printf(service.LEVER_INFO, "DBModule QueryEx Time %s , Query :%s , args :%+v", TimeFuncPass, query, args) } if err != nil { @@ -454,11 +501,17 @@ func (slf *DBModule) Exec(query string, args ...interface{}) (*DBResultEx, error return ret, fmt.Errorf("cannot connect database!") } + if slf.CheckArgs(args) != nil { + service.GetLogger().Printf(service.LEVER_ERROR, "CheckArgs is error :%s", query) + //return ret, fmt.Errorf("cannot connect database!") + return ret, fmt.Errorf("CheckArgs is error!") + } + TimeFuncStart := time.Now() res, err := slf.db.Exec(query, args...) TimeFuncPass := time.Since(TimeFuncStart) if slf.IsPrintTimeLog(TimeFuncPass) { - service.GetLogger().Printf(service.LEVER_INFO, "DBModule QueryEx Time %s , Query :%s , args :%+v",TimeFuncPass,query,args) + service.GetLogger().Printf(service.LEVER_INFO, "DBModule QueryEx Time %s , Query :%s , args :%+v", TimeFuncPass, query, args) } if err != nil { service.GetLogger().Printf(service.LEVER_ERROR, "Exec:%s(%v)", query, err)