Browse Source

ssh: return informations to client (#3821)

fatedier 1 year ago
parent
commit
95cf418963
5 changed files with 153 additions and 51 deletions
  1. 15 5
      client/service.go
  2. 53 26
      pkg/config/flags.go
  3. 53 19
      pkg/ssh/server.go
  4. 31 0
      pkg/ssh/terminal.go
  5. 1 1
      test/e2e/pkg/ssh/client.go

+ 15 - 5
client/service.go

@@ -42,6 +42,14 @@ func init() {
 	crypto.DefaultSalt = "frp"
 }
 
+type cancelErr struct {
+	Err error
+}
+
+func (e cancelErr) Error() string {
+	return e.Err.Error()
+}
+
 // ServiceOptions contains options for creating a new client service.
 type ServiceOptions struct {
 	Common      *v1.ClientCommonConfig
@@ -108,7 +116,7 @@ type Service struct {
 	// service context
 	ctx context.Context
 	// call cancel to stop service
-	cancel                   context.CancelFunc
+	cancel                   context.CancelCauseFunc
 	gracefulShutdownDuration time.Duration
 
 	connectorCreator func(context.Context, *v1.ClientCommonConfig) Connector
@@ -145,7 +153,7 @@ func NewService(options ServiceOptions) (*Service, error) {
 }
 
 func (svr *Service) Run(ctx context.Context) error {
-	ctx, cancel := context.WithCancel(ctx)
+	ctx, cancel := context.WithCancelCause(ctx)
 	svr.ctx = xlog.NewContext(ctx, xlog.FromContextSafe(ctx))
 	svr.cancel = cancel
 
@@ -157,7 +165,9 @@ func (svr *Service) Run(ctx context.Context) error {
 	// first login to frps
 	svr.loopLoginUntilSuccess(10*time.Second, lo.FromPtr(svr.common.LoginFailExit))
 	if svr.ctl == nil {
-		return fmt.Errorf("the process exited because the first login to the server failed, and the loginFailExit feature is enabled")
+		cancelCause := cancelErr{}
+		_ = errors.As(context.Cause(svr.ctx), &cancelCause)
+		return fmt.Errorf("login to the server failed: %v. With loginFailExit enabled, no additional retries will be attempted", cancelCause.Err)
 	}
 
 	go svr.keepControllerWorking()
@@ -280,7 +290,7 @@ func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginE
 		if err != nil {
 			xl.Warn("connect to server error: %v", err)
 			if firstLoginExit {
-				svr.cancel()
+				svr.cancel(cancelErr{Err: err})
 			}
 			return err
 		}
@@ -356,7 +366,7 @@ func (svr *Service) Close() {
 
 func (svr *Service) GracefulClose(d time.Duration) {
 	svr.gracefulShutdownDuration = d
-	svr.cancel()
+	svr.cancel(nil)
 }
 
 func (svr *Service) stop() {

+ 53 - 26
pkg/config/flags.go

@@ -25,6 +25,18 @@ import (
 	"github.com/fatedier/frp/pkg/config/v1/validation"
 )
 
+type RegisterFlagOption func(*registerFlagOptions)
+
+type registerFlagOptions struct {
+	sshMode bool
+}
+
+func WithSSHMode() RegisterFlagOption {
+	return func(o *registerFlagOptions) {
+		o.sshMode = true
+	}
+}
+
 type BandwidthQuantityFlag struct {
 	V *types.BandwidthQuantity
 }
@@ -41,8 +53,9 @@ func (f *BandwidthQuantityFlag) Type() string {
 	return "string"
 }
 
-func RegisterProxyFlags(cmd *cobra.Command, c v1.ProxyConfigurer) {
-	registerProxyBaseConfigFlags(cmd, c.GetBaseConfig())
+func RegisterProxyFlags(cmd *cobra.Command, c v1.ProxyConfigurer, opts ...RegisterFlagOption) {
+	registerProxyBaseConfigFlags(cmd, c.GetBaseConfig(), opts...)
+
 	switch cc := c.(type) {
 	case *v1.TCPProxyConfig:
 		cmd.Flags().IntVarP(&cc.RemotePort, "remote_port", "r", 0, "remote port")
@@ -73,17 +86,25 @@ func RegisterProxyFlags(cmd *cobra.Command, c v1.ProxyConfigurer) {
 	}
 }
 
-func registerProxyBaseConfigFlags(cmd *cobra.Command, c *v1.ProxyBaseConfig) {
+func registerProxyBaseConfigFlags(cmd *cobra.Command, c *v1.ProxyBaseConfig, opts ...RegisterFlagOption) {
 	if c == nil {
 		return
 	}
+	options := &registerFlagOptions{}
+	for _, opt := range opts {
+		opt(options)
+	}
+
 	cmd.Flags().StringVarP(&c.Name, "proxy_name", "n", "", "proxy name")
-	cmd.Flags().StringVarP(&c.LocalIP, "local_ip", "i", "127.0.0.1", "local ip")
-	cmd.Flags().IntVarP(&c.LocalPort, "local_port", "l", 0, "local port")
-	cmd.Flags().BoolVarP(&c.Transport.UseEncryption, "ue", "", false, "use encryption")
-	cmd.Flags().BoolVarP(&c.Transport.UseCompression, "uc", "", false, "use compression")
-	cmd.Flags().StringVarP(&c.Transport.BandwidthLimitMode, "bandwidth_limit_mode", "", types.BandwidthLimitModeClient, "bandwidth limit mode")
-	cmd.Flags().VarP(&BandwidthQuantityFlag{V: &c.Transport.BandwidthLimit}, "bandwidth_limit", "", "bandwidth limit (e.g. 100KB or 1MB)")
+
+	if !options.sshMode {
+		cmd.Flags().StringVarP(&c.LocalIP, "local_ip", "i", "127.0.0.1", "local ip")
+		cmd.Flags().IntVarP(&c.LocalPort, "local_port", "l", 0, "local port")
+		cmd.Flags().BoolVarP(&c.Transport.UseEncryption, "ue", "", false, "use encryption")
+		cmd.Flags().BoolVarP(&c.Transport.UseCompression, "uc", "", false, "use compression")
+		cmd.Flags().StringVarP(&c.Transport.BandwidthLimitMode, "bandwidth_limit_mode", "", types.BandwidthLimitModeClient, "bandwidth limit mode")
+		cmd.Flags().VarP(&BandwidthQuantityFlag{V: &c.Transport.BandwidthLimit}, "bandwidth_limit", "", "bandwidth limit (e.g. 100KB or 1MB)")
+	}
 }
 
 func registerProxyDomainConfigFlags(cmd *cobra.Command, c *v1.DomainConfig) {
@@ -94,13 +115,13 @@ func registerProxyDomainConfigFlags(cmd *cobra.Command, c *v1.DomainConfig) {
 	cmd.Flags().StringVarP(&c.SubDomain, "sd", "", "", "sub domain")
 }
 
-func RegisterVisitorFlags(cmd *cobra.Command, c v1.VisitorConfigurer) {
-	registerVisitorBaseConfigFlags(cmd, c.GetBaseConfig())
+func RegisterVisitorFlags(cmd *cobra.Command, c v1.VisitorConfigurer, opts ...RegisterFlagOption) {
+	registerVisitorBaseConfigFlags(cmd, c.GetBaseConfig(), opts...)
 
 	// add visitor flags if exist
 }
 
-func registerVisitorBaseConfigFlags(cmd *cobra.Command, c *v1.VisitorBaseConfig) {
+func registerVisitorBaseConfigFlags(cmd *cobra.Command, c *v1.VisitorBaseConfig, _ ...RegisterFlagOption) {
 	if c == nil {
 		return
 	}
@@ -113,21 +134,27 @@ func registerVisitorBaseConfigFlags(cmd *cobra.Command, c *v1.VisitorBaseConfig)
 	cmd.Flags().IntVarP(&c.BindPort, "bind_port", "", 0, "bind port")
 }
 
-func RegisterClientCommonConfigFlags(cmd *cobra.Command, c *v1.ClientCommonConfig) {
-	cmd.PersistentFlags().StringVarP(&c.ServerAddr, "server_addr", "s", "127.0.0.1", "frp server's address")
-	cmd.PersistentFlags().IntVarP(&c.ServerPort, "server_port", "P", 7000, "frp server's port")
+func RegisterClientCommonConfigFlags(cmd *cobra.Command, c *v1.ClientCommonConfig, opts ...RegisterFlagOption) {
+	options := &registerFlagOptions{}
+	for _, opt := range opts {
+		opt(options)
+	}
+
+	if !options.sshMode {
+		cmd.PersistentFlags().StringVarP(&c.ServerAddr, "server_addr", "s", "127.0.0.1", "frp server's address")
+		cmd.PersistentFlags().IntVarP(&c.ServerPort, "server_port", "P", 7000, "frp server's port")
+		cmd.PersistentFlags().StringVarP(&c.Transport.Protocol, "protocol", "p", "tcp",
+			fmt.Sprintf("optional values are %v", validation.SupportedTransportProtocols))
+		cmd.PersistentFlags().StringVarP(&c.Log.Level, "log_level", "", "info", "log level")
+		cmd.PersistentFlags().StringVarP(&c.Log.To, "log_file", "", "console", "console or file path")
+		cmd.PersistentFlags().Int64VarP(&c.Log.MaxDays, "log_max_days", "", 3, "log file reversed days")
+		cmd.PersistentFlags().BoolVarP(&c.Log.DisablePrintColor, "disable_log_color", "", false, "disable log color in console")
+		cmd.PersistentFlags().StringVarP(&c.Transport.TLS.ServerName, "tls_server_name", "", "", "specify the custom server name of tls certificate")
+		cmd.PersistentFlags().StringVarP(&c.DNSServer, "dns_server", "", "", "specify dns server instead of using system default one")
+		c.Transport.TLS.Enable = cmd.PersistentFlags().BoolP("tls_enable", "", true, "enable frpc tls")
+	}
 	cmd.PersistentFlags().StringVarP(&c.User, "user", "u", "", "user")
-	cmd.PersistentFlags().StringVarP(&c.Transport.Protocol, "protocol", "p", "tcp",
-		fmt.Sprintf("optional values are %v", validation.SupportedTransportProtocols))
 	cmd.PersistentFlags().StringVarP(&c.Auth.Token, "token", "t", "", "auth token")
-	cmd.PersistentFlags().StringVarP(&c.Log.Level, "log_level", "", "info", "log level")
-	cmd.PersistentFlags().StringVarP(&c.Log.To, "log_file", "", "console", "console or file path")
-	cmd.PersistentFlags().Int64VarP(&c.Log.MaxDays, "log_max_days", "", 3, "log file reversed days")
-	cmd.PersistentFlags().BoolVarP(&c.Log.DisablePrintColor, "disable_log_color", "", false, "disable log color in console")
-	cmd.PersistentFlags().StringVarP(&c.Transport.TLS.ServerName, "tls_server_name", "", "", "specify the custom server name of tls certificate")
-	cmd.PersistentFlags().StringVarP(&c.DNSServer, "dns_server", "", "", "specify dns server instead of using system default one")
-
-	c.Transport.TLS.Enable = cmd.PersistentFlags().BoolP("tls_enable", "", true, "enable frpc tls")
 }
 
 type PortsRangeSliceFlag struct {
@@ -185,7 +212,7 @@ func (f *BoolFuncFlag) Type() string {
 	return "bool"
 }
 
-func RegisterServerConfigFlags(cmd *cobra.Command, c *v1.ServerConfig) {
+func RegisterServerConfigFlags(cmd *cobra.Command, c *v1.ServerConfig, opts ...RegisterFlagOption) {
 	cmd.PersistentFlags().StringVarP(&c.BindAddr, "bind_addr", "", "0.0.0.0", "bind address")
 	cmd.PersistentFlags().IntVarP(&c.BindPort, "bind_port", "p", 7000, "bind port")
 	cmd.PersistentFlags().IntVarP(&c.KCPBindPort, "kcp_bind_port", "", 0, "kcp bind udp port")

+ 53 - 19
pkg/ssh/server.go

@@ -27,6 +27,7 @@ import (
 	libio "github.com/fatedier/golib/io"
 	"github.com/samber/lo"
 	"github.com/spf13/cobra"
+	flag "github.com/spf13/pflag"
 	"golang.org/x/crypto/ssh"
 
 	"github.com/fatedier/frp/client/proxy"
@@ -64,6 +65,7 @@ type TunnelServer struct {
 	underlyingConn net.Conn
 	sshConn        *ssh.ServerConn
 	sc             *ssh.ServerConfig
+	firstChannel   ssh.Channel
 
 	vc                 *virtual.Client
 	peerServerListener *netpkg.InternalListener
@@ -86,6 +88,7 @@ func (s *TunnelServer) Run() error {
 	if err != nil {
 		return err
 	}
+
 	s.sshConn = sshConn
 
 	addr, extraPayload, err := s.waitForwardAddrAndExtraPayload(channels, requests, 3*time.Second)
@@ -93,9 +96,14 @@ func (s *TunnelServer) Run() error {
 		return err
 	}
 
-	clientCfg, pc, err := s.parseClientAndProxyConfigurer(addr, extraPayload)
+	clientCfg, pc, helpMessage, err := s.parseClientAndProxyConfigurer(addr, extraPayload)
 	if err != nil {
-		return err
+		if errors.Is(err, flag.ErrHelp) {
+			s.writeToClient(helpMessage)
+			return nil
+		}
+		s.writeToClient(err.Error())
+		return fmt.Errorf("parse flags from ssh client error: %v", err)
 	}
 	clientCfg.Complete()
 	if sshConn.Permissions != nil {
@@ -142,7 +150,11 @@ func (s *TunnelServer) Run() error {
 	xl := xlog.New().AddPrefix(xlog.LogPrefix{Name: "sshVirtualClient", Value: "sshVirtualClient", Priority: 100})
 	ctx := xlog.NewContext(context.Background(), xl)
 	go func() {
-		_ = s.vc.Run(ctx)
+		vcErr := s.vc.Run(ctx)
+		if vcErr != nil {
+			s.writeToClient(vcErr.Error())
+		}
+
 		// If vc.Run returns, it means that the virtual client has been closed, and the ssh tunnel connection should be closed.
 		// One scenario is that the virtual client exits due to login failure.
 		s.closeDoneChOnce.Do(func() {
@@ -153,9 +165,12 @@ func (s *TunnelServer) Run() error {
 
 	s.vc.UpdateProxyConfigurer([]v1.ProxyConfigurer{pc})
 
-	if err := s.waitProxyStatusReady(pc.GetBaseConfig().Name, time.Second); err != nil {
+	if ps, err := s.waitProxyStatusReady(pc.GetBaseConfig().Name, time.Second); err != nil {
+		s.writeToClient(err.Error())
 		log.Warn("wait proxy status ready error: %v", err)
 	} else {
+		// success
+		s.writeToClient(createSuccessInfo(clientCfg.User, pc, ps))
 		_ = sshConn.Wait()
 	}
 
@@ -168,6 +183,13 @@ func (s *TunnelServer) Run() error {
 	return nil
 }
 
+func (s *TunnelServer) writeToClient(data string) {
+	if s.firstChannel == nil {
+		return
+	}
+	_, _ = s.firstChannel.Write([]byte(data + "\n"))
+}
+
 func (s *TunnelServer) waitForwardAddrAndExtraPayload(
 	channels <-chan ssh.NewChannel,
 	requests <-chan *ssh.Request,
@@ -225,38 +247,47 @@ func (s *TunnelServer) waitForwardAddrAndExtraPayload(
 	return addr, extraPayload, nil
 }
 
-func (s *TunnelServer) parseClientAndProxyConfigurer(_ *tcpipForward, extraPayload string) (*v1.ClientCommonConfig, v1.ProxyConfigurer, error) {
-	cmd := &cobra.Command{}
+func (s *TunnelServer) parseClientAndProxyConfigurer(_ *tcpipForward, extraPayload string) (*v1.ClientCommonConfig, v1.ProxyConfigurer, string, error) {
+	helpMessage := ""
+	cmd := &cobra.Command{
+		Use:   "ssh v0@{address} [command]",
+		Short: "ssh v0@{address} [command]",
+		Run:   func(*cobra.Command, []string) {},
+	}
 	args := strings.Split(extraPayload, " ")
 	if len(args) < 1 {
-		return nil, nil, fmt.Errorf("invalid extra payload")
+		return nil, nil, helpMessage, fmt.Errorf("invalid extra payload")
 	}
 	proxyType := strings.TrimSpace(args[0])
 	supportTypes := []string{"tcp", "http", "https", "tcpmux", "stcp"}
 	if !lo.Contains(supportTypes, proxyType) {
-		return nil, nil, fmt.Errorf("invalid proxy type: %s, support types: %v", proxyType, supportTypes)
+		return nil, nil, helpMessage, fmt.Errorf("invalid proxy type: %s, support types: %v", proxyType, supportTypes)
 	}
 	pc := v1.NewProxyConfigurerByType(v1.ProxyType(proxyType))
 	if pc == nil {
-		return nil, nil, fmt.Errorf("new proxy configurer error")
+		return nil, nil, helpMessage, fmt.Errorf("new proxy configurer error")
 	}
-	config.RegisterProxyFlags(cmd, pc)
+	config.RegisterProxyFlags(cmd, pc, config.WithSSHMode())
 
 	clientCfg := v1.ClientCommonConfig{}
-	config.RegisterClientCommonConfigFlags(cmd, &clientCfg)
+	config.RegisterClientCommonConfigFlags(cmd, &clientCfg, config.WithSSHMode())
 
+	cmd.InitDefaultHelpCmd()
 	if err := cmd.ParseFlags(args); err != nil {
-		return nil, nil, fmt.Errorf("parse flags from ssh client error: %v", err)
+		if errors.Is(err, flag.ErrHelp) {
+			helpMessage = cmd.UsageString()
+		}
+		return nil, nil, helpMessage, err
 	}
 	// if name is not set, generate a random one
 	if pc.GetBaseConfig().Name == "" {
 		id, err := util.RandIDWithLen(8)
 		if err != nil {
-			return nil, nil, fmt.Errorf("generate random id error: %v", err)
+			return nil, nil, helpMessage, fmt.Errorf("generate random id error: %v", err)
 		}
 		pc.GetBaseConfig().Name = fmt.Sprintf("sshtunnel-%s-%s", proxyType, id)
 	}
-	return &clientCfg, pc, nil
+	return &clientCfg, pc, helpMessage, nil
 }
 
 func (s *TunnelServer) handleNewChannel(channel ssh.NewChannel, extraPayloadCh chan string) {
@@ -264,6 +295,9 @@ func (s *TunnelServer) handleNewChannel(channel ssh.NewChannel, extraPayloadCh c
 	if err != nil {
 		return
 	}
+	if s.firstChannel == nil {
+		s.firstChannel = ch
+	}
 	go s.keepAlive(ch)
 
 	for req := range reqs {
@@ -320,7 +354,7 @@ func (s *TunnelServer) openConn(addr *tcpipForward) (net.Conn, error) {
 	return conn, nil
 }
 
-func (s *TunnelServer) waitProxyStatusReady(name string, timeout time.Duration) error {
+func (s *TunnelServer) waitProxyStatusReady(name string, timeout time.Duration) (*proxy.WorkingStatus, error) {
 	ticker := time.NewTicker(100 * time.Millisecond)
 	defer ticker.Stop()
 
@@ -336,14 +370,14 @@ func (s *TunnelServer) waitProxyStatusReady(name string, timeout time.Duration)
 			}
 			switch ps.Phase {
 			case proxy.ProxyPhaseRunning:
-				return nil
+				return ps, nil
 			case proxy.ProxyPhaseStartErr, proxy.ProxyPhaseClosed:
-				return errors.New(ps.Err)
+				return ps, errors.New(ps.Err)
 			}
 		case <-timer.C:
-			return fmt.Errorf("wait proxy status ready timeout")
+			return nil, fmt.Errorf("wait proxy status ready timeout")
 		case <-s.doneCh:
-			return fmt.Errorf("ssh tunnel server closed")
+			return nil, fmt.Errorf("ssh tunnel server closed")
 		}
 	}
 }

+ 31 - 0
pkg/ssh/terminal.go

@@ -0,0 +1,31 @@
+// Copyright 2023 The frp Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ssh
+
+import (
+	"github.com/fatedier/frp/client/proxy"
+	v1 "github.com/fatedier/frp/pkg/config/v1"
+)
+
+func createSuccessInfo(user string, pc v1.ProxyConfigurer, ps *proxy.WorkingStatus) string {
+	base := pc.GetBaseConfig()
+	out := "\n"
+	out += "frp (via SSH) (Ctrl+C to quit)\n\n"
+	out += "User: " + user + "\n"
+	out += "ProxyName: " + base.Name + "\n"
+	out += "Type: " + base.Type + "\n"
+	out += "RemoteAddress: " + ps.RemoteAddr + "\n"
+	return out
+}

+ 1 - 1
test/e2e/pkg/ssh/client.go

@@ -41,7 +41,7 @@ func (c *TunnelClient) Start() error {
 		return err
 	}
 	c.ln = l
-	ch, req, err := conn.OpenChannel("direct", []byte(""))
+	ch, req, err := conn.OpenChannel("session", []byte(""))
 	if err != nil {
 		return err
 	}