|
@@ -297,6 +297,68 @@ func (svr *Service) Run() {
|
|
|
svr.HandleListener(svr.listener)
|
|
|
}
|
|
|
|
|
|
+func (svr *Service) handleConnection(ctx context.Context, conn net.Conn) {
|
|
|
+ xl := xlog.FromContextSafe(ctx)
|
|
|
+
|
|
|
+ var (
|
|
|
+ rawMsg msg.Message
|
|
|
+ err error
|
|
|
+ )
|
|
|
+
|
|
|
+ conn.SetReadDeadline(time.Now().Add(connReadTimeout))
|
|
|
+ if rawMsg, err = msg.ReadMsg(conn); err != nil {
|
|
|
+ log.Trace("Failed to read message: %v", err)
|
|
|
+ conn.Close()
|
|
|
+ return
|
|
|
+ }
|
|
|
+ conn.SetReadDeadline(time.Time{})
|
|
|
+
|
|
|
+ switch m := rawMsg.(type) {
|
|
|
+ case *msg.Login:
|
|
|
+ // server plugin hook
|
|
|
+ content := &plugin.LoginContent{
|
|
|
+ Login: *m,
|
|
|
+ }
|
|
|
+ retContent, err := svr.pluginManager.Login(content)
|
|
|
+ if err == nil {
|
|
|
+ m = &retContent.Login
|
|
|
+ err = svr.RegisterControl(conn, m)
|
|
|
+ }
|
|
|
+
|
|
|
+ // If login failed, send error message there.
|
|
|
+ // Otherwise send success message in control's work goroutine.
|
|
|
+ if err != nil {
|
|
|
+ xl.Warn("register control error: %v", err)
|
|
|
+ msg.WriteMsg(conn, &msg.LoginResp{
|
|
|
+ Version: version.Full(),
|
|
|
+ Error: util.GenerateResponseErrorString("register control error", err, svr.cfg.DetailedErrorsToClient),
|
|
|
+ })
|
|
|
+ conn.Close()
|
|
|
+ }
|
|
|
+ case *msg.NewWorkConn:
|
|
|
+ if err := svr.RegisterWorkConn(conn, m); err != nil {
|
|
|
+ conn.Close()
|
|
|
+ }
|
|
|
+ case *msg.NewVisitorConn:
|
|
|
+ if err = svr.RegisterVisitorConn(conn, m); err != nil {
|
|
|
+ xl.Warn("register visitor conn error: %v", err)
|
|
|
+ msg.WriteMsg(conn, &msg.NewVisitorConnResp{
|
|
|
+ ProxyName: m.ProxyName,
|
|
|
+ Error: util.GenerateResponseErrorString("register visitor conn error", err, svr.cfg.DetailedErrorsToClient),
|
|
|
+ })
|
|
|
+ conn.Close()
|
|
|
+ } else {
|
|
|
+ msg.WriteMsg(conn, &msg.NewVisitorConnResp{
|
|
|
+ ProxyName: m.ProxyName,
|
|
|
+ Error: "",
|
|
|
+ })
|
|
|
+ }
|
|
|
+ default:
|
|
|
+ log.Warn("Error message type for the new connection [%s]", conn.RemoteAddr().String())
|
|
|
+ conn.Close()
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func (svr *Service) HandleListener(l net.Listener) {
|
|
|
// Listen for incoming connections from client.
|
|
|
for {
|
|
@@ -307,7 +369,9 @@ func (svr *Service) HandleListener(l net.Listener) {
|
|
|
}
|
|
|
// inject xlog object into net.Conn context
|
|
|
xl := xlog.New()
|
|
|
- c = frpNet.NewContextConn(c, xlog.NewContext(context.Background(), xl))
|
|
|
+ ctx := context.Background()
|
|
|
+
|
|
|
+ c = frpNet.NewContextConn(c, xlog.NewContext(ctx, xl))
|
|
|
|
|
|
log.Trace("start check TLS connection...")
|
|
|
originConn := c
|
|
@@ -320,63 +384,7 @@ func (svr *Service) HandleListener(l net.Listener) {
|
|
|
log.Trace("success check TLS connection")
|
|
|
|
|
|
// Start a new goroutine for dealing connections.
|
|
|
- go func(frpConn net.Conn) {
|
|
|
- dealFn := func(conn net.Conn) {
|
|
|
- var rawMsg msg.Message
|
|
|
- conn.SetReadDeadline(time.Now().Add(connReadTimeout))
|
|
|
- if rawMsg, err = msg.ReadMsg(conn); err != nil {
|
|
|
- log.Trace("Failed to read message: %v", err)
|
|
|
- conn.Close()
|
|
|
- return
|
|
|
- }
|
|
|
- conn.SetReadDeadline(time.Time{})
|
|
|
-
|
|
|
- switch m := rawMsg.(type) {
|
|
|
- case *msg.Login:
|
|
|
- // server plugin hook
|
|
|
- content := &plugin.LoginContent{
|
|
|
- Login: *m,
|
|
|
- }
|
|
|
- retContent, err := svr.pluginManager.Login(content)
|
|
|
- if err == nil {
|
|
|
- m = &retContent.Login
|
|
|
- err = svr.RegisterControl(conn, m)
|
|
|
- }
|
|
|
-
|
|
|
- // If login failed, send error message there.
|
|
|
- // Otherwise send success message in control's work goroutine.
|
|
|
- if err != nil {
|
|
|
- xl.Warn("register control error: %v", err)
|
|
|
- msg.WriteMsg(conn, &msg.LoginResp{
|
|
|
- Version: version.Full(),
|
|
|
- Error: util.GenerateResponseErrorString("register control error", err, svr.cfg.DetailedErrorsToClient),
|
|
|
- })
|
|
|
- conn.Close()
|
|
|
- }
|
|
|
- case *msg.NewWorkConn:
|
|
|
- if err := svr.RegisterWorkConn(conn, m); err != nil {
|
|
|
- conn.Close()
|
|
|
- }
|
|
|
- case *msg.NewVisitorConn:
|
|
|
- if err = svr.RegisterVisitorConn(conn, m); err != nil {
|
|
|
- xl.Warn("register visitor conn error: %v", err)
|
|
|
- msg.WriteMsg(conn, &msg.NewVisitorConnResp{
|
|
|
- ProxyName: m.ProxyName,
|
|
|
- Error: util.GenerateResponseErrorString("register visitor conn error", err, svr.cfg.DetailedErrorsToClient),
|
|
|
- })
|
|
|
- conn.Close()
|
|
|
- } else {
|
|
|
- msg.WriteMsg(conn, &msg.NewVisitorConnResp{
|
|
|
- ProxyName: m.ProxyName,
|
|
|
- Error: "",
|
|
|
- })
|
|
|
- }
|
|
|
- default:
|
|
|
- log.Warn("Error message type for the new connection [%s]", conn.RemoteAddr().String())
|
|
|
- conn.Close()
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
+ go func(ctx context.Context, frpConn net.Conn) {
|
|
|
if svr.cfg.TcpMux {
|
|
|
fmuxCfg := fmux.DefaultConfig()
|
|
|
fmuxCfg.KeepAliveInterval = 20 * time.Second
|
|
@@ -395,12 +403,12 @@ func (svr *Service) HandleListener(l net.Listener) {
|
|
|
session.Close()
|
|
|
return
|
|
|
}
|
|
|
- go dealFn(stream)
|
|
|
+ go svr.handleConnection(ctx, stream)
|
|
|
}
|
|
|
} else {
|
|
|
- dealFn(frpConn)
|
|
|
+ svr.handleConnection(ctx, frpConn)
|
|
|
}
|
|
|
- }(c)
|
|
|
+ }(ctx, c)
|
|
|
}
|
|
|
}
|
|
|
|