Browse Source

Bugfix: add ipv6 parsing with address of frps (#2163)

yuyulei 4 years ago
parent
commit
ed61049041
4 changed files with 49 additions and 13 deletions
  1. 5 3
      client/control.go
  2. 4 2
      client/service.go
  3. 7 8
      cmd/frpc/sub/root.go
  4. 33 0
      pkg/util/util/http.go

+ 5 - 3
client/control.go

@@ -17,10 +17,10 @@ package client
 import (
 	"context"
 	"crypto/tls"
-	"fmt"
 	"io"
 	"net"
 	"runtime/debug"
+	"strconv"
 	"sync"
 	"time"
 
@@ -222,8 +222,10 @@ func (ctl *Control) connectServer() (conn net.Conn, err error) {
 				return
 			}
 		}
-		conn, err = frpNet.ConnectServerByProxyWithTLS(ctl.clientCfg.HTTPProxy, ctl.clientCfg.Protocol,
-			fmt.Sprintf("%s:%d", ctl.clientCfg.ServerAddr, ctl.clientCfg.ServerPort), tlsConfig)
+
+		address := net.JoinHostPort(ctl.clientCfg.ServerAddr, strconv.Itoa(ctl.clientCfg.ServerPort))
+		conn, err = frpNet.ConnectServerByProxyWithTLS(ctl.clientCfg.HTTPProxy, ctl.clientCfg.Protocol, address, tlsConfig)
+
 		if err != nil {
 			xl.Warn("start new connection to server error: %v", err)
 			return

+ 4 - 2
client/service.go

@@ -21,6 +21,7 @@ import (
 	"io/ioutil"
 	"net"
 	"runtime"
+	"strconv"
 	"sync"
 	"sync/atomic"
 	"time"
@@ -215,8 +216,9 @@ func (svr *Service) login() (conn net.Conn, session *fmux.Session, err error) {
 			return
 		}
 	}
-	conn, err = frpNet.ConnectServerByProxyWithTLS(svr.cfg.HTTPProxy, svr.cfg.Protocol,
-		fmt.Sprintf("%s:%d", svr.cfg.ServerAddr, svr.cfg.ServerPort), tlsConfig)
+
+	address := net.JoinHostPort(svr.cfg.ServerAddr, strconv.Itoa(svr.cfg.ServerPort))
+	conn, err = frpNet.ConnectServerByProxyWithTLS(svr.cfg.HTTPProxy, svr.cfg.Protocol, address, tlsConfig)
 	if err != nil {
 		return
 	}

+ 7 - 8
cmd/frpc/sub/root.go

@@ -157,17 +157,16 @@ func parseClientCommonCfgFromIni(content string) (config.ClientCommonConf, error
 func parseClientCommonCfgFromCmd() (cfg config.ClientCommonConf, err error) {
 	cfg = config.GetDefaultClientConf()
 
-	strs := strings.Split(serverAddr, ":")
-	if len(strs) < 2 {
-		err = fmt.Errorf("invalid server_addr")
+	ipStr, portStr, err := net.SplitHostPort(serverAddr)
+	if err != nil {
+		err = fmt.Errorf("invalid server_addr: %v", err)
 		return
 	}
-	if strs[0] != "" {
-		cfg.ServerAddr = strs[0]
-	}
-	cfg.ServerPort, err = strconv.Atoi(strs[1])
+
+	cfg.ServerAddr = ipStr
+	cfg.ServerPort, err = strconv.Atoi(portStr)
 	if err != nil {
-		err = fmt.Errorf("invalid server_addr")
+		err = fmt.Errorf("invalid server_addr: %v", err)
 		return
 	}
 

+ 33 - 0
pkg/util/util/http.go

@@ -15,6 +15,7 @@
 package util
 
 import (
+	"net"
 	"net/http"
 	"strings"
 )
@@ -33,6 +34,7 @@ func OkResponse() *http.Response {
 	return res
 }
 
+// TODO: use "CanonicalHost" func to replace all "GetHostFromAddr" func.
 func GetHostFromAddr(addr string) (host string) {
 	strs := strings.Split(addr, ":")
 	if len(strs) > 1 {
@@ -42,3 +44,34 @@ func GetHostFromAddr(addr string) (host string) {
 	}
 	return
 }
+
+// canonicalHost strips port from host if present and returns the canonicalized
+// host name.
+func CanonicalHost(host string) (string, error) {
+	var err error
+	host = strings.ToLower(host)
+	if hasPort(host) {
+		host, _, err = net.SplitHostPort(host)
+		if err != nil {
+			return "", err
+		}
+	}
+	if strings.HasSuffix(host, ".") {
+		// Strip trailing dot from fully qualified domain names.
+		host = host[:len(host)-1]
+	}
+	return host, nil
+}
+
+// hasPort reports whether host contains a port number. host may be a host
+// name, an IPv4 or an IPv6 address.
+func hasPort(host string) bool {
+	colons := strings.Count(host, ":")
+	if colons == 0 {
+		return false
+	}
+	if colons == 1 {
+		return true
+	}
+	return host[0] == '[' && strings.Contains(host, "]:")
+}