优化ws读写最大限制

This commit is contained in:
boyce
2026-01-08 08:33:29 +08:00
parent ef7ee0ab8e
commit 330644cebb
3 changed files with 19 additions and 15 deletions

View File

@@ -15,16 +15,16 @@ type WSConn struct {
sync.Mutex sync.Mutex
conn *websocket.Conn conn *websocket.Conn
writeChan chan []byte writeChan chan []byte
maxMsgLen uint32 maxWriteMsgLen uint32
closeFlag bool closeFlag bool
header http.Header header http.Header
} }
func newWSConn(conn *websocket.Conn, header http.Header, pendingWriteNum int, maxMsgLen uint32, messageType int) *WSConn { func newWSConn(conn *websocket.Conn, header http.Header, pendingWriteNum int, maxWriteMsgLen uint32, messageType int) *WSConn {
wsConn := new(WSConn) wsConn := new(WSConn)
wsConn.conn = conn wsConn.conn = conn
wsConn.writeChan = make(chan []byte, pendingWriteNum) wsConn.writeChan = make(chan []byte, pendingWriteNum)
wsConn.maxMsgLen = maxMsgLen wsConn.maxWriteMsgLen = maxWriteMsgLen
wsConn.header = header wsConn.header = header
go func() { go func() {
@@ -118,7 +118,7 @@ func (wsConn *WSConn) WriteMsg(args ...[]byte) error {
} }
// check len // check len
if msgLen > wsConn.maxMsgLen { if wsConn.maxWriteMsgLen > 0 && msgLen > wsConn.maxWriteMsgLen {
return errors.New("message too long") return errors.New("message too long")
} else if msgLen < 1 { } else if msgLen < 1 {
return errors.New("message too short") return errors.New("message too short")

View File

@@ -16,7 +16,8 @@ type WSServer struct {
Addr string Addr string
MaxConnNum int MaxConnNum int
PendingWriteNum int PendingWriteNum int
MaxMsgLen uint32 MaxReadMsgLen uint32
MaxWriteMsgLen uint32
CertFile string CertFile string
KeyFile string KeyFile string
NewAgent func(*WSConn) Agent NewAgent func(*WSConn) Agent
@@ -32,7 +33,8 @@ type WSServer struct {
type WSHandler struct { type WSHandler struct {
maxConnNum int maxConnNum int
pendingWriteNum int pendingWriteNum int
maxMsgLen uint32 maxReadMsgLen uint32
maxWriteMsgLen uint32
newAgent func(*WSConn) Agent newAgent func(*WSConn) Agent
upgrader websocket.Upgrader upgrader websocket.Upgrader
conns WebsocketConnSet conns WebsocketConnSet
@@ -55,7 +57,7 @@ func (handler *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.Error("upgrade fail", log.String("error", err.Error())) log.Error("upgrade fail", log.String("error", err.Error()))
return return
} }
conn.SetReadLimit(int64(handler.maxMsgLen)) conn.SetReadLimit(int64(handler.maxReadMsgLen))
if handler.messageType == 0 { if handler.messageType == 0 {
handler.messageType = websocket.TextMessage handler.messageType = websocket.TextMessage
} }
@@ -93,7 +95,7 @@ func (handler *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c.SetLinger(0) c.SetLinger(0)
c.SetNoDelay(true) c.SetNoDelay(true)
wsConn := newWSConn(conn, r.Header, handler.pendingWriteNum, handler.maxMsgLen, handler.messageType) wsConn := newWSConn(conn, r.Header, handler.pendingWriteNum, handler.maxWriteMsgLen, handler.messageType)
agent := handler.newAgent(wsConn) agent := handler.newAgent(wsConn)
agent.Run() agent.Run()
@@ -118,7 +120,6 @@ func (server *WSServer) Start() error {
log.Error("WSServer Listen fail", log.String("error", err.Error())) log.Error("WSServer Listen fail", log.String("error", err.Error()))
return err return err
} }
if server.MaxConnNum <= 0 { if server.MaxConnNum <= 0 {
server.MaxConnNum = 100 server.MaxConnNum = 100
log.Info("invalid MaxConnNum", log.Int("reset", server.MaxConnNum)) log.Info("invalid MaxConnNum", log.Int("reset", server.MaxConnNum))
@@ -127,9 +128,9 @@ func (server *WSServer) Start() error {
server.PendingWriteNum = 100 server.PendingWriteNum = 100
log.Info("invalid PendingWriteNum", log.Int("reset", server.PendingWriteNum)) log.Info("invalid PendingWriteNum", log.Int("reset", server.PendingWriteNum))
} }
if server.MaxMsgLen <= 0 { if server.MaxReadMsgLen <= 0 {
server.MaxMsgLen = 4096 server.MaxReadMsgLen = 4096
log.Info("invalid MaxMsgLen", log.Uint32("reset", server.MaxMsgLen)) log.Info("invalid MaxReadMsgLen", log.Uint32("reset", server.MaxReadMsgLen))
} }
if server.HandshakeTimeout <= 0 { if server.HandshakeTimeout <= 0 {
server.HandshakeTimeout = 15 * time.Second server.HandshakeTimeout = 15 * time.Second
@@ -167,7 +168,8 @@ func (server *WSServer) Start() error {
server.handler = &WSHandler{ server.handler = &WSHandler{
maxConnNum: server.MaxConnNum, maxConnNum: server.MaxConnNum,
pendingWriteNum: server.PendingWriteNum, pendingWriteNum: server.PendingWriteNum,
maxMsgLen: server.MaxMsgLen, maxReadMsgLen: server.MaxReadMsgLen,
maxWriteMsgLen: server.MaxWriteMsgLen,
newAgent: server.NewAgent, newAgent: server.NewAgent,
conns: make(WebsocketConnSet), conns: make(WebsocketConnSet),
messageType: server.messageType, messageType: server.messageType,

View File

@@ -34,7 +34,8 @@ type WSCfg struct {
ListenAddr string ListenAddr string
MaxConnNum int MaxConnNum int
PendingWriteNum int PendingWriteNum int
MaxMsgLen uint32 MaxReadMsgLen uint32
MaxWriteMsgLen uint32
LittleEndian bool //是否小端序 LittleEndian bool //是否小端序
KeyFile string KeyFile string
CertFile string CertFile string
@@ -67,7 +68,8 @@ func (ws *WSModule) OnInit() error {
ws.WSServer.MaxConnNum = ws.wsCfg.MaxConnNum ws.WSServer.MaxConnNum = ws.wsCfg.MaxConnNum
ws.WSServer.PendingWriteNum = ws.wsCfg.PendingWriteNum ws.WSServer.PendingWriteNum = ws.wsCfg.PendingWriteNum
ws.WSServer.MaxMsgLen = ws.wsCfg.MaxMsgLen ws.WSServer.MaxReadMsgLen = ws.wsCfg.MaxReadMsgLen
ws.WSServer.MaxWriteMsgLen = ws.wsCfg.MaxWriteMsgLen
ws.WSServer.Addr = ws.wsCfg.ListenAddr ws.WSServer.Addr = ws.wsCfg.ListenAddr
ws.WSServer.HandshakeTimeout = ws.wsCfg.HandshakeTimeoutSecond*time.Second ws.WSServer.HandshakeTimeout = ws.wsCfg.HandshakeTimeoutSecond*time.Second
ws.WSServer.ReadTimeout = ws.wsCfg.ReadTimeoutSecond*time.Second ws.WSServer.ReadTimeout = ws.wsCfg.ReadTimeoutSecond*time.Second