提交origin2.0版本

This commit is contained in:
duanhf2012
2020-03-28 09:57:16 +08:00
parent 0d98f77d07
commit 84fb8ab36d
111 changed files with 3657 additions and 8382 deletions

6
network/agent.go Normal file
View File

@@ -0,0 +1,6 @@
package network
type Agent interface {
Run()
OnClose()
}

14
network/conn.go Normal file
View File

@@ -0,0 +1,14 @@
package network
import (
"net"
)
type Conn interface {
ReadMsg() ([]byte, error)
WriteMsg(args ...[]byte) error
LocalAddr() net.Addr
RemoteAddr() net.Addr
Close()
Destroy()
}

View File

@@ -1,98 +0,0 @@
package network
import (
"crypto/tls"
"fmt"
"net/http"
"os"
"time"
"github.com/duanhf2012/origin/service"
"github.com/duanhf2012/origin/sysmodule"
)
type CA struct {
certfile string
keyfile string
}
type HttpServer struct {
port uint16
handler http.Handler
readtimeout time.Duration
writetimeout time.Duration
httpserver *http.Server
caList []CA
ishttps bool
}
func (slf *HttpServer) Init(port uint16, handler http.Handler, readtimeout time.Duration, writetimeout time.Duration) {
slf.port = port
slf.handler = handler
slf.readtimeout = readtimeout
slf.writetimeout = writetimeout
}
func (slf *HttpServer) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) {
http.HandleFunc(pattern, handler)
}
func (slf *HttpServer) Start() {
go slf.startListen()
}
func (slf *HttpServer) startListen() error {
listenPort := fmt.Sprintf(":%d", slf.port)
var tlscatList []tls.Certificate
var tlsConfig *tls.Config
for _, cadata := range slf.caList {
cer, err := tls.LoadX509KeyPair(cadata.certfile, cadata.keyfile)
if err != nil {
service.GetLogger().Printf(sysmodule.LEVER_FATAL, "load CA [%s]-[%s] file is error :%s", cadata.certfile, cadata.keyfile, err.Error())
os.Exit(1)
return nil
}
tlscatList = append(tlscatList, cer)
}
if len(tlscatList) > 0 {
tlsConfig = &tls.Config{Certificates: tlscatList}
}
slf.httpserver = &http.Server{
Addr: listenPort,
Handler: slf.handler,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
MaxHeaderBytes: 1 << 20,
TLSConfig: tlsConfig,
}
var err error
if slf.ishttps == true {
err = slf.httpserver.ListenAndServeTLS("", "")
} else {
err = slf.httpserver.ListenAndServe()
}
if err != nil {
service.GetLogger().Printf(sysmodule.LEVER_FATAL, "http.ListenAndServe(%d, nil) error:%v\n", listenPort, err)
fmt.Printf("http.ListenAndServe(%d, %v) error\n", slf.port, err)
os.Exit(1)
}
return nil
}
func (slf *HttpServer) SetHttps(certfile string, keyfile string) bool {
if certfile == "" || keyfile == "" {
return false
}
slf.caList = append(slf.caList, CA{certfile, keyfile})
slf.ishttps = true
return true
}

7
network/netserver.go Normal file
View File

@@ -0,0 +1,7 @@
package network
type inetserver interface {
}

10
network/processor.go Normal file
View File

@@ -0,0 +1,10 @@
package network
type Processor interface {
// must goroutine safe
Route(msg interface{}, userData interface{}) error
// must goroutine safe
Unmarshal(data []byte) (interface{}, error)
// must goroutine safe
Marshal(msg interface{}) ([]byte, error)
}

View File

@@ -0,0 +1 @@
package processor

View File

@@ -0,0 +1,17 @@
package processor
type JsonProcessor struct {
//SetByteOrder(littleEndian bool)
//SetMsgLen(lenMsgLen int, minMsgLen uint32, maxMsgLen uint32)
}
func (slf *JsonProcessor) Unmarshal(data []byte) (interface{}, error) {
return nil,nil
}
func (slf *JsonProcessor) Marshal(msg interface{}) ([][]byte, error) {
return nil,nil
}

View File

@@ -0,0 +1 @@
package processor

View File

@@ -0,0 +1,11 @@
package processor
type IProcessor interface {
//SetByteOrder(littleEndian bool)
//SetMsgLen(lenMsgLen int, minMsgLen uint32, maxMsgLen uint32)
Unmarshal(data []byte) (interface{}, error)
// must goroutine safe
Marshal(msg interface{}) ([][]byte, error)
}

View File

@@ -0,0 +1 @@
package processor

130
network/tcp_client.go Normal file
View File

@@ -0,0 +1,130 @@
package network
import (
"github.com/duanhf2012/originnet/log"
"net"
"sync"
"time"
)
type TCPClient struct {
sync.Mutex
Addr string
ConnNum int
ConnectInterval time.Duration
PendingWriteNum int
AutoReconnect bool
NewAgent func(*TCPConn) Agent
conns ConnSet
wg sync.WaitGroup
closeFlag bool
// msg parser
LenMsgLen int
MinMsgLen uint32
MaxMsgLen uint32
LittleEndian bool
msgParser *MsgParser
}
func (client *TCPClient) Start() {
client.init()
for i := 0; i < client.ConnNum; i++ {
client.wg.Add(1)
go client.connect()
}
}
func (client *TCPClient) init() {
client.Lock()
defer client.Unlock()
if client.ConnNum <= 0 {
client.ConnNum = 1
log.Release("invalid ConnNum, reset to %v", client.ConnNum)
}
if client.ConnectInterval <= 0 {
client.ConnectInterval = 3 * time.Second
log.Release("invalid ConnectInterval, reset to %v", client.ConnectInterval)
}
if client.PendingWriteNum <= 0 {
client.PendingWriteNum = 100
log.Release("invalid PendingWriteNum, reset to %v", client.PendingWriteNum)
}
if client.NewAgent == nil {
log.Fatal("NewAgent must not be nil")
}
if client.conns != nil {
log.Fatal("client is running")
}
client.conns = make(ConnSet)
client.closeFlag = false
// msg parser
msgParser := NewMsgParser()
msgParser.SetMsgLen(client.LenMsgLen, client.MinMsgLen, client.MaxMsgLen)
msgParser.SetByteOrder(client.LittleEndian)
client.msgParser = msgParser
}
func (client *TCPClient) dial() net.Conn {
for {
conn, err := net.Dial("tcp", client.Addr)
if err == nil || client.closeFlag {
return conn
}
log.Release("connect to %v error: %v", client.Addr, err)
time.Sleep(client.ConnectInterval)
continue
}
}
func (client *TCPClient) connect() {
defer client.wg.Done()
reconnect:
conn := client.dial()
if conn == nil {
return
}
client.Lock()
if client.closeFlag {
client.Unlock()
conn.Close()
return
}
client.conns[conn] = struct{}{}
client.Unlock()
tcpConn := newTCPConn(conn, client.PendingWriteNum, client.msgParser)
agent := client.NewAgent(tcpConn)
agent.Run()
// cleanup
tcpConn.Close()
client.Lock()
delete(client.conns, conn)
client.Unlock()
agent.OnClose()
if client.AutoReconnect {
time.Sleep(client.ConnectInterval)
goto reconnect
}
}
func (client *TCPClient) Close() {
client.Lock()
client.closeFlag = true
for conn := range client.conns {
conn.Close()
}
client.conns = nil
client.Unlock()
client.wg.Wait()
}

