1
0

proxyprotocol_test.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. package net
  2. import (
  3. "net"
  4. "testing"
  5. pp "github.com/pires/go-proxyproto"
  6. "github.com/stretchr/testify/require"
  7. )
  8. func TestBuildProxyProtocolHeader(t *testing.T) {
  9. require := require.New(t)
  10. tests := []struct {
  11. name string
  12. srcAddr net.Addr
  13. dstAddr net.Addr
  14. version string
  15. expectError bool
  16. }{
  17. {
  18. name: "UDP IPv4 v2",
  19. srcAddr: &net.UDPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345},
  20. dstAddr: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 3306},
  21. version: "v2",
  22. expectError: false,
  23. },
  24. {
  25. name: "TCP IPv4 v1",
  26. srcAddr: &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345},
  27. dstAddr: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 80},
  28. version: "v1",
  29. expectError: false,
  30. },
  31. {
  32. name: "UDP IPv6 v2",
  33. srcAddr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 12345},
  34. dstAddr: &net.UDPAddr{IP: net.ParseIP("::1"), Port: 3306},
  35. version: "v2",
  36. expectError: false,
  37. },
  38. {
  39. name: "TCP IPv6 v1",
  40. srcAddr: &net.TCPAddr{IP: net.ParseIP("::1"), Port: 12345},
  41. dstAddr: &net.TCPAddr{IP: net.ParseIP("2001:db8::1"), Port: 80},
  42. version: "v1",
  43. expectError: false,
  44. },
  45. {
  46. name: "nil source address",
  47. srcAddr: nil,
  48. dstAddr: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 3306},
  49. version: "v2",
  50. expectError: false,
  51. },
  52. {
  53. name: "nil destination address",
  54. srcAddr: &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345},
  55. dstAddr: nil,
  56. version: "v2",
  57. expectError: false,
  58. },
  59. {
  60. name: "unsupported address type",
  61. srcAddr: &net.UnixAddr{Name: "/tmp/test.sock", Net: "unix"},
  62. dstAddr: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 3306},
  63. version: "v2",
  64. expectError: false,
  65. },
  66. }
  67. for _, tt := range tests {
  68. header, err := BuildProxyProtocolHeader(tt.srcAddr, tt.dstAddr, tt.version)
  69. if tt.expectError {
  70. require.Error(err, "test case: %s", tt.name)
  71. continue
  72. }
  73. require.NoError(err, "test case: %s", tt.name)
  74. require.NotEmpty(header, "test case: %s", tt.name)
  75. }
  76. }
  77. func TestBuildProxyProtocolHeaderStruct(t *testing.T) {
  78. require := require.New(t)
  79. tests := []struct {
  80. name string
  81. srcAddr net.Addr
  82. dstAddr net.Addr
  83. version string
  84. expectedProtocol pp.AddressFamilyAndProtocol
  85. expectedVersion byte
  86. expectedCommand pp.ProtocolVersionAndCommand
  87. expectedSourceAddr net.Addr
  88. expectedDestAddr net.Addr
  89. }{
  90. {
  91. name: "TCP IPv4 v2",
  92. srcAddr: &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345},
  93. dstAddr: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 80},
  94. version: "v2",
  95. expectedProtocol: pp.TCPv4,
  96. expectedVersion: 2,
  97. expectedCommand: pp.PROXY,
  98. expectedSourceAddr: &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345},
  99. expectedDestAddr: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 80},
  100. },
  101. {
  102. name: "UDP IPv6 v1",
  103. srcAddr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 12345},
  104. dstAddr: &net.UDPAddr{IP: net.ParseIP("::1"), Port: 3306},
  105. version: "v1",
  106. expectedProtocol: pp.UDPv6,
  107. expectedVersion: 1,
  108. expectedCommand: pp.PROXY,
  109. expectedSourceAddr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 12345},
  110. expectedDestAddr: &net.UDPAddr{IP: net.ParseIP("::1"), Port: 3306},
  111. },
  112. {
  113. name: "TCP IPv6 default version",
  114. srcAddr: &net.TCPAddr{IP: net.ParseIP("::1"), Port: 12345},
  115. dstAddr: &net.TCPAddr{IP: net.ParseIP("2001:db8::1"), Port: 80},
  116. version: "",
  117. expectedProtocol: pp.TCPv6,
  118. expectedVersion: 2, // default to v2
  119. expectedCommand: pp.PROXY,
  120. expectedSourceAddr: &net.TCPAddr{IP: net.ParseIP("::1"), Port: 12345},
  121. expectedDestAddr: &net.TCPAddr{IP: net.ParseIP("2001:db8::1"), Port: 80},
  122. },
  123. {
  124. name: "nil source address",
  125. srcAddr: nil,
  126. dstAddr: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 3306},
  127. version: "v2",
  128. expectedProtocol: pp.UNSPEC,
  129. expectedVersion: 2,
  130. expectedCommand: pp.LOCAL,
  131. expectedSourceAddr: nil, // go-proxyproto sets both to nil when srcAddr is nil
  132. expectedDestAddr: nil,
  133. },
  134. {
  135. name: "nil destination address",
  136. srcAddr: &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345},
  137. dstAddr: nil,
  138. version: "v2",
  139. expectedProtocol: pp.UNSPEC,
  140. expectedVersion: 2,
  141. expectedCommand: pp.LOCAL,
  142. expectedSourceAddr: nil, // go-proxyproto sets both to nil when dstAddr is nil
  143. expectedDestAddr: nil,
  144. },
  145. {
  146. name: "unsupported address type",
  147. srcAddr: &net.UnixAddr{Name: "/tmp/test.sock", Net: "unix"},
  148. dstAddr: &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 3306},
  149. version: "v2",
  150. expectedProtocol: pp.UNSPEC,
  151. expectedVersion: 2,
  152. expectedCommand: pp.LOCAL,
  153. expectedSourceAddr: nil, // go-proxyproto sets both to nil for unsupported types
  154. expectedDestAddr: nil,
  155. },
  156. }
  157. for _, tt := range tests {
  158. header := BuildProxyProtocolHeaderStruct(tt.srcAddr, tt.dstAddr, tt.version)
  159. require.NotNil(header, "test case: %s", tt.name)
  160. require.Equal(tt.expectedCommand, header.Command, "test case: %s", tt.name)
  161. require.Equal(tt.expectedSourceAddr, header.SourceAddr, "test case: %s", tt.name)
  162. require.Equal(tt.expectedDestAddr, header.DestinationAddr, "test case: %s", tt.name)
  163. require.Equal(tt.expectedProtocol, header.TransportProtocol, "test case: %s", tt.name)
  164. require.Equal(tt.expectedVersion, header.Version, "test case: %s", tt.name)
  165. }
  166. }