Browse Source

Adding recv buffer bypass

Armon Dadgar 10 years ago
parent
commit
bc7d83979f
1 changed files with 35 additions and 4 deletions
  1. 35 4
      stream.go

+ 35 - 4
stream.go

@@ -32,8 +32,9 @@ type Stream struct {
 	state     streamState
 	stateLock sync.Mutex
 
-	recvBuf  bytes.Buffer
-	recvLock sync.Mutex
+	recvBuf       bytes.Buffer
+	waitingBuffer *directBuffer
+	recvLock      sync.Mutex
 
 	controlHdr     header
 	controlHdrLock sync.Mutex
@@ -48,6 +49,11 @@ type Stream struct {
 	writeDeadline time.Time
 }
 
+type directBuffer struct {
+	buf   []byte
+	bytes int
+}
+
 // newStream is used to construct a new stream within
 // a given session for an ID
 func newStream(session *Session, id uint32, state streamState) *Stream {
@@ -78,6 +84,8 @@ func (s *Stream) StreamID() uint32 {
 // Read is used to read from the stream
 func (s *Stream) Read(b []byte) (n int, err error) {
 	defer asyncNotify(s.recvNotifyCh)
+	var dBuf *directBuffer
+
 START:
 	s.stateLock.Lock()
 	switch s.state {
@@ -94,6 +102,11 @@ START:
 	// If there is no data available, block
 	s.recvLock.Lock()
 	if s.recvBuf.Len() == 0 {
+		// Mark ourself as waiting potentially
+		if s.waitingBuffer == nil {
+			dBuf = &directBuffer{buf: b}
+			s.waitingBuffer = dBuf
+		}
 		s.recvLock.Unlock()
 		goto WAIT
 	}
@@ -114,8 +127,14 @@ WAIT:
 	}
 	select {
 	case <-s.recvNotifyCh:
+		if dBuf != nil && dBuf.bytes > 0 {
+			return dBuf.bytes, nil
+		}
 		goto START
 	case <-timeout:
+		if dBuf != nil && dBuf.bytes > 0 {
+			return dBuf.bytes, nil
+		}
 		return 0, ErrTimeout
 	}
 }
@@ -364,9 +383,21 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
 	// Wrap in a limited reader
 	conn = &io.LimitedReader{R: conn, N: int64(length)}
 
-	// Copy to our buffer
+	// Copy into waiting buffer if any
 	s.recvLock.Lock()
-	if _, err := io.Copy(&s.recvBuf, conn); err != nil {
+	if s.waitingBuffer != nil {
+		n, err := conn.Read(s.waitingBuffer.buf)
+		s.waitingBuffer.bytes = n
+		s.waitingBuffer = nil
+		if err != nil {
+			s.recvLock.Unlock()
+			return err
+		}
+	}
+
+	// Copy to our buffer anything left
+	if n, err := io.Copy(&s.recvBuf, conn); err != nil {
+		s.recvLock.Unlock()
 		return err
 	}
 	s.recvLock.Unlock()