session.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. package yamux
  2. import (
  3. "bufio"
  4. "fmt"
  5. "io"
  6. "io/ioutil"
  7. "log"
  8. "math"
  9. "net"
  10. "sync"
  11. "sync/atomic"
  12. "time"
  13. )
  14. // Session is used to wrap a reliable ordered connection and to
  15. // multiplex it into multiple streams.
  16. type Session struct {
  17. // remoteGoAway indicates the remote side does
  18. // not want futher connections. Must be first for alignment.
  19. remoteGoAway int32
  20. // localGoAway indicates that we should stop
  21. // accepting futher connections. Must be first for alignment.
  22. localGoAway int32
  23. // nextStreamID is the next stream we should
  24. // send. This depends if we are a client/server.
  25. nextStreamID uint32
  26. // config holds our configuration
  27. config *Config
  28. // logger is used for our logs
  29. logger *log.Logger
  30. // conn is the underlying connection
  31. conn io.ReadWriteCloser
  32. // bufRead is a buffered reader
  33. bufRead *bufio.Reader
  34. // pings is used to track inflight pings
  35. pings map[uint32]chan struct{}
  36. pingID uint32
  37. pingLock sync.Mutex
  38. // streams maps a stream id to a stream
  39. streams map[uint32]*Stream
  40. streamLock sync.Mutex
  41. // acceptCh is used to pass ready streams to the client
  42. acceptCh chan *Stream
  43. // sendCh is used to mark a stream as ready to send,
  44. // or to send a header out directly.
  45. sendCh chan sendReady
  46. // shutdown is used to safely close a session
  47. shutdown bool
  48. shutdownErr error
  49. shutdownCh chan struct{}
  50. shutdownLock sync.Mutex
  51. }
  52. // sendReady is used to either mark a stream as ready
  53. // or to directly send a header
  54. type sendReady struct {
  55. Hdr []byte
  56. Body io.Reader
  57. Err chan error
  58. }
  59. // newSession is used to construct a new session
  60. func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
  61. s := &Session{
  62. config: config,
  63. logger: log.New(config.LogOutput, "", log.LstdFlags),
  64. conn: conn,
  65. bufRead: bufio.NewReader(conn),
  66. pings: make(map[uint32]chan struct{}),
  67. streams: make(map[uint32]*Stream),
  68. acceptCh: make(chan *Stream, config.AcceptBacklog),
  69. sendCh: make(chan sendReady, 64),
  70. shutdownCh: make(chan struct{}),
  71. }
  72. if client {
  73. s.nextStreamID = 1
  74. } else {
  75. s.nextStreamID = 2
  76. }
  77. go s.recv()
  78. go s.send()
  79. if config.EnableKeepAlive {
  80. go s.keepalive()
  81. }
  82. return s
  83. }
  84. // IsClosed does a safe check to see if we have shutdown
  85. func (s *Session) IsClosed() bool {
  86. select {
  87. case <-s.shutdownCh:
  88. return true
  89. default:
  90. return false
  91. }
  92. }
  93. // Open is used to create a new stream as a net.Conn
  94. func (s *Session) Open() (net.Conn, error) {
  95. return s.OpenStream()
  96. }
  97. // OpenStream is used to create a new stream
  98. func (s *Session) OpenStream() (*Stream, error) {
  99. if s.IsClosed() {
  100. return nil, ErrSessionShutdown
  101. }
  102. if atomic.LoadInt32(&s.remoteGoAway) == 1 {
  103. return nil, ErrRemoteGoAway
  104. }
  105. GET_ID:
  106. // Get and ID, and check for stream exhaustion
  107. id := atomic.LoadUint32(&s.nextStreamID)
  108. if id >= math.MaxUint32-1 {
  109. return nil, ErrStreamsExhausted
  110. }
  111. if !atomic.CompareAndSwapUint32(&s.nextStreamID, id, id+2) {
  112. goto GET_ID
  113. }
  114. // Register the stream
  115. stream := newStream(s, id, streamInit)
  116. s.streamLock.Lock()
  117. s.streams[id] = stream
  118. s.streamLock.Unlock()
  119. // Send the window update to create
  120. return stream, stream.sendWindowUpdate()
  121. }
  122. // Accept is used to block until the next available stream
  123. // is ready to be accepted.
  124. func (s *Session) Accept() (net.Conn, error) {
  125. return s.AcceptStream()
  126. }
  127. // AcceptStream is used to block until the next available stream
  128. // is ready to be accepted.
  129. func (s *Session) AcceptStream() (*Stream, error) {
  130. select {
  131. case stream := <-s.acceptCh:
  132. return stream, stream.sendWindowUpdate()
  133. case <-s.shutdownCh:
  134. return nil, s.shutdownErr
  135. }
  136. }
  137. // Close is used to close the session and all streams.
  138. // Attempts to send a GoAway before closing the connection.
  139. func (s *Session) Close() error {
  140. s.shutdownLock.Lock()
  141. defer s.shutdownLock.Unlock()
  142. if s.shutdown {
  143. return nil
  144. }
  145. s.shutdown = true
  146. if s.shutdownErr == nil {
  147. s.shutdownErr = ErrSessionShutdown
  148. }
  149. close(s.shutdownCh)
  150. s.conn.Close()
  151. s.streamLock.Lock()
  152. defer s.streamLock.Unlock()
  153. for _, stream := range s.streams {
  154. stream.forceClose()
  155. }
  156. return nil
  157. }
  158. // exitErr is used to handle an error that is causing the
  159. // session to terminate.
  160. func (s *Session) exitErr(err error) {
  161. s.shutdownLock.Lock()
  162. if s.shutdownErr == nil {
  163. s.shutdownErr = err
  164. }
  165. s.shutdownLock.Unlock()
  166. s.Close()
  167. }
  168. // GoAway can be used to prevent accepting further
  169. // connections. It does not close the underlying conn.
  170. func (s *Session) GoAway() error {
  171. return s.waitForSend(s.goAway(goAwayNormal), nil)
  172. }
  173. // goAway is used to send a goAway message
  174. func (s *Session) goAway(reason uint32) header {
  175. atomic.SwapInt32(&s.localGoAway, 1)
  176. hdr := header(make([]byte, headerSize))
  177. hdr.encode(typeGoAway, 0, 0, reason)
  178. return hdr
  179. }
  180. // Ping is used to measure the RTT response time
  181. func (s *Session) Ping() (time.Duration, error) {
  182. // Get a channel for the ping
  183. ch := make(chan struct{})
  184. // Get a new ping id, mark as pending
  185. s.pingLock.Lock()
  186. id := s.pingID
  187. s.pingID++
  188. s.pings[id] = ch
  189. s.pingLock.Unlock()
  190. // Send the ping request
  191. hdr := header(make([]byte, headerSize))
  192. hdr.encode(typePing, flagSYN, 0, id)
  193. if err := s.waitForSend(hdr, nil); err != nil {
  194. return 0, err
  195. }
  196. // Wait for a response
  197. start := time.Now()
  198. select {
  199. case <-ch:
  200. case <-s.shutdownCh:
  201. return 0, ErrSessionShutdown
  202. }
  203. // Compute the RTT
  204. return time.Now().Sub(start), nil
  205. }
  206. // keepalive is a long running goroutine that periodically does
  207. // a ping to keep the connection alive.
  208. func (s *Session) keepalive() {
  209. for {
  210. select {
  211. case <-time.After(s.config.KeepAliveInterval):
  212. s.Ping()
  213. case <-s.shutdownCh:
  214. return
  215. }
  216. }
  217. }
  218. // waitForSendErr waits to send a header, checking for a potential shutdown
  219. func (s *Session) waitForSend(hdr header, body io.Reader) error {
  220. errCh := make(chan error, 1)
  221. return s.waitForSendErr(hdr, body, errCh)
  222. }
  223. // waitForSendErr waits to send a header, checking for a potential shutdown
  224. func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error {
  225. ready := sendReady{Hdr: hdr, Body: body, Err: errCh}
  226. select {
  227. case s.sendCh <- ready:
  228. case <-s.shutdownCh:
  229. return ErrSessionShutdown
  230. }
  231. select {
  232. case err := <-errCh:
  233. return err
  234. case <-s.shutdownCh:
  235. return ErrSessionShutdown
  236. }
  237. }
  238. // sendNoWait does a send without waiting
  239. func (s *Session) sendNoWait(hdr header) error {
  240. select {
  241. case s.sendCh <- sendReady{Hdr: hdr}:
  242. return nil
  243. case <-s.shutdownCh:
  244. return ErrSessionShutdown
  245. }
  246. }
  247. // send is a long running goroutine that sends data
  248. func (s *Session) send() {
  249. for {
  250. select {
  251. case ready := <-s.sendCh:
  252. // Send a header if ready
  253. if ready.Hdr != nil {
  254. sent := 0
  255. for sent < len(ready.Hdr) {
  256. n, err := s.conn.Write(ready.Hdr[sent:])
  257. if err != nil {
  258. s.logger.Printf("[ERR] yamux: Failed to write header: %v", err)
  259. asyncSendErr(ready.Err, err)
  260. s.exitErr(err)
  261. return
  262. }
  263. sent += n
  264. }
  265. }
  266. // Send data from a body if given
  267. if ready.Body != nil {
  268. _, err := io.Copy(s.conn, ready.Body)
  269. if err != nil {
  270. s.logger.Printf("[ERR] yamux: Failed to write body: %v", err)
  271. asyncSendErr(ready.Err, err)
  272. s.exitErr(err)
  273. return
  274. }
  275. }
  276. // No error, successful send
  277. asyncSendErr(ready.Err, nil)
  278. case <-s.shutdownCh:
  279. return
  280. }
  281. }
  282. }
  283. // recv is a long running goroutine that accepts new data
  284. func (s *Session) recv() {
  285. hdr := header(make([]byte, headerSize))
  286. var handler func(header) error
  287. for {
  288. // Read the header
  289. if _, err := io.ReadFull(s.bufRead, hdr); err != nil {
  290. if err != io.EOF {
  291. s.logger.Printf("[ERR] yamux: Failed to read header: %v", err)
  292. }
  293. s.exitErr(err)
  294. return
  295. }
  296. // Verify the version
  297. if hdr.Version() != protoVersion {
  298. s.logger.Printf("[ERR] yamux: Invalid protocol version: %d", hdr.Version())
  299. s.exitErr(ErrInvalidVersion)
  300. return
  301. }
  302. // Switch on the type
  303. switch hdr.MsgType() {
  304. case typeData:
  305. handler = s.handleStreamMessage
  306. case typeWindowUpdate:
  307. handler = s.handleStreamMessage
  308. case typeGoAway:
  309. handler = s.handleGoAway
  310. case typePing:
  311. handler = s.handlePing
  312. default:
  313. s.exitErr(ErrInvalidMsgType)
  314. return
  315. }
  316. // Invoke the handler
  317. if err := handler(hdr); err != nil {
  318. s.exitErr(err)
  319. return
  320. }
  321. }
  322. }
  323. // handleStreamMessage handles either a data or window update frame
  324. func (s *Session) handleStreamMessage(hdr header) error {
  325. // Check for a new stream creation
  326. id := hdr.StreamID()
  327. flags := hdr.Flags()
  328. if flags&flagSYN == flagSYN {
  329. if err := s.incomingStream(id); err != nil {
  330. return err
  331. }
  332. }
  333. // Get the stream
  334. s.streamLock.Lock()
  335. stream := s.streams[id]
  336. s.streamLock.Unlock()
  337. // If we do not have a stream, likely we sent a RST
  338. if stream == nil {
  339. // Drain any data on the wire
  340. if hdr.MsgType() == typeData && hdr.Length() > 0 {
  341. if _, err := io.CopyN(ioutil.Discard, s.bufRead, int64(hdr.Length())); err != nil {
  342. s.logger.Printf("[ERR] yamux: Failed to discard data: %v", err)
  343. return nil
  344. }
  345. }
  346. return nil
  347. }
  348. // Check if this is a window update
  349. if hdr.MsgType() == typeWindowUpdate {
  350. if err := stream.incrSendWindow(hdr, flags); err != nil {
  351. s.sendNoWait(s.goAway(goAwayProtoErr))
  352. return err
  353. }
  354. return nil
  355. }
  356. // Read the new data
  357. if err := stream.readData(hdr, flags, s.bufRead); err != nil {
  358. s.sendNoWait(s.goAway(goAwayProtoErr))
  359. return err
  360. }
  361. return nil
  362. }
  363. // handlePing is invokde for a typePing frame
  364. func (s *Session) handlePing(hdr header) error {
  365. flags := hdr.Flags()
  366. pingID := hdr.Length()
  367. // Check if this is a query, respond back
  368. if flags&flagSYN == flagSYN {
  369. hdr := header(make([]byte, headerSize))
  370. hdr.encode(typePing, flagACK, 0, pingID)
  371. s.sendNoWait(hdr)
  372. return nil
  373. }
  374. // Handle a response
  375. s.pingLock.Lock()
  376. ch := s.pings[pingID]
  377. if ch != nil {
  378. delete(s.pings, pingID)
  379. close(ch)
  380. }
  381. s.pingLock.Unlock()
  382. return nil
  383. }
  384. // handleGoAway is invokde for a typeGoAway frame
  385. func (s *Session) handleGoAway(hdr header) error {
  386. code := hdr.Length()
  387. switch code {
  388. case goAwayNormal:
  389. atomic.SwapInt32(&s.remoteGoAway, 1)
  390. case goAwayProtoErr:
  391. s.logger.Printf("[ERR] yamux: received protocol error go away")
  392. return fmt.Errorf("yamux protocol error")
  393. case goAwayInternalErr:
  394. s.logger.Printf("[ERR] yamux: received internal error go away")
  395. return fmt.Errorf("remote yamux internal error")
  396. default:
  397. s.logger.Printf("[ERR] yamux: received unexpected go away")
  398. return fmt.Errorf("unexpected go away received")
  399. }
  400. return nil
  401. }
  402. // incomingStream is used to create a new incoming stream
  403. func (s *Session) incomingStream(id uint32) error {
  404. // Reject immediately if we are doing a go away
  405. if atomic.LoadInt32(&s.localGoAway) == 1 {
  406. hdr := header(make([]byte, headerSize))
  407. hdr.encode(typeWindowUpdate, flagRST, id, 0)
  408. return s.sendNoWait(hdr)
  409. }
  410. // Allocate a new stream
  411. stream := newStream(s, id, streamSYNReceived)
  412. s.streamLock.Lock()
  413. defer s.streamLock.Unlock()
  414. // Check if stream already exists
  415. if _, ok := s.streams[id]; ok {
  416. s.logger.Printf("[ERR] yamux: duplicate stream declared")
  417. s.sendNoWait(s.goAway(goAwayProtoErr))
  418. return ErrDuplicateStream
  419. }
  420. // Register the stream
  421. s.streams[id] = stream
  422. // Check if we've exceeded the backlog
  423. select {
  424. case s.acceptCh <- stream:
  425. return nil
  426. default:
  427. // Backlog exceeded! RST the stream
  428. s.logger.Printf("[WARN] yamux: backlog exceeded, forcing connection reset")
  429. delete(s.streams, id)
  430. stream.sendHdr.encode(typeWindowUpdate, flagRST, id, 0)
  431. return s.sendNoWait(stream.sendHdr)
  432. }
  433. }
  434. // closeStream is used to close a stream once both sides have
  435. // issued a close.
  436. func (s *Session) closeStream(id uint32, withLock bool) {
  437. if !withLock {
  438. s.streamLock.Lock()
  439. defer s.streamLock.Unlock()
  440. }
  441. delete(s.streams, id)
  442. }