request.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. package request
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "io/ioutil"
  7. "net"
  8. "net/http"
  9. "net/url"
  10. "strconv"
  11. "time"
  12. "github.com/fatedier/frp/test/e2e/pkg/rpc"
  13. libnet "github.com/fatedier/golib/net"
  14. )
  15. type Request struct {
  16. protocol string
  17. // for all protocol
  18. addr string
  19. port int
  20. body []byte
  21. timeout time.Duration
  22. // for http
  23. method string
  24. host string
  25. path string
  26. headers map[string]string
  27. proxyURL string
  28. }
  29. func New() *Request {
  30. return &Request{
  31. protocol: "tcp",
  32. addr: "127.0.0.1",
  33. method: "GET",
  34. path: "/",
  35. }
  36. }
  37. func (r *Request) Protocol(protocol string) *Request {
  38. r.protocol = protocol
  39. return r
  40. }
  41. func (r *Request) TCP() *Request {
  42. r.protocol = "tcp"
  43. return r
  44. }
  45. func (r *Request) UDP() *Request {
  46. r.protocol = "udp"
  47. return r
  48. }
  49. func (r *Request) HTTP() *Request {
  50. r.protocol = "http"
  51. return r
  52. }
  53. func (r *Request) Proxy(url string) *Request {
  54. r.proxyURL = url
  55. return r
  56. }
  57. func (r *Request) Addr(addr string) *Request {
  58. r.addr = addr
  59. return r
  60. }
  61. func (r *Request) Port(port int) *Request {
  62. r.port = port
  63. return r
  64. }
  65. func (r *Request) HTTPParams(method, host, path string, headers map[string]string) *Request {
  66. r.method = method
  67. r.host = host
  68. r.path = path
  69. r.headers = headers
  70. return r
  71. }
  72. func (r *Request) HTTPHost(host string) *Request {
  73. r.host = host
  74. return r
  75. }
  76. func (r *Request) HTTPPath(path string) *Request {
  77. r.path = path
  78. return r
  79. }
  80. func (r *Request) HTTPHeaders(headers map[string]string) *Request {
  81. r.headers = headers
  82. return r
  83. }
  84. func (r *Request) Timeout(timeout time.Duration) *Request {
  85. r.timeout = timeout
  86. return r
  87. }
  88. func (r *Request) Body(content []byte) *Request {
  89. r.body = content
  90. return r
  91. }
  92. func (r *Request) Do() (*Response, error) {
  93. var (
  94. conn net.Conn
  95. err error
  96. )
  97. addr := net.JoinHostPort(r.addr, strconv.Itoa(r.port))
  98. // for protocol http
  99. if r.protocol == "http" {
  100. return sendHTTPRequest(r.method, fmt.Sprintf("http://%s%s", addr, r.path),
  101. r.host, r.headers, r.proxyURL, r.body)
  102. }
  103. // for protocol tcp and udp
  104. if len(r.proxyURL) > 0 {
  105. if r.protocol != "tcp" {
  106. return nil, fmt.Errorf("only tcp protocol is allowed for proxy")
  107. }
  108. conn, err = libnet.DialTcpByProxy(r.proxyURL, addr)
  109. if err != nil {
  110. return nil, err
  111. }
  112. } else {
  113. switch r.protocol {
  114. case "tcp":
  115. conn, err = net.Dial("tcp", addr)
  116. case "udp":
  117. conn, err = net.Dial("udp", addr)
  118. default:
  119. return nil, fmt.Errorf("invalid protocol")
  120. }
  121. if err != nil {
  122. return nil, err
  123. }
  124. }
  125. defer conn.Close()
  126. if r.timeout > 0 {
  127. conn.SetDeadline(time.Now().Add(r.timeout))
  128. }
  129. buf, err := sendRequestByConn(conn, r.body)
  130. if err != nil {
  131. return nil, err
  132. }
  133. return &Response{Content: buf}, nil
  134. }
  135. type Response struct {
  136. Code int
  137. Header http.Header
  138. Content []byte
  139. }
  140. func sendHTTPRequest(method, urlstr string, host string, headers map[string]string, proxy string, body []byte) (*Response, error) {
  141. var inBody io.Reader
  142. if len(body) != 0 {
  143. inBody = bytes.NewReader(body)
  144. }
  145. req, err := http.NewRequest(method, urlstr, inBody)
  146. if err != nil {
  147. return nil, err
  148. }
  149. if host != "" {
  150. req.Host = host
  151. }
  152. for k, v := range headers {
  153. req.Header.Set(k, v)
  154. }
  155. tr := &http.Transport{
  156. DialContext: (&net.Dialer{
  157. Timeout: time.Second,
  158. KeepAlive: 30 * time.Second,
  159. DualStack: true,
  160. }).DialContext,
  161. MaxIdleConns: 100,
  162. IdleConnTimeout: 90 * time.Second,
  163. TLSHandshakeTimeout: 10 * time.Second,
  164. ExpectContinueTimeout: 1 * time.Second,
  165. }
  166. if len(proxy) != 0 {
  167. tr.Proxy = func(req *http.Request) (*url.URL, error) {
  168. return url.Parse(proxy)
  169. }
  170. }
  171. client := http.Client{Transport: tr}
  172. resp, err := client.Do(req)
  173. if err != nil {
  174. return nil, err
  175. }
  176. ret := &Response{Code: resp.StatusCode, Header: resp.Header}
  177. buf, err := ioutil.ReadAll(resp.Body)
  178. if err != nil {
  179. return nil, err
  180. }
  181. ret.Content = buf
  182. return ret, nil
  183. }
  184. func sendRequestByConn(c net.Conn, content []byte) ([]byte, error) {
  185. _, err := rpc.WriteBytes(c, content)
  186. if err != nil {
  187. return nil, fmt.Errorf("write error: %v", err)
  188. }
  189. buf, err := rpc.ReadBytes(c)
  190. if err != nil {
  191. return nil, fmt.Errorf("read error: %v", err)
  192. }
  193. return buf, nil
  194. }