113
network/tcp_conn.go Normal file
View File

@@ -0,0 +1,113 @@
package network
import (
"github.com/duanhf2012/originnet/log"
"net"
"sync"
)
type ConnSet map[net.Conn]struct{}
type TCPConn struct {
sync.Mutex
conn net.Conn
writeChan chan []byte
closeFlag bool
msgParser *MsgParser
}
func newTCPConn(conn net.Conn, pendingWriteNum int, msgParser *MsgParser) *TCPConn {
tcpConn := new(TCPConn)
tcpConn.conn = conn
tcpConn.writeChan = make(chan []byte, pendingWriteNum)
tcpConn.msgParser = msgParser
go func() {
for b := range tcpConn.writeChan {
if b == nil {
break
}
_, err := conn.Write(b)
if err != nil {
break
}
}
conn.Close()
tcpConn.Lock()
tcpConn.closeFlag = true
tcpConn.Unlock()
}()
return tcpConn
}
func (tcpConn *TCPConn) doDestroy() {
tcpConn.conn.(*net.TCPConn).SetLinger(0)
tcpConn.conn.Close()
if !tcpConn.closeFlag {
close(tcpConn.writeChan)
tcpConn.closeFlag = true
}
}
func (tcpConn *TCPConn) Destroy() {
tcpConn.Lock()
defer tcpConn.Unlock()
tcpConn.doDestroy()
}
func (tcpConn *TCPConn) Close() {
tcpConn.Lock()
defer tcpConn.Unlock()
if tcpConn.closeFlag {
return
}
tcpConn.doWrite(nil)
tcpConn.closeFlag = true
}
func (tcpConn *TCPConn) doWrite(b []byte) {
if len(tcpConn.writeChan) == cap(tcpConn.writeChan) {
log.Debug("close conn: channel full")
tcpConn.doDestroy()
return
}
tcpConn.writeChan <- b
}
// b must not be modified by the others goroutines
func (tcpConn *TCPConn) Write(b []byte) {
tcpConn.Lock()
defer tcpConn.Unlock()
if tcpConn.closeFlag || b == nil {
return
}
tcpConn.doWrite(b)
}
func (tcpConn *TCPConn) Read(b []byte) (int, error) {
return tcpConn.conn.Read(b)
}
func (tcpConn *TCPConn) LocalAddr() net.Addr {
return tcpConn.conn.LocalAddr()
}
func (tcpConn *TCPConn) RemoteAddr() net.Addr {
return tcpConn.conn.RemoteAddr()
}
func (tcpConn *TCPConn) ReadMsg() ([]byte, error) {
return tcpConn.msgParser.Read(tcpConn)
}
func (tcpConn *TCPConn) WriteMsg(args ...[]byte) error {
return tcpConn.msgParser.Write(tcpConn, args...)
}

154
network/tcp_msg.go Normal file
View File

@@ -0,0 +1,154 @@
package network
import (
"encoding/binary"
"errors"
"io"
"math"
)
// --------------
// | len | data |
// --------------
type MsgParser struct {
lenMsgLen int
minMsgLen uint32
maxMsgLen uint32
littleEndian bool
}
func NewMsgParser() *MsgParser {
p := new(MsgParser)
p.lenMsgLen = 2
p.minMsgLen = 1
p.maxMsgLen = 4096
p.littleEndian = false
return p
}
// It's dangerous to call the method on reading or writing
func (p *MsgParser) SetMsgLen(lenMsgLen int, minMsgLen uint32, maxMsgLen uint32) {
if lenMsgLen == 1 || lenMsgLen == 2 || lenMsgLen == 4 {
p.lenMsgLen = lenMsgLen
}
if minMsgLen != 0 {
p.minMsgLen = minMsgLen
}
if maxMsgLen != 0 {
p.maxMsgLen = maxMsgLen
}
var max uint32
switch p.lenMsgLen {
case 1:
max = math.MaxUint8
case 2:
max = math.MaxUint16
case 4:
max = math.MaxUint32
}
if p.minMsgLen > max {
p.minMsgLen = max
}
if p.maxMsgLen > max {
p.maxMsgLen = max
}
}
// It's dangerous to call the method on reading or writing
func (p *MsgParser) SetByteOrder(littleEndian bool) {
p.littleEndian = littleEndian
}
// goroutine safe
func (p *MsgParser) Read(conn *TCPConn) ([]byte, error) {
var b [4]byte
bufMsgLen := b[:p.lenMsgLen]
// read len
if _, err := io.ReadFull(conn, bufMsgLen); err != nil {
return nil, err
}
// parse len
var msgLen uint32
switch p.lenMsgLen {
case 1:
msgLen = uint32(bufMsgLen[0])
case 2:
if p.littleEndian {
msgLen = uint32(binary.LittleEndian.Uint16(bufMsgLen))
} else {
msgLen = uint32(binary.BigEndian.Uint16(bufMsgLen))
}
case 4:
if p.littleEndian {
msgLen = binary.LittleEndian.Uint32(bufMsgLen)
} else {
msgLen = binary.BigEndian.Uint32(bufMsgLen)
}
}
// check len
if msgLen > p.maxMsgLen {
return nil, errors.New("message too long")
} else if msgLen < p.minMsgLen {
return nil, errors.New("message too short")
}
// data
msgData := make([]byte, msgLen)
if _, err := io.ReadFull(conn, msgData); err != nil {
return nil, err
}
return msgData, nil
}
// goroutine safe
func (p *MsgParser) Write(conn *TCPConn, args ...[]byte) error {
// get len
var msgLen uint32
for i := 0; i < len(args); i++ {
msgLen += uint32(len(args[i]))
}
// check len
if msgLen > p.maxMsgLen {
return errors.New("message too long")
} else if msgLen < p.minMsgLen {
return errors.New("message too short")
}
msg := make([]byte, uint32(p.lenMsgLen)+msgLen)
// write len
switch p.lenMsgLen {
case 1:
msg[0] = byte(msgLen)
case 2:
if p.littleEndian {
binary.LittleEndian.PutUint16(msg, uint16(msgLen))
} else {
binary.BigEndian.PutUint16(msg, uint16(msgLen))
}
case 4:
if p.littleEndian {
binary.LittleEndian.PutUint32(msg, msgLen)
} else {
binary.BigEndian.PutUint32(msg, msgLen)
}
}
// write data
l := p.lenMsgLen
for i := 0; i < len(args); i++ {
copy(msg[l:], args[i])
l += len(args[i])
}
conn.Write(msg)
return nil
}

