stream.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. package yamux
  2. import (
  3. "bytes"
  4. "io"
  5. "sync"
  6. "sync/atomic"
  7. "time"
  8. )
  9. type streamState int
  10. const (
  11. streamInit streamState = iota
  12. streamSYNSent
  13. streamSYNReceived
  14. streamEstablished
  15. streamLocalClose
  16. streamRemoteClose
  17. streamClosed
  18. )
  19. // Stream is used to represent a logical stream
  20. // within a session.
  21. type Stream struct {
  22. recvWindow uint32
  23. sendWindow uint32
  24. id uint32
  25. session *Session
  26. state streamState
  27. stateLock sync.Mutex
  28. recvBuf bytes.Buffer
  29. recvLock sync.Mutex
  30. controlHdr header
  31. controlHdrLock sync.Mutex
  32. sendHdr header
  33. sendLock sync.Mutex
  34. notifyCh chan struct{}
  35. readDeadline time.Time
  36. writeDeadline time.Time
  37. }
  38. // newStream is used to construct a new stream within
  39. // a given session for an ID
  40. func newStream(session *Session, id uint32, state streamState) *Stream {
  41. s := &Stream{
  42. id: id,
  43. session: session,
  44. state: state,
  45. controlHdr: header(make([]byte, headerSize)),
  46. sendHdr: header(make([]byte, headerSize)),
  47. recvWindow: initialStreamWindow,
  48. sendWindow: initialStreamWindow,
  49. notifyCh: make(chan struct{}, 1),
  50. }
  51. return s
  52. }
  53. // Session returns the associated stream session
  54. func (s *Stream) Session() *Session {
  55. return s.session
  56. }
  57. // StreamID returns the ID of this stream
  58. func (s *Stream) StreamID() uint32 {
  59. return s.id
  60. }
  61. // Read is used to read from the stream
  62. func (s *Stream) Read(b []byte) (n int, err error) {
  63. START:
  64. s.stateLock.Lock()
  65. switch s.state {
  66. case streamRemoteClose:
  67. fallthrough
  68. case streamClosed:
  69. if s.recvBuf.Len() == 0 {
  70. s.stateLock.Unlock()
  71. return 0, io.EOF
  72. }
  73. }
  74. s.stateLock.Unlock()
  75. // If there is no data available, block
  76. s.recvLock.Lock()
  77. if s.recvBuf.Len() == 0 {
  78. s.recvLock.Unlock()
  79. goto WAIT
  80. }
  81. // Read any bytes
  82. n, _ = s.recvBuf.Read(b)
  83. s.recvLock.Unlock()
  84. // Send a window update potentially
  85. err = s.sendWindowUpdate()
  86. return n, err
  87. WAIT:
  88. var timeout <-chan time.Time
  89. if !s.readDeadline.IsZero() {
  90. delay := s.readDeadline.Sub(time.Now())
  91. timeout = time.After(delay)
  92. }
  93. select {
  94. case <-s.notifyCh:
  95. goto START
  96. case <-timeout:
  97. return 0, ErrTimeout
  98. }
  99. }
  100. // Write is used to write to the stream
  101. func (s *Stream) Write(b []byte) (n int, err error) {
  102. s.sendLock.Lock()
  103. defer s.sendLock.Unlock()
  104. total := 0
  105. for total < len(b) {
  106. n, err := s.write(b[total:])
  107. total += n
  108. if err != nil {
  109. return total, err
  110. }
  111. }
  112. return total, nil
  113. }
  114. // write is used to write to the stream, may return on
  115. // a short write.
  116. func (s *Stream) write(b []byte) (n int, err error) {
  117. var flags uint16
  118. var max uint32
  119. var body io.Reader
  120. START:
  121. s.stateLock.Lock()
  122. switch s.state {
  123. case streamLocalClose:
  124. fallthrough
  125. case streamClosed:
  126. s.stateLock.Unlock()
  127. return 0, ErrStreamClosed
  128. }
  129. s.stateLock.Unlock()
  130. // If there is no data available, block
  131. if atomic.LoadUint32(&s.sendWindow) == 0 {
  132. goto WAIT
  133. }
  134. // Determine the flags if any
  135. flags = s.sendFlags()
  136. // Send up to our send window
  137. max = min(s.sendWindow, uint32(len(b)))
  138. body = bytes.NewReader(b[:max])
  139. // Send the header
  140. s.sendHdr.encode(typeData, flags, s.id, max)
  141. if err := s.session.waitForSend(s.sendHdr, body); err != nil {
  142. return 0, err
  143. }
  144. // Reduce our send window
  145. atomic.AddUint32(&s.sendWindow, ^uint32(max-1))
  146. // Unlock
  147. return int(max), err
  148. WAIT:
  149. var timeout <-chan time.Time
  150. if !s.writeDeadline.IsZero() {
  151. delay := s.writeDeadline.Sub(time.Now())
  152. timeout = time.After(delay)
  153. }
  154. select {
  155. case <-s.notifyCh:
  156. goto START
  157. case <-timeout:
  158. return 0, ErrTimeout
  159. }
  160. return 0, nil
  161. }
  162. // sendFlags determines any flags that are appropriate
  163. // based on the current stream state
  164. func (s *Stream) sendFlags() uint16 {
  165. s.stateLock.Lock()
  166. defer s.stateLock.Unlock()
  167. var flags uint16
  168. switch s.state {
  169. case streamInit:
  170. flags |= flagSYN
  171. s.state = streamSYNSent
  172. case streamSYNReceived:
  173. flags |= flagACK
  174. s.state = streamEstablished
  175. }
  176. return flags
  177. }
  178. // sendWindowUpdate potentially sends a window update enabling
  179. // further writes to take place. Must be invoked with the lock.
  180. func (s *Stream) sendWindowUpdate() error {
  181. s.controlHdrLock.Lock()
  182. defer s.controlHdrLock.Unlock()
  183. // Determine the delta update
  184. max := s.session.config.MaxStreamWindowSize
  185. delta := max - s.recvWindow
  186. // Determine the flags if any
  187. flags := s.sendFlags()
  188. // Check if we can omit the update
  189. if delta < (max/2) && flags == 0 {
  190. return nil
  191. }
  192. // Send the header
  193. s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta)
  194. if err := s.session.waitForSend(s.controlHdr, nil); err != nil {
  195. return err
  196. }
  197. // Update our window
  198. s.recvWindow += delta
  199. return nil
  200. }
  201. // sendClose is used to send a FIN
  202. func (s *Stream) sendClose() error {
  203. s.controlHdrLock.Lock()
  204. defer s.controlHdrLock.Unlock()
  205. flags := s.sendFlags()
  206. flags |= flagFIN
  207. s.controlHdr.encode(typeWindowUpdate, flags, s.id, 0)
  208. if err := s.session.sendNoWait(s.controlHdr); err != nil {
  209. return err
  210. }
  211. return nil
  212. }
  213. // Close is used to close the stream
  214. func (s *Stream) Close() error {
  215. s.stateLock.Lock()
  216. switch s.state {
  217. // Opened means we need to signal a close
  218. case streamSYNSent:
  219. fallthrough
  220. case streamSYNReceived:
  221. fallthrough
  222. case streamEstablished:
  223. s.state = streamLocalClose
  224. goto SEND_CLOSE
  225. case streamLocalClose:
  226. case streamRemoteClose:
  227. s.state = streamClosed
  228. s.session.closeStream(s.id, false)
  229. goto SEND_CLOSE
  230. case streamClosed:
  231. default:
  232. panic("unhandled state")
  233. }
  234. s.stateLock.Unlock()
  235. return nil
  236. SEND_CLOSE:
  237. s.stateLock.Unlock()
  238. s.sendClose()
  239. return nil
  240. }
  241. // forceClose is used for when the session is exiting
  242. func (s *Stream) forceClose() {
  243. s.stateLock.Lock()
  244. s.state = streamClosed
  245. s.stateLock.Unlock()
  246. asyncNotify(s.notifyCh)
  247. }
  248. // processFlags is used to update the state of the stream
  249. // based on set flags, if any. Lock must be held
  250. func (s *Stream) processFlags(flags uint16) error {
  251. s.stateLock.Lock()
  252. defer s.stateLock.Unlock()
  253. if flags&flagACK == flagACK {
  254. if s.state == streamSYNSent {
  255. s.state = streamEstablished
  256. }
  257. } else if flags&flagFIN == flagFIN {
  258. switch s.state {
  259. case streamSYNSent:
  260. fallthrough
  261. case streamSYNReceived:
  262. fallthrough
  263. case streamEstablished:
  264. s.state = streamRemoteClose
  265. case streamLocalClose:
  266. s.state = streamClosed
  267. s.session.closeStream(s.id, true)
  268. default:
  269. return ErrUnexpectedFlag
  270. }
  271. } else if flags&flagRST == flagRST {
  272. s.state = streamClosed
  273. s.session.closeStream(s.id, true)
  274. }
  275. return nil
  276. }
  277. // incrSendWindow updates the size of our send window
  278. func (s *Stream) incrSendWindow(hdr header, flags uint16) error {
  279. if err := s.processFlags(flags); err != nil {
  280. return err
  281. }
  282. // Increase window, unblock a sender
  283. atomic.AddUint32(&s.sendWindow, hdr.Length())
  284. asyncNotify(s.notifyCh)
  285. return nil
  286. }
  287. // readData is used to handle a data frame
  288. func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
  289. if err := s.processFlags(flags); err != nil {
  290. return err
  291. }
  292. // Check that our recv window is not exceeded
  293. length := hdr.Length()
  294. if length == 0 {
  295. return nil
  296. }
  297. if length > atomic.LoadUint32(&s.recvWindow) {
  298. return ErrRecvWindowExceeded
  299. }
  300. // Decrement the receive window
  301. atomic.AddUint32(&s.recvWindow, ^uint32(length-1))
  302. // Wrap in a limited reader
  303. conn = &io.LimitedReader{R: conn, N: int64(length)}
  304. // Copy to our buffer
  305. s.recvLock.Lock()
  306. if _, err := io.Copy(&s.recvBuf, conn); err != nil {
  307. return err
  308. }
  309. s.recvLock.Unlock()
  310. // Unblock any readers
  311. asyncNotify(s.notifyCh)
  312. return nil
  313. }
  314. // SetDeadline sets the read and write deadlines
  315. func (s *Stream) SetDeadline(t time.Time) error {
  316. if err := s.SetReadDeadline(t); err != nil {
  317. return err
  318. }
  319. if err := s.SetWriteDeadline(t); err != nil {
  320. return err
  321. }
  322. return nil
  323. }
  324. // SetReadDeadline sets the deadline for future Read calls.
  325. func (s *Stream) SetReadDeadline(t time.Time) error {
  326. s.readDeadline = t
  327. return nil
  328. }
  329. // SetWriteDeadline sets the deadline for future Write calls
  330. func (s *Stream) SetWriteDeadline(t time.Time) error {
  331. s.writeDeadline = t
  332. return nil
  333. }