mirror of
https://github.com/duanhf2012/origin.git
synced 2026-02-04 06:54:45 +08:00
添加一个mysql模块
This commit is contained in:
303
sysmodule/DBModule.go
Normal file
303
sysmodule/DBModule.go
Normal file
@@ -0,0 +1,303 @@
|
||||
package sysmodule
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DBModule ...
|
||||
type DBModule struct {
|
||||
db *sql.DB
|
||||
URL string
|
||||
UserName string
|
||||
Password string
|
||||
DBName string
|
||||
}
|
||||
|
||||
// DBResult ...
|
||||
type DBResult struct {
|
||||
Err error
|
||||
LastInsertID int64
|
||||
RowsAffected int64
|
||||
res *sql.Rows
|
||||
}
|
||||
|
||||
// Next ...
|
||||
func (slf *DBResult) Next() bool {
|
||||
if slf.Err != nil {
|
||||
return false
|
||||
}
|
||||
return slf.res.Next()
|
||||
}
|
||||
|
||||
// Scan ...
|
||||
func (slf *DBResult) Scan(arg ...interface{}) error {
|
||||
if slf.Err != nil {
|
||||
return slf.Err
|
||||
}
|
||||
return slf.res.Scan(arg...)
|
||||
}
|
||||
|
||||
// SQLDecoder ...
|
||||
type SQLDecoder struct {
|
||||
res DBResult
|
||||
tag string
|
||||
strict bool
|
||||
}
|
||||
|
||||
// NewSQLDecoder ...
|
||||
func NewSQLDecoder(res DBResult) *SQLDecoder {
|
||||
return &SQLDecoder{
|
||||
res: res,
|
||||
tag: "col",
|
||||
}
|
||||
}
|
||||
|
||||
// SetSpecificTag ...
|
||||
func (slf *SQLDecoder) SetSpecificTag(tag string) *SQLDecoder {
|
||||
slf.tag = tag
|
||||
return slf
|
||||
}
|
||||
|
||||
// SetStrictMode ...
|
||||
func (slf *SQLDecoder) SetStrictMode(strict bool) *SQLDecoder {
|
||||
slf.strict = strict
|
||||
return slf
|
||||
}
|
||||
|
||||
// UnMarshal ...
|
||||
func (slf *SQLDecoder) UnMarshal(out interface{}) error {
|
||||
if slf.res.Err != nil {
|
||||
return slf.res.Err
|
||||
}
|
||||
tbm, err := dbResult2Map(slf.res.res)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println(tbm)
|
||||
v := reflect.ValueOf(out)
|
||||
if v.Kind() != reflect.Ptr {
|
||||
return errors.New("interface must be a pointer")
|
||||
}
|
||||
if v.Elem().Kind() == reflect.Struct {
|
||||
if len(tbm) != 1 {
|
||||
return fmt.Errorf("数据结果集的长度不匹配 len=%d", len(tbm))
|
||||
}
|
||||
return slf.mapSingle2interface(tbm[0], v)
|
||||
}
|
||||
if v.Elem().Kind() == reflect.Slice {
|
||||
return slf.mapSlice2interface(tbm, out)
|
||||
}
|
||||
return fmt.Errorf("错误的数据类型 %v", v.Elem().Kind())
|
||||
}
|
||||
|
||||
func dbResult2Map(rows *sql.Rows) ([]map[string]string, error) {
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
count := len(columns)
|
||||
tableData := make([]map[string]string, 0)
|
||||
values := make([]string, count)
|
||||
valuePtrs := make([]interface{}, count)
|
||||
for rows.Next() {
|
||||
for i := 0; i < count; i++ {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
err := rows.Scan(valuePtrs...)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
entry := make(map[string]string)
|
||||
for i, col := range columns {
|
||||
entry[strings.ToLower(col)] = values[i]
|
||||
}
|
||||
tableData = append(tableData, entry)
|
||||
}
|
||||
return tableData, nil
|
||||
}
|
||||
|
||||
func (slf *SQLDecoder) mapSingle2interface(m map[string]string, v reflect.Value) error {
|
||||
t := v.Type()
|
||||
val := v.Elem()
|
||||
typ := t.Elem()
|
||||
|
||||
if !val.IsValid() {
|
||||
return errors.New("数据类型不正确")
|
||||
}
|
||||
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
value := val.Field(i)
|
||||
kind := value.Kind()
|
||||
tag := typ.Field(i).Tag.Get(slf.tag)
|
||||
if tag == "" {
|
||||
tag = typ.Field(i).Name
|
||||
}
|
||||
|
||||
if tag != "" && tag != "-" {
|
||||
vtag := strings.Split(strings.ToLower(tag), ",")
|
||||
meta, ok := m[vtag[0]]
|
||||
if !ok {
|
||||
if slf.strict {
|
||||
return fmt.Errorf("没有在结果集中找到对应的字段 %s", tag)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if !value.CanSet() {
|
||||
return errors.New("结构体字段没有读写权限")
|
||||
}
|
||||
if len(meta) == 0 {
|
||||
continue
|
||||
}
|
||||
switch kind {
|
||||
case reflect.String:
|
||||
value.SetString(meta)
|
||||
case reflect.Float32, reflect.Float64:
|
||||
f, err := strconv.ParseFloat(meta, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
value.SetFloat(f)
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
integer64, err := strconv.ParseInt(meta, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
value.SetInt(integer64)
|
||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
integer64, err := strconv.ParseUint(meta, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
value.SetUint(integer64)
|
||||
case reflect.Bool:
|
||||
b, err := strconv.ParseBool(meta)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
value.SetBool(b)
|
||||
default:
|
||||
return errors.New("数据库映射存在不识别的数据类型")
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (slf *SQLDecoder) mapSlice2interface(data []map[string]string, in interface{}) error {
|
||||
length := len(data)
|
||||
|
||||
if length > 0 {
|
||||
v := reflect.ValueOf(in).Elem()
|
||||
newv := reflect.MakeSlice(v.Type(), 0, length)
|
||||
v.Set(newv)
|
||||
v.SetLen(length)
|
||||
|
||||
for i := 0; i < length; i++ {
|
||||
idxv := v.Index(i)
|
||||
if idxv.Kind() == reflect.Ptr {
|
||||
newObj := reflect.New(idxv.Type().Elem())
|
||||
v.Index(i).Set(newObj)
|
||||
idxv = newObj
|
||||
} else {
|
||||
idxv = idxv.Addr()
|
||||
}
|
||||
err := slf.mapSingle2interface(data[i], idxv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Connect ...
|
||||
func (slf *DBModule) Connect() error {
|
||||
cmd := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8&parseTime=true&loc=%s",
|
||||
slf.UserName,
|
||||
slf.Password,
|
||||
slf.URL,
|
||||
slf.DBName,
|
||||
url.QueryEscape(time.Local.String()))
|
||||
|
||||
db, err := sql.Open("mysql", cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = db.Ping()
|
||||
if err != nil {
|
||||
db.Close()
|
||||
return err
|
||||
}
|
||||
slf.db = db
|
||||
return nil
|
||||
}
|
||||
|
||||
// SyncDBResult ...
|
||||
type SyncDBResult struct {
|
||||
sres chan DBResult
|
||||
}
|
||||
|
||||
// Get ...
|
||||
func (slf *SyncDBResult) Get(timeoutMs int) DBResult {
|
||||
timerC := time.NewTicker(time.Millisecond * time.Duration(timeoutMs)).C
|
||||
select {
|
||||
case <-timerC:
|
||||
break
|
||||
case rsp := <-slf.sres:
|
||||
return rsp
|
||||
}
|
||||
return DBResult{
|
||||
Err: fmt.Errorf("Getting the return result timeout [%d]ms", timeoutMs),
|
||||
}
|
||||
}
|
||||
|
||||
// Query ...
|
||||
func (slf *DBModule) Query(query string, args ...interface{}) DBResult {
|
||||
rows, err := slf.db.Query(query, args...)
|
||||
return DBResult{
|
||||
Err: err,
|
||||
res: rows,
|
||||
}
|
||||
}
|
||||
|
||||
// SyncQuery ...
|
||||
func (slf *DBModule) SyncQuery(query string, args ...interface{}) SyncDBResult {
|
||||
ret := SyncDBResult{
|
||||
sres: make(chan DBResult, 1),
|
||||
}
|
||||
go func() {
|
||||
rsp := slf.Query(query, args...)
|
||||
ret.sres <- rsp
|
||||
}()
|
||||
return ret
|
||||
}
|
||||
|
||||
// Exec ...
|
||||
func (slf *DBModule) Exec(query string, args ...interface{}) DBResult {
|
||||
ret := DBResult{}
|
||||
res, err := slf.db.Exec(query, args...)
|
||||
ret.Err = err
|
||||
ret.LastInsertID, _ = res.LastInsertId()
|
||||
ret.RowsAffected, _ = res.RowsAffected()
|
||||
return ret
|
||||
}
|
||||
|
||||
// SyncExec ...
|
||||
func (slf *DBModule) SyncExec(query string, args ...interface{}) SyncDBResult {
|
||||
ret := SyncDBResult{
|
||||
sres: make(chan DBResult, 1),
|
||||
}
|
||||
go func() {
|
||||
rsp := slf.Exec(query, args...)
|
||||
ret.sres <- rsp
|
||||
}()
|
||||
return ret
|
||||
}
|
||||
@@ -25,10 +25,10 @@ type HttpRespone struct {
|
||||
}
|
||||
|
||||
type SyncHttpRespone struct {
|
||||
resp chan *HttpRespone
|
||||
resp chan HttpRespone
|
||||
}
|
||||
|
||||
func (slf *SyncHttpRespone) Get(timeoutMs int) *HttpRespone {
|
||||
func (slf *SyncHttpRespone) Get(timeoutMs int) HttpRespone {
|
||||
timerC := time.NewTicker(time.Millisecond * time.Duration(timeoutMs)).C
|
||||
select {
|
||||
case <-timerC:
|
||||
@@ -36,7 +36,7 @@ func (slf *SyncHttpRespone) Get(timeoutMs int) *HttpRespone {
|
||||
case rsp := <-slf.resp:
|
||||
return rsp
|
||||
}
|
||||
return &HttpRespone{
|
||||
return HttpRespone{
|
||||
Err: fmt.Errorf("Getting the return result timeout [%d]ms", timeoutMs),
|
||||
}
|
||||
}
|
||||
@@ -57,7 +57,7 @@ func (slf *HttpClientPoolModule) Init(maxpool int) {
|
||||
|
||||
func (slf *HttpClientPoolModule) SyncRequest(method string, url string, body []byte) SyncHttpRespone {
|
||||
ret := SyncHttpRespone{
|
||||
resp: make(chan *HttpRespone, 1),
|
||||
resp: make(chan HttpRespone, 1),
|
||||
}
|
||||
go func() {
|
||||
rsp := slf.Request(method, url, body)
|
||||
@@ -66,11 +66,11 @@ func (slf *HttpClientPoolModule) SyncRequest(method string, url string, body []b
|
||||
return ret
|
||||
}
|
||||
|
||||
func (slf *HttpClientPoolModule) Request(method string, url string, body []byte) *HttpRespone {
|
||||
func (slf *HttpClientPoolModule) Request(method string, url string, body []byte) HttpRespone {
|
||||
if slf.client == nil {
|
||||
panic("Call the init function first")
|
||||
}
|
||||
ret := &HttpRespone{}
|
||||
ret := HttpRespone{}
|
||||
req, err := http.NewRequest(method, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
ret.Err = err
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/duanhf2012/origin/sysmodule"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
||||
func TestHttpClientPoolModule(t *testing.T) {
|
||||
@@ -27,3 +28,34 @@ func TestHttpClientPoolModule(t *testing.T) {
|
||||
fmt.Println(rsp1.Status)
|
||||
fmt.Println(string(rsp1.Body))
|
||||
}
|
||||
|
||||
func TestDBModule(t *testing.T) {
|
||||
db := sysmodule.DBModule{
|
||||
URL: "192.168.0.5:3306",
|
||||
UserName: "root",
|
||||
Password: "Root!!2018",
|
||||
DBName: "QuantFundsDB",
|
||||
}
|
||||
db.Connect()
|
||||
|
||||
res := db.Query("select * from tbl_fun_heelthrow where id >= 1")
|
||||
if res.Err != nil {
|
||||
t.Error(res.Err)
|
||||
}
|
||||
out := []struct {
|
||||
Addtime int64 `json:"addtime"`
|
||||
Tname string `json:"tname"`
|
||||
Uuid string `json:"uuid,omitempty"`
|
||||
AAAA string `json:"-"`
|
||||
}{}
|
||||
err := sysmodule.NewSQLDecoder(res).SetSpecificTag("json").SetStrictMode(true).UnMarshal(&out)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
sres := db.SyncQuery("select * from tbl_fun_heelthrow where id >= 1")
|
||||
res = sres.Get(1000)
|
||||
if res.Err != nil {
|
||||
t.Error(res.Err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user