Bläddra i källkod

add proxy protocol support for UDP proxies (#4810)

fatedier 4 veckor sedan
förälder
incheckning
ce366ee17f

+ 1 - 1
README.md

@@ -1025,7 +1025,7 @@ You can get user's real IP from HTTP request headers `X-Forwarded-For`.
 
 #### Proxy Protocol
 
-frp supports Proxy Protocol to send user's real IP to local services. It support all types except UDP.
+frp supports Proxy Protocol to send user's real IP to local services.
 
 Here is an example for https service:
 

+ 2 - 1
Release.md

@@ -1,3 +1,4 @@
 ## Features
 
-* Support for YAML merge functionality (anchors and references with dot-prefixed fields) in strict configuration mode without requiring `--strict-config=false` parameter.
+* Support for YAML merge functionality (anchors and references with dot-prefixed fields) in strict configuration mode without requiring `--strict-config=false` parameter.
+* Support for proxy protocol in UDP proxies to preserve real client IP addresses.

+ 4 - 20
client/proxy/proxy.go

@@ -20,13 +20,11 @@ import (
 	"net"
 	"reflect"
 	"strconv"
-	"strings"
 	"sync"
 	"time"
 
 	libio "github.com/fatedier/golib/io"
 	libnet "github.com/fatedier/golib/net"
-	pp "github.com/pires/go-proxyproto"
 	"golang.org/x/time/rate"
 
 	"github.com/fatedier/frp/pkg/config/types"
@@ -35,6 +33,7 @@ import (
 	plugin "github.com/fatedier/frp/pkg/plugin/client"
 	"github.com/fatedier/frp/pkg/transport"
 	"github.com/fatedier/frp/pkg/util/limit"
+	netpkg "github.com/fatedier/frp/pkg/util/net"
 	"github.com/fatedier/frp/pkg/util/xlog"
 	"github.com/fatedier/frp/pkg/vnet"
 )
@@ -176,24 +175,9 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor
 	}
 
 	if baseCfg.Transport.ProxyProtocolVersion != "" && m.SrcAddr != "" && m.SrcPort != 0 {
-		h := &pp.Header{
-			Command:         pp.PROXY,
-			SourceAddr:      connInfo.SrcAddr,
-			DestinationAddr: connInfo.DstAddr,
-		}
-
-		if strings.Contains(m.SrcAddr, ".") {
-			h.TransportProtocol = pp.TCPv4
-		} else {
-			h.TransportProtocol = pp.TCPv6
-		}
-
-		if baseCfg.Transport.ProxyProtocolVersion == "v1" {
-			h.Version = 1
-		} else if baseCfg.Transport.ProxyProtocolVersion == "v2" {
-			h.Version = 2
-		}
-		connInfo.ProxyProtocolHeader = h
+		// Use the common proxy protocol builder function
+		header := netpkg.BuildProxyProtocolHeaderStruct(connInfo.SrcAddr, connInfo.DstAddr, baseCfg.Transport.ProxyProtocolVersion)
+		connInfo.ProxyProtocolHeader = header
 	}
 	connInfo.Conn = remote
 	connInfo.UnderlyingConn = workConn

+ 1 - 1
client/proxy/sudp.go

@@ -205,5 +205,5 @@ func (pxy *SUDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) {
 	go workConnReaderFn(workConn, readCh)
 	go heartbeatFn(sendCh)
 
-	udp.Forwarder(pxy.localAddr, readCh, sendCh, int(pxy.clientCfg.UDPPacketSize))
+	udp.Forwarder(pxy.localAddr, readCh, sendCh, int(pxy.clientCfg.UDPPacketSize), pxy.cfg.Transport.ProxyProtocolVersion)
 }

+ 3 - 1
client/proxy/udp.go

@@ -171,5 +171,7 @@ func (pxy *UDPProxy) InWorkConn(conn net.Conn, _ *msg.StartWorkConn) {
 	go workConnSenderFn(pxy.workConn, pxy.sendCh)
 	go workConnReaderFn(pxy.workConn, pxy.readCh)
 	go heartbeatFn(pxy.sendCh)
-	udp.Forwarder(pxy.localAddr, pxy.readCh, pxy.sendCh, int(pxy.clientCfg.UDPPacketSize))
+
+	// Call Forwarder with proxy protocol version (empty string means no proxy protocol)
+	udp.Forwarder(pxy.localAddr, pxy.readCh, pxy.sendCh, int(pxy.clientCfg.UDPPacketSize), pxy.cfg.Transport.ProxyProtocolVersion)
 }

+ 15 - 1
pkg/proto/udp/udp.go

@@ -24,6 +24,7 @@ import (
 	"github.com/fatedier/golib/pool"
 
 	"github.com/fatedier/frp/pkg/msg"
+	netpkg "github.com/fatedier/frp/pkg/util/net"
 )
 
 func NewUDPPacket(buf []byte, laddr, raddr *net.UDPAddr) *msg.UDPPacket {
@@ -69,7 +70,7 @@ func ForwardUserConn(udpConn *net.UDPConn, readCh <-chan *msg.UDPPacket, sendCh
 	}
 }
 
-func Forwarder(dstAddr *net.UDPAddr, readCh <-chan *msg.UDPPacket, sendCh chan<- msg.Message, bufSize int) {
+func Forwarder(dstAddr *net.UDPAddr, readCh <-chan *msg.UDPPacket, sendCh chan<- msg.Message, bufSize int, proxyProtocolVersion string) {
 	var mu sync.RWMutex
 	udpConnMap := make(map[string]*net.UDPConn)
 
@@ -110,6 +111,7 @@ func Forwarder(dstAddr *net.UDPAddr, readCh <-chan *msg.UDPPacket, sendCh chan<-
 			if err != nil {
 				continue
 			}
+
 			mu.Lock()
 			udpConn, ok := udpConnMap[udpMsg.RemoteAddr.String()]
 			if !ok {
@@ -122,6 +124,18 @@ func Forwarder(dstAddr *net.UDPAddr, readCh <-chan *msg.UDPPacket, sendCh chan<-
 			}
 			mu.Unlock()
 
+			// Add proxy protocol header if configured
+			if proxyProtocolVersion != "" && udpMsg.RemoteAddr != nil {
+				ppBuf, err := netpkg.BuildProxyProtocolHeader(udpMsg.RemoteAddr, dstAddr, proxyProtocolVersion)
+				if err == nil {
+					// Prepend proxy protocol header to the UDP payload
+					finalBuf := make([]byte, len(ppBuf)+len(buf))
+					copy(finalBuf, ppBuf)
+					copy(finalBuf[len(ppBuf):], buf)
+					buf = finalBuf
+				}
+			}
+
 			_, err = udpConn.Write(buf)
 			if err != nil {
 				udpConn.Close()

+ 45 - 0
pkg/util/net/proxyprotocol.go

@@ -0,0 +1,45 @@
+// Copyright 2025 The frp Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package net
+
+import (
+	"bytes"
+	"fmt"
+	"net"
+
+	pp "github.com/pires/go-proxyproto"
+)
+
+func BuildProxyProtocolHeaderStruct(srcAddr, dstAddr net.Addr, version string) *pp.Header {
+	var versionByte byte
+	if version == "v1" {
+		versionByte = 1
+	} else {
+		versionByte = 2 // default to v2
+	}
+	return pp.HeaderProxyFromAddrs(versionByte, srcAddr, dstAddr)
+}
+
+func BuildProxyProtocolHeader(srcAddr, dstAddr net.Addr, version string) ([]byte, error) {
+	h := BuildProxyProtocolHeaderStruct(srcAddr, dstAddr, version)
+
+	// Convert header to bytes using a buffer
+	var buf bytes.Buffer
+	_, err := h.WriteTo(&buf)
+	if err != nil {
+		return nil, fmt.Errorf("failed to write proxy protocol header: %v", err)
+	}
+	return buf.Bytes(), nil
+}

+ 178 - 0
pkg/util/net/proxyprotocol_test.go

@@ -0,0 +1,178 @@
+package net
+
+import (
+	"net"
+	"testing"
+
+	pp "github.com/pires/go-proxyproto"
+	"github.com/stretchr/testify/require"
+)
+
+func TestBuildProxyProtocolHeader(t *testing.T) {
+	require := require.New(t)
+
+	tests := []struct {
+		name        string
+		srcAddr     net.Addr
+		dstAddr     net.Addr
+		version     string
+		expectError bool
+	}{
+		{
+			name:        "UDP IPv4 v2",
+			srcAddr:     &net.UDPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345},
+			dstAddr:     &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 3306},
+			version:     "v2",
+			expectError: false,
+		},
+		{
+			name:        "TCP IPv4 v1",
+			srcAddr:     &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345},
+			dstAddr:     &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 80},
+			version:     "v1",
+			expectError: false,
+		},
+		{
+			name:        "UDP IPv6 v2",
+			srcAddr:     &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 12345},
+			dstAddr:     &net.UDPAddr{IP: net.ParseIP("::1"), Port: 3306},
+			version:     "v2",
+			expectError: false,
+		},
+		{
+			name:        "TCP IPv6 v1",
+			srcAddr:     &net.TCPAddr{IP: net.ParseIP("::1"), Port: 12345},
+			dstAddr:     &net.TCPAddr{IP: net.ParseIP("2001:db8::1"), Port: 80},
+			version:     "v1",
+			expectError: false,
+		},
+		{
+			name:        "nil source address",
+			srcAddr:     nil,
+			dstAddr:     &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 3306},
+			version:     "v2",
+			expectError: false,
+		},
+		{
+			name:        "nil destination address",
+			srcAddr:     &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345},
+			dstAddr:     nil,
+			version:     "v2",
+			expectError: false,
+		},
+		{
+			name:        "unsupported address type",
+			srcAddr:     &net.UnixAddr{Name: "/tmp/test.sock", Net: "unix"},
+			dstAddr:     &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 3306},
+			version:     "v2",
+			expectError: false,
+		},
+	}
+
+	for _, tt := range tests {
+		header, err := BuildProxyProtocolHeader(tt.srcAddr, tt.dstAddr, tt.version)
+
+		if tt.expectError {
+			require.Error(err, "test case: %s", tt.name)
+			continue
+		}
+
+		require.NoError(err, "test case: %s", tt.name)
+		require.NotEmpty(header, "test case: %s", tt.name)
+	}
+}
+
+func TestBuildProxyProtocolHeaderStruct(t *testing.T) {
+	require := require.New(t)
+
+	tests := []struct {
+		name               string
+		srcAddr            net.Addr
+		dstAddr            net.Addr
+		version            string
+		expectedProtocol   pp.AddressFamilyAndProtocol
+		expectedVersion    byte
+		expectedCommand    pp.ProtocolVersionAndCommand
+		expectedSourceAddr net.Addr
+		expectedDestAddr   net.Addr
+	}{
+		{
+			name:               "TCP IPv4 v2",
+			srcAddr:            &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345},
+			dstAddr:            &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 80},
+			version:            "v2",
+			expectedProtocol:   pp.TCPv4,
+			expectedVersion:    2,
+			expectedCommand:    pp.PROXY,
+			expectedSourceAddr: &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345},
+			expectedDestAddr:   &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 80},
+		},
+		{
+			name:               "UDP IPv6 v1",
+			srcAddr:            &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 12345},
+			dstAddr:            &net.UDPAddr{IP: net.ParseIP("::1"), Port: 3306},
+			version:            "v1",
+			expectedProtocol:   pp.UDPv6,
+			expectedVersion:    1,
+			expectedCommand:    pp.PROXY,
+			expectedSourceAddr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 12345},
+			expectedDestAddr:   &net.UDPAddr{IP: net.ParseIP("::1"), Port: 3306},
+		},
+		{
+			name:               "TCP IPv6 default version",
+			srcAddr:            &net.TCPAddr{IP: net.ParseIP("::1"), Port: 12345},
+			dstAddr:            &net.TCPAddr{IP: net.ParseIP("2001:db8::1"), Port: 80},
+			version:            "",
+			expectedProtocol:   pp.TCPv6,
+			expectedVersion:    2, // default to v2
+			expectedCommand:    pp.PROXY,
+			expectedSourceAddr: &net.TCPAddr{IP: net.ParseIP("::1"), Port: 12345},
+			expectedDestAddr:   &net.TCPAddr{IP: net.ParseIP("2001:db8::1"), Port: 80},
+		},
+		{
+			name:               "nil source address",
+			srcAddr:            nil,
+			dstAddr:            &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 3306},
+			version:            "v2",
+			expectedProtocol:   pp.UNSPEC,
+			expectedVersion:    2,
+			expectedCommand:    pp.LOCAL,
+			expectedSourceAddr: nil, // go-proxyproto sets both to nil when srcAddr is nil
+			expectedDestAddr:   nil,
+		},
+		{
+			name:               "nil destination address",
+			srcAddr:            &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345},
+			dstAddr:            nil,
+			version:            "v2",
+			expectedProtocol:   pp.UNSPEC,
+			expectedVersion:    2,
+			expectedCommand:    pp.LOCAL,
+			expectedSourceAddr: nil, // go-proxyproto sets both to nil when dstAddr is nil
+			expectedDestAddr:   nil,
+		},
+		{
+			name:               "unsupported address type",
+			srcAddr:            &net.UnixAddr{Name: "/tmp/test.sock", Net: "unix"},
+			dstAddr:            &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 3306},
+			version:            "v2",
+			expectedProtocol:   pp.UNSPEC,
+			expectedVersion:    2,
+			expectedCommand:    pp.LOCAL,
+			expectedSourceAddr: nil, // go-proxyproto sets both to nil for unsupported types
+			expectedDestAddr:   nil,
+		},
+	}
+
+	for _, tt := range tests {
+		header := BuildProxyProtocolHeaderStruct(tt.srcAddr, tt.dstAddr, tt.version)
+
+		require.NotNil(header, "test case: %s", tt.name)
+
+		require.Equal(tt.expectedCommand, header.Command, "test case: %s", tt.name)
+		require.Equal(tt.expectedSourceAddr, header.SourceAddr, "test case: %s", tt.name)
+		require.Equal(tt.expectedDestAddr, header.DestinationAddr, "test case: %s", tt.name)
+		require.Equal(tt.expectedProtocol, header.TransportProtocol, "test case: %s", tt.name)
+		require.Equal(tt.expectedVersion, header.Version, "test case: %s", tt.name)
+	}
+}

