Browse Source

add real ip test

fatedier 3 years ago
parent
commit
a51e221db3

+ 4 - 4
.circleci/config.yml

@@ -1,6 +1,6 @@
 version: 2
 jobs:
-  go1.16:
+  go-version-latest:
     docker:
       - image: circleci/golang:1.16-node
     working_directory: /go/src/github.com/fatedier/frp
@@ -8,7 +8,7 @@ jobs:
       - checkout
       - run: make
       - run: make alltest
-  go1.15:
+  go-version-last:
     docker:
       - image: circleci/golang:1.15-node
     working_directory: /go/src/github.com/fatedier/frp
@@ -21,5 +21,5 @@ workflows:
   version: 2
   build_and_test:
     jobs:
-      - go1.16
-      - go1.15
+      - go-version-latest
+      - go-version-last

+ 136 - 4
test/e2e/features/real_ip.go

@@ -1,20 +1,152 @@
 package features
 
 import (
+	"bufio"
+	"fmt"
+	"net"
+	"net/http"
+
+	"github.com/fatedier/frp/pkg/util/log"
 	"github.com/fatedier/frp/test/e2e/framework"
+	"github.com/fatedier/frp/test/e2e/framework/consts"
+	"github.com/fatedier/frp/test/e2e/mock/server/httpserver"
+	"github.com/fatedier/frp/test/e2e/mock/server/streamserver"
+	"github.com/fatedier/frp/test/e2e/pkg/request"
+	"github.com/fatedier/frp/test/e2e/pkg/rpc"
 
 	. "github.com/onsi/ginkgo"
+	pp "github.com/pires/go-proxyproto"
 )
 
 var _ = Describe("[Feature: Real IP]", func() {
 	f := framework.NewDefaultFramework()
 
 	It("HTTP X-Forwarded-For", func() {
-		// TODO
-		_ = f
+		vhostHTTPPort := f.AllocPort()
+		serverConf := consts.DefaultServerConfig + fmt.Sprintf(`
+		vhost_http_port = %d
+		`, vhostHTTPPort)
+
+		localPort := f.AllocPort()
+		localServer := httpserver.New(
+			httpserver.WithBindPort(localPort),
+			httpserver.WithHandler(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+				w.Write([]byte(req.Header.Get("X-Forwarded-For")))
+			})),
+		)
+		f.RunServer("", localServer)
+
+		clientConf := consts.DefaultClientConfig
+		clientConf += fmt.Sprintf(`
+		[test]
+		type = http
+		local_port = %d
+		custom_domains = normal.example.com
+		`, localPort)
+
+		f.RunProcesses([]string{serverConf}, []string{clientConf})
+
+		framework.NewRequestExpect(f).Port(vhostHTTPPort).
+			RequestModify(func(r *request.Request) {
+				r.HTTP().HTTPHost("normal.example.com")
+			}).
+			ExpectResp([]byte("127.0.0.1")).
+			Ensure()
+
 	})
 
