session.go 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. package yamux
  2. import (
  3. "fmt"
  4. "io"
  5. "math"
  6. "net"
  7. "sync"
  8. "sync/atomic"
  9. "time"
  10. )
  11. // Session is used to wrap a reliable ordered connection and to
  12. // multiplex it into multiple streams.
  13. type Session struct {
  14. // remoteGoAway indicates the remote side does
  15. // not want futher connections. Must be first for alignment.
  16. remoteGoAway int32
  17. // localGoAway indicates that we should stop
  18. // accepting futher connections. Must be first for alignment.
  19. localGoAway int32
  20. // config holds our configuration
  21. config *Config
  22. // conn is the underlying connection
  23. conn io.ReadWriteCloser
  24. // pings is used to track inflight pings
  25. pings map[uint32]chan struct{}
  26. pingID uint32
  27. pingLock sync.Mutex
  28. // nextStreamID is the next stream we should
  29. // send. This depends if we are a client/server.
  30. nextStreamID uint32
  31. // streams maps a stream id to a stream
  32. streams map[uint32]*Stream
  33. streamLock sync.Mutex
  34. // acceptCh is used to pass ready streams to the client
  35. acceptCh chan *Stream
  36. // sendCh is used to mark a stream as ready to send,
  37. // or to send a header out directly.
  38. sendCh chan sendReady
  39. // shutdown is used to safely close a session
  40. shutdown bool
  41. shutdownErr error
  42. shutdownCh chan struct{}
  43. shutdownLock sync.Mutex
  44. }
  45. // sendReady is used to either mark a stream as ready
  46. // or to directly send a header
  47. type sendReady struct {
  48. Hdr []byte
  49. Body io.Reader
  50. Err chan error
  51. }
  52. // newSession is used to construct a new session
  53. func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
  54. s := &Session{
  55. config: config,
  56. conn: conn,
  57. pings: make(map[uint32]chan struct{}),
  58. streams: make(map[uint32]*Stream),
  59. acceptCh: make(chan *Stream, config.AcceptBacklog),
  60. sendCh: make(chan sendReady, 64),
  61. shutdownCh: make(chan struct{}),
  62. }
  63. if client {
  64. s.nextStreamID = 1
  65. } else {
  66. s.nextStreamID = 2
  67. }
  68. go s.recv()
  69. go s.send()
  70. if config.EnableKeepAlive {
  71. go s.keepalive()
  72. }
  73. return s
  74. }
  75. // IsClosed does a safe check to see if we have shutdown
  76. func (s *Session) IsClosed() bool {
  77. select {
  78. case <-s.shutdownCh:
  79. return true
  80. default:
  81. return false
  82. }
  83. }
  84. // Open is used to create a new stream
  85. func (s *Session) Open() (*Stream, error) {
  86. if s.IsClosed() {
  87. return nil, ErrSessionShutdown
  88. }
  89. if atomic.LoadInt32(&s.remoteGoAway) == 1 {
  90. return nil, ErrRemoteGoAway
  91. }
  92. // Check if we've exhaused the streams
  93. s.streamLock.Lock()
  94. id := s.nextStreamID
  95. if id >= math.MaxUint32-1 {
  96. s.streamLock.Unlock()
  97. return nil, ErrStreamsExhausted
  98. }
  99. s.nextStreamID += 2
  100. // Register the stream
  101. stream := newStream(s, id, streamInit)
  102. s.streams[id] = stream
  103. s.streamLock.Unlock()
  104. // Send the window update to create
  105. return stream, stream.sendWindowUpdate()
  106. }
  107. // Accept is used to block until the next available stream
  108. // is ready to be accepted.
  109. func (s *Session) Accept() (net.Conn, error) {
  110. return s.AcceptStream()
  111. }
  112. // AcceptStream is used to block until the next available stream
  113. // is ready to be accepted.
  114. func (s *Session) AcceptStream() (*Stream, error) {
  115. select {
  116. case stream := <-s.acceptCh:
  117. return stream, nil
  118. case <-s.shutdownCh:
  119. return nil, s.shutdownErr
  120. }
  121. }
  122. // Close is used to close the session and all streams.
  123. // Attempts to send a GoAway before closing the connection.
  124. func (s *Session) Close() error {
  125. s.shutdownLock.Lock()
  126. defer s.shutdownLock.Unlock()
  127. if s.shutdown {
  128. return nil
  129. }
  130. s.shutdown = true
  131. if s.shutdownErr == nil {
  132. s.shutdownErr = ErrSessionShutdown
  133. }
  134. close(s.shutdownCh)
  135. s.conn.Close()
  136. s.streamLock.Lock()
  137. defer s.streamLock.Unlock()
  138. for _, stream := range s.streams {
  139. stream.forceClose()
  140. }
  141. return nil
  142. }
  143. // exitErr is used to handle an error that is causing the
  144. // session to terminate.
  145. func (s *Session) exitErr(err error) {
  146. s.shutdownErr = err
  147. s.Close()
  148. }
  149. // GoAway can be used to prevent accepting further
  150. // connections. It does not close the underlying conn.
  151. func (s *Session) GoAway() error {
  152. return s.waitForSend(s.goAway(goAwayNormal), nil)
  153. }
  154. // goAway is used to send a goAway message
  155. func (s *Session) goAway(reason uint32) header {
  156. atomic.SwapInt32(&s.localGoAway, 1)
  157. hdr := header(make([]byte, headerSize))
  158. hdr.encode(typeGoAway, 0, 0, reason)
  159. return hdr
  160. }
  161. // Ping is used to measure the RTT response time
  162. func (s *Session) Ping() (time.Duration, error) {
  163. // Get a channel for the ping
  164. ch := make(chan struct{})
  165. // Get a new ping id, mark as pending
  166. s.pingLock.Lock()
  167. id := s.pingID
  168. s.pingID++
  169. s.pings[id] = ch
  170. s.pingLock.Unlock()
  171. // Send the ping request
  172. hdr := header(make([]byte, headerSize))
  173. hdr.encode(typePing, flagSYN, 0, id)
  174. if err := s.waitForSend(hdr, nil); err != nil {
  175. return 0, err
  176. }
  177. // Wait for a response
  178. start := time.Now()
  179. select {
  180. case <-ch:
  181. case <-s.shutdownCh:
  182. return 0, ErrSessionShutdown
  183. }
  184. // Compute the RTT
  185. return time.Now().Sub(start), nil
  186. }
  187. // keepalive is a long running goroutine that periodically does
  188. // a ping to keep the connection alive.
  189. func (s *Session) keepalive() {
  190. for {
  191. select {
  192. case <-time.After(s.config.KeepAliveInterval):
  193. s.Ping()
  194. case <-s.shutdownCh:
  195. return
  196. }
  197. }
  198. }
  199. // waitForSend waits to send a header, checking for a potential shutdown
  200. func (s *Session) waitForSend(hdr header, body io.Reader) error {
  201. errCh := make(chan error, 1)
  202. ready := sendReady{Hdr: hdr, Body: body, Err: errCh}
  203. select {
  204. case s.sendCh <- ready:
  205. case <-s.shutdownCh:
  206. return ErrSessionShutdown
  207. }
  208. select {
  209. case err := <-errCh:
  210. return err
  211. case <-s.shutdownCh:
  212. return ErrSessionShutdown
  213. }
  214. }
  215. // sendNoWait does a send without waiting
  216. func (s *Session) sendNoWait(hdr header) error {
  217. select {
  218. case s.sendCh <- sendReady{Hdr: hdr}:
  219. return nil
  220. case <-s.shutdownCh:
  221. return ErrSessionShutdown
  222. }
  223. }
  224. // send is a long running goroutine that sends data
  225. func (s *Session) send() {
  226. for !s.IsClosed() {
  227. select {
  228. case ready := <-s.sendCh:
  229. // Send a header if ready
  230. if ready.Hdr != nil {
  231. sent := 0
  232. for sent < len(ready.Hdr) {
  233. n, err := s.conn.Write(ready.Hdr[sent:])
  234. if err != nil {
  235. asyncSendErr(ready.Err, err)
  236. s.exitErr(err)
  237. return
  238. }
  239. sent += n
  240. }
  241. }
  242. // Send data from a body if given
  243. if ready.Body != nil {
  244. _, err := io.Copy(s.conn, ready.Body)
  245. if err != nil {
  246. asyncSendErr(ready.Err, err)
  247. s.exitErr(err)
  248. return
  249. }
  250. }
  251. // No error, successful send
  252. asyncSendErr(ready.Err, nil)
  253. case <-s.shutdownCh:
  254. return
  255. }
  256. }
  257. }
  258. // recv is a long running goroutine that accepts new data
  259. func (s *Session) recv() {
  260. hdr := header(make([]byte, headerSize))
  261. var handler func(header) error
  262. for !s.IsClosed() {
  263. // Read the header
  264. if _, err := io.ReadFull(s.conn, hdr); err != nil {
  265. s.exitErr(err)
  266. return
  267. }
  268. // Verify the version
  269. if hdr.Version() != protoVersion {
  270. s.exitErr(ErrInvalidVersion)
  271. return
  272. }
  273. // Switch on the type
  274. switch hdr.MsgType() {
  275. case typeData:
  276. handler = s.handleStreamMessage
  277. case typeWindowUpdate:
  278. handler = s.handleStreamMessage
  279. case typeGoAway:
  280. handler = s.handleGoAway
  281. case typePing:
  282. handler = s.handlePing
  283. default:
  284. s.exitErr(ErrInvalidMsgType)
  285. return
  286. }
  287. // Invoke the handler
  288. if err := handler(hdr); err != nil {
  289. s.exitErr(err)
  290. return
  291. }
  292. }
  293. }
  294. // handleStreamMessage handles either a data or window update frame
  295. func (s *Session) handleStreamMessage(hdr header) error {
  296. // Check for a new stream creation
  297. id := hdr.StreamID()
  298. flags := hdr.Flags()
  299. if flags&flagSYN == flagSYN {
  300. if err := s.incomingStream(id); err != nil {
  301. return err
  302. }
  303. }
  304. // Get the stream
  305. s.streamLock.Lock()
  306. stream := s.streams[id]
  307. s.streamLock.Unlock()
  308. // Make sure we have a stream
  309. if stream == nil {
  310. s.sendNoWait(s.goAway(goAwayProtoErr))
  311. return ErrMissingStream
  312. }
  313. // Check if this is a window update
  314. if hdr.MsgType() == typeWindowUpdate {
  315. if err := stream.incrSendWindow(hdr, flags); err != nil {
  316. s.sendNoWait(s.goAway(goAwayProtoErr))
  317. return err
  318. }
  319. return nil
  320. }
  321. // Read the new data
  322. if err := stream.readData(hdr, flags, s.conn); err != nil {
  323. s.sendNoWait(s.goAway(goAwayProtoErr))
  324. return err
  325. }
  326. return nil
  327. }
  328. // handlePing is invokde for a typePing frame
  329. func (s *Session) handlePing(hdr header) error {
  330. flags := hdr.Flags()
  331. pingID := hdr.Length()
  332. // Check if this is a query, respond back
  333. if flags&flagSYN == flagSYN {
  334. hdr := header(make([]byte, headerSize))
  335. hdr.encode(typePing, flagACK, 0, pingID)
  336. s.sendNoWait(hdr)
  337. return nil
  338. }
  339. // Handle a response
  340. s.pingLock.Lock()
  341. ch := s.pings[pingID]
  342. if ch != nil {
  343. delete(s.pings, pingID)
  344. close(ch)
  345. }
  346. s.pingLock.Unlock()
  347. return nil
  348. }
  349. // handleGoAway is invokde for a typeGoAway frame
  350. func (s *Session) handleGoAway(hdr header) error {
  351. code := hdr.Length()
  352. switch code {
  353. case goAwayNormal:
  354. atomic.SwapInt32(&s.remoteGoAway, 1)
  355. case goAwayProtoErr:
  356. return fmt.Errorf("yamux protocol error")
  357. case goAwayInternalErr:
  358. return fmt.Errorf("remote yamux internal error")
  359. default:
  360. return fmt.Errorf("unexpected go away received")
  361. }
  362. return nil
  363. }
  364. // incomingStream is used to create a new incoming stream
  365. func (s *Session) incomingStream(id uint32) error {
  366. // Reject immediately if we are doing a go away
  367. if atomic.LoadInt32(&s.localGoAway) == 1 {
  368. hdr := header(make([]byte, headerSize))
  369. hdr.encode(typeWindowUpdate, flagRST, id, 0)
  370. return s.waitForSend(hdr, nil)
  371. }
  372. s.streamLock.Lock()
  373. defer s.streamLock.Unlock()
  374. // Check if stream already exists
  375. if _, ok := s.streams[id]; ok {
  376. s.sendNoWait(s.goAway(goAwayProtoErr))
  377. return ErrDuplicateStream
  378. }
  379. // Register the stream
  380. stream := newStream(s, id, streamSYNReceived)
  381. s.streams[id] = stream
  382. // Check if we've exceeded the backlog
  383. select {
  384. case s.acceptCh <- stream:
  385. return nil
  386. default:
  387. // Backlog exceeded! RST the stream
  388. delete(s.streams, id)
  389. stream.sendHdr.encode(typeWindowUpdate, flagRST, id, 0)
  390. s.sendNoWait(stream.sendHdr)
  391. }
  392. return nil
  393. }
  394. // closeStream is used to close a stream once both sides have
  395. // issued a close.
  396. func (s *Session) closeStream(id uint32, withLock bool) {
  397. if !withLock {
  398. s.streamLock.Lock()
  399. defer s.streamLock.Unlock()
  400. }
  401. delete(s.streams, id)
  402. }