Browse Source

support proxy protocol in unix_domain_socket

fatedier 5 years ago
parent
commit
6a1f15b25e

+ 36 - 29
client/proxy/proxy.go

@@ -503,10 +503,43 @@ func HandleTcpWorkConnection(localInfo *config.LocalSvrConf, proxyPlugin plugin.
 		remote = frpIo.WithCompression(remote)
 	}
 
+	// check if we need to send proxy protocol info
+	var extraInfo []byte
+	if baseInfo.ProxyProtocolVersion != "" {
+		if m.SrcAddr != "" && m.SrcPort != 0 {
+			if m.DstAddr == "" {
+				m.DstAddr = "127.0.0.1"
+			}
+			h := &pp.Header{
+				Command:            pp.PROXY,
+				SourceAddress:      net.ParseIP(m.SrcAddr),
+				SourcePort:         m.SrcPort,
+				DestinationAddress: net.ParseIP(m.DstAddr),
+				DestinationPort:    m.DstPort,
+			}
+
+			if h.SourceAddress.To16() == nil {
+				h.TransportProtocol = pp.TCPv4
+			} else {
+				h.TransportProtocol = pp.TCPv6
+			}
+
+			if baseInfo.ProxyProtocolVersion == "v1" {
+				h.Version = 1
+			} else if baseInfo.ProxyProtocolVersion == "v2" {
+				h.Version = 2
+			}
+
+			buf := bytes.NewBuffer(nil)
+			h.WriteTo(buf)
+			extraInfo = buf.Bytes()
+		}
+	}
+
 	if proxyPlugin != nil {
 		// if plugin is set, let plugin handle connections first
 		workConn.Debug("handle by plugin: %s", proxyPlugin.Name())
-		proxyPlugin.Handle(remote, workConn)
+		proxyPlugin.Handle(remote, workConn, extraInfo)
 		workConn.Debug("handle by plugin finished")
 		return
 	} else {
@@ -520,34 +553,8 @@ func HandleTcpWorkConnection(localInfo *config.LocalSvrConf, proxyPlugin plugin.
 		workConn.Debug("join connections, localConn(l[%s] r[%s]) workConn(l[%s] r[%s])", localConn.LocalAddr().String(),
 			localConn.RemoteAddr().String(), workConn.LocalAddr().String(), workConn.RemoteAddr().String())
 
-		// check if we need to send proxy protocol info
-		if baseInfo.ProxyProtocolVersion != "" {
-			if m.SrcAddr != "" && m.SrcPort != 0 {
-				if m.DstAddr == "" {
-					m.DstAddr = "127.0.0.1"
-				}
-				h := &pp.Header{
-					Command:            pp.PROXY,
-					SourceAddress:      net.ParseIP(m.SrcAddr),
-					SourcePort:         m.SrcPort,
-					DestinationAddress: net.ParseIP(m.DstAddr),
-					DestinationPort:    m.DstPort,
-				}
-
-				if h.SourceAddress.To16() == nil {
-					h.TransportProtocol = pp.TCPv4
-				} else {
-					h.TransportProtocol = pp.TCPv6
-				}
-
-				if baseInfo.ProxyProtocolVersion == "v1" {
-					h.Version = 1
-				} else if baseInfo.ProxyProtocolVersion == "v2" {
-					h.Version = 2
-				}
-
-				h.WriteTo(localConn)
-			}
+		if len(extraInfo) > 0 {
+			localConn.Write(extraInfo)
 		}
 
 		frpIo.Join(localConn, remote)

+ 1 - 1
models/plugin/http_proxy.go

@@ -64,7 +64,7 @@ func (hp *HttpProxy) Name() string {
 	return PluginHttpProxy
 }
 
-func (hp *HttpProxy) Handle(conn io.ReadWriteCloser, realConn frpNet.Conn) {
+func (hp *HttpProxy) Handle(conn io.ReadWriteCloser, realConn frpNet.Conn, extraBufToLocal []byte) {
 	wrapConn := frpNet.WrapReadWriteCloserToConn(conn, realConn)
 
 	sc, rd := gnet.NewSharedConn(wrapConn)

+ 1 - 6
models/plugin/https2http.go

@@ -100,16 +100,11 @@ func (p *HTTPS2HTTPPlugin) genTLSConfig() (*tls.Config, error) {
 	return config, nil
 }
 
-func (p *HTTPS2HTTPPlugin) Handle(conn io.ReadWriteCloser, realConn frpNet.Conn) {
+func (p *HTTPS2HTTPPlugin) Handle(conn io.ReadWriteCloser, realConn frpNet.Conn, extraBufToLocal []byte) {
 	wrapConn := frpNet.WrapReadWriteCloserToConn(conn, realConn)
 	p.l.PutConn(wrapConn)
 }
 
-func (p *HTTPS2HTTPPlugin) handleRequest(w http.ResponseWriter, r *http.Request) {
-	w.Write([]byte("hello"))
-	return
-}
-
 func (p *HTTPS2HTTPPlugin) Name() string {
 	return PluginHTTPS2HTTP
 }

+ 1 - 1
models/plugin/plugin.go

@@ -46,7 +46,7 @@ func Create(name string, params map[string]string) (p Plugin, err error) {
 
 type Plugin interface {
 	Name() string
-	Handle(conn io.ReadWriteCloser, realConn frpNet.Conn)
+	Handle(conn io.ReadWriteCloser, realConn frpNet.Conn, extraBufToLocal []byte)
 	Close() error
 }
 

+ 1 - 1
models/plugin/socks5.go

@@ -53,7 +53,7 @@ func NewSocks5Plugin(params map[string]string) (p Plugin, err error) {
 	return
 }
 
-func (sp *Socks5Plugin) Handle(conn io.ReadWriteCloser, realConn frpNet.Conn) {
+func (sp *Socks5Plugin) Handle(conn io.ReadWriteCloser, realConn frpNet.Conn, extraBufToLocal []byte) {
 	defer conn.Close()
 	wrapConn := frpNet.WrapReadWriteCloserToConn(conn, realConn)
 	sp.Server.ServeConn(wrapConn)

+ 1 - 1
models/plugin/static_file.go

@@ -72,7 +72,7 @@ func NewStaticFilePlugin(params map[string]string) (Plugin, error) {
 	return sp, nil
 }
 
-func (sp *StaticFilePlugin) Handle(conn io.ReadWriteCloser, realConn frpNet.Conn) {
+func (sp *StaticFilePlugin) Handle(conn io.ReadWriteCloser, realConn frpNet.Conn, extraBufToLocal []byte) {
 	wrapConn := frpNet.WrapReadWriteCloserToConn(conn, realConn)
 	sp.l.PutConn(wrapConn)
 }

+ 4 - 1
models/plugin/unix_domain_socket.go

@@ -53,11 +53,14 @@ func NewUnixDomainSocketPlugin(params map[string]string) (p Plugin, err error) {
 	return
 }
 
-func (uds *UnixDomainSocketPlugin) Handle(conn io.ReadWriteCloser, realConn frpNet.Conn) {
+func (uds *UnixDomainSocketPlugin) Handle(conn io.ReadWriteCloser, realConn frpNet.Conn, extraBufToLocal []byte) {
 	localConn, err := net.DialUnix("unix", nil, uds.UnixAddr)
 	if err != nil {
 		return
 	}
+	if len(extraBufToLocal) > 0 {
+		localConn.Write(extraBufToLocal)
+	}
 
 	frpIo.Join(localConn, conn)
 }