diff --git a/sysmodule/DBModule.go b/sysmodule/DBModule.go new file mode 100644 index 0000000..d274c97 --- /dev/null +++ b/sysmodule/DBModule.go @@ -0,0 +1,303 @@ +package sysmodule + +import ( + "database/sql" + "errors" + "fmt" + "net/url" + "reflect" + "strconv" + "strings" + "time" +) + +// DBModule ... +type DBModule struct { + db *sql.DB + URL string + UserName string + Password string + DBName string +} + +// DBResult ... +type DBResult struct { + Err error + LastInsertID int64 + RowsAffected int64 + res *sql.Rows +} + +// Next ... +func (slf *DBResult) Next() bool { + if slf.Err != nil { + return false + } + return slf.res.Next() +} + +// Scan ... +func (slf *DBResult) Scan(arg ...interface{}) error { + if slf.Err != nil { + return slf.Err + } + return slf.res.Scan(arg...) +} + +// SQLDecoder ... +type SQLDecoder struct { + res DBResult + tag string + strict bool +} + +// NewSQLDecoder ... +func NewSQLDecoder(res DBResult) *SQLDecoder { + return &SQLDecoder{ + res: res, + tag: "col", + } +} + +// SetSpecificTag ... +func (slf *SQLDecoder) SetSpecificTag(tag string) *SQLDecoder { + slf.tag = tag + return slf +} + +// SetStrictMode ... +func (slf *SQLDecoder) SetStrictMode(strict bool) *SQLDecoder { + slf.strict = strict + return slf +} + +// UnMarshal ... +func (slf *SQLDecoder) UnMarshal(out interface{}) error { + if slf.res.Err != nil { + return slf.res.Err + } + tbm, err := dbResult2Map(slf.res.res) + if err != nil { + return err + } + fmt.Println(tbm) + v := reflect.ValueOf(out) + if v.Kind() != reflect.Ptr { + return errors.New("interface must be a pointer") + } + if v.Elem().Kind() == reflect.Struct { + if len(tbm) != 1 { + return fmt.Errorf("数据结果集的长度不匹配 len=%d", len(tbm)) + } + return slf.mapSingle2interface(tbm[0], v) + } + if v.Elem().Kind() == reflect.Slice { + return slf.mapSlice2interface(tbm, out) + } + return fmt.Errorf("错误的数据类型 %v", v.Elem().Kind()) +} + +func dbResult2Map(rows *sql.Rows) ([]map[string]string, error) { + columns, err := rows.Columns() + if err != nil { + return nil, err + } + count := len(columns) + tableData := make([]map[string]string, 0) + values := make([]string, count) + valuePtrs := make([]interface{}, count) + for rows.Next() { + for i := 0; i < count; i++ { + valuePtrs[i] = &values[i] + } + err := rows.Scan(valuePtrs...) + if err != nil { + fmt.Println(err) + } + entry := make(map[string]string) + for i, col := range columns { + entry[strings.ToLower(col)] = values[i] + } + tableData = append(tableData, entry) + } + return tableData, nil +} + +func (slf *SQLDecoder) mapSingle2interface(m map[string]string, v reflect.Value) error { + t := v.Type() + val := v.Elem() + typ := t.Elem() + + if !val.IsValid() { + return errors.New("数据类型不正确") + } + + for i := 0; i < val.NumField(); i++ { + value := val.Field(i) + kind := value.Kind() + tag := typ.Field(i).Tag.Get(slf.tag) + if tag == "" { + tag = typ.Field(i).Name + } + + if tag != "" && tag != "-" { + vtag := strings.Split(strings.ToLower(tag), ",") + meta, ok := m[vtag[0]] + if !ok { + if slf.strict { + return fmt.Errorf("没有在结果集中找到对应的字段 %s", tag) + } + continue + } + if !value.CanSet() { + return errors.New("结构体字段没有读写权限") + } + if len(meta) == 0 { + continue + } + switch kind { + case reflect.String: + value.SetString(meta) + case reflect.Float32, reflect.Float64: + f, err := strconv.ParseFloat(meta, 64) + if err != nil { + return err + } + value.SetFloat(f) + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + integer64, err := strconv.ParseInt(meta, 10, 64) + if err != nil { + return err + } + value.SetInt(integer64) + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + integer64, err := strconv.ParseUint(meta, 10, 64) + if err != nil { + return err + } + value.SetUint(integer64) + case reflect.Bool: + b, err := strconv.ParseBool(meta) + if err != nil { + return err + } + value.SetBool(b) + default: + return errors.New("数据库映射存在不识别的数据类型") + } + } + } + return nil +} + +func (slf *SQLDecoder) mapSlice2interface(data []map[string]string, in interface{}) error { + length := len(data) + + if length > 0 { + v := reflect.ValueOf(in).Elem() + newv := reflect.MakeSlice(v.Type(), 0, length) + v.Set(newv) + v.SetLen(length) + + for i := 0; i < length; i++ { + idxv := v.Index(i) + if idxv.Kind() == reflect.Ptr { + newObj := reflect.New(idxv.Type().Elem()) + v.Index(i).Set(newObj) + idxv = newObj + } else { + idxv = idxv.Addr() + } + err := slf.mapSingle2interface(data[i], idxv) + if err != nil { + return err + } + } + } + return nil +} + +// Connect ... +func (slf *DBModule) Connect() error { + cmd := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8&parseTime=true&loc=%s", + slf.UserName, + slf.Password, + slf.URL, + slf.DBName, + url.QueryEscape(time.Local.String())) + + db, err := sql.Open("mysql", cmd) + if err != nil { + return err + } + + err = db.Ping() + if err != nil { + db.Close() + return err + } + slf.db = db + return nil +} + +// SyncDBResult ... +type SyncDBResult struct { + sres chan DBResult +} + +// Get ... +func (slf *SyncDBResult) Get(timeoutMs int) DBResult { + timerC := time.NewTicker(time.Millisecond * time.Duration(timeoutMs)).C + select { + case <-timerC: + break + case rsp := <-slf.sres: + return rsp + } + return DBResult{ + Err: fmt.Errorf("Getting the return result timeout [%d]ms", timeoutMs), + } +} + +// Query ... +func (slf *DBModule) Query(query string, args ...interface{}) DBResult { + rows, err := slf.db.Query(query, args...) + return DBResult{ + Err: err, + res: rows, + } +} + +// SyncQuery ... +func (slf *DBModule) SyncQuery(query string, args ...interface{}) SyncDBResult { + ret := SyncDBResult{ + sres: make(chan DBResult, 1), + } + go func() { + rsp := slf.Query(query, args...) + ret.sres <- rsp + }() + return ret +} + +// Exec ... +func (slf *DBModule) Exec(query string, args ...interface{}) DBResult { + ret := DBResult{} + res, err := slf.db.Exec(query, args...) + ret.Err = err + ret.LastInsertID, _ = res.LastInsertId() + ret.RowsAffected, _ = res.RowsAffected() + return ret +} + +// SyncExec ... +func (slf *DBModule) SyncExec(query string, args ...interface{}) SyncDBResult { + ret := SyncDBResult{ + sres: make(chan DBResult, 1), + } + go func() { + rsp := slf.Exec(query, args...) + ret.sres <- rsp + }() + return ret +} diff --git a/sysmodule/HttpClientPoolModule.go b/sysmodule/HttpClientPoolModule.go index 656c9b9..cda6bfe 100644 --- a/sysmodule/HttpClientPoolModule.go +++ b/sysmodule/HttpClientPoolModule.go @@ -25,10 +25,10 @@ type HttpRespone struct { } type SyncHttpRespone struct { - resp chan *HttpRespone + resp chan HttpRespone } -func (slf *SyncHttpRespone) Get(timeoutMs int) *HttpRespone { +func (slf *SyncHttpRespone) Get(timeoutMs int) HttpRespone { timerC := time.NewTicker(time.Millisecond * time.Duration(timeoutMs)).C select { case <-timerC: @@ -36,7 +36,7 @@ func (slf *SyncHttpRespone) Get(timeoutMs int) *HttpRespone { case rsp := <-slf.resp: return rsp } - return &HttpRespone{ + return HttpRespone{ Err: fmt.Errorf("Getting the return result timeout [%d]ms", timeoutMs), } } @@ -57,7 +57,7 @@ func (slf *HttpClientPoolModule) Init(maxpool int) { func (slf *HttpClientPoolModule) SyncRequest(method string, url string, body []byte) SyncHttpRespone { ret := SyncHttpRespone{ - resp: make(chan *HttpRespone, 1), + resp: make(chan HttpRespone, 1), } go func() { rsp := slf.Request(method, url, body) @@ -66,11 +66,11 @@ func (slf *HttpClientPoolModule) SyncRequest(method string, url string, body []b return ret } -func (slf *HttpClientPoolModule) Request(method string, url string, body []byte) *HttpRespone { +func (slf *HttpClientPoolModule) Request(method string, url string, body []byte) HttpRespone { if slf.client == nil { panic("Call the init function first") } - ret := &HttpRespone{} + ret := HttpRespone{} req, err := http.NewRequest(method, url, bytes.NewReader(body)) if err != nil { ret.Err = err diff --git a/sysmodule/HttpClientPoolModule_test.go b/sysmodule/HttpClientPoolModule_test.go index 61360f0..ce4a699 100644 --- a/sysmodule/HttpClientPoolModule_test.go +++ b/sysmodule/HttpClientPoolModule_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/duanhf2012/origin/sysmodule" + _ "github.com/go-sql-driver/mysql" ) func TestHttpClientPoolModule(t *testing.T) { @@ -27,3 +28,34 @@ func TestHttpClientPoolModule(t *testing.T) { fmt.Println(rsp1.Status) fmt.Println(string(rsp1.Body)) } + +func TestDBModule(t *testing.T) { + db := sysmodule.DBModule{ + URL: "192.168.0.5:3306", + UserName: "root", + Password: "Root!!2018", + DBName: "QuantFundsDB", + } + db.Connect() + + res := db.Query("select * from tbl_fun_heelthrow where id >= 1") + if res.Err != nil { + t.Error(res.Err) + } + out := []struct { + Addtime int64 `json:"addtime"` + Tname string `json:"tname"` + Uuid string `json:"uuid,omitempty"` + AAAA string `json:"-"` + }{} + err := sysmodule.NewSQLDecoder(res).SetSpecificTag("json").SetStrictMode(true).UnMarshal(&out) + if err != nil { + t.Error(err) + } + + sres := db.SyncQuery("select * from tbl_fun_heelthrow where id >= 1") + res = sres.Get(1000) + if res.Err != nil { + t.Error(res.Err) + } +}