Browse Source

frps: vhost_http_port and vhost_https_port can be same with frps bind
port

fatedier 6 years ago
parent
commit
5db605ca02
7 changed files with 423 additions and 25 deletions
  1. 1 0
      conf/frps_full.ini
  2. 45 11
      server/service.go
  3. 9 9
      utils/net/conn.go
  4. 210 0
      utils/net/mux/mux.go
  5. 95 0
      utils/net/mux/mux_test.go
  6. 55 0
      utils/net/mux/rule.go
  7. 8 5
      utils/vhost/https.go

+ 1 - 0
conf/frps_full.ini

@@ -16,6 +16,7 @@ kcp_bind_port = 7000
 # proxy_bind_addr = 127.0.0.1
 
 # if you want to support virtual host, you must set the http port for listening (optional)
+# Note: http port and https port can be same with bind_port
 vhost_http_port = 80
 vhost_https_port = 443
 

+ 45 - 11
server/service.go

@@ -26,6 +26,7 @@ import (
 	"github.com/fatedier/frp/models/msg"
 	"github.com/fatedier/frp/utils/log"
 	frpNet "github.com/fatedier/frp/utils/net"
+	"github.com/fatedier/frp/utils/net/mux"
 	"github.com/fatedier/frp/utils/util"
 	"github.com/fatedier/frp/utils/version"
 	"github.com/fatedier/frp/utils/vhost"
@@ -41,6 +42,9 @@ var ServerService *Service
 
 // Server service.
 type Service struct {
+	// Dispatch connections to different handlers listen on same port.
+	muxer *mux.Mux
+
 	// Accept connections from client.
 	listener frpNet.Listener
 
@@ -88,12 +92,33 @@ func NewService() (svr *Service, err error) {
 		return
 	}
 
+	var (
+		httpMuxOn  bool
+		httpsMuxOn bool
+	)
+	if cfg.BindAddr == cfg.ProxyBindAddr {
+		if cfg.BindPort == cfg.VhostHttpPort {
+			httpMuxOn = true
+		}
+		if cfg.BindPort == cfg.VhostHttpsPort {
+			httpsMuxOn = true
+		}
+		if httpMuxOn || httpsMuxOn {
+			svr.muxer = mux.NewMux()
+		}
+	}
+
 	// Listen for accepting connections from client.
-	svr.listener, err = frpNet.ListenTcp(cfg.BindAddr, cfg.BindPort)
+	ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", cfg.BindAddr, cfg.BindPort))
 	if err != nil {
 		err = fmt.Errorf("Create server listener error, %v", err)
 		return
 	}
+	if svr.muxer != nil {
+		go svr.muxer.Serve(ln)
+		ln = svr.muxer.DefaultListener()
+	}
+	svr.listener = frpNet.WrapLogListener(ln)
 	log.Info("frps tcp listen on %s:%d", cfg.BindAddr, cfg.BindPort)
 
 	// Listen for accepting connections from client using kcp protocol.
@@ -117,10 +142,14 @@ func NewService() (svr *Service, err error) {
 			Handler: rp,
 		}
 		var l net.Listener
-		l, err = net.Listen("tcp", address)
-		if err != nil {
-			err = fmt.Errorf("Create vhost http listener error, %v", err)
-			return
+		if httpMuxOn {
+			l = svr.muxer.ListenHttp(0)
+		} else {
+			l, err = net.Listen("tcp", address)
+			if err != nil {
+				err = fmt.Errorf("Create vhost http listener error, %v", err)
+				return
+			}
 		}
 		go server.Serve(l)
 		log.Info("http service listen on %s:%d", cfg.ProxyBindAddr, cfg.VhostHttpPort)
@@ -128,13 +157,18 @@ func NewService() (svr *Service, err error) {
 
 	// Create https vhost muxer.
 	if cfg.VhostHttpsPort > 0 {
-		var l frpNet.Listener
-		l, err = frpNet.ListenTcp(cfg.ProxyBindAddr, cfg.VhostHttpsPort)
-		if err != nil {
-			err = fmt.Errorf("Create vhost https listener error, %v", err)
-			return
+		var l net.Listener
+		if httpsMuxOn {
+			l = svr.muxer.ListenHttps(0)
+		} else {
+			l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", cfg.ProxyBindAddr, cfg.VhostHttpsPort))
+			if err != nil {
+				err = fmt.Errorf("Create server listener error, %v", err)
+				return
+			}
 		}
-		svr.VhostHttpsMuxer, err = vhost.NewHttpsMuxer(l, 30*time.Second)
+
+		svr.VhostHttpsMuxer, err = vhost.NewHttpsMuxer(frpNet.WrapLogListener(l), 30*time.Second)
 		if err != nil {
 			err = fmt.Errorf("Create vhost httpsMuxer error, %v", err)
 			return

+ 9 - 9
utils/net/conn.go

@@ -20,7 +20,6 @@ import (
 	"fmt"
 	"io"
 	"net"
-	"sync"
 	"sync/atomic"
 	"time"
 
@@ -136,7 +135,6 @@ func ConnectServerByProxy(proxyUrl string, protocol string, addr string) (c Conn
 
 type SharedConn struct {
 	Conn
-	sync.Mutex
 	buf *bytes.Buffer
 }
 
@@ -149,22 +147,24 @@ func NewShareConn(conn Conn) (*SharedConn, io.Reader) {
 	return sc, io.TeeReader(conn, sc.buf)
 }
 
+func NewShareConnSize(conn Conn, bufSize int) (*SharedConn, io.Reader) {
+	sc := &SharedConn{
+		Conn: conn,
+		buf:  bytes.NewBuffer(make([]byte, 0, bufSize)),
+	}
+	return sc, io.TeeReader(conn, sc.buf)
+}
+
+// Not thread safety.
 func (sc *SharedConn) Read(p []byte) (n int, err error) {
-	sc.Lock()
 	if sc.buf == nil {
-		sc.Unlock()
 		return sc.Conn.Read(p)
 	}
-	sc.Unlock()
 	n, err = sc.buf.Read(p)
-
 	if err == io.EOF {
-		sc.Lock()
 		sc.buf = nil
-		sc.Unlock()
 		var n2 int
 		n2, err = sc.Conn.Read(p[n:])
-
 		n += n2
 	}
 	return

+ 210 - 0
utils/net/mux/mux.go

@@ -0,0 +1,210 @@
+package mux
+
+import (
+	"fmt"
+	"io"
+	"net"
+	"sort"
+	"sync"
+	"time"
+
+	"github.com/fatedier/frp/utils/errors"
+	frpNet "github.com/fatedier/frp/utils/net"
+)
+
+const (
+	// DefaultTimeout is the default length of time to wait for bytes we need.
+	DefaultTimeout = 10 * time.Second
+)
+
+type Mux struct {
+	ln net.Listener
+
+	defaultLn       *listener
+	lns             []*listener
+	maxNeedBytesNum uint32
+	mu              sync.RWMutex
+}
+
+func NewMux() (mux *Mux) {
+	mux = &Mux{
+		lns: make([]*listener, 0),
+	}
+	return
+}
+
+func (mux *Mux) Listen(priority int, needBytesNum uint32, fn MatchFunc) net.Listener {
+	ln := &listener{
+		c:            make(chan net.Conn),
+		mux:          mux,
+		needBytesNum: needBytesNum,
+		matchFn:      fn,
+	}
+
+	mux.mu.Lock()
+	defer mux.mu.Unlock()
+	if needBytesNum > mux.maxNeedBytesNum {
+		mux.maxNeedBytesNum = needBytesNum
+	}
+
+	newlns := append(mux.copyLns(), ln)
+	sort.Slice(newlns, func(i, j int) bool {
+		return newlns[i].needBytesNum < newlns[j].needBytesNum
+	})
+	mux.lns = newlns
+	return ln
+}
+
+func (mux *Mux) ListenHttp(priority int) net.Listener {
+	return mux.Listen(priority, HttpNeedBytesNum, HttpMatchFunc)
+}
+
+func (mux *Mux) ListenHttps(priority int) net.Listener {
+	return mux.Listen(priority, HttpsNeedBytesNum, HttpsMatchFunc)
+}
+
+func (mux *Mux) DefaultListener() net.Listener {
+	mux.mu.Lock()
+	defer mux.mu.Unlock()
+	if mux.defaultLn == nil {
+		mux.defaultLn = &listener{
+			c:   make(chan net.Conn),
+			mux: mux,
+		}
+	}
+	return mux.defaultLn
+}
+
+func (mux *Mux) release(ln *listener) bool {
+	result := false
+	mux.mu.Lock()
+	defer mux.mu.Unlock()
+	lns := mux.copyLns()
+
+	for i, l := range lns {
+		if l == ln {
+			lns = append(lns[:i], lns[i+1:]...)
+			result = true
+		}
+	}
+	mux.lns = lns
+	return result
+}
+
+func (mux *Mux) copyLns() []*listener {
+	lns := make([]*listener, 0, len(mux.lns))
+	for _, l := range mux.lns {
+		lns = append(lns, l)
+	}
+	return lns
+}
+
+// Serve handles connections from ln and multiplexes then across registered listeners.
+func (mux *Mux) Serve(ln net.Listener) error {
+	mux.mu.Lock()
+	mux.ln = ln
+	mux.mu.Unlock()
+	for {
+		// Wait for the next connection.
+		// If it returns a temporary error then simply retry.
+		// If it returns any other error then exit immediately.
+		conn, err := ln.Accept()
+		if err, ok := err.(interface {
+			Temporary() bool
+		}); ok && err.Temporary() {
+			continue
+		}
+
+		if err != nil {
+			return err
+		}
+
+		go mux.handleConn(conn)
+	}
+}
+
+func (mux *Mux) handleConn(conn net.Conn) {
+	mux.mu.RLock()
+	maxNeedBytesNum := mux.maxNeedBytesNum
+	lns := mux.lns
+	defaultLn := mux.defaultLn
+	mux.mu.RUnlock()
+
+	shareConn, rd := frpNet.NewShareConnSize(frpNet.WrapConn(conn), int(maxNeedBytesNum))
+	data := make([]byte, maxNeedBytesNum)
+
+	conn.SetReadDeadline(time.Now().Add(DefaultTimeout))
+	_, err := io.ReadFull(rd, data)
+	if err != nil {
+		conn.Close()
+		return
+	}
+	conn.SetReadDeadline(time.Time{})
+
+	for _, ln := range lns {
+		if match := ln.matchFn(data); match {
+			err = errors.PanicToError(func() {
+				ln.c <- shareConn
+			})
+			if err != nil {
+				conn.Close()
+			}
+			return
+		}
+	}
+
+	// No match listeners
+	if defaultLn != nil {
+		err = errors.PanicToError(func() {
+			defaultLn.c <- shareConn
+		})
+		if err != nil {
+			conn.Close()
+		}
+		return
+	}
+
+	// No listeners for this connection, close it.
+	conn.Close()
+	return
+}
+
+type listener struct {
+	mux *Mux
+
+	needBytesNum uint32
+	matchFn      MatchFunc
+
+	c  chan net.Conn
+	mu sync.RWMutex
+}
+
+// Accept waits for and returns the next connection to the listener.
+func (ln *listener) Accept() (net.Conn, error) {
+	conn, ok := <-ln.c
+	if !ok {
+		return nil, fmt.Errorf("network connection closed")
+	}
+	return conn, nil
+}
+
+// Close removes this listener from the parent mux and closes the channel.
+func (ln *listener) Close() error {
+	if ok := ln.mux.release(ln); ok {
+		// Close done to signal to any RLock holders to release their lock.
+		close(ln.c)
+	}
+	return nil
+}
+
+func (ln *listener) Addr() net.Addr {
+	if ln.mux == nil {
+		return nil
+	}
+	ln.mux.mu.RLock()
+	defer ln.mux.mu.RUnlock()
+	if ln.mux.ln == nil {
+		return nil
+	}
+	return ln.mux.ln.Addr()
+}

+ 95 - 0
utils/net/mux/mux_test.go

@@ -0,0 +1,95 @@
+package mux
+
+import (
+	"bufio"
+	"io/ioutil"
+	"net"
+	"net/http"
+	"net/http/httptest"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func runHttpSvr(ln net.Listener) *httptest.Server {
+	svr := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.Write([]byte("http service"))
+	}))
+	svr.Listener = ln
+	svr.Start()
+	return svr
+}
+
+func runHttpsSvr(ln net.Listener) *httptest.Server {
+	svr := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.Write([]byte("https service"))
+	}))
+	svr.Listener = ln
+	svr.StartTLS()
+	return svr
+}
+
+func runEchoSvr(ln net.Listener) {
+	go func() {
+		for {
+			conn, err := ln.Accept()
+			if err != nil {
+				return
+			}
+			rd := bufio.NewReader(conn)
+			data, err := rd.ReadString('\n')
+			if err != nil {
+				return
+			}
+			conn.Write([]byte(data))
+			conn.Close()
+		}
+	}()
+}
+
+func TestMux(t *testing.T) {
+	assert := assert.New(t)
+
+	ln, err := net.Listen("tcp", "127.0.0.1:")
+	assert.NoError(err)
+
+	mux := NewMux()
+	httpLn := mux.ListenHttp(0)
+	httpsLn := mux.ListenHttps(0)
+	defaultLn := mux.DefaultListener()
+	go mux.Serve(ln)
+	time.Sleep(100 * time.Millisecond)
+
+	httpSvr := runHttpSvr(httpLn)
+	defer httpSvr.Close()
+	httpsSvr := runHttpsSvr(httpsLn)
+	defer httpsSvr.Close()
+	runEchoSvr(defaultLn)
+	defer ln.Close()
+
+	// test http service
+	resp, err := http.Get(httpSvr.URL)
+	assert.NoError(err)
+	data, err := ioutil.ReadAll(resp.Body)
+	assert.NoError(err)
+	assert.Equal("http service", string(data))
+
+	// test https service
+	client := httpsSvr.Client()
+	resp, err = client.Get(httpsSvr.URL)
+	assert.NoError(err)
+	data, err = ioutil.ReadAll(resp.Body)
+	assert.NoError(err)
+	assert.Equal("https service", string(data))
+
+	// test echo service
+	conn, err := net.Dial("tcp", ln.Addr().String())
+	assert.NoError(err)
+	_, err = conn.Write([]byte("test echo\n"))
+	assert.NoError(err)
+	data = make([]byte, 1024)
+	n, err := conn.Read(data)
+	assert.NoError(err)
+	assert.Equal("test echo\n", string(data[:n]))
+}

+ 55 - 0
utils/net/mux/rule.go

@@ -0,0 +1,55 @@
+package mux
+
+type MatchFunc func(data []byte) (match bool)
+
+var (
+	HttpsNeedBytesNum uint32 = 1
+	HttpNeedBytesNum  uint32 = 3
+	YamuxNeedBytesNum uint32 = 2
+)
+
+var HttpsMatchFunc MatchFunc = func(data []byte) bool {
+	if len(data) < int(HttpsNeedBytesNum) {
+		return false
+	}
+
+	if data[0] == 0x16 {
+		return true
+	} else {
+		return false
+	}
+}
+
+// From https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods
+var httpHeadBytes = map[string]struct{}{
+	"GET": struct{}{},
+	"HEA": struct{}{},
+	"POS": struct{}{},
+	"PUT": struct{}{},
+	"DEL": struct{}{},
+	"CON": struct{}{},
+	"OPT": struct{}{},
+	"TRA": struct{}{},
+	"PAT": struct{}{},
+}
+
+var HttpMatchFunc MatchFunc = func(data []byte) bool {
+	if len(data) < int(HttpNeedBytesNum) {
+		return false
+	}
+
+	_, ok := httpHeadBytes[string(data[:3])]
+	return ok
+}
+
+// From https://github.com/hashicorp/yamux/blob/master/spec.md
+var YamuxMatchFunc MatchFunc = func(data []byte) bool {
+	if len(data) < int(YamuxNeedBytesNum) {
+		return false
+	}
+
+	if data[0] == 0 && data[1] >= 0x0 && data[1] <= 0x3 {
+		return true
+	}
+	return false
+}

+ 8 - 5
utils/vhost/https.go

@@ -55,14 +55,17 @@ func readHandshake(rd io.Reader) (host string, err error) {
 	data := pool.GetBuf(1024)
 	origin := data
 	defer pool.PutBuf(origin)
-	length, err := rd.Read(data)
+
+	_, err = io.ReadFull(rd, data[:47])
+	if err != nil {
+		return
+	}
+
+	length, err := rd.Read(data[47:])
 	if err != nil {
 		return
 	} else {
-		if length < 47 {
-			err = fmt.Errorf("readHandshake: proto length[%d] is too short", length)
-			return
-		}
+		length += 47
 	}
 	data = data[:length]
 	if uint8(data[5]) != typeClientHello {