Browse Source

refactor: refine pkg net utils (#2720)

* refactor: refine pkg net utils

* fix: x

Co-authored-by: blizard863 <760076784@qq.com>
Blizard 3 years ago
parent
commit
ea568e8a4f
6 changed files with 142 additions and 60 deletions
  1. 5 2
      client/control.go
  2. 1 1
      client/proxy/proxy.go
  3. 6 2
      client/service.go
  4. 41 29
      pkg/util/net/conn.go
  5. 89 0
      pkg/util/net/dial.go
  6. 0 26
      pkg/util/net/websocket.go

+ 5 - 2
client/control.go

@@ -234,8 +234,11 @@ func (ctl *Control) connectServer() (conn net.Conn, err error) {
 			}
 		}
 
-		address := net.JoinHostPort(ctl.clientCfg.ServerAddr, strconv.Itoa(ctl.clientCfg.ServerPort))
-		conn, err = frpNet.ConnectServerByProxyWithTLS(ctl.clientCfg.HTTPProxy, ctl.clientCfg.Protocol, address, tlsConfig, ctl.clientCfg.DisableCustomTLSFirstByte)
+		conn, err = frpNet.DialWithOptions(net.JoinHostPort(ctl.clientCfg.ServerAddr, strconv.Itoa(ctl.clientCfg.ServerPort)),
+			frpNet.WithProxyURL(ctl.clientCfg.HTTPProxy),
+			frpNet.WithProtocol(ctl.clientCfg.Protocol),
+			frpNet.WithTLSConfig(tlsConfig),
+			frpNet.WithDisableCustomTLSHeadByte(ctl.clientCfg.DisableCustomTLSFirstByte))
 
 		if err != nil {
 			xl.Warn("start new connection to server error: %v", err)

+ 1 - 1
client/proxy/proxy.go

@@ -790,7 +790,7 @@ func HandleTCPWorkConnection(ctx context.Context, localInfo *config.LocalSvrConf
 		return
 	}
 
-	localConn, err := frpNet.ConnectServer("tcp", fmt.Sprintf("%s:%d", localInfo.LocalIP, localInfo.LocalPort))
+	localConn, err := frpNet.DialWithOptions(net.JoinHostPort(localInfo.LocalIP, strconv.Itoa(localInfo.LocalPort)))
 	if err != nil {
 		workConn.Close()
 		xl.Error("connect to local service [%s:%d] error: %v", localInfo.LocalIP, localInfo.LocalPort, err)

+ 6 - 2
client/service.go

@@ -228,8 +228,12 @@ func (svr *Service) login() (conn net.Conn, session *fmux.Session, err error) {
 		}
 	}
 
-	address := net.JoinHostPort(svr.cfg.ServerAddr, strconv.Itoa(svr.cfg.ServerPort))
-	conn, err = frpNet.ConnectServerByProxyWithTLS(svr.cfg.HTTPProxy, svr.cfg.Protocol, address, tlsConfig, svr.cfg.DisableCustomTLSFirstByte)
+	conn, err = frpNet.DialWithOptions(net.JoinHostPort(svr.cfg.ServerAddr, strconv.Itoa(svr.cfg.ServerPort)),
+		frpNet.WithProxyURL(svr.cfg.HTTPProxy),
+		frpNet.WithProtocol(svr.cfg.Protocol),
+		frpNet.WithTLSConfig(tlsConfig),
+		frpNet.WithDisableCustomTLSHeadByte(svr.cfg.DisableCustomTLSFirstByte))
+
 	if err != nil {
 		return
 	}

+ 41 - 29
pkg/util/net/conn.go

@@ -16,15 +16,16 @@ package net
 
 import (
 	"context"
-	"crypto/tls"
 	"errors"
 	"fmt"
 	"io"
 	"net"
+	"net/url"
 	"sync/atomic"
 	"time"
 
 	"github.com/fatedier/frp/pkg/util/xlog"
+	"golang.org/x/net/websocket"
 
 	gnet "github.com/fatedier/golib/net"
 	kcp "github.com/fatedier/kcp-go"
@@ -194,50 +195,61 @@ func ConnectServer(protocol string, addr string) (c net.Conn, err error) {
 	case "tcp":
 		return net.Dial("tcp", addr)
 	case "kcp":
-		kcpConn, errRet := kcp.DialWithOptions(addr, nil, 10, 3)
-		if errRet != nil {
-			err = errRet
-			return
-		}
-		kcpConn.SetStreamMode(true)
-		kcpConn.SetWriteDelay(true)
-		kcpConn.SetNoDelay(1, 20, 2, 1)
-		kcpConn.SetWindowSize(128, 512)
-		kcpConn.SetMtu(1350)
-		kcpConn.SetACKNoDelay(false)
-		kcpConn.SetReadBuffer(4194304)
-		kcpConn.SetWriteBuffer(4194304)
-		c = kcpConn
-		return
+		return DialKCPServer(addr)
+	case "websocket":
+		return DialWebsocketServer(addr)
 	default:
 		return nil, fmt.Errorf("unsupport protocol: %s", protocol)
 	}
 }
 
+func DialKCPServer(addr string) (c net.Conn, err error) {
+	kcpConn, errRet := kcp.DialWithOptions(addr, nil, 10, 3)
+	if errRet != nil {
+		err = errRet
+		return
+	}
+	kcpConn.SetStreamMode(true)
+	kcpConn.SetWriteDelay(true)
+	kcpConn.SetNoDelay(1, 20, 2, 1)
+	kcpConn.SetWindowSize(128, 512)
+	kcpConn.SetMtu(1350)
+	kcpConn.SetACKNoDelay(false)
+	kcpConn.SetReadBuffer(4194304)
+	kcpConn.SetWriteBuffer(4194304)
+	c = kcpConn
+	return
+}
+
 func ConnectServerByProxy(proxyURL string, protocol string, addr string) (c net.Conn, err error) {
 	switch protocol {
 	case "tcp":
 		return gnet.DialTcpByProxy(proxyURL, addr)
-	case "kcp":
-		// http proxy is not supported for kcp
-		return ConnectServer(protocol, addr)
-	case "websocket":
-		return ConnectWebsocketServer(addr)
 	default:
-		return nil, fmt.Errorf("unsupport protocol: %s", protocol)
+		return nil, fmt.Errorf("unsupport protocol: %s when connecting by proxy", protocol)
 	}
 }
 
-func ConnectServerByProxyWithTLS(proxyURL string, protocol string, addr string, tlsConfig *tls.Config, disableCustomTLSHeadByte bool) (c net.Conn, err error) {
-	c, err = ConnectServerByProxy(proxyURL, protocol, addr)
+// addr: domain:port
+func DialWebsocketServer(addr string) (net.Conn, error) {
+	addr = "ws://" + addr + FrpWebsocketPath
+	uri, err := url.Parse(addr)
 	if err != nil {
-		return
+		return nil, err
 	}
 
-	if tlsConfig == nil {
-		return
+	origin := "http://" + uri.Host
+	cfg, err := websocket.NewConfig(addr, origin)
+	if err != nil {
+		return nil, err
+	}
+	cfg.Dialer = &net.Dialer{
+		Timeout: 10 * time.Second,
 	}
 
-	c = WrapTLSClientConn(c, tlsConfig, disableCustomTLSHeadByte)
-	return
+	conn, err := websocket.DialConfig(cfg)
+	if err != nil {
+		return nil, err
+	}
+	return conn, nil
 }

+ 89 - 0
pkg/util/net/dial.go

@@ -0,0 +1,89 @@
+package net
+
+import (
+	"crypto/tls"
+	"net"
+)
+
+type dialOptions struct {
+	proxyURL                 string
+	protocol                 string
+	tlsConfig                *tls.Config
+	disableCustomTLSHeadByte bool
+}
+
+type DialOption interface {
+	apply(*dialOptions)
+}
+
+type EmptyDialOption struct{}
+
+func (EmptyDialOption) apply(*dialOptions) {}
+
+type funcDialOption struct {
+	f func(*dialOptions)
+}
+
+func (fdo *funcDialOption) apply(do *dialOptions) {
+	fdo.f(do)
+}
+
+func newFuncDialOption(f func(*dialOptions)) *funcDialOption {
+	return &funcDialOption{
+		f: f,
+	}
+}
+
+func DefaultDialOptions() dialOptions {
+	return dialOptions{
+		protocol: "tcp",
+	}
+}
+
+func WithProxyURL(proxyURL string) DialOption {
+	return newFuncDialOption(func(do *dialOptions) {
+		do.proxyURL = proxyURL
+	})
+}
+
+func WithTLSConfig(tlsConfig *tls.Config) DialOption {
+	return newFuncDialOption(func(do *dialOptions) {
+		do.tlsConfig = tlsConfig
+	})
+}
+
+func WithDisableCustomTLSHeadByte(disableCustomTLSHeadByte bool) DialOption {
+	return newFuncDialOption(func(do *dialOptions) {
+		do.disableCustomTLSHeadByte = disableCustomTLSHeadByte
+	})
+}
+
+func WithProtocol(protocol string) DialOption {
+	return newFuncDialOption(func(do *dialOptions) {
+		do.protocol = protocol
+	})
+}
+
+func DialWithOptions(addr string, opts ...DialOption) (c net.Conn, err error) {
+	op := DefaultDialOptions()
+
+	for _, opt := range opts {
+		opt.apply(&op)
+	}
+
+	if op.proxyURL == "" {
+		c, err = ConnectServer(op.protocol, addr)
+	} else {
+		c, err = ConnectServerByProxy(op.proxyURL, op.protocol, addr)
+	}
+	if err != nil {
+		return nil, err
+	}
+
+	if op.tlsConfig == nil {
+		return
+	}
+
+	c = WrapTLSClientConn(c, op.tlsConfig, op.disableCustomTLSHeadByte)
+	return
+}

+ 0 - 26
pkg/util/net/websocket.go

@@ -5,8 +5,6 @@ import (
 	"fmt"
 	"net"
 	"net/http"
-	"net/url"
-	"time"
 
 	"golang.org/x/net/websocket"
 )
@@ -77,27 +75,3 @@ func (p *WebsocketListener) Close() error {
 func (p *WebsocketListener) Addr() net.Addr {
 	return p.ln.Addr()
 }
-
-// addr: domain:port
-func ConnectWebsocketServer(addr string) (net.Conn, error) {
-	addr = "ws://" + addr + FrpWebsocketPath
-	uri, err := url.Parse(addr)
-	if err != nil {
-		return nil, err
-	}
-
-	origin := "http://" + uri.Host
-	cfg, err := websocket.NewConfig(addr, origin)
-	if err != nil {
-		return nil, err
-	}
-	cfg.Dialer = &net.Dialer{
-		Timeout: 10 * time.Second,
-	}
-
-	conn, err := websocket.DialConfig(cfg)
-	if err != nil {
-		return nil, err
-	}
-	return conn, nil
-}