Quellcode durchsuchen

Increase safety of GoAway

Armon Dadgar vor 10 Jahren
Ursprung
Commit
dcc15f74df
1 geänderte Dateien mit 17 neuen und 16 gelöschten Zeilen
  1. 17 16
      session.go

+ 17 - 16
session.go

@@ -6,12 +6,21 @@ import (
 	"math"
 	"net"
 	"sync"
+	"sync/atomic"
 	"time"
 )
 
 // Session is used to wrap a reliable ordered connection and to
 // multiplex it into multiple streams.
 type Session struct {
+	// remoteGoAway indicates the remote side does
+	// not want futher connections. Must be first for alignment.
+	remoteGoAway int32
+
+	// localGoAway indicates that we should stop
+	// accepting futher connections. Must be first for alignment.
+	localGoAway int32
+
 	// client is true if we are a client size connection
 	client bool
 
@@ -26,14 +35,6 @@ type Session struct {
 	pingID   uint32
 	pingLock sync.Mutex
 
-	// remoteGoAway indicates the remote side does
-	// not want futher connections
-	remoteGoAway bool
-
-	// localGoAway indicates that we should stop
-	// accepting futher connections
-	localGoAway bool
-
 	// nextStreamID is the next stream we should
 	// send. This depends if we are a client/server.
 	nextStreamID uint32
@@ -89,8 +90,8 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
 	return s
 }
 
-// isShutdown does a safe check to see if we have shutdown
-func (s *Session) isShutdown() bool {
+// IsClosed does a safe check to see if we have shutdown
+func (s *Session) IsClosed() bool {
 	select {
 	case <-s.shutdownCh:
 		return true
@@ -101,10 +102,10 @@ func (s *Session) isShutdown() bool {
 
 // Open is used to create a new stream
 func (s *Session) Open() (*Stream, error) {
-	if s.isShutdown() {
+	if s.IsClosed() {
 		return nil, ErrSessionShutdown
 	}
-	if s.remoteGoAway {
+	if atomic.LoadInt32(&s.remoteGoAway) == 1 {
 		return nil, ErrRemoteGoAway
 	}
 
@@ -170,7 +171,7 @@ func (s *Session) Close() error {
 // GoAway can be used to prevent accepting further
 // connections. It does not close the underlying conn.
 func (s *Session) GoAway() error {
-	s.localGoAway = true
+	atomic.SwapInt32(&s.localGoAway, 1)
 	s.goAway(goAwayNormal)
 	return nil
 }
@@ -286,7 +287,7 @@ func (s *Session) send() {
 // recv is a long running goroutine that accepts new data
 func (s *Session) recv() {
 	hdr := header(make([]byte, headerSize))
-	for !s.isShutdown() {
+	for !s.IsClosed() {
 		// Read the header
 		if _, err := io.ReadFull(s.conn, hdr); err != nil {
 			s.exitErr(err)
@@ -393,7 +394,7 @@ func (s *Session) handleGoAway(hdr header) error {
 	code := hdr.Length()
 	switch code {
 	case goAwayNormal:
-		s.remoteGoAway = true
+		atomic.SwapInt32(&s.remoteGoAway, 1)
 	case goAwayProtoErr:
 		return fmt.Errorf("yamux protocol error")
 	case goAwayInternalErr:
@@ -421,7 +422,7 @@ func (s *Session) goAway(reason uint32) {
 // incomingStream is used to create a new incoming stream
 func (s *Session) incomingStream(id uint32) error {
 	// Reject immediately if we are doing a go away
-	if s.localGoAway {
+	if atomic.LoadInt32(&s.localGoAway) == 1 {
 		hdr := header(make([]byte, headerSize))
 		hdr.encode(typeWindowUpdate, flagRST, id, 0)
 		s.sendNoWait(hdr)