diff --git a/network/kcp_client.go b/network/kcp_client.go index 0ce8df6..5f03e3d 100644 --- a/network/kcp_client.go +++ b/network/kcp_client.go @@ -69,15 +69,15 @@ func (client *KCPClient) init() { if client.MinMsgLen == 0 { client.MinMsgLen = Default_MinMsgLen } - if client.MaxMsgLen == 0 { - client.MaxMsgLen = Default_MaxMsgLen + if client.MaxReadMsgLen == 0 { + client.MaxReadMsgLen = Default_MaxReadMsgLen } if client.LenMsgLen == 0 { client.LenMsgLen = Default_LenMsgLen } maxMsgLen := client.MsgParser.getMaxMsgLen() - if client.MaxMsgLen > maxMsgLen { - client.MaxMsgLen = maxMsgLen + if client.MaxReadMsgLen > maxMsgLen { + client.MaxReadMsgLen = maxMsgLen log.Info("invalid MaxMsgLen", log.Uint32("reset", maxMsgLen)) } diff --git a/network/kcp_server.go b/network/kcp_server.go index 8a3d986..949d1e3 100644 --- a/network/kcp_server.go +++ b/network/kcp_server.go @@ -81,7 +81,8 @@ type KcpCfg struct { LittleEndian bool //是否小端序 LenMsgLen int //消息头占用byte数量,只能是1byte,2byte,4byte。如果是4byte,意味着消息最大可以是math.MaxUint32(4GB) MinMsgLen uint32 //最小消息长度 - MaxMsgLen uint32 //最大消息长度,超过判定不合法,断开连接 + MaxReadMsgLen uint32 //最大读消息长度,超过判定不合法,断开连接 + MaxWriteMsgLen uint32 // 最大写消息长度 PendingWriteNum int //写channel最大消息数量 } @@ -89,7 +90,8 @@ func (kp *KCPServer) Init(kcpCfg *KcpCfg) { kp.kcpCfg = kcpCfg kp.msgParser.Init() kp.msgParser.LenMsgLen = kp.kcpCfg.LenMsgLen - kp.msgParser.MaxMsgLen = kp.kcpCfg.MaxMsgLen + kp.msgParser.MaxReadMsgLen = kp.kcpCfg.MaxReadMsgLen + kp.msgParser.MaxWriteMsgLen = kp.kcpCfg.MaxWriteMsgLen kp.msgParser.MinMsgLen = kp.kcpCfg.MinMsgLen kp.msgParser.LittleEndian = kp.kcpCfg.LittleEndian diff --git a/network/processor/jsonprocessor.go b/network/processor/jsonprocessor.go index b9fafec..541a38b 100644 --- a/network/processor/jsonprocessor.go +++ b/network/processor/jsonprocessor.go @@ -45,9 +45,8 @@ func (jsonProcessor *JsonProcessor) SetByteOrder(littleEndian bool) { } // MsgRoute must goroutine safe -func (jsonProcessor *JsonProcessor) MsgRoute(clientId string, msg interface{}, recyclerReaderBytes func(data []byte)) error { +func (jsonProcessor *JsonProcessor) MsgRoute(clientId string, msg interface{}) error { pPackInfo := msg.(*JsonPackInfo) - defer recyclerReaderBytes(pPackInfo.rawMsg) v, ok := jsonProcessor.mapMsg[pPackInfo.typ] if ok == false { @@ -107,8 +106,7 @@ func (jsonProcessor *JsonProcessor) MakeRawMsg(msgType uint16, msg []byte) *Json return &JsonPackInfo{typ: msgType, rawMsg: msg} } -func (jsonProcessor *JsonProcessor) UnknownMsgRoute(clientId string, msg interface{}, recyclerReaderBytes func(data []byte)) { - defer recyclerReaderBytes(msg.([]byte)) +func (jsonProcessor *JsonProcessor) UnknownMsgRoute(clientId string, msg interface{}) { if jsonProcessor.unknownMessageHandler == nil { log.Debug("Unknown message", log.String("clientId", clientId)) return diff --git a/network/processor/pbprocessor.go b/network/processor/pbprocessor.go index 150d4ae..36bf69d 100644 --- a/network/processor/pbprocessor.go +++ b/network/processor/pbprocessor.go @@ -54,10 +54,8 @@ func (slf *PBPackInfo) GetMsg() proto.Message { } // MsgRoute must goroutine safe -func (pbProcessor *PBProcessor) MsgRoute(clientId string, msg interface{}, recyclerReaderBytes func(data []byte)) error { +func (pbProcessor *PBProcessor) MsgRoute(clientId string, msg interface{}) error { pPackInfo := msg.(*PBPackInfo) - defer recyclerReaderBytes(pPackInfo.rawMsg) - v, ok := pbProcessor.mapMsg[pPackInfo.typ] if ok == false { return fmt.Errorf("cannot find msgtype %d is register", pPackInfo.typ) @@ -134,9 +132,8 @@ func (pbProcessor *PBProcessor) MakeRawMsg(msgType uint16, msg []byte) *PBPackIn return &PBPackInfo{typ: msgType, rawMsg: msg} } -func (pbProcessor *PBProcessor) UnknownMsgRoute(clientId string, msg interface{}, recyclerReaderBytes func(data []byte)) { +func (pbProcessor *PBProcessor) UnknownMsgRoute(clientId string, msg interface{}) { pbProcessor.unknownMessageHandler(clientId, msg.([]byte)) - recyclerReaderBytes(msg.([]byte)) } func (pbProcessor *PBProcessor) ConnectedRoute(clientId string) { diff --git a/network/processor/pbrawprocessor.go b/network/processor/pbrawprocessor.go index c0bd814..bcd3bf9 100644 --- a/network/processor/pbrawprocessor.go +++ b/network/processor/pbrawprocessor.go @@ -39,10 +39,9 @@ func (pbRawProcessor *PBRawProcessor) SetByteOrder(littleEndian bool) { } // MsgRoute must goroutine safe -func (pbRawProcessor *PBRawProcessor) MsgRoute(clientId string, msg interface{}, recyclerReaderBytes func(data []byte)) error { +func (pbRawProcessor *PBRawProcessor) MsgRoute(clientId string, msg interface{}) error { pPackInfo := msg.(*PBRawPackInfo) pbRawProcessor.msgHandler(clientId, pPackInfo.typ, nil, pPackInfo.rawMsg) - recyclerReaderBytes(pPackInfo.rawMsg) return nil } @@ -83,8 +82,7 @@ func (pbRawProcessor *PBRawProcessor) MakeRawMsg(msgType uint16, msg []byte, pbR pbRawPackInfo.rawMsg = msg } -func (pbRawProcessor *PBRawProcessor) UnknownMsgRoute(clientId string, msg interface{}, recyclerReaderBytes func(data []byte)) { - defer recyclerReaderBytes(msg.([]byte)) +func (pbRawProcessor *PBRawProcessor) UnknownMsgRoute(clientId string, msg interface{}) { if pbRawProcessor.unknownMessageHandler == nil { return } diff --git a/network/processor/processor.go b/network/processor/processor.go index f6a7cae..ee9a60e 100644 --- a/network/processor/processor.go +++ b/network/processor/processor.go @@ -2,9 +2,9 @@ package processor type IProcessor interface { // MsgRoute must goroutine safe - MsgRoute(clientId string, msg interface{}, recyclerReaderBytes func(data []byte)) error + MsgRoute(clientId string, msg interface{}) error // UnknownMsgRoute must goroutine safe - UnknownMsgRoute(clientId string, msg interface{}, recyclerReaderBytes func(data []byte)) + UnknownMsgRoute(clientId string, msg interface{}) // ConnectedRoute connect event ConnectedRoute(clientId string) DisConnectedRoute(clientId string) diff --git a/network/tcp_client.go b/network/tcp_client.go index 1f799b6..f3b0d90 100644 --- a/network/tcp_client.go +++ b/network/tcp_client.go @@ -68,18 +68,19 @@ func (client *TCPClient) init() { if client.MinMsgLen == 0 { client.MinMsgLen = Default_MinMsgLen } - if client.MaxMsgLen == 0 { - client.MaxMsgLen = Default_MaxMsgLen + if client.MaxReadMsgLen == 0 { + client.MaxReadMsgLen = Default_MaxReadMsgLen } if client.LenMsgLen == 0 { client.LenMsgLen = Default_LenMsgLen } maxMsgLen := client.MsgParser.getMaxMsgLen() - if client.MaxMsgLen > maxMsgLen { - client.MaxMsgLen = maxMsgLen + if client.MaxReadMsgLen > maxMsgLen { + client.MaxReadMsgLen = maxMsgLen log.Info("invalid MaxMsgLen", log.Uint32("reset", maxMsgLen)) } + client.cons = make(ConnSet) client.closeFlag = false client.MsgParser.Init() diff --git a/network/tcp_msg.go b/network/tcp_msg.go index 6c0d5d0..49ae5df 100644 --- a/network/tcp_msg.go +++ b/network/tcp_msg.go @@ -14,7 +14,8 @@ import ( type MsgParser struct { LenMsgLen int MinMsgLen uint32 - MaxMsgLen uint32 + MaxReadMsgLen uint32 + MaxWriteMsgLen uint32 LittleEndian bool bytespool.IBytesMemPool @@ -67,7 +68,7 @@ func (p *MsgParser) Read(r io.Reader) ([]byte, error) { } // check len - if msgLen > p.MaxMsgLen { + if msgLen > p.MaxReadMsgLen { return nil, errors.New("message too long") } else if msgLen < p.MinMsgLen { return nil, errors.New("message too short") @@ -92,7 +93,7 @@ func (p *MsgParser) Write(conn io.Writer, args ...[]byte) error { } // check len - if msgLen > p.MaxMsgLen { + if p.MaxWriteMsgLen > 0 && msgLen > p.MaxWriteMsgLen { return errors.New("message too long") } else if msgLen < p.MinMsgLen { return errors.New("message too short") diff --git a/network/tcp_server.go b/network/tcp_server.go index 5d3f3d5..c9e51d1 100644 --- a/network/tcp_server.go +++ b/network/tcp_server.go @@ -11,13 +11,13 @@ import ( ) const ( - Default_ReadDeadline = time.Second * 30 //默认读超时30s - Default_WriteDeadline = time.Second * 30 //默认写超时30s - Default_MaxConnNum = 1000000 //默认最大连接数 - Default_PendingWriteNum = 100000 //单连接写消息Channel容量 - Default_MinMsgLen = 2 //最小消息长度2byte - Default_LenMsgLen = 2 //包头字段长度占用2byte - Default_MaxMsgLen = 65535 //最大消息长度 + Default_ReadDeadline = time.Second * 30 // 默认读超时30s + Default_WriteDeadline = time.Second * 30 // 默认写超时30s + Default_MaxConnNum = 1000000 // 默认最大连接数 + Default_PendingWriteNum = 100000 // 单连接写消息Channel容量 + Default_MinMsgLen = 2 // 最小消息长度2byte + Default_LenMsgLen = 2 // 包头字段长度占用2byte + Default_MaxReadMsgLen = 65535 // 最大读消息长度 ) type TCPServer struct { @@ -70,14 +70,14 @@ func (server *TCPServer) init() error { log.Info("invalid LenMsgLen", log.Int("reset", server.LenMsgLen)) } - if server.MaxMsgLen <= 0 { - server.MaxMsgLen = Default_MaxMsgLen - log.Info("invalid MaxMsgLen", log.Uint32("reset to", server.MaxMsgLen)) + if server.MaxReadMsgLen <= 0 { + server.MaxReadMsgLen = Default_MaxReadMsgLen + log.Info("invalid MaxMsgLen", log.Uint32("reset to", server.MaxReadMsgLen)) } maxMsgLen := server.MsgParser.getMaxMsgLen() - if server.MaxMsgLen > maxMsgLen { - server.MaxMsgLen = maxMsgLen + if server.MaxReadMsgLen > maxMsgLen { + server.MaxReadMsgLen = maxMsgLen log.Info("invalid MaxMsgLen", log.Uint32("reset", maxMsgLen)) } diff --git a/rpc/rclient.go b/rpc/rclient.go index dc3139a..86bebfd 100644 --- a/rpc/rclient.go +++ b/rpc/rclient.go @@ -127,9 +127,11 @@ func NewRClient(targetNodeId string, addr string, maxRpcParamLen uint32, compres c.NewAgent = client.NewClientAgent if maxRpcParamLen > 0 { - c.MaxMsgLen = maxRpcParamLen + c.MaxReadMsgLen = maxRpcParamLen + c.MaxWriteMsgLen = maxRpcParamLen } else { - c.MaxMsgLen = math.MaxUint32 + c.MaxReadMsgLen = math.MaxUint32 + c.MaxWriteMsgLen = math.MaxUint32 } client.IRealClient = c client.CallSet = callSet diff --git a/rpc/server.go b/rpc/server.go index 7057b45..3bb1250 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -91,9 +91,11 @@ func (server *Server) Start() error { server.rpcServer.Addr = ":" + splitAddr[1] server.rpcServer.MinMsgLen = 2 if server.maxRpcParamLen > 0 { - server.rpcServer.MaxMsgLen = server.maxRpcParamLen + server.rpcServer.MaxReadMsgLen = server.maxRpcParamLen + server.rpcServer.MaxWriteMsgLen = server.maxRpcParamLen } else { - server.rpcServer.MaxMsgLen = math.MaxUint32 + server.rpcServer.MaxReadMsgLen = math.MaxUint32 + server.rpcServer.MaxWriteMsgLen = math.MaxUint32 } server.rpcServer.MaxConnNum = 100000 diff --git a/sysmodule/netmodule/kcpmodule/KcpModule.go b/sysmodule/netmodule/kcpmodule/KcpModule.go index 6e1b8a2..f37186e 100644 --- a/sysmodule/netmodule/kcpmodule/KcpModule.go +++ b/sysmodule/netmodule/kcpmodule/KcpModule.go @@ -19,6 +19,7 @@ type KcpModule struct { mapClientLocker sync.RWMutex mapClient map[string]*Client process processor.IRawProcessor + newClientIdHandler func() string kcpServer network.KCPServer kcpCfg *network.KcpCfg @@ -56,7 +57,11 @@ func (km *KcpModule) OnInit() error { km.process.SetByteOrder(km.kcpCfg.LittleEndian) km.kcpServer.Init(km.kcpCfg) km.kcpServer.NewAgent = km.NewAgent - + if km.newClientIdHandler == nil { + km.newClientIdHandler = func()string{ + return primitive.NewObjectID().Hex() + } + } return nil } @@ -65,6 +70,10 @@ func (km *KcpModule) Init(kcpCfg *network.KcpCfg, process processor.IRawProcesso km.process = process } +func (km *KcpModule) SetNewClientIdHandler(newClientIdHandler func() string){ + km.newClientIdHandler = newClientIdHandler +} + func (km *KcpModule) Start() error { return km.kcpServer.Start() } @@ -77,9 +86,9 @@ func (km *KcpModule) kcpEventHandler(ev event.IEvent) { case KPTDisConnected: km.process.DisConnectedRoute(e.StringExt[0]) case KPTUnknownPack: - km.process.UnknownMsgRoute(e.StringExt[0], e.Data, e.AnyExt[0].(func(data []byte))) + km.process.UnknownMsgRoute(e.StringExt[0], e.Data) case KPTPack: - km.process.MsgRoute(e.StringExt[0], e.Data, e.AnyExt[0].(func(data []byte))) + km.process.MsgRoute(e.StringExt[0], e.Data) } event.DeleteEvent(ev) @@ -111,7 +120,7 @@ func (km *KcpModule) newClient(conn network.Conn) *Client { km.mapClientLocker.Lock() defer km.mapClientLocker.Unlock() - pClient := &Client{kcpConn: conn.(*network.NetConn), id: primitive.NewObjectID().Hex()} + pClient := &Client{kcpConn: conn.(*network.NetConn), id: km.newClientIdHandler()} pClient.kcpModule = km km.mapClient[pClient.id] = pClient diff --git a/sysmodule/netmodule/tcpmodule/TcpModule.go b/sysmodule/netmodule/tcpmodule/TcpModule.go index 72eb9f7..9d22edd 100644 --- a/sysmodule/netmodule/tcpmodule/TcpModule.go +++ b/sysmodule/netmodule/tcpmodule/TcpModule.go @@ -20,6 +20,7 @@ type TcpModule struct { mapClient map[string]*Client process processor.IRawProcessor tcpCfg *TcpCfg + newClientIdHandler func() string } type TcpPackType int8 @@ -35,6 +36,7 @@ type TcpPack struct { Type TcpPackType //0表示连接 1表示断开 2表示数据 ClientId string Data interface{} + rawData []byte RecyclerReaderBytes func(data []byte) } @@ -51,7 +53,8 @@ type TcpCfg struct { LittleEndian bool //是否小端序 LenMsgLen int //消息头占用byte数量,只能是1byte,2byte,4byte。如果是4byte,意味着消息最大可以是math.MaxUint32(4GB) MinMsgLen uint32 //最小消息长度 - MaxMsgLen uint32 //最大消息长度,超过判定不合法,断开连接 + MaxReadMsgLen uint32 //最大消息长度,超过判定不合法,断开连接 + MaxWriteMsgLen uint32 // 最大写消息长度 ReadDeadlineSecond time.Duration //读超时 WriteDeadlineSecond time.Duration //写超时 } @@ -68,11 +71,17 @@ func (tm *TcpModule) OnInit() error { tm.tcpServer.LittleEndian = tm.tcpCfg.LittleEndian tm.tcpServer.LenMsgLen = tm.tcpCfg.LenMsgLen tm.tcpServer.MinMsgLen = tm.tcpCfg.MinMsgLen - tm.tcpServer.MaxMsgLen = tm.tcpCfg.MaxMsgLen + tm.tcpServer.MaxReadMsgLen = tm.tcpCfg.MaxReadMsgLen + tm.tcpServer.MaxWriteMsgLen = tm.tcpCfg.MaxWriteMsgLen tm.tcpServer.ReadDeadline = tm.tcpCfg.ReadDeadlineSecond * time.Second tm.tcpServer.WriteDeadline = tm.tcpCfg.WriteDeadlineSecond * time.Second tm.mapClient = make(map[string]*Client, tm.tcpServer.MaxConnNum) tm.tcpServer.NewAgent = tm.NewClient + if tm.newClientIdHandler == nil { + tm.newClientIdHandler = func()string{ + return primitive.NewObjectID().Hex() + } + } //3.设置解析处理器 tm.process.SetByteOrder(tm.tcpCfg.LittleEndian) @@ -87,6 +96,10 @@ func (tm *TcpModule) Init(tcpCfg *TcpCfg, process processor.IRawProcessor) { tm.process = process } +func (tm *TcpModule) SetNewClientIdHandler(newClientIdHandler func() string){ + tm.newClientIdHandler = newClientIdHandler +} + func (tm *TcpModule) Start() error { return tm.tcpServer.Start() } @@ -99,9 +112,11 @@ func (tm *TcpModule) tcpEventHandler(ev event.IEvent) { case TPTDisConnected: tm.process.DisConnectedRoute(pack.ClientId) case TPTUnknownPack: - tm.process.UnknownMsgRoute(pack.ClientId, pack.Data, pack.RecyclerReaderBytes) + tm.process.UnknownMsgRoute(pack.ClientId, pack.Data) + pack.RecyclerReaderBytes(pack.rawData) case TPTPack: - tm.process.MsgRoute(pack.ClientId, pack.Data, pack.RecyclerReaderBytes) + tm.process.MsgRoute(pack.ClientId, pack.Data) + pack.RecyclerReaderBytes(pack.rawData) } } @@ -109,7 +124,7 @@ func (tm *TcpModule) NewClient(conn network.Conn) network.Agent { tm.mapClientLocker.Lock() defer tm.mapClientLocker.Unlock() - clientId := primitive.NewObjectID().Hex() + clientId := tm.newClientIdHandler() pClient := &Client{tcpConn: conn.(*network.NetConn), id: clientId} pClient.tcpModule = tm tm.mapClient[clientId] = pClient @@ -138,10 +153,10 @@ func (slf *Client) Run() { } data, err := slf.tcpModule.process.Unmarshal(slf.id, bytes) if err != nil { - slf.tcpModule.NotifyEvent(&event.Event{Type: event.Sys_Event_Tcp, Data: TcpPack{ClientId: slf.id, Type: TPTUnknownPack, Data: bytes, RecyclerReaderBytes: slf.tcpConn.GetRecyclerReaderBytes()}}) + slf.tcpModule.NotifyEvent(&event.Event{Type: event.Sys_Event_Tcp, Data: TcpPack{ClientId: slf.id, Type: TPTUnknownPack, Data: data,rawData: bytes,RecyclerReaderBytes: slf.tcpConn.GetRecyclerReaderBytes()}}) continue } - slf.tcpModule.NotifyEvent(&event.Event{Type: event.Sys_Event_Tcp, Data: TcpPack{ClientId: slf.id, Type: TPTPack, Data: data, RecyclerReaderBytes: slf.tcpConn.GetRecyclerReaderBytes()}}) + slf.tcpModule.NotifyEvent(&event.Event{Type: event.Sys_Event_Tcp, Data: TcpPack{ClientId: slf.id, Type: TPTPack, Data: data,rawData: bytes, RecyclerReaderBytes: slf.tcpConn.GetRecyclerReaderBytes()}}) } } diff --git a/sysmodule/netmodule/wsmodule/WSModule.go b/sysmodule/netmodule/wsmodule/WSModule.go index 7cef44b..f5c7f1c 100644 --- a/sysmodule/netmodule/wsmodule/WSModule.go +++ b/sysmodule/netmodule/wsmodule/WSModule.go @@ -22,6 +22,7 @@ type WSModule struct { mapClient map[string]*WSClient process processor.IRawProcessor wsCfg *WSCfg + newClientIdHandler func() string } type WSClient struct { @@ -74,7 +75,11 @@ func (ws *WSModule) OnInit() error { ws.WSServer.HandshakeTimeout = ws.wsCfg.HandshakeTimeoutSecond*time.Second ws.WSServer.ReadTimeout = ws.wsCfg.ReadTimeoutSecond*time.Second ws.WSServer.WriteTimeout = ws.wsCfg.WriteTimeoutSecond*time.Second - + if ws.newClientIdHandler == nil { + ws.newClientIdHandler = func()string{ + return primitive.NewObjectID().Hex() + } + } if ws.wsCfg.KeyFile != "" && ws.wsCfg.CertFile != "" { ws.WSServer.KeyFile = ws.wsCfg.KeyFile ws.WSServer.CertFile = ws.wsCfg.CertFile @@ -97,6 +102,10 @@ func (ws *WSModule) Init(wsCfg *WSCfg, process processor.IRawProcessor) { ws.process = process } +func (ws *WSModule) SetNewClientIdHandler(newClientIdHandler func() string){ + ws.newClientIdHandler = newClientIdHandler +} + func (ws *WSModule) Start() error { return ws.WSServer.Start() } @@ -109,9 +118,9 @@ func (ws *WSModule) wsEventHandler(ev event.IEvent) { case WPTDisConnected: ws.process.DisConnectedRoute(pack.ClientId) case WPTUnknownPack: - ws.process.UnknownMsgRoute(pack.ClientId, pack.Data, ws.recyclerReaderBytes) + ws.process.UnknownMsgRoute(pack.ClientId, pack.Data) case WPTPack: - ws.process.MsgRoute(pack.ClientId, pack.Data, ws.recyclerReaderBytes) + ws.process.MsgRoute(pack.ClientId, pack.Data) } } @@ -121,7 +130,7 @@ func (ws *WSModule) recyclerReaderBytes([]byte) { func (ws *WSModule) NewWSClient(conn *network.WSConn) network.Agent { ws.mapClientLocker.Lock() defer ws.mapClientLocker.Unlock() - pClient := &WSClient{wsConn: conn, id: primitive.NewObjectID().Hex()} + pClient := &WSClient{wsConn: conn, id: ws.newClientIdHandler()} pClient.wsModule = ws ws.mapClient[pClient.id] = pClient @@ -142,7 +151,7 @@ func (wc *WSClient) Run() { } data, err := wc.wsModule.process.Unmarshal(wc.id, bytes) if err != nil { - wc.wsModule.NotifyEvent(&event.Event{Type: event.Sys_Event_WebSocket, Data: &WSPack{ClientId: wc.id, Type: WPTUnknownPack, Data: bytes}}) + wc.wsModule.NotifyEvent(&event.Event{Type: event.Sys_Event_WebSocket, Data: &WSPack{ClientId: wc.id, Type: WPTUnknownPack, Data: data}}) continue } wc.wsModule.NotifyEvent(&event.Event{Type: event.Sys_Event_WebSocket, Data: &WSPack{ClientId: wc.id, Type: WPTPack, Data: data}})