Armon Dadgar 10 years ago
parent
commit
154c0d95c8
2 changed files with 63 additions and 70 deletions
  1. 40 45
      session.go
  2. 23 25
      stream.go

+ 40 - 45
session.go

@@ -41,7 +41,7 @@ type Session struct {
 
 	// streams maps a stream id to a stream
 	streams    map[uint32]*Stream
-	streamLock sync.RWMutex
+	streamLock sync.Mutex
 
 	// acceptCh is used to pass ready streams to the client
 	acceptCh chan *Stream
@@ -168,12 +168,25 @@ func (s *Session) Close() error {
 	return nil
 }
 
+// exitErr is used to handle an error that is causing the
+// session to terminate.
+func (s *Session) exitErr(err error) {
+	s.shutdownErr = err
+	s.Close()
+}
+
 // GoAway can be used to prevent accepting further
 // connections. It does not close the underlying conn.
 func (s *Session) GoAway() error {
+	return s.waitForSend(s.goAway(goAwayNormal), nil)
+}
+
+// goAway is used to send a goAway message
+func (s *Session) goAway(reason uint32) header {
 	atomic.SwapInt32(&s.localGoAway, 1)
-	s.goAway(goAwayNormal)
-	return nil
+	hdr := header(make([]byte, headerSize))
+	hdr.encode(typeGoAway, 0, 0, reason)
+	return hdr
 }
 
 // Ping is used to measure the RTT response time
@@ -249,7 +262,7 @@ func (s *Session) sendNoWait(hdr header) error {
 
 // send is a long running goroutine that sends data
 func (s *Session) send() {
-	for {
+	for !s.IsClosed() {
 		select {
 		case ready := <-s.sendCh:
 			// Send a header if ready
@@ -258,8 +271,8 @@ func (s *Session) send() {
 				for sent < len(ready.Hdr) {
 					n, err := s.conn.Write(ready.Hdr[sent:])
 					if err != nil {
-						s.exitErr(err)
 						asyncSendErr(ready.Err, err)
+						s.exitErr(err)
 						return
 					}
 					sent += n
@@ -270,8 +283,8 @@ func (s *Session) send() {
 			if ready.Body != nil {
 				_, err := io.Copy(s.conn, ready.Body)
 				if err != nil {
-					s.exitErr(err)
 					asyncSendErr(ready.Err, err)
+					s.exitErr(err)
 					return
 				}
 			}
@@ -287,6 +300,7 @@ func (s *Session) send() {
 // recv is a long running goroutine that accepts new data
 func (s *Session) recv() {
 	hdr := header(make([]byte, headerSize))
+	var handler func(header) error
 	for !s.IsClosed() {
 		// Read the header
 		if _, err := io.ReadFull(s.conn, hdr); err != nil {
@@ -301,29 +315,25 @@ func (s *Session) recv() {
 		}
 
 		// Switch on the type
-		msgType := hdr.MsgType()
-		switch msgType {
+		switch hdr.MsgType() {
 		case typeData:
-			fallthrough
+			handler = s.handleStreamMessage
 		case typeWindowUpdate:
-			if err := s.handleStreamMessage(hdr); err != nil {
-				s.exitErr(err)
-				return
-			}
+			handler = s.handleStreamMessage
 		case typeGoAway:
-			if err := s.handleGoAway(hdr); err != nil {
-				s.exitErr(err)
-				return
-			}
+			handler = s.handleGoAway
 		case typePing:
-			if err := s.handlePing(hdr); err != nil {
-				s.exitErr(err)
-				return
-			}
+			handler = s.handlePing
 		default:
 			s.exitErr(ErrInvalidMsgType)
 			return
 		}
+
+		// Invoke the handler
+		if err := handler(hdr); err != nil {
+			s.exitErr(err)
+			return
+		}
 	}
 }
 
@@ -339,27 +349,28 @@ func (s *Session) handleStreamMessage(hdr header) error {
 	}
 
 	// Get the stream
-	s.streamLock.RLock()
+	s.streamLock.Lock()
 	stream := s.streams[id]
-	s.streamLock.RUnlock()
+	s.streamLock.Unlock()
 
 	// Make sure we have a stream
 	if stream == nil {
-		s.goAway(goAwayProtoErr)
+		s.sendNoWait(s.goAway(goAwayProtoErr))
 		return ErrMissingStream
 	}
 
 	// Check if this is a window update
 	if hdr.MsgType() == typeWindowUpdate {
 		if err := stream.incrSendWindow(hdr, flags); err != nil {
-			s.goAway(goAwayProtoErr)
+			s.sendNoWait(s.goAway(goAwayProtoErr))
 			return err
 		}
+		return nil
 	}
 
 	// Read the new data
 	if err := stream.readData(hdr, flags, s.conn); err != nil {
-		s.goAway(goAwayProtoErr)
+		s.sendNoWait(s.goAway(goAwayProtoErr))
 		return err
 	}
 	return nil
@@ -405,28 +416,13 @@ func (s *Session) handleGoAway(hdr header) error {
 	return nil
 }
 
-// exitErr is used to handle an error that is causing
-// the listener to exit.
-func (s *Session) exitErr(err error) {
-	s.shutdownErr = err
-	s.Close()
-}
-
-// goAway is used to send a goAway message
-func (s *Session) goAway(reason uint32) {
-	hdr := header(make([]byte, headerSize))
-	hdr.encode(typeGoAway, 0, 0, reason)
-	s.sendNoWait(hdr)
-}
-
 // incomingStream is used to create a new incoming stream
 func (s *Session) incomingStream(id uint32) error {
 	// Reject immediately if we are doing a go away
 	if atomic.LoadInt32(&s.localGoAway) == 1 {
 		hdr := header(make([]byte, headerSize))
 		hdr.encode(typeWindowUpdate, flagRST, id, 0)
-		s.sendNoWait(hdr)
-		return nil
+		return s.waitForSend(hdr, nil)
 	}
 
 	s.streamLock.Lock()
@@ -434,9 +430,8 @@ func (s *Session) incomingStream(id uint32) error {
 
 	// Check if stream already exists
 	if _, ok := s.streams[id]; ok {
-		s.goAway(goAwayProtoErr)
-		s.exitErr(ErrDuplicateStream)
-		return nil
+		s.sendNoWait(s.goAway(goAwayProtoErr))
+		return ErrDuplicateStream
 	}
 
 	// Register the stream

+ 23 - 25
stream.go

@@ -4,7 +4,6 @@ import (
 	"bytes"
 	"compress/lzw"
 	"io"
-	"log"
 	"sync"
 	"time"
 )
@@ -218,7 +217,6 @@ func (s *Stream) sendWindowUpdate() error {
 	if err := s.session.waitForSend(s.sendHdr, nil); err != nil {
 		return err
 	}
-	log.Printf("Window Update %d +%d", s.id, delta)
 
 	// Update our window
 	s.recvWindow += delta
@@ -276,29 +274,6 @@ func (s *Stream) forceClose() {
 	asyncNotify(s.notifyCh)
 }
 
-// SetDeadline sets the read and write deadlines
-func (s *Stream) SetDeadline(t time.Time) error {
-	if err := s.SetReadDeadline(t); err != nil {
-		return err
-	}
-	if err := s.SetWriteDeadline(t); err != nil {
-		return err
-	}
-	return nil
-}
-
-// SetReadDeadline sets the deadline for future Read calls.
-func (s *Stream) SetReadDeadline(t time.Time) error {
-	s.readDeadline = t
-	return nil
-}
-
-// SetWriteDeadline sets the deadline for future Write calls
-func (s *Stream) SetWriteDeadline(t time.Time) error {
-	s.writeDeadline = t
-	return nil
-}
-
 // processFlags is used to update the state of the stream
 // based on set flags, if any. Lock must be held
 func (s *Stream) processFlags(flags uint16) error {
@@ -378,3 +353,26 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
 	asyncNotify(s.notifyCh)
 	return nil
 }
+
+// SetDeadline sets the read and write deadlines
+func (s *Stream) SetDeadline(t time.Time) error {
+	if err := s.SetReadDeadline(t); err != nil {
+		return err
+	}
+	if err := s.SetWriteDeadline(t); err != nil {
+		return err
+	}
+	return nil
+}
+
+// SetReadDeadline sets the deadline for future Read calls.
+func (s *Stream) SetReadDeadline(t time.Time) error {
+	s.readDeadline = t
+	return nil
+}
+
+// SetWriteDeadline sets the deadline for future Write calls
+func (s *Stream) SetWriteDeadline(t time.Time) error {
+	s.writeDeadline = t
+	return nil
+}