Przeglądaj źródła

client/proxy: simplify the code (#3465)

fatedier 1 rok temu
rodzic
commit
cceab7e1b1
5 zmienionych plików z 147 dodań i 224 usunięć
  1. 47 0
      client/proxy/general_tcp.go
  2. 51 203
      client/proxy/proxy.go
  3. 17 0
      client/proxy/sudp.go
  4. 17 1
      client/proxy/udp.go
  5. 15 20
      client/proxy/xtcp.go

+ 47 - 0
client/proxy/general_tcp.go

@@ -0,0 +1,47 @@
+// Copyright 2023 The frp Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proxy
+
+import (
+	"reflect"
+
+	"github.com/fatedier/frp/pkg/config"
+)
+
+func init() {
+	pxyConfs := []config.ProxyConf{
+		&config.TCPProxyConf{},
+		&config.HTTPProxyConf{},
+		&config.HTTPSProxyConf{},
+		&config.STCPProxyConf{},
+		&config.TCPMuxProxyConf{},
+	}
+	for _, cfg := range pxyConfs {
+		RegisterProxyFactory(reflect.TypeOf(cfg), NewGeneralTCPProxy)
+	}
+}
+
+// GeneralTCPProxy is a general implementation of Proxy interface for TCP protocol.
+// If the default GeneralTCPProxy cannot meet the requirements, you can customize
+// the implementation of the Proxy interface.
+type GeneralTCPProxy struct {
+	*BaseProxy
+}
+
+func NewGeneralTCPProxy(baseProxy *BaseProxy, cfg config.ProxyConf) Proxy {
+	return &GeneralTCPProxy{
+		BaseProxy: baseProxy,
+	}
+}

+ 51 - 203
client/proxy/proxy.go

@@ -19,6 +19,7 @@ import (
 	"context"
 	"io"
 	"net"
+	"reflect"
 	"strconv"
 	"strings"
 	"sync"
@@ -37,6 +38,12 @@ import (
 	"github.com/fatedier/frp/pkg/util/xlog"
 )
 
+var proxyFactoryRegistry = map[reflect.Type]func(*BaseProxy, config.ProxyConf) Proxy{}
+
+func RegisterProxyFactory(proxyConfType reflect.Type, factory func(*BaseProxy, config.ProxyConf) Proxy) {
+	proxyFactoryRegistry[proxyConfType] = factory
+}
+
 // Proxy defines how to handle work connections for different proxy type.
 type Proxy interface {
 	Run() error
@@ -60,233 +67,74 @@ func NewProxy(
 	}
 
 	baseProxy := BaseProxy{
-		clientCfg:      clientCfg,
-		limiter:        limiter,
-		msgTransporter: msgTransporter,
-		xl:             xlog.FromContextSafe(ctx),
-		ctx:            ctx,
+		baseProxyConfig: pxyConf.GetBaseConfig(),
+		clientCfg:       clientCfg,
+		limiter:         limiter,
+		msgTransporter:  msgTransporter,
+		xl:              xlog.FromContextSafe(ctx),
+		ctx:             ctx,
 	}
-	switch cfg := pxyConf.(type) {
-	case *config.TCPProxyConf:
-		pxy = &TCPProxy{
-			BaseProxy: &baseProxy,
-			cfg:       cfg,
-		}
-	case *config.TCPMuxProxyConf:
-		pxy = &TCPMuxProxy{
-			BaseProxy: &baseProxy,
-			cfg:       cfg,
-		}
-	case *config.UDPProxyConf:
-		pxy = &UDPProxy{
-			BaseProxy: &baseProxy,
-			cfg:       cfg,
-		}
-	case *config.HTTPProxyConf:
-		pxy = &HTTPProxy{
-			BaseProxy: &baseProxy,
-			cfg:       cfg,
-		}
-	case *config.HTTPSProxyConf:
-		pxy = &HTTPSProxy{
-			BaseProxy: &baseProxy,
-			cfg:       cfg,
-		}
-	case *config.STCPProxyConf:
-		pxy = &STCPProxy{
-			BaseProxy: &baseProxy,
-			cfg:       cfg,
-		}
-	case *config.XTCPProxyConf:
-		pxy = &XTCPProxy{
-			BaseProxy: &baseProxy,
-			cfg:       cfg,
-		}
-	case *config.SUDPProxyConf:
-		pxy = &SUDPProxy{
-			BaseProxy: &baseProxy,
-			cfg:       cfg,
-			closeCh:   make(chan struct{}),
-		}
+
+	factory := proxyFactoryRegistry[reflect.TypeOf(pxyConf)]
+	if factory == nil {
+		return nil
 	}
-	return
+	return factory(&baseProxy, pxyConf)
 }
 
 type BaseProxy struct {
-	closed         bool
-	clientCfg      config.ClientCommonConf
-	msgTransporter transport.MessageTransporter
-	limiter        *rate.Limiter
+	baseProxyConfig *config.BaseProxyConf
+	clientCfg       config.ClientCommonConf
+	msgTransporter  transport.MessageTransporter
+	limiter         *rate.Limiter
+	// proxyPlugin is used to handle connections instead of dialing to local service.
+	// It's only validate for TCP protocol now.
+	proxyPlugin plugin.Plugin
 
 	mu  sync.RWMutex
 	xl  *xlog.Logger
 	ctx context.Context
 }
 
-// TCP
-type TCPProxy struct {
-	*BaseProxy
-
-	cfg         *config.TCPProxyConf
-	proxyPlugin plugin.Plugin
-}
-
-func (pxy *TCPProxy) Run() (err error) {
-	if pxy.cfg.Plugin != "" {
-		pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams)
-		if err != nil {
-			return
-		}
-	}
-	return
-}
-
-func (pxy *TCPProxy) Close() {
-	if pxy.proxyPlugin != nil {
-		pxy.proxyPlugin.Close()
-	}
-}
-
-func (pxy *TCPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
-	HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseConfig(), pxy.limiter,
-		conn, []byte(pxy.clientCfg.Token), m)
-}
-
-// TCP Multiplexer
-type TCPMuxProxy struct {
-	*BaseProxy
-
-	cfg         *config.TCPMuxProxyConf
-	proxyPlugin plugin.Plugin
-}
-
-func (pxy *TCPMuxProxy) Run() (err error) {
-	if pxy.cfg.Plugin != "" {
-		pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams)
+func (pxy *BaseProxy) Run() error {
+	if pxy.baseProxyConfig.Plugin != "" {
+		p, err := plugin.Create(pxy.baseProxyConfig.Plugin, pxy.baseProxyConfig.PluginParams)
 		if err != nil {
-			return
+			return err
 		}
+		pxy.proxyPlugin = p
 	}
-	return
+	return nil
 }
 
-func (pxy *TCPMuxProxy) Close() {
+func (pxy *BaseProxy) Close() {
 	if pxy.proxyPlugin != nil {
 		pxy.proxyPlugin.Close()
 	}
 }
 
-func (pxy *TCPMuxProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
-	HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseConfig(), pxy.limiter,
-		conn, []byte(pxy.clientCfg.Token), m)
-}
-
-// HTTP
-type HTTPProxy struct {
-	*BaseProxy
-
-	cfg         *config.HTTPProxyConf
-	proxyPlugin plugin.Plugin
-}
-
-func (pxy *HTTPProxy) Run() (err error) {
-	if pxy.cfg.Plugin != "" {
-		pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams)
-		if err != nil {
-			return
-		}
-	}
-	return
-}
-
-func (pxy *HTTPProxy) Close() {
-	if pxy.proxyPlugin != nil {
-		pxy.proxyPlugin.Close()
-	}
-}
-
-func (pxy *HTTPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
-	HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseConfig(), pxy.limiter,
-		conn, []byte(pxy.clientCfg.Token), m)
-}
-
-// HTTPS
-type HTTPSProxy struct {
-	*BaseProxy
-
-	cfg         *config.HTTPSProxyConf
-	proxyPlugin plugin.Plugin
-}
-
-func (pxy *HTTPSProxy) Run() (err error) {
-	if pxy.cfg.Plugin != "" {
-		pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams)
-		if err != nil {
-			return
-		}
-	}
-	return
-}
-
-func (pxy *HTTPSProxy) Close() {
-	if pxy.proxyPlugin != nil {
-		pxy.proxyPlugin.Close()
-	}
-}
-
-func (pxy *HTTPSProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
-	HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseConfig(), pxy.limiter,
-		conn, []byte(pxy.clientCfg.Token), m)
-}
-
-// STCP
-type STCPProxy struct {
-	*BaseProxy
-
-	cfg         *config.STCPProxyConf
-	proxyPlugin plugin.Plugin
-}
-
-func (pxy *STCPProxy) Run() (err error) {
-	if pxy.cfg.Plugin != "" {
-		pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams)
-		if err != nil {
-			return
-		}
-	}
-	return
-}
-
-func (pxy *STCPProxy) Close() {
-	if pxy.proxyPlugin != nil {
-		pxy.proxyPlugin.Close()
-	}
-}
-
-func (pxy *STCPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
-	HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseConfig(), pxy.limiter,
-		conn, []byte(pxy.clientCfg.Token), m)
+func (pxy *BaseProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
+	pxy.HandleTCPWorkConnection(conn, m, []byte(pxy.clientCfg.Token))
 }
 
 // Common handler for tcp work connections.