+ 50 - 0
test/e2e/v1/features/real_ip.go

@@ -227,6 +227,56 @@ var _ = ginkgo.Describe("[Feature: Real IP]", func() {
 			})
 		})
 
+		ginkgo.It("UDP", func() {
+			serverConf := consts.DefaultServerConfig
+			clientConf := consts.DefaultClientConfig
+
+			localPort := f.AllocPort()
+			localServer := streamserver.New(streamserver.UDP, streamserver.WithBindPort(localPort),
+				streamserver.WithCustomHandler(func(c net.Conn) {
+					defer c.Close()
+					rd := bufio.NewReader(c)
+					ppHeader, err := pp.Read(rd)
+					if err != nil {
+						log.Errorf("read proxy protocol error: %v", err)
+						return
+					}
+
+					// Read the actual UDP content after proxy protocol header
+					if _, err := rpc.ReadBytes(rd); err != nil {
+						return
+					}
+
+					buf := []byte(ppHeader.SourceAddr.String())
+					_, _ = rpc.WriteBytes(c, buf)
+				}))
+			f.RunServer("", localServer)
+
+			remotePort := f.AllocPort()
+			clientConf += fmt.Sprintf(`
+			[[proxies]]
+			name = "udp"
+			type = "udp"
+			localPort = %d
+			remotePort = %d
+			transport.proxyProtocolVersion = "v2"
+			`, localPort, remotePort)
+
+			f.RunProcesses([]string{serverConf}, []string{clientConf})
+
+			framework.NewRequestExpect(f).Protocol("udp").Port(remotePort).Ensure(func(resp *request.Response) bool {
+				log.Tracef("udp proxy protocol get SourceAddr: %s", string(resp.Content))
+				addr, err := net.ResolveUDPAddr("udp", string(resp.Content))
+				if err != nil {
+					return false
+				}
+				if addr.IP.String() != "127.0.0.1" {
+					return false
+				}
+				return true
+			})
+		})
+
 		ginkgo.It("HTTP", func() {
 			vhostHTTPPort := f.AllocPort()
 			serverConf := consts.DefaultServerConfig + fmt.Sprintf(`