diff --git a/network/ws_conn.go b/network/ws_conn.go index a45c5c6..2243c84 100644 --- a/network/ws_conn.go +++ b/network/ws_conn.go @@ -15,16 +15,16 @@ type WSConn struct { sync.Mutex conn *websocket.Conn writeChan chan []byte - maxMsgLen uint32 + maxWriteMsgLen uint32 closeFlag bool header http.Header } -func newWSConn(conn *websocket.Conn, header http.Header, pendingWriteNum int, maxMsgLen uint32, messageType int) *WSConn { +func newWSConn(conn *websocket.Conn, header http.Header, pendingWriteNum int, maxWriteMsgLen uint32, messageType int) *WSConn { wsConn := new(WSConn) wsConn.conn = conn wsConn.writeChan = make(chan []byte, pendingWriteNum) - wsConn.maxMsgLen = maxMsgLen + wsConn.maxWriteMsgLen = maxWriteMsgLen wsConn.header = header go func() { @@ -118,7 +118,7 @@ func (wsConn *WSConn) WriteMsg(args ...[]byte) error { } // check len - if msgLen > wsConn.maxMsgLen { + if wsConn.maxWriteMsgLen > 0 && msgLen > wsConn.maxWriteMsgLen { return errors.New("message too long") } else if msgLen < 1 { return errors.New("message too short") diff --git a/network/ws_server.go b/network/ws_server.go index 8623faf..903ea2a 100644 --- a/network/ws_server.go +++ b/network/ws_server.go @@ -16,7 +16,8 @@ type WSServer struct { Addr string MaxConnNum int PendingWriteNum int - MaxMsgLen uint32 + MaxReadMsgLen uint32 + MaxWriteMsgLen uint32 CertFile string KeyFile string NewAgent func(*WSConn) Agent @@ -32,7 +33,8 @@ type WSServer struct { type WSHandler struct { maxConnNum int pendingWriteNum int - maxMsgLen uint32 + maxReadMsgLen uint32 + maxWriteMsgLen uint32 newAgent func(*WSConn) Agent upgrader websocket.Upgrader conns WebsocketConnSet @@ -55,7 +57,7 @@ func (handler *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { log.Error("upgrade fail", log.String("error", err.Error())) return } - conn.SetReadLimit(int64(handler.maxMsgLen)) + conn.SetReadLimit(int64(handler.maxReadMsgLen)) if handler.messageType == 0 { handler.messageType = websocket.TextMessage } @@ -93,7 +95,7 @@ func (handler *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { c.SetLinger(0) c.SetNoDelay(true) - wsConn := newWSConn(conn, r.Header, handler.pendingWriteNum, handler.maxMsgLen, handler.messageType) + wsConn := newWSConn(conn, r.Header, handler.pendingWriteNum, handler.maxWriteMsgLen, handler.messageType) agent := handler.newAgent(wsConn) agent.Run() @@ -118,7 +120,6 @@ func (server *WSServer) Start() error { log.Error("WSServer Listen fail", log.String("error", err.Error())) return err } - if server.MaxConnNum <= 0 { server.MaxConnNum = 100 log.Info("invalid MaxConnNum", log.Int("reset", server.MaxConnNum)) @@ -127,9 +128,9 @@ func (server *WSServer) Start() error { server.PendingWriteNum = 100 log.Info("invalid PendingWriteNum", log.Int("reset", server.PendingWriteNum)) } - if server.MaxMsgLen <= 0 { - server.MaxMsgLen = 4096 - log.Info("invalid MaxMsgLen", log.Uint32("reset", server.MaxMsgLen)) + if server.MaxReadMsgLen <= 0 { + server.MaxReadMsgLen = 4096 + log.Info("invalid MaxReadMsgLen", log.Uint32("reset", server.MaxReadMsgLen)) } if server.HandshakeTimeout <= 0 { server.HandshakeTimeout = 15 * time.Second @@ -167,7 +168,8 @@ func (server *WSServer) Start() error { server.handler = &WSHandler{ maxConnNum: server.MaxConnNum, pendingWriteNum: server.PendingWriteNum, - maxMsgLen: server.MaxMsgLen, + maxReadMsgLen: server.MaxReadMsgLen, + maxWriteMsgLen: server.MaxWriteMsgLen, newAgent: server.NewAgent, conns: make(WebsocketConnSet), messageType: server.messageType, diff --git a/sysmodule/netmodule/wsmodule/WSModule.go b/sysmodule/netmodule/wsmodule/WSModule.go index ba2dd43..7cef44b 100644 --- a/sysmodule/netmodule/wsmodule/WSModule.go +++ b/sysmodule/netmodule/wsmodule/WSModule.go @@ -34,7 +34,8 @@ type WSCfg struct { ListenAddr string MaxConnNum int PendingWriteNum int - MaxMsgLen uint32 + MaxReadMsgLen uint32 + MaxWriteMsgLen uint32 LittleEndian bool //是否小端序 KeyFile string CertFile string @@ -67,7 +68,8 @@ func (ws *WSModule) OnInit() error { ws.WSServer.MaxConnNum = ws.wsCfg.MaxConnNum ws.WSServer.PendingWriteNum = ws.wsCfg.PendingWriteNum - ws.WSServer.MaxMsgLen = ws.wsCfg.MaxMsgLen + ws.WSServer.MaxReadMsgLen = ws.wsCfg.MaxReadMsgLen + ws.WSServer.MaxWriteMsgLen = ws.wsCfg.MaxWriteMsgLen ws.WSServer.Addr = ws.wsCfg.ListenAddr ws.WSServer.HandshakeTimeout = ws.wsCfg.HandshakeTimeoutSecond*time.Second ws.WSServer.ReadTimeout = ws.wsCfg.ReadTimeoutSecond*time.Second