Преглед изворни кода

Merge pull request #50 from fatedier/fatedier/fix_package_loss

frp/models/msg: fix a bug if local service write to socket immediatel…
fatedier пре 8 година
родитељ
комит
4067591a4d

+ 9 - 13
src/frp/models/msg/process.go

@@ -15,12 +15,10 @@
 package msg
 
 import (
-	"bufio"
 	"bytes"
 	"encoding/binary"
 	"fmt"
 	"io"
-	"net"
 	"sync"
 
 	"frp/models/config"
@@ -61,7 +59,7 @@ func JoinMore(c1 *conn.Conn, c2 *conn.Conn, conf config.BaseConf, needRecord boo
 		defer wait.Done()
 
 		// we don't care about errors here
-		pipeEncrypt(from.TcpConn, to.TcpConn, conf, needRecord)
+		pipeEncrypt(from, to, conf, needRecord)
 	}
 
 	decryptPipe := func(to *conn.Conn, from *conn.Conn) {
@@ -70,7 +68,7 @@ func JoinMore(c1 *conn.Conn, c2 *conn.Conn, conf config.BaseConf, needRecord boo
 		defer wait.Done()
 
 		// we don't care about errors here
-		pipeDecrypt(to.TcpConn, from.TcpConn, conf, needRecord)
+		pipeDecrypt(to, from, conf, needRecord)
 	}
 
 	wait.Add(2)
@@ -106,7 +104,7 @@ func unpkgMsg(data []byte) (int, []byte, []byte) {
 }
 
 // decrypt msg from reader, then write into writer
-func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) (err error) {
+func pipeDecrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bool) (err error) {
 	laes := new(pcrypto.Pcrypto)
 	key := conf.AuthToken
 	if conf.PrivilegeMode {
@@ -119,7 +117,7 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
 
 	buf := make([]byte, 5*1024+4)
 	var left, res []byte
-	var cnt int
+	var cnt int = -1
 
 	// record
 	var flowBytes int64 = 0
@@ -129,13 +127,12 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
 		}()
 	}
 
-	nreader := bufio.NewReader(r)
 	for {
 		// there may be more than 1 package in variable
 		// and we read more bytes if unpkgMsg returns an error
 		var newBuf []byte
 		if cnt < 0 {
-			n, err := nreader.Read(buf)
+			n, err := r.Read(buf)
 			if err != nil {
 				return err
 			}
@@ -165,7 +162,7 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
 			}
 		}
 
-		_, err = w.Write(res)
+		_, err = w.WriteBytes(res)
 		if err != nil {
 			return err
 		}
@@ -182,7 +179,7 @@ func pipeDecrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
 }
 
 // recvive msg from reader, then encrypt msg into writer
-func pipeEncrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool) (err error) {
+func pipeEncrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bool) (err error) {
 	laes := new(pcrypto.Pcrypto)
 	key := conf.AuthToken
 	if conf.PrivilegeMode {
@@ -201,10 +198,9 @@ func pipeEncrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
 		}()
 	}
 
-	nreader := bufio.NewReader(r)
 	buf := make([]byte, 5*1024)
 	for {
-		n, err := nreader.Read(buf)
+		n, err := r.Read(buf)
 		if err != nil {
 			return err
 		}
@@ -235,7 +231,7 @@ func pipeEncrypt(r net.Conn, w net.Conn, conf config.BaseConf, needRecord bool)
 		}
 
 		res = pkgMsg(res)
-		_, err = w.Write(res)
+		_, err = w.WriteBytes(res)
 		if err != nil {
 			return err
 		}

+ 2 - 3
src/frp/models/server/server.go

@@ -154,13 +154,12 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) {
 				}
 
 				// start another goroutine for join two conns from frpc and user
-				go func() {
+				go func(userConn *conn.Conn) {
 					workConn, err := p.getWorkConn()
 					if err != nil {
 						return
 					}
 
-					userConn := c
 					// msg will transfer to another without modifying
 					// l means local, r means remote
 					log.Debug("Join two connections, (l[%s] r[%s]) (l[%s] r[%s])", workConn.GetLocalAddr(), workConn.GetRemoteAddr(),
@@ -169,7 +168,7 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) {
 					needRecord := true
 					go msg.JoinMore(userConn, workConn, p.BaseConf, needRecord)
 					metric.OpenConnection(p.Name)
-				}()
+				}(c)
 			}
 		}(listener)
 	}

+ 15 - 1
src/frp/utils/conn/conn.go

@@ -117,6 +117,11 @@ func ConnectServer(host string, port int64) (c *Conn, err error) {
 	return c, nil
 }
 
+func (c *Conn) SetTcpConn(tcpConn net.Conn) {
+	c.TcpConn = tcpConn
+	c.Reader = bufio.NewReader(c.TcpConn)
+}
+
 func (c *Conn) GetRemoteAddr() (addr string) {
 	return c.TcpConn.RemoteAddr().String()
 }
@@ -125,6 +130,11 @@ func (c *Conn) GetLocalAddr() (addr string) {
 	return c.TcpConn.LocalAddr().String()
 }
 
+func (c *Conn) Read(p []byte) (n int, err error) {
+	n, err = c.Reader.Read(p)
+	return
+}
+
 func (c *Conn) ReadLine() (buff string, err error) {
 	buff, err = c.Reader.ReadString('\n')
 	if err != nil {
@@ -138,10 +148,14 @@ func (c *Conn) ReadLine() (buff string, err error) {
 	return buff, err
 }
 
+func (c *Conn) WriteBytes(content []byte) (n int, err error) {
+	n, err = c.TcpConn.Write(content)
+	return
+}
+
 func (c *Conn) Write(content string) (err error) {
 	_, err = c.TcpConn.Write([]byte(content))
 	return err
-
 }
 
 func (c *Conn) SetDeadline(t time.Time) error {

+ 1 - 1
src/frp/utils/vhost/vhost.go

@@ -105,7 +105,7 @@ func (v *VhostMuxer) handle(c *conn.Conn) {
 	if err = sConn.SetDeadline(time.Time{}); err != nil {
 		return
 	}
-	c.TcpConn = sConn
+	c.SetTcpConn(sConn)
 
 	l.accept <- c
 }