fatedier 7 lat temu
rodzic
commit
7957572ced
4 zmienionych plików z 139 dodań i 42 usunięć
  1. 2 1
      models/config/proxy.go
  2. 8 41
      models/config/server_common.go
  3. 65 0
      utils/util/util.go
  4. 64 0
      utils/util/util_test.go

+ 2 - 1
models/config/proxy.go

@@ -23,6 +23,7 @@ import (
 	"github.com/fatedier/frp/models/consts"
 	"github.com/fatedier/frp/models/msg"
 
+	"github.com/fatedier/frp/utils/util"
 	ini "github.com/vaughan0/go-ini"
 )
 
@@ -173,7 +174,7 @@ func (cfg *BindInfoConf) UnMarshalToMsg(pMsg *msg.NewProxy) {
 
 func (cfg *BindInfoConf) check() (err error) {
 	if len(ServerCommonCfg.PrivilegeAllowPorts) != 0 {
-		if _, ok := ServerCommonCfg.PrivilegeAllowPorts[cfg.RemotePort]; !ok {
+		if ok := util.ContainsPort(ServerCommonCfg.PrivilegeAllowPorts, cfg.RemotePort); !ok {
 			return fmt.Errorf("remote port [%d] isn't allowed", cfg.RemotePort)
 		}
 	}

+ 8 - 41
models/config/server_common.go

@@ -19,6 +19,7 @@ import (
 	"strconv"
 	"strings"
 
+	"github.com/fatedier/frp/utils/util"
 	ini "github.com/vaughan0/go-ini"
 )
 
@@ -52,7 +53,7 @@ type ServerCommonConf struct {
 	TcpMux         bool
 
 	// if PrivilegeAllowPorts is not nil, tcp proxies which remote port exist in this map can be connected
-	PrivilegeAllowPorts map[int64]struct{}
+	PrivilegeAllowPorts [][2]int64
 	MaxPoolCount        int64
 	HeartBeatTimeout    int64
 	UserConnTimeout     int64
@@ -188,47 +189,13 @@ func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) {
 	if cfg.PrivilegeMode == true {
 		cfg.PrivilegeToken, _ = conf.Get("common", "privilege_token")
 
-		cfg.PrivilegeAllowPorts = make(map[int64]struct{})
-		tmpStr, ok = conf.Get("common", "privilege_allow_ports")
+		allowPortsStr, ok := conf.Get("common", "privilege_allow_ports")
+		// TODO: check if conflicts exist in port ranges
 		if ok {
-			// e.g. 1000-2000,2001,2002,3000-4000
-			portRanges := strings.Split(tmpStr, ",")
-			for _, portRangeStr := range portRanges {
-				// 1000-2000 or 2001
-				portArray := strings.Split(portRangeStr, "-")
-				// length: only 1 or 2 is correct
-				rangeType := len(portArray)
-				if rangeType == 1 {
-					// single port
-					singlePort, errRet := strconv.ParseInt(portArray[0], 10, 64)
-					if errRet != nil {
-						err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet)
-						return
-					}
-					cfg.PrivilegeAllowPorts[singlePort] = struct{}{}
-				} else if rangeType == 2 {
-					// range ports
-					min, errRet := strconv.ParseInt(portArray[0], 10, 64)
-					if errRet != nil {
-						err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet)
-						return
-					}
-					max, errRet := strconv.ParseInt(portArray[1], 10, 64)
-					if errRet != nil {
-						err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet)
-						return
-					}
-					if max < min {
-						err = fmt.Errorf("Parse conf error: privilege_allow_ports range incorrect")
-						return
-					}
-					for i := min; i <= max; i++ {
-						cfg.PrivilegeAllowPorts[i] = struct{}{}
-					}
-				} else {
-					err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect")
-					return
-				}
+			cfg.PrivilegeAllowPorts, err = util.GetPortRanges(allowPortsStr)
+			if err != nil {
+				err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", err)
+				return
 			}
 		}
 	}

+ 65 - 0
utils/util/util.go

@@ -19,6 +19,8 @@ import (
 	"crypto/rand"
 	"encoding/hex"
 	"fmt"
+	"strconv"
+	"strings"
 )
 
 // RandId return a rand string used in frp.
@@ -45,3 +47,66 @@ func GetAuthKey(token string, timestamp int64) (key string) {
 	data := md5Ctx.Sum(nil)
 	return hex.EncodeToString(data)
 }
+
+// for example: rangeStr is "1000-2000,2001,2002,3000-4000", return an array as port ranges.
+func GetPortRanges(rangeStr string) (portRanges [][2]int64, err error) {
+	// for example: 1000-2000,2001,2002,3000-4000
+	rangeArray := strings.Split(rangeStr, ",")
+	for _, portRangeStr := range rangeArray {
+		// 1000-2000 or 2001
+		portArray := strings.Split(portRangeStr, "-")
+		// length: only 1 or 2 is correct
+		rangeType := len(portArray)
+		if rangeType == 1 {
+			singlePort, err := strconv.ParseInt(portArray[0], 10, 64)
+			if err != nil {
+				return [][2]int64{}, err
+			}
+			portRanges = append(portRanges, [2]int64{singlePort, singlePort})
+		} else if rangeType == 2 {
+			min, err := strconv.ParseInt(portArray[0], 10, 64)
+			if err != nil {
+				return [][2]int64{}, err
+			}
+			max, err := strconv.ParseInt(portArray[1], 10, 64)
+			if err != nil {
+				return [][2]int64{}, err
+			}
+			if max < min {
+				return [][2]int64{}, fmt.Errorf("range incorrect")
+			}
+			portRanges = append(portRanges, [2]int64{min, max})
+		} else {
+			return [][2]int64{}, fmt.Errorf("format error")
+		}
+	}
+	return portRanges, nil
+}
+
+func ContainsPort(portRanges [][2]int64, port int64) bool {
+	for _, pr := range portRanges {
+		if port >= pr[0] && port <= pr[1] {
+			return true
+		}
+	}
+	return false
+}
+
+func PortRangesCut(portRanges [][2]int64, port int64) [][2]int64 {
+	var tmpRanges [][2]int64
+	for _, pr := range portRanges {
+		if port >= pr[0] && port <= pr[1] {
+			leftRange := [2]int64{pr[0], port - 1}
+			rightRange := [2]int64{port + 1, pr[1]}
+			if leftRange[0] <= leftRange[1] {
+				tmpRanges = append(tmpRanges, leftRange)
+			}
+			if rightRange[0] <= rightRange[1] {
+				tmpRanges = append(tmpRanges, rightRange)
+			}
+		} else {
+			tmpRanges = append(tmpRanges, pr)
+		}
+	}
+	return tmpRanges
+}

+ 64 - 0
utils/util/util_test.go

@@ -20,3 +20,67 @@ func TestGetAuthKey(t *testing.T) {
 	t.Log(key)
 	assert.Equal("6df41a43725f0c770fd56379e12acf8c", key)
 }
+
+func TestGetPortRanges(t *testing.T) {
+	assert := assert.New(t)
+
+	rangesStr := "2000-3000,3001,4000-50000"
+	expect := [][2]int64{
+		[2]int64{2000, 3000},
+		[2]int64{3001, 3001},
+		[2]int64{4000, 50000},
+	}
+	actual, err := GetPortRanges(rangesStr)
+	assert.Nil(err)
+	t.Log(actual)
+	assert.Equal(expect, actual)
+}
+
+func TestContainsPort(t *testing.T) {
+	assert := assert.New(t)
+
+	rangesStr := "2000-3000,3001,4000-50000"
+	portRanges, err := GetPortRanges(rangesStr)
+	assert.Nil(err)
+
+	type Case struct {
+		Port   int64
+		Answer bool
+	}
+	cases := []Case{
+		Case{
+			Port:   3001,
+			Answer: true,
+		},
+		Case{
+			Port:   3002,
+			Answer: false,
+		},
+		Case{
+			Port:   44444,
+			Answer: true,
+		},
+	}
+	for _, elem := range cases {
+		ok := ContainsPort(portRanges, elem.Port)
+		assert.Equal(elem.Answer, ok)
+	}
+}
+
+func TestPortRangesCut(t *testing.T) {
+	assert := assert.New(t)
+
+	rangesStr := "2000-3000,3001,4000-50000"
+	portRanges, err := GetPortRanges(rangesStr)
+	assert.Nil(err)
+
+	expect := [][2]int64{
+		[2]int64{2000, 3000},
+		[2]int64{3001, 3001},
+		[2]int64{4000, 44443},
+		[2]int64{44445, 50000},
+	}
+	actual := PortRangesCut(portRanges, 44444)
+	t.Log(actual)
+	assert.Equal(expect, actual)
+}