123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202 |
- package proxyproto
- import (
- "bufio"
- "bytes"
- "encoding/binary"
- "io"
- )
- var (
- lengthV4 = uint16(12)
- lengthV6 = uint16(36)
- lengthUnix = uint16(218)
- lengthV4Bytes = func() []byte {
- a := make([]byte, 2)
- binary.BigEndian.PutUint16(a, lengthV4)
- return a
- }()
- lengthV6Bytes = func() []byte {
- a := make([]byte, 2)
- binary.BigEndian.PutUint16(a, lengthV6)
- return a
- }()
- lengthUnixBytes = func() []byte {
- a := make([]byte, 2)
- binary.BigEndian.PutUint16(a, lengthUnix)
- return a
- }()
- )
- type _ports struct {
- SrcPort uint16
- DstPort uint16
- }
- type _addr4 struct {
- Src [4]byte
- Dst [4]byte
- SrcPort uint16
- DstPort uint16
- }
- type _addr6 struct {
- Src [16]byte
- Dst [16]byte
- _ports
- }
- type _addrUnix struct {
- Src [108]byte
- Dst [108]byte
- }
- func parseVersion2(reader *bufio.Reader) (header *Header, err error) {
-
- for i := 0; i < 12; i++ {
- if _, err = reader.ReadByte(); err != nil {
- return nil, ErrCantReadProtocolVersionAndCommand
- }
- }
- header = new(Header)
- header.Version = 2
-
- b13, err := reader.ReadByte()
- if err != nil {
- return nil, ErrCantReadProtocolVersionAndCommand
- }
- header.Command = ProtocolVersionAndCommand(b13)
- if _, ok := supportedCommand[header.Command]; !ok {
- return nil, ErrUnsupportedProtocolVersionAndCommand
- }
-
- if header.Command.IsLocal() {
- return header, nil
- }
-
- b14, err := reader.ReadByte()
- if err != nil {
- return nil, ErrCantReadAddressFamilyAndProtocol
- }
- header.TransportProtocol = AddressFamilyAndProtocol(b14)
- if _, ok := supportedTransportProtocol[header.TransportProtocol]; !ok {
- return nil, ErrUnsupportedAddressFamilyAndProtocol
- }
-
- var length uint16
- if err := binary.Read(io.LimitReader(reader, 2), binary.BigEndian, &length); err != nil {
- return nil, ErrCantReadLength
- }
- if !header.validateLength(length) {
- return nil, ErrInvalidLength
- }
- if _, err := reader.Peek(int(length)); err != nil {
- return nil, ErrInvalidLength
- }
-
- payloadReader := io.LimitReader(reader, int64(length))
-
- if header.TransportProtocol.IsIPv4() {
- var addr _addr4
- if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil {
- return nil, ErrInvalidAddress
- }
- header.SourceAddress = addr.Src[:]
- header.DestinationAddress = addr.Dst[:]
- header.SourcePort = addr.SrcPort
- header.DestinationPort = addr.DstPort
- } else if header.TransportProtocol.IsIPv6() {
- var addr _addr6
- if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil {
- return nil, ErrInvalidAddress
- }
- header.SourceAddress = addr.Src[:]
- header.DestinationAddress = addr.Dst[:]
- header.SourcePort = addr.SrcPort
- header.DestinationPort = addr.DstPort
- }
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- payloadReader.Read(make([]byte, length))
- return header, nil
- }
- func (header *Header) writeVersion2(w io.Writer) (int64, error) {
- var buf bytes.Buffer
- buf.Write(SIGV2)
- buf.WriteByte(header.Command.toByte())
- if !header.Command.IsLocal() {
- buf.WriteByte(header.TransportProtocol.toByte())
-
- var addrSrc, addrDst []byte
- if header.TransportProtocol.IsIPv4() {
- buf.Write(lengthV4Bytes)
- addrSrc = header.SourceAddress.To4()
- addrDst = header.DestinationAddress.To4()
- } else if header.TransportProtocol.IsIPv6() {
- buf.Write(lengthV6Bytes)
- addrSrc = header.SourceAddress.To16()
- addrDst = header.DestinationAddress.To16()
- } else if header.TransportProtocol.IsUnix() {
- buf.Write(lengthUnixBytes)
-
- addrSrc = []byte(header.SourceAddress.String())
- addrDst = []byte(header.DestinationAddress.String())
- }
- buf.Write(addrSrc)
- buf.Write(addrDst)
- portSrcBytes := func() []byte {
- a := make([]byte, 2)
- binary.BigEndian.PutUint16(a, header.SourcePort)
- return a
- }()
- buf.Write(portSrcBytes)
- portDstBytes := func() []byte {
- a := make([]byte, 2)
- binary.BigEndian.PutUint16(a, header.DestinationPort)
- return a
- }()
- buf.Write(portDstBytes)
- }
- return buf.WriteTo(w)
- }
- func (header *Header) validateLength(length uint16) bool {
- if header.TransportProtocol.IsIPv4() {
- return length >= lengthV4
- } else if header.TransportProtocol.IsIPv6() {
- return length >= lengthV6
- } else if header.TransportProtocol.IsUnix() {
- return length >= lengthUnix
- }
- return false
- }
|