session_test.go 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. package yamux
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "sync"
  7. "testing"
  8. "time"
  9. )
  10. type pipeConn struct {
  11. reader *io.PipeReader
  12. writer *io.PipeWriter
  13. }
  14. func (p *pipeConn) Read(b []byte) (int, error) {
  15. return p.reader.Read(b)
  16. }
  17. func (p *pipeConn) Write(b []byte) (int, error) {
  18. return p.writer.Write(b)
  19. }
  20. func (p *pipeConn) Close() error {
  21. p.reader.Close()
  22. return p.writer.Close()
  23. }
  24. func testConn() (io.ReadWriteCloser, io.ReadWriteCloser) {
  25. read1, write1 := io.Pipe()
  26. read2, write2 := io.Pipe()
  27. return &pipeConn{read1, write2}, &pipeConn{read2, write1}
  28. }
  29. func testClientServer() (*Session, *Session) {
  30. conn1, conn2 := testConn()
  31. client, _ := Client(conn1, nil)
  32. server, _ := Server(conn2, nil)
  33. return client, server
  34. }
  35. func TestPing(t *testing.T) {
  36. client, server := testClientServer()
  37. defer client.Close()
  38. defer server.Close()
  39. rtt, err := client.Ping()
  40. if err != nil {
  41. t.Fatalf("err: %v", err)
  42. }
  43. if rtt == 0 {
  44. t.Fatalf("bad: %v", rtt)
  45. }
  46. rtt, err = server.Ping()
  47. if err != nil {
  48. t.Fatalf("err: %v", err)
  49. }
  50. if rtt == 0 {
  51. t.Fatalf("bad: %v", rtt)
  52. }
  53. }
  54. func TestAccept(t *testing.T) {
  55. client, server := testClientServer()
  56. defer client.Close()
  57. defer server.Close()
  58. wg := &sync.WaitGroup{}
  59. wg.Add(4)
  60. go func() {
  61. defer wg.Done()
  62. stream, err := server.AcceptStream()
  63. if err != nil {
  64. t.Fatalf("err: %v", err)
  65. }
  66. if id := stream.StreamID(); id != 1 {
  67. t.Fatalf("bad: %v", id)
  68. }
  69. if err := stream.Close(); err != nil {
  70. t.Fatalf("err: %v", err)
  71. }
  72. }()
  73. go func() {
  74. defer wg.Done()
  75. stream, err := client.AcceptStream()
  76. if err != nil {
  77. t.Fatalf("err: %v", err)
  78. }
  79. if id := stream.StreamID(); id != 2 {
  80. t.Fatalf("bad: %v", id)
  81. }
  82. if err := stream.Close(); err != nil {
  83. t.Fatalf("err: %v", err)
  84. }
  85. }()
  86. go func() {
  87. defer wg.Done()
  88. stream, err := server.Open()
  89. if err != nil {
  90. t.Fatalf("err: %v", err)
  91. }
  92. if id := stream.StreamID(); id != 2 {
  93. t.Fatalf("bad: %v", id)
  94. }
  95. if err := stream.Close(); err != nil {
  96. t.Fatalf("err: %v", err)
  97. }
  98. }()
  99. go func() {
  100. defer wg.Done()
  101. stream, err := client.Open()
  102. if err != nil {
  103. t.Fatalf("err: %v", err)
  104. }
  105. if id := stream.StreamID(); id != 1 {
  106. t.Fatalf("bad: %v", id)
  107. }
  108. if err := stream.Close(); err != nil {
  109. t.Fatalf("err: %v", err)
  110. }
  111. }()
  112. doneCh := make(chan struct{})
  113. go func() {
  114. wg.Wait()
  115. close(doneCh)
  116. }()
  117. select {
  118. case <-doneCh:
  119. case <-time.After(time.Second):
  120. panic("timeout")
  121. }
  122. }
  123. func TestSendData_Small(t *testing.T) {
  124. client, server := testClientServer()
  125. defer client.Close()
  126. defer server.Close()
  127. wg := &sync.WaitGroup{}
  128. wg.Add(2)
  129. go func() {
  130. defer wg.Done()
  131. stream, err := server.AcceptStream()
  132. if err != nil {
  133. t.Fatalf("err: %v", err)
  134. }
  135. buf := make([]byte, 4)
  136. for i := 0; i < 1000; i++ {
  137. n, err := stream.Read(buf)
  138. if err != nil {
  139. t.Fatalf("err: %v", err)
  140. }
  141. if n != 4 {
  142. t.Fatalf("short read: %d", n)
  143. }
  144. if string(buf) != "test" {
  145. t.Fatalf("bad: %s", buf)
  146. }
  147. }
  148. if err := stream.Close(); err != nil {
  149. t.Fatalf("err: %v", err)
  150. }
  151. }()
  152. go func() {
  153. defer wg.Done()
  154. stream, err := client.Open()
  155. if err != nil {
  156. t.Fatalf("err: %v", err)
  157. }
  158. for i := 0; i < 1000; i++ {
  159. n, err := stream.Write([]byte("test"))
  160. if err != nil {
  161. t.Fatalf("err: %v", err)
  162. }
  163. if n != 4 {
  164. t.Fatalf("short write %d", n)
  165. }
  166. }
  167. if err := stream.Close(); err != nil {
  168. t.Fatalf("err: %v", err)
  169. }
  170. }()
  171. doneCh := make(chan struct{})
  172. go func() {
  173. wg.Wait()
  174. close(doneCh)
  175. }()
  176. select {
  177. case <-doneCh:
  178. case <-time.After(time.Second):
  179. panic("timeout")
  180. }
  181. }
  182. func TestSendData_Large(t *testing.T) {
  183. client, server := testClientServer()
  184. defer client.Close()
  185. defer server.Close()
  186. data := make([]byte, 512*1024)
  187. for idx := range data {
  188. data[idx] = byte(idx % 256)
  189. }
  190. wg := &sync.WaitGroup{}
  191. wg.Add(2)
  192. go func() {
  193. defer wg.Done()
  194. stream, err := server.AcceptStream()
  195. if err != nil {
  196. t.Fatalf("err: %v", err)
  197. }
  198. buf := make([]byte, 4*1024)
  199. for i := 0; i < 128; i++ {
  200. n, err := stream.Read(buf)
  201. if err != nil {
  202. t.Fatalf("err: %v", err)
  203. }
  204. if n != 4*1024 {
  205. t.Fatalf("short read: %d", n)
  206. }
  207. for idx := range buf {
  208. if buf[idx] != byte(idx%256) {
  209. t.Fatalf("bad: %v %v %v", i, idx, buf[idx])
  210. }
  211. }
  212. }
  213. if err := stream.Close(); err != nil {
  214. t.Fatalf("err: %v", err)
  215. }
  216. }()
  217. go func() {
  218. defer wg.Done()
  219. stream, err := client.Open()
  220. if err != nil {
  221. t.Fatalf("err: %v", err)
  222. }
  223. n, err := stream.Write(data)
  224. if err != nil {
  225. t.Fatalf("err: %v", err)
  226. }
  227. if n != len(data) {
  228. t.Fatalf("short write %d", n)
  229. }
  230. if err := stream.Close(); err != nil {
  231. t.Fatalf("err: %v", err)
  232. }
  233. }()
  234. doneCh := make(chan struct{})
  235. go func() {
  236. wg.Wait()
  237. close(doneCh)
  238. }()
  239. select {
  240. case <-doneCh:
  241. case <-time.After(time.Second):
  242. panic("timeout")
  243. }
  244. }
  245. func TestGoAway(t *testing.T) {
  246. client, server := testClientServer()
  247. defer client.Close()
  248. defer server.Close()
  249. if err := server.GoAway(); err != nil {
  250. t.Fatalf("err: %v", err)
  251. }
  252. _, err := client.Open()
  253. if err != ErrRemoteGoAway {
  254. t.Fatalf("err: %v", err)
  255. }
  256. }
  257. func TestManyStreams(t *testing.T) {
  258. client, server := testClientServer()
  259. defer client.Close()
  260. defer server.Close()
  261. wg := &sync.WaitGroup{}
  262. acceptor := func(i int) {
  263. defer wg.Done()
  264. stream, err := server.AcceptStream()
  265. if err != nil {
  266. t.Fatalf("err: %v", err)
  267. }
  268. defer stream.Close()
  269. buf := make([]byte, 512)
  270. for {
  271. n, err := stream.Read(buf)
  272. if err == io.EOF {
  273. return
  274. }
  275. if err != nil {
  276. t.Fatalf("err: %v", err)
  277. }
  278. if n == 0 {
  279. t.Fatalf("err: %v", err)
  280. }
  281. }
  282. }
  283. sender := func(i int) {
  284. defer wg.Done()
  285. stream, err := client.Open()
  286. if err != nil {
  287. t.Fatalf("err: %v", err)
  288. }
  289. defer stream.Close()
  290. msg := fmt.Sprintf("%08d", i)
  291. for i := 0; i < 1000; i++ {
  292. n, err := stream.Write([]byte(msg))
  293. if err != nil {
  294. t.Fatalf("err: %v", err)
  295. }
  296. if n != len(msg) {
  297. t.Fatalf("short write %d", n)
  298. }
  299. }
  300. }
  301. for i := 0; i < 50; i++ {
  302. wg.Add(2)
  303. go acceptor(i)
  304. go sender(i)
  305. }
  306. wg.Wait()
  307. }
  308. func TestManyStreams_PingPong(t *testing.T) {
  309. client, server := testClientServer()
  310. defer client.Close()
  311. defer server.Close()
  312. wg := &sync.WaitGroup{}
  313. ping := []byte("ping")
  314. pong := []byte("pong")
  315. acceptor := func(i int) {
  316. defer wg.Done()
  317. stream, err := server.AcceptStream()
  318. if err != nil {
  319. t.Fatalf("err: %v", err)
  320. }
  321. defer stream.Close()
  322. buf := make([]byte, 4)
  323. for {
  324. n, err := stream.Read(buf)
  325. if err == io.EOF {
  326. return
  327. }
  328. if err != nil {
  329. t.Fatalf("err: %v", err)
  330. }
  331. if n != 4 {
  332. t.Fatalf("err: %v", err)
  333. }
  334. if !bytes.Equal(buf, ping) {
  335. t.Fatalf("bad: %s", buf)
  336. }
  337. n, err = stream.Write(pong)
  338. if err != nil {
  339. t.Fatalf("err: %v", err)
  340. }
  341. if n != 4 {
  342. t.Fatalf("err: %v", err)
  343. }
  344. }
  345. }
  346. sender := func(i int) {
  347. defer wg.Done()
  348. stream, err := client.Open()
  349. if err != nil {
  350. t.Fatalf("err: %v", err)
  351. }
  352. defer stream.Close()
  353. buf := make([]byte, 4)
  354. for i := 0; i < 1000; i++ {
  355. n, err := stream.Write(ping)
  356. if err != nil {
  357. t.Fatalf("err: %v", err)
  358. }
  359. if n != 4 {
  360. t.Fatalf("short write %d", n)
  361. }
  362. n, err = stream.Read(buf)
  363. if err != nil {
  364. t.Fatalf("err: %v", err)
  365. }
  366. if n != 4 {
  367. t.Fatalf("err: %v", err)
  368. }
  369. if !bytes.Equal(buf, pong) {
  370. t.Fatalf("bad: %s", buf)
  371. }
  372. }
  373. }
  374. for i := 0; i < 50; i++ {
  375. wg.Add(2)
  376. go acceptor(i)
  377. go sender(i)
  378. }
  379. wg.Wait()
  380. }