From e590f0dce9e601374c7a6f44c7094512f555486c Mon Sep 17 00:00:00 2001 From: duanhf2012 Date: Tue, 21 Apr 2020 14:40:31 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9Ewebsocket?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- event/eventtype.go | 1 + example/GateService/GateService.go | 19 +++ example/config/cluster/subnet/cluster.json | 2 +- example/config/cluster/subnet/service.json | 6 + example/main.go | 3 +- network/processor/gobprocessor.go | 1 - network/processor/msgpackprocessor.go | 1 - network/processor/protobufprocessor.go | 1 - network/ws_client.go | 132 +++++++++++++++ network/ws_conn.go | 138 ++++++++++++++++ network/ws_server.go | 154 ++++++++++++++++++ sysservice/tcpservice.go | 1 - sysservice/wsservice.go | 178 +++++++++++++++++++++ 13 files changed, 631 insertions(+), 6 deletions(-) delete mode 100644 network/processor/gobprocessor.go delete mode 100644 network/processor/msgpackprocessor.go delete mode 100644 network/processor/protobufprocessor.go create mode 100644 network/ws_client.go create mode 100644 network/ws_conn.go create mode 100644 network/ws_server.go diff --git a/event/eventtype.go b/event/eventtype.go index 249ec9d..071e0b6 100644 --- a/event/eventtype.go +++ b/event/eventtype.go @@ -6,6 +6,7 @@ type EventType int const ( Sys_Event_Tcp EventType = 1 Sys_Event_Http_Event EventType = 2 + Sys_Event_WebSocket EventType = 3 Sys_Event_User_Define EventType = 1000 ) diff --git a/example/GateService/GateService.go b/example/GateService/GateService.go index ac4fcf1..a646dc9 100644 --- a/example/GateService/GateService.go +++ b/example/GateService/GateService.go @@ -14,6 +14,7 @@ import ( type GateService struct { service.Service processor *processor.PBProcessor + processor2 *processor.PBProcessor httpRouter sysservice.IHttpRouter } @@ -24,6 +25,15 @@ func (slf *GateService) OnInit() error{ slf.processor.RegisterConnected(slf.OnConnected) tcpervice.SetProcessor(slf.processor,slf.GetEventHandler()) + + wsService := node.GetService("WSService").(*sysservice.WSService) + slf.processor2 = &processor.PBProcessor{} + slf.processor2.RegisterDisConnected(slf.OnWSDisconnected) + slf.processor2.RegisterConnected(slf.OnWSConnected) + slf.processor2.Register() + wsService.SetProcessor(slf.processor2,slf.GetEventHandler()) + + httpervice := node.GetService("HttpService").(*sysservice.HttpService) slf.httpRouter = sysservice.NewHttpHttpRouter() httpervice.SetHttpRouter(slf.httpRouter,slf.GetEventHandler()) @@ -78,3 +88,12 @@ func (slf *GateService) OnConnected(clientid uint64){ func (slf *GateService) OnDisconnected(clientid uint64){ fmt.Printf("client id %d disconnected",clientid) } + +func (slf *GateService) OnWSConnected(clientid uint64){ + fmt.Printf("client id %d connected",clientid) +} + + +func (slf *GateService) OnWSDisconnected(clientid uint64){ + fmt.Printf("client id %d disconnected",clientid) +} diff --git a/example/config/cluster/subnet/cluster.json b/example/config/cluster/subnet/cluster.json index c51337c..673f938 100644 --- a/example/config/cluster/subnet/cluster.json +++ b/example/config/cluster/subnet/cluster.json @@ -5,7 +5,7 @@ "ListenAddr":"127.0.0.1:8001", "NodeName": "Node_Test1", "remark":"//以_打头的,表示只在本机进程,不对整个子网开发", - "ServiceList": ["TestService1","TestService2","TestServiceCall","GateService","TcpService","HttpService"] + "ServiceList": ["TestService1","TestService2","TestServiceCall","GateService","TcpService","HttpService","WSService"] } ] } \ No newline at end of file diff --git a/example/config/cluster/subnet/service.json b/example/config/cluster/subnet/service.json index c76c7e3..cee8ab3 100644 --- a/example/config/cluster/subnet/service.json +++ b/example/config/cluster/subnet/service.json @@ -19,5 +19,11 @@ "LittleEndian":false, "MinMsgLen":4, "MaxMsgLen":65535 + }, + "WSService":{ + "ListenAddr":"0.0.0.0:9031", + "MaxConnNum":3000, + "PendingWriteNum":10000, + "MaxMsgLen":65535 } } \ No newline at end of file diff --git a/example/main.go b/example/main.go index d5061e0..f5f13b9 100644 --- a/example/main.go +++ b/example/main.go @@ -305,8 +305,9 @@ func main(){ httpService := &sysservice.HttpService{} + wsService := &sysservice.WSService{} - node.Setup(tcpService,gateService,httpService) + node.Setup(tcpService,gateService,httpService,wsService) node.OpenProfilerReport(time.Second*10) node.Start() } diff --git a/network/processor/gobprocessor.go b/network/processor/gobprocessor.go deleted file mode 100644 index 95af6c9..0000000 --- a/network/processor/gobprocessor.go +++ /dev/null @@ -1 +0,0 @@ -package processor diff --git a/network/processor/msgpackprocessor.go b/network/processor/msgpackprocessor.go deleted file mode 100644 index 95af6c9..0000000 --- a/network/processor/msgpackprocessor.go +++ /dev/null @@ -1 +0,0 @@ -package processor diff --git a/network/processor/protobufprocessor.go b/network/processor/protobufprocessor.go deleted file mode 100644 index 95af6c9..0000000 --- a/network/processor/protobufprocessor.go +++ /dev/null @@ -1 +0,0 @@ -package processor diff --git a/network/ws_client.go b/network/ws_client.go new file mode 100644 index 0000000..c8cfd04 --- /dev/null +++ b/network/ws_client.go @@ -0,0 +1,132 @@ +package network + +import ( + "github.com/duanhf2012/origin/log" + "github.com/gorilla/websocket" + "sync" + "time" +) + +type WSClient struct { + sync.Mutex + Addr string + ConnNum int + ConnectInterval time.Duration + PendingWriteNum int + MaxMsgLen uint32 + HandshakeTimeout time.Duration + AutoReconnect bool + NewAgent func(*WSConn) Agent + dialer websocket.Dialer + conns WebsocketConnSet + wg sync.WaitGroup + closeFlag bool +} + +func (client *WSClient) Start() { + client.init() + + for i := 0; i < client.ConnNum; i++ { + client.wg.Add(1) + go client.connect() + } +} + +func (client *WSClient) init() { + client.Lock() + defer client.Unlock() + + if client.ConnNum <= 0 { + client.ConnNum = 1 + log.Release("invalid ConnNum, reset to %v", client.ConnNum) + } + if client.ConnectInterval <= 0 { + client.ConnectInterval = 3 * time.Second + log.Release("invalid ConnectInterval, reset to %v", client.ConnectInterval) + } + if client.PendingWriteNum <= 0 { + client.PendingWriteNum = 100 + log.Release("invalid PendingWriteNum, reset to %v", client.PendingWriteNum) + } + if client.MaxMsgLen <= 0 { + client.MaxMsgLen = 4096 + log.Release("invalid MaxMsgLen, reset to %v", client.MaxMsgLen) + } + if client.HandshakeTimeout <= 0 { + client.HandshakeTimeout = 10 * time.Second + log.Release("invalid HandshakeTimeout, reset to %v", client.HandshakeTimeout) + } + if client.NewAgent == nil { + log.Fatal("NewAgent must not be nil") + } + if client.conns != nil { + log.Fatal("client is running") + } + + client.conns = make(WebsocketConnSet) + client.closeFlag = false + client.dialer = websocket.Dialer{ + HandshakeTimeout: client.HandshakeTimeout, + } +} + +func (client *WSClient) dial() *websocket.Conn { + for { + conn, _, err := client.dialer.Dial(client.Addr, nil) + if err == nil || client.closeFlag { + return conn + } + + log.Release("connect to %v error: %v", client.Addr, err) + time.Sleep(client.ConnectInterval) + continue + } +} + +func (client *WSClient) connect() { + defer client.wg.Done() + +reconnect: + conn := client.dial() + if conn == nil { + return + } + conn.SetReadLimit(int64(client.MaxMsgLen)) + + client.Lock() + if client.closeFlag { + client.Unlock() + conn.Close() + return + } + client.conns[conn] = struct{}{} + client.Unlock() + + wsConn := newWSConn(conn, client.PendingWriteNum, client.MaxMsgLen) + agent := client.NewAgent(wsConn) + agent.Run() + + // cleanup + wsConn.Close() + client.Lock() + delete(client.conns, conn) + client.Unlock() + agent.OnClose() + + if client.AutoReconnect { + time.Sleep(client.ConnectInterval) + goto reconnect + } +} + +func (client *WSClient) Close() { + client.Lock() + client.closeFlag = true + for conn := range client.conns { + conn.Close() + } + client.conns = nil + client.Unlock() + + client.wg.Wait() +} diff --git a/network/ws_conn.go b/network/ws_conn.go new file mode 100644 index 0000000..5dfa598 --- /dev/null +++ b/network/ws_conn.go @@ -0,0 +1,138 @@ +package network + +import ( + "errors" + "github.com/duanhf2012/origin/log" + "github.com/gorilla/websocket" + "net" + "sync" +) + +type WebsocketConnSet map[*websocket.Conn]struct{} + +type WSConn struct { + sync.Mutex + conn *websocket.Conn + writeChan chan []byte + maxMsgLen uint32 + closeFlag bool +} + +func newWSConn(conn *websocket.Conn, pendingWriteNum int, maxMsgLen uint32) *WSConn { + wsConn := new(WSConn) + wsConn.conn = conn + wsConn.writeChan = make(chan []byte, pendingWriteNum) + wsConn.maxMsgLen = maxMsgLen + + go func() { + for b := range wsConn.writeChan { + if b == nil { + break + } + + err := conn.WriteMessage(websocket.BinaryMessage, b) + if err != nil { + break + } + } + + conn.Close() + wsConn.Lock() + wsConn.closeFlag = true + wsConn.Unlock() + }() + + return wsConn +} + +func (wsConn *WSConn) doDestroy() { + wsConn.conn.UnderlyingConn().(*net.TCPConn).SetLinger(0) + wsConn.conn.Close() + + if !wsConn.closeFlag { + close(wsConn.writeChan) + wsConn.closeFlag = true + } +} + +func (wsConn *WSConn) Destroy() { + wsConn.Lock() + defer wsConn.Unlock() + + wsConn.doDestroy() +} + +func (wsConn *WSConn) Close() { + wsConn.Lock() + defer wsConn.Unlock() + if wsConn.closeFlag { + return + } + + wsConn.doWrite(nil) + wsConn.closeFlag = true +} + +func (wsConn *WSConn) doWrite(b []byte) { + if len(wsConn.writeChan) == cap(wsConn.writeChan) { + log.Debug("close conn: channel full") + wsConn.doDestroy() + return + } + + wsConn.writeChan <- b +} + +func (wsConn *WSConn) LocalAddr() net.Addr { + return wsConn.conn.LocalAddr() +} + +func (wsConn *WSConn) RemoteAddr() net.Addr { + return wsConn.conn.RemoteAddr() +} + +// goroutine not safe +func (wsConn *WSConn) ReadMsg() ([]byte, error) { + _, b, err := wsConn.conn.ReadMessage() + return b, err +} + +// args must not be modified by the others goroutines +func (wsConn *WSConn) WriteMsg(args ...[]byte) error { + wsConn.Lock() + defer wsConn.Unlock() + if wsConn.closeFlag { + return nil + } + + // get len + var msgLen uint32 + for i := 0; i < len(args); i++ { + msgLen += uint32(len(args[i])) + } + + // check len + if msgLen > wsConn.maxMsgLen { + return errors.New("message too long") + } else if msgLen < 1 { + return errors.New("message too short") + } + + // don't copy + if len(args) == 1 { + wsConn.doWrite(args[0]) + return nil + } + + // merge the args + msg := make([]byte, msgLen) + l := 0 + for i := 0; i < len(args); i++ { + copy(msg[l:], args[i]) + l += len(args[i]) + } + + wsConn.doWrite(msg) + + return nil +} diff --git a/network/ws_server.go b/network/ws_server.go new file mode 100644 index 0000000..58aa754 --- /dev/null +++ b/network/ws_server.go @@ -0,0 +1,154 @@ +package network + +import ( + "crypto/tls" + "github.com/duanhf2012/origin/log" + "github.com/gorilla/websocket" + "net" + "net/http" + "sync" + "time" +) + +type WSServer struct { + Addr string + MaxConnNum int + PendingWriteNum int + MaxMsgLen uint32 + HTTPTimeout time.Duration + CertFile string + KeyFile string + NewAgent func(*WSConn) Agent + ln net.Listener + handler *WSHandler +} + +type WSHandler struct { + maxConnNum int + pendingWriteNum int + maxMsgLen uint32 + newAgent func(*WSConn) Agent + upgrader websocket.Upgrader + conns WebsocketConnSet + mutexConns sync.Mutex + wg sync.WaitGroup +} + +func (handler *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + http.Error(w, "Method not allowed", 405) + return + } + conn, err := handler.upgrader.Upgrade(w, r, nil) + if err != nil { + log.Debug("upgrade error: %v", err) + return + } + conn.SetReadLimit(int64(handler.maxMsgLen)) + + handler.wg.Add(1) + defer handler.wg.Done() + + handler.mutexConns.Lock() + if handler.conns == nil { + handler.mutexConns.Unlock() + conn.Close() + return + } + if len(handler.conns) >= handler.maxConnNum { + handler.mutexConns.Unlock() + conn.Close() + log.Debug("too many connections") + return + } + handler.conns[conn] = struct{}{} + handler.mutexConns.Unlock() + + wsConn := newWSConn(conn, handler.pendingWriteNum, handler.maxMsgLen) + agent := handler.newAgent(wsConn) + agent.Run() + + // cleanup + wsConn.Close() + handler.mutexConns.Lock() + delete(handler.conns, conn) + handler.mutexConns.Unlock() + agent.OnClose() +} + +func (server *WSServer) Start() { + ln, err := net.Listen("tcp", server.Addr) + if err != nil { + log.Fatal("%v", err) + } + + if server.MaxConnNum <= 0 { + server.MaxConnNum = 100 + log.Release("invalid MaxConnNum, reset to %v", server.MaxConnNum) + } + if server.PendingWriteNum <= 0 { + server.PendingWriteNum = 100 + log.Release("invalid PendingWriteNum, reset to %v", server.PendingWriteNum) + } + if server.MaxMsgLen <= 0 { + server.MaxMsgLen = 4096 + log.Release("invalid MaxMsgLen, reset to %v", server.MaxMsgLen) + } + if server.HTTPTimeout <= 0 { + server.HTTPTimeout = 10 * time.Second + log.Release("invalid HTTPTimeout, reset to %v", server.HTTPTimeout) + } + if server.NewAgent == nil { + log.Fatal("NewAgent must not be nil") + } + + if server.CertFile != "" || server.KeyFile != "" { + config := &tls.Config{} + config.NextProtos = []string{"http/1.1"} + + var err error + config.Certificates = make([]tls.Certificate, 1) + config.Certificates[0], err = tls.LoadX509KeyPair(server.CertFile, server.KeyFile) + if err != nil { + log.Fatal("%v", err) + } + + ln = tls.NewListener(ln, config) + } + + server.ln = ln + server.handler = &WSHandler{ + maxConnNum: server.MaxConnNum, + pendingWriteNum: server.PendingWriteNum, + maxMsgLen: server.MaxMsgLen, + newAgent: server.NewAgent, + conns: make(WebsocketConnSet), + upgrader: websocket.Upgrader{ + HandshakeTimeout: server.HTTPTimeout, + CheckOrigin: func(_ *http.Request) bool { return true }, + }, + } + + httpServer := &http.Server{ + Addr: server.Addr, + Handler: server.handler, + ReadTimeout: server.HTTPTimeout, + WriteTimeout: server.HTTPTimeout, + MaxHeaderBytes: 1024, + } + + go httpServer.Serve(ln) +} + +func (server *WSServer) Close() { + server.ln.Close() + + server.handler.mutexConns.Lock() + for conn := range server.handler.conns { + conn.Close() + } + server.handler.conns = nil + server.handler.mutexConns.Unlock() + + server.handler.wg.Wait() +} diff --git a/sysservice/tcpservice.go b/sysservice/tcpservice.go index 422bc62..427b9e3 100644 --- a/sysservice/tcpservice.go +++ b/sysservice/tcpservice.go @@ -13,7 +13,6 @@ type TcpService struct { tcpServer network.TCPServer service.Service - tcpService *TcpService mapClientLocker sync.RWMutex mapClient map[uint64] *Client initClientId uint64 diff --git a/sysservice/wsservice.go b/sysservice/wsservice.go index ce47f20..8b02918 100644 --- a/sysservice/wsservice.go +++ b/sysservice/wsservice.go @@ -1,4 +1,182 @@ package sysservice +import ( + "fmt" + "github.com/duanhf2012/origin/event" + "github.com/duanhf2012/origin/log" + "github.com/duanhf2012/origin/network" + "github.com/duanhf2012/origin/service" + "sync" +) + +type WSService struct { + service.Service + wsServer network.WSServer + + mapClientLocker sync.RWMutex + mapClient map[uint64] *WSClient + initClientId uint64 + process network.Processor +} + +type WSPackType int8 +const( + WPT_Connected WSPackType = 0 + WPT_DisConnected WSPackType = 1 + WPT_Pack WSPackType = 2 + WPT_UnknownPack WSPackType = 3 +) + +type WSPack struct { + Type WSPackType //0表示连接 1表示断开 2表示数据 + MsgProcessor network.Processor + ClientId uint64 + Data interface{} +} +const Default_WS_MaxConnNum = 3000 +const Default_WS_PendingWriteNum = 10000 +const Default_WS_MaxMsgLen = 65535 + + +func (slf *WSService) OnInit() error{ + iConfig := slf.GetServiceCfg() + if iConfig == nil { + return fmt.Errorf("%s service config is error!",slf.GetName()) + } + wsCfg := iConfig.(map[string]interface{}) + addr,ok := wsCfg["ListenAddr"] + if ok == false { + return fmt.Errorf("%s service config is error!",slf.GetName()) + } + + slf.wsServer.Addr = addr.(string) + slf.wsServer.MaxConnNum = Default_WS_MaxConnNum + slf.wsServer.PendingWriteNum = Default_WS_PendingWriteNum + slf.wsServer.MaxMsgLen = Default_WS_MaxMsgLen + MaxConnNum,ok := wsCfg["MaxConnNum"] + if ok == true { + slf.wsServer.MaxConnNum = int(MaxConnNum.(float64)) + } + PendingWriteNum,ok := wsCfg["PendingWriteNum"] + if ok == true { + slf.wsServer.PendingWriteNum = int(PendingWriteNum.(float64)) + } + + MaxMsgLen,ok := wsCfg["MaxMsgLen"] + if ok == true { + slf.wsServer.MaxMsgLen = uint32(MaxMsgLen.(float64)) + } + + slf.mapClient = make( map[uint64] *WSClient,slf.wsServer.MaxConnNum) + slf.wsServer.NewAgent =slf.NewWSClient + slf.wsServer.Start() + return nil +} + +func (slf *WSService) WSEventHandler(ev *event.Event) { + pack := ev.Data.(*WSPack) + switch pack.Type { + case WPT_Connected: + pack.MsgProcessor.ConnectedRoute(pack.ClientId) + case WPT_DisConnected: + pack.MsgProcessor.DisConnectedRoute(pack.ClientId) + case WPT_UnknownPack: + pack.MsgProcessor.UnknownMsgRoute(pack.Data,pack.ClientId) + case WPT_Pack: + pack.MsgProcessor.MsgRoute(pack.Data, pack.ClientId) + } +} + +func (slf *WSService) SetProcessor(process network.Processor,handler event.IEventHandler){ + slf.process = process + slf.RegEventReciverFunc(event.Sys_Event_WebSocket,handler,slf.WSEventHandler) +} + +func (slf *WSService) NewWSClient(conn *network.WSConn) network.Agent { + slf.mapClientLocker.Lock() + defer slf.mapClientLocker.Unlock() + + for { + slf.initClientId+=1 + _,ok := slf.mapClient[slf.initClientId] + if ok == true { + continue + } + + pClient := &WSClient{wsConn:conn, id:slf.initClientId} + pClient.wsService = slf + slf.mapClient[slf.initClientId] = pClient + return pClient + } + + return nil +} + +type WSClient struct { + id uint64 + wsConn *network.WSConn + wsService *WSService +} + +func (slf *WSClient) GetId() uint64 { + return slf.id +} + +func (slf *WSClient) Run() { + slf.wsService.NotifyEvent(&event.Event{Type:event.Sys_Event_WebSocket,Data:&WSPack{ClientId:slf.id,Type:WPT_Connected,MsgProcessor:slf.wsService.process}}) + for{ + bytes,err := slf.wsConn.ReadMsg() + if err != nil { + log.Debug("read client id %d is error:%+v",slf.id,err) + break + } + data,err:=slf.wsService.process.Unmarshal(bytes) + if err != nil { + slf.wsService.NotifyEvent(&event.Event{Type:event.Sys_Event_WebSocket,Data:&WSPack{ClientId:slf.id,Type:WPT_UnknownPack,Data:bytes,MsgProcessor:slf.wsService.process}}) + continue + } + slf.wsService.NotifyEvent(&event.Event{Type:event.Sys_Event_WebSocket,Data:&WSPack{ClientId:slf.id,Type:WPT_Pack,Data:data,MsgProcessor:slf.wsService.process}}) + } +} + +func (slf *WSClient) OnClose(){ + slf.wsService.NotifyEvent(&event.Event{Type:event.Sys_Event_WebSocket,Data:&WSPack{ClientId:slf.id,Type:WPT_DisConnected,MsgProcessor:slf.wsService.process}}) + slf.wsService.mapClientLocker.Lock() + defer slf.wsService.mapClientLocker.Unlock() + delete (slf.wsService.mapClient,slf.GetId()) +} + +func (slf *WSService) SendMsg(clientid uint64,msg interface{}) error{ + slf.mapClientLocker.Lock() + client,ok := slf.mapClient[clientid] + if ok == false{ + slf.mapClientLocker.Unlock() + return fmt.Errorf("client %d is disconnect!",clientid) + } + + slf.mapClientLocker.Unlock() + bytes,err := slf.process.Marshal(msg) + if err != nil { + return err + } + return client.wsConn.WriteMsg(bytes) +} + +func (slf *WSService) Close(clientid uint64) { + slf.mapClientLocker.Lock() + defer slf.mapClientLocker.Unlock() + + client,ok := slf.mapClient[clientid] + if ok == false{ + return + } + + if client.wsConn!=nil { + client.wsConn.Close() + } + + return +} +