소스 검색

connection pool: ssh can't work when pool_count is set, fix #193

fatedier 8 년 전
부모
커밋
9faf4acd62
3개의 변경된 파일69개의 추가작업 그리고 19개의 파일을 삭제
  1. 1 0
      src/models/server/server.go
  2. 65 18
      src/utils/conn/conn.go
  3. 3 1
      src/utils/vhost/vhost.go

+ 1 - 0
src/models/server/server.go

@@ -384,6 +384,7 @@ func (p *ProxyServer) getWorkConn() (workConn *conn.Conn, err error) {
 				err = fmt.Errorf("ProxyName [%s], no work connections available, control is closing", p.Name)
 				return
 			}
+			log.Debug("ProxyName [%s], get work connection from pool", p.Name)
 		default:
 			// no work connections available in the poll, send message to frpc to get more
 			p.ctlMsgChan <- 1

+ 65 - 18
src/utils/conn/conn.go

@@ -16,6 +16,7 @@ package conn
 
 import (
 	"bufio"
+	"bytes"
 	"encoding/base64"
 	"fmt"
 	"io"
@@ -25,6 +26,8 @@ import (
 	"strings"
 	"sync"
 	"time"
+
+	"github.com/fatedier/frp/src/utils/pool"
 )
 
 type Listener struct {
@@ -61,11 +64,7 @@ func Listen(bindAddr string, bindPort int64) (l *Listener, err error) {
 				continue
 			}
 
-			c := &Conn{
-				TcpConn:   conn,
-				closeFlag: false,
-			}
-			c.Reader = bufio.NewReader(c.TcpConn)
+			c := NewConn(conn)
 			l.accept <- c
 		}
 	}()
@@ -95,20 +94,23 @@ func (l *Listener) Close() error {
 type Conn struct {
 	TcpConn   net.Conn
 	Reader    *bufio.Reader
+	buffer    *bytes.Buffer
 	closeFlag bool
-	mutex     sync.RWMutex
+
+	mutex sync.RWMutex
 }
 
 func NewConn(conn net.Conn) (c *Conn) {
-	c = &Conn{}
-	c.TcpConn = conn
+	c = &Conn{
+		TcpConn:   conn,
+		buffer:    nil,
+		closeFlag: false,
+	}
 	c.Reader = bufio.NewReader(c.TcpConn)
-	c.closeFlag = false
-	return c
+	return
 }
 
 func ConnectServer(addr string) (c *Conn, err error) {
-	c = &Conn{}
 	servertAddr, err := net.ResolveTCPAddr("tcp", addr)
 	if err != nil {
 		return
@@ -117,9 +119,7 @@ func ConnectServer(addr string) (c *Conn, err error) {
 	if err != nil {
 		return
 	}
-	c.TcpConn = conn
-	c.Reader = bufio.NewReader(c.TcpConn)
-	c.closeFlag = false
+	c = NewConn(conn)
 	return c, nil
 }
 
@@ -185,7 +185,23 @@ func (c *Conn) GetLocalAddr() (addr string) {
 }
 
 func (c *Conn) Read(p []byte) (n int, err error) {
-	n, err = c.Reader.Read(p)
+	c.mutex.RLock()
+	if c.buffer == nil {
+		c.mutex.RUnlock()
+		return c.Reader.Read(p)
+	}
+	c.mutex.RUnlock()
+
+	n, err = c.buffer.Read(p)
+	if err == io.EOF {
+		c.mutex.Lock()
+		c.buffer = nil
+		c.mutex.Unlock()
+		var n2 int
+		n2, err = c.Reader.Read(p[n:])
+
+		n += n2
+	}
 	return
 }
 
@@ -212,6 +228,16 @@ func (c *Conn) WriteString(content string) (err error) {
 	return err
 }
 
+func (c *Conn) AppendReaderBuffer(content []byte) {
+	c.mutex.Lock()
+	defer c.mutex.Unlock()
+
+	if c.buffer == nil {
+		c.buffer = bytes.NewBuffer(make([]byte, 0, 2048))
+	}
+	c.buffer.Write(content)
+}
+
 func (c *Conn) SetDeadline(t time.Time) error {
 	return c.TcpConn.SetDeadline(t)
 }
@@ -238,22 +264,36 @@ func (c *Conn) IsClosed() (closeFlag bool) {
 }
 
 // when you call this function, you should make sure that
-// remote client won't send any bytes to this socket
+// no bytes were read before
 func (c *Conn) CheckClosed() bool {
 	c.mutex.RLock()
 	if c.closeFlag {
+		c.mutex.RUnlock()
 		return true
 	}
 	c.mutex.RUnlock()
 
+	tmp := pool.GetBuf(2048)
+	defer pool.PutBuf(tmp)
 	err := c.TcpConn.SetReadDeadline(time.Now().Add(time.Millisecond))
 	if err != nil {
 		c.Close()
 		return true
 	}
 
-	var tmp []byte = make([]byte, 1)
-	_, err = c.TcpConn.Read(tmp)
+	n, err := c.TcpConn.Read(tmp)
+	if err == io.EOF {
+		return true
+	}
+
+	var tmp2 []byte = make([]byte, 1)
+	err = c.TcpConn.SetReadDeadline(time.Now().Add(time.Millisecond))
+	if err != nil {
+		c.Close()
+		return true
+	}
+
+	n2, err := c.TcpConn.Read(tmp2)
 	if err == io.EOF {
 		return true
 	}
@@ -263,5 +303,12 @@ func (c *Conn) CheckClosed() bool {
 		c.Close()
 		return true
 	}
+
+	if n > 0 {
+		c.AppendReaderBuffer(tmp[:n])
+	}
+	if n2 > 0 {
+		c.AppendReaderBuffer(tmp2[:n2])
+	}
 	return false
 }

+ 3 - 1
src/utils/vhost/vhost.go

@@ -205,16 +205,18 @@ func (sc *sharedConn) Read(p []byte) (n int, err error) {
 		sc.Unlock()
 		return sc.Conn.Read(p)
 	}
+	sc.Unlock()
 	n, err = sc.buff.Read(p)
 
 	if err == io.EOF {
+		sc.Lock()
 		sc.buff = nil
+		sc.Unlock()
 		var n2 int
 		n2, err = sc.Conn.Read(p[n:])
 
 		n += n2
 	}
-	sc.Unlock()
 	return
 }