Преглед на файлове

Merge pull request #875 from jettyu/jettyu-websocket

websocket protocol
fatedier преди 6 години
родител
ревизия
64136a3b3e
променени са 5 файла, в които са добавени 177 реда и са изтрити 15 реда
  1. 1 1
      conf/frpc_full.ini
  2. 1 1
      models/config/client_common.go
  3. 46 13
      server/service.go
  4. 2 0
      utils/net/conn.go
  5. 127 0
      utils/net/websocket.go

+ 1 - 1
conf/frpc_full.ini

@@ -41,7 +41,7 @@ user = your_name
 login_fail_exit = true
 
 # communication protocol used to connect to server
-# now it supports tcp and kcp, default is tcp
+# now it supports tcp and kcp and websocket, default is tcp
 protocol = tcp
 
 # specify a dns server, so frpc will use this instead of default one

+ 1 - 1
models/config/client_common.go

@@ -187,7 +187,7 @@ func UnmarshalClientConfFromIni(defaultCfg *ClientCommonConf, content string) (c
 
 	if tmpStr, ok = conf.Get("common", "protocol"); ok {
 		// Now it only support tcp and kcp.
-		if tmpStr != "kcp" {
+		if tmpStr != "kcp" && tmpStr != "websocket" {
 			tmpStr = "tcp"
 		}
 		cfg.Protocol = tmpStr

+ 46 - 13
server/service.go

@@ -19,6 +19,7 @@ import (
 	"io/ioutil"
 	"net"
 	"net/http"
+	"strings"
 	"time"
 
 	"github.com/fatedier/frp/assets"
@@ -53,6 +54,9 @@ type Service struct {
 	// Accept connections using kcp
 	kcpListener frpNet.Listener
 
+	// Accept connections using websocket
+	websocketListener frpNet.Listener
+
 	// For https proxies, route requests to different clients by hostname and other infomation
 	VhostHttpsMuxer *vhost.HttpsMuxer
 
@@ -109,9 +113,6 @@ func NewService() (svr *Service, err error) {
 		if cfg.BindPort == cfg.VhostHttpsPort {
 			httpsMuxOn = true
 		}
-		if httpMuxOn || httpsMuxOn {
-			svr.muxer = mux.NewMux()
-		}
 	}
 
 	// Listen for accepting connections from client.
@@ -120,10 +121,11 @@ func NewService() (svr *Service, err error) {
 		err = fmt.Errorf("Create server listener error, %v", err)
 		return
 	}
-	if svr.muxer != nil {
-		go svr.muxer.Serve(ln)
-		ln = svr.muxer.DefaultListener()
-	}
+
+	svr.muxer = mux.NewMux()
+	go svr.muxer.Serve(ln)
+	ln = svr.muxer.DefaultListener()
+
 	svr.listener = frpNet.WrapLogListener(ln)
 	log.Info("frps tcp listen on %s:%d", cfg.BindAddr, cfg.BindPort)
 
@@ -148,16 +150,14 @@ func NewService() (svr *Service, err error) {
 			Handler: rp,
 		}
 		var l net.Listener
-		if httpMuxOn {
-			l = svr.muxer.ListenHttp(0)
-		} else {
+		if !httpMuxOn {
 			l, err = net.Listen("tcp", address)
 			if err != nil {
 				err = fmt.Errorf("Create vhost http listener error, %v", err)
 				return
 			}
+			go server.Serve(l)
 		}
-		go server.Serve(l)
 		log.Info("http service listen on %s:%d", cfg.ProxyBindAddr, cfg.VhostHttpPort)
 	}
 
@@ -204,6 +204,38 @@ func NewService() (svr *Service, err error) {
 		}
 		log.Info("Dashboard listen on %s:%d", cfg.DashboardAddr, cfg.DashboardPort)
 	}
+
+	if !httpMuxOn {
+		svr.websocketListener, err = frpNet.NewWebsocketListener(svr.muxer.ListenHttp(0), nil)
+		return
+	}
+
+	// server := &http.Server{}
+	if httpMuxOn {
+		rp := svr.httpReverseProxy
+		svr.websocketListener, err = frpNet.NewWebsocketListener(svr.muxer.ListenHttp(0),
+			func(w http.ResponseWriter, req *http.Request) bool {
+				domain := getHostFromAddr(req.Host)
+				location := req.URL.Path
+				headers := rp.GetHeaders(domain, location)
+				if headers == nil {
+					return true
+				}
+				rp.ServeHTTP(w, req)
+				return false
+			})
+	}
+
+	return
+}
+
+func getHostFromAddr(addr string) (host string) {
+	strs := strings.Split(addr, ":")
+	if len(strs) > 1 {
+		host = strs[0]
+	} else {
+		host = addr
+	}
 	return
 }
 
@@ -214,8 +246,10 @@ func (svr *Service) Run() {
 	if g.GlbServerCfg.KcpBindPort > 0 {
 		go svr.HandleListener(svr.kcpListener)
 	}
+	if svr.websocketListener != nil {
+		go svr.HandleListener(svr.websocketListener)
+	}
 	svr.HandleListener(svr.listener)
-
 }
 
 func (svr *Service) HandleListener(l frpNet.Listener) {
@@ -226,7 +260,6 @@ func (svr *Service) HandleListener(l frpNet.Listener) {
 			log.Warn("Listener for incoming connections from client closed")
 			return
 		}
-
 		// Start a new goroutine for dealing connections.
 		go func(frpConn frpNet.Conn) {
 			dealFn := func(conn frpNet.Conn) {

+ 2 - 0
utils/net/conn.go

@@ -132,6 +132,8 @@ func ConnectServerByProxy(proxyUrl string, protocol string, addr string) (c Conn
 	case "kcp":
 		// http proxy is not supported for kcp
 		return ConnectServer(protocol, addr)
+	case "websocket":
+		return ConnectWebsocketServer(addr)
 	default:
 		return nil, fmt.Errorf("unsupport protocol: %s", protocol)
 	}

+ 127 - 0
utils/net/websocket.go

@@ -0,0 +1,127 @@
+package net
+
+import (
+	"fmt"
+	"net"
+	"net/http"
+	"net/url"
+	"sync/atomic"
+	"time"
+
+	"github.com/fatedier/frp/utils/log"
+	"golang.org/x/net/websocket"
+)
+
+type WebsocketListener struct {
+	log.Logger
+	server    *http.Server
+	httpMutex *http.ServeMux
+	connChan  chan *WebsocketConn
+	closeFlag bool
+}
+
+func NewWebsocketListener(ln net.Listener,
+	filter func(w http.ResponseWriter, r *http.Request) bool) (l *WebsocketListener, err error) {
+	l = &WebsocketListener{
+		httpMutex: http.NewServeMux(),
+		connChan:  make(chan *WebsocketConn),
+		Logger:    log.NewPrefixLogger(""),
+	}
+	l.httpMutex.Handle("/", websocket.Handler(func(c *websocket.Conn) {
+		conn := NewWebScoketConn(c)
+		l.connChan <- conn
+		conn.waitClose()
+	}))
+	l.server = &http.Server{
+		Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+			if filter != nil && !filter(w, r) {
+				return
+			}
+			l.httpMutex.ServeHTTP(w, r)
+		}),
+	}
+	ch := make(chan struct{})
+	go func() {
+		close(ch)
+		err = l.server.Serve(ln)
+	}()
+	<-ch
+	<-time.After(time.Millisecond)
+	return
+}
+
+func ListenWebsocket(bindAddr string, bindPort int) (l *WebsocketListener, err error) {
+	ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
+	if err != nil {
+		return
+	}
+	l, err = NewWebsocketListener(ln, nil)
+	return
+}
+
+func (p *WebsocketListener) Accept() (Conn, error) {
+	c := <-p.connChan
+	return c, nil
+}
+
+func (p *WebsocketListener) Close() error {
+	if !p.closeFlag {
+		p.closeFlag = true
+		p.server.Close()
+	}
+	return nil
+}
+
+type WebsocketConn struct {
+	net.Conn
+	log.Logger
+	closed int32
+	wait   chan struct{}
+}
+
+func NewWebScoketConn(conn net.Conn) (c *WebsocketConn) {
+	c = &WebsocketConn{
+		Conn:   conn,
+		Logger: log.NewPrefixLogger(""),
+		wait:   make(chan struct{}),
+	}
+	return
+}
+
+func (p *WebsocketConn) Close() error {
+	if atomic.SwapInt32(&p.closed, 1) == 1 {
+		return nil
+	}
+	close(p.wait)
+	return p.Conn.Close()
+}
+
+func (p *WebsocketConn) waitClose() {
+	<-p.wait
+}
+
+// ConnectWebsocketServer :
+// addr: ws://domain:port
+func ConnectWebsocketServer(addr string) (c Conn, err error) {
+	addr = "ws://" + addr
+	uri, err := url.Parse(addr)
+	if err != nil {
+		return
+	}
+
+	origin := "http://" + uri.Host
+	cfg, err := websocket.NewConfig(addr, origin)
+	if err != nil {
+		return
+	}
+	cfg.Dialer = &net.Dialer{
+		Timeout: time.Second * 10,
+	}
+
+	conn, err := websocket.DialConfig(cfg)
+	if err != nil {
+		return
+	}
+	c = NewWebScoketConn(conn)
+	return
+}