stream.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. package yamux
  2. import (
  3. "bytes"
  4. "compress/lzw"
  5. "io"
  6. "log"
  7. "net"
  8. "sync"
  9. "time"
  10. )
  11. type streamState int
  12. const (
  13. streamInit streamState = iota
  14. streamSYNSent
  15. streamSYNReceived
  16. streamEstablished
  17. streamLocalClose
  18. streamRemoteClose
  19. streamClosed
  20. )
  21. // Stream is used to represent a logical stream
  22. // within a session.
  23. type Stream struct {
  24. id uint32
  25. session *Session
  26. state streamState
  27. lock sync.Mutex
  28. recvBuf bytes.Buffer
  29. sendHdr header
  30. recvWindow uint32
  31. sendWindow uint32
  32. notifyCh chan struct{}
  33. readDeadline time.Time
  34. writeDeadline time.Time
  35. }
  36. // newStream is used to construct a new stream within
  37. // a given session for an ID
  38. func newStream(session *Session, id uint32, state streamState) *Stream {
  39. s := &Stream{
  40. id: id,
  41. session: session,
  42. state: state,
  43. recvWindow: initialStreamWindow,
  44. sendWindow: initialStreamWindow,
  45. notifyCh: make(chan struct{}, 1),
  46. sendHdr: header(make([]byte, headerSize)),
  47. }
  48. return s
  49. }
  50. // Session returns the associated stream session
  51. func (s *Stream) Session() *Session {
  52. return s.session
  53. }
  54. // StreamID returns the ID of this stream
  55. func (s *Stream) StreamID() uint32 {
  56. return s.id
  57. }
  58. // Read is used to read from the stream
  59. func (s *Stream) Read(b []byte) (n int, err error) {
  60. START:
  61. s.lock.Lock()
  62. switch s.state {
  63. case streamRemoteClose:
  64. fallthrough
  65. case streamClosed:
  66. if s.recvBuf.Len() == 0 {
  67. s.lock.Unlock()
  68. return 0, io.EOF
  69. }
  70. }
  71. // If there is no data available, block
  72. if s.recvBuf.Len() == 0 {
  73. s.lock.Unlock()
  74. goto WAIT
  75. }
  76. // Read any bytes
  77. n, _ = s.recvBuf.Read(b)
  78. // Send a window update potentially
  79. err = s.sendWindowUpdate()
  80. s.lock.Unlock()
  81. return n, err
  82. WAIT:
  83. var timeout <-chan time.Time
  84. if !s.readDeadline.IsZero() {
  85. delay := s.readDeadline.Sub(time.Now())
  86. timeout = time.After(delay)
  87. }
  88. select {
  89. case <-s.notifyCh:
  90. goto START
  91. case <-timeout:
  92. return 0, ErrTimeout
  93. }
  94. }
  95. // Write is used to write to the stream
  96. func (s *Stream) Write(b []byte) (n int, err error) {
  97. total := 0
  98. for total < len(b) {
  99. n, err := s.write(b[total:])
  100. total += n
  101. if err != nil {
  102. return total, err
  103. }
  104. }
  105. return total, nil
  106. }
  107. // write is used to write to the stream, may return on
  108. // a short write.
  109. func (s *Stream) write(b []byte) (n int, err error) {
  110. var flags uint16
  111. var max uint32
  112. var body io.Reader
  113. START:
  114. s.lock.Lock()
  115. switch s.state {
  116. case streamLocalClose:
  117. fallthrough
  118. case streamClosed:
  119. s.lock.Unlock()
  120. return 0, ErrStreamClosed
  121. }
  122. // If there is no data available, block
  123. if s.sendWindow == 0 {
  124. s.lock.Unlock()
  125. goto WAIT
  126. }
  127. // Determine the flags if any
  128. flags = s.sendFlags()
  129. // Send up to our send window
  130. max = min(s.sendWindow, uint32(len(b)))
  131. body = bytes.NewReader(b[:max])
  132. // TODO: Compress
  133. // Send the header
  134. s.sendHdr.encode(typeData, flags, s.id, max)
  135. if err := s.session.waitForSend(s.sendHdr, body); err != nil {
  136. s.lock.Unlock()
  137. return 0, err
  138. }
  139. // Reduce our send window
  140. s.sendWindow -= max
  141. // Unlock
  142. s.lock.Unlock()
  143. return int(max), err
  144. WAIT:
  145. var timeout <-chan time.Time
  146. if !s.writeDeadline.IsZero() {
  147. delay := s.writeDeadline.Sub(time.Now())
  148. timeout = time.After(delay)
  149. }
  150. select {
  151. case <-s.notifyCh:
  152. goto START
  153. case <-timeout:
  154. return 0, ErrTimeout
  155. }
  156. return 0, nil
  157. }
  158. // sendFlags determines any flags that are appropriate
  159. // based on the current stream state
  160. func (s *Stream) sendFlags() uint16 {
  161. // Determine the flags if any
  162. var flags uint16
  163. switch s.state {
  164. case streamInit:
  165. flags |= flagSYN
  166. s.state = streamSYNSent
  167. case streamSYNReceived:
  168. flags |= flagACK
  169. s.state = streamEstablished
  170. }
  171. return flags
  172. }
  173. // sendWindowUpdate potentially sends a window update enabling
  174. // further writes to take place. Must be invoked with the lock.
  175. func (s *Stream) sendWindowUpdate() error {
  176. // Determine the delta update
  177. max := s.session.config.MaxStreamWindowSize
  178. delta := max - s.recvWindow
  179. // Determine the flags if any
  180. flags := s.sendFlags()
  181. // Check if we can omit the update
  182. if delta < (max/2) && flags == 0 {
  183. return nil
  184. }
  185. // Send the header
  186. s.sendHdr.encode(typeWindowUpdate, flags, s.id, delta)
  187. if err := s.session.waitForSend(s.sendHdr, nil); err != nil {
  188. return err
  189. }
  190. log.Printf("Window Update %d +%d", s.id, delta)
  191. // Update our window
  192. s.recvWindow += delta
  193. return nil
  194. }
  195. // sendClose is used to send a FIN
  196. func (s *Stream) sendClose() error {
  197. flags := s.sendFlags()
  198. flags |= flagFIN
  199. s.sendHdr.encode(typeWindowUpdate, flags, s.id, 0)
  200. if err := s.session.waitForSend(s.sendHdr, nil); err != nil {
  201. return err
  202. }
  203. return nil
  204. }
  205. // Close is used to close the stream
  206. func (s *Stream) Close() error {
  207. s.lock.Lock()
  208. defer s.lock.Unlock()
  209. switch s.state {
  210. // Local or full close means nothing to do
  211. case streamLocalClose:
  212. fallthrough
  213. case streamClosed:
  214. return nil
  215. // Remote close, weneed to send FIN and we are done
  216. case streamRemoteClose:
  217. s.state = streamClosed
  218. s.session.closeStream(s.id, false)
  219. s.sendClose()
  220. return nil
  221. // Opened means we need to signal a close
  222. case streamSYNSent:
  223. fallthrough
  224. case streamSYNReceived:
  225. fallthrough
  226. case streamEstablished:
  227. s.state = streamLocalClose
  228. s.sendClose()
  229. return nil
  230. }
  231. panic("unhandled state")
  232. }
  233. // forceClose is used for when the session is exiting
  234. func (s *Stream) forceClose() {
  235. s.lock.Lock()
  236. defer s.lock.Unlock()
  237. s.state = streamClosed
  238. asyncNotify(s.notifyCh)
  239. }
  240. // LocalAddr returns the local address
  241. func (s *Stream) LocalAddr() net.Addr {
  242. return s.session.LocalAddr()
  243. }
  244. // LocalAddr returns the remote address
  245. func (s *Stream) RemoteAddr() net.Addr {
  246. return s.session.RemoteAddr()
  247. }
  248. // SetDeadline sets the read and write deadlines
  249. func (s *Stream) SetDeadline(t time.Time) error {
  250. if err := s.SetReadDeadline(t); err != nil {
  251. return err
  252. }
  253. if err := s.SetWriteDeadline(t); err != nil {
  254. return err
  255. }
  256. return nil
  257. }
  258. // SetReadDeadline sets the deadline for future Read calls.
  259. func (s *Stream) SetReadDeadline(t time.Time) error {
  260. s.readDeadline = t
  261. return nil
  262. }
  263. // SetWriteDeadline sets the deadline for future Write calls
  264. func (s *Stream) SetWriteDeadline(t time.Time) error {
  265. s.writeDeadline = t
  266. return nil
  267. }
  268. // processFlags is used to update the state of the stream
  269. // based on set flags, if any. Lock must be held
  270. func (s *Stream) processFlags(flags uint16) error {
  271. if flags&flagACK == flagACK {
  272. if s.state == streamSYNSent {
  273. s.state = streamEstablished
  274. }
  275. } else if flags&flagFIN == flagFIN {
  276. switch s.state {
  277. case streamSYNSent:
  278. fallthrough
  279. case streamSYNReceived:
  280. fallthrough
  281. case streamEstablished:
  282. s.state = streamRemoteClose
  283. case streamLocalClose:
  284. s.state = streamClosed
  285. s.session.closeStream(s.id, true)
  286. default:
  287. return ErrUnexpectedFlag
  288. }
  289. } else if flags&flagRST == flagRST {
  290. s.state = streamClosed
  291. s.session.closeStream(s.id, true)
  292. }
  293. return nil
  294. }
  295. // incrSendWindow updates the size of our send window
  296. func (s *Stream) incrSendWindow(hdr header, flags uint16) error {
  297. s.lock.Lock()
  298. defer s.lock.Unlock()
  299. if err := s.processFlags(flags); err != nil {
  300. return err
  301. }
  302. // Increase window, unblock a sender
  303. s.sendWindow += hdr.Length()
  304. asyncNotify(s.notifyCh)
  305. return nil
  306. }
  307. // readData is used to handle a data frame
  308. func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
  309. s.lock.Lock()
  310. defer s.lock.Unlock()
  311. if err := s.processFlags(flags); err != nil {
  312. return err
  313. }
  314. // Check that our recv window is not exceeded
  315. length := hdr.Length()
  316. if length > s.recvWindow {
  317. return ErrRecvWindowExceeded
  318. }
  319. // Decrement the receive window
  320. s.recvWindow -= length
  321. // Wrap in a limited reader
  322. conn = &io.LimitedReader{R: conn, N: int64(length)}
  323. // Handle potential data compression
  324. if flags&flagLZW == flagLZW {
  325. cr := lzw.NewReader(conn, lzw.MSB, 8)
  326. defer cr.Close()
  327. conn = cr
  328. }
  329. // Copy to our buffer
  330. if _, err := io.Copy(&s.recvBuf, conn); err != nil {
  331. return err
  332. }
  333. // Unblock any readers
  334. asyncNotify(s.notifyCh)
  335. return nil
  336. }