1
0
Эх сурвалжийг харах

support http load balancing

fatedier 5 жил өмнө
parent
commit
b3ed863021

+ 3 - 0
server/controller/resource.go

@@ -29,6 +29,9 @@ type ResourceController struct {
 	// Tcp Group Controller
 	TcpGroupCtl *group.TcpGroupCtl
 
+	// HTTP Group Controller
+	HTTPGroupCtl *group.HTTPGroupController
+
 	// Manage all tcp ports
 	TcpPortManager *ports.PortManager
 

+ 1 - 0
server/group/group.go

@@ -23,4 +23,5 @@ var (
 	ErrGroupParamsInvalid = errors.New("group params invalid")
 	ErrListenerClosed     = errors.New("group listener closed")
 	ErrGroupDifferentPort = errors.New("group should have same remote port")
+	ErrProxyRepeated      = errors.New("group proxy repeated")
 )

+ 157 - 0
server/group/http.go

@@ -0,0 +1,157 @@
+package group
+
+import (
+	"fmt"
+	"sync"
+	"sync/atomic"
+
+	frpNet "github.com/fatedier/frp/utils/net"
+
+	"github.com/fatedier/frp/utils/vhost"
+)
+
+type HTTPGroupController struct {
+	groups map[string]*HTTPGroup
+
+	vhostRouter *vhost.VhostRouters
+
+	mu sync.Mutex
+}
+
+func NewHTTPGroupController(vhostRouter *vhost.VhostRouters) *HTTPGroupController {
+	return &HTTPGroupController{
+		groups:      make(map[string]*HTTPGroup),
+		vhostRouter: vhostRouter,
+	}
+}
+
+func (ctl *HTTPGroupController) Register(proxyName, group, groupKey string,
+	routeConfig vhost.VhostRouteConfig) (err error) {
+
+	indexKey := httpGroupIndex(group, routeConfig.Domain, routeConfig.Location)
+	ctl.mu.Lock()
+	g, ok := ctl.groups[indexKey]
+	if !ok {
+		g = NewHTTPGroup(ctl)
+		ctl.groups[indexKey] = g
+	}
+	ctl.mu.Unlock()
+
+	return g.Register(proxyName, group, groupKey, routeConfig)
+}
+
+func (ctl *HTTPGroupController) UnRegister(proxyName, group, domain, location string) {
+	indexKey := httpGroupIndex(group, domain, location)
+	ctl.mu.Lock()
+	defer ctl.mu.Unlock()
+	g, ok := ctl.groups[indexKey]
+	if !ok {
+		return
+	}
+
+	isEmpty := g.UnRegister(proxyName)
+	if isEmpty {
+		delete(ctl.groups, indexKey)
+	}
+}
+
+type HTTPGroup struct {
+	group    string
+	groupKey string
+	domain   string
+	location string
+
+	createFuncs map[string]vhost.CreateConnFunc
+	pxyNames    []string
+	index       uint64
+	ctl         *HTTPGroupController
+	mu          sync.RWMutex
+}
+
+func NewHTTPGroup(ctl *HTTPGroupController) *HTTPGroup {
+	return &HTTPGroup{
+		createFuncs: make(map[string]vhost.CreateConnFunc),
+		pxyNames:    make([]string, 0),
+		ctl:         ctl,
+	}
+}
+
+func (g *HTTPGroup) Register(proxyName, group, groupKey string,
+	routeConfig vhost.VhostRouteConfig) (err error) {
+
+	g.mu.Lock()
+	defer g.mu.Unlock()
+	if len(g.createFuncs) == 0 {
+		// the first proxy in this group
+		tmp := routeConfig // copy object
+		tmp.CreateConnFn = g.createConn
+		err = g.ctl.vhostRouter.Add(routeConfig.Domain, routeConfig.Location, &tmp)
+		if err != nil {
+			return
+		}
+
+		g.group = group
+		g.groupKey = groupKey
+		g.domain = routeConfig.Domain
+		g.location = routeConfig.Location
+	} else {
+		if g.group != group || g.domain != routeConfig.Domain || g.location != routeConfig.Location {
+			err = ErrGroupParamsInvalid
+			return
+		}
+		if g.groupKey != groupKey {
+			err = ErrGroupAuthFailed
+			return
+		}
+	}
+	if _, ok := g.createFuncs[proxyName]; ok {
+		err = ErrProxyRepeated
+		return
+	}
+	g.createFuncs[proxyName] = routeConfig.CreateConnFn
+	g.pxyNames = append(g.pxyNames, proxyName)
+	return nil
+}
+
+func (g *HTTPGroup) UnRegister(proxyName string) (isEmpty bool) {
+	g.mu.Lock()
+	defer g.mu.Unlock()
+	delete(g.createFuncs, proxyName)
+	for i, name := range g.pxyNames {
+		if name == proxyName {
+			g.pxyNames = append(g.pxyNames[:i], g.pxyNames[i+1:]...)
+			break
+		}
+	}
+
+	if len(g.createFuncs) == 0 {
+		isEmpty = true
+		g.ctl.vhostRouter.Del(g.domain, g.location)
+	}
+	return
+}
+
+func (g *HTTPGroup) createConn(remoteAddr string) (frpNet.Conn, error) {
+	var f vhost.CreateConnFunc
+	newIndex := atomic.AddUint64(&g.index, 1)
+
+	g.mu.RLock()
+	group := g.group
+	domain := g.domain
+	location := g.location
+	if len(g.pxyNames) > 0 {
+		name := g.pxyNames[int(newIndex)%len(g.pxyNames)]
+		f, _ = g.createFuncs[name]
+	}
+	g.mu.RUnlock()
+
+	if f == nil {
+		return nil, fmt.Errorf("no CreateConnFunc for http group [%s], domain [%s], location [%s]", group, domain, location)
+	}
+
+	return f(remoteAddr)
+}
+
+func httpGroupIndex(group, domain, location string) string {
+	return fmt.Sprintf("%s_%s_%s", group, domain, location)
+}

+ 71 - 51
server/group/tcp.go

@@ -24,46 +24,47 @@ import (
 	gerr "github.com/fatedier/golib/errors"
 )
 
-type TcpGroupListener struct {
-	groupName string
-	group     *TcpGroup
+// TcpGroupCtl manage all TcpGroups
+type TcpGroupCtl struct {
+	groups map[string]*TcpGroup
 
-	addr    net.Addr
-	closeCh chan struct{}
+	// portManager is used to manage port
+	portManager *ports.PortManager
+	mu          sync.Mutex
 }
 
-func newTcpGroupListener(name string, group *TcpGroup, addr net.Addr) *TcpGroupListener {
-	return &TcpGroupListener{
-		groupName: name,
-		group:     group,
-		addr:      addr,
-		closeCh:   make(chan struct{}),
+// NewTcpGroupCtl return a new TcpGroupCtl
+func NewTcpGroupCtl(portManager *ports.PortManager) *TcpGroupCtl {
+	return &TcpGroupCtl{
+		groups:      make(map[string]*TcpGroup),
+		portManager: portManager,
 	}
 }
 
-func (ln *TcpGroupListener) Accept() (c net.Conn, err error) {
-	var ok bool
-	select {
-	case <-ln.closeCh:
-		return nil, ErrListenerClosed
-	case c, ok = <-ln.group.Accept():
-		if !ok {
-			return nil, ErrListenerClosed
-		}
-		return c, nil
+// Listen is the wrapper for TcpGroup's Listen
+// If there are no group, we will create one here
+func (tgc *TcpGroupCtl) Listen(proxyName string, group string, groupKey string,
+	addr string, port int) (l net.Listener, realPort int, err error) {
+
+	tgc.mu.Lock()
+	tcpGroup, ok := tgc.groups[group]
+	if !ok {
+		tcpGroup = NewTcpGroup(tgc)
+		tgc.groups[group] = tcpGroup
 	}
-}
+	tgc.mu.Unlock()
 
-func (ln *TcpGroupListener) Addr() net.Addr {
-	return ln.addr
+	return tcpGroup.Listen(proxyName, group, groupKey, addr, port)
 }
 
-func (ln *TcpGroupListener) Close() (err error) {
-	close(ln.closeCh)
-	ln.group.CloseListener(ln)
-	return
+// RemoveGroup remove TcpGroup from controller
+func (tgc *TcpGroupCtl) RemoveGroup(group string) {
+	tgc.mu.Lock()
+	defer tgc.mu.Unlock()
+	delete(tgc.groups, group)
 }
 
+// TcpGroup route connections to different proxies
 type TcpGroup struct {
 	group    string
 	groupKey string
@@ -79,6 +80,7 @@ type TcpGroup struct {
 	mu       sync.Mutex
 }
 
+// NewTcpGroup return a new TcpGroup
 func NewTcpGroup(ctl *TcpGroupCtl) *TcpGroup {
 	return &TcpGroup{
 		lns:      make([]*TcpGroupListener, 0),
@@ -87,10 +89,14 @@ func NewTcpGroup(ctl *TcpGroupCtl) *TcpGroup {
 	}
 }
 
+// Listen will return a new TcpGroupListener
+// if TcpGroup already has a listener, just add a new TcpGroupListener to the queues
+// otherwise, listen on the real address
 func (tg *TcpGroup) Listen(proxyName string, group string, groupKey string, addr string, port int) (ln *TcpGroupListener, realPort int, err error) {
 	tg.mu.Lock()
 	defer tg.mu.Unlock()
 	if len(tg.lns) == 0 {
+		// the first listener, listen on the real address
 		realPort, err = tg.ctl.portManager.Acquire(proxyName, port)
 		if err != nil {
 			return
@@ -114,6 +120,7 @@ func (tg *TcpGroup) Listen(proxyName string, group string, groupKey string, addr
 		}
 		go tg.worker()
 	} else {
+		// address and port in the same group must be equal
 		if tg.group != group || tg.addr != addr {
 			err = ErrGroupParamsInvalid
 			return
@@ -133,6 +140,7 @@ func (tg *TcpGroup) Listen(proxyName string, group string, groupKey string, addr
 	return
 }
 
+// worker is called when the real tcp listener has been created
 func (tg *TcpGroup) worker() {
 	for {
 		c, err := tg.tcpLn.Accept()
@@ -152,6 +160,7 @@ func (tg *TcpGroup) Accept() <-chan net.Conn {
 	return tg.acceptCh
 }
 
+// CloseListener remove the TcpGroupListener from the TcpGroup
 func (tg *TcpGroup) CloseListener(ln *TcpGroupListener) {
 	tg.mu.Lock()
 	defer tg.mu.Unlock()
@@ -169,36 +178,47 @@ func (tg *TcpGroup) CloseListener(ln *TcpGroupListener) {
 	}
 }
 
-type TcpGroupCtl struct {
-	groups map[string]*TcpGroup
+// TcpGroupListener
+type TcpGroupListener struct {
+	groupName string
+	group     *TcpGroup
 
-	portManager *ports.PortManager
-	mu          sync.Mutex
+	addr    net.Addr
+	closeCh chan struct{}
 }
 
-func NewTcpGroupCtl(portManager *ports.PortManager) *TcpGroupCtl {
-	return &TcpGroupCtl{
-		groups:      make(map[string]*TcpGroup),
-		portManager: portManager,
+func newTcpGroupListener(name string, group *TcpGroup, addr net.Addr) *TcpGroupListener {
+	return &TcpGroupListener{
+		groupName: name,
+		group:     group,
+		addr:      addr,
+		closeCh:   make(chan struct{}),
 	}
 }
 
-func (tgc *TcpGroupCtl) Listen(proxyNanme string, group string, groupKey string,
-	addr string, port int) (l net.Listener, realPort int, err error) {
-
-	tgc.mu.Lock()
-	defer tgc.mu.Unlock()
-	if tcpGroup, ok := tgc.groups[group]; ok {
-		return tcpGroup.Listen(proxyNanme, group, groupKey, addr, port)
-	} else {
-		tcpGroup = NewTcpGroup(tgc)
-		tgc.groups[group] = tcpGroup
-		return tcpGroup.Listen(proxyNanme, group, groupKey, addr, port)
+// Accept will accept connections from TcpGroup
+func (ln *TcpGroupListener) Accept() (c net.Conn, err error) {
+	var ok bool
+	select {
+	case <-ln.closeCh:
+		return nil, ErrListenerClosed
+	case c, ok = <-ln.group.Accept():
+		if !ok {
+			return nil, ErrListenerClosed
+		}
+		return c, nil
 	}
 }
 
-func (tgc *TcpGroupCtl) RemoveGroup(group string) {
-	tgc.mu.Lock()
-	defer tgc.mu.Unlock()
-	delete(tgc.groups, group)
+func (ln *TcpGroupListener) Addr() net.Addr {
+	return ln.addr
+}
+
+// Close close the listener
+func (ln *TcpGroupListener) Close() (err error) {
+	close(ln.closeCh)
+
+	// remove self from TcpGroup
+	ln.group.CloseListener(ln)
+	return
 }

+ 51 - 17
server/proxy/http.go

@@ -50,6 +50,12 @@ func (pxy *HttpProxy) Run() (remoteAddr string, err error) {
 		locations = []string{""}
 	}
 
+	defer func() {
+		if err != nil {
+			pxy.Close()
+		}
+	}()
+
 	addrs := make([]string, 0)
 	for _, domain := range pxy.cfg.CustomDomains {
 		if domain == "" {
@@ -59,17 +65,31 @@ func (pxy *HttpProxy) Run() (remoteAddr string, err error) {
 		routeConfig.Domain = domain
 		for _, location := range locations {
 			routeConfig.Location = location
-			err = pxy.rc.HttpReverseProxy.Register(routeConfig)
-			if err != nil {
-				return
-			}
 			tmpDomain := routeConfig.Domain
 			tmpLocation := routeConfig.Location
-			addrs = append(addrs, util.CanonicalAddr(tmpDomain, int(g.GlbServerCfg.VhostHttpPort)))
-			pxy.closeFuncs = append(pxy.closeFuncs, func() {
-				pxy.rc.HttpReverseProxy.UnRegister(tmpDomain, tmpLocation)
-			})
-			pxy.Info("http proxy listen for host [%s] location [%s]", routeConfig.Domain, routeConfig.Location)
+
+			// handle group
+			if pxy.cfg.Group != "" {
+				err = pxy.rc.HTTPGroupCtl.Register(pxy.name, pxy.cfg.Group, pxy.cfg.GroupKey, routeConfig)
+				if err != nil {
+					return
+				}
+
+				pxy.closeFuncs = append(pxy.closeFuncs, func() {
+					pxy.rc.HTTPGroupCtl.UnRegister(pxy.name, pxy.cfg.Group, tmpDomain, tmpLocation)
+				})
+			} else {
+				// no group
+				err = pxy.rc.HttpReverseProxy.Register(routeConfig)
+				if err != nil {
+					return
+				}
+				pxy.closeFuncs = append(pxy.closeFuncs, func() {
+					pxy.rc.HttpReverseProxy.UnRegister(tmpDomain, tmpLocation)
+				})
+			}
+			addrs = append(addrs, util.CanonicalAddr(routeConfig.Domain, int(g.GlbServerCfg.VhostHttpPort)))
+			pxy.Info("http proxy listen for host [%s] location [%s] group [%s]", routeConfig.Domain, routeConfig.Location, pxy.cfg.Group)
 		}
 	}
 
@@ -77,17 +97,31 @@ func (pxy *HttpProxy) Run() (remoteAddr string, err error) {
 		routeConfig.Domain = pxy.cfg.SubDomain + "." + g.GlbServerCfg.SubDomainHost
 		for _, location := range locations {
 			routeConfig.Location = location
-			err = pxy.rc.HttpReverseProxy.Register(routeConfig)
-			if err != nil {
-				return
-			}
 			tmpDomain := routeConfig.Domain
 			tmpLocation := routeConfig.Location
+
+			// handle group
+			if pxy.cfg.Group != "" {
+				err = pxy.rc.HTTPGroupCtl.Register(pxy.name, pxy.cfg.Group, pxy.cfg.GroupKey, routeConfig)
+				if err != nil {
+					return
+				}
+
+				pxy.closeFuncs = append(pxy.closeFuncs, func() {
+					pxy.rc.HTTPGroupCtl.UnRegister(pxy.name, pxy.cfg.Group, tmpDomain, tmpLocation)
+				})
+			} else {
+				err = pxy.rc.HttpReverseProxy.Register(routeConfig)
+				if err != nil {
+					return
+				}
+				pxy.closeFuncs = append(pxy.closeFuncs, func() {
+					pxy.rc.HttpReverseProxy.UnRegister(tmpDomain, tmpLocation)
+				})
+			}
 			addrs = append(addrs, util.CanonicalAddr(tmpDomain, g.GlbServerCfg.VhostHttpPort))
-			pxy.closeFuncs = append(pxy.closeFuncs, func() {
-				pxy.rc.HttpReverseProxy.UnRegister(tmpDomain, tmpLocation)
-			})
-			pxy.Info("http proxy listen for host [%s] location [%s]", routeConfig.Domain, routeConfig.Location)
+
+			pxy.Info("http proxy listen for host [%s] location [%s] group [%s]", routeConfig.Domain, routeConfig.Location, pxy.cfg.Group)
 		}
 	}
 	remoteAddr = strings.Join(addrs, ",")

+ 5 - 0
server/proxy/https.go

@@ -31,6 +31,11 @@ type HttpsProxy struct {
 func (pxy *HttpsProxy) Run() (remoteAddr string, err error) {
 	routeConfig := &vhost.VhostRouteConfig{}
 
+	defer func() {
+		if err != nil {
+			pxy.Close()
+		}
+	}()
 	addrs := make([]string, 0)
 	for _, domain := range pxy.cfg.CustomDomains {
 		if domain == "" {

+ 2 - 0
server/proxy/proxy.go

@@ -72,6 +72,8 @@ func (pxy *BaseProxy) Close() {
 	}
 }
 
+// GetWorkConnFromPool try to get a new work connections from pool
+// for quickly response, we immediately send the StartWorkConn message to frpc after take out one from pool
 func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn frpNet.Conn, err error) {
 	// try all connections from the pool
 	for i := 0; i < pxy.poolCount+1; i++ {

+ 9 - 2
server/service.go

@@ -76,6 +76,9 @@ type Service struct {
 	// Manage all proxies
 	pxyManager *proxy.ProxyManager
 
+	// HTTP vhost router
+	httpVhostRouter *vhost.VhostRouters
+
 	// All resource managers and controllers
 	rc *controller.ResourceController
 
@@ -95,12 +98,16 @@ func NewService() (svr *Service, err error) {
 			TcpPortManager: ports.NewPortManager("tcp", cfg.ProxyBindAddr, cfg.AllowPorts),
 			UdpPortManager: ports.NewPortManager("udp", cfg.ProxyBindAddr, cfg.AllowPorts),
 		},
-		tlsConfig: generateTLSConfig(),
+		httpVhostRouter: vhost.NewVhostRouters(),
+		tlsConfig:       generateTLSConfig(),
 	}
 
 	// Init group controller
 	svr.rc.TcpGroupCtl = group.NewTcpGroupCtl(svr.rc.TcpPortManager)
 
+	// Init HTTP group controller
+	svr.rc.HTTPGroupCtl = group.NewHTTPGroupController(svr.httpVhostRouter)
+
 	// Init assets
 	err = assets.Load(cfg.AssetsDir)
 	if err != nil {
@@ -159,7 +166,7 @@ func NewService() (svr *Service, err error) {
 	if cfg.VhostHttpPort > 0 {
 		rp := vhost.NewHttpReverseProxy(vhost.HttpReverseProxyOptions{
 			ResponseHeaderTimeoutS: cfg.VhostHttpTimeout,
-		})
+		}, svr.httpVhostRouter)
 		svr.rc.HttpReverseProxy = rp
 
 		address := fmt.Sprintf("%s:%d", cfg.ProxyBindAddr, cfg.VhostHttpPort)

+ 12 - 20
utils/vhost/http.go

@@ -23,7 +23,6 @@ import (
 	"net"
 	"net/http"
 	"strings"
-	"sync"
 	"time"
 
 	frpLog "github.com/fatedier/frp/utils/log"
@@ -32,8 +31,7 @@ import (
 )
 
 var (
-	ErrRouterConfigConflict = errors.New("router config conflict")
-	ErrNoDomain             = errors.New("no such domain")
+	ErrNoDomain = errors.New("no such domain")
 )
 
 func getHostFromAddr(addr string) (host string) {
@@ -51,21 +49,19 @@ type HttpReverseProxyOptions struct {
 }
 
 type HttpReverseProxy struct {
-	proxy *ReverseProxy
-
+	proxy       *ReverseProxy
 	vhostRouter *VhostRouters
 
 	responseHeaderTimeout time.Duration
-	cfgMu                 sync.RWMutex
 }
 
-func NewHttpReverseProxy(option HttpReverseProxyOptions) *HttpReverseProxy {
+func NewHttpReverseProxy(option HttpReverseProxyOptions, vhostRouter *VhostRouters) *HttpReverseProxy {
 	if option.ResponseHeaderTimeoutS <= 0 {
 		option.ResponseHeaderTimeoutS = 60
 	}
 	rp := &HttpReverseProxy{
 		responseHeaderTimeout: time.Duration(option.ResponseHeaderTimeoutS) * time.Second,
-		vhostRouter:           NewVhostRouters(),
+		vhostRouter:           vhostRouter,
 	}
 	proxy := &ReverseProxy{
 		Director: func(req *http.Request) {
@@ -106,21 +102,18 @@ func NewHttpReverseProxy(option HttpReverseProxyOptions) *HttpReverseProxy {
 	return rp
 }
 
+// Register register the route config to reverse proxy
+// reverse proxy will use CreateConnFn from routeCfg to create a connection to the remote service
 func (rp *HttpReverseProxy) Register(routeCfg VhostRouteConfig) error {
-	rp.cfgMu.Lock()
-	defer rp.cfgMu.Unlock()
-	_, ok := rp.vhostRouter.Exist(routeCfg.Domain, routeCfg.Location)
-	if ok {
-		return ErrRouterConfigConflict
-	} else {
-		rp.vhostRouter.Add(routeCfg.Domain, routeCfg.Location, &routeCfg)
+	err := rp.vhostRouter.Add(routeCfg.Domain, routeCfg.Location, &routeCfg)
+	if err != nil {
+		return err
 	}
 	return nil
 }
 
+// UnRegister unregister route config by domain and location
 func (rp *HttpReverseProxy) UnRegister(domain string, location string) {
-	rp.cfgMu.Lock()
-	defer rp.cfgMu.Unlock()
 	rp.vhostRouter.Del(domain, location)
 }
 
@@ -140,6 +133,7 @@ func (rp *HttpReverseProxy) GetHeaders(domain string, location string) (headers
 	return
 }
 
+// CreateConnection create a new connection by route config
 func (rp *HttpReverseProxy) CreateConnection(domain string, location string, remoteAddr string) (net.Conn, error) {
 	vr, ok := rp.getVhost(domain, location)
 	if ok {
@@ -163,10 +157,8 @@ func (rp *HttpReverseProxy) CheckAuth(domain, location, user, passwd string) boo
 	return true
 }
 
+// getVhost get vhost router by domain and location
 func (rp *HttpReverseProxy) getVhost(domain string, location string) (vr *VhostRouter, ok bool) {
-	rp.cfgMu.RLock()
-	defer rp.cfgMu.RUnlock()
-
 	// first we check the full hostname
 	// if not exist, then check the wildcard_domain such as *.example.com
 	vr, ok = rp.vhostRouter.Get(domain, location)

+ 12 - 5
utils/vhost/router.go

@@ -1,11 +1,16 @@
 package vhost
 
 import (
+	"errors"
 	"sort"
 	"strings"
 	"sync"
 )
 
+var (
+	ErrRouterConfigConflict = errors.New("router config conflict")
+)
+
 type VhostRouters struct {
 	RouterByDomain map[string][]*VhostRouter
 	mutex          sync.RWMutex
@@ -24,10 +29,14 @@ func NewVhostRouters() *VhostRouters {
 	}
 }
 
-func (r *VhostRouters) Add(domain, location string, payload interface{}) {
+func (r *VhostRouters) Add(domain, location string, payload interface{}) error {
 	r.mutex.Lock()
 	defer r.mutex.Unlock()
 
+	if _, exist := r.exist(domain, location); exist {
+		return ErrRouterConfigConflict
+	}
+
 	vrs, found := r.RouterByDomain[domain]
 	if !found {
 		vrs = make([]*VhostRouter, 0, 1)
@@ -42,6 +51,7 @@ func (r *VhostRouters) Add(domain, location string, payload interface{}) {
 
 	sort.Sort(sort.Reverse(ByLocation(vrs)))
 	r.RouterByDomain[domain] = vrs
+	return nil
 }
 
 func (r *VhostRouters) Del(domain, location string) {
@@ -80,10 +90,7 @@ func (r *VhostRouters) Get(host, path string) (vr *VhostRouter, exist bool) {
 	return
 }
 
-func (r *VhostRouters) Exist(host, path string) (vr *VhostRouter, exist bool) {
-	r.mutex.RLock()
-	defer r.mutex.RUnlock()
-
+func (r *VhostRouters) exist(host, path string) (vr *VhostRouter, exist bool) {
 	vrs, found := r.RouterByDomain[host]
 	if !found {
 		return

+ 5 - 14
utils/vhost/vhost.go

@@ -15,7 +15,6 @@ package vhost
 import (
 	"fmt"
 	"strings"
-	"sync"
 	"time"
 
 	"github.com/fatedier/frp/utils/log"
@@ -35,7 +34,6 @@ type VhostMuxer struct {
 	authFunc       httpAuthFunc
 	rewriteFunc    hostRewriteFunc
 	registryRouter *VhostRouters
-	mutex          sync.RWMutex
 }
 
 func NewVhostMuxer(listener frpNet.Listener, vhostFunc muxFunc, authFunc httpAuthFunc, rewriteFunc hostRewriteFunc, timeout time.Duration) (mux *VhostMuxer, err error) {
@@ -53,6 +51,7 @@ func NewVhostMuxer(listener frpNet.Listener, vhostFunc muxFunc, authFunc httpAut
 
 type CreateConnFunc func(remoteAddr string) (frpNet.Conn, error)
 
+// VhostRouteConfig is the params used to match HTTP requests
 type VhostRouteConfig struct {
 	Domain      string
 	Location    string
@@ -67,14 +66,6 @@ type VhostRouteConfig struct {
 // listen for a new domain name, if rewriteHost is not empty  and rewriteFunc is not nil
 // then rewrite the host header to rewriteHost
 func (v *VhostMuxer) Listen(cfg *VhostRouteConfig) (l *Listener, err error) {
-	v.mutex.Lock()
-	defer v.mutex.Unlock()
-
-	_, ok := v.registryRouter.Exist(cfg.Domain, cfg.Location)
-	if ok {
-		return nil, fmt.Errorf("hostname [%s] location [%s] is already registered", cfg.Domain, cfg.Location)
-	}
-
 	l = &Listener{
 		name:        cfg.Domain,
 		location:    cfg.Location,
@@ -85,14 +76,14 @@ func (v *VhostMuxer) Listen(cfg *VhostRouteConfig) (l *Listener, err error) {
 		accept:      make(chan frpNet.Conn),
 		Logger:      log.NewPrefixLogger(""),
 	}
-	v.registryRouter.Add(cfg.Domain, cfg.Location, l)
+	err = v.registryRouter.Add(cfg.Domain, cfg.Location, l)
+	if err != nil {
+		return
+	}
 	return l, nil
 }
 
 func (v *VhostMuxer) getListener(name, path string) (l *Listener, exist bool) {
-	v.mutex.RLock()
-	defer v.mutex.RUnlock()
-
 	// first we check the full hostname
 	// if not exist, then check the wildcard_domain such as *.example.com
 	vr, found := v.registryRouter.Get(name, path)