mirror of
https://github.com/duanhf2012/origin.git
synced 2026-02-03 22:45:13 +08:00
206 lines
4.9 KiB
Go
206 lines
4.9 KiB
Go
package network
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"errors"
|
|
"net"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/duanhf2012/origin/v2/log"
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
type WSServer struct {
|
|
Addr string
|
|
MaxConnNum int
|
|
PendingWriteNum int
|
|
MaxReadMsgLen uint32
|
|
MaxWriteMsgLen uint32
|
|
CertFile string
|
|
KeyFile string
|
|
NewAgent func(*WSConn) Agent
|
|
ln net.Listener
|
|
handler *WSHandler
|
|
messageType int
|
|
|
|
HandshakeTimeout time.Duration
|
|
ReadTimeout time.Duration
|
|
WriteTimeout time.Duration
|
|
}
|
|
|
|
type WSHandler struct {
|
|
maxConnNum int
|
|
pendingWriteNum int
|
|
maxReadMsgLen uint32
|
|
maxWriteMsgLen uint32
|
|
newAgent func(*WSConn) Agent
|
|
upgrader websocket.Upgrader
|
|
conns WebsocketConnSet
|
|
mutexConns sync.Mutex
|
|
wg sync.WaitGroup
|
|
messageType int
|
|
}
|
|
|
|
func (handler *WSHandler) SetMessageType(messageType int) {
|
|
handler.messageType = messageType
|
|
}
|
|
|
|
func (handler *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "GET" {
|
|
http.Error(w, "Method not allowed", 405)
|
|
return
|
|
}
|
|
conn, err := handler.upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
log.Error("upgrade fail", log.String("error", err.Error()))
|
|
return
|
|
}
|
|
conn.SetReadLimit(int64(handler.maxReadMsgLen))
|
|
if handler.messageType == 0 {
|
|
handler.messageType = websocket.TextMessage
|
|
}
|
|
|
|
handler.wg.Add(1)
|
|
defer handler.wg.Done()
|
|
|
|
handler.mutexConns.Lock()
|
|
if handler.conns == nil {
|
|
handler.mutexConns.Unlock()
|
|
conn.Close()
|
|
return
|
|
}
|
|
if len(handler.conns) >= handler.maxConnNum {
|
|
handler.mutexConns.Unlock()
|
|
conn.Close()
|
|
log.Warn("too many connections")
|
|
return
|
|
}
|
|
handler.conns[conn] = struct{}{}
|
|
handler.mutexConns.Unlock()
|
|
c, ok := conn.NetConn().(*net.TCPConn)
|
|
if !ok {
|
|
tlsConn, ok := conn.NetConn().(*tls.Conn)
|
|
if !ok {
|
|
log.Error("conn error")
|
|
return
|
|
}
|
|
c, ok = tlsConn.NetConn().(*net.TCPConn)
|
|
if !ok {
|
|
log.Error("conn error")
|
|
return
|
|
}
|
|
}
|
|
|
|
c.SetLinger(0)
|
|
c.SetNoDelay(true)
|
|
wsConn := newWSConn(conn, r.Header, handler.pendingWriteNum, handler.maxWriteMsgLen, handler.messageType)
|
|
agent := handler.newAgent(wsConn)
|
|
agent.Run()
|
|
|
|
// cleanup
|
|
wsConn.Close()
|
|
handler.mutexConns.Lock()
|
|
delete(handler.conns, conn)
|
|
handler.mutexConns.Unlock()
|
|
agent.OnClose()
|
|
}
|
|
|
|
func (server *WSServer) SetMessageType(messageType int) {
|
|
server.messageType = messageType
|
|
if server.handler != nil {
|
|
server.handler.SetMessageType(messageType)
|
|
}
|
|
}
|
|
|
|
func (server *WSServer) Start() error {
|
|
ln, err := net.Listen("tcp", server.Addr)
|
|
if err != nil {
|
|
log.Error("WSServer Listen fail", log.String("error", err.Error()))
|
|
return err
|
|
}
|
|
if server.MaxConnNum <= 0 {
|
|
server.MaxConnNum = 100
|
|
log.Info("invalid MaxConnNum", log.Int("reset", server.MaxConnNum))
|
|
}
|
|
if server.PendingWriteNum <= 0 {
|
|
server.PendingWriteNum = 100
|
|
log.Info("invalid PendingWriteNum", log.Int("reset", server.PendingWriteNum))
|
|
}
|
|
if server.MaxReadMsgLen <= 0 {
|
|
server.MaxReadMsgLen = 4096
|
|
log.Info("invalid MaxReadMsgLen", log.Uint32("reset", server.MaxReadMsgLen))
|
|
}
|
|
if server.HandshakeTimeout <= 0 {
|
|
server.HandshakeTimeout = 15 * time.Second
|
|
log.Info("invalid HandshakeTimeout", log.Duration("reset", server.HandshakeTimeout))
|
|
}
|
|
if server.ReadTimeout <= 0 {
|
|
server.ReadTimeout = 15 * time.Second
|
|
log.Info("invalid ReadTimeout", log.Duration("reset", server.ReadTimeout))
|
|
}
|
|
if server.WriteTimeout <= 0 {
|
|
server.WriteTimeout = 15 * time.Second
|
|
log.Info("invalid WriteTimeout", log.Duration("reset", server.WriteTimeout))
|
|
}
|
|
|
|
if server.NewAgent == nil {
|
|
log.Error("NewAgent must not be nil")
|
|
return errors.New("NewAgent must not be nil")
|
|
}
|
|
|
|
if server.CertFile != "" || server.KeyFile != "" {
|
|
config := &tls.Config{}
|
|
config.NextProtos = []string{"http/1.1"}
|
|
|
|
config.Certificates = make([]tls.Certificate, 1)
|
|
config.Certificates[0], err = tls.LoadX509KeyPair(server.CertFile, server.KeyFile)
|
|
if err != nil {
|
|
log.Error("LoadX509KeyPair fail", log.String("error", err.Error()))
|
|
return err
|
|
}
|
|
|
|
ln = tls.NewListener(ln, config)
|
|
}
|
|
|
|
server.ln = ln
|
|
server.handler = &WSHandler{
|
|
maxConnNum: server.MaxConnNum,
|
|
pendingWriteNum: server.PendingWriteNum,
|
|
maxReadMsgLen: server.MaxReadMsgLen,
|
|
maxWriteMsgLen: server.MaxWriteMsgLen,
|
|
newAgent: server.NewAgent,
|
|
conns: make(WebsocketConnSet),
|
|
messageType: server.messageType,
|
|
upgrader: websocket.Upgrader{
|
|
HandshakeTimeout: server.HandshakeTimeout,
|
|
CheckOrigin: func(_ *http.Request) bool { return true },
|
|
},
|
|
}
|
|
|
|
httpServer := &http.Server{
|
|
Addr: server.Addr,
|
|
Handler: server.handler,
|
|
ReadTimeout: server.ReadTimeout,
|
|
WriteTimeout: server.WriteTimeout,
|
|
MaxHeaderBytes: 1024,
|
|
}
|
|
|
|
go httpServer.Serve(ln)
|
|
return nil
|
|
}
|
|
|
|
func (server *WSServer) Close() {
|
|
server.ln.Close()
|
|
|
|
server.handler.mutexConns.Lock()
|
|
for conn := range server.handler.conns {
|
|
conn.Close()
|
|
}
|
|
server.handler.conns = nil
|
|
server.handler.mutexConns.Unlock()
|
|
|
|
server.handler.wg.Wait()
|
|
}
|