mirror of
https://github.com/duanhf2012/origin.git
synced 2026-02-04 15:04:45 +08:00
405 lines
8.2 KiB
Go
405 lines
8.2 KiB
Go
package sysmodule
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"net/url"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/duanhf2012/origin/service"
|
|
_ "github.com/go-sql-driver/mysql"
|
|
)
|
|
|
|
type SyncFun func()
|
|
|
|
const (
|
|
MAX_EXECUTE_FUN = 10000
|
|
)
|
|
|
|
// DBModule ...
|
|
type DBModule struct {
|
|
service.BaseModule
|
|
db *sql.DB
|
|
url string
|
|
username string
|
|
password string
|
|
dbname string
|
|
maxconn int
|
|
|
|
syncExecuteFun chan SyncFun
|
|
syncCoroutineNum int
|
|
}
|
|
|
|
// DBResult ...
|
|
type DBResult struct {
|
|
Err error
|
|
LastInsertID int64
|
|
RowsAffected int64
|
|
res *sql.Rows
|
|
// 解码数据相关设置
|
|
tag string
|
|
blur bool
|
|
}
|
|
|
|
func (slf *DBModule) Init(maxConn int, url string, userName string, password string, dbname string) error {
|
|
slf.url = url
|
|
slf.maxconn = maxConn
|
|
slf.username = userName
|
|
slf.password = password
|
|
slf.dbname = dbname
|
|
slf.syncExecuteFun = make(chan SyncFun, MAX_EXECUTE_FUN)
|
|
|
|
return slf.Connect(slf.maxconn)
|
|
}
|
|
|
|
func (slf *DBModule) OnInit() error {
|
|
for i := 0; i < slf.syncCoroutineNum; i++ {
|
|
go slf.RunExecuteDBCoroutine()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
//Close ...
|
|
func (slf *DBResult) Close() {
|
|
if slf.res != nil {
|
|
slf.res.Close()
|
|
}
|
|
}
|
|
|
|
//NextResult ...
|
|
func (slf *DBResult) NextResult() bool {
|
|
if slf.Err != nil || slf.res == nil {
|
|
return false
|
|
}
|
|
return slf.res.NextResultSet()
|
|
}
|
|
|
|
/*
|
|
// Next ...
|
|
func (slf *DBResult) Next() bool {
|
|
if slf.Err != nil {
|
|
return false
|
|
}
|
|
return slf.res.Next()
|
|
}
|
|
|
|
// Scan ...
|
|
func (slf *DBResult) Scan(arg ...interface{}) error {
|
|
if slf.Err != nil {
|
|
return slf.Err
|
|
}
|
|
return slf.res.Scan(arg...)
|
|
}
|
|
*/
|
|
// SetSpecificTag ...
|
|
func (slf *DBResult) SetSpecificTag(tag string) *DBResult {
|
|
slf.tag = tag
|
|
return slf
|
|
}
|
|
|
|
// SetBlurMode ...
|
|
func (slf *DBResult) SetBlurMode(blur bool) *DBResult {
|
|
slf.blur = blur
|
|
return slf
|
|
}
|
|
|
|
// UnMarshal ...
|
|
func (slf *DBResult) UnMarshal(out interface{}) error {
|
|
if slf.Err != nil {
|
|
return slf.Err
|
|
}
|
|
tbm, err := dbResult2Map(slf.res)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
//fmt.Println(tbm)
|
|
v := reflect.ValueOf(out)
|
|
if v.Kind() != reflect.Ptr {
|
|
return errors.New("interface must be a pointer")
|
|
}
|
|
if v.Elem().Kind() == reflect.Struct {
|
|
if len(tbm) != 1 {
|
|
return fmt.Errorf("数据结果集的长度不匹配 len=%d", len(tbm))
|
|
}
|
|
return slf.mapSingle2interface(tbm[0], v)
|
|
}
|
|
if v.Elem().Kind() == reflect.Slice {
|
|
return slf.mapSlice2interface(tbm, out)
|
|
}
|
|
return fmt.Errorf("错误的数据类型 %v", v.Elem().Kind())
|
|
}
|
|
|
|
func dbResult2Map(rows *sql.Rows) ([]map[string]string, error) {
|
|
columns, err := rows.Columns()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
count := len(columns)
|
|
tableData := make([]map[string]string, 0)
|
|
values := make([]string, count)
|
|
valuePtrs := make([]interface{}, count)
|
|
for rows.Next() {
|
|
for i := 0; i < count; i++ {
|
|
valuePtrs[i] = &values[i]
|
|
}
|
|
err := rows.Scan(valuePtrs...)
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
}
|
|
entry := make(map[string]string)
|
|
for i, col := range columns {
|
|
entry[strings.ToLower(col)] = values[i]
|
|
}
|
|
tableData = append(tableData, entry)
|
|
}
|
|
return tableData, nil
|
|
}
|
|
|
|
func (slf *DBResult) mapSingle2interface(m map[string]string, v reflect.Value) error {
|
|
t := v.Type()
|
|
val := v.Elem()
|
|
typ := t.Elem()
|
|
|
|
if !val.IsValid() {
|
|
return errors.New("数据类型不正确")
|
|
}
|
|
|
|
for i := 0; i < val.NumField(); i++ {
|
|
value := val.Field(i)
|
|
kind := value.Kind()
|
|
tag := typ.Field(i).Tag.Get(slf.tag)
|
|
if tag == "" {
|
|
tag = typ.Field(i).Name
|
|
}
|
|
|
|
if tag != "" && tag != "-" {
|
|
vtag := strings.Split(strings.ToLower(tag), ",")
|
|
meta, ok := m[vtag[0]]
|
|
if !ok {
|
|
if !slf.blur {
|
|
return fmt.Errorf("没有在结果集中找到对应的字段 %s", tag)
|
|
}
|
|
continue
|
|
}
|
|
if !value.CanSet() {
|
|
return errors.New("结构体字段没有读写权限")
|
|
}
|
|
if len(meta) == 0 {
|
|
continue
|
|
}
|
|
switch kind {
|
|
case reflect.String:
|
|
value.SetString(meta)
|
|
case reflect.Float32, reflect.Float64:
|
|
f, err := strconv.ParseFloat(meta, 64)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
value.SetFloat(f)
|
|
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
|
|
integer64, err := strconv.ParseInt(meta, 10, 64)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
value.SetInt(integer64)
|
|
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
|
integer64, err := strconv.ParseUint(meta, 10, 64)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
value.SetUint(integer64)
|
|
case reflect.Bool:
|
|
b, err := strconv.ParseBool(meta)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
value.SetBool(b)
|
|
default:
|
|
return errors.New("数据库映射存在不识别的数据类型")
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (slf *DBResult) mapSlice2interface(data []map[string]string, in interface{}) error {
|
|
length := len(data)
|
|
|
|
if length > 0 {
|
|
v := reflect.ValueOf(in).Elem()
|
|
newv := reflect.MakeSlice(v.Type(), 0, length)
|
|
v.Set(newv)
|
|
v.SetLen(length)
|
|
|
|
for i := 0; i < length; i++ {
|
|
idxv := v.Index(i)
|
|
if idxv.Kind() == reflect.Ptr {
|
|
newObj := reflect.New(idxv.Type().Elem())
|
|
v.Index(i).Set(newObj)
|
|
idxv = newObj
|
|
} else {
|
|
idxv = idxv.Addr()
|
|
}
|
|
err := slf.mapSingle2interface(data[i], idxv)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Connect ...
|
|
func (slf *DBModule) Connect(maxConn int) error {
|
|
cmd := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8&parseTime=true&loc=%s",
|
|
slf.username,
|
|
slf.password,
|
|
slf.url,
|
|
slf.dbname,
|
|
url.QueryEscape(time.Local.String()))
|
|
|
|
db, err := sql.Open("mysql", cmd)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = db.Ping()
|
|
if err != nil {
|
|
db.Close()
|
|
return err
|
|
}
|
|
slf.db = db
|
|
db.SetMaxOpenConns(maxConn)
|
|
db.SetMaxIdleConns(maxConn)
|
|
db.SetConnMaxLifetime(time.Second * 90)
|
|
|
|
slf.syncCoroutineNum = maxConn
|
|
|
|
return nil
|
|
}
|
|
|
|
// SyncDBResult ...
|
|
type SyncDBResult struct {
|
|
sres chan DBResult
|
|
}
|
|
|
|
// Get ...
|
|
func (slf *SyncDBResult) Get(timeoutMs int) DBResult {
|
|
timerC := time.NewTicker(time.Millisecond * time.Duration(timeoutMs)).C
|
|
select {
|
|
case <-timerC:
|
|
break
|
|
case rsp := <-slf.sres:
|
|
return rsp
|
|
}
|
|
return DBResult{
|
|
Err: fmt.Errorf("Getting the return result timeout [%d]ms", timeoutMs),
|
|
}
|
|
}
|
|
|
|
// Query ...
|
|
func (slf *DBModule) Query(query string, args ...interface{}) DBResult {
|
|
if slf.db == nil {
|
|
ret := DBResult{}
|
|
service.GetLogger().Printf(service.LEVER_ERROR, "cannot connect database:%s", query)
|
|
ret.Err = fmt.Errorf("cannot connect database!")
|
|
return ret
|
|
}
|
|
rows, err := slf.db.Query(query, args...)
|
|
if err != nil {
|
|
service.GetLogger().Printf(service.LEVER_ERROR, "Query:%s(%v)", query, err)
|
|
}
|
|
|
|
return DBResult{
|
|
Err: err,
|
|
res: rows,
|
|
tag: "json",
|
|
blur: true,
|
|
}
|
|
}
|
|
|
|
// SyncQuery ...
|
|
func (slf *DBModule) SyncQuery(query string, args ...interface{}) SyncDBResult {
|
|
ret := SyncDBResult{
|
|
sres: make(chan DBResult, 1),
|
|
}
|
|
|
|
if len(slf.syncExecuteFun) >= MAX_EXECUTE_FUN {
|
|
dbret := DBResult{}
|
|
dbret.Err = fmt.Errorf("chan is full,sql:%s", query)
|
|
ret.sres <- dbret
|
|
|
|
return ret
|
|
}
|
|
|
|
slf.syncExecuteFun <- func() {
|
|
rsp := slf.Query(query, args...)
|
|
ret.sres <- rsp
|
|
}
|
|
|
|
return ret
|
|
}
|
|
|
|
// Exec ...
|
|
func (slf *DBModule) Exec(query string, args ...interface{}) DBResult {
|
|
ret := DBResult{}
|
|
if slf.db == nil {
|
|
service.GetLogger().Printf(service.LEVER_ERROR, "cannot connect database:%s", query)
|
|
ret.Err = fmt.Errorf("cannot connect database!")
|
|
return ret
|
|
}
|
|
|
|
res, err := slf.db.Exec(query, args...)
|
|
ret.Err = err
|
|
if err != nil {
|
|
service.GetLogger().Printf(service.LEVER_ERROR, "Exec:%s(%v)", query, err)
|
|
return ret
|
|
}
|
|
|
|
ret.LastInsertID, _ = res.LastInsertId()
|
|
ret.RowsAffected, _ = res.RowsAffected()
|
|
return ret
|
|
}
|
|
|
|
// SyncExec ...
|
|
func (slf *DBModule) SyncExec(query string, args ...interface{}) SyncDBResult {
|
|
ret := SyncDBResult{
|
|
sres: make(chan DBResult, 1),
|
|
}
|
|
|
|
if len(slf.syncExecuteFun) >= MAX_EXECUTE_FUN {
|
|
dbret := DBResult{}
|
|
dbret.Err = fmt.Errorf("chan is full,sql:%s", query)
|
|
ret.sres <- dbret
|
|
|
|
return ret
|
|
}
|
|
|
|
slf.syncExecuteFun <- func() {
|
|
rsp := slf.Exec(query, args...)
|
|
ret.sres <- rsp
|
|
}
|
|
|
|
return ret
|
|
}
|
|
|
|
func (slf *DBModule) RunExecuteDBCoroutine() {
|
|
slf.WaitGroup.Add(1)
|
|
defer slf.WaitGroup.Done()
|
|
for {
|
|
select {
|
|
case <-slf.ExitChan:
|
|
service.GetLogger().Printf(LEVER_WARN, "stopping module %s...", fmt.Sprintf("%T", slf))
|
|
return
|
|
case fun := <-slf.syncExecuteFun:
|
|
fun()
|
|
}
|
|
}
|
|
|
|
}
|