session.go 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. package yamux
  2. import (
  3. "fmt"
  4. "io"
  5. "math"
  6. "net"
  7. "sync"
  8. "sync/atomic"
  9. "time"
  10. )
  11. // Session is used to wrap a reliable ordered connection and to
  12. // multiplex it into multiple streams.
  13. type Session struct {
  14. // remoteGoAway indicates the remote side does
  15. // not want futher connections. Must be first for alignment.
  16. remoteGoAway int32
  17. // localGoAway indicates that we should stop
  18. // accepting futher connections. Must be first for alignment.
  19. localGoAway int32
  20. // client is true if we are a client size connection
  21. client bool
  22. // config holds our configuration
  23. config *Config
  24. // conn is the underlying connection
  25. conn io.ReadWriteCloser
  26. // pings is used to track inflight pings
  27. pings map[uint32]chan struct{}
  28. pingID uint32
  29. pingLock sync.Mutex
  30. // nextStreamID is the next stream we should
  31. // send. This depends if we are a client/server.
  32. nextStreamID uint32
  33. // streams maps a stream id to a stream
  34. streams map[uint32]*Stream
  35. streamLock sync.RWMutex
  36. // acceptCh is used to pass ready streams to the client
  37. acceptCh chan *Stream
  38. // sendCh is used to mark a stream as ready to send,
  39. // or to send a header out directly.
  40. sendCh chan sendReady
  41. // shutdown is used to safely close a session
  42. shutdown bool
  43. shutdownErr error
  44. shutdownCh chan struct{}
  45. shutdownLock sync.Mutex
  46. }
  47. // sendReady is used to either mark a stream as ready
  48. // or to directly send a header
  49. type sendReady struct {
  50. Hdr []byte
  51. Body io.Reader
  52. Err chan error
  53. }
  54. // newSession is used to construct a new session
  55. func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
  56. s := &Session{
  57. client: client,
  58. config: config,
  59. conn: conn,
  60. pings: make(map[uint32]chan struct{}),
  61. streams: make(map[uint32]*Stream),
  62. acceptCh: make(chan *Stream, config.AcceptBacklog),
  63. sendCh: make(chan sendReady, 64),
  64. shutdownCh: make(chan struct{}),
  65. }
  66. if client {
  67. s.nextStreamID = 1
  68. } else {
  69. s.nextStreamID = 2
  70. }
  71. go s.recv()
  72. go s.send()
  73. if config.EnableKeepAlive {
  74. go s.keepalive()
  75. }
  76. return s
  77. }
  78. // IsClosed does a safe check to see if we have shutdown
  79. func (s *Session) IsClosed() bool {
  80. select {
  81. case <-s.shutdownCh:
  82. return true
  83. default:
  84. return false
  85. }
  86. }
  87. // Open is used to create a new stream
  88. func (s *Session) Open() (*Stream, error) {
  89. if s.IsClosed() {
  90. return nil, ErrSessionShutdown
  91. }
  92. if atomic.LoadInt32(&s.remoteGoAway) == 1 {
  93. return nil, ErrRemoteGoAway
  94. }
  95. s.streamLock.Lock()
  96. defer s.streamLock.Unlock()
  97. // Check if we've exhaused the streams
  98. id := s.nextStreamID
  99. if id >= math.MaxUint32-1 {
  100. return nil, ErrStreamsExhausted
  101. }
  102. s.nextStreamID += 2
  103. // Register the stream
  104. stream := newStream(s, id, streamInit)
  105. s.streams[id] = stream
  106. // Send the window update to create
  107. return stream, stream.sendWindowUpdate()
  108. }
  109. // Accept is used to block until the next available stream
  110. // is ready to be accepted.
  111. func (s *Session) Accept() (net.Conn, error) {
  112. return s.AcceptStream()
  113. }
  114. // AcceptStream is used to block until the next available stream
  115. // is ready to be accepted.
  116. func (s *Session) AcceptStream() (*Stream, error) {
  117. select {
  118. case stream := <-s.acceptCh:
  119. return stream, nil
  120. case <-s.shutdownCh:
  121. return nil, s.shutdownErr
  122. }
  123. }
  124. // Close is used to close the session and all streams.
  125. // Attempts to send a GoAway before closing the connection.
  126. func (s *Session) Close() error {
  127. s.shutdownLock.Lock()
  128. defer s.shutdownLock.Unlock()
  129. if s.shutdown {
  130. return nil
  131. }
  132. s.shutdown = true
  133. if s.shutdownErr == nil {
  134. s.shutdownErr = ErrSessionShutdown
  135. }
  136. close(s.shutdownCh)
  137. s.conn.Close()
  138. s.streamLock.Lock()
  139. defer s.streamLock.Unlock()
  140. for _, stream := range s.streams {
  141. stream.forceClose()
  142. }
  143. return nil
  144. }
  145. // GoAway can be used to prevent accepting further
  146. // connections. It does not close the underlying conn.
  147. func (s *Session) GoAway() error {
  148. atomic.SwapInt32(&s.localGoAway, 1)
  149. s.goAway(goAwayNormal)
  150. return nil
  151. }
  152. // Ping is used to measure the RTT response time
  153. func (s *Session) Ping() (time.Duration, error) {
  154. // Get a channel for the ping
  155. ch := make(chan struct{})
  156. // Get a new ping id, mark as pending
  157. s.pingLock.Lock()
  158. id := s.pingID
  159. s.pingID++
  160. s.pings[id] = ch
  161. s.pingLock.Unlock()
  162. // Send the ping request
  163. hdr := header(make([]byte, headerSize))
  164. hdr.encode(typePing, flagSYN, 0, id)
  165. if err := s.waitForSend(hdr, nil); err != nil {
  166. return 0, err
  167. }
  168. // Wait for a response
  169. start := time.Now()
  170. select {
  171. case <-ch:
  172. case <-s.shutdownCh:
  173. return 0, ErrSessionShutdown
  174. }
  175. // Compute the RTT
  176. return time.Now().Sub(start), nil
  177. }
  178. // keepalive is a long running goroutine that periodically does
  179. // a ping to keep the connection alive.
  180. func (s *Session) keepalive() {
  181. for {
  182. select {
  183. case <-time.After(s.config.KeepAliveInterval):
  184. s.Ping()
  185. case <-s.shutdownCh:
  186. return
  187. }
  188. }
  189. }
  190. // waitForSend waits to send a header, checking for a potential shutdown
  191. func (s *Session) waitForSend(hdr header, body io.Reader) error {
  192. errCh := make(chan error, 1)
  193. ready := sendReady{Hdr: hdr, Body: body, Err: errCh}
  194. select {
  195. case s.sendCh <- ready:
  196. case <-s.shutdownCh:
  197. return ErrSessionShutdown
  198. }
  199. select {
  200. case err := <-errCh:
  201. return err
  202. case <-s.shutdownCh:
  203. return ErrSessionShutdown
  204. }
  205. }
  206. // sendNoWait does a send without waiting
  207. func (s *Session) sendNoWait(hdr header) error {
  208. select {
  209. case s.sendCh <- sendReady{Hdr: hdr}:
  210. return nil
  211. case <-s.shutdownCh:
  212. return ErrSessionShutdown
  213. }
  214. }
  215. // send is a long running goroutine that sends data
  216. func (s *Session) send() {
  217. for {
  218. select {
  219. case ready := <-s.sendCh:
  220. // Send a header if ready
  221. if ready.Hdr != nil {
  222. sent := 0
  223. for sent < len(ready.Hdr) {
  224. n, err := s.conn.Write(ready.Hdr[sent:])
  225. if err != nil {
  226. s.exitErr(err)
  227. asyncSendErr(ready.Err, err)
  228. return
  229. }
  230. sent += n
  231. }
  232. }
  233. // Send data from a body if given
  234. if ready.Body != nil {
  235. _, err := io.Copy(s.conn, ready.Body)
  236. if err != nil {
  237. s.exitErr(err)
  238. asyncSendErr(ready.Err, err)
  239. return
  240. }
  241. }
  242. // No error, successful send
  243. asyncSendErr(ready.Err, nil)
  244. case <-s.shutdownCh:
  245. return
  246. }
  247. }
  248. }
  249. // recv is a long running goroutine that accepts new data
  250. func (s *Session) recv() {
  251. hdr := header(make([]byte, headerSize))
  252. for !s.IsClosed() {
  253. // Read the header
  254. if _, err := io.ReadFull(s.conn, hdr); err != nil {
  255. s.exitErr(err)
  256. return
  257. }
  258. // Verify the version
  259. if hdr.Version() != protoVersion {
  260. s.exitErr(ErrInvalidVersion)
  261. return
  262. }
  263. // Switch on the type
  264. msgType := hdr.MsgType()
  265. switch msgType {
  266. case typeData:
  267. fallthrough
  268. case typeWindowUpdate:
  269. if err := s.handleStreamMessage(hdr); err != nil {
  270. s.exitErr(err)
  271. return
  272. }
  273. case typeGoAway:
  274. if err := s.handleGoAway(hdr); err != nil {
  275. s.exitErr(err)
  276. return
  277. }
  278. case typePing:
  279. if err := s.handlePing(hdr); err != nil {
  280. s.exitErr(err)
  281. return
  282. }
  283. default:
  284. s.exitErr(ErrInvalidMsgType)
  285. return
  286. }
  287. }
  288. }
  289. // handleStreamMessage handles either a data or window update frame
  290. func (s *Session) handleStreamMessage(hdr header) error {
  291. // Check for a new stream creation
  292. id := hdr.StreamID()
  293. flags := hdr.Flags()
  294. if flags&flagSYN == flagSYN {
  295. if err := s.incomingStream(id); err != nil {
  296. return err
  297. }
  298. }
  299. // Get the stream
  300. s.streamLock.RLock()
  301. stream := s.streams[id]
  302. s.streamLock.RUnlock()
  303. // Make sure we have a stream
  304. if stream == nil {
  305. s.goAway(goAwayProtoErr)
  306. return ErrMissingStream
  307. }
  308. // Check if this is a window update
  309. if hdr.MsgType() == typeWindowUpdate {
  310. if err := stream.incrSendWindow(hdr, flags); err != nil {
  311. s.goAway(goAwayProtoErr)
  312. return err
  313. }
  314. }
  315. // Read the new data
  316. if err := stream.readData(hdr, flags, s.conn); err != nil {
  317. s.goAway(goAwayProtoErr)
  318. return err
  319. }
  320. return nil
  321. }
  322. // handlePing is invokde for a typePing frame
  323. func (s *Session) handlePing(hdr header) error {
  324. flags := hdr.Flags()
  325. pingID := hdr.Length()
  326. // Check if this is a query, respond back
  327. if flags&flagSYN == flagSYN {
  328. hdr := header(make([]byte, headerSize))
  329. hdr.encode(typePing, flagACK, 0, pingID)
  330. s.sendNoWait(hdr)
  331. return nil
  332. }
  333. // Handle a response
  334. s.pingLock.Lock()
  335. ch := s.pings[pingID]
  336. if ch != nil {
  337. delete(s.pings, pingID)
  338. close(ch)
  339. }
  340. s.pingLock.Unlock()
  341. return nil
  342. }
  343. // handleGoAway is invokde for a typeGoAway frame
  344. func (s *Session) handleGoAway(hdr header) error {
  345. code := hdr.Length()
  346. switch code {
  347. case goAwayNormal:
  348. atomic.SwapInt32(&s.remoteGoAway, 1)
  349. case goAwayProtoErr:
  350. return fmt.Errorf("yamux protocol error")
  351. case goAwayInternalErr:
  352. return fmt.Errorf("remote yamux internal error")
  353. default:
  354. return fmt.Errorf("unexpected go away received")
  355. }
  356. return nil
  357. }
  358. // exitErr is used to handle an error that is causing
  359. // the listener to exit.
  360. func (s *Session) exitErr(err error) {
  361. s.shutdownErr = err
  362. s.Close()
  363. }
  364. // goAway is used to send a goAway message
  365. func (s *Session) goAway(reason uint32) {
  366. hdr := header(make([]byte, headerSize))
  367. hdr.encode(typeGoAway, 0, 0, reason)
  368. s.sendNoWait(hdr)
  369. }
  370. // incomingStream is used to create a new incoming stream
  371. func (s *Session) incomingStream(id uint32) error {
  372. // Reject immediately if we are doing a go away
  373. if atomic.LoadInt32(&s.localGoAway) == 1 {
  374. hdr := header(make([]byte, headerSize))
  375. hdr.encode(typeWindowUpdate, flagRST, id, 0)
  376. s.sendNoWait(hdr)
  377. return nil
  378. }
  379. s.streamLock.Lock()
  380. defer s.streamLock.Unlock()
  381. // Check if stream already exists
  382. if _, ok := s.streams[id]; ok {
  383. s.goAway(goAwayProtoErr)
  384. s.exitErr(ErrDuplicateStream)
  385. return nil
  386. }
  387. // Register the stream
  388. stream := newStream(s, id, streamSYNReceived)
  389. s.streams[id] = stream
  390. // Check if we've exceeded the backlog
  391. select {
  392. case s.acceptCh <- stream:
  393. return nil
  394. default:
  395. // Backlog exceeded! RST the stream
  396. delete(s.streams, id)
  397. stream.sendHdr.encode(typeWindowUpdate, flagRST, id, 0)
  398. s.sendNoWait(stream.sendHdr)
  399. }
  400. return nil
  401. }
  402. // closeStream is used to close a stream once both sides have
  403. // issued a close.
  404. func (s *Session) closeStream(id uint32, withLock bool) {
  405. if !withLock {
  406. s.streamLock.Lock()
  407. defer s.streamLock.Unlock()
  408. }
  409. delete(s.streams, id)
  410. }