Просмотр исходного кода

add host_header_rewrite in frpc.ini to rewrite your requests with a modified Host header

fatedier 8 лет назад
Родитель
Сommit
452e02adab

+ 1 - 0
conf/frpc.ini

@@ -52,3 +52,4 @@ local_ip = 127.0.0.1
 local_port = 80
 use_gzip = true
 custom_domains = web03.yourdomain.com
+host_header_rewrite = example.com

+ 8 - 9
src/frp/cmd/frpc/control.go

@@ -138,15 +138,14 @@ func loginToServer(cli *client.ProxyClient) (c *conn.Conn, err error) {
 
 	nowTime := time.Now().Unix()
 	req := &msg.ControlReq{
-		Type:          consts.NewCtlConn,
-		ProxyName:     cli.Name,
-		UseEncryption: cli.UseEncryption,
-		UseGzip:       cli.UseGzip,
-		PrivilegeMode: cli.PrivilegeMode,
-		ProxyType:     cli.Type,
-		LocalIp:       cli.LocalIp,
-		LocalPort:     cli.LocalPort,
-		Timestamp:     nowTime,
+		Type:              consts.NewCtlConn,
+		ProxyName:         cli.Name,
+		UseEncryption:     cli.UseEncryption,
+		UseGzip:           cli.UseGzip,
+		PrivilegeMode:     cli.PrivilegeMode,
+		ProxyType:         cli.Type,
+		HostHeaderRewrite: cli.HostHeaderRewrite,
+		Timestamp:         nowTime,
 	}
 	if cli.PrivilegeMode {
 		privilegeKey := pcrypto.GetAuthKey(cli.Name + client.PrivilegeToken + fmt.Sprintf("%d", nowTime))

+ 1 - 2
src/frp/cmd/frps/control.go

@@ -276,8 +276,7 @@ func doLogin(req *msg.ControlReq, c *conn.Conn) (ret int64, info string) {
 		// set infomations from frpc
 		s.UseEncryption = req.UseEncryption
 		s.UseGzip = req.UseGzip
-		s.ClientIp = req.LocalIp
-		s.ClientPort = req.LocalPort
+		s.HostHeaderRewrite = req.HostHeaderRewrite
 
 		// start proxy and listen for user connections, no block
 		err := s.Start(c)

+ 10 - 0
src/frp/models/client/config.go

@@ -140,6 +140,14 @@ func LoadConf(confFile string) (err error) {
 				proxyClient.UseGzip = true
 			}
 
+			if proxyClient.Type == "http" {
+				// host_header_rewrite
+				tmpStr, ok = section["host_header_rewrite"]
+				if ok {
+					proxyClient.HostHeaderRewrite = tmpStr
+				}
+			}
+
 			// privilege_mode
 			proxyClient.PrivilegeMode = false
 			tmpStr, ok = section["privilege_mode"]
@@ -167,6 +175,7 @@ func LoadConf(confFile string) (err error) {
 						return fmt.Errorf("Parse conf error: proxy [%s] remote_port not found", proxyClient.Name)
 					}
 				} else if proxyClient.Type == "http" {
+					// custom_domains
 					domainStr, ok := section["custom_domains"]
 					if ok {
 						proxyClient.CustomDomains = strings.Split(domainStr, ",")
@@ -180,6 +189,7 @@ func LoadConf(confFile string) (err error) {
 						return fmt.Errorf("Parse conf error: proxy [%s] custom_domains must be set when type equals http", proxyClient.Name)
 					}
 				} else if proxyClient.Type == "https" {
+					// custom_domains
 					domainStr, ok := section["custom_domains"]
 					if ok {
 						proxyClient.CustomDomains = strings.Split(domainStr, ",")

+ 8 - 10
src/frp/models/config/config.go

@@ -15,14 +15,12 @@
 package config
 
 type BaseConf struct {
-	Name           string
-	AuthToken      string
-	Type           string
-	UseEncryption  bool
-	UseGzip        bool
-	PrivilegeMode  bool
-	PrivilegeToken string
-	ClientIp       string
-	ClientPort     int64
-	ServerPort     int64
+	Name              string
+	AuthToken         string
+	Type              string
+	UseEncryption     bool
+	UseGzip           bool
+	PrivilegeMode     bool
+	PrivilegeToken    string
+	HostHeaderRewrite string
 }

+ 7 - 8
src/frp/models/msg/msg.go

@@ -26,16 +26,15 @@ type ControlReq struct {
 	AuthKey       string `json:"auth_key"`
 	UseEncryption bool   `json:"use_encryption"`
 	UseGzip       bool   `json:"use_gzip"`
-	LocalIp       string `json:"local_ip"`
-	LocalPort     int64  `json:"local_port"`
 
 	// configures used if privilege_mode is enabled
-	PrivilegeMode bool     `json:"privilege_mode"`
-	PrivilegeKey  string   `json:"privilege_key"`
-	ProxyType     string   `json:"proxy_type"`
-	RemotePort    int64    `json:"remote_port"`
-	CustomDomains []string `json:"custom_domains, omitempty"`
-	Timestamp     int64    `json:"timestamp"`
+	PrivilegeMode     bool     `json:"privilege_mode"`
+	PrivilegeKey      string   `json:"privilege_key"`
+	ProxyType         string   `json:"proxy_type"`
+	RemotePort        int64    `json:"remote_port"`
+	CustomDomains     []string `json:"custom_domains, omitempty"`
+	HostHeaderRewrite string   `json:"host_header_rewrite"`
+	Timestamp         int64    `json:"timestamp"`
 }
 
 type ControlRes struct {

+ 4 - 4
src/frp/models/server/server.go

@@ -64,7 +64,7 @@ func NewProxyServerFromCtlMsg(req *msg.ControlReq) (p *ProxyServer) {
 	p.BindAddr = BindAddr
 	p.ListenPort = req.RemotePort
 	p.CustomDomains = req.CustomDomains
-	p.ServerPort = VhostHttpPort
+	p.HostHeaderRewrite = req.HostHeaderRewrite
 	return
 }
 
@@ -80,7 +80,7 @@ func (p *ProxyServer) Init() {
 
 func (p *ProxyServer) Compare(p2 *ProxyServer) bool {
 	if p.Name != p2.Name || p.AuthToken != p2.AuthToken || p.Type != p2.Type ||
-		p.BindAddr != p2.BindAddr || p.ListenPort != p2.ListenPort {
+		p.BindAddr != p2.BindAddr || p.ListenPort != p2.ListenPort || p.HostHeaderRewrite != p2.HostHeaderRewrite {
 		return false
 	}
 	if len(p.CustomDomains) != len(p2.CustomDomains) {
@@ -114,7 +114,7 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) {
 		p.listeners = append(p.listeners, l)
 	} else if p.Type == "http" {
 		for _, domain := range p.CustomDomains {
-			l, err := VhostHttpMuxer.Listen(domain, p.Type, p.ClientIp, p.ClientPort, p.ServerPort)
+			l, err := VhostHttpMuxer.Listen(domain, p.HostHeaderRewrite)
 			if err != nil {
 				return err
 			}
@@ -122,7 +122,7 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) {
 		}
 	} else if p.Type == "https" {
 		for _, domain := range p.CustomDomains {
-			l, err := VhostHttpsMuxer.Listen(domain, p.Type, p.ClientIp, p.ClientPort, p.ServerPort)
+			l, err := VhostHttpsMuxer.Listen(domain, p.HostHeaderRewrite)
 			if err != nil {
 				return err
 			}

+ 5 - 0
src/frp/utils/conn/conn.go

@@ -117,8 +117,13 @@ func ConnectServer(host string, port int64) (c *Conn, err error) {
 	return c, nil
 }
 
+// if the tcpConn is different with c.TcpConn
+// you should call c.Close() first
 func (c *Conn) SetTcpConn(tcpConn net.Conn) {
+	c.mutex.Lock()
+	defer c.mutex.Unlock()
 	c.TcpConn = tcpConn
+	c.closeFlag = false
 	c.Reader = bufio.NewReader(c.TcpConn)
 }
 

+ 23 - 16
src/frp/utils/vhost/http.go

@@ -26,7 +26,6 @@ import (
 	"time"
 
 	"frp/utils/conn"
-	"frp/utils/log"
 )
 
 type HttpMuxer struct {
@@ -47,31 +46,28 @@ func GetHttpHostname(c *conn.Conn) (_ net.Conn, routerName string, err error) {
 }
 
 func NewHttpMuxer(listener *conn.Listener, timeout time.Duration) (*HttpMuxer, error) {
-	mux, err := NewVhostMuxer(listener, GetHttpHostname, timeout)
+	mux, err := NewVhostMuxer(listener, GetHttpHostname, HttpHostNameRewrite, timeout)
 	return &HttpMuxer{mux}, err
 }
 
-func HostNameRewrite(c *conn.Conn, clientHost string) (_ net.Conn, err error) {
-	log.Info("HostNameRewrite, clientHost: %s", clientHost)
+func HttpHostNameRewrite(c *conn.Conn, rewriteHost string) (_ net.Conn, err error) {
 	sc, rd := newShareConn(c.TcpConn)
 	var buff []byte
-	if buff, err = hostNameRewrite(rd, clientHost); err != nil {
+	if buff, err = hostNameRewrite(rd, rewriteHost); err != nil {
 		return sc, err
 	}
 	err = sc.WriteBuff(buff)
 	return sc, err
 }
 
-func hostNameRewrite(request io.Reader, clientHost string) (_ []byte, err error) {
+func hostNameRewrite(request io.Reader, rewriteHost string) (_ []byte, err error) {
 	buffer := make([]byte, 1024)
 	request.Read(buffer)
-	log.Debug("before hostNameRewrite:\n %s", string(buffer))
-	retBuffer, err := parseRequest(buffer, clientHost)
-	log.Debug("after hostNameRewrite:\n %s", string(retBuffer))
+	retBuffer, err := parseRequest(buffer, rewriteHost)
 	return retBuffer, err
 }
 
-func parseRequest(org []byte, clientHost string) (ret []byte, err error) {
+func parseRequest(org []byte, rewriteHost string) (ret []byte, err error) {
 	tp := bytes.NewBuffer(org)
 	// First line: GET /index.html HTTP/1.0
 	var b []byte
@@ -79,10 +75,10 @@ func parseRequest(org []byte, clientHost string) (ret []byte, err error) {
 		return nil, err
 	}
 	req := new(http.Request)
-	//we invoked ReadRequest in GetHttpHostname before, so we ignore error
+	// we invoked ReadRequest in GetHttpHostname before, so we ignore error
 	req.Method, req.RequestURI, req.Proto, _ = parseRequestLine(string(b))
 	rawurl := req.RequestURI
-	//CONNECT www.google.com:443 HTTP/1.1
+	// CONNECT www.google.com:443 HTTP/1.1
 	justAuthority := req.Method == "CONNECT" && !strings.HasPrefix(rawurl, "/")
 	if justAuthority {
 		rawurl = "http://" + rawurl
@@ -97,7 +93,7 @@ func parseRequest(org []byte, clientHost string) (ret []byte, err error) {
 	//  GET /index.html HTTP/1.1
 	//  Host: www.google.com
 	if req.URL.Host == "" {
-		changedBuf, err := changeHostName(tp, clientHost)
+		changedBuf, err := changeHostName(tp, rewriteHost)
 		buf := new(bytes.Buffer)
 		buf.Write(b)
 		buf.Write(changedBuf)
@@ -108,7 +104,12 @@ func parseRequest(org []byte, clientHost string) (ret []byte, err error) {
 	// GET http://www.google.com/index.html HTTP/1.1
 	// Host: doesntmatter
 	// In this case, any Host line is ignored.
-	req.URL.Host = clientHost
+	hostPort := strings.Split(req.URL.Host, ":")
+	if len(hostPort) == 1 {
+		req.URL.Host = rewriteHost
+	} else if len(hostPort) == 2 {
+		req.URL.Host = fmt.Sprintf("%s:%s", rewriteHost, hostPort[1])
+	}
 	firstLine := req.Method + " " + req.URL.String() + " " + req.Proto
 	buf := new(bytes.Buffer)
 	buf.WriteString(firstLine)
@@ -128,7 +129,7 @@ func parseRequestLine(line string) (method, requestURI, proto string, ok bool) {
 	return line[:s1], line[s1+1 : s2], line[s2+1:], true
 }
 
-func changeHostName(buff *bytes.Buffer, clientHost string) (_ []byte, err error) {
+func changeHostName(buff *bytes.Buffer, rewriteHost string) (_ []byte, err error) {
 	retBuf := new(bytes.Buffer)
 
 	peek := buff.Bytes()
@@ -145,7 +146,13 @@ func changeHostName(buff *bytes.Buffer, clientHost string) (_ []byte, err error)
 			return nil, fmt.Errorf("malformed MIME header line: " + string(kv))
 		}
 		if strings.Contains(strings.ToLower(string(kv[:j])), "host") {
-			hostHeader := fmt.Sprintf("Host: %s\n", clientHost)
+			var hostHeader string
+			portPos := bytes.IndexByte(kv[j+1:], ':')
+			if portPos == -1 {
+				hostHeader = fmt.Sprintf("Host: %s\n", rewriteHost)
+			} else {
+				hostHeader = fmt.Sprintf("Host: %s:%s\n", rewriteHost, kv[portPos+1:])
+			}
 			retBuf.WriteString(hostHeader)
 			peek = peek[i+1:]
 			break

+ 1 - 1
src/frp/utils/vhost/https.go

@@ -47,7 +47,7 @@ type HttpsMuxer struct {
 }
 
 func NewHttpsMuxer(listener *conn.Listener, timeout time.Duration) (*HttpsMuxer, error) {
-	mux, err := NewVhostMuxer(listener, GetHttpsHostname, timeout)
+	mux, err := NewVhostMuxer(listener, GetHttpsHostname, nil, timeout)
 	return &HttpsMuxer{mux}, err
 }
 

+ 25 - 29
src/frp/utils/vhost/vhost.go

@@ -27,41 +27,42 @@ import (
 )
 
 type muxFunc func(*conn.Conn) (net.Conn, string, error)
+type hostRewriteFunc func(*conn.Conn, string) (net.Conn, error)
 
 type VhostMuxer struct {
 	listener    *conn.Listener
 	timeout     time.Duration
 	vhostFunc   muxFunc
+	rewriteFunc hostRewriteFunc
 	registryMap map[string]*Listener
 	mutex       sync.RWMutex
 }
 
-func NewVhostMuxer(listener *conn.Listener, vhostFunc muxFunc, timeout time.Duration) (mux *VhostMuxer, err error) {
+func NewVhostMuxer(listener *conn.Listener, vhostFunc muxFunc, rewriteFunc hostRewriteFunc, timeout time.Duration) (mux *VhostMuxer, err error) {
 	mux = &VhostMuxer{
 		listener:    listener,
 		timeout:     timeout,
 		vhostFunc:   vhostFunc,
+		rewriteFunc: rewriteFunc,
 		registryMap: make(map[string]*Listener),
 	}
 	go mux.run()
 	return mux, nil
 }
 
-func (v *VhostMuxer) Listen(name, proxytype, clientIp string, clientPort, serverPort int64) (l *Listener, err error) {
+// listen for a new domain name, if rewriteHost is not empty  and rewriteFunc is not nil, then rewrite the host header to rewriteHost
+func (v *VhostMuxer) Listen(name string, rewriteHost string) (l *Listener, err error) {
 	v.mutex.Lock()
 	defer v.mutex.Unlock()
 	if _, exist := v.registryMap[name]; exist {
-		return nil, fmt.Errorf("name %s is already bound", name)
+		return nil, fmt.Errorf("domain name %s is already bound", name)
 	}
 
 	l = &Listener{
-		name:       name,
-		mux:        v,
-		accept:     make(chan *conn.Conn),
-		proxyType:  proxytype,
-		clientIp:   clientIp,
-		clientPort: clientPort,
-		serverPort: serverPort,
+		name:        name,
+		rewriteHost: rewriteHost,
+		mux:         v,
+		accept:      make(chan *conn.Conn),
 	}
 	v.registryMap[name] = l
 	return l, nil
@@ -115,13 +116,10 @@ func (v *VhostMuxer) handle(c *conn.Conn) {
 }
 
 type Listener struct {
-	name       string
-	mux        *VhostMuxer // for closing VhostMuxer
-	accept     chan *conn.Conn
-	proxyType  string //suppor http host rewrite
-	clientIp   string
-	clientPort int64
-	serverPort int64
+	name        string
+	rewriteHost string
+	mux         *VhostMuxer // for closing VhostMuxer
+	accept      chan *conn.Conn
 }
 
 func (l *Listener) Accept() (*conn.Conn, error) {
@@ -129,19 +127,16 @@ func (l *Listener) Accept() (*conn.Conn, error) {
 	if !ok {
 		return nil, fmt.Errorf("Listener closed")
 	}
-	if net.ParseIP(l.clientIp) == nil && l.proxyType == "http" {
-		if (l.name != l.clientIp) || (l.serverPort != l.clientPort) {
-			clientHost := l.clientIp
-			if l.clientPort != 80 {
-				strPort := fmt.Sprintf(":%d", l.clientPort)
-				clientHost += strPort
-			}
-			retConn, err := HostNameRewrite(conn, clientHost)
-			if err != nil {
-				return nil, fmt.Errorf("http host rewrite failed")
-			}
-			conn.SetTcpConn(retConn)
+
+	// if rewriteFunc is exist and rewriteHost is set
+	// rewrite http requests with a modified host header
+	if l.mux.rewriteFunc != nil && l.rewriteHost != "" {
+		fmt.Printf("host rewrite: %s\n", l.rewriteHost)
+		sConn, err := l.mux.rewriteFunc(conn, l.rewriteHost)
+		if err != nil {
+			return nil, fmt.Errorf("http host header rewrite failed")
 		}
+		conn.SetTcpConn(sConn)
 	}
 	return conn, nil
 }
@@ -162,6 +157,7 @@ type sharedConn struct {
 	buff *bytes.Buffer
 }
 
+// the bytes you read in io.Reader, will be reserved in sharedConn
 func newShareConn(conn net.Conn) (*sharedConn, io.Reader) {
 	sc := &sharedConn{
 		Conn: conn,