Browse Source

Adds mitigation for the ping of death.

James Phillips 9 years ago
parent
commit
a785d62dc2
2 changed files with 90 additions and 8 deletions
  1. 18 7
      session.go
  2. 72 1
      session_test.go

+ 18 - 7
session.go

@@ -464,7 +464,9 @@ func (s *Session) handleStreamMessage(hdr header) error {
 	// Check if this is a window update
 	if hdr.MsgType() == typeWindowUpdate {
 		if err := stream.incrSendWindow(hdr, flags); err != nil {
-			s.sendNoWait(s.goAway(goAwayProtoErr))
+			if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
+				s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
+			}
 			return err
 		}
 		return nil
@@ -472,7 +474,9 @@ func (s *Session) handleStreamMessage(hdr header) error {
 
 	// Read the new data
 	if err := stream.readData(hdr, flags, s.bufRead); err != nil {
-		s.sendNoWait(s.goAway(goAwayProtoErr))
+		if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
+			s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
+		}
 		return err
 	}
 	return nil
@@ -483,11 +487,16 @@ func (s *Session) handlePing(hdr header) error {
 	flags := hdr.Flags()
 	pingID := hdr.Length()
 
-	// Check if this is a query, respond back
+	// Check if this is a query, respond back in a separate context so we
+	// don't interfere with the receiving thread blocking for the write.
 	if flags&flagSYN == flagSYN {
-		hdr := header(make([]byte, headerSize))
-		hdr.encode(typePing, flagACK, 0, pingID)
-		return s.sendNoWait(hdr)
+		go func() {
+			hdr := header(make([]byte, headerSize))
+			hdr.encode(typePing, flagACK, 0, pingID)
+			if err := s.sendNoWait(hdr); err != nil {
+				s.logger.Printf("[WARN] yamux: failed to send ping reply: %v", err)
+			}
+		}()
 	}
 
 	// Handle a response
@@ -538,7 +547,9 @@ func (s *Session) incomingStream(id uint32) error {
 	// Check if stream already exists
 	if _, ok := s.streams[id]; ok {
 		s.logger.Printf("[ERR] yamux: duplicate stream declared")
-		s.sendNoWait(s.goAway(goAwayProtoErr))
+		if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
+			s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
+		}
 		return ErrDuplicateStream
 	}
 

+ 72 - 1
session_test.go

@@ -806,7 +806,7 @@ func TestBacklogExceeded_Accept(t *testing.T) {
 	}
 }
 
-func TestWindowUpdateWriteDuringRead(t *testing.T) {
+func TestSession_WindowUpdateWriteDuringRead(t *testing.T) {
 	client, server := testClientServer()
 	defer client.Close()
 	defer server.Close()
@@ -908,3 +908,74 @@ func TestSession_sendNoWait_Timeout(t *testing.T) {
 
 	wg.Wait()
 }
+
+func TestSession_PingOfDeath(t *testing.T) {
+	client, server := testClientServer()
+	defer client.Close()
+	defer server.Close()
+
+	var wg sync.WaitGroup
+	wg.Add(2)
+
+	var doPingOfDeath sync.Mutex
+	doPingOfDeath.Lock()
+
+	// This is used later to block outbound writes.
+	conn := server.conn.(*pipeConn)
+
+	// The server will accept a stream, block outbound writes, and then
+	// flood its send channel so that no more headers can be queued.
+	go func() {
+		defer wg.Done()
+
+		stream, err := server.AcceptStream()
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+		defer stream.Close()
+
+		conn.writeBlocker.Lock()
+		for {
+			hdr := header(make([]byte, headerSize))
+			hdr.encode(typePing, 0, 0, 0)
+			err = server.sendNoWait(hdr)
+			if err == nil {
+				continue
+			} else if err == ErrHeaderWriteTimeout {
+				break
+			} else {
+				t.Fatalf("err: %v", err)
+			}
+		}
+
+		doPingOfDeath.Unlock()
+	}()
+
+	// The client will open a stream and then send the server a ping once it
+	// can no longer write. This makes sure the server doesn't deadlock reads
+	// while trying to reply to the ping with no ability to write.
+	go func() {
+		defer wg.Done()
+
+		stream, err := client.OpenStream()
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+		defer stream.Close()
+
+		// This ping will never unblock because the ping id will never
+		// show up in a response.
+		doPingOfDeath.Lock()
+		go func() { client.Ping() }()
+
+		// Wait for a while to make sure the previous ping times out,
+		// then turn writes back on and make sure a ping works again.
+		time.Sleep(2 * server.config.HeaderWriteTimeout)
+		conn.writeBlocker.Unlock()
+		if _, err = client.Ping(); err != nil {
+			t.Fatalf("err: %v", err)
+		}
+	}()
+
+	wg.Wait()
+}