diff --git a/network/websocketserver.go b/network/websocketserver.go index 5ad3faf..da9e0ae 100644 --- a/network/websocketserver.go +++ b/network/websocketserver.go @@ -14,13 +14,27 @@ import ( type IWebsocketServer interface { SendMsg(clientid uint64, messageType int, msg []byte) bool + CreateClient(conn *websocket.Conn) *WSClient + ReleaseClient(pclient *WSClient) } type IMessageReceiver interface { - OnListen(webServer IWebsocketServer) + initReciver(messageReciver IMessageReceiver, websocketServer IWebsocketServer) + OnConnected(clientid uint64) OnDisconnect(clientid uint64, err error) OnRecvMsg(clientid uint64, msgtype int, data []byte) + OnHandleHttp(w http.ResponseWriter, r *http.Request) +} + +type Reciver struct { + messageReciver IMessageReceiver + bEnableCompression bool +} + +type BaseMessageReciver struct { + messageReciver IMessageReceiver + WsServer IWebsocketServer } type WSClient struct { @@ -38,59 +52,49 @@ type WebsocketServer struct { wsUri string maxClientid uint64 //记录当前最新clientid mapClient map[uint64]*WSClient + locker sync.Mutex - pattern string - port uint16 - bEnableCompression bool - locker sync.Mutex - messageReciver IMessageReceiver + port uint16 httpserver *http.Server + reciver map[string]Reciver } -func (slf *WebsocketServer) wsHandler(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Upgrade(w, r, w.Header(), 1024, 1024) - if err != nil { - http.Error(w, "Could not open websocket connection", http.StatusBadRequest) - } +func (slf *WebsocketServer) Init(port uint16) { - slf.maxClientid++ - pclient := &WSClient{slf.maxClientid, conn, make(chan WSMessage, 1024)} - slf.mapClient[pclient.clientid] = pclient - - slf.messageReciver.OnConnected(pclient.clientid) - go pclient.startSendMsg() - go slf.OnConnected(pclient) -} - -func (slf *WebsocketServer) OnConnected(pclient *WSClient) { - for { - msgtype, message, err := pclient.conn.ReadMessage() - if err != nil { - pclient.conn.Close() - slf.locker.Lock() - defer slf.locker.Unlock() - delete(slf.mapClient, pclient.clientid) - slf.messageReciver.OnDisconnect(pclient.clientid, err) - return - } - - slf.messageReciver.OnRecvMsg(pclient.clientid, msgtype, message) - } -} - -func (slf *WebsocketServer) Init(pattern string, port uint16, messageReciver IMessageReceiver, bEnableCompression bool) { - slf.pattern = pattern slf.port = port - slf.bEnableCompression = bEnableCompression slf.mapClient = make(map[uint64]*WSClient) - slf.messageReciver = messageReciver - //http.HandleFunc(slf.pattern, slf.wsHandler) +} + +func (slf *WebsocketServer) CreateClient(conn *websocket.Conn) *WSClient { + slf.locker.Lock() + slf.maxClientid++ + clientid := slf.maxClientid + pclient := &WSClient{clientid, conn, make(chan WSMessage, 1024)} + slf.mapClient[pclient.clientid] = pclient + slf.locker.Unlock() + + return pclient +} + +func (slf *WebsocketServer) ReleaseClient(pclient *WSClient) { + pclient.conn.Close() + slf.locker.Lock() + delete(slf.mapClient, pclient.clientid) + slf.locker.Unlock() +} + +func (slf *WebsocketServer) SetupReciver(pattern string, messageReciver IMessageReceiver, bEnableCompression bool) { + messageReciver.initReciver(messageReciver, slf) + + if slf.reciver == nil { + slf.reciver = make(map[string]Reciver) + } + slf.reciver[pattern] = Reciver{messageReciver, bEnableCompression} } func (slf *WebsocketServer) startListen() { - listenPort := fmt.Sprintf(":%d", slf.port) slf.httpserver = &http.Server{ @@ -101,13 +105,11 @@ func (slf *WebsocketServer) startListen() { MaxHeaderBytes: 1 << 20, } - slf.messageReciver.OnListen(slf) err := slf.httpserver.ListenAndServe() if err != nil { fmt.Printf("http.ListenAndServe(%d, nil) error\n", slf.port) os.Exit(1) } - } func (slf *WSClient) startSendMsg() { @@ -135,20 +137,55 @@ func (slf *WebsocketServer) SendMsg(clientid uint64, messageType int, msg []byte } func (slf *WebsocketServer) Stop() { +} +func (slf *BaseMessageReciver) startReadMsg(pclient *WSClient) { + for { + msgtype, message, err := pclient.conn.ReadMessage() + if err != nil { + slf.messageReciver.OnDisconnect(pclient.clientid, err) + slf.WsServer.ReleaseClient(pclient) + + return + } + + slf.messageReciver.OnRecvMsg(pclient.clientid, msgtype, message) + } +} + +func (slf *BaseMessageReciver) initReciver(messageReciver IMessageReceiver, websocketServer IWebsocketServer) { + slf.messageReciver = messageReciver + slf.WsServer = websocketServer +} + +func (slf *BaseMessageReciver) OnConnected(clientid uint64) { +} + +func (slf *BaseMessageReciver) OnDisconnect(clientid uint64, err error) { +} + +func (slf *BaseMessageReciver) OnRecvMsg(clientid uint64, msgtype int, data []byte) { +} + +func (slf *BaseMessageReciver) OnHandleHttp(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Upgrade(w, r, w.Header(), 1024, 1024) + if err != nil { + http.Error(w, "Could not open websocket connection", http.StatusBadRequest) + } + + pclient := slf.WsServer.CreateClient(conn) + slf.messageReciver.OnConnected(pclient.clientid) + go pclient.startSendMsg() + go slf.startReadMsg(pclient) } func (slf *WebsocketServer) initRouterHandler() http.Handler { r := mux.NewRouter() - /*r.HandleFunc("/{server:[a-zA-Z0-9]+}/{method:[a-zA-Z0-9]+}", func(w http.ResponseWriter, r *http.Request) { - slf.wsHandler(w, r) - }) - */ - r.HandleFunc(slf.pattern, func(w http.ResponseWriter, r *http.Request) { - slf.wsHandler(w, r) - }) + + for pattern, reciver := range slf.reciver { + r.HandleFunc(pattern, reciver.messageReciver.OnHandleHttp) + } cors := cors.AllowAll() - //return cors.Handler(gziphandler.GzipHandler(r)) return cors.Handler(r) } diff --git a/sysservice/wsserverservice.go b/sysservice/wsserverservice.go index f401c3c..d928233 100644 --- a/sysservice/wsserverservice.go +++ b/sysservice/wsserverservice.go @@ -17,7 +17,7 @@ type WSServerService struct { func (ws *WSServerService) OnInit() error { - ws.wsserver.Init(ws.pattern, ws.port, ws.messageReciver, ws.bEnableCompression) + ws.wsserver.Init(ws.port) return nil } @@ -27,12 +27,10 @@ func (ws *WSServerService) OnRun() bool { return false } -func NewWSServerService(pattern string, port uint16, messageReciver network.IMessageReceiver, bEnableCompression bool) *WSServerService { +func NewWSServerService(port uint16) *WSServerService { wss := new(WSServerService) - wss.pattern = pattern + wss.port = port - wss.messageReciver = messageReciver - wss.bEnableCompression = bEnableCompression wss.Init(wss, 0) return wss @@ -41,3 +39,6 @@ func NewWSServerService(pattern string, port uint16, messageReciver network.IMes func (ws *WSServerService) OnDestory() error { return nil } +func (ws *WSServerService) SetupReciver(pattern string, messageReciver network.IMessageReceiver, bEnableCompression bool) { + ws.wsserver.SetupReciver(pattern, messageReciver, bEnableCompression) +}