|
@@ -41,7 +41,7 @@ type Session struct {
|
|
|
|
|
|
|
|
|
streams map[uint32]*Stream
|
|
|
- streamLock sync.RWMutex
|
|
|
+ streamLock sync.Mutex
|
|
|
|
|
|
|
|
|
acceptCh chan *Stream
|
|
@@ -168,12 +168,25 @@ func (s *Session) Close() error {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
+
|
|
|
+
|
|
|
+func (s *Session) exitErr(err error) {
|
|
|
+ s.shutdownErr = err
|
|
|
+ s.Close()
|
|
|
+}
|
|
|
+
|
|
|
|
|
|
|
|
|
func (s *Session) GoAway() error {
|
|
|
+ return s.waitForSend(s.goAway(goAwayNormal), nil)
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+func (s *Session) goAway(reason uint32) header {
|
|
|
atomic.SwapInt32(&s.localGoAway, 1)
|
|
|
- s.goAway(goAwayNormal)
|
|
|
- return nil
|
|
|
+ hdr := header(make([]byte, headerSize))
|
|
|
+ hdr.encode(typeGoAway, 0, 0, reason)
|
|
|
+ return hdr
|
|
|
}
|
|
|
|
|
|
|
|
@@ -249,7 +262,7 @@ func (s *Session) sendNoWait(hdr header) error {
|
|
|
|
|
|
|
|
|
func (s *Session) send() {
|
|
|
- for {
|
|
|
+ for !s.IsClosed() {
|
|
|
select {
|
|
|
case ready := <-s.sendCh:
|
|
|
|
|
@@ -258,8 +271,8 @@ func (s *Session) send() {
|
|
|
for sent < len(ready.Hdr) {
|
|
|
n, err := s.conn.Write(ready.Hdr[sent:])
|
|
|
if err != nil {
|
|
|
- s.exitErr(err)
|
|
|
asyncSendErr(ready.Err, err)
|
|
|
+ s.exitErr(err)
|
|
|
return
|
|
|
}
|
|
|
sent += n
|
|
@@ -270,8 +283,8 @@ func (s *Session) send() {
|
|
|
if ready.Body != nil {
|
|
|
_, err := io.Copy(s.conn, ready.Body)
|
|
|
if err != nil {
|
|
|
- s.exitErr(err)
|
|
|
asyncSendErr(ready.Err, err)
|
|
|
+ s.exitErr(err)
|
|
|
return
|
|
|
}
|
|
|
}
|
|
@@ -287,6 +300,7 @@ func (s *Session) send() {
|
|
|
|
|
|
func (s *Session) recv() {
|
|
|
hdr := header(make([]byte, headerSize))
|
|
|
+ var handler func(header) error
|
|
|
for !s.IsClosed() {
|
|
|
|
|
|
if _, err := io.ReadFull(s.conn, hdr); err != nil {
|
|
@@ -301,29 +315,25 @@ func (s *Session) recv() {
|
|
|
}
|
|
|
|
|
|
|
|
|
- msgType := hdr.MsgType()
|
|
|
- switch msgType {
|
|
|
+ switch hdr.MsgType() {
|
|
|
case typeData:
|
|
|
- fallthrough
|
|
|
+ handler = s.handleStreamMessage
|
|
|
case typeWindowUpdate:
|
|
|
- if err := s.handleStreamMessage(hdr); err != nil {
|
|
|
- s.exitErr(err)
|
|
|
- return
|
|
|
- }
|
|
|
+ handler = s.handleStreamMessage
|
|
|
case typeGoAway:
|
|
|
- if err := s.handleGoAway(hdr); err != nil {
|
|
|
- s.exitErr(err)
|
|
|
- return
|
|
|
- }
|
|
|
+ handler = s.handleGoAway
|
|
|
case typePing:
|
|
|
- if err := s.handlePing(hdr); err != nil {
|
|
|
- s.exitErr(err)
|
|
|
- return
|
|
|
- }
|
|
|
+ handler = s.handlePing
|
|
|
default:
|
|
|
s.exitErr(ErrInvalidMsgType)
|
|
|
return
|
|
|
}
|
|
|
+
|
|
|
+
|
|
|
+ if err := handler(hdr); err != nil {
|
|
|
+ s.exitErr(err)
|
|
|
+ return
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -339,27 +349,28 @@ func (s *Session) handleStreamMessage(hdr header) error {
|
|
|
}
|
|
|
|
|
|
|
|
|
- s.streamLock.RLock()
|
|
|
+ s.streamLock.Lock()
|
|
|
stream := s.streams[id]
|
|
|
- s.streamLock.RUnlock()
|
|
|
+ s.streamLock.Unlock()
|
|
|
|
|
|
|
|
|
if stream == nil {
|
|
|
- s.goAway(goAwayProtoErr)
|
|
|
+ s.sendNoWait(s.goAway(goAwayProtoErr))
|
|
|
return ErrMissingStream
|
|
|
}
|
|
|
|
|
|
|
|
|
if hdr.MsgType() == typeWindowUpdate {
|
|
|
if err := stream.incrSendWindow(hdr, flags); err != nil {
|
|
|
- s.goAway(goAwayProtoErr)
|
|
|
+ s.sendNoWait(s.goAway(goAwayProtoErr))
|
|
|
return err
|
|
|
}
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
|
|
|
if err := stream.readData(hdr, flags, s.conn); err != nil {
|
|
|
- s.goAway(goAwayProtoErr)
|
|
|
+ s.sendNoWait(s.goAway(goAwayProtoErr))
|
|
|
return err
|
|
|
}
|
|
|
return nil
|
|
@@ -405,28 +416,13 @@ func (s *Session) handleGoAway(hdr header) error {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-
|
|
|
-
|
|
|
-func (s *Session) exitErr(err error) {
|
|
|
- s.shutdownErr = err
|
|
|
- s.Close()
|
|
|
-}
|
|
|
-
|
|
|
-
|
|
|
-func (s *Session) goAway(reason uint32) {
|
|
|
- hdr := header(make([]byte, headerSize))
|
|
|
- hdr.encode(typeGoAway, 0, 0, reason)
|
|
|
- s.sendNoWait(hdr)
|
|
|
-}
|
|
|
-
|
|
|
|
|
|
func (s *Session) incomingStream(id uint32) error {
|
|
|
|
|
|
if atomic.LoadInt32(&s.localGoAway) == 1 {
|
|
|
hdr := header(make([]byte, headerSize))
|
|
|
hdr.encode(typeWindowUpdate, flagRST, id, 0)
|
|
|
- s.sendNoWait(hdr)
|
|
|
- return nil
|
|
|
+ return s.waitForSend(hdr, nil)
|
|
|
}
|
|
|
|
|
|
s.streamLock.Lock()
|
|
@@ -434,9 +430,8 @@ func (s *Session) incomingStream(id uint32) error {
|
|
|
|
|
|
|
|
|
if _, ok := s.streams[id]; ok {
|
|
|
- s.goAway(goAwayProtoErr)
|
|
|
- s.exitErr(ErrDuplicateStream)
|
|
|
- return nil
|
|
|
+ s.sendNoWait(s.goAway(goAwayProtoErr))
|
|
|
+ return ErrDuplicateStream
|
|
|
}
|
|
|
|
|
|
|