소스 검색

add e2e tests for ssh tunnel (#3805)

fatedier 1 년 전
부모
커밋
7c799ee921
4개의 변경된 파일300개의 추가작업 그리고 18개의 파일을 삭제
  1. 1 0
      .gitignore
  2. 17 18
      pkg/ssh/server.go
  3. 89 0
      test/e2e/pkg/ssh/client.go
  4. 193 0
      test/e2e/v1/features/ssh_tunnel.go

+ 1 - 0
.gitignore

@@ -33,6 +33,7 @@ lastversion/
 dist/
 .idea/
 .vscode/
+.autogen_ssh_key
 
 # Cache
 *.swp

+ 17 - 18
pkg/ssh/server.go

@@ -56,8 +56,6 @@ type forwardedTCPPayload struct {
 	Addr string
 	Port uint32
 
-	// can be default empty value but do not delete it
-	// because ssh protocol shoule be reserved
 	OriginAddr string
 	OriginPort uint32
 }
@@ -117,6 +115,8 @@ func (s *TunnelServer) Run() error {
 			// join workConn and ssh channel
 			c, err := s.openConn(addr)
 			if err != nil {
+				log.Trace("open conn error: %v", err)
+				workConn.Close()
 				return false
 			}
 			libio.Join(c, workConn)
@@ -180,20 +180,16 @@ func (s *TunnelServer) waitForwardAddrAndExtraPayload(
 	go func() {
 		addrGot := false
 		for req := range requests {
-			switch req.Type {
-			case RequestTypeForward:
-				if !addrGot {
-					payload := tcpipForward{}
-					if err := ssh.Unmarshal(req.Payload, &payload); err != nil {
-						return
-					}
-					addrGot = true
-					addrCh <- &payload
-				}
-			default:
-				if req.WantReply {
-					_ = req.Reply(true, nil)
+			if req.Type == RequestTypeForward && !addrGot {
+				payload := tcpipForward{}
+				if err := ssh.Unmarshal(req.Payload, &payload); err != nil {
+					return
 				}
+				addrGot = true
+				addrCh <- &payload
+			}
+			if req.WantReply {
+				_ = req.Reply(true, nil)
 			}
 		}
 	}()
@@ -271,10 +267,10 @@ func (s *TunnelServer) handleNewChannel(channel ssh.NewChannel, extraPayloadCh c
 	go s.keepAlive(ch)
 
 	for req := range reqs {
-		if req.Type != "exec" {
-			continue
+		if req.WantReply {
+			_ = req.Reply(true, nil)
 		}
-		if len(req.Payload) <= 4 {
+		if req.Type != "exec" || len(req.Payload) <= 4 {
 			continue
 		}
 		end := 4 + binary.BigEndian.Uint32(req.Payload[:4])
@@ -310,6 +306,9 @@ func (s *TunnelServer) openConn(addr *tcpipForward) (net.Conn, error) {
 	payload := forwardedTCPPayload{
 		Addr: addr.Host,
 		Port: addr.Port,
+		// Note: Here is just for compatibility, not the real source address.
+		OriginAddr: addr.Host,
+		OriginPort: addr.Port,
 	}
 	channel, reqs, err := s.sshConn.OpenChannel(ChannelTypeServerOpenChannel, ssh.Marshal(&payload))
 	if err != nil {

+ 89 - 0
test/e2e/pkg/ssh/client.go

@@ -0,0 +1,89 @@
+package ssh
+
+import (
+	"net"
+
+	libio "github.com/fatedier/golib/io"
+	"golang.org/x/crypto/ssh"
+)
+
+type TunnelClient struct {
+	localAddr string
+	sshServer string
+	commands  string
+
+	sshConn *ssh.Client
+	ln      net.Listener
+}
+
+func NewTunnelClient(localAddr string, sshServer string, commands string) *TunnelClient {
+	return &TunnelClient{
+		localAddr: localAddr,
+		sshServer: sshServer,
+		commands:  commands,
+	}
+}
+
+func (c *TunnelClient) Start() error {
+	config := &ssh.ClientConfig{
+		User:            "v0",
+		HostKeyCallback: func(string, net.Addr, ssh.PublicKey) error { return nil },
+	}
+
+	conn, err := ssh.Dial("tcp", c.sshServer, config)
+	if err != nil {
+		return err
+	}
+	c.sshConn = conn
+
+	l, err := conn.Listen("tcp", "0.0.0.0:80")
+	if err != nil {
+		return err
+	}
+	c.ln = l
+	ch, req, err := conn.OpenChannel("direct", []byte(""))
+	if err != nil {
+		return err
+	}
+	defer ch.Close()
+	go ssh.DiscardRequests(req)
+
+	type command struct {
+		Cmd string
+	}
+	_, err = ch.SendRequest("exec", false, ssh.Marshal(command{Cmd: c.commands}))
+	if err != nil {
+		return err
+	}
+
+	go c.serveListener()
+	return nil
+}
+
+func (c *TunnelClient) Close() {
+	if c.sshConn != nil {
+		_ = c.sshConn.Close()
+	}
+	if c.ln != nil {
+		_ = c.ln.Close()
+	}
+}
+
+func (c *TunnelClient) serveListener() {
+	for {
+		conn, err := c.ln.Accept()
+		if err != nil {
+			return
+		}
+		go c.hanldeConn(conn)
+	}
+}
+
+func (c *TunnelClient) hanldeConn(conn net.Conn) {
+	defer conn.Close()
+	local, err := net.Dial("tcp", c.localAddr)
+	if err != nil {
+		return
+	}
+	_, _, _ = libio.Join(local, conn)
+}

+ 193 - 0
test/e2e/v1/features/ssh_tunnel.go

@@ -0,0 +1,193 @@
+package features
+
+import (
+	"crypto/tls"
+	"fmt"
+	"time"
+
+	"github.com/onsi/ginkgo/v2"
+
+	"github.com/fatedier/frp/pkg/transport"
+	"github.com/fatedier/frp/test/e2e/framework"
+	"github.com/fatedier/frp/test/e2e/framework/consts"
+	"github.com/fatedier/frp/test/e2e/mock/server/httpserver"
+	"github.com/fatedier/frp/test/e2e/mock/server/streamserver"
+	"github.com/fatedier/frp/test/e2e/pkg/request"
+	"github.com/fatedier/frp/test/e2e/pkg/ssh"
+)
+
+var _ = ginkgo.Describe("[Feature: SSH Tunnel]", func() {
+	f := framework.NewDefaultFramework()
+
+	ginkgo.It("tcp", func() {
+		sshPort := f.AllocPort()
+		serverConf := consts.DefaultServerConfig + fmt.Sprintf(`
+		sshTunnelGateway.bindPort = %d
+		`, sshPort)
+
+		f.RunProcesses([]string{serverConf}, nil)
+
+		localPort := f.PortByName(framework.TCPEchoServerPort)
+		remotePort := f.AllocPort()
+		tc := ssh.NewTunnelClient(
+			fmt.Sprintf("127.0.0.1:%d", localPort),
+			fmt.Sprintf("127.0.0.1:%d", sshPort),
+			fmt.Sprintf("tcp --remote_port %d", remotePort),
+		)
+		framework.ExpectNoError(tc.Start())
+		defer tc.Close()
+
+		time.Sleep(time.Second)
+		framework.NewRequestExpect(f).Port(remotePort).Ensure()
+	})
+
+	ginkgo.It("http", func() {
+		sshPort := f.AllocPort()
+		vhostPort := f.AllocPort()
+		serverConf := consts.DefaultServerConfig + fmt.Sprintf(`
+		vhostHTTPPort = %d
+		sshTunnelGateway.bindPort = %d
+		`, vhostPort, sshPort)
+
+		f.RunProcesses([]string{serverConf}, nil)
+
+		localPort := f.PortByName(framework.HTTPSimpleServerPort)
+		tc := ssh.NewTunnelClient(
+			fmt.Sprintf("127.0.0.1:%d", localPort),
+			fmt.Sprintf("127.0.0.1:%d", sshPort),
+			"http --custom_domain test.example.com",
+		)
+		framework.ExpectNoError(tc.Start())
+		defer tc.Close()
+
+		time.Sleep(time.Second)
+		framework.NewRequestExpect(f).Port(vhostPort).
+			RequestModify(func(r *request.Request) {
+				r.HTTP().HTTPHost("test.example.com")
+			}).
+			Ensure()
+	})
+
+	ginkgo.It("https", func() {
+		sshPort := f.AllocPort()
+		vhostPort := f.AllocPort()
+		serverConf := consts.DefaultServerConfig + fmt.Sprintf(`
+		vhostHTTPSPort = %d
+		sshTunnelGateway.bindPort = %d
+		`, vhostPort, sshPort)
+
+		f.RunProcesses([]string{serverConf}, nil)
+
+		localPort := f.AllocPort()
+		testDomain := "test.example.com"
+		tc := ssh.NewTunnelClient(
+			fmt.Sprintf("127.0.0.1:%d", localPort),
+			fmt.Sprintf("127.0.0.1:%d", sshPort),
+			fmt.Sprintf("https --custom_domain %s", testDomain),
+		)
+		framework.ExpectNoError(tc.Start())
+		defer tc.Close()
+
+		tlsConfig, err := transport.NewServerTLSConfig("", "", "")
+		framework.ExpectNoError(err)
+		localServer := httpserver.New(
+			httpserver.WithBindPort(localPort),
+			httpserver.WithTLSConfig(tlsConfig),
+			httpserver.WithResponse([]byte("test")),
+		)
+		f.RunServer("", localServer)
+
+		time.Sleep(time.Second)
+		framework.NewRequestExpect(f).
+			Port(vhostPort).
+			RequestModify(func(r *request.Request) {
+				r.HTTPS().HTTPHost(testDomain).TLSConfig(&tls.Config{
+					ServerName:         testDomain,
+					InsecureSkipVerify: true,
+				})
+			}).
+			ExpectResp([]byte("test")).
+			Ensure()
+	})
+
+	ginkgo.It("tcpmux", func() {
+		sshPort := f.AllocPort()
+		tcpmuxPort := f.AllocPort()
+		serverConf := consts.DefaultServerConfig + fmt.Sprintf(`
+		tcpmuxHTTPConnectPort = %d
+		sshTunnelGateway.bindPort = %d
+		`, tcpmuxPort, sshPort)
+
+		f.RunProcesses([]string{serverConf}, nil)
+
+		localPort := f.AllocPort()
+		testDomain := "test.example.com"
+		tc := ssh.NewTunnelClient(
+			fmt.Sprintf("127.0.0.1:%d", localPort),
+			fmt.Sprintf("127.0.0.1:%d", sshPort),
+			fmt.Sprintf("tcpmux --mux=httpconnect --custom_domain %s", testDomain),
+		)
+		framework.ExpectNoError(tc.Start())
+		defer tc.Close()
+
+		localServer := streamserver.New(
+			streamserver.TCP,
+			streamserver.WithBindPort(localPort),
+			streamserver.WithRespContent([]byte("test")),
+		)
+		f.RunServer("", localServer)
+
+		time.Sleep(time.Second)
+		// Request without HTTP connect should get error
+		framework.NewRequestExpect(f).
+			Port(tcpmuxPort).
+			ExpectError(true).
+			Explain("request without HTTP connect expect error").
+			Ensure()
+
+		proxyURL := fmt.Sprintf("http://127.0.0.1:%d", tcpmuxPort)
+		// Request with incorrect connect hostname
+		framework.NewRequestExpect(f).RequestModify(func(r *request.Request) {
+			r.Addr("invalid").Proxy(proxyURL)
+		}).ExpectError(true).Explain("request without HTTP connect expect error").Ensure()
+
+		// Request with correct connect hostname
+		framework.NewRequestExpect(f).RequestModify(func(r *request.Request) {
+			r.Addr(testDomain).Proxy(proxyURL)
+		}).ExpectResp([]byte("test")).Ensure()
+	})
+
+	ginkgo.It("stcp", func() {
+		sshPort := f.AllocPort()
+		serverConf := consts.DefaultServerConfig + fmt.Sprintf(`
+		sshTunnelGateway.bindPort = %d
+		`, sshPort)
+
+		bindPort := f.AllocPort()
+		visitorConf := consts.DefaultClientConfig + fmt.Sprintf(`
+        [[visitors]]
+		name = "stcp-test-visitor"
+		type = "stcp"
+		serverName = "stcp-test"
+		secretKey = "abcdefg"
+		bindPort = %d
+		`, bindPort)
+
+		f.RunProcesses([]string{serverConf}, []string{visitorConf})
+
+		localPort := f.PortByName(framework.TCPEchoServerPort)
+		tc := ssh.NewTunnelClient(
+			fmt.Sprintf("127.0.0.1:%d", localPort),
+			fmt.Sprintf("127.0.0.1:%d", sshPort),
+			"stcp -n stcp-test --sk=abcdefg --allow_users=\"*\"",
+		)
+		framework.ExpectNoError(tc.Start())
+		defer tc.Close()
+
+		time.Sleep(time.Second)
+
+		framework.NewRequestExpect(f).
+			Port(bindPort).
+			Ensure()
+	})
+})