Browse Source

Adding backpressure to Open to avoid RST

Armon Dadgar 10 years ago
parent
commit
b2e55852dd
3 changed files with 87 additions and 18 deletions
  1. 31 2
      session.go
  2. 48 13
      session_test.go
  3. 8 3
      stream.go

+ 31 - 2
session.go

@@ -50,6 +50,11 @@ type Session struct {
 	streams    map[uint32]*Stream
 	streamLock sync.Mutex
 
+	// synCh acts like a semaphore. It is sized to the AcceptBacklog which
+	// is assumed to be symmetric between the client and server. This allows
+	// the client to avoid exceeding the backlog and instead blocks the open.
+	synCh chan struct{}
+
 	// acceptCh is used to pass ready streams to the client
 	acceptCh chan *Stream
 
@@ -85,6 +90,7 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
 		bufRead:    bufio.NewReader(conn),
 		pings:      make(map[uint32]chan struct{}),
 		streams:    make(map[uint32]*Stream),
+		synCh:      make(chan struct{}, config.AcceptBacklog),
 		acceptCh:   make(chan *Stream, config.AcceptBacklog),
 		sendCh:     make(chan sendReady, 64),
 		recvDoneCh: make(chan struct{}),
@@ -135,6 +141,13 @@ func (s *Session) OpenStream() (*Stream, error) {
 		return nil, ErrRemoteGoAway
 	}
 
+	// Block if we have too many inflight SYNs
+	select {
+	case s.synCh <- struct{}{}:
+	case <-s.shutdownCh:
+		return nil, ErrSessionShutdown
+	}
+
 GET_ID:
 	// Get and ID, and check for stream exhaustion
 	id := atomic.LoadUint32(&s.nextStreamID)
@@ -152,7 +165,10 @@ GET_ID:
 	s.streamLock.Unlock()
 
 	// Send the window update to create
-	return stream, stream.sendWindowUpdate()
+	if err := stream.sendWindowUpdate(); err != nil {
+		return nil, err
+	}
+	return stream, nil
 }
 
 // Accept is used to block until the next available stream
@@ -166,7 +182,10 @@ func (s *Session) Accept() (net.Conn, error) {
 func (s *Session) AcceptStream() (*Stream, error) {
 	select {
 	case stream := <-s.acceptCh:
-		return stream, stream.sendWindowUpdate()
+		if err := stream.sendWindowUpdate(); err != nil {
+			return nil, err
+		}
+		return stream, nil
 	case <-s.shutdownCh:
 		return nil, s.shutdownErr
 	}
@@ -521,3 +540,13 @@ func (s *Session) closeStream(id uint32) {
 	delete(s.streams, id)
 	s.streamLock.Unlock()
 }
+
+// establishStream is used to mark a stream that was in the
+// SYN Sent state as established.
+func (s *Session) establishStream() {
+	select {
+	case <-s.synCh:
+	default:
+		panic("established stream without inflight syn")
+	}
+}

+ 48 - 13
session_test.go

@@ -598,21 +598,26 @@ func TestBacklogExceeded(t *testing.T) {
 		}
 	}
 
-	// Exceed the backlog!
-	stream, err := client.Open()
-	if err != nil {
-		t.Fatalf("err: %v", err)
-	}
-	defer stream.Close()
+	// Attempt to open a new stream
+	errCh := make(chan error, 1)
+	go func() {
+		_, err := client.Open()
+		errCh <- err
+	}()
 
-	if _, err := stream.Write([]byte("foo")); err != nil {
-		t.Fatalf("err: %v", err)
-	}
+	// Shutdown the server
+	go func() {
+		time.Sleep(10 * time.Millisecond)
+		server.Close()
+	}()
 
-	buf := make([]byte, 4)
-	stream.SetReadDeadline(time.Now().Add(50 * time.Millisecond))
-	if _, err := stream.Read(buf); err != ErrConnectionReset {
-		t.Fatalf("err: %v", err)
+	select {
+	case err := <-errCh:
+		if err == nil {
+			t.Fatalf("open should fail")
+		}
+	case <-time.After(time.Second):
+		t.Fatalf("timeout")
 	}
 }
 
@@ -749,3 +754,33 @@ func TestSendData_VeryLarge(t *testing.T) {
 		panic("timeout")
 	}
 }
+
+func TestBacklogExceeded_Accept(t *testing.T) {
+	client, server := testClientServer()
+	defer client.Close()
+	defer server.Close()
+
+	max := 5 * client.config.AcceptBacklog
+	go func() {
+		for i := 0; i < max; i++ {
+			stream, err := server.Accept()
+			if err != nil {
+				t.Fatalf("err: %v", err)
+			}
+			defer stream.Close()
+		}
+	}()
+
+	// Fill the backlog
+	for i := 0; i < max; i++ {
+		stream, err := client.Open()
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+		defer stream.Close()
+
+		if _, err := stream.Write([]byte("foo")); err != nil {
+			t.Fatalf("err: %v", err)
+		}
+	}
+}

+ 8 - 3
stream.go

@@ -327,8 +327,9 @@ func (s *Stream) processFlags(flags uint16) error {
 		if s.state == streamSYNSent {
 			s.state = streamEstablished
 		}
-
-	} else if flags&flagFIN == flagFIN {
+		s.session.establishStream()
+	}
+	if flags&flagFIN == flagFIN {
 		switch s.state {
 		case streamSYNSent:
 			fallthrough
@@ -345,7 +346,11 @@ func (s *Stream) processFlags(flags uint16) error {
 			s.session.logger.Printf("[ERR] yamux: unexpected FIN flag in state %d", s.state)
 			return ErrUnexpectedFlag
 		}
-	} else if flags&flagRST == flagRST {
+	}
+	if flags&flagRST == flagRST {
+		if s.state == streamSYNSent {
+			s.session.establishStream()
+		}
 		s.state = streamReset
 		closeStream = true
 		s.notifyWaiting()