Browse Source

Adds header write timeout to prevent deadlocking reads on window updates.

James Phillips 9 years ago
parent
commit
612eb1754a
4 changed files with 86 additions and 2 deletions
  1. 4 0
      const.go
  2. 7 0
      mux.go
  3. 13 1
      session.go
  4. 62 1
      session_test.go

+ 4 - 0
const.go

@@ -29,6 +29,10 @@ var (
 	// ErrReceiveWindowExceeded indicates the window was exceeded
 	ErrRecvWindowExceeded = fmt.Errorf("recv window exceeded")
 
+	// ErrHeaderWriteTimeout indicates that we hit an IO deadline waiting
+	// for a header to be written.
+	ErrHeaderWriteTimeout = fmt.Errorf("header write timeout")
+
 	// ErrTimeout is used when we reach an IO deadline
 	ErrTimeout = fmt.Errorf("i/o deadline reached")
 

+ 7 - 0
mux.go

@@ -20,6 +20,12 @@ type Config struct {
 	// KeepAliveInterval is how often to perform the keep alive
 	KeepAliveInterval time.Duration
 
+	// HeaderWriteTimeout is how long we will wait to perform a blocking
+	// operation writing a header, after which we will throw an error and
+	// close the stream. Headers are small, so this should be set to a value
+	// after which you suspect there is something wrong with the connection.
+	HeaderWriteTimeout time.Duration
+
 	// MaxStreamWindowSize is used to control the maximum
 	// window size that we allow for a stream.
 	MaxStreamWindowSize uint32
@@ -34,6 +40,7 @@ func DefaultConfig() *Config {
 		AcceptBacklog:       256,
 		EnableKeepAlive:     true,
 		KeepAliveInterval:   30 * time.Second,
+		HeaderWriteTimeout:  10 * time.Second,
 		MaxStreamWindowSize: initialStreamWindow,
 		LogOutput:           os.Stderr,
 	}

+ 13 - 1
session.go

@@ -299,19 +299,31 @@ func (s *Session) waitForSend(hdr header, body io.Reader) error {
 	return s.waitForSendErr(hdr, body, errCh)
 }
 
-// waitForSendErr waits to send a header, checking for a potential shutdown
+// waitForSendErr waits to send a header with optional data, checking for a
+// potential shutdown. If the body is not supplied then we will enforce the
+// configured HeaderWriteTimeout, since this is a small control header.
 func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error {
+	var timeout <- chan time.Time
+	if body == nil {
+		timeout = time.After(s.config.HeaderWriteTimeout)
+	}
+
 	ready := sendReady{Hdr: hdr, Body: body, Err: errCh}
 	select {
 	case s.sendCh <- ready:
 	case <-s.shutdownCh:
 		return ErrSessionShutdown
+	case <-timeout:
+		return ErrHeaderWriteTimeout
 	}
+
 	select {
 	case err := <-errCh:
 		return err
 	case <-s.shutdownCh:
 		return ErrSessionShutdown
+	case <-timeout:
+		return ErrHeaderWriteTimeout
 	}
 }
 

+ 62 - 1
session_test.go

@@ -14,6 +14,7 @@ import (
 type pipeConn struct {
 	reader *io.PipeReader
 	writer *io.PipeWriter
+	writeBlocker sync.Mutex
 }
 
 func (p *pipeConn) Read(b []byte) (int, error) {
@@ -21,6 +22,8 @@ func (p *pipeConn) Read(b []byte) (int, error) {
 }
 
 func (p *pipeConn) Write(b []byte) (int, error) {
+	p.writeBlocker.Lock()
+	defer p.writeBlocker.Unlock()
 	return p.writer.Write(b)
 }
 
@@ -32,13 +35,16 @@ func (p *pipeConn) Close() error {
 func testConn() (io.ReadWriteCloser, io.ReadWriteCloser) {
 	read1, write1 := io.Pipe()
 	read2, write2 := io.Pipe()
-	return &pipeConn{read1, write2}, &pipeConn{read2, write1}
+	conn1 := &pipeConn{reader: read1, writer: write2}
+	conn2 := &pipeConn{reader: read2, writer: write1}
+	return conn1, conn2
 }
 
 func testClientServer() (*Session, *Session) {
 	conf := DefaultConfig()
 	conf.AcceptBacklog = 64
 	conf.KeepAliveInterval = 100 * time.Millisecond
+	conf.HeaderWriteTimeout = 100 * time.Millisecond
 	return testClientServerConfig(conf)
 }
 
@@ -799,3 +805,58 @@ func TestBacklogExceeded_Accept(t *testing.T) {
 		}
 	}
 }
+
+func TestWindowUpdateWriteDuringRead(t *testing.T) {
+	client, server := testClientServer()
+	defer client.Close()
+	defer server.Close()
+
+	wg := &sync.WaitGroup{}
+	wg.Add(2)
+
+	// Choose a huge flood size that we know will result in a window update.
+	flood := int64(client.config.MaxStreamWindowSize) - 1
+
+	// The server will accept a new stream and then flood data to it.
+	go func() {
+		defer wg.Done()
+
+		stream, err := server.AcceptStream()
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+		defer stream.Close()
+
+		n, err := stream.Write(make([]byte, flood))
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+		if int64(n) != flood {
+			t.Fatalf("short write: %d", n)
+		}
+	}()
+
+	// The client will open a stream, block outbound writes, and then
+	// listen to the flood from the server, which should time out since
+	// it won't be able to send the window update.
+	go func() {
+		defer wg.Done()
+
+		stream, err := client.OpenStream()
+		if err != nil {
+			t.Fatalf("err: %v", err)
+		}
+		defer stream.Close()
+
+		conn := client.conn.(*pipeConn)
+		conn.writeBlocker.Lock()
+
+		_, err = stream.Read(make([]byte, flood))
+		if err != ErrHeaderWriteTimeout {
+			t.Fatalf("err: %v", err)
+		}
+	}()
+
+	wg.Wait()
+}
+