Browse Source

new feature: assign a random port if remote_port is 0 in type tcp and
udp

fatedier 7 years ago
parent
commit
b2c846664d

+ 1 - 1
.travis.yml

@@ -3,7 +3,7 @@ language: go
 
 go:
     - 1.8.x
-    - 1.x
+    - 1.9.x
 
 install:
     - make

+ 1 - 1
client/admin.go

@@ -31,7 +31,7 @@ var (
 	httpServerWriteTimeout = 10 * time.Second
 )
 
-func (svr *Service) RunAdminServer(addr string, port int64) (err error) {
+func (svr *Service) RunAdminServer(addr string, port int) (err error) {
 	// url router
 	router := httprouter.New()
 

+ 10 - 2
client/admin_api.go

@@ -124,12 +124,20 @@ func NewProxyStatusResp(status *ProxyStatus) ProxyStatusResp {
 			psr.LocalAddr = fmt.Sprintf("%s:%d", cfg.LocalIp, cfg.LocalPort)
 		}
 		psr.Plugin = cfg.Plugin
-		psr.RemoteAddr = config.ClientCommonCfg.ServerAddr + status.RemoteAddr
+		if status.Err != "" {
+			psr.RemoteAddr = fmt.Sprintf("%s:%d", config.ClientCommonCfg.ServerAddr, cfg.RemotePort)
+		} else {
+			psr.RemoteAddr = config.ClientCommonCfg.ServerAddr + status.RemoteAddr
+		}
 	case *config.UdpProxyConf:
 		if cfg.LocalPort != 0 {
 			psr.LocalAddr = fmt.Sprintf("%s:%d", cfg.LocalIp, cfg.LocalPort)
 		}
-		psr.RemoteAddr = config.ClientCommonCfg.ServerAddr + status.RemoteAddr
+		if status.Err != "" {
+			psr.RemoteAddr = fmt.Sprintf("%s:%d", config.ClientCommonCfg.ServerAddr, cfg.RemotePort)
+		} else {
+			psr.RemoteAddr = config.ClientCommonCfg.ServerAddr + status.RemoteAddr
+		}
 	case *config.HttpProxyConf:
 		if cfg.LocalPort != 0 {
 			psr.LocalAddr = fmt.Sprintf("%s:%d", cfg.LocalIp, cfg.LocalPort)

+ 4 - 4
client/visitor.go

@@ -77,7 +77,7 @@ type StcpVisitor struct {
 }
 
 func (sv *StcpVisitor) Run() (err error) {
-	sv.l, err = frpNet.ListenTcp(sv.cfg.BindAddr, int64(sv.cfg.BindPort))
+	sv.l, err = frpNet.ListenTcp(sv.cfg.BindAddr, sv.cfg.BindPort)
 	if err != nil {
 		return
 	}
@@ -164,7 +164,7 @@ type XtcpVisitor struct {
 }
 
 func (sv *XtcpVisitor) Run() (err error) {
-	sv.l, err = frpNet.ListenTcp(sv.cfg.BindAddr, int64(sv.cfg.BindPort))
+	sv.l, err = frpNet.ListenTcp(sv.cfg.BindAddr, sv.cfg.BindPort)
 	if err != nil {
 		return
 	}
@@ -255,7 +255,7 @@ func (sv *XtcpVisitor) handleConn(userConn frpNet.Conn) {
 		sv.Error("get natHoleResp client address error: %s", natHoleRespMsg.ClientAddr)
 		return
 	}
-	sv.sendDetectMsg(array[0], int64(port), laddr, []byte(natHoleRespMsg.Sid))
+	sv.sendDetectMsg(array[0], int(port), laddr, []byte(natHoleRespMsg.Sid))
 	sv.Trace("send all detect msg done")
 
 	// Listen for visitorConn's address and wait for client connection.
@@ -302,7 +302,7 @@ func (sv *XtcpVisitor) handleConn(userConn frpNet.Conn) {
 	sv.Debug("join connections closed")
 }
 
-func (sv *XtcpVisitor) sendDetectMsg(addr string, port int64, laddr *net.UDPAddr, content []byte) (err error) {
+func (sv *XtcpVisitor) sendDetectMsg(addr string, port int, laddr *net.UDPAddr, content []byte) (err error) {
 	daddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", addr, port))
 	if err != nil {
 		return err

+ 2 - 2
cmd/frpc/main.go

@@ -99,7 +99,7 @@ func main() {
 	if args["status"] != nil {
 		if args["status"].(bool) {
 			if err = CmdStatus(); err != nil {
-				fmt.Println("frps get status error: %v\n", err)
+				fmt.Printf("frps get status error: %v\n", err)
 				os.Exit(1)
 			} else {
 				os.Exit(0)
@@ -132,7 +132,7 @@ func main() {
 			os.Exit(1)
 		}
 		config.ClientCommonCfg.ServerAddr = addr[0]
-		config.ClientCommonCfg.ServerPort = serverPort
+		config.ClientCommonCfg.ServerPort = int(serverPort)
 	}
 
 	if args["-v"] != nil {

+ 1 - 1
cmd/frps/main.go

@@ -91,7 +91,7 @@ func main() {
 			os.Exit(1)
 		}
 		config.ServerCommonCfg.BindAddr = addr[0]
-		config.ServerCommonCfg.BindPort = bindPort
+		config.ServerCommonCfg.BindPort = int(bindPort)
 	}
 
 	if args["-v"] != nil {

+ 17 - 9
models/config/client_common.go

@@ -29,8 +29,8 @@ var ClientCommonCfg *ClientCommonConf
 type ClientCommonConf struct {
 	ConfigFile        string
 	ServerAddr        string
-	ServerPort        int64
-	ServerUdpPort     int64 // this is specified by login response message from frps
+	ServerPort        int
+	ServerUdpPort     int // this is specified by login response message from frps
 	HttpProxy         string
 	LogFile           string
 	LogWay            string
@@ -38,7 +38,7 @@ type ClientCommonConf struct {
 	LogMaxDays        int64
 	PrivilegeToken    string
 	AdminAddr         string
-	AdminPort         int64
+	AdminPort         int
 	AdminUser         string
 	AdminPwd          string
 	PoolCount         int
@@ -93,7 +93,12 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) {
 
 	tmpStr, ok = conf.Get("common", "server_port")
 	if ok {
-		cfg.ServerPort, _ = strconv.ParseInt(tmpStr, 10, 64)
+		v, err = strconv.ParseInt(tmpStr, 10, 64)
+		if err != nil {
+			err = fmt.Errorf("Parse conf error: invalid server_port")
+			return
+		}
+		cfg.ServerPort = int(v)
 	}
 
 	tmpStr, ok = conf.Get("common", "http_proxy")
@@ -139,7 +144,10 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) {
 	tmpStr, ok = conf.Get("common", "admin_port")
 	if ok {
 		if v, err = strconv.ParseInt(tmpStr, 10, 64); err == nil {
-			cfg.AdminPort = v
+			cfg.AdminPort = int(v)
+		} else {
+			err = fmt.Errorf("Parse conf error: invalid admin_port")
+			return
 		}
 	}
 
@@ -203,7 +211,7 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) {
 	if ok {
 		v, err = strconv.ParseInt(tmpStr, 10, 64)
 		if err != nil {
-			err = fmt.Errorf("Parse conf error: heartbeat_timeout is incorrect")
+			err = fmt.Errorf("Parse conf error: invalid heartbeat_timeout")
 			return
 		} else {
 			cfg.HeartBeatTimeout = v
@@ -214,7 +222,7 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) {
 	if ok {
 		v, err = strconv.ParseInt(tmpStr, 10, 64)
 		if err != nil {
-			err = fmt.Errorf("Parse conf error: heartbeat_interval is incorrect")
+			err = fmt.Errorf("Parse conf error: invalid heartbeat_interval")
 			return
 		} else {
 			cfg.HeartBeatInterval = v
@@ -222,12 +230,12 @@ func LoadClientCommonConf(conf ini.File) (cfg *ClientCommonConf, err error) {
 	}
 
 	if cfg.HeartBeatInterval <= 0 {
-		err = fmt.Errorf("Parse conf error: heartbeat_interval is incorrect")
+		err = fmt.Errorf("Parse conf error: invalid heartbeat_interval")
 		return
 	}
 
 	if cfg.HeartBeatTimeout < cfg.HeartBeatInterval {
-		err = fmt.Errorf("Parse conf error: heartbeat_timeout is incorrect, heartbeat_timeout is less than heartbeat_interval")
+		err = fmt.Errorf("Parse conf error: invalid heartbeat_timeout, heartbeat_timeout is less than heartbeat_interval")
 		return
 	}
 	return

+ 5 - 8
models/config/proxy.go

@@ -23,7 +23,6 @@ import (
 	"github.com/fatedier/frp/models/consts"
 	"github.com/fatedier/frp/models/msg"
 
-	"github.com/fatedier/frp/utils/util"
 	ini "github.com/vaughan0/go-ini"
 )
 
@@ -163,7 +162,7 @@ func (cfg *BaseProxyConf) UnMarshalToMsg(pMsg *msg.NewProxy) {
 // Bind info
 type BindInfoConf struct {
 	BindAddr   string `json:"bind_addr"`
-	RemotePort int64  `json:"remote_port"`
+	RemotePort int    `json:"remote_port"`
 }
 
 func (cfg *BindInfoConf) compare(cmp *BindInfoConf) bool {
@@ -183,10 +182,13 @@ func (cfg *BindInfoConf) LoadFromFile(name string, section ini.Section) (err err
 	var (
 		tmpStr string
 		ok     bool
+		v      int64
 	)
 	if tmpStr, ok = section["remote_port"]; ok {
-		if cfg.RemotePort, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
+		if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
 			return fmt.Errorf("Parse conf error: proxy [%s] remote_port error", name)
+		} else {
+			cfg.RemotePort = int(v)
 		}
 	} else {
 		return fmt.Errorf("Parse conf error: proxy [%s] remote_port not found", name)
@@ -199,11 +201,6 @@ func (cfg *BindInfoConf) UnMarshalToMsg(pMsg *msg.NewProxy) {
 }
 
 func (cfg *BindInfoConf) check() (err error) {
-	if len(ServerCommonCfg.PrivilegeAllowPorts) != 0 {
-		if ok := util.ContainsPort(ServerCommonCfg.PrivilegeAllowPorts, cfg.RemotePort); !ok {
-			return fmt.Errorf("remote port [%d] isn't allowed", cfg.RemotePort)
-		}
-	}
 	return nil
 }
 

+ 98 - 57
models/config/server_common.go

@@ -19,7 +19,6 @@ import (
 	"strconv"
 	"strings"
 
-	"github.com/fatedier/frp/utils/util"
 	ini "github.com/vaughan0/go-ini"
 )
 
@@ -29,20 +28,20 @@ var ServerCommonCfg *ServerCommonConf
 type ServerCommonConf struct {
 	ConfigFile    string
 	BindAddr      string
-	BindPort      int64
-	BindUdpPort   int64
-	KcpBindPort   int64
+	BindPort      int
+	BindUdpPort   int
+	KcpBindPort   int
 	ProxyBindAddr string
 
 	// If VhostHttpPort equals 0, don't listen a public port for http protocol.
-	VhostHttpPort int64
+	VhostHttpPort int
 
 	// if VhostHttpsPort equals 0, don't listen a public port for https protocol
-	VhostHttpsPort int64
+	VhostHttpsPort int
 	DashboardAddr  string
 
 	// if DashboardPort equals 0, dashboard is not available
-	DashboardPort  int64
+	DashboardPort  int
 	DashboardUser  string
 	DashboardPwd   string
 	AssetsDir      string
@@ -56,8 +55,7 @@ type ServerCommonConf struct {
 	SubDomainHost  string
 	TcpMux         bool
 
-	// if PrivilegeAllowPorts is not nil, tcp proxies which remote port exist in this map can be connected
-	PrivilegeAllowPorts [][2]int64
+	PrivilegeAllowPorts map[int]struct{}
 	MaxPoolCount        int64
 	HeartBeatTimeout    int64
 	UserConnTimeout     int64
@@ -65,31 +63,32 @@ type ServerCommonConf struct {
 
 func GetDefaultServerCommonConf() *ServerCommonConf {
 	return &ServerCommonConf{
-		ConfigFile:       "./frps.ini",
-		BindAddr:         "0.0.0.0",
-		BindPort:         7000,
-		BindUdpPort:      0,
-		KcpBindPort:      0,
-		ProxyBindAddr:    "0.0.0.0",
-		VhostHttpPort:    0,
-		VhostHttpsPort:   0,
-		DashboardAddr:    "0.0.0.0",
-		DashboardPort:    0,
-		DashboardUser:    "admin",
-		DashboardPwd:     "admin",
-		AssetsDir:        "",
-		LogFile:          "console",
-		LogWay:           "console",
-		LogLevel:         "info",
-		LogMaxDays:       3,
-		PrivilegeMode:    true,
-		PrivilegeToken:   "",
-		AuthTimeout:      900,
-		SubDomainHost:    "",
-		TcpMux:           true,
-		MaxPoolCount:     5,
-		HeartBeatTimeout: 90,
-		UserConnTimeout:  10,
+		ConfigFile:          "./frps.ini",
+		BindAddr:            "0.0.0.0",
+		BindPort:            7000,
+		BindUdpPort:         0,
+		KcpBindPort:         0,
+		ProxyBindAddr:       "0.0.0.0",
+		VhostHttpPort:       0,
+		VhostHttpsPort:      0,
+		DashboardAddr:       "0.0.0.0",
+		DashboardPort:       0,
+		DashboardUser:       "admin",
+		DashboardPwd:        "admin",
+		AssetsDir:           "",
+		LogFile:             "console",
+		LogWay:              "console",
+		LogLevel:            "info",
+		LogMaxDays:          3,
+		PrivilegeMode:       true,
+		PrivilegeToken:      "",
+		AuthTimeout:         900,
+		SubDomainHost:       "",
+		TcpMux:              true,
+		PrivilegeAllowPorts: make(map[int]struct{}),
+		MaxPoolCount:        5,
+		HeartBeatTimeout:    90,
+		UserConnTimeout:     10,
 	}
 }
 
@@ -109,25 +108,31 @@ func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) {
 
 	tmpStr, ok = conf.Get("common", "bind_port")
 	if ok {
-		v, err = strconv.ParseInt(tmpStr, 10, 64)
-		if err == nil {
-			cfg.BindPort = v
+		if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
+			err = fmt.Errorf("Parse conf error: invalid bind_port")
+			return
+		} else {
+			cfg.BindPort = int(v)
 		}
 	}
 
 	tmpStr, ok = conf.Get("common", "bind_udp_port")
 	if ok {
-		v, err = strconv.ParseInt(tmpStr, 10, 64)
-		if err == nil {
-			cfg.BindUdpPort = v
+		if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
+			err = fmt.Errorf("Parse conf error: invalid bind_udp_port")
+			return
+		} else {
+			cfg.BindUdpPort = int(v)
 		}
 	}
 
 	tmpStr, ok = conf.Get("common", "kcp_bind_port")
 	if ok {
-		v, err = strconv.ParseInt(tmpStr, 10, 64)
-		if err == nil && v > 0 {
-			cfg.KcpBindPort = v
+		if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
+			err = fmt.Errorf("Parse conf error: invalid kcp_bind_port")
+			return
+		} else {
+			cfg.KcpBindPort = int(v)
 		}
 	}
 
@@ -140,10 +145,11 @@ func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) {
 
 	tmpStr, ok = conf.Get("common", "vhost_http_port")
 	if ok {
-		cfg.VhostHttpPort, err = strconv.ParseInt(tmpStr, 10, 64)
-		if err != nil {
-			err = fmt.Errorf("Parse conf error: vhost_http_port is incorrect")
+		if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
+			err = fmt.Errorf("Parse conf error: invalid vhost_http_port")
 			return
+		} else {
+			cfg.VhostHttpPort = int(v)
 		}
 	} else {
 		cfg.VhostHttpPort = 0
@@ -151,10 +157,11 @@ func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) {
 
 	tmpStr, ok = conf.Get("common", "vhost_https_port")
 	if ok {
-		cfg.VhostHttpsPort, err = strconv.ParseInt(tmpStr, 10, 64)
-		if err != nil {
-			err = fmt.Errorf("Parse conf error: vhost_https_port is incorrect")
+		if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
+			err = fmt.Errorf("Parse conf error: invalid vhost_https_port")
 			return
+		} else {
+			cfg.VhostHttpsPort = int(v)
 		}
 	} else {
 		cfg.VhostHttpsPort = 0
@@ -169,10 +176,11 @@ func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) {
 
 	tmpStr, ok = conf.Get("common", "dashboard_port")
 	if ok {
-		cfg.DashboardPort, err = strconv.ParseInt(tmpStr, 10, 64)
-		if err != nil {
-			err = fmt.Errorf("Parse conf error: dashboard_port is incorrect")
+		if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
+			err = fmt.Errorf("Parse conf error: invalid dashboard_port")
 			return
+		} else {
+			cfg.DashboardPort = int(v)
 		}
 	} else {
 		cfg.DashboardPort = 0
@@ -228,12 +236,45 @@ func LoadServerCommonConf(conf ini.File) (cfg *ServerCommonConf, err error) {
 		cfg.PrivilegeToken, _ = conf.Get("common", "privilege_token")
 
 		allowPortsStr, ok := conf.Get("common", "privilege_allow_ports")
-		// TODO: check if conflicts exist in port ranges
 		if ok {
-			cfg.PrivilegeAllowPorts, err = util.GetPortRanges(allowPortsStr)
-			if err != nil {
-				err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", err)
-				return
+			// e.g. 1000-2000,2001,2002,3000-4000
+			portRanges := strings.Split(allowPortsStr, ",")
+			for _, portRangeStr := range portRanges {
+				// 1000-2000 or 2001
+				portArray := strings.Split(portRangeStr, "-")
+				// length: only 1 or 2 is correct
+				rangeType := len(portArray)
+				if rangeType == 1 {
+					// single port
+					singlePort, errRet := strconv.ParseInt(portArray[0], 10, 64)
+					if errRet != nil {
+						err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet)
+						return
+					}
+					cfg.PrivilegeAllowPorts[int(singlePort)] = struct{}{}
+				} else if rangeType == 2 {
+					// range ports
+					min, errRet := strconv.ParseInt(portArray[0], 10, 64)
+					if errRet != nil {
+						err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet)
+						return
+					}
+					max, errRet := strconv.ParseInt(portArray[1], 10, 64)
+					if errRet != nil {
+						err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", errRet)
+						return
+					}
+					if max < min {
+						err = fmt.Errorf("Parse conf error: privilege_allow_ports range incorrect")
+						return
+					}
+					for i := min; i <= max; i++ {
+						cfg.PrivilegeAllowPorts[int(i)] = struct{}{}
+					}
+				} else {
+					err = fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect")
+					return
+				}
 			}
 		}
 	}

+ 2 - 2
models/msg/msg.go

@@ -92,7 +92,7 @@ type Login struct {
 type LoginResp struct {
 	Version       string `json:"version"`
 	RunId         string `json:"run_id"`
-	ServerUdpPort int64  `json:"server_udp_port"`
+	ServerUdpPort int    `json:"server_udp_port"`
 	Error         string `json:"error"`
 }
 
@@ -104,7 +104,7 @@ type NewProxy struct {
 	UseCompression bool   `json:"use_compression"`
 
 	// tcp and udp only
-	RemotePort int64 `json:"remote_port"`
+	RemotePort int `json:"remote_port"`
 
 	// http and https only
 	CustomDomains     []string `json:"custom_domains"`

+ 1 - 1
server/dashboard.go

@@ -32,7 +32,7 @@ var (
 	httpServerWriteTimeout = 10 * time.Second
 )
 
-func RunDashboardServer(addr string, port int64) (err error) {
+func RunDashboardServer(addr string, port int) (err error) {
 	// url router
 	router := httprouter.New()
 

+ 2 - 2
server/dashboard_api.go

@@ -36,8 +36,8 @@ type ServerInfoResp struct {
 	GeneralResponse
 
 	Version          string `json:"version"`
-	VhostHttpPort    int64  `json:"vhost_http_port"`
-	VhostHttpsPort   int64  `json:"vhost_https_port"`
+	VhostHttpPort    int    `json:"vhost_http_port"`
+	VhostHttpsPort   int    `json:"vhost_https_port"`
 	AuthTimeout      int64  `json:"auth_timeout"`
 	SubdomainHost    string `json:"subdomain_host"`
 	MaxPoolCount     int64  `json:"max_pool_count"`

+ 180 - 0
server/ports.go

@@ -0,0 +1,180 @@
+package server
+
+import (
+	"errors"
+	"fmt"
+	"net"
+	"sync"
+	"time"
+)
+
+const (
+	MinPort                    = 1025
+	MaxPort                    = 65535
+	MaxPortReservedDuration    = time.Duration(24) * time.Hour
+	CleanReservedPortsInterval = time.Hour
+)
+
+var (
+	ErrPortAlreadyUsed = errors.New("port already used")
+	ErrPortNotAllowed  = errors.New("port not allowed")
+	ErrPortUnAvailable = errors.New("port unavailable")
+	ErrNoAvailablePort = errors.New("no available port")
+)
+
+type PortCtx struct {
+	ProxyName  string
+	Port       int
+	Closed     bool
+	UpdateTime time.Time
+}
+
+type PortManager struct {
+	reservedPorts map[string]*PortCtx
+	usedPorts     map[int]*PortCtx
+	freePorts     map[int]struct{}
+
+	bindAddr string
+	netType  string
+	mu       sync.Mutex
+}
+
+func NewPortManager(netType string, bindAddr string, allowPorts map[int]struct{}) *PortManager {
+	pm := &PortManager{
+		reservedPorts: make(map[string]*PortCtx),
+		usedPorts:     make(map[int]*PortCtx),
+		freePorts:     make(map[int]struct{}),
+		bindAddr:      bindAddr,
+		netType:       netType,
+	}
+	if len(allowPorts) > 0 {
+		for port, _ := range allowPorts {
+			pm.freePorts[port] = struct{}{}
+		}
+	} else {
+		for i := MinPort; i <= MaxPort; i++ {
+			pm.freePorts[i] = struct{}{}
+		}
+	}
+	go pm.cleanReservedPortsWorker()
+	return pm
+}
+
+func (pm *PortManager) Acquire(name string, port int) (realPort int, err error) {
+	portCtx := &PortCtx{
+		ProxyName:  name,
+		Closed:     false,
+		UpdateTime: time.Now(),
+	}
+
+	var ok bool
+
+	pm.mu.Lock()
+	defer func() {
+		if err == nil {
+			portCtx.Port = realPort
+		}
+		pm.mu.Unlock()
+	}()
+
+	// check reserved ports first
+	if port == 0 {
+		if ctx, ok := pm.reservedPorts[name]; ok {
+			if pm.isPortAvailable(ctx.Port) {
+				realPort = ctx.Port
+				pm.usedPorts[realPort] = portCtx
+				pm.reservedPorts[name] = portCtx
+				delete(pm.freePorts, realPort)
+				return
+			}
+		}
+	}
+
+	if port == 0 {
+		// get random port
+		count := 0
+		maxTryTimes := 5
+		for k, _ := range pm.freePorts {
+			count++
+			if count > maxTryTimes {
+				break
+			}
+			if pm.isPortAvailable(k) {
+				realPort = k
+				pm.usedPorts[realPort] = portCtx
+				pm.reservedPorts[name] = portCtx
+				delete(pm.freePorts, realPort)
+				break
+			}
+		}
+		if realPort == 0 {
+			err = ErrNoAvailablePort
+		}
+	} else {
+		// specified port
+		if _, ok = pm.freePorts[port]; ok {
+			if pm.isPortAvailable(port) {
+				realPort = port
+				pm.usedPorts[realPort] = portCtx
+				pm.reservedPorts[name] = portCtx
+				delete(pm.freePorts, realPort)
+			} else {
+				err = ErrPortUnAvailable
+			}
+		} else {
+			if _, ok = pm.usedPorts[port]; ok {
+				err = ErrPortAlreadyUsed
+			} else {
+				err = ErrPortNotAllowed
+			}
+		}
+	}
+	return
+}
+
+func (pm *PortManager) isPortAvailable(port int) bool {
+	if pm.netType == "udp" {
+		addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pm.bindAddr, port))
+		if err != nil {
+			return false
+		}
+		l, err := net.ListenUDP("udp", addr)
+		if err != nil {
+			return false
+		}
+		l.Close()
+		return true
+	} else {
+		l, err := net.Listen(pm.netType, fmt.Sprintf("%s:%d", pm.bindAddr, port))
+		if err != nil {
+			return false
+		}
+		l.Close()
+		return true
+	}
+}
+
+func (pm *PortManager) Release(port int) {
+	pm.mu.Lock()
+	defer pm.mu.Unlock()
+	if ctx, ok := pm.usedPorts[port]; ok {
+		pm.freePorts[port] = struct{}{}
+		delete(pm.usedPorts, port)
+		ctx.Closed = true
+		ctx.UpdateTime = time.Now()
+	}
+}
+
+// Release reserved port if it isn't used in last 24 hours.
+func (pm *PortManager) cleanReservedPortsWorker() {
+	for {
+		time.Sleep(CleanReservedPortsInterval)
+		pm.mu.Lock()
+		for name, ctx := range pm.reservedPorts {
+			if ctx.Closed && time.Since(ctx.UpdateTime) > MaxPortReservedDuration {
+				delete(pm.reservedPorts, name)
+			}
+		}
+		pm.mu.Unlock()
+	}
+}

+ 32 - 4
server/proxy.go

@@ -165,11 +165,24 @@ func NewProxy(ctl *Control, pxyConf config.ProxyConf) (pxy Proxy, err error) {
 type TcpProxy struct {
 	BaseProxy
 	cfg *config.TcpProxyConf
+
+	realPort int
 }
 
 func (pxy *TcpProxy) Run() (remoteAddr string, err error) {
-	remoteAddr = fmt.Sprintf(":%d", pxy.cfg.RemotePort)
-	listener, errRet := frpNet.ListenTcp(config.ServerCommonCfg.ProxyBindAddr, pxy.cfg.RemotePort)
+	pxy.realPort, err = pxy.ctl.svr.tcpPortManager.Acquire(pxy.name, pxy.cfg.RemotePort)
+	if err != nil {
+		return
+	}
+	defer func() {
+		if err != nil {
+			pxy.ctl.svr.tcpPortManager.Release(pxy.realPort)
+		}
+	}()
+
+	remoteAddr = fmt.Sprintf(":%d", pxy.realPort)
+	pxy.cfg.RemotePort = pxy.realPort
+	listener, errRet := frpNet.ListenTcp(config.ServerCommonCfg.ProxyBindAddr, pxy.realPort)
 	if errRet != nil {
 		err = errRet
 		return
@@ -188,6 +201,7 @@ func (pxy *TcpProxy) GetConf() config.ProxyConf {
 
 func (pxy *TcpProxy) Close() {
 	pxy.BaseProxy.Close()
+	pxy.ctl.svr.tcpPortManager.Release(pxy.realPort)
 }
 
 type HttpProxy struct {
@@ -412,6 +426,8 @@ type UdpProxy struct {
 	BaseProxy
 	cfg *config.UdpProxyConf
 
+	realPort int
+
 	// udpConn is the listener of udp packages
 	udpConn *net.UDPConn
 
@@ -432,8 +448,19 @@ type UdpProxy struct {
 }
 
 func (pxy *UdpProxy) Run() (remoteAddr string, err error) {
-	remoteAddr = fmt.Sprintf(":%d", pxy.cfg.RemotePort)
-	addr, errRet := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", config.ServerCommonCfg.ProxyBindAddr, pxy.cfg.RemotePort))
+	pxy.realPort, err = pxy.ctl.svr.udpPortManager.Acquire(pxy.name, pxy.cfg.RemotePort)
+	if err != nil {
+		return
+	}
+	defer func() {
+		if err != nil {
+			pxy.ctl.svr.udpPortManager.Release(pxy.realPort)
+		}
+	}()
+
+	remoteAddr = fmt.Sprintf(":%d", pxy.realPort)
+	pxy.cfg.RemotePort = pxy.realPort
+	addr, errRet := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", config.ServerCommonCfg.ProxyBindAddr, pxy.realPort))
 	if errRet != nil {
 		err = errRet
 		return
@@ -581,6 +608,7 @@ func (pxy *UdpProxy) Close() {
 		close(pxy.readCh)
 		close(pxy.sendCh)
 	}
+	pxy.ctl.svr.udpPortManager.Release(pxy.realPort)
 }
 
 // HandleUserTcpConnection is used for incoming tcp user connections.

+ 9 - 1
server/service.go

@@ -60,17 +60,25 @@ type Service struct {
 	// Manage all visitor listeners.
 	visitorManager *VisitorManager
 
+	// Manage all tcp ports.
+	tcpPortManager *PortManager
+
+	// Manage all udp ports.
+	udpPortManager *PortManager
+
 	// Controller for nat hole connections.
 	natHoleController *NatHoleController
 }
 
 func NewService() (svr *Service, err error) {
+	cfg := config.ServerCommonCfg
 	svr = &Service{
 		ctlManager:     NewControlManager(),
 		pxyManager:     NewProxyManager(),
 		visitorManager: NewVisitorManager(),
+		tcpPortManager: NewPortManager("tcp", cfg.ProxyBindAddr, cfg.PrivilegeAllowPorts),
+		udpPortManager: NewPortManager("udp", cfg.ProxyBindAddr, cfg.PrivilegeAllowPorts),
 	}
-	cfg := config.ServerCommonCfg
 
 	// Init assets.
 	err = assets.Load(cfg.AssetsDir)

+ 11 - 11
tests/func_test.go

@@ -10,28 +10,28 @@ import (
 
 var (
 	TEST_STR                    = "frp is a fast reverse proxy to help you expose a local server behind a NAT or firewall to the internet."
-	TEST_TCP_PORT        int64  = 10701
-	TEST_TCP_FRP_PORT    int64  = 10801
-	TEST_TCP_EC_FRP_PORT int64  = 10901
+	TEST_TCP_PORT        int    = 10701
+	TEST_TCP_FRP_PORT    int    = 10801
+	TEST_TCP_EC_FRP_PORT int    = 10901
 	TEST_TCP_ECHO_STR    string = "tcp type:" + TEST_STR
 
-	TEST_UDP_PORT        int64  = 10702
-	TEST_UDP_FRP_PORT    int64  = 10802
-	TEST_UDP_EC_FRP_PORT int64  = 10902
+	TEST_UDP_PORT        int    = 10702
+	TEST_UDP_FRP_PORT    int    = 10802
+	TEST_UDP_EC_FRP_PORT int    = 10902
 	TEST_UDP_ECHO_STR    string = "udp type:" + TEST_STR
 
 	TEST_UNIX_DOMAIN_ADDR     string = "/tmp/frp_echo_server.sock"
-	TEST_UNIX_DOMAIN_FRP_PORT int64  = 10803
+	TEST_UNIX_DOMAIN_FRP_PORT int    = 10803
 	TEST_UNIX_DOMAIN_STR      string = "unix domain type:" + TEST_STR
 
-	TEST_HTTP_PORT       int64  = 10704
-	TEST_HTTP_FRP_PORT   int64  = 10804
+	TEST_HTTP_PORT       int    = 10704
+	TEST_HTTP_FRP_PORT   int    = 10804
 	TEST_HTTP_NORMAL_STR string = "http normal string: " + TEST_STR
 	TEST_HTTP_FOO_STR    string = "http foo string: " + TEST_STR
 	TEST_HTTP_BAR_STR    string = "http bar string: " + TEST_STR
 
-	TEST_STCP_FRP_PORT    int64  = 10805
-	TEST_STCP_EC_FRP_PORT int64  = 10905
+	TEST_STCP_FRP_PORT    int    = 10805
+	TEST_STCP_EC_FRP_PORT int    = 10905
 	TEST_STCP_ECHO_STR    string = "stcp type:" + TEST_STR
 )
 

+ 1 - 1
utils/net/kcp.go

@@ -31,7 +31,7 @@ type KcpListener struct {
 	log.Logger
 }
 
-func ListenKcp(bindAddr string, bindPort int64) (l *KcpListener, err error) {
+func ListenKcp(bindAddr string, bindPort int) (l *KcpListener, err error) {
 	listener, err := kcp.ListenWithOptions(fmt.Sprintf("%s:%d", bindAddr, bindPort), nil, 10, 3)
 	if err != nil {
 		return l, err

+ 1 - 1
utils/net/tcp.go

@@ -33,7 +33,7 @@ type TcpListener struct {
 	log.Logger
 }
 
-func ListenTcp(bindAddr string, bindPort int64) (l *TcpListener, err error) {
+func ListenTcp(bindAddr string, bindPort int) (l *TcpListener, err error) {
 	tcpAddr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
 	if err != nil {
 		return l, err

+ 1 - 1
utils/net/udp.go

@@ -167,7 +167,7 @@ type UdpListener struct {
 	log.Logger
 }
 
-func ListenUDP(bindAddr string, bindPort int64) (l *UdpListener, err error) {
+func ListenUDP(bindAddr string, bindPort int) (l *UdpListener, err error) {
 	udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
 	if err != nil {
 		return l, err

+ 0 - 65
utils/util/util.go

@@ -19,8 +19,6 @@ import (
 	"crypto/rand"
 	"encoding/hex"
 	"fmt"
-	"strconv"
-	"strings"
 )
 
 // RandId return a rand string used in frp.
@@ -48,69 +46,6 @@ func GetAuthKey(token string, timestamp int64) (key string) {
 	return hex.EncodeToString(data)
 }
 
-// for example: rangeStr is "1000-2000,2001,2002,3000-4000", return an array as port ranges.
-func GetPortRanges(rangeStr string) (portRanges [][2]int64, err error) {
-	// for example: 1000-2000,2001,2002,3000-4000
-	rangeArray := strings.Split(rangeStr, ",")
-	for _, portRangeStr := range rangeArray {
-		// 1000-2000 or 2001
-		portArray := strings.Split(portRangeStr, "-")
-		// length: only 1 or 2 is correct
-		rangeType := len(portArray)
-		if rangeType == 1 {
-			singlePort, err := strconv.ParseInt(portArray[0], 10, 64)
-			if err != nil {
-				return [][2]int64{}, err
-			}
-			portRanges = append(portRanges, [2]int64{singlePort, singlePort})
-		} else if rangeType == 2 {
-			min, err := strconv.ParseInt(portArray[0], 10, 64)
-			if err != nil {
-				return [][2]int64{}, err
-			}
-			max, err := strconv.ParseInt(portArray[1], 10, 64)
-			if err != nil {
-				return [][2]int64{}, err
-			}
-			if max < min {
-				return [][2]int64{}, fmt.Errorf("range incorrect")
-			}
-			portRanges = append(portRanges, [2]int64{min, max})
-		} else {
-			return [][2]int64{}, fmt.Errorf("format error")
-		}
-	}
-	return portRanges, nil
-}
-
-func ContainsPort(portRanges [][2]int64, port int64) bool {
-	for _, pr := range portRanges {
-		if port >= pr[0] && port <= pr[1] {
-			return true
-		}
-	}
-	return false
-}
-
-func PortRangesCut(portRanges [][2]int64, port int64) [][2]int64 {
-	var tmpRanges [][2]int64
-	for _, pr := range portRanges {
-		if port >= pr[0] && port <= pr[1] {
-			leftRange := [2]int64{pr[0], port - 1}
-			rightRange := [2]int64{port + 1, pr[1]}
-			if leftRange[0] <= leftRange[1] {
-				tmpRanges = append(tmpRanges, leftRange)
-			}
-			if rightRange[0] <= rightRange[1] {
-				tmpRanges = append(tmpRanges, rightRange)
-			}
-		} else {
-			tmpRanges = append(tmpRanges, pr)
-		}
-	}
-	return tmpRanges
-}
-
 func CanonicalAddr(host string, port int) (addr string) {
 	if port == 80 || port == 443 {
 		addr = host

+ 0 - 64
utils/util/util_test.go

@@ -20,67 +20,3 @@ func TestGetAuthKey(t *testing.T) {
 	t.Log(key)
 	assert.Equal("6df41a43725f0c770fd56379e12acf8c", key)
 }
-
-func TestGetPortRanges(t *testing.T) {
-	assert := assert.New(t)
-
-	rangesStr := "2000-3000,3001,4000-50000"
-	expect := [][2]int64{
-		[2]int64{2000, 3000},
-		[2]int64{3001, 3001},
-		[2]int64{4000, 50000},
-	}
-	actual, err := GetPortRanges(rangesStr)
-	assert.Nil(err)
-	t.Log(actual)
-	assert.Equal(expect, actual)
-}
-
-func TestContainsPort(t *testing.T) {
-	assert := assert.New(t)
-
-	rangesStr := "2000-3000,3001,4000-50000"
-	portRanges, err := GetPortRanges(rangesStr)
-	assert.Nil(err)
-
-	type Case struct {
-		Port   int64
-		Answer bool
-	}
-	cases := []Case{
-		Case{
-			Port:   3001,
-			Answer: true,
-		},
-		Case{
-			Port:   3002,
-			Answer: false,
-		},
-		Case{
-			Port:   44444,
-			Answer: true,
-		},
-	}
-	for _, elem := range cases {
-		ok := ContainsPort(portRanges, elem.Port)
-		assert.Equal(elem.Answer, ok)
-	}
-}
-
-func TestPortRangesCut(t *testing.T) {
-	assert := assert.New(t)
-
-	rangesStr := "2000-3000,3001,4000-50000"
-	portRanges, err := GetPortRanges(rangesStr)
-	assert.Nil(err)
-
-	expect := [][2]int64{
-		[2]int64{2000, 3000},
-		[2]int64{3001, 3001},
-		[2]int64{4000, 44443},
-		[2]int64{44445, 50000},
-	}
-	actual := PortRangesCut(portRanges, 44444)
-	t.Log(actual)
-	assert.Equal(expect, actual)
-}