dial_test.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. // Copyright 2018 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package socks_test
  5. import (
  6. "context"
  7. "io"
  8. "math/rand"
  9. "net"
  10. "os"
  11. "testing"
  12. "time"
  13. "golang.org/x/net/internal/socks"
  14. "golang.org/x/net/internal/sockstest"
  15. )
  16. const (
  17. targetNetwork = "tcp6"
  18. targetHostname = "fqdn.doesnotexist"
  19. targetHostIP = "2001:db8::1"
  20. targetPort = "5963"
  21. )
  22. func TestDial(t *testing.T) {
  23. t.Run("Connect", func(t *testing.T) {
  24. ss, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
  25. if err != nil {
  26. t.Error(err)
  27. return
  28. }
  29. defer ss.Close()
  30. d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
  31. d.AuthMethods = []socks.AuthMethod{
  32. socks.AuthMethodNotRequired,
  33. socks.AuthMethodUsernamePassword,
  34. }
  35. d.Authenticate = (&socks.UsernamePassword{
  36. Username: "username",
  37. Password: "password",
  38. }).Authenticate
  39. c, err := d.Dial(targetNetwork, net.JoinHostPort(targetHostIP, targetPort))
  40. if err == nil {
  41. c.(*socks.Conn).BoundAddr()
  42. c.Close()
  43. }
  44. if err != nil {
  45. t.Error(err)
  46. return
  47. }
  48. })
  49. t.Run("Cancel", func(t *testing.T) {
  50. ss, err := sockstest.NewServer(sockstest.NoAuthRequired, blackholeCmdFunc)
  51. if err != nil {
  52. t.Error(err)
  53. return
  54. }
  55. defer ss.Close()
  56. d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
  57. ctx, cancel := context.WithCancel(context.Background())
  58. defer cancel()
  59. dialErr := make(chan error)
  60. go func() {
  61. c, err := d.DialContext(ctx, ss.TargetAddr().Network(), net.JoinHostPort(targetHostname, targetPort))
  62. if err == nil {
  63. c.Close()
  64. }
  65. dialErr <- err
  66. }()
  67. time.Sleep(100 * time.Millisecond)
  68. cancel()
  69. err = <-dialErr
  70. if perr, nerr := parseDialError(err); perr != context.Canceled && nerr == nil {
  71. t.Errorf("got %v; want context.Canceled or equivalent", err)
  72. return
  73. }
  74. })
  75. t.Run("Deadline", func(t *testing.T) {
  76. ss, err := sockstest.NewServer(sockstest.NoAuthRequired, blackholeCmdFunc)
  77. if err != nil {
  78. t.Error(err)
  79. return
  80. }
  81. defer ss.Close()
  82. d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
  83. ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond))
  84. defer cancel()
  85. c, err := d.DialContext(ctx, ss.TargetAddr().Network(), net.JoinHostPort(targetHostname, targetPort))
  86. if err == nil {
  87. c.Close()
  88. }
  89. if perr, nerr := parseDialError(err); perr != context.DeadlineExceeded && nerr == nil {
  90. t.Errorf("got %v; want context.DeadlineExceeded or equivalent", err)
  91. return
  92. }
  93. })
  94. t.Run("WithRogueServer", func(t *testing.T) {
  95. ss, err := sockstest.NewServer(sockstest.NoAuthRequired, rogueCmdFunc)
  96. if err != nil {
  97. t.Error(err)
  98. return
  99. }
  100. defer ss.Close()
  101. d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
  102. for i := 0; i < 2*len(rogueCmdList); i++ {
  103. ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond))
  104. defer cancel()
  105. c, err := d.DialContext(ctx, targetNetwork, net.JoinHostPort(targetHostIP, targetPort))
  106. if err == nil {
  107. t.Log(c.(*socks.Conn).BoundAddr())
  108. c.Close()
  109. t.Error("should fail")
  110. }
  111. }
  112. })
  113. }
  114. func blackholeCmdFunc(rw io.ReadWriter, b []byte) error {
  115. if _, err := sockstest.ParseCmdRequest(b); err != nil {
  116. return err
  117. }
  118. var bb [1]byte
  119. for {
  120. if _, err := rw.Read(bb[:]); err != nil {
  121. return err
  122. }
  123. }
  124. }
  125. func rogueCmdFunc(rw io.ReadWriter, b []byte) error {
  126. if _, err := sockstest.ParseCmdRequest(b); err != nil {
  127. return err
  128. }
  129. rw.Write(rogueCmdList[rand.Intn(len(rogueCmdList))])
  130. return nil
  131. }
  132. var rogueCmdList = [][]byte{
  133. {0x05},
  134. {0x06, 0x00, 0x00, 0x01, 192, 0, 2, 1, 0x17, 0x4b},
  135. {0x05, 0x00, 0xff, 0x01, 192, 0, 2, 2, 0x17, 0x4b},
  136. {0x05, 0x00, 0x00, 0x01, 192, 0, 2, 3},
  137. {0x05, 0x00, 0x00, 0x03, 0x04, 'F', 'Q', 'D', 'N'},
  138. }
  139. func parseDialError(err error) (perr, nerr error) {
  140. if e, ok := err.(*net.OpError); ok {
  141. err = e.Err
  142. nerr = e
  143. }
  144. if e, ok := err.(*os.SyscallError); ok {
  145. err = e.Err
  146. }
  147. perr = err
  148. return
  149. }