Browse Source

frpc: add disable_custom_tls_first_byte to not send first custom tls to frps (#2520)

fatedier 3 years ago
parent
commit
42745a3da2

+ 1 - 1
client/control.go

@@ -228,7 +228,7 @@ func (ctl *Control) connectServer() (conn net.Conn, err error) {
 		}
 
 		address := net.JoinHostPort(ctl.clientCfg.ServerAddr, strconv.Itoa(ctl.clientCfg.ServerPort))
-		conn, err = frpNet.ConnectServerByProxyWithTLS(ctl.clientCfg.HTTPProxy, ctl.clientCfg.Protocol, address, tlsConfig)
+		conn, err = frpNet.ConnectServerByProxyWithTLS(ctl.clientCfg.HTTPProxy, ctl.clientCfg.Protocol, address, tlsConfig, ctl.clientCfg.DisableCustomTLSFirstByte)
 
 		if err != nil {
 			xl.Warn("start new connection to server error: %v", err)

+ 1 - 1
client/service.go

@@ -232,7 +232,7 @@ func (svr *Service) login() (conn net.Conn, session *fmux.Session, err error) {
 	}
 
 	address := net.JoinHostPort(svr.cfg.ServerAddr, strconv.Itoa(svr.cfg.ServerPort))
-	conn, err = frpNet.ConnectServerByProxyWithTLS(svr.cfg.HTTPProxy, svr.cfg.Protocol, address, tlsConfig)
+	conn, err = frpNet.ConnectServerByProxyWithTLS(svr.cfg.HTTPProxy, svr.cfg.Protocol, address, tlsConfig, svr.cfg.DisableCustomTLSFirstByte)
 	if err != nil {
 		return
 	}

+ 4 - 0
conf/frpc_full.ini

@@ -105,6 +105,10 @@ udp_packet_size = 1500
 # include other config files for proxies.
 # includes = ./confd/*.ini
 
+# By default, frpc will connect frps with first custom byte if tls is enabled.
+# If DisableCustomTLSFirstByte is true, frpc will not send that custom byte.
+disable_custom_tls_first_byte = false
+
 # 'ssh' is the unique proxy name
 # if user in [common] section is not empty, it will be changed to {user}.{proxy} such as 'your_name.ssh'
 [ssh]

+ 3 - 0
pkg/config/client.go

@@ -124,6 +124,9 @@ type ClientCommonConf struct {
 	// TLSServerName specifices the custom server name of tls certificate. By
 	// default, server name if same to ServerAddr.
 	TLSServerName string `ini:"tls_server_name" json:"tls_server_name"`
+	// By default, frpc will connect frps with first custom byte if tls is enabled.
+	// If DisableCustomTLSFirstByte is true, frpc will not send that custom byte.
+	DisableCustomTLSFirstByte bool `ini:"disable_custom_tls_first_byte" json:"disable_custom_tls_first_byte"`
 	// HeartBeatInterval specifies at what interval heartbeats are sent to the
 	// server, in seconds. It is not recommended to change this value. By
 	// default, this value is 30.

+ 2 - 2
pkg/util/net/conn.go

@@ -228,7 +228,7 @@ func ConnectServerByProxy(proxyURL string, protocol string, addr string) (c net.
 	}
 }
 
-func ConnectServerByProxyWithTLS(proxyURL string, protocol string, addr string, tlsConfig *tls.Config) (c net.Conn, err error) {
+func ConnectServerByProxyWithTLS(proxyURL string, protocol string, addr string, tlsConfig *tls.Config, disableCustomTLSHeadByte bool) (c net.Conn, err error) {
 	c, err = ConnectServerByProxy(proxyURL, protocol, addr)
 	if err != nil {
 		return
@@ -238,6 +238,6 @@ func ConnectServerByProxyWithTLS(proxyURL string, protocol string, addr string,
 		return
 	}
 
-	c = WrapTLSClientConn(c, tlsConfig)
+	c = WrapTLSClientConn(c, tlsConfig, disableCustomTLSHeadByte)
 	return
 }

+ 13 - 3
pkg/util/net/tls.go

@@ -27,13 +27,18 @@ var (
 	FRPTLSHeadByte = 0x17
 )
 
-func WrapTLSClientConn(c net.Conn, tlsConfig *tls.Config) (out net.Conn) {
-	c.Write([]byte{byte(FRPTLSHeadByte)})
+func WrapTLSClientConn(c net.Conn, tlsConfig *tls.Config, disableCustomTLSHeadByte bool) (out net.Conn) {
+	if !disableCustomTLSHeadByte {
+		c.Write([]byte{byte(FRPTLSHeadByte)})
+	}
 	out = tls.Client(c, tlsConfig)
 	return
 }
 
-func CheckAndEnableTLSServerConnWithTimeout(c net.Conn, tlsConfig *tls.Config, tlsOnly bool, timeout time.Duration) (out net.Conn, err error) {
+func CheckAndEnableTLSServerConnWithTimeout(
+	c net.Conn, tlsConfig *tls.Config, tlsOnly bool, timeout time.Duration,
+) (out net.Conn, isTLS bool, custom bool, err error) {
+
 	sc, r := gnet.NewSharedConnSize(c, 2)
 	buf := make([]byte, 1)
 	var n int
@@ -46,6 +51,11 @@ func CheckAndEnableTLSServerConnWithTimeout(c net.Conn, tlsConfig *tls.Config, t
 
 	if n == 1 && int(buf[0]) == FRPTLSHeadByte {
 		out = tls.Server(c, tlsConfig)
+		isTLS = true
+		custom = true
+	} else if n == 1 && int(buf[0]) == 0x16 {
+		out = tls.Server(sc, tlsConfig)
+		isTLS = true
 	} else {
 		if tlsOnly {
 			err = fmt.Errorf("non-TLS connection received on a TlsOnly server")

+ 7 - 5
server/service.go

@@ -258,8 +258,9 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) {
 	}
 
 	// frp tls listener
-	svr.tlsListener = svr.muxer.Listen(1, 1, func(data []byte) bool {
-		return int(data[0]) == frpNet.FRPTLSHeadByte
+	svr.tlsListener = svr.muxer.Listen(2, 1, func(data []byte) bool {
+		// tls first byte can be 0x16 only when vhost https port is not same with bind port
+		return int(data[0]) == frpNet.FRPTLSHeadByte || int(data[0]) == 0x16
 	})
 
 	// Create nat hole controller.
@@ -395,15 +396,16 @@ func (svr *Service) HandleListener(l net.Listener) {
 
 		log.Trace("start check TLS connection...")
 		originConn := c
-		c, err = frpNet.CheckAndEnableTLSServerConnWithTimeout(c, svr.tlsConfig, svr.cfg.TLSOnly, connReadTimeout)
+		var isTLS, custom bool
+		c, isTLS, custom, err = frpNet.CheckAndEnableTLSServerConnWithTimeout(c, svr.tlsConfig, svr.cfg.TLSOnly, connReadTimeout)
 		if err != nil {
 			log.Warn("CheckAndEnableTLSServerConnWithTimeout error: %v", err)
 			originConn.Close()
 			continue
 		}
-		log.Trace("success check TLS connection")
+		log.Trace("check TLS connection success, isTLS: %v custom: %v", isTLS, custom)
 
-		// Start a new goroutine for dealing connections.
+		// Start a new goroutine to handle connection.
 		go func(ctx context.Context, frpConn net.Conn) {
 			if svr.cfg.TCPMux {
 				fmuxCfg := fmux.DefaultConfig()

+ 18 - 0
test/e2e/basic/client_server.go

@@ -231,4 +231,22 @@ var _ = Describe("[Feature: Client-Server]", func() {
 			})
 		})
 	})
+
+	Describe("TLS with disable_custom_tls_first_byte", func() {
+		supportProtocols := []string{"tcp", "kcp", "websocket"}
+		for _, protocol := range supportProtocols {
+			tmp := protocol
+			defineClientServerTest("TLS over "+strings.ToUpper(tmp), f, &generalTestConfigures{
+				server: fmt.Sprintf(`
+					kcp_bind_port = {{ .%s }}
+					protocol = %s
+					`, consts.PortServerName, protocol),
+				client: fmt.Sprintf(`
+					tls_enable = true
+					protocol = %s
+					disable_custom_tls_first_byte = true
+					`, protocol),
+			})
+		}
+	})
 })