Sfoglia il codice sorgente

vhost: check host and location for url router

fatedier 8 anni fa
parent
commit
04a4591caa
3 ha cambiato i file con 27 aggiunte e 3 eliminazioni
  1. 2 1
      src/models/server/server.go
  2. 20 2
      src/utils/vhost/router.go
  3. 5 0
      src/utils/vhost/vhost.go

+ 2 - 1
src/models/server/server.go

@@ -271,6 +271,8 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) {
 
 func (p *ProxyServer) Close() {
 	p.Lock()
+	defer p.Unlock()
+
 	if p.Status != consts.Closed {
 		p.Status = consts.Closed
 		for _, l := range p.listeners {
@@ -298,7 +300,6 @@ func (p *ProxyServer) Close() {
 	if p.PrivilegeMode {
 		DeleteProxy(p.Name)
 	}
-	p.Unlock()
 }
 
 func (p *ProxyServer) WaitUserConn() (closeFlag bool) {

+ 20 - 2
src/utils/vhost/router.go

@@ -72,7 +72,7 @@ func (r *VhostRouters) Get(host, path string) (vr *VhostRouter, exist bool) {
 		return
 	}
 
-	//can't support load balance,will to do
+	// can't support load balance, will to do
 	for _, vr = range vrs {
 		if strings.HasPrefix(path, vr.location) {
 			return vr, true
@@ -82,7 +82,25 @@ func (r *VhostRouters) Get(host, path string) (vr *VhostRouter, exist bool) {
 	return
 }
 
-//sort by location
+func (r *VhostRouters) Exist(host, path string) (vr *VhostRouter, exist bool) {
+	r.mutex.RLock()
+	defer r.mutex.RUnlock()
+
+	vrs, found := r.RouterByDomain[host]
+	if !found {
+		return
+	}
+
+	for _, vr = range vrs {
+		if path == vr.location {
+			return vr, true
+		}
+	}
+
+	return
+}
+
+// sort by location
 type ByLocation []*VhostRouter
 
 func (a ByLocation) Len() int {

+ 5 - 0
src/utils/vhost/vhost.go

@@ -56,6 +56,11 @@ func (v *VhostMuxer) Listen(name, location, rewriteHost, userName, passWord stri
 	v.mutex.Lock()
 	defer v.mutex.Unlock()
 
+	_, ok := v.registryRouter.Exist(name, location)
+	if ok {
+		return nil, fmt.Errorf("hostname [%s] location [%s] is already registered", name, location)
+	}
+
 	l = &Listener{
 		name:        name,
 		rewriteHost: rewriteHost,