-	It("Proxy Protocol", func() {
-		// TODO
+	Describe("Proxy Protocol", func() {
+		It("TCP", func() {
+			serverConf := consts.DefaultServerConfig
+			clientConf := consts.DefaultClientConfig
+
+			localPort := f.AllocPort()
+			localServer := streamserver.New(streamserver.TCP, streamserver.WithBindPort(localPort),
+				streamserver.WithCustomHandler(func(c net.Conn) {
+					defer c.Close()
+					rd := bufio.NewReader(c)
+					ppHeader, err := pp.Read(rd)
+					if err != nil {
+						log.Error("read proxy protocol error: %v", err)
+						return
+					}
+
+					for {
+						if _, err := rpc.ReadBytes(rd); err != nil {
+							return
+						}
+
+						buf := []byte(ppHeader.SourceAddr.String())
+						rpc.WriteBytes(c, buf)
+					}
+				}))
+			f.RunServer("", localServer)
+
+			remotePort := f.AllocPort()
+			clientConf += fmt.Sprintf(`
+			[tcp]
+			type = tcp
+			local_port = %d
+			remote_port = %d
+			proxy_protocol_version = v2
+			`, localPort, remotePort)
+
+			f.RunProcesses([]string{serverConf}, []string{clientConf})
+
+			framework.NewRequestExpect(f).Port(remotePort).Ensure(func(resp *request.Response) bool {
+				log.Trace("ProxyProtocol get SourceAddr: %s", string(resp.Content))
+				addr, err := net.ResolveTCPAddr("tcp", string(resp.Content))
+				if err != nil {
+					return false
+				}
+				if addr.IP.String() != "127.0.0.1" {
+					return false
+				}
+				return true
+			})
+		})
+
+		It("HTTP", func() {
+			vhostHTTPPort := f.AllocPort()
+			serverConf := consts.DefaultServerConfig + fmt.Sprintf(`
+		vhost_http_port = %d
+		`, vhostHTTPPort)
+
+			clientConf := consts.DefaultClientConfig
+
+			localPort := f.AllocPort()
+			var srcAddrRecord string
+			localServer := streamserver.New(streamserver.TCP, streamserver.WithBindPort(localPort),
+				streamserver.WithCustomHandler(func(c net.Conn) {
+					defer c.Close()
+					rd := bufio.NewReader(c)
+					ppHeader, err := pp.Read(rd)
+					if err != nil {
+						log.Error("read proxy protocol error: %v", err)
+						return
+					}
+					srcAddrRecord = ppHeader.SourceAddr.String()
+				}))
+			f.RunServer("", localServer)
+
+			clientConf += fmt.Sprintf(`
+			[test]
+			type = http
+			local_port = %d
+			custom_domains = normal.example.com
+			proxy_protocol_version = v2
+			`, localPort)
+
+			f.RunProcesses([]string{serverConf}, []string{clientConf})
+
+			framework.NewRequestExpect(f).Port(vhostHTTPPort).RequestModify(func(r *request.Request) {
+				r.HTTP().HTTPHost("normal.example.com")
+			}).Ensure(framework.ExpectResponseCode(404))
+
+			log.Trace("ProxyProtocol get SourceAddr: %s", srcAddrRecord)
+			addr, err := net.ResolveTCPAddr("tcp", srcAddrRecord)
+			framework.ExpectNoError(err, srcAddrRecord)
+			framework.ExpectEqualValues("127.0.0.1", addr.IP.String())
+		})
 	})
 })

+ 5 - 1
test/e2e/framework/request.go

