Browse Source

Code refactoring related to message handling and retry logic. (#3745)

fatedier 1 năm trước cách đây
mục cha
commit
184223cb2f
11 tập tin đã thay đổi với 690 bổ sung529 xóa
  1. 1 0
      Release.md
  2. 8 1
      client/admin_api.go
  3. 103 172
      client/control.go
  4. 93 106
      client/service.go
  5. 14 0
      pkg/metrics/metrics.go
  6. 103 0
      pkg/msg/handler.go
  7. 2 0
      pkg/transport/message.go
  8. 16 0
      pkg/util/net/conn.go
  9. 197 0
      pkg/util/wait/backoff.go
  10. 141 244
      server/control.go
  11. 12 6
      server/service.go

+ 1 - 0
Release.md

@@ -1,3 +1,4 @@
 ### Fixes
 
 * frpc: Return code 1 when the first login attempt fails and exits.
+* When auth.method is `oidc` and auth.additionalScopes contains `HeartBeats`, if obtaining AccessToken fails, the application will be unresponsive.

+ 8 - 1
client/admin_api.go

@@ -144,7 +144,14 @@ func (svr *Service) apiStatus(w http.ResponseWriter, _ *http.Request) {
 		_, _ = w.Write(buf)
 	}()
 
-	ps := svr.ctl.pm.GetAllProxyStatus()
+	svr.ctlMu.RLock()
+	ctl := svr.ctl
+	svr.ctlMu.RUnlock()
+	if ctl == nil {
+		return
+	}
+
+	ps := ctl.pm.GetAllProxyStatus()
 	for _, status := range ps {
 		res[status.Type] = append(res[status.Type], NewProxyStatusResp(status, svr.cfg.ServerAddr))
 	}

+ 103 - 172
client/control.go

@@ -16,13 +16,10 @@ package client
 
 import (
 	"context"
-	"io"
 	"net"
-	"runtime/debug"
+	"sync/atomic"
 	"time"
 
-	"github.com/fatedier/golib/control/shutdown"
-	"github.com/fatedier/golib/crypto"
 	"github.com/samber/lo"
 
 	"github.com/fatedier/frp/client/proxy"
@@ -31,6 +28,8 @@ import (
 	v1 "github.com/fatedier/frp/pkg/config/v1"
 	"github.com/fatedier/frp/pkg/msg"
 	"github.com/fatedier/frp/pkg/transport"
+	utilnet "github.com/fatedier/frp/pkg/util/net"
+	"github.com/fatedier/frp/pkg/util/wait"
 	"github.com/fatedier/frp/pkg/util/xlog"
 )
 
@@ -39,6 +38,12 @@ type Control struct {
 	ctx context.Context
 	xl  *xlog.Logger
 
+	// The client configuration
+	clientCfg *v1.ClientCommonConfig
+
+	// sets authentication based on selected method
+	authSetter auth.Setter
+
 	// Unique ID obtained from frps.
 	// It should be attached to the login message when reconnecting.
 	runID string
@@ -50,36 +55,25 @@ type Control struct {
 	// manage all visitors
 	vm *visitor.Manager
 
-	// control connection
+	// control connection. Once conn is closed, the msgDispatcher and the entire Control will exit.
 	conn net.Conn
 
+	// use cm to create new connections, which could be real TCP connections or virtual streams.
 	cm *ConnectionManager
 
-	// put a message in this channel to send it over control connection to server
-	sendCh chan (msg.Message)
-
-	// read from this channel to get the next message sent by server
-	readCh chan (msg.Message)
-
-	// goroutines can block by reading from this channel, it will be closed only in reader() when control connection is closed
-	closedCh chan struct{}
-
-	closedDoneCh chan struct{}
+	doneCh chan struct{}
 
-	// last time got the Pong message
-	lastPong time.Time
-
-	// The client configuration
-	clientCfg *v1.ClientCommonConfig
-
-	readerShutdown     *shutdown.Shutdown
-	writerShutdown     *shutdown.Shutdown
-	msgHandlerShutdown *shutdown.Shutdown
-
-	// sets authentication based on selected method
-	authSetter auth.Setter
+	// of time.Time, last time got the Pong message
+	lastPong atomic.Value
 
+	// The role of msgTransporter is similar to HTTP2.
+	// It allows multiple messages to be sent simultaneously on the same control connection.
+	// The server's response messages will be dispatched to the corresponding waiting goroutines based on the laneKey and message type.
 	msgTransporter transport.MessageTransporter
+
+	// msgDispatcher is a wrapper for control connection.
+	// It provides a channel for sending messages, and you can register handlers to process messages based on their respective types.
+	msgDispatcher *msg.Dispatcher
 }
 
 func NewControl(
@@ -88,31 +82,34 @@ func NewControl(
 	pxyCfgs []v1.ProxyConfigurer,
 	visitorCfgs []v1.VisitorConfigurer,
 	authSetter auth.Setter,
-) *Control {
+) (*Control, error) {
 	// new xlog instance
 	ctl := &Control{
-		ctx:                ctx,
-		xl:                 xlog.FromContextSafe(ctx),
-		runID:              runID,
-		conn:               conn,
-		cm:                 cm,
-		pxyCfgs:            pxyCfgs,
-		sendCh:             make(chan msg.Message, 100),
-		readCh:             make(chan msg.Message, 100),
-		closedCh:           make(chan struct{}),
-		closedDoneCh:       make(chan struct{}),
-		clientCfg:          clientCfg,
-		readerShutdown:     shutdown.New(),
-		writerShutdown:     shutdown.New(),
-		msgHandlerShutdown: shutdown.New(),
-		authSetter:         authSetter,
+		ctx:        ctx,
+		xl:         xlog.FromContextSafe(ctx),
+		clientCfg:  clientCfg,
+		authSetter: authSetter,
+		runID:      runID,
+		pxyCfgs:    pxyCfgs,
+		conn:       conn,
+		cm:         cm,
+		doneCh:     make(chan struct{}),
+	}
+	ctl.lastPong.Store(time.Now())
+
+	cryptoRW, err := utilnet.NewCryptoReadWriter(conn, []byte(clientCfg.Auth.Token))
+	if err != nil {
+		return nil, err
 	}
-	ctl.msgTransporter = transport.NewMessageTransporter(ctl.sendCh)
-	ctl.pm = proxy.NewManager(ctl.ctx, clientCfg, ctl.msgTransporter)
 
+	ctl.msgDispatcher = msg.NewDispatcher(cryptoRW)
+	ctl.registerMsgHandlers()
+	ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher.SendChannel())
+
+	ctl.pm = proxy.NewManager(ctl.ctx, clientCfg, ctl.msgTransporter)
 	ctl.vm = visitor.NewManager(ctl.ctx, ctl.runID, ctl.clientCfg, ctl.connectServer, ctl.msgTransporter)
 	ctl.vm.Reload(visitorCfgs)
-	return ctl
+	return ctl, nil
 }
 
 func (ctl *Control) Run() {
@@ -125,7 +122,7 @@ func (ctl *Control) Run() {
 	go ctl.vm.Run()
 }
 
-func (ctl *Control) HandleReqWorkConn(_ *msg.ReqWorkConn) {
+func (ctl *Control) handleReqWorkConn(_ msg.Message) {
 	xl := ctl.xl
 	workConn, err := ctl.connectServer()
 	if err != nil {
@@ -162,8 +159,9 @@ func (ctl *Control) HandleReqWorkConn(_ *msg.ReqWorkConn) {
 	ctl.pm.HandleWorkConn(startMsg.ProxyName, workConn, &startMsg)
 }
 
-func (ctl *Control) HandleNewProxyResp(inMsg *msg.NewProxyResp) {
+func (ctl *Control) handleNewProxyResp(m msg.Message) {
 	xl := ctl.xl
+	inMsg := m.(*msg.NewProxyResp)
 	// Server will return NewProxyResp message to each NewProxy message.
 	// Start a new proxy handler if no error got
 	err := ctl.pm.StartProxy(inMsg.ProxyName, inMsg.RemoteAddr, inMsg.Error)
@@ -174,8 +172,9 @@ func (ctl *Control) HandleNewProxyResp(inMsg *msg.NewProxyResp) {
 	}
 }
 
-func (ctl *Control) HandleNatHoleResp(inMsg *msg.NatHoleResp) {
+func (ctl *Control) handleNatHoleResp(m msg.Message) {
 	xl := ctl.xl
+	inMsg := m.(*msg.NatHoleResp)
 
 	// Dispatch the NatHoleResp message to the related proxy.
 	ok := ctl.msgTransporter.DispatchWithType(inMsg, msg.TypeNameNatHoleResp, inMsg.TransactionID)
@@ -184,6 +183,19 @@ func (ctl *Control) HandleNatHoleResp(inMsg *msg.NatHoleResp) {
 	}
 }
 
+func (ctl *Control) handlePong(m msg.Message) {
+	xl := ctl.xl
+	inMsg := m.(*msg.Pong)
+
+	if inMsg.Error != "" {
+		xl.Error("Pong message contains error: %s", inMsg.Error)
+		ctl.conn.Close()
+		return
+	}
+	ctl.lastPong.Store(time.Now())
+	xl.Debug("receive heartbeat from server")
+}
+
 func (ctl *Control) Close() error {
 	return ctl.GracefulClose(0)
 }
@@ -199,9 +211,9 @@ func (ctl *Control) GracefulClose(d time.Duration) error {
 	return nil
 }
 
-// ClosedDoneCh returns a channel that will be closed after all resources are released
-func (ctl *Control) ClosedDoneCh() <-chan struct{} {
-	return ctl.closedDoneCh
+// Done returns a channel that will be closed after all resources are released
+func (ctl *Control) Done() <-chan struct{} {
+	return ctl.doneCh
 }
 
 // connectServer return a new connection to frps
@@ -209,151 +221,70 @@ func (ctl *Control) connectServer() (conn net.Conn, err error) {
 	return ctl.cm.Connect()
 }
 
-// reader read all messages from frps and send to readCh
-func (ctl *Control) reader() {
-	xl := ctl.xl
-	defer func() {
-		if err := recover(); err != nil {
-			xl.Error("panic error: %v", err)
-			xl.Error(string(debug.Stack()))
-		}
-	}()
-	defer ctl.readerShutdown.Done()
-	defer close(ctl.closedCh)
-
-	encReader := crypto.NewReader(ctl.conn, []byte(ctl.clientCfg.Auth.Token))
-	for {
-		m, err := msg.ReadMsg(encReader)
-		if err != nil {
-			if err == io.EOF {
-				xl.Debug("read from control connection EOF")
-				return
-			}
-			xl.Warn("read error: %v", err)
-			ctl.conn.Close()
-			return
-		}
-		ctl.readCh <- m
-	}
+func (ctl *Control) registerMsgHandlers() {
+	ctl.msgDispatcher.RegisterHandler(&msg.ReqWorkConn{}, msg.AsyncHandler(ctl.handleReqWorkConn))
+	ctl.msgDispatcher.RegisterHandler(&msg.NewProxyResp{}, ctl.handleNewProxyResp)
+	ctl.msgDispatcher.RegisterHandler(&msg.NatHoleResp{}, ctl.handleNatHoleResp)
+	ctl.msgDispatcher.RegisterHandler(&msg.Pong{}, ctl.handlePong)
 }
 
-// writer writes messages got from sendCh to frps
-func (ctl *Control) writer() {
+// headerWorker sends heartbeat to server and check heartbeat timeout.
+func (ctl *Control) heartbeatWorker() {
 	xl := ctl.xl
-	defer ctl.writerShutdown.Done()
-	encWriter, err := crypto.NewWriter(ctl.conn, []byte(ctl.clientCfg.Auth.Token))
-	if err != nil {
-		xl.Error("crypto new writer error: %v", err)
-		ctl.conn.Close()
-		return
-	}
-	for {
-		m, ok := <-ctl.sendCh
-		if !ok {
-			xl.Info("control writer is closing")
-			return
-		}
 
-		if err := msg.WriteMsg(encWriter, m); err != nil {
-			xl.Warn("write message to control connection error: %v", err)
-			return
-		}
-	}
-}
-
-// msgHandler handles all channel events and performs corresponding operations.
-func (ctl *Control) msgHandler() {
-	xl := ctl.xl
-	defer func() {
-		if err := recover(); err != nil {
-			xl.Error("panic error: %v", err)
-			xl.Error(string(debug.Stack()))
+	// TODO(fatedier): Change default value of HeartbeatInterval to -1 if tcpmux is enabled.
+	// Users can still enable heartbeat feature by setting HeartbeatInterval to a positive value.
+	if ctl.clientCfg.Transport.HeartbeatInterval > 0 {
+		// send heartbeat to server
+		sendHeartBeat := func() error {
+			xl.Debug("send heartbeat to server")
+			pingMsg := &msg.Ping{}
+			if err := ctl.authSetter.SetPing(pingMsg); err != nil {
+				xl.Warn("error during ping authentication: %v, skip sending ping message", err)
+				return err
+			}
+			_ = ctl.msgDispatcher.Send(pingMsg)
+			return nil
 		}
-	}()
-	defer ctl.msgHandlerShutdown.Done()
 
-	var hbSendCh <-chan time.Time
-	// TODO(fatedier): disable heartbeat if TCPMux is enabled.
-	// Just keep it here to keep compatible with old version frps.
-	if ctl.clientCfg.Transport.HeartbeatInterval > 0 {
-		hbSend := time.NewTicker(time.Duration(ctl.clientCfg.Transport.HeartbeatInterval) * time.Second)
-		defer hbSend.Stop()
-		hbSendCh = hbSend.C
+		go wait.BackoffUntil(sendHeartBeat,
+			wait.NewFastBackoffManager(wait.FastBackoffOptions{
+				Duration:           time.Duration(ctl.clientCfg.Transport.HeartbeatInterval) * time.Second,
+				InitDurationIfFail: time.Second,
+				Factor:             2.0,
+				Jitter:             0.1,
+				MaxDuration:        time.Duration(ctl.clientCfg.Transport.HeartbeatInterval) * time.Second,
+			}),
+			true, ctl.doneCh,
+		)
 	}
 
-	var hbCheckCh <-chan time.Time
 	// Check heartbeat timeout only if TCPMux is not enabled and users don't disable heartbeat feature.
 	if ctl.clientCfg.Transport.HeartbeatInterval > 0 && ctl.clientCfg.Transport.HeartbeatTimeout > 0 &&
 		!lo.FromPtr(ctl.clientCfg.Transport.TCPMux) {
-		hbCheck := time.NewTicker(time.Second)
-		defer hbCheck.Stop()
-		hbCheckCh = hbCheck.C
-	}
 
-	ctl.lastPong = time.Now()
-	for {
-		select {
-		case <-hbSendCh:
-			// send heartbeat to server
-			xl.Debug("send heartbeat to server")
-			pingMsg := &msg.Ping{}
-			if err := ctl.authSetter.SetPing(pingMsg); err != nil {
-				xl.Warn("error during ping authentication: %v. skip sending ping message", err)
-				continue
-			}
-			ctl.sendCh <- pingMsg
-		case <-hbCheckCh:
-			if time.Since(ctl.lastPong) > time.Duration(ctl.clientCfg.Transport.HeartbeatTimeout)*time.Second {
+		go wait.Until(func() {
+			if time.Since(ctl.lastPong.Load().(time.Time)) > time.Duration(ctl.clientCfg.Transport.HeartbeatTimeout)*time.Second {
 				xl.Warn("heartbeat timeout")
-				// let reader() stop
 				ctl.conn.Close()
 				return
 			}
-		case rawMsg, ok := <-ctl.readCh:
-			if !ok {
-				return
-			}
-
-			switch m := rawMsg.(type) {
-			case *msg.ReqWorkConn:
-				go ctl.HandleReqWorkConn(m)
-			case *msg.NewProxyResp:
-				ctl.HandleNewProxyResp(m)
-			case *msg.NatHoleResp:
-				ctl.HandleNatHoleResp(m)
-			case *msg.Pong:
-				if m.Error != "" {
-					xl.Error("Pong contains error: %s", m.Error)
-					ctl.conn.Close()
-					return
-				}
-				ctl.lastPong = time.Now()
-				xl.Debug("receive heartbeat from server")
-			}
-		}
+		}, time.Second, ctl.doneCh)
 	}
 }
 
-// If controler is notified by closedCh, reader and writer and handler will exit
 func (ctl *Control) worker() {
-	go ctl.msgHandler()
-	go ctl.reader()
-	go ctl.writer()
+	go ctl.heartbeatWorker()
+	go ctl.msgDispatcher.Run()
 
-	<-ctl.closedCh
-	// close related channels and wait until other goroutines done
-	close(ctl.readCh)
-	ctl.readerShutdown.WaitDone()
-	ctl.msgHandlerShutdown.WaitDone()
-
-	close(ctl.sendCh)
-	ctl.writerShutdown.WaitDone()
+	<-ctl.msgDispatcher.Done()
+	ctl.conn.Close()
 
 	ctl.pm.Close()
 	ctl.vm.Close()
-
-	close(ctl.closedDoneCh)
 	ctl.cm.Close()
+
+	close(ctl.doneCh)
 }
 
 func (ctl *Control) ReloadConf(pxyCfgs []v1.ProxyConfigurer, visitorCfgs []v1.VisitorConfigurer) error {

+ 93 - 106
client/service.go

@@ -17,6 +17,7 @@ package client
 import (
 	"context"
 	"crypto/tls"
+	"errors"
 	"fmt"
 	"io"
 	"net"
@@ -24,7 +25,6 @@ import (
 	"strconv"
 	"strings"
 	"sync"
-	"sync/atomic"
 	"time"
 
 	"github.com/fatedier/golib/crypto"
@@ -40,8 +40,8 @@ import (
 	"github.com/fatedier/frp/pkg/transport"
 	"github.com/fatedier/frp/pkg/util/log"
 	utilnet "github.com/fatedier/frp/pkg/util/net"
-	"github.com/fatedier/frp/pkg/util/util"
 	"github.com/fatedier/frp/pkg/util/version"
+	"github.com/fatedier/frp/pkg/util/wait"
 	"github.com/fatedier/frp/pkg/util/xlog"
 )
 
@@ -70,12 +70,11 @@ type Service struct {
 	// string if no configuration file was used.
 	cfgFile string
 
-	exit uint32 // 0 means not exit
-
 	// service context
 	ctx context.Context
 	// call cancel to stop service
-	cancel context.CancelFunc
+	cancel           context.CancelFunc
+	gracefulDuration time.Duration
 }
 
 func NewService(
@@ -91,7 +90,6 @@ func NewService(
 		pxyCfgs:     pxyCfgs,
 		visitorCfgs: visitorCfgs,
 		ctx:         context.Background(),
-		exit:        0,
 	}
 }
 
@@ -106,8 +104,6 @@ func (svr *Service) Run(ctx context.Context) error {
 	svr.ctx = xlog.NewContext(ctx, xlog.New())
 	svr.cancel = cancel
 
-	xl := xlog.FromContextSafe(svr.ctx)
-
 	// set custom DNSServer
 	if svr.cfg.DNSServer != "" {
 		dnsAddr := svr.cfg.DNSServer
@@ -124,26 +120,9 @@ func (svr *Service) Run(ctx context.Context) error {
 	}
 
 	// login to frps
-	for {
-		conn, cm, err := svr.login()
-		if err != nil {
-			xl.Warn("login to server failed: %v", err)
-
-			// if login_fail_exit is true, just exit this program
-			// otherwise sleep a while and try again to connect to server
-			if lo.FromPtr(svr.cfg.LoginFailExit) {
-				return err
-			}
-			util.RandomSleep(5*time.Second, 0.9, 1.1)
-		} else {
-			// login success
-			ctl := NewControl(svr.ctx, svr.runID, conn, cm, svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.authSetter)
-			ctl.Run()
-			svr.ctlMu.Lock()
-			svr.ctl = ctl
-			svr.ctlMu.Unlock()
-			break
-		}
+	svr.loopLoginUntilSuccess(10*time.Second, lo.FromPtr(svr.cfg.LoginFailExit))
+	if svr.ctl == nil {
+		return fmt.Errorf("the process exited because the first login to the server failed, and the loginFailExit feature is enabled")
 	}
 
 	go svr.keepControllerWorking()
@@ -160,80 +139,35 @@ func (svr *Service) Run(ctx context.Context) error {
 		log.Info("admin server listen on %s:%d", svr.cfg.WebServer.Addr, svr.cfg.WebServer.Port)
 	}
 	<-svr.ctx.Done()
-	// service context may not be canceled by svr.Close(), we should call it here to release resources
-	if atomic.LoadUint32(&svr.exit) == 0 {
-		svr.Close()
-	}
+	svr.stop()
 	return nil
 }
 
 func (svr *Service) keepControllerWorking() {
-	xl := xlog.FromContextSafe(svr.ctx)
-	maxDelayTime := 20 * time.Second
-	delayTime := time.Second
-
-	// if frpc reconnect frps, we need to limit retry times in 1min
-	// current retry logic is sleep 0s, 0s, 0s, 1s, 2s, 4s, 8s, ...
-	// when exceed 1min, we will reset delay and counts
-	cutoffTime := time.Now().Add(time.Minute)
-	reconnectDelay := time.Second
-	reconnectCounts := 1
-
-	for {
-		<-svr.ctl.ClosedDoneCh()
-		if atomic.LoadUint32(&svr.exit) != 0 {
-			return
-		}
-
-		// the first three attempts with a low delay
-		if reconnectCounts > 3 {
-			util.RandomSleep(reconnectDelay, 0.9, 1.1)
-			xl.Info("wait %v to reconnect", reconnectDelay)
-			reconnectDelay *= 2
-		} else {
-			util.RandomSleep(time.Second, 0, 0.5)
-		}
-		reconnectCounts++
-
-		now := time.Now()
-		if now.After(cutoffTime) {
-			// reset
-			cutoffTime = now.Add(time.Minute)
-			reconnectDelay = time.Second
-			reconnectCounts = 1
-		}
-
-		for {
-			if atomic.LoadUint32(&svr.exit) != 0 {
-				return
-			}
-
-			xl.Info("try to reconnect to server...")
-			conn, cm, err := svr.login()
-			if err != nil {
-				xl.Warn("reconnect to server error: %v, wait %v for another retry", err, delayTime)
-				util.RandomSleep(delayTime, 0.9, 1.1)
-
-				delayTime *= 2
-				if delayTime > maxDelayTime {
-					delayTime = maxDelayTime
-				}
-				continue
-			}
-			// reconnect success, init delayTime
-			delayTime = time.Second
-
-			ctl := NewControl(svr.ctx, svr.runID, conn, cm, svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.authSetter)
-			ctl.Run()
-			svr.ctlMu.Lock()
-			if svr.ctl != nil {
-				svr.ctl.Close()
-			}
-			svr.ctl = ctl
-			svr.ctlMu.Unlock()
-			break
-		}
-	}
+	<-svr.ctl.Done()
+
+	// There is a situation where the login is successful but due to certain reasons,
+	// the control immediately exits. It is necessary to limit the frequency of reconnection in this case.
+	// The interval for the first three retries in 1 minute will be very short, and then it will increase exponentially.
+	// The maximum interval is 20 seconds.
+	wait.BackoffUntil(func() error {
+		// loopLoginUntilSuccess is another layer of loop that will continuously attempt to
+		// login to the server until successful.
+		svr.loopLoginUntilSuccess(20*time.Second, false)
+		<-svr.ctl.Done()
+		return errors.New("control is closed and try another loop")
+	}, wait.NewFastBackoffManager(
+		wait.FastBackoffOptions{
+			Duration:        time.Second,
+			Factor:          2,
+			Jitter:          0.1,
+			MaxDuration:     20 * time.Second,
+			FastRetryCount:  3,
+			FastRetryDelay:  200 * time.Millisecond,
+			FastRetryWindow: time.Minute,
+			FastRetryJitter: 0.5,
+		},
+	), true, svr.ctx.Done())
 }
 
 // login creates a connection to frps and registers it self as a client
@@ -299,6 +233,54 @@ func (svr *Service) login() (conn net.Conn, cm *ConnectionManager, err error) {
 	return
 }
 
+func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginExit bool) {
+	xl := xlog.FromContextSafe(svr.ctx)
+	successCh := make(chan struct{})
+
+	loginFunc := func() error {
+		xl.Info("try to connect to server...")
+		conn, cm, err := svr.login()
+		if err != nil {
+			xl.Warn("connect to server error: %v", err)
+			if firstLoginExit {
+				svr.cancel()
+			}
+			return err
+		}
+
+		ctl, err := NewControl(svr.ctx, svr.runID, conn, cm,
+			svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.authSetter)
+		if err != nil {
+			conn.Close()
+			xl.Error("NewControl error: %v", err)
+			return err
+		}
+
+		ctl.Run()
+		// close and replace previous control
+		svr.ctlMu.Lock()
+		if svr.ctl != nil {
+			svr.ctl.Close()
+		}
+		svr.ctl = ctl
+		svr.ctlMu.Unlock()
+
+		close(successCh)
+		return nil
+	}
+
+	// try to reconnect to server until success
+	wait.BackoffUntil(loginFunc, wait.NewFastBackoffManager(
+		wait.FastBackoffOptions{
+			Duration:    time.Second,
+			Factor:      2,
+			Jitter:      0.1,
+			MaxDuration: maxInterval,
+		}),
+		true,
+		wait.MergeAndCloseOnAnyStopChannel(svr.ctx.Done(), successCh))
+}
+
 func (svr *Service) ReloadConf(pxyCfgs []v1.ProxyConfigurer, visitorCfgs []v1.VisitorConfigurer) error {
 	svr.cfgMu.Lock()
 	svr.pxyCfgs = pxyCfgs
@@ -320,20 +302,20 @@ func (svr *Service) Close() {
 }
 
 func (svr *Service) GracefulClose(d time.Duration) {
-	atomic.StoreUint32(&svr.exit, 1)
+	svr.gracefulDuration = d
+	svr.cancel()
+}
 
-	svr.ctlMu.RLock()
+func (svr *Service) stop() {
+	svr.ctlMu.Lock()
+	defer svr.ctlMu.Unlock()
 	if svr.ctl != nil {
-		svr.ctl.GracefulClose(d)
+		svr.ctl.GracefulClose(svr.gracefulDuration)
 		svr.ctl = nil
 	}
-	svr.ctlMu.RUnlock()
-
-	if svr.cancel != nil {
-		svr.cancel()
-	}
 }
 
+// ConnectionManager is a wrapper for establishing connections to the server.
 type ConnectionManager struct {
 	ctx context.Context
 	cfg *v1.ClientCommonConfig
@@ -349,6 +331,10 @@ func NewConnectionManager(ctx context.Context, cfg *v1.ClientCommonConfig) *Conn
 	}
 }
 
+// OpenConnection opens a underlying connection to the server.
+// The underlying connection is either a TCP connection or a QUIC connection.
+// After the underlying connection is established, you can call Connect() to get a stream.
+// If TCPMux isn't enabled, the underlying connection is nil, you will get a new real TCP connection every time you call Connect().
 func (cm *ConnectionManager) OpenConnection() error {
 	xl := xlog.FromContextSafe(cm.ctx)
 
@@ -411,6 +397,7 @@ func (cm *ConnectionManager) OpenConnection() error {
 	return nil
 }
 
+// Connect returns a stream from the underlying connection, or a new TCP connection if TCPMux isn't enabled.
 func (cm *ConnectionManager) Connect() (net.Conn, error) {
 	if cm.quicConn != nil {
 		stream, err := cm.quicConn.OpenStreamSync(context.Background())

+ 14 - 0
pkg/metrics/metrics.go

@@ -1,3 +1,17 @@
+// Copyright 2023 The frp Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
 package metrics
 
 import (

+ 103 - 0
pkg/msg/handler.go

@@ -0,0 +1,103 @@
+// Copyright 2023 The frp Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package msg
+
+import (
+	"io"
+	"reflect"
+)
+
+func AsyncHandler(f func(Message)) func(Message) {
+	return func(m Message) {
+		go f(m)
+	}
+}
+
+// Dispatcher is used to send messages to net.Conn or register handlers for messages read from net.Conn.
+type Dispatcher struct {
+	rw io.ReadWriter
+
+	sendCh         chan Message
+	doneCh         chan struct{}
+	msgHandlers    map[reflect.Type]func(Message)
+	defaultHandler func(Message)
+}
+
+func NewDispatcher(rw io.ReadWriter) *Dispatcher {
+	return &Dispatcher{
+		rw:          rw,
+		sendCh:      make(chan Message, 100),
+		doneCh:      make(chan struct{}),
+		msgHandlers: make(map[reflect.Type]func(Message)),
+	}
+}
+
+// Run will block until io.EOF or some error occurs.
+func (d *Dispatcher) Run() {
+	go d.sendLoop()
+	go d.readLoop()
+}
+
+func (d *Dispatcher) sendLoop() {
+	for {
+		select {
+		case <-d.doneCh:
+			return
+		case m := <-d.sendCh:
+			_ = WriteMsg(d.rw, m)
+		}
+	}
+}
+
+func (d *Dispatcher) readLoop() {
+	for {
+		m, err := ReadMsg(d.rw)
+		if err != nil {
+			close(d.doneCh)
+			return
+		}
+
+		if handler, ok := d.msgHandlers[reflect.TypeOf(m)]; ok {
+			handler(m)
+		} else if d.defaultHandler != nil {
+			d.defaultHandler(m)
+		}
+	}
+}
+
+func (d *Dispatcher) Send(m Message) error {
+	select {
+	case <-d.doneCh:
+		return io.EOF
+	case d.sendCh <- m:
+		return nil
+	}
+}
+
+func (d *Dispatcher) SendChannel() chan Message {
+	return d.sendCh
+}
+
+func (d *Dispatcher) RegisterHandler(msg Message, handler func(Message)) {
+	d.msgHandlers[reflect.TypeOf(msg)] = handler
+}
+
+func (d *Dispatcher) RegisterDefaultHandler(handler func(Message)) {
+	d.defaultHandler = handler
+}
+
+func (d *Dispatcher) Done() chan struct{} {
+	return d.doneCh
+}

+ 2 - 0
pkg/transport/message.go

@@ -29,7 +29,9 @@ type MessageTransporter interface {
 	// Recv(ctx context.Context, laneKey string, msgType string) (Message, error)
 	// Do will first send msg, then recv msg with the same laneKey and specified msgType.
 	Do(ctx context.Context, req msg.Message, laneKey, recvMsgType string) (msg.Message, error)
+	// Dispatch will dispatch message to releated channel registered in Do function by its message type and laneKey.
 	Dispatch(m msg.Message, laneKey string) bool
+	// Same with Dispatch but with specified message type.
 	DispatchWithType(m msg.Message, msgType, laneKey string) bool
 }
 

+ 16 - 0
pkg/util/net/conn.go

@@ -22,6 +22,7 @@ import (
 	"sync/atomic"
 	"time"
 
+	"github.com/fatedier/golib/crypto"
 	quic "github.com/quic-go/quic-go"
 
 	"github.com/fatedier/frp/pkg/util/xlog"
@@ -216,3 +217,18 @@ func (conn *wrapQuicStream) Close() error {
 	conn.Stream.CancelRead(0)
 	return conn.Stream.Close()
 }
+
+func NewCryptoReadWriter(rw io.ReadWriter, key []byte) (io.ReadWriter, error) {
+	encReader := crypto.NewReader(rw, key)
+	encWriter, err := crypto.NewWriter(rw, key)
+	if err != nil {
+		return nil, err
+	}
+	return struct {
+		io.Reader
+		io.Writer
+	}{
+		Reader: encReader,
+		Writer: encWriter,
+	}, nil
+}

+ 197 - 0
pkg/util/wait/backoff.go

@@ -0,0 +1,197 @@
+// Copyright 2023 The frp Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package wait
+
+import (
+	"math/rand"
+	"time"
+
+	"github.com/samber/lo"
+
+	"github.com/fatedier/frp/pkg/util/util"
+)
+
+type BackoffFunc func(previousDuration time.Duration, previousConditionError bool) time.Duration
+
+func (f BackoffFunc) Backoff(previousDuration time.Duration, previousConditionError bool) time.Duration {
+	return f(previousDuration, previousConditionError)
+}
+
+type BackoffManager interface {
+	Backoff(previousDuration time.Duration, previousConditionError bool) time.Duration
+}
+
+type FastBackoffOptions struct {
+	Duration           time.Duration
+	Factor             float64
+	Jitter             float64
+	MaxDuration        time.Duration
+	InitDurationIfFail time.Duration
+
+	// If FastRetryCount > 0, then within the FastRetryWindow time window,
+	// the retry will be performed with a delay of FastRetryDelay for the first FastRetryCount calls.
+	FastRetryCount  int
+	FastRetryDelay  time.Duration
+	FastRetryJitter float64
+	FastRetryWindow time.Duration
+}
+
+type fastBackoffImpl struct {
+	options FastBackoffOptions
+
+	lastCalledTime      time.Time
+	consecutiveErrCount int
+
+	fastRetryCutoffTime     time.Time
+	countsInFastRetryWindow int
+}
+
+func NewFastBackoffManager(options FastBackoffOptions) BackoffManager {
+	return &fastBackoffImpl{
+		options:                 options,
+		countsInFastRetryWindow: 1,
+	}
+}
+
+func (f *fastBackoffImpl) Backoff(previousDuration time.Duration, previousConditionError bool) time.Duration {
+	if f.lastCalledTime.IsZero() {
+		f.lastCalledTime = time.Now()
+		return f.options.Duration
+	}
+	now := time.Now()
+	f.lastCalledTime = now
+
+	if previousConditionError {
+		f.consecutiveErrCount++
+	} else {
+		f.consecutiveErrCount = 0
+	}
+
+	if f.options.FastRetryCount > 0 && previousConditionError {
+		f.countsInFastRetryWindow++
+		if f.countsInFastRetryWindow <= f.options.FastRetryCount {
+			return Jitter(f.options.FastRetryDelay, f.options.FastRetryJitter)
+		}
+		if now.After(f.fastRetryCutoffTime) {
+			// reset
+			f.fastRetryCutoffTime = now.Add(f.options.FastRetryWindow)
+			f.countsInFastRetryWindow = 0
+		}
+	}
+
+	if previousConditionError {
+		var duration time.Duration
+		if f.consecutiveErrCount == 1 {
+			duration = util.EmptyOr(f.options.InitDurationIfFail, previousDuration)
+		} else {
+			duration = previousDuration
+		}
+
+		duration = util.EmptyOr(duration, time.Second)
+		if f.options.Factor != 0 {
+			duration = time.Duration(float64(duration) * f.options.Factor)
+		}
+		if f.options.Jitter > 0 {
+			duration = Jitter(duration, f.options.Jitter)
+		}
+		if f.options.MaxDuration > 0 && duration > f.options.MaxDuration {
+			duration = f.options.MaxDuration
+		}
+		return duration
+	}
+	return f.options.Duration
+}
+
+func BackoffUntil(f func() error, backoff BackoffManager, sliding bool, stopCh <-chan struct{}) {
+	var delay time.Duration
+	previousError := false
+
+	ticker := time.NewTicker(backoff.Backoff(delay, previousError))
+	defer ticker.Stop()
+
+	for {
+		select {
+		case <-stopCh:
+			return
+		default:
+		}
+
+		if !sliding {
+			delay = backoff.Backoff(delay, previousError)
+		}
+
+		if err := f(); err != nil {
+			previousError = true
+		} else {
+			previousError = false
+		}
+
+		if sliding {
+			delay = backoff.Backoff(delay, previousError)
+		}
+
+		ticker.Reset(delay)
+		select {
+		case <-stopCh:
+			return
+		default:
+		}
+
+		select {
+		case <-stopCh:
+			return
+		case <-ticker.C:
+		}
+	}
+}
+
+// Jitter returns a time.Duration between duration and duration + maxFactor *
+// duration.
+//
+// This allows clients to avoid converging on periodic behavior. If maxFactor
+// is 0.0, a suggested default value will be chosen.
+func Jitter(duration time.Duration, maxFactor float64) time.Duration {
+	if maxFactor <= 0.0 {
+		maxFactor = 1.0
+	}
+	wait := duration + time.Duration(rand.Float64()*maxFactor*float64(duration))
+	return wait
+}
+
+func Until(f func(), period time.Duration, stopCh <-chan struct{}) {
+	ff := func() error {
+		f()
+		return nil
+	}
+	BackoffUntil(ff, BackoffFunc(func(time.Duration, bool) time.Duration {
+		return period
+	}), true, stopCh)
+}
+
+func MergeAndCloseOnAnyStopChannel[T any](upstreams ...<-chan T) <-chan T {
+	out := make(chan T)
+
+	for _, upstream := range upstreams {
+		ch := upstream
+		go lo.Try0(func() {
+			select {
+			case <-ch:
+				close(out)
+			case <-out:
+			}
+		})
+	}
+	return out
+}

+ 141 - 244
server/control.go

@@ -17,15 +17,12 @@ package server
 import (
 	"context"
 	"fmt"
-	"io"
 	"net"
 	"runtime/debug"
 	"sync"
+	"sync/atomic"
 	"time"
 
-	"github.com/fatedier/golib/control/shutdown"
-	"github.com/fatedier/golib/crypto"
-	"github.com/fatedier/golib/errors"
 	"github.com/samber/lo"
 
 	"github.com/fatedier/frp/pkg/auth"
@@ -35,8 +32,10 @@ import (
 	"github.com/fatedier/frp/pkg/msg"
 	plugin "github.com/fatedier/frp/pkg/plugin/server"
 	"github.com/fatedier/frp/pkg/transport"
+	utilnet "github.com/fatedier/frp/pkg/util/net"
 	"github.com/fatedier/frp/pkg/util/util"
 	"github.com/fatedier/frp/pkg/util/version"
+	"github.com/fatedier/frp/pkg/util/wait"
 	"github.com/fatedier/frp/pkg/util/xlog"
 	"github.com/fatedier/frp/server/controller"
 	"github.com/fatedier/frp/server/metrics"
@@ -111,18 +110,16 @@ type Control struct {
 	// other components can use this to communicate with client
 	msgTransporter transport.MessageTransporter
 
+	// msgDispatcher is a wrapper for control connection.
+	// It provides a channel for sending messages, and you can register handlers to process messages based on their respective types.
+	msgDispatcher *msg.Dispatcher
+
 	// login message
 	loginMsg *msg.Login
 
 	// control connection
 	conn net.Conn
 
-	// put a message in this channel to send it over control connection to client
-	sendCh chan (msg.Message)
-
-	// read from this channel to get the next message sent by client
-	readCh chan (msg.Message)
-
 	// work connections
 	workConnCh chan net.Conn
 
@@ -136,27 +133,21 @@ type Control struct {
 	portsUsedNum int
 
 	// last time got the Ping message
-	lastPing time.Time
+	lastPing atomic.Value
 
 	// A new run id will be generated when a new client login.
 	// If run id got from login message has same run id, it means it's the same client, so we can
 	// replace old controller instantly.
 	runID string
 
-	readerShutdown  *shutdown.Shutdown
-	writerShutdown  *shutdown.Shutdown
-	managerShutdown *shutdown.Shutdown
-	allShutdown     *shutdown.Shutdown
-
-	started bool
-
 	mu sync.RWMutex
 
 	// Server configuration information
 	serverCfg *v1.ServerConfig
 
-	xl  *xlog.Logger
-	ctx context.Context
+	xl     *xlog.Logger
+	ctx    context.Context
+	doneCh chan struct{}
 }
 
 func NewControl(
@@ -168,36 +159,38 @@ func NewControl(
 	ctlConn net.Conn,
 	loginMsg *msg.Login,
 	serverCfg *v1.ServerConfig,
-) *Control {
+) (*Control, error) {
 	poolCount := loginMsg.PoolCount
 	if poolCount > int(serverCfg.Transport.MaxPoolCount) {
 		poolCount = int(serverCfg.Transport.MaxPoolCount)
 	}
 	ctl := &Control{
-		rc:              rc,
-		pxyManager:      pxyManager,
-		pluginManager:   pluginManager,
-		authVerifier:    authVerifier,
-		conn:            ctlConn,
-		loginMsg:        loginMsg,
-		sendCh:          make(chan msg.Message, 10),
-		readCh:          make(chan msg.Message, 10),
-		workConnCh:      make(chan net.Conn, poolCount+10),
-		proxies:         make(map[string]proxy.Proxy),
-		poolCount:       poolCount,
-		portsUsedNum:    0,
-		lastPing:        time.Now(),
-		runID:           loginMsg.RunID,
-		readerShutdown:  shutdown.New(),
-		writerShutdown:  shutdown.New(),
-		managerShutdown: shutdown.New(),
-		allShutdown:     shutdown.New(),
-		serverCfg:       serverCfg,
-		xl:              xlog.FromContextSafe(ctx),
-		ctx:             ctx,
+		rc:            rc,
+		pxyManager:    pxyManager,
+		pluginManager: pluginManager,
+		authVerifier:  authVerifier,
+		conn:          ctlConn,
+		loginMsg:      loginMsg,
+		workConnCh:    make(chan net.Conn, poolCount+10),
+		proxies:       make(map[string]proxy.Proxy),
+		poolCount:     poolCount,
+		portsUsedNum:  0,
+		runID:         loginMsg.RunID,
+		serverCfg:     serverCfg,
+		xl:            xlog.FromContextSafe(ctx),
+		ctx:           ctx,
+		doneCh:        make(chan struct{}),
+	}
+	ctl.lastPing.Store(time.Now())
+
+	cryptoRW, err := utilnet.NewCryptoReadWriter(ctl.conn, []byte(ctl.serverCfg.Auth.Token))
+	if err != nil {
+		return nil, err
 	}
-	ctl.msgTransporter = transport.NewMessageTransporter(ctl.sendCh)
-	return ctl
+	ctl.msgDispatcher = msg.NewDispatcher(cryptoRW)
+	ctl.registerMsgHandlers()
+	ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher.SendChannel())
+	return ctl, nil
 }
 
 // Start send a login success message to client and start working.
@@ -208,27 +201,18 @@ func (ctl *Control) Start() {
 		Error:   "",
 	}
 	_ = msg.WriteMsg(ctl.conn, loginRespMsg)
-	ctl.mu.Lock()
-	ctl.started = true
-	ctl.mu.Unlock()
 
-	go ctl.writer()
 	go func() {
 		for i := 0; i < ctl.poolCount; i++ {
 			// ignore error here, that means that this control is closed
-			_ = errors.PanicToError(func() {
-				ctl.sendCh <- &msg.ReqWorkConn{}
-			})
+			_ = ctl.msgDispatcher.Send(&msg.ReqWorkConn{})
 		}
 	}()
-
-	go ctl.manager()
-	go ctl.reader()
-	go ctl.stoper()
+	go ctl.worker()
 }
 
 func (ctl *Control) Close() error {
-	ctl.allShutdown.Start()
+	ctl.conn.Close()
 	return nil
 }
 
@@ -236,7 +220,7 @@ func (ctl *Control) Replaced(newCtl *Control) {
 	xl := ctl.xl
 	xl.Info("Replaced by client [%s]", newCtl.runID)
 	ctl.runID = ""
-	ctl.allShutdown.Start()
+	ctl.conn.Close()
 }
 
 func (ctl *Control) RegisterWorkConn(conn net.Conn) error {
@@ -282,9 +266,7 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) {
 		xl.Debug("get work connection from pool")
 	default:
 		// no work connections available in the poll, send message to frpc to get more
-		if err = errors.PanicToError(func() {
-			ctl.sendCh <- &msg.ReqWorkConn{}
-		}); err != nil {
+		if err := ctl.msgDispatcher.Send(&msg.ReqWorkConn{}); err != nil {
 			return nil, fmt.Errorf("control is already closed")
 		}
 
@@ -304,92 +286,39 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) {
 	}
 
 	// When we get a work connection from pool, replace it with a new one.
-	_ = errors.PanicToError(func() {
-		ctl.sendCh <- &msg.ReqWorkConn{}
-	})
+	_ = ctl.msgDispatcher.Send(&msg.ReqWorkConn{})
 	return
 }
 
-func (ctl *Control) writer() {
+func (ctl *Control) heartbeatWorker() {
 	xl := ctl.xl
-	defer func() {
-		if err := recover(); err != nil {
-			xl.Error("panic error: %v", err)
-			xl.Error(string(debug.Stack()))
-		}
-	}()
-
-	defer ctl.allShutdown.Start()
-	defer ctl.writerShutdown.Done()
-
-	encWriter, err := crypto.NewWriter(ctl.conn, []byte(ctl.serverCfg.Auth.Token))
-	if err != nil {
-		xl.Error("crypto new writer error: %v", err)
-		ctl.allShutdown.Start()
-		return
-	}
-	for {
-		m, ok := <-ctl.sendCh
-		if !ok {
-			xl.Info("control writer is closing")
-			return
-		}
 
-		if err := msg.WriteMsg(encWriter, m); err != nil {
-			xl.Warn("write message to control connection error: %v", err)
-			return
-		}
-	}
-}
-
-func (ctl *Control) reader() {
-	xl := ctl.xl
-	defer func() {
-		if err := recover(); err != nil {
-			xl.Error("panic error: %v", err)
-			xl.Error(string(debug.Stack()))
-		}
-	}()
-
-	defer ctl.allShutdown.Start()
-	defer ctl.readerShutdown.Done()
-
-	encReader := crypto.NewReader(ctl.conn, []byte(ctl.serverCfg.Auth.Token))
-	for {
-		m, err := msg.ReadMsg(encReader)
-		if err != nil {
-			if err == io.EOF {
-				xl.Debug("control connection closed")
+	// Don't need application heartbeat if TCPMux is enabled,
+	// yamux will do same thing.
+	// TODO(fatedier): let default HeartbeatTimeout to -1 if TCPMux is enabled. Users can still set it to positive value to enable it.
+	if !lo.FromPtr(ctl.serverCfg.Transport.TCPMux) && ctl.serverCfg.Transport.HeartbeatTimeout > 0 {
+		go wait.Until(func() {
+			if time.Since(ctl.lastPing.Load().(time.Time)) > time.Duration(ctl.serverCfg.Transport.HeartbeatTimeout)*time.Second {
+				xl.Warn("heartbeat timeout")
 				return
 			}
-			xl.Warn("read error: %v", err)
-			ctl.conn.Close()
-			return
-		}
-
-		ctl.readCh <- m
+		}, time.Second, ctl.doneCh)
 	}
 }
 
-func (ctl *Control) stoper() {
+// block until Control closed
+func (ctl *Control) WaitClosed() {
+	<-ctl.doneCh
+}
+
+func (ctl *Control) worker() {
 	xl := ctl.xl
-	defer func() {
-		if err := recover(); err != nil {
-			xl.Error("panic error: %v", err)
-			xl.Error(string(debug.Stack()))
-		}
-	}()
 
-	ctl.allShutdown.WaitStart()
+	go ctl.heartbeatWorker()
+	go ctl.msgDispatcher.Run()
 
+	<-ctl.msgDispatcher.Done()
 	ctl.conn.Close()
-	ctl.readerShutdown.WaitDone()
-
-	close(ctl.readCh)
-	ctl.managerShutdown.WaitDone()
-
-	close(ctl.sendCh)
-	ctl.writerShutdown.WaitDone()
 
 	ctl.mu.Lock()
 	defer ctl.mu.Unlock()
@@ -419,136 +348,104 @@ func (ctl *Control) stoper() {
 		}()
 	}
 
-	ctl.allShutdown.Done()
-	xl.Info("client exit success")
 	metrics.Server.CloseClient()
+	xl.Info("client exit success")
+	close(ctl.doneCh)
 }
 
-// block until Control closed
-func (ctl *Control) WaitClosed() {
-	ctl.mu.RLock()
-	started := ctl.started
-	ctl.mu.RUnlock()
-
-	if !started {
-		ctl.allShutdown.Done()
-		return
-	}
-	ctl.allShutdown.WaitDone()
+func (ctl *Control) registerMsgHandlers() {
+	ctl.msgDispatcher.RegisterHandler(&msg.NewProxy{}, ctl.handleNewProxy)
+	ctl.msgDispatcher.RegisterHandler(&msg.Ping{}, ctl.handlePing)
+	ctl.msgDispatcher.RegisterHandler(&msg.NatHoleVisitor{}, msg.AsyncHandler(ctl.handleNatHoleVisitor))
+	ctl.msgDispatcher.RegisterHandler(&msg.NatHoleClient{}, msg.AsyncHandler(ctl.handleNatHoleClient))
+	ctl.msgDispatcher.RegisterHandler(&msg.NatHoleReport{}, msg.AsyncHandler(ctl.handleNatHoleReport))
+	ctl.msgDispatcher.RegisterHandler(&msg.CloseProxy{}, ctl.handleCloseProxy)
 }
 
-func (ctl *Control) manager() {
+func (ctl *Control) handleNewProxy(m msg.Message) {
 	xl := ctl.xl
-	defer func() {
-		if err := recover(); err != nil {
-			xl.Error("panic error: %v", err)
-			xl.Error(string(debug.Stack()))
-		}
-	}()
+	inMsg := m.(*msg.NewProxy)
 
-	defer ctl.allShutdown.Start()
-	defer ctl.managerShutdown.Done()
+	content := &plugin.NewProxyContent{
+		User: plugin.UserInfo{
+			User:  ctl.loginMsg.User,
+			Metas: ctl.loginMsg.Metas,
+			RunID: ctl.loginMsg.RunID,
+		},
+		NewProxy: *inMsg,
+	}
+	var remoteAddr string
+	retContent, err := ctl.pluginManager.NewProxy(content)
+	if err == nil {
+		inMsg = &retContent.NewProxy
+		remoteAddr, err = ctl.RegisterProxy(inMsg)
+	}
 
-	var heartbeatCh <-chan time.Time
-	// Don't need application heartbeat if TCPMux is enabled,
-	// yamux will do same thing.
-	if !lo.FromPtr(ctl.serverCfg.Transport.TCPMux) && ctl.serverCfg.Transport.HeartbeatTimeout > 0 {
-		heartbeat := time.NewTicker(time.Second)
-		defer heartbeat.Stop()
-		heartbeatCh = heartbeat.C
+	// register proxy in this control
+	resp := &msg.NewProxyResp{
+		ProxyName: inMsg.ProxyName,
+	}
+	if err != nil {
+		xl.Warn("new proxy [%s] type [%s] error: %v", inMsg.ProxyName, inMsg.ProxyType, err)
+		resp.Error = util.GenerateResponseErrorString(fmt.Sprintf("new proxy [%s] error", inMsg.ProxyName),
+			err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient))
+	} else {
+		resp.RemoteAddr = remoteAddr
+		xl.Info("new proxy [%s] type [%s] success", inMsg.ProxyName, inMsg.ProxyType)
+		metrics.Server.NewProxy(inMsg.ProxyName, inMsg.ProxyType)
 	}
+	_ = ctl.msgDispatcher.Send(resp)
+}
 
-	for {
-		select {
-		case <-heartbeatCh:
-			if time.Since(ctl.lastPing) > time.Duration(ctl.serverCfg.Transport.HeartbeatTimeout)*time.Second {
-				xl.Warn("heartbeat timeout")
-				return
-			}
-		case rawMsg, ok := <-ctl.readCh:
-			if !ok {
-				return
-			}
+func (ctl *Control) handlePing(m msg.Message) {
+	xl := ctl.xl
+	inMsg := m.(*msg.Ping)
 
-			switch m := rawMsg.(type) {
-			case *msg.NewProxy:
-				content := &plugin.NewProxyContent{
-					User: plugin.UserInfo{
-						User:  ctl.loginMsg.User,
-						Metas: ctl.loginMsg.Metas,
-						RunID: ctl.loginMsg.RunID,
-					},
-					NewProxy: *m,
-				}
-				var remoteAddr string
-				retContent, err := ctl.pluginManager.NewProxy(content)
-				if err == nil {
-					m = &retContent.NewProxy
-					remoteAddr, err = ctl.RegisterProxy(m)
-				}
-
-				// register proxy in this control
-				resp := &msg.NewProxyResp{
-					ProxyName: m.ProxyName,
-				}
-				if err != nil {
-					xl.Warn("new proxy [%s] type [%s] error: %v", m.ProxyName, m.ProxyType, err)
-					resp.Error = util.GenerateResponseErrorString(fmt.Sprintf("new proxy [%s] error", m.ProxyName),
-						err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient))
-				} else {
-					resp.RemoteAddr = remoteAddr
-					xl.Info("new proxy [%s] type [%s] success", m.ProxyName, m.ProxyType)
-					metrics.Server.NewProxy(m.ProxyName, m.ProxyType)
-				}
-				ctl.sendCh <- resp
-			case *msg.NatHoleVisitor:
-				go ctl.HandleNatHoleVisitor(m)
-			case *msg.NatHoleClient:
-				go ctl.HandleNatHoleClient(m)
-			case *msg.NatHoleReport:
-				go ctl.HandleNatHoleReport(m)
-			case *msg.CloseProxy:
-				_ = ctl.CloseProxy(m)
-				xl.Info("close proxy [%s] success", m.ProxyName)
-			case *msg.Ping:
-				content := &plugin.PingContent{
-					User: plugin.UserInfo{
-						User:  ctl.loginMsg.User,
-						Metas: ctl.loginMsg.Metas,
-						RunID: ctl.loginMsg.RunID,
-					},
-					Ping: *m,
-				}
-				retContent, err := ctl.pluginManager.Ping(content)
-				if err == nil {
-					m = &retContent.Ping
-					err = ctl.authVerifier.VerifyPing(m)
-				}
-				if err != nil {
-					xl.Warn("received invalid ping: %v", err)
-					ctl.sendCh <- &msg.Pong{
-						Error: util.GenerateResponseErrorString("invalid ping", err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient)),
-					}
-					return
-				}
-				ctl.lastPing = time.Now()
-				xl.Debug("receive heartbeat")
-				ctl.sendCh <- &msg.Pong{}
-			}
-		}
+	content := &plugin.PingContent{
+		User: plugin.UserInfo{
+			User:  ctl.loginMsg.User,
+			Metas: ctl.loginMsg.Metas,
+			RunID: ctl.loginMsg.RunID,
+		},
+		Ping: *inMsg,
+	}
+	retContent, err := ctl.pluginManager.Ping(content)
+	if err == nil {
+		inMsg = &retContent.Ping
+		err = ctl.authVerifier.VerifyPing(inMsg)
 	}
+	if err != nil {
+		xl.Warn("received invalid ping: %v", err)
+		_ = ctl.msgDispatcher.Send(&msg.Pong{
+			Error: util.GenerateResponseErrorString("invalid ping", err, lo.FromPtr(ctl.serverCfg.DetailedErrorsToClient)),
+		})
+		return
+	}
+	ctl.lastPing.Store(time.Now())
+	xl.Debug("receive heartbeat")
+	_ = ctl.msgDispatcher.Send(&msg.Pong{})
 }
 
-func (ctl *Control) HandleNatHoleVisitor(m *msg.NatHoleVisitor) {
-	ctl.rc.NatHoleController.HandleVisitor(m, ctl.msgTransporter, ctl.loginMsg.User)
+func (ctl *Control) handleNatHoleVisitor(m msg.Message) {
+	inMsg := m.(*msg.NatHoleVisitor)
+	ctl.rc.NatHoleController.HandleVisitor(inMsg, ctl.msgTransporter, ctl.loginMsg.User)
 }
 
-func (ctl *Control) HandleNatHoleClient(m *msg.NatHoleClient) {
-	ctl.rc.NatHoleController.HandleClient(m, ctl.msgTransporter)
+func (ctl *Control) handleNatHoleClient(m msg.Message) {
+	inMsg := m.(*msg.NatHoleClient)
+	ctl.rc.NatHoleController.HandleClient(inMsg, ctl.msgTransporter)
 }
 
-func (ctl *Control) HandleNatHoleReport(m *msg.NatHoleReport) {
-	ctl.rc.NatHoleController.HandleReport(m)
+func (ctl *Control) handleNatHoleReport(m msg.Message) {
+	inMsg := m.(*msg.NatHoleReport)
+	ctl.rc.NatHoleController.HandleReport(inMsg)
+}
+
+func (ctl *Control) handleCloseProxy(m msg.Message) {
+	xl := ctl.xl
+	inMsg := m.(*msg.CloseProxy)
+	_ = ctl.CloseProxy(inMsg)
+	xl.Info("close proxy [%s] success", inMsg.ProxyName)
 }
 
 func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err error) {

+ 12 - 6
server/service.go

@@ -516,13 +516,14 @@ func (svr *Service) HandleQUICListener(l *quic.Listener) {
 	}
 }
 
-func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login) (err error) {
+func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login) error {
 	// If client's RunID is empty, it's a new client, we just create a new controller.
 	// Otherwise, we check if there is one controller has the same run id. If so, we release previous controller and start new one.
+	var err error
 	if loginMsg.RunID == "" {
 		loginMsg.RunID, err = util.RandID()
 		if err != nil {
-			return
+			return err
 		}
 	}
 
@@ -534,11 +535,16 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login) (err
 		ctlConn.RemoteAddr().String(), loginMsg.Version, loginMsg.Hostname, loginMsg.Os, loginMsg.Arch)
 
 	// Check auth.
-	if err = svr.authVerifier.VerifyLogin(loginMsg); err != nil {
-		return
+	if err := svr.authVerifier.VerifyLogin(loginMsg); err != nil {
+		return err
 	}
 
-	ctl := NewControl(ctx, svr.rc, svr.pxyManager, svr.pluginManager, svr.authVerifier, ctlConn, loginMsg, svr.cfg)
+	ctl, err := NewControl(ctx, svr.rc, svr.pxyManager, svr.pluginManager, svr.authVerifier, ctlConn, loginMsg, svr.cfg)
+	if err != nil {
+		xl.Warn("create new controller error: %v", err)
+		// don't return detailed errors to client
+		return fmt.Errorf("unexpect error when creating new controller")
+	}
 	if oldCtl := svr.ctlManager.Add(loginMsg.RunID, ctl); oldCtl != nil {
 		oldCtl.WaitClosed()
 	}
@@ -553,7 +559,7 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login) (err
 		ctl.WaitClosed()
 		svr.ctlManager.Del(loginMsg.RunID, ctl)
 	}()
-	return
+	return nil
 }
 
 // RegisterWorkConn register a new work connection to control and proxies need it.