Browse Source

websocket: update muxer for websocket

fatedier 6 years ago
parent
commit
7793f55545
4 changed files with 140 additions and 152 deletions
  1. 4 3
      models/config/client_common.go
  2. 16 37
      server/service.go
  3. 66 36
      utils/net/conn.go
  4. 54 76
      utils/net/websocket.go

+ 4 - 3
models/config/client_common.go

@@ -186,9 +186,10 @@ func UnmarshalClientConfFromIni(defaultCfg *ClientCommonConf, content string) (c
 	}
 
 	if tmpStr, ok = conf.Get("common", "protocol"); ok {
-		// Now it only support tcp and kcp.
-		if tmpStr != "kcp" && tmpStr != "websocket" {
-			tmpStr = "tcp"
+		// Now it only support tcp and kcp and websocket.
+		if tmpStr != "tcp" && tmpStr != "kcp" && tmpStr != "websocket" {
+			err = fmt.Errorf("Parse conf error: invalid protocol")
+			return
 		}
 		cfg.Protocol = tmpStr
 	}

+ 16 - 37
server/service.go

@@ -15,11 +15,11 @@
 package server
 
 import (
+	"bytes"
 	"fmt"
 	"io/ioutil"
 	"net"
 	"net/http"
-	"strings"
 	"time"
 
 	"github.com/fatedier/frp/assets"
@@ -139,6 +139,13 @@ func NewService() (svr *Service, err error) {
 		log.Info("frps kcp listen on udp %s:%d", cfg.BindAddr, cfg.KcpBindPort)
 	}
 
+	// Listen for accepting connections from client using websocket protocol.
+	websocketPrefix := []byte("GET /%23frp")
+	websocketLn := svr.muxer.Listen(0, uint32(len(websocketPrefix)), func(data []byte) bool {
+		return bytes.Equal(data, websocketPrefix)
+	})
+	svr.websocketListener = frpNet.NewWebsocketListener(websocketLn)
+
 	// Create http vhost muxer.
 	if cfg.VhostHttpPort > 0 {
 		rp := vhost.NewHttpReverseProxy()
@@ -150,7 +157,9 @@ func NewService() (svr *Service, err error) {
 			Handler: rp,
 		}
 		var l net.Listener
-		if !httpMuxOn {
+		if httpMuxOn {
+			l = svr.muxer.ListenHttp(1)
+		} else {
 			l, err = net.Listen("tcp", address)
 			if err != nil {
 				err = fmt.Errorf("Create vhost http listener error, %v", err)
@@ -165,7 +174,7 @@ func NewService() (svr *Service, err error) {
 	if cfg.VhostHttpsPort > 0 {
 		var l net.Listener
 		if httpsMuxOn {
-			l = svr.muxer.ListenHttps(0)
+			l = svr.muxer.ListenHttps(1)
 		} else {
 			l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", cfg.ProxyBindAddr, cfg.VhostHttpsPort))
 			if err != nil {
@@ -205,37 +214,6 @@ func NewService() (svr *Service, err error) {
 		log.Info("Dashboard listen on %s:%d", cfg.DashboardAddr, cfg.DashboardPort)
 	}
 
-	if !httpMuxOn {
-		svr.websocketListener, err = frpNet.NewWebsocketListener(svr.muxer.ListenHttp(0), nil)
-		return
-	}
-
-	// server := &http.Server{}
-	if httpMuxOn {
-		rp := svr.httpReverseProxy
-		svr.websocketListener, err = frpNet.NewWebsocketListener(svr.muxer.ListenHttp(0),
-			func(w http.ResponseWriter, req *http.Request) bool {
-				domain := getHostFromAddr(req.Host)
-				location := req.URL.Path
-				headers := rp.GetHeaders(domain, location)
-				if headers == nil {
-					return true
-				}
-				rp.ServeHTTP(w, req)
-				return false
-			})
-	}
-
-	return
-}
-
-func getHostFromAddr(addr string) (host string) {
-	strs := strings.Split(addr, ":")
-	if len(strs) > 1 {
-		host = strs[0]
-	} else {
-		host = addr
-	}
 	return
 }
 
@@ -246,9 +224,9 @@ func (svr *Service) Run() {
 	if g.GlbServerCfg.KcpBindPort > 0 {
 		go svr.HandleListener(svr.kcpListener)
 	}
-	if svr.websocketListener != nil {
-		go svr.HandleListener(svr.websocketListener)
-	}
+
+	go svr.HandleListener(svr.websocketListener)
+
 	svr.HandleListener(svr.listener)
 }
 
@@ -260,6 +238,7 @@ func (svr *Service) HandleListener(l frpNet.Listener) {
 			log.Warn("Listener for incoming connections from client closed")
 			return
 		}
+
 		// Start a new goroutine for dealing connections.
 		go func(frpConn frpNet.Conn) {
 			dealFn := func(conn frpNet.Conn) {

+ 66 - 36
utils/net/conn.go

@@ -96,47 +96,34 @@ func (conn *WrapReadWriteCloserConn) SetWriteDeadline(t time.Time) error {
 	return &net.OpError{Op: "set", Net: "wrap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
 }
 
-func ConnectServer(protocol string, addr string) (c Conn, err error) {
-	switch protocol {
-	case "tcp":
-		return ConnectTcpServer(addr)
-	case "kcp":
-		kcpConn, errRet := kcp.DialWithOptions(addr, nil, 10, 3)
-		if errRet != nil {
-			err = errRet
-			return
-		}
-		kcpConn.SetStreamMode(true)
-		kcpConn.SetWriteDelay(true)
-		kcpConn.SetNoDelay(1, 20, 2, 1)
-		kcpConn.SetWindowSize(128, 512)
-		kcpConn.SetMtu(1350)
-		kcpConn.SetACKNoDelay(false)
-		kcpConn.SetReadBuffer(4194304)
-		kcpConn.SetWriteBuffer(4194304)
-		c = WrapConn(kcpConn)
-		return
-	default:
-		return nil, fmt.Errorf("unsupport protocol: %s", protocol)
+type CloseNotifyConn struct {
+	net.Conn
+	log.Logger
+
+	// 1 means closed
+	closeFlag int32
+
+	closeFn func()
+}
+
+// closeFn will be only called once
+func WrapCloseNotifyConn(c net.Conn, closeFn func()) Conn {
+	return &CloseNotifyConn{
+		Conn:    c,
+		Logger:  log.NewPrefixLogger(""),
+		closeFn: closeFn,
 	}
 }
 
-func ConnectServerByProxy(proxyUrl string, protocol string, addr string) (c Conn, err error) {
-	switch protocol {
-	case "tcp":
-		var conn net.Conn
-		if conn, err = gnet.DialTcpByProxy(proxyUrl, addr); err != nil {
-			return
+func (cc *CloseNotifyConn) Close() (err error) {
+	pflag := atomic.SwapInt32(&cc.closeFlag, 1)
+	if pflag == 0 {
+		err = cc.Close()
+		if cc.closeFn != nil {
+			cc.closeFn()
 		}
-		return WrapConn(conn), nil
-	case "kcp":
-		// http proxy is not supported for kcp
-		return ConnectServer(protocol, addr)
-	case "websocket":
-		return ConnectWebsocketServer(addr)
-	default:
-		return nil, fmt.Errorf("unsupport protocol: %s", protocol)
 	}
+	return
 }
 
 type StatsConn struct {
@@ -177,3 +164,46 @@ func (statsConn *StatsConn) Close() (err error) {
 	}
 	return
 }
+
+func ConnectServer(protocol string, addr string) (c Conn, err error) {
+	switch protocol {
+	case "tcp":
+		return ConnectTcpServer(addr)
+	case "kcp":
+		kcpConn, errRet := kcp.DialWithOptions(addr, nil, 10, 3)
+		if errRet != nil {
+			err = errRet
+			return
+		}
+		kcpConn.SetStreamMode(true)
+		kcpConn.SetWriteDelay(true)
+		kcpConn.SetNoDelay(1, 20, 2, 1)
+		kcpConn.SetWindowSize(128, 512)
+		kcpConn.SetMtu(1350)
+		kcpConn.SetACKNoDelay(false)
+		kcpConn.SetReadBuffer(4194304)
+		kcpConn.SetWriteBuffer(4194304)
+		c = WrapConn(kcpConn)
+		return
+	default:
+		return nil, fmt.Errorf("unsupport protocol: %s", protocol)
+	}
+}
+
+func ConnectServerByProxy(proxyUrl string, protocol string, addr string) (c Conn, err error) {
+	switch protocol {
+	case "tcp":
+		var conn net.Conn
+		if conn, err = gnet.DialTcpByProxy(proxyUrl, addr); err != nil {
+			return
+		}
+		return WrapConn(conn), nil
+	case "kcp":
+		// http proxy is not supported for kcp
+		return ConnectServer(protocol, addr)
+	case "websocket":
+		return ConnectWebsocketServer(addr)
+	default:
+		return nil, fmt.Errorf("unsupport protocol: %s", protocol)
+	}
+}

+ 54 - 76
utils/net/websocket.go

@@ -1,127 +1,105 @@
 package net
 
 import (
+	"errors"
 	"fmt"
 	"net"
 	"net/http"
 	"net/url"
-	"sync/atomic"
 	"time"
 
 	"github.com/fatedier/frp/utils/log"
+
 	"golang.org/x/net/websocket"
 )
 
+var (
+	ErrWebsocketListenerClosed = errors.New("websocket listener closed")
+)
+
+const (
+	FrpWebsocketPath = "/#frp"
+)
+
 type WebsocketListener struct {
+	net.Addr
+	ln     net.Listener
+	accept chan Conn
 	log.Logger
+
 	server    *http.Server
 	httpMutex *http.ServeMux
-	connChan  chan *WebsocketConn
-	closeFlag bool
 }
 
-func NewWebsocketListener(ln net.Listener,
-	filter func(w http.ResponseWriter, r *http.Request) bool) (l *WebsocketListener, err error) {
-	l = &WebsocketListener{
-		httpMutex: http.NewServeMux(),
-		connChan:  make(chan *WebsocketConn),
-		Logger:    log.NewPrefixLogger(""),
+// ln: tcp listener for websocket connections
+func NewWebsocketListener(ln net.Listener) (wl *WebsocketListener) {
+	wl = &WebsocketListener{
+		Addr:   ln.Addr(),
+		accept: make(chan Conn),
+		Logger: log.NewPrefixLogger(""),
 	}
-	l.httpMutex.Handle("/", websocket.Handler(func(c *websocket.Conn) {
-		conn := NewWebScoketConn(c)
-		l.connChan <- conn
-		conn.waitClose()
+
+	muxer := http.NewServeMux()
+	muxer.Handle(FrpWebsocketPath, websocket.Handler(func(c *websocket.Conn) {
+		notifyCh := make(chan struct{})
+		conn := WrapCloseNotifyConn(c, func() {
+			close(notifyCh)
+		})
+		wl.accept <- conn
+		<-notifyCh
 	}))
-	l.server = &http.Server{
-		Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-			if filter != nil && !filter(w, r) {
-				return
-			}
-			l.httpMutex.ServeHTTP(w, r)
-		}),
+
+	wl.server = &http.Server{
+		Addr:    ln.Addr().String(),
+		Handler: muxer,
 	}
-	ch := make(chan struct{})
-	go func() {
-		close(ch)
-		err = l.server.Serve(ln)
-	}()
-	<-ch
-	<-time.After(time.Millisecond)
+
+	go wl.server.Serve(ln)
 	return
 }
 
-func ListenWebsocket(bindAddr string, bindPort int) (l *WebsocketListener, err error) {
-	ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
+func ListenWebsocket(bindAddr string, bindPort int) (*WebsocketListener, error) {
+	tcpLn, err := net.Listen("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
 	if err != nil {
-		return
+		return nil, err
 	}
-	l, err = NewWebsocketListener(ln, nil)
-	return
+	l := NewWebsocketListener(tcpLn)
+	return l, nil
 }
 
 func (p *WebsocketListener) Accept() (Conn, error) {
-	c := <-p.connChan
+	c, ok := <-p.accept
+	if !ok {
+		return nil, ErrWebsocketListenerClosed
+	}
 	return c, nil
 }
 
 func (p *WebsocketListener) Close() error {
-	if !p.closeFlag {
-		p.closeFlag = true
-		p.server.Close()
-	}
-	return nil
-}
-
-type WebsocketConn struct {
-	net.Conn
-	log.Logger
-	closed int32
-	wait   chan struct{}
-}
-
-func NewWebScoketConn(conn net.Conn) (c *WebsocketConn) {
-	c = &WebsocketConn{
-		Conn:   conn,
-		Logger: log.NewPrefixLogger(""),
-		wait:   make(chan struct{}),
-	}
-	return
+	return p.server.Close()
 }
 
-func (p *WebsocketConn) Close() error {
-	if atomic.SwapInt32(&p.closed, 1) == 1 {
-		return nil
-	}
-	close(p.wait)
-	return p.Conn.Close()
-}
-
-func (p *WebsocketConn) waitClose() {
-	<-p.wait
-}
-
-// ConnectWebsocketServer :
-// addr: ws://domain:port
-func ConnectWebsocketServer(addr string) (c Conn, err error) {
-	addr = "ws://" + addr
+// addr: domain:port
+func ConnectWebsocketServer(addr string) (Conn, error) {
+	addr = "ws://" + addr + FrpWebsocketPath
 	uri, err := url.Parse(addr)
 	if err != nil {
-		return
+		return nil, err
 	}
 
 	origin := "http://" + uri.Host
 	cfg, err := websocket.NewConfig(addr, origin)
 	if err != nil {
-		return
+		return nil, err
 	}
 	cfg.Dialer = &net.Dialer{
-		Timeout: time.Second * 10,
+		Timeout: 10 * time.Second,
 	}
 
 	conn, err := websocket.DialConfig(cfg)
 	if err != nil {
-		return
+		return nil, err
 	}
-	c = NewWebScoketConn(conn)
-	return
+	c := WrapConn(conn)
+	return c, nil
 }