Browse Source

Fix conflicts in fatedier/connection_pool with dev

Conflicts:
	src/frp/cmd/frpc/control.go
	src/frp/cmd/frps/control.go
	src/frp/models/config/config.go
	src/frp/models/server/server.go
fatedier 8 years ago
parent
commit
fd3c97a0e9

+ 1 - 0
conf/frpc.ini

@@ -55,3 +55,4 @@ local_ip = 127.0.0.1
 local_port = 80
 use_gzip = true
 custom_domains = web03.yourdomain.com
+host_header_rewrite = example.com

+ 8 - 8
src/frp/cmd/frpc/control.go

@@ -138,14 +138,14 @@ func loginToServer(cli *client.ProxyClient) (c *conn.Conn, err error) {
 
 	nowTime := time.Now().Unix()
 	req := &msg.ControlReq{
-		Type:          consts.NewCtlConn,
-		ProxyName:     cli.Name,
-		UseEncryption: cli.UseEncryption,
-		UseGzip:       cli.UseGzip,
-		PoolCount:     cli.PoolCount,
-		PrivilegeMode: cli.PrivilegeMode,
-		ProxyType:     cli.Type,
-		Timestamp:     nowTime,
+		Type:              consts.NewCtlConn,
+		ProxyName:         cli.Name,
+		UseEncryption:     cli.UseEncryption,
+		UseGzip:           cli.UseGzip,
+		PrivilegeMode:     cli.PrivilegeMode,
+		ProxyType:         cli.Type,
+		HostHeaderRewrite: cli.HostHeaderRewrite,
+		Timestamp:         nowTime,
 	}
 	if cli.PrivilegeMode {
 		privilegeKey := pcrypto.GetAuthKey(cli.Name + client.PrivilegeToken + fmt.Sprintf("%d", nowTime))

+ 1 - 0
src/frp/cmd/frps/control.go

@@ -276,6 +276,7 @@ func doLogin(req *msg.ControlReq, c *conn.Conn) (ret int64, info string) {
 		// set infomations from frpc
 		s.UseEncryption = req.UseEncryption
 		s.UseGzip = req.UseGzip
+		s.HostHeaderRewrite = req.HostHeaderRewrite
 		if req.PoolCount > server.MaxPoolCount {
 			s.PoolCount = server.MaxPoolCount
 		} else if req.PoolCount < 0 {

+ 10 - 0
src/frp/models/client/config.go

@@ -140,6 +140,14 @@ func LoadConf(confFile string) (err error) {
 				proxyClient.UseGzip = true
 			}
 
+			if proxyClient.Type == "http" {
+				// host_header_rewrite
+				tmpStr, ok = section["host_header_rewrite"]
+				if ok {
+					proxyClient.HostHeaderRewrite = tmpStr
+				}
+			}
+
 			// privilege_mode
 			proxyClient.PrivilegeMode = false
 			tmpStr, ok = section["privilege_mode"]
@@ -178,6 +186,7 @@ func LoadConf(confFile string) (err error) {
 						return fmt.Errorf("Parse conf error: proxy [%s] remote_port not found", proxyClient.Name)
 					}
 				} else if proxyClient.Type == "http" {
+					// custom_domains
 					domainStr, ok := section["custom_domains"]
 					if ok {
 						proxyClient.CustomDomains = strings.Split(domainStr, ",")
@@ -191,6 +200,7 @@ func LoadConf(confFile string) (err error) {
 						return fmt.Errorf("Parse conf error: proxy [%s] custom_domains must be set when type equals http", proxyClient.Name)
 					}
 				} else if proxyClient.Type == "https" {
+					// custom_domains
 					domainStr, ok := section["custom_domains"]
 					if ok {
 						proxyClient.CustomDomains = strings.Split(domainStr, ",")

+ 9 - 8
src/frp/models/config/config.go

@@ -15,12 +15,13 @@
 package config
 
 type BaseConf struct {
-	Name           string
-	AuthToken      string
-	Type           string
-	UseEncryption  bool
-	UseGzip        bool
-	PrivilegeMode  bool
-	PrivilegeToken string
-	PoolCount      int64
+	Name              string
+	AuthToken         string
+	Type              string
+	UseEncryption     bool
+	UseGzip           bool
+	PrivilegeMode     bool
+	PrivilegeToken    string
+	PoolCount         int64
+	HostHeaderRewrite string
 }

+ 7 - 6
src/frp/models/msg/msg.go

@@ -29,12 +29,13 @@ type ControlReq struct {
 	PoolCount     int64  `json:"pool_count"`
 
 	// configures used if privilege_mode is enabled
-	PrivilegeMode bool     `json:"privilege_mode"`
-	PrivilegeKey  string   `json:"privilege_key"`
-	ProxyType     string   `json:"proxy_type"`
-	RemotePort    int64    `json:"remote_port"`
-	CustomDomains []string `json:"custom_domains, omitempty"`
-	Timestamp     int64    `json:"timestamp"`
+	PrivilegeMode     bool     `json:"privilege_mode"`
+	PrivilegeKey      string   `json:"privilege_key"`
+	ProxyType         string   `json:"proxy_type"`
+	RemotePort        int64    `json:"remote_port"`
+	CustomDomains     []string `json:"custom_domains, omitempty"`
+	HostHeaderRewrite string   `json:"host_header_rewrite"`
+	Timestamp         int64    `json:"timestamp"`
 }
 
 type ControlRes struct {

+ 9 - 13
src/frp/models/msg/process.go

@@ -15,12 +15,10 @@
 package msg
 
 import (
-	"bufio"
 	"bytes"
 	"encoding/binary"
 	"fmt"
 	"io"
-	"net"
 	"sync"
 
 	"frp/models/config"
@@ -61,7 +59,7 @@ func JoinMore(c1 *conn.Conn, c2 *conn.Conn, conf config.BaseConf, needRecord boo
 		defer wait.Done()
 
 		// we don't care about errors here
-		pipeEncrypt(from.TcpConn, to.TcpConn, conf, needRecord)
+		pipeEncrypt(from, to, conf, needRecord)
 	}
 
 	decryptPipe := func(to *conn.Conn, from *conn.Conn) {
@@ -70,7 +68,7 @@ func JoinMore(c1 *conn.Conn, c2 *conn.Conn, conf config.BaseConf, needRecord boo
 		defer wait.Done()
 
 		// we don't care about errors here
-		pipeDecrypt(to.TcpConn, from.TcpConn, conf, needRecord)
+		pipeDecrypt(to, from, conf, needRecord)
 	}
 
 	wait.Add(2)
@@ -106,7 +104,7 @@ func unpkgMsg(data []byte) (int, []byte, []byte) {
 }
 
 // decrypt msg from reader, then write into writer
-func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) (err error) {
+func pipeDecrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bool) (err error) {
 	laes := new(pcrypto.Pcrypto)
 	key := conf.AuthToken
 	if conf.PrivilegeMode {
@@ -119,7 +117,7 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
 
 	buf := make([]byte, 5*1024+4)
 	var left, res []byte
-	var cnt int
+	var cnt int = -1
 
 	// record
 	var flowBytes int64 = 0
@@ -129,13 +127,12 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
 		}()
 	}
 
-	nreader := bufio.NewReader(r)
 	for {
 		// there may be more than 1 package in variable
 		// and we read more bytes if unpkgMsg returns an error
 		var newBuf []byte
 		if cnt < 0 {
-			n, err := nreader.Read(buf)
+			n, err := r.Read(buf)
 			if err != nil {
 				return err
 			}
@@ -165,7 +162,7 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
 			}
 		}
 
-		_, err = w.Write(res)
+		_, err = w.WriteBytes(res)
 		if err != nil {
 			return err
 		}
@@ -182,7 +179,7 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
 }
 
 // recvive msg from reader, then encrypt msg into writer
-func pipeEncrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) (err error) {
+func pipeEncrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bool) (err error) {
 	laes := new(pcrypto.Pcrypto)
 	key := conf.AuthToken
 	if conf.PrivilegeMode {
@@ -201,10 +198,9 @@ func pipeEncrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
 		}()
 	}
 
-	nreader := bufio.NewReader(r)
 	buf := make([]byte, 5*1024)
 	for {
-		n, err := nreader.Read(buf)
+		n, err := r.Read(buf)
 		if err != nil {
 			return err
 		}
@@ -235,7 +231,7 @@ func pipeEncrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
 		}
 
 		res = pkgMsg(res)
-		_, err = w.Write(res)
+		_, err = w.WriteBytes(res)
 		if err != nil {
 			return err
 		}

+ 7 - 7
src/frp/models/server/server.go

@@ -65,6 +65,7 @@ func NewProxyServerFromCtlMsg(req *msg.ControlReq) (p *ProxyServer) {
 	p.BindAddr = BindAddr
 	p.ListenPort = req.RemotePort
 	p.CustomDomains = req.CustomDomains
+	p.HostHeaderRewrite = req.HostHeaderRewrite
 	return
 }
 
@@ -81,7 +82,7 @@ func (p *ProxyServer) Init() {
 
 func (p *ProxyServer) Compare(p2 *ProxyServer) bool {
 	if p.Name != p2.Name || p.AuthToken != p2.AuthToken || p.Type != p2.Type ||
-		p.BindAddr != p2.BindAddr || p.ListenPort != p2.ListenPort {
+		p.BindAddr != p2.BindAddr || p.ListenPort != p2.ListenPort || p.HostHeaderRewrite != p2.HostHeaderRewrite {
 		return false
 	}
 	if len(p.CustomDomains) != len(p2.CustomDomains) {
@@ -115,7 +116,7 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) {
 		p.listeners = append(p.listeners, l)
 	} else if p.Type == "http" {
 		for _, domain := range p.CustomDomains {
-			l, err := VhostHttpMuxer.Listen(domain)
+			l, err := VhostHttpMuxer.Listen(domain, p.HostHeaderRewrite)
 			if err != nil {
 				return err
 			}
@@ -123,7 +124,7 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) {
 		}
 	} else if p.Type == "https" {
 		for _, domain := range p.CustomDomains {
-			l, err := VhostHttpsMuxer.Listen(domain)
+			l, err := VhostHttpsMuxer.Listen(domain, p.HostHeaderRewrite)
 			if err != nil {
 				return err
 			}
@@ -160,14 +161,12 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) {
 					return
 				}
 
-				// start another goroutine for join two connections between frpc and user
-				go func() {
+				go func(userConn *conn.Conn) {
 					workConn, err := p.getWorkConn()
 					if err != nil {
 						return
 					}
 
-					userConn := c
 					// message will be transferred to another without modifying
 					// l means local, r means remote
 					log.Debug("Join two connections, (l[%s] r[%s]) (l[%s] r[%s])", workConn.GetLocalAddr(), workConn.GetRemoteAddr(),
@@ -176,7 +175,8 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) {
 					metric.OpenConnection(p.Name)
 					needRecord := true
 					go msg.JoinMore(userConn, workConn, p.BaseConf, needRecord)
-				}()
+					metric.OpenConnection(p.Name)
+				}(c)
 			}
 		}(listener)
 	}

+ 20 - 1
src/frp/utils/conn/conn.go

@@ -117,6 +117,16 @@ func ConnectServer(host string, port int64) (c *Conn, err error) {
 	return c, nil
 }
 
+// if the tcpConn is different with c.TcpConn
+// you should call c.Close() first
+func (c *Conn) SetTcpConn(tcpConn net.Conn) {
+	c.mutex.Lock()
+	defer c.mutex.Unlock()
+	c.TcpConn = tcpConn
+	c.closeFlag = false
+	c.Reader = bufio.NewReader(c.TcpConn)
+}
+
 func (c *Conn) GetRemoteAddr() (addr string) {
 	return c.TcpConn.RemoteAddr().String()
 }
@@ -125,6 +135,11 @@ func (c *Conn) GetLocalAddr() (addr string) {
 	return c.TcpConn.LocalAddr().String()
 }
 
+func (c *Conn) Read(p []byte) (n int, err error) {
+	n, err = c.Reader.Read(p)
+	return
+}
+
 func (c *Conn) ReadLine() (buff string, err error) {
 	buff, err = c.Reader.ReadString('\n')
 	if err != nil {
@@ -138,10 +153,14 @@ func (c *Conn) ReadLine() (buff string, err error) {
 	return buff, err
 }
 
+func (c *Conn) WriteBytes(content []byte) (n int, err error) {
+	n, err = c.TcpConn.Write(content)
+	return
+}
+
 func (c *Conn) Write(content string) (err error) {
 	_, err = c.TcpConn.Write([]byte(content))
 	return err
-
 }
 
 func (c *Conn) SetDeadline(t time.Time) error {

+ 122 - 1
src/frp/utils/vhost/http.go

@@ -16,8 +16,12 @@ package vhost
 
 import (
 	"bufio"
+	"bytes"
+	"fmt"
+	"io"
 	"net"
 	"net/http"
+	"net/url"
 	"strings"
 	"time"
 
@@ -42,6 +46,123 @@ func GetHttpHostname(c *conn.Conn) (_ net.Conn, routerName string, err error) {
 }
 
 func NewHttpMuxer(listener *conn.Listener, timeout time.Duration) (*HttpMuxer, error) {
-	mux, err := NewVhostMuxer(listener, GetHttpHostname, timeout)
+	mux, err := NewVhostMuxer(listener, GetHttpHostname, HttpHostNameRewrite, timeout)
 	return &HttpMuxer{mux}, err
 }
+
+func HttpHostNameRewrite(c *conn.Conn, rewriteHost string) (_ net.Conn, err error) {
+	sc, rd := newShareConn(c.TcpConn)
+	var buff []byte
+	if buff, err = hostNameRewrite(rd, rewriteHost); err != nil {
+		return sc, err
+	}
+	err = sc.WriteBuff(buff)
+	return sc, err
+}
+
+func hostNameRewrite(request io.Reader, rewriteHost string) (_ []byte, err error) {
+	buffer := make([]byte, 1024)
+	request.Read(buffer)
+	retBuffer, err := parseRequest(buffer, rewriteHost)
+	return retBuffer, err
+}
+
+func parseRequest(org []byte, rewriteHost string) (ret []byte, err error) {
+	tp := bytes.NewBuffer(org)
+	// First line: GET /index.html HTTP/1.0
+	var b []byte
+	if b, err = tp.ReadBytes('\n'); err != nil {
+		return nil, err
+	}
+	req := new(http.Request)
+	// we invoked ReadRequest in GetHttpHostname before, so we ignore error
+	req.Method, req.RequestURI, req.Proto, _ = parseRequestLine(string(b))
+	rawurl := req.RequestURI
+	// CONNECT www.google.com:443 HTTP/1.1
+	justAuthority := req.Method == "CONNECT" && !strings.HasPrefix(rawurl, "/")
+	if justAuthority {
+		rawurl = "http://" + rawurl
+	}
+	req.URL, _ = url.ParseRequestURI(rawurl)
+	if justAuthority {
+		// Strip the bogus "http://" back off.
+		req.URL.Scheme = ""
+	}
+
+	//  RFC2616: first case
+	//  GET /index.html HTTP/1.1
+	//  Host: www.google.com
+	if req.URL.Host == "" {
+		changedBuf, err := changeHostName(tp, rewriteHost)
+		buf := new(bytes.Buffer)
+		buf.Write(b)
+		buf.Write(changedBuf)
+		return buf.Bytes(), err
+	}
+
+	// RFC2616: second case
+	// GET http://www.google.com/index.html HTTP/1.1
+	// Host: doesntmatter
+	// In this case, any Host line is ignored.
+	hostPort := strings.Split(req.URL.Host, ":")
+	if len(hostPort) == 1 {
+		req.URL.Host = rewriteHost
+	} else if len(hostPort) == 2 {
+		req.URL.Host = fmt.Sprintf("%s:%s", rewriteHost, hostPort[1])
+	}
+	firstLine := req.Method + " " + req.URL.String() + " " + req.Proto
+	buf := new(bytes.Buffer)
+	buf.WriteString(firstLine)
+	tp.WriteTo(buf)
+	return buf.Bytes(), err
+
+}
+
+// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts.
+func parseRequestLine(line string) (method, requestURI, proto string, ok bool) {
+	s1 := strings.Index(line, " ")
+	s2 := strings.Index(line[s1+1:], " ")
+	if s1 < 0 || s2 < 0 {
+		return
+	}
+	s2 += s1 + 1
+	return line[:s1], line[s1+1 : s2], line[s2+1:], true
+}
+
+func changeHostName(buff *bytes.Buffer, rewriteHost string) (_ []byte, err error) {
+	retBuf := new(bytes.Buffer)
+
+	peek := buff.Bytes()
+	for len(peek) > 0 {
+		i := bytes.IndexByte(peek, '\n')
+		if i < 3 {
+			// Not present (-1) or found within the next few bytes,
+			// implying we're at the end ("\r\n\r\n" or "\n\n")
+			return nil, err
+		}
+		kv := peek[:i]
+		j := bytes.IndexByte(kv, ':')
+		if j < 0 {
+			return nil, fmt.Errorf("malformed MIME header line: " + string(kv))
+		}
+		if strings.Contains(strings.ToLower(string(kv[:j])), "host") {
+			var hostHeader string
+			portPos := bytes.IndexByte(kv[j+1:], ':')
+			if portPos == -1 {
+				hostHeader = fmt.Sprintf("Host: %s\n", rewriteHost)
+			} else {
+				hostHeader = fmt.Sprintf("Host: %s:%s\n", rewriteHost, kv[portPos+1:])
+			}
+			retBuf.WriteString(hostHeader)
+			peek = peek[i+1:]
+			break
+		} else {
+			retBuf.Write(peek[:i])
+			retBuf.WriteByte('\n')
+		}
+
+		peek = peek[i+1:]
+	}
+	retBuf.Write(peek)
+	return retBuf.Bytes(), err
+}

+ 1 - 1
src/frp/utils/vhost/https.go

@@ -47,7 +47,7 @@ type HttpsMuxer struct {
 }
 
 func NewHttpsMuxer(listener *conn.Listener, timeout time.Duration) (*HttpsMuxer, error) {
-	mux, err := NewVhostMuxer(listener, GetHttpsHostname, timeout)
+	mux, err := NewVhostMuxer(listener, GetHttpsHostname, nil, timeout)
 	return &HttpsMuxer{mux}, err
 }
 

+ 34 - 10
src/frp/utils/vhost/vhost.go

@@ -27,37 +27,42 @@ import (
 )
 
 type muxFunc func(*conn.Conn) (net.Conn, string, error)
+type hostRewriteFunc func(*conn.Conn, string) (net.Conn, error)
 
 type VhostMuxer struct {
 	listener    *conn.Listener
 	timeout     time.Duration
 	vhostFunc   muxFunc
+	rewriteFunc hostRewriteFunc
 	registryMap map[string]*Listener
 	mutex       sync.RWMutex
 }
 
-func NewVhostMuxer(listener *conn.Listener, vhostFunc muxFunc, timeout time.Duration) (mux *VhostMuxer, err error) {
+func NewVhostMuxer(listener *conn.Listener, vhostFunc muxFunc, rewriteFunc hostRewriteFunc, timeout time.Duration) (mux *VhostMuxer, err error) {
 	mux = &VhostMuxer{
 		listener:    listener,
 		timeout:     timeout,
 		vhostFunc:   vhostFunc,
+		rewriteFunc: rewriteFunc,
 		registryMap: make(map[string]*Listener),
 	}
 	go mux.run()
 	return mux, nil
 }
 
-func (v *VhostMuxer) Listen(name string) (l *Listener, err error) {
+// listen for a new domain name, if rewriteHost is not empty  and rewriteFunc is not nil, then rewrite the host header to rewriteHost
+func (v *VhostMuxer) Listen(name string, rewriteHost string) (l *Listener, err error) {
 	v.mutex.Lock()
 	defer v.mutex.Unlock()
 	if _, exist := v.registryMap[name]; exist {
-		return nil, fmt.Errorf("name %s is already bound", name)
+		return nil, fmt.Errorf("domain name %s is already bound", name)
 	}
 
 	l = &Listener{
-		name:   name,
-		mux:    v,
-		accept: make(chan *conn.Conn),
+		name:        name,
+		rewriteHost: rewriteHost,
+		mux:         v,
+		accept:      make(chan *conn.Conn),
 	}
 	v.registryMap[name] = l
 	return l, nil
@@ -105,15 +110,16 @@ func (v *VhostMuxer) handle(c *conn.Conn) {
 	if err = sConn.SetDeadline(time.Time{}); err != nil {
 		return
 	}
-	c.TcpConn = sConn
+	c.SetTcpConn(sConn)
 
 	l.accept <- c
 }
 
 type Listener struct {
-	name   string
-	mux    *VhostMuxer // for closing VhostMuxer
-	accept chan *conn.Conn
+	name        string
+	rewriteHost string
+	mux         *VhostMuxer // for closing VhostMuxer
+	accept      chan *conn.Conn
 }
 
 func (l *Listener) Accept() (*conn.Conn, error) {
@@ -121,6 +127,17 @@ func (l *Listener) Accept() (*conn.Conn, error) {
 	if !ok {
 		return nil, fmt.Errorf("Listener closed")
 	}
+
+	// if rewriteFunc is exist and rewriteHost is set
+	// rewrite http requests with a modified host header
+	if l.mux.rewriteFunc != nil && l.rewriteHost != "" {
+		fmt.Printf("host rewrite: %s\n", l.rewriteHost)
+		sConn, err := l.mux.rewriteFunc(conn, l.rewriteHost)
+		if err != nil {
+			return nil, fmt.Errorf("http host header rewrite failed")
+		}
+		conn.SetTcpConn(sConn)
+	}
 	return conn, nil
 }
 
@@ -140,6 +157,7 @@ type sharedConn struct {
 	buff *bytes.Buffer
 }
 
+// the bytes you read in io.Reader, will be reserved in sharedConn
 func newShareConn(conn net.Conn) (*sharedConn, io.Reader) {
 	sc := &sharedConn{
 		Conn: conn,
@@ -166,3 +184,9 @@ func (sc *sharedConn) Read(p []byte) (n int, err error) {
 	sc.Unlock()
 	return
 }
+
+func (sc *sharedConn) WriteBuff(buffer []byte) (err error) {
+	sc.buff.Reset()
+	_, err = sc.buff.Write(buffer)
+	return err
+}