Browse Source

use net.JoinHostPort instead of fmt.Sprintf (#2791)

fatedier 3 years ago
parent
commit
6194273615

+ 6 - 10
client/proxy/proxy.go

@@ -347,22 +347,18 @@ func (pxy *XTCPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
 	xl.Trace("get natHoleRespMsg, sid [%s], client address [%s] visitor address [%s]", natHoleRespMsg.Sid, natHoleRespMsg.ClientAddr, natHoleRespMsg.VisitorAddr)
 
 	// Send detect message
-	array := strings.Split(natHoleRespMsg.VisitorAddr, ":")
-	if len(array) <= 1 {
-		xl.Error("get NatHoleResp visitor address error: %v", natHoleRespMsg.VisitorAddr)
+	host, portStr, err := net.SplitHostPort(natHoleRespMsg.VisitorAddr)
+	if err != nil {
+		xl.Error("get NatHoleResp visitor address [%s] error: %v", natHoleRespMsg.VisitorAddr, err)
 	}
 	laddr, _ := net.ResolveUDPAddr("udp", clientConn.LocalAddr().String())
-	/*
-		for i := 1000; i < 65000; i++ {
-			pxy.sendDetectMsg(array[0], int64(i), laddr, "a")
-		}
-	*/
-	port, err := strconv.ParseInt(array[1], 10, 64)
+
+	port, err := strconv.ParseInt(portStr, 10, 64)
 	if err != nil {
 		xl.Error("get natHoleResp visitor address error: %v", natHoleRespMsg.VisitorAddr)
 		return
 	}
-	pxy.sendDetectMsg(array[0], int(port), laddr, []byte(natHoleRespMsg.Sid))
+	pxy.sendDetectMsg(host, int(port), laddr, []byte(natHoleRespMsg.Sid))
 	xl.Trace("send all detect msg done")
 
 	msg.WriteMsg(conn, &msg.NatHoleClientDetectOK{})

+ 4 - 3
client/visitor.go

@@ -20,6 +20,7 @@ import (
 	"fmt"
 	"io"
 	"net"
+	"strconv"
 	"sync"
 	"time"
 
@@ -85,7 +86,7 @@ type STCPVisitor struct {
 }
 
 func (sv *STCPVisitor) Run() (err error) {
-	sv.l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", sv.cfg.BindAddr, sv.cfg.BindPort))
+	sv.l, err = net.Listen("tcp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort)))
 	if err != nil {
 		return
 	}
@@ -174,7 +175,7 @@ type XTCPVisitor struct {
 }
 
 func (sv *XTCPVisitor) Run() (err error) {
-	sv.l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", sv.cfg.BindAddr, sv.cfg.BindPort))
+	sv.l, err = net.Listen("tcp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort)))
 	if err != nil {
 		return
 	}
@@ -352,7 +353,7 @@ type SUDPVisitor struct {
 func (sv *SUDPVisitor) Run() (err error) {
 	xl := xlog.FromContextSafe(sv.ctx)
 
-	addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", sv.cfg.BindAddr, sv.cfg.BindPort))
+	addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort)))
 	if err != nil {
 		return fmt.Errorf("sudp ResolveUDPAddr error: %v", err)
 	}

+ 2 - 1
pkg/util/net/udp.go

@@ -18,6 +18,7 @@ import (
 	"fmt"
 	"io"
 	"net"
+	"strconv"
 	"sync"
 	"time"
 
@@ -163,7 +164,7 @@ type UDPListener struct {
 }
 
 func ListenUDP(bindAddr string, bindPort int) (l *UDPListener, err error) {
-	udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
+	udpAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(bindAddr, strconv.Itoa(bindPort)))
 	if err != nil {
 		return l, err
 	}

+ 2 - 2
pkg/util/net/websocket.go

@@ -2,9 +2,9 @@ package net
 
 import (
 	"errors"
-	"fmt"
 	"net"
 	"net/http"
+	"strconv"
 
 	"golang.org/x/net/websocket"
 )
@@ -52,7 +52,7 @@ func NewWebsocketListener(ln net.Listener) (wl *WebsocketListener) {
 }
 
 func ListenWebsocket(bindAddr string, bindPort int) (*WebsocketListener, error) {
-	tcpLn, err := net.Listen("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
+	tcpLn, err := net.Listen("tcp", net.JoinHostPort(bindAddr, strconv.Itoa(bindPort)))
 	if err != nil {
 		return nil, err
 	}

+ 1 - 1
pkg/util/tcpmux/httpconnect.go

@@ -48,7 +48,7 @@ func readHTTPConnectRequest(rd io.Reader) (host string, err error) {
 		return
 	}
 
-	host = util.GetHostFromAddr(req.Host)
+	host, _ = util.CanonicalHost(req.Host)
 	return
 }
 

+ 0 - 11
pkg/util/util/http.go

@@ -34,17 +34,6 @@ func OkResponse() *http.Response {
 	return res
 }
 
-// TODO: use "CanonicalHost" func to replace all "GetHostFromAddr" func.
-func GetHostFromAddr(addr string) (host string) {
-	strs := strings.Split(addr, ":")
-	if len(strs) > 1 {
-		host = strs[0]
-	} else {
-		host = addr
-	}
-	return
-}
-
 // canonicalHost strips port from host if present and returns the canonicalized
 // host name.
 func CanonicalHost(host string) (string, error) {

+ 2 - 1
pkg/util/util/util.go

@@ -19,6 +19,7 @@ import (
 	"crypto/rand"
 	"encoding/hex"
 	"fmt"
+	"net"
 	"strconv"
 	"strings"
 )
@@ -52,7 +53,7 @@ func CanonicalAddr(host string, port int) (addr string) {
 	if port == 80 || port == 443 {
 		addr = host
 	} else {
-		addr = fmt.Sprintf("%s:%d", host, port)
+		addr = net.JoinHostPort(host, strconv.Itoa(port))
 	}
 	return
 }

+ 3 - 3
pkg/util/vhost/http.go

@@ -59,7 +59,7 @@ func NewHTTPReverseProxy(option HTTPReverseProxyOptions, vhostRouter *Routers) *
 		Director: func(req *http.Request) {
 			req.URL.Scheme = "http"
 			url := req.Context().Value(RouteInfoURL).(string)
-			oldHost := util.GetHostFromAddr(req.Context().Value(RouteInfoHost).(string))
+			oldHost, _ := util.CanonicalHost(req.Context().Value(RouteInfoHost).(string))
 			rc := rp.GetRouteConfig(oldHost, url)
 			if rc != nil {
 				if rc.RewriteHost != "" {
@@ -81,7 +81,7 @@ func NewHTTPReverseProxy(option HTTPReverseProxyOptions, vhostRouter *Routers) *
 			IdleConnTimeout:       60 * time.Second,
 			DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
 				url := ctx.Value(RouteInfoURL).(string)
-				host := util.GetHostFromAddr(ctx.Value(RouteInfoHost).(string))
+				host, _ := util.CanonicalHost(ctx.Value(RouteInfoHost).(string))
 				remote := ctx.Value(RouteInfoRemote).(string)
 				return rp.CreateConnection(host, url, remote)
 			},
@@ -191,7 +191,7 @@ func (rp *HTTPReverseProxy) getVhost(domain string, location string) (vr *Router
 }
 
 func (rp *HTTPReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
-	domain := util.GetHostFromAddr(req.Host)
+	domain, _ := util.CanonicalHost(req.Host)
 	location := req.URL.Path
 	user, passwd, _ := req.BasicAuth()
 	if !rp.CheckAuth(domain, location, user, passwd) {

+ 2 - 2
server/group/tcp.go

@@ -15,8 +15,8 @@
 package group
 
 import (
-	"fmt"
 	"net"
+	"strconv"
 	"sync"
 
 	"github.com/fatedier/frp/server/ports"
@@ -101,7 +101,7 @@ func (tg *TCPGroup) Listen(proxyName string, group string, groupKey string, addr
 		if err != nil {
 			return
 		}
-		tcpLn, errRet := net.Listen("tcp", fmt.Sprintf("%s:%d", addr, port))
+		tcpLn, errRet := net.Listen("tcp", net.JoinHostPort(addr, strconv.Itoa(port)))
 		if errRet != nil {
 			err = errRet
 			return

+ 3 - 3
server/ports/ports.go

@@ -2,8 +2,8 @@ package ports
 
 import (
 	"errors"
-	"fmt"
 	"net"
+	"strconv"
 	"sync"
 	"time"
 )
@@ -134,7 +134,7 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) {
 
 func (pm *Manager) isPortAvailable(port int) bool {
 	if pm.netType == "udp" {
-		addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pm.bindAddr, port))
+		addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(pm.bindAddr, strconv.Itoa(port)))
 		if err != nil {
 			return false
 		}
@@ -146,7 +146,7 @@ func (pm *Manager) isPortAvailable(port int) bool {
 		return true
 	}
 
-	l, err := net.Listen(pm.netType, fmt.Sprintf("%s:%d", pm.bindAddr, port))
+	l, err := net.Listen(pm.netType, net.JoinHostPort(pm.bindAddr, strconv.Itoa(port)))
 	if err != nil {
 		return false
 	}

+ 2 - 1
server/proxy/tcp.go

@@ -17,6 +17,7 @@ package proxy
 import (
 	"fmt"
 	"net"
+	"strconv"
 
 	"github.com/fatedier/frp/pkg/config"
 )
@@ -54,7 +55,7 @@ func (pxy *TCPProxy) Run() (remoteAddr string, err error) {
 				pxy.rc.TCPPortManager.Release(pxy.realPort)
 			}
 		}()
-		listener, errRet := net.Listen("tcp", fmt.Sprintf("%s:%d", pxy.serverCfg.ProxyBindAddr, pxy.realPort))
+		listener, errRet := net.Listen("tcp", net.JoinHostPort(pxy.serverCfg.ProxyBindAddr, strconv.Itoa(pxy.realPort)))
 		if errRet != nil {
 			err = errRet
 			return

+ 2 - 1
server/proxy/udp.go

@@ -19,6 +19,7 @@ import (
 	"fmt"
 	"io"
 	"net"
+	"strconv"
 	"time"
 
 	"github.com/fatedier/frp/pkg/config"
@@ -70,7 +71,7 @@ func (pxy *UDPProxy) Run() (remoteAddr string, err error) {
 
 	remoteAddr = fmt.Sprintf(":%d", pxy.realPort)
 	pxy.cfg.RemotePort = pxy.realPort
-	addr, errRet := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pxy.serverCfg.ProxyBindAddr, pxy.realPort))
+	addr, errRet := net.ResolveUDPAddr("udp", net.JoinHostPort(pxy.serverCfg.ProxyBindAddr, strconv.Itoa(pxy.realPort)))
 	if errRet != nil {
 		err = errRet
 		return

+ 6 - 5
server/service.go

@@ -124,7 +124,8 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) {
 	// Create tcpmux httpconnect multiplexer.
 	if cfg.TCPMuxHTTPConnectPort > 0 {
 		var l net.Listener
-		l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", cfg.ProxyBindAddr, cfg.TCPMuxHTTPConnectPort))
+		address := net.JoinHostPort(cfg.ProxyBindAddr, strconv.Itoa(cfg.TCPMuxHTTPConnectPort))
+		l, err = net.Listen("tcp", address)
 		if err != nil {
 			err = fmt.Errorf("Create server listener error, %v", err)
 			return
@@ -135,7 +136,7 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) {
 			err = fmt.Errorf("Create vhost tcpMuxer error, %v", err)
 			return
 		}
-		log.Info("tcpmux httpconnect multiplexer listen on %s:%d", cfg.ProxyBindAddr, cfg.TCPMuxHTTPConnectPort)
+		log.Info("tcpmux httpconnect multiplexer listen on %s", address)
 	}
 
 	// Init all plugins
@@ -199,7 +200,7 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) {
 			err = fmt.Errorf("Listen on kcp address udp %s error: %v", address, err)
 			return
 		}
-		log.Info("frps kcp listen on udp %s:%d", cfg.BindAddr, cfg.KCPBindPort)
+		log.Info("frps kcp listen on udp %s", address)
 	}
 
 	// Listen for accepting connections from client using websocket protocol.
@@ -232,7 +233,7 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) {
 			}
 		}
 		go server.Serve(l)
-		log.Info("http service listen on %s:%d", cfg.ProxyBindAddr, cfg.VhostHTTPPort)
+		log.Info("http service listen on %s", address)
 	}
 
 	// Create https vhost muxer.
@@ -288,7 +289,7 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) {
 			err = fmt.Errorf("Create dashboard web server error, %v", err)
 			return
 		}
-		log.Info("Dashboard listen on %s:%d", cfg.DashboardAddr, cfg.DashboardPort)
+		log.Info("Dashboard listen on %s", address)
 		statsEnable = true
 	}
 	if statsEnable {

+ 20 - 0
test/e2e/basic/client_server.go

@@ -249,4 +249,24 @@ var _ = Describe("[Feature: Client-Server]", func() {
 			})
 		}
 	})
+
+	Describe("IPv6 bind address", func() {
+		supportProtocols := []string{"tcp", "kcp", "websocket"}
+		for _, protocol := range supportProtocols {
+			tmp := protocol
+			defineClientServerTest("IPv6 bind address: "+strings.ToUpper(tmp), f, &generalTestConfigures{
+				server: fmt.Sprintf(`
+					bind_addr = ::
+					kcp_bind_port = {{ .%s }}
+					protocol = %s
+					`, consts.PortServerName, protocol),
+				client: fmt.Sprintf(`
+					tls_enable = true
+					protocol = %s
+					disable_custom_tls_first_byte = true
+					`, protocol),
+			})
+		}
+	})
+
 })

+ 1 - 2
test/e2e/mock/server/httpserver/server.go

@@ -2,7 +2,6 @@ package httpserver
 
 import (
 	"crypto/tls"
-	"fmt"
 	"net"
 	"net/http"
 	"strconv"
@@ -97,7 +96,7 @@ func (s *Server) Close() error {
 }
 
 func (s *Server) initListener() (err error) {
-	s.l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", s.bindAddr, s.bindPort))
+	s.l, err = net.Listen("tcp", net.JoinHostPort(s.bindAddr, strconv.Itoa(s.bindPort)))
 	return
 }
 

+ 2 - 1
test/e2e/mock/server/streamserver/server.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 	"io"
 	"net"
+	"strconv"
 
 	libnet "github.com/fatedier/frp/pkg/util/net"
 	"github.com/fatedier/frp/test/e2e/pkg/rpc"
@@ -99,7 +100,7 @@ func (s *Server) Close() error {
 func (s *Server) initListener() (err error) {
 	switch s.netType {
 	case TCP:
-		s.l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", s.bindAddr, s.bindPort))
+		s.l, err = net.Listen("tcp", net.JoinHostPort(s.bindAddr, strconv.Itoa(s.bindPort)))
 	case UDP:
 		s.l, err = libnet.ListenUDP(s.bindAddr, s.bindPort)
 	case Unix:

+ 3 - 2
test/e2e/pkg/port/port.go

@@ -3,6 +3,7 @@ package port
 import (
 	"fmt"
 	"net"
+	"strconv"
 	"sync"
 
 	"k8s.io/apimachinery/pkg/util/sets"
@@ -57,7 +58,7 @@ func (pa *Allocator) GetByName(portName string) int {
 			return 0
 		}
 
-		l, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
+		l, err := net.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(port)))
 		if err != nil {
 			// Maybe not controlled by us, mark it used.
 			pa.used.Insert(port)
@@ -65,7 +66,7 @@ func (pa *Allocator) GetByName(portName string) int {
 		}
 		l.Close()
 
-		udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", port))
+		udpAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort("127.0.0.1", strconv.Itoa(port)))
 		if err != nil {
 			continue
 		}