Browse Source

all: fix bug when client shutdown and reconnect, server response already use

1. if client is offline, server will release all resources
2. use a graceful method to shutdown go net.Listeners
3. add closeFlag for Conn, so startHeartBeat func can exit correctly now
fatedier 9 years ago
parent
commit
26479cf92a
6 changed files with 173 additions and 146 deletions
  1. 0 2
      .travis.yml
  2. 44 52
      cmd/frpc/control.go
  3. 12 12
      cmd/frps/control.go
  4. 2 4
      models/client/client.go
  5. 44 38
      models/server/server.go
  6. 71 38
      utils/conn/conn.go

+ 0 - 2
.travis.yml

@@ -1,11 +1,9 @@
-go_import_path: github.com/fatedier/frp
 sudo: false
 language: go
 
 go:
     - 1.4.2
     - 1.5.2 
-    - tip
 
 install:
     - make

+ 44 - 52
cmd/frpc/control.go

@@ -2,6 +2,7 @@ package main
 
 import (
 	"encoding/json"
+	"fmt"
 	"io"
 	"sync"
 	"time"
@@ -18,8 +19,8 @@ var isHeartBeatContinue bool = true
 func ControlProcess(cli *client.ProxyClient, wait *sync.WaitGroup) {
 	defer wait.Done()
 
-	c := loginToServer(cli)
-	if c == nil {
+	c, err := loginToServer(cli)
+	if err != nil {
 		log.Error("ProxyName [%s], connect to server failed!", cli.Name)
 		return
 	}
@@ -34,15 +35,15 @@ func ControlProcess(cli *client.ProxyClient, wait *sync.WaitGroup) {
 			var sleepTime time.Duration = 1
 			for {
 				log.Debug("ProxyName [%s], try to reconnect to server[%s:%d]...", cli.Name, client.ServerAddr, client.ServerPort)
-				tmpConn := loginToServer(cli)
-				if tmpConn != nil {
+				tmpConn, err := loginToServer(cli)
+				if err == nil {
 					c.Close()
 					c = tmpConn
 					break
 				}
 
 				if sleepTime < 60 {
-					sleepTime++
+					sleepTime = sleepTime * 2
 				}
 				time.Sleep(sleepTime * time.Second)
 			}
@@ -56,71 +57,62 @@ func ControlProcess(cli *client.ProxyClient, wait *sync.WaitGroup) {
 	}
 }
 
-func loginToServer(cli *client.ProxyClient) (connection *conn.Conn) {
-	c := &conn.Conn{}
-
-	connection = nil
-	for i := 0; i < 1; i++ {
-		err := c.ConnectServer(client.ServerAddr, client.ServerPort)
-		if err != nil {
-			log.Error("ProxyName [%s], connect to server [%s:%d] error, %v", cli.Name, client.ServerAddr, client.ServerPort, err)
-			break
-		}
-
-		req := &msg.ClientCtlReq{
-			Type:      consts.CtlConn,
-			ProxyName: cli.Name,
-			Passwd:    cli.Passwd,
-		}
-		buf, _ := json.Marshal(req)
-		err = c.Write(string(buf) + "\n")
-		if err != nil {
-			log.Error("ProxyName [%s], write to server error, %v", cli.Name, err)
-			break
-		}
-
-		res, err := c.ReadLine()
-		if err != nil {
-			log.Error("ProxyName [%s], read from server error, %v", cli.Name, err)
-			break
-		}
-		log.Debug("ProxyName [%s], read [%s]", cli.Name, res)
+func loginToServer(cli *client.ProxyClient) (c *conn.Conn, err error) {
+	c, err = conn.ConnectServer(client.ServerAddr, client.ServerPort)
+	if err != nil {
+		log.Error("ProxyName [%s], connect to server [%s:%d] error, %v", cli.Name, client.ServerAddr, client.ServerPort, err)
+		return
+	}
 
-		clientCtlRes := &msg.ClientCtlRes{}
-		if err = json.Unmarshal([]byte(res), &clientCtlRes); err != nil {
-			log.Error("ProxyName [%s], format server response error, %v", cli.Name, err)
-			break
-		}
+	req := &msg.ClientCtlReq{
+		Type:      consts.CtlConn,
+		ProxyName: cli.Name,
+		Passwd:    cli.Passwd,
+	}
+	buf, _ := json.Marshal(req)
+	err = c.Write(string(buf) + "\n")
+	if err != nil {
+		log.Error("ProxyName [%s], write to server error, %v", cli.Name, err)
+		return
+	}
 
-		if clientCtlRes.Code != 0 {
-			log.Error("ProxyName [%s], start proxy error, %s", cli.Name, clientCtlRes.Msg)
-			break
-		}
+	res, err := c.ReadLine()
+	if err != nil {
+		log.Error("ProxyName [%s], read from server error, %v", cli.Name, err)
+		return
+	}
+	log.Debug("ProxyName [%s], read [%s]", cli.Name, res)
 
-		connection = c
-		go startHeartBeat(connection)
-		log.Debug("ProxyName [%s], connect to server[%s:%d] success!", cli.Name, client.ServerAddr, client.ServerPort)
+	clientCtlRes := &msg.ClientCtlRes{}
+	if err = json.Unmarshal([]byte(res), &clientCtlRes); err != nil {
+		log.Error("ProxyName [%s], format server response error, %v", cli.Name, err)
+		return
 	}
 
-	if connection == nil {
-		c.Close()
+	if clientCtlRes.Code != 0 {
+		log.Error("ProxyName [%s], start proxy error, %s", cli.Name, clientCtlRes.Msg)
+		return c, fmt.Errorf("%s", clientCtlRes.Msg)
 	}
 
+	go startHeartBeat(c)
+	log.Debug("ProxyName [%s], connect to server[%s:%d] success!", cli.Name, client.ServerAddr, client.ServerPort)
+
 	return
 }
 
-func startHeartBeat(con *conn.Conn) {
-	isHeartBeatContinue = true
+func startHeartBeat(c *conn.Conn) {
 	log.Debug("Start to send heartbeat")
 	for {
 		time.Sleep(time.Duration(client.HeartBeatInterval) * time.Second)
-		if isHeartBeatContinue {
-			err := con.Write("\n")
+		if !c.IsClosed() {
+			err := c.Write("\n")
 			if err != nil {
 				log.Error("Send hearbeat to server failed! Err:%s", err.Error())
+				continue
 			}
 		} else {
 			break
 		}
 	}
+	log.Info("heartbeat exit")
 }

+ 12 - 12
cmd/frps/control.go

@@ -75,8 +75,9 @@ func controlWorker(c *conn.Conn) {
 	serverCtlReq := &msg.ClientCtlReq{}
 	serverCtlReq.Type = consts.WorkConn
 	for {
-		_, isStop := s.WaitUserConn()
-		if isStop {
+		closeFlag := s.WaitUserConn()
+		if closeFlag {
+			log.Debug("ProxyName [%s], goroutine for dealing user conn is closed", s.Name)
 			break
 		}
 		buf, _ := json.Marshal(serverCtlReq)
@@ -90,7 +91,7 @@ func controlWorker(c *conn.Conn) {
 		log.Debug("ProxyName [%s], write to client to add work conn success", s.Name)
 	}
 
-	log.Error("ProxyName [%s], I'm dead!", s.Name)
+	log.Info("ProxyName [%s], I'm dead!", s.Name)
 	return
 }
 
@@ -152,26 +153,25 @@ func readControlMsgFromClient(s *server.ProxyServer, c *conn.Conn) {
 	isContinueRead := true
 	f := func() {
 		isContinueRead = false
-		s.StopWaitUserConn()
+		c.Close()
+		s.Close()
 	}
 	timer := time.AfterFunc(time.Duration(server.HeartBeatTimeout)*time.Second, f)
 	defer timer.Stop()
 
 	for isContinueRead {
-		content, err := c.ReadLine()
-		//log.Debug("Receive msg from client! content:%s", content)
+		_, err := c.ReadLine()
 		if err != nil {
 			if err == io.EOF {
-				log.Warn("Server detect client[%s] is dead!", s.Name)
-				s.StopWaitUserConn()
+				log.Warn("ProxyName [%s], client is dead!", s.Name)
+				c.Close()
+				s.Close()
 				break
 			}
-			log.Error("ProxyName [%s], read error:%s", s.Name, err.Error())
+			log.Error("ProxyName [%s], read error: %v", s.Name, err)
 			continue
 		}
 
-		if content == "\n" {
-			timer.Reset(time.Duration(server.HeartBeatTimeout) * time.Second)
-		}
+		timer.Reset(time.Duration(server.HeartBeatTimeout) * time.Second)
 	}
 }

+ 2 - 4
models/client/client.go

@@ -16,8 +16,7 @@ type ProxyClient struct {
 }
 
 func (p *ProxyClient) GetLocalConn() (c *conn.Conn, err error) {
-	c = &conn.Conn{}
-	err = c.ConnectServer("127.0.0.1", p.LocalPort)
+	c, err = conn.ConnectServer("127.0.0.1", p.LocalPort)
 	if err != nil {
 		log.Error("ProxyName [%s], connect to local port error, %v", p.Name, err)
 	}
@@ -25,14 +24,13 @@ func (p *ProxyClient) GetLocalConn() (c *conn.Conn, err error) {
 }
 
 func (p *ProxyClient) GetRemoteConn(addr string, port int64) (c *conn.Conn, err error) {
-	c = &conn.Conn{}
 	defer func() {
 		if err != nil {
 			c.Close()
 		}
 	}()
 
-	err = c.ConnectServer(addr, port)
+	c, err = conn.ConnectServer(addr, port)
 	if err != nil {
 		log.Error("ProxyName [%s], connect to server [%s:%d] error, %v", p.Name, addr, port, err)
 		return

+ 44 - 38
models/server/server.go

@@ -10,39 +10,38 @@ import (
 )
 
 type ProxyServer struct {
-	Name       string
-	Passwd     string
-	BindAddr   string
-	ListenPort int64
-
-	Status        int64
-	Listener      *conn.Listener  // accept new connection from remote users
-	CtlMsgChan    chan int64      // every time accept a new user conn, put "1" to the channel
-	StopBlockChan chan int64      // put any number to the channel, if you want to stop wait user conn
-	CliConnChan   chan *conn.Conn // get client conns from control goroutine
-	UserConnList  *list.List      // store user conns
-	Mutex         sync.Mutex
+	Name        string
+	Passwd      string
+	BindAddr    string
+	ListenPort  int64
+	Status      int64
+	CliConnChan chan *conn.Conn // get client conns from control goroutine
+
+	listener     *conn.Listener // accept new connection from remote users
+	ctlMsgChan   chan int64     // every time accept a new user conn, put "1" to the channel
+	userConnList *list.List     // store user conns
+	mutex        sync.Mutex
 }
 
 func (p *ProxyServer) Init() {
 	p.Status = consts.Idle
-	p.CtlMsgChan = make(chan int64)
-	p.StopBlockChan = make(chan int64)
 	p.CliConnChan = make(chan *conn.Conn)
-	p.UserConnList = list.New()
+	p.ctlMsgChan = make(chan int64)
+	p.userConnList = list.New()
 }
 
 func (p *ProxyServer) Lock() {
-	p.Mutex.Lock()
+	p.mutex.Lock()
 }
 
 func (p *ProxyServer) Unlock() {
-	p.Mutex.Unlock()
+	p.mutex.Unlock()
 }
 
 // start listening for user conns
 func (p *ProxyServer) Start() (err error) {
-	p.Listener, err = conn.Listen(p.BindAddr, p.ListenPort)
+	p.Init()
+	p.listener, err = conn.Listen(p.BindAddr, p.ListenPort)
 	if err != nil {
 		return err
 	}
@@ -53,10 +52,15 @@ func (p *ProxyServer) Start() (err error) {
 	go func() {
 		for {
 			// block
-			c := p.Listener.GetConn()
+			// if listener is closed, get nil
+			c := p.listener.GetConn()
+			if c == nil {
+				log.Info("ProxyName [%s], listener is closed", p.Name)
+				return
+			}
 			log.Debug("ProxyName [%s], get one new user conn [%s]", p.Name, c.GetRemoteAddr())
 
-			// put to list
+			// insert into list
 			p.Lock()
 			if p.Status != consts.Working {
 				log.Debug("ProxyName [%s] is not working, new user conn close", p.Name)
@@ -64,25 +68,29 @@ func (p *ProxyServer) Start() (err error) {
 				p.Unlock()
 				return
 			}
-			p.UserConnList.PushBack(c)
+			p.userConnList.PushBack(c)
 			p.Unlock()
 
 			// put msg to control conn
-			p.CtlMsgChan <- 1
+			p.ctlMsgChan <- 1
 		}
 	}()
 
 	// start another goroutine for join two conns from client and user
 	go func() {
 		for {
-			cliConn := <-p.CliConnChan
+			cliConn, ok := <-p.CliConnChan
+			if !ok {
+				return
+			}
+
 			p.Lock()
-			element := p.UserConnList.Front()
+			element := p.userConnList.Front()
 
 			var userConn *conn.Conn
 			if element != nil {
 				userConn = element.Value.(*conn.Conn)
-				p.UserConnList.Remove(element)
+				p.userConnList.Remove(element)
 			} else {
 				cliConn.Close()
 				p.Unlock()
@@ -104,21 +112,19 @@ func (p *ProxyServer) Start() (err error) {
 func (p *ProxyServer) Close() {
 	p.Lock()
 	p.Status = consts.Idle
-	p.CtlMsgChan = make(chan int64)
-	p.CliConnChan = make(chan *conn.Conn)
-	p.UserConnList = list.New()
+	p.listener.Close()
+	close(p.ctlMsgChan)
+	close(p.CliConnChan)
+	p.userConnList = list.New()
 	p.Unlock()
 }
 
-func (p *ProxyServer) WaitUserConn() (res int64, isStop bool) {
-	select {
-	case res = <-p.CtlMsgChan:
-		return res, false
-	case <-p.StopBlockChan:
-		return 0, true
-	}
-}
+func (p *ProxyServer) WaitUserConn() (closeFlag bool) {
+	closeFlag = false
 
-func (p *ProxyServer) StopWaitUserConn() {
-	p.StopBlockChan <- 1
+	_, ok := <-p.ctlMsgChan
+	if !ok {
+		closeFlag = true
+	}
+	return
 }

+ 71 - 38
utils/conn/conn.go

@@ -11,33 +11,87 @@ import (
 )
 
 type Listener struct {
-	Addr  net.Addr
-	Conns chan *Conn
+	addr      net.Addr
+	l         *net.TCPListener
+	conns     chan *Conn
+	closeFlag bool
 }
 
-// wait util get one
+func Listen(bindAddr string, bindPort int64) (l *Listener, err error) {
+	tcpAddr, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", bindAddr, bindPort))
+	listener, err := net.ListenTCP("tcp", tcpAddr)
+	if err != nil {
+		return l, err
+	}
+
+	l = &Listener{
+		addr:      listener.Addr(),
+		l:         listener,
+		conns:     make(chan *Conn),
+		closeFlag: false,
+	}
+
+	go func() {
+		for {
+			conn, err := l.l.AcceptTCP()
+			if err != nil {
+				if l.closeFlag {
+					return
+				}
+				continue
+			}
+
+			c := &Conn{
+				TcpConn:   conn,
+				closeFlag: false,
+			}
+			c.Reader = bufio.NewReader(c.TcpConn)
+			l.conns <- c
+		}
+	}()
+	return l, err
+}
+
+// wait util get one new connection or close
+// if listener is closed, return nil
 func (l *Listener) GetConn() (conn *Conn) {
-	conn = <-l.Conns
+	var ok bool
+	conn, ok = <-l.conns
+	if !ok {
+		return nil
+	}
 	return conn
 }
 
+func (l *Listener) Close() {
+	if l.l != nil && l.closeFlag == false {
+		l.closeFlag = true
+		l.l.Close()
+		close(l.conns)
+	}
+}
+
+// wrap for TCPConn
 type Conn struct {
-	TcpConn *net.TCPConn
-	Reader  *bufio.Reader
+	TcpConn   *net.TCPConn
+	Reader    *bufio.Reader
+	closeFlag bool
 }
 
-func (c *Conn) ConnectServer(host string, port int64) (err error) {
+func ConnectServer(host string, port int64) (c *Conn, err error) {
+	c = &Conn{}
 	servertAddr, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", host, port))
 	if err != nil {
-		return err
+		return
 	}
 	conn, err := net.DialTCP("tcp", nil, servertAddr)
 	if err != nil {
-		return err
+		return
 	}
 	c.TcpConn = conn
 	c.Reader = bufio.NewReader(c.TcpConn)
-	return nil
+	c.closeFlag = false
+	return c, nil
 }
 
 func (c *Conn) GetRemoteAddr() (addr string) {
@@ -50,6 +104,9 @@ func (c *Conn) GetLocalAddr() (addr string) {
 
 func (c *Conn) ReadLine() (buff string, err error) {
 	buff, err = c.Reader.ReadString('\n')
+	if err == io.EOF {
+		c.closeFlag = true
+	}
 	return buff, err
 }
 
@@ -60,40 +117,16 @@ func (c *Conn) Write(content string) (err error) {
 
 func (c *Conn) Close() {
 	if c.TcpConn != nil {
+		c.closeFlag = true
 		c.TcpConn.Close()
 	}
 }
 
-func Listen(bindAddr string, bindPort int64) (l *Listener, err error) {
-	tcpAddr, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", bindAddr, bindPort))
-	listener, err := net.ListenTCP("tcp", tcpAddr)
-	if err != nil {
-		return l, err
-	}
-
-	l = &Listener{
-		Addr:  listener.Addr(),
-		Conns: make(chan *Conn),
-	}
-
-	go func() {
-		for {
-			conn, err := listener.AcceptTCP()
-			if err != nil {
-				continue
-			}
-
-			c := &Conn{
-				TcpConn: conn,
-			}
-			c.Reader = bufio.NewReader(c.TcpConn)
-			l.Conns <- c
-		}
-	}()
-	return l, err
+func (c *Conn) IsClosed() bool {
+	return c.closeFlag
 }
 
-// will block until conn close
+// will block until connection close
 func Join(c1 *Conn, c2 *Conn) {
 	var wait sync.WaitGroup
 	pipe := func(to *Conn, from *Conn) {