1
0

request.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. package request
  2. import (
  3. "fmt"
  4. "net"
  5. "time"
  6. libnet "github.com/fatedier/golib/net"
  7. )
  8. type Request struct {
  9. protocol string
  10. addr string
  11. port int
  12. body []byte
  13. timeout time.Duration
  14. proxyURL string
  15. proxyHost string
  16. }
  17. func New() *Request {
  18. return &Request{
  19. protocol: "tcp",
  20. }
  21. }
  22. func (r *Request) Protocol(protocol string) *Request {
  23. r.protocol = protocol
  24. return r
  25. }
  26. func (r *Request) TCP() *Request {
  27. r.protocol = "tcp"
  28. return r
  29. }
  30. func (r *Request) UDP() *Request {
  31. r.protocol = "udp"
  32. return r
  33. }
  34. func (r *Request) Proxy(url, host string) *Request {
  35. r.proxyURL = url
  36. r.proxyHost = host
  37. return r
  38. }
  39. func (r *Request) Addr(addr string) *Request {
  40. r.addr = addr
  41. return r
  42. }
  43. func (r *Request) Port(port int) *Request {
  44. r.port = port
  45. return r
  46. }
  47. func (r *Request) Timeout(timeout time.Duration) *Request {
  48. r.timeout = timeout
  49. return r
  50. }
  51. func (r *Request) Body(content []byte) *Request {
  52. r.body = content
  53. return r
  54. }
  55. func (r *Request) Do() ([]byte, error) {
  56. var (
  57. conn net.Conn
  58. err error
  59. )
  60. if len(r.proxyURL) > 0 {
  61. if r.protocol != "tcp" {
  62. return nil, fmt.Errorf("only tcp protocol is allowed for proxy")
  63. }
  64. conn, err = libnet.DialTcpByProxy(r.proxyURL, r.proxyHost)
  65. if err != nil {
  66. return nil, err
  67. }
  68. } else {
  69. if r.addr == "" {
  70. r.addr = fmt.Sprintf("127.0.0.1:%d", r.port)
  71. }
  72. switch r.protocol {
  73. case "tcp":
  74. conn, err = net.Dial("tcp", r.addr)
  75. case "udp":
  76. conn, err = net.Dial("udp", r.addr)
  77. default:
  78. return nil, fmt.Errorf("invalid protocol")
  79. }
  80. if err != nil {
  81. return nil, err
  82. }
  83. }
  84. defer conn.Close()
  85. if r.timeout > 0 {
  86. conn.SetDeadline(time.Now().Add(r.timeout))
  87. }
  88. return sendRequestByConn(conn, r.body)
  89. }
  90. func SendTCPRequest(port int, content []byte, timeout time.Duration) ([]byte, error) {
  91. c, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port))
  92. if err != nil {
  93. return nil, fmt.Errorf("connect to tcp server error: %v", err)
  94. }
  95. defer c.Close()
  96. c.SetDeadline(time.Now().Add(timeout))
  97. return sendRequestByConn(c, content)
  98. }
  99. func SendUDPRequest(port int, content []byte, timeout time.Duration) ([]byte, error) {
  100. c, err := net.Dial("udp", fmt.Sprintf("127.0.0.1:%d", port))
  101. if err != nil {
  102. return nil, fmt.Errorf("connect to udp server error: %v", err)
  103. }
  104. defer c.Close()
  105. c.SetDeadline(time.Now().Add(timeout))
  106. return sendRequestByConn(c, content)
  107. }
  108. func sendRequestByConn(c net.Conn, content []byte) ([]byte, error) {
  109. _, err := c.Write(content)
  110. if err != nil {
  111. return nil, fmt.Errorf("write error: %v", err)
  112. }
  113. buf := make([]byte, 2048)
  114. n, err := c.Read(buf)
  115. if err != nil {
  116. return nil, fmt.Errorf("read error: %v", err)
  117. }
  118. return buf[:n], nil
  119. }