conn.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. package conn
  2. import (
  3. "bufio"
  4. "fmt"
  5. "io"
  6. "net"
  7. "sync"
  8. "github.com/fatedier/frp/utils/log"
  9. )
  10. type Listener struct {
  11. addr net.Addr
  12. l *net.TCPListener
  13. conns chan *Conn
  14. closeFlag bool
  15. }
  16. func Listen(bindAddr string, bindPort int64) (l *Listener, err error) {
  17. tcpAddr, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", bindAddr, bindPort))
  18. listener, err := net.ListenTCP("tcp", tcpAddr)
  19. if err != nil {
  20. return l, err
  21. }
  22. l = &Listener{
  23. addr: listener.Addr(),
  24. l: listener,
  25. conns: make(chan *Conn),
  26. closeFlag: false,
  27. }
  28. go func() {
  29. for {
  30. conn, err := l.l.AcceptTCP()
  31. if err != nil {
  32. if l.closeFlag {
  33. return
  34. }
  35. continue
  36. }
  37. c := &Conn{
  38. TcpConn: conn,
  39. closeFlag: false,
  40. }
  41. c.Reader = bufio.NewReader(c.TcpConn)
  42. l.conns <- c
  43. }
  44. }()
  45. return l, err
  46. }
  47. // wait util get one new connection or listener is closed
  48. // if listener is closed, err returned
  49. func (l *Listener) GetConn() (conn *Conn, err error) {
  50. var ok bool
  51. conn, ok = <-l.conns
  52. if !ok {
  53. return conn, fmt.Errorf("channel close")
  54. }
  55. return conn, nil
  56. }
  57. func (l *Listener) Close() {
  58. if l.l != nil && l.closeFlag == false {
  59. l.closeFlag = true
  60. l.l.Close()
  61. close(l.conns)
  62. }
  63. }
  64. // wrap for TCPConn
  65. type Conn struct {
  66. TcpConn *net.TCPConn
  67. Reader *bufio.Reader
  68. closeFlag bool
  69. }
  70. func ConnectServer(host string, port int64) (c *Conn, err error) {
  71. c = &Conn{}
  72. servertAddr, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", host, port))
  73. if err != nil {
  74. return
  75. }
  76. conn, err := net.DialTCP("tcp", nil, servertAddr)
  77. if err != nil {
  78. return
  79. }
  80. c.TcpConn = conn
  81. c.Reader = bufio.NewReader(c.TcpConn)
  82. c.closeFlag = false
  83. return c, nil
  84. }
  85. func (c *Conn) GetRemoteAddr() (addr string) {
  86. return c.TcpConn.RemoteAddr().String()
  87. }
  88. func (c *Conn) GetLocalAddr() (addr string) {
  89. return c.TcpConn.LocalAddr().String()
  90. }
  91. func (c *Conn) ReadLine() (buff string, err error) {
  92. buff, err = c.Reader.ReadString('\n')
  93. if err == io.EOF {
  94. c.closeFlag = true
  95. }
  96. return buff, err
  97. }
  98. func (c *Conn) Write(content string) (err error) {
  99. _, err = c.TcpConn.Write([]byte(content))
  100. return err
  101. }
  102. func (c *Conn) Close() {
  103. if c.TcpConn != nil && c.closeFlag == false {
  104. c.closeFlag = true
  105. c.TcpConn.Close()
  106. }
  107. }
  108. func (c *Conn) IsClosed() bool {
  109. return c.closeFlag
  110. }
  111. // will block until connection close
  112. func Join(c1 *Conn, c2 *Conn) {
  113. var wait sync.WaitGroup
  114. pipe := func(to *Conn, from *Conn) {
  115. defer to.Close()
  116. defer from.Close()
  117. defer wait.Done()
  118. var err error
  119. _, err = io.Copy(to.TcpConn, from.TcpConn)
  120. if err != nil {
  121. log.Warn("join conns error, %v", err)
  122. }
  123. }
  124. wait.Add(2)
  125. go pipe(c1, c2)
  126. go pipe(c2, c1)
  127. wait.Wait()
  128. return
  129. }