session.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. package yamux
  2. import (
  3. "bufio"
  4. "fmt"
  5. "io"
  6. "math"
  7. "net"
  8. "sync"
  9. "sync/atomic"
  10. "time"
  11. )
  12. // Session is used to wrap a reliable ordered connection and to
  13. // multiplex it into multiple streams.
  14. type Session struct {
  15. // remoteGoAway indicates the remote side does
  16. // not want futher connections. Must be first for alignment.
  17. remoteGoAway int32
  18. // localGoAway indicates that we should stop
  19. // accepting futher connections. Must be first for alignment.
  20. localGoAway int32
  21. // nextStreamID is the next stream we should
  22. // send. This depends if we are a client/server.
  23. nextStreamID uint32
  24. // config holds our configuration
  25. config *Config
  26. // conn is the underlying connection
  27. conn io.ReadWriteCloser
  28. // bufRead is a buffered reader
  29. bufRead *bufio.Reader
  30. // pings is used to track inflight pings
  31. pings map[uint32]chan struct{}
  32. pingID uint32
  33. pingLock sync.Mutex
  34. // streams maps a stream id to a stream
  35. streams map[uint32]*Stream
  36. streamLock sync.Mutex
  37. // acceptCh is used to pass ready streams to the client
  38. acceptCh chan *Stream
  39. // sendCh is used to mark a stream as ready to send,
  40. // or to send a header out directly.
  41. sendCh chan sendReady
  42. // shutdown is used to safely close a session
  43. shutdown bool
  44. shutdownErr error
  45. shutdownCh chan struct{}
  46. shutdownLock sync.Mutex
  47. }
  48. // sendReady is used to either mark a stream as ready
  49. // or to directly send a header
  50. type sendReady struct {
  51. Hdr []byte
  52. Body io.Reader
  53. Err chan error
  54. }
  55. // newSession is used to construct a new session
  56. func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
  57. s := &Session{
  58. config: config,
  59. conn: conn,
  60. bufRead: bufio.NewReader(conn),
  61. pings: make(map[uint32]chan struct{}),
  62. streams: make(map[uint32]*Stream),
  63. acceptCh: make(chan *Stream, config.AcceptBacklog),
  64. sendCh: make(chan sendReady, 64),
  65. shutdownCh: make(chan struct{}),
  66. }
  67. if client {
  68. s.nextStreamID = 1
  69. } else {
  70. s.nextStreamID = 2
  71. }
  72. go s.recv()
  73. go s.send()
  74. if config.EnableKeepAlive {
  75. go s.keepalive()
  76. }
  77. return s
  78. }
  79. // IsClosed does a safe check to see if we have shutdown
  80. func (s *Session) IsClosed() bool {
  81. select {
  82. case <-s.shutdownCh:
  83. return true
  84. default:
  85. return false
  86. }
  87. }
  88. // Open is used to create a new stream as a net.Conn
  89. func (s *Session) Open() (net.Conn, error) {
  90. return s.OpenStream()
  91. }
  92. // OpenStream is used to create a new stream
  93. func (s *Session) OpenStream() (*Stream, error) {
  94. if s.IsClosed() {
  95. return nil, ErrSessionShutdown
  96. }
  97. if atomic.LoadInt32(&s.remoteGoAway) == 1 {
  98. return nil, ErrRemoteGoAway
  99. }
  100. GET_ID:
  101. // Get and ID, and check for stream exhaustion
  102. id := atomic.LoadUint32(&s.nextStreamID)
  103. if id >= math.MaxUint32-1 {
  104. return nil, ErrStreamsExhausted
  105. }
  106. if !atomic.CompareAndSwapUint32(&s.nextStreamID, id, id+2) {
  107. goto GET_ID
  108. }
  109. // Register the stream
  110. stream := newStream(s, id, streamInit)
  111. s.streamLock.Lock()
  112. s.streams[id] = stream
  113. s.streamLock.Unlock()
  114. // Send the window update to create
  115. return stream, stream.sendWindowUpdate()
  116. }
  117. // Accept is used to block until the next available stream
  118. // is ready to be accepted.
  119. func (s *Session) Accept() (net.Conn, error) {
  120. return s.AcceptStream()
  121. }
  122. // AcceptStream is used to block until the next available stream
  123. // is ready to be accepted.
  124. func (s *Session) AcceptStream() (*Stream, error) {
  125. select {
  126. case stream := <-s.acceptCh:
  127. return stream, stream.sendWindowUpdate()
  128. case <-s.shutdownCh:
  129. return nil, s.shutdownErr
  130. }
  131. }
  132. // Close is used to close the session and all streams.
  133. // Attempts to send a GoAway before closing the connection.
  134. func (s *Session) Close() error {
  135. s.shutdownLock.Lock()
  136. defer s.shutdownLock.Unlock()
  137. if s.shutdown {
  138. return nil
  139. }
  140. s.shutdown = true
  141. if s.shutdownErr == nil {
  142. s.shutdownErr = ErrSessionShutdown
  143. }
  144. close(s.shutdownCh)
  145. s.conn.Close()
  146. s.streamLock.Lock()
  147. defer s.streamLock.Unlock()
  148. for _, stream := range s.streams {
  149. stream.forceClose()
  150. }
  151. return nil
  152. }
  153. // exitErr is used to handle an error that is causing the
  154. // session to terminate.
  155. func (s *Session) exitErr(err error) {
  156. s.shutdownErr = err
  157. s.Close()
  158. }
  159. // GoAway can be used to prevent accepting further
  160. // connections. It does not close the underlying conn.
  161. func (s *Session) GoAway() error {
  162. return s.waitForSend(s.goAway(goAwayNormal), nil)
  163. }
  164. // goAway is used to send a goAway message
  165. func (s *Session) goAway(reason uint32) header {
  166. atomic.SwapInt32(&s.localGoAway, 1)
  167. hdr := header(make([]byte, headerSize))
  168. hdr.encode(typeGoAway, 0, 0, reason)
  169. return hdr
  170. }
  171. // Ping is used to measure the RTT response time
  172. func (s *Session) Ping() (time.Duration, error) {
  173. // Get a channel for the ping
  174. ch := make(chan struct{})
  175. // Get a new ping id, mark as pending
  176. s.pingLock.Lock()
  177. id := s.pingID
  178. s.pingID++
  179. s.pings[id] = ch
  180. s.pingLock.Unlock()
  181. // Send the ping request
  182. hdr := header(make([]byte, headerSize))
  183. hdr.encode(typePing, flagSYN, 0, id)
  184. if err := s.waitForSend(hdr, nil); err != nil {
  185. return 0, err
  186. }
  187. // Wait for a response
  188. start := time.Now()
  189. select {
  190. case <-ch:
  191. case <-s.shutdownCh:
  192. return 0, ErrSessionShutdown
  193. }
  194. // Compute the RTT
  195. return time.Now().Sub(start), nil
  196. }
  197. // keepalive is a long running goroutine that periodically does
  198. // a ping to keep the connection alive.
  199. func (s *Session) keepalive() {
  200. for {
  201. select {
  202. case <-time.After(s.config.KeepAliveInterval):
  203. s.Ping()
  204. case <-s.shutdownCh:
  205. return
  206. }
  207. }
  208. }
  209. // waitForSendErr waits to send a header, checking for a potential shutdown
  210. func (s *Session) waitForSend(hdr header, body io.Reader) error {
  211. errCh := make(chan error, 1)
  212. return s.waitForSendErr(hdr, body, errCh)
  213. }
  214. // waitForSendErr waits to send a header, checking for a potential shutdown
  215. func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error {
  216. ready := sendReady{Hdr: hdr, Body: body, Err: errCh}
  217. select {
  218. case s.sendCh <- ready:
  219. case <-s.shutdownCh:
  220. return ErrSessionShutdown
  221. }
  222. select {
  223. case err := <-errCh:
  224. return err
  225. case <-s.shutdownCh:
  226. return ErrSessionShutdown
  227. }
  228. }
  229. // sendNoWait does a send without waiting
  230. func (s *Session) sendNoWait(hdr header) error {
  231. select {
  232. case s.sendCh <- sendReady{Hdr: hdr}:
  233. return nil
  234. case <-s.shutdownCh:
  235. return ErrSessionShutdown
  236. }
  237. }
  238. // send is a long running goroutine that sends data
  239. func (s *Session) send() {
  240. for {
  241. select {
  242. case ready := <-s.sendCh:
  243. // Send a header if ready
  244. if ready.Hdr != nil {
  245. sent := 0
  246. for sent < len(ready.Hdr) {
  247. n, err := s.conn.Write(ready.Hdr[sent:])
  248. if err != nil {
  249. asyncSendErr(ready.Err, err)
  250. s.exitErr(err)
  251. return
  252. }
  253. sent += n
  254. }
  255. }
  256. // Send data from a body if given
  257. if ready.Body != nil {
  258. _, err := io.Copy(s.conn, ready.Body)
  259. if err != nil {
  260. asyncSendErr(ready.Err, err)
  261. s.exitErr(err)
  262. return
  263. }
  264. }
  265. // No error, successful send
  266. asyncSendErr(ready.Err, nil)
  267. case <-s.shutdownCh:
  268. return
  269. }
  270. }
  271. }
  272. // recv is a long running goroutine that accepts new data
  273. func (s *Session) recv() {
  274. hdr := header(make([]byte, headerSize))
  275. var handler func(header) error
  276. for {
  277. // Read the header
  278. if _, err := io.ReadFull(s.bufRead, hdr); err != nil {
  279. s.exitErr(err)
  280. return
  281. }
  282. // Verify the version
  283. if hdr.Version() != protoVersion {
  284. s.exitErr(ErrInvalidVersion)
  285. return
  286. }
  287. // Switch on the type
  288. switch hdr.MsgType() {
  289. case typeData:
  290. handler = s.handleStreamMessage
  291. case typeWindowUpdate:
  292. handler = s.handleStreamMessage
  293. case typeGoAway:
  294. handler = s.handleGoAway
  295. case typePing:
  296. handler = s.handlePing
  297. default:
  298. s.exitErr(ErrInvalidMsgType)
  299. return
  300. }
  301. // Invoke the handler
  302. if err := handler(hdr); err != nil {
  303. s.exitErr(err)
  304. return
  305. }
  306. }
  307. }
  308. // handleStreamMessage handles either a data or window update frame
  309. func (s *Session) handleStreamMessage(hdr header) error {
  310. // Check for a new stream creation
  311. id := hdr.StreamID()
  312. flags := hdr.Flags()
  313. if flags&flagSYN == flagSYN {
  314. if err := s.incomingStream(id); err != nil {
  315. return err
  316. }
  317. }
  318. // Get the stream
  319. s.streamLock.Lock()
  320. stream := s.streams[id]
  321. s.streamLock.Unlock()
  322. // If we do not have a stream, likely we sent a RST
  323. if stream == nil {
  324. return nil
  325. }
  326. // Check if this is a window update
  327. if hdr.MsgType() == typeWindowUpdate {
  328. if err := stream.incrSendWindow(hdr, flags); err != nil {
  329. s.sendNoWait(s.goAway(goAwayProtoErr))
  330. return err
  331. }
  332. return nil
  333. }
  334. // Read the new data
  335. if err := stream.readData(hdr, flags, s.bufRead); err != nil {
  336. s.sendNoWait(s.goAway(goAwayProtoErr))
  337. return err
  338. }
  339. return nil
  340. }
  341. // handlePing is invokde for a typePing frame
  342. func (s *Session) handlePing(hdr header) error {
  343. flags := hdr.Flags()
  344. pingID := hdr.Length()
  345. // Check if this is a query, respond back
  346. if flags&flagSYN == flagSYN {
  347. hdr := header(make([]byte, headerSize))
  348. hdr.encode(typePing, flagACK, 0, pingID)
  349. s.sendNoWait(hdr)
  350. return nil
  351. }
  352. // Handle a response
  353. s.pingLock.Lock()
  354. ch := s.pings[pingID]
  355. if ch != nil {
  356. delete(s.pings, pingID)
  357. close(ch)
  358. }
  359. s.pingLock.Unlock()
  360. return nil
  361. }
  362. // handleGoAway is invokde for a typeGoAway frame
  363. func (s *Session) handleGoAway(hdr header) error {
  364. code := hdr.Length()
  365. switch code {
  366. case goAwayNormal:
  367. atomic.SwapInt32(&s.remoteGoAway, 1)
  368. case goAwayProtoErr:
  369. return fmt.Errorf("yamux protocol error")
  370. case goAwayInternalErr:
  371. return fmt.Errorf("remote yamux internal error")
  372. default:
  373. return fmt.Errorf("unexpected go away received")
  374. }
  375. return nil
  376. }
  377. // incomingStream is used to create a new incoming stream
  378. func (s *Session) incomingStream(id uint32) error {
  379. // Reject immediately if we are doing a go away
  380. if atomic.LoadInt32(&s.localGoAway) == 1 {
  381. hdr := header(make([]byte, headerSize))
  382. hdr.encode(typeWindowUpdate, flagRST, id, 0)
  383. return s.sendNoWait(hdr)
  384. }
  385. // Allocate a new stream
  386. stream := newStream(s, id, streamSYNReceived)
  387. s.streamLock.Lock()
  388. defer s.streamLock.Unlock()
  389. // Check if stream already exists
  390. if _, ok := s.streams[id]; ok {
  391. s.sendNoWait(s.goAway(goAwayProtoErr))
  392. return ErrDuplicateStream
  393. }
  394. // Register the stream
  395. s.streams[id] = stream
  396. // Check if we've exceeded the backlog
  397. select {
  398. case s.acceptCh <- stream:
  399. return nil
  400. default:
  401. // Backlog exceeded! RST the stream
  402. delete(s.streams, id)
  403. stream.sendHdr.encode(typeWindowUpdate, flagRST, id, 0)
  404. return s.sendNoWait(stream.sendHdr)
  405. }
  406. }
  407. // closeStream is used to close a stream once both sides have
  408. // issued a close.
  409. func (s *Session) closeStream(id uint32, withLock bool) {
  410. if !withLock {
  411. s.streamLock.Lock()
  412. defer s.streamLock.Unlock()
  413. }
  414. delete(s.streams, id)
  415. }