浏览代码

newhttp support BasicAuth

fatedier 7 年之前
父节点
当前提交
4cc5ddc012
共有 3 个文件被更改,包括 36 次插入23 次删除
  1. 7 6
      server/proxy.go
  2. 25 17
      utils/vhost/newhttp.go
  3. 4 0
      utils/vhost/vhost.go

+ 7 - 6
server/proxy.go

@@ -194,10 +194,11 @@ type HttpProxy struct {
 }
 
 func (pxy *HttpProxy) Run() (err error) {
-	routeConfig := &vhost.VhostRouteConfig{
-		RewriteHost: pxy.cfg.HostHeaderRewrite,
-		Username:    pxy.cfg.HttpUser,
-		Password:    pxy.cfg.HttpPwd,
+	routeConfig := vhost.VhostRouteConfig{
+		RewriteHost:  pxy.cfg.HostHeaderRewrite,
+		Username:     pxy.cfg.HttpUser,
+		Password:     pxy.cfg.HttpPwd,
+		CreateConnFn: pxy.GetRealConn,
 	}
 
 	locations := pxy.cfg.Locations
@@ -208,7 +209,7 @@ func (pxy *HttpProxy) Run() (err error) {
 		routeConfig.Domain = domain
 		for _, location := range locations {
 			routeConfig.Location = location
-			err := pxy.ctl.svr.httpReverseProxy.Register(routeConfig.Domain, routeConfig.Location, routeConfig.RewriteHost, pxy.GetRealConn)
+			err := pxy.ctl.svr.httpReverseProxy.Register(routeConfig)
 			if err != nil {
 				return err
 			}
@@ -225,7 +226,7 @@ func (pxy *HttpProxy) Run() (err error) {
 		routeConfig.Domain = pxy.cfg.SubDomain + "." + config.ServerCommonCfg.SubDomainHost
 		for _, location := range locations {
 			routeConfig.Location = location
-			err := pxy.ctl.svr.httpReverseProxy.Register(routeConfig.Domain, routeConfig.Location, routeConfig.RewriteHost, pxy.GetRealConn)
+			err := pxy.ctl.svr.httpReverseProxy.Register(routeConfig)
 			if err != nil {
 				return err
 			}

+ 25 - 17
utils/vhost/newhttp.go

@@ -26,7 +26,6 @@ import (
 	"time"
 
 	frpLog "github.com/fatedier/frp/utils/log"
-	frpNet "github.com/fatedier/frp/utils/net"
 	"github.com/fatedier/frp/utils/pool"
 )
 
@@ -47,13 +46,6 @@ func getHostFromAddr(addr string) (host string) {
 	return
 }
 
-type CreateConnFunc func() (frpNet.Conn, error)
-
-type ProxyOption struct {
-	RewriteHost string
-	DialFunc    CreateConnFunc
-}
-
 type HttpReverseProxy struct {
 	proxy *ReverseProxy
 	tr    *http.Transport
@@ -94,18 +86,14 @@ func NewHttpReverseProxy() *HttpReverseProxy {
 	return rp
 }
 
-func (rp *HttpReverseProxy) Register(domain string, location string, rewriteHost string, fn CreateConnFunc) error {
+func (rp *HttpReverseProxy) Register(routeCfg VhostRouteConfig) error {
 	rp.cfgMu.Lock()
 	defer rp.cfgMu.Unlock()
-	_, ok := rp.vhostRouter.Exist(domain, location)
+	_, ok := rp.vhostRouter.Exist(routeCfg.Domain, routeCfg.Location)
 	if ok {
 		return ErrRouterConfigConflict
 	} else {
-		payload := &ProxyOption{
-			RewriteHost: rewriteHost,
-			DialFunc:    fn,
-		}
-		rp.vhostRouter.Add(domain, location, payload)
+		rp.vhostRouter.Add(routeCfg.Domain, routeCfg.Location, &routeCfg)
 	}
 	return nil
 }
@@ -119,7 +107,7 @@ func (rp *HttpReverseProxy) UnRegister(domain string, location string) {
 func (rp *HttpReverseProxy) GetRealHost(domain string, location string) (host string) {
 	vr, ok := rp.getVhost(domain, location)
 	if ok {
-		host = vr.payload.(*ProxyOption).RewriteHost
+		host = vr.payload.(*VhostRouteConfig).RewriteHost
 	}
 	return
 }
@@ -127,7 +115,7 @@ func (rp *HttpReverseProxy) GetRealHost(domain string, location string) (host st
 func (rp *HttpReverseProxy) CreateConnection(domain string, location string) (net.Conn, error) {
 	vr, ok := rp.getVhost(domain, location)
 	if ok {
-		fn := vr.payload.(*ProxyOption).DialFunc
+		fn := vr.payload.(*VhostRouteConfig).CreateConnFn
 		if fn != nil {
 			return fn()
 		}
@@ -135,6 +123,18 @@ func (rp *HttpReverseProxy) CreateConnection(domain string, location string) (ne
 	return nil, ErrNoDomain
 }
 
+func (rp *HttpReverseProxy) CheckAuth(domain, location, user, passwd string) bool {
+	vr, ok := rp.getVhost(domain, location)
+	if ok {
+		checkUser := vr.payload.(*VhostRouteConfig).Username
+		checkPasswd := vr.payload.(*VhostRouteConfig).Password
+		if (checkUser != "" || checkPasswd != "") && (checkUser != user || checkPasswd != passwd) {
+			return false
+		}
+	}
+	return true
+}
+
 func (rp *HttpReverseProxy) getVhost(domain string, location string) (vr *VhostRouter, ok bool) {
 	rp.cfgMu.RLock()
 	defer rp.cfgMu.RUnlock()
@@ -157,6 +157,14 @@ func (rp *HttpReverseProxy) getVhost(domain string, location string) (vr *VhostR
 }
 
 func (rp *HttpReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
+	domain := getHostFromAddr(req.Host)
+	location := req.URL.Path
+	user, passwd, _ := req.BasicAuth()
+	if !rp.CheckAuth(domain, location, user, passwd) {
+		rw.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
+		http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
+		return
+	}
 	rp.proxy.ServeHTTP(rw, req)
 }
 

+ 4 - 0
utils/vhost/vhost.go

@@ -50,12 +50,16 @@ func NewVhostMuxer(listener frpNet.Listener, vhostFunc muxFunc, authFunc httpAut
 	return mux, nil
 }
 
+type CreateConnFunc func() (frpNet.Conn, error)
+
 type VhostRouteConfig struct {
 	Domain      string
 	Location    string
 	RewriteHost string
 	Username    string
 	Password    string
+
+	CreateConnFn CreateConnFunc
 }
 
 // listen for a new domain name, if rewriteHost is not empty  and rewriteFunc is not nil