127
network/tcp_server.go Normal file
View File

@@ -0,0 +1,127 @@
package network
import (
"github.com/duanhf2012/originnet/log"
"net"
"sync"
"time"
)
type TCPServer struct {
Addr string
MaxConnNum int
PendingWriteNum int
NewAgent func(*TCPConn) Agent
ln net.Listener
conns ConnSet
mutexConns sync.Mutex
wgLn sync.WaitGroup
wgConns sync.WaitGroup
// msg parser
LenMsgLen int
MinMsgLen uint32
MaxMsgLen uint32
LittleEndian bool
msgParser *MsgParser
}
func (server *TCPServer) Start() {
server.init()
go server.run()
}
func (server *TCPServer) init() {
ln, err := net.Listen("tcp", server.Addr)
if err != nil {
log.Fatal("%v", err)
}
if server.MaxConnNum <= 0 {
server.MaxConnNum = 100
log.Release("invalid MaxConnNum, reset to %v", server.MaxConnNum)
}
if server.PendingWriteNum <= 0 {
server.PendingWriteNum = 100
log.Release("invalid PendingWriteNum, reset to %v", server.PendingWriteNum)
}
if server.NewAgent == nil {
log.Fatal("NewAgent must not be nil")
}
server.ln = ln
server.conns = make(ConnSet)
// msg parser
msgParser := NewMsgParser()
msgParser.SetMsgLen(server.LenMsgLen, server.MinMsgLen, server.MaxMsgLen)
msgParser.SetByteOrder(server.LittleEndian)
server.msgParser = msgParser
}
func (server *TCPServer) run() {
server.wgLn.Add(1)
defer server.wgLn.Done()
var tempDelay time.Duration
for {
conn, err := server.ln.Accept()
if err != nil {
if ne, ok := err.(net.Error); ok && ne.Temporary() {
if tempDelay == 0 {
tempDelay = 5 * time.Millisecond
} else {
tempDelay *= 2
}
if max := 1 * time.Second; tempDelay > max {
tempDelay = max
}
log.Release("accept error: %v; retrying in %v", err, tempDelay)
time.Sleep(tempDelay)
continue
}
return
}
tempDelay = 0
server.mutexConns.Lock()
if len(server.conns) >= server.MaxConnNum {
server.mutexConns.Unlock()
conn.Close()
log.Debug("too many connections")
continue
}
server.conns[conn] = struct{}{}
server.mutexConns.Unlock()
server.wgConns.Add(1)
tcpConn := newTCPConn(conn, server.PendingWriteNum, server.msgParser)
agent := server.NewAgent(tcpConn)
go func() {
agent.Run()
// cleanup
tcpConn.Close()
server.mutexConns.Lock()
delete(server.conns, conn)
server.mutexConns.Unlock()
agent.OnClose()
server.wgConns.Done()
}()
}
}
func (server *TCPServer) Close() {
server.ln.Close()
server.wgLn.Wait()
server.mutexConns.Lock()
for conn := range server.conns {
conn.Close()
}
server.conns = nil
server.mutexConns.Unlock()
server.wgConns.Wait()
}

View File

@@ -1,48 +0,0 @@
package network
import (
"fmt"
"github.com/golang/protobuf/proto"
"net"
)
type TcpSocketClient struct {
conn net.Conn
addr string
}
func (slf *TcpSocketClient) Connect(addr string) error{
tcpAddr,terr := net.ResolveTCPAddr("tcp",addr)
if terr != nil {
return terr
}
conn,err := net.DialTCP("tcp",nil,tcpAddr)
if err!=nil {
fmt.Println("Client connect error ! " + err.Error())
return err
}
slf.conn = conn
slf.addr = addr
//
return nil
}
func (slf *TcpSocketClient) SendMsg(packtype uint16,message proto.Message) error{
if slf.conn == nil {
return fmt.Errorf("cannt connect %s",slf.addr)
}
var msg MsgBasePack
data,err := proto.Marshal(message)
if err != nil {
return err
}
msg.Make(packtype,data)
slf.conn.Write(msg.Bytes())
return nil
}

View File

