session_test.go 31 KB


  1. package yamux
  2. import (
  3. "bytes"
  4. "context"
  5. "fmt"
  6. "io"
  7. "io/ioutil"
  8. "log"
  9. "net"
  10. "reflect"
  11. "runtime"
  12. "strings"
  13. "sync"
  14. "testing"
  15. "time"
  16. )
  17. type logCapture struct{ bytes.Buffer }
  18. func (l *logCapture) logs() []string {
  19. return strings.Split(strings.TrimSpace(l.String()), "\n")
  20. }
  21. func (l *logCapture) match(expect []string) bool {
  22. return reflect.DeepEqual(l.logs(), expect)
  23. }
  24. func captureLogs(s *Session) *logCapture {
  25. buf := new(logCapture)
  26. s.logger = log.New(buf, "", 0)
  27. return buf
  28. }
  29. type pipeConn struct {
  30. reader *io.PipeReader
  31. writer *io.PipeWriter
  32. writeBlocker sync.Mutex
  33. }
  34. func (p *pipeConn) Read(b []byte) (int, error) {
  35. return p.reader.Read(b)
  36. }
  37. func (p *pipeConn) Write(b []byte) (int, error) {
  38. p.writeBlocker.Lock()
  39. defer p.writeBlocker.Unlock()
  40. return p.writer.Write(b)
  41. }
  42. func (p *pipeConn) Close() error {
  43. p.reader.Close()
  44. return p.writer.Close()
  45. }
  46. func testConn() (io.ReadWriteCloser, io.ReadWriteCloser) {
  47. read1, write1 := io.Pipe()
  48. read2, write2 := io.Pipe()
  49. conn1 := &pipeConn{reader: read1, writer: write2}
  50. conn2 := &pipeConn{reader: read2, writer: write1}
  51. return conn1, conn2
  52. }
  53. func testConf() *Config {
  54. conf := DefaultConfig()
  55. conf.AcceptBacklog = 64
  56. conf.KeepAliveInterval = 100 * time.Millisecond
  57. conf.ConnectionWriteTimeout = 250 * time.Millisecond
  58. return conf
  59. }
  60. func testConfNoKeepAlive() *Config {
  61. conf := testConf()
  62. conf.EnableKeepAlive = false
  63. return conf
  64. }
  65. func testClientServer() (*Session, *Session) {
  66. return testClientServerConfig(testConf())
  67. }
  68. func testClientServerConfig(conf *Config) (*Session, *Session) {
  69. conn1, conn2 := testConn()
  70. client, _ := Client(conn1, conf)
  71. server, _ := Server(conn2, conf)
  72. return client, server
  73. }
  74. func TestPing(t *testing.T) {
  75. client, server := testClientServer()
  76. defer client.Close()
  77. defer server.Close()
  78. rtt, err := client.Ping()
  79. if err != nil {
  80. t.Fatalf("err: %v", err)
  81. }
  82. if rtt == 0 {
  83. t.Fatalf("bad: %v", rtt)
  84. }
  85. rtt, err = server.Ping()
  86. if err != nil {
  87. t.Fatalf("err: %v", err)
  88. }
  89. if rtt == 0 {
  90. t.Fatalf("bad: %v", rtt)
  91. }
  92. }
  93. func TestPing_Timeout(t *testing.T) {
  94. client, server := testClientServerConfig(testConfNoKeepAlive())
  95. defer client.Close()
  96. defer server.Close()
  97. // Prevent the client from responding
  98. clientConn := client.conn.(*pipeConn)
  99. clientConn.writeBlocker.Lock()
  100. errCh := make(chan error, 1)
  101. go func() {
  102. _, err := server.Ping() // Ping via the server session
  103. errCh <- err
  104. }()
  105. select {
  106. case err := <-errCh:
  107. if err != ErrTimeout {
  108. t.Fatalf("err: %v", err)
  109. }
  110. case <-time.After(client.config.ConnectionWriteTimeout * 2):
  111. t.Fatalf("failed to timeout within expected %v", client.config.ConnectionWriteTimeout)
  112. }
  113. // Verify that we recover, even if we gave up
  114. clientConn.writeBlocker.Unlock()
  115. go func() {
  116. _, err := server.Ping() // Ping via the server session
  117. errCh <- err
  118. }()
  119. select {
  120. case err := <-errCh:
  121. if err != nil {
  122. t.Fatalf("err: %v", err)
  123. }
  124. case <-time.After(client.config.ConnectionWriteTimeout):
  125. t.Fatalf("timeout")
  126. }
  127. }
  128. func TestCloseBeforeAck(t *testing.T) {
  129. cfg := testConf()
  130. cfg.AcceptBacklog = 8
  131. client, server := testClientServerConfig(cfg)
  132. defer client.Close()
  133. defer server.Close()
  134. for i := 0; i < 8; i++ {
  135. s, err := client.OpenStream()
  136. if err != nil {
  137. t.Fatal(err)
  138. }
  139. s.Close()
  140. }
  141. for i := 0; i < 8; i++ {
  142. s, err := server.AcceptStream()
  143. if err != nil {
  144. t.Fatal(err)
  145. }
  146. s.Close()
  147. }
  148. done := make(chan struct{})
  149. go func() {
  150. defer close(done)
  151. s, err := client.OpenStream()
  152. if err != nil {
  153. t.Fatal(err)
  154. }
  155. s.Close()
  156. }()
  157. select {
  158. case <-done:
  159. case <-time.After(time.Second * 5):
  160. t.Fatal("timed out trying to open stream")
  161. }
  162. }
  163. func TestAccept(t *testing.T) {
  164. client, server := testClientServer()
  165. defer client.Close()
  166. defer server.Close()
  167. if client.NumStreams() != 0 {
  168. t.Fatalf("bad")
  169. }
  170. if server.NumStreams() != 0 {
  171. t.Fatalf("bad")
  172. }
  173. wg := &sync.WaitGroup{}
  174. wg.Add(4)
  175. go func() {
  176. defer wg.Done()
  177. stream, err := server.AcceptStream()
  178. if err != nil {
  179. t.Fatalf("err: %v", err)
  180. }
  181. if id := stream.StreamID(); id != 1 {
  182. t.Fatalf("bad: %v", id)
  183. }
  184. if err := stream.Close(); err != nil {
  185. t.Fatalf("err: %v", err)
  186. }
  187. }()
  188. go func() {
  189. defer wg.Done()
  190. stream, err := client.AcceptStream()
  191. if err != nil {
  192. t.Fatalf("err: %v", err)
  193. }
  194. if id := stream.StreamID(); id != 2 {
  195. t.Fatalf("bad: %v", id)
  196. }
  197. if err := stream.Close(); err != nil {
  198. t.Fatalf("err: %v", err)
  199. }
  200. }()
  201. go func() {
  202. defer wg.Done()
  203. stream, err := server.OpenStream()
  204. if err != nil {
  205. t.Fatalf("err: %v", err)
  206. }
  207. if id := stream.StreamID(); id != 2 {
  208. t.Fatalf("bad: %v", id)
  209. }
  210. if err := stream.Close(); err != nil {
  211. t.Fatalf("err: %v", err)
  212. }
  213. }()
  214. go func() {
  215. defer wg.Done()
  216. stream, err := client.OpenStream()
  217. if err != nil {
  218. t.Fatalf("err: %v", err)
  219. }
  220. if id := stream.StreamID(); id != 1 {
  221. t.Fatalf("bad: %v", id)
  222. }
  223. if err := stream.Close(); err != nil {
  224. t.Fatalf("err: %v", err)
  225. }
  226. }()
  227. doneCh := make(chan struct{})
  228. go func() {
  229. wg.Wait()
  230. close(doneCh)
  231. }()
  232. select {
  233. case <-doneCh:
  234. case <-time.After(time.Second):
  235. panic("timeout")
  236. }
  237. }
  238. func TestOpenStreamTimeout(t *testing.T) {
  239. const timeout = 25 * time.Millisecond
  240. cfg := testConf()
  241. cfg.StreamOpenTimeout = timeout
  242. client, server := testClientServerConfig(cfg)
  243. defer client.Close()
  244. defer server.Close()
  245. clientLogs := captureLogs(client)
  246. // Open a single stream without a server to acknowledge it.
  247. s, err := client.OpenStream()
  248. if err != nil {
  249. t.Fatal(err)
  250. }
  251. // Sleep for longer than the stream open timeout.
  252. // Since no ACKs are received, the stream and session should be closed.
  253. time.Sleep(timeout * 5)
  254. if !clientLogs.match([]string{"[ERR] yamux: aborted stream open (destination=yamux:remote): i/o deadline reached"}) {
  255. t.Fatalf("server log incorect: %v", clientLogs.logs())
  256. }
  257. if s.state != streamClosed {
  258. t.Fatalf("stream should have been closed")
  259. }
  260. if !client.IsClosed() {
  261. t.Fatalf("session should have been closed")
  262. }
  263. }
  264. func TestClose_closeTimeout(t *testing.T) {
  265. conf := testConf()
  266. conf.StreamCloseTimeout = 10 * time.Millisecond
  267. client, server := testClientServerConfig(conf)
  268. defer client.Close()
  269. defer server.Close()
  270. if client.NumStreams() != 0 {
  271. t.Fatalf("bad")
  272. }
  273. if server.NumStreams() != 0 {
  274. t.Fatalf("bad")
  275. }
  276. wg := &sync.WaitGroup{}
  277. wg.Add(2)
  278. // Open a stream on the client but only close it on the server.
  279. // We want to see if the stream ever gets cleaned up on the client.
  280. var clientStream *Stream
  281. go func() {
  282. defer wg.Done()
  283. var err error
  284. clientStream, err = client.OpenStream()
  285. if err != nil {
  286. t.Fatalf("err: %v", err)
  287. }
  288. }()
  289. go func() {
  290. defer wg.Done()
  291. stream, err := server.AcceptStream()
  292. if err != nil {
  293. t.Fatalf("err: %v", err)
  294. }
  295. if err := stream.Close(); err != nil {
  296. t.Fatalf("err: %v", err)
  297. }
  298. }()
  299. doneCh := make(chan struct{})
  300. go func() {
  301. wg.Wait()
  302. close(doneCh)
  303. }()
  304. select {
  305. case <-doneCh:
  306. case <-time.After(time.Second):
  307. panic("timeout")
  308. }
  309. // We should have zero streams after our timeout period
  310. time.Sleep(100 * time.Millisecond)
  311. if v := server.NumStreams(); v > 0 {
  312. t.Fatalf("should have zero streams: %d", v)
  313. }
  314. if v := client.NumStreams(); v > 0 {
  315. t.Fatalf("should have zero streams: %d", v)
  316. }
  317. if _, err := clientStream.Write([]byte("hello")); err == nil {
  318. t.Fatal("should error on write")
  319. } else if err.Error() != "connection reset" {
  320. t.Fatalf("expected connection reset, got %q", err)
  321. }
  322. }
  323. func TestNonNilInterface(t *testing.T) {
  324. _, server := testClientServer()
  325. server.Close()
  326. conn, err := server.Accept()
  327. if err != nil && conn != nil {
  328. t.Error("bad: accept should return a connection of nil value")
  329. }
  330. conn, err = server.Open()
  331. if err != nil && conn != nil {
  332. t.Error("bad: open should return a connection of nil value")
  333. }
  334. }
  335. func TestSendData_Small(t *testing.T) {
  336. client, server := testClientServer()
  337. defer client.Close()
  338. defer server.Close()
  339. wg := &sync.WaitGroup{}
  340. wg.Add(2)
  341. go func() {
  342. defer wg.Done()
  343. stream, err := server.AcceptStream()
  344. if err != nil {
  345. t.Fatalf("err: %v", err)
  346. }
  347. if server.NumStreams() != 1 {
  348. t.Fatalf("bad")
  349. }
  350. buf := make([]byte, 4)
  351. for i := 0; i < 1000; i++ {
  352. n, err := stream.Read(buf)
  353. if err != nil {
  354. t.Fatalf("err: %v", err)
  355. }
  356. if n != 4 {
  357. t.Fatalf("short read: %d", n)
  358. }
  359. if string(buf) != "test" {
  360. t.Fatalf("bad: %s", buf)
  361. }
  362. }
  363. if err := stream.Close(); err != nil {
  364. t.Fatalf("err: %v", err)
  365. }
  366. }()
  367. go func() {
  368. defer wg.Done()
  369. stream, err := client.Open()
  370. if err != nil {
  371. t.Fatalf("err: %v", err)
  372. }
  373. if client.NumStreams() != 1 {
  374. t.Fatalf("bad")
  375. }
  376. for i := 0; i < 1000; i++ {
  377. n, err := stream.Write([]byte("test"))
  378. if err != nil {
  379. t.Fatalf("err: %v", err)
  380. }
  381. if n != 4 {
  382. t.Fatalf("short write %d", n)
  383. }
  384. }
  385. if err := stream.Close(); err != nil {
  386. t.Fatalf("err: %v", err)
  387. }
  388. }()
  389. doneCh := make(chan struct{})
  390. go func() {
  391. wg.Wait()
  392. close(doneCh)
  393. }()
  394. select {
  395. case <-doneCh:
  396. if client.NumStreams() != 0 {
  397. t.Fatalf("bad")
  398. }
  399. if server.NumStreams() != 0 {
  400. t.Fatalf("bad")
  401. }
  402. return
  403. case <-time.After(time.Second):
  404. panic("timeout")
  405. }
  406. }
  407. func TestSendData_Large(t *testing.T) {
  408. client, server := testClientServer()
  409. defer client.Close()
  410. defer server.Close()
  411. const (
  412. sendSize = 250 * 1024 * 1024
  413. recvSize = 4 * 1024
  414. )
  415. data := make([]byte, sendSize)
  416. for idx := range data {
  417. data[idx] = byte(idx % 256)
  418. }
  419. wg := &sync.WaitGroup{}
  420. wg.Add(2)
  421. go func() {
  422. defer wg.Done()
  423. stream, err := server.AcceptStream()
  424. if err != nil {
  425. t.Fatalf("err: %v", err)
  426. }
  427. var sz int
  428. buf := make([]byte, recvSize)
  429. for i := 0; i < sendSize/recvSize; i++ {
  430. n, err := stream.Read(buf)
  431. if err != nil {
  432. t.Fatalf("err: %v", err)
  433. }
  434. if n != recvSize {
  435. t.Fatalf("short read: %d", n)
  436. }
  437. sz += n
  438. for idx := range buf {
  439. if buf[idx] != byte(idx%256) {
  440. t.Fatalf("bad: %v %v %v", i, idx, buf[idx])
  441. }
  442. }
  443. }
  444. if err := stream.Close(); err != nil {
  445. t.Fatalf("err: %v", err)
  446. }
  447. t.Logf("cap=%d, n=%d\n", stream.recvBuf.Cap(), sz)
  448. }()
  449. go func() {
  450. defer wg.Done()
  451. stream, err := client.Open()
  452. if err != nil {
  453. t.Fatalf("err: %v", err)
  454. }
  455. n, err := stream.Write(data)
  456. if err != nil {
  457. t.Fatalf("err: %v", err)
  458. }
  459. if n != len(data) {
  460. t.Fatalf("short write %d", n)
  461. }
  462. if err := stream.Close(); err != nil {
  463. t.Fatalf("err: %v", err)
  464. }
  465. }()
  466. doneCh := make(chan struct{})
  467. go func() {
  468. wg.Wait()
  469. close(doneCh)
  470. }()
  471. select {
  472. case <-doneCh:
  473. return
  474. case <-time.After(5 * time.Second):
  475. panic("timeout")
  476. }
  477. }
  478. func TestGoAway(t *testing.T) {
  479. client, server := testClientServer()
  480. defer client.Close()
  481. defer server.Close()
  482. if err := server.GoAway(); err != nil {
  483. t.Fatalf("err: %v", err)
  484. }
  485. _, err := client.Open()
  486. if err != ErrRemoteGoAway {
  487. t.Fatalf("err: %v", err)
  488. }
  489. }
  490. func TestManyStreams(t *testing.T) {
  491. client, server := testClientServer()
  492. defer client.Close()
  493. defer server.Close()
  494. wg := &sync.WaitGroup{}
  495. acceptor := func(i int) {
  496. defer wg.Done()
  497. stream, err := server.AcceptStream()
  498. if err != nil {
  499. t.Fatalf("err: %v", err)
  500. }
  501. defer stream.Close()
  502. buf := make([]byte, 512)
  503. for {
  504. n, err := stream.Read(buf)
  505. if err == io.EOF {
  506. return
  507. }
  508. if err != nil {
  509. t.Fatalf("err: %v", err)
  510. }
  511. if n == 0 {
  512. t.Fatalf("err: %v", err)
  513. }
  514. }
  515. }
  516. sender := func(i int) {
  517. defer wg.Done()
  518. stream, err := client.Open()
  519. if err != nil {
  520. t.Fatalf("err: %v", err)
  521. }
  522. defer stream.Close()
  523. msg := fmt.Sprintf("%08d", i)
  524. for i := 0; i < 1000; i++ {
  525. n, err := stream.Write([]byte(msg))
  526. if err != nil {
  527. t.Fatalf("err: %v", err)
  528. }
  529. if n != len(msg) {
  530. t.Fatalf("short write %d", n)
  531. }
  532. }
  533. }
  534. for i := 0; i < 50; i++ {
  535. wg.Add(2)
  536. go acceptor(i)
  537. go sender(i)
  538. }
  539. wg.Wait()
  540. }
  541. func TestManyStreams_PingPong(t *testing.T) {
  542. client, server := testClientServer()
  543. defer client.Close()
  544. defer server.Close()
  545. wg := &sync.WaitGroup{}
  546. ping := []byte("ping")
  547. pong := []byte("pong")
  548. acceptor := func(i int) {
  549. defer wg.Done()
  550. stream, err := server.AcceptStream()
  551. if err != nil {
  552. t.Fatalf("err: %v", err)
  553. }
  554. defer stream.Close()
  555. buf := make([]byte, 4)
  556. for {
  557. // Read the 'ping'
  558. n, err := stream.Read(buf)
  559. if err == io.EOF {
  560. return
  561. }
  562. if err != nil {
  563. t.Fatalf("err: %v", err)
  564. }
  565. if n != 4 {
  566. t.Fatalf("err: %v", err)
  567. }
  568. if !bytes.Equal(buf, ping) {
  569. t.Fatalf("bad: %s", buf)
  570. }
  571. // Shrink the internal buffer!
  572. stream.Shrink()
  573. // Write out the 'pong'
  574. n, err = stream.Write(pong)
  575. if err != nil {
  576. t.Fatalf("err: %v", err)
  577. }
  578. if n != 4 {
  579. t.Fatalf("err: %v", err)
  580. }
  581. }
  582. }
  583. sender := func(i int) {
  584. defer wg.Done()
  585. stream, err := client.OpenStream()
  586. if err != nil {
  587. t.Fatalf("err: %v", err)
  588. }
  589. defer stream.Close()
  590. buf := make([]byte, 4)
  591. for i := 0; i < 1000; i++ {
  592. // Send the 'ping'
  593. n, err := stream.Write(ping)
  594. if err != nil {
  595. t.Fatalf("err: %v", err)
  596. }
  597. if n != 4 {
  598. t.Fatalf("short write %d", n)
  599. }
  600. // Read the 'pong'
  601. n, err = stream.Read(buf)
  602. if err != nil {
  603. t.Fatalf("err: %v", err)
  604. }
  605. if n != 4 {
  606. t.Fatalf("err: %v", err)
  607. }
  608. if !bytes.Equal(buf, pong) {
  609. t.Fatalf("bad: %s", buf)
  610. }
  611. // Shrink the buffer
  612. stream.Shrink()
  613. }
  614. }
  615. for i := 0; i < 50; i++ {
  616. wg.Add(2)
  617. go acceptor(i)
  618. go sender(i)
  619. }
  620. wg.Wait()
  621. }
  622. func TestHalfClose(t *testing.T) {
  623. client, server := testClientServer()
  624. defer client.Close()
  625. defer server.Close()
  626. stream, err := client.Open()
  627. if err != nil {
  628. t.Fatalf("err: %v", err)
  629. }
  630. if _, err = stream.Write([]byte("a")); err != nil {
  631. t.Fatalf("err: %v", err)
  632. }
  633. stream2, err := server.Accept()
  634. if err != nil {
  635. t.Fatalf("err: %v", err)
  636. }
  637. stream2.Close() // Half close
  638. buf := make([]byte, 4)
  639. n, err := stream2.Read(buf)
  640. if err != nil {
  641. t.Fatalf("err: %v", err)
  642. }
  643. if n != 1 {
  644. t.Fatalf("bad: %v", n)
  645. }
  646. // Send more
  647. if _, err = stream.Write([]byte("bcd")); err != nil {
  648. t.Fatalf("err: %v", err)
  649. }
  650. stream.Close()
  651. // Read after close
  652. n, err = stream2.Read(buf)
  653. if err != nil {
  654. t.Fatalf("err: %v", err)
  655. }
  656. if n != 3 {
  657. t.Fatalf("bad: %v", n)
  658. }
  659. // EOF after close
  660. n, err = stream2.Read(buf)
  661. if err != io.EOF {
  662. t.Fatalf("err: %v", err)
  663. }
  664. if n != 0 {
  665. t.Fatalf("bad: %v", n)
  666. }
  667. }
  668. func TestHalfCloseSessionShutdown(t *testing.T) {
  669. client, server := testClientServer()
  670. defer client.Close()
  671. defer server.Close()
  672. // dataSize must be large enough to ensure the server will send a window
  673. // update
  674. dataSize := int64(server.config.MaxStreamWindowSize)
  675. data := make([]byte, dataSize)
  676. for idx := range data {
  677. data[idx] = byte(idx % 256)
  678. }
  679. stream, err := client.Open()
  680. if err != nil {
  681. t.Fatalf("err: %v", err)
  682. }
  683. if _, err = stream.Write(data); err != nil {
  684. t.Fatalf("err: %v", err)
  685. }
  686. stream2, err := server.Accept()
  687. if err != nil {
  688. t.Fatalf("err: %v", err)
  689. }
  690. if err := stream.Close(); err != nil {
  691. t.Fatalf("err: %v", err)
  692. }
  693. // Shut down the session of the sending side. This should not cause reads
  694. // to fail on the receiving side.
  695. if err := client.Close(); err != nil {
  696. t.Fatalf("err: %v", err)
  697. }
  698. buf := make([]byte, dataSize)
  699. n, err := stream2.Read(buf)
  700. if err != nil {
  701. t.Fatalf("err: %v", err)
  702. }
  703. if int64(n) != dataSize {
  704. t.Fatalf("bad: %v", n)
  705. }
  706. // EOF after close
  707. n, err = stream2.Read(buf)
  708. if err != io.EOF {
  709. t.Fatalf("err: %v", err)
  710. }
  711. if n != 0 {
  712. t.Fatalf("bad: %v", n)
  713. }
  714. }
  715. func TestReadDeadline(t *testing.T) {
  716. client, server := testClientServer()
  717. defer client.Close()
  718. defer server.Close()
  719. stream, err := client.Open()
  720. if err != nil {
  721. t.Fatalf("err: %v", err)
  722. }
  723. defer stream.Close()
  724. stream2, err := server.Accept()
  725. if err != nil {
  726. t.Fatalf("err: %v", err)
  727. }
  728. defer stream2.Close()
  729. if err := stream.SetReadDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
  730. t.Fatalf("err: %v", err)
  731. }
  732. buf := make([]byte, 4)
  733. _, err = stream.Read(buf)
  734. if err != ErrTimeout {
  735. t.Fatalf("err: %v", err)
  736. }
  737. // See https://github.com/hashicorp/yamux/issues/90
  738. // The standard library's http server package will read from connections in
  739. // the background to detect if they are alive.
  740. //
  741. // It sets a read deadline on connections and detect if the returned error
  742. // is a network timeout error which implements net.Error.
  743. //
  744. // The HTTP server will cancel all server requests if it isn't timeout error
  745. // from the connection.
  746. //
  747. // We assert that we return an error meeting the interface to avoid
  748. // accidently breaking yamux session compatability with the standard
  749. // library's http server implementation.
  750. if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
  751. t.Fatalf("reading timeout error is expected to implement net.Error and return true when calling Timeout()")
  752. }
  753. }
  754. func TestReadDeadline_BlockedRead(t *testing.T) {
  755. client, server := testClientServer()
  756. defer client.Close()
  757. defer server.Close()
  758. stream, err := client.Open()
  759. if err != nil {
  760. t.Fatalf("err: %v", err)
  761. }
  762. defer stream.Close()
  763. stream2, err := server.Accept()
  764. if err != nil {
  765. t.Fatalf("err: %v", err)
  766. }
  767. defer stream2.Close()
  768. // Start a read that will block
  769. errCh := make(chan error, 1)
  770. go func() {
  771. buf := make([]byte, 4)
  772. _, err := stream.Read(buf)
  773. errCh <- err
  774. close(errCh)
  775. }()
  776. // Wait to ensure the read has started.
  777. time.Sleep(5 * time.Millisecond)
  778. // Update the read deadline
  779. if err := stream.SetReadDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
  780. t.Fatalf("err: %v", err)
  781. }
  782. select {
  783. case <-time.After(100 * time.Millisecond):
  784. t.Fatal("expected read timeout")
  785. case err := <-errCh:
  786. if err != ErrTimeout {
  787. t.Fatalf("expected ErrTimeout; got %v", err)
  788. }
  789. }
  790. }
  791. func TestWriteDeadline(t *testing.T) {
  792. client, server := testClientServer()
  793. defer client.Close()
  794. defer server.Close()
  795. stream, err := client.Open()
  796. if err != nil {
  797. t.Fatalf("err: %v", err)
  798. }
  799. defer stream.Close()
  800. stream2, err := server.Accept()
  801. if err != nil {
  802. t.Fatalf("err: %v", err)
  803. }
  804. defer stream2.Close()
  805. if err := stream.SetWriteDeadline(time.Now().Add(50 * time.Millisecond)); err != nil {
  806. t.Fatalf("err: %v", err)
  807. }
  808. buf := make([]byte, 512)
  809. for i := 0; i < int(initialStreamWindow); i++ {
  810. _, err := stream.Write(buf)
  811. if err != nil && err == ErrTimeout {
  812. return
  813. } else if err != nil {
  814. t.Fatalf("err: %v", err)
  815. }
  816. }
  817. t.Fatalf("Expected timeout")
  818. }
  819. func TestWriteDeadline_BlockedWrite(t *testing.T) {
  820. client, server := testClientServer()
  821. defer client.Close()
  822. defer server.Close()
  823. stream, err := client.Open()
  824. if err != nil {
  825. t.Fatalf("err: %v", err)
  826. }
  827. defer stream.Close()
  828. stream2, err := server.Accept()
  829. if err != nil {
  830. t.Fatalf("err: %v", err)
  831. }
  832. defer stream2.Close()
  833. // Start a goroutine making writes that will block
  834. errCh := make(chan error, 1)
  835. go func() {
  836. buf := make([]byte, 512)
  837. for i := 0; i < int(initialStreamWindow); i++ {
  838. _, err := stream.Write(buf)
  839. if err == nil {
  840. continue
  841. }
  842. errCh <- err
  843. close(errCh)
  844. return
  845. }
  846. close(errCh)
  847. }()
  848. // Wait to ensure the write has started.
  849. time.Sleep(5 * time.Millisecond)
  850. // Update the write deadline
  851. if err := stream.SetWriteDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
  852. t.Fatalf("err: %v", err)
  853. }
  854. select {
  855. case <-time.After(1 * time.Second):
  856. t.Fatal("expected write timeout")
  857. case err := <-errCh:
  858. if err != ErrTimeout {
  859. t.Fatalf("expected ErrTimeout; got %v", err)
  860. }
  861. }
  862. }
  863. func TestBacklogExceeded(t *testing.T) {
  864. client, server := testClientServer()
  865. defer client.Close()
  866. defer server.Close()
  867. // Fill the backlog
  868. max := client.config.AcceptBacklog
  869. for i := 0; i < max; i++ {
  870. stream, err := client.Open()
  871. if err != nil {
  872. t.Fatalf("err: %v", err)
  873. }
  874. defer stream.Close()
  875. if _, err := stream.Write([]byte("foo")); err != nil {
  876. t.Fatalf("err: %v", err)
  877. }
  878. }
  879. // Attempt to open a new stream
  880. errCh := make(chan error, 1)
  881. go func() {
  882. _, err := client.Open()
  883. errCh <- err
  884. }()
  885. // Shutdown the server
  886. go func() {
  887. time.Sleep(10 * time.Millisecond)
  888. server.Close()
  889. }()
  890. select {
  891. case err := <-errCh:
  892. if err == nil {
  893. t.Fatalf("open should fail")
  894. }
  895. case <-time.After(time.Second):
  896. t.Fatalf("timeout")
  897. }
  898. }
  899. func TestKeepAlive(t *testing.T) {
  900. client, server := testClientServer()
  901. defer client.Close()
  902. defer server.Close()
  903. time.Sleep(200 * time.Millisecond)
  904. // Ping value should increase
  905. client.pingLock.Lock()
  906. defer client.pingLock.Unlock()
  907. if client.pingID == 0 {
  908. t.Fatalf("should ping")
  909. }
  910. server.pingLock.Lock()
  911. defer server.pingLock.Unlock()
  912. if server.pingID == 0 {
  913. t.Fatalf("should ping")
  914. }
  915. }
  916. func TestKeepAlive_Timeout(t *testing.T) {
  917. conn1, conn2 := testConn()
  918. clientConf := testConf()
  919. clientConf.ConnectionWriteTimeout = time.Hour // We're testing keep alives, not connection writes
  920. clientConf.EnableKeepAlive = false // Just test one direction, so it's deterministic who hangs up on whom
  921. client, _ := Client(conn1, clientConf)
  922. defer client.Close()
  923. server, _ := Server(conn2, testConf())
  924. defer server.Close()
  925. _ = captureLogs(client) // Client logs aren't part of the test
  926. serverLogs := captureLogs(server)
  927. errCh := make(chan error, 1)
  928. go func() {
  929. _, err := server.Accept() // Wait until server closes
  930. errCh <- err
  931. }()
  932. // Prevent the client from responding
  933. clientConn := client.conn.(*pipeConn)
  934. clientConn.writeBlocker.Lock()
  935. select {
  936. case err := <-errCh:
  937. if err != ErrKeepAliveTimeout {
  938. t.Fatalf("unexpected error: %v", err)
  939. }
  940. case <-time.After(1 * time.Second):
  941. t.Fatalf("timeout waiting for timeout")
  942. }
  943. clientConn.writeBlocker.Unlock()
  944. if !server.IsClosed() {
  945. t.Fatalf("server should have closed")
  946. }
  947. if !serverLogs.match([]string{"[ERR] yamux: keepalive failed: i/o deadline reached"}) {
  948. t.Fatalf("server log incorect: %v", serverLogs.logs())
  949. }
  950. }
  951. func TestLargeWindow(t *testing.T) {
  952. conf := DefaultConfig()
  953. conf.MaxStreamWindowSize *= 2
  954. client, server := testClientServerConfig(conf)
  955. defer client.Close()
  956. defer server.Close()
  957. stream, err := client.Open()
  958. if err != nil {
  959. t.Fatalf("err: %v", err)
  960. }
  961. defer stream.Close()
  962. stream2, err := server.Accept()
  963. if err != nil {
  964. t.Fatalf("err: %v", err)
  965. }
  966. defer stream2.Close()
  967. stream.SetWriteDeadline(time.Now().Add(10 * time.Millisecond))
  968. buf := make([]byte, conf.MaxStreamWindowSize)
  969. n, err := stream.Write(buf)
  970. if err != nil {
  971. t.Fatalf("err: %v", err)
  972. }
  973. if n != len(buf) {
  974. t.Fatalf("short write: %d", n)
  975. }
  976. }
  977. type UnlimitedReader struct{}
  978. func (u *UnlimitedReader) Read(p []byte) (int, error) {
  979. runtime.Gosched()
  980. return len(p), nil
  981. }
  982. func TestSendData_VeryLarge(t *testing.T) {
  983. client, server := testClientServer()
  984. defer client.Close()
  985. defer server.Close()
  986. var n int64 = 1 * 1024 * 1024 * 1024
  987. var workers int = 16
  988. wg := &sync.WaitGroup{}
  989. wg.Add(workers * 2)
  990. for i := 0; i < workers; i++ {
  991. go func() {
  992. defer wg.Done()
  993. stream, err := server.AcceptStream()
  994. if err != nil {
  995. t.Fatalf("err: %v", err)
  996. }
  997. defer stream.Close()
  998. buf := make([]byte, 4)
  999. _, err = stream.Read(buf)
  1000. if err != nil {
  1001. t.Fatalf("err: %v", err)
  1002. }
  1003. if !bytes.Equal(buf, []byte{0, 1, 2, 3}) {
  1004. t.Fatalf("bad header")
  1005. }
  1006. recv, err := io.Copy(ioutil.Discard, stream)
  1007. if err != nil {
  1008. t.Fatalf("err: %v", err)
  1009. }
  1010. if recv != n {
  1011. t.Fatalf("bad: %v", recv)
  1012. }
  1013. }()
  1014. }
  1015. for i := 0; i < workers; i++ {
  1016. go func() {
  1017. defer wg.Done()
  1018. stream, err := client.Open()
  1019. if err != nil {
  1020. t.Fatalf("err: %v", err)
  1021. }
  1022. defer stream.Close()
  1023. _, err = stream.Write([]byte{0, 1, 2, 3})
  1024. if err != nil {
  1025. t.Fatalf("err: %v", err)
  1026. }
  1027. unlimited := &UnlimitedReader{}
  1028. sent, err := io.Copy(stream, io.LimitReader(unlimited, n))
  1029. if err != nil {
  1030. t.Fatalf("err: %v", err)
  1031. }
  1032. if sent != n {
  1033. t.Fatalf("bad: %v", sent)
  1034. }
  1035. }()
  1036. }
  1037. doneCh := make(chan struct{})
  1038. go func() {
  1039. wg.Wait()
  1040. close(doneCh)
  1041. }()
  1042. select {
  1043. case <-doneCh:
  1044. case <-time.After(20 * time.Second):
  1045. panic("timeout")
  1046. }
  1047. }
  1048. func TestBacklogExceeded_Accept(t *testing.T) {
  1049. client, server := testClientServer()
  1050. defer client.Close()
  1051. defer server.Close()
  1052. max := 5 * client.config.AcceptBacklog
  1053. go func() {
  1054. for i := 0; i < max; i++ {
  1055. stream, err := server.Accept()
  1056. if err != nil {
  1057. t.Fatalf("err: %v", err)
  1058. }
  1059. defer stream.Close()
  1060. }
  1061. }()
  1062. // Fill the backlog
  1063. for i := 0; i < max; i++ {
  1064. stream, err := client.Open()
  1065. if err != nil {
  1066. t.Fatalf("err: %v", err)
  1067. }
  1068. defer stream.Close()
  1069. if _, err := stream.Write([]byte("foo")); err != nil {
  1070. t.Fatalf("err: %v", err)
  1071. }
  1072. }
  1073. }
  1074. func TestSession_WindowUpdateWriteDuringRead(t *testing.T) {
  1075. client, server := testClientServerConfig(testConfNoKeepAlive())
  1076. defer client.Close()
  1077. defer server.Close()
  1078. var wg sync.WaitGroup
  1079. wg.Add(2)
  1080. // Choose a huge flood size that we know will result in a window update.
  1081. flood := int64(client.config.MaxStreamWindowSize) - 1
  1082. // The server will accept a new stream and then flood data to it.
  1083. go func() {
  1084. defer wg.Done()
  1085. stream, err := server.AcceptStream()
  1086. if err != nil {
  1087. t.Fatalf("err: %v", err)
  1088. }
  1089. defer stream.Close()
  1090. n, err := stream.Write(make([]byte, flood))
  1091. if err != nil {
  1092. t.Fatalf("err: %v", err)
  1093. }
  1094. if int64(n) != flood {
  1095. t.Fatalf("short write: %d", n)
  1096. }
  1097. }()
  1098. // The client will open a stream, block outbound writes, and then
  1099. // listen to the flood from the server, which should time out since
  1100. // it won't be able to send the window update.
  1101. go func() {
  1102. defer wg.Done()
  1103. stream, err := client.OpenStream()
  1104. if err != nil {
  1105. t.Fatalf("err: %v", err)
  1106. }
  1107. defer stream.Close()
  1108. conn := client.conn.(*pipeConn)
  1109. conn.writeBlocker.Lock()
  1110. defer conn.writeBlocker.Unlock()
  1111. _, err = stream.Read(make([]byte, flood))
  1112. if err != ErrConnectionWriteTimeout {
  1113. t.Fatalf("err: %v", err)
  1114. }
  1115. }()
  1116. wg.Wait()
  1117. }
  1118. func TestSession_PartialReadWindowUpdate(t *testing.T) {
  1119. client, server := testClientServerConfig(testConfNoKeepAlive())
  1120. defer client.Close()
  1121. defer server.Close()
  1122. var wg sync.WaitGroup
  1123. wg.Add(1)
  1124. // Choose a huge flood size that we know will result in a window update.
  1125. flood := int64(client.config.MaxStreamWindowSize)
  1126. var wr *Stream
  1127. // The server will accept a new stream and then flood data to it.
  1128. go func() {
  1129. defer wg.Done()
  1130. var err error
  1131. wr, err = server.AcceptStream()
  1132. if err != nil {
  1133. t.Fatalf("err: %v", err)
  1134. }
  1135. defer wr.Close()
  1136. if wr.sendWindow != client.config.MaxStreamWindowSize {
  1137. t.Fatalf("sendWindow: exp=%d, got=%d", client.config.MaxStreamWindowSize, wr.sendWindow)
  1138. }
  1139. n, err := wr.Write(make([]byte, flood))
  1140. if err != nil {
  1141. t.Fatalf("err: %v", err)
  1142. }
  1143. if int64(n) != flood {
  1144. t.Fatalf("short write: %d", n)
  1145. }
  1146. if wr.sendWindow != 0 {
  1147. t.Fatalf("sendWindow: exp=%d, got=%d", 0, wr.sendWindow)
  1148. }
  1149. }()
  1150. stream, err := client.OpenStream()
  1151. if err != nil {
  1152. t.Fatalf("err: %v", err)
  1153. }
  1154. defer stream.Close()
  1155. wg.Wait()
  1156. _, err = stream.Read(make([]byte, flood/2+1))
  1157. if exp := uint32(flood/2 + 1); wr.sendWindow != exp {
  1158. t.Errorf("sendWindow: exp=%d, got=%d", exp, wr.sendWindow)
  1159. }
  1160. }
  1161. func TestSession_sendNoWait_Timeout(t *testing.T) {
  1162. client, server := testClientServerConfig(testConfNoKeepAlive())
  1163. defer client.Close()
  1164. defer server.Close()
  1165. var wg sync.WaitGroup
  1166. wg.Add(2)
  1167. go func() {
  1168. defer wg.Done()
  1169. stream, err := server.AcceptStream()
  1170. if err != nil {
  1171. t.Fatalf("err: %v", err)
  1172. }
  1173. defer stream.Close()
  1174. }()
  1175. // The client will open the stream and then block outbound writes, we'll
  1176. // probe sendNoWait once it gets into that state.
  1177. go func() {
  1178. defer wg.Done()
  1179. stream, err := client.OpenStream()
  1180. if err != nil {
  1181. t.Fatalf("err: %v", err)
  1182. }
  1183. defer stream.Close()
  1184. conn := client.conn.(*pipeConn)
  1185. conn.writeBlocker.Lock()
  1186. defer conn.writeBlocker.Unlock()
  1187. hdr := header(make([]byte, headerSize))
  1188. hdr.encode(typePing, flagACK, 0, 0)
  1189. for {
  1190. err = client.sendNoWait(hdr)
  1191. if err == nil {
  1192. continue
  1193. } else if err == ErrConnectionWriteTimeout {
  1194. break
  1195. } else {
  1196. t.Fatalf("err: %v", err)
  1197. }
  1198. }
  1199. }()
  1200. wg.Wait()
  1201. }
  1202. func TestSession_PingOfDeath(t *testing.T) {
  1203. client, server := testClientServerConfig(testConfNoKeepAlive())
  1204. defer client.Close()
  1205. defer server.Close()
  1206. var wg sync.WaitGroup
  1207. wg.Add(2)
  1208. var doPingOfDeath sync.Mutex
  1209. doPingOfDeath.Lock()
  1210. // This is used later to block outbound writes.
  1211. conn := server.conn.(*pipeConn)
  1212. // The server will accept a stream, block outbound writes, and then
  1213. // flood its send channel so that no more headers can be queued.
  1214. go func() {
  1215. defer wg.Done()
  1216. stream, err := server.AcceptStream()
  1217. if err != nil {
  1218. t.Fatalf("err: %v", err)
  1219. }
  1220. defer stream.Close()
  1221. conn.writeBlocker.Lock()
  1222. for {
  1223. hdr := header(make([]byte, headerSize))
  1224. hdr.encode(typePing, 0, 0, 0)
  1225. err = server.sendNoWait(hdr)
  1226. if err == nil {
  1227. continue
  1228. } else if err == ErrConnectionWriteTimeout {
  1229. break
  1230. } else {
  1231. t.Fatalf("err: %v", err)
  1232. }
  1233. }
  1234. doPingOfDeath.Unlock()
  1235. }()
  1236. // The client will open a stream and then send the server a ping once it
  1237. // can no longer write. This makes sure the server doesn't deadlock reads
  1238. // while trying to reply to the ping with no ability to write.
  1239. go func() {
  1240. defer wg.Done()
  1241. stream, err := client.OpenStream()
  1242. if err != nil {
  1243. t.Fatalf("err: %v", err)
  1244. }
  1245. defer stream.Close()
  1246. // This ping will never unblock because the ping id will never
  1247. // show up in a response.
  1248. doPingOfDeath.Lock()
  1249. go func() { client.Ping() }()
  1250. // Wait for a while to make sure the previous ping times out,
  1251. // then turn writes back on and make sure a ping works again.
  1252. time.Sleep(2 * server.config.ConnectionWriteTimeout)
  1253. conn.writeBlocker.Unlock()
  1254. if _, err = client.Ping(); err != nil {
  1255. t.Fatalf("err: %v", err)
  1256. }
  1257. }()
  1258. wg.Wait()
  1259. }
  1260. func TestSession_ConnectionWriteTimeout(t *testing.T) {
  1261. client, server := testClientServerConfig(testConfNoKeepAlive())
  1262. defer client.Close()
  1263. defer server.Close()
  1264. var wg sync.WaitGroup
  1265. wg.Add(2)
  1266. go func() {
  1267. defer wg.Done()
  1268. stream, err := server.AcceptStream()
  1269. if err != nil {
  1270. t.Fatalf("err: %v", err)
  1271. }
  1272. defer stream.Close()
  1273. }()
  1274. // The client will open the stream and then block outbound writes, we'll
  1275. // tee up a write and make sure it eventually times out.
  1276. go func() {
  1277. defer wg.Done()
  1278. stream, err := client.OpenStream()
  1279. if err != nil {
  1280. t.Fatalf("err: %v", err)
  1281. }
  1282. defer stream.Close()
  1283. conn := client.conn.(*pipeConn)
  1284. conn.writeBlocker.Lock()
  1285. defer conn.writeBlocker.Unlock()
  1286. // Since the write goroutine is blocked then this will return a
  1287. // timeout since it can't get feedback about whether the write
  1288. // worked.
  1289. n, err := stream.Write([]byte("hello"))
  1290. if err != ErrConnectionWriteTimeout {
  1291. t.Fatalf("err: %v", err)
  1292. }
  1293. if n != 0 {
  1294. t.Fatalf("lied about writes: %d", n)
  1295. }
  1296. }()
  1297. wg.Wait()
  1298. }
  1299. func TestCancelAccept(t *testing.T) {
  1300. _, server := testClientServer()
  1301. defer server.Close()
  1302. ctx, cancel := context.WithCancel(context.Background())
  1303. var wg sync.WaitGroup
  1304. wg.Add(1)
  1305. go func() {
  1306. defer wg.Done()
  1307. stream, err := server.AcceptStreamWithContext(ctx)
  1308. if err != context.Canceled {
  1309. t.Fatalf("err: %v", err)
  1310. }
  1311. if stream != nil {
  1312. defer stream.Close()
  1313. }
  1314. }()
  1315. cancel()
  1316. wg.Wait()
  1317. }