Forráskód Böngészése

Tests and improvements to #19 KeepAlive is not working

Delete abandoned pings
Avoid leaking goroutine
Add tests
Rob Napier 9 éve
szülő
commit
bb7a27e978
3 módosított fájl, 120 hozzáadás és 1 törlés
  1. 3 0
      const.go
  2. 5 1
      session.go
  3. 112 0
      session_test.go

+ 3 - 0
const.go

@@ -48,6 +48,9 @@ var (
 	// ErrConnectionWriteTimeout indicates that we hit the "safety valve"
 	// timeout writing to the underlying stream connection.
 	ErrConnectionWriteTimeout = fmt.Errorf("connection write timeout")
+
+	// ErrKeepAliveTimeout is sent if a missed keepalive caused the stream close
+	ErrKeepAliveTimeout = fmt.Errorf("keepalive timeout")
 )
 
 const (

+ 5 - 1
session.go

@@ -273,6 +273,9 @@ func (s *Session) Ping() (time.Duration, error) {
 	select {
 	case <-ch:
 	case <-time.After(s.config.ConnectionWriteTimeout):
+		s.pingLock.Lock()
+		delete(s.pings, id) // Ignore it if a response comes later.
+		s.pingLock.Unlock()
 		return 0, ErrTimeout
 	case <-s.shutdownCh:
 		return 0, ErrSessionShutdown
@@ -291,7 +294,8 @@ func (s *Session) keepalive() {
 			_, err := s.Ping()
 			if err != nil {
 				s.logger.Printf("[ERR] yamux: keepalive failed: %v", err)
-				s.Close()
+				s.exitErr(ErrKeepAliveTimeout)
+				return
 			}
 		case <-s.shutdownCh:
 			return

+ 112 - 0
session_test.go

@@ -5,12 +5,31 @@ import (
 	"fmt"
 	"io"
 	"io/ioutil"
+	"log"
+	"reflect"
 	"runtime"
+	"strings"
 	"sync"
 	"testing"
 	"time"
 )
 
+type logCapture struct{ bytes.Buffer }
+
+func (l *logCapture) logs() []string {
+	return strings.Split(strings.TrimSpace(l.String()), "\n")
+}
+
+func (l *logCapture) match(expect []string) bool {
+	return reflect.DeepEqual(l.logs(), expect)
+}
+
+func captureLogs(s *Session) *logCapture {
+	buf := new(logCapture)
+	s.logger = log.New(buf, "", 0)
+	return buf
+}
+
 type pipeConn struct {
 	reader       *io.PipeReader
 	writer       *io.PipeWriter
@@ -87,6 +106,48 @@ func TestPing(t *testing.T) {
 	}
 }
 
+func TestPing_Timeout(t *testing.T) {
+	client, server := testClientServerConfig(testConfNoKeepAlive())
+	defer client.Close()
+	defer server.Close()
+
+	// Prevent the client from responding
+	clientConn := client.conn.(*pipeConn)
+	clientConn.writeBlocker.Lock()
+
+	errCh := make(chan error, 1)
+	go func() {
+		_, err := server.Ping() // Ping via the server session
+		errCh <- err
+	}()
+
+	select {
+	case err := <-errCh:
+		if err != ErrTimeout {
+			t.Fatalf("err: %v", err)
+		}
+	case <-time.After(client.config.ConnectionWriteTimeout * 2):
+		t.Fatalf("failed to timeout within expected %v", client.config.ConnectionWriteTimeout)
+	}
+
+	// Verify that we recover, even if we gave up
+	clientConn.writeBlocker.Unlock()
+
+	go func() {
+		_, err := server.Ping() // Ping via the server session
+		errCh <- err
+	}()
+
+	select {
+	case err := <-errCh:
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+	case <-time.After(client.config.ConnectionWriteTimeout):
+		t.Fatalf("timeout")
+	}
+}
+
 func TestAccept(t *testing.T) {
 	client, server := testClientServer()
 	defer client.Close()
@@ -673,6 +734,57 @@ func TestKeepAlive(t *testing.T) {
 	}
 }
 
+func TestKeepAlive_Timeout(t *testing.T) {
+	conn1, conn2 := testConn()
+
+	clientConf := testConf()
+	clientConf.ConnectionWriteTimeout = time.Hour // We're testing keep alives, not connection writes
+	clientConf.EnableKeepAlive = false            // Just test one direction, so it's deterministic who hangs up on whom
+	client, _ := Client(conn1, clientConf)
+	defer client.Close()
+
+	server, _ := Server(conn2, testConf())
+	defer server.Close()
+
+	clientLogs := captureLogs(client)
+	serverLogs := captureLogs(server)
+
+	errCh := make(chan error, 1)
+	go func() {
+		_, err := server.Accept() // Wait until server closes
+		errCh <- err
+	}()
+
+	// Prevent the client from responding
+	clientConn := client.conn.(*pipeConn)
+	clientConn.writeBlocker.Lock()
+
+	select {
+	case err := <-errCh:
+		if err != ErrKeepAliveTimeout {
+			t.Fatalf("unexpected error: %v", err)
+		}
+	case <-time.After(1 * time.Second):
+		t.Fatalf("timeout waiting for timeout")
+	}
+
+	if !client.IsClosed() {
+		t.Fatalf("client should have closed")
+	}
+
+	if !server.IsClosed() {
+		t.Fatalf("server should have closed")
+	}
+
+	if clientLogs.Len() != 0 {
+		t.Fatalf("client log incorect: %v", clientLogs.logs())
+	}
+
+	if !serverLogs.match([]string{"[ERR] yamux: keepalive failed: i/o deadline reached"}) {
+		t.Fatalf("server log incorect: %v", serverLogs.logs())
+	}
+}
+
 func TestLargeWindow(t *testing.T) {
 	conf := DefaultConfig()
 	conf.MaxStreamWindowSize *= 2