session_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729
  1. package yamux
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "io/ioutil"
  7. "runtime"
  8. "sync"
  9. "testing"
  10. "time"
  11. )
  12. type pipeConn struct {
  13. reader *io.PipeReader
  14. writer *io.PipeWriter
  15. }
  16. func (p *pipeConn) Read(b []byte) (int, error) {
  17. return p.reader.Read(b)
  18. }
  19. func (p *pipeConn) Write(b []byte) (int, error) {
  20. return p.writer.Write(b)
  21. }
  22. func (p *pipeConn) Close() error {
  23. p.reader.Close()
  24. return p.writer.Close()
  25. }
  26. func testConn() (io.ReadWriteCloser, io.ReadWriteCloser) {
  27. read1, write1 := io.Pipe()
  28. read2, write2 := io.Pipe()
  29. return &pipeConn{read1, write2}, &pipeConn{read2, write1}
  30. }
  31. func testClientServer() (*Session, *Session) {
  32. conf := DefaultConfig()
  33. conf.AcceptBacklog = 64
  34. conf.KeepAliveInterval = 100 * time.Millisecond
  35. return testClientServerConfig(conf)
  36. }
  37. func testClientServerConfig(conf *Config) (*Session, *Session) {
  38. conn1, conn2 := testConn()
  39. client, _ := Client(conn1, conf)
  40. server, _ := Server(conn2, conf)
  41. return client, server
  42. }
  43. func TestPing(t *testing.T) {
  44. client, server := testClientServer()
  45. defer client.Close()
  46. defer server.Close()
  47. rtt, err := client.Ping()
  48. if err != nil {
  49. t.Fatalf("err: %v", err)
  50. }
  51. if rtt == 0 {
  52. t.Fatalf("bad: %v", rtt)
  53. }
  54. rtt, err = server.Ping()
  55. if err != nil {
  56. t.Fatalf("err: %v", err)
  57. }
  58. if rtt == 0 {
  59. t.Fatalf("bad: %v", rtt)
  60. }
  61. }
  62. func TestAccept(t *testing.T) {
  63. client, server := testClientServer()
  64. defer client.Close()
  65. defer server.Close()
  66. wg := &sync.WaitGroup{}
  67. wg.Add(4)
  68. go func() {
  69. defer wg.Done()
  70. stream, err := server.AcceptStream()
  71. if err != nil {
  72. t.Fatalf("err: %v", err)
  73. }
  74. if id := stream.StreamID(); id != 1 {
  75. t.Fatalf("bad: %v", id)
  76. }
  77. if err := stream.Close(); err != nil {
  78. t.Fatalf("err: %v", err)
  79. }
  80. }()
  81. go func() {
  82. defer wg.Done()
  83. stream, err := client.AcceptStream()
  84. if err != nil {
  85. t.Fatalf("err: %v", err)
  86. }
  87. if id := stream.StreamID(); id != 2 {
  88. t.Fatalf("bad: %v", id)
  89. }
  90. if err := stream.Close(); err != nil {
  91. t.Fatalf("err: %v", err)
  92. }
  93. }()
  94. go func() {
  95. defer wg.Done()
  96. stream, err := server.OpenStream()
  97. if err != nil {
  98. t.Fatalf("err: %v", err)
  99. }
  100. if id := stream.StreamID(); id != 2 {
  101. t.Fatalf("bad: %v", id)
  102. }
  103. if err := stream.Close(); err != nil {
  104. t.Fatalf("err: %v", err)
  105. }
  106. }()
  107. go func() {
  108. defer wg.Done()
  109. stream, err := client.OpenStream()
  110. if err != nil {
  111. t.Fatalf("err: %v", err)
  112. }
  113. if id := stream.StreamID(); id != 1 {
  114. t.Fatalf("bad: %v", id)
  115. }
  116. if err := stream.Close(); err != nil {
  117. t.Fatalf("err: %v", err)
  118. }
  119. }()
  120. doneCh := make(chan struct{})
  121. go func() {
  122. wg.Wait()
  123. close(doneCh)
  124. }()
  125. select {
  126. case <-doneCh:
  127. case <-time.After(time.Second):
  128. panic("timeout")
  129. }
  130. }
  131. func TestSendData_Small(t *testing.T) {
  132. client, server := testClientServer()
  133. defer client.Close()
  134. defer server.Close()
  135. wg := &sync.WaitGroup{}
  136. wg.Add(2)
  137. go func() {
  138. defer wg.Done()
  139. stream, err := server.AcceptStream()
  140. if err != nil {
  141. t.Fatalf("err: %v", err)
  142. }
  143. buf := make([]byte, 4)
  144. for i := 0; i < 1000; i++ {
  145. n, err := stream.Read(buf)
  146. if err != nil {
  147. t.Fatalf("err: %v", err)
  148. }
  149. if n != 4 {
  150. t.Fatalf("short read: %d", n)
  151. }
  152. if string(buf) != "test" {
  153. t.Fatalf("bad: %s", buf)
  154. }
  155. }
  156. if err := stream.Close(); err != nil {
  157. t.Fatalf("err: %v", err)
  158. }
  159. }()
  160. go func() {
  161. defer wg.Done()
  162. stream, err := client.Open()
  163. if err != nil {
  164. t.Fatalf("err: %v", err)
  165. }
  166. for i := 0; i < 1000; i++ {
  167. n, err := stream.Write([]byte("test"))
  168. if err != nil {
  169. t.Fatalf("err: %v", err)
  170. }
  171. if n != 4 {
  172. t.Fatalf("short write %d", n)
  173. }
  174. }
  175. if err := stream.Close(); err != nil {
  176. t.Fatalf("err: %v", err)
  177. }
  178. }()
  179. doneCh := make(chan struct{})
  180. go func() {
  181. wg.Wait()
  182. close(doneCh)
  183. }()
  184. select {
  185. case <-doneCh:
  186. case <-time.After(time.Second):
  187. panic("timeout")
  188. }
  189. }
  190. func TestSendData_Large(t *testing.T) {
  191. client, server := testClientServer()
  192. defer client.Close()
  193. defer server.Close()
  194. data := make([]byte, 512*1024)
  195. for idx := range data {
  196. data[idx] = byte(idx % 256)
  197. }
  198. wg := &sync.WaitGroup{}
  199. wg.Add(2)
  200. go func() {
  201. defer wg.Done()
  202. stream, err := server.AcceptStream()
  203. if err != nil {
  204. t.Fatalf("err: %v", err)
  205. }
  206. buf := make([]byte, 4*1024)
  207. for i := 0; i < 128; i++ {
  208. n, err := stream.Read(buf)
  209. if err != nil {
  210. t.Fatalf("err: %v", err)
  211. }
  212. if n != 4*1024 {
  213. t.Fatalf("short read: %d", n)
  214. }
  215. for idx := range buf {
  216. if buf[idx] != byte(idx%256) {
  217. t.Fatalf("bad: %v %v %v", i, idx, buf[idx])
  218. }
  219. }
  220. }
  221. if err := stream.Close(); err != nil {
  222. t.Fatalf("err: %v", err)
  223. }
  224. }()
  225. go func() {
  226. defer wg.Done()
  227. stream, err := client.Open()
  228. if err != nil {
  229. t.Fatalf("err: %v", err)
  230. }
  231. n, err := stream.Write(data)
  232. if err != nil {
  233. t.Fatalf("err: %v", err)
  234. }
  235. if n != len(data) {
  236. t.Fatalf("short write %d", n)
  237. }
  238. if err := stream.Close(); err != nil {
  239. t.Fatalf("err: %v", err)
  240. }
  241. }()
  242. doneCh := make(chan struct{})
  243. go func() {
  244. wg.Wait()
  245. close(doneCh)
  246. }()
  247. select {
  248. case <-doneCh:
  249. case <-time.After(time.Second):
  250. panic("timeout")
  251. }
  252. }
  253. func TestGoAway(t *testing.T) {
  254. client, server := testClientServer()
  255. defer client.Close()
  256. defer server.Close()
  257. if err := server.GoAway(); err != nil {
  258. t.Fatalf("err: %v", err)
  259. }
  260. _, err := client.Open()
  261. if err != ErrRemoteGoAway {
  262. t.Fatalf("err: %v", err)
  263. }
  264. }
  265. func TestManyStreams(t *testing.T) {
  266. client, server := testClientServer()
  267. defer client.Close()
  268. defer server.Close()
  269. wg := &sync.WaitGroup{}
  270. acceptor := func(i int) {
  271. defer wg.Done()
  272. stream, err := server.AcceptStream()
  273. if err != nil {
  274. t.Fatalf("err: %v", err)
  275. }
  276. defer stream.Close()
  277. buf := make([]byte, 512)
  278. for {
  279. n, err := stream.Read(buf)
  280. if err == io.EOF {
  281. return
  282. }
  283. if err != nil {
  284. t.Fatalf("err: %v", err)
  285. }
  286. if n == 0 {
  287. t.Fatalf("err: %v", err)
  288. }
  289. }
  290. }
  291. sender := func(i int) {
  292. defer wg.Done()
  293. stream, err := client.Open()
  294. if err != nil {
  295. t.Fatalf("err: %v", err)
  296. }
  297. defer stream.Close()
  298. msg := fmt.Sprintf("%08d", i)
  299. for i := 0; i < 1000; i++ {
  300. n, err := stream.Write([]byte(msg))
  301. if err != nil {
  302. t.Fatalf("err: %v", err)
  303. }
  304. if n != len(msg) {
  305. t.Fatalf("short write %d", n)
  306. }
  307. }
  308. }
  309. for i := 0; i < 50; i++ {
  310. wg.Add(2)
  311. go acceptor(i)
  312. go sender(i)
  313. }
  314. wg.Wait()
  315. }
  316. func TestManyStreams_PingPong(t *testing.T) {
  317. client, server := testClientServer()
  318. defer client.Close()
  319. defer server.Close()
  320. wg := &sync.WaitGroup{}
  321. ping := []byte("ping")
  322. pong := []byte("pong")
  323. acceptor := func(i int) {
  324. defer wg.Done()
  325. stream, err := server.AcceptStream()
  326. if err != nil {
  327. t.Fatalf("err: %v", err)
  328. }
  329. defer stream.Close()
  330. buf := make([]byte, 4)
  331. for {
  332. n, err := stream.Read(buf)
  333. if err == io.EOF {
  334. return
  335. }
  336. if err != nil {
  337. t.Fatalf("err: %v", err)
  338. }
  339. if n != 4 {
  340. t.Fatalf("err: %v", err)
  341. }
  342. if !bytes.Equal(buf, ping) {
  343. t.Fatalf("bad: %s", buf)
  344. }
  345. n, err = stream.Write(pong)
  346. if err != nil {
  347. t.Fatalf("err: %v", err)
  348. }
  349. if n != 4 {
  350. t.Fatalf("err: %v", err)
  351. }
  352. }
  353. }
  354. sender := func(i int) {
  355. defer wg.Done()
  356. stream, err := client.Open()
  357. if err != nil {
  358. t.Fatalf("err: %v", err)
  359. }
  360. defer stream.Close()
  361. buf := make([]byte, 4)
  362. for i := 0; i < 1000; i++ {
  363. n, err := stream.Write(ping)
  364. if err != nil {
  365. t.Fatalf("err: %v", err)
  366. }
  367. if n != 4 {
  368. t.Fatalf("short write %d", n)
  369. }
  370. n, err = stream.Read(buf)
  371. if err != nil {
  372. t.Fatalf("err: %v", err)
  373. }
  374. if n != 4 {
  375. t.Fatalf("err: %v", err)
  376. }
  377. if !bytes.Equal(buf, pong) {
  378. t.Fatalf("bad: %s", buf)
  379. }
  380. }
  381. }
  382. for i := 0; i < 50; i++ {
  383. wg.Add(2)
  384. go acceptor(i)
  385. go sender(i)
  386. }
  387. wg.Wait()
  388. }
  389. func TestHalfClose(t *testing.T) {
  390. client, server := testClientServer()
  391. defer client.Close()
  392. defer server.Close()
  393. stream, err := client.Open()
  394. if err != nil {
  395. t.Fatalf("err: %v", err)
  396. }
  397. if _, err := stream.Write([]byte("a")); err != nil {
  398. t.Fatalf("err: %v", err)
  399. }
  400. stream2, err := server.Accept()
  401. if err != nil {
  402. t.Fatalf("err: %v", err)
  403. }
  404. stream2.Close() // Half close
  405. buf := make([]byte, 4)
  406. n, err := stream2.Read(buf)
  407. if err != nil {
  408. t.Fatalf("err: %v", err)
  409. }
  410. if n != 1 {
  411. t.Fatalf("bad: %v", n)
  412. }
  413. // Send more
  414. if _, err := stream.Write([]byte("bcd")); err != nil {
  415. t.Fatalf("err: %v", err)
  416. }
  417. stream.Close()
  418. // Read after close
  419. n, err = stream2.Read(buf)
  420. if err != nil {
  421. t.Fatalf("err: %v", err)
  422. }
  423. if n != 3 {
  424. t.Fatalf("bad: %v", n)
  425. }
  426. // EOF after close
  427. n, err = stream2.Read(buf)
  428. if err != io.EOF {
  429. t.Fatalf("err: %v", err)
  430. }
  431. if n != 0 {
  432. t.Fatalf("bad: %v", n)
  433. }
  434. }
  435. func TestReadDeadline(t *testing.T) {
  436. client, server := testClientServer()
  437. defer client.Close()
  438. defer server.Close()
  439. stream, err := client.Open()
  440. if err != nil {
  441. t.Fatalf("err: %v", err)
  442. }
  443. defer stream.Close()
  444. stream2, err := server.Accept()
  445. if err != nil {
  446. t.Fatalf("err: %v", err)
  447. }
  448. defer stream2.Close()
  449. if err := stream.SetReadDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
  450. t.Fatalf("err: %v", err)
  451. }
  452. buf := make([]byte, 4)
  453. if _, err := stream.Read(buf); err != ErrTimeout {
  454. t.Fatalf("err: %v", err)
  455. }
  456. }
  457. func TestWriteDeadline(t *testing.T) {
  458. client, server := testClientServer()
  459. defer client.Close()
  460. defer server.Close()
  461. stream, err := client.Open()
  462. if err != nil {
  463. t.Fatalf("err: %v", err)
  464. }
  465. defer stream.Close()
  466. stream2, err := server.Accept()
  467. if err != nil {
  468. t.Fatalf("err: %v", err)
  469. }
  470. defer stream2.Close()
  471. if err := stream.SetWriteDeadline(time.Now().Add(50 * time.Millisecond)); err != nil {
  472. t.Fatalf("err: %v", err)
  473. }
  474. buf := make([]byte, 512)
  475. for i := 0; i < int(initialStreamWindow); i++ {
  476. _, err := stream.Write(buf)
  477. if err != nil && err == ErrTimeout {
  478. return
  479. } else if err != nil {
  480. t.Fatalf("err: %v", err)
  481. }
  482. }
  483. t.Fatalf("Expected timeout")
  484. }
  485. func TestBacklogExceeded(t *testing.T) {
  486. client, server := testClientServer()
  487. defer client.Close()
  488. defer server.Close()
  489. // Fill the backlog
  490. max := client.config.AcceptBacklog
  491. for i := 0; i < max; i++ {
  492. stream, err := client.Open()
  493. if err != nil {
  494. t.Fatalf("err: %v", err)
  495. }
  496. defer stream.Close()
  497. if _, err := stream.Write([]byte("foo")); err != nil {
  498. t.Fatalf("err: %v", err)
  499. }
  500. }
  501. // Exceed the backlog!
  502. stream, err := client.Open()
  503. if err != nil {
  504. t.Fatalf("err: %v", err)
  505. }
  506. defer stream.Close()
  507. if _, err := stream.Write([]byte("foo")); err != nil {
  508. t.Fatalf("err: %v", err)
  509. }
  510. buf := make([]byte, 4)
  511. stream.SetReadDeadline(time.Now().Add(50 * time.Millisecond))
  512. if _, err := stream.Read(buf); err != ErrConnectionReset {
  513. t.Fatalf("err: %v", err)
  514. }
  515. }
  516. func TestKeepAlive(t *testing.T) {
  517. client, server := testClientServer()
  518. defer client.Close()
  519. defer server.Close()
  520. time.Sleep(200 * time.Millisecond)
  521. // Ping value should increase
  522. client.pingLock.Lock()
  523. defer client.pingLock.Unlock()
  524. if client.pingID == 0 {
  525. t.Fatalf("should ping")
  526. }
  527. server.pingLock.Lock()
  528. defer server.pingLock.Unlock()
  529. if server.pingID == 0 {
  530. t.Fatalf("should ping")
  531. }
  532. }
  533. func TestLargeWindow(t *testing.T) {
  534. conf := DefaultConfig()
  535. conf.MaxStreamWindowSize *= 2
  536. client, server := testClientServerConfig(conf)
  537. defer client.Close()
  538. defer server.Close()
  539. stream, err := client.Open()
  540. if err != nil {
  541. t.Fatalf("err: %v", err)
  542. }
  543. defer stream.Close()
  544. stream2, err := server.Accept()
  545. if err != nil {
  546. t.Fatalf("err: %v", err)
  547. }
  548. defer stream2.Close()
  549. stream.SetWriteDeadline(time.Now().Add(10 * time.Millisecond))
  550. buf := make([]byte, conf.MaxStreamWindowSize)
  551. n, err := stream.Write(buf)
  552. if err != nil {
  553. t.Fatalf("err: %v", err)
  554. }
  555. if n != len(buf) {
  556. t.Fatalf("short write: %d", n)
  557. }
  558. }
  559. type UnlimitedReader struct{}
  560. func (u *UnlimitedReader) Read(p []byte) (int, error) {
  561. runtime.Gosched()
  562. return len(p), nil
  563. }
  564. func TestSendData_VeryLarge(t *testing.T) {
  565. client, server := testClientServer()
  566. defer client.Close()
  567. defer server.Close()
  568. var n int64 = 1 * 1024 * 1024 * 1024
  569. var workers int = 16
  570. wg := &sync.WaitGroup{}
  571. wg.Add(workers * 2)
  572. for i := 0; i < workers; i++ {
  573. go func() {
  574. defer wg.Done()
  575. stream, err := server.AcceptStream()
  576. if err != nil {
  577. t.Fatalf("err: %v", err)
  578. }
  579. defer stream.Close()
  580. buf := make([]byte, 4)
  581. _, err = stream.Read(buf)
  582. if err != nil {
  583. t.Fatalf("err: %v", err)
  584. }
  585. if !bytes.Equal(buf, []byte{0, 1, 2, 3}) {
  586. t.Fatalf("bad header")
  587. }
  588. recv, err := io.Copy(ioutil.Discard, stream)
  589. if err != nil {
  590. t.Fatalf("err: %v", err)
  591. }
  592. if recv != n {
  593. t.Fatalf("bad: %v", recv)
  594. }
  595. }()
  596. }
  597. for i := 0; i < workers; i++ {
  598. go func() {
  599. defer wg.Done()
  600. stream, err := client.Open()
  601. if err != nil {
  602. t.Fatalf("err: %v", err)
  603. }
  604. defer stream.Close()
  605. _, err = stream.Write([]byte{0, 1, 2, 3})
  606. if err != nil {
  607. t.Fatalf("err: %v", err)
  608. }
  609. unlimited := &UnlimitedReader{}
  610. sent, err := io.Copy(stream, io.LimitReader(unlimited, n))
  611. if err != nil {
  612. t.Fatalf("err: %v", err)
  613. }
  614. if sent != n {
  615. t.Fatalf("bad: %v", sent)
  616. }
  617. }()
  618. }
  619. doneCh := make(chan struct{})
  620. go func() {
  621. wg.Wait()
  622. close(doneCh)
  623. }()
  624. select {
  625. case <-doneCh:
  626. case <-time.After(20 * time.Second):
  627. panic("timeout")
  628. }
  629. }