123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374 |
- package yamux
- import (
- "fmt"
- "io"
- "net"
- "sync"
- "time"
- )
- var (
- // ErrInvalidVersion means we received a frame with an
- // invalid version
- ErrInvalidVersion = fmt.Errorf("invalid protocol version")
- // ErrInvalidMsgType means we received a frame with an
- // invalid message type
- ErrInvalidMsgType = fmt.Errorf("invalid msg type")
- // ErrSessionShutdown is used if there is a shutdown during
- // an operation
- ErrSessionShutdown = fmt.Errorf("session shutdown")
- )
- // Session is used to wrap a reliable ordered connection and to
- // multiplex it into multiple streams.
- type Session struct {
- // client is true if we are a client size connection
- client bool
- // config holds our configuration
- config *Config
- // conn is the underlying connection
- conn io.ReadWriteCloser
- // nextStreamID is the next stream we should
- // send. This depends if we are a client/server.
- nextStreamID uint32
- // pings is used to track inflight pings
- pings map[uint32]chan struct{}
- pingID uint32
- pingLock sync.Mutex
- // streams maps a stream id to a stream
- streams map[uint32]*Stream
- // acceptCh is used to pass ready streams to the client
- acceptCh chan *Stream
- // sendCh is used to mark a stream as ready to send,
- // or to send a header out directly.
- sendCh chan sendReady
- // shutdown is used to safely close a session
- shutdown bool
- shutdownErr error
- shutdownCh chan struct{}
- shutdownLock sync.Mutex
- }
- // hasAddr is used to get the address from the underlying connection
- type hasAddr interface {
- LocalAddr() net.Addr
- RemoteAddr() net.Addr
- }
- // yamuxAddr is used when we cannot get the underlying address
- type yamuxAddr struct {
- Addr string
- }
- func (*yamuxAddr) Network() string {
- return "yamux"
- }
- func (y *yamuxAddr) String() string {
- return fmt.Sprintf("yamux:%s", y.Addr)
- }
- // sendReady is used to either mark a stream as ready
- // or to directly send a header
- type sendReady struct {
- StreamID uint32
- Hdr []byte
- Err chan error
- }
- // newSession is used to construct a new session
- func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
- s := &Session{
- client: client,
- config: config,
- conn: conn,
- pings: make(map[uint32]chan struct{}),
- streams: make(map[uint32]*Stream),
- acceptCh: make(chan *Stream, config.AcceptBacklog),
- sendCh: make(chan sendReady, 64),
- shutdownCh: make(chan struct{}),
- }
- if client {
- s.nextStreamID = 1
- } else {
- s.nextStreamID = 2
- }
- go s.recv()
- go s.send()
- if config.EnableKeepAlive {
- go s.keepalive()
- }
- return s
- }
- // Open is used to create a new stream
- func (s *Session) Open() (*Stream, error) {
- return nil, nil
- }
- // Accept is used to block until the next available stream
- // is ready to be accepted.
- func (s *Session) Accept() (net.Conn, error) {
- return s.AcceptStream()
- }
- // AcceptStream is used to block until the next available stream
- // is ready to be accepted.
- func (s *Session) AcceptStream() (*Stream, error) {
- select {
- case stream := <-s.acceptCh:
- return stream, nil
- case <-s.shutdownCh:
- return nil, s.shutdownErr
- }
- }
- // Close is used to close the session and all streams.
- // Attempts to send a GoAway before closing the connection.
- func (s *Session) Close() error {
- s.shutdownLock.Lock()
- defer s.shutdownLock.Unlock()
- if s.shutdown {
- return nil
- }
- s.shutdown = true
- close(s.shutdownCh)
- s.conn.Close()
- return nil
- }
- // Addr is used to get the address of the listener.
- func (s *Session) Addr() net.Addr {
- return s.LocalAddr()
- }
- // LocalAddr is used to get the local address of the
- // underlying connection.
- func (s *Session) LocalAddr() net.Addr {
- addr, ok := s.conn.(hasAddr)
- if !ok {
- return &yamuxAddr{"local"}
- }
- return addr.LocalAddr()
- }
- // RemoteAddr is used to get the address of remote end
- // of the underlying connection
- func (s *Session) RemoteAddr() net.Addr {
- addr, ok := s.conn.(hasAddr)
- if !ok {
- return &yamuxAddr{"remote"}
- }
- return addr.RemoteAddr()
- }
- // Ping is used to measure the RTT response time
- func (s *Session) Ping() (time.Duration, error) {
- // Get a channel for the ping
- ch := make(chan struct{})
- // Get a new ping id, mark as pending
- s.pingLock.Lock()
- id := s.pingID
- s.pingID++
- s.pings[id] = ch
- s.pingLock.Unlock()
- // Send the ping request
- hdr := header(make([]byte, headerSize))
- hdr.encode(typePing, flagSYN, 0, id)
- if err := s.waitForSend(hdr); err != nil {
- return 0, err
- }
- // Wait for a response
- start := time.Now()
- select {
- case <-ch:
- case <-s.shutdownCh:
- return 0, ErrSessionShutdown
- }
- // Compute the RTT
- return time.Now().Sub(start), nil
- }
- // keepalive is a long running goroutine that periodically does
- // a ping to keep the connection alive.
- func (s *Session) keepalive() {
- for {
- select {
- case <-time.After(s.config.KeepAliveInterval):
- s.Ping()
- case <-s.shutdownCh:
- return
- }
- }
- }
- // waitForSend waits to send a header, checking for a potential shutdown
- func (s *Session) waitForSend(hdr header) error {
- errCh := make(chan error, 1)
- ready := sendReady{Hdr: hdr, Err: errCh}
- select {
- case s.sendCh <- ready:
- case <-s.shutdownCh:
- return ErrSessionShutdown
- }
- select {
- case err := <-errCh:
- return err
- case <-s.shutdownCh:
- return ErrSessionShutdown
- }
- }
- // sendNoWait does a send without waiting
- func (s *Session) sendNoWait(hdr header) error {
- select {
- case s.sendCh <- sendReady{Hdr: hdr}:
- return nil
- case <-s.shutdownCh:
- return ErrSessionShutdown
- }
- }
- // send is a long running goroutine that sends data
- func (s *Session) send() {
- for {
- select {
- case ready := <-s.sendCh:
- // Send data from a stream if ready
- if ready.StreamID != 0 {
- }
- // Send a header if ready
- if ready.Hdr != nil {
- sent := 0
- for sent < len(ready.Hdr) {
- n, err := s.conn.Write(ready.Hdr[sent:])
- if err != nil {
- s.exitErr(err)
- asyncSendErr(ready.Err, err)
- }
- sent += n
- }
- }
- asyncSendErr(ready.Err, nil)
- case <-s.shutdownCh:
- return
- }
- }
- }
- // recv is a long running goroutine that accepts new data
- func (s *Session) recv() {
- hdr := header(make([]byte, headerSize))
- for {
- // Read the header
- if _, err := io.ReadFull(s.conn, hdr); err != nil {
- s.exitErr(err)
- return
- }
- // Verify the version
- if hdr.Version() != protoVersion {
- s.exitErr(ErrInvalidVersion)
- return
- }
- // Switch on the type
- msgType := hdr.MsgType()
- switch msgType {
- case typeData:
- s.handleData(hdr)
- case typeWindowUpdate:
- s.handleWindowUpdate(hdr)
- case typePing:
- s.handlePing(hdr)
- case typeGoAway:
- s.handleGoAway(hdr)
- default:
- s.exitErr(ErrInvalidMsgType)
- return
- }
- }
- }
- // handleData is invokde for a typeData frame
- func (s *Session) handleData(hdr header) {
- flags := hdr.Flags()
- // Check for a new stream creation
- if flags&flagSYN == flagSYN {
- s.createStream(hdr.StreamID())
- }
- }
- // handleWindowUpdate is invokde for a typeWindowUpdate frame
- func (s *Session) handleWindowUpdate(hdr header) {
- flags := hdr.Flags()
- // Check for a new stream creation
- if flags&flagSYN == flagSYN {
- s.createStream(hdr.StreamID())
- }
- }
- // handlePing is invokde for a typePing frame
- func (s *Session) handlePing(hdr header) {
- flags := hdr.Flags()
- pingID := hdr.Length()
- // Check if this is a query, respond back
- if flags&flagSYN == flagSYN {
- hdr := header(make([]byte, headerSize))
- hdr.encode(typePing, flagACK, 0, pingID)
- s.sendNoWait(hdr)
- return
- }
- // Handle a response
- s.pingLock.Lock()
- ch := s.pings[pingID]
- if ch != nil {
- delete(s.pings, pingID)
- close(ch)
- }
- s.pingLock.Unlock()
- }
- // handleGoAway is invokde for a typeGoAway frame
- func (s *Session) handleGoAway(hdr header) {
- }
- // exitErr is used to handle an error that is causing
- // the listener to exit.
- func (s *Session) exitErr(err error) {
- }
- // goAway is used to send a goAway message
- func (s *Session) goAway(reason uint32) {
- hdr := header(make([]byte, headerSize))
- hdr.encode(typeGoAway, 0, 0, reason)
- s.sendNoWait(hdr)
- }
- // createStream is used to create a new stream
- func (s *Session) createStream(id uint32) {
- // TODO
- }
|