|
@@ -16,6 +16,8 @@ import (
|
|
|
"strings"
|
|
|
"sync"
|
|
|
"time"
|
|
|
+
|
|
|
+ frpIo "github.com/fatedier/frp/utils/io"
|
|
|
)
|
|
|
|
|
|
// onExitFlushLoop is a callback set by tests to detect the state of the
|
|
@@ -59,6 +61,8 @@ type ReverseProxy struct {
|
|
|
// modifies the Response from the backend.
|
|
|
// If it returns an error, the proxy returns a StatusBadGateway error.
|
|
|
ModifyResponse func(*http.Response) error
|
|
|
+
|
|
|
+ WebSocketDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
|
|
|
}
|
|
|
|
|
|
// A BufferPool is an interface for getting and returning temporary
|
|
@@ -139,6 +143,48 @@ var hopHeaders = []string{
|
|
|
}
|
|
|
|
|
|
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|
|
+ if IsWebsocketRequest(req) {
|
|
|
+ p.serveWebSocket(rw, req)
|
|
|
+ } else {
|
|
|
+ p.serveHTTP(rw, req)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (p *ReverseProxy) serveWebSocket(rw http.ResponseWriter, req *http.Request) {
|
|
|
+ if p.WebSocketDialContext == nil {
|
|
|
+ rw.WriteHeader(500)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ req = req.WithContext(context.WithValue(req.Context(), "url", req.URL.Path))
|
|
|
+ req = req.WithContext(context.WithValue(req.Context(), "host", req.Host))
|
|
|
+
|
|
|
+ targetConn, err := p.WebSocketDialContext(req.Context(), "tcp", "")
|
|
|
+ if err != nil {
|
|
|
+ rw.WriteHeader(501)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ defer targetConn.Close()
|
|
|
+
|
|
|
+ p.Director(req)
|
|
|
+
|
|
|
+ hijacker, ok := rw.(http.Hijacker)
|
|
|
+ if !ok {
|
|
|
+ rw.WriteHeader(500)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ conn, _, errHijack := hijacker.Hijack()
|
|
|
+ if errHijack != nil {
|
|
|
+ rw.WriteHeader(500)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ defer conn.Close()
|
|
|
+
|
|
|
+ req.Write(targetConn)
|
|
|
+ frpIo.Join(conn, targetConn)
|
|
|
+}
|
|
|
+
|
|
|
+func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
|
|
|
transport := p.Transport
|
|
|
if transport == nil {
|
|
|
transport = http.DefaultTransport
|
|
@@ -368,3 +414,16 @@ func (m *maxLatencyWriter) flushLoop() {
|
|
|
}
|
|
|
|
|
|
func (m *maxLatencyWriter) stop() { m.done <- true }
|
|
|
+
|
|
|
+func IsWebsocketRequest(req *http.Request) bool {
|
|
|
+ containsHeader := func(name, value string) bool {
|
|
|
+ items := strings.Split(req.Header.Get(name), ",")
|
|
|
+ for _, item := range items {
|
|
|
+ if value == strings.ToLower(strings.TrimSpace(item)) {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ return containsHeader("Connection", "upgrade") && containsHeader("Upgrade", "websocket")
|
|
|
+}
|