Pārlūkot izejas kodu

(1)优化重连和心跳检测

Hurricanezwf 9 gadi atpakaļ
vecāks
revīzija
5d6f37aa82
5 mainītis faili ar 104 papildinājumiem un 48 dzēšanām
  1. 2 0
      .gitignore
  2. 13 12
      cmd/frpc/control.go
  3. 50 9
      cmd/frps/control.go
  4. 27 17
      pkg/models/server.go
  5. 12 10
      pkg/utils/conn/conn.go

+ 2 - 0
.gitignore

@@ -26,3 +26,5 @@ _testmain.go
 # Self
 bin/
 
+# Cache
+*.swp

+ 13 - 12
cmd/frpc/control.go

@@ -15,8 +15,7 @@ const (
 	heartbeatDuration = 2 //心跳检测时间间隔,单位秒
 )
 
-// client与server之间连接的保护锁
-var connProtect sync.Mutex
+var isHeartBeatContinue bool = true
 
 func ControlProcess(cli *models.ProxyClient, wait *sync.WaitGroup) {
 	defer wait.Done()
@@ -28,13 +27,11 @@ func ControlProcess(cli *models.ProxyClient, wait *sync.WaitGroup) {
 	}
 	defer c.Close()
 
-	go startHeartBeat(c)
-
 	for {
 		// ignore response content now
 		_, err := c.ReadLine()
 		if err == io.EOF {
-			connProtect.Lock() // 除了这里,其他地方禁止对连接进行任何操作
+			isHeartBeatContinue = false
 			log.Debug("ProxyName [%s], server close this control conn", cli.Name)
 			var sleepTime time.Duration = 1
 			for {
@@ -51,7 +48,6 @@ func ControlProcess(cli *models.ProxyClient, wait *sync.WaitGroup) {
 				}
 				time.Sleep(sleepTime * time.Second)
 			}
-			connProtect.Unlock()
 			continue
 		} else if err != nil {
 			log.Warn("ProxyName [%s], read from server error, %v", cli.Name, err)
@@ -104,6 +100,8 @@ func loginToServer(cli *models.ProxyClient) (connection *conn.Conn) {
 		}
 
 		connection = c
+		go startHeartBeat(connection)
+		log.Debug("ProxyName [%s], connect to server[%s:%d] success!", cli.Name, ServerAddr, ServerPort)
 	}
 
 	if connection == nil {
@@ -114,14 +112,17 @@ func loginToServer(cli *models.ProxyClient) (connection *conn.Conn) {
 }
 
 func startHeartBeat(con *conn.Conn) {
+	isHeartBeatContinue = true
 	for {
 		time.Sleep(heartbeatDuration * time.Second)
-
-		connProtect.Lock()
-		err := con.Write("\r\n")
-		connProtect.Unlock()
-		if err != nil {
-			log.Error("Send hearbeat to server failed! Err:%s", err.Error())
+		if isHeartBeatContinue { // 把isHeartBeatContinue放在这里是为了防止SIGPIPE
+			err := con.Write("\r\n")
+			//log.Debug("send heart beat to server!")
+			if err != nil {
+				log.Error("Send hearbeat to server failed! Err:%s", err.Error())
+			}
+		} else {
+			break
 		}
 	}
 }

+ 50 - 9
cmd/frps/control.go

@@ -1,12 +1,14 @@
 package main
 
 import (
-	"fmt"
 	"encoding/json"
+	"fmt"
+	"io"
+	"time"
 
-	"frp/pkg/utils/log"
-	"frp/pkg/utils/conn"
 	"frp/pkg/models"
+	"frp/pkg/utils/conn"
+	"frp/pkg/utils/log"
 )
 
 func ProcessControlConn(l *conn.Listener) {
@@ -19,6 +21,8 @@ func ProcessControlConn(l *conn.Listener) {
 
 // control connection from every client and server
 func controlWorker(c *conn.Conn) {
+	defer c.Close()
+
 	// the first message is from client to server
 	// if error, close connection
 	res, err := c.ReadLine()
@@ -41,19 +45,20 @@ func controlWorker(c *conn.Conn) {
 		clientCtlRes.Code = 1
 		clientCtlRes.Msg = msg
 	}
-	
+
 	if needRes {
 		buf, _ := json.Marshal(clientCtlRes)
 		err = c.Write(string(buf) + "\n")
 		if err != nil {
 			log.Warn("Write error, %v", err)
+			time.Sleep(1 * time.Second)
+			return
 		}
 	} else {
-	// work conn, just return
+		// work conn, just return
 		return
 	}
 
-	defer c.Close()
 	// others is from server to client
 	server, ok := ProxyServers[clientCtlReq.ProxyName]
 	if !ok {
@@ -61,10 +66,16 @@ func controlWorker(c *conn.Conn) {
 		return
 	}
 
+	// read control msg from client
+	go readControlMsgFromClient(server, c)
+
 	serverCtlReq := &models.ClientCtlReq{}
 	serverCtlReq.Type = models.WorkConn
 	for {
-		server.WaitUserConn()
+		_, isStop := server.WaitUserConn()
+		if isStop {
+			break
+		}
 		buf, _ := json.Marshal(serverCtlReq)
 		err = c.Write(string(buf) + "\n")
 		if err != nil {
@@ -76,6 +87,7 @@ func controlWorker(c *conn.Conn) {
 		log.Debug("ProxyName [%s], write to client to add work conn success", server.Name)
 	}
 
+	log.Error("ProxyName [%s], I'm dead!", server.Name)
 	return
 }
 
@@ -96,7 +108,7 @@ func checkProxy(req *models.ClientCtlReq, c *conn.Conn) (succ bool, msg string,
 		log.Warn(msg)
 		return
 	}
-	
+
 	// control conn
 	if req.Type == models.ControlConn {
 		if server.Status != models.Idle {
@@ -115,7 +127,7 @@ func checkProxy(req *models.ClientCtlReq, c *conn.Conn) (succ bool, msg string,
 
 		log.Info("ProxyName [%s], start proxy success", req.ProxyName)
 	} else if req.Type == models.WorkConn {
-	// work conn
+		// work conn
 		needRes = false
 		if server.Status != models.Working {
 			log.Warn("ProxyName [%s], is not working when it gets one new work conn", req.ProxyName)
@@ -132,3 +144,32 @@ func checkProxy(req *models.ClientCtlReq, c *conn.Conn) (succ bool, msg string,
 	succ = true
 	return
 }
+
+func readControlMsgFromClient(server *models.ProxyServer, c *conn.Conn) {
+	isContinueRead := true
+	f := func() {
+		isContinueRead = false
+		server.StopWaitUserConn()
+	}
+	timer := time.AfterFunc(10*time.Second, f)
+	defer timer.Stop()
+
+	for isContinueRead {
+		content, err := c.ReadLine()
+		//log.Debug("Receive msg from client! content:%s", content)
+		if err != nil {
+			if err == io.EOF {
+				log.Warn("Server detect client[%s] is dead!", server.Name)
+				server.StopWaitUserConn()
+				break
+			}
+			log.Error("ProxyName [%s], read error:%s", server.Name, err.Error())
+			continue
+		}
+
+		if content == "\r\n" {
+			log.Debug("receive hearbeat:%s", content)
+			timer.Reset(10 * time.Second)
+		}
+	}
+}

+ 27 - 17
pkg/models/server.go

@@ -1,8 +1,8 @@
 package models
 
 import (
-	"sync"
 	"container/list"
+	"sync"
 
 	"frp/pkg/utils/conn"
 	"frp/pkg/utils/log"
@@ -14,22 +14,24 @@ const (
 )
 
 type ProxyServer struct {
-	Name			string
-	Passwd			string
-	BindAddr		string
-	ListenPort		int64
-
-	Status			int64
-	Listener		*conn.Listener		// accept new connection from remote users
-	CtlMsgChan		chan int64			// every time accept a new user conn, put "1" to the channel
-	CliConnChan		chan *conn.Conn		// get client conns from control goroutine
-	UserConnList	*list.List			// store user conns
-	Mutex			sync.Mutex
+	Name       string
+	Passwd     string
+	BindAddr   string
+	ListenPort int64
+
+	Status        int64
+	Listener      *conn.Listener  // accept new connection from remote users
+	CtlMsgChan    chan int64      // every time accept a new user conn, put "1" to the channel
+	StopBlockChan chan int64      // put any number to the channel, if you want to stop wait user conn
+	CliConnChan   chan *conn.Conn // get client conns from control goroutine
+	UserConnList  *list.List      // store user conns
+	Mutex         sync.Mutex
 }
 
 func (p *ProxyServer) Init() {
 	p.Status = Idle
 	p.CtlMsgChan = make(chan int64)
+	p.StopBlockChan = make(chan int64)
 	p.CliConnChan = make(chan *conn.Conn)
 	p.UserConnList = list.New()
 }
@@ -55,7 +57,7 @@ func (p *ProxyServer) Start() (err error) {
 	go func() {
 		for {
 			// block
-			c := p.Listener.GetConn()	
+			c := p.Listener.GetConn()
 			log.Debug("ProxyName [%s], get one new user conn [%s]", p.Name, c.GetRemoteAddr())
 
 			// put to list
@@ -93,7 +95,7 @@ func (p *ProxyServer) Start() (err error) {
 
 			// msg will transfer to another without modifying
 			log.Debug("Join two conns, (l[%s] r[%s]) (l[%s] r[%s])", cliConn.GetLocalAddr(), cliConn.GetRemoteAddr(),
-					userConn.GetLocalAddr(), userConn.GetRemoteAddr())
+				userConn.GetLocalAddr(), userConn.GetRemoteAddr())
 			go conn.Join(cliConn, userConn)
 		}
 	}()
@@ -110,7 +112,15 @@ func (p *ProxyServer) Close() {
 	p.Unlock()
 }
 
-func (p *ProxyServer) WaitUserConn() (res int64) {
-	res = <-p.CtlMsgChan
-	return 
+func (p *ProxyServer) WaitUserConn() (res int64, isStop bool) {
+	select {
+	case res = <-p.CtlMsgChan:
+		return res, false
+	case <-p.StopBlockChan:
+		return 0, true
+	}
+}
+
+func (p *ProxyServer) StopWaitUserConn() {
+	p.StopBlockChan <- 1
 }

+ 12 - 10
pkg/utils/conn/conn.go

@@ -1,18 +1,18 @@
 package conn
 
 import (
+	"bufio"
 	"fmt"
+	"io"
 	"net"
-	"bufio"
 	"sync"
-	"io"
 
 	"frp/pkg/utils/log"
 )
 
 type Listener struct {
-	Addr	net.Addr
-	Conns	chan *Conn
+	Addr  net.Addr
+	Conns chan *Conn
 }
 
 // wait util get one
@@ -22,8 +22,8 @@ func (l *Listener) GetConn() (conn *Conn) {
 }
 
 type Conn struct {
-	TcpConn		*net.TCPConn
-	Reader		*bufio.Reader
+	TcpConn *net.TCPConn
+	Reader  *bufio.Reader
 }
 
 func (c *Conn) ConnectServer(host string, port int64) (err error) {
@@ -59,7 +59,9 @@ func (c *Conn) Write(content string) (err error) {
 }
 
 func (c *Conn) Close() {
-	c.TcpConn.Close()
+	if c.TcpConn != nil { // ZWF:我觉得应该加一个非空保护
+		c.TcpConn.Close()
+	}
 }
 
 func Listen(bindAddr string, bindPort int64) (l *Listener, err error) {
@@ -70,8 +72,8 @@ func Listen(bindAddr string, bindPort int64) (l *Listener, err error) {
 	}
 
 	l = &Listener{
-		Addr:	listener.Addr(),
-		Conns:	make(chan *Conn),
+		Addr:  listener.Addr(),
+		Conns: make(chan *Conn),
 	}
 
 	go func() {
@@ -83,7 +85,7 @@ func Listen(bindAddr string, bindPort int64) (l *Listener, err error) {
 			}
 
 			c := &Conn{
-				TcpConn:	conn,
+				TcpConn: conn,
 			}
 			c.Reader = bufio.NewReader(c.TcpConn)
 			l.Conns <- c