Przeglądaj źródła

fix: duplicate call loginFunc (#3860) (#3875)

modify ext func, specify whether exit immediately
im_zhou 1 rok temu
rodzic
commit
3bf6605e1a
3 zmienionych plików z 17 dodań i 19 usunięć
  1. 3 3
      client/control.go
  2. 8 12
      client/service.go
  3. 6 4
      pkg/util/wait/backoff.go

+ 3 - 3
client/control.go

@@ -239,15 +239,15 @@ func (ctl *Control) heartbeatWorker() {
 	// Users can still enable heartbeat feature by setting HeartbeatInterval to a positive value.
 	if ctl.sessionCtx.Common.Transport.HeartbeatInterval > 0 {
 		// send heartbeat to server
-		sendHeartBeat := func() error {
+		sendHeartBeat := func() (bool, error) {
 			xl.Debug("send heartbeat to server")
 			pingMsg := &msg.Ping{}
 			if err := ctl.sessionCtx.AuthSetter.SetPing(pingMsg); err != nil {
 				xl.Warn("error during ping authentication: %v, skip sending ping message", err)
-				return err
+				return false, err
 			}
 			_ = ctl.msgDispatcher.Send(pingMsg)
-			return nil
+			return false, nil
 		}
 
 		go wait.BackoffUntil(sendHeartBeat,

+ 8 - 12
client/service.go

@@ -192,16 +192,16 @@ func (svr *Service) keepControllerWorking() {
 	// the control immediately exits. It is necessary to limit the frequency of reconnection in this case.
 	// The interval for the first three retries in 1 minute will be very short, and then it will increase exponentially.
 	// The maximum interval is 20 seconds.
-	wait.BackoffUntil(func() error {
+	wait.BackoffUntil(func() (bool, error) {
 		// loopLoginUntilSuccess is another layer of loop that will continuously attempt to
 		// login to the server until successful.
 		svr.loopLoginUntilSuccess(20*time.Second, false)
 		if svr.ctl != nil {
 			<-svr.ctl.Done()
-			return errors.New("control is closed and try another loop")
+			return false, errors.New("control is closed and try another loop")
 		}
 		// If the control is nil, it means that the login failed and the service is also closed.
-		return nil
+		return false, nil
 	}, wait.NewFastBackoffManager(
 		wait.FastBackoffOptions{
 			Duration:        time.Second,
@@ -282,9 +282,8 @@ func (svr *Service) login() (conn net.Conn, connector Connector, err error) {
 
 func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginExit bool) {
 	xl := xlog.FromContextSafe(svr.ctx)
-	successCh := make(chan struct{})
 
-	loginFunc := func() error {
+	loginFunc := func() (bool, error) {
 		xl.Info("try to connect to server...")
 		conn, connector, err := svr.login()
 		if err != nil {
@@ -292,7 +291,7 @@ func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginE
 			if firstLoginExit {
 				svr.cancel(cancelErr{Err: err})
 			}
-			return err
+			return false, err
 		}
 
 		svr.cfgMu.RLock()
@@ -315,7 +314,7 @@ func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginE
 		if err != nil {
 			conn.Close()
 			xl.Error("NewControl error: %v", err)
-			return err
+			return false, err
 		}
 		ctl.SetInWorkConnCallback(svr.handleWorkConnCb)
 
@@ -328,8 +327,7 @@ func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginE
 		svr.ctl = ctl
 		svr.ctlMu.Unlock()
 
-		close(successCh)
-		return nil
+		return true, nil
 	}
 
 	// try to reconnect to server until success
@@ -339,9 +337,7 @@ func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginE
 			Factor:      2,
 			Jitter:      0.1,
 			MaxDuration: maxInterval,
-		}),
-		true,
-		wait.MergeAndCloseOnAnyStopChannel(svr.ctx.Done(), successCh))
+		}), true, svr.ctx.Done())
 }
 
 func (svr *Service) UpdateAllConfigurer(proxyCfgs []v1.ProxyConfigurer, visitorCfgs []v1.VisitorConfigurer) error {

+ 6 - 4
pkg/util/wait/backoff.go

@@ -113,7 +113,7 @@ func (f *fastBackoffImpl) Backoff(previousDuration time.Duration, previousCondit
 	return f.options.Duration
 }
 
-func BackoffUntil(f func() error, backoff BackoffManager, sliding bool, stopCh <-chan struct{}) {
+func BackoffUntil(f func() (bool, error), backoff BackoffManager, sliding bool, stopCh <-chan struct{}) {
 	var delay time.Duration
 	previousError := false
 
@@ -131,7 +131,9 @@ func BackoffUntil(f func() error, backoff BackoffManager, sliding bool, stopCh <
 			delay = backoff.Backoff(delay, previousError)
 		}
 
-		if err := f(); err != nil {
+		if done, err := f(); done {
+			return
+		} else if err != nil {
 			previousError = true
 		} else {
 			previousError = false
@@ -170,9 +172,9 @@ func Jitter(duration time.Duration, maxFactor float64) time.Duration {
 }
 
 func Until(f func(), period time.Duration, stopCh <-chan struct{}) {
-	ff := func() error {
+	ff := func() (bool, error) {
 		f()
-		return nil
+		return false, nil
 	}
 	BackoffUntil(ff, BackoffFunc(func(time.Duration, bool) time.Duration {
 		return period