mirror of
https://github.com/duanhf2012/origin.git
synced 2026-02-04 06:54:45 +08:00
138 lines
2.4 KiB
Go
138 lines
2.4 KiB
Go
package network
|
|
|
|
import (
|
|
"errors"
|
|
"github.com/duanhf2012/origin/log"
|
|
"github.com/gorilla/websocket"
|
|
"net"
|
|
"sync"
|
|
)
|
|
|
|
type WebsocketConnSet map[*websocket.Conn]struct{}
|
|
|
|
type WSConn struct {
|
|
sync.Mutex
|
|
conn *websocket.Conn
|
|
writeChan chan []byte
|
|
maxMsgLen uint32
|
|
closeFlag bool
|
|
}
|
|
|
|
func newWSConn(conn *websocket.Conn, pendingWriteNum int, maxMsgLen uint32) *WSConn {
|
|
wsConn := new(WSConn)
|
|
wsConn.conn = conn
|
|
wsConn.writeChan = make(chan []byte, pendingWriteNum)
|
|
wsConn.maxMsgLen = maxMsgLen
|
|
|
|
go func() {
|
|
for b := range wsConn.writeChan {
|
|
if b == nil {
|
|
break
|
|
}
|
|
|
|
err := conn.WriteMessage(websocket.BinaryMessage, b)
|
|
if err != nil {
|
|
break
|
|
}
|
|
}
|
|
|
|
conn.Close()
|
|
wsConn.Lock()
|
|
wsConn.closeFlag = true
|
|
wsConn.Unlock()
|
|
}()
|
|
|
|
return wsConn
|
|
}
|
|
|
|
func (wsConn *WSConn) doDestroy() {
|
|
wsConn.conn.UnderlyingConn().(*net.TCPConn).SetLinger(0)
|
|
wsConn.conn.Close()
|
|
|
|
if !wsConn.closeFlag {
|
|
close(wsConn.writeChan)
|
|
wsConn.closeFlag = true
|
|
}
|
|
}
|
|
|
|
func (wsConn *WSConn) Destroy() {
|
|
wsConn.Lock()
|
|
defer wsConn.Unlock()
|
|
|
|
wsConn.doDestroy()
|
|
}
|
|
|
|
func (wsConn *WSConn) Close() {
|
|
wsConn.Lock()
|
|
defer wsConn.Unlock()
|
|
if wsConn.closeFlag {
|
|
return
|
|
}
|
|
|
|
wsConn.doWrite(nil)
|
|
wsConn.closeFlag = true
|
|
}
|
|
|
|
func (wsConn *WSConn) doWrite(b []byte) {
|
|
if len(wsConn.writeChan) == cap(wsConn.writeChan) {
|
|
log.Debug("close conn: channel full")
|
|
wsConn.doDestroy()
|
|
return
|
|
}
|
|
|
|
wsConn.writeChan <- b
|
|
}
|
|
|
|
func (wsConn *WSConn) LocalAddr() net.Addr {
|
|
return wsConn.conn.LocalAddr()
|
|
}
|
|
|
|
func (wsConn *WSConn) RemoteAddr() net.Addr {
|
|
return wsConn.conn.RemoteAddr()
|
|
}
|
|
|
|
// goroutine not safe
|
|
func (wsConn *WSConn) ReadMsg() ([]byte, error) {
|
|
_, b, err := wsConn.conn.ReadMessage()
|
|
return b, err
|
|
}
|
|
|
|
// args must not be modified by the others goroutines
|
|
func (wsConn *WSConn) WriteMsg(args ...[]byte) error {
|
|
wsConn.Lock()
|
|
defer wsConn.Unlock()
|
|
if wsConn.closeFlag {
|
|
return nil
|
|
}
|
|
|
|
// get len
|
|
var msgLen uint32
|
|
for i := 0; i < len(args); i++ {
|
|
msgLen += uint32(len(args[i]))
|
|
}
|
|
|
|
// check len
|
|
if msgLen > wsConn.maxMsgLen {
|
|
return errors.New("message too long")
|
|
} else if msgLen < 1 {
|
|
return errors.New("message too short")
|
|
}
|
|
|
|
// don't copy
|
|
if len(args) == 1 {
|
|
wsConn.doWrite(args[0])
|
|
return nil
|
|
}
|
|
|
|
// merge the args
|
|
msg := make([]byte, msgLen)
|
|
l := 0
|
|
for i := 0; i < len(args); i++ {
|
|
copy(msg[l:], args[i])
|
|
l += len(args[i])
|
|
}
|
|
|
|
wsConn.doWrite(msg)
|
|
return nil
|
|
}
|