Browse Source

add ci for setting headers

fatedier 6 years ago
parent
commit
2e497274ba
6 changed files with 41 additions and 18 deletions
  1. 1 1
      README.md
  2. 1 1
      README_zh.md
  3. 8 0
      tests/conf/auto_test_frpc.ini
  4. 23 13
      tests/func_test.go
  5. 5 1
      tests/http_server.go
  6. 3 2
      tests/util.go

+ 1 - 1
README.md

@@ -616,7 +616,7 @@ local_port = 6000-6006,6007
 remote_port = 6000-6006,6007
 ```
 
-frpc will generate 6 proxies like `test_tcp_0, test_tcp_1 ... test_tcp_5`.
+frpc will generate 8 proxies like `test_tcp_0, test_tcp_1 ... test_tcp_7`.
 
 ### Plugin
 

+ 1 - 1
README_zh.md

@@ -654,7 +654,7 @@ local_port = 6000-6006,6007
 remote_port = 6000-6006,6007
 ```
 
-实际连接成功后会创建 6 个 proxy,命名为 `test_tcp_0, test_tcp_1 ... test_tcp_5`。
+实际连接成功后会创建 8 个 proxy,命名为 `test_tcp_0, test_tcp_1 ... test_tcp_7`。
 
 ### 插件
 

+ 8 - 0
tests/conf/auto_test_frpc.ini

@@ -103,6 +103,14 @@ use_compression = true
 http_user = test
 http_user = test
 
+[web06]
+type = http
+local_ip = 127.0.0.1
+local_port = 10704
+custom_domains = test6.frp.com
+host_header_rewrite = test6.frp.com
+header_X-From-Where = frp
+
 [subhost01]
 type = http
 local_ip = 127.0.0.1

+ 23 - 13
tests/func_test.go

@@ -2,6 +2,7 @@ package tests
 
 import (
 	"fmt"
+	"net/http"
 	"net/url"
 	"strings"
 	"testing"
@@ -127,67 +128,76 @@ func TestStcp(t *testing.T) {
 func TestHttp(t *testing.T) {
 	assert := assert.New(t)
 	// web01
-	code, body, err := sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT), "", nil, "")
+	code, body, _, err := sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT), "", nil, "")
 	if assert.NoError(err) {
 		assert.Equal(200, code)
 		assert.Equal(TEST_HTTP_NORMAL_STR, body)
 	}
 
 	// web02
-	code, body, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT), "test2.frp.com", nil, "")
+	code, body, _, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT), "test2.frp.com", nil, "")
 	if assert.NoError(err) {
 		assert.Equal(200, code)
 		assert.Equal(TEST_HTTP_NORMAL_STR, body)
 	}
 
 	// error host header
-	code, body, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT), "errorhost.frp.com", nil, "")
+	code, body, _, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT), "errorhost.frp.com", nil, "")
 	if assert.NoError(err) {
 		assert.Equal(404, code)
 	}
 
 	// web03
-	code, body, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT), "test3.frp.com", nil, "")
+	code, body, _, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT), "test3.frp.com", nil, "")
 	if assert.NoError(err) {
 		assert.Equal(200, code)
 		assert.Equal(TEST_HTTP_NORMAL_STR, body)
 	}
 
-	code, body, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d/foo", TEST_HTTP_FRP_PORT), "test3.frp.com", nil, "")
+	code, body, _, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d/foo", TEST_HTTP_FRP_PORT), "test3.frp.com", nil, "")
 	if assert.NoError(err) {
 		assert.Equal(200, code)
 		assert.Equal(TEST_HTTP_FOO_STR, body)
 	}
 
 	// web04
-	code, body, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d/bar", TEST_HTTP_FRP_PORT), "test3.frp.com", nil, "")
+	code, body, _, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d/bar", TEST_HTTP_FRP_PORT), "test3.frp.com", nil, "")
 	if assert.NoError(err) {
 		assert.Equal(200, code)
 		assert.Equal(TEST_HTTP_BAR_STR, body)
 	}
 
 	// web05
-	code, body, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT), "test5.frp.com", nil, "")
+	code, body, _, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT), "test5.frp.com", nil, "")
 	if assert.NoError(err) {
 		assert.Equal(401, code)
 	}
 
