From b943ea9a83dc884850b91f33325eaf5baf2b9597 Mon Sep 17 00:00:00 2001 From: duanhf2012 <6549168@qq.com> Date: Mon, 30 Sep 2024 14:31:24 +0800 Subject: [PATCH] =?UTF-8?q?1.=E4=BC=98=E5=8C=96=E7=BD=91=E7=BB=9C=E6=A8=A1?= =?UTF-8?q?=E5=9D=97=202.=E6=96=B0=E5=A2=9Ekcp=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- event/event.go | 3 + event/eventtype.go | 15 +- network/conn.go | 159 +++++++++++++++++ network/kcp_client.go | 161 +++++++++++++++++ network/kcp_server.go | 257 ++++++++++++++++++++++++++++ network/tcp_client.go | 8 +- network/tcp_conn.go | 168 ------------------ network/tcp_msg.go | 16 +- network/tcp_server.go | 12 +- network/ws_client.go | 2 +- network/ws_conn.go | 10 +- network/ws_server.go | 22 ++- node/node.go | 14 +- rpc/client.go | 4 +- rpc/lclient.go | 2 +- rpc/natsclient.go | 2 +- rpc/rclient.go | 6 +- rpc/server.go | 2 +- sysservice/tcpservice/tcpservice.go | 4 +- sysservice/wsservice/wsservice.go | 3 +- 20 files changed, 656 insertions(+), 214 deletions(-) create mode 100644 network/kcp_client.go create mode 100644 network/kcp_server.go delete mode 100644 network/tcp_conn.go diff --git a/event/event.go b/event/event.go index 90edc5e..b94a0b7 100644 --- a/event/event.go +++ b/event/event.go @@ -17,6 +17,9 @@ type IEvent interface { type Event struct { Type EventType Data interface{} + IntExt [2]int64 + StringExt [2]string + AnyExt [2]any ref bool } diff --git a/event/eventtype.go b/event/eventtype.go index 91c2bf8..7203711 100644 --- a/event/eventtype.go +++ b/event/eventtype.go @@ -10,13 +10,14 @@ const ( Sys_Event_Tcp EventType = -3 Sys_Event_Http_Event EventType = -4 Sys_Event_WebSocket EventType = -5 - Sys_Event_Node_Conn_Event EventType = -6 - Sys_Event_Nats_Conn_Event EventType = -7 - Sys_Event_DiscoverService EventType = -8 - Sys_Event_Retire EventType = -9 - Sys_Event_EtcdDiscovery EventType = -10 - Sys_Event_Gin_Event EventType = -11 - Sys_Event_FrameTick EventType = -12 + Sys_Event_Kcp EventType = -6 + Sys_Event_Node_Conn_Event EventType = -7 + Sys_Event_Nats_Conn_Event EventType = -8 + Sys_Event_DiscoverService EventType = -9 + Sys_Event_Retire EventType = -10 + Sys_Event_EtcdDiscovery EventType = -11 + Sys_Event_Gin_Event EventType = -12 + Sys_Event_FrameTick EventType = -13 Sys_Event_User_Define EventType = 1 ) diff --git a/network/conn.go b/network/conn.go index 1f7aaf6..7d438a3 100644 --- a/network/conn.go +++ b/network/conn.go @@ -1,7 +1,12 @@ package network import ( + "errors" + "github.com/duanhf2012/origin/v2/log" "net" + "sync" + "sync/atomic" + "time" ) type Conn interface { @@ -13,3 +18,157 @@ type Conn interface { Destroy() ReleaseReadMsg(byteBuff []byte) } + +type ConnSet map[net.Conn]struct{} + +type NetConn struct { + sync.Mutex + conn net.Conn + writeChan chan []byte + closeFlag int32 + msgParser *MsgParser +} + +func freeChannel(conn *NetConn) { + for len(conn.writeChan) > 0 { + byteBuff := <-conn.writeChan + if byteBuff != nil { + conn.ReleaseReadMsg(byteBuff) + } + } +} + +func newNetConn(conn net.Conn, pendingWriteNum int, msgParser *MsgParser, writeDeadline time.Duration) *NetConn { + netConn := new(NetConn) + netConn.conn = conn + netConn.writeChan = make(chan []byte, pendingWriteNum) + netConn.msgParser = msgParser + go func() { + for b := range netConn.writeChan { + if b == nil { + break + } + + conn.SetWriteDeadline(time.Now().Add(writeDeadline)) + _, err := conn.Write(b) + netConn.msgParser.ReleaseBytes(b) + + if err != nil { + break + } + } + conn.Close() + netConn.Lock() + freeChannel(netConn) + atomic.StoreInt32(&netConn.closeFlag, 1) + netConn.Unlock() + }() + + return netConn +} + +func (netConn *NetConn) doDestroy() { + netConn.conn.Close() + + if atomic.LoadInt32(&netConn.closeFlag) == 0 { + close(netConn.writeChan) + atomic.StoreInt32(&netConn.closeFlag, 1) + } +} + +func (netConn *NetConn) Destroy() { + netConn.Lock() + defer netConn.Unlock() + + netConn.doDestroy() +} + +func (netConn *NetConn) Close() { + netConn.Lock() + defer netConn.Unlock() + if atomic.LoadInt32(&netConn.closeFlag) == 1 { + return + } + + netConn.doWrite(nil) + atomic.StoreInt32(&netConn.closeFlag, 1) +} + +func (netConn *NetConn) GetRemoteIp() string { + return netConn.conn.RemoteAddr().String() +} + +func (netConn *NetConn) doWrite(b []byte) error { + if len(netConn.writeChan) == cap(netConn.writeChan) { + netConn.ReleaseReadMsg(b) + log.Error("close conn: channel full") + netConn.doDestroy() + return errors.New("close conn: channel full") + } + + netConn.writeChan <- b + return nil +} + +// b must not be modified by the others goroutines +func (netConn *NetConn) Write(b []byte) error { + netConn.Lock() + defer netConn.Unlock() + if atomic.LoadInt32(&netConn.closeFlag) == 1 || b == nil { + netConn.ReleaseReadMsg(b) + return errors.New("conn is close") + } + + return netConn.doWrite(b) +} + +func (netConn *NetConn) Read(b []byte) (int, error) { + return netConn.conn.Read(b) +} + +func (netConn *NetConn) LocalAddr() net.Addr { + return netConn.conn.LocalAddr() +} + +func (netConn *NetConn) RemoteAddr() net.Addr { + return netConn.conn.RemoteAddr() +} + +func (netConn *NetConn) ReadMsg() ([]byte, error) { + return netConn.msgParser.Read(netConn) +} + +func (netConn *NetConn) GetRecyclerReaderBytes() func(data []byte) { + return netConn.msgParser.GetRecyclerReaderBytes() +} + +func (netConn *NetConn) ReleaseReadMsg(byteBuff []byte) { + netConn.msgParser.ReleaseBytes(byteBuff) +} + +func (netConn *NetConn) WriteMsg(args ...[]byte) error { + if atomic.LoadInt32(&netConn.closeFlag) == 1 { + return errors.New("conn is close") + } + return netConn.msgParser.Write(netConn.conn, args...) +} + +func (netConn *NetConn) WriteRawMsg(args []byte) error { + if atomic.LoadInt32(&netConn.closeFlag) == 1 { + return errors.New("conn is close") + } + + return netConn.Write(args) +} + +func (netConn *NetConn) IsConnected() bool { + return atomic.LoadInt32(&netConn.closeFlag) == 0 +} + +func (netConn *NetConn) SetReadDeadline(d time.Duration) { + netConn.conn.SetReadDeadline(time.Now().Add(d)) +} + +func (netConn *NetConn) SetWriteDeadline(d time.Duration) { + netConn.conn.SetWriteDeadline(time.Now().Add(d)) +} diff --git a/network/kcp_client.go b/network/kcp_client.go new file mode 100644 index 0000000..13be511 --- /dev/null +++ b/network/kcp_client.go @@ -0,0 +1,161 @@ +package network + +import ( + "github.com/duanhf2012/origin/v2/log" + kcp "github.com/xtaci/kcp-go/v5" + "net" + "sync" + "time" +) + +type KCPClient struct { + sync.Mutex + Addr string + ConnNum int + ConnectInterval time.Duration + PendingWriteNum int + ReadDeadline time.Duration + WriteDeadline time.Duration + AutoReconnect bool + NewAgent func(conn *NetConn) Agent + cons ConnSet + wg sync.WaitGroup + closeFlag bool + + // msg parser + MsgParser +} + +func (client *KCPClient) Start() { + client.init() + + for i := 0; i < client.ConnNum; i++ { + client.wg.Add(1) + go client.connect() + } +} + +func (client *KCPClient) init() { + client.Lock() + defer client.Unlock() + + if client.ConnNum <= 0 { + client.ConnNum = 1 + log.Info("invalid ConnNum", log.Int("reset", client.ConnNum)) + } + if client.ConnectInterval <= 0 { + client.ConnectInterval = 3 * time.Second + log.Info("invalid ConnectInterval", log.Duration("reset", client.ConnectInterval)) + } + if client.PendingWriteNum <= 0 { + client.PendingWriteNum = 1000 + log.Info("invalid PendingWriteNum", log.Int("reset", client.PendingWriteNum)) + } + if client.ReadDeadline == 0 { + client.ReadDeadline = 15 * time.Second + log.Info("invalid ReadDeadline", log.Int64("reset", int64(client.ReadDeadline.Seconds()))) + } + if client.WriteDeadline == 0 { + client.WriteDeadline = 15 * time.Second + log.Info("invalid WriteDeadline", log.Int64("reset", int64(client.WriteDeadline.Seconds()))) + } + if client.NewAgent == nil { + log.Fatal("NewAgent must not be nil") + } + if client.cons != nil { + log.Fatal("client is running") + } + + if client.MinMsgLen == 0 { + client.MinMsgLen = Default_MinMsgLen + } + if client.MaxMsgLen == 0 { + client.MaxMsgLen = Default_MaxMsgLen + } + if client.LenMsgLen == 0 { + client.LenMsgLen = Default_LenMsgLen + } + maxMsgLen := client.MsgParser.getMaxMsgLen(client.LenMsgLen) + if client.MaxMsgLen > maxMsgLen { + client.MaxMsgLen = maxMsgLen + log.Info("invalid MaxMsgLen", log.Uint32("reset", maxMsgLen)) + } + + client.cons = make(ConnSet) + client.closeFlag = false + client.MsgParser.Init() +} + +func (client *KCPClient) GetCloseFlag() bool { + client.Lock() + defer client.Unlock() + + return client.closeFlag +} + +func (client *KCPClient) dial() net.Conn { + for { + conn, err := kcp.DialWithOptions(client.Addr, nil, 10, 3) + if client.closeFlag { + return conn + } else if err == nil && conn != nil { + conn.SetNoDelay(1, 10, 2, 1) + conn.SetDSCP(46) + conn.SetStreamMode(true) + conn.SetWindowSize(1024, 1024) + return conn + } + + log.Warning("connect error ", log.String("error", err.Error()), log.String("Addr", client.Addr)) + time.Sleep(client.ConnectInterval) + continue + } +} + +func (client *KCPClient) connect() { + defer client.wg.Done() +reconnect: + conn := client.dial() + if conn == nil { + return + } + + client.Lock() + if client.closeFlag { + client.Unlock() + conn.Close() + return + } + client.cons[conn] = struct{}{} + client.Unlock() + + netConn := newNetConn(conn, client.PendingWriteNum, &client.MsgParser, client.WriteDeadline) + agent := client.NewAgent(netConn) + agent.Run() + + // cleanup + netConn.Close() + client.Lock() + delete(client.cons, conn) + client.Unlock() + agent.OnClose() + + if client.AutoReconnect { + time.Sleep(client.ConnectInterval) + goto reconnect + } +} + +func (client *KCPClient) Close(waitDone bool) { + client.Lock() + client.closeFlag = true + for conn := range client.cons { + conn.Close() + } + client.cons = nil + client.Unlock() + + if waitDone == true { + client.wg.Wait() + } +} diff --git a/network/kcp_server.go b/network/kcp_server.go new file mode 100644 index 0000000..171e2ec --- /dev/null +++ b/network/kcp_server.go @@ -0,0 +1,257 @@ +package network + +import ( + "github.com/duanhf2012/origin/v2/log" + "github.com/duanhf2012/origin/v2/network/processor" + kcp "github.com/xtaci/kcp-go/v5" + "sync" + "time" +) + +type KCPServer struct { + NewAgent func(Conn) Agent + + kcpCfg *KcpCfg + blockCrypt kcp.BlockCrypt + + process processor.IRawProcessor + msgParser MsgParser + conns ConnSet + mutexConns sync.Mutex + wgLn sync.WaitGroup + wgConns sync.WaitGroup + + listener *kcp.Listener +} + +/* + NoDelayCfg + +普通模式: ikcp_nodelay(kcp, 0, 40, 0, 0); +极速模式: ikcp_nodelay(kcp, 1, 10, 2, 1); +*/ +type NoDelayCfg struct { + NoDelay int // 是否启用 nodelay模式,0不启用;1启用 + IntervalMill int // 协议内部工作的 interval,单位毫秒,比如 10ms或者 20ms + Resend int // 快速重传模式,默认0关闭,可以设置2(2次ACK跨越将会直接重传) + CongestionControl int // 是否关闭流控,默认是0代表不关闭,1代表关闭 +} + +const ( + DefaultNoDelay = 1 + DefaultIntervalMill = 10 + DefaultResend = 2 + DefaultCongestionControl = 1 + + DefaultMtu = 1400 + DefaultSndWndSize = 4096 + DefaultRcvWndSize = 4096 + DefaultStreamMode = true + DefaultDSCP = 46 + DefaultDataShards = 10 + DefaultParityShards = 0 + + DefaultReadDeadlineMill = 15 * time.Second + DefaultWriteDeadlineMill = 15 * time.Second + + DefaultMaxConnNum = 20000 +) + +type KcpCfg struct { + ListenAddr string // 监听地址 + MaxConnNum int //最大连接数 + NoDelay *NoDelayCfg + + Mtu *int // mtu大小 + SndWndSize *int // 发送窗口大小,默认1024 + RcvWndSize *int // 接收窗口大小,默认1024 + ReadDeadlineMill *time.Duration // 读超时毫秒 + WriteDeadlineMill *time.Duration // 写超时毫秒 + StreamMode *bool // 是否打开流模式,默认true + DSCP *int // 差分服务代码点,默认46 + ReadBuffSize *int // 读Buff大小,默认 + WriteBuffSize *int // 写Buff大小 + + // 用于 FEC(前向纠错)的数据分片和校验分片数量,,默认10,0 + DataShards *int + ParityShards *int + + // 包体内容 + + LittleEndian bool //是否小端序 + LenMsgLen int //消息头占用byte数量,只能是1byte,2byte,4byte。如果是4byte,意味着消息最大可以是math.MaxUint32(4GB) + MinMsgLen uint32 //最小消息长度 + MaxMsgLen uint32 //最大消息长度,超过判定不合法,断开连接 + PendingWriteNum int //写channel最大消息数量 +} + +func (kp *KCPServer) Init(kcpCfg *KcpCfg) { + kp.kcpCfg = kcpCfg + kp.msgParser.Init() + kp.msgParser.LenMsgLen = kp.kcpCfg.LenMsgLen + kp.msgParser.MaxMsgLen = kp.kcpCfg.MaxMsgLen + kp.msgParser.MinMsgLen = kp.kcpCfg.MinMsgLen + kp.msgParser.LittleEndian = kp.kcpCfg.LittleEndian + + // setting default noDelay + if kp.kcpCfg.NoDelay == nil { + var noDelay NoDelayCfg + noDelay.NoDelay = DefaultNoDelay + noDelay.IntervalMill = DefaultIntervalMill + noDelay.Resend = DefaultResend + noDelay.CongestionControl = DefaultCongestionControl + kp.kcpCfg.NoDelay = &noDelay + } + + if kp.kcpCfg.Mtu == nil { + mtu := DefaultMtu + kp.kcpCfg.Mtu = &mtu + } + + if kp.kcpCfg.SndWndSize == nil { + sndWndSize := DefaultSndWndSize + kp.kcpCfg.SndWndSize = &sndWndSize + } + if kp.kcpCfg.RcvWndSize == nil { + rcvWndSize := DefaultRcvWndSize + kp.kcpCfg.RcvWndSize = &rcvWndSize + } + if kp.kcpCfg.ReadDeadlineMill == nil { + readDeadlineMill := DefaultReadDeadlineMill + kp.kcpCfg.ReadDeadlineMill = &readDeadlineMill + } else { + *kp.kcpCfg.ReadDeadlineMill *= time.Millisecond + } + if kp.kcpCfg.WriteDeadlineMill == nil { + writeDeadlineMill := DefaultWriteDeadlineMill + kp.kcpCfg.WriteDeadlineMill = &writeDeadlineMill + } else { + *kp.kcpCfg.WriteDeadlineMill *= time.Millisecond + } + if kp.kcpCfg.StreamMode == nil { + streamMode := DefaultStreamMode + kp.kcpCfg.StreamMode = &streamMode + } + if kp.kcpCfg.DataShards == nil { + dataShards := DefaultDataShards + kp.kcpCfg.DataShards = &dataShards + } + if kp.kcpCfg.ParityShards == nil { + parityShards := DefaultParityShards + kp.kcpCfg.ParityShards = &parityShards + } + if kp.kcpCfg.DSCP == nil { + dss := DefaultDSCP + kp.kcpCfg.DSCP = &dss + } + + if kp.kcpCfg.MaxConnNum == 0 { + kp.kcpCfg.MaxConnNum = DefaultMaxConnNum + } + + kp.conns = make(ConnSet, 2048) + kp.msgParser.Init() + return +} + +func (kp *KCPServer) Start() error { + listener, err := kcp.ListenWithOptions(kp.kcpCfg.ListenAddr, kp.blockCrypt, *kp.kcpCfg.DataShards, *kp.kcpCfg.ParityShards) + if err != nil { + return err + } + + if kp.kcpCfg.ReadBuffSize != nil { + err = listener.SetReadBuffer(*kp.kcpCfg.ReadBuffSize) + if err != nil { + return err + } + } + if kp.kcpCfg.WriteBuffSize != nil { + err = listener.SetWriteBuffer(*kp.kcpCfg.WriteBuffSize) + if err != nil { + return err + } + } + err = listener.SetDSCP(*kp.kcpCfg.DSCP) + if err != nil { + return err + } + + kp.listener = listener + + kp.wgLn.Add(1) + go func() { + defer kp.wgLn.Done() + for kp.run(listener) { + } + }() + + return nil +} + +func (kp *KCPServer) initSession(session *kcp.UDPSession) { + session.SetStreamMode(*kp.kcpCfg.StreamMode) + session.SetWindowSize(*kp.kcpCfg.SndWndSize, *kp.kcpCfg.RcvWndSize) + session.SetNoDelay(kp.kcpCfg.NoDelay.NoDelay, kp.kcpCfg.NoDelay.IntervalMill, kp.kcpCfg.NoDelay.Resend, kp.kcpCfg.NoDelay.CongestionControl) + session.SetDSCP(*kp.kcpCfg.DSCP) + session.SetMtu(*kp.kcpCfg.Mtu) + session.SetACKNoDelay(false) + + //session.SetWriteDeadline(time.Now().Add(*kp.kcpCfg.WriteDeadlineMill)) +} + +func (kp *KCPServer) run(listener *kcp.Listener) bool { + conn, err := listener.Accept() + if err != nil { + log.Error("accept error", log.String("ListenAddr", kp.kcpCfg.ListenAddr), log.ErrorAttr("err", err)) + return false + } + + kp.mutexConns.Lock() + if len(kp.conns) >= kp.kcpCfg.MaxConnNum { + kp.mutexConns.Unlock() + conn.Close() + log.Warning("too many connections") + return true + } + kp.conns[conn] = struct{}{} + kp.mutexConns.Unlock() + + if kp.kcpCfg.ReadBuffSize != nil { + conn.(*kcp.UDPSession).SetReadBuffer(*kp.kcpCfg.ReadBuffSize) + } + if kp.kcpCfg.WriteBuffSize != nil { + conn.(*kcp.UDPSession).SetWriteBuffer(*kp.kcpCfg.WriteBuffSize) + } + kp.initSession(conn.(*kcp.UDPSession)) + + netConn := newNetConn(conn, kp.kcpCfg.PendingWriteNum, &kp.msgParser, *kp.kcpCfg.WriteDeadlineMill) + agent := kp.NewAgent(netConn) + kp.wgConns.Add(1) + go func() { + agent.Run() + // cleanup + conn.Close() + kp.mutexConns.Lock() + delete(kp.conns, conn) + kp.mutexConns.Unlock() + agent.OnClose() + + kp.wgConns.Done() + }() + + return true +} + +func (kp *KCPServer) Close() { + kp.listener.Close() + kp.wgLn.Wait() + + kp.mutexConns.Lock() + for conn := range kp.conns { + conn.Close() + } + kp.conns = nil + kp.mutexConns.Unlock() + kp.wgConns.Wait() +} diff --git a/network/tcp_client.go b/network/tcp_client.go index fb8df9d..fde1084 100644 --- a/network/tcp_client.go +++ b/network/tcp_client.go @@ -1,5 +1,5 @@ -package network +package network import ( "github.com/duanhf2012/origin/v2/log" "net" @@ -16,7 +16,7 @@ type TCPClient struct { ReadDeadline time.Duration WriteDeadline time.Duration AutoReconnect bool - NewAgent func(*TCPConn) Agent + NewAgent func(conn *NetConn) Agent cons ConnSet wg sync.WaitGroup closeFlag bool @@ -82,7 +82,7 @@ func (client *TCPClient) init() { client.cons = make(ConnSet) client.closeFlag = false - client.MsgParser.init() + client.MsgParser.Init() } func (client *TCPClient) GetCloseFlag() bool{ @@ -126,7 +126,7 @@ reconnect: client.cons[conn] = struct{}{} client.Unlock() - tcpConn := newTCPConn(conn, client.PendingWriteNum, &client.MsgParser,client.WriteDeadline) + tcpConn := newNetConn(conn, client.PendingWriteNum, &client.MsgParser,client.WriteDeadline) agent := client.NewAgent(tcpConn) agent.Run() diff --git a/network/tcp_conn.go b/network/tcp_conn.go deleted file mode 100644 index f97fa24..0000000 --- a/network/tcp_conn.go +++ /dev/null @@ -1,168 +0,0 @@ -package network - -import ( - "errors" - "github.com/duanhf2012/origin/v2/log" - "net" - "sync" - "sync/atomic" - "time" -) - -type ConnSet map[net.Conn]struct{} - -type TCPConn struct { - sync.Mutex - conn net.Conn - writeChan chan []byte - closeFlag int32 - msgParser *MsgParser -} - -func freeChannel(conn *TCPConn) { - for len(conn.writeChan) > 0 { - byteBuff := <-conn.writeChan - if byteBuff != nil { - conn.ReleaseReadMsg(byteBuff) - } - } -} - -func newTCPConn(conn net.Conn, pendingWriteNum int, msgParser *MsgParser, writeDeadline time.Duration) *TCPConn { - tcpConn := new(TCPConn) - tcpConn.conn = conn - tcpConn.writeChan = make(chan []byte, pendingWriteNum) - tcpConn.msgParser = msgParser - go func() { - for b := range tcpConn.writeChan { - if b == nil { - break - } - - conn.SetWriteDeadline(time.Now().Add(writeDeadline)) - _, err := conn.Write(b) - tcpConn.msgParser.ReleaseBytes(b) - - if err != nil { - break - } - } - conn.Close() - tcpConn.Lock() - freeChannel(tcpConn) - atomic.StoreInt32(&tcpConn.closeFlag, 1) - tcpConn.Unlock() - }() - - return tcpConn -} - -func (tcpConn *TCPConn) doDestroy() { - tcpConn.conn.(*net.TCPConn).SetLinger(0) - tcpConn.conn.Close() - - if atomic.LoadInt32(&tcpConn.closeFlag) == 0 { - close(tcpConn.writeChan) - atomic.StoreInt32(&tcpConn.closeFlag, 1) - } -} - -func (tcpConn *TCPConn) Destroy() { - tcpConn.Lock() - defer tcpConn.Unlock() - - tcpConn.doDestroy() -} - -func (tcpConn *TCPConn) Close() { - tcpConn.Lock() - defer tcpConn.Unlock() - if atomic.LoadInt32(&tcpConn.closeFlag) == 1 { - return - } - - tcpConn.doWrite(nil) - atomic.StoreInt32(&tcpConn.closeFlag, 1) -} - -func (tcpConn *TCPConn) GetRemoteIp() string { - return tcpConn.conn.RemoteAddr().String() -} - -func (tcpConn *TCPConn) doWrite(b []byte) error { - if len(tcpConn.writeChan) == cap(tcpConn.writeChan) { - tcpConn.ReleaseReadMsg(b) - log.Error("close conn: channel full") - tcpConn.doDestroy() - return errors.New("close conn: channel full") - } - - tcpConn.writeChan <- b - return nil -} - -// b must not be modified by the others goroutines -func (tcpConn *TCPConn) Write(b []byte) error { - tcpConn.Lock() - defer tcpConn.Unlock() - if atomic.LoadInt32(&tcpConn.closeFlag) == 1 || b == nil { - tcpConn.ReleaseReadMsg(b) - return errors.New("conn is close") - } - - return tcpConn.doWrite(b) -} - -func (tcpConn *TCPConn) Read(b []byte) (int, error) { - return tcpConn.conn.Read(b) -} - -func (tcpConn *TCPConn) LocalAddr() net.Addr { - return tcpConn.conn.LocalAddr() -} - -func (tcpConn *TCPConn) RemoteAddr() net.Addr { - return tcpConn.conn.RemoteAddr() -} - -func (tcpConn *TCPConn) ReadMsg() ([]byte, error) { - return tcpConn.msgParser.Read(tcpConn) -} - -func (tcpConn *TCPConn) GetRecyclerReaderBytes() func(data []byte) { - bytePool := tcpConn.msgParser.IBytesMemPool - return func(data []byte) { - bytePool.ReleaseBytes(data) - } -} - -func (tcpConn *TCPConn) ReleaseReadMsg(byteBuff []byte) { - tcpConn.msgParser.ReleaseBytes(byteBuff) -} - -func (tcpConn *TCPConn) WriteMsg(args ...[]byte) error { - if atomic.LoadInt32(&tcpConn.closeFlag) == 1 { - return errors.New("conn is close") - } - return tcpConn.msgParser.Write(tcpConn, args...) -} - -func (tcpConn *TCPConn) WriteRawMsg(args []byte) error { - if atomic.LoadInt32(&tcpConn.closeFlag) == 1 { - return errors.New("conn is close") - } - - return tcpConn.Write(args) -} - -func (tcpConn *TCPConn) IsConnected() bool { - return atomic.LoadInt32(&tcpConn.closeFlag) == 0 -} - -func (tcpConn *TCPConn) SetReadDeadline(d time.Duration) { - tcpConn.conn.SetReadDeadline(time.Now().Add(d)) -} - -func (tcpConn *TCPConn) SetWriteDeadline(d time.Duration) { - tcpConn.conn.SetWriteDeadline(time.Now().Add(d)) -} diff --git a/network/tcp_msg.go b/network/tcp_msg.go index 12ab697..f0184b7 100644 --- a/network/tcp_msg.go +++ b/network/tcp_msg.go @@ -33,17 +33,17 @@ func (p *MsgParser) getMaxMsgLen(lenMsgLen int) uint32 { } } -func (p *MsgParser) init() { +func (p *MsgParser) Init() { p.IBytesMemPool = bytespool.NewMemAreaPool() } // goroutine safe -func (p *MsgParser) Read(conn *TCPConn) ([]byte, error) { +func (p *MsgParser) Read(r io.Reader) ([]byte, error) { var b [4]byte bufMsgLen := b[:p.LenMsgLen] // read len - if _, err := io.ReadFull(conn, bufMsgLen); err != nil { + if _, err := io.ReadFull(r, bufMsgLen); err != nil { return nil, err } @@ -75,7 +75,7 @@ func (p *MsgParser) Read(conn *TCPConn) ([]byte, error) { // data msgData := p.MakeBytes(int(msgLen)) - if _, err := io.ReadFull(conn, msgData[:msgLen]); err != nil { + if _, err := io.ReadFull(r, msgData[:msgLen]); err != nil { p.ReleaseBytes(msgData) return nil, err } @@ -84,7 +84,7 @@ func (p *MsgParser) Read(conn *TCPConn) ([]byte, error) { } // goroutine safe -func (p *MsgParser) Write(conn *TCPConn, args ...[]byte) error { +func (p *MsgParser) Write(conn io.Writer, args ...[]byte) error { // get len var msgLen uint32 for i := 0; i < len(args); i++ { @@ -129,3 +129,9 @@ func (p *MsgParser) Write(conn *TCPConn, args ...[]byte) error { return nil } + +func (p *MsgParser) GetRecyclerReaderBytes() func(data []byte) { + return func(data []byte) { + p.IBytesMemPool.ReleaseBytes(data) + } +} \ No newline at end of file diff --git a/network/tcp_server.go b/network/tcp_server.go index d068ec5..ac9fae4 100644 --- a/network/tcp_server.go +++ b/network/tcp_server.go @@ -27,7 +27,7 @@ type TCPServer struct { ReadDeadline time.Duration WriteDeadline time.Duration - NewAgent func(*TCPConn) Agent + NewAgent func(conn Conn) Agent ln net.Listener conns ConnSet mutexConns sync.Mutex @@ -42,6 +42,8 @@ func (server *TCPServer) Start() error { if err != nil { return err } + + server.wgLn.Add(1) go server.run() return nil @@ -99,8 +101,8 @@ func (server *TCPServer) init() error { } server.ln = ln - server.conns = make(ConnSet) - server.MsgParser.init() + server.conns = make(ConnSet, 2048) + server.MsgParser.Init() return nil } @@ -114,7 +116,6 @@ func (server *TCPServer) GetNetMemPool() bytespool.IBytesMemPool { } func (server *TCPServer) run() { - server.wgLn.Add(1) defer server.wgLn.Done() var tempDelay time.Duration @@ -137,6 +138,7 @@ func (server *TCPServer) run() { return } + conn.(*net.TCPConn).SetLinger(0) conn.(*net.TCPConn).SetNoDelay(true) tempDelay = 0 @@ -152,7 +154,7 @@ func (server *TCPServer) run() { server.mutexConns.Unlock() server.wgConns.Add(1) - tcpConn := newTCPConn(conn, server.PendingWriteNum, &server.MsgParser, server.WriteDeadline) + tcpConn := newNetConn(conn, server.PendingWriteNum, &server.MsgParser, server.WriteDeadline) agent := server.NewAgent(tcpConn) go func() { diff --git a/network/ws_client.go b/network/ws_client.go index 8b76717..312b127 100644 --- a/network/ws_client.go +++ b/network/ws_client.go @@ -108,7 +108,7 @@ reconnect: client.cons[conn] = struct{}{} client.Unlock() - wsConn := newWSConn(conn, client.PendingWriteNum, client.MaxMsgLen,client.MessageType) + wsConn := newWSConn(conn,nil, client.PendingWriteNum, client.MaxMsgLen,client.MessageType) agent := client.NewAgent(wsConn) agent.Run() diff --git a/network/ws_conn.go b/network/ws_conn.go index 9f290ba..a45c5c6 100644 --- a/network/ws_conn.go +++ b/network/ws_conn.go @@ -5,6 +5,7 @@ import ( "github.com/duanhf2012/origin/v2/log" "github.com/gorilla/websocket" "net" + "net/http" "sync" ) @@ -16,13 +17,15 @@ type WSConn struct { writeChan chan []byte maxMsgLen uint32 closeFlag bool + header http.Header } -func newWSConn(conn *websocket.Conn, pendingWriteNum int, maxMsgLen uint32, messageType int) *WSConn { +func newWSConn(conn *websocket.Conn, header http.Header, pendingWriteNum int, maxMsgLen uint32, messageType int) *WSConn { wsConn := new(WSConn) wsConn.conn = conn wsConn.writeChan = make(chan []byte, pendingWriteNum) wsConn.maxMsgLen = maxMsgLen + wsConn.header = header go func() { for b := range wsConn.writeChan { @@ -46,7 +49,6 @@ func newWSConn(conn *websocket.Conn, pendingWriteNum int, maxMsgLen uint32, mess } func (wsConn *WSConn) doDestroy() { - wsConn.conn.UnderlyingConn().(*net.TCPConn).SetLinger(0) wsConn.conn.Close() if !wsConn.closeFlag { @@ -83,6 +85,10 @@ func (wsConn *WSConn) doWrite(b []byte) { wsConn.writeChan <- b } +func (wsConn *WSConn) GetHeader() http.Header { + return wsConn.header +} + func (wsConn *WSConn) LocalAddr() net.Addr { return wsConn.conn.LocalAddr() } diff --git a/network/ws_server.go b/network/ws_server.go index 394eb2a..0fe65e7 100644 --- a/network/ws_server.go +++ b/network/ws_server.go @@ -2,6 +2,7 @@ package network import ( "crypto/tls" + "errors" "github.com/duanhf2012/origin/v2/log" "github.com/gorilla/websocket" "net" @@ -47,7 +48,7 @@ func (handler *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } conn, err := handler.upgrader.Upgrade(w, r, nil) if err != nil { - log.Error("upgrade fail",log.String("error",err.Error())) + log.Error("upgrade fail", log.String("error", err.Error())) return } conn.SetReadLimit(int64(handler.maxMsgLen)) @@ -73,7 +74,9 @@ func (handler *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { handler.conns[conn] = struct{}{} handler.mutexConns.Unlock() - wsConn := newWSConn(conn, handler.pendingWriteNum, handler.maxMsgLen, handler.messageType) + conn.UnderlyingConn().(*net.TCPConn).SetLinger(0) + conn.UnderlyingConn().(*net.TCPConn).SetNoDelay(true) + wsConn := newWSConn(conn, r.Header, handler.pendingWriteNum, handler.maxMsgLen, handler.messageType) agent := handler.newAgent(wsConn) agent.Run() @@ -92,10 +95,11 @@ func (server *WSServer) SetMessageType(messageType int) { } } -func (server *WSServer) Start() { +func (server *WSServer) Start() error { ln, err := net.Listen("tcp", server.Addr) if err != nil { - log.Fatal("WSServer Listen fail",log.String("error", err.Error())) + log.Error("WSServer Listen fail", log.String("error", err.Error())) + return err } if server.MaxConnNum <= 0 { @@ -115,18 +119,19 @@ func (server *WSServer) Start() { log.Info("invalid HTTPTimeout", log.Duration("reset", server.HTTPTimeout)) } if server.NewAgent == nil { - log.Fatal("NewAgent must not be nil") + log.Error("NewAgent must not be nil") + return errors.New("NewAgent must not be nil") } if server.CertFile != "" || server.KeyFile != "" { config := &tls.Config{} config.NextProtos = []string{"http/1.1"} - var err error config.Certificates = make([]tls.Certificate, 1) config.Certificates[0], err = tls.LoadX509KeyPair(server.CertFile, server.KeyFile) if err != nil { - log.Fatal("LoadX509KeyPair fail",log.String("error", err.Error())) + log.Error("LoadX509KeyPair fail", log.String("error", err.Error())) + return err } ln = tls.NewListener(ln, config) @@ -139,7 +144,7 @@ func (server *WSServer) Start() { maxMsgLen: server.MaxMsgLen, newAgent: server.NewAgent, conns: make(WebsocketConnSet), - messageType:server.messageType, + messageType: server.messageType, upgrader: websocket.Upgrader{ HandshakeTimeout: server.HTTPTimeout, CheckOrigin: func(_ *http.Request) bool { return true }, @@ -155,6 +160,7 @@ func (server *WSServer) Start() { } go httpServer.Serve(ln) + return nil } func (server *WSServer) Close() { diff --git a/node/node.go b/node/node.go index b23a8ef..ef24a1b 100644 --- a/node/node.go +++ b/node/node.go @@ -382,6 +382,11 @@ func startNode(args interface{}) error { return nil } +type templateServicePoint[T any] interface { + *T + service.IService +} + func Setup(s ...service.IService) { for _, sv := range s { sv.OnSetup(sv) @@ -389,12 +394,19 @@ func Setup(s ...service.IService) { } } -func SetupTemplate(fs ...func() service.IService) { +func SetupTemplateFunc(fs ...func() service.IService) { for _, f := range fs { preSetupTemplateService = append(preSetupTemplateService, f) } } +func SetupTemplate[T any,P templateServicePoint[T]]() { + SetupTemplateFunc(func() service.IService{ + var t T + return P(&t) + }) +} + func GetService(serviceName string) service.IService { return service.GetService(serviceName) } diff --git a/rpc/client.go b/rpc/client.go index 891d009..e349466 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -29,7 +29,7 @@ type IWriter interface { } type IRealClient interface { - SetConn(conn *network.TCPConn) + SetConn(conn *network.NetConn) Close(waitDone bool) AsyncCall(NodeId string, timeout time.Duration, rpcHandler IRpcHandler, serviceMethod string, callback reflect.Value, args interface{}, replyParam interface{}, cancelable bool) (CancelRpc, error) @@ -52,7 +52,7 @@ type Client struct { IRealClient } -func (client *Client) NewClientAgent(conn *network.TCPConn) network.Agent { +func (client *Client) NewClientAgent(conn *network.NetConn) network.Agent { client.SetConn(conn) return client diff --git a/rpc/lclient.go b/rpc/lclient.go index 5c0dbd1..2cce955 100644 --- a/rpc/lclient.go +++ b/rpc/lclient.go @@ -31,7 +31,7 @@ func (lc *LClient) IsConnected() bool { return true } -func (lc *LClient) SetConn(conn *network.TCPConn) { +func (lc *LClient) SetConn(conn *network.NetConn) { } func (lc *LClient) Close(waitDone bool) { diff --git a/rpc/natsclient.go b/rpc/natsclient.go index e7030d6..f4d5609 100644 --- a/rpc/natsclient.go +++ b/rpc/natsclient.go @@ -29,7 +29,7 @@ func (nc *NatsClient) onSubscribe(msg *nats.Msg) { nc.client.processRpcResponse(msg.Data) } -func (nc *NatsClient) SetConn(conn *network.TCPConn) { +func (nc *NatsClient) SetConn(conn *network.NetConn) { } func (nc *NatsClient) Close(waitDone bool) { diff --git a/rpc/rclient.go b/rpc/rclient.go index ff85a95..43fe727 100644 --- a/rpc/rclient.go +++ b/rpc/rclient.go @@ -15,7 +15,7 @@ import ( type RClient struct { selfClient *Client network.TCPClient - conn *network.TCPConn + conn *network.NetConn notifyEventFun NotifyEventFun } @@ -27,7 +27,7 @@ func (rc *RClient) IsConnected() bool { return rc.conn != nil && rc.conn.IsConnected() == true } -func (rc *RClient) GetConn() *network.TCPConn { +func (rc *RClient) GetConn() *network.NetConn { rc.Lock() conn := rc.conn rc.Unlock() @@ -35,7 +35,7 @@ func (rc *RClient) GetConn() *network.TCPConn { return conn } -func (rc *RClient) SetConn(conn *network.TCPConn) { +func (rc *RClient) SetConn(conn *network.NetConn) { rc.Lock() rc.conn = conn rc.Unlock() diff --git a/rpc/server.go b/rpc/server.go index 0055960..595c093 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -211,7 +211,7 @@ func (agent *RpcAgent) Destroy() { agent.conn.Destroy() } -func (server *Server) NewAgent(c *network.TCPConn) network.Agent { +func (server *Server) NewAgent(c network.Conn) network.Agent { agent := &RpcAgent{conn: c, rpcServer: server} return agent diff --git a/sysservice/tcpservice/tcpservice.go b/sysservice/tcpservice/tcpservice.go index a21af28..64daac1 100644 --- a/sysservice/tcpservice/tcpservice.go +++ b/sysservice/tcpservice/tcpservice.go @@ -95,9 +95,7 @@ func (tcpService *TcpService) OnInit() error { tcpService.mapClient = make(map[string]*Client, tcpService.tcpServer.MaxConnNum) tcpService.tcpServer.NewAgent = tcpService.NewClient - tcpService.tcpServer.Start() - - return nil + return tcpService.tcpServer.Start() } func (tcpService *TcpService) TcpEventHandler(ev event.IEvent) { diff --git a/sysservice/wsservice/wsservice.go b/sysservice/wsservice/wsservice.go index 8dfc704..d68f8ef 100644 --- a/sysservice/wsservice/wsservice.go +++ b/sysservice/wsservice/wsservice.go @@ -80,8 +80,7 @@ func (ws *WSService) OnInit() error { ws.mapClient = make(map[string]*WSClient, ws.wsServer.MaxConnNum) ws.wsServer.NewAgent = ws.NewWSClient - ws.wsServer.Start() - return nil + return ws.wsServer.Start() } func (ws *WSService) SetMessageType(messageType int) {