|
@@ -5,12 +5,31 @@ import (
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"io/ioutil"
|
|
|
+ "log"
|
|
|
+ "reflect"
|
|
|
"runtime"
|
|
|
+ "strings"
|
|
|
"sync"
|
|
|
"testing"
|
|
|
"time"
|
|
|
)
|
|
|
|
|
|
+type logCapture struct{ bytes.Buffer }
|
|
|
+
|
|
|
+func (l *logCapture) logs() []string {
|
|
|
+ return strings.Split(strings.TrimSpace(l.String()), "\n")
|
|
|
+}
|
|
|
+
|
|
|
+func (l *logCapture) match(expect []string) bool {
|
|
|
+ return reflect.DeepEqual(l.logs(), expect)
|
|
|
+}
|
|
|
+
|
|
|
+func captureLogs(s *Session) *logCapture {
|
|
|
+ buf := new(logCapture)
|
|
|
+ s.logger = log.New(buf, "", 0)
|
|
|
+ return buf
|
|
|
+}
|
|
|
+
|
|
|
type pipeConn struct {
|
|
|
reader *io.PipeReader
|
|
|
writer *io.PipeWriter
|
|
@@ -87,6 +106,48 @@ func TestPing(t *testing.T) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+func TestPing_Timeout(t *testing.T) {
|
|
|
+ client, server := testClientServerConfig(testConfNoKeepAlive())
|
|
|
+ defer client.Close()
|
|
|
+ defer server.Close()
|
|
|
+
|
|
|
+ // Prevent the client from responding
|
|
|
+ clientConn := client.conn.(*pipeConn)
|
|
|
+ clientConn.writeBlocker.Lock()
|
|
|
+
|
|
|
+ errCh := make(chan error, 1)
|
|
|
+ go func() {
|
|
|
+ _, err := server.Ping() // Ping via the server session
|
|
|
+ errCh <- err
|
|
|
+ }()
|
|
|
+
|
|
|
+ select {
|
|
|
+ case err := <-errCh:
|
|
|
+ if err != ErrTimeout {
|
|
|
+ t.Fatalf("err: %v", err)
|
|
|
+ }
|
|
|
+ case <-time.After(client.config.ConnectionWriteTimeout * 2):
|
|
|
+ t.Fatalf("failed to timeout within expected %v", client.config.ConnectionWriteTimeout)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Verify that we recover, even if we gave up
|
|
|
+ clientConn.writeBlocker.Unlock()
|
|
|
+
|
|
|
+ go func() {
|
|
|
+ _, err := server.Ping() // Ping via the server session
|
|
|
+ errCh <- err
|
|
|
+ }()
|
|
|
+
|
|
|
+ select {
|
|
|
+ case err := <-errCh:
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("err: %v", err)
|
|
|
+ }
|
|
|
+ case <-time.After(client.config.ConnectionWriteTimeout):
|
|
|
+ t.Fatalf("timeout")
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func TestAccept(t *testing.T) {
|
|
|
client, server := testClientServer()
|
|
|
defer client.Close()
|
|
@@ -673,6 +734,57 @@ func TestKeepAlive(t *testing.T) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+func TestKeepAlive_Timeout(t *testing.T) {
|
|
|
+ conn1, conn2 := testConn()
|
|
|
+
|
|
|
+ clientConf := testConf()
|
|
|
+ clientConf.ConnectionWriteTimeout = time.Hour // We're testing keep alives, not connection writes
|
|
|
+ clientConf.EnableKeepAlive = false // Just test one direction, so it's deterministic who hangs up on whom
|
|
|
+ client, _ := Client(conn1, clientConf)
|
|
|
+ defer client.Close()
|
|
|
+
|
|
|
+ server, _ := Server(conn2, testConf())
|
|
|
+ defer server.Close()
|
|
|
+
|
|
|
+ clientLogs := captureLogs(client)
|
|
|
+ serverLogs := captureLogs(server)
|
|
|
+
|
|
|
+ errCh := make(chan error, 1)
|
|
|
+ go func() {
|
|
|
+ _, err := server.Accept() // Wait until server closes
|
|
|
+ errCh <- err
|
|
|
+ }()
|
|
|
+
|
|
|
+ // Prevent the client from responding
|
|
|
+ clientConn := client.conn.(*pipeConn)
|
|
|
+ clientConn.writeBlocker.Lock()
|
|
|
+
|
|
|
+ select {
|
|
|
+ case err := <-errCh:
|
|
|
+ if err != ErrKeepAliveTimeout {
|
|
|
+ t.Fatalf("unexpected error: %v", err)
|
|
|
+ }
|
|
|
+ case <-time.After(1 * time.Second):
|
|
|
+ t.Fatalf("timeout waiting for timeout")
|
|
|
+ }
|
|
|
+
|
|
|
+ if !client.IsClosed() {
|
|
|
+ t.Fatalf("client should have closed")
|
|
|
+ }
|
|
|
+
|
|
|
+ if !server.IsClosed() {
|
|
|
+ t.Fatalf("server should have closed")
|
|
|
+ }
|
|
|
+
|
|
|
+ if clientLogs.Len() != 0 {
|
|
|
+ t.Fatalf("client log incorect: %v", clientLogs.logs())
|
|
|
+ }
|
|
|
+
|
|
|
+ if !serverLogs.match([]string{"[ERR] yamux: keepalive failed: i/o deadline reached"}) {
|
|
|
+ t.Fatalf("server log incorect: %v", serverLogs.logs())
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func TestLargeWindow(t *testing.T) {
|
|
|
conf := DefaultConfig()
|
|
|
conf.MaxStreamWindowSize *= 2
|