防止SQL注入

This commit is contained in:
boyce
2019-07-27 18:41:38 +08:00
parent 468c609481
commit 7bbfc9e21d

View File

@@ -241,14 +241,12 @@ func (slf *DBResult) mapSingle2interface(m map[string]string, v reflect.Value) e
return nil
}
func (slf *DBModule) SetQuerySlowTime(Time time.Duration){
func (slf *DBModule) SetQuerySlowTime(Time time.Duration) {
slf.PrintTime = Time
}
func (slf *DBModule) IsPrintTimeLog(Time time.Duration)bool{
if slf.PrintTime != 0 && Time >= slf.PrintTime{
func (slf *DBModule) IsPrintTimeLog(Time time.Duration) bool {
if slf.PrintTime != 0 && Time >= slf.PrintTime {
return true
}
return false
@@ -337,8 +335,52 @@ func (slf *SyncQueryDBResultEx) Get(timeoutMs int) (*DataSetList, error) {
return nil, fmt.Errorf("Getting the return result timeout [%d]ms", timeoutMs)
}
func (slf *DBModule) CheckArgs(args ...interface{}) error {
for _, val := range args {
if reflect.TypeOf(val).Kind() == reflect.String {
retVal := val.(string)
if strings.Contains(retVal, "-") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
if strings.Contains(retVal, "&") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
if strings.Contains(retVal, "=") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
if strings.Contains(retVal, "%") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
if strings.Contains(retVal, "'") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
if strings.Contains(strings.ToLower(retVal), "delete ") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
if strings.Contains(strings.ToLower(retVal), "truncate ") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
if strings.Contains(strings.ToLower(retVal), " or ") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
if strings.Contains(strings.ToLower(retVal), "from ") == true {
return fmt.Errorf("CheckArgs is error arg is %+v", retVal)
}
}
}
return nil
}
// Query ...
func (slf *DBModule) Query(query string, args ...interface{}) DBResult {
if slf.CheckArgs(args) != nil {
ret := DBResult{}
service.GetLogger().Printf(service.LEVER_ERROR, "CheckArgs is error :%s", query)
ret.Err = fmt.Errorf("CheckArgs is error!")
return ret
}
if slf.db == nil {
ret := DBResult{}
service.GetLogger().Printf(service.LEVER_ERROR, "cannot connect database:%s", query)
@@ -363,6 +405,11 @@ func (slf *DBModule) QueryEx(query string, args ...interface{}) (*DataSetList, e
datasetList.tag = "json"
datasetList.blur = true
if slf.CheckArgs(args) != nil {
service.GetLogger().Printf(service.LEVER_ERROR, "CheckArgs is error :%s", query)
return &datasetList, fmt.Errorf("CheckArgs is error!")
}
if slf.db == nil {
service.GetLogger().Printf(service.LEVER_ERROR, "cannot connect database:%s", query)
return &datasetList, fmt.Errorf("cannot connect database!")
@@ -372,7 +419,7 @@ func (slf *DBModule) QueryEx(query string, args ...interface{}) (*DataSetList, e
rows, err := slf.db.Query(query, args...)
TimeFuncPass := time.Since(TimeFuncStart)
if slf.IsPrintTimeLog(TimeFuncPass) {
if slf.IsPrintTimeLog(TimeFuncPass) {
service.GetLogger().Printf(service.LEVER_INFO, "DBModule QueryEx Time %s , Query :%s , args :%+v", TimeFuncPass, query, args)
}
if err != nil {
@@ -454,11 +501,17 @@ func (slf *DBModule) Exec(query string, args ...interface{}) (*DBResultEx, error
return ret, fmt.Errorf("cannot connect database!")
}
if slf.CheckArgs(args) != nil {
service.GetLogger().Printf(service.LEVER_ERROR, "CheckArgs is error :%s", query)
//return ret, fmt.Errorf("cannot connect database!")
return ret, fmt.Errorf("CheckArgs is error!")
}
TimeFuncStart := time.Now()
res, err := slf.db.Exec(query, args...)
TimeFuncPass := time.Since(TimeFuncStart)
if slf.IsPrintTimeLog(TimeFuncPass) {
service.GetLogger().Printf(service.LEVER_INFO, "DBModule QueryEx Time %s , Query :%s , args :%+v",TimeFuncPass,query,args)
service.GetLogger().Printf(service.LEVER_INFO, "DBModule QueryEx Time %s , Query :%s , args :%+v", TimeFuncPass, query, args)
}
if err != nil {
service.GetLogger().Printf(service.LEVER_ERROR, "Exec:%s(%v)", query, err)