Browse Source

support proxy protocol for type http

fatedier 6 years ago
parent
commit
b7a73d3469
5 changed files with 21 additions and 7 deletions
  1. 3 0
      client/proxy/proxy.go
  2. 9 2
      server/proxy/http.go
  3. 6 4
      utils/vhost/http.go
  4. 2 0
      utils/vhost/reverseproxy.go
  5. 1 1
      utils/vhost/vhost.go

+ 3 - 0
client/proxy/proxy.go

@@ -523,6 +523,9 @@ func HandleTcpWorkConnection(localInfo *config.LocalSvrConf, proxyPlugin plugin.
 		// check if we need to send proxy protocol info
 		if baseInfo.ProxyProtocolVersion != "" {
 			if m.SrcAddr != "" && m.SrcPort != 0 {
+				if m.DstAddr == "" {
+					m.DstAddr = "127.0.0.1"
+				}
 				h := &pp.Header{
 					Command:            pp.PROXY,
 					SourceAddress:      net.ParseIP(m.SrcAddr),

+ 9 - 2
server/proxy/http.go

@@ -16,6 +16,7 @@ package proxy
 
 import (
 	"io"
+	"net"
 	"strings"
 
 	"github.com/fatedier/frp/g"
@@ -97,8 +98,14 @@ func (pxy *HttpProxy) GetConf() config.ProxyConf {
 	return pxy.cfg
 }
 
-func (pxy *HttpProxy) GetRealConn() (workConn frpNet.Conn, err error) {
-	tmpConn, errRet := pxy.GetWorkConnFromPool(nil, nil)
+func (pxy *HttpProxy) GetRealConn(remoteAddr string) (workConn frpNet.Conn, err error) {
+	rAddr, errRet := net.ResolveTCPAddr("tcp", remoteAddr)
+	if errRet != nil {
+		pxy.Warn("resolve TCP addr [%s] error: %v", remoteAddr, errRet)
+		// we do not return error here since remoteAddr is not necessary for proxies without proxy protocol enabled
+	}
+
+	tmpConn, errRet := pxy.GetWorkConnFromPool(rAddr, nil)
 	if errRet != nil {
 		err = errRet
 		return

+ 6 - 4
utils/vhost/newhttp.go → utils/vhost/http.go

@@ -89,13 +89,15 @@ func NewHttpReverseProxy(option HttpReverseProxyOptions) *HttpReverseProxy {
 			DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
 				url := ctx.Value("url").(string)
 				host := getHostFromAddr(ctx.Value("host").(string))
-				return rp.CreateConnection(host, url)
+				remote := ctx.Value("remote").(string)
+				return rp.CreateConnection(host, url, remote)
 			},
 		},
 		WebSocketDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
 			url := ctx.Value("url").(string)
 			host := getHostFromAddr(ctx.Value("host").(string))
-			return rp.CreateConnection(host, url)
+			remote := ctx.Value("remote").(string)
+			return rp.CreateConnection(host, url, remote)
 		},
 		BufferPool: newWrapPool(),
 		ErrorLog:   log.New(newWrapLogger(), "", 0),
@@ -138,12 +140,12 @@ func (rp *HttpReverseProxy) GetHeaders(domain string, location string) (headers
 	return
 }
 
-func (rp *HttpReverseProxy) CreateConnection(domain string, location string) (net.Conn, error) {
+func (rp *HttpReverseProxy) CreateConnection(domain string, location string, remoteAddr string) (net.Conn, error) {
 	vr, ok := rp.getVhost(domain, location)
 	if ok {
 		fn := vr.payload.(*VhostRouteConfig).CreateConnFn
 		if fn != nil {
-			return fn()
+			return fn(remoteAddr)
 		}
 	}
 	return nil, fmt.Errorf("%v: %s %s", ErrNoDomain, domain, location)

+ 2 - 0
utils/vhost/reverseproxy.go

@@ -158,6 +158,7 @@ func (p *ReverseProxy) serveWebSocket(rw http.ResponseWriter, req *http.Request)
 
 	req = req.WithContext(context.WithValue(req.Context(), "url", req.URL.Path))
 	req = req.WithContext(context.WithValue(req.Context(), "host", req.Host))
+	req = req.WithContext(context.WithValue(req.Context(), "remote", req.RemoteAddr))
 
 	targetConn, err := p.WebSocketDialContext(req.Context(), "tcp", "")
 	if err != nil {
@@ -215,6 +216,7 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
 	// Modify for frp
 	outreq = outreq.WithContext(context.WithValue(outreq.Context(), "url", req.URL.Path))
 	outreq = outreq.WithContext(context.WithValue(outreq.Context(), "host", req.Host))
+	outreq = outreq.WithContext(context.WithValue(outreq.Context(), "remote", req.RemoteAddr))
 
 	p.Director(outreq)
 	outreq.Close = false

+ 1 - 1
utils/vhost/vhost.go

@@ -51,7 +51,7 @@ func NewVhostMuxer(listener frpNet.Listener, vhostFunc muxFunc, authFunc httpAut
 	return mux, nil
 }
 
-type CreateConnFunc func() (frpNet.Conn, error)
+type CreateConnFunc func(remoteAddr string) (frpNet.Conn, error)
 
 type VhostRouteConfig struct {
 	Domain      string