Ver código fonte

Merge branch 'fix_some_bugs' of https://github.com/Hurricanezwf/frp into Hurricanezwf-fix_some_bugs

Conflicts:
	cmd/frpc/config.go
	cmd/frpc/control.go
	cmd/frpc/main.go
	cmd/frps/config.go
	cmd/frps/control.go
	cmd/frps/main.go
	pkg/models/server.go
fatedier 9 anos atrás
pai
commit
e805acd1e3

+ 3 - 0
.gitignore

@@ -26,3 +26,6 @@ _testmain.go
 # Self
 bin/
 
+# Cache
+*.swp
+*.swo

+ 6 - 5
cmd/frpc/config.go

@@ -11,11 +11,12 @@ import (
 
 // common config
 var (
-	ServerAddr string = "0.0.0.0"
-	ServerPort int64  = 7000
-	LogFile    string = "./frpc.log"
-	LogLevel   string = "warn"
-	LogWay     string = "file"
+	ServerAddr        string = "0.0.0.0"
+	ServerPort        int64  = 7000
+	LogFile           string = "./frpc.log"
+	LogLevel          string = "warn"
+	LogWay            string = "file"
+	HeartBeatInterval int64  = 5
 )
 
 var ProxyClients map[string]*models.ProxyClient = make(map[string]*models.ProxyClient)

+ 92 - 35
cmd/frpc/control.go

@@ -4,59 +4,47 @@ import (
 	"encoding/json"
 	"io"
 	"sync"
+	"time"
 
 	"github.com/fatedier/frp/pkg/models"
 	"github.com/fatedier/frp/pkg/utils/conn"
 	"github.com/fatedier/frp/pkg/utils/log"
 )
 
+var isHeartBeatContinue bool = true
+
 func ControlProcess(cli *models.ProxyClient, wait *sync.WaitGroup) {
 	defer wait.Done()
 
-	c := &conn.Conn{}
-	err := c.ConnectServer(ServerAddr, ServerPort)
-	if err != nil {
-		log.Error("ProxyName [%s], connect to server [%s:%d] error, %v", cli.Name, ServerAddr, ServerPort, err)
+	c := loginToServer(cli)
+	if c == nil {
+		log.Error("ProxyName [%s], connect to server failed!", cli.Name)
 		return
 	}
 	defer c.Close()
 
-	req := &models.ClientCtlReq{
-		Type:      models.ControlConn,
-		ProxyName: cli.Name,
-		Passwd:    cli.Passwd,
-	}
-	buf, _ := json.Marshal(req)
-	err = c.Write(string(buf) + "\n")
-	if err != nil {
-		log.Error("ProxyName [%s], write to server error, %v", cli.Name, err)
-		return
-	}
-
-	res, err := c.ReadLine()
-	if err != nil {
-		log.Error("ProxyName [%s], read from server error, %v", cli.Name, err)
-		return
-	}
-	log.Debug("ProxyName [%s], read [%s]", cli.Name, res)
-
-	clientCtlRes := &models.ClientCtlRes{}
-	if err = json.Unmarshal([]byte(res), &clientCtlRes); err != nil {
-		log.Error("ProxyName [%s], format server response error, %v", cli.Name, err)
-		return
-	}
-
-	if clientCtlRes.Code != 0 {
-		log.Error("ProxyName [%s], start proxy error, %s", cli.Name, clientCtlRes.Msg)
-		return
-	}
-
 	for {
 		// ignore response content now
 		_, err := c.ReadLine()
 		if err == io.EOF {
+			isHeartBeatContinue = false
 			log.Debug("ProxyName [%s], server close this control conn", cli.Name)
-			break
+			var sleepTime time.Duration = 1
+			for {
+				log.Debug("ProxyName [%s], try to reconnect to server[%s:%d]...", cli.Name, ServerAddr, ServerPort)
+				tmpConn := loginToServer(cli)
+				if tmpConn != nil {
+					c.Close()
+					c = tmpConn
+					break
+				}
+
+				if sleepTime < 60 {
+					sleepTime++
+				}
+				time.Sleep(sleepTime * time.Second)
+			}
+			continue
 		} else if err != nil {
 			log.Warn("ProxyName [%s], read from server error, %v", cli.Name, err)
 			continue
@@ -65,3 +53,72 @@ func ControlProcess(cli *models.ProxyClient, wait *sync.WaitGroup) {
 		cli.StartTunnel(ServerAddr, ServerPort)
 	}
 }
+
+func loginToServer(cli *models.ProxyClient) (connection *conn.Conn) {
+	c := &conn.Conn{}
+
+	connection = nil
+	for i := 0; i < 1; i++ {
+		err := c.ConnectServer(ServerAddr, ServerPort)
+		if err != nil {
+			log.Error("ProxyName [%s], connect to server [%s:%d] error, %v", cli.Name, ServerAddr, ServerPort, err)
+			break
+		}
+
+		req := &models.ClientCtlReq{
+			Type:      models.ControlConn,
+			ProxyName: cli.Name,
+			Passwd:    cli.Passwd,
+		}
+		buf, _ := json.Marshal(req)
+		err = c.Write(string(buf) + "\n")
+		if err != nil {
+			log.Error("ProxyName [%s], write to server error, %v", cli.Name, err)
+			break
+		}
+
+		res, err := c.ReadLine()
+		if err != nil {
+			log.Error("ProxyName [%s], read from server error, %v", cli.Name, err)
+			break
+		}
+		log.Debug("ProxyName [%s], read [%s]", cli.Name, res)
+
+		clientCtlRes := &models.ClientCtlRes{}
+		if err = json.Unmarshal([]byte(res), &clientCtlRes); err != nil {
+			log.Error("ProxyName [%s], format server response error, %v", cli.Name, err)
+			break
+		}
+
+		if clientCtlRes.Code != 0 {
+			log.Error("ProxyName [%s], start proxy error, %s", cli.Name, clientCtlRes.Msg)
+			break
+		}
+
+		connection = c
+		go startHeartBeat(connection)
+		log.Debug("ProxyName [%s], connect to server[%s:%d] success!", cli.Name, ServerAddr, ServerPort)
+	}
+
+	if connection == nil {
+		c.Close()
+	}
+
+	return
+}
+
+func startHeartBeat(con *conn.Conn) {
+	isHeartBeatContinue = true
+	log.Debug("Start to send heartbeat")
+	for {
+		time.Sleep(time.Duration(HeartBeatInterval) * time.Second)
+		if isHeartBeatContinue {
+			err := con.Write("\n")
+			if err != nil {
+				log.Error("Send hearbeat to server failed! Err:%s", err.Error())
+			}
+		} else {
+			break
+		}
+	}
+}

+ 6 - 5
cmd/frps/config.go

@@ -11,11 +11,12 @@ import (
 
 // common config
 var (
-	BindAddr string = "0.0.0.0"
-	BindPort int64  = 9527
-	LogFile  string = "./frps.log"
-	LogLevel string = "warn"
-	LogWay   string = "file"
+	BindAddr         string = "0.0.0.0"
+	BindPort         int64  = 9527
+	LogFile          string = "./frps.log"
+	LogLevel         string = "warn"
+	LogWay           string = "file"
+	HeartBeatTimeout int64  = 30
 )
 
 var ProxyServers map[string]*models.ProxyServer = make(map[string]*models.ProxyServer)

+ 45 - 5
cmd/frps/control.go

@@ -3,6 +3,8 @@ package main
 import (
 	"encoding/json"
 	"fmt"
+	"io"
+	"time"
 
 	"github.com/fatedier/frp/pkg/models"
 	"github.com/fatedier/frp/pkg/utils/conn"
@@ -17,7 +19,7 @@ func ProcessControlConn(l *conn.Listener) {
 	}
 }
 
-// control connection from every client and server
+// connection from every client and server
 func controlWorker(c *conn.Conn) {
 	// the first message is from client to server
 	// if error, close connection
@@ -43,17 +45,21 @@ func controlWorker(c *conn.Conn) {
 	}
 
 	if needRes {
+		// control conn
+		defer c.Close()
+
 		buf, _ := json.Marshal(clientCtlRes)
 		err = c.Write(string(buf) + "\n")
 		if err != nil {
 			log.Warn("Write error, %v", err)
+			time.Sleep(1 * time.Second)
+			return
 		}
 	} else {
 		// work conn, just return
 		return
 	}
 
-	defer c.Close()
 	// others is from server to client
 	server, ok := ProxyServers[clientCtlReq.ProxyName]
 	if !ok {
@@ -61,10 +67,16 @@ func controlWorker(c *conn.Conn) {
 		return
 	}
 
+	// read control msg from client
+	go readControlMsgFromClient(server, c)
+
 	serverCtlReq := &models.ClientCtlReq{}
 	serverCtlReq.Type = models.WorkConn
 	for {
-		server.WaitUserConn()
+		_, isStop := server.WaitUserConn()
+		if isStop {
+			break
+		}
 		buf, _ := json.Marshal(serverCtlReq)
 		err = c.Write(string(buf) + "\n")
 		if err != nil {
@@ -76,6 +88,7 @@ func controlWorker(c *conn.Conn) {
 		log.Debug("ProxyName [%s], write to client to add work conn success", server.Name)
 	}
 
+	log.Error("ProxyName [%s], I'm dead!", server.Name)
 	return
 }
 
@@ -124,11 +137,38 @@ func checkProxy(req *models.ClientCtlReq, c *conn.Conn) (succ bool, msg string,
 
 		server.CliConnChan <- c
 	} else {
-		msg = fmt.Sprintf("ProxyName [%s], type [%d] unsupport", req.ProxyName)
-		log.Warn(msg)
+		log.Warn("ProxyName [%s], type [%d] unsupport", req.ProxyName, req.Type)
 		return
 	}
 
 	succ = true
 	return
 }
+
+func readControlMsgFromClient(server *models.ProxyServer, c *conn.Conn) {
+	isContinueRead := true
+	f := func() {
+		isContinueRead = false
+		server.StopWaitUserConn()
+	}
+	timer := time.AfterFunc(time.Duration(HeartBeatTimeout)*time.Second, f)
+	defer timer.Stop()
+
+	for isContinueRead {
+		content, err := c.ReadLine()
+		//log.Debug("Receive msg from client! content:%s", content)
+		if err != nil {
+			if err == io.EOF {
+				log.Warn("Server detect client[%s] is dead!", server.Name)
+				server.StopWaitUserConn()
+				break
+			}
+			log.Error("ProxyName [%s], read error:%s", server.Name, err.Error())
+			continue
+		}
+
+		if content == "\n" {
+			timer.Reset(time.Duration(HeartBeatTimeout) * time.Second)
+		}
+	}
+}

+ 2 - 2
conf/frpc.ini

@@ -4,9 +4,9 @@ server_addr = 127.0.0.1
 bind_port = 7000
 log_file = ./frpc.log
 # debug, info, warn, error
-log_level = info
+log_level = debug
 # file, console
-log_way = file
+log_way = console
 
 # test1即为name
 [test1]

+ 2 - 2
conf/frps.ini

@@ -4,9 +4,9 @@ bind_addr = 0.0.0.0
 bind_port = 7000
 log_file = ./frps.log
 # debug, info, warn, error
-log_level = info
+log_level = debug
 # file, console
-log_way = file
+log_way = console 
 
 # test1即为name
 [test1]

+ 1 - 0
pkg/models/client.go

@@ -63,6 +63,7 @@ func (p *ProxyClient) StartTunnel(serverAddr string, serverPort int64) (err erro
 		return
 	}
 
+	// l means local, r means remote
 	log.Debug("Join two conns, (l[%s] r[%s]) (l[%s] r[%s])", localConn.GetLocalAddr(), localConn.GetRemoteAddr(),
 		remoteConn.GetLocalAddr(), remoteConn.GetRemoteAddr())
 	go conn.Join(localConn, remoteConn)

+ 21 - 9
pkg/models/server.go

@@ -19,17 +19,19 @@ type ProxyServer struct {
 	BindAddr   string
 	ListenPort int64
 
-	Status       int64
-	Listener     *conn.Listener  // accept new connection from remote users
-	CtlMsgChan   chan int64      // every time accept a new user conn, put "1" to the channel
-	CliConnChan  chan *conn.Conn // get client conns from control goroutine
-	UserConnList *list.List      // store user conns
-	Mutex        sync.Mutex
+	Status        int64
+	Listener      *conn.Listener  // accept new connection from remote users
+	CtlMsgChan    chan int64      // every time accept a new user conn, put "1" to the channel
+	StopBlockChan chan int64      // put any number to the channel, if you want to stop wait user conn
+	CliConnChan   chan *conn.Conn // get client conns from control goroutine
+	UserConnList  *list.List      // store user conns
+	Mutex         sync.Mutex
 }
 
 func (p *ProxyServer) Init() {
 	p.Status = Idle
 	p.CtlMsgChan = make(chan int64)
+	p.StopBlockChan = make(chan int64)
 	p.CliConnChan = make(chan *conn.Conn)
 	p.UserConnList = list.New()
 }
@@ -87,11 +89,13 @@ func (p *ProxyServer) Start() (err error) {
 				p.UserConnList.Remove(element)
 			} else {
 				cliConn.Close()
+				p.Unlock()
 				continue
 			}
 			p.Unlock()
 
 			// msg will transfer to another without modifying
+			// l means local, r means remote
 			log.Debug("Join two conns, (l[%s] r[%s]) (l[%s] r[%s])", cliConn.GetLocalAddr(), cliConn.GetRemoteAddr(),
 				userConn.GetLocalAddr(), userConn.GetRemoteAddr())
 			go conn.Join(cliConn, userConn)
@@ -110,7 +114,15 @@ func (p *ProxyServer) Close() {
 	p.Unlock()
 }
 
-func (p *ProxyServer) WaitUserConn() (res int64) {
-	res = <-p.CtlMsgChan
-	return
+func (p *ProxyServer) WaitUserConn() (res int64, isStop bool) {
+	select {
+	case res = <-p.CtlMsgChan:
+		return res, false
+	case <-p.StopBlockChan:
+		return 0, true
+	}
+}
+
+func (p *ProxyServer) StopWaitUserConn() {
+	p.StopBlockChan <- 1
 }

+ 3 - 1
pkg/utils/conn/conn.go

@@ -59,7 +59,9 @@ func (c *Conn) Write(content string) (err error) {
 }
 
 func (c *Conn) Close() {
-	c.TcpConn.Close()
+	if c.TcpConn != nil {
+		c.TcpConn.Close()
+	}
 }
 
 func Listen(bindAddr string, bindPort int64) (l *Listener, err error) {

+ 4 - 4
pkg/utils/pcrypto/pcrypto_test.go

@@ -6,7 +6,7 @@ import (
 	"testing"
 )
 
-func Test_Encrypto(t *testing.T) {
+func TestEncrypto(t *testing.T) {
 	pp := new(Pcrypto)
 	pp.Init([]byte("Hana"))
 	res, err := pp.Encrypto([]byte("Just One Test!"))
@@ -17,7 +17,7 @@ func Test_Encrypto(t *testing.T) {
 	fmt.Printf("[%x]\n", res)
 }
 
-func Test_Decrypto(t *testing.T) {
+func TestDecrypto(t *testing.T) {
 	pp := new(Pcrypto)
 	pp.Init([]byte("Hana"))
 	res, err := pp.Encrypto([]byte("Just One Test!"))
@@ -33,13 +33,13 @@ func Test_Decrypto(t *testing.T) {
 	fmt.Printf("[%s]\n", string(res))
 }
 
-func Test_PKCS7Padding(t *testing.T) {
+func TestPKCS7Padding(t *testing.T) {
 	ltt := []byte("Test_PKCS7Padding")
 	ltt = PKCS7Padding(ltt, aes.BlockSize)
 	fmt.Printf("[%x]\n", (ltt))
 }
 
-func Test_PKCS7UnPadding(t *testing.T) {
+func TestPKCS7UnPadding(t *testing.T) {
 	ltt := []byte("Test_PKCS7Padding")
 	ltt = PKCS7Padding(ltt, aes.BlockSize)
 	ltt = PKCS7UnPadding(ltt)