v2.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. package proxyproto
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/binary"
  6. "io"
  7. )
  8. var (
  9. lengthV4 = uint16(12)
  10. lengthV6 = uint16(36)
  11. lengthUnix = uint16(218)
  12. lengthV4Bytes = func() []byte {
  13. a := make([]byte, 2)
  14. binary.BigEndian.PutUint16(a, lengthV4)
  15. return a
  16. }()
  17. lengthV6Bytes = func() []byte {
  18. a := make([]byte, 2)
  19. binary.BigEndian.PutUint16(a, lengthV6)
  20. return a
  21. }()
  22. lengthUnixBytes = func() []byte {
  23. a := make([]byte, 2)
  24. binary.BigEndian.PutUint16(a, lengthUnix)
  25. return a
  26. }()
  27. )
  28. type _ports struct {
  29. SrcPort uint16
  30. DstPort uint16
  31. }
  32. type _addr4 struct {
  33. Src [4]byte
  34. Dst [4]byte
  35. SrcPort uint16
  36. DstPort uint16
  37. }
  38. type _addr6 struct {
  39. Src [16]byte
  40. Dst [16]byte
  41. _ports
  42. }
  43. type _addrUnix struct {
  44. Src [108]byte
  45. Dst [108]byte
  46. }
  47. func parseVersion2(reader *bufio.Reader) (header *Header, err error) {
  48. // Skip first 12 bytes (signature)
  49. for i := 0; i < 12; i++ {
  50. if _, err = reader.ReadByte(); err != nil {
  51. return nil, ErrCantReadProtocolVersionAndCommand
  52. }
  53. }
  54. header = new(Header)
  55. header.Version = 2
  56. // Read the 13th byte, protocol version and command
  57. b13, err := reader.ReadByte()
  58. if err != nil {
  59. return nil, ErrCantReadProtocolVersionAndCommand
  60. }
  61. header.Command = ProtocolVersionAndCommand(b13)
  62. if _, ok := supportedCommand[header.Command]; !ok {
  63. return nil, ErrUnsupportedProtocolVersionAndCommand
  64. }
  65. // If command is LOCAL, header ends here
  66. if header.Command.IsLocal() {
  67. return header, nil
  68. }
  69. // Read the 14th byte, address family and protocol
  70. b14, err := reader.ReadByte()
  71. if err != nil {
  72. return nil, ErrCantReadAddressFamilyAndProtocol
  73. }
  74. header.TransportProtocol = AddressFamilyAndProtocol(b14)
  75. if _, ok := supportedTransportProtocol[header.TransportProtocol]; !ok {
  76. return nil, ErrUnsupportedAddressFamilyAndProtocol
  77. }
  78. // Make sure there are bytes available as specified in length
  79. var length uint16
  80. if err := binary.Read(io.LimitReader(reader, 2), binary.BigEndian, &length); err != nil {
  81. return nil, ErrCantReadLength
  82. }
  83. if !header.validateLength(length) {
  84. return nil, ErrInvalidLength
  85. }
  86. if _, err := reader.Peek(int(length)); err != nil {
  87. return nil, ErrInvalidLength
  88. }
  89. // Length-limited reader for payload section
  90. payloadReader := io.LimitReader(reader, int64(length))
  91. // Read addresses and ports
  92. if header.TransportProtocol.IsIPv4() {
  93. var addr _addr4
  94. if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil {
  95. return nil, ErrInvalidAddress
  96. }
  97. header.SourceAddress = addr.Src[:]
  98. header.DestinationAddress = addr.Dst[:]
  99. header.SourcePort = addr.SrcPort
  100. header.DestinationPort = addr.DstPort
  101. } else if header.TransportProtocol.IsIPv6() {
  102. var addr _addr6
  103. if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil {
  104. return nil, ErrInvalidAddress
  105. }
  106. header.SourceAddress = addr.Src[:]
  107. header.DestinationAddress = addr.Dst[:]
  108. header.SourcePort = addr.SrcPort
  109. header.DestinationPort = addr.DstPort
  110. }
  111. // TODO fully support Unix addresses
  112. // else if header.TransportProtocol.IsUnix() {
  113. // var addr _addrUnix
  114. // if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil {
  115. // return nil, ErrInvalidAddress
  116. // }
  117. //
  118. //if header.SourceAddress, err = net.ResolveUnixAddr("unix", string(addr.Src[:])); err != nil {
  119. // return nil, ErrCantResolveSourceUnixAddress
  120. //}
  121. //if header.DestinationAddress, err = net.ResolveUnixAddr("unix", string(addr.Dst[:])); err != nil {
  122. // return nil, ErrCantResolveDestinationUnixAddress
  123. //}
  124. //}
  125. // TODO add encapsulated TLV support
  126. // Drain the remaining padding
  127. payloadReader.Read(make([]byte, length))
  128. return header, nil
  129. }
  130. func (header *Header) writeVersion2(w io.Writer) (int64, error) {
  131. var buf bytes.Buffer
  132. buf.Write(SIGV2)
  133. buf.WriteByte(header.Command.toByte())
  134. if !header.Command.IsLocal() {
  135. buf.WriteByte(header.TransportProtocol.toByte())
  136. // TODO add encapsulated TLV length
  137. var addrSrc, addrDst []byte
  138. if header.TransportProtocol.IsIPv4() {
  139. buf.Write(lengthV4Bytes)
  140. addrSrc = header.SourceAddress.To4()
  141. addrDst = header.DestinationAddress.To4()
  142. } else if header.TransportProtocol.IsIPv6() {
  143. buf.Write(lengthV6Bytes)
  144. addrSrc = header.SourceAddress.To16()
  145. addrDst = header.DestinationAddress.To16()
  146. } else if header.TransportProtocol.IsUnix() {
  147. buf.Write(lengthUnixBytes)
  148. // TODO is below right?
  149. addrSrc = []byte(header.SourceAddress.String())
  150. addrDst = []byte(header.DestinationAddress.String())
  151. }
  152. buf.Write(addrSrc)
  153. buf.Write(addrDst)
  154. portSrcBytes := func() []byte {
  155. a := make([]byte, 2)
  156. binary.BigEndian.PutUint16(a, header.SourcePort)
  157. return a
  158. }()
  159. buf.Write(portSrcBytes)
  160. portDstBytes := func() []byte {
  161. a := make([]byte, 2)
  162. binary.BigEndian.PutUint16(a, header.DestinationPort)
  163. return a
  164. }()
  165. buf.Write(portDstBytes)
  166. }
  167. return buf.WriteTo(w)
  168. }
  169. func (header *Header) validateLength(length uint16) bool {
  170. if header.TransportProtocol.IsIPv4() {
  171. return length >= lengthV4
  172. } else if header.TransportProtocol.IsIPv6() {
  173. return length >= lengthV6
  174. } else if header.TransportProtocol.IsUnix() {
  175. return length >= lengthUnix
  176. }
  177. return false
  178. }