Ver código fonte

Adds a timeout for sendNoWait, which can block if the connection gets dire.

James Phillips 9 anos atrás
pai
commit
904091c9d5
2 arquivos alterados com 56 adições e 3 exclusões
  1. 8 3
      session.go
  2. 48 0
      session_test.go

+ 8 - 3
session.go

@@ -327,13 +327,19 @@ func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) e
 	}
 }
 
-// sendNoWait does a send without waiting
+// sendNoWait does a send without waiting. Since there's still a case where
+// sendCh itself can be full, we will enforce the configured HeaderWriteTimeout,
+// since this is a small control header.
 func (s *Session) sendNoWait(hdr header) error {
+	timeout := time.After(s.config.HeaderWriteTimeout)
+
 	select {
 	case s.sendCh <- sendReady{Hdr: hdr}:
 		return nil
 	case <-s.shutdownCh:
 		return ErrSessionShutdown
+	case <-timeout:
+		return ErrHeaderWriteTimeout
 	}
 }
 
@@ -481,8 +487,7 @@ func (s *Session) handlePing(hdr header) error {
 	if flags&flagSYN == flagSYN {
 		hdr := header(make([]byte, headerSize))
 		hdr.encode(typePing, flagACK, 0, pingID)
-		s.sendNoWait(hdr)
-		return nil
+		return s.sendNoWait(hdr)
 	}
 
 	// Handle a response

+ 48 - 0
session_test.go

@@ -860,3 +860,51 @@ func TestWindowUpdateWriteDuringRead(t *testing.T) {
 	wg.Wait()
 }
 
+func TestSession_sendNoWait_Timeout(t *testing.T) {
+	client, server := testClientServer()
+	defer client.Close()
+	defer server.Close()
+
+	var wg sync.WaitGroup
+	wg.Add(2)
+
+	go func() {
+		defer wg.Done()
+
+		stream, err := server.AcceptStream()
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+		defer stream.Close()
+	}()
+
+	// The client will open the stream and then block outbound writes, we'll
+	// probe sendNoWait once it gets into that state.
+	go func() {
+		defer wg.Done()
+
+		stream, err := client.OpenStream()
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+		defer stream.Close()
+
+		conn := client.conn.(*pipeConn)
+		conn.writeBlocker.Lock()
+
+		hdr := header(make([]byte, headerSize))
+		hdr.encode(typePing, flagACK, 0, 0)
+		for {
+			err = client.sendNoWait(hdr)
+			if err == nil {
+				continue
+			} else if err == ErrHeaderWriteTimeout {
+				break
+			} else {
+				t.Fatalf("err: %v", err)
+			}
+		}
+	}()
+
+	wg.Wait()
+}