Quellcode durchsuchen

Modifies SYN credit system to depend only on the opener. Changes panic to a error.

James Phillips vor 8 Jahren
Ursprung
Commit
85decf1188
2 geänderte Dateien mit 28 neuen und 16 gelöschten Zeilen
  1. 25 5
      session.go
  2. 3 11
      stream.go

+ 25 - 5
session.go

@@ -46,8 +46,11 @@ type Session struct {
 	pingID   uint32
 	pingID   uint32
 	pingLock sync.Mutex
 	pingLock sync.Mutex
 
 
-	// streams maps a stream id to a stream
+	// streams maps a stream id to a stream, and inflight has an entry
+	// for any outgoing stream that has not yet been established. Both are
+	// protected by streamLock.
 	streams    map[uint32]*Stream
 	streams    map[uint32]*Stream
+	inflight   map[uint32]struct{}
 	streamLock sync.Mutex
 	streamLock sync.Mutex
 
 
 	// synCh acts like a semaphore. It is sized to the AcceptBacklog which
 	// synCh acts like a semaphore. It is sized to the AcceptBacklog which
@@ -90,6 +93,7 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
 		bufRead:    bufio.NewReader(conn),
 		bufRead:    bufio.NewReader(conn),
 		pings:      make(map[uint32]chan struct{}),
 		pings:      make(map[uint32]chan struct{}),
 		streams:    make(map[uint32]*Stream),
 		streams:    make(map[uint32]*Stream),
+		inflight:   make(map[uint32]struct{}),
 		synCh:      make(chan struct{}, config.AcceptBacklog),
 		synCh:      make(chan struct{}, config.AcceptBacklog),
 		acceptCh:   make(chan *Stream, config.AcceptBacklog),
 		acceptCh:   make(chan *Stream, config.AcceptBacklog),
 		sendCh:     make(chan sendReady, 64),
 		sendCh:     make(chan sendReady, 64),
@@ -153,7 +157,7 @@ func (s *Session) OpenStream() (*Stream, error) {
 	}
 	}
 
 
 GET_ID:
 GET_ID:
-	// Get and ID, and check for stream exhaustion
+	// Get an ID, and check for stream exhaustion
 	id := atomic.LoadUint32(&s.nextStreamID)
 	id := atomic.LoadUint32(&s.nextStreamID)
 	if id >= math.MaxUint32-1 {
 	if id >= math.MaxUint32-1 {
 		return nil, ErrStreamsExhausted
 		return nil, ErrStreamsExhausted
@@ -166,6 +170,7 @@ GET_ID:
 	stream := newStream(s, id, streamInit)
 	stream := newStream(s, id, streamInit)
 	s.streamLock.Lock()
 	s.streamLock.Lock()
 	s.streams[id] = stream
 	s.streams[id] = stream
+	s.inflight[id] = struct{}{}
 	s.streamLock.Unlock()
 	s.streamLock.Unlock()
 
 
 	// Send the window update to create
 	// Send the window update to create
@@ -580,19 +585,34 @@ func (s *Session) incomingStream(id uint32) error {
 }
 }
 
 
 // closeStream is used to close a stream once both sides have
 // closeStream is used to close a stream once both sides have
-// issued a close.
+// issued a close. If there was an in-flight SYN and the stream
+// was not yet established, then this will give the credit back.
 func (s *Session) closeStream(id uint32) {
 func (s *Session) closeStream(id uint32) {
 	s.streamLock.Lock()
 	s.streamLock.Lock()
+	if _, ok := s.inflight[id]; ok {
+		select {
+		case <-s.synCh:
+		default:
+			s.logger.Printf("[ERR] un-established stream without inflight syn semaphore")
+		}
+	}
 	delete(s.streams, id)
 	delete(s.streams, id)
 	s.streamLock.Unlock()
 	s.streamLock.Unlock()
 }
 }
 
 
 // establishStream is used to mark a stream that was in the
 // establishStream is used to mark a stream that was in the
 // SYN Sent state as established.
 // SYN Sent state as established.
-func (s *Session) establishStream() {
+func (s *Session) establishStream(id uint32) {
+	s.streamLock.Lock()
+	if _, ok := s.inflight[id]; ok {
+		delete(s.inflight, id)
+	} else {
+		s.logger.Printf("[ERR] established stream without inflight syn entry")
+	}
 	select {
 	select {
 	case <-s.synCh:
 	case <-s.synCh:
 	default:
 	default:
-		panic("established stream without inflight syn")
+		s.logger.Printf("[ERR] established stream without inflight syn semaphore")
 	}
 	}
+	s.streamLock.Unlock()
 }
 }

+ 3 - 11
stream.go

@@ -17,7 +17,6 @@ const (
 	streamEstablished
 	streamEstablished
 	streamLocalClose
 	streamLocalClose
 	streamRemoteClose
 	streamRemoteClose
-	streamEarlyClose
 	streamClosed
 	streamClosed
 	streamReset
 	streamReset
 )
 )
@@ -219,9 +218,6 @@ func (s *Stream) sendFlags() uint16 {
 	case streamSYNReceived:
 	case streamSYNReceived:
 		flags |= flagACK
 		flags |= flagACK
 		s.state = streamEstablished
 		s.state = streamEstablished
-	case streamEarlyClose:
-		flags |= flagACK
-		s.state = streamRemoteClose
 	}
 	}
 	return flags
 	return flags
 }
 }
@@ -331,15 +327,14 @@ func (s *Stream) processFlags(flags uint16) error {
 		if s.state == streamSYNSent {
 		if s.state == streamSYNSent {
 			s.state = streamEstablished
 			s.state = streamEstablished
 		}
 		}
-		s.session.establishStream()
+		s.session.establishStream(s.id)
 	}
 	}
 	if flags&flagFIN == flagFIN {
 	if flags&flagFIN == flagFIN {
 		switch s.state {
 		switch s.state {
-		case streamSYNReceived:
-			s.state = streamEarlyClose
-			s.notifyWaiting()
 		case streamSYNSent:
 		case streamSYNSent:
 			fallthrough
 			fallthrough
+		case streamSYNReceived:
+			fallthrough
 		case streamEstablished:
 		case streamEstablished:
 			s.state = streamRemoteClose
 			s.state = streamRemoteClose
 			s.notifyWaiting()
 			s.notifyWaiting()
@@ -353,9 +348,6 @@ func (s *Stream) processFlags(flags uint16) error {
 		}
 		}
 	}
 	}
 	if flags&flagRST == flagRST {
 	if flags&flagRST == flagRST {
-		if s.state == streamSYNSent {
-			s.session.establishStream()
-		}
 		s.state = streamReset
 		s.state = streamReset
 		closeStream = true
 		closeStream = true
 		s.notifyWaiting()
 		s.notifyWaiting()