stream.go 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. package yamux
  2. import (
  3. "bytes"
  4. "io"
  5. "sync"
  6. "sync/atomic"
  7. "time"
  8. )
  9. type streamState int
  10. const (
  11. streamInit streamState = iota
  12. streamSYNSent
  13. streamSYNReceived
  14. streamEstablished
  15. streamLocalClose
  16. streamRemoteClose
  17. streamClosed
  18. )
  19. // Stream is used to represent a logical stream
  20. // within a session.
  21. type Stream struct {
  22. recvWindow uint32
  23. sendWindow uint32
  24. id uint32
  25. session *Session
  26. state streamState
  27. stateLock sync.Mutex
  28. recvBuf bytes.Buffer
  29. waitingBuffer *directBuffer
  30. recvLock sync.Mutex
  31. controlHdr header
  32. controlErr chan error
  33. controlHdrLock sync.Mutex
  34. sendHdr header
  35. sendErr chan error
  36. sendLock sync.Mutex
  37. recvNotifyCh chan struct{}
  38. sendNotifyCh chan struct{}
  39. readDeadline time.Time
  40. writeDeadline time.Time
  41. }
  42. type directBuffer struct {
  43. buf []byte
  44. bytes int
  45. }
  46. // newStream is used to construct a new stream within
  47. // a given session for an ID
  48. func newStream(session *Session, id uint32, state streamState) *Stream {
  49. s := &Stream{
  50. id: id,
  51. session: session,
  52. state: state,
  53. controlHdr: header(make([]byte, headerSize)),
  54. controlErr: make(chan error, 1),
  55. sendHdr: header(make([]byte, headerSize)),
  56. sendErr: make(chan error, 1),
  57. recvWindow: initialStreamWindow,
  58. sendWindow: initialStreamWindow,
  59. recvNotifyCh: make(chan struct{}, 1),
  60. sendNotifyCh: make(chan struct{}, 1),
  61. }
  62. return s
  63. }
  64. // Session returns the associated stream session
  65. func (s *Stream) Session() *Session {
  66. return s.session
  67. }
  68. // StreamID returns the ID of this stream
  69. func (s *Stream) StreamID() uint32 {
  70. return s.id
  71. }
  72. // Read is used to read from the stream
  73. func (s *Stream) Read(b []byte) (n int, err error) {
  74. defer asyncNotify(s.recvNotifyCh)
  75. var dBuf *directBuffer
  76. START:
  77. s.stateLock.Lock()
  78. switch s.state {
  79. case streamRemoteClose:
  80. fallthrough
  81. case streamClosed:
  82. if s.recvBuf.Len() == 0 {
  83. s.stateLock.Unlock()
  84. return 0, io.EOF
  85. }
  86. }
  87. s.stateLock.Unlock()
  88. // If there is no data available, block
  89. s.recvLock.Lock()
  90. if s.recvBuf.Len() == 0 {
  91. // Mark ourself as waiting potentially
  92. if s.waitingBuffer == nil {
  93. dBuf = &directBuffer{buf: b}
  94. s.waitingBuffer = dBuf
  95. }
  96. s.recvLock.Unlock()
  97. goto WAIT
  98. }
  99. // Read any bytes
  100. n, _ = s.recvBuf.Read(b)
  101. s.recvLock.Unlock()
  102. // Send a window update potentially
  103. err = s.sendWindowUpdate()
  104. return n, err
  105. WAIT:
  106. var timeout <-chan time.Time
  107. if !s.readDeadline.IsZero() {
  108. delay := s.readDeadline.Sub(time.Now())
  109. timeout = time.After(delay)
  110. }
  111. select {
  112. case <-s.recvNotifyCh:
  113. if dBuf != nil && dBuf.bytes > 0 {
  114. err = s.sendWindowUpdate()
  115. return dBuf.bytes, err
  116. }
  117. goto START
  118. case <-timeout:
  119. if dBuf != nil && dBuf.bytes > 0 {
  120. err = s.sendWindowUpdate()
  121. return dBuf.bytes, err
  122. }
  123. return 0, ErrTimeout
  124. }
  125. }
  126. // Write is used to write to the stream
  127. func (s *Stream) Write(b []byte) (n int, err error) {
  128. s.sendLock.Lock()
  129. defer s.sendLock.Unlock()
  130. total := 0
  131. for total < len(b) {
  132. n, err := s.write(b[total:])
  133. total += n
  134. if err != nil {
  135. return total, err
  136. }
  137. }
  138. return total, nil
  139. }
  140. // write is used to write to the stream, may return on
  141. // a short write.
  142. func (s *Stream) write(b []byte) (n int, err error) {
  143. var flags uint16
  144. var max uint32
  145. var body io.Reader
  146. START:
  147. s.stateLock.Lock()
  148. switch s.state {
  149. case streamLocalClose:
  150. fallthrough
  151. case streamClosed:
  152. s.stateLock.Unlock()
  153. return 0, ErrStreamClosed
  154. }
  155. s.stateLock.Unlock()
  156. // If there is no data available, block
  157. if atomic.LoadUint32(&s.sendWindow) == 0 {
  158. goto WAIT
  159. }
  160. // Determine the flags if any
  161. flags = s.sendFlags()
  162. // Send up to our send window
  163. max = min(s.sendWindow, uint32(len(b)))
  164. body = bytes.NewReader(b[:max])
  165. // Send the header
  166. s.sendHdr.encode(typeData, flags, s.id, max)
  167. if err := s.session.waitForSendErr(s.sendHdr, body, s.sendErr); err != nil {
  168. return 0, err
  169. }
  170. // Reduce our send window
  171. atomic.AddUint32(&s.sendWindow, ^uint32(max-1))
  172. // Unlock
  173. return int(max), err
  174. WAIT:
  175. var timeout <-chan time.Time
  176. if !s.writeDeadline.IsZero() {
  177. delay := s.writeDeadline.Sub(time.Now())
  178. timeout = time.After(delay)
  179. }
  180. select {
  181. case <-s.sendNotifyCh:
  182. goto START
  183. case <-timeout:
  184. return 0, ErrTimeout
  185. }
  186. return 0, nil
  187. }
  188. // sendFlags determines any flags that are appropriate
  189. // based on the current stream state
  190. func (s *Stream) sendFlags() uint16 {
  191. s.stateLock.Lock()
  192. defer s.stateLock.Unlock()
  193. var flags uint16
  194. switch s.state {
  195. case streamInit:
  196. flags |= flagSYN
  197. s.state = streamSYNSent
  198. case streamSYNReceived:
  199. flags |= flagACK
  200. s.state = streamEstablished
  201. }
  202. return flags
  203. }
  204. // sendWindowUpdate potentially sends a window update enabling
  205. // further writes to take place. Must be invoked with the lock.
  206. func (s *Stream) sendWindowUpdate() error {
  207. s.controlHdrLock.Lock()
  208. defer s.controlHdrLock.Unlock()
  209. // Determine the delta update
  210. max := s.session.config.MaxStreamWindowSize
  211. delta := max - s.recvWindow
  212. // Determine the flags if any
  213. flags := s.sendFlags()
  214. // Check if we can omit the update
  215. if delta < (max/2) && flags == 0 {
  216. return nil
  217. }
  218. // Send the header
  219. s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta)
  220. if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil {
  221. return err
  222. }
  223. // Update our window
  224. s.recvWindow += delta
  225. return nil
  226. }
  227. // sendClose is used to send a FIN
  228. func (s *Stream) sendClose() error {
  229. s.controlHdrLock.Lock()
  230. defer s.controlHdrLock.Unlock()
  231. flags := s.sendFlags()
  232. flags |= flagFIN
  233. s.controlHdr.encode(typeWindowUpdate, flags, s.id, 0)
  234. if err := s.session.sendNoWait(s.controlHdr); err != nil {
  235. return err
  236. }
  237. return nil
  238. }
  239. // Close is used to close the stream
  240. func (s *Stream) Close() error {
  241. s.stateLock.Lock()
  242. switch s.state {
  243. // Opened means we need to signal a close
  244. case streamSYNSent:
  245. fallthrough
  246. case streamSYNReceived:
  247. fallthrough
  248. case streamEstablished:
  249. s.state = streamLocalClose
  250. goto SEND_CLOSE
  251. case streamLocalClose:
  252. case streamRemoteClose:
  253. s.state = streamClosed
  254. s.session.closeStream(s.id, false)
  255. goto SEND_CLOSE
  256. case streamClosed:
  257. default:
  258. panic("unhandled state")
  259. }
  260. s.stateLock.Unlock()
  261. return nil
  262. SEND_CLOSE:
  263. s.stateLock.Unlock()
  264. s.sendClose()
  265. s.notifyWaiting()
  266. return nil
  267. }
  268. // forceClose is used for when the session is exiting
  269. func (s *Stream) forceClose() {
  270. s.stateLock.Lock()
  271. s.state = streamClosed
  272. s.stateLock.Unlock()
  273. s.notifyWaiting()
  274. }
  275. // processFlags is used to update the state of the stream
  276. // based on set flags, if any. Lock must be held
  277. func (s *Stream) processFlags(flags uint16) error {
  278. s.stateLock.Lock()
  279. defer s.stateLock.Unlock()
  280. if flags&flagACK == flagACK {
  281. if s.state == streamSYNSent {
  282. s.state = streamEstablished
  283. }
  284. } else if flags&flagFIN == flagFIN {
  285. switch s.state {
  286. case streamSYNSent:
  287. fallthrough
  288. case streamSYNReceived:
  289. fallthrough
  290. case streamEstablished:
  291. s.state = streamRemoteClose
  292. s.notifyWaiting()
  293. case streamLocalClose:
  294. s.state = streamClosed
  295. s.session.closeStream(s.id, true)
  296. s.notifyWaiting()
  297. default:
  298. return ErrUnexpectedFlag
  299. }
  300. } else if flags&flagRST == flagRST {
  301. s.state = streamClosed
  302. s.session.closeStream(s.id, true)
  303. s.notifyWaiting()
  304. }
  305. return nil
  306. }
  307. // notifyWaiting notifies all the waiting channels
  308. func (s *Stream) notifyWaiting() {
  309. asyncNotify(s.recvNotifyCh)
  310. asyncNotify(s.sendNotifyCh)
  311. }
  312. // incrSendWindow updates the size of our send window
  313. func (s *Stream) incrSendWindow(hdr header, flags uint16) error {
  314. if err := s.processFlags(flags); err != nil {
  315. return err
  316. }
  317. // Increase window, unblock a sender
  318. atomic.AddUint32(&s.sendWindow, hdr.Length())
  319. asyncNotify(s.sendNotifyCh)
  320. return nil
  321. }
  322. // readData is used to handle a data frame
  323. func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
  324. if err := s.processFlags(flags); err != nil {
  325. return err
  326. }
  327. // Check that our recv window is not exceeded
  328. length := hdr.Length()
  329. if length == 0 {
  330. return nil
  331. }
  332. if length > atomic.LoadUint32(&s.recvWindow) {
  333. return ErrRecvWindowExceeded
  334. }
  335. // Decrement the receive window
  336. atomic.AddUint32(&s.recvWindow, ^uint32(length-1))
  337. // Wrap in a limited reader
  338. conn = &io.LimitedReader{R: conn, N: int64(length)}
  339. // Copy into waiting buffer if any
  340. s.recvLock.Lock()
  341. if s.waitingBuffer != nil {
  342. n, err := conn.Read(s.waitingBuffer.buf)
  343. s.waitingBuffer.bytes = n
  344. s.waitingBuffer = nil
  345. if err != nil {
  346. s.recvLock.Unlock()
  347. return err
  348. }
  349. }
  350. // Copy to our buffer anything left
  351. if _, err := io.Copy(&s.recvBuf, conn); err != nil {
  352. s.recvLock.Unlock()
  353. return err
  354. }
  355. s.recvLock.Unlock()
  356. // Unblock any readers
  357. asyncNotify(s.recvNotifyCh)
  358. return nil
  359. }
  360. // SetDeadline sets the read and write deadlines
  361. func (s *Stream) SetDeadline(t time.Time) error {
  362. if err := s.SetReadDeadline(t); err != nil {
  363. return err
  364. }
  365. if err := s.SetWriteDeadline(t); err != nil {
  366. return err
  367. }
  368. return nil
  369. }
  370. // SetReadDeadline sets the deadline for future Read calls.
  371. func (s *Stream) SetReadDeadline(t time.Time) error {
  372. s.readDeadline = t
  373. return nil
  374. }
  375. // SetWriteDeadline sets the deadline for future Write calls
  376. func (s *Stream) SetWriteDeadline(t time.Time) error {
  377. s.writeDeadline = t
  378. return nil
  379. }