-func HandleTCPWorkConnection(ctx context.Context, localInfo *config.LocalSvrConf, proxyPlugin plugin.Plugin,
-	baseInfo *config.BaseProxyConf, limiter *rate.Limiter, workConn net.Conn, encKey []byte, m *msg.StartWorkConn,
-) {
-	xl := xlog.FromContextSafe(ctx)
+func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWorkConn, encKey []byte) {
+	xl := pxy.xl
+	baseConfig := pxy.baseProxyConfig
 	var (
 		remote io.ReadWriteCloser
 		err    error
 	)
 	remote = workConn
-	if limiter != nil {
-		remote = libio.WrapReadWriteCloser(limit.NewReader(workConn, limiter), limit.NewWriter(workConn, limiter), func() error {
+	if pxy.limiter != nil {
+		remote = libio.WrapReadWriteCloser(limit.NewReader(workConn, pxy.limiter), limit.NewWriter(workConn, pxy.limiter), func() error {
 			return workConn.Close()
 		})
 	}
 
 	xl.Trace("handle tcp work connection, use_encryption: %t, use_compression: %t",
-		baseInfo.UseEncryption, baseInfo.UseCompression)
-	if baseInfo.UseEncryption {
+		baseConfig.UseEncryption, baseConfig.UseCompression)
+	if baseConfig.UseEncryption {
 		remote, err = libio.WithEncryption(remote, encKey)
 		if err != nil {
 			workConn.Close()
@@ -294,13 +142,13 @@ func HandleTCPWorkConnection(ctx context.Context, localInfo *config.LocalSvrConf
 			return
 		}
 	}
-	if baseInfo.UseCompression {
+	if baseConfig.UseCompression {
 		remote = libio.WithCompression(remote)
 	}
 
 	// check if we need to send proxy protocol info
 	var extraInfo []byte
-	if baseInfo.ProxyProtocolVersion != "" {
+	if baseConfig.ProxyProtocolVersion != "" {
 		if m.SrcAddr != "" && m.SrcPort != 0 {
 			if m.DstAddr == "" {
 				m.DstAddr = "127.0.0.1"
@@ -319,9 +167,9 @@ func HandleTCPWorkConnection(ctx context.Context, localInfo *config.LocalSvrConf
 				h.TransportProtocol = pp.TCPv6
 			}
 
-			if baseInfo.ProxyProtocolVersion == "v1" {
+			if baseConfig.ProxyProtocolVersion == "v1" {
 				h.Version = 1
-			} else if baseInfo.ProxyProtocolVersion == "v2" {
+			} else if baseConfig.ProxyProtocolVersion == "v2" {
 				h.Version = 2
 			}
 
@@ -331,21 +179,21 @@ func HandleTCPWorkConnection(ctx context.Context, localInfo *config.LocalSvrConf
 		}
 	}
 
-	if proxyPlugin != nil {
-		// if plugin is set, let plugin handle connections first
-		xl.Debug("handle by plugin: %s", proxyPlugin.Name())
-		proxyPlugin.Handle(remote, workConn, extraInfo)
+	if pxy.proxyPlugin != nil {
+		// if plugin is set, let plugin handle connection first
+		xl.Debug("handle by plugin: %s", pxy.proxyPlugin.Name())
+		pxy.proxyPlugin.Handle(remote, workConn, extraInfo)
 		xl.Debug("handle by plugin finished")
 		return
 	}
 
 	localConn, err := libdial.Dial(
-		net.JoinHostPort(localInfo.LocalIP, strconv.Itoa(localInfo.LocalPort)),
+		net.JoinHostPort(baseConfig.LocalIP, strconv.Itoa(baseConfig.LocalPort)),
 		libdial.WithTimeout(10*time.Second),
 	)
 	if err != nil {
 		workConn.Close()
-		xl.Error("connect to local service [%s:%d] error: %v", localInfo.LocalIP, localInfo.LocalPort, err)
+		xl.Error("connect to local service [%s:%d] error: %v", baseConfig.LocalIP, baseConfig.LocalPort, err)
 		return
 	}
 

+ 17 - 0
client/proxy/sudp.go

@@ -17,6 +17,7 @@ package proxy
 import (
 	"io"
 	"net"
+	"reflect"
 	"strconv"
 	"sync"
 	"time"
@@ -31,6 +32,10 @@ import (
 	utilnet "github.com/fatedier/frp/pkg/util/net"
 )
 
+func init() {
+	RegisterProxyFactory(reflect.TypeOf(&config.SUDPProxyConf{}), NewSUDPProxy)
+}
+
 type SUDPProxy struct {
 	*BaseProxy
 
@@ -41,6 +46,18 @@ type SUDPProxy struct {
 	closeCh chan struct{}
 }
 
+func NewSUDPProxy(baseProxy *BaseProxy, cfg config.ProxyConf) Proxy {
+	unwrapped, ok := cfg.(*config.SUDPProxyConf)
+	if !ok {
+		return nil
+	}
+	return &SUDPProxy{
+		BaseProxy: baseProxy,
+		cfg:       unwrapped,
+		closeCh:   make(chan struct{}),
+	}
+}
+
 func (pxy *SUDPProxy) Run() (err error) {
 	pxy.localAddr, err = net.ResolveUDPAddr("udp", net.JoinHostPort(pxy.cfg.LocalIP, strconv.Itoa(pxy.cfg.LocalPort)))
 	if err != nil {

+ 17 - 1
client/proxy/udp.go

@@ -17,6 +17,7 @@ package proxy
 import (
 	"io"
 	"net"
+	"reflect"
 	"strconv"
 	"time"
 
@@ -30,7 +31,10 @@ import (
 	utilnet "github.com/fatedier/frp/pkg/util/net"
 )
 
-// UDP
+func init() {
+	RegisterProxyFactory(reflect.TypeOf(&config.UDPProxyConf{}), NewUDPProxy)
+}
+
 type UDPProxy struct {
 	*BaseProxy
 
@@ -42,6 +46,18 @@ type UDPProxy struct {
 	// include msg.UDPPacket and msg.Ping
 	sendCh   chan msg.Message
 	workConn net.Conn
+	closed   bool
+}
+
+func NewUDPProxy(baseProxy *BaseProxy, cfg config.ProxyConf) Proxy {
+	unwrapped, ok := cfg.(*config.UDPProxyConf)
+	if !ok {
+		return nil
+	}
+	return &UDPProxy{
+		BaseProxy: baseProxy,
+		cfg:       unwrapped,
+	}
 }
 
 func (pxy *UDPProxy) Run() (err error) {

+ 15 - 20
client/proxy/xtcp.go

@@ -17,6 +17,7 @@ package proxy
 import (
 	"io"
 	"net"
+	"reflect"
 	"time"
 
 	fmux "github.com/hashicorp/yamux"
@@ -25,32 +26,28 @@ import (
 	"github.com/fatedier/frp/pkg/config"
 	"github.com/fatedier/frp/pkg/msg"
 	"github.com/fatedier/frp/pkg/nathole"
-	plugin "github.com/fatedier/frp/pkg/plugin/client"
 	"github.com/fatedier/frp/pkg/transport"
 	utilnet "github.com/fatedier/frp/pkg/util/net"
 )
 
-// XTCP
+func init() {
+	RegisterProxyFactory(reflect.TypeOf(&config.XTCPProxyConf{}), NewXTCPProxy)
+}
+
 type XTCPProxy struct {
 	*BaseProxy
 
-	cfg         *config.XTCPProxyConf
-	proxyPlugin plugin.Plugin
+	cfg *config.XTCPProxyConf
 }
 
-func (pxy *XTCPProxy) Run() (err error) {
-	if pxy.cfg.Plugin != "" {
-		pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams)
-		if err != nil {
-			return
-		}
+func NewXTCPProxy(baseProxy *BaseProxy, cfg config.ProxyConf) Proxy {
+	unwrapped, ok := cfg.(*config.XTCPProxyConf)
+	if !ok {
+		return nil
 	}
-	return
-}
-
-func (pxy *XTCPProxy) Close() {
-	if pxy.proxyPlugin != nil {
-		pxy.proxyPlugin.Close()
+	return &XTCPProxy{
+		BaseProxy: baseProxy,
+		cfg:       unwrapped,
 	}
 }
 
@@ -155,8 +152,7 @@ func (pxy *XTCPProxy) listenByKCP(listenConn *net.UDPConn, raddr *net.UDPAddr, s
 			xl.Error("accept connection error: %v", err)
 			return
 		}
-		go HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseConfig(), pxy.limiter,
-			muxConn, []byte(pxy.cfg.Sk), startWorkConnMsg)
+		go pxy.HandleTCPWorkConnection(muxConn, startWorkConnMsg, []byte(pxy.cfg.Sk))
 	}
 }
 
@@ -194,7 +190,6 @@ func (pxy *XTCPProxy) listenByQUIC(listenConn *net.UDPConn, _ *net.UDPAddr, star
 			_ = c.CloseWithError(0, "")
 			return
 		}
-		go HandleTCPWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, pxy.cfg.GetBaseConfig(), pxy.limiter,
-			utilnet.QuicStreamToNetConn(stream, c), []byte(pxy.cfg.Sk), startWorkConnMsg)
+		go pxy.HandleTCPWorkConnection(utilnet.QuicStreamToNetConn(stream, c), startWorkConnMsg, []byte(pxy.cfg.Sk))
 	}
 }