mux.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. package mux
  2. import (
  3. "fmt"
  4. "io"
  5. "net"
  6. "sort"
  7. "sync"
  8. "time"
  9. "github.com/fatedier/frp/utils/errors"
  10. frpNet "github.com/fatedier/frp/utils/net"
  11. )
  12. const (
  13. // DefaultTimeout is the default length of time to wait for bytes we need.
  14. DefaultTimeout = 10 * time.Second
  15. )
  16. type Mux struct {
  17. ln net.Listener
  18. defaultLn *listener
  19. lns []*listener
  20. maxNeedBytesNum uint32
  21. mu sync.RWMutex
  22. }
  23. func NewMux() (mux *Mux) {
  24. mux = &Mux{
  25. lns: make([]*listener, 0),
  26. }
  27. return
  28. }
  29. func (mux *Mux) Listen(priority int, needBytesNum uint32, fn MatchFunc) net.Listener {
  30. ln := &listener{
  31. c: make(chan net.Conn),
  32. mux: mux,
  33. needBytesNum: needBytesNum,
  34. matchFn: fn,
  35. }
  36. mux.mu.Lock()
  37. defer mux.mu.Unlock()
  38. if needBytesNum > mux.maxNeedBytesNum {
  39. mux.maxNeedBytesNum = needBytesNum
  40. }
  41. newlns := append(mux.copyLns(), ln)
  42. sort.Slice(newlns, func(i, j int) bool {
  43. return newlns[i].needBytesNum < newlns[j].needBytesNum
  44. })
  45. mux.lns = newlns
  46. return ln
  47. }
  48. func (mux *Mux) ListenHttp(priority int) net.Listener {
  49. return mux.Listen(priority, HttpNeedBytesNum, HttpMatchFunc)
  50. }
  51. func (mux *Mux) ListenHttps(priority int) net.Listener {
  52. return mux.Listen(priority, HttpsNeedBytesNum, HttpsMatchFunc)
  53. }
  54. func (mux *Mux) DefaultListener() net.Listener {
  55. mux.mu.Lock()
  56. defer mux.mu.Unlock()
  57. if mux.defaultLn == nil {
  58. mux.defaultLn = &listener{
  59. c: make(chan net.Conn),
  60. mux: mux,
  61. }
  62. }
  63. return mux.defaultLn
  64. }
  65. func (mux *Mux) release(ln *listener) bool {
  66. result := false
  67. mux.mu.Lock()
  68. defer mux.mu.Unlock()
  69. lns := mux.copyLns()
  70. for i, l := range lns {
  71. if l == ln {
  72. lns = append(lns[:i], lns[i+1:]...)
  73. result = true
  74. }
  75. }
  76. mux.lns = lns
  77. return result
  78. }
  79. func (mux *Mux) copyLns() []*listener {
  80. lns := make([]*listener, 0, len(mux.lns))
  81. for _, l := range mux.lns {
  82. lns = append(lns, l)
  83. }
  84. return lns
  85. }
  86. // Serve handles connections from ln and multiplexes then across registered listeners.
  87. func (mux *Mux) Serve(ln net.Listener) error {
  88. mux.mu.Lock()
  89. mux.ln = ln
  90. mux.mu.Unlock()
  91. for {
  92. // Wait for the next connection.
  93. // If it returns a temporary error then simply retry.
  94. // If it returns any other error then exit immediately.
  95. conn, err := ln.Accept()
  96. if err, ok := err.(interface {
  97. Temporary() bool
  98. }); ok && err.Temporary() {
  99. continue
  100. }
  101. if err != nil {
  102. return err
  103. }
  104. go mux.handleConn(conn)
  105. }
  106. }
  107. func (mux *Mux) handleConn(conn net.Conn) {
  108. mux.mu.RLock()
  109. maxNeedBytesNum := mux.maxNeedBytesNum
  110. lns := mux.lns
  111. defaultLn := mux.defaultLn
  112. mux.mu.RUnlock()
  113. shareConn, rd := frpNet.NewShareConnSize(frpNet.WrapConn(conn), int(maxNeedBytesNum))
  114. data := make([]byte, maxNeedBytesNum)
  115. conn.SetReadDeadline(time.Now().Add(DefaultTimeout))
  116. _, err := io.ReadFull(rd, data)
  117. if err != nil {
  118. conn.Close()
  119. return
  120. }
  121. conn.SetReadDeadline(time.Time{})
  122. for _, ln := range lns {
  123. if match := ln.matchFn(data); match {
  124. err = errors.PanicToError(func() {
  125. ln.c <- shareConn
  126. })
  127. if err != nil {
  128. conn.Close()
  129. }
  130. return
  131. }
  132. }
  133. // No match listeners
  134. if defaultLn != nil {
  135. err = errors.PanicToError(func() {
  136. defaultLn.c <- shareConn
  137. })
  138. if err != nil {
  139. conn.Close()
  140. }
  141. return
  142. }
  143. // No listeners for this connection, close it.
  144. conn.Close()
  145. return
  146. }
  147. type listener struct {
  148. mux *Mux
  149. needBytesNum uint32
  150. matchFn MatchFunc
  151. c chan net.Conn
  152. mu sync.RWMutex
  153. }
  154. // Accept waits for and returns the next connection to the listener.
  155. func (ln *listener) Accept() (net.Conn, error) {
  156. conn, ok := <-ln.c
  157. if !ok {
  158. return nil, fmt.Errorf("network connection closed")
  159. }
  160. return conn, nil
  161. }
  162. // Close removes this listener from the parent mux and closes the channel.
  163. func (ln *listener) Close() error {
  164. if ok := ln.mux.release(ln); ok {
  165. // Close done to signal to any RLock holders to release their lock.
  166. close(ln.c)
  167. }
  168. return nil
  169. }
  170. func (ln *listener) Addr() net.Addr {
  171. if ln.mux == nil {
  172. return nil
  173. }
  174. ln.mux.mu.RLock()
  175. defer ln.mux.mu.RUnlock()
  176. if ln.mux.ln == nil {
  177. return nil
  178. }
  179. return ln.mux.ln.Addr()
  180. }