Browse Source

Merge pull request #564 from fatedier/dev

bump version to v0.14.1
fatedier 7 years ago
parent
commit
a384bf5580

+ 5 - 1
client/visitor.go

@@ -259,7 +259,11 @@ func (sv *XtcpVisitor) handleConn(userConn frpNet.Conn) {
 	sv.Trace("send all detect msg done")
 
 	// Listen for visitorConn's address and wait for client connection.
-	lConn, _ := net.ListenUDP("udp", laddr)
+	lConn, err := net.ListenUDP("udp", laddr)
+	if err != nil {
+		sv.Error("listen on visitorConn's local adress error: %v", err)
+		return
+	}
 	lConn.SetReadDeadline(time.Now().Add(5 * time.Second))
 	sidBuf := pool.GetBuf(1024)
 	n, _, err = lConn.ReadFromUDP(sidBuf)

+ 2 - 2
models/config/proxy.go

@@ -635,7 +635,7 @@ func (cfg *StcpProxyConf) LoadFromFile(name string, section ini.Section) (err er
 	if tmpStr == "server" || tmpStr == "visitor" {
 		cfg.Role = tmpStr
 	} else {
-		cfg.Role = "server"
+		return fmt.Errorf("Parse conf error: incorrect role [%s]", tmpStr)
 	}
 
 	cfg.Sk = section["sk"]
@@ -724,7 +724,7 @@ func (cfg *XtcpProxyConf) LoadFromFile(name string, section ini.Section) (err er
 	if tmpStr == "server" || tmpStr == "visitor" {
 		cfg.Role = tmpStr
 	} else {
-		cfg.Role = "server"
+		return fmt.Errorf("Parse conf error: incorrect role [%s]", tmpStr)
 	}
 
 	cfg.Sk = section["sk"]

+ 1 - 1
models/msg/msg.go

@@ -181,5 +181,5 @@ type NatHoleResp struct {
 }
 
 type NatHoleSid struct {
-	Sid string `json"sid"`
+	Sid string `json:"sid"`
 }

+ 1 - 1
models/plugin/http_proxy.go

@@ -111,7 +111,7 @@ func (hp *HttpProxy) Handle(conn io.ReadWriteCloser) {
 	if realConn, ok := conn.(frpNet.Conn); ok {
 		wrapConn = realConn
 	} else {
-		wrapConn = frpNet.WrapReadWriteCloserToConn(conn)
+		wrapConn = frpNet.WrapReadWriteCloserToConn(conn, realConn)
 	}
 
 	sc, rd := frpNet.NewShareConn(wrapConn)

+ 1 - 1
models/plugin/socks5.go

@@ -50,7 +50,7 @@ func (sp *Socks5Plugin) Handle(conn io.ReadWriteCloser) {
 	if realConn, ok := conn.(frpNet.Conn); ok {
 		wrapConn = realConn
 	} else {
-		wrapConn = frpNet.WrapReadWriteCloserToConn(conn)
+		wrapConn = frpNet.WrapReadWriteCloserToConn(conn, realConn)
 	}
 
 	sp.Server.ServeConn(wrapConn)

+ 1 - 1
server/manager.go

@@ -146,7 +146,7 @@ func (vm *VisitorManager) NewConn(name string, conn frpNet.Conn, timestamp int64
 		if useCompression {
 			rwc = frpIo.WithCompression(rwc)
 		}
-		err = l.PutConn(frpNet.WrapReadWriteCloserToConn(rwc))
+		err = l.PutConn(frpNet.WrapReadWriteCloserToConn(rwc, conn))
 	} else {
 		err = fmt.Errorf("custom listener for [%s] doesn't exist", name)
 		return

+ 44 - 12
server/proxy.go

@@ -189,13 +189,16 @@ func (pxy *TcpProxy) Close() {
 type HttpProxy struct {
 	BaseProxy
 	cfg *config.HttpProxyConf
+
+	closeFuncs []func()
 }
 
 func (pxy *HttpProxy) Run() (err error) {
-	routeConfig := &vhost.VhostRouteConfig{
-		RewriteHost: pxy.cfg.HostHeaderRewrite,
-		Username:    pxy.cfg.HttpUser,
-		Password:    pxy.cfg.HttpPwd,
+	routeConfig := vhost.VhostRouteConfig{
+		RewriteHost:  pxy.cfg.HostHeaderRewrite,
+		Username:     pxy.cfg.HttpUser,
+		Password:     pxy.cfg.HttpPwd,
+		CreateConnFn: pxy.GetRealConn,
 	}
 
 	locations := pxy.cfg.Locations
@@ -206,13 +209,16 @@ func (pxy *HttpProxy) Run() (err error) {
 		routeConfig.Domain = domain
 		for _, location := range locations {
 			routeConfig.Location = location
-			l, err := pxy.ctl.svr.VhostHttpMuxer.Listen(routeConfig)
+			err := pxy.ctl.svr.httpReverseProxy.Register(routeConfig)
 			if err != nil {
 				return err
 			}
-			l.AddLogPrefix(pxy.name)
+			tmpDomain := routeConfig.Domain
+			tmpLocation := routeConfig.Location
+			pxy.closeFuncs = append(pxy.closeFuncs, func() {
+				pxy.ctl.svr.httpReverseProxy.UnRegister(tmpDomain, tmpLocation)
+			})
 			pxy.Info("http proxy listen for host [%s] location [%s]", routeConfig.Domain, routeConfig.Location)
-			pxy.listeners = append(pxy.listeners, l)
 		}
 	}
 
@@ -220,17 +226,18 @@ func (pxy *HttpProxy) Run() (err error) {
 		routeConfig.Domain = pxy.cfg.SubDomain + "." + config.ServerCommonCfg.SubDomainHost
 		for _, location := range locations {
 			routeConfig.Location = location
-			l, err := pxy.ctl.svr.VhostHttpMuxer.Listen(routeConfig)
+			err := pxy.ctl.svr.httpReverseProxy.Register(routeConfig)
 			if err != nil {
 				return err
 			}
-			l.AddLogPrefix(pxy.name)
+			tmpDomain := routeConfig.Domain
+			tmpLocation := routeConfig.Location
+			pxy.closeFuncs = append(pxy.closeFuncs, func() {
+				pxy.ctl.svr.httpReverseProxy.UnRegister(tmpDomain, tmpLocation)
+			})
 			pxy.Info("http proxy listen for host [%s] location [%s]", routeConfig.Domain, routeConfig.Location)
-			pxy.listeners = append(pxy.listeners, l)
 		}
 	}
-
-	pxy.startListenHandler(pxy, HandleUserTcpConnection)
 	return
 }
 
@@ -238,8 +245,33 @@ func (pxy *HttpProxy) GetConf() config.ProxyConf {
 	return pxy.cfg
 }
 
+func (pxy *HttpProxy) GetRealConn() (workConn frpNet.Conn, err error) {
+	tmpConn, errRet := pxy.GetWorkConnFromPool()
+	if errRet != nil {
+		err = errRet
+		return
+	}
+
+	var rwc io.ReadWriteCloser = tmpConn
+	if pxy.cfg.UseEncryption {
+		rwc, err = frpIo.WithEncryption(rwc, []byte(config.ServerCommonCfg.PrivilegeToken))
+		if err != nil {
+			pxy.Error("create encryption stream error: %v", err)
+			return
+		}
+	}
+	if pxy.cfg.UseCompression {
+		rwc = frpIo.WithCompression(rwc)
+	}
+	workConn = frpNet.WrapReadWriteCloserToConn(rwc, tmpConn)
+	return
+}
+
 func (pxy *HttpProxy) Close() {
 	pxy.BaseProxy.Close()
+	for _, closeFn := range pxy.closeFuncs {
+		closeFn()
+	}
 }
 
 type HttpsProxy struct {

+ 16 - 11
server/service.go

@@ -16,6 +16,8 @@ package server
 
 import (
 	"fmt"
+	"net"
+	"net/http"
 	"time"
 
 	"github.com/fatedier/frp/assets"
@@ -44,12 +46,11 @@ type Service struct {
 	// Accept connections using kcp.
 	kcpListener frpNet.Listener
 
-	// For http proxies, route requests to different clients by hostname and other infomation.
-	VhostHttpMuxer *vhost.HttpMuxer
-
 	// For https proxies, route requests to different clients by hostname and other infomation.
 	VhostHttpsMuxer *vhost.HttpsMuxer
 
+	httpReverseProxy *vhost.HttpReverseProxy
+
 	// Manage all controllers.
 	ctlManager *ControlManager
 
@@ -93,22 +94,26 @@ func NewService() (svr *Service, err error) {
 			err = fmt.Errorf("Listen on kcp address udp [%s:%d] error: %v", cfg.BindAddr, cfg.KcpBindPort, err)
 			return
 		}
-		log.Info("frps kcp listen on udp %s:%d", cfg.BindAddr, cfg.BindPort)
+		log.Info("frps kcp listen on udp %s:%d", cfg.BindAddr, cfg.KcpBindPort)
 	}
 
 	// Create http vhost muxer.
 	if cfg.VhostHttpPort > 0 {
-		var l frpNet.Listener
-		l, err = frpNet.ListenTcp(cfg.ProxyBindAddr, cfg.VhostHttpPort)
-		if err != nil {
-			err = fmt.Errorf("Create vhost http listener error, %v", err)
-			return
+		rp := vhost.NewHttpReverseProxy()
+		svr.httpReverseProxy = rp
+
+		address := fmt.Sprintf("%s:%d", cfg.ProxyBindAddr, cfg.VhostHttpPort)
+		server := &http.Server{
+			Addr:    address,
+			Handler: rp,
 		}
-		svr.VhostHttpMuxer, err = vhost.NewHttpMuxer(l, 30*time.Second)
+		var l net.Listener
+		l, err = net.Listen("tcp", address)
 		if err != nil {
-			err = fmt.Errorf("Create vhost httpMuxer error, %v", err)
+			err = fmt.Errorf("Create vhost http listener error, %v", err)
 			return
 		}
+		go server.Serve(l)
 		log.Info("http service listen on %s:%d", cfg.ProxyBindAddr, cfg.VhostHttpPort)
 	}
 

+ 19 - 1
utils/net/conn.go

@@ -49,32 +49,50 @@ func WrapConn(c net.Conn) Conn {
 type WrapReadWriteCloserConn struct {
 	io.ReadWriteCloser
 	log.Logger
+
+	underConn net.Conn
 }
 
-func WrapReadWriteCloserToConn(rwc io.ReadWriteCloser) Conn {
+func WrapReadWriteCloserToConn(rwc io.ReadWriteCloser, underConn net.Conn) Conn {
 	return &WrapReadWriteCloserConn{
 		ReadWriteCloser: rwc,
 		Logger:          log.NewPrefixLogger(""),
+		underConn:       underConn,
 	}
 }
 
 func (conn *WrapReadWriteCloserConn) LocalAddr() net.Addr {
+	if conn.underConn != nil {
+		return conn.underConn.LocalAddr()
+	}
 	return (*net.TCPAddr)(nil)
 }
 
 func (conn *WrapReadWriteCloserConn) RemoteAddr() net.Addr {
+	if conn.underConn != nil {
+		return conn.underConn.RemoteAddr()
+	}
 	return (*net.TCPAddr)(nil)
 }
 
 func (conn *WrapReadWriteCloserConn) SetDeadline(t time.Time) error {
+	if conn.underConn != nil {
+		return conn.underConn.SetDeadline(t)
+	}
 	return &net.OpError{Op: "set", Net: "wrap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
 }
 
 func (conn *WrapReadWriteCloserConn) SetReadDeadline(t time.Time) error {
+	if conn.underConn != nil {
+		return conn.underConn.SetReadDeadline(t)
+	}
 	return &net.OpError{Op: "set", Net: "wrap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
 }
 
 func (conn *WrapReadWriteCloserConn) SetWriteDeadline(t time.Time) error {
+	if conn.underConn != nil {
+		return conn.underConn.SetWriteDeadline(t)
+	}
 	return &net.OpError{Op: "set", Net: "wrap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
 }
 

+ 1 - 1
utils/version/version.go

@@ -19,7 +19,7 @@ import (
 	"strings"
 )
 
-var version string = "0.14.0"
+var version string = "0.14.1"
 
 func Full() string {
 	return version

+ 186 - 0
utils/vhost/newhttp.go

@@ -0,0 +1,186 @@
+// Copyright 2017 fatedier, fatedier@gmail.com
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vhost
+
+import (
+	"bytes"
+	"context"
+	"errors"
+	"log"
+	"net"
+	"net/http"
+	"strings"
+	"sync"
+	"time"
+
+	frpLog "github.com/fatedier/frp/utils/log"
+	"github.com/fatedier/frp/utils/pool"
+)
+
+var (
+	responseHeaderTimeout = time.Duration(30) * time.Second
+
+	ErrRouterConfigConflict = errors.New("router config conflict")
+	ErrNoDomain             = errors.New("no such domain")
+)
+
+func getHostFromAddr(addr string) (host string) {
+	strs := strings.Split(addr, ":")
+	if len(strs) > 1 {
+		host = strs[0]
+	} else {
+		host = addr
+	}
+	return
+}
+
+type HttpReverseProxy struct {
+	proxy *ReverseProxy
+	tr    *http.Transport
+
+	vhostRouter *VhostRouters
+
+	cfgMu sync.RWMutex
+}
+
+func NewHttpReverseProxy() *HttpReverseProxy {
+	rp := &HttpReverseProxy{
+		vhostRouter: NewVhostRouters(),
+	}
+	proxy := &ReverseProxy{
+		Director: func(req *http.Request) {
+			req.URL.Scheme = "http"
+			url := req.Context().Value("url").(string)
+			host := getHostFromAddr(req.Context().Value("host").(string))
+			host = rp.GetRealHost(host, url)
+			if host != "" {
+				req.Host = host
+			}
+			req.URL.Host = req.Host
+		},
+		Transport: &http.Transport{
+			ResponseHeaderTimeout: responseHeaderTimeout,
+			DisableKeepAlives:     true,
+			DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+				url := ctx.Value("url").(string)
+				host := getHostFromAddr(ctx.Value("host").(string))
+				return rp.CreateConnection(host, url)
+			},
+		},
+		BufferPool: newWrapPool(),
+		ErrorLog:   log.New(newWrapLogger(), "", 0),
+	}
+	rp.proxy = proxy
+	return rp
+}
+
+func (rp *HttpReverseProxy) Register(routeCfg VhostRouteConfig) error {
+	rp.cfgMu.Lock()
+	defer rp.cfgMu.Unlock()
+	_, ok := rp.vhostRouter.Exist(routeCfg.Domain, routeCfg.Location)
+	if ok {
+		return ErrRouterConfigConflict
+	} else {
+		rp.vhostRouter.Add(routeCfg.Domain, routeCfg.Location, &routeCfg)
+	}
+	return nil
+}
+
+func (rp *HttpReverseProxy) UnRegister(domain string, location string) {
+	rp.cfgMu.Lock()
+	defer rp.cfgMu.Unlock()
+	rp.vhostRouter.Del(domain, location)
+}
+
+func (rp *HttpReverseProxy) GetRealHost(domain string, location string) (host string) {
+	vr, ok := rp.getVhost(domain, location)
+	if ok {
+		host = vr.payload.(*VhostRouteConfig).RewriteHost
+	}
+	return
+}
+
+func (rp *HttpReverseProxy) CreateConnection(domain string, location string) (net.Conn, error) {
+	vr, ok := rp.getVhost(domain, location)
+	if ok {
+		fn := vr.payload.(*VhostRouteConfig).CreateConnFn
+		if fn != nil {
+			return fn()
+		}
+	}
+	return nil, ErrNoDomain
+}
+
+func (rp *HttpReverseProxy) CheckAuth(domain, location, user, passwd string) bool {
+	vr, ok := rp.getVhost(domain, location)
+	if ok {
+		checkUser := vr.payload.(*VhostRouteConfig).Username
+		checkPasswd := vr.payload.(*VhostRouteConfig).Password
+		if (checkUser != "" || checkPasswd != "") && (checkUser != user || checkPasswd != passwd) {
+			return false
+		}
+	}
+	return true
+}
+
+func (rp *HttpReverseProxy) getVhost(domain string, location string) (vr *VhostRouter, ok bool) {
+	rp.cfgMu.RLock()
+	defer rp.cfgMu.RUnlock()
+
+	// first we check the full hostname
+	// if not exist, then check the wildcard_domain such as *.example.com
+	vr, ok = rp.vhostRouter.Get(domain, location)
+	if ok {
+		return
+	}
+
+	domainSplit := strings.Split(domain, ".")
+	if len(domainSplit) < 3 {
+		return vr, false
+	}
+	domainSplit[0] = "*"
+	domain = strings.Join(domainSplit, ".")
+	vr, ok = rp.vhostRouter.Get(domain, location)
+	return
+}
+
+func (rp *HttpReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
+	domain := getHostFromAddr(req.Host)
+	location := req.URL.Path
+	user, passwd, _ := req.BasicAuth()
+	if !rp.CheckAuth(domain, location, user, passwd) {
+		rw.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
+		http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
+		return
+	}
+	rp.proxy.ServeHTTP(rw, req)
+}
+
+type wrapPool struct{}
+
+func newWrapPool() *wrapPool { return &wrapPool{} }
+
+func (p *wrapPool) Get() []byte { return pool.GetBuf(32 * 1024) }
+
+func (p *wrapPool) Put(buf []byte) { pool.PutBuf(buf) }
+
+type wrapLogger struct{}
+
+func newWrapLogger() *wrapLogger { return &wrapLogger{} }
+
+func (l *wrapLogger) Write(p []byte) (n int, err error) {
+	frpLog.Warn("%s", string(bytes.TrimRight(p, "\n")))
+	return len(p), nil
+}

+ 370 - 0
utils/vhost/reverseproxy.go

@@ -0,0 +1,370 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// HTTP reverse proxy handler
+
+package vhost
+
+import (
+	"context"
+	"io"
+	"log"
+	"net"
+	"net/http"
+	"net/url"
+	"strings"
+	"sync"
+	"time"
+)
+
+// onExitFlushLoop is a callback set by tests to detect the state of the
+// flushLoop() goroutine.
+var onExitFlushLoop func()
+
+// ReverseProxy is an HTTP Handler that takes an incoming request and
+// sends it to another server, proxying the response back to the
+// client.
+type ReverseProxy struct {
+	// Director must be a function which modifies
+	// the request into a new request to be sent
+	// using Transport. Its response is then copied
+	// back to the original client unmodified.
+	// Director must not access the provided Request
+	// after returning.
+	Director func(*http.Request)
+
+	// The transport used to perform proxy requests.
+	// If nil, http.DefaultTransport is used.
+	Transport http.RoundTripper
+
+	// FlushInterval specifies the flush interval
+	// to flush to the client while copying the
+	// response body.
+	// If zero, no periodic flushing is done.
+	FlushInterval time.Duration
+
+	// ErrorLog specifies an optional logger for errors
+	// that occur when attempting to proxy the request.
+	// If nil, logging goes to os.Stderr via the log package's
+	// standard logger.
+	ErrorLog *log.Logger
+
+	// BufferPool optionally specifies a buffer pool to
+	// get byte slices for use by io.CopyBuffer when
+	// copying HTTP response bodies.
+	BufferPool BufferPool
+
+	// ModifyResponse is an optional function that
+	// modifies the Response from the backend.
+	// If it returns an error, the proxy returns a StatusBadGateway error.
+	ModifyResponse func(*http.Response) error
+}
+
+// A BufferPool is an interface for getting and returning temporary
+// byte slices for use by io.CopyBuffer.
+type BufferPool interface {
+	Get() []byte
+	Put([]byte)
+}
+
+func singleJoiningSlash(a, b string) string {
+	aslash := strings.HasSuffix(a, "/")
+	bslash := strings.HasPrefix(b, "/")
+	switch {
+	case aslash && bslash:
+		return a + b[1:]
+	case !aslash && !bslash:
+		return a + "/" + b
+	}
+	return a + b
+}
+
+// NewSingleHostReverseProxy returns a new ReverseProxy that routes
+// URLs to the scheme, host, and base path provided in target. If the
+// target's path is "/base" and the incoming request was for "/dir",
+// the target request will be for /base/dir.
+// NewSingleHostReverseProxy does not rewrite the Host header.
+// To rewrite Host headers, use ReverseProxy directly with a custom
+// Director policy.
+func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
+	targetQuery := target.RawQuery
+	director := func(req *http.Request) {
+		req.URL.Scheme = target.Scheme
+		req.URL.Host = target.Host
+		req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
+		if targetQuery == "" || req.URL.RawQuery == "" {
+			req.URL.RawQuery = targetQuery + req.URL.RawQuery
+		} else {
+			req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
+		}
+		if _, ok := req.Header["User-Agent"]; !ok {
+			// explicitly disable User-Agent so it's not set to default value
+			req.Header.Set("User-Agent", "")
+		}
+	}
+	return &ReverseProxy{Director: director}
+}
+
+func copyHeader(dst, src http.Header) {
+	for k, vv := range src {
+		for _, v := range vv {
+			dst.Add(k, v)
+		}
+	}
+}
+
+func cloneHeader(h http.Header) http.Header {
+	h2 := make(http.Header, len(h))
+	for k, vv := range h {
+		vv2 := make([]string, len(vv))
+		copy(vv2, vv)
+		h2[k] = vv2
+	}
+	return h2
+}
+
+// Hop-by-hop headers. These are removed when sent to the backend.
+// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
+var hopHeaders = []string{
+	"Connection",
+	"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
+	"Keep-Alive",
+	"Proxy-Authenticate",
+	"Proxy-Authorization",
+	"Te",      // canonicalized version of "TE"
+	"Trailer", // not Trailers per URL above; http://www.rfc-editor.org/errata_search.php?eid=4522
+	"Transfer-Encoding",
+	"Upgrade",
+}
+
+func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
+	transport := p.Transport
+	if transport == nil {
+		transport = http.DefaultTransport
+	}
+
+	ctx := req.Context()
+	if cn, ok := rw.(http.CloseNotifier); ok {
+		var cancel context.CancelFunc
+		ctx, cancel = context.WithCancel(ctx)
+		defer cancel()
+		notifyChan := cn.CloseNotify()
+		go func() {
+			select {
+			case <-notifyChan:
+				cancel()
+			case <-ctx.Done():
+			}
+		}()
+	}
+
+	outreq := req.WithContext(ctx) // includes shallow copies of maps, but okay
+	if req.ContentLength == 0 {
+		outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
+	}
+
+	outreq.Header = cloneHeader(req.Header)
+
+	// Modify for frp
+	outreq = outreq.WithContext(context.WithValue(outreq.Context(), "url", req.URL.Path))
+	outreq = outreq.WithContext(context.WithValue(outreq.Context(), "host", req.Host))
+
+	p.Director(outreq)
+	outreq.Close = false
+
+	// Remove hop-by-hop headers listed in the "Connection" header.
+	// See RFC 2616, section 14.10.
+	if c := outreq.Header.Get("Connection"); c != "" {
+		for _, f := range strings.Split(c, ",") {
+			if f = strings.TrimSpace(f); f != "" {
+				outreq.Header.Del(f)
+			}
+		}
+	}
+
+	// Remove hop-by-hop headers to the backend. Especially
+	// important is "Connection" because we want a persistent
+	// connection, regardless of what the client sent to us.
+	for _, h := range hopHeaders {
+		if outreq.Header.Get(h) != "" {
+			outreq.Header.Del(h)
+		}
+	}
+
+	if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
+		// If we aren't the first proxy retain prior
+		// X-Forwarded-For information as a comma+space
+		// separated list and fold multiple headers into one.
+		if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
+			clientIP = strings.Join(prior, ", ") + ", " + clientIP
+		}
+		outreq.Header.Set("X-Forwarded-For", clientIP)
+	}
+
+	res, err := transport.RoundTrip(outreq)
+	if err != nil {
+		p.logf("http: proxy error: %v", err)
+		rw.WriteHeader(http.StatusNotFound)
+		rw.Write([]byte(NotFound))
+		return
+	}
+
+	// Remove hop-by-hop headers listed in the
+	// "Connection" header of the response.
+	if c := res.Header.Get("Connection"); c != "" {
+		for _, f := range strings.Split(c, ",") {
+			if f = strings.TrimSpace(f); f != "" {
+				res.Header.Del(f)
+			}
+		}
+	}
+
+	for _, h := range hopHeaders {
+		res.Header.Del(h)
+	}
+
+	if p.ModifyResponse != nil {
+		if err := p.ModifyResponse(res); err != nil {
+			p.logf("http: proxy error: %v", err)
+			rw.WriteHeader(http.StatusBadGateway)
+			return
+		}
+	}
+
+	copyHeader(rw.Header(), res.Header)
+
+	// The "Trailer" header isn't included in the Transport's response,
+	// at least for *http.Transport. Build it up from Trailer.
+	announcedTrailers := len(res.Trailer)
+	if announcedTrailers > 0 {
+		trailerKeys := make([]string, 0, len(res.Trailer))
+		for k := range res.Trailer {
+			trailerKeys = append(trailerKeys, k)
+		}
+		rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
+	}
+
+	rw.WriteHeader(res.StatusCode)
+	if len(res.Trailer) > 0 {
+		// Force chunking if we saw a response trailer.
+		// This prevents net/http from calculating the length for short
+		// bodies and adding a Content-Length.
+		if fl, ok := rw.(http.Flusher); ok {
+			fl.Flush()
+		}
+	}
+	p.copyResponse(rw, res.Body)
+	res.Body.Close() // close now, instead of defer, to populate res.Trailer
+
+	if len(res.Trailer) == announcedTrailers {
+		copyHeader(rw.Header(), res.Trailer)
+		return
+	}
+
+	for k, vv := range res.Trailer {
+		k = http.TrailerPrefix + k
+		for _, v := range vv {
+			rw.Header().Add(k, v)
+		}
+	}
+}
+
+func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
+	if p.FlushInterval != 0 {
+		if wf, ok := dst.(writeFlusher); ok {
+			mlw := &maxLatencyWriter{
+				dst:     wf,
+				latency: p.FlushInterval,
+				done:    make(chan bool),
+			}
+			go mlw.flushLoop()
+			defer mlw.stop()
+			dst = mlw
+		}
+	}
+
+	var buf []byte
+	if p.BufferPool != nil {
+		buf = p.BufferPool.Get()
+	}
+	p.copyBuffer(dst, src, buf)
+	if p.BufferPool != nil {
+		p.BufferPool.Put(buf)
+	}
+}
+
+func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
+	if len(buf) == 0 {
+		buf = make([]byte, 32*1024)
+	}
+	var written int64
+	for {
+		nr, rerr := src.Read(buf)
+		if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
+			p.logf("httputil: ReverseProxy read error during body copy: %v", rerr)
+		}
+		if nr > 0 {
+			nw, werr := dst.Write(buf[:nr])
+			if nw > 0 {
+				written += int64(nw)
+			}
+			if werr != nil {
+				return written, werr
+			}
+			if nr != nw {
+				return written, io.ErrShortWrite
+			}
+		}
+		if rerr != nil {
+			return written, rerr
+		}
+	}
+}
+
+func (p *ReverseProxy) logf(format string, args ...interface{}) {
+	if p.ErrorLog != nil {
+		p.ErrorLog.Printf(format, args...)
+	} else {
+		log.Printf(format, args...)
+	}
+}
+
+type writeFlusher interface {
+	io.Writer
+	http.Flusher
+}
+
+type maxLatencyWriter struct {
+	dst     writeFlusher
+	latency time.Duration
+
+	mu   sync.Mutex // protects Write + Flush
+	done chan bool
+}
+
+func (m *maxLatencyWriter) Write(p []byte) (int, error) {
+	m.mu.Lock()
+	defer m.mu.Unlock()
+	return m.dst.Write(p)
+}
+
+func (m *maxLatencyWriter) flushLoop() {
+	t := time.NewTicker(m.latency)
+	defer t.Stop()
+	for {
+		select {
+		case <-m.done:
+			if onExitFlushLoop != nil {
+				onExitFlushLoop()
+			}
+			return
+		case <-t.C:
+			m.mu.Lock()
+			m.dst.Flush()
+			m.mu.Unlock()
+		}
+	}
+}
+
+func (m *maxLatencyWriter) stop() { m.done <- true }

+ 4 - 3
utils/vhost/router.go

@@ -14,7 +14,8 @@ type VhostRouters struct {
 type VhostRouter struct {
 	domain   string
 	location string
-	listener *Listener
+
+	payload interface{}
 }
 
 func NewVhostRouters() *VhostRouters {
@@ -23,7 +24,7 @@ func NewVhostRouters() *VhostRouters {
 	}
 }
 
-func (r *VhostRouters) Add(domain, location string, l *Listener) {
+func (r *VhostRouters) Add(domain, location string, payload interface{}) {
 	r.mutex.Lock()
 	defer r.mutex.Unlock()
 
@@ -35,7 +36,7 @@ func (r *VhostRouters) Add(domain, location string, l *Listener) {
 	vr := &VhostRouter{
 		domain:   domain,
 		location: location,
-		listener: l,
+		payload:  payload,
 	}
 	vrs = append(vrs, vr)
 

+ 6 - 2
utils/vhost/vhost.go

@@ -50,12 +50,16 @@ func NewVhostMuxer(listener frpNet.Listener, vhostFunc muxFunc, authFunc httpAut
 	return mux, nil
 }
 
+type CreateConnFunc func() (frpNet.Conn, error)
+
 type VhostRouteConfig struct {
 	Domain      string
 	Location    string
 	RewriteHost string
 	Username    string
 	Password    string
+
+	CreateConnFn CreateConnFunc
 }
 
 // listen for a new domain name, if rewriteHost is not empty  and rewriteFunc is not nil
@@ -91,7 +95,7 @@ func (v *VhostMuxer) getListener(name, path string) (l *Listener, exist bool) {
 	// if not exist, then check the wildcard_domain such as *.example.com
 	vr, found := v.registryRouter.Get(name, path)
 	if found {
-		return vr.listener, true
+		return vr.payload.(*Listener), true
 	}
 
 	domainSplit := strings.Split(name, ".")
@@ -106,7 +110,7 @@ func (v *VhostMuxer) getListener(name, path string) (l *Listener, exist bool) {
 		return
 	}
 
-	return vr.listener, true
+	return vr.payload.(*Listener), true
 }
 
 func (v *VhostMuxer) run() {