Эх сурвалжийг харах

Better handling of backlog being exceeded

Armon Dadgar 10 жил өмнө
parent
commit
0d3a1c514b
4 өөрчлөгдсөн 61 нэмэгдсэн , 6 устгасан
  1. 4 0
      const.go
  2. 2 3
      session.go
  3. 46 2
      session_test.go
  4. 9 1
      stream.go

+ 4 - 0
const.go

@@ -44,6 +44,10 @@ var (
 
 	// ErrRemoteGoAway is used when we get a go away from the other side
 	ErrRemoteGoAway = fmt.Errorf("remote end is not accepting connections")
+
+	// ErrConnectionReset is sent if a stream is reset. This can happen
+	// if the backlog is exceeded, or if there was a remote GoAway.
+	ErrConnectionReset = fmt.Errorf("connection reset")
 )
 
 const (

+ 2 - 3
session.go

@@ -361,10 +361,9 @@ func (s *Session) handleStreamMessage(hdr header) error {
 	stream := s.streams[id]
 	s.streamLock.Unlock()
 
-	// Make sure we have a stream
+	// If we do not have a stream, likely we sent a RST
 	if stream == nil {
-		s.sendNoWait(s.goAway(goAwayProtoErr))
-		return ErrMissingStream
+		return nil
 	}
 
 	// Check if this is a window update

+ 46 - 2
session_test.go

@@ -34,9 +34,16 @@ func testConn() (io.ReadWriteCloser, io.ReadWriteCloser) {
 }
 
 func testClientServer() (*Session, *Session) {
+	conf := DefaultConfig()
+	conf.AcceptBacklog = 64
+	conf.KeepAliveInterval = 100 * time.Millisecond
+	return testClientServerConfig(conf)
+}
+
+func testClientServerConfig(conf *Config) (*Session, *Session) {
 	conn1, conn2 := testConn()
-	client, _ := Client(conn1, nil)
-	server, _ := Server(conn2, nil)
+	client, _ := Client(conn1, conf)
+	server, _ := Server(conn2, conf)
 	return client, server
 }
 
@@ -547,3 +554,40 @@ func TestWriteDeadline(t *testing.T) {
 	}
 	t.Fatalf("Expected timeout")
 }
+
+func TestBacklogExceeded(t *testing.T) {
+	client, server := testClientServer()
+	defer client.Close()
+	defer server.Close()
+
+	// Fill the backlog
+	max := client.config.AcceptBacklog
+	for i := 0; i < max; i++ {
+		stream, err := client.Open()
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+		defer stream.Close()
+
+		if _, err := stream.Write([]byte("foo")); err != nil {
+			t.Fatalf("err: %v", err)
+		}
+	}
+
+	// Exceed the backlog!
+	stream, err := client.Open()
+	if err != nil {
+		t.Fatalf("err: %v", err)
+	}
+	defer stream.Close()
+
+	if _, err := stream.Write([]byte("foo")); err != nil {
+		t.Fatalf("err: %v", err)
+	}
+
+	buf := make([]byte, 4)
+	stream.SetReadDeadline(time.Now().Add(50 * time.Millisecond))
+	if _, err := stream.Read(buf); err != ErrConnectionReset {
+		t.Fatalf("err: %v", err)
+	}
+}

+ 9 - 1
stream.go

@@ -18,6 +18,7 @@ const (
 	streamLocalClose
 	streamRemoteClose
 	streamClosed
+	streamReset
 )
 
 // Stream is used to represent a logical stream
@@ -100,6 +101,9 @@ START:
 			s.stateLock.Unlock()
 			return 0, io.EOF
 		}
+	case streamReset:
+		s.stateLock.Unlock()
+		return 0, ErrConnectionReset
 	}
 	s.stateLock.Unlock()
 
@@ -174,6 +178,9 @@ START:
 	case streamClosed:
 		s.stateLock.Unlock()
 		return 0, ErrStreamClosed
+	case streamReset:
+		s.stateLock.Unlock()
+		return 0, ErrConnectionReset
 	}
 	s.stateLock.Unlock()
 
@@ -296,6 +303,7 @@ func (s *Stream) Close() error {
 		goto SEND_CLOSE
 
 	case streamClosed:
+	case streamReset:
 	default:
 		panic("unhandled state")
 	}
@@ -343,7 +351,7 @@ func (s *Stream) processFlags(flags uint16) error {
 			return ErrUnexpectedFlag
 		}
 	} else if flags&flagRST == flagRST {
-		s.state = streamClosed
+		s.state = streamReset
 		s.session.closeStream(s.id, true)
 		s.notifyWaiting()
 	}