package network import ( "errors" "github.com/duanhf2012/origin/log" "github.com/gorilla/websocket" "net" "sync" ) type WebsocketConnSet map[*websocket.Conn]struct{} type WSConn struct { sync.Mutex conn *websocket.Conn writeChan chan []byte maxMsgLen uint32 closeFlag bool } func newWSConn(conn *websocket.Conn, pendingWriteNum int, maxMsgLen uint32,messageType int) *WSConn { wsConn := new(WSConn) wsConn.conn = conn wsConn.writeChan = make(chan []byte, pendingWriteNum) wsConn.maxMsgLen = maxMsgLen go func() { for b := range wsConn.writeChan { if b == nil { break } err := conn.WriteMessage(messageType, b) if err != nil { break } } conn.Close() wsConn.Lock() wsConn.closeFlag = true wsConn.Unlock() }() return wsConn } func (wsConn *WSConn) doDestroy() { wsConn.conn.UnderlyingConn().(*net.TCPConn).SetLinger(0) wsConn.conn.Close() if !wsConn.closeFlag { close(wsConn.writeChan) wsConn.closeFlag = true } } func (wsConn *WSConn) Destroy() { wsConn.Lock() defer wsConn.Unlock() wsConn.doDestroy() } func (wsConn *WSConn) Close() { wsConn.Lock() defer wsConn.Unlock() if wsConn.closeFlag { return } wsConn.doWrite(nil) wsConn.closeFlag = true } func (wsConn *WSConn) doWrite(b []byte) { if len(wsConn.writeChan) == cap(wsConn.writeChan) { log.Debug("close conn: channel full") wsConn.doDestroy() return } wsConn.writeChan <- b } func (wsConn *WSConn) LocalAddr() net.Addr { return wsConn.conn.LocalAddr() } func (wsConn *WSConn) RemoteAddr() net.Addr { return wsConn.conn.RemoteAddr() } // goroutine not safe func (wsConn *WSConn) ReadMsg() ([]byte, error) { _, b, err := wsConn.conn.ReadMessage() return b, err } // args must not be modified by the others goroutines func (wsConn *WSConn) WriteMsg(args ...[]byte) error { wsConn.Lock() defer wsConn.Unlock() if wsConn.closeFlag { return nil } // get len var msgLen uint32 for i := 0; i < len(args); i++ { msgLen += uint32(len(args[i])) } // check len if msgLen > wsConn.maxMsgLen { return errors.New("message too long") } else if msgLen < 1 { return errors.New("message too short") } // don't copy if len(args) == 1 { wsConn.doWrite(args[0]) return nil } // merge the args msg := make([]byte, msgLen) l := 0 for i := 0; i < len(args); i++ { copy(msg[l:], args[i]) l += len(args[i]) } wsConn.doWrite(msg) return nil }