Selaa lähdekoodia

feat(nathole): use serverUDPPort in nathole discovery when available (#3382)

fatedier 1 vuosi sitten
vanhempi
commit
3faae194d0
3 muutettua tiedostoa jossa 151 lisäystä ja 97 poistoa
  1. 9 4
      cmd/frpc/sub/nathole.go
  2. 125 93
      pkg/nathole/discovery.go
  3. 17 0
      pkg/nathole/utils.go

+ 9 - 4
cmd/frpc/sub/nathole.go

@@ -53,8 +53,12 @@ var natholeDiscoveryCmd = &cobra.Command{
 			os.Exit(1)
 		}
 
+		serverAddr := ""
+		if cfg.ServerUDPPort != 0 {
+			serverAddr = net.JoinHostPort(cfg.ServerAddr, strconv.Itoa(cfg.ServerUDPPort))
+		}
 		addresses, err := nathole.Discover(
-			net.JoinHostPort(cfg.ServerAddr, strconv.Itoa(cfg.ServerUDPPort)),
+			serverAddr,
 			[]string{cfg.NatHoleSTUNServer},
 			[]byte(cfg.Token),
 		)
@@ -62,6 +66,10 @@ var natholeDiscoveryCmd = &cobra.Command{
 			fmt.Println("discover error:", err)
 			os.Exit(1)
 		}
+		if len(addresses) < 2 {
+			fmt.Printf("discover error: can not get enough addresses, need 2, got: %v\n", addresses)
+			os.Exit(1)
+		}
 
 		natType, behavior, err := nathole.ClassifyNATType(addresses)
 		if err != nil {
@@ -79,8 +87,5 @@ func validateForNatHoleDiscovery(cfg config.ClientCommonConf) error {
 	if cfg.NatHoleSTUNServer == "" {
 		return fmt.Errorf("nat_hole_stun_server can not be empty")
 	}
-	if cfg.ServerUDPPort == 0 {
-		return fmt.Errorf("server udp port can not be empty")
-	}
 	return nil
 }

+ 125 - 93
pkg/nathole/discovery.go

@@ -26,31 +26,12 @@ import (
 
 var responseTimeout = 3 * time.Second
 
-type Address struct {
-	IP   string
-	Port int
-}
-
 type Message struct {
 	Body []byte
 	Addr string
 }
 
 func Discover(serverAddress string, stunServers []string, key []byte) ([]string, error) {
-	// parse address to net.Address
-	stunAddresses := make([]net.Addr, 0, len(stunServers))
-	for _, stunServer := range stunServers {
-		addr, err := net.ResolveUDPAddr("udp4", stunServer)
-		if err != nil {
-			return nil, err
-		}
-		stunAddresses = append(stunAddresses, addr)
-	}
-	serverAddr, err := net.ResolveUDPAddr("udp4", serverAddress)
-	if err != nil {
-		return nil, err
-	}
-
 	// create a discoverConn and get response from messageChan
 	discoverConn, err := listen()
 	if err != nil {
@@ -61,90 +42,29 @@ func Discover(serverAddress string, stunServers []string, key []byte) ([]string,
 	go discoverConn.readLoop()
 
 	addresses := make([]string, 0, len(stunServers)+1)
-	// get external address from frp server
-	externalAddr, err := discoverFromServer(discoverConn, serverAddr, key)
-	if err != nil {
-		return nil, err
-	}
-	addresses = append(addresses, externalAddr)
-
-	for _, stunAddr := range stunAddresses {
-		// get external address from stun server
-		externalAddr, err = discoverFromStunServer(discoverConn, stunAddr)
+	if serverAddress != "" {
+		// get external address from frp server
+		externalAddr, err := discoverConn.discoverFromServer(serverAddress, key)
 		if err != nil {
 			return nil, err
 		}
 		addresses = append(addresses, externalAddr)
 	}
-	return addresses, nil
-}
-
-func discoverFromServer(c *discoverConn, addr net.Addr, key []byte) (string, error) {
-	m := &msg.NatHoleBinding{
-		TransactionID: NewTransactionID(),
-	}
-
-	buf, err := EncodeMessage(m, key)
-	if err != nil {
-		return "", err
-	}
 
-	if _, err := c.conn.WriteTo(buf, addr); err != nil {
-		return "", err
-	}
-
-	var respMsg msg.NatHoleBindingResp
-	select {
-	case rawMsg := <-c.messageChan:
-		if err := DecodeMessageInto(rawMsg.Body, key, &respMsg); err != nil {
-			return "", err
+	for _, addr := range stunServers {
+		// get external address from stun server
+		externalAddrs, err := discoverConn.discoverFromStunServer(addr)
+		if err != nil {
+			return nil, err
 		}
-	case <-time.After(responseTimeout):
-		return "", fmt.Errorf("wait response from frp server timeout")
-	}
-
-	if respMsg.TransactionID == "" {
-		return "", fmt.Errorf("error format: no transaction id found")
+		addresses = append(addresses, externalAddrs...)
 	}
-	if respMsg.Error != "" {
-		return "", fmt.Errorf("get externalAddr from frp server error: %s", respMsg.Error)
-	}
-	return respMsg.Address, nil
+	return addresses, nil
 }
 
-func discoverFromStunServer(c *discoverConn, addr net.Addr) (string, error) {
-	request, err := stun.Build(stun.TransactionID, stun.BindingRequest)
-	if err != nil {
-		return "", err
-	}
-
-	if err = request.NewTransactionID(); err != nil {
-		return "", err
-	}
-	if _, err := c.conn.WriteTo(request.Raw, addr); err != nil {
-		return "", err
-	}
-
-	var m stun.Message
-	select {
-	case msg := <-c.messageChan:
-		m.Raw = msg.Body
-		if err := m.Decode(); err != nil {
-			return "", err
-		}
-	case <-time.After(responseTimeout):
-		return "", fmt.Errorf("wait response from stun server timeout")
-	}
-
-	xorAddr := &stun.XORMappedAddress{}
-	mappedAddr := &stun.MappedAddress{}
-	if err := xorAddr.GetFrom(&m); err == nil {
-		return xorAddr.String(), nil
-	}
-	if err := mappedAddr.GetFrom(&m); err == nil {
-		return mappedAddr.String(), nil
-	}
-	return "", fmt.Errorf("no address found")
+type stunResponse struct {
+	externalAddr string
+	otherAddr    string
 }
 
 type discoverConn struct {
@@ -190,3 +110,115 @@ func (c *discoverConn) readLoop() {
 		}
 	}
 }
+
+func (c *discoverConn) doSTUNRequest(addr string) (*stunResponse, error) {
+	serverAddr, err := net.ResolveUDPAddr("udp4", addr)
+	if err != nil {
+		return nil, err
+	}
+	request, err := stun.Build(stun.TransactionID, stun.BindingRequest)
+	if err != nil {
+		return nil, err
+	}
+
+	if err = request.NewTransactionID(); err != nil {
+		return nil, err
+	}
+	if _, err := c.conn.WriteTo(request.Raw, serverAddr); err != nil {
+		return nil, err
+	}
+
+	var m stun.Message
+	select {
+	case msg := <-c.messageChan:
+		m.Raw = msg.Body
+		if err := m.Decode(); err != nil {
+			return nil, err
+		}
+	case <-time.After(responseTimeout):
+		return nil, fmt.Errorf("wait response from stun server timeout")
+	}
+	xorAddrGetter := &stun.XORMappedAddress{}
+	mappedAddrGetter := &stun.MappedAddress{}
+	changedAddrGetter := ChangedAddress{}
+	otherAddrGetter := &stun.OtherAddress{}
+
+	resp := &stunResponse{}
+	if err := mappedAddrGetter.GetFrom(&m); err == nil {
+		resp.externalAddr = mappedAddrGetter.String()
+	}
+	if err := xorAddrGetter.GetFrom(&m); err == nil {
+		resp.externalAddr = xorAddrGetter.String()
+	}
+	if err := changedAddrGetter.GetFrom(&m); err == nil {
+		resp.otherAddr = changedAddrGetter.String()
+	}
+	if err := otherAddrGetter.GetFrom(&m); err == nil {
+		resp.otherAddr = otherAddrGetter.String()
+	}
+	return resp, nil
+}
+
+func (c *discoverConn) discoverFromServer(serverAddress string, key []byte) (string, error) {
+	addr, err := net.ResolveUDPAddr("udp4", serverAddress)
+	if err != nil {
+		return "", err
+	}
+	m := &msg.NatHoleBinding{
+		TransactionID: NewTransactionID(),
+	}
+
+	buf, err := EncodeMessage(m, key)
+	if err != nil {
+		return "", err
+	}
+
+	if _, err := c.conn.WriteTo(buf, addr); err != nil {
+		return "", err
+	}
+
+	var respMsg msg.NatHoleBindingResp
+	select {
+	case rawMsg := <-c.messageChan:
+		if err := DecodeMessageInto(rawMsg.Body, key, &respMsg); err != nil {
+			return "", err
+		}
+	case <-time.After(responseTimeout):
+		return "", fmt.Errorf("wait response from frp server timeout")
+	}
+
+	if respMsg.TransactionID == "" {
+		return "", fmt.Errorf("error format: no transaction id found")
+	}
+	if respMsg.Error != "" {
+		return "", fmt.Errorf("get externalAddr from frp server error: %s", respMsg.Error)
+	}
+	return respMsg.Address, nil
+}
+
+func (c *discoverConn) discoverFromStunServer(addr string) ([]string, error) {
+	resp, err := c.doSTUNRequest(addr)
+	if err != nil {
+		return nil, err
+	}
+	if resp.externalAddr == "" {
+		return nil, fmt.Errorf("no external address found")
+	}
+
+	externalAddrs := make([]string, 0, 2)
+	externalAddrs = append(externalAddrs, resp.externalAddr)
+
+	if resp.otherAddr == "" {
+		return externalAddrs, nil
+	}
+
+	// find external address from changed address
+	resp, err = c.doSTUNRequest(resp.otherAddr)
+	if err != nil {
+		return nil, err
+	}
+	if resp.externalAddr != "" {
+		externalAddrs = append(externalAddrs, resp.externalAddr)
+	}
+	return externalAddrs, nil
+}

+ 17 - 0
pkg/nathole/utils.go

@@ -16,8 +16,11 @@ package nathole
 
 import (
 	"bytes"
+	"net"
+	"strconv"
 
 	"github.com/fatedier/golib/crypto"
+	"github.com/pion/stun"
 
 	"github.com/fatedier/frp/pkg/msg"
 )
@@ -46,3 +49,17 @@ func DecodeMessageInto(data, key []byte, m msg.Message) error {
 	}
 	return nil
 }
+
+type ChangedAddress struct {
+	IP   net.IP
+	Port int
+}
+
+func (s *ChangedAddress) GetFrom(m *stun.Message) error {
+	a := (*stun.MappedAddress)(s)
+	return a.GetFromAs(m, stun.AttrChangedAddress)
+}
+
+func (s *ChangedAddress) String() string {
+	return net.JoinHostPort(s.IP.String(), strconv.Itoa(s.Port))
+}