session_test.go 23 KB


  1. package yamux
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "io/ioutil"
  7. "log"
  8. "reflect"
  9. "runtime"
  10. "strings"
  11. "sync"
  12. "testing"
  13. "time"
  14. )
  15. type logCapture struct{ bytes.Buffer }
  16. func (l *logCapture) logs() []string {
  17. return strings.Split(strings.TrimSpace(l.String()), "\n")
  18. }
  19. func (l *logCapture) match(expect []string) bool {
  20. return reflect.DeepEqual(l.logs(), expect)
  21. }
  22. func captureLogs(s *Session) *logCapture {
  23. buf := new(logCapture)
  24. s.logger = log.New(buf, "", 0)
  25. return buf
  26. }
  27. type pipeConn struct {
  28. reader *io.PipeReader
  29. writer *io.PipeWriter
  30. writeBlocker sync.Mutex
  31. }
  32. func (p *pipeConn) Read(b []byte) (int, error) {
  33. return p.reader.Read(b)
  34. }
  35. func (p *pipeConn) Write(b []byte) (int, error) {
  36. p.writeBlocker.Lock()
  37. defer p.writeBlocker.Unlock()
  38. return p.writer.Write(b)
  39. }
  40. func (p *pipeConn) Close() error {
  41. p.reader.Close()
  42. return p.writer.Close()
  43. }
  44. func testConn() (io.ReadWriteCloser, io.ReadWriteCloser) {
  45. read1, write1 := io.Pipe()
  46. read2, write2 := io.Pipe()
  47. conn1 := &pipeConn{reader: read1, writer: write2}
  48. conn2 := &pipeConn{reader: read2, writer: write1}
  49. return conn1, conn2
  50. }
  51. func testConf() *Config {
  52. conf := DefaultConfig()
  53. conf.AcceptBacklog = 64
  54. conf.KeepAliveInterval = 100 * time.Millisecond
  55. conf.ConnectionWriteTimeout = 250 * time.Millisecond
  56. return conf
  57. }
  58. func testConfNoKeepAlive() *Config {
  59. conf := testConf()
  60. conf.EnableKeepAlive = false
  61. return conf
  62. }
  63. func testClientServer() (*Session, *Session) {
  64. return testClientServerConfig(testConf())
  65. }
  66. func testClientServerConfig(conf *Config) (*Session, *Session) {
  67. conn1, conn2 := testConn()
  68. client, _ := Client(conn1, conf)
  69. server, _ := Server(conn2, conf)
  70. return client, server
  71. }
  72. func TestPing(t *testing.T) {
  73. client, server := testClientServer()
  74. defer client.Close()
  75. defer server.Close()
  76. rtt, err := client.Ping()
  77. if err != nil {
  78. t.Fatalf("err: %v", err)
  79. }
  80. if rtt == 0 {
  81. t.Fatalf("bad: %v", rtt)
  82. }
  83. rtt, err = server.Ping()
  84. if err != nil {
  85. t.Fatalf("err: %v", err)
  86. }
  87. if rtt == 0 {
  88. t.Fatalf("bad: %v", rtt)
  89. }
  90. }
  91. func TestPing_Timeout(t *testing.T) {
  92. client, server := testClientServerConfig(testConfNoKeepAlive())
  93. defer client.Close()
  94. defer server.Close()
  95. // Prevent the client from responding
  96. clientConn := client.conn.(*pipeConn)
  97. clientConn.writeBlocker.Lock()
  98. errCh := make(chan error, 1)
  99. go func() {
  100. _, err := server.Ping() // Ping via the server session
  101. errCh <- err
  102. }()
  103. select {
  104. case err := <-errCh:
  105. if err != ErrTimeout {
  106. t.Fatalf("err: %v", err)
  107. }
  108. case <-time.After(client.config.ConnectionWriteTimeout * 2):
  109. t.Fatalf("failed to timeout within expected %v", client.config.ConnectionWriteTimeout)
  110. }
  111. // Verify that we recover, even if we gave up
  112. clientConn.writeBlocker.Unlock()
  113. go func() {
  114. _, err := server.Ping() // Ping via the server session
  115. errCh <- err
  116. }()
  117. select {
  118. case err := <-errCh:
  119. if err != nil {
  120. t.Fatalf("err: %v", err)
  121. }
  122. case <-time.After(client.config.ConnectionWriteTimeout):
  123. t.Fatalf("timeout")
  124. }
  125. }
  126. func TestCloseBeforeAck(t *testing.T) {
  127. cfg := testConf()
  128. cfg.AcceptBacklog = 8
  129. client, server := testClientServerConfig(cfg)
  130. defer client.Close()
  131. defer server.Close()
  132. for i := 0; i < 8; i++ {
  133. s, err := client.OpenStream()
  134. if err != nil {
  135. t.Fatal(err)
  136. }
  137. s.Close()
  138. }
  139. for i := 0; i < 8; i++ {
  140. s, err := server.AcceptStream()
  141. if err != nil {
  142. t.Fatal(err)
  143. }
  144. s.Close()
  145. }
  146. done := make(chan struct{})
  147. go func() {
  148. defer close(done)
  149. s, err := client.OpenStream()
  150. if err != nil {
  151. t.Fatal(err)
  152. }
  153. s.Close()
  154. }()
  155. select {
  156. case <-done:
  157. case <-time.After(time.Second * 5):
  158. t.Fatal("timed out trying to open stream")
  159. }
  160. }
  161. func TestAccept(t *testing.T) {
  162. client, server := testClientServer()
  163. defer client.Close()
  164. defer server.Close()
  165. if client.NumStreams() != 0 {
  166. t.Fatalf("bad")
  167. }
  168. if server.NumStreams() != 0 {
  169. t.Fatalf("bad")
  170. }
  171. wg := &sync.WaitGroup{}
  172. wg.Add(4)
  173. go func() {
  174. defer wg.Done()
  175. stream, err := server.AcceptStream()
  176. if err != nil {
  177. t.Fatalf("err: %v", err)
  178. }
  179. if id := stream.StreamID(); id != 1 {
  180. t.Fatalf("bad: %v", id)
  181. }
  182. if err := stream.Close(); err != nil {
  183. t.Fatalf("err: %v", err)
  184. }
  185. }()
  186. go func() {
  187. defer wg.Done()
  188. stream, err := client.AcceptStream()
  189. if err != nil {
  190. t.Fatalf("err: %v", err)
  191. }
  192. if id := stream.StreamID(); id != 2 {
  193. t.Fatalf("bad: %v", id)
  194. }
  195. if err := stream.Close(); err != nil {
  196. t.Fatalf("err: %v", err)
  197. }
  198. }()
  199. go func() {
  200. defer wg.Done()
  201. stream, err := server.OpenStream()
  202. if err != nil {
  203. t.Fatalf("err: %v", err)
  204. }
  205. if id := stream.StreamID(); id != 2 {
  206. t.Fatalf("bad: %v", id)
  207. }
  208. if err := stream.Close(); err != nil {
  209. t.Fatalf("err: %v", err)
  210. }
  211. }()
  212. go func() {
  213. defer wg.Done()
  214. stream, err := client.OpenStream()
  215. if err != nil {
  216. t.Fatalf("err: %v", err)
  217. }
  218. if id := stream.StreamID(); id != 1 {
  219. t.Fatalf("bad: %v", id)
  220. }
  221. if err := stream.Close(); err != nil {
  222. t.Fatalf("err: %v", err)
  223. }
  224. }()
  225. doneCh := make(chan struct{})
  226. go func() {
  227. wg.Wait()
  228. close(doneCh)
  229. }()
  230. select {
  231. case <-doneCh:
  232. case <-time.After(time.Second):
  233. panic("timeout")
  234. }
  235. }
  236. func TestNonNilInterface(t *testing.T) {
  237. _, server := testClientServer()
  238. server.Close()
  239. conn, err := server.Accept()
  240. if err != nil && conn != nil {
  241. t.Error("bad: accept should return a connection of nil value")
  242. }
  243. conn, err = server.Open()
  244. if err != nil && conn != nil {
  245. t.Error("bad: open should return a connection of nil value")
  246. }
  247. }
  248. func TestSendData_Small(t *testing.T) {
  249. client, server := testClientServer()
  250. defer client.Close()
  251. defer server.Close()
  252. wg := &sync.WaitGroup{}
  253. wg.Add(2)
  254. go func() {
  255. defer wg.Done()
  256. stream, err := server.AcceptStream()
  257. if err != nil {
  258. t.Fatalf("err: %v", err)
  259. }
  260. if server.NumStreams() != 1 {
  261. t.Fatalf("bad")
  262. }
  263. buf := make([]byte, 4)
  264. for i := 0; i < 1000; i++ {
  265. n, err := stream.Read(buf)
  266. if err != nil {
  267. t.Fatalf("err: %v", err)
  268. }
  269. if n != 4 {
  270. t.Fatalf("short read: %d", n)
  271. }
  272. if string(buf) != "test" {
  273. t.Fatalf("bad: %s", buf)
  274. }
  275. }
  276. if err := stream.Close(); err != nil {
  277. t.Fatalf("err: %v", err)
  278. }
  279. }()
  280. go func() {
  281. defer wg.Done()
  282. stream, err := client.Open()
  283. if err != nil {
  284. t.Fatalf("err: %v", err)
  285. }
  286. if client.NumStreams() != 1 {
  287. t.Fatalf("bad")
  288. }
  289. for i := 0; i < 1000; i++ {
  290. n, err := stream.Write([]byte("test"))
  291. if err != nil {
  292. t.Fatalf("err: %v", err)
  293. }
  294. if n != 4 {
  295. t.Fatalf("short write %d", n)
  296. }
  297. }
  298. if err := stream.Close(); err != nil {
  299. t.Fatalf("err: %v", err)
  300. }
  301. }()
  302. doneCh := make(chan struct{})
  303. go func() {
  304. wg.Wait()
  305. close(doneCh)
  306. }()
  307. select {
  308. case <-doneCh:
  309. case <-time.After(time.Second):
  310. panic("timeout")
  311. }
  312. if client.NumStreams() != 0 {
  313. t.Fatalf("bad")
  314. }
  315. if server.NumStreams() != 0 {
  316. t.Fatalf("bad")
  317. }
  318. }
  319. func TestSendData_Large(t *testing.T) {
  320. client, server := testClientServer()
  321. defer client.Close()
  322. defer server.Close()
  323. data := make([]byte, 512*1024)
  324. for idx := range data {
  325. data[idx] = byte(idx % 256)
  326. }
  327. wg := &sync.WaitGroup{}
  328. wg.Add(2)
  329. go func() {
  330. defer wg.Done()
  331. stream, err := server.AcceptStream()
  332. if err != nil {
  333. t.Fatalf("err: %v", err)
  334. }
  335. buf := make([]byte, 4*1024)
  336. for i := 0; i < 128; i++ {
  337. n, err := stream.Read(buf)
  338. if err != nil {
  339. t.Fatalf("err: %v", err)
  340. }
  341. if n != 4*1024 {
  342. t.Fatalf("short read: %d", n)
  343. }
  344. for idx := range buf {
  345. if buf[idx] != byte(idx%256) {
  346. t.Fatalf("bad: %v %v %v", i, idx, buf[idx])
  347. }
  348. }
  349. }
  350. if err := stream.Close(); err != nil {
  351. t.Fatalf("err: %v", err)
  352. }
  353. }()
  354. go func() {
  355. defer wg.Done()
  356. stream, err := client.Open()
  357. if err != nil {
  358. t.Fatalf("err: %v", err)
  359. }
  360. n, err := stream.Write(data)
  361. if err != nil {
  362. t.Fatalf("err: %v", err)
  363. }
  364. if n != len(data) {
  365. t.Fatalf("short write %d", n)
  366. }
  367. if err := stream.Close(); err != nil {
  368. t.Fatalf("err: %v", err)
  369. }
  370. }()
  371. doneCh := make(chan struct{})
  372. go func() {
  373. wg.Wait()
  374. close(doneCh)
  375. }()
  376. select {
  377. case <-doneCh:
  378. case <-time.After(time.Second):
  379. panic("timeout")
  380. }
  381. }
  382. func TestGoAway(t *testing.T) {
  383. client, server := testClientServer()
  384. defer client.Close()
  385. defer server.Close()
  386. if err := server.GoAway(); err != nil {
  387. t.Fatalf("err: %v", err)
  388. }
  389. _, err := client.Open()
  390. if err != ErrRemoteGoAway {
  391. t.Fatalf("err: %v", err)
  392. }
  393. }
  394. func TestManyStreams(t *testing.T) {
  395. client, server := testClientServer()
  396. defer client.Close()
  397. defer server.Close()
  398. wg := &sync.WaitGroup{}
  399. acceptor := func(i int) {
  400. defer wg.Done()
  401. stream, err := server.AcceptStream()
  402. if err != nil {
  403. t.Fatalf("err: %v", err)
  404. }
  405. defer stream.Close()
  406. buf := make([]byte, 512)
  407. for {
  408. n, err := stream.Read(buf)
  409. if err == io.EOF {
  410. return
  411. }
  412. if err != nil {
  413. t.Fatalf("err: %v", err)
  414. }
  415. if n == 0 {
  416. t.Fatalf("err: %v", err)
  417. }
  418. }
  419. }
  420. sender := func(i int) {
  421. defer wg.Done()
  422. stream, err := client.Open()
  423. if err != nil {
  424. t.Fatalf("err: %v", err)
  425. }
  426. defer stream.Close()
  427. msg := fmt.Sprintf("%08d", i)
  428. for i := 0; i < 1000; i++ {
  429. n, err := stream.Write([]byte(msg))
  430. if err != nil {
  431. t.Fatalf("err: %v", err)
  432. }
  433. if n != len(msg) {
  434. t.Fatalf("short write %d", n)
  435. }
  436. }
  437. }
  438. for i := 0; i < 50; i++ {
  439. wg.Add(2)
  440. go acceptor(i)
  441. go sender(i)
  442. }
  443. wg.Wait()
  444. }
  445. func TestManyStreams_PingPong(t *testing.T) {
  446. client, server := testClientServer()
  447. defer client.Close()
  448. defer server.Close()
  449. wg := &sync.WaitGroup{}
  450. ping := []byte("ping")
  451. pong := []byte("pong")
  452. acceptor := func(i int) {
  453. defer wg.Done()
  454. stream, err := server.AcceptStream()
  455. if err != nil {
  456. t.Fatalf("err: %v", err)
  457. }
  458. defer stream.Close()
  459. buf := make([]byte, 4)
  460. for {
  461. // Read the 'ping'
  462. n, err := stream.Read(buf)
  463. if err == io.EOF {
  464. return
  465. }
  466. if err != nil {
  467. t.Fatalf("err: %v", err)
  468. }
  469. if n != 4 {
  470. t.Fatalf("err: %v", err)
  471. }
  472. if !bytes.Equal(buf, ping) {
  473. t.Fatalf("bad: %s", buf)
  474. }
  475. // Shrink the internal buffer!
  476. stream.Shrink()
  477. // Write out the 'pong'
  478. n, err = stream.Write(pong)
  479. if err != nil {
  480. t.Fatalf("err: %v", err)
  481. }
  482. if n != 4 {
  483. t.Fatalf("err: %v", err)
  484. }
  485. }
  486. }
  487. sender := func(i int) {
  488. defer wg.Done()
  489. stream, err := client.OpenStream()
  490. if err != nil {
  491. t.Fatalf("err: %v", err)
  492. }
  493. defer stream.Close()
  494. buf := make([]byte, 4)
  495. for i := 0; i < 1000; i++ {
  496. // Send the 'ping'
  497. n, err := stream.Write(ping)
  498. if err != nil {
  499. t.Fatalf("err: %v", err)
  500. }
  501. if n != 4 {
  502. t.Fatalf("short write %d", n)
  503. }
  504. // Read the 'pong'
  505. n, err = stream.Read(buf)
  506. if err != nil {
  507. t.Fatalf("err: %v", err)
  508. }
  509. if n != 4 {
  510. t.Fatalf("err: %v", err)
  511. }
  512. if !bytes.Equal(buf, pong) {
  513. t.Fatalf("bad: %s", buf)
  514. }
  515. // Shrink the buffer
  516. stream.Shrink()
  517. }
  518. }
  519. for i := 0; i < 50; i++ {
  520. wg.Add(2)
  521. go acceptor(i)
  522. go sender(i)
  523. }
  524. wg.Wait()
  525. }
  526. func TestHalfClose(t *testing.T) {
  527. client, server := testClientServer()
  528. defer client.Close()
  529. defer server.Close()
  530. stream, err := client.Open()
  531. if err != nil {
  532. t.Fatalf("err: %v", err)
  533. }
  534. if _, err := stream.Write([]byte("a")); err != nil {
  535. t.Fatalf("err: %v", err)
  536. }
  537. stream2, err := server.Accept()
  538. if err != nil {
  539. t.Fatalf("err: %v", err)
  540. }
  541. stream2.Close() // Half close
  542. buf := make([]byte, 4)
  543. n, err := stream2.Read(buf)
  544. if err != nil {
  545. t.Fatalf("err: %v", err)
  546. }
  547. if n != 1 {
  548. t.Fatalf("bad: %v", n)
  549. }
  550. // Send more
  551. if _, err := stream.Write([]byte("bcd")); err != nil {
  552. t.Fatalf("err: %v", err)
  553. }
  554. stream.Close()
  555. // Read after close
  556. n, err = stream2.Read(buf)
  557. if err != nil {
  558. t.Fatalf("err: %v", err)
  559. }
  560. if n != 3 {
  561. t.Fatalf("bad: %v", n)
  562. }
  563. // EOF after close
  564. n, err = stream2.Read(buf)
  565. if err != io.EOF {
  566. t.Fatalf("err: %v", err)
  567. }
  568. if n != 0 {
  569. t.Fatalf("bad: %v", n)
  570. }
  571. }
  572. func TestReadDeadline(t *testing.T) {
  573. client, server := testClientServer()
  574. defer client.Close()
  575. defer server.Close()
  576. stream, err := client.Open()
  577. if err != nil {
  578. t.Fatalf("err: %v", err)
  579. }
  580. defer stream.Close()
  581. stream2, err := server.Accept()
  582. if err != nil {
  583. t.Fatalf("err: %v", err)
  584. }
  585. defer stream2.Close()
  586. if err := stream.SetReadDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
  587. t.Fatalf("err: %v", err)
  588. }
  589. buf := make([]byte, 4)
  590. if _, err := stream.Read(buf); err != ErrTimeout {
  591. t.Fatalf("err: %v", err)
  592. }
  593. }
  594. func TestWriteDeadline(t *testing.T) {
  595. client, server := testClientServer()
  596. defer client.Close()
  597. defer server.Close()
  598. stream, err := client.Open()
  599. if err != nil {
  600. t.Fatalf("err: %v", err)
  601. }
  602. defer stream.Close()
  603. stream2, err := server.Accept()
  604. if err != nil {
  605. t.Fatalf("err: %v", err)
  606. }
  607. defer stream2.Close()
  608. if err := stream.SetWriteDeadline(time.Now().Add(50 * time.Millisecond)); err != nil {
  609. t.Fatalf("err: %v", err)
  610. }
  611. buf := make([]byte, 512)
  612. for i := 0; i < int(initialStreamWindow); i++ {
  613. _, err := stream.Write(buf)
  614. if err != nil && err == ErrTimeout {
  615. return
  616. } else if err != nil {
  617. t.Fatalf("err: %v", err)
  618. }
  619. }
  620. t.Fatalf("Expected timeout")
  621. }
  622. func TestBacklogExceeded(t *testing.T) {
  623. client, server := testClientServer()
  624. defer client.Close()
  625. defer server.Close()
  626. // Fill the backlog
  627. max := client.config.AcceptBacklog
  628. for i := 0; i < max; i++ {
  629. stream, err := client.Open()
  630. if err != nil {
  631. t.Fatalf("err: %v", err)
  632. }
  633. defer stream.Close()
  634. if _, err := stream.Write([]byte("foo")); err != nil {
  635. t.Fatalf("err: %v", err)
  636. }
  637. }
  638. // Attempt to open a new stream
  639. errCh := make(chan error, 1)
  640. go func() {
  641. _, err := client.Open()
  642. errCh <- err
  643. }()
  644. // Shutdown the server
  645. go func() {
  646. time.Sleep(10 * time.Millisecond)
  647. server.Close()
  648. }()
  649. select {
  650. case err := <-errCh:
  651. if err == nil {
  652. t.Fatalf("open should fail")
  653. }
  654. case <-time.After(time.Second):
  655. t.Fatalf("timeout")
  656. }
  657. }
  658. func TestKeepAlive(t *testing.T) {
  659. client, server := testClientServer()
  660. defer client.Close()
  661. defer server.Close()
  662. time.Sleep(200 * time.Millisecond)
  663. // Ping value should increase
  664. client.pingLock.Lock()
  665. defer client.pingLock.Unlock()
  666. if client.pingID == 0 {
  667. t.Fatalf("should ping")
  668. }
  669. server.pingLock.Lock()
  670. defer server.pingLock.Unlock()
  671. if server.pingID == 0 {
  672. t.Fatalf("should ping")
  673. }
  674. }
  675. func TestKeepAlive_Timeout(t *testing.T) {
  676. conn1, conn2 := testConn()
  677. clientConf := testConf()
  678. clientConf.ConnectionWriteTimeout = time.Hour // We're testing keep alives, not connection writes
  679. clientConf.EnableKeepAlive = false // Just test one direction, so it's deterministic who hangs up on whom
  680. client, _ := Client(conn1, clientConf)
  681. defer client.Close()
  682. server, _ := Server(conn2, testConf())
  683. defer server.Close()
  684. _ = captureLogs(client) // Client logs aren't part of the test
  685. serverLogs := captureLogs(server)
  686. errCh := make(chan error, 1)
  687. go func() {
  688. _, err := server.Accept() // Wait until server closes
  689. errCh <- err
  690. }()
  691. // Prevent the client from responding
  692. clientConn := client.conn.(*pipeConn)
  693. clientConn.writeBlocker.Lock()
  694. select {
  695. case err := <-errCh:
  696. if err != ErrKeepAliveTimeout {
  697. t.Fatalf("unexpected error: %v", err)
  698. }
  699. case <-time.After(1 * time.Second):
  700. t.Fatalf("timeout waiting for timeout")
  701. }
  702. if !server.IsClosed() {
  703. t.Fatalf("server should have closed")
  704. }
  705. if !serverLogs.match([]string{"[ERR] yamux: keepalive failed: i/o deadline reached"}) {
  706. t.Fatalf("server log incorect: %v", serverLogs.logs())
  707. }
  708. }
  709. func TestLargeWindow(t *testing.T) {
  710. conf := DefaultConfig()
  711. conf.MaxStreamWindowSize *= 2
  712. client, server := testClientServerConfig(conf)
  713. defer client.Close()
  714. defer server.Close()
  715. stream, err := client.Open()
  716. if err != nil {
  717. t.Fatalf("err: %v", err)
  718. }
  719. defer stream.Close()
  720. stream2, err := server.Accept()
  721. if err != nil {
  722. t.Fatalf("err: %v", err)
  723. }
  724. defer stream2.Close()
  725. stream.SetWriteDeadline(time.Now().Add(10 * time.Millisecond))
  726. buf := make([]byte, conf.MaxStreamWindowSize)
  727. n, err := stream.Write(buf)
  728. if err != nil {
  729. t.Fatalf("err: %v", err)
  730. }
  731. if n != len(buf) {
  732. t.Fatalf("short write: %d", n)
  733. }
  734. }
  735. type UnlimitedReader struct{}
  736. func (u *UnlimitedReader) Read(p []byte) (int, error) {
  737. runtime.Gosched()
  738. return len(p), nil
  739. }
  740. func TestSendData_VeryLarge(t *testing.T) {
  741. client, server := testClientServer()
  742. defer client.Close()
  743. defer server.Close()
  744. var n int64 = 1 * 1024 * 1024 * 1024
  745. var workers int = 16
  746. wg := &sync.WaitGroup{}
  747. wg.Add(workers * 2)
  748. for i := 0; i < workers; i++ {
  749. go func() {
  750. defer wg.Done()
  751. stream, err := server.AcceptStream()
  752. if err != nil {
  753. t.Fatalf("err: %v", err)
  754. }
  755. defer stream.Close()
  756. buf := make([]byte, 4)
  757. _, err = stream.Read(buf)
  758. if err != nil {
  759. t.Fatalf("err: %v", err)
  760. }
  761. if !bytes.Equal(buf, []byte{0, 1, 2, 3}) {
  762. t.Fatalf("bad header")
  763. }
  764. recv, err := io.Copy(ioutil.Discard, stream)
  765. if err != nil {
  766. t.Fatalf("err: %v", err)
  767. }
  768. if recv != n {
  769. t.Fatalf("bad: %v", recv)
  770. }
  771. }()
  772. }
  773. for i := 0; i < workers; i++ {
  774. go func() {
  775. defer wg.Done()
  776. stream, err := client.Open()
  777. if err != nil {
  778. t.Fatalf("err: %v", err)
  779. }
  780. defer stream.Close()
  781. _, err = stream.Write([]byte{0, 1, 2, 3})
  782. if err != nil {
  783. t.Fatalf("err: %v", err)
  784. }
  785. unlimited := &UnlimitedReader{}
  786. sent, err := io.Copy(stream, io.LimitReader(unlimited, n))
  787. if err != nil {
  788. t.Fatalf("err: %v", err)
  789. }
  790. if sent != n {
  791. t.Fatalf("bad: %v", sent)
  792. }
  793. }()
  794. }
  795. doneCh := make(chan struct{})
  796. go func() {
  797. wg.Wait()
  798. close(doneCh)
  799. }()
  800. select {
  801. case <-doneCh:
  802. case <-time.After(20 * time.Second):
  803. panic("timeout")
  804. }
  805. }
  806. func TestBacklogExceeded_Accept(t *testing.T) {
  807. client, server := testClientServer()
  808. defer client.Close()
  809. defer server.Close()
  810. max := 5 * client.config.AcceptBacklog
  811. go func() {
  812. for i := 0; i < max; i++ {
  813. stream, err := server.Accept()
  814. if err != nil {
  815. t.Fatalf("err: %v", err)
  816. }
  817. defer stream.Close()
  818. }
  819. }()
  820. // Fill the backlog
  821. for i := 0; i < max; i++ {
  822. stream, err := client.Open()
  823. if err != nil {
  824. t.Fatalf("err: %v", err)
  825. }
  826. defer stream.Close()
  827. if _, err := stream.Write([]byte("foo")); err != nil {
  828. t.Fatalf("err: %v", err)
  829. }
  830. }
  831. }
  832. func TestSession_WindowUpdateWriteDuringRead(t *testing.T) {
  833. client, server := testClientServerConfig(testConfNoKeepAlive())
  834. defer client.Close()
  835. defer server.Close()
  836. var wg sync.WaitGroup
  837. wg.Add(2)
  838. // Choose a huge flood size that we know will result in a window update.
  839. flood := int64(client.config.MaxStreamWindowSize) - 1
  840. // The server will accept a new stream and then flood data to it.
  841. go func() {
  842. defer wg.Done()
  843. stream, err := server.AcceptStream()
  844. if err != nil {
  845. t.Fatalf("err: %v", err)
  846. }
  847. defer stream.Close()
  848. n, err := stream.Write(make([]byte, flood))
  849. if err != nil {
  850. t.Fatalf("err: %v", err)
  851. }
  852. if int64(n) != flood {
  853. t.Fatalf("short write: %d", n)
  854. }
  855. }()
  856. // The client will open a stream, block outbound writes, and then
  857. // listen to the flood from the server, which should time out since
  858. // it won't be able to send the window update.
  859. go func() {
  860. defer wg.Done()
  861. stream, err := client.OpenStream()
  862. if err != nil {
  863. t.Fatalf("err: %v", err)
  864. }
  865. defer stream.Close()
  866. conn := client.conn.(*pipeConn)
  867. conn.writeBlocker.Lock()
  868. _, err = stream.Read(make([]byte, flood))
  869. if err != ErrConnectionWriteTimeout {
  870. t.Fatalf("err: %v", err)
  871. }
  872. }()
  873. wg.Wait()
  874. }
  875. func TestSession_sendNoWait_Timeout(t *testing.T) {
  876. client, server := testClientServerConfig(testConfNoKeepAlive())
  877. defer client.Close()
  878. defer server.Close()
  879. var wg sync.WaitGroup
  880. wg.Add(2)
  881. go func() {
  882. defer wg.Done()
  883. stream, err := server.AcceptStream()
  884. if err != nil {
  885. t.Fatalf("err: %v", err)
  886. }
  887. defer stream.Close()
  888. }()
  889. // The client will open the stream and then block outbound writes, we'll
  890. // probe sendNoWait once it gets into that state.
  891. go func() {
  892. defer wg.Done()
  893. stream, err := client.OpenStream()
  894. if err != nil {
  895. t.Fatalf("err: %v", err)
  896. }
  897. defer stream.Close()
  898. conn := client.conn.(*pipeConn)
  899. conn.writeBlocker.Lock()
  900. hdr := header(make([]byte, headerSize))
  901. hdr.encode(typePing, flagACK, 0, 0)
  902. for {
  903. err = client.sendNoWait(hdr)
  904. if err == nil {
  905. continue
  906. } else if err == ErrConnectionWriteTimeout {
  907. break
  908. } else {
  909. t.Fatalf("err: %v", err)
  910. }
  911. }
  912. }()
  913. wg.Wait()
  914. }
  915. func TestSession_PingOfDeath(t *testing.T) {
  916. client, server := testClientServerConfig(testConfNoKeepAlive())
  917. defer client.Close()
  918. defer server.Close()
  919. var wg sync.WaitGroup
  920. wg.Add(2)
  921. var doPingOfDeath sync.Mutex
  922. doPingOfDeath.Lock()
  923. // This is used later to block outbound writes.
  924. conn := server.conn.(*pipeConn)
  925. // The server will accept a stream, block outbound writes, and then
  926. // flood its send channel so that no more headers can be queued.
  927. go func() {
  928. defer wg.Done()
  929. stream, err := server.AcceptStream()
  930. if err != nil {
  931. t.Fatalf("err: %v", err)
  932. }
  933. defer stream.Close()
  934. conn.writeBlocker.Lock()
  935. for {
  936. hdr := header(make([]byte, headerSize))
  937. hdr.encode(typePing, 0, 0, 0)
  938. err = server.sendNoWait(hdr)
  939. if err == nil {
  940. continue
  941. } else if err == ErrConnectionWriteTimeout {
  942. break
  943. } else {
  944. t.Fatalf("err: %v", err)
  945. }
  946. }
  947. doPingOfDeath.Unlock()
  948. }()
  949. // The client will open a stream and then send the server a ping once it
  950. // can no longer write. This makes sure the server doesn't deadlock reads
  951. // while trying to reply to the ping with no ability to write.
  952. go func() {
  953. defer wg.Done()
  954. stream, err := client.OpenStream()
  955. if err != nil {
  956. t.Fatalf("err: %v", err)
  957. }
  958. defer stream.Close()
  959. // This ping will never unblock because the ping id will never
  960. // show up in a response.
  961. doPingOfDeath.Lock()
  962. go func() { client.Ping() }()
  963. // Wait for a while to make sure the previous ping times out,
  964. // then turn writes back on and make sure a ping works again.
  965. time.Sleep(2 * server.config.ConnectionWriteTimeout)
  966. conn.writeBlocker.Unlock()
  967. if _, err = client.Ping(); err != nil {
  968. t.Fatalf("err: %v", err)
  969. }
  970. }()
  971. wg.Wait()
  972. }
  973. func TestSession_ConnectionWriteTimeout(t *testing.T) {
  974. client, server := testClientServerConfig(testConfNoKeepAlive())
  975. defer client.Close()
  976. defer server.Close()
  977. var wg sync.WaitGroup
  978. wg.Add(2)
  979. go func() {
  980. defer wg.Done()
  981. stream, err := server.AcceptStream()
  982. if err != nil {
  983. t.Fatalf("err: %v", err)
  984. }
  985. defer stream.Close()
  986. }()
  987. // The client will open the stream and then block outbound writes, we'll
  988. // tee up a write and make sure it eventually times out.
  989. go func() {
  990. defer wg.Done()
  991. stream, err := client.OpenStream()
  992. if err != nil {
  993. t.Fatalf("err: %v", err)
  994. }
  995. defer stream.Close()
  996. conn := client.conn.(*pipeConn)
  997. conn.writeBlocker.Lock()
  998. // Since the write goroutine is blocked then this will return a
  999. // timeout since it can't get feedback about whether the write
  1000. // worked.
  1001. n, err := stream.Write([]byte("hello"))
  1002. if err != ErrConnectionWriteTimeout {
  1003. t.Fatalf("err: %v", err)
  1004. }
  1005. if n != 0 {
  1006. t.Fatalf("lied about writes: %d", n)
  1007. }
  1008. }()
  1009. wg.Wait()
  1010. }