|
@@ -4,6 +4,7 @@ import (
|
|
|
"bytes"
|
|
|
"io"
|
|
|
"sync"
|
|
|
+ "sync/atomic"
|
|
|
"time"
|
|
|
)
|
|
|
|
|
@@ -22,17 +23,20 @@ const (
|
|
|
|
|
|
|
|
|
type Stream struct {
|
|
|
+ recvWindow uint32
|
|
|
+ sendWindow uint32
|
|
|
+
|
|
|
id uint32
|
|
|
session *Session
|
|
|
|
|
|
- state streamState
|
|
|
- lock sync.Mutex
|
|
|
+ state streamState
|
|
|
+ stateLock sync.Mutex
|
|
|
|
|
|
- recvBuf bytes.Buffer
|
|
|
- sendHdr header
|
|
|
+ recvBuf bytes.Buffer
|
|
|
+ recvLock sync.Mutex
|
|
|
|
|
|
- recvWindow uint32
|
|
|
- sendWindow uint32
|
|
|
+ sendHdr header
|
|
|
+ sendLock sync.Mutex
|
|
|
|
|
|
notifyCh chan struct{}
|
|
|
|
|
@@ -68,29 +72,31 @@ func (s *Stream) StreamID() uint32 {
|
|
|
|
|
|
func (s *Stream) Read(b []byte) (n int, err error) {
|
|
|
START:
|
|
|
- s.lock.Lock()
|
|
|
+ s.stateLock.Lock()
|
|
|
switch s.state {
|
|
|
case streamRemoteClose:
|
|
|
fallthrough
|
|
|
case streamClosed:
|
|
|
if s.recvBuf.Len() == 0 {
|
|
|
- s.lock.Unlock()
|
|
|
+ s.stateLock.Unlock()
|
|
|
return 0, io.EOF
|
|
|
}
|
|
|
}
|
|
|
+ s.stateLock.Unlock()
|
|
|
|
|
|
|
|
|
+ s.recvLock.Lock()
|
|
|
if s.recvBuf.Len() == 0 {
|
|
|
- s.lock.Unlock()
|
|
|
+ s.recvLock.Unlock()
|
|
|
goto WAIT
|
|
|
}
|
|
|
|
|
|
|
|
|
n, _ = s.recvBuf.Read(b)
|
|
|
+ s.recvLock.Unlock()
|
|
|
|
|
|
|
|
|
err = s.sendWindowUpdate()
|
|
|
- s.lock.Unlock()
|
|
|
return n, err
|
|
|
|
|
|
WAIT:
|
|
@@ -127,18 +133,22 @@ func (s *Stream) write(b []byte) (n int, err error) {
|
|
|
var max uint32
|
|
|
var body io.Reader
|
|
|
START:
|
|
|
- s.lock.Lock()
|
|
|
+ s.stateLock.Lock()
|
|
|
switch s.state {
|
|
|
case streamLocalClose:
|
|
|
fallthrough
|
|
|
case streamClosed:
|
|
|
- s.lock.Unlock()
|
|
|
+ s.stateLock.Unlock()
|
|
|
return 0, ErrStreamClosed
|
|
|
}
|
|
|
+ s.stateLock.Unlock()
|
|
|
+
|
|
|
+
|
|
|
+ s.sendLock.Lock()
|
|
|
|
|
|
|
|
|
- if s.sendWindow == 0 {
|
|
|
- s.lock.Unlock()
|
|
|
+ if atomic.LoadUint32(&s.sendWindow) == 0 {
|
|
|
+ s.sendLock.Unlock()
|
|
|
goto WAIT
|
|
|
}
|
|
|
|
|
@@ -152,15 +162,15 @@ START:
|
|
|
|
|
|
s.sendHdr.encode(typeData, flags, s.id, max)
|
|
|
if err := s.session.waitForSend(s.sendHdr, body); err != nil {
|
|
|
- s.lock.Unlock()
|
|
|
+ s.sendLock.Unlock()
|
|
|
return 0, err
|
|
|
}
|
|
|
|
|
|
|
|
|
- s.sendWindow -= max
|
|
|
+ atomic.AddUint32(&s.sendWindow, ^uint32(max-1))
|
|
|
+ s.sendLock.Unlock()
|
|
|
|
|
|
|
|
|
- s.lock.Unlock()
|
|
|
return int(max), err
|
|
|
|
|
|
WAIT:
|
|
@@ -181,7 +191,8 @@ WAIT:
|
|
|
|
|
|
|
|
|
func (s *Stream) sendFlags() uint16 {
|
|
|
-
|
|
|
+ s.stateLock.Lock()
|
|
|
+ defer s.stateLock.Unlock()
|
|
|
var flags uint16
|
|
|
switch s.state {
|
|
|
case streamInit:
|
|
@@ -233,23 +244,8 @@ func (s *Stream) sendClose() error {
|
|
|
|
|
|
|
|
|
func (s *Stream) Close() error {
|
|
|
- s.lock.Lock()
|
|
|
- defer s.lock.Unlock()
|
|
|
-
|
|
|
+ s.stateLock.Lock()
|
|
|
switch s.state {
|
|
|
-
|
|
|
- case streamLocalClose:
|
|
|
- fallthrough
|
|
|
- case streamClosed:
|
|
|
- return nil
|
|
|
-
|
|
|
-
|
|
|
- case streamRemoteClose:
|
|
|
- s.state = streamClosed
|
|
|
- s.session.closeStream(s.id, false)
|
|
|
- s.sendClose()
|
|
|
- return nil
|
|
|
-
|
|
|
|
|
|
case streamSYNSent:
|
|
|
fallthrough
|
|
@@ -257,23 +253,39 @@ func (s *Stream) Close() error {
|
|
|
fallthrough
|
|
|
case streamEstablished:
|
|
|
s.state = streamLocalClose
|
|
|
- s.sendClose()
|
|
|
- return nil
|
|
|
+ goto SEND_CLOSE
|
|
|
+
|
|
|
+ case streamLocalClose:
|
|
|
+ case streamRemoteClose:
|
|
|
+ s.state = streamClosed
|
|
|
+ s.session.closeStream(s.id, false)
|
|
|
+ goto SEND_CLOSE
|
|
|
+
|
|
|
+ case streamClosed:
|
|
|
+ default:
|
|
|
+ panic("unhandled state")
|
|
|
}
|
|
|
- panic("unhandled state")
|
|
|
+ s.stateLock.Unlock()
|
|
|
+ return nil
|
|
|
+SEND_CLOSE:
|
|
|
+ s.stateLock.Unlock()
|
|
|
+ s.sendClose()
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
|
|
|
func (s *Stream) forceClose() {
|
|
|
- s.lock.Lock()
|
|
|
- defer s.lock.Unlock()
|
|
|
+ s.stateLock.Lock()
|
|
|
s.state = streamClosed
|
|
|
+ s.stateLock.Unlock()
|
|
|
asyncNotify(s.notifyCh)
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (s *Stream) processFlags(flags uint16) error {
|
|
|
+ s.stateLock.Lock()
|
|
|
+ defer s.stateLock.Unlock()
|
|
|
if flags&flagACK == flagACK {
|
|
|
if s.state == streamSYNSent {
|
|
|
s.state = streamEstablished
|
|
@@ -302,42 +314,43 @@ func (s *Stream) processFlags(flags uint16) error {
|
|
|
|
|
|
|
|
|
func (s *Stream) incrSendWindow(hdr header, flags uint16) error {
|
|
|
- s.lock.Lock()
|
|
|
- defer s.lock.Unlock()
|
|
|
if err := s.processFlags(flags); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
|
|
|
- s.sendWindow += hdr.Length()
|
|
|
+ atomic.AddUint32(&s.sendWindow, hdr.Length())
|
|
|
asyncNotify(s.notifyCh)
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
|
|
|
func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
|
|
|
- s.lock.Lock()
|
|
|
- defer s.lock.Unlock()
|
|
|
if err := s.processFlags(flags); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
|
|
|
length := hdr.Length()
|
|
|
- if length > s.recvWindow {
|
|
|
+ if length == 0 {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ if length > atomic.LoadUint32(&s.recvWindow) {
|
|
|
return ErrRecvWindowExceeded
|
|
|
}
|
|
|
|
|
|
|
|
|
- s.recvWindow -= length
|
|
|
+ atomic.AddUint32(&s.recvWindow, ^uint32(length-1))
|
|
|
|
|
|
|
|
|
conn = &io.LimitedReader{R: conn, N: int64(length)}
|
|
|
|
|
|
|
|
|
+ s.recvLock.Lock()
|
|
|
if _, err := io.Copy(&s.recvBuf, conn); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
+ s.recvLock.Unlock()
|
|
|
|
|
|
|
|
|
asyncNotify(s.notifyCh)
|