@@ -1,331 +0,0 @@
package network
import (
"bufio"
"encoding/binary"
"fmt"
"github.com/duanhf2012/origin/service"
"github.com/duanhf2012/origin/util"
"github.com/golang/protobuf/proto"
"io"
"net"
"unsafe"
"os"
"time"
)
type ITcpSocketServerReciver interface {
OnConnected(pClient *SClient)
OnDisconnect(pClient *SClient)
OnRecvMsg(pClient *SClient, pPack *MsgBasePack)
}
type SClient struct {
id uint64
conn net.Conn
recvPack *util.SyncQueue
sendPack *util.SyncQueue
tcpserver *TcpSocketServer
remoteip string
starttime int64
bClose bool
}
type TcpSocketServer struct {
listenAddr string //ip:port
mapClient util.Map
MaxRecvPackSize uint16
MaxSendPackSize uint16
iReciver ITcpSocketServerReciver
nodelay bool
}
type MsgBasePack struct {
PackSize uint16
PackType uint16
Body []byte
StartTime time.Time
}
func (slf *TcpSocketServer) Register(listenAddr string,iReciver ITcpSocketServerReciver){
slf.listenAddr = listenAddr
slf.iReciver = iReciver
}
func (slf *TcpSocketServer) Start(){
slf.MaxRecvPackSize = 2048
slf.MaxSendPackSize = 40960
util.Go(slf.listenServer)
}
func (slf *TcpSocketServer) listenServer(){
slf.nodelay = true
listener, err := net.Listen("tcp", slf.listenAddr)
if err != nil {
service.GetLogger().Printf(service.LEVER_FATAL, "TcpSocketServer Listen error %+v",err)
os.Exit(1)
}
var clientId uint64
for {
conn, aerr := listener.Accept()
if aerr != nil {
service.GetLogger().Printf(service.LEVER_FATAL, "TcpSocketServer accept error %+v",aerr)
continue
}
if slf.nodelay {
//conn.(ifaceSetNoDelay)
}
for {
clientId += 1
if slf.mapClient.Get(clientId)!= nil {
continue
}
sc :=&SClient{id:clientId,conn:conn,tcpserver:slf,remoteip:conn.RemoteAddr().String(),starttime:time.Now().UnixNano(),
recvPack:util.NewSyncQueue(),sendPack:util.NewSyncQueue()}
slf.mapClient.Set(clientId,sc)
util.Go(sc.listendata)
//收来自客户端数据
util.Go(sc.onrecv)
//发送数据队列
util.Go(sc.onsend)
break
}
}
}
func (slf *TcpSocketServer) Close(clientid uint64) error {
pClient := slf.mapClient.Get(clientid)
if pClient == nil {
return fmt.Errorf("clientid %d is not in connect pool.",clientid)
}
pClient.(*SClient).Close()
return nil
}
func (slf *TcpSocketServer) SendMsg(clientid uint64,packtype uint16,message proto.Message) error{
pClient := slf.mapClient.Get(clientid)
if pClient == nil {
return fmt.Errorf("clientid %d is not in connect pool.",clientid)
}
return pClient.(*SClient).SendMsg(packtype,message)
}
func (slf *TcpSocketServer) Send(clientid uint64,pack *MsgBasePack) error{
pClient := slf.mapClient.Get(clientid)
if pClient == nil {
return fmt.Errorf("clientid %d is not in connect pool.",clientid)
}
return pClient.(*SClient).Send(pack)
}
func (slf *SClient) listendata(){
defer func() {
slf.Close()
slf.tcpserver.mapClient.Del(slf.id)
slf.tcpserver.iReciver.OnDisconnect(slf)
service.GetLogger().Printf(service.LEVER_DEBUG, "clent id %d return listendata...",slf.id)
}()
slf.tcpserver.iReciver.OnConnected(slf)
//获取一个连接的reader读取流
reader := bufio.NewReader(slf.conn)
//临时接受数据的buff
var buff []byte //tmprecvbuf
var tmpbuff []byte
var buffDataSize uint16
tmpbuff = make([]byte,2048)
//解析包数据
var pack MsgBasePack
for {
n,err := reader.Read(tmpbuff)
if err != nil || err == io.EOF {
service.GetLogger().Printf(service.LEVER_INFO, "clent id %d is disconnect %+v",slf.id,err)
return
}
buff = append(buff,tmpbuff[:n]...)
buffDataSize += uint16(n)
if buffDataSize> slf.tcpserver.MaxRecvPackSize {
service.GetLogger().Print(service.LEVER_WARN,"recv client id %d data size %d is over %d",slf.id,buffDataSize,slf.tcpserver.MaxRecvPackSize)
return
}
fillsize,bfillRet,fillhead := pack.FillData(buff,buffDataSize)
//提交校验头
if fillhead == true {
if pack.PackSize>slf.tcpserver.MaxRecvPackSize {
service.GetLogger().Printf(service.LEVER_WARN, "VerifyPackType error clent id %d is disconnect %d,%d",slf.id,pack.PackType, pack.PackSize)
return
}
}
if bfillRet == true {
slf.recvPack.Push(pack)
pack = MsgBasePack{}
}
if fillsize>0 {
buff = append(buff[fillsize:])
buffDataSize -= fillsize
}
}
}
func (slf *MsgBasePack) Bytes() []byte{
var bRet []byte
bRet = make([]byte,4)
binary.BigEndian.PutUint16(bRet,slf.PackSize)
binary.BigEndian.PutUint16(bRet[2:],slf.PackType)
bRet = append(bRet,slf.Body...)
return bRet
}
//返回值:填充多少字节,是否完成,是否填充头
func (slf *MsgBasePack) FillData(bdata []byte,datasize uint16) (uint16,bool,bool) {
var fillsize uint16
fillhead := false
//解包头
if slf.PackSize ==0 {
if datasize < 4 {
return 0,false,fillhead
}
slf.PackSize= binary.BigEndian.Uint16(bdata[:2])
slf.PackType= binary.BigEndian.Uint16(bdata[2:4])
fillsize += 4
fillhead = true
}
//解包体
if slf.PackSize>0 && datasize+4>=slf.PackSize {
slf.Body = append(slf.Body, bdata[fillsize:slf.PackSize]...)
fillsize += (slf.PackSize - fillsize)
return fillsize,true,fillhead
}
return fillsize,false,fillhead
}
func (slf *MsgBasePack) Clear() {
}
func (slf *MsgBasePack) Make(packtype uint16,data []byte) {
slf.PackType = packtype
slf.Body = data
slf.PackSize = uint16(unsafe.Sizeof(slf.PackType)*2)+uint16(len(data))
}
func (slf *SClient) Send(pack *MsgBasePack) error {
if slf.bClose == true {
return fmt.Errorf("client id %d is close!",slf.id)
}
slf.sendPack.Push(pack)
return nil
}
func (slf *SClient) SendMsg(packtype uint16,message proto.Message) error{
if slf.bClose == true {
return fmt.Errorf("client id %d is close!",slf.id)
}
var msg MsgBasePack
data,err := proto.Marshal(message)
if err != nil {
return err
}
msg.Make(packtype,data)
slf.sendPack.Push(&msg)
return nil
}
func (slf *SClient) onsend(){
defer func() {
slf.Close()
service.GetLogger().Printf(service.LEVER_DEBUG, "clent id %d return onsend...",slf.id)
}()
for {
pack,ok := slf.sendPack.TryPop()
if slf.bClose == true {
break
}
if ok == false || pack == nil {
time.Sleep(time.Millisecond*1)
continue
}
pPackData := pack.(*MsgBasePack)
_,e := slf.conn.Write(pPackData.Bytes())
if e!=nil {
service.GetLogger().Printf(service.LEVER_DEBUG, "clent id %d write error...",slf.id)
return
}
//fmt.Print("xxxxxxxxxxxxxxx:",n,e)
}
}
func (slf *SClient) onrecv(){
defer func() {
slf.Close()
service.GetLogger().Printf(service.LEVER_DEBUG, "clent id %d return onrecv...",slf.id)
}()
for {
pack,ok := slf.recvPack.TryPop()
if slf.bClose == true {
break
}
if ok == false || pack == nil {
time.Sleep(time.Millisecond*1)
continue
}
pMsg := pack.(MsgBasePack)
slf.tcpserver.iReciver.OnRecvMsg(slf,&pMsg)
}
}
func (slf *SClient) Close(){
if slf.bClose == false {
slf.conn.Close()
slf.bClose = true
slf.recvPack.Close()
slf.sendPack.Close()
}
}
func (slf *SClient) GetId() uint64{
return slf.id
}

