request.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. package request
  2. import (
  3. "bufio"
  4. "bytes"
  5. "fmt"
  6. "io"
  7. "io/ioutil"
  8. "net"
  9. "net/http"
  10. "net/url"
  11. "strconv"
  12. "time"
  13. "github.com/fatedier/frp/test/e2e/pkg/rpc"
  14. libnet "github.com/fatedier/golib/net"
  15. )
  16. type Request struct {
  17. protocol string
  18. // for all protocol
  19. addr string
  20. port int
  21. body []byte
  22. timeout time.Duration
  23. // for http
  24. method string
  25. host string
  26. path string
  27. headers map[string]string
  28. proxyURL string
  29. }
  30. func New() *Request {
  31. return &Request{
  32. protocol: "tcp",
  33. addr: "127.0.0.1",
  34. method: "GET",
  35. path: "/",
  36. }
  37. }
  38. func (r *Request) Protocol(protocol string) *Request {
  39. r.protocol = protocol
  40. return r
  41. }
  42. func (r *Request) TCP() *Request {
  43. r.protocol = "tcp"
  44. return r
  45. }
  46. func (r *Request) UDP() *Request {
  47. r.protocol = "udp"
  48. return r
  49. }
  50. func (r *Request) HTTP() *Request {
  51. r.protocol = "http"
  52. return r
  53. }
  54. func (r *Request) Proxy(url string) *Request {
  55. r.proxyURL = url
  56. return r
  57. }
  58. func (r *Request) Addr(addr string) *Request {
  59. r.addr = addr
  60. return r
  61. }
  62. func (r *Request) Port(port int) *Request {
  63. r.port = port
  64. return r
  65. }
  66. func (r *Request) HTTPParams(method, host, path string, headers map[string]string) *Request {
  67. r.method = method
  68. r.host = host
  69. r.path = path
  70. r.headers = headers
  71. return r
  72. }
  73. func (r *Request) HTTPHost(host string) *Request {
  74. r.host = host
  75. return r
  76. }
  77. func (r *Request) HTTPPath(path string) *Request {
  78. r.path = path
  79. return r
  80. }
  81. func (r *Request) HTTPHeaders(headers map[string]string) *Request {
  82. r.headers = headers
  83. return r
  84. }
  85. func (r *Request) Timeout(timeout time.Duration) *Request {
  86. r.timeout = timeout
  87. return r
  88. }
  89. func (r *Request) Body(content []byte) *Request {
  90. r.body = content
  91. return r
  92. }
  93. func (r *Request) Do() (*Response, error) {
  94. var (
  95. conn net.Conn
  96. err error
  97. )
  98. addr := net.JoinHostPort(r.addr, strconv.Itoa(r.port))
  99. // for protocol http
  100. if r.protocol == "http" {
  101. return r.sendHTTPRequest(r.method, fmt.Sprintf("http://%s%s", addr, r.path),
  102. r.host, r.headers, r.proxyURL, r.body)
  103. }
  104. // for protocol tcp and udp
  105. if len(r.proxyURL) > 0 {
  106. if r.protocol != "tcp" {
  107. return nil, fmt.Errorf("only tcp protocol is allowed for proxy")
  108. }
  109. conn, err = libnet.DialTcpByProxy(r.proxyURL, addr)
  110. if err != nil {
  111. return nil, err
  112. }
  113. } else {
  114. switch r.protocol {
  115. case "tcp":
  116. conn, err = net.Dial("tcp", addr)
  117. case "udp":
  118. conn, err = net.Dial("udp", addr)
  119. default:
  120. return nil, fmt.Errorf("invalid protocol")
  121. }
  122. if err != nil {
  123. return nil, err
  124. }
  125. }
  126. defer conn.Close()
  127. if r.timeout > 0 {
  128. conn.SetDeadline(time.Now().Add(r.timeout))
  129. }
  130. buf, err := r.sendRequestByConn(conn, r.body)
  131. if err != nil {
  132. return nil, err
  133. }
  134. return &Response{Content: buf}, nil
  135. }
  136. type Response struct {
  137. Code int
  138. Header http.Header
  139. Content []byte
  140. }
  141. func (r *Request) sendHTTPRequest(method, urlstr string, host string, headers map[string]string, proxy string, body []byte) (*Response, error) {
  142. var inBody io.Reader
  143. if len(body) != 0 {
  144. inBody = bytes.NewReader(body)
  145. }
  146. req, err := http.NewRequest(method, urlstr, inBody)
  147. if err != nil {
  148. return nil, err
  149. }
  150. if host != "" {
  151. req.Host = host
  152. }
  153. for k, v := range headers {
  154. req.Header.Set(k, v)
  155. }
  156. tr := &http.Transport{
  157. DialContext: (&net.Dialer{
  158. Timeout: time.Second,
  159. KeepAlive: 30 * time.Second,
  160. DualStack: true,
  161. }).DialContext,
  162. MaxIdleConns: 100,
  163. IdleConnTimeout: 90 * time.Second,
  164. TLSHandshakeTimeout: 10 * time.Second,
  165. ExpectContinueTimeout: 1 * time.Second,
  166. }
  167. if len(proxy) != 0 {
  168. tr.Proxy = func(req *http.Request) (*url.URL, error) {
  169. return url.Parse(proxy)
  170. }
  171. }
  172. client := http.Client{Transport: tr}
  173. resp, err := client.Do(req)
  174. if err != nil {
  175. return nil, err
  176. }
  177. ret := &Response{Code: resp.StatusCode, Header: resp.Header}
  178. buf, err := ioutil.ReadAll(resp.Body)
  179. if err != nil {
  180. return nil, err
  181. }
  182. ret.Content = buf
  183. return ret, nil
  184. }
  185. func (r *Request) sendRequestByConn(c net.Conn, content []byte) ([]byte, error) {
  186. _, err := rpc.WriteBytes(c, content)
  187. if err != nil {
  188. return nil, fmt.Errorf("write error: %v", err)
  189. }
  190. var reader io.Reader = c
  191. if r.protocol == "udp" {
  192. reader = bufio.NewReader(c)
  193. }
  194. buf, err := rpc.ReadBytes(reader)
  195. if err != nil {
  196. return nil, fmt.Errorf("read error: %v", err)
  197. }
  198. return buf, nil
  199. }