session.go 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. package yamux
  2. import (
  3. "fmt"
  4. "io"
  5. "net"
  6. "sync"
  7. "time"
  8. )
  9. var (
  10. // ErrInvalidVersion means we received a frame with an
  11. // invalid version
  12. ErrInvalidVersion = fmt.Errorf("invalid protocol version")
  13. // ErrInvalidMsgType means we received a frame with an
  14. // invalid message type
  15. ErrInvalidMsgType = fmt.Errorf("invalid msg type")
  16. // ErrSessionShutdown is used if there is a shutdown during
  17. // an operation
  18. ErrSessionShutdown = fmt.Errorf("session shutdown")
  19. )
  20. // Session is used to wrap a reliable ordered connection and to
  21. // multiplex it into multiple streams.
  22. type Session struct {
  23. // client is true if we are a client size connection
  24. client bool
  25. // config holds our configuration
  26. config *Config
  27. // conn is the underlying connection
  28. conn io.ReadWriteCloser
  29. // nextStreamID is the next stream we should
  30. // send. This depends if we are a client/server.
  31. nextStreamID uint32
  32. // pings is used to track inflight pings
  33. pings map[uint32]chan struct{}
  34. pingID uint32
  35. pingLock sync.Mutex
  36. // streams maps a stream id to a stream
  37. streams map[uint32]*Stream
  38. // acceptCh is used to pass ready streams to the client
  39. acceptCh chan *Stream
  40. // sendCh is used to mark a stream as ready to send,
  41. // or to send a header out directly.
  42. sendCh chan sendReady
  43. // shutdown is used to safely close a session
  44. shutdown bool
  45. shutdownErr error
  46. shutdownCh chan struct{}
  47. shutdownLock sync.Mutex
  48. }
  49. // hasAddr is used to get the address from the underlying connection
  50. type hasAddr interface {
  51. LocalAddr() net.Addr
  52. RemoteAddr() net.Addr
  53. }
  54. // yamuxAddr is used when we cannot get the underlying address
  55. type yamuxAddr struct {
  56. Addr string
  57. }
  58. func (*yamuxAddr) Network() string {
  59. return "yamux"
  60. }
  61. func (y *yamuxAddr) String() string {
  62. return fmt.Sprintf("yamux:%s", y.Addr)
  63. }
  64. // sendReady is used to either mark a stream as ready
  65. // or to directly send a header
  66. type sendReady struct {
  67. StreamID uint32
  68. Hdr []byte
  69. Err chan error
  70. }
  71. // newSession is used to construct a new session
  72. func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
  73. s := &Session{
  74. client: client,
  75. config: config,
  76. conn: conn,
  77. pings: make(map[uint32]chan struct{}),
  78. streams: make(map[uint32]*Stream),
  79. acceptCh: make(chan *Stream, config.AcceptBacklog),
  80. sendCh: make(chan sendReady, 64),
  81. shutdownCh: make(chan struct{}),
  82. }
  83. if client {
  84. s.nextStreamID = 1
  85. } else {
  86. s.nextStreamID = 2
  87. }
  88. go s.recv()
  89. go s.send()
  90. if config.EnableKeepAlive {
  91. go s.keepalive()
  92. }
  93. return s
  94. }
  95. // Open is used to create a new stream
  96. func (s *Session) Open() (*Stream, error) {
  97. return nil, nil
  98. }
  99. // Accept is used to block until the next available stream
  100. // is ready to be accepted.
  101. func (s *Session) Accept() (net.Conn, error) {
  102. return s.AcceptStream()
  103. }
  104. // AcceptStream is used to block until the next available stream
  105. // is ready to be accepted.
  106. func (s *Session) AcceptStream() (*Stream, error) {
  107. select {
  108. case stream := <-s.acceptCh:
  109. return stream, nil
  110. case <-s.shutdownCh:
  111. return nil, s.shutdownErr
  112. }
  113. }
  114. // Close is used to close the session and all streams.
  115. // Attempts to send a GoAway before closing the connection.
  116. func (s *Session) Close() error {
  117. s.shutdownLock.Lock()
  118. defer s.shutdownLock.Unlock()
  119. if s.shutdown {
  120. return nil
  121. }
  122. s.shutdown = true
  123. close(s.shutdownCh)
  124. s.conn.Close()
  125. return nil
  126. }
  127. // Addr is used to get the address of the listener.
  128. func (s *Session) Addr() net.Addr {
  129. return s.LocalAddr()
  130. }
  131. // LocalAddr is used to get the local address of the
  132. // underlying connection.
  133. func (s *Session) LocalAddr() net.Addr {
  134. addr, ok := s.conn.(hasAddr)
  135. if !ok {
  136. return &yamuxAddr{"local"}
  137. }
  138. return addr.LocalAddr()
  139. }
  140. // RemoteAddr is used to get the address of remote end
  141. // of the underlying connection
  142. func (s *Session) RemoteAddr() net.Addr {
  143. addr, ok := s.conn.(hasAddr)
  144. if !ok {
  145. return &yamuxAddr{"remote"}
  146. }
  147. return addr.RemoteAddr()
  148. }
  149. // Ping is used to measure the RTT response time
  150. func (s *Session) Ping() (time.Duration, error) {
  151. // Get a channel for the ping
  152. ch := make(chan struct{})
  153. // Get a new ping id, mark as pending
  154. s.pingLock.Lock()
  155. id := s.pingID
  156. s.pingID++
  157. s.pings[id] = ch
  158. s.pingLock.Unlock()
  159. // Send the ping request
  160. hdr := header(make([]byte, headerSize))
  161. hdr.encode(typePing, flagSYN, 0, id)
  162. if err := s.waitForSend(hdr); err != nil {
  163. return 0, err
  164. }
  165. // Wait for a response
  166. start := time.Now()
  167. select {
  168. case <-ch:
  169. case <-s.shutdownCh:
  170. return 0, ErrSessionShutdown
  171. }
  172. // Compute the RTT
  173. return time.Now().Sub(start), nil
  174. }
  175. // keepalive is a long running goroutine that periodically does
  176. // a ping to keep the connection alive.
  177. func (s *Session) keepalive() {
  178. for {
  179. select {
  180. case <-time.After(s.config.KeepAliveInterval):
  181. s.Ping()
  182. case <-s.shutdownCh:
  183. return
  184. }
  185. }
  186. }
  187. // waitForSend waits to send a header, checking for a potential shutdown
  188. func (s *Session) waitForSend(hdr header) error {
  189. errCh := make(chan error, 1)
  190. ready := sendReady{Hdr: hdr, Err: errCh}
  191. select {
  192. case s.sendCh <- ready:
  193. case <-s.shutdownCh:
  194. return ErrSessionShutdown
  195. }
  196. select {
  197. case err := <-errCh:
  198. return err
  199. case <-s.shutdownCh:
  200. return ErrSessionShutdown
  201. }
  202. }
  203. // sendNoWait does a send without waiting
  204. func (s *Session) sendNoWait(hdr header) error {
  205. select {
  206. case s.sendCh <- sendReady{Hdr: hdr}:
  207. return nil
  208. case <-s.shutdownCh:
  209. return ErrSessionShutdown
  210. }
  211. }
  212. // send is a long running goroutine that sends data
  213. func (s *Session) send() {
  214. for {
  215. select {
  216. case ready := <-s.sendCh:
  217. // Send data from a stream if ready
  218. if ready.StreamID != 0 {
  219. }
  220. // Send a header if ready
  221. if ready.Hdr != nil {
  222. sent := 0
  223. for sent < len(ready.Hdr) {
  224. n, err := s.conn.Write(ready.Hdr[sent:])
  225. if err != nil {
  226. s.exitErr(err)
  227. asyncSendErr(ready.Err, err)
  228. }
  229. sent += n
  230. }
  231. }
  232. asyncSendErr(ready.Err, nil)
  233. case <-s.shutdownCh:
  234. return
  235. }
  236. }
  237. }
  238. // recv is a long running goroutine that accepts new data
  239. func (s *Session) recv() {
  240. hdr := header(make([]byte, headerSize))
  241. for {
  242. // Read the header
  243. if _, err := io.ReadFull(s.conn, hdr); err != nil {
  244. s.exitErr(err)
  245. return
  246. }
  247. // Verify the version
  248. if hdr.Version() != protoVersion {
  249. s.exitErr(ErrInvalidVersion)
  250. return
  251. }
  252. // Switch on the type
  253. msgType := hdr.MsgType()
  254. switch msgType {
  255. case typeData:
  256. s.handleData(hdr)
  257. case typeWindowUpdate:
  258. s.handleWindowUpdate(hdr)
  259. case typePing:
  260. s.handlePing(hdr)
  261. case typeGoAway:
  262. s.handleGoAway(hdr)
  263. default:
  264. s.exitErr(ErrInvalidMsgType)
  265. return
  266. }
  267. }
  268. }
  269. // handleData is invokde for a typeData frame
  270. func (s *Session) handleData(hdr header) {
  271. flags := hdr.Flags()
  272. // Check for a new stream creation
  273. if flags&flagSYN == flagSYN {
  274. s.createStream(hdr.StreamID())
  275. }
  276. }
  277. // handleWindowUpdate is invokde for a typeWindowUpdate frame
  278. func (s *Session) handleWindowUpdate(hdr header) {
  279. flags := hdr.Flags()
  280. // Check for a new stream creation
  281. if flags&flagSYN == flagSYN {
  282. s.createStream(hdr.StreamID())
  283. }
  284. }
  285. // handlePing is invokde for a typePing frame
  286. func (s *Session) handlePing(hdr header) {
  287. flags := hdr.Flags()
  288. pingID := hdr.Length()
  289. // Check if this is a query, respond back
  290. if flags&flagSYN == flagSYN {
  291. hdr := header(make([]byte, headerSize))
  292. hdr.encode(typePing, flagACK, 0, pingID)
  293. s.sendNoWait(hdr)
  294. return
  295. }
  296. // Handle a response
  297. s.pingLock.Lock()
  298. ch := s.pings[pingID]
  299. if ch != nil {
  300. delete(s.pings, pingID)
  301. close(ch)
  302. }
  303. s.pingLock.Unlock()
  304. }
  305. // handleGoAway is invokde for a typeGoAway frame
  306. func (s *Session) handleGoAway(hdr header) {
  307. }
  308. // exitErr is used to handle an error that is causing
  309. // the listener to exit.
  310. func (s *Session) exitErr(err error) {
  311. }
  312. // goAway is used to send a goAway message
  313. func (s *Session) goAway(reason uint32) {
  314. hdr := header(make([]byte, headerSize))
  315. hdr.encode(typeGoAway, 0, 0, reason)
  316. s.sendNoWait(hdr)
  317. }
  318. // createStream is used to create a new stream
  319. func (s *Session) createStream(id uint32) {
  320. // TODO
  321. }