|
@@ -6,12 +6,21 @@ import (
|
|
|
"math"
|
|
|
"net"
|
|
|
"sync"
|
|
|
+ "sync/atomic"
|
|
|
"time"
|
|
|
)
|
|
|
|
|
|
// Session is used to wrap a reliable ordered connection and to
|
|
|
// multiplex it into multiple streams.
|
|
|
type Session struct {
|
|
|
+ // remoteGoAway indicates the remote side does
|
|
|
+ // not want futher connections. Must be first for alignment.
|
|
|
+ remoteGoAway int32
|
|
|
+
|
|
|
+ // localGoAway indicates that we should stop
|
|
|
+ // accepting futher connections. Must be first for alignment.
|
|
|
+ localGoAway int32
|
|
|
+
|
|
|
// client is true if we are a client size connection
|
|
|
client bool
|
|
|
|
|
@@ -26,14 +35,6 @@ type Session struct {
|
|
|
pingID uint32
|
|
|
pingLock sync.Mutex
|
|
|
|
|
|
- // remoteGoAway indicates the remote side does
|
|
|
- // not want futher connections
|
|
|
- remoteGoAway bool
|
|
|
-
|
|
|
- // localGoAway indicates that we should stop
|
|
|
- // accepting futher connections
|
|
|
- localGoAway bool
|
|
|
-
|
|
|
// nextStreamID is the next stream we should
|
|
|
// send. This depends if we are a client/server.
|
|
|
nextStreamID uint32
|
|
@@ -89,8 +90,8 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
|
|
|
return s
|
|
|
}
|
|
|
|
|
|
-// isShutdown does a safe check to see if we have shutdown
|
|
|
-func (s *Session) isShutdown() bool {
|
|
|
+// IsClosed does a safe check to see if we have shutdown
|
|
|
+func (s *Session) IsClosed() bool {
|
|
|
select {
|
|
|
case <-s.shutdownCh:
|
|
|
return true
|
|
@@ -101,10 +102,10 @@ func (s *Session) isShutdown() bool {
|
|
|
|
|
|
// Open is used to create a new stream
|
|
|
func (s *Session) Open() (*Stream, error) {
|
|
|
- if s.isShutdown() {
|
|
|
+ if s.IsClosed() {
|
|
|
return nil, ErrSessionShutdown
|
|
|
}
|
|
|
- if s.remoteGoAway {
|
|
|
+ if atomic.LoadInt32(&s.remoteGoAway) == 1 {
|
|
|
return nil, ErrRemoteGoAway
|
|
|
}
|
|
|
|
|
@@ -170,7 +171,7 @@ func (s *Session) Close() error {
|
|
|
// GoAway can be used to prevent accepting further
|
|
|
// connections. It does not close the underlying conn.
|
|
|
func (s *Session) GoAway() error {
|
|
|
- s.localGoAway = true
|
|
|
+ atomic.SwapInt32(&s.localGoAway, 1)
|
|
|
s.goAway(goAwayNormal)
|
|
|
return nil
|
|
|
}
|
|
@@ -286,7 +287,7 @@ func (s *Session) send() {
|
|
|
// recv is a long running goroutine that accepts new data
|
|
|
func (s *Session) recv() {
|
|
|
hdr := header(make([]byte, headerSize))
|
|
|
- for !s.isShutdown() {
|
|
|
+ for !s.IsClosed() {
|
|
|
// Read the header
|
|
|
if _, err := io.ReadFull(s.conn, hdr); err != nil {
|
|
|
s.exitErr(err)
|
|
@@ -393,7 +394,7 @@ func (s *Session) handleGoAway(hdr header) error {
|
|
|
code := hdr.Length()
|
|
|
switch code {
|
|
|
case goAwayNormal:
|
|
|
- s.remoteGoAway = true
|
|
|
+ atomic.SwapInt32(&s.remoteGoAway, 1)
|
|
|
case goAwayProtoErr:
|
|
|
return fmt.Errorf("yamux protocol error")
|
|
|
case goAwayInternalErr:
|
|
@@ -421,7 +422,7 @@ func (s *Session) goAway(reason uint32) {
|
|
|
// incomingStream is used to create a new incoming stream
|
|
|
func (s *Session) incomingStream(id uint32) error {
|
|
|
// Reject immediately if we are doing a go away
|
|
|
- if s.localGoAway {
|
|
|
+ if atomic.LoadInt32(&s.localGoAway) == 1 {
|
|
|
hdr := header(make([]byte, headerSize))
|
|
|
hdr.encode(typeWindowUpdate, flagRST, id, 0)
|
|
|
s.sendNoWait(hdr)
|