Browse Source

fix: frps handle multi conn may happen data race (#1768)

Tank 4 years ago
parent
commit
7728e35c52
1 changed files with 69 additions and 61 deletions
  1. 69 61
      server/service.go

+ 69 - 61
server/service.go

@@ -297,6 +297,68 @@ func (svr *Service) Run() {
 	svr.HandleListener(svr.listener)
 }
 
+func (svr *Service) handleConnection(ctx context.Context, conn net.Conn) {
+	xl := xlog.FromContextSafe(ctx)
+
+	var (
+		rawMsg msg.Message
+		err    error
+	)
+
+	conn.SetReadDeadline(time.Now().Add(connReadTimeout))
+	if rawMsg, err = msg.ReadMsg(conn); err != nil {
+		log.Trace("Failed to read message: %v", err)
+		conn.Close()
+		return
+	}
+	conn.SetReadDeadline(time.Time{})
+
+	switch m := rawMsg.(type) {
+	case *msg.Login:
+		// server plugin hook
+		content := &plugin.LoginContent{
+			Login: *m,
+		}
+		retContent, err := svr.pluginManager.Login(content)
+		if err == nil {
+			m = &retContent.Login
+			err = svr.RegisterControl(conn, m)
+		}
+
+		// If login failed, send error message there.
+		// Otherwise send success message in control's work goroutine.
+		if err != nil {
+			xl.Warn("register control error: %v", err)
+			msg.WriteMsg(conn, &msg.LoginResp{
+				Version: version.Full(),
+				Error:   util.GenerateResponseErrorString("register control error", err, svr.cfg.DetailedErrorsToClient),
+			})
+			conn.Close()
+		}
+	case *msg.NewWorkConn:
+		if err := svr.RegisterWorkConn(conn, m); err != nil {
+			conn.Close()
+		}
+	case *msg.NewVisitorConn:
+		if err = svr.RegisterVisitorConn(conn, m); err != nil {
+			xl.Warn("register visitor conn error: %v", err)
+			msg.WriteMsg(conn, &msg.NewVisitorConnResp{
+				ProxyName: m.ProxyName,
+				Error:     util.GenerateResponseErrorString("register visitor conn error", err, svr.cfg.DetailedErrorsToClient),
+			})
+			conn.Close()
+		} else {
+			msg.WriteMsg(conn, &msg.NewVisitorConnResp{
+				ProxyName: m.ProxyName,
+				Error:     "",
+			})
+		}
+	default:
+		log.Warn("Error message type for the new connection [%s]", conn.RemoteAddr().String())
+		conn.Close()
+	}
+}
+
 func (svr *Service) HandleListener(l net.Listener) {
 	// Listen for incoming connections from client.
 	for {
@@ -307,7 +369,9 @@ func (svr *Service) HandleListener(l net.Listener) {
 		}
 		// inject xlog object into net.Conn context
 		xl := xlog.New()
-		c = frpNet.NewContextConn(c, xlog.NewContext(context.Background(), xl))
+		ctx := context.Background()
+
+		c = frpNet.NewContextConn(c, xlog.NewContext(ctx, xl))
 
 		log.Trace("start check TLS connection...")
 		originConn := c
@@ -320,63 +384,7 @@ func (svr *Service) HandleListener(l net.Listener) {
 		log.Trace("success check TLS connection")
 
 		// Start a new goroutine for dealing connections.
-		go func(frpConn net.Conn) {
-			dealFn := func(conn net.Conn) {
-				var rawMsg msg.Message
-				conn.SetReadDeadline(time.Now().Add(connReadTimeout))
-				if rawMsg, err = msg.ReadMsg(conn); err != nil {
-					log.Trace("Failed to read message: %v", err)
-					conn.Close()
-					return
-				}
-				conn.SetReadDeadline(time.Time{})
-
-				switch m := rawMsg.(type) {
-				case *msg.Login:
-					// server plugin hook
-					content := &plugin.LoginContent{
-						Login: *m,
-					}
-					retContent, err := svr.pluginManager.Login(content)
-					if err == nil {
-						m = &retContent.Login
-						err = svr.RegisterControl(conn, m)
-					}
-
-					// If login failed, send error message there.
-					// Otherwise send success message in control's work goroutine.
-					if err != nil {
-						xl.Warn("register control error: %v", err)
-						msg.WriteMsg(conn, &msg.LoginResp{
-							Version: version.Full(),
-							Error:   util.GenerateResponseErrorString("register control error", err, svr.cfg.DetailedErrorsToClient),
-						})
-						conn.Close()
-					}
-				case *msg.NewWorkConn:
-					if err := svr.RegisterWorkConn(conn, m); err != nil {
-						conn.Close()
-					}
-				case *msg.NewVisitorConn:
-					if err = svr.RegisterVisitorConn(conn, m); err != nil {
-						xl.Warn("register visitor conn error: %v", err)
-						msg.WriteMsg(conn, &msg.NewVisitorConnResp{
-							ProxyName: m.ProxyName,
-							Error:     util.GenerateResponseErrorString("register visitor conn error", err, svr.cfg.DetailedErrorsToClient),
-						})
-						conn.Close()
-					} else {
-						msg.WriteMsg(conn, &msg.NewVisitorConnResp{
-							ProxyName: m.ProxyName,
-							Error:     "",
-						})
-					}
-				default:
-					log.Warn("Error message type for the new connection [%s]", conn.RemoteAddr().String())
-					conn.Close()
-				}
-			}
-
+		go func(ctx context.Context, frpConn net.Conn) {
 			if svr.cfg.TcpMux {
 				fmuxCfg := fmux.DefaultConfig()
 				fmuxCfg.KeepAliveInterval = 20 * time.Second
@@ -395,12 +403,12 @@ func (svr *Service) HandleListener(l net.Listener) {
 						session.Close()
 						return
 					}
-					go dealFn(stream)
+					go svr.handleConnection(ctx, stream)
 				}
 			} else {
-				dealFn(frpConn)
+				svr.handleConnection(ctx, frpConn)
 			}
-		}(c)
+		}(ctx, c)
 	}
 }