session.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598
  1. package yamux
  2. import (
  3. "bufio"
  4. "fmt"
  5. "io"
  6. "io/ioutil"
  7. "log"
  8. "math"
  9. "net"
  10. "strings"
  11. "sync"
  12. "sync/atomic"
  13. "time"
  14. )
  15. // Session is used to wrap a reliable ordered connection and to
  16. // multiplex it into multiple streams.
  17. type Session struct {
  18. // remoteGoAway indicates the remote side does
  19. // not want futher connections. Must be first for alignment.
  20. remoteGoAway int32
  21. // localGoAway indicates that we should stop
  22. // accepting futher connections. Must be first for alignment.
  23. localGoAway int32
  24. // nextStreamID is the next stream we should
  25. // send. This depends if we are a client/server.
  26. nextStreamID uint32
  27. // config holds our configuration
  28. config *Config
  29. // logger is used for our logs
  30. logger *log.Logger
  31. // conn is the underlying connection
  32. conn io.ReadWriteCloser
  33. // bufRead is a buffered reader
  34. bufRead *bufio.Reader
  35. // pings is used to track inflight pings
  36. pings map[uint32]chan struct{}
  37. pingID uint32
  38. pingLock sync.Mutex
  39. // streams maps a stream id to a stream
  40. streams map[uint32]*Stream
  41. streamLock sync.Mutex
  42. // synCh acts like a semaphore. It is sized to the AcceptBacklog which
  43. // is assumed to be symmetric between the client and server. This allows
  44. // the client to avoid exceeding the backlog and instead blocks the open.
  45. synCh chan struct{}
  46. // acceptCh is used to pass ready streams to the client
  47. acceptCh chan *Stream
  48. // sendCh is used to mark a stream as ready to send,
  49. // or to send a header out directly.
  50. sendCh chan sendReady
  51. // recvDoneCh is closed when recv() exits to avoid a race
  52. // between stream registration and stream shutdown
  53. recvDoneCh chan struct{}
  54. // shutdown is used to safely close a session
  55. shutdown bool
  56. shutdownErr error
  57. shutdownCh chan struct{}
  58. shutdownLock sync.Mutex
  59. }
  60. // sendReady is used to either mark a stream as ready
  61. // or to directly send a header
  62. type sendReady struct {
  63. Hdr []byte
  64. Body io.Reader
  65. Err chan error
  66. }
  67. // newSession is used to construct a new session
  68. func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
  69. s := &Session{
  70. config: config,
  71. logger: log.New(config.LogOutput, "", log.LstdFlags),
  72. conn: conn,
  73. bufRead: bufio.NewReader(conn),
  74. pings: make(map[uint32]chan struct{}),
  75. streams: make(map[uint32]*Stream),
  76. synCh: make(chan struct{}, config.AcceptBacklog),
  77. acceptCh: make(chan *Stream, config.AcceptBacklog),
  78. sendCh: make(chan sendReady, 64),
  79. recvDoneCh: make(chan struct{}),
  80. shutdownCh: make(chan struct{}),
  81. }
  82. if client {
  83. s.nextStreamID = 1
  84. } else {
  85. s.nextStreamID = 2
  86. }
  87. go s.recv()
  88. go s.send()
  89. if config.EnableKeepAlive {
  90. go s.keepalive()
  91. }
  92. return s
  93. }
  94. // IsClosed does a safe check to see if we have shutdown
  95. func (s *Session) IsClosed() bool {
  96. select {
  97. case <-s.shutdownCh:
  98. return true
  99. default:
  100. return false
  101. }
  102. }
  103. // NumStreams returns the number of currently open streams
  104. func (s *Session) NumStreams() int {
  105. s.streamLock.Lock()
  106. num := len(s.streams)
  107. s.streamLock.Unlock()
  108. return num
  109. }
  110. // Open is used to create a new stream as a net.Conn
  111. func (s *Session) Open() (net.Conn, error) {
  112. conn, err := s.OpenStream()
  113. if err != nil {
  114. return nil, err
  115. }
  116. return conn, nil
  117. }
  118. // OpenStream is used to create a new stream
  119. func (s *Session) OpenStream() (*Stream, error) {
  120. if s.IsClosed() {
  121. return nil, ErrSessionShutdown
  122. }
  123. if atomic.LoadInt32(&s.remoteGoAway) == 1 {
  124. return nil, ErrRemoteGoAway
  125. }
  126. // Block if we have too many inflight SYNs
  127. select {
  128. case s.synCh <- struct{}{}:
  129. case <-s.shutdownCh:
  130. return nil, ErrSessionShutdown
  131. }
  132. GET_ID:
  133. // Get and ID, and check for stream exhaustion
  134. id := atomic.LoadUint32(&s.nextStreamID)
  135. if id >= math.MaxUint32-1 {
  136. return nil, ErrStreamsExhausted
  137. }
  138. if !atomic.CompareAndSwapUint32(&s.nextStreamID, id, id+2) {
  139. goto GET_ID
  140. }
  141. // Register the stream
  142. stream := newStream(s, id, streamInit)
  143. s.streamLock.Lock()
  144. s.streams[id] = stream
  145. s.streamLock.Unlock()
  146. // Send the window update to create
  147. if err := stream.sendWindowUpdate(); err != nil {
  148. return nil, err
  149. }
  150. return stream, nil
  151. }
  152. // Accept is used to block until the next available stream
  153. // is ready to be accepted.
  154. func (s *Session) Accept() (net.Conn, error) {
  155. conn, err := s.AcceptStream()
  156. if err != nil {
  157. return nil, err
  158. }
  159. return conn, err
  160. }
  161. // AcceptStream is used to block until the next available stream
  162. // is ready to be accepted.
  163. func (s *Session) AcceptStream() (*Stream, error) {
  164. select {
  165. case stream := <-s.acceptCh:
  166. if err := stream.sendWindowUpdate(); err != nil {
  167. return nil, err
  168. }
  169. return stream, nil
  170. case <-s.shutdownCh:
  171. return nil, s.shutdownErr
  172. }
  173. }
  174. // Close is used to close the session and all streams.
  175. // Attempts to send a GoAway before closing the connection.
  176. func (s *Session) Close() error {
  177. s.shutdownLock.Lock()
  178. defer s.shutdownLock.Unlock()
  179. if s.shutdown {
  180. return nil
  181. }
  182. s.shutdown = true
  183. if s.shutdownErr == nil {
  184. s.shutdownErr = ErrSessionShutdown
  185. }
  186. close(s.shutdownCh)
  187. s.conn.Close()
  188. <-s.recvDoneCh
  189. s.streamLock.Lock()
  190. defer s.streamLock.Unlock()
  191. for _, stream := range s.streams {
  192. stream.forceClose()
  193. }
  194. return nil
  195. }
  196. // exitErr is used to handle an error that is causing the
  197. // session to terminate.
  198. func (s *Session) exitErr(err error) {
  199. s.shutdownLock.Lock()
  200. if s.shutdownErr == nil {
  201. s.shutdownErr = err
  202. }
  203. s.shutdownLock.Unlock()
  204. s.Close()
  205. }
  206. // GoAway can be used to prevent accepting further
  207. // connections. It does not close the underlying conn.
  208. func (s *Session) GoAway() error {
  209. return s.waitForSend(s.goAway(goAwayNormal), nil)
  210. }
  211. // goAway is used to send a goAway message
  212. func (s *Session) goAway(reason uint32) header {
  213. atomic.SwapInt32(&s.localGoAway, 1)
  214. hdr := header(make([]byte, headerSize))
  215. hdr.encode(typeGoAway, 0, 0, reason)
  216. return hdr
  217. }
  218. // Ping is used to measure the RTT response time
  219. func (s *Session) Ping() (time.Duration, error) {
  220. // Get a channel for the ping
  221. ch := make(chan struct{})
  222. // Get a new ping id, mark as pending
  223. s.pingLock.Lock()
  224. id := s.pingID
  225. s.pingID++
  226. s.pings[id] = ch
  227. s.pingLock.Unlock()
  228. // Send the ping request
  229. hdr := header(make([]byte, headerSize))
  230. hdr.encode(typePing, flagSYN, 0, id)
  231. if err := s.waitForSend(hdr, nil); err != nil {
  232. return 0, err
  233. }
  234. // Wait for a response
  235. start := time.Now()
  236. select {
  237. case <-ch:
  238. case <-time.After(s.config.ConnectionWriteTimeout):
  239. s.pingLock.Lock()
  240. delete(s.pings, id) // Ignore it if a response comes later.
  241. s.pingLock.Unlock()
  242. return 0, ErrTimeout
  243. case <-s.shutdownCh:
  244. return 0, ErrSessionShutdown
  245. }
  246. // Compute the RTT
  247. return time.Now().Sub(start), nil
  248. }
  249. // keepalive is a long running goroutine that periodically does
  250. // a ping to keep the connection alive.
  251. func (s *Session) keepalive() {
  252. for {
  253. select {
  254. case <-time.After(s.config.KeepAliveInterval):
  255. _, err := s.Ping()
  256. if err != nil {
  257. s.logger.Printf("[ERR] yamux: keepalive failed: %v", err)
  258. s.exitErr(ErrKeepAliveTimeout)
  259. return
  260. }
  261. case <-s.shutdownCh:
  262. return
  263. }
  264. }
  265. }
  266. // waitForSendErr waits to send a header, checking for a potential shutdown
  267. func (s *Session) waitForSend(hdr header, body io.Reader) error {
  268. errCh := make(chan error, 1)
  269. return s.waitForSendErr(hdr, body, errCh)
  270. }
  271. // waitForSendErr waits to send a header with optional data, checking for a
  272. // potential shutdown. Since there's the expectation that sends can happen
  273. // in a timely manner, we enforce the connection write timeout here.
  274. func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error {
  275. timer := time.NewTimer(s.config.ConnectionWriteTimeout)
  276. defer timer.Stop()
  277. ready := sendReady{Hdr: hdr, Body: body, Err: errCh}
  278. select {
  279. case s.sendCh <- ready:
  280. case <-s.shutdownCh:
  281. return ErrSessionShutdown
  282. case <-timer.C:
  283. return ErrConnectionWriteTimeout
  284. }
  285. select {
  286. case err := <-errCh:
  287. return err
  288. case <-s.shutdownCh:
  289. return ErrSessionShutdown
  290. case <-timer.C:
  291. return ErrConnectionWriteTimeout
  292. }
  293. }
  294. // sendNoWait does a send without waiting. Since there's the expectation that
  295. // the send happens right here, we enforce the connection write timeout if we
  296. // can't queue the header to be sent.
  297. func (s *Session) sendNoWait(hdr header) error {
  298. timer := time.NewTimer(s.config.ConnectionWriteTimeout)
  299. defer timer.Stop()
  300. select {
  301. case s.sendCh <- sendReady{Hdr: hdr}:
  302. return nil
  303. case <-s.shutdownCh:
  304. return ErrSessionShutdown
  305. case <-timer.C:
  306. return ErrConnectionWriteTimeout
  307. }
  308. }
  309. // send is a long running goroutine that sends data
  310. func (s *Session) send() {
  311. for {
  312. select {
  313. case ready := <-s.sendCh:
  314. // Send a header if ready
  315. if ready.Hdr != nil {
  316. sent := 0
  317. for sent < len(ready.Hdr) {
  318. n, err := s.conn.Write(ready.Hdr[sent:])
  319. if err != nil {
  320. s.logger.Printf("[ERR] yamux: Failed to write header: %v", err)
  321. asyncSendErr(ready.Err, err)
  322. s.exitErr(err)
  323. return
  324. }
  325. sent += n
  326. }
  327. }
  328. // Send data from a body if given
  329. if ready.Body != nil {
  330. _, err := io.Copy(s.conn, ready.Body)
  331. if err != nil {
  332. s.logger.Printf("[ERR] yamux: Failed to write body: %v", err)
  333. asyncSendErr(ready.Err, err)
  334. s.exitErr(err)
  335. return
  336. }
  337. }
  338. // No error, successful send
  339. asyncSendErr(ready.Err, nil)
  340. case <-s.shutdownCh:
  341. return
  342. }
  343. }
  344. }
  345. // recv is a long running goroutine that accepts new data
  346. func (s *Session) recv() {
  347. if err := s.recvLoop(); err != nil {
  348. s.exitErr(err)
  349. }
  350. }
  351. // recvLoop continues to receive data until a fatal error is encountered
  352. func (s *Session) recvLoop() error {
  353. defer close(s.recvDoneCh)
  354. hdr := header(make([]byte, headerSize))
  355. var handler func(header) error
  356. for {
  357. // Read the header
  358. if _, err := io.ReadFull(s.bufRead, hdr); err != nil {
  359. if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") {
  360. s.logger.Printf("[ERR] yamux: Failed to read header: %v", err)
  361. }
  362. return err
  363. }
  364. // Verify the version
  365. if hdr.Version() != protoVersion {
  366. s.logger.Printf("[ERR] yamux: Invalid protocol version: %d", hdr.Version())
  367. return ErrInvalidVersion
  368. }
  369. // Switch on the type
  370. switch hdr.MsgType() {
  371. case typeData:
  372. handler = s.handleStreamMessage
  373. case typeWindowUpdate:
  374. handler = s.handleStreamMessage
  375. case typeGoAway:
  376. handler = s.handleGoAway
  377. case typePing:
  378. handler = s.handlePing
  379. default:
  380. return ErrInvalidMsgType
  381. }
  382. // Invoke the handler
  383. if err := handler(hdr); err != nil {
  384. return err
  385. }
  386. }
  387. }
  388. // handleStreamMessage handles either a data or window update frame
  389. func (s *Session) handleStreamMessage(hdr header) error {
  390. // Check for a new stream creation
  391. id := hdr.StreamID()
  392. flags := hdr.Flags()
  393. if flags&flagSYN == flagSYN {
  394. if err := s.incomingStream(id); err != nil {
  395. return err
  396. }
  397. }
  398. // Get the stream
  399. s.streamLock.Lock()
  400. stream := s.streams[id]
  401. s.streamLock.Unlock()
  402. // If we do not have a stream, likely we sent a RST
  403. if stream == nil {
  404. // Drain any data on the wire
  405. if hdr.MsgType() == typeData && hdr.Length() > 0 {
  406. s.logger.Printf("[WARN] yamux: Discarding data for stream: %d", id)
  407. if _, err := io.CopyN(ioutil.Discard, s.bufRead, int64(hdr.Length())); err != nil {
  408. s.logger.Printf("[ERR] yamux: Failed to discard data: %v", err)
  409. return nil
  410. }
  411. } else {
  412. s.logger.Printf("[WARN] yamux: frame for missing stream: %v", hdr)
  413. }
  414. return nil
  415. }
  416. // Check if this is a window update
  417. if hdr.MsgType() == typeWindowUpdate {
  418. if err := stream.incrSendWindow(hdr, flags); err != nil {
  419. if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
  420. s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
  421. }
  422. return err
  423. }
  424. return nil
  425. }
  426. // Read the new data
  427. if err := stream.readData(hdr, flags, s.bufRead); err != nil {
  428. if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
  429. s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
  430. }
  431. return err
  432. }
  433. return nil
  434. }
  435. // handlePing is invokde for a typePing frame
  436. func (s *Session) handlePing(hdr header) error {
  437. flags := hdr.Flags()
  438. pingID := hdr.Length()
  439. // Check if this is a query, respond back in a separate context so we
  440. // don't interfere with the receiving thread blocking for the write.
  441. if flags&flagSYN == flagSYN {
  442. go func() {
  443. hdr := header(make([]byte, headerSize))
  444. hdr.encode(typePing, flagACK, 0, pingID)
  445. if err := s.sendNoWait(hdr); err != nil {
  446. s.logger.Printf("[WARN] yamux: failed to send ping reply: %v", err)
  447. }
  448. }()
  449. return nil
  450. }
  451. // Handle a response
  452. s.pingLock.Lock()
  453. ch := s.pings[pingID]
  454. if ch != nil {
  455. delete(s.pings, pingID)
  456. close(ch)
  457. }
  458. s.pingLock.Unlock()
  459. return nil
  460. }
  461. // handleGoAway is invokde for a typeGoAway frame
  462. func (s *Session) handleGoAway(hdr header) error {
  463. code := hdr.Length()
  464. switch code {
  465. case goAwayNormal:
  466. atomic.SwapInt32(&s.remoteGoAway, 1)
  467. case goAwayProtoErr:
  468. s.logger.Printf("[ERR] yamux: received protocol error go away")
  469. return fmt.Errorf("yamux protocol error")
  470. case goAwayInternalErr:
  471. s.logger.Printf("[ERR] yamux: received internal error go away")
  472. return fmt.Errorf("remote yamux internal error")
  473. default:
  474. s.logger.Printf("[ERR] yamux: received unexpected go away")
  475. return fmt.Errorf("unexpected go away received")
  476. }
  477. return nil
  478. }
  479. // incomingStream is used to create a new incoming stream
  480. func (s *Session) incomingStream(id uint32) error {
  481. // Reject immediately if we are doing a go away
  482. if atomic.LoadInt32(&s.localGoAway) == 1 {
  483. hdr := header(make([]byte, headerSize))
  484. hdr.encode(typeWindowUpdate, flagRST, id, 0)
  485. return s.sendNoWait(hdr)
  486. }
  487. // Allocate a new stream
  488. stream := newStream(s, id, streamSYNReceived)
  489. s.streamLock.Lock()
  490. defer s.streamLock.Unlock()
  491. // Check if stream already exists
  492. if _, ok := s.streams[id]; ok {
  493. s.logger.Printf("[ERR] yamux: duplicate stream declared")
  494. if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
  495. s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
  496. }
  497. return ErrDuplicateStream
  498. }
  499. // Register the stream
  500. s.streams[id] = stream
  501. // Check if we've exceeded the backlog
  502. select {
  503. case s.acceptCh <- stream:
  504. return nil
  505. default:
  506. // Backlog exceeded! RST the stream
  507. s.logger.Printf("[WARN] yamux: backlog exceeded, forcing connection reset")
  508. delete(s.streams, id)
  509. stream.sendHdr.encode(typeWindowUpdate, flagRST, id, 0)
  510. return s.sendNoWait(stream.sendHdr)
  511. }
  512. }
  513. // closeStream is used to close a stream once both sides have
  514. // issued a close.
  515. func (s *Session) closeStream(id uint32) {
  516. s.streamLock.Lock()
  517. delete(s.streams, id)
  518. s.streamLock.Unlock()
  519. }
  520. // establishStream is used to mark a stream that was in the
  521. // SYN Sent state as established.
  522. func (s *Session) establishStream() {
  523. select {
  524. case <-s.synCh:
  525. default:
  526. panic("established stream without inflight syn")
  527. }
  528. }