新增websocket

This commit is contained in:
duanhf2012
2020-04-21 14:40:31 +08:00
parent f74f3a812e
commit e590f0dce9
13 changed files with 631 additions and 6 deletions

View File

@@ -1 +0,0 @@
package processor

View File

@@ -1 +0,0 @@
package processor

View File

@@ -1 +0,0 @@
package processor

132
network/ws_client.go Normal file
View File

@@ -0,0 +1,132 @@
package network
import (
"github.com/duanhf2012/origin/log"
"github.com/gorilla/websocket"
"sync"
"time"
)
type WSClient struct {
sync.Mutex
Addr string
ConnNum int
ConnectInterval time.Duration
PendingWriteNum int
MaxMsgLen uint32
HandshakeTimeout time.Duration
AutoReconnect bool
NewAgent func(*WSConn) Agent
dialer websocket.Dialer
conns WebsocketConnSet
wg sync.WaitGroup
closeFlag bool
}
func (client *WSClient) Start() {
client.init()
for i := 0; i < client.ConnNum; i++ {
client.wg.Add(1)
go client.connect()
}
}
func (client *WSClient) init() {
client.Lock()
defer client.Unlock()
if client.ConnNum <= 0 {
client.ConnNum = 1
log.Release("invalid ConnNum, reset to %v", client.ConnNum)
}
if client.ConnectInterval <= 0 {
client.ConnectInterval = 3 * time.Second
log.Release("invalid ConnectInterval, reset to %v", client.ConnectInterval)
}
if client.PendingWriteNum <= 0 {
client.PendingWriteNum = 100
log.Release("invalid PendingWriteNum, reset to %v", client.PendingWriteNum)
}
if client.MaxMsgLen <= 0 {
client.MaxMsgLen = 4096
log.Release("invalid MaxMsgLen, reset to %v", client.MaxMsgLen)
}
if client.HandshakeTimeout <= 0 {
client.HandshakeTimeout = 10 * time.Second
log.Release("invalid HandshakeTimeout, reset to %v", client.HandshakeTimeout)
}
if client.NewAgent == nil {
log.Fatal("NewAgent must not be nil")
}
if client.conns != nil {
log.Fatal("client is running")
}
client.conns = make(WebsocketConnSet)
client.closeFlag = false
client.dialer = websocket.Dialer{
HandshakeTimeout: client.HandshakeTimeout,
}
}
func (client *WSClient) dial() *websocket.Conn {
for {
conn, _, err := client.dialer.Dial(client.Addr, nil)
if err == nil || client.closeFlag {
return conn
}
log.Release("connect to %v error: %v", client.Addr, err)
time.Sleep(client.ConnectInterval)
continue
}
}
func (client *WSClient) connect() {
defer client.wg.Done()
reconnect:
conn := client.dial()
if conn == nil {
return
}
conn.SetReadLimit(int64(client.MaxMsgLen))
client.Lock()
if client.closeFlag {
client.Unlock()
conn.Close()
return
}
client.conns[conn] = struct{}{}
client.Unlock()
wsConn := newWSConn(conn, client.PendingWriteNum, client.MaxMsgLen)
agent := client.NewAgent(wsConn)
agent.Run()
// cleanup
wsConn.Close()
client.Lock()
delete(client.conns, conn)
client.Unlock()
agent.OnClose()
if client.AutoReconnect {
time.Sleep(client.ConnectInterval)
goto reconnect
}
}
func (client *WSClient) Close() {
client.Lock()
client.closeFlag = true
for conn := range client.conns {
conn.Close()
}
client.conns = nil
client.Unlock()
client.wg.Wait()
}

138
network/ws_conn.go Normal file
View File

@@ -0,0 +1,138 @@
package network
import (
"errors"
"github.com/duanhf2012/origin/log"
"github.com/gorilla/websocket"
"net"
"sync"
)
type WebsocketConnSet map[*websocket.Conn]struct{}
type WSConn struct {
sync.Mutex
conn *websocket.Conn
writeChan chan []byte
maxMsgLen uint32
closeFlag bool
}
func newWSConn(conn *websocket.Conn, pendingWriteNum int, maxMsgLen uint32) *WSConn {
wsConn := new(WSConn)
wsConn.conn = conn
wsConn.writeChan = make(chan []byte, pendingWriteNum)
wsConn.maxMsgLen = maxMsgLen
go func() {
for b := range wsConn.writeChan {
if b == nil {
break
}
err := conn.WriteMessage(websocket.BinaryMessage, b)
if err != nil {
break
}
}
conn.Close()
wsConn.Lock()
wsConn.closeFlag = true
wsConn.Unlock()
}()
return wsConn
}
func (wsConn *WSConn) doDestroy() {
wsConn.conn.UnderlyingConn().(*net.TCPConn).SetLinger(0)
wsConn.conn.Close()
if !wsConn.closeFlag {
close(wsConn.writeChan)
wsConn.closeFlag = true
}
}
func (wsConn *WSConn) Destroy() {
wsConn.Lock()
defer wsConn.Unlock()
wsConn.doDestroy()
}
func (wsConn *WSConn) Close() {
wsConn.Lock()
defer wsConn.Unlock()
if wsConn.closeFlag {
return
}
wsConn.doWrite(nil)
wsConn.closeFlag = true
}
func (wsConn *WSConn) doWrite(b []byte) {
if len(wsConn.writeChan) == cap(wsConn.writeChan) {
log.Debug("close conn: channel full")
wsConn.doDestroy()
return
}
wsConn.writeChan <- b
}
func (wsConn *WSConn) LocalAddr() net.Addr {
return wsConn.conn.LocalAddr()
}
func (wsConn *WSConn) RemoteAddr() net.Addr {
return wsConn.conn.RemoteAddr()
}
// goroutine not safe
func (wsConn *WSConn) ReadMsg() ([]byte, error) {
_, b, err := wsConn.conn.ReadMessage()
return b, err
}
// args must not be modified by the others goroutines
func (wsConn *WSConn) WriteMsg(args ...[]byte) error {
wsConn.Lock()
defer wsConn.Unlock()
if wsConn.closeFlag {
return nil
}
// get len
var msgLen uint32
for i := 0; i < len(args); i++ {
msgLen += uint32(len(args[i]))
}
// check len
if msgLen > wsConn.maxMsgLen {
return errors.New("message too long")
} else if msgLen < 1 {
return errors.New("message too short")
}
// don't copy
if len(args) == 1 {
wsConn.doWrite(args[0])
return nil
}
// merge the args
msg := make([]byte, msgLen)
l := 0
for i := 0; i < len(args); i++ {
copy(msg[l:], args[i])
l += len(args[i])
}
wsConn.doWrite(msg)
return nil
}

