123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519 |
- package yamux
- import (
- "bytes"
- "io"
- "sync"
- "sync/atomic"
- "time"
- )
- type streamState int
- const (
- streamInit streamState = iota
- streamSYNSent
- streamSYNReceived
- streamEstablished
- streamLocalClose
- streamRemoteClose
- streamClosed
- streamReset
- )
- // Stream is used to represent a logical stream
- // within a session.
- type Stream struct {
- recvWindow uint32
- sendWindow uint32
- id uint32
- session *Session
- state streamState
- stateLock sync.Mutex
- recvBuf *bytes.Buffer
- recvLock sync.Mutex
- controlHdr header
- controlErr chan error
- controlHdrLock sync.Mutex
- sendHdr header
- sendErr chan error
- sendLock sync.Mutex
- recvNotifyCh chan struct{}
- sendNotifyCh chan struct{}
- readDeadline atomic.Value // time.Time
- writeDeadline atomic.Value // time.Time
- // closeTimer is set with stateLock held to honor the StreamCloseTimeout
- // setting on Session.
- closeTimer *time.Timer
- }
- // newStream is used to construct a new stream within
- // a given session for an ID
- func newStream(session *Session, id uint32, state streamState) *Stream {
- s := &Stream{
- id: id,
- session: session,
- state: state,
- controlHdr: header(make([]byte, headerSize)),
- controlErr: make(chan error, 1),
- sendHdr: header(make([]byte, headerSize)),
- sendErr: make(chan error, 1),
- recvWindow: initialStreamWindow,
- sendWindow: initialStreamWindow,
- recvNotifyCh: make(chan struct{}, 1),
- sendNotifyCh: make(chan struct{}, 1),
- }
- s.readDeadline.Store(time.Time{})
- s.writeDeadline.Store(time.Time{})
- return s
- }
- // Session returns the associated stream session
- func (s *Stream) Session() *Session {
- return s.session
- }
- // StreamID returns the ID of this stream
- func (s *Stream) StreamID() uint32 {
- return s.id
- }
- // Read is used to read from the stream
- func (s *Stream) Read(b []byte) (n int, err error) {
- defer asyncNotify(s.recvNotifyCh)
- START:
- s.stateLock.Lock()
- switch s.state {
- case streamLocalClose:
- fallthrough
- case streamRemoteClose:
- fallthrough
- case streamClosed:
- s.recvLock.Lock()
- if s.recvBuf == nil || s.recvBuf.Len() == 0 {
- s.recvLock.Unlock()
- s.stateLock.Unlock()
- return 0, io.EOF
- }
- s.recvLock.Unlock()
- case streamReset:
- s.stateLock.Unlock()
- return 0, ErrConnectionReset
- }
- s.stateLock.Unlock()
- // If there is no data available, block
- s.recvLock.Lock()
- if s.recvBuf == nil || s.recvBuf.Len() == 0 {
- s.recvLock.Unlock()
- goto WAIT
- }
- // Read any bytes
- n, _ = s.recvBuf.Read(b)
- s.recvLock.Unlock()
- // Send a window update potentially
- err = s.sendWindowUpdate()
- return n, err
- WAIT:
- var timeout <-chan time.Time
- var timer *time.Timer
- readDeadline := s.readDeadline.Load().(time.Time)
- if !readDeadline.IsZero() {
- delay := readDeadline.Sub(time.Now())
- timer = time.NewTimer(delay)
- timeout = timer.C
- }
- select {
- case <-s.recvNotifyCh:
- if timer != nil {
- timer.Stop()
- }
- goto START
- case <-timeout:
- return 0, ErrTimeout
- }
- }
- // Write is used to write to the stream
- func (s *Stream) Write(b []byte) (n int, err error) {
- s.sendLock.Lock()
- defer s.sendLock.Unlock()
- total := 0
- for total < len(b) {
- n, err := s.write(b[total:])
- total += n
- if err != nil {
- return total, err
- }
- }
- return total, nil
- }
- // write is used to write to the stream, may return on
- // a short write.
- func (s *Stream) write(b []byte) (n int, err error) {
- var flags uint16
- var max uint32
- var body io.Reader
- START:
- s.stateLock.Lock()
- switch s.state {
- case streamLocalClose:
- fallthrough
- case streamClosed:
- s.stateLock.Unlock()
- return 0, ErrStreamClosed
- case streamReset:
- s.stateLock.Unlock()
- return 0, ErrConnectionReset
- }
- s.stateLock.Unlock()
- // If there is no data available, block
- window := atomic.LoadUint32(&s.sendWindow)
- if window == 0 {
- goto WAIT
- }
- // Determine the flags if any
- flags = s.sendFlags()
- // Send up to our send window
- max = min(window, uint32(len(b)))
- body = bytes.NewReader(b[:max])
- // Send the header
- s.sendHdr.encode(typeData, flags, s.id, max)
- if err = s.session.waitForSendErr(s.sendHdr, body, s.sendErr); err != nil {
- return 0, err
- }
- // Reduce our send window
- atomic.AddUint32(&s.sendWindow, ^uint32(max-1))
- // Unlock
- return int(max), err
- WAIT:
- var timeout <-chan time.Time
- writeDeadline := s.writeDeadline.Load().(time.Time)
- if !writeDeadline.IsZero() {
- delay := writeDeadline.Sub(time.Now())
- timeout = time.After(delay)
- }
- select {
- case <-s.sendNotifyCh:
- goto START
- case <-timeout:
- return 0, ErrTimeout
- }
- return 0, nil
- }
- // sendFlags determines any flags that are appropriate
- // based on the current stream state
- func (s *Stream) sendFlags() uint16 {
- s.stateLock.Lock()
- defer s.stateLock.Unlock()
- var flags uint16
- switch s.state {
- case streamInit:
- flags |= flagSYN
- s.state = streamSYNSent
- case streamSYNReceived:
- flags |= flagACK
- s.state = streamEstablished
- }
- return flags
- }
- // sendWindowUpdate potentially sends a window update enabling
- // further writes to take place. Must be invoked with the lock.
- func (s *Stream) sendWindowUpdate() error {
- s.controlHdrLock.Lock()
- defer s.controlHdrLock.Unlock()
- // Determine the delta update
- max := s.session.config.MaxStreamWindowSize
- var bufLen uint32
- s.recvLock.Lock()
- if s.recvBuf != nil {
- bufLen = uint32(s.recvBuf.Len())
- }
- delta := (max - bufLen) - s.recvWindow
- // Determine the flags if any
- flags := s.sendFlags()
- // Check if we can omit the update
- if delta < (max/2) && flags == 0 {
- s.recvLock.Unlock()
- return nil
- }
- // Update our window
- s.recvWindow += delta
- s.recvLock.Unlock()
- // Send the header
- s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta)
- if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil {
- return err
- }
- return nil
- }
- // sendClose is used to send a FIN
- func (s *Stream) sendClose() error {
- s.controlHdrLock.Lock()
- defer s.controlHdrLock.Unlock()
- flags := s.sendFlags()
- flags |= flagFIN
- s.controlHdr.encode(typeWindowUpdate, flags, s.id, 0)
- if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil {
- return err
- }
- return nil
- }
- // Close is used to close the stream
- func (s *Stream) Close() error {
- closeStream := false
- s.stateLock.Lock()
- switch s.state {
- // Opened means we need to signal a close
- case streamSYNSent:
- fallthrough
- case streamSYNReceived:
- fallthrough
- case streamEstablished:
- s.state = streamLocalClose
- goto SEND_CLOSE
- case streamLocalClose:
- case streamRemoteClose:
- s.state = streamClosed
- closeStream = true
- goto SEND_CLOSE
- case streamClosed:
- case streamReset:
- default:
- panic("unhandled state")
- }
- s.stateLock.Unlock()
- return nil
- SEND_CLOSE:
- // This shouldn't happen (the more realistic scenario to cancel the
- // timer is via processFlags) but just in case this ever happens, we
- // cancel the timer to prevent dangling timers.
- if s.closeTimer != nil {
- s.closeTimer.Stop()
- s.closeTimer = nil
- }
- // If we have a StreamCloseTimeout set we start the timeout timer.
- // We do this only if we're not already closing the stream since that
- // means this was a graceful close.
- //
- // This prevents memory leaks if one side (this side) closes and the
- // remote side poorly behaves and never responds with a FIN to complete
- // the close. After the specified timeout, we clean our resources up no
- // matter what.
- if !closeStream && s.session.config.StreamCloseTimeout > 0 {
- s.closeTimer = time.AfterFunc(
- s.session.config.StreamCloseTimeout, s.closeTimeout)
- }
- s.stateLock.Unlock()
- s.sendClose()
- s.notifyWaiting()
- if closeStream {
- s.session.closeStream(s.id)
- }
- return nil
- }
- // closeTimeout is called after StreamCloseTimeout during a close to
- // close this stream.
- func (s *Stream) closeTimeout() {
- // Close our side forcibly
- s.forceClose()
- // Free the stream from the session map
- s.session.closeStream(s.id)
- // Send a RST so the remote side closes too.
- s.sendLock.Lock()
- defer s.sendLock.Unlock()
- s.sendHdr.encode(typeWindowUpdate, flagRST, s.id, 0)
- s.session.sendNoWait(s.sendHdr)
- }
- // forceClose is used for when the session is exiting
- func (s *Stream) forceClose() {
- s.stateLock.Lock()
- s.state = streamClosed
- s.stateLock.Unlock()
- s.notifyWaiting()
- }
- // processFlags is used to update the state of the stream
- // based on set flags, if any. Lock must be held
- func (s *Stream) processFlags(flags uint16) error {
- s.stateLock.Lock()
- defer s.stateLock.Unlock()
- // Close the stream without holding the state lock
- closeStream := false
- defer func() {
- if closeStream {
- if s.closeTimer != nil {
- // Stop our close timeout timer since we gracefully closed
- s.closeTimer.Stop()
- }
- s.session.closeStream(s.id)
- }
- }()
- if flags&flagACK == flagACK {
- if s.state == streamSYNSent {
- s.state = streamEstablished
- }
- s.session.establishStream(s.id)
- }
- if flags&flagFIN == flagFIN {
- switch s.state {
- case streamSYNSent:
- fallthrough
- case streamSYNReceived:
- fallthrough
- case streamEstablished:
- s.state = streamRemoteClose
- s.notifyWaiting()
- case streamLocalClose:
- s.state = streamClosed
- closeStream = true
- s.notifyWaiting()
- default:
- s.session.logger.Printf("[ERR] yamux: unexpected FIN flag in state %d", s.state)
- return ErrUnexpectedFlag
- }
- }
- if flags&flagRST == flagRST {
- s.state = streamReset
- closeStream = true
- s.notifyWaiting()
- }
- return nil
- }
- // notifyWaiting notifies all the waiting channels
- func (s *Stream) notifyWaiting() {
- asyncNotify(s.recvNotifyCh)
- asyncNotify(s.sendNotifyCh)
- }
- // incrSendWindow updates the size of our send window
- func (s *Stream) incrSendWindow(hdr header, flags uint16) error {
- if err := s.processFlags(flags); err != nil {
- return err
- }
- // Increase window, unblock a sender
- atomic.AddUint32(&s.sendWindow, hdr.Length())
- asyncNotify(s.sendNotifyCh)
- return nil
- }
- // readData is used to handle a data frame
- func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
- if err := s.processFlags(flags); err != nil {
- return err
- }
- // Check that our recv window is not exceeded
- length := hdr.Length()
- if length == 0 {
- return nil
- }
- // Wrap in a limited reader
- conn = &io.LimitedReader{R: conn, N: int64(length)}
- // Copy into buffer
- s.recvLock.Lock()
- if length > s.recvWindow {
- s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length)
- return ErrRecvWindowExceeded
- }
- if s.recvBuf == nil {
- // Allocate the receive buffer just-in-time to fit the full data frame.
- // This way we can read in the whole packet without further allocations.
- s.recvBuf = bytes.NewBuffer(make([]byte, 0, length))
- }
- if _, err := io.Copy(s.recvBuf, conn); err != nil {
- s.session.logger.Printf("[ERR] yamux: Failed to read stream data: %v", err)
- s.recvLock.Unlock()
- return err
- }
- // Decrement the receive window
- s.recvWindow -= length
- s.recvLock.Unlock()
- // Unblock any readers
- asyncNotify(s.recvNotifyCh)
- return nil
- }
- // SetDeadline sets the read and write deadlines
- func (s *Stream) SetDeadline(t time.Time) error {
- if err := s.SetReadDeadline(t); err != nil {
- return err
- }
- if err := s.SetWriteDeadline(t); err != nil {
- return err
- }
- return nil
- }
- // SetReadDeadline sets the deadline for blocked and future Read calls.
- func (s *Stream) SetReadDeadline(t time.Time) error {
- s.readDeadline.Store(t)
- asyncNotify(s.recvNotifyCh)
- return nil
- }
- // SetWriteDeadline sets the deadline for blocked and future Write calls
- func (s *Stream) SetWriteDeadline(t time.Time) error {
- s.writeDeadline.Store(t)
- asyncNotify(s.sendNotifyCh)
- return nil
- }
- // Shrink is used to compact the amount of buffers utilized
- // This is useful when using Yamux in a connection pool to reduce
- // the idle memory utilization.
- func (s *Stream) Shrink() {
- s.recvLock.Lock()
- if s.recvBuf != nil && s.recvBuf.Len() == 0 {
- s.recvBuf = nil
- }
- s.recvLock.Unlock()
- }
|