优化ws超时

This commit is contained in:
boyce
2025-11-12 16:46:35 +08:00
parent 4cb6882a1a
commit d7c4cfb1ef
2 changed files with 35 additions and 13 deletions

View File

@@ -3,12 +3,13 @@ package network
import ( import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"github.com/duanhf2012/origin/v2/log"
"github.com/gorilla/websocket"
"net" "net"
"net/http" "net/http"
"sync" "sync"
"time" "time"
"github.com/duanhf2012/origin/v2/log"
"github.com/gorilla/websocket"
) )
type WSServer struct { type WSServer struct {
@@ -16,13 +17,16 @@ type WSServer struct {
MaxConnNum int MaxConnNum int
PendingWriteNum int PendingWriteNum int
MaxMsgLen uint32 MaxMsgLen uint32
HTTPTimeout time.Duration
CertFile string CertFile string
KeyFile string KeyFile string
NewAgent func(*WSConn) Agent NewAgent func(*WSConn) Agent
ln net.Listener ln net.Listener
handler *WSHandler handler *WSHandler
messageType int messageType int
HandshakeTimeout time.Duration
ReadTimeout time.Duration
WriteTimeout time.Duration
} }
type WSHandler struct { type WSHandler struct {
@@ -73,14 +77,14 @@ func (handler *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
handler.conns[conn] = struct{}{} handler.conns[conn] = struct{}{}
handler.mutexConns.Unlock() handler.mutexConns.Unlock()
c,ok:=conn.NetConn().(*net.TCPConn) c, ok := conn.NetConn().(*net.TCPConn)
if !ok { if !ok {
tlsConn,ok := conn.NetConn().(*tls.Conn) tlsConn, ok := conn.NetConn().(*tls.Conn)
if !ok { if !ok {
log.Error("conn error") log.Error("conn error")
return return
} }
c,ok = tlsConn.NetConn().(*net.TCPConn) c, ok = tlsConn.NetConn().(*net.TCPConn)
if !ok { if !ok {
log.Error("conn error") log.Error("conn error")
return return
@@ -127,10 +131,19 @@ func (server *WSServer) Start() error {
server.MaxMsgLen = 4096 server.MaxMsgLen = 4096
log.Info("invalid MaxMsgLen", log.Uint32("reset", server.MaxMsgLen)) log.Info("invalid MaxMsgLen", log.Uint32("reset", server.MaxMsgLen))
} }
if server.HTTPTimeout <= 0 { if server.HandshakeTimeout <= 0 {
server.HTTPTimeout = 10 * time.Second server.HandshakeTimeout = 15 * time.Second
log.Info("invalid HTTPTimeout", log.Duration("reset", server.HTTPTimeout)) log.Info("invalid HandshakeTimeout", log.Duration("reset", server.HandshakeTimeout))
} }
if server.ReadTimeout <= 0 {
server.ReadTimeout = 15 * time.Second
log.Info("invalid ReadTimeout", log.Duration("reset", server.ReadTimeout))
}
if server.WriteTimeout <= 0 {
server.WriteTimeout = 15 * time.Second
log.Info("invalid WriteTimeout", log.Duration("reset", server.WriteTimeout))
}
if server.NewAgent == nil { if server.NewAgent == nil {
log.Error("NewAgent must not be nil") log.Error("NewAgent must not be nil")
return errors.New("NewAgent must not be nil") return errors.New("NewAgent must not be nil")
@@ -159,7 +172,7 @@ func (server *WSServer) Start() error {
conns: make(WebsocketConnSet), conns: make(WebsocketConnSet),
messageType: server.messageType, messageType: server.messageType,
upgrader: websocket.Upgrader{ upgrader: websocket.Upgrader{
HandshakeTimeout: server.HTTPTimeout, HandshakeTimeout: server.HandshakeTimeout,
CheckOrigin: func(_ *http.Request) bool { return true }, CheckOrigin: func(_ *http.Request) bool { return true },
}, },
} }
@@ -167,8 +180,8 @@ func (server *WSServer) Start() error {
httpServer := &http.Server{ httpServer := &http.Server{
Addr: server.Addr, Addr: server.Addr,
Handler: server.handler, Handler: server.handler,
ReadTimeout: server.HTTPTimeout, ReadTimeout: server.ReadTimeout,
WriteTimeout: server.HTTPTimeout, WriteTimeout: server.WriteTimeout,
MaxHeaderBytes: 1024, MaxHeaderBytes: 1024,
} }

View File

@@ -2,13 +2,15 @@ package wsmodule
import ( import (
"fmt" "fmt"
"sync"
"time"
"github.com/duanhf2012/origin/v2/event" "github.com/duanhf2012/origin/v2/event"
"github.com/duanhf2012/origin/v2/log" "github.com/duanhf2012/origin/v2/log"
"github.com/duanhf2012/origin/v2/network" "github.com/duanhf2012/origin/v2/network"
"github.com/duanhf2012/origin/v2/network/processor" "github.com/duanhf2012/origin/v2/network/processor"
"github.com/duanhf2012/origin/v2/service" "github.com/duanhf2012/origin/v2/service"
"go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/bson/primitive"
"sync"
) )
type WSModule struct { type WSModule struct {
@@ -36,6 +38,10 @@ type WSCfg struct {
LittleEndian bool //是否小端序 LittleEndian bool //是否小端序
KeyFile string KeyFile string
CertFile string CertFile string
HandshakeTimeoutSecond time.Duration
ReadTimeoutSecond time.Duration
WriteTimeoutSecond time.Duration
} }
type WSPackType int8 type WSPackType int8
@@ -63,6 +69,9 @@ func (ws *WSModule) OnInit() error {
ws.WSServer.PendingWriteNum = ws.wsCfg.PendingWriteNum ws.WSServer.PendingWriteNum = ws.wsCfg.PendingWriteNum
ws.WSServer.MaxMsgLen = ws.wsCfg.MaxMsgLen ws.WSServer.MaxMsgLen = ws.wsCfg.MaxMsgLen
ws.WSServer.Addr = ws.wsCfg.ListenAddr ws.WSServer.Addr = ws.wsCfg.ListenAddr
ws.WSServer.HandshakeTimeout = ws.wsCfg.HandshakeTimeoutSecond*time.Second
ws.WSServer.ReadTimeout = ws.wsCfg.ReadTimeoutSecond*time.Second
ws.WSServer.WriteTimeout = ws.wsCfg.WriteTimeoutSecond*time.Second
if ws.wsCfg.KeyFile != "" && ws.wsCfg.CertFile != "" { if ws.wsCfg.KeyFile != "" && ws.wsCfg.CertFile != "" {
ws.WSServer.KeyFile = ws.wsCfg.KeyFile ws.WSServer.KeyFile = ws.wsCfg.KeyFile