group.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. // Copyright 2018 fatedier, fatedier@gmail.com
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package server
  15. import (
  16. "errors"
  17. "fmt"
  18. "net"
  19. "sync"
  20. gerr "github.com/fatedier/golib/errors"
  21. )
  22. var (
  23. ErrGroupAuthFailed = errors.New("group auth failed")
  24. ErrGroupParamsInvalid = errors.New("group params invalid")
  25. ErrListenerClosed = errors.New("group listener closed")
  26. )
  27. type TcpGroupListener struct {
  28. groupName string
  29. group *TcpGroup
  30. addr net.Addr
  31. closeCh chan struct{}
  32. }
  33. func newTcpGroupListener(name string, group *TcpGroup, addr net.Addr) *TcpGroupListener {
  34. return &TcpGroupListener{
  35. groupName: name,
  36. group: group,
  37. addr: addr,
  38. closeCh: make(chan struct{}),
  39. }
  40. }
  41. func (ln *TcpGroupListener) Accept() (c net.Conn, err error) {
  42. var ok bool
  43. select {
  44. case <-ln.closeCh:
  45. return nil, ErrListenerClosed
  46. case c, ok = <-ln.group.Accept():
  47. if !ok {
  48. return nil, ErrListenerClosed
  49. }
  50. return c, nil
  51. }
  52. }
  53. func (ln *TcpGroupListener) Addr() net.Addr {
  54. return ln.addr
  55. }
  56. func (ln *TcpGroupListener) Close() (err error) {
  57. close(ln.closeCh)
  58. ln.group.CloseListener(ln)
  59. return
  60. }
  61. type TcpGroup struct {
  62. group string
  63. groupKey string
  64. addr string
  65. port int
  66. realPort int
  67. acceptCh chan net.Conn
  68. index uint64
  69. tcpLn net.Listener
  70. lns []*TcpGroupListener
  71. ctl *TcpGroupCtl
  72. mu sync.Mutex
  73. }
  74. func NewTcpGroup(ctl *TcpGroupCtl) *TcpGroup {
  75. return &TcpGroup{
  76. lns: make([]*TcpGroupListener, 0),
  77. ctl: ctl,
  78. acceptCh: make(chan net.Conn),
  79. }
  80. }
  81. func (tg *TcpGroup) Listen(proxyName string, group string, groupKey string, addr string, port int) (ln *TcpGroupListener, realPort int, err error) {
  82. tg.mu.Lock()
  83. defer tg.mu.Unlock()
  84. if len(tg.lns) == 0 {
  85. realPort, err = tg.ctl.portManager.Acquire(proxyName, port)
  86. if err != nil {
  87. return
  88. }
  89. tcpLn, errRet := net.Listen("tcp", fmt.Sprintf("%s:%d", addr, port))
  90. if errRet != nil {
  91. err = errRet
  92. return
  93. }
  94. ln = newTcpGroupListener(group, tg, tcpLn.Addr())
  95. tg.group = group
  96. tg.groupKey = groupKey
  97. tg.addr = addr
  98. tg.port = port
  99. tg.realPort = realPort
  100. tg.tcpLn = tcpLn
  101. tg.lns = append(tg.lns, ln)
  102. if tg.acceptCh == nil {
  103. tg.acceptCh = make(chan net.Conn)
  104. }
  105. go tg.worker()
  106. } else {
  107. if tg.group != group || tg.addr != addr || tg.port != port {
  108. err = ErrGroupParamsInvalid
  109. return
  110. }
  111. if tg.groupKey != groupKey {
  112. err = ErrGroupAuthFailed
  113. return
  114. }
  115. ln = newTcpGroupListener(group, tg, tg.lns[0].Addr())
  116. realPort = tg.realPort
  117. tg.lns = append(tg.lns, ln)
  118. }
  119. return
  120. }
  121. func (tg *TcpGroup) worker() {
  122. for {
  123. c, err := tg.tcpLn.Accept()
  124. if err != nil {
  125. return
  126. }
  127. err = gerr.PanicToError(func() {
  128. tg.acceptCh <- c
  129. })
  130. if err != nil {
  131. return
  132. }
  133. }
  134. }
  135. func (tg *TcpGroup) Accept() <-chan net.Conn {
  136. return tg.acceptCh
  137. }
  138. func (tg *TcpGroup) CloseListener(ln *TcpGroupListener) {
  139. tg.mu.Lock()
  140. defer tg.mu.Unlock()
  141. for i, tmpLn := range tg.lns {
  142. if tmpLn == ln {
  143. tg.lns = append(tg.lns[:i], tg.lns[i+1:]...)
  144. break
  145. }
  146. }
  147. if len(tg.lns) == 0 {
  148. close(tg.acceptCh)
  149. tg.tcpLn.Close()
  150. tg.ctl.portManager.Release(tg.realPort)
  151. tg.ctl.RemoveGroup(tg.group)
  152. }
  153. }
  154. type TcpGroupCtl struct {
  155. groups map[string]*TcpGroup
  156. portManager *PortManager
  157. mu sync.Mutex
  158. }
  159. func NewTcpGroupCtl(portManager *PortManager) *TcpGroupCtl {
  160. return &TcpGroupCtl{
  161. groups: make(map[string]*TcpGroup),
  162. portManager: portManager,
  163. }
  164. }
  165. func (tgc *TcpGroupCtl) Listen(proxyNanme string, group string, groupKey string,
  166. addr string, port int) (l net.Listener, realPort int, err error) {
  167. tgc.mu.Lock()
  168. defer tgc.mu.Unlock()
  169. if tcpGroup, ok := tgc.groups[group]; ok {
  170. return tcpGroup.Listen(proxyNanme, group, groupKey, addr, port)
  171. } else {
  172. tcpGroup = NewTcpGroup(tgc)
  173. tgc.groups[group] = tcpGroup
  174. return tcpGroup.Listen(proxyNanme, group, groupKey, addr, port)
  175. }
  176. }
  177. func (tgc *TcpGroupCtl) RemoveGroup(group string) {
  178. tgc.mu.Lock()
  179. defer tgc.mu.Unlock()
  180. delete(tgc.groups, group)
  181. }