Prechádzať zdrojové kódy

improve memory utilization in receive buffer, fix flow control

* flow control was assuming a `Read` consumed the entire buffer
* flow control fix reduces memory utilization when receiving large
  streams of data
* use timer pool to reduce allocations
* use static handler function pointer to avoid closure allocations
  for every frame
Stuart Carnie 7 rokov pred
rodič
commit
ca8dfd01ce
3 zmenil súbory, kde vykonal 50 pridanie a 23 odobranie
  1. 19 16
      session.go
  2. 16 7
      stream.go
  3. 15 0
      util.go

+ 19 - 16
session.go

@@ -329,8 +329,13 @@ func (s *Session) waitForSend(hdr header, body io.Reader) error {
 // potential shutdown. Since there's the expectation that sends can happen
 // in a timely manner, we enforce the connection write timeout here.
 func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error {
-	timer := time.NewTimer(s.config.ConnectionWriteTimeout)
-	defer timer.Stop()
+	t := timerPool.Get()
+	timer := t.(*time.Timer)
+	timer.Reset(s.config.ConnectionWriteTimeout)
+	defer func() {
+		timer.Stop()
+		timerPool.Put(t)
+	}()
 
 	ready := sendReady{Hdr: hdr, Body: body, Err: errCh}
 	select {
@@ -414,11 +419,19 @@ func (s *Session) recv() {
 	}
 }
 
+var (
+	handlers = []func(*Session, header) error{
+		typeData:         (*Session).handleStreamMessage,
+		typeWindowUpdate: (*Session).handleStreamMessage,
+		typePing:         (*Session).handlePing,
+		typeGoAway:       (*Session).handleGoAway,
+	}
+)
+
 // recvLoop continues to receive data until a fatal error is encountered
 func (s *Session) recvLoop() error {
 	defer close(s.recvDoneCh)
 	hdr := header(make([]byte, headerSize))
-	var handler func(header) error
 	for {
 		// Read the header
 		if _, err := io.ReadFull(s.bufRead, hdr); err != nil {
@@ -434,22 +447,12 @@ func (s *Session) recvLoop() error {
 			return ErrInvalidVersion
 		}
 
-		// Switch on the type
-		switch hdr.MsgType() {
-		case typeData:
-			handler = s.handleStreamMessage
-		case typeWindowUpdate:
-			handler = s.handleStreamMessage
-		case typeGoAway:
-			handler = s.handleGoAway
-		case typePing:
-			handler = s.handlePing
-		default:
+		mt := hdr.MsgType()
+		if mt < typeData || mt > typeGoAway {
 			return ErrInvalidMsgType
 		}
 
-		// Invoke the handler
-		if err := handler(hdr); err != nil {
+		if err := handlers[mt](s, hdr); err != nil {
 			return err
 		}
 	}

+ 16 - 7
stream.go

@@ -242,18 +242,25 @@ func (s *Stream) sendWindowUpdate() error {
 
 	// Determine the delta update
 	max := s.session.config.MaxStreamWindowSize
-	delta := max - atomic.LoadUint32(&s.recvWindow)
+	var bufLen uint32
+	s.recvLock.Lock()
+	if s.recvBuf != nil {
+		bufLen = uint32(s.recvBuf.Len())
+	}
+	delta := (max - bufLen) - s.recvWindow
 
 	// Determine the flags if any
 	flags := s.sendFlags()
 
 	// Check if we can omit the update
 	if delta < (max/2) && flags == 0 {
+		s.recvLock.Unlock()
 		return nil
 	}
 
 	// Update our window
-	atomic.AddUint32(&s.recvWindow, delta)
+	s.recvWindow += delta
+	s.recvLock.Unlock()
 
 	// Send the header
 	s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta)
@@ -396,16 +403,18 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
 	if length == 0 {
 		return nil
 	}
-	if remain := atomic.LoadUint32(&s.recvWindow); length > remain {
-		s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, remain, length)
-		return ErrRecvWindowExceeded
-	}
 
 	// Wrap in a limited reader
 	conn = &io.LimitedReader{R: conn, N: int64(length)}
 
 	// Copy into buffer
 	s.recvLock.Lock()
+
+	if length > s.recvWindow {
+		s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length)
+		return ErrRecvWindowExceeded
+	}
+
 	if s.recvBuf == nil {
 		// Allocate the receive buffer just-in-time to fit the full data frame.
 		// This way we can read in the whole packet without further allocations.
@@ -418,7 +427,7 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
 	}
 
 	// Decrement the receive window
-	atomic.AddUint32(&s.recvWindow, ^uint32(length-1))
+	s.recvWindow += ^uint32(length - 1)
 	s.recvLock.Unlock()
 
 	// Unblock any readers

+ 15 - 0
util.go

@@ -1,5 +1,20 @@
 package yamux
 
+import (
+	"sync"
+	"time"
+)
+
+var (
+	timerPool = &sync.Pool{
+		New: func() interface{} {
+			timer := time.NewTimer(time.Hour * 1e6)
+			timer.Stop()
+			return timer
+		},
+	}
+)
+
 // asyncSendErr is used to try an async send of an error
 func asyncSendErr(ch chan error, err error) {
 	if ch == nil {