|
@@ -16,6 +16,7 @@ package conn
|
|
|
|
|
|
import (
|
|
|
"bufio"
|
|
|
+ "bytes"
|
|
|
"encoding/base64"
|
|
|
"fmt"
|
|
|
"io"
|
|
@@ -25,6 +26,8 @@ import (
|
|
|
"strings"
|
|
|
"sync"
|
|
|
"time"
|
|
|
+
|
|
|
+ "github.com/fatedier/frp/src/utils/pool"
|
|
|
)
|
|
|
|
|
|
type Listener struct {
|
|
@@ -61,11 +64,7 @@ func Listen(bindAddr string, bindPort int64) (l *Listener, err error) {
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
- c := &Conn{
|
|
|
- TcpConn: conn,
|
|
|
- closeFlag: false,
|
|
|
- }
|
|
|
- c.Reader = bufio.NewReader(c.TcpConn)
|
|
|
+ c := NewConn(conn)
|
|
|
l.accept <- c
|
|
|
}
|
|
|
}()
|
|
@@ -95,20 +94,23 @@ func (l *Listener) Close() error {
|
|
|
type Conn struct {
|
|
|
TcpConn net.Conn
|
|
|
Reader *bufio.Reader
|
|
|
+ buffer *bytes.Buffer
|
|
|
closeFlag bool
|
|
|
- mutex sync.RWMutex
|
|
|
+
|
|
|
+ mutex sync.RWMutex
|
|
|
}
|
|
|
|
|
|
func NewConn(conn net.Conn) (c *Conn) {
|
|
|
- c = &Conn{}
|
|
|
- c.TcpConn = conn
|
|
|
+ c = &Conn{
|
|
|
+ TcpConn: conn,
|
|
|
+ buffer: nil,
|
|
|
+ closeFlag: false,
|
|
|
+ }
|
|
|
c.Reader = bufio.NewReader(c.TcpConn)
|
|
|
- c.closeFlag = false
|
|
|
- return c
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
func ConnectServer(addr string) (c *Conn, err error) {
|
|
|
- c = &Conn{}
|
|
|
servertAddr, err := net.ResolveTCPAddr("tcp", addr)
|
|
|
if err != nil {
|
|
|
return
|
|
@@ -117,9 +119,7 @@ func ConnectServer(addr string) (c *Conn, err error) {
|
|
|
if err != nil {
|
|
|
return
|
|
|
}
|
|
|
- c.TcpConn = conn
|
|
|
- c.Reader = bufio.NewReader(c.TcpConn)
|
|
|
- c.closeFlag = false
|
|
|
+ c = NewConn(conn)
|
|
|
return c, nil
|
|
|
}
|
|
|
|
|
@@ -185,7 +185,23 @@ func (c *Conn) GetLocalAddr() (addr string) {
|
|
|
}
|
|
|
|
|
|
func (c *Conn) Read(p []byte) (n int, err error) {
|
|
|
- n, err = c.Reader.Read(p)
|
|
|
+ c.mutex.RLock()
|
|
|
+ if c.buffer == nil {
|
|
|
+ c.mutex.RUnlock()
|
|
|
+ return c.Reader.Read(p)
|
|
|
+ }
|
|
|
+ c.mutex.RUnlock()
|
|
|
+
|
|
|
+ n, err = c.buffer.Read(p)
|
|
|
+ if err == io.EOF {
|
|
|
+ c.mutex.Lock()
|
|
|
+ c.buffer = nil
|
|
|
+ c.mutex.Unlock()
|
|
|
+ var n2 int
|
|
|
+ n2, err = c.Reader.Read(p[n:])
|
|
|
+
|
|
|
+ n += n2
|
|
|
+ }
|
|
|
return
|
|
|
}
|
|
|
|
|
@@ -212,6 +228,16 @@ func (c *Conn) WriteString(content string) (err error) {
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
+func (c *Conn) AppendReaderBuffer(content []byte) {
|
|
|
+ c.mutex.Lock()
|
|
|
+ defer c.mutex.Unlock()
|
|
|
+
|
|
|
+ if c.buffer == nil {
|
|
|
+ c.buffer = bytes.NewBuffer(make([]byte, 0, 2048))
|
|
|
+ }
|
|
|
+ c.buffer.Write(content)
|
|
|
+}
|
|
|
+
|
|
|
func (c *Conn) SetDeadline(t time.Time) error {
|
|
|
return c.TcpConn.SetDeadline(t)
|
|
|
}
|
|
@@ -238,22 +264,36 @@ func (c *Conn) IsClosed() (closeFlag bool) {
|
|
|
}
|
|
|
|
|
|
// when you call this function, you should make sure that
|
|
|
-// remote client won't send any bytes to this socket
|
|
|
+// no bytes were read before
|
|
|
func (c *Conn) CheckClosed() bool {
|
|
|
c.mutex.RLock()
|
|
|
if c.closeFlag {
|
|
|
+ c.mutex.RUnlock()
|
|
|
return true
|
|
|
}
|
|
|
c.mutex.RUnlock()
|
|
|
|
|
|
+ tmp := pool.GetBuf(2048)
|
|
|
+ defer pool.PutBuf(tmp)
|
|
|
err := c.TcpConn.SetReadDeadline(time.Now().Add(time.Millisecond))
|
|
|
if err != nil {
|
|
|
c.Close()
|
|
|
return true
|
|
|
}
|
|
|
|
|
|
- var tmp []byte = make([]byte, 1)
|
|
|
- _, err = c.TcpConn.Read(tmp)
|
|
|
+ n, err := c.TcpConn.Read(tmp)
|
|
|
+ if err == io.EOF {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+
|
|
|
+ var tmp2 []byte = make([]byte, 1)
|
|
|
+ err = c.TcpConn.SetReadDeadline(time.Now().Add(time.Millisecond))
|
|
|
+ if err != nil {
|
|
|
+ c.Close()
|
|
|
+ return true
|
|
|
+ }
|
|
|
+
|
|
|
+ n2, err := c.TcpConn.Read(tmp2)
|
|
|
if err == io.EOF {
|
|
|
return true
|
|
|
}
|
|
@@ -263,5 +303,12 @@ func (c *Conn) CheckClosed() bool {
|
|
|
c.Close()
|
|
|
return true
|
|
|
}
|
|
|
+
|
|
|
+ if n > 0 {
|
|
|
+ c.AppendReaderBuffer(tmp[:n])
|
|
|
+ }
|
|
|
+ if n2 > 0 {
|
|
|
+ c.AppendReaderBuffer(tmp2[:n2])
|
|
|
+ }
|
|
|
return false
|
|
|
}
|