154
network/ws_server.go Normal file
View File

@@ -0,0 +1,154 @@
package network
import (
"crypto/tls"
"github.com/duanhf2012/origin/log"
"github.com/gorilla/websocket"
"net"
"net/http"
"sync"
"time"
)
type WSServer struct {
Addr string
MaxConnNum int
PendingWriteNum int
MaxMsgLen uint32
HTTPTimeout time.Duration
CertFile string
KeyFile string
NewAgent func(*WSConn) Agent
ln net.Listener
handler *WSHandler
}
type WSHandler struct {
maxConnNum int
pendingWriteNum int
maxMsgLen uint32
newAgent func(*WSConn) Agent
upgrader websocket.Upgrader
conns WebsocketConnSet
mutexConns sync.Mutex
wg sync.WaitGroup
}
func (handler *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
http.Error(w, "Method not allowed", 405)
return
}
conn, err := handler.upgrader.Upgrade(w, r, nil)
if err != nil {
log.Debug("upgrade error: %v", err)
return
}
conn.SetReadLimit(int64(handler.maxMsgLen))
handler.wg.Add(1)
defer handler.wg.Done()
handler.mutexConns.Lock()
if handler.conns == nil {
handler.mutexConns.Unlock()
conn.Close()
return
}
if len(handler.conns) >= handler.maxConnNum {
handler.mutexConns.Unlock()
conn.Close()
log.Debug("too many connections")
return
}
handler.conns[conn] = struct{}{}
handler.mutexConns.Unlock()
wsConn := newWSConn(conn, handler.pendingWriteNum, handler.maxMsgLen)
agent := handler.newAgent(wsConn)
agent.Run()
// cleanup
wsConn.Close()
handler.mutexConns.Lock()
delete(handler.conns, conn)
handler.mutexConns.Unlock()
agent.OnClose()
}
func (server *WSServer) Start() {
ln, err := net.Listen("tcp", server.Addr)
if err != nil {
log.Fatal("%v", err)
}
if server.MaxConnNum <= 0 {
server.MaxConnNum = 100
log.Release("invalid MaxConnNum, reset to %v", server.MaxConnNum)
}
if server.PendingWriteNum <= 0 {
server.PendingWriteNum = 100
log.Release("invalid PendingWriteNum, reset to %v", server.PendingWriteNum)
}
if server.MaxMsgLen <= 0 {
server.MaxMsgLen = 4096
log.Release("invalid MaxMsgLen, reset to %v", server.MaxMsgLen)
}
if server.HTTPTimeout <= 0 {
server.HTTPTimeout = 10 * time.Second
log.Release("invalid HTTPTimeout, reset to %v", server.HTTPTimeout)
}
if server.NewAgent == nil {
log.Fatal("NewAgent must not be nil")
}
if server.CertFile != "" || server.KeyFile != "" {
config := &tls.Config{}
config.NextProtos = []string{"http/1.1"}
var err error
config.Certificates = make([]tls.Certificate, 1)
config.Certificates[0], err = tls.LoadX509KeyPair(server.CertFile, server.KeyFile)
if err != nil {
log.Fatal("%v", err)
}
ln = tls.NewListener(ln, config)
}
server.ln = ln
server.handler = &WSHandler{
maxConnNum: server.MaxConnNum,
pendingWriteNum: server.PendingWriteNum,
maxMsgLen: server.MaxMsgLen,
newAgent: server.NewAgent,
conns: make(WebsocketConnSet),
upgrader: websocket.Upgrader{
HandshakeTimeout: server.HTTPTimeout,
CheckOrigin: func(_ *http.Request) bool { return true },
},
}
httpServer := &http.Server{
Addr: server.Addr,
Handler: server.handler,
ReadTimeout: server.HTTPTimeout,
WriteTimeout: server.HTTPTimeout,
MaxHeaderBytes: 1024,
}
go httpServer.Serve(ln)
}
func (server *WSServer) Close() {
server.ln.Close()
server.handler.mutexConns.Lock()
for conn := range server.handler.conns {
conn.Close()
}
server.handler.conns = nil
server.handler.mutexConns.Unlock()
server.handler.wg.Wait()
}