Browse Source

Merge pull request #19 from rnapier/keepalive-timeout

Fixes #12 KeepAlive is not working
James Phillips 9 years ago
parent
commit
ddcd0a6ec7
3 changed files with 134 additions and 13 deletions
  1. 3 0
      const.go
  2. 11 1
      session.go
  3. 120 12
      session_test.go

+ 3 - 0
const.go

@@ -48,6 +48,9 @@ var (
 	// ErrConnectionWriteTimeout indicates that we hit the "safety valve"
 	// timeout writing to the underlying stream connection.
 	ErrConnectionWriteTimeout = fmt.Errorf("connection write timeout")
+
+	// ErrKeepAliveTimeout is sent if a missed keepalive caused the stream close
+	ErrKeepAliveTimeout = fmt.Errorf("keepalive timeout")
 )
 
 const (

+ 11 - 1
session.go

@@ -272,6 +272,11 @@ func (s *Session) Ping() (time.Duration, error) {
 	start := time.Now()
 	select {
 	case <-ch:
+	case <-time.After(s.config.ConnectionWriteTimeout):
+		s.pingLock.Lock()
+		delete(s.pings, id) // Ignore it if a response comes later.
+		s.pingLock.Unlock()
+		return 0, ErrTimeout
 	case <-s.shutdownCh:
 		return 0, ErrSessionShutdown
 	}
@@ -286,7 +291,12 @@ func (s *Session) keepalive() {
 	for {
 		select {
 		case <-time.After(s.config.KeepAliveInterval):
-			s.Ping()
+			_, err := s.Ping()
+			if err != nil {
+				s.logger.Printf("[ERR] yamux: keepalive failed: %v", err)
+				s.exitErr(ErrKeepAliveTimeout)
+				return
+			}
 		case <-s.shutdownCh:
 			return
 		}

+ 120 - 12
session_test.go

@@ -5,12 +5,31 @@ import (
 	"fmt"
 	"io"
 	"io/ioutil"
+	"log"
+	"reflect"
 	"runtime"
+	"strings"
 	"sync"
 	"testing"
 	"time"
 )
 
+type logCapture struct{ bytes.Buffer }
+
+func (l *logCapture) logs() []string {
+	return strings.Split(strings.TrimSpace(l.String()), "\n")
+}
+
+func (l *logCapture) match(expect []string) bool {
+	return reflect.DeepEqual(l.logs(), expect)
+}
+
+func captureLogs(s *Session) *logCapture {
+	buf := new(logCapture)
+	s.logger = log.New(buf, "", 0)
+	return buf
+}
+
 type pipeConn struct {
 	reader       *io.PipeReader
 	writer       *io.PipeWriter
@@ -40,12 +59,22 @@ func testConn() (io.ReadWriteCloser, io.ReadWriteCloser) {
 	return conn1, conn2
 }
 
-func testClientServer() (*Session, *Session) {
+func testConf() *Config {
 	conf := DefaultConfig()
 	conf.AcceptBacklog = 64
 	conf.KeepAliveInterval = 100 * time.Millisecond
 	conf.ConnectionWriteTimeout = 250 * time.Millisecond
-	return testClientServerConfig(conf)
+	return conf
+}
+
+func testConfNoKeepAlive() *Config {
+	conf := testConf()
+	conf.EnableKeepAlive = false
+	return conf
+}
+
+func testClientServer() (*Session, *Session) {
+	return testClientServerConfig(testConf())
 }
 
 func testClientServerConfig(conf *Config) (*Session, *Session) {
@@ -77,6 +106,48 @@ func TestPing(t *testing.T) {
 	}
 }
 
+func TestPing_Timeout(t *testing.T) {
+	client, server := testClientServerConfig(testConfNoKeepAlive())
+	defer client.Close()
+	defer server.Close()
+
+	// Prevent the client from responding
+	clientConn := client.conn.(*pipeConn)
+	clientConn.writeBlocker.Lock()
+
+	errCh := make(chan error, 1)
+	go func() {
+		_, err := server.Ping() // Ping via the server session
+		errCh <- err
+	}()
+
+	select {
+	case err := <-errCh:
+		if err != ErrTimeout {
+			t.Fatalf("err: %v", err)
+		}
+	case <-time.After(client.config.ConnectionWriteTimeout * 2):
+		t.Fatalf("failed to timeout within expected %v", client.config.ConnectionWriteTimeout)
+	}
+
+	// Verify that we recover, even if we gave up
+	clientConn.writeBlocker.Unlock()
+
+	go func() {
+		_, err := server.Ping() // Ping via the server session
+		errCh <- err
+	}()
+
+	select {
+	case err := <-errCh:
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+	case <-time.After(client.config.ConnectionWriteTimeout):
+		t.Fatalf("timeout")
+	}
+}
+
 func TestAccept(t *testing.T) {
 	client, server := testClientServer()
 	defer client.Close()
@@ -663,6 +734,49 @@ func TestKeepAlive(t *testing.T) {
 	}
 }
 
+func TestKeepAlive_Timeout(t *testing.T) {
+	conn1, conn2 := testConn()
+
+	clientConf := testConf()
+	clientConf.ConnectionWriteTimeout = time.Hour // We're testing keep alives, not connection writes
+	clientConf.EnableKeepAlive = false            // Just test one direction, so it's deterministic who hangs up on whom
+	client, _ := Client(conn1, clientConf)
+	defer client.Close()
+
+	server, _ := Server(conn2, testConf())
+	defer server.Close()
+
+	_ = captureLogs(client) // Client logs aren't part of the test
+	serverLogs := captureLogs(server)
+
+	errCh := make(chan error, 1)
+	go func() {
+		_, err := server.Accept() // Wait until server closes
+		errCh <- err
+	}()
+
+	// Prevent the client from responding
+	clientConn := client.conn.(*pipeConn)
+	clientConn.writeBlocker.Lock()
+
+	select {
+	case err := <-errCh:
+		if err != ErrKeepAliveTimeout {
+			t.Fatalf("unexpected error: %v", err)
+		}
+	case <-time.After(1 * time.Second):
+		t.Fatalf("timeout waiting for timeout")
+	}
+
+	if !server.IsClosed() {
+		t.Fatalf("server should have closed")
+	}
+
+	if !serverLogs.match([]string{"[ERR] yamux: keepalive failed: i/o deadline reached"}) {
+		t.Fatalf("server log incorect: %v", serverLogs.logs())
+	}
+}
+
 func TestLargeWindow(t *testing.T) {
 	conf := DefaultConfig()
 	conf.MaxStreamWindowSize *= 2
@@ -807,7 +921,7 @@ func TestBacklogExceeded_Accept(t *testing.T) {
 }
 
 func TestSession_WindowUpdateWriteDuringRead(t *testing.T) {
-	client, server := testClientServer()
+	client, server := testClientServerConfig(testConfNoKeepAlive())
 	defer client.Close()
 	defer server.Close()
 
@@ -861,7 +975,7 @@ func TestSession_WindowUpdateWriteDuringRead(t *testing.T) {
 }
 
 func TestSession_sendNoWait_Timeout(t *testing.T) {
-	client, server := testClientServer()
+	client, server := testClientServerConfig(testConfNoKeepAlive())
 	defer client.Close()
 	defer server.Close()
 
@@ -910,7 +1024,7 @@ func TestSession_sendNoWait_Timeout(t *testing.T) {
 }
 
 func TestSession_PingOfDeath(t *testing.T) {
-	client, server := testClientServer()
+	client, server := testClientServerConfig(testConfNoKeepAlive())
 	defer client.Close()
 	defer server.Close()
 
@@ -981,13 +1095,7 @@ func TestSession_PingOfDeath(t *testing.T) {
 }
 
 func TestSession_ConnectionWriteTimeout(t *testing.T) {
-	// Disable keepalives so they don't detect the failed connection
-	// before the user's write does.
-	conf := DefaultConfig()
-	conf.EnableKeepAlive = false
-	conf.ConnectionWriteTimeout = 250 * time.Millisecond
-
-	client, server := testClientServerConfig(conf)
+	client, server := testClientServerConfig(testConfNoKeepAlive())
 	defer client.Close()
 	defer server.Close()