Просмотр исходного кода

Merge pull request #29 from hashicorp/slackpad-pr-28

Properly handles closing streams before they are ACKed.
James Phillips 8 лет назад
Родитель
Сommit
172cde3b6c
3 измененных файлов с 72 добавлено и 9 удалено
  1. 30 5
      session.go
  2. 41 0
      session_test.go
  3. 1 4
      stream.go

+ 30 - 5
session.go

@@ -46,8 +46,11 @@ type Session struct {
 	pingID   uint32
 	pingLock sync.Mutex
 
-	// streams maps a stream id to a stream
+	// streams maps a stream id to a stream, and inflight has an entry
+	// for any outgoing stream that has not yet been established. Both are
+	// protected by streamLock.
 	streams    map[uint32]*Stream
+	inflight   map[uint32]struct{}
 	streamLock sync.Mutex
 
 	// synCh acts like a semaphore. It is sized to the AcceptBacklog which
@@ -90,6 +93,7 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
 		bufRead:    bufio.NewReader(conn),
 		pings:      make(map[uint32]chan struct{}),
 		streams:    make(map[uint32]*Stream),
+		inflight:   make(map[uint32]struct{}),
 		synCh:      make(chan struct{}, config.AcceptBacklog),
 		acceptCh:   make(chan *Stream, config.AcceptBacklog),
 		sendCh:     make(chan sendReady, 64),
@@ -153,7 +157,7 @@ func (s *Session) OpenStream() (*Stream, error) {
 	}
 
 GET_ID:
-	// Get and ID, and check for stream exhaustion
+	// Get an ID, and check for stream exhaustion
 	id := atomic.LoadUint32(&s.nextStreamID)
 	if id >= math.MaxUint32-1 {
 		return nil, ErrStreamsExhausted
@@ -166,10 +170,16 @@ GET_ID:
 	stream := newStream(s, id, streamInit)
 	s.streamLock.Lock()
 	s.streams[id] = stream
+	s.inflight[id] = struct{}{}
 	s.streamLock.Unlock()
 
 	// Send the window update to create
 	if err := stream.sendWindowUpdate(); err != nil {
+		select {
+		case <-s.synCh:
+		default:
+			s.logger.Printf("[ERR] yamux: aborted stream open without inflight syn semaphore")
+		}
 		return nil, err
 	}
 	return stream, nil
@@ -580,19 +590,34 @@ func (s *Session) incomingStream(id uint32) error {
 }
 
 // closeStream is used to close a stream once both sides have
-// issued a close.
+// issued a close. If there was an in-flight SYN and the stream
+// was not yet established, then this will give the credit back.
 func (s *Session) closeStream(id uint32) {
 	s.streamLock.Lock()
+	if _, ok := s.inflight[id]; ok {
+		select {
+		case <-s.synCh:
+		default:
+			s.logger.Printf("[ERR] yamux: SYN tracking out of sync")
+		}
+	}
 	delete(s.streams, id)
 	s.streamLock.Unlock()
 }
 
 // establishStream is used to mark a stream that was in the
 // SYN Sent state as established.
-func (s *Session) establishStream() {
+func (s *Session) establishStream(id uint32) {
+	s.streamLock.Lock()
+	if _, ok := s.inflight[id]; ok {
+		delete(s.inflight, id)
+	} else {
+		s.logger.Printf("[ERR] yamux: established stream without inflight SYN (no tracking entry)")
+	}
 	select {
 	case <-s.synCh:
 	default:
-		panic("established stream without inflight syn")
+		s.logger.Printf("[ERR] yamux: established stream without inflight SYN (didn't have semaphore)")
 	}
+	s.streamLock.Unlock()
 }

+ 41 - 0
session_test.go

@@ -148,6 +148,47 @@ func TestPing_Timeout(t *testing.T) {
 	}
 }
 
+func TestCloseBeforeAck(t *testing.T) {
+	cfg := testConf()
+	cfg.AcceptBacklog = 8
+	client, server := testClientServerConfig(cfg)
+
+	defer client.Close()
+	defer server.Close()
+
+	for i := 0; i < 8; i++ {
+		s, err := client.OpenStream()
+		if err != nil {
+			t.Fatal(err)
+		}
+		s.Close()
+	}
+
+	for i := 0; i < 8; i++ {
+		s, err := server.AcceptStream()
+		if err != nil {
+			t.Fatal(err)
+		}
+		s.Close()
+	}
+
+	done := make(chan struct{})
+	go func() {
+		defer close(done)
+		s, err := client.OpenStream()
+		if err != nil {
+			t.Fatal(err)
+		}
+		s.Close()
+	}()
+
+	select {
+	case <-done:
+	case <-time.After(time.Second * 5):
+		t.Fatal("timed out trying to open stream")
+	}
+}
+
 func TestAccept(t *testing.T) {
 	client, server := testClientServer()
 	defer client.Close()

+ 1 - 4
stream.go

@@ -327,7 +327,7 @@ func (s *Stream) processFlags(flags uint16) error {
 		if s.state == streamSYNSent {
 			s.state = streamEstablished
 		}
-		s.session.establishStream()
+		s.session.establishStream(s.id)
 	}
 	if flags&flagFIN == flagFIN {
 		switch s.state {
@@ -348,9 +348,6 @@ func (s *Stream) processFlags(flags uint16) error {
 		}
 	}
 	if flags&flagRST == flagRST {
-		if s.state == streamSYNSent {
-			s.session.establishStream()
-		}
 		s.state = streamReset
 		closeStream = true
 		s.notifyWaiting()