@@ -17,7 +17,11 @@ func SpecifiedHTTPBodyHandler(body []byte) http.HandlerFunc {
 
 func ExpectResponseCode(code int) EnsureFunc {
 	return func(resp *request.Response) bool {
-		return resp.Code == code
+		if resp.Code == code {
+			return true
+		}
+		flog.Warn("Expect code %d, but got %d", code, resp.Code)
+		return false
 	}
 }
 

+ 18 - 2
test/e2e/mock/server/streamserver/server.go

@@ -1,7 +1,9 @@
 package streamserver
 
 import (
+	"bufio"
 	"fmt"
+	"io"
 	"net"
 
 	libnet "github.com/fatedier/frp/pkg/util/net"
@@ -22,6 +24,8 @@ type Server struct {
 	bindPort    int
 	respContent []byte
 
+	handler func(net.Conn)
+
 	l net.Listener
 }
 
@@ -32,6 +36,7 @@ func New(netType Type, options ...Option) *Server {
 		netType:  netType,
 		bindAddr: "127.0.0.1",
 	}
+	s.handler = s.handle
 
 	for _, option := range options {
 		s = option(s)
@@ -60,6 +65,13 @@ func WithRespContent(content []byte) Option {
 	}
 }
 
+func WithCustomHandler(handler func(net.Conn)) Option {
+	return func(s *Server) *Server {
+		s.handler = handler
+		return s
+	}
+}
+
 func (s *Server) Run() error {
 	if err := s.initListener(); err != nil {
 		return err
@@ -71,7 +83,7 @@ func (s *Server) Run() error {
 			if err != nil {
 				return
 			}
-			go s.handle(c)
+			go s.handler(c)
 		}
 	}()
 	return nil
@@ -101,8 +113,12 @@ func (s *Server) initListener() (err error) {
 func (s *Server) handle(c net.Conn) {
 	defer c.Close()
 
+	var reader io.Reader = c
+	if s.netType == UDP {
+		reader = bufio.NewReader(c)
+	}
 	for {
-		buf, err := rpc.ReadBytes(c)
+		buf, err := rpc.ReadBytes(reader)
 		if err != nil {
 			return
 		}

+ 11 - 5
test/e2e/pkg/request/request.go

@@ -1,6 +1,7 @@
 package request
 
 import (
+	"bufio"
 	"bytes"
 	"fmt"
 	"io"
@@ -120,7 +121,7 @@ func (r *Request) Do() (*Response, error) {
 	addr := net.JoinHostPort(r.addr, strconv.Itoa(r.port))
 	// for protocol http
 	if r.protocol == "http" {
-		return sendHTTPRequest(r.method, fmt.Sprintf("http://%s%s", addr, r.path),
+		return r.sendHTTPRequest(r.method, fmt.Sprintf("http://%s%s", addr, r.path),
 			r.host, r.headers, r.proxyURL, r.body)
 	}
 
@@ -151,7 +152,7 @@ func (r *Request) Do() (*Response, error) {
 	if r.timeout > 0 {
 		conn.SetDeadline(time.Now().Add(r.timeout))
 	}
-	buf, err := sendRequestByConn(conn, r.body)
+	buf, err := r.sendRequestByConn(conn, r.body)
 	if err != nil {
 		return nil, err
 	}
@@ -164,7 +165,7 @@ type Response struct {
 	Content []byte
 }
 
-func sendHTTPRequest(method, urlstr string, host string, headers map[string]string, proxy string, body []byte) (*Response, error) {
+func (r *Request) sendHTTPRequest(method, urlstr string, host string, headers map[string]string, proxy string, body []byte) (*Response, error) {
 	var inBody io.Reader
 	if len(body) != 0 {
 		inBody = bytes.NewReader(body)
@@ -210,13 +211,18 @@ func sendHTTPRequest(method, urlstr string, host string, headers map[string]stri
 	return ret, nil
 }
 
-func sendRequestByConn(c net.Conn, content []byte) ([]byte, error) {
+func (r *Request) sendRequestByConn(c net.Conn, content []byte) ([]byte, error) {
 	_, err := rpc.WriteBytes(c, content)
 	if err != nil {
 		return nil, fmt.Errorf("write error: %v", err)
 	}
 
-	buf, err := rpc.ReadBytes(c)
+	var reader io.Reader = c
+	if r.protocol == "udp" {
+		reader = bufio.NewReader(c)
+	}
+
+	buf, err := rpc.ReadBytes(reader)
 	if err != nil {
 		return nil, fmt.Errorf("read error: %v", err)
 	}

+ 2 - 6
test/e2e/pkg/rpc/rpc.go

@@ -1,7 +1,6 @@
 package rpc
 
 import (
-	"bufio"
 	"bytes"
 	"encoding/binary"
 	"errors"
@@ -16,15 +15,12 @@ func WriteBytes(w io.Writer, buf []byte) (int, error) {
 }
 
 func ReadBytes(r io.Reader) ([]byte, error) {
-	// To compatible with UDP connection, use bufio reader here to avoid lost conent.
-	rd := bufio.NewReader(r)
-
 	var length int64
-	if err := binary.Read(rd, binary.BigEndian, &length); err != nil {
+	if err := binary.Read(r, binary.BigEndian, &length); err != nil {
 		return nil, err
 	}
 	buffer := make([]byte, length)
-	n, err := io.ReadFull(rd, buffer)
+	n, err := io.ReadFull(r, buffer)
 	if err != nil {
 		return nil, err
 	}