-	header := make(map[string]string)
-	header["Authorization"] = basicAuth("test", "test")
-	code, body, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT), "test5.frp.com", header, "")
+	headers := make(map[string]string)
+	headers["Authorization"] = basicAuth("test", "test")
+	code, body, _, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT), "test5.frp.com", headers, "")
 	if assert.NoError(err) {
 		assert.Equal(401, code)
 	}
 
+	// web06
+	var header http.Header
+	code, body, header, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT), "test6.frp.com", nil, "")
+	if assert.NoError(err) {
+		assert.Equal(200, code)
+		assert.Equal(TEST_HTTP_NORMAL_STR, body)
+		assert.Equal("true", header.Get("X-Header-Set"))
+	}
+
 	// subhost01
-	code, body, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT), "test01.sub.com", nil, "")
+	code, body, _, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT), "test01.sub.com", nil, "")
 	if assert.NoError(err) {
 		assert.Equal(200, code)
 		assert.Equal("test01.sub.com", body)
 	}
 
 	// subhost02
-	code, body, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT), "test02.sub.com", nil, "")
+	code, body, _, err = sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT), "test02.sub.com", nil, "")
 	if assert.NoError(err) {
 		assert.Equal(200, code)
 		assert.Equal("test02.sub.com", body)
@@ -272,7 +282,7 @@ func TestPluginHttpProxy(t *testing.T) {
 
 		// http proxy
 		addr := status.RemoteAddr
-		code, body, err := sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT),
+		code, body, _, err := sendHttpMsg("GET", fmt.Sprintf("http://127.0.0.1:%d", TEST_HTTP_FRP_PORT),
 			"", nil, "http://"+addr)
 		if assert.NoError(err) {
 			assert.Equal(200, code)

+ 5 - 1
tests/http_server.go

@@ -39,6 +39,10 @@ func handleWebSocket(w http.ResponseWriter, r *http.Request) {
 }
 
 func handleHttp(w http.ResponseWriter, r *http.Request) {
+	if r.Header.Get("X-From-Where") == "frp" {
+		w.Header().Set("X-Header-Set", "true")
+	}
+
 	match, err := regexp.Match(`.*\.sub\.com`, []byte(r.Host))
 	if err != nil {
 		w.WriteHeader(500)
@@ -52,7 +56,7 @@ func handleHttp(w http.ResponseWriter, r *http.Request) {
 	}
 
 	if strings.Contains(r.Host, "127.0.0.1") || strings.Contains(r.Host, "test2.frp.com") ||
-		strings.Contains(r.Host, "test5.frp.com") {
+		strings.Contains(r.Host, "test5.frp.com") || strings.Contains(r.Host, "test6.frp.com") {
 		w.WriteHeader(200)
 		w.Write([]byte(TEST_HTTP_NORMAL_STR))
 	} else if strings.Contains(r.Host, "test3.frp.com") {

+ 3 - 2
tests/util.go

@@ -126,7 +126,7 @@ func sendUdpMsg(addr string, msg string) (res string, err error) {
 	return string(buf[:n]), nil
 }
 
-func sendHttpMsg(method, urlStr string, host string, header map[string]string, proxy string) (code int, body string, err error) {
+func sendHttpMsg(method, urlStr string, host string, headers map[string]string, proxy string) (code int, body string, header http.Header, err error) {
 	req, errRet := http.NewRequest(method, urlStr, nil)
 	if errRet != nil {
 		err = errRet
@@ -136,7 +136,7 @@ func sendHttpMsg(method, urlStr string, host string, header map[string]string, p
 	if host != "" {
 		req.Host = host
 	}
-	for k, v := range header {
+	for k, v := range headers {
 		req.Header.Set(k, v)
 	}
 
@@ -167,6 +167,7 @@ func sendHttpMsg(method, urlStr string, host string, header map[string]string, p
 		return
 	}
 	code = resp.StatusCode
+	header = resp.Header
 	buf, errRet := ioutil.ReadAll(resp.Body)
 	if errRet != nil {
 		err = errRet