vhost.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. // Copyright 2016 fatedier, fatedier@gmail.com
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package vhost
  15. import (
  16. "bytes"
  17. "fmt"
  18. "io"
  19. "net"
  20. "strings"
  21. "sync"
  22. "time"
  23. "github.com/fatedier/frp/src/utils/conn"
  24. )
  25. type muxFunc func(*conn.Conn) (net.Conn, string, error)
  26. type hostRewriteFunc func(*conn.Conn, string) (net.Conn, error)
  27. type VhostMuxer struct {
  28. listener *conn.Listener
  29. timeout time.Duration
  30. vhostFunc muxFunc
  31. rewriteFunc hostRewriteFunc
  32. registryMap map[string]*Listener
  33. mutex sync.RWMutex
  34. }
  35. func NewVhostMuxer(listener *conn.Listener, vhostFunc muxFunc, rewriteFunc hostRewriteFunc, timeout time.Duration) (mux *VhostMuxer, err error) {
  36. mux = &VhostMuxer{
  37. listener: listener,
  38. timeout: timeout,
  39. vhostFunc: vhostFunc,
  40. rewriteFunc: rewriteFunc,
  41. registryMap: make(map[string]*Listener),
  42. }
  43. go mux.run()
  44. return mux, nil
  45. }
  46. // listen for a new domain name, if rewriteHost is not empty and rewriteFunc is not nil, then rewrite the host header to rewriteHost
  47. func (v *VhostMuxer) Listen(name string, rewriteHost string) (l *Listener, err error) {
  48. v.mutex.Lock()
  49. defer v.mutex.Unlock()
  50. if _, exist := v.registryMap[name]; exist {
  51. return nil, fmt.Errorf("domain name %s is already bound", name)
  52. }
  53. l = &Listener{
  54. name: name,
  55. rewriteHost: rewriteHost,
  56. mux: v,
  57. accept: make(chan *conn.Conn),
  58. }
  59. v.registryMap[name] = l
  60. return l, nil
  61. }
  62. func (v *VhostMuxer) getListener(name string) (l *Listener, exist bool) {
  63. v.mutex.RLock()
  64. defer v.mutex.RUnlock()
  65. l, exist = v.registryMap[name]
  66. return l, exist
  67. }
  68. func (v *VhostMuxer) unRegister(name string) {
  69. v.mutex.Lock()
  70. defer v.mutex.Unlock()
  71. delete(v.registryMap, name)
  72. }
  73. func (v *VhostMuxer) run() {
  74. for {
  75. conn, err := v.listener.Accept()
  76. if err != nil {
  77. return
  78. }
  79. go v.handle(conn)
  80. }
  81. }
  82. func (v *VhostMuxer) handle(c *conn.Conn) {
  83. if err := c.SetDeadline(time.Now().Add(v.timeout)); err != nil {
  84. return
  85. }
  86. sConn, name, err := v.vhostFunc(c)
  87. if err != nil {
  88. return
  89. }
  90. name = strings.ToLower(name)
  91. l, ok := v.getListener(name)
  92. if !ok {
  93. return
  94. }
  95. if err = sConn.SetDeadline(time.Time{}); err != nil {
  96. return
  97. }
  98. c.SetTcpConn(sConn)
  99. l.accept <- c
  100. }
  101. type Listener struct {
  102. name string
  103. rewriteHost string
  104. mux *VhostMuxer // for closing VhostMuxer
  105. accept chan *conn.Conn
  106. }
  107. func (l *Listener) Accept() (*conn.Conn, error) {
  108. conn, ok := <-l.accept
  109. if !ok {
  110. return nil, fmt.Errorf("Listener closed")
  111. }
  112. // if rewriteFunc is exist and rewriteHost is set
  113. // rewrite http requests with a modified host header
  114. if l.mux.rewriteFunc != nil && l.rewriteHost != "" {
  115. sConn, err := l.mux.rewriteFunc(conn, l.rewriteHost)
  116. if err != nil {
  117. return nil, fmt.Errorf("http host header rewrite failed")
  118. }
  119. conn.SetTcpConn(sConn)
  120. }
  121. return conn, nil
  122. }
  123. func (l *Listener) Close() error {
  124. l.mux.unRegister(l.name)
  125. close(l.accept)
  126. return nil
  127. }
  128. func (l *Listener) Name() string {
  129. return l.name
  130. }
  131. type sharedConn struct {
  132. net.Conn
  133. sync.Mutex
  134. buff *bytes.Buffer
  135. }
  136. // the bytes you read in io.Reader, will be reserved in sharedConn
  137. func newShareConn(conn net.Conn) (*sharedConn, io.Reader) {
  138. sc := &sharedConn{
  139. Conn: conn,
  140. buff: bytes.NewBuffer(make([]byte, 0, 1024)),
  141. }
  142. return sc, io.TeeReader(conn, sc.buff)
  143. }
  144. func (sc *sharedConn) Read(p []byte) (n int, err error) {
  145. sc.Lock()
  146. if sc.buff == nil {
  147. sc.Unlock()
  148. return sc.Conn.Read(p)
  149. }
  150. n, err = sc.buff.Read(p)
  151. if err == io.EOF {
  152. sc.buff = nil
  153. var n2 int
  154. n2, err = sc.Conn.Read(p[n:])
  155. n += n2
  156. }
  157. sc.Unlock()
  158. return
  159. }
  160. func (sc *sharedConn) WriteBuff(buffer []byte) (err error) {
  161. sc.buff.Reset()
  162. _, err = sc.buff.Write(buffer)
  163. return err
  164. }