Browse Source

newhttp: support websocket

fatedier 7 years ago
parent
commit
cf9193a429
2 changed files with 64 additions and 0 deletions
  1. 5 0
      utils/vhost/newhttp.go
  2. 59 0
      utils/vhost/reverseproxy.go

+ 5 - 0
utils/vhost/newhttp.go

@@ -79,6 +79,11 @@ func NewHttpReverseProxy() *HttpReverseProxy {
 				return rp.CreateConnection(host, url)
 			},
 		},
+		WebSocketDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+			url := ctx.Value("url").(string)
+			host := getHostFromAddr(ctx.Value("host").(string))
+			return rp.CreateConnection(host, url)
+		},
 		BufferPool: newWrapPool(),
 		ErrorLog:   log.New(newWrapLogger(), "", 0),
 	}

+ 59 - 0
utils/vhost/reverseproxy.go

@@ -16,6 +16,8 @@ import (
 	"strings"
 	"sync"
 	"time"
+
+	frpIo "github.com/fatedier/frp/utils/io"
 )
 
 // onExitFlushLoop is a callback set by tests to detect the state of the
@@ -59,6 +61,8 @@ type ReverseProxy struct {
 	// modifies the Response from the backend.
 	// If it returns an error, the proxy returns a StatusBadGateway error.
 	ModifyResponse func(*http.Response) error
+
+	WebSocketDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
 }
 
 // A BufferPool is an interface for getting and returning temporary
@@ -139,6 +143,48 @@ var hopHeaders = []string{
 }
 
 func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
+	if IsWebsocketRequest(req) {
+		p.serveWebSocket(rw, req)
+	} else {
+		p.serveHTTP(rw, req)
+	}
+}
+
+func (p *ReverseProxy) serveWebSocket(rw http.ResponseWriter, req *http.Request) {
+	if p.WebSocketDialContext == nil {
+		rw.WriteHeader(500)
+		return
+	}
+
+	req = req.WithContext(context.WithValue(req.Context(), "url", req.URL.Path))
+	req = req.WithContext(context.WithValue(req.Context(), "host", req.Host))
+
+	targetConn, err := p.WebSocketDialContext(req.Context(), "tcp", "")
+	if err != nil {
+		rw.WriteHeader(501)
+		return
+	}
+	defer targetConn.Close()
+
+	p.Director(req)
+
+	hijacker, ok := rw.(http.Hijacker)
+	if !ok {
+		rw.WriteHeader(500)
+		return
+	}
+	conn, _, errHijack := hijacker.Hijack()
+	if errHijack != nil {
+		rw.WriteHeader(500)
+		return
+	}
+	defer conn.Close()
+
+	req.Write(targetConn)
+	frpIo.Join(conn, targetConn)
+}
+
+func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
 	transport := p.Transport
 	if transport == nil {
 		transport = http.DefaultTransport
@@ -368,3 +414,16 @@ func (m *maxLatencyWriter) flushLoop() {
 }
 
 func (m *maxLatencyWriter) stop() { m.done <- true }
+
+func IsWebsocketRequest(req *http.Request) bool {
+	containsHeader := func(name, value string) bool {
+		items := strings.Split(req.Header.Get(name), ",")
+		for _, item := range items {
+			if value == strings.ToLower(strings.TrimSpace(item)) {
+				return true
+			}
+		}
+		return false
+	}
+	return containsHeader("Connection", "upgrade") && containsHeader("Upgrade", "websocket")
+}