session.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623
  1. package yamux
  2. import (
  3. "bufio"
  4. "fmt"
  5. "io"
  6. "io/ioutil"
  7. "log"
  8. "math"
  9. "net"
  10. "strings"
  11. "sync"
  12. "sync/atomic"
  13. "time"
  14. )
  15. // Session is used to wrap a reliable ordered connection and to
  16. // multiplex it into multiple streams.
  17. type Session struct {
  18. // remoteGoAway indicates the remote side does
  19. // not want futher connections. Must be first for alignment.
  20. remoteGoAway int32
  21. // localGoAway indicates that we should stop
  22. // accepting futher connections. Must be first for alignment.
  23. localGoAway int32
  24. // nextStreamID is the next stream we should
  25. // send. This depends if we are a client/server.
  26. nextStreamID uint32
  27. // config holds our configuration
  28. config *Config
  29. // logger is used for our logs
  30. logger *log.Logger
  31. // conn is the underlying connection
  32. conn io.ReadWriteCloser
  33. // bufRead is a buffered reader
  34. bufRead *bufio.Reader
  35. // pings is used to track inflight pings
  36. pings map[uint32]chan struct{}
  37. pingID uint32
  38. pingLock sync.Mutex
  39. // streams maps a stream id to a stream, and inflight has an entry
  40. // for any outgoing stream that has not yet been established. Both are
  41. // protected by streamLock.
  42. streams map[uint32]*Stream
  43. inflight map[uint32]struct{}
  44. streamLock sync.Mutex
  45. // synCh acts like a semaphore. It is sized to the AcceptBacklog which
  46. // is assumed to be symmetric between the client and server. This allows
  47. // the client to avoid exceeding the backlog and instead blocks the open.
  48. synCh chan struct{}
  49. // acceptCh is used to pass ready streams to the client
  50. acceptCh chan *Stream
  51. // sendCh is used to mark a stream as ready to send,
  52. // or to send a header out directly.
  53. sendCh chan sendReady
  54. // recvDoneCh is closed when recv() exits to avoid a race
  55. // between stream registration and stream shutdown
  56. recvDoneCh chan struct{}
  57. // shutdown is used to safely close a session
  58. shutdown bool
  59. shutdownErr error
  60. shutdownCh chan struct{}
  61. shutdownLock sync.Mutex
  62. }
  63. // sendReady is used to either mark a stream as ready
  64. // or to directly send a header
  65. type sendReady struct {
  66. Hdr []byte
  67. Body io.Reader
  68. Err chan error
  69. }
  70. // newSession is used to construct a new session
  71. func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
  72. s := &Session{
  73. config: config,
  74. logger: log.New(config.LogOutput, "", log.LstdFlags),
  75. conn: conn,
  76. bufRead: bufio.NewReader(conn),
  77. pings: make(map[uint32]chan struct{}),
  78. streams: make(map[uint32]*Stream),
  79. inflight: make(map[uint32]struct{}),
  80. synCh: make(chan struct{}, config.AcceptBacklog),
  81. acceptCh: make(chan *Stream, config.AcceptBacklog),
  82. sendCh: make(chan sendReady, 64),
  83. recvDoneCh: make(chan struct{}),
  84. shutdownCh: make(chan struct{}),
  85. }
  86. if client {
  87. s.nextStreamID = 1
  88. } else {
  89. s.nextStreamID = 2
  90. }
  91. go s.recv()
  92. go s.send()
  93. if config.EnableKeepAlive {
  94. go s.keepalive()
  95. }
  96. return s
  97. }
  98. // IsClosed does a safe check to see if we have shutdown
  99. func (s *Session) IsClosed() bool {
  100. select {
  101. case <-s.shutdownCh:
  102. return true
  103. default:
  104. return false
  105. }
  106. }
  107. // NumStreams returns the number of currently open streams
  108. func (s *Session) NumStreams() int {
  109. s.streamLock.Lock()
  110. num := len(s.streams)
  111. s.streamLock.Unlock()
  112. return num
  113. }
  114. // Open is used to create a new stream as a net.Conn
  115. func (s *Session) Open() (net.Conn, error) {
  116. conn, err := s.OpenStream()
  117. if err != nil {
  118. return nil, err
  119. }
  120. return conn, nil
  121. }
  122. // OpenStream is used to create a new stream
  123. func (s *Session) OpenStream() (*Stream, error) {
  124. if s.IsClosed() {
  125. return nil, ErrSessionShutdown
  126. }
  127. if atomic.LoadInt32(&s.remoteGoAway) == 1 {
  128. return nil, ErrRemoteGoAway
  129. }
  130. // Block if we have too many inflight SYNs
  131. select {
  132. case s.synCh <- struct{}{}:
  133. case <-s.shutdownCh:
  134. return nil, ErrSessionShutdown
  135. }
  136. GET_ID:
  137. // Get an ID, and check for stream exhaustion
  138. id := atomic.LoadUint32(&s.nextStreamID)
  139. if id >= math.MaxUint32-1 {
  140. return nil, ErrStreamsExhausted
  141. }
  142. if !atomic.CompareAndSwapUint32(&s.nextStreamID, id, id+2) {
  143. goto GET_ID
  144. }
  145. // Register the stream
  146. stream := newStream(s, id, streamInit)
  147. s.streamLock.Lock()
  148. s.streams[id] = stream
  149. s.inflight[id] = struct{}{}
  150. s.streamLock.Unlock()
  151. // Send the window update to create
  152. if err := stream.sendWindowUpdate(); err != nil {
  153. select {
  154. case <-s.synCh:
  155. default:
  156. s.logger.Printf("[ERR] yamux: aborted stream open without inflight syn semaphore")
  157. }
  158. return nil, err
  159. }
  160. return stream, nil
  161. }
  162. // Accept is used to block until the next available stream
  163. // is ready to be accepted.
  164. func (s *Session) Accept() (net.Conn, error) {
  165. conn, err := s.AcceptStream()
  166. if err != nil {
  167. return nil, err
  168. }
  169. return conn, err
  170. }
  171. // AcceptStream is used to block until the next available stream
  172. // is ready to be accepted.
  173. func (s *Session) AcceptStream() (*Stream, error) {
  174. select {
  175. case stream := <-s.acceptCh:
  176. if err := stream.sendWindowUpdate(); err != nil {
  177. return nil, err
  178. }
  179. return stream, nil
  180. case <-s.shutdownCh:
  181. return nil, s.shutdownErr
  182. }
  183. }
  184. // Close is used to close the session and all streams.
  185. // Attempts to send a GoAway before closing the connection.
  186. func (s *Session) Close() error {
  187. s.shutdownLock.Lock()
  188. defer s.shutdownLock.Unlock()
  189. if s.shutdown {
  190. return nil
  191. }
  192. s.shutdown = true
  193. if s.shutdownErr == nil {
  194. s.shutdownErr = ErrSessionShutdown
  195. }
  196. close(s.shutdownCh)
  197. s.conn.Close()
  198. <-s.recvDoneCh
  199. s.streamLock.Lock()
  200. defer s.streamLock.Unlock()
  201. for _, stream := range s.streams {
  202. stream.forceClose()
  203. }
  204. return nil
  205. }
  206. // exitErr is used to handle an error that is causing the
  207. // session to terminate.
  208. func (s *Session) exitErr(err error) {
  209. s.shutdownLock.Lock()
  210. if s.shutdownErr == nil {
  211. s.shutdownErr = err
  212. }
  213. s.shutdownLock.Unlock()
  214. s.Close()
  215. }
  216. // GoAway can be used to prevent accepting further
  217. // connections. It does not close the underlying conn.
  218. func (s *Session) GoAway() error {
  219. return s.waitForSend(s.goAway(goAwayNormal), nil)
  220. }
  221. // goAway is used to send a goAway message
  222. func (s *Session) goAway(reason uint32) header {
  223. atomic.SwapInt32(&s.localGoAway, 1)
  224. hdr := header(make([]byte, headerSize))
  225. hdr.encode(typeGoAway, 0, 0, reason)
  226. return hdr
  227. }
  228. // Ping is used to measure the RTT response time
  229. func (s *Session) Ping() (time.Duration, error) {
  230. // Get a channel for the ping
  231. ch := make(chan struct{})
  232. // Get a new ping id, mark as pending
  233. s.pingLock.Lock()
  234. id := s.pingID
  235. s.pingID++
  236. s.pings[id] = ch
  237. s.pingLock.Unlock()
  238. // Send the ping request
  239. hdr := header(make([]byte, headerSize))
  240. hdr.encode(typePing, flagSYN, 0, id)
  241. if err := s.waitForSend(hdr, nil); err != nil {
  242. return 0, err
  243. }
  244. // Wait for a response
  245. start := time.Now()
  246. select {
  247. case <-ch:
  248. case <-time.After(s.config.ConnectionWriteTimeout):
  249. s.pingLock.Lock()
  250. delete(s.pings, id) // Ignore it if a response comes later.
  251. s.pingLock.Unlock()
  252. return 0, ErrTimeout
  253. case <-s.shutdownCh:
  254. return 0, ErrSessionShutdown
  255. }
  256. // Compute the RTT
  257. return time.Now().Sub(start), nil
  258. }
  259. // keepalive is a long running goroutine that periodically does
  260. // a ping to keep the connection alive.
  261. func (s *Session) keepalive() {
  262. for {
  263. select {
  264. case <-time.After(s.config.KeepAliveInterval):
  265. _, err := s.Ping()
  266. if err != nil {
  267. s.logger.Printf("[ERR] yamux: keepalive failed: %v", err)
  268. s.exitErr(ErrKeepAliveTimeout)
  269. return
  270. }
  271. case <-s.shutdownCh:
  272. return
  273. }
  274. }
  275. }
  276. // waitForSendErr waits to send a header, checking for a potential shutdown
  277. func (s *Session) waitForSend(hdr header, body io.Reader) error {
  278. errCh := make(chan error, 1)
  279. return s.waitForSendErr(hdr, body, errCh)
  280. }
  281. // waitForSendErr waits to send a header with optional data, checking for a
  282. // potential shutdown. Since there's the expectation that sends can happen
  283. // in a timely manner, we enforce the connection write timeout here.
  284. func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error {
  285. timer := time.NewTimer(s.config.ConnectionWriteTimeout)
  286. defer timer.Stop()
  287. ready := sendReady{Hdr: hdr, Body: body, Err: errCh}
  288. select {
  289. case s.sendCh <- ready:
  290. case <-s.shutdownCh:
  291. return ErrSessionShutdown
  292. case <-timer.C:
  293. return ErrConnectionWriteTimeout
  294. }
  295. select {
  296. case err := <-errCh:
  297. return err
  298. case <-s.shutdownCh:
  299. return ErrSessionShutdown
  300. case <-timer.C:
  301. return ErrConnectionWriteTimeout
  302. }
  303. }
  304. // sendNoWait does a send without waiting. Since there's the expectation that
  305. // the send happens right here, we enforce the connection write timeout if we
  306. // can't queue the header to be sent.
  307. func (s *Session) sendNoWait(hdr header) error {
  308. timer := time.NewTimer(s.config.ConnectionWriteTimeout)
  309. defer timer.Stop()
  310. select {
  311. case s.sendCh <- sendReady{Hdr: hdr}:
  312. return nil
  313. case <-s.shutdownCh:
  314. return ErrSessionShutdown
  315. case <-timer.C:
  316. return ErrConnectionWriteTimeout
  317. }
  318. }
  319. // send is a long running goroutine that sends data
  320. func (s *Session) send() {
  321. for {
  322. select {
  323. case ready := <-s.sendCh:
  324. // Send a header if ready
  325. if ready.Hdr != nil {
  326. sent := 0
  327. for sent < len(ready.Hdr) {
  328. n, err := s.conn.Write(ready.Hdr[sent:])
  329. if err != nil {
  330. s.logger.Printf("[ERR] yamux: Failed to write header: %v", err)
  331. asyncSendErr(ready.Err, err)
  332. s.exitErr(err)
  333. return
  334. }
  335. sent += n
  336. }
  337. }
  338. // Send data from a body if given
  339. if ready.Body != nil {
  340. _, err := io.Copy(s.conn, ready.Body)
  341. if err != nil {
  342. s.logger.Printf("[ERR] yamux: Failed to write body: %v", err)
  343. asyncSendErr(ready.Err, err)
  344. s.exitErr(err)
  345. return
  346. }
  347. }
  348. // No error, successful send
  349. asyncSendErr(ready.Err, nil)
  350. case <-s.shutdownCh:
  351. return
  352. }
  353. }
  354. }
  355. // recv is a long running goroutine that accepts new data
  356. func (s *Session) recv() {
  357. if err := s.recvLoop(); err != nil {
  358. s.exitErr(err)
  359. }
  360. }
  361. // recvLoop continues to receive data until a fatal error is encountered
  362. func (s *Session) recvLoop() error {
  363. defer close(s.recvDoneCh)
  364. hdr := header(make([]byte, headerSize))
  365. var handler func(header) error
  366. for {
  367. // Read the header
  368. if _, err := io.ReadFull(s.bufRead, hdr); err != nil {
  369. if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") {
  370. s.logger.Printf("[ERR] yamux: Failed to read header: %v", err)
  371. }
  372. return err
  373. }
  374. // Verify the version
  375. if hdr.Version() != protoVersion {
  376. s.logger.Printf("[ERR] yamux: Invalid protocol version: %d", hdr.Version())
  377. return ErrInvalidVersion
  378. }
  379. // Switch on the type
  380. switch hdr.MsgType() {
  381. case typeData:
  382. handler = s.handleStreamMessage
  383. case typeWindowUpdate:
  384. handler = s.handleStreamMessage
  385. case typeGoAway:
  386. handler = s.handleGoAway
  387. case typePing:
  388. handler = s.handlePing
  389. default:
  390. return ErrInvalidMsgType
  391. }
  392. // Invoke the handler
  393. if err := handler(hdr); err != nil {
  394. return err
  395. }
  396. }
  397. }
  398. // handleStreamMessage handles either a data or window update frame
  399. func (s *Session) handleStreamMessage(hdr header) error {
  400. // Check for a new stream creation
  401. id := hdr.StreamID()
  402. flags := hdr.Flags()
  403. if flags&flagSYN == flagSYN {
  404. if err := s.incomingStream(id); err != nil {
  405. return err
  406. }
  407. }
  408. // Get the stream
  409. s.streamLock.Lock()
  410. stream := s.streams[id]
  411. s.streamLock.Unlock()
  412. // If we do not have a stream, likely we sent a RST
  413. if stream == nil {
  414. // Drain any data on the wire
  415. if hdr.MsgType() == typeData && hdr.Length() > 0 {
  416. s.logger.Printf("[WARN] yamux: Discarding data for stream: %d", id)
  417. if _, err := io.CopyN(ioutil.Discard, s.bufRead, int64(hdr.Length())); err != nil {
  418. s.logger.Printf("[ERR] yamux: Failed to discard data: %v", err)
  419. return nil
  420. }
  421. } else {
  422. s.logger.Printf("[WARN] yamux: frame for missing stream: %v", hdr)
  423. }
  424. return nil
  425. }
  426. // Check if this is a window update
  427. if hdr.MsgType() == typeWindowUpdate {
  428. if err := stream.incrSendWindow(hdr, flags); err != nil {
  429. if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
  430. s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
  431. }
  432. return err
  433. }
  434. return nil
  435. }
  436. // Read the new data
  437. if err := stream.readData(hdr, flags, s.bufRead); err != nil {
  438. if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
  439. s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
  440. }
  441. return err
  442. }
  443. return nil
  444. }
  445. // handlePing is invokde for a typePing frame
  446. func (s *Session) handlePing(hdr header) error {
  447. flags := hdr.Flags()
  448. pingID := hdr.Length()
  449. // Check if this is a query, respond back in a separate context so we
  450. // don't interfere with the receiving thread blocking for the write.
  451. if flags&flagSYN == flagSYN {
  452. go func() {
  453. hdr := header(make([]byte, headerSize))
  454. hdr.encode(typePing, flagACK, 0, pingID)
  455. if err := s.sendNoWait(hdr); err != nil {
  456. s.logger.Printf("[WARN] yamux: failed to send ping reply: %v", err)
  457. }
  458. }()
  459. return nil
  460. }
  461. // Handle a response
  462. s.pingLock.Lock()
  463. ch := s.pings[pingID]
  464. if ch != nil {
  465. delete(s.pings, pingID)
  466. close(ch)
  467. }
  468. s.pingLock.Unlock()
  469. return nil
  470. }
  471. // handleGoAway is invokde for a typeGoAway frame
  472. func (s *Session) handleGoAway(hdr header) error {
  473. code := hdr.Length()
  474. switch code {
  475. case goAwayNormal:
  476. atomic.SwapInt32(&s.remoteGoAway, 1)
  477. case goAwayProtoErr:
  478. s.logger.Printf("[ERR] yamux: received protocol error go away")
  479. return fmt.Errorf("yamux protocol error")
  480. case goAwayInternalErr:
  481. s.logger.Printf("[ERR] yamux: received internal error go away")
  482. return fmt.Errorf("remote yamux internal error")
  483. default:
  484. s.logger.Printf("[ERR] yamux: received unexpected go away")
  485. return fmt.Errorf("unexpected go away received")
  486. }
  487. return nil
  488. }
  489. // incomingStream is used to create a new incoming stream
  490. func (s *Session) incomingStream(id uint32) error {
  491. // Reject immediately if we are doing a go away
  492. if atomic.LoadInt32(&s.localGoAway) == 1 {
  493. hdr := header(make([]byte, headerSize))
  494. hdr.encode(typeWindowUpdate, flagRST, id, 0)
  495. return s.sendNoWait(hdr)
  496. }
  497. // Allocate a new stream
  498. stream := newStream(s, id, streamSYNReceived)
  499. s.streamLock.Lock()
  500. defer s.streamLock.Unlock()
  501. // Check if stream already exists
  502. if _, ok := s.streams[id]; ok {
  503. s.logger.Printf("[ERR] yamux: duplicate stream declared")
  504. if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
  505. s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
  506. }
  507. return ErrDuplicateStream
  508. }
  509. // Register the stream
  510. s.streams[id] = stream
  511. // Check if we've exceeded the backlog
  512. select {
  513. case s.acceptCh <- stream:
  514. return nil
  515. default:
  516. // Backlog exceeded! RST the stream
  517. s.logger.Printf("[WARN] yamux: backlog exceeded, forcing connection reset")
  518. delete(s.streams, id)
  519. stream.sendHdr.encode(typeWindowUpdate, flagRST, id, 0)
  520. return s.sendNoWait(stream.sendHdr)
  521. }
  522. }
  523. // closeStream is used to close a stream once both sides have
  524. // issued a close. If there was an in-flight SYN and the stream
  525. // was not yet established, then this will give the credit back.
  526. func (s *Session) closeStream(id uint32) {
  527. s.streamLock.Lock()
  528. if _, ok := s.inflight[id]; ok {
  529. select {
  530. case <-s.synCh:
  531. default:
  532. s.logger.Printf("[ERR] yamux: un-established stream without inflight syn semaphore")
  533. }
  534. }
  535. delete(s.streams, id)
  536. s.streamLock.Unlock()
  537. }
  538. // establishStream is used to mark a stream that was in the
  539. // SYN Sent state as established.
  540. func (s *Session) establishStream(id uint32) {
  541. s.streamLock.Lock()
  542. if _, ok := s.inflight[id]; ok {
  543. delete(s.inflight, id)
  544. } else {
  545. s.logger.Printf("[ERR] yamux: established stream without inflight syn entry")
  546. }
  547. select {
  548. case <-s.synCh:
  549. default:
  550. s.logger.Printf("[ERR] yamux: established stream without inflight syn semaphore")
  551. }
  552. s.streamLock.Unlock()
  553. }