Quellcode durchsuchen

change the Host value in http request header

Maodanping vor 8 Jahren
Ursprung
Commit
e0f2993b70

+ 2 - 0
src/frp/cmd/frpc/control.go

@@ -144,6 +144,8 @@ func loginToServer(cli *client.ProxyClient) (c *conn.Conn, err error) {
 		UseGzip:       cli.UseGzip,
 		PrivilegeMode: cli.PrivilegeMode,
 		ProxyType:     cli.Type,
+		LocalIp:       cli.LocalIp,
+		LocalPort:     cli.LocalPort,
 		Timestamp:     nowTime,
 	}
 	if cli.PrivilegeMode {

+ 2 - 0
src/frp/cmd/frps/control.go

@@ -276,6 +276,8 @@ func doLogin(req *msg.ControlReq, c *conn.Conn) (ret int64, info string) {
 		// set infomations from frpc
 		s.UseEncryption = req.UseEncryption
 		s.UseGzip = req.UseGzip
+		s.ClientIp = req.LocalIp
+		s.ClientPort = req.LocalPort
 
 		// start proxy and listen for user connections, no block
 		err := s.Start(c)

+ 3 - 0
src/frp/models/config/config.go

@@ -22,4 +22,7 @@ type BaseConf struct {
 	UseGzip        bool
 	PrivilegeMode  bool
 	PrivilegeToken string
+	ClientIp       string
+	ClientPort     int64
+	ServerPort     int64
 }

+ 2 - 0
src/frp/models/msg/msg.go

@@ -26,6 +26,8 @@ type ControlReq struct {
 	AuthKey       string `json:"auth_key"`
 	UseEncryption bool   `json:"use_encryption"`
 	UseGzip       bool   `json:"use_gzip"`
+	LocalIp       string `json:"local_ip"`
+	LocalPort     int64  `json:"local_port"`
 
 	// configures used if privilege_mode is enabled
 	PrivilegeMode bool     `json:"privilege_mode"`

+ 3 - 2
src/frp/models/server/server.go

@@ -64,6 +64,7 @@ func NewProxyServerFromCtlMsg(req *msg.ControlReq) (p *ProxyServer) {
 	p.BindAddr = BindAddr
 	p.ListenPort = req.RemotePort
 	p.CustomDomains = req.CustomDomains
+	p.ServerPort = VhostHttpPort
 	return
 }
 
@@ -113,7 +114,7 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) {
 		p.listeners = append(p.listeners, l)
 	} else if p.Type == "http" {
 		for _, domain := range p.CustomDomains {
-			l, err := VhostHttpMuxer.Listen(domain)
+			l, err := VhostHttpMuxer.Listen(domain, p.Type, p.ClientIp, p.ClientPort, p.ServerPort)
 			if err != nil {
 				return err
 			}
@@ -121,7 +122,7 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) {
 		}
 	} else if p.Type == "https" {
 		for _, domain := range p.CustomDomains {
-			l, err := VhostHttpsMuxer.Listen(domain)
+			l, err := VhostHttpsMuxer.Listen(domain, p.Type, p.ClientIp, p.ClientPort, p.ServerPort)
 			if err != nil {
 				return err
 			}

+ 114 - 0
src/frp/utils/vhost/http.go

@@ -16,12 +16,17 @@ package vhost
 
 import (
 	"bufio"
+	"bytes"
+	"fmt"
+	"io"
 	"net"
 	"net/http"
+	"net/url"
 	"strings"
 	"time"
 
 	"frp/utils/conn"
+	"frp/utils/log"
 )
 
 type HttpMuxer struct {
@@ -45,3 +50,112 @@ func NewHttpMuxer(listener *conn.Listener, timeout time.Duration) (*HttpMuxer, e
 	mux, err := NewVhostMuxer(listener, GetHttpHostname, timeout)
 	return &HttpMuxer{mux}, err
 }
+
+func HostNameRewrite(c *conn.Conn, clientHost string) (_ net.Conn, err error) {
+	log.Info("HostNameRewrite, clientHost: %s", clientHost)
+	sc, rd := newShareConn(c.TcpConn)
+	var buff []byte
+	if buff, err = hostNameRewrite(rd, clientHost); err != nil {
+		return sc, err
+	}
+	err = sc.WriteBuff(buff)
+	return sc, err
+}
+
+func hostNameRewrite(request io.Reader, clientHost string) (_ []byte, err error) {
+	buffer := make([]byte, 1024)
+	request.Read(buffer)
+	log.Debug("before hostNameRewrite:\n %s", string(buffer))
+	retBuffer, err := parseRequest(buffer, clientHost)
+	log.Debug("after hostNameRewrite:\n %s", string(retBuffer))
+	return retBuffer, err
+}
+
+func parseRequest(org []byte, clientHost string) (ret []byte, err error) {
+	tp := bytes.NewBuffer(org)
+	// First line: GET /index.html HTTP/1.0
+	var b []byte
+	if b, err = tp.ReadBytes('\n'); err != nil {
+		return nil, err
+	}
+	req := new(http.Request)
+	//we invoked ReadRequest in GetHttpHostname before, so we ignore error
+	req.Method, req.RequestURI, req.Proto, _ = parseRequestLine(string(b))
+	rawurl := req.RequestURI
+	//CONNECT www.google.com:443 HTTP/1.1
+	justAuthority := req.Method == "CONNECT" && !strings.HasPrefix(rawurl, "/")
+	if justAuthority {
+		rawurl = "http://" + rawurl
+	}
+	req.URL, _ = url.ParseRequestURI(rawurl)
+	if justAuthority {
+		// Strip the bogus "http://" back off.
+		req.URL.Scheme = ""
+	}
+
+	//  RFC2616: first case
+	//  GET /index.html HTTP/1.1
+	//  Host: www.google.com
+	if req.URL.Host == "" {
+		changedBuf, err := changeHostName(tp, clientHost)
+		buf := new(bytes.Buffer)
+		buf.Write(b)
+		buf.Write(changedBuf)
+		return buf.Bytes(), err
+	}
+
+	// RFC2616: second case
+	// GET http://www.google.com/index.html HTTP/1.1
+	// Host: doesntmatter
+	// In this case, any Host line is ignored.
+	req.URL.Host = clientHost
+	firstLine := req.Method + " " + req.URL.String() + " " + req.Proto
+	buf := new(bytes.Buffer)
+	buf.WriteString(firstLine)
+	tp.WriteTo(buf)
+	return buf.Bytes(), err
+
+}
+
+// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts.
+func parseRequestLine(line string) (method, requestURI, proto string, ok bool) {
+	s1 := strings.Index(line, " ")
+	s2 := strings.Index(line[s1+1:], " ")
+	if s1 < 0 || s2 < 0 {
+		return
+	}
+	s2 += s1 + 1
+	return line[:s1], line[s1+1 : s2], line[s2+1:], true
+}
+
+func changeHostName(buff *bytes.Buffer, clientHost string) (_ []byte, err error) {
+	retBuf := new(bytes.Buffer)
+
+	peek := buff.Bytes()
+	for len(peek) > 0 {
+		i := bytes.IndexByte(peek, '\n')
+		if i < 3 {
+			// Not present (-1) or found within the next few bytes,
+			// implying we're at the end ("\r\n\r\n" or "\n\n")
+			return nil, err
+		}
+		kv := peek[:i]
+		j := bytes.IndexByte(kv, ':')
+		if j < 0 {
+			return nil, fmt.Errorf("malformed MIME header line: " + string(kv))
+		}
+		if strings.Contains(strings.ToLower(string(kv[:j])), "host") {
+			hostHeader := fmt.Sprintf("Host: %s\n", clientHost)
+			retBuf.WriteString(hostHeader)
+			peek = peek[i+1:]
+			break
+		} else {
+			retBuf.Write(peek[:i])
+			retBuf.WriteByte('\n')
+		}
+
+		peek = peek[i+1:]
+	}
+	retBuf.Write(peek)
+	return retBuf.Bytes(), err
+}

+ 39 - 7
src/frp/utils/vhost/vhost.go

@@ -34,6 +34,10 @@ type VhostMuxer struct {
 	vhostFunc   muxFunc
 	registryMap map[string]*Listener
 	mutex       sync.RWMutex
+
+	//build map between custom_domains and client_domain
+	domainMap   map[string]string
+	domainMutex sync.RWMutex
 }
 
 func NewVhostMuxer(listener *conn.Listener, vhostFunc muxFunc, timeout time.Duration) (mux *VhostMuxer, err error) {
@@ -47,7 +51,7 @@ func NewVhostMuxer(listener *conn.Listener, vhostFunc muxFunc, timeout time.Dura
 	return mux, nil
 }
 
-func (v *VhostMuxer) Listen(name string) (l *Listener, err error) {
+func (v *VhostMuxer) Listen(name, proxytype, clientIp string, clientPort, serverPort int64) (l *Listener, err error) {
 	v.mutex.Lock()
 	defer v.mutex.Unlock()
 	if _, exist := v.registryMap[name]; exist {
@@ -55,9 +59,13 @@ func (v *VhostMuxer) Listen(name string) (l *Listener, err error) {
 	}
 
 	l = &Listener{
-		name:   name,
-		mux:    v,
-		accept: make(chan *conn.Conn),
+		name:       name,
+		mux:        v,
+		accept:     make(chan *conn.Conn),
+		proxyType:  proxytype,
+		clientIp:   clientIp,
+		clientPort: clientPort,
+		serverPort: serverPort,
 	}
 	v.registryMap[name] = l
 	return l, nil
@@ -111,9 +119,13 @@ func (v *VhostMuxer) handle(c *conn.Conn) {
 }
 
 type Listener struct {
-	name   string
-	mux    *VhostMuxer // for closing VhostMuxer
-	accept chan *conn.Conn
+	name       string
+	mux        *VhostMuxer // for closing VhostMuxer
+	accept     chan *conn.Conn
+	proxyType  string //suppor http host rewrite
+	clientIp   string
+	clientPort int64
+	serverPort int64
 }
 
 func (l *Listener) Accept() (*conn.Conn, error) {
@@ -121,6 +133,20 @@ func (l *Listener) Accept() (*conn.Conn, error) {
 	if !ok {
 		return nil, fmt.Errorf("Listener closed")
 	}
+	if net.ParseIP(l.clientIp) == nil && l.proxyType == "http" {
+		if (l.name != l.clientIp) || (l.serverPort != l.clientPort) {
+			clientHost := l.clientIp
+			if l.clientPort != 80 {
+				strPort := fmt.Sprintf(":%d", l.clientPort)
+				clientHost += strPort
+			}
+			retConn, err := HostNameRewrite(conn, clientHost)
+			if err != nil {
+				return nil, fmt.Errorf("http host rewrite failed")
+			}
+			conn.SetTcpConn(retConn)
+		}
+	}
 	return conn, nil
 }
 
@@ -166,3 +192,9 @@ func (sc *sharedConn) Read(p []byte) (n int, err error) {
 	sc.Unlock()
 	return
 }
+
+func (sc *sharedConn) WriteBuff(buffer []byte) (err error) {
+	sc.buff.Reset()
+	_, err = sc.buff.Write(buffer)
+	return err
+}