stream.go 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. package yamux
  2. import (
  3. "bytes"
  4. "io"
  5. "sync"
  6. "sync/atomic"
  7. "time"
  8. )
  9. type streamState int
  10. const (
  11. streamInit streamState = iota
  12. streamSYNSent
  13. streamSYNReceived
  14. streamEstablished
  15. streamLocalClose
  16. streamRemoteClose
  17. streamClosed
  18. streamReset
  19. )
  20. // Stream is used to represent a logical stream
  21. // within a session.
  22. type Stream struct {
  23. recvWindow uint32
  24. sendWindow uint32
  25. id uint32
  26. session *Session
  27. state streamState
  28. stateLock sync.Mutex
  29. recvBuf bytes.Buffer
  30. recvLock sync.Mutex
  31. controlHdr header
  32. controlErr chan error
  33. controlHdrLock sync.Mutex
  34. sendHdr header
  35. sendErr chan error
  36. sendLock sync.Mutex
  37. recvNotifyCh chan struct{}
  38. sendNotifyCh chan struct{}
  39. readDeadline time.Time
  40. writeDeadline time.Time
  41. }
  42. // newStream is used to construct a new stream within
  43. // a given session for an ID
  44. func newStream(session *Session, id uint32, state streamState) *Stream {
  45. s := &Stream{
  46. id: id,
  47. session: session,
  48. state: state,
  49. controlHdr: header(make([]byte, headerSize)),
  50. controlErr: make(chan error, 1),
  51. sendHdr: header(make([]byte, headerSize)),
  52. sendErr: make(chan error, 1),
  53. recvWindow: initialStreamWindow,
  54. sendWindow: initialStreamWindow,
  55. recvNotifyCh: make(chan struct{}, 1),
  56. sendNotifyCh: make(chan struct{}, 1),
  57. }
  58. return s
  59. }
  60. // Session returns the associated stream session
  61. func (s *Stream) Session() *Session {
  62. return s.session
  63. }
  64. // StreamID returns the ID of this stream
  65. func (s *Stream) StreamID() uint32 {
  66. return s.id
  67. }
  68. // Read is used to read from the stream
  69. func (s *Stream) Read(b []byte) (n int, err error) {
  70. defer asyncNotify(s.recvNotifyCh)
  71. START:
  72. s.stateLock.Lock()
  73. switch s.state {
  74. case streamLocalClose:
  75. fallthrough
  76. case streamRemoteClose:
  77. fallthrough
  78. case streamClosed:
  79. if s.recvBuf.Len() == 0 {
  80. s.stateLock.Unlock()
  81. return 0, io.EOF
  82. }
  83. case streamReset:
  84. s.stateLock.Unlock()
  85. return 0, ErrConnectionReset
  86. }
  87. s.stateLock.Unlock()
  88. // If there is no data available, block
  89. s.recvLock.Lock()
  90. if s.recvBuf.Len() == 0 {
  91. s.recvLock.Unlock()
  92. goto WAIT
  93. }
  94. // Read any bytes
  95. n, _ = s.recvBuf.Read(b)
  96. s.recvLock.Unlock()
  97. // Send a window update potentially
  98. err = s.sendWindowUpdate()
  99. return n, err
  100. WAIT:
  101. var timeout <-chan time.Time
  102. if !s.readDeadline.IsZero() {
  103. delay := s.readDeadline.Sub(time.Now())
  104. timeout = time.After(delay)
  105. }
  106. select {
  107. case <-s.recvNotifyCh:
  108. goto START
  109. case <-timeout:
  110. return 0, ErrTimeout
  111. }
  112. }
  113. // Write is used to write to the stream
  114. func (s *Stream) Write(b []byte) (n int, err error) {
  115. s.sendLock.Lock()
  116. defer s.sendLock.Unlock()
  117. total := 0
  118. for total < len(b) {
  119. n, err := s.write(b[total:])
  120. total += n
  121. if err != nil {
  122. return total, err
  123. }
  124. }
  125. return total, nil
  126. }
  127. // write is used to write to the stream, may return on
  128. // a short write.
  129. func (s *Stream) write(b []byte) (n int, err error) {
  130. var flags uint16
  131. var max uint32
  132. var body io.Reader
  133. START:
  134. s.stateLock.Lock()
  135. switch s.state {
  136. case streamLocalClose:
  137. fallthrough
  138. case streamClosed:
  139. s.stateLock.Unlock()
  140. return 0, ErrStreamClosed
  141. case streamReset:
  142. s.stateLock.Unlock()
  143. return 0, ErrConnectionReset
  144. }
  145. s.stateLock.Unlock()
  146. // If there is no data available, block
  147. window := atomic.LoadUint32(&s.sendWindow)
  148. if window == 0 {
  149. goto WAIT
  150. }
  151. // Determine the flags if any
  152. flags = s.sendFlags()
  153. // Send up to our send window
  154. max = min(window, uint32(len(b)))
  155. body = bytes.NewReader(b[:max])
  156. // Send the header
  157. s.sendHdr.encode(typeData, flags, s.id, max)
  158. if err := s.session.waitForSendErr(s.sendHdr, body, s.sendErr); err != nil {
  159. return 0, err
  160. }
  161. // Reduce our send window
  162. atomic.AddUint32(&s.sendWindow, ^uint32(max-1))
  163. // Unlock
  164. return int(max), err
  165. WAIT:
  166. var timeout <-chan time.Time
  167. if !s.writeDeadline.IsZero() {
  168. delay := s.writeDeadline.Sub(time.Now())
  169. timeout = time.After(delay)
  170. }
  171. select {
  172. case <-s.sendNotifyCh:
  173. goto START
  174. case <-timeout:
  175. return 0, ErrTimeout
  176. }
  177. return 0, nil
  178. }
  179. // sendFlags determines any flags that are appropriate
  180. // based on the current stream state
  181. func (s *Stream) sendFlags() uint16 {
  182. s.stateLock.Lock()
  183. defer s.stateLock.Unlock()
  184. var flags uint16
  185. switch s.state {
  186. case streamInit:
  187. flags |= flagSYN
  188. s.state = streamSYNSent
  189. case streamSYNReceived:
  190. flags |= flagACK
  191. s.state = streamEstablished
  192. }
  193. return flags
  194. }
  195. // sendWindowUpdate potentially sends a window update enabling
  196. // further writes to take place. Must be invoked with the lock.
  197. func (s *Stream) sendWindowUpdate() error {
  198. s.controlHdrLock.Lock()
  199. defer s.controlHdrLock.Unlock()
  200. // Determine the delta update
  201. max := s.session.config.MaxStreamWindowSize
  202. delta := max - atomic.LoadUint32(&s.recvWindow)
  203. // Determine the flags if any
  204. flags := s.sendFlags()
  205. // Check if we can omit the update
  206. if delta < (max/2) && flags == 0 {
  207. return nil
  208. }
  209. // Update our window
  210. atomic.AddUint32(&s.recvWindow, delta)
  211. // Send the header
  212. s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta)
  213. if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil {
  214. return err
  215. }
  216. return nil
  217. }
  218. // sendClose is used to send a FIN
  219. func (s *Stream) sendClose() error {
  220. s.controlHdrLock.Lock()
  221. defer s.controlHdrLock.Unlock()
  222. flags := s.sendFlags()
  223. flags |= flagFIN
  224. s.controlHdr.encode(typeWindowUpdate, flags, s.id, 0)
  225. if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil {
  226. return err
  227. }
  228. return nil
  229. }
  230. // Close is used to close the stream
  231. func (s *Stream) Close() error {
  232. closeStream := false
  233. s.stateLock.Lock()
  234. switch s.state {
  235. // Opened means we need to signal a close
  236. case streamSYNSent:
  237. fallthrough
  238. case streamSYNReceived:
  239. fallthrough
  240. case streamEstablished:
  241. s.state = streamLocalClose
  242. goto SEND_CLOSE
  243. case streamLocalClose:
  244. case streamRemoteClose:
  245. s.state = streamClosed
  246. closeStream = true
  247. goto SEND_CLOSE
  248. case streamClosed:
  249. case streamReset:
  250. default:
  251. panic("unhandled state")
  252. }
  253. s.stateLock.Unlock()
  254. return nil
  255. SEND_CLOSE:
  256. s.stateLock.Unlock()
  257. s.sendClose()
  258. s.notifyWaiting()
  259. if closeStream {
  260. s.session.closeStream(s.id)
  261. }
  262. return nil
  263. }
  264. // forceClose is used for when the session is exiting
  265. func (s *Stream) forceClose() {
  266. s.stateLock.Lock()
  267. s.state = streamClosed
  268. s.stateLock.Unlock()
  269. s.notifyWaiting()
  270. }
  271. // processFlags is used to update the state of the stream
  272. // based on set flags, if any. Lock must be held
  273. func (s *Stream) processFlags(flags uint16) error {
  274. // Close the stream without holding the state lock
  275. closeStream := false
  276. defer func() {
  277. if closeStream {
  278. s.session.closeStream(s.id)
  279. }
  280. }()
  281. s.stateLock.Lock()
  282. defer s.stateLock.Unlock()
  283. if flags&flagACK == flagACK {
  284. if s.state == streamSYNSent {
  285. s.state = streamEstablished
  286. }
  287. s.session.establishStream()
  288. }
  289. if flags&flagFIN == flagFIN {
  290. switch s.state {
  291. case streamSYNSent:
  292. fallthrough
  293. case streamSYNReceived:
  294. fallthrough
  295. case streamEstablished:
  296. s.state = streamRemoteClose
  297. s.notifyWaiting()
  298. case streamLocalClose:
  299. s.state = streamClosed
  300. closeStream = true
  301. s.notifyWaiting()
  302. default:
  303. s.session.logger.Printf("[ERR] yamux: unexpected FIN flag in state %d", s.state)
  304. return ErrUnexpectedFlag
  305. }
  306. }
  307. if flags&flagRST == flagRST {
  308. if s.state == streamSYNSent {
  309. s.session.establishStream()
  310. }
  311. s.state = streamReset
  312. closeStream = true
  313. s.notifyWaiting()
  314. }
  315. return nil
  316. }
  317. // notifyWaiting notifies all the waiting channels
  318. func (s *Stream) notifyWaiting() {
  319. asyncNotify(s.recvNotifyCh)
  320. asyncNotify(s.sendNotifyCh)
  321. }
  322. // incrSendWindow updates the size of our send window
  323. func (s *Stream) incrSendWindow(hdr header, flags uint16) error {
  324. if err := s.processFlags(flags); err != nil {
  325. return err
  326. }
  327. // Increase window, unblock a sender
  328. atomic.AddUint32(&s.sendWindow, hdr.Length())
  329. asyncNotify(s.sendNotifyCh)
  330. return nil
  331. }
  332. // readData is used to handle a data frame
  333. func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
  334. if err := s.processFlags(flags); err != nil {
  335. return err
  336. }
  337. // Check that our recv window is not exceeded
  338. length := hdr.Length()
  339. if length == 0 {
  340. return nil
  341. }
  342. if remain := atomic.LoadUint32(&s.recvWindow); length > remain {
  343. s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, remain, length)
  344. return ErrRecvWindowExceeded
  345. }
  346. // Wrap in a limited reader
  347. conn = &io.LimitedReader{R: conn, N: int64(length)}
  348. // Copy into buffer
  349. s.recvLock.Lock()
  350. if _, err := io.Copy(&s.recvBuf, conn); err != nil {
  351. s.session.logger.Printf("[ERR] yamux: Failed to read stream data: %v", err)
  352. s.recvLock.Unlock()
  353. return err
  354. }
  355. // Decrement the receive window
  356. atomic.AddUint32(&s.recvWindow, ^uint32(length-1))
  357. s.recvLock.Unlock()
  358. // Unblock any readers
  359. asyncNotify(s.recvNotifyCh)
  360. return nil
  361. }
  362. // SetDeadline sets the read and write deadlines
  363. func (s *Stream) SetDeadline(t time.Time) error {
  364. if err := s.SetReadDeadline(t); err != nil {
  365. return err
  366. }
  367. if err := s.SetWriteDeadline(t); err != nil {
  368. return err
  369. }
  370. return nil
  371. }
  372. // SetReadDeadline sets the deadline for future Read calls.
  373. func (s *Stream) SetReadDeadline(t time.Time) error {
  374. s.readDeadline = t
  375. return nil
  376. }
  377. // SetWriteDeadline sets the deadline for future Write calls
  378. func (s *Stream) SetWriteDeadline(t time.Time) error {
  379. s.writeDeadline = t
  380. return nil
  381. }