session_test.go 24 KB


  1. package yamux
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "io/ioutil"
  7. "log"
  8. "reflect"
  9. "runtime"
  10. "strings"
  11. "sync"
  12. "testing"
  13. "time"
  14. )
  15. type logCapture struct{ bytes.Buffer }
  16. func (l *logCapture) logs() []string {
  17. return strings.Split(strings.TrimSpace(l.String()), "\n")
  18. }
  19. func (l *logCapture) match(expect []string) bool {
  20. return reflect.DeepEqual(l.logs(), expect)
  21. }
  22. func captureLogs(s *Session) *logCapture {
  23. buf := new(logCapture)
  24. s.logger = log.New(buf, "", 0)
  25. return buf
  26. }
  27. type pipeConn struct {
  28. reader *io.PipeReader
  29. writer *io.PipeWriter
  30. writeBlocker sync.Mutex
  31. }
  32. func (p *pipeConn) Read(b []byte) (int, error) {
  33. return p.reader.Read(b)
  34. }
  35. func (p *pipeConn) Write(b []byte) (int, error) {
  36. p.writeBlocker.Lock()
  37. defer p.writeBlocker.Unlock()
  38. return p.writer.Write(b)
  39. }
  40. func (p *pipeConn) Close() error {
  41. p.reader.Close()
  42. return p.writer.Close()
  43. }
  44. func testConn() (io.ReadWriteCloser, io.ReadWriteCloser) {
  45. read1, write1 := io.Pipe()
  46. read2, write2 := io.Pipe()
  47. conn1 := &pipeConn{reader: read1, writer: write2}
  48. conn2 := &pipeConn{reader: read2, writer: write1}
  49. return conn1, conn2
  50. }
  51. func testConf() *Config {
  52. conf := DefaultConfig()
  53. conf.AcceptBacklog = 64
  54. conf.KeepAliveInterval = 100 * time.Millisecond
  55. conf.ConnectionWriteTimeout = 250 * time.Millisecond
  56. return conf
  57. }
  58. func testConfNoKeepAlive() *Config {
  59. conf := testConf()
  60. conf.EnableKeepAlive = false
  61. return conf
  62. }
  63. func testClientServer() (*Session, *Session) {
  64. return testClientServerConfig(testConf())
  65. }
  66. func testClientServerConfig(conf *Config) (*Session, *Session) {
  67. conn1, conn2 := testConn()
  68. client, _ := Client(conn1, conf)
  69. server, _ := Server(conn2, conf)
  70. return client, server
  71. }
  72. func TestPing(t *testing.T) {
  73. client, server := testClientServer()
  74. defer client.Close()
  75. defer server.Close()
  76. rtt, err := client.Ping()
  77. if err != nil {
  78. t.Fatalf("err: %v", err)
  79. }
  80. if rtt == 0 {
  81. t.Fatalf("bad: %v", rtt)
  82. }
  83. rtt, err = server.Ping()
  84. if err != nil {
  85. t.Fatalf("err: %v", err)
  86. }
  87. if rtt == 0 {
  88. t.Fatalf("bad: %v", rtt)
  89. }
  90. }
  91. func TestPing_Timeout(t *testing.T) {
  92. client, server := testClientServerConfig(testConfNoKeepAlive())
  93. defer client.Close()
  94. defer server.Close()
  95. // Prevent the client from responding
  96. clientConn := client.conn.(*pipeConn)
  97. clientConn.writeBlocker.Lock()
  98. errCh := make(chan error, 1)
  99. go func() {
  100. _, err := server.Ping() // Ping via the server session
  101. errCh <- err
  102. }()
  103. select {
  104. case err := <-errCh:
  105. if err != ErrTimeout {
  106. t.Fatalf("err: %v", err)
  107. }
  108. case <-time.After(client.config.ConnectionWriteTimeout * 2):
  109. t.Fatalf("failed to timeout within expected %v", client.config.ConnectionWriteTimeout)
  110. }
  111. // Verify that we recover, even if we gave up
  112. clientConn.writeBlocker.Unlock()
  113. go func() {
  114. _, err := server.Ping() // Ping via the server session
  115. errCh <- err
  116. }()
  117. select {
  118. case err := <-errCh:
  119. if err != nil {
  120. t.Fatalf("err: %v", err)
  121. }
  122. case <-time.After(client.config.ConnectionWriteTimeout):
  123. t.Fatalf("timeout")
  124. }
  125. }
  126. func TestCloseBeforeAck(t *testing.T) {
  127. cfg := testConf()
  128. cfg.AcceptBacklog = 8
  129. client, server := testClientServerConfig(cfg)
  130. defer client.Close()
  131. defer server.Close()
  132. for i := 0; i < 8; i++ {
  133. s, err := client.OpenStream()
  134. if err != nil {
  135. t.Fatal(err)
  136. }
  137. s.Close()
  138. }
  139. for i := 0; i < 8; i++ {
  140. s, err := server.AcceptStream()
  141. if err != nil {
  142. t.Fatal(err)
  143. }
  144. s.Close()
  145. }
  146. done := make(chan struct{})
  147. go func() {
  148. defer close(done)
  149. s, err := client.OpenStream()
  150. if err != nil {
  151. t.Fatal(err)
  152. }
  153. s.Close()
  154. }()
  155. select {
  156. case <-done:
  157. case <-time.After(time.Second * 5):
  158. t.Fatal("timed out trying to open stream")
  159. }
  160. }
  161. func TestAccept(t *testing.T) {
  162. client, server := testClientServer()
  163. defer client.Close()
  164. defer server.Close()
  165. if client.NumStreams() != 0 {
  166. t.Fatalf("bad")
  167. }
  168. if server.NumStreams() != 0 {
  169. t.Fatalf("bad")
  170. }
  171. wg := &sync.WaitGroup{}
  172. wg.Add(4)
  173. go func() {
  174. defer wg.Done()
  175. stream, err := server.AcceptStream()
  176. if err != nil {
  177. t.Fatalf("err: %v", err)
  178. }
  179. if id := stream.StreamID(); id != 1 {
  180. t.Fatalf("bad: %v", id)
  181. }
  182. if err := stream.Close(); err != nil {
  183. t.Fatalf("err: %v", err)
  184. }
  185. }()
  186. go func() {
  187. defer wg.Done()
  188. stream, err := client.AcceptStream()
  189. if err != nil {
  190. t.Fatalf("err: %v", err)
  191. }
  192. if id := stream.StreamID(); id != 2 {
  193. t.Fatalf("bad: %v", id)
  194. }
  195. if err := stream.Close(); err != nil {
  196. t.Fatalf("err: %v", err)
  197. }
  198. }()
  199. go func() {
  200. defer wg.Done()
  201. stream, err := server.OpenStream()
  202. if err != nil {
  203. t.Fatalf("err: %v", err)
  204. }
  205. if id := stream.StreamID(); id != 2 {
  206. t.Fatalf("bad: %v", id)
  207. }
  208. if err := stream.Close(); err != nil {
  209. t.Fatalf("err: %v", err)
  210. }
  211. }()
  212. go func() {
  213. defer wg.Done()
  214. stream, err := client.OpenStream()
  215. if err != nil {
  216. t.Fatalf("err: %v", err)
  217. }
  218. if id := stream.StreamID(); id != 1 {
  219. t.Fatalf("bad: %v", id)
  220. }
  221. if err := stream.Close(); err != nil {
  222. t.Fatalf("err: %v", err)
  223. }
  224. }()
  225. doneCh := make(chan struct{})
  226. go func() {
  227. wg.Wait()
  228. close(doneCh)
  229. }()
  230. select {
  231. case <-doneCh:
  232. case <-time.After(time.Second):
  233. panic("timeout")
  234. }
  235. }
  236. func TestNonNilInterface(t *testing.T) {
  237. _, server := testClientServer()
  238. server.Close()
  239. conn, err := server.Accept()
  240. if err != nil && conn != nil {
  241. t.Error("bad: accept should return a connection of nil value")
  242. }
  243. conn, err = server.Open()
  244. if err != nil && conn != nil {
  245. t.Error("bad: open should return a connection of nil value")
  246. }
  247. }
  248. func TestSendData_Small(t *testing.T) {
  249. client, server := testClientServer()
  250. defer client.Close()
  251. defer server.Close()
  252. wg := &sync.WaitGroup{}
  253. wg.Add(2)
  254. go func() {
  255. defer wg.Done()
  256. stream, err := server.AcceptStream()
  257. if err != nil {
  258. t.Fatalf("err: %v", err)
  259. }
  260. if server.NumStreams() != 1 {
  261. t.Fatalf("bad")
  262. }
  263. buf := make([]byte, 4)
  264. for i := 0; i < 1000; i++ {
  265. n, err := stream.Read(buf)
  266. if err != nil {
  267. t.Fatalf("err: %v", err)
  268. }
  269. if n != 4 {
  270. t.Fatalf("short read: %d", n)
  271. }
  272. if string(buf) != "test" {
  273. t.Fatalf("bad: %s", buf)
  274. }
  275. }
  276. if err := stream.Close(); err != nil {
  277. t.Fatalf("err: %v", err)
  278. }
  279. }()
  280. go func() {
  281. defer wg.Done()
  282. stream, err := client.Open()
  283. if err != nil {
  284. t.Fatalf("err: %v", err)
  285. }
  286. if client.NumStreams() != 1 {
  287. t.Fatalf("bad")
  288. }
  289. for i := 0; i < 1000; i++ {
  290. n, err := stream.Write([]byte("test"))
  291. if err != nil {
  292. t.Fatalf("err: %v", err)
  293. }
  294. if n != 4 {
  295. t.Fatalf("short write %d", n)
  296. }
  297. }
  298. if err := stream.Close(); err != nil {
  299. t.Fatalf("err: %v", err)
  300. }
  301. }()
  302. doneCh := make(chan struct{})
  303. go func() {
  304. wg.Wait()
  305. close(doneCh)
  306. }()
  307. select {
  308. case <-doneCh:
  309. case <-time.After(time.Second):
  310. panic("timeout")
  311. }
  312. if client.NumStreams() != 0 {
  313. t.Fatalf("bad")
  314. }
  315. if server.NumStreams() != 0 {
  316. t.Fatalf("bad")
  317. }
  318. }
  319. func TestSendData_Large(t *testing.T) {
  320. client, server := testClientServer()
  321. defer client.Close()
  322. defer server.Close()
  323. const (
  324. sendSize = 250 * 1024 * 1024
  325. recvSize = 4 * 1024
  326. )
  327. data := make([]byte, sendSize)
  328. for idx := range data {
  329. data[idx] = byte(idx % 256)
  330. }
  331. wg := &sync.WaitGroup{}
  332. wg.Add(2)
  333. go func() {
  334. defer wg.Done()
  335. stream, err := server.AcceptStream()
  336. if err != nil {
  337. t.Fatalf("err: %v", err)
  338. }
  339. var sz int
  340. buf := make([]byte, recvSize)
  341. for i := 0; i < sendSize/recvSize; i++ {
  342. n, err := stream.Read(buf)
  343. if err != nil {
  344. t.Fatalf("err: %v", err)
  345. }
  346. if n != recvSize {
  347. t.Fatalf("short read: %d", n)
  348. }
  349. sz += n
  350. for idx := range buf {
  351. if buf[idx] != byte(idx%256) {
  352. t.Fatalf("bad: %v %v %v", i, idx, buf[idx])
  353. }
  354. }
  355. }
  356. if err := stream.Close(); err != nil {
  357. t.Fatalf("err: %v", err)
  358. }
  359. t.Logf("cap=%d, n=%d\n", stream.recvBuf.Cap(), sz)
  360. }()
  361. go func() {
  362. defer wg.Done()
  363. stream, err := client.Open()
  364. if err != nil {
  365. t.Fatalf("err: %v", err)
  366. }
  367. n, err := stream.Write(data)
  368. if err != nil {
  369. t.Fatalf("err: %v", err)
  370. }
  371. if n != len(data) {
  372. t.Fatalf("short write %d", n)
  373. }
  374. if err := stream.Close(); err != nil {
  375. t.Fatalf("err: %v", err)
  376. }
  377. }()
  378. doneCh := make(chan struct{})
  379. go func() {
  380. wg.Wait()
  381. close(doneCh)
  382. }()
  383. select {
  384. case <-doneCh:
  385. case <-time.After(5 * time.Second):
  386. panic("timeout")
  387. }
  388. }
  389. func TestGoAway(t *testing.T) {
  390. client, server := testClientServer()
  391. defer client.Close()
  392. defer server.Close()
  393. if err := server.GoAway(); err != nil {
  394. t.Fatalf("err: %v", err)
  395. }
  396. _, err := client.Open()
  397. if err != ErrRemoteGoAway {
  398. t.Fatalf("err: %v", err)
  399. }
  400. }
  401. func TestManyStreams(t *testing.T) {
  402. client, server := testClientServer()
  403. defer client.Close()
  404. defer server.Close()
  405. wg := &sync.WaitGroup{}
  406. acceptor := func(i int) {
  407. defer wg.Done()
  408. stream, err := server.AcceptStream()
  409. if err != nil {
  410. t.Fatalf("err: %v", err)
  411. }
  412. defer stream.Close()
  413. buf := make([]byte, 512)
  414. for {
  415. n, err := stream.Read(buf)
  416. if err == io.EOF {
  417. return
  418. }
  419. if err != nil {
  420. t.Fatalf("err: %v", err)
  421. }
  422. if n == 0 {
  423. t.Fatalf("err: %v", err)
  424. }
  425. }
  426. }
  427. sender := func(i int) {
  428. defer wg.Done()
  429. stream, err := client.Open()
  430. if err != nil {
  431. t.Fatalf("err: %v", err)
  432. }
  433. defer stream.Close()
  434. msg := fmt.Sprintf("%08d", i)
  435. for i := 0; i < 1000; i++ {
  436. n, err := stream.Write([]byte(msg))
  437. if err != nil {
  438. t.Fatalf("err: %v", err)
  439. }
  440. if n != len(msg) {
  441. t.Fatalf("short write %d", n)
  442. }
  443. }
  444. }
  445. for i := 0; i < 50; i++ {
  446. wg.Add(2)
  447. go acceptor(i)
  448. go sender(i)
  449. }
  450. wg.Wait()
  451. }
  452. func TestManyStreams_PingPong(t *testing.T) {
  453. client, server := testClientServer()
  454. defer client.Close()
  455. defer server.Close()
  456. wg := &sync.WaitGroup{}
  457. ping := []byte("ping")
  458. pong := []byte("pong")
  459. acceptor := func(i int) {
  460. defer wg.Done()
  461. stream, err := server.AcceptStream()
  462. if err != nil {
  463. t.Fatalf("err: %v", err)
  464. }
  465. defer stream.Close()
  466. buf := make([]byte, 4)
  467. for {
  468. // Read the 'ping'
  469. n, err := stream.Read(buf)
  470. if err == io.EOF {
  471. return
  472. }
  473. if err != nil {
  474. t.Fatalf("err: %v", err)
  475. }
  476. if n != 4 {
  477. t.Fatalf("err: %v", err)
  478. }
  479. if !bytes.Equal(buf, ping) {
  480. t.Fatalf("bad: %s", buf)
  481. }
  482. // Shrink the internal buffer!
  483. stream.Shrink()
  484. // Write out the 'pong'
  485. n, err = stream.Write(pong)
  486. if err != nil {
  487. t.Fatalf("err: %v", err)
  488. }
  489. if n != 4 {
  490. t.Fatalf("err: %v", err)
  491. }
  492. }
  493. }
  494. sender := func(i int) {
  495. defer wg.Done()
  496. stream, err := client.OpenStream()
  497. if err != nil {
  498. t.Fatalf("err: %v", err)
  499. }
  500. defer stream.Close()
  501. buf := make([]byte, 4)
  502. for i := 0; i < 1000; i++ {
  503. // Send the 'ping'
  504. n, err := stream.Write(ping)
  505. if err != nil {
  506. t.Fatalf("err: %v", err)
  507. }
  508. if n != 4 {
  509. t.Fatalf("short write %d", n)
  510. }
  511. // Read the 'pong'
  512. n, err = stream.Read(buf)
  513. if err != nil {
  514. t.Fatalf("err: %v", err)
  515. }
  516. if n != 4 {
  517. t.Fatalf("err: %v", err)
  518. }
  519. if !bytes.Equal(buf, pong) {
  520. t.Fatalf("bad: %s", buf)
  521. }
  522. // Shrink the buffer
  523. stream.Shrink()
  524. }
  525. }
  526. for i := 0; i < 50; i++ {
  527. wg.Add(2)
  528. go acceptor(i)
  529. go sender(i)
  530. }
  531. wg.Wait()
  532. }
  533. func TestHalfClose(t *testing.T) {
  534. client, server := testClientServer()
  535. defer client.Close()
  536. defer server.Close()
  537. stream, err := client.Open()
  538. if err != nil {
  539. t.Fatalf("err: %v", err)
  540. }
  541. if _, err = stream.Write([]byte("a")); err != nil {
  542. t.Fatalf("err: %v", err)
  543. }
  544. stream2, err := server.Accept()
  545. if err != nil {
  546. t.Fatalf("err: %v", err)
  547. }
  548. stream2.Close() // Half close
  549. buf := make([]byte, 4)
  550. n, err := stream2.Read(buf)
  551. if err != nil {
  552. t.Fatalf("err: %v", err)
  553. }
  554. if n != 1 {
  555. t.Fatalf("bad: %v", n)
  556. }
  557. // Send more
  558. if _, err = stream.Write([]byte("bcd")); err != nil {
  559. t.Fatalf("err: %v", err)
  560. }
  561. stream.Close()
  562. // Read after close
  563. n, err = stream2.Read(buf)
  564. if err != nil {
  565. t.Fatalf("err: %v", err)
  566. }
  567. if n != 3 {
  568. t.Fatalf("bad: %v", n)
  569. }
  570. // EOF after close
  571. n, err = stream2.Read(buf)
  572. if err != io.EOF {
  573. t.Fatalf("err: %v", err)
  574. }
  575. if n != 0 {
  576. t.Fatalf("bad: %v", n)
  577. }
  578. }
  579. func TestReadDeadline(t *testing.T) {
  580. client, server := testClientServer()
  581. defer client.Close()
  582. defer server.Close()
  583. stream, err := client.Open()
  584. if err != nil {
  585. t.Fatalf("err: %v", err)
  586. }
  587. defer stream.Close()
  588. stream2, err := server.Accept()
  589. if err != nil {
  590. t.Fatalf("err: %v", err)
  591. }
  592. defer stream2.Close()
  593. if err := stream.SetReadDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
  594. t.Fatalf("err: %v", err)
  595. }
  596. buf := make([]byte, 4)
  597. if _, err := stream.Read(buf); err != ErrTimeout {
  598. t.Fatalf("err: %v", err)
  599. }
  600. }
  601. func TestWriteDeadline(t *testing.T) {
  602. client, server := testClientServer()
  603. defer client.Close()
  604. defer server.Close()
  605. stream, err := client.Open()
  606. if err != nil {
  607. t.Fatalf("err: %v", err)
  608. }
  609. defer stream.Close()
  610. stream2, err := server.Accept()
  611. if err != nil {
  612. t.Fatalf("err: %v", err)
  613. }
  614. defer stream2.Close()
  615. if err := stream.SetWriteDeadline(time.Now().Add(50 * time.Millisecond)); err != nil {
  616. t.Fatalf("err: %v", err)
  617. }
  618. buf := make([]byte, 512)
  619. for i := 0; i < int(initialStreamWindow); i++ {
  620. _, err := stream.Write(buf)
  621. if err != nil && err == ErrTimeout {
  622. return
  623. } else if err != nil {
  624. t.Fatalf("err: %v", err)
  625. }
  626. }
  627. t.Fatalf("Expected timeout")
  628. }
  629. func TestBacklogExceeded(t *testing.T) {
  630. client, server := testClientServer()
  631. defer client.Close()
  632. defer server.Close()
  633. // Fill the backlog
  634. max := client.config.AcceptBacklog
  635. for i := 0; i < max; i++ {
  636. stream, err := client.Open()
  637. if err != nil {
  638. t.Fatalf("err: %v", err)
  639. }
  640. defer stream.Close()
  641. if _, err := stream.Write([]byte("foo")); err != nil {
  642. t.Fatalf("err: %v", err)
  643. }
  644. }
  645. // Attempt to open a new stream
  646. errCh := make(chan error, 1)
  647. go func() {
  648. _, err := client.Open()
  649. errCh <- err
  650. }()
  651. // Shutdown the server
  652. go func() {
  653. time.Sleep(10 * time.Millisecond)
  654. server.Close()
  655. }()
  656. select {
  657. case err := <-errCh:
  658. if err == nil {
  659. t.Fatalf("open should fail")
  660. }
  661. case <-time.After(time.Second):
  662. t.Fatalf("timeout")
  663. }
  664. }
  665. func TestKeepAlive(t *testing.T) {
  666. client, server := testClientServer()
  667. defer client.Close()
  668. defer server.Close()
  669. time.Sleep(200 * time.Millisecond)
  670. // Ping value should increase
  671. client.pingLock.Lock()
  672. defer client.pingLock.Unlock()
  673. if client.pingID == 0 {
  674. t.Fatalf("should ping")
  675. }
  676. server.pingLock.Lock()
  677. defer server.pingLock.Unlock()
  678. if server.pingID == 0 {
  679. t.Fatalf("should ping")
  680. }
  681. }
  682. func TestKeepAlive_Timeout(t *testing.T) {
  683. conn1, conn2 := testConn()
  684. clientConf := testConf()
  685. clientConf.ConnectionWriteTimeout = time.Hour // We're testing keep alives, not connection writes
  686. clientConf.EnableKeepAlive = false // Just test one direction, so it's deterministic who hangs up on whom
  687. client, _ := Client(conn1, clientConf)
  688. defer client.Close()
  689. server, _ := Server(conn2, testConf())
  690. defer server.Close()
  691. _ = captureLogs(client) // Client logs aren't part of the test
  692. serverLogs := captureLogs(server)
  693. errCh := make(chan error, 1)
  694. go func() {
  695. _, err := server.Accept() // Wait until server closes
  696. errCh <- err
  697. }()
  698. // Prevent the client from responding
  699. clientConn := client.conn.(*pipeConn)
  700. clientConn.writeBlocker.Lock()
  701. select {
  702. case err := <-errCh:
  703. if err != ErrKeepAliveTimeout {
  704. t.Fatalf("unexpected error: %v", err)
  705. }
  706. case <-time.After(1 * time.Second):
  707. t.Fatalf("timeout waiting for timeout")
  708. }
  709. if !server.IsClosed() {
  710. t.Fatalf("server should have closed")
  711. }
  712. if !serverLogs.match([]string{"[ERR] yamux: keepalive failed: i/o deadline reached"}) {
  713. t.Fatalf("server log incorect: %v", serverLogs.logs())
  714. }
  715. }
  716. func TestLargeWindow(t *testing.T) {
  717. conf := DefaultConfig()
  718. conf.MaxStreamWindowSize *= 2
  719. client, server := testClientServerConfig(conf)
  720. defer client.Close()
  721. defer server.Close()
  722. stream, err := client.Open()
  723. if err != nil {
  724. t.Fatalf("err: %v", err)
  725. }
  726. defer stream.Close()
  727. stream2, err := server.Accept()
  728. if err != nil {
  729. t.Fatalf("err: %v", err)
  730. }
  731. defer stream2.Close()
  732. stream.SetWriteDeadline(time.Now().Add(10 * time.Millisecond))
  733. buf := make([]byte, conf.MaxStreamWindowSize)
  734. n, err := stream.Write(buf)
  735. if err != nil {
  736. t.Fatalf("err: %v", err)
  737. }
  738. if n != len(buf) {
  739. t.Fatalf("short write: %d", n)
  740. }
  741. }
  742. type UnlimitedReader struct{}
  743. func (u *UnlimitedReader) Read(p []byte) (int, error) {
  744. runtime.Gosched()
  745. return len(p), nil
  746. }
  747. func TestSendData_VeryLarge(t *testing.T) {
  748. client, server := testClientServer()
  749. defer client.Close()
  750. defer server.Close()
  751. var n int64 = 1 * 1024 * 1024 * 1024
  752. var workers int = 16
  753. wg := &sync.WaitGroup{}
  754. wg.Add(workers * 2)
  755. for i := 0; i < workers; i++ {
  756. go func() {
  757. defer wg.Done()
  758. stream, err := server.AcceptStream()
  759. if err != nil {
  760. t.Fatalf("err: %v", err)
  761. }
  762. defer stream.Close()
  763. buf := make([]byte, 4)
  764. _, err = stream.Read(buf)
  765. if err != nil {
  766. t.Fatalf("err: %v", err)
  767. }
  768. if !bytes.Equal(buf, []byte{0, 1, 2, 3}) {
  769. t.Fatalf("bad header")
  770. }
  771. recv, err := io.Copy(ioutil.Discard, stream)
  772. if err != nil {
  773. t.Fatalf("err: %v", err)
  774. }
  775. if recv != n {
  776. t.Fatalf("bad: %v", recv)
  777. }
  778. }()
  779. }
  780. for i := 0; i < workers; i++ {
  781. go func() {
  782. defer wg.Done()
  783. stream, err := client.Open()
  784. if err != nil {
  785. t.Fatalf("err: %v", err)
  786. }
  787. defer stream.Close()
  788. _, err = stream.Write([]byte{0, 1, 2, 3})
  789. if err != nil {
  790. t.Fatalf("err: %v", err)
  791. }
  792. unlimited := &UnlimitedReader{}
  793. sent, err := io.Copy(stream, io.LimitReader(unlimited, n))
  794. if err != nil {
  795. t.Fatalf("err: %v", err)
  796. }
  797. if sent != n {
  798. t.Fatalf("bad: %v", sent)
  799. }
  800. }()
  801. }
  802. doneCh := make(chan struct{})
  803. go func() {
  804. wg.Wait()
  805. close(doneCh)
  806. }()
  807. select {
  808. case <-doneCh:
  809. case <-time.After(20 * time.Second):
  810. panic("timeout")
  811. }
  812. }
  813. func TestBacklogExceeded_Accept(t *testing.T) {
  814. client, server := testClientServer()
  815. defer client.Close()
  816. defer server.Close()
  817. max := 5 * client.config.AcceptBacklog
  818. go func() {
  819. for i := 0; i < max; i++ {
  820. stream, err := server.Accept()
  821. if err != nil {
  822. t.Fatalf("err: %v", err)
  823. }
  824. defer stream.Close()
  825. }
  826. }()
  827. // Fill the backlog
  828. for i := 0; i < max; i++ {
  829. stream, err := client.Open()
  830. if err != nil {
  831. t.Fatalf("err: %v", err)
  832. }
  833. defer stream.Close()
  834. if _, err := stream.Write([]byte("foo")); err != nil {
  835. t.Fatalf("err: %v", err)
  836. }
  837. }
  838. }
  839. func TestSession_WindowUpdateWriteDuringRead(t *testing.T) {
  840. client, server := testClientServerConfig(testConfNoKeepAlive())
  841. defer client.Close()
  842. defer server.Close()
  843. var wg sync.WaitGroup
  844. wg.Add(2)
  845. // Choose a huge flood size that we know will result in a window update.
  846. flood := int64(client.config.MaxStreamWindowSize) - 1
  847. // The server will accept a new stream and then flood data to it.
  848. go func() {
  849. defer wg.Done()
  850. stream, err := server.AcceptStream()
  851. if err != nil {
  852. t.Fatalf("err: %v", err)
  853. }
  854. defer stream.Close()
  855. n, err := stream.Write(make([]byte, flood))
  856. if err != nil {
  857. t.Fatalf("err: %v", err)
  858. }
  859. if int64(n) != flood {
  860. t.Fatalf("short write: %d", n)
  861. }
  862. }()
  863. // The client will open a stream, block outbound writes, and then
  864. // listen to the flood from the server, which should time out since
  865. // it won't be able to send the window update.
  866. go func() {
  867. defer wg.Done()
  868. stream, err := client.OpenStream()
  869. if err != nil {
  870. t.Fatalf("err: %v", err)
  871. }
  872. defer stream.Close()
  873. conn := client.conn.(*pipeConn)
  874. conn.writeBlocker.Lock()
  875. _, err = stream.Read(make([]byte, flood))
  876. if err != ErrConnectionWriteTimeout {
  877. t.Fatalf("err: %v", err)
  878. }
  879. }()
  880. wg.Wait()
  881. }
  882. func TestSession_PartialReadWindowUpdate(t *testing.T) {
  883. client, server := testClientServerConfig(testConfNoKeepAlive())
  884. defer client.Close()
  885. defer server.Close()
  886. var wg sync.WaitGroup
  887. wg.Add(1)
  888. // Choose a huge flood size that we know will result in a window update.
  889. flood := int64(client.config.MaxStreamWindowSize)
  890. var wr *Stream
  891. // The server will accept a new stream and then flood data to it.
  892. go func() {
  893. defer wg.Done()
  894. var err error
  895. wr, err = server.AcceptStream()
  896. if err != nil {
  897. t.Fatalf("err: %v", err)
  898. }
  899. defer wr.Close()
  900. if wr.sendWindow != client.config.MaxStreamWindowSize {
  901. t.Fatalf("sendWindow: exp=%d, got=%d", client.config.MaxStreamWindowSize, wr.sendWindow)
  902. }
  903. n, err := wr.Write(make([]byte, flood))
  904. if err != nil {
  905. t.Fatalf("err: %v", err)
  906. }
  907. if int64(n) != flood {
  908. t.Fatalf("short write: %d", n)
  909. }
  910. if wr.sendWindow != 0 {
  911. t.Fatalf("sendWindow: exp=%d, got=%d", 0, wr.sendWindow)
  912. }
  913. }()
  914. stream, err := client.OpenStream()
  915. if err != nil {
  916. t.Fatalf("err: %v", err)
  917. }
  918. defer stream.Close()
  919. wg.Wait()
  920. _, err = stream.Read(make([]byte, flood/2+1))
  921. if exp := uint32(flood/2 + 1); wr.sendWindow != exp {
  922. t.Errorf("sendWindow: exp=%d, got=%d", exp, wr.sendWindow)
  923. }
  924. }
  925. func TestSession_sendNoWait_Timeout(t *testing.T) {
  926. client, server := testClientServerConfig(testConfNoKeepAlive())
  927. defer client.Close()
  928. defer server.Close()
  929. var wg sync.WaitGroup
  930. wg.Add(2)
  931. go func() {
  932. defer wg.Done()
  933. stream, err := server.AcceptStream()
  934. if err != nil {
  935. t.Fatalf("err: %v", err)
  936. }
  937. defer stream.Close()
  938. }()
  939. // The client will open the stream and then block outbound writes, we'll
  940. // probe sendNoWait once it gets into that state.
  941. go func() {
  942. defer wg.Done()
  943. stream, err := client.OpenStream()
  944. if err != nil {
  945. t.Fatalf("err: %v", err)
  946. }
  947. defer stream.Close()
  948. conn := client.conn.(*pipeConn)
  949. conn.writeBlocker.Lock()
  950. hdr := header(make([]byte, headerSize))
  951. hdr.encode(typePing, flagACK, 0, 0)
  952. for {
  953. err = client.sendNoWait(hdr)
  954. if err == nil {
  955. continue
  956. } else if err == ErrConnectionWriteTimeout {
  957. break
  958. } else {
  959. t.Fatalf("err: %v", err)
  960. }
  961. }
  962. }()
  963. wg.Wait()
  964. }
  965. func TestSession_PingOfDeath(t *testing.T) {
  966. client, server := testClientServerConfig(testConfNoKeepAlive())
  967. defer client.Close()
  968. defer server.Close()
  969. var wg sync.WaitGroup
  970. wg.Add(2)
  971. var doPingOfDeath sync.Mutex
  972. doPingOfDeath.Lock()
  973. // This is used later to block outbound writes.
  974. conn := server.conn.(*pipeConn)
  975. // The server will accept a stream, block outbound writes, and then
  976. // flood its send channel so that no more headers can be queued.
  977. go func() {
  978. defer wg.Done()
  979. stream, err := server.AcceptStream()
  980. if err != nil {
  981. t.Fatalf("err: %v", err)
  982. }
  983. defer stream.Close()
  984. conn.writeBlocker.Lock()
  985. for {
  986. hdr := header(make([]byte, headerSize))
  987. hdr.encode(typePing, 0, 0, 0)
  988. err = server.sendNoWait(hdr)
  989. if err == nil {
  990. continue
  991. } else if err == ErrConnectionWriteTimeout {
  992. break
  993. } else {
  994. t.Fatalf("err: %v", err)
  995. }
  996. }
  997. doPingOfDeath.Unlock()
  998. }()
  999. // The client will open a stream and then send the server a ping once it
  1000. // can no longer write. This makes sure the server doesn't deadlock reads
  1001. // while trying to reply to the ping with no ability to write.
  1002. go func() {
  1003. defer wg.Done()
  1004. stream, err := client.OpenStream()
  1005. if err != nil {
  1006. t.Fatalf("err: %v", err)
  1007. }
  1008. defer stream.Close()
  1009. // This ping will never unblock because the ping id will never
  1010. // show up in a response.
  1011. doPingOfDeath.Lock()
  1012. go func() { client.Ping() }()
  1013. // Wait for a while to make sure the previous ping times out,
  1014. // then turn writes back on and make sure a ping works again.
  1015. time.Sleep(2 * server.config.ConnectionWriteTimeout)
  1016. conn.writeBlocker.Unlock()
  1017. if _, err = client.Ping(); err != nil {
  1018. t.Fatalf("err: %v", err)
  1019. }
  1020. }()
  1021. wg.Wait()
  1022. }
  1023. func TestSession_ConnectionWriteTimeout(t *testing.T) {
  1024. client, server := testClientServerConfig(testConfNoKeepAlive())
  1025. defer client.Close()
  1026. defer server.Close()
  1027. var wg sync.WaitGroup
  1028. wg.Add(2)
  1029. go func() {
  1030. defer wg.Done()
  1031. stream, err := server.AcceptStream()
  1032. if err != nil {
  1033. t.Fatalf("err: %v", err)
  1034. }
  1035. defer stream.Close()
  1036. }()
  1037. // The client will open the stream and then block outbound writes, we'll
  1038. // tee up a write and make sure it eventually times out.
  1039. go func() {
  1040. defer wg.Done()
  1041. stream, err := client.OpenStream()
  1042. if err != nil {
  1043. t.Fatalf("err: %v", err)
  1044. }
  1045. defer stream.Close()
  1046. conn := client.conn.(*pipeConn)
  1047. conn.writeBlocker.Lock()
  1048. // Since the write goroutine is blocked then this will return a
  1049. // timeout since it can't get feedback about whether the write
  1050. // worked.
  1051. n, err := stream.Write([]byte("hello"))
  1052. if err != ErrConnectionWriteTimeout {
  1053. t.Fatalf("err: %v", err)
  1054. }
  1055. if n != 0 {
  1056. t.Fatalf("lied about writes: %d", n)
  1057. }
  1058. }()
  1059. wg.Wait()
  1060. }