request.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. package request
  2. import (
  3. "bufio"
  4. "bytes"
  5. "crypto/tls"
  6. "fmt"
  7. "io"
  8. "net"
  9. "net/http"
  10. "net/url"
  11. "strconv"
  12. "time"
  13. "github.com/fatedier/frp/test/e2e/pkg/rpc"
  14. libnet "github.com/fatedier/golib/net"
  15. )
  16. type Request struct {
  17. protocol string
  18. // for all protocol
  19. addr string
  20. port int
  21. body []byte
  22. timeout time.Duration
  23. // for http or https
  24. method string
  25. host string
  26. path string
  27. headers map[string]string
  28. tlsConfig *tls.Config
  29. proxyURL string
  30. }
  31. func New() *Request {
  32. return &Request{
  33. protocol: "tcp",
  34. addr: "127.0.0.1",
  35. method: "GET",
  36. path: "/",
  37. }
  38. }
  39. func (r *Request) Protocol(protocol string) *Request {
  40. r.protocol = protocol
  41. return r
  42. }
  43. func (r *Request) TCP() *Request {
  44. r.protocol = "tcp"
  45. return r
  46. }
  47. func (r *Request) UDP() *Request {
  48. r.protocol = "udp"
  49. return r
  50. }
  51. func (r *Request) HTTP() *Request {
  52. r.protocol = "http"
  53. return r
  54. }
  55. func (r *Request) HTTPS() *Request {
  56. r.protocol = "https"
  57. return r
  58. }
  59. func (r *Request) Proxy(url string) *Request {
  60. r.proxyURL = url
  61. return r
  62. }
  63. func (r *Request) Addr(addr string) *Request {
  64. r.addr = addr
  65. return r
  66. }
  67. func (r *Request) Port(port int) *Request {
  68. r.port = port
  69. return r
  70. }
  71. func (r *Request) HTTPParams(method, host, path string, headers map[string]string) *Request {
  72. r.method = method
  73. r.host = host
  74. r.path = path
  75. r.headers = headers
  76. return r
  77. }
  78. func (r *Request) HTTPHost(host string) *Request {
  79. r.host = host
  80. return r
  81. }
  82. func (r *Request) HTTPPath(path string) *Request {
  83. r.path = path
  84. return r
  85. }
  86. func (r *Request) HTTPHeaders(headers map[string]string) *Request {
  87. r.headers = headers
  88. return r
  89. }
  90. func (r *Request) TLSConfig(tlsConfig *tls.Config) *Request {
  91. r.tlsConfig = tlsConfig
  92. return r
  93. }
  94. func (r *Request) Timeout(timeout time.Duration) *Request {
  95. r.timeout = timeout
  96. return r
  97. }
  98. func (r *Request) Body(content []byte) *Request {
  99. r.body = content
  100. return r
  101. }
  102. func (r *Request) Do() (*Response, error) {
  103. var (
  104. conn net.Conn
  105. err error
  106. )
  107. addr := net.JoinHostPort(r.addr, strconv.Itoa(r.port))
  108. // for protocol http and https
  109. if r.protocol == "http" || r.protocol == "https" {
  110. return r.sendHTTPRequest(r.method, fmt.Sprintf("%s://%s%s", r.protocol, addr, r.path),
  111. r.host, r.headers, r.proxyURL, r.body, r.tlsConfig)
  112. }
  113. // for protocol tcp and udp
  114. if len(r.proxyURL) > 0 {
  115. if r.protocol != "tcp" {
  116. return nil, fmt.Errorf("only tcp protocol is allowed for proxy")
  117. }
  118. conn, err = libnet.DialTcpByProxy(r.proxyURL, addr)
  119. if err != nil {
  120. return nil, err
  121. }
  122. } else {
  123. switch r.protocol {
  124. case "tcp":
  125. conn, err = net.Dial("tcp", addr)
  126. case "udp":
  127. conn, err = net.Dial("udp", addr)
  128. default:
  129. return nil, fmt.Errorf("invalid protocol")
  130. }
  131. if err != nil {
  132. return nil, err
  133. }
  134. }
  135. defer conn.Close()
  136. if r.timeout > 0 {
  137. conn.SetDeadline(time.Now().Add(r.timeout))
  138. }
  139. buf, err := r.sendRequestByConn(conn, r.body)
  140. if err != nil {
  141. return nil, err
  142. }
  143. return &Response{Content: buf}, nil
  144. }
  145. type Response struct {
  146. Code int
  147. Header http.Header
  148. Content []byte
  149. }
  150. func (r *Request) sendHTTPRequest(method, urlstr string, host string, headers map[string]string,
  151. proxy string, body []byte, tlsConfig *tls.Config,
  152. ) (*Response, error) {
  153. var inBody io.Reader
  154. if len(body) != 0 {
  155. inBody = bytes.NewReader(body)
  156. }
  157. req, err := http.NewRequest(method, urlstr, inBody)
  158. if err != nil {
  159. return nil, err
  160. }
  161. if host != "" {
  162. req.Host = host
  163. }
  164. for k, v := range headers {
  165. req.Header.Set(k, v)
  166. }
  167. tr := &http.Transport{
  168. DialContext: (&net.Dialer{
  169. Timeout: time.Second,
  170. KeepAlive: 30 * time.Second,
  171. DualStack: true,
  172. }).DialContext,
  173. MaxIdleConns: 100,
  174. IdleConnTimeout: 90 * time.Second,
  175. TLSHandshakeTimeout: 10 * time.Second,
  176. ExpectContinueTimeout: 1 * time.Second,
  177. TLSClientConfig: tlsConfig,
  178. }
  179. if len(proxy) != 0 {
  180. tr.Proxy = func(req *http.Request) (*url.URL, error) {
  181. return url.Parse(proxy)
  182. }
  183. }
  184. client := http.Client{Transport: tr}
  185. resp, err := client.Do(req)
  186. if err != nil {
  187. return nil, err
  188. }
  189. ret := &Response{Code: resp.StatusCode, Header: resp.Header}
  190. buf, err := io.ReadAll(resp.Body)
  191. if err != nil {
  192. return nil, err
  193. }
  194. ret.Content = buf
  195. return ret, nil
  196. }
  197. func (r *Request) sendRequestByConn(c net.Conn, content []byte) ([]byte, error) {
  198. _, err := rpc.WriteBytes(c, content)
  199. if err != nil {
  200. return nil, fmt.Errorf("write error: %v", err)
  201. }
  202. var reader io.Reader = c
  203. if r.protocol == "udp" {
  204. reader = bufio.NewReader(c)
  205. }
  206. buf, err := rpc.ReadBytes(reader)
  207. if err != nil {
  208. return nil, fmt.Errorf("read error: %v", err)
  209. }
  210. return buf, nil
  211. }