Browse Source

Merge pull request #21 from hashicorp/f-shrink

Support shrinking of receive buffers of a stream
Armon Dadgar 9 years ago
parent
commit
df949784da
2 changed files with 32 additions and 5 deletions
  1. 12 1
      session_test.go
  2. 20 4
      stream.go

+ 12 - 1
session_test.go

@@ -496,6 +496,7 @@ func TestManyStreams_PingPong(t *testing.T) {
 
 		buf := make([]byte, 4)
 		for {
+			// Read the 'ping'
 			n, err := stream.Read(buf)
 			if err == io.EOF {
 				return
@@ -509,6 +510,11 @@ func TestManyStreams_PingPong(t *testing.T) {
 			if !bytes.Equal(buf, ping) {
 				t.Fatalf("bad: %s", buf)
 			}
+
+			// Shrink the internal buffer!
+			stream.Shrink()
+
+			// Write out the 'pong'
 			n, err = stream.Write(pong)
 			if err != nil {
 				t.Fatalf("err: %v", err)
@@ -520,7 +526,7 @@ func TestManyStreams_PingPong(t *testing.T) {
 	}
 	sender := func(i int) {
 		defer wg.Done()
-		stream, err := client.Open()
+		stream, err := client.OpenStream()
 		if err != nil {
 			t.Fatalf("err: %v", err)
 		}
@@ -528,6 +534,7 @@ func TestManyStreams_PingPong(t *testing.T) {
 
 		buf := make([]byte, 4)
 		for i := 0; i < 1000; i++ {
+			// Send the 'ping'
 			n, err := stream.Write(ping)
 			if err != nil {
 				t.Fatalf("err: %v", err)
@@ -536,6 +543,7 @@ func TestManyStreams_PingPong(t *testing.T) {
 				t.Fatalf("short write %d", n)
 			}
 
+			// Read the 'pong'
 			n, err = stream.Read(buf)
 			if err != nil {
 				t.Fatalf("err: %v", err)
@@ -546,6 +554,9 @@ func TestManyStreams_PingPong(t *testing.T) {
 			if !bytes.Equal(buf, pong) {
 				t.Fatalf("bad: %s", buf)
 			}
+
+			// Shrink the buffer
+			stream.Shrink()
 		}
 	}
 

+ 20 - 4
stream.go

@@ -33,7 +33,7 @@ type Stream struct {
 	state     streamState
 	stateLock sync.Mutex
 
-	recvBuf  bytes.Buffer
+	recvBuf  *bytes.Buffer
 	recvLock sync.Mutex
 
 	controlHdr     header
@@ -91,7 +91,7 @@ START:
 	case streamRemoteClose:
 		fallthrough
 	case streamClosed:
-		if s.recvBuf.Len() == 0 {
+		if s.recvBuf == nil || s.recvBuf.Len() == 0 {
 			s.stateLock.Unlock()
 			return 0, io.EOF
 		}
@@ -103,7 +103,7 @@ START:
 
 	// If there is no data available, block
 	s.recvLock.Lock()
-	if s.recvBuf.Len() == 0 {
+	if s.recvBuf == nil || s.recvBuf.Len() == 0 {
 		s.recvLock.Unlock()
 		goto WAIT
 	}
@@ -397,7 +397,12 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
 
 	// Copy into buffer
 	s.recvLock.Lock()
-	if _, err := io.Copy(&s.recvBuf, conn); err != nil {
+	if s.recvBuf == nil {
+		// Allocate the receive buffer just-in-time to fit the full data frame.
+		// This way we can read in the whole packet without further allocations.
+		s.recvBuf = bytes.NewBuffer(make([]byte, 0, length))
+	}
+	if _, err := io.Copy(s.recvBuf, conn); err != nil {
 		s.session.logger.Printf("[ERR] yamux: Failed to read stream data: %v", err)
 		s.recvLock.Unlock()
 		return err
@@ -434,3 +439,14 @@ func (s *Stream) SetWriteDeadline(t time.Time) error {
 	s.writeDeadline = t
 	return nil
 }
+
+// Shrink is used to compact the amount of buffers utilized
+// This is useful when using Yamux in a connection pool to reduce
+// the idle memory utilization.
+func (s *Stream) Shrink() {
+	s.recvLock.Lock()
+	if s.recvBuf != nil && s.recvBuf.Len() == 0 {
+		s.recvBuf = nil
+	}
+	s.recvLock.Unlock()
+}