session.go 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. package yamux
  2. import (
  3. "fmt"
  4. "io"
  5. "math"
  6. "net"
  7. "sync"
  8. "time"
  9. )
  10. // Session is used to wrap a reliable ordered connection and to
  11. // multiplex it into multiple streams.
  12. type Session struct {
  13. // client is true if we are a client size connection
  14. client bool
  15. // config holds our configuration
  16. config *Config
  17. // conn is the underlying connection
  18. conn io.ReadWriteCloser
  19. // pings is used to track inflight pings
  20. pings map[uint32]chan struct{}
  21. pingID uint32
  22. pingLock sync.Mutex
  23. // remoteGoAway indicates the remote side does
  24. // not want futher connections
  25. remoteGoAway bool
  26. // localGoAway indicates that we should stop
  27. // accepting futher connections
  28. localGoAway bool
  29. // nextStreamID is the next stream we should
  30. // send. This depends if we are a client/server.
  31. nextStreamID uint32
  32. // streams maps a stream id to a stream
  33. streams map[uint32]*Stream
  34. streamLock sync.RWMutex
  35. // acceptCh is used to pass ready streams to the client
  36. acceptCh chan *Stream
  37. // sendCh is used to mark a stream as ready to send,
  38. // or to send a header out directly.
  39. sendCh chan sendReady
  40. // shutdown is used to safely close a session
  41. shutdown bool
  42. shutdownErr error
  43. shutdownCh chan struct{}
  44. shutdownLock sync.Mutex
  45. }
  46. // sendReady is used to either mark a stream as ready
  47. // or to directly send a header
  48. type sendReady struct {
  49. Hdr []byte
  50. Body io.Reader
  51. Err chan error
  52. }
  53. // newSession is used to construct a new session
  54. func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
  55. s := &Session{
  56. client: client,
  57. config: config,
  58. conn: conn,
  59. pings: make(map[uint32]chan struct{}),
  60. streams: make(map[uint32]*Stream),
  61. acceptCh: make(chan *Stream, config.AcceptBacklog),
  62. sendCh: make(chan sendReady, 64),
  63. shutdownCh: make(chan struct{}),
  64. }
  65. if client {
  66. s.nextStreamID = 1
  67. } else {
  68. s.nextStreamID = 2
  69. }
  70. go s.recv()
  71. go s.send()
  72. if config.EnableKeepAlive {
  73. go s.keepalive()
  74. }
  75. return s
  76. }
  77. // isShutdown does a safe check to see if we have shutdown
  78. func (s *Session) isShutdown() bool {
  79. select {
  80. case <-s.shutdownCh:
  81. return true
  82. default:
  83. return false
  84. }
  85. }
  86. // Open is used to create a new stream
  87. func (s *Session) Open() (*Stream, error) {
  88. if s.isShutdown() {
  89. return nil, ErrSessionShutdown
  90. }
  91. if s.remoteGoAway {
  92. return nil, ErrRemoteGoAway
  93. }
  94. s.streamLock.Lock()
  95. defer s.streamLock.Unlock()
  96. // Check if we've exhaused the streams
  97. id := s.nextStreamID
  98. if id >= math.MaxUint32-1 {
  99. return nil, ErrStreamsExhausted
  100. }
  101. s.nextStreamID += 2
  102. // Register the stream
  103. stream := newStream(s, id, streamInit)
  104. s.streams[id] = stream
  105. // Send the window update to create
  106. return stream, stream.sendWindowUpdate()
  107. }
  108. // Accept is used to block until the next available stream
  109. // is ready to be accepted.
  110. func (s *Session) Accept() (net.Conn, error) {
  111. return s.AcceptStream()
  112. }
  113. // AcceptStream is used to block until the next available stream
  114. // is ready to be accepted.
  115. func (s *Session) AcceptStream() (*Stream, error) {
  116. select {
  117. case stream := <-s.acceptCh:
  118. return stream, nil
  119. case <-s.shutdownCh:
  120. return nil, s.shutdownErr
  121. }
  122. }
  123. // Close is used to close the session and all streams.
  124. // Attempts to send a GoAway before closing the connection.
  125. func (s *Session) Close() error {
  126. s.shutdownLock.Lock()
  127. defer s.shutdownLock.Unlock()
  128. if s.shutdown {
  129. return nil
  130. }
  131. s.shutdown = true
  132. if s.shutdownErr == nil {
  133. s.shutdownErr = ErrSessionShutdown
  134. }
  135. close(s.shutdownCh)
  136. s.conn.Close()
  137. s.streamLock.Lock()
  138. defer s.streamLock.Unlock()
  139. for _, stream := range s.streams {
  140. stream.forceClose()
  141. }
  142. return nil
  143. }
  144. // GoAway can be used to prevent accepting further
  145. // connections. It does not close the underlying conn.
  146. func (s *Session) GoAway() error {
  147. s.localGoAway = true
  148. s.goAway(goAwayNormal)
  149. return nil
  150. }
  151. // Ping is used to measure the RTT response time
  152. func (s *Session) Ping() (time.Duration, error) {
  153. // Get a channel for the ping
  154. ch := make(chan struct{})
  155. // Get a new ping id, mark as pending
  156. s.pingLock.Lock()
  157. id := s.pingID
  158. s.pingID++
  159. s.pings[id] = ch
  160. s.pingLock.Unlock()
  161. // Send the ping request
  162. hdr := header(make([]byte, headerSize))
  163. hdr.encode(typePing, flagSYN, 0, id)
  164. if err := s.waitForSend(hdr, nil); err != nil {
  165. return 0, err
  166. }
  167. // Wait for a response
  168. start := time.Now()
  169. select {
  170. case <-ch:
  171. case <-s.shutdownCh:
  172. return 0, ErrSessionShutdown
  173. }
  174. // Compute the RTT
  175. return time.Now().Sub(start), nil
  176. }
  177. // keepalive is a long running goroutine that periodically does
  178. // a ping to keep the connection alive.
  179. func (s *Session) keepalive() {
  180. for {
  181. select {
  182. case <-time.After(s.config.KeepAliveInterval):
  183. s.Ping()
  184. case <-s.shutdownCh:
  185. return
  186. }
  187. }
  188. }
  189. // waitForSend waits to send a header, checking for a potential shutdown
  190. func (s *Session) waitForSend(hdr header, body io.Reader) error {
  191. errCh := make(chan error, 1)
  192. ready := sendReady{Hdr: hdr, Body: body, Err: errCh}
  193. select {
  194. case s.sendCh <- ready:
  195. case <-s.shutdownCh:
  196. return ErrSessionShutdown
  197. }
  198. select {
  199. case err := <-errCh:
  200. return err
  201. case <-s.shutdownCh:
  202. return ErrSessionShutdown
  203. }
  204. }
  205. // sendNoWait does a send without waiting
  206. func (s *Session) sendNoWait(hdr header) error {
  207. select {
  208. case s.sendCh <- sendReady{Hdr: hdr}:
  209. return nil
  210. case <-s.shutdownCh:
  211. return ErrSessionShutdown
  212. }
  213. }
  214. // send is a long running goroutine that sends data
  215. func (s *Session) send() {
  216. for {
  217. select {
  218. case ready := <-s.sendCh:
  219. // Send a header if ready
  220. if ready.Hdr != nil {
  221. sent := 0
  222. for sent < len(ready.Hdr) {
  223. n, err := s.conn.Write(ready.Hdr[sent:])
  224. if err != nil {
  225. s.exitErr(err)
  226. asyncSendErr(ready.Err, err)
  227. return
  228. }
  229. sent += n
  230. }
  231. }
  232. // Send data from a body if given
  233. if ready.Body != nil {
  234. _, err := io.Copy(s.conn, ready.Body)
  235. if err != nil {
  236. s.exitErr(err)
  237. asyncSendErr(ready.Err, err)
  238. return
  239. }
  240. }
  241. // No error, successful send
  242. asyncSendErr(ready.Err, nil)
  243. case <-s.shutdownCh:
  244. return
  245. }
  246. }
  247. }
  248. // recv is a long running goroutine that accepts new data
  249. func (s *Session) recv() {
  250. hdr := header(make([]byte, headerSize))
  251. for !s.isShutdown() {
  252. // Read the header
  253. if _, err := io.ReadFull(s.conn, hdr); err != nil {
  254. s.exitErr(err)
  255. return
  256. }
  257. // Verify the version
  258. if hdr.Version() != protoVersion {
  259. s.exitErr(ErrInvalidVersion)
  260. return
  261. }
  262. // Switch on the type
  263. msgType := hdr.MsgType()
  264. switch msgType {
  265. case typeData:
  266. fallthrough
  267. case typeWindowUpdate:
  268. if err := s.handleStreamMessage(hdr); err != nil {
  269. s.exitErr(err)
  270. return
  271. }
  272. case typeGoAway:
  273. if err := s.handleGoAway(hdr); err != nil {
  274. s.exitErr(err)
  275. return
  276. }
  277. case typePing:
  278. if err := s.handlePing(hdr); err != nil {
  279. s.exitErr(err)
  280. return
  281. }
  282. default:
  283. s.exitErr(ErrInvalidMsgType)
  284. return
  285. }
  286. }
  287. }
  288. // handleStreamMessage handles either a data or window update frame
  289. func (s *Session) handleStreamMessage(hdr header) error {
  290. // Check for a new stream creation
  291. id := hdr.StreamID()
  292. flags := hdr.Flags()
  293. if flags&flagSYN == flagSYN {
  294. if err := s.incomingStream(id); err != nil {
  295. return err
  296. }
  297. }
  298. // Get the stream
  299. s.streamLock.RLock()
  300. stream := s.streams[id]
  301. s.streamLock.RUnlock()
  302. // Make sure we have a stream
  303. if stream == nil {
  304. s.goAway(goAwayProtoErr)
  305. return ErrMissingStream
  306. }
  307. // Check if this is a window update
  308. if hdr.MsgType() == typeWindowUpdate {
  309. if err := stream.incrSendWindow(hdr, flags); err != nil {
  310. s.goAway(goAwayProtoErr)
  311. return err
  312. }
  313. }
  314. // Read the new data
  315. if err := stream.readData(hdr, flags, s.conn); err != nil {
  316. s.goAway(goAwayProtoErr)
  317. return err
  318. }
  319. return nil
  320. }
  321. // handlePing is invokde for a typePing frame
  322. func (s *Session) handlePing(hdr header) error {
  323. flags := hdr.Flags()
  324. pingID := hdr.Length()
  325. // Check if this is a query, respond back
  326. if flags&flagSYN == flagSYN {
  327. hdr := header(make([]byte, headerSize))
  328. hdr.encode(typePing, flagACK, 0, pingID)
  329. s.sendNoWait(hdr)
  330. return nil
  331. }
  332. // Handle a response
  333. s.pingLock.Lock()
  334. ch := s.pings[pingID]
  335. if ch != nil {
  336. delete(s.pings, pingID)
  337. close(ch)
  338. }
  339. s.pingLock.Unlock()
  340. return nil
  341. }
  342. // handleGoAway is invokde for a typeGoAway frame
  343. func (s *Session) handleGoAway(hdr header) error {
  344. code := hdr.Length()
  345. switch code {
  346. case goAwayNormal:
  347. s.remoteGoAway = true
  348. case goAwayProtoErr:
  349. return fmt.Errorf("yamux protocol error")
  350. case goAwayInternalErr:
  351. return fmt.Errorf("remote yamux internal error")
  352. default:
  353. return fmt.Errorf("unexpected go away received")
  354. }
  355. return nil
  356. }
  357. // exitErr is used to handle an error that is causing
  358. // the listener to exit.
  359. func (s *Session) exitErr(err error) {
  360. s.shutdownErr = err
  361. s.Close()
  362. }
  363. // goAway is used to send a goAway message
  364. func (s *Session) goAway(reason uint32) {
  365. hdr := header(make([]byte, headerSize))
  366. hdr.encode(typeGoAway, 0, 0, reason)
  367. s.sendNoWait(hdr)
  368. }
  369. // incomingStream is used to create a new incoming stream
  370. func (s *Session) incomingStream(id uint32) error {
  371. // Reject immediately if we are doing a go away
  372. if s.localGoAway {
  373. hdr := header(make([]byte, headerSize))
  374. hdr.encode(typeWindowUpdate, flagRST, id, 0)
  375. s.sendNoWait(hdr)
  376. return nil
  377. }
  378. s.streamLock.Lock()
  379. defer s.streamLock.Unlock()
  380. // Check if stream already exists
  381. if _, ok := s.streams[id]; ok {
  382. s.goAway(goAwayProtoErr)
  383. s.exitErr(ErrDuplicateStream)
  384. return nil
  385. }
  386. // Register the stream
  387. stream := newStream(s, id, streamSYNReceived)
  388. s.streams[id] = stream
  389. // Check if we've exceeded the backlog
  390. select {
  391. case s.acceptCh <- stream:
  392. return nil
  393. default:
  394. // Backlog exceeded! RST the stream
  395. delete(s.streams, id)
  396. stream.sendHdr.encode(typeWindowUpdate, flagRST, id, 0)
  397. s.sendNoWait(stream.sendHdr)
  398. }
  399. return nil
  400. }
  401. // closeStream is used to close a stream once both sides have
  402. // issued a close.
  403. func (s *Session) closeStream(id uint32, withLock bool) {
  404. if !withLock {
  405. s.streamLock.Lock()
  406. defer s.streamLock.Unlock()
  407. }
  408. delete(s.streams, id)
  409. }