mirror of
https://github.com/duanhf2012/origin.git
synced 2026-02-15 08:14:46 +08:00
新增websocket
This commit is contained in:
@@ -1 +0,0 @@
|
||||
package processor
|
||||
@@ -1 +0,0 @@
|
||||
package processor
|
||||
@@ -1 +0,0 @@
|
||||
package processor
|
||||
132
network/ws_client.go
Normal file
132
network/ws_client.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"github.com/duanhf2012/origin/log"
|
||||
"github.com/gorilla/websocket"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type WSClient struct {
|
||||
sync.Mutex
|
||||
Addr string
|
||||
ConnNum int
|
||||
ConnectInterval time.Duration
|
||||
PendingWriteNum int
|
||||
MaxMsgLen uint32
|
||||
HandshakeTimeout time.Duration
|
||||
AutoReconnect bool
|
||||
NewAgent func(*WSConn) Agent
|
||||
dialer websocket.Dialer
|
||||
conns WebsocketConnSet
|
||||
wg sync.WaitGroup
|
||||
closeFlag bool
|
||||
}
|
||||
|
||||
func (client *WSClient) Start() {
|
||||
client.init()
|
||||
|
||||
for i := 0; i < client.ConnNum; i++ {
|
||||
client.wg.Add(1)
|
||||
go client.connect()
|
||||
}
|
||||
}
|
||||
|
||||
func (client *WSClient) init() {
|
||||
client.Lock()
|
||||
defer client.Unlock()
|
||||
|
||||
if client.ConnNum <= 0 {
|
||||
client.ConnNum = 1
|
||||
log.Release("invalid ConnNum, reset to %v", client.ConnNum)
|
||||
}
|
||||
if client.ConnectInterval <= 0 {
|
||||
client.ConnectInterval = 3 * time.Second
|
||||
log.Release("invalid ConnectInterval, reset to %v", client.ConnectInterval)
|
||||
}
|
||||
if client.PendingWriteNum <= 0 {
|
||||
client.PendingWriteNum = 100
|
||||
log.Release("invalid PendingWriteNum, reset to %v", client.PendingWriteNum)
|
||||
}
|
||||
if client.MaxMsgLen <= 0 {
|
||||
client.MaxMsgLen = 4096
|
||||
log.Release("invalid MaxMsgLen, reset to %v", client.MaxMsgLen)
|
||||
}
|
||||
if client.HandshakeTimeout <= 0 {
|
||||
client.HandshakeTimeout = 10 * time.Second
|
||||
log.Release("invalid HandshakeTimeout, reset to %v", client.HandshakeTimeout)
|
||||
}
|
||||
if client.NewAgent == nil {
|
||||
log.Fatal("NewAgent must not be nil")
|
||||
}
|
||||
if client.conns != nil {
|
||||
log.Fatal("client is running")
|
||||
}
|
||||
|
||||
client.conns = make(WebsocketConnSet)
|
||||
client.closeFlag = false
|
||||
client.dialer = websocket.Dialer{
|
||||
HandshakeTimeout: client.HandshakeTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
func (client *WSClient) dial() *websocket.Conn {
|
||||
for {
|
||||
conn, _, err := client.dialer.Dial(client.Addr, nil)
|
||||
if err == nil || client.closeFlag {
|
||||
return conn
|
||||
}
|
||||
|
||||
log.Release("connect to %v error: %v", client.Addr, err)
|
||||
time.Sleep(client.ConnectInterval)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
func (client *WSClient) connect() {
|
||||
defer client.wg.Done()
|
||||
|
||||
reconnect:
|
||||
conn := client.dial()
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
conn.SetReadLimit(int64(client.MaxMsgLen))
|
||||
|
||||
client.Lock()
|
||||
if client.closeFlag {
|
||||
client.Unlock()
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
client.conns[conn] = struct{}{}
|
||||
client.Unlock()
|
||||
|
||||
wsConn := newWSConn(conn, client.PendingWriteNum, client.MaxMsgLen)
|
||||
agent := client.NewAgent(wsConn)
|
||||
agent.Run()
|
||||
|
||||
// cleanup
|
||||
wsConn.Close()
|
||||
client.Lock()
|
||||
delete(client.conns, conn)
|
||||
client.Unlock()
|
||||
agent.OnClose()
|
||||
|
||||
if client.AutoReconnect {
|
||||
time.Sleep(client.ConnectInterval)
|
||||
goto reconnect
|
||||
}
|
||||
}
|
||||
|
||||
func (client *WSClient) Close() {
|
||||
client.Lock()
|
||||
client.closeFlag = true
|
||||
for conn := range client.conns {
|
||||
conn.Close()
|
||||
}
|
||||
client.conns = nil
|
||||
client.Unlock()
|
||||
|
||||
client.wg.Wait()
|
||||
}
|
||||
138
network/ws_conn.go
Normal file
138
network/ws_conn.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/duanhf2012/origin/log"
|
||||
"github.com/gorilla/websocket"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type WebsocketConnSet map[*websocket.Conn]struct{}
|
||||
|
||||
type WSConn struct {
|
||||
sync.Mutex
|
||||
conn *websocket.Conn
|
||||
writeChan chan []byte
|
||||
maxMsgLen uint32
|
||||
closeFlag bool
|
||||
}
|
||||
|
||||
func newWSConn(conn *websocket.Conn, pendingWriteNum int, maxMsgLen uint32) *WSConn {
|
||||
wsConn := new(WSConn)
|
||||
wsConn.conn = conn
|
||||
wsConn.writeChan = make(chan []byte, pendingWriteNum)
|
||||
wsConn.maxMsgLen = maxMsgLen
|
||||
|
||||
go func() {
|
||||
for b := range wsConn.writeChan {
|
||||
if b == nil {
|
||||
break
|
||||
}
|
||||
|
||||
err := conn.WriteMessage(websocket.BinaryMessage, b)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
conn.Close()
|
||||
wsConn.Lock()
|
||||
wsConn.closeFlag = true
|
||||
wsConn.Unlock()
|
||||
}()
|
||||
|
||||
return wsConn
|
||||
}
|
||||
|
||||
func (wsConn *WSConn) doDestroy() {
|
||||
wsConn.conn.UnderlyingConn().(*net.TCPConn).SetLinger(0)
|
||||
wsConn.conn.Close()
|
||||
|
||||
if !wsConn.closeFlag {
|
||||
close(wsConn.writeChan)
|
||||
wsConn.closeFlag = true
|
||||
}
|
||||
}
|
||||
|
||||
func (wsConn *WSConn) Destroy() {
|
||||
wsConn.Lock()
|
||||
defer wsConn.Unlock()
|
||||
|
||||
wsConn.doDestroy()
|
||||
}
|
||||
|
||||
func (wsConn *WSConn) Close() {
|
||||
wsConn.Lock()
|
||||
defer wsConn.Unlock()
|
||||
if wsConn.closeFlag {
|
||||
return
|
||||
}
|
||||
|
||||
wsConn.doWrite(nil)
|
||||
wsConn.closeFlag = true
|
||||
}
|
||||
|
||||
func (wsConn *WSConn) doWrite(b []byte) {
|
||||
if len(wsConn.writeChan) == cap(wsConn.writeChan) {
|
||||
log.Debug("close conn: channel full")
|
||||
wsConn.doDestroy()
|
||||
return
|
||||
}
|
||||
|
||||
wsConn.writeChan <- b
|
||||
}
|
||||
|
||||
func (wsConn *WSConn) LocalAddr() net.Addr {
|
||||
return wsConn.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (wsConn *WSConn) RemoteAddr() net.Addr {
|
||||
return wsConn.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
// goroutine not safe
|
||||
func (wsConn *WSConn) ReadMsg() ([]byte, error) {
|
||||
_, b, err := wsConn.conn.ReadMessage()
|
||||
return b, err
|
||||
}
|
||||
|
||||
// args must not be modified by the others goroutines
|
||||
func (wsConn *WSConn) WriteMsg(args ...[]byte) error {
|
||||
wsConn.Lock()
|
||||
defer wsConn.Unlock()
|
||||
if wsConn.closeFlag {
|
||||
return nil
|
||||
}
|
||||
|
||||
// get len
|
||||
var msgLen uint32
|
||||
for i := 0; i < len(args); i++ {
|
||||
msgLen += uint32(len(args[i]))
|
||||
}
|
||||
|
||||
// check len
|
||||
if msgLen > wsConn.maxMsgLen {
|
||||
return errors.New("message too long")
|
||||
} else if msgLen < 1 {
|
||||
return errors.New("message too short")
|
||||
}
|
||||
|
||||
// don't copy
|
||||
if len(args) == 1 {
|
||||
wsConn.doWrite(args[0])
|
||||
return nil
|
||||
}
|
||||
|
||||
// merge the args
|
||||
msg := make([]byte, msgLen)
|
||||
l := 0
|
||||
for i := 0; i < len(args); i++ {
|
||||
copy(msg[l:], args[i])
|
||||
l += len(args[i])
|
||||
}
|
||||
|
||||
wsConn.doWrite(msg)
|
||||
|
||||
return nil
|
||||
}
|
||||
154
network/ws_server.go
Normal file
154
network/ws_server.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"github.com/duanhf2012/origin/log"
|
||||
"github.com/gorilla/websocket"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type WSServer struct {
|
||||
Addr string
|
||||
MaxConnNum int
|
||||
PendingWriteNum int
|
||||
MaxMsgLen uint32
|
||||
HTTPTimeout time.Duration
|
||||
CertFile string
|
||||
KeyFile string
|
||||
NewAgent func(*WSConn) Agent
|
||||
ln net.Listener
|
||||
handler *WSHandler
|
||||
}
|
||||
|
||||
type WSHandler struct {
|
||||
maxConnNum int
|
||||
pendingWriteNum int
|
||||
maxMsgLen uint32
|
||||
newAgent func(*WSConn) Agent
|
||||
upgrader websocket.Upgrader
|
||||
conns WebsocketConnSet
|
||||
mutexConns sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func (handler *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "GET" {
|
||||
http.Error(w, "Method not allowed", 405)
|
||||
return
|
||||
}
|
||||
conn, err := handler.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
log.Debug("upgrade error: %v", err)
|
||||
return
|
||||
}
|
||||
conn.SetReadLimit(int64(handler.maxMsgLen))
|
||||
|
||||
handler.wg.Add(1)
|
||||
defer handler.wg.Done()
|
||||
|
||||
handler.mutexConns.Lock()
|
||||
if handler.conns == nil {
|
||||
handler.mutexConns.Unlock()
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
if len(handler.conns) >= handler.maxConnNum {
|
||||
handler.mutexConns.Unlock()
|
||||
conn.Close()
|
||||
log.Debug("too many connections")
|
||||
return
|
||||
}
|
||||
handler.conns[conn] = struct{}{}
|
||||
handler.mutexConns.Unlock()
|
||||
|
||||
wsConn := newWSConn(conn, handler.pendingWriteNum, handler.maxMsgLen)
|
||||
agent := handler.newAgent(wsConn)
|
||||
agent.Run()
|
||||
|
||||
// cleanup
|
||||
wsConn.Close()
|
||||
handler.mutexConns.Lock()
|
||||
delete(handler.conns, conn)
|
||||
handler.mutexConns.Unlock()
|
||||
agent.OnClose()
|
||||
}
|
||||
|
||||
func (server *WSServer) Start() {
|
||||
ln, err := net.Listen("tcp", server.Addr)
|
||||
if err != nil {
|
||||
log.Fatal("%v", err)
|
||||
}
|
||||
|
||||
if server.MaxConnNum <= 0 {
|
||||
server.MaxConnNum = 100
|
||||
log.Release("invalid MaxConnNum, reset to %v", server.MaxConnNum)
|
||||
}
|
||||
if server.PendingWriteNum <= 0 {
|
||||
server.PendingWriteNum = 100
|
||||
log.Release("invalid PendingWriteNum, reset to %v", server.PendingWriteNum)
|
||||
}
|
||||
if server.MaxMsgLen <= 0 {
|
||||
server.MaxMsgLen = 4096
|
||||
log.Release("invalid MaxMsgLen, reset to %v", server.MaxMsgLen)
|
||||
}
|
||||
if server.HTTPTimeout <= 0 {
|
||||
server.HTTPTimeout = 10 * time.Second
|
||||
log.Release("invalid HTTPTimeout, reset to %v", server.HTTPTimeout)
|
||||
}
|
||||
if server.NewAgent == nil {
|
||||
log.Fatal("NewAgent must not be nil")
|
||||
}
|
||||
|
||||
if server.CertFile != "" || server.KeyFile != "" {
|
||||
config := &tls.Config{}
|
||||
config.NextProtos = []string{"http/1.1"}
|
||||
|
||||
var err error
|
||||
config.Certificates = make([]tls.Certificate, 1)
|
||||
config.Certificates[0], err = tls.LoadX509KeyPair(server.CertFile, server.KeyFile)
|
||||
if err != nil {
|
||||
log.Fatal("%v", err)
|
||||
}
|
||||
|
||||
ln = tls.NewListener(ln, config)
|
||||
}
|
||||
|
||||
server.ln = ln
|
||||
server.handler = &WSHandler{
|
||||
maxConnNum: server.MaxConnNum,
|
||||
pendingWriteNum: server.PendingWriteNum,
|
||||
maxMsgLen: server.MaxMsgLen,
|
||||
newAgent: server.NewAgent,
|
||||
conns: make(WebsocketConnSet),
|
||||
upgrader: websocket.Upgrader{
|
||||
HandshakeTimeout: server.HTTPTimeout,
|
||||
CheckOrigin: func(_ *http.Request) bool { return true },
|
||||
},
|
||||
}
|
||||
|
||||
httpServer := &http.Server{
|
||||
Addr: server.Addr,
|
||||
Handler: server.handler,
|
||||
ReadTimeout: server.HTTPTimeout,
|
||||
WriteTimeout: server.HTTPTimeout,
|
||||
MaxHeaderBytes: 1024,
|
||||
}
|
||||
|
||||
go httpServer.Serve(ln)
|
||||
}
|
||||
|
||||
func (server *WSServer) Close() {
|
||||
server.ln.Close()
|
||||
|
||||
server.handler.mutexConns.Lock()
|
||||
for conn := range server.handler.conns {
|
||||
conn.Close()
|
||||
}
|
||||
server.handler.conns = nil
|
||||
server.handler.mutexConns.Unlock()
|
||||
|
||||
server.handler.wg.Wait()
|
||||
}
|
||||
Reference in New Issue
Block a user