View File

@@ -1,259 +0,0 @@
package network
import (
"errors"
"fmt"
"net/http"
"net/url"
"runtime/debug"
"github.com/duanhf2012/origin/service"
"github.com/duanhf2012/origin/sysmodule"
"github.com/gorilla/websocket"
"time"
)
//IWebsocketClient ...
type IWebsocketClient interface {
Init(slf IWebsocketClient, strurl, strProxyPath string, timeoutsec time.Duration) error
Start() error
WriteMessage(msg []byte) error
OnDisconnect() error
OnConnected() error
OnReadMessage(msg []byte) error
ReConnect()
}
//WebsocketClient ...
type WebsocketClient struct {
WsDailer *websocket.Dialer
conn *websocket.Conn
url string
state int //0未连接状态 1正在重连 2连接状态
bwritemsg chan []byte
closer chan bool
slf IWebsocketClient
timeoutsec time.Duration
bRun bool
ping []byte
}
const (
MAX_WRITE_MSG = 10240
)
//Init ...
func (ws *WebsocketClient) Init(slf IWebsocketClient, strurl, strProxyPath string, timeoutsec time.Duration) error {
ws.timeoutsec = timeoutsec
ws.slf = slf
if strProxyPath != "" {
proxy := func(_ *http.Request) (*url.URL, error) {
return url.Parse(strProxyPath)
}
if timeoutsec > 0 {
tosec := timeoutsec * time.Second
ws.WsDailer = &websocket.Dialer{Proxy: proxy, HandshakeTimeout: tosec}
} else {
ws.WsDailer = &websocket.Dialer{Proxy: proxy}
}
} else {
if timeoutsec > 0 {
tosec := timeoutsec * time.Second
ws.WsDailer = &websocket.Dialer{HandshakeTimeout: tosec}
} else {
ws.WsDailer = &websocket.Dialer{}
}
}
ws.url = strurl
ws.ping = []byte(`ping`)
return nil
}
func (ws *WebsocketClient) SetPing(ping string) {
ws.ping = []byte(ping)
}
//OnRun ...
func (ws *WebsocketClient) OnRun() error {
defer func() {
if r := recover(); r != nil {
coreInfo := string(debug.Stack())
coreInfo += "\n" + fmt.Sprintf("Core WebsocketClient url is %s. Core information is %v\n", ws.url, r)
service.GetLogger().Printf(service.LEVER_FATAL, coreInfo)
go ws.OnRun()
}
}()
for {
if ws.bRun == false {
break
}
if ws.state == 0 {
time.Sleep(2 * time.Second)
ws.StartConnect()
} else if ws.state == 1 {
ws.state = 0
close(ws.closer)
ws.conn.Close()
ws.slf.OnDisconnect()
} else if ws.state == 2 {
ws.conn.SetReadDeadline(time.Now().Add(ws.timeoutsec * time.Second))
_, message, err := ws.conn.ReadMessage()
if err != nil {
service.GetLogger().Printf(service.LEVER_WARN, "websocket client is disconnect [%s],information is %v", ws.url, err)
ws.conn.Close()
ws.state = 0
close(ws.closer)
ws.slf.OnDisconnect()
continue
}
ws.slf.OnReadMessage(message)
}
}
return nil
}
//StartConnect ...
func (ws *WebsocketClient) StartConnect() error {
var err error
ws.conn, _, err = ws.WsDailer.Dial(ws.url, nil)
service.GetLogger().Printf(sysmodule.LEVER_INFO, "connecting %s, %+v\n", ws.url, err)
if err != nil {
return err
}
ws.closer = make(chan bool)
ws.bwritemsg = make(chan []byte, MAX_WRITE_MSG)
ws.state = 2
ws.slf.OnConnected()
return nil
}
//Start ...
func (ws *WebsocketClient) Start() error {
if ws.bRun == false {
ws.bRun = true
ws.state = 0
go ws.OnRun()
go ws.writeMsg()
}
return nil
}
//触发
func (ws *WebsocketClient) writeMsg() error {
//dump处理
defer func() {
if r := recover(); r != nil {
coreInfo := string(debug.Stack())
coreInfo += "\n" + fmt.Sprintf("Core WebsocketClient url is %s. Core information is %v\n", ws.url, r)
service.GetLogger().Printf(service.LEVER_FATAL, coreInfo)
go ws.writeMsg()
}
}()
timerC := time.NewTicker(time.Second * 5).C
for {
if ws.bRun == false {
break
}
if ws.state == 0 {
time.Sleep(1 * time.Second)
continue
}
select {
case _, ok := <-ws.closer:
if ok == false {
break
}
case <-timerC:
if ws.state == 2 {
err := ws.WriteMessage(ws.ping)
if err != nil {
service.GetLogger().Printf(service.LEVER_WARN, "websocket client is disconnect [%s],information is %v", ws.url, err)
ws.state = 0
ws.conn.Close()
ws.slf.OnDisconnect()
}
}
case msg, ok := <-ws.bwritemsg:
if ok == true && ws.state == 2 {
ws.conn.SetWriteDeadline(time.Now().Add(ws.timeoutsec * time.Second))
err := ws.conn.WriteMessage(websocket.TextMessage, msg)
if err != nil {
service.GetLogger().Printf(service.LEVER_WARN, "websocket client is disconnect [%s],information is %v", ws.url, err)
ws.state = 0
ws.conn.Close()
ws.slf.OnDisconnect()
}
}
}
}
return nil
}
//ReConnect ...
func (ws *WebsocketClient) ReConnect() {
ws.state = 1
}
//WriteMessage ...
func (ws *WebsocketClient) WriteMessage(msg []byte) error {
if ws.closer == nil || ws.bwritemsg == nil {
service.GetLogger().Printf(service.LEVER_WARN, "WriteMessage data fail,websocket client is disconnect.")
return errors.New("riteMessage data fail,websocket client is disconnect.")
}
select {
case <-ws.closer:
service.GetLogger().Printf(service.LEVER_WARN, "WriteMessage data fail,websocket client is disconnect.")
return errors.New("riteMessage data fail,websocket client is disconnect.")
default:
if len(ws.bwritemsg) < MAX_WRITE_MSG {
ws.bwritemsg <- msg
} else {
service.GetLogger().Printf(service.LEVER_ERROR, "WriteMessage data fail,bwriteMsg is overload.")
return errors.New("WriteMessage data fail,bwriteMsg is overload.")
}
}
return nil
}
//OnDisconnect ...
func (ws *WebsocketClient) OnDisconnect() error {
return nil
}
//OnConnected ...
func (ws *WebsocketClient) OnConnected() error {
return nil
}
//OnReadMessage 触发
func (ws *WebsocketClient) OnReadMessage(msg []byte) error {
return nil
}
//Stop ...
func (ws *WebsocketClient) Stop() {
ws.bRun = false
}

View File

@@ -1,328 +0,0 @@
package network
import (
"crypto/tls"
"errors"
"fmt"
"github.com/duanhf2012/origin/util"
"net/http"
"os"
"runtime/debug"
"sync"
"time"
"github.com/duanhf2012/origin/service"
"github.com/duanhf2012/origin/sysmodule"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
"github.com/gotoxu/cors"
)
type IWebsocketServer interface {
SendMsg(clientid uint64, messageType int, msg []byte) bool
CreateClient(conn *websocket.Conn) *WSClient
Disconnect(clientid uint64)
ReleaseClient(pclient *WSClient)
Clients() []uint64
BroadcastMsg(messageType int, msg []byte) int
}
type IMessageReceiver interface {
initReciver(messageReciver IMessageReceiver, websocketServer IWebsocketServer)
OnConnected(clientid uint64)
OnDisconnect(clientid uint64, err error)
OnRecvMsg(clientid uint64, msgtype int, data []byte)
OnHandleHttp(w http.ResponseWriter, r *http.Request)
IsInit() bool
}
type Reciver struct {
messageReciver IMessageReceiver
bEnableCompression bool
}
type BaseMessageReciver struct {
messageReciver IMessageReceiver
WsServer IWebsocketServer
}
type WSClient struct {
clientid uint64
conn *websocket.Conn
bwritemsg chan WSMessage
}
type WSMessage struct {
msgtype int
bwritemsg []byte
}
type WebsocketServer struct {
wsUri string
maxClientid uint64 //记录当前最新clientid
mapClient map[uint64]*WSClient
locker sync.RWMutex
port uint16
httpserver *http.Server
reciver map[string]Reciver
caList []CA
iswss bool
}
const (
MAX_MSG_COUNT = 20480
)
func (slf *WebsocketServer) Init(port uint16) {
slf.port = port
slf.mapClient = make(map[uint64]*WSClient)
}
func (slf *WebsocketServer) CreateClient(conn *websocket.Conn) *WSClient {
slf.locker.Lock()
slf.maxClientid++
clientid := slf.maxClientid
pclient := &WSClient{clientid, conn, make(chan WSMessage, MAX_MSG_COUNT+1)}
slf.mapClient[pclient.clientid] = pclient
slf.locker.Unlock()
service.GetLogger().Printf(sysmodule.LEVER_INFO, "Client id %d is connected.", clientid)
return pclient
}
func (slf *WebsocketServer) ReleaseClient(pclient *WSClient) {
pclient.conn.Close()
slf.locker.Lock()
delete(slf.mapClient, pclient.clientid)
slf.locker.Unlock()
//关闭写管道
close(pclient.bwritemsg)
service.GetLogger().Printf(sysmodule.LEVER_INFO, "Client id %d is disconnected.", pclient.clientid)
}
func (slf *WebsocketServer) SetupReciver(pattern string, messageReciver IMessageReceiver, bEnableCompression bool) {
messageReciver.initReciver(messageReciver, slf)
if slf.reciver == nil {
slf.reciver = make(map[string]Reciver)
}
slf.reciver[pattern] = Reciver{messageReciver, bEnableCompression}
}
func (slf *WebsocketServer) startListen() {
listenPort := fmt.Sprintf(":%d", slf.port)
var tlscatList []tls.Certificate
var tlsConfig *tls.Config
for _, cadata := range slf.caList {
cer, err := tls.LoadX509KeyPair(cadata.certfile, cadata.keyfile)
if err != nil {
service.GetLogger().Printf(sysmodule.LEVER_FATAL, "load CA %s-%s file is error :%s", cadata.certfile, cadata.keyfile, err.Error())
os.Exit(1)
return
}
tlscatList = append(tlscatList, cer)
}
if len(tlscatList) > 0 {
tlsConfig = &tls.Config{Certificates: tlscatList}
}
slf.httpserver = &http.Server{
Addr: listenPort,
Handler: slf.initRouterHandler(),
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
MaxHeaderBytes: 1 << 20,
TLSConfig: tlsConfig,
}
var err error
if slf.iswss == true {
err = slf.httpserver.ListenAndServeTLS("", "")
} else {
err = slf.httpserver.ListenAndServe()
}
if err != nil {
service.GetLogger().Printf(sysmodule.LEVER_FATAL, "http.ListenAndServe(%d, nil) error:%v\n", slf.port, err)
os.Exit(1)
}
}
func (slf *WSClient) startSendMsg() {
for {
msgbuf, ok := <-slf.bwritemsg
if ok == false {
break
}
slf.conn.SetWriteDeadline(time.Now().Add(15 * time.Second))
err := slf.conn.WriteMessage(msgbuf.msgtype, msgbuf.bwritemsg)
if err != nil {
service.GetLogger().Printf(sysmodule.LEVER_INFO, "write client id %d is error :%v\n", slf.clientid, err)
break
}
}
}
func (slf *WebsocketServer) Start() {
go slf.startListen()
}
func (slf *WebsocketServer) Clients() []uint64 {
slf.locker.RLock()
defer slf.locker.RUnlock()
r := make([]uint64, 0, len(slf.mapClient))
for i, _ := range slf.mapClient {
r = append(r, i)
}
return r
}
func (slf *WebsocketServer) BroadcastMsg(messageType int, msg []byte) int {
slf.locker.RLock()
defer slf.locker.RUnlock()
err := 0
wsMsg := WSMessage{messageType, msg}
for _, value := range slf.mapClient {
if len(value.bwritemsg) >= MAX_MSG_COUNT {
service.GetLogger().Printf(sysmodule.LEVER_ERROR, "message chan is full :%d\n", len(value.bwritemsg))
err++
}
value.bwritemsg <- wsMsg
}
return err
}
func (slf *WebsocketServer) SendMsg(clientid uint64, messageType int, msg []byte) bool {
slf.locker.RLock()
defer slf.locker.RUnlock()
value, ok := slf.mapClient[clientid]
if ok == false {
return false
}
if len(value.bwritemsg) >= MAX_MSG_COUNT {
service.GetLogger().Printf(sysmodule.LEVER_ERROR, "message chan is full :%d\n", len(value.bwritemsg))
return false
}
value.bwritemsg <- WSMessage{messageType, msg}
return true
}
func (slf *WebsocketServer) Disconnect(clientid uint64) {
slf.locker.Lock()
defer slf.locker.Unlock()
value, ok := slf.mapClient[clientid]
if ok == false {
return
}
value.conn.Close()
}
func (slf *WebsocketServer) Stop() {
}
func (slf *BaseMessageReciver) startReadMsg(pclient *WSClient) {
defer func() {
if r := recover(); r != nil {
var coreInfo string
coreInfo = string(debug.Stack())
coreInfo += "\n" + fmt.Sprintf("Core information is %v\n", r)
service.GetLogger().Printf(service.LEVER_FATAL, coreInfo)
slf.messageReciver.OnDisconnect(pclient.clientid, errors.New("Core dump"))
slf.WsServer.ReleaseClient(pclient)
}
}()
var maxTimeStamp int64
var maxMsgType int
logMinMsgTime :=time.Millisecond*300
statisticsIntervalTm := util.Timer{}
statisticsIntervalTm.SetupTimer(1000 * 15)//15秒间隔
for {
pclient.conn.SetReadDeadline(time.Now().Add(15 * time.Second))
msgtype, message, err := pclient.conn.ReadMessage()
if err != nil {
slf.messageReciver.OnDisconnect(pclient.clientid, err)
slf.WsServer.ReleaseClient(pclient)
return
}
if statisticsIntervalTm.CheckTimeOut() {
service.GetLogger().Printf(service.LEVER_INFO, "MaxMsgtype:%d,diff:%d",maxMsgType,maxTimeStamp)
}
//记录处理时间
startRecvTm := time.Now().UnixNano()
slf.messageReciver.OnRecvMsg(pclient.clientid, msgtype, message)
diff := time.Now().UnixNano() - startRecvTm
if diff> maxTimeStamp{
maxTimeStamp = diff
maxMsgType = msgtype
}
if diff >= int64(logMinMsgTime) {
service.GetLogger().Printf(service.LEVER_WARN, "Process slowly MaxMsgtype:%d,diff:%d",maxMsgType,maxTimeStamp)
}
}
}
func (slf *BaseMessageReciver) initReciver(messageReciver IMessageReceiver, websocketServer IWebsocketServer) {
slf.messageReciver = messageReciver
slf.WsServer = websocketServer
}
func (slf *BaseMessageReciver) OnConnected(clientid uint64) {
}
func (slf *BaseMessageReciver) OnDisconnect(clientid uint64, err error) {
}
func (slf *BaseMessageReciver) OnRecvMsg(clientid uint64, msgtype int, data []byte) {
}
func (slf *BaseMessageReciver) OnHandleHttp(w http.ResponseWriter, r *http.Request) {
conn, err := websocket.Upgrade(w, r, w.Header(), 1024, 1024)
if err != nil {
http.Error(w, "Could not open websocket connection", http.StatusBadRequest)
return
}
pclient := slf.WsServer.CreateClient(conn)
slf.messageReciver.OnConnected(pclient.clientid)
go pclient.startSendMsg()
go slf.startReadMsg(pclient)
}
func (slf *WebsocketServer) initRouterHandler() http.Handler {
r := mux.NewRouter()
for pattern, reciver := range slf.reciver {
if reciver.messageReciver.IsInit() == true {
r.HandleFunc(pattern, reciver.messageReciver.OnHandleHttp)
}
}
cors := cors.AllowAll()
return cors.Handler(r)
}
func (slf *WebsocketServer) SetWSS(certfile string, keyfile string) bool {
if certfile == "" || keyfile == "" {
return false
}
slf.caList = append(slf.caList, CA{certfile, keyfile})
slf.iswss = true
return true
}

View File

@@ -1,323 +0,0 @@
package network
import (
"crypto/tls"
"errors"
"fmt"
"net/http"
"os"
"reflect"
"runtime/debug"
"sync"
"time"
"github.com/duanhf2012/origin/service"
"github.com/duanhf2012/origin/sysmodule"
"github.com/gotoxu/cors"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
)
type IWSAgentServer interface {
SendMsg(agentid uint32, messageType int, msg []byte) bool
CreateAgent(urlPattern string, conn *websocket.Conn) IAgent
Disconnect(agentid uint32)
ReleaseAgent(iagent IAgent)
}
type IAgent interface {
initAgent(conn *websocket.Conn, agentid uint32, iagent IAgent, WSAgentServer IWSAgentServer)
startReadMsg()
startSendMsg()
OnConnected()
OnDisconnect(err error)
OnRecvMsg(msgtype int, data []byte)
//OnHandleHttp(w http.ResponseWriter, r *http.Request)
GetAgentId() uint32
getConn() *websocket.Conn
getWriteMsgChan() chan WSAgentMessage
}
type BaseAgent struct {
service.BaseModule
WsServer IWSAgentServer
agent IAgent
agentid uint32
conn *websocket.Conn
bwritemsg chan WSAgentMessage
iagent IAgent
}
type WSAgentMessage struct {
msgtype int
bwritemsg []byte
}
type WSAgentServer struct {
service.BaseModule
wsUri string
maxAgentid uint32 //记录当前最新agentid
//mapAgent map[uint32]IAgent
locker sync.Mutex
port uint16
httpserver *http.Server
regAgent map[string]reflect.Type
caList []CA
iswss bool
}
const (
MAX_AGENT_MSG_COUNT = 20480
)
func (slf *WSAgentServer) Init(port uint16) {
slf.port = port
}
func (slf *WSAgentServer) CreateAgent(urlPattern string, conn *websocket.Conn) IAgent {
slf.locker.Lock()
iAgent, ok := slf.regAgent[urlPattern]
if ok == false {
slf.locker.Unlock()
service.GetLogger().Printf(sysmodule.LEVER_WARN, "Cannot find %s pattern!", urlPattern)
return nil
}
v := reflect.New(iAgent).Elem().Addr().Interface()
if v == nil {
slf.locker.Unlock()
service.GetLogger().Printf(sysmodule.LEVER_WARN, "new %s pattern agent type is error!", urlPattern)
return nil
}
pModule := v.(service.IModule)
iagent := v.(IAgent)
slf.maxAgentid++
agentid := slf.maxAgentid
iagent.initAgent(conn, agentid, iagent, slf)
slf.AddModule(pModule)
slf.locker.Unlock()
service.GetLogger().Printf(sysmodule.LEVER_INFO, "Agent id %d is connected.", iagent.GetAgentId())
return iagent
}
func (slf *WSAgentServer) ReleaseAgent(iagent IAgent) {
iagent.getConn().Close()
slf.locker.Lock()
slf.ReleaseModule(iagent.GetAgentId())
//delete(slf.mapAgent, iagent.GetAgentId())
slf.locker.Unlock()
//关闭写管道
close(iagent.getWriteMsgChan())
service.GetLogger().Printf(sysmodule.LEVER_INFO, "Agent id %d is disconnected.", iagent.GetAgentId())
}
func (slf *WSAgentServer) SetupAgent(pattern string, agent IAgent, bEnableCompression bool) {
if slf.regAgent == nil {
slf.regAgent = make(map[string]reflect.Type)
}
slf.regAgent[pattern] = reflect.TypeOf(agent).Elem() //reflect.TypeOf(agent).Elem()
}
func (slf *WSAgentServer) startListen() {
listenPort := fmt.Sprintf(":%d", slf.port)
var tlscatList []tls.Certificate
var tlsConfig *tls.Config
for _, cadata := range slf.caList {
cer, err := tls.LoadX509KeyPair(cadata.certfile, cadata.keyfile)
if err != nil {
service.GetLogger().Printf(sysmodule.LEVER_FATAL, "load CA %s-%s file is error :%s", cadata.certfile, cadata.keyfile, err.Error())
os.Exit(1)
return
}
tlscatList = append(tlscatList, cer)
}
if len(tlscatList) > 0 {
tlsConfig = &tls.Config{Certificates: tlscatList}
}
slf.httpserver = &http.Server{
Addr: listenPort,
Handler: slf.initRouterHandler(),
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
MaxHeaderBytes: 1 << 20,
TLSConfig: tlsConfig,
}
var err error
if slf.iswss == true {
err = slf.httpserver.ListenAndServeTLS("", "")
} else {
err = slf.httpserver.ListenAndServe()
}
if err != nil {
service.GetLogger().Printf(sysmodule.LEVER_FATAL, "http.ListenAndServe(%d, nil) error:%v\n", slf.port, err)
os.Exit(1)
}
}
func (slf *BaseAgent) startSendMsg() {
for {
msgbuf, ok := <-slf.bwritemsg
if ok == false {
break
}
slf.conn.SetWriteDeadline(time.Now().Add(15 * time.Second))
err := slf.conn.WriteMessage(msgbuf.msgtype, msgbuf.bwritemsg)
if err != nil {
service.GetLogger().Printf(sysmodule.LEVER_INFO, "write agent id %d is error :%v\n", slf.GetAgentId(), err)
break
}
}
}
func (slf *WSAgentServer) Start() {
go slf.startListen()
}
func (slf *WSAgentServer) GetAgentById(agentid uint32) IAgent {
pModule := slf.GetModuleById(agentid)
if pModule == nil {
service.GetLogger().Printf(sysmodule.LEVER_ERROR, "GetAgentById :%d is fail.\n", agentid)
return nil
}
return pModule.(IAgent)
}
func (slf *WSAgentServer) SendMsg(agentid uint32, messageType int, msg []byte) bool {
slf.locker.Lock()
defer slf.locker.Unlock()
iagent := slf.GetAgentById(agentid)
if iagent == nil {
return false
}
if len(iagent.getWriteMsgChan()) >= MAX_AGENT_MSG_COUNT {
service.GetLogger().Printf(sysmodule.LEVER_ERROR, "message chan is full :%d\n", len(iagent.getWriteMsgChan()))
return false
}
iagent.getWriteMsgChan() <- WSAgentMessage{messageType, msg}
return true
}
func (slf *WSAgentServer) Disconnect(agentid uint32) {
slf.locker.Lock()
defer slf.locker.Unlock()
iagent := slf.GetAgentById(agentid)
if iagent == nil {
return
}
iagent.getConn().Close()
}
func (slf *WSAgentServer) Stop() {
}
func (slf *BaseAgent) startReadMsg() {
defer func() {
if r := recover(); r != nil {
var coreInfo string
coreInfo = string(debug.Stack())
coreInfo += "\n" + fmt.Sprintf("Core information is %v\n", r)
service.GetLogger().Printf(service.LEVER_FATAL, coreInfo)
slf.agent.OnDisconnect(errors.New("Core dump"))
slf.WsServer.ReleaseAgent(slf.agent)
}
}()
slf.agent.OnConnected()
for {
slf.conn.SetReadDeadline(time.Now().Add(15 * time.Second))
msgtype, message, err := slf.conn.ReadMessage()
if err != nil {
slf.agent.OnDisconnect(err)
slf.WsServer.ReleaseAgent(slf.agent)
return
}
slf.agent.OnRecvMsg(msgtype, message)
}
}
func (slf *WSAgentServer) initRouterHandler() http.Handler {
r := mux.NewRouter()
for pattern, _ := range slf.regAgent {
r.HandleFunc(pattern, slf.OnHandleHttp)
}
cors := cors.AllowAll()
return cors.Handler(r)
}
func (slf *WSAgentServer) SetWSS(certfile string, keyfile string) bool {
if certfile == "" || keyfile == "" {
return false
}
slf.caList = append(slf.caList, CA{certfile, keyfile})
slf.iswss = true
return true
}
func (slf *BaseAgent) GetAgentId() uint32 {
return slf.agentid
}
func (slf *BaseAgent) initAgent(conn *websocket.Conn, agentid uint32, iagent IAgent, WSAgentServer IWSAgentServer) {
slf.agent = iagent
slf.WsServer = WSAgentServer
slf.bwritemsg = make(chan WSAgentMessage, MAX_AGENT_MSG_COUNT)
slf.agentid = agentid
slf.conn = conn
}
func (slf *BaseAgent) OnConnected() {
}
func (slf *BaseAgent) OnDisconnect(err error) {
}
func (slf *BaseAgent) OnRecvMsg(msgtype int, data []byte) {
}
func (slf *BaseAgent) getConn() *websocket.Conn {
return slf.conn
}
func (slf *BaseAgent) getWriteMsgChan() chan WSAgentMessage {
return slf.bwritemsg
}
func (slf *BaseAgent) SendMsg(agentid uint32, messageType int, msg []byte) bool {
return slf.WsServer.SendMsg(agentid, messageType, msg)
}
func (slf *WSAgentServer) OnHandleHttp(w http.ResponseWriter, r *http.Request) {
conn, err := websocket.Upgrade(w, r, w.Header(), 1024, 1024)
if err != nil {
http.Error(w, "Could not open websocket connection!", http.StatusBadRequest)
return
}
agent := slf.CreateAgent(r.URL.Path, conn)
fmt.Print(agent.GetAgentId())
slf.AddModule(agent.(service.IModule))
go agent.startSendMsg()
go agent.startReadMsg()
}