Browse Source

support udp

fatedier 8 years ago
parent
commit
54bbfe26b0
10 changed files with 378 additions and 155 deletions
  1. 2 0
      client/control.go
  2. 88 21
      client/proxy.go
  3. 12 1
      models/msg/msg.go
  4. 3 0
      models/msg/process_test.go
  5. 104 41
      models/proto/udp/udp.go
  6. 0 50
      models/proto/udp/udp_test.go
  7. 4 4
      server/control.go
  8. 148 35
      server/proxy.go
  9. 2 1
      utils/errors/errors.go
  10. 15 2
      utils/net/udp.go

+ 2 - 0
client/control.go

@@ -146,6 +146,8 @@ func (ctl *Control) NewWorkConn() {
 	if pxy, ok := ctl.proxies[startMsg.ProxyName]; ok {
 		go pxy.InWorkConn(workConn)
 		workConn.Info("start a new work connection")
+	} else {
+		workConn.Close()
 	}
 }
 

+ 88 - 21
client/proxy.go

@@ -17,10 +17,15 @@ package client
 import (
 	"fmt"
 	"io"
+	"net"
 
 	"github.com/fatedier/frp/models/config"
+	"github.com/fatedier/frp/models/msg"
 	"github.com/fatedier/frp/models/proto/tcp"
-	"github.com/fatedier/frp/utils/net"
+	"github.com/fatedier/frp/models/proto/udp"
+	"github.com/fatedier/frp/utils/errors"
+	"github.com/fatedier/frp/utils/log"
+	frpNet "github.com/fatedier/frp/utils/net"
 )
 
 // Proxy defines how to work for different proxy type.
@@ -28,40 +33,51 @@ type Proxy interface {
 	Run() error
 
 	// InWorkConn accept work connections registered to server.
-	InWorkConn(conn net.Conn)
+	InWorkConn(conn frpNet.Conn)
 	Close()
+	log.Logger
 }
 
 func NewProxy(ctl *Control, pxyConf config.ProxyConf) (pxy Proxy) {
+	baseProxy := BaseProxy{
+		ctl:    ctl,
+		Logger: log.NewPrefixLogger(pxyConf.GetName()),
+	}
 	switch cfg := pxyConf.(type) {
 	case *config.TcpProxyConf:
 		pxy = &TcpProxy{
-			cfg: cfg,
-			ctl: ctl,
+			BaseProxy: baseProxy,
+			cfg:       cfg,
 		}
 	case *config.UdpProxyConf:
 		pxy = &UdpProxy{
-			cfg: cfg,
-			ctl: ctl,
+			BaseProxy: baseProxy,
+			cfg:       cfg,
 		}
 	case *config.HttpProxyConf:
 		pxy = &HttpProxy{
-			cfg: cfg,
-			ctl: ctl,
+			BaseProxy: baseProxy,
+			cfg:       cfg,
 		}
 	case *config.HttpsProxyConf:
 		pxy = &HttpsProxy{
-			cfg: cfg,
-			ctl: ctl,
+			BaseProxy: baseProxy,
+			cfg:       cfg,
 		}
 	}
 	return
 }
 
+type BaseProxy struct {
+	ctl *Control
+	log.Logger
+}
+
 // TCP
 type TcpProxy struct {
+	BaseProxy
+
 	cfg *config.TcpProxyConf
-	ctl *Control
 }
 
 func (pxy *TcpProxy) Run() (err error) {
@@ -71,15 +87,16 @@ func (pxy *TcpProxy) Run() (err error) {
 func (pxy *TcpProxy) Close() {
 }
 
-func (pxy *TcpProxy) InWorkConn(conn net.Conn) {
+func (pxy *TcpProxy) InWorkConn(conn frpNet.Conn) {
 	defer conn.Close()
 	HandleTcpWorkConnection(&pxy.cfg.LocalSvrConf, &pxy.cfg.BaseProxyConf, conn)
 }
 
 // HTTP
 type HttpProxy struct {
+	BaseProxy
+
 	cfg *config.HttpProxyConf
-	ctl *Control
 }
 
 func (pxy *HttpProxy) Run() (err error) {
@@ -89,15 +106,16 @@ func (pxy *HttpProxy) Run() (err error) {
 func (pxy *HttpProxy) Close() {
 }
 
-func (pxy *HttpProxy) InWorkConn(conn net.Conn) {
+func (pxy *HttpProxy) InWorkConn(conn frpNet.Conn) {
 	defer conn.Close()
 	HandleTcpWorkConnection(&pxy.cfg.LocalSvrConf, &pxy.cfg.BaseProxyConf, conn)
 }
 
 // HTTPS
 type HttpsProxy struct {
+	BaseProxy
+
 	cfg *config.HttpsProxyConf
-	ctl *Control
 }
 
 func (pxy *HttpsProxy) Run() (err error) {
@@ -107,31 +125,80 @@ func (pxy *HttpsProxy) Run() (err error) {
 func (pxy *HttpsProxy) Close() {
 }
 
-func (pxy *HttpsProxy) InWorkConn(conn net.Conn) {
+func (pxy *HttpsProxy) InWorkConn(conn frpNet.Conn) {
 	defer conn.Close()
 	HandleTcpWorkConnection(&pxy.cfg.LocalSvrConf, &pxy.cfg.BaseProxyConf, conn)
 }
 
 // UDP
 type UdpProxy struct {
+	BaseProxy
+
 	cfg *config.UdpProxyConf
-	ctl *Control
+
+	localAddr *net.UDPAddr
+	readCh    chan *msg.UdpPacket
+	sendCh    chan *msg.UdpPacket
+	workConn  frpNet.Conn
 }
 
 func (pxy *UdpProxy) Run() (err error) {
+	pxy.localAddr, err = net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pxy.cfg.LocalIp, pxy.cfg.LocalPort))
+	if err != nil {
+		return
+	}
 	return
 }
 
 func (pxy *UdpProxy) Close() {
+	pxy.workConn.Close()
+	close(pxy.readCh)
+	close(pxy.sendCh)
 }
 
-func (pxy *UdpProxy) InWorkConn(conn net.Conn) {
-	defer conn.Close()
+func (pxy *UdpProxy) InWorkConn(conn frpNet.Conn) {
+	if pxy.workConn != nil {
+		pxy.workConn.Close()
+		close(pxy.readCh)
+		close(pxy.sendCh)
+	}
+	pxy.workConn = conn
+	pxy.readCh = make(chan *msg.UdpPacket, 64)
+	pxy.sendCh = make(chan *msg.UdpPacket, 64)
+
+	workConnReaderFn := func(conn net.Conn) {
+		for {
+			var udpMsg msg.UdpPacket
+			if errRet := msg.ReadMsgInto(conn, &udpMsg); errRet != nil {
+				pxy.Warn("read from workConn for udp error: %v", errRet)
+				return
+			}
+			if errRet := errors.PanicToError(func() {
+				pxy.readCh <- &udpMsg
+			}); errRet != nil {
+				pxy.Info("reader goroutine for udp work connection closed")
+				return
+			}
+		}
+	}
+	workConnSenderFn := func(conn net.Conn) {
+		var errRet error
+		for udpMsg := range pxy.sendCh {
+			if errRet = msg.WriteMsg(conn, udpMsg); errRet != nil {
+				pxy.Info("sender goroutine for udp work connection closed")
+				return
+			}
+		}
+	}
+
+	go workConnSenderFn(pxy.workConn)
+	go workConnReaderFn(pxy.workConn)
+	udp.Forwarder(pxy.localAddr, pxy.readCh, pxy.sendCh)
 }
 
 // Common handler for tcp work connections.
-func HandleTcpWorkConnection(localInfo *config.LocalSvrConf, baseInfo *config.BaseProxyConf, workConn net.Conn) {
-	localConn, err := net.ConnectTcpServer(fmt.Sprintf("%s:%d", localInfo.LocalIp, localInfo.LocalPort))
+func HandleTcpWorkConnection(localInfo *config.LocalSvrConf, baseInfo *config.BaseProxyConf, workConn frpNet.Conn) {
+	localConn, err := frpNet.ConnectTcpServer(fmt.Sprintf("%s:%d", localInfo.LocalIp, localInfo.LocalPort))
 	if err != nil {
 		workConn.Error("connect to local service [%s:%d] error: %v", localInfo.LocalIp, localInfo.LocalPort, err)
 		return

+ 12 - 1
models/msg/msg.go

@@ -14,7 +14,10 @@
 
 package msg
 
-import "reflect"
+import (
+	"net"
+	"reflect"
+)
 
 const (
 	TypeLogin         = 'o'
@@ -26,6 +29,7 @@ const (
 	TypeStartWorkConn = 's'
 	TypePing          = 'h'
 	TypePong          = '4'
+	TypeUdpPacket     = 'u'
 )
 
 var (
@@ -46,6 +50,7 @@ func init() {
 	TypeMap[TypeStartWorkConn] = reflect.TypeOf(StartWorkConn{})
 	TypeMap[TypePing] = reflect.TypeOf(Ping{})
 	TypeMap[TypePong] = reflect.TypeOf(Pong{})
+	TypeMap[TypeUdpPacket] = reflect.TypeOf(UdpPacket{})
 
 	for k, v := range TypeMap {
 		TypeStringMap[v] = k
@@ -116,3 +121,9 @@ type Ping struct {
 
 type Pong struct {
 }
+
+type UdpPacket struct {
+	Content    string       `json:"c"`
+	LocalAddr  *net.UDPAddr `json:"l"`
+	RemoteAddr *net.UDPAddr `json:"r"`
+}

+ 3 - 0
models/msg/process_test.go

@@ -17,6 +17,9 @@ package msg
 import (
 	"bytes"
 	"encoding/binary"
+	"encoding/json"
+	"fmt"
+	"net"
 	"reflect"
 	"testing"
 

+ 104 - 41
models/proto/udp/udp.go

@@ -1,4 +1,4 @@
-// Copyright 2016 fatedier, fatedier@gmail.com
+// Copyright 2017 fatedier, fatedier@gmail.com
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -16,57 +16,120 @@ package udp
 
 import (
 	"encoding/base64"
-	"encoding/json"
 	"net"
-)
+	"sync"
+	"time"
 
-type UdpPacket struct {
-	Content []byte       `json:"-"`
-	Src     *net.UDPAddr `json:"-"`
-	Dst     *net.UDPAddr `json:"-"`
+	"github.com/fatedier/frp/models/msg"
+	"github.com/fatedier/frp/utils/errors"
+	"github.com/fatedier/frp/utils/pool"
+)
 
-	EncodeContent string `json:"content"`
-	SrcStr        string `json:"src"`
-	DstStr        string `json:"dst"`
+func NewUdpPacket(buf []byte, laddr, raddr *net.UDPAddr) *msg.UdpPacket {
+	return &msg.UdpPacket{
+		Content:    base64.StdEncoding.EncodeToString(buf),
+		LocalAddr:  laddr,
+		RemoteAddr: raddr,
+	}
 }
 
-func NewUdpPacket(content []byte, src, dst *net.UDPAddr) *UdpPacket {
-	up := &UdpPacket{
-		Src:           src,
-		Dst:           dst,
-		EncodeContent: base64.StdEncoding.EncodeToString(content),
-		SrcStr:        src.String(),
-		DstStr:        dst.String(),
-	}
-	return up
+func GetContent(m *msg.UdpPacket) (buf []byte, err error) {
+	buf, err = base64.StdEncoding.DecodeString(m.Content)
+	return
 }
 
-// parse one udp packet struct to bytes
-func (up *UdpPacket) Pack() []byte {
-	b, _ := json.Marshal(up)
-	return b
+func ForwardUserConn(udpConn *net.UDPConn, readCh <-chan *msg.UdpPacket, sendCh chan<- *msg.UdpPacket) {
+	// read
+	go func() {
+		for udpMsg := range readCh {
+			buf, err := GetContent(udpMsg)
+			if err != nil {
+				continue
+			}
+			udpConn.WriteToUDP(buf, udpMsg.RemoteAddr)
+		}
+	}()
+
+	// write
+	go func() {
+		buf := pool.GetBuf(1500)
+		defer pool.PutBuf(buf)
+		for {
+			n, remoteAddr, err := udpConn.ReadFromUDP(buf)
+			if err != nil {
+				udpConn.Close()
+				return
+			}
+			udpMsg := NewUdpPacket(buf[:n], nil, remoteAddr)
+			select {
+			case sendCh <- udpMsg:
+			default:
+			}
+		}
+	}()
 }
 
-// parse from bytes to UdpPacket struct
-func (up *UdpPacket) UnPack(packet []byte) error {
-	err := json.Unmarshal(packet, &up)
-	if err != nil {
-		return err
-	}
+func Forwarder(dstAddr *net.UDPAddr, readCh <-chan *msg.UdpPacket, sendCh chan<- *msg.UdpPacket) {
+	var (
+		mu sync.RWMutex
+	)
+	udpConnMap := make(map[string]*net.UDPConn)
 
-	up.Content, err = base64.StdEncoding.DecodeString(up.EncodeContent)
-	if err != nil {
-		return err
-	}
+	// read from dstAddr and write to sendCh
+	writerFn := func(raddr *net.UDPAddr, udpConn *net.UDPConn) {
+		addr := raddr.String()
+		defer func() {
+			mu.Lock()
+			delete(udpConnMap, addr)
+			mu.Unlock()
+		}()
 
-	up.Src, err = net.ResolveUDPAddr("udp", up.SrcStr)
-	if err != nil {
-		return err
-	}
+		buf := pool.GetBuf(1500)
+		for {
+			udpConn.SetReadDeadline(time.Now().Add(30 * time.Second))
+			n, _, err := udpConn.ReadFromUDP(buf)
+			if err != nil {
+				return
+			}
 
-	up.Dst, err = net.ResolveUDPAddr("udp", up.DstStr)
-	if err != nil {
-		return err
+			udpMsg := NewUdpPacket(buf[:n], nil, raddr)
+			if err = errors.PanicToError(func() {
+				select {
+				case sendCh <- udpMsg:
+				default:
+				}
+			}); err != nil {
+				return
+			}
+		}
 	}
-	return nil
+
+	// read from readCh
+	go func() {
+		for udpMsg := range readCh {
+			buf, err := GetContent(udpMsg)
+			if err != nil {
+				continue
+			}
+			mu.Lock()
+			udpConn, ok := udpConnMap[udpMsg.RemoteAddr.String()]
+			if !ok {
+				udpConn, err = net.DialUDP("udp", nil, dstAddr)
+				if err != nil {
+					continue
+				}
+				udpConnMap[udpMsg.RemoteAddr.String()] = udpConn
+			}
+			mu.Unlock()
+
+			_, err = udpConn.Write(buf)
+			if err != nil {
+				udpConn.Close()
+			}
+
+			if !ok {
+				go writerFn(udpMsg.RemoteAddr, udpConn)
+			}
+		}
+	}()
 }

+ 0 - 50
models/proto/udp/udp_test.go

@@ -1,50 +0,0 @@
-// Copyright 2016 fatedier, fatedier@gmail.com
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package udp
-
-import (
-	"net"
-	"testing"
-
-	"github.com/stretchr/testify/assert"
-)
-
-var (
-	content string = "udp packet test"
-	src     string = "1.1.1.1:1000"
-	dst     string = "2.2.2.2:2000"
-
-	udpMsg *UdpPacket
-)
-
-func init() {
-	srcAddr, _ := net.ResolveUDPAddr("udp", src)
-	dstAddr, _ := net.ResolveUDPAddr("udp", dst)
-	udpMsg = NewUdpPacket([]byte(content), srcAddr, dstAddr)
-}
-
-func TestPack(t *testing.T) {
-	assert := assert.New(t)
-	msg := udpMsg.Pack()
-	assert.Equal(string(msg), `{"content":"dWRwIHBhY2tldCB0ZXN0","src":"1.1.1.1:1000","dst":"2.2.2.2:2000"}`)
-}
-
-func TestUnpack(t *testing.T) {
-	assert := assert.New(t)
-	udpMsg.UnPack([]byte(`{"content":"dWRwIHBhY2tldCB0ZXN0","src":"1.1.1.1:1000","dst":"2.2.2.2:2000"}`))
-	assert.Equal(content, string(udpMsg.Content))
-	assert.Equal(src, udpMsg.Src.String())
-	assert.Equal(dst, udpMsg.Dst.String())
-}

+ 4 - 4
server/control.go

@@ -131,7 +131,7 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) {
 	select {
 	case workConn, ok = <-ctl.workConnCh:
 		if !ok {
-			err = fmt.Errorf("no work connections available, control is closing")
+			err = errors.ErrCtlClosed
 			return
 		}
 		ctl.conn.Debug("get work connection from pool")
@@ -148,8 +148,8 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) {
 		select {
 		case workConn, ok = <-ctl.workConnCh:
 			if !ok {
-				err = fmt.Errorf("no work connections available, control is closing")
-				ctl.conn.Warn("%v", err)
+				err = errors.ErrCtlClosed
+				ctl.conn.Warn("no work connections avaiable, %v", err)
 				return
 			}
 
@@ -251,8 +251,8 @@ func (ctl *Control) stoper() {
 	}
 
 	for _, pxy := range ctl.proxies {
-		ctl.svr.DelProxy(pxy.GetName())
 		pxy.Close()
+		ctl.svr.DelProxy(pxy.GetName())
 	}
 
 	ctl.allShutdown.Done()

+ 148 - 35
server/proxy.go

@@ -1,14 +1,19 @@
 package server
 
 import (
+	"context"
 	"fmt"
 	"io"
+	"net"
+	"time"
 
 	"github.com/fatedier/frp/models/config"
 	"github.com/fatedier/frp/models/msg"
 	"github.com/fatedier/frp/models/proto/tcp"
+	"github.com/fatedier/frp/models/proto/udp"
+	"github.com/fatedier/frp/utils/errors"
 	"github.com/fatedier/frp/utils/log"
-	"github.com/fatedier/frp/utils/net"
+	frpNet "github.com/fatedier/frp/utils/net"
 	"github.com/fatedier/frp/utils/vhost"
 )
 
@@ -17,6 +22,7 @@ type Proxy interface {
 	GetControl() *Control
 	GetName() string
 	GetConf() config.ProxyConf
+	GetWorkConnFromPool() (workConn frpNet.Conn, err error)
 	Close()
 	log.Logger
 }
@@ -24,7 +30,7 @@ type Proxy interface {
 type BaseProxy struct {
 	name      string
 	ctl       *Control
-	listeners []net.Listener
+	listeners []frpNet.Listener
 	log.Logger
 }
 
@@ -43,12 +49,41 @@ func (pxy *BaseProxy) Close() {
 	}
 }
 
+func (pxy *BaseProxy) GetWorkConnFromPool() (workConn frpNet.Conn, err error) {
+	ctl := pxy.GetControl()
+	// try all connections from the pool
+	for i := 0; i < ctl.poolCount+1; i++ {
+		if workConn, err = ctl.GetWorkConn(); err != nil {
+			pxy.Warn("failed to get work connection: %v", err)
+			return
+		}
+		pxy.Info("get a new work connection: [%s]", workConn.RemoteAddr().String())
+		workConn.AddLogPrefix(pxy.GetName())
+
+		err := msg.WriteMsg(workConn, &msg.StartWorkConn{
+			ProxyName: pxy.GetName(),
+		})
+		if err != nil {
+			workConn.Warn("failed to send message to work connection from pool: %v, times: %d", err, i)
+			workConn.Close()
+		} else {
+			break
+		}
+	}
+
+	if err != nil {
+		pxy.Error("try to get work connection failed in the end")
+		return
+	}
+	return
+}
+
 // startListenHandler start a goroutine handler for each listener.
-// p: p will just be passed to handler(Proxy, net.Conn).
+// p: p will just be passed to handler(Proxy, frpNet.Conn).
 // handler: each proxy type can set different handler function to deal with connections accepted from listeners.
-func (pxy *BaseProxy) startListenHandler(p Proxy, handler func(Proxy, net.Conn)) {
+func (pxy *BaseProxy) startListenHandler(p Proxy, handler func(Proxy, frpNet.Conn)) {
 	for _, listener := range pxy.listeners {
-		go func(l net.Listener) {
+		go func(l frpNet.Listener) {
 			for {
 				// block
 				// if listener is closed, err returned
@@ -68,7 +103,7 @@ func NewProxy(ctl *Control, pxyConf config.ProxyConf) (pxy Proxy, err error) {
 	basePxy := BaseProxy{
 		name:      pxyConf.GetName(),
 		ctl:       ctl,
-		listeners: make([]net.Listener, 0),
+		listeners: make([]frpNet.Listener, 0),
 		Logger:    log.NewPrefixLogger(ctl.runId),
 	}
 	switch cfg := pxyConf.(type) {
@@ -105,7 +140,7 @@ type TcpProxy struct {
 }
 
 func (pxy *TcpProxy) Run() error {
-	listener, err := net.ListenTcp(config.ServerCommonCfg.BindAddr, pxy.cfg.RemotePort)
+	listener, err := frpNet.ListenTcp(config.ServerCommonCfg.BindAddr, pxy.cfg.RemotePort)
 	if err != nil {
 		return err
 	}
@@ -226,10 +261,106 @@ func (pxy *HttpsProxy) Close() {
 type UdpProxy struct {
 	BaseProxy
 	cfg *config.UdpProxyConf
+
+	udpConn      *net.UDPConn
+	workConn     net.Conn
+	sendCh       chan *msg.UdpPacket
+	readCh       chan *msg.UdpPacket
+	checkCloseCh chan int
 }
 
 func (pxy *UdpProxy) Run() (err error) {
-	return
+	addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", config.ServerCommonCfg.BindAddr, pxy.cfg.RemotePort))
+	if err != nil {
+		return err
+	}
+	udpConn, err := net.ListenUDP("udp", addr)
+	if err != nil {
+		pxy.Warn("listen udp port error: %v", err)
+		return err
+	}
+	pxy.Info("udp proxy listen port [%d]", pxy.cfg.RemotePort)
+
+	pxy.udpConn = udpConn
+	pxy.sendCh = make(chan *msg.UdpPacket, 64)
+	pxy.readCh = make(chan *msg.UdpPacket, 64)
+	pxy.checkCloseCh = make(chan int)
+
+	workConnReaderFn := func(conn net.Conn) {
+		for {
+			var udpMsg msg.UdpPacket
+			if errRet := msg.ReadMsgInto(conn, &udpMsg); errRet != nil {
+				pxy.Warn("read from workConn for udp error: %v", errRet)
+				conn.Close()
+				// notity proxy to start a new work connection
+				errors.PanicToError(func() {
+					pxy.checkCloseCh <- 1
+				})
+				return
+			}
+			if errRet := errors.PanicToError(func() {
+				pxy.readCh <- &udpMsg
+			}); errRet != nil {
+				pxy.Info("reader goroutine for udp work connection closed")
+				return
+			}
+		}
+	}
+	workConnSenderFn := func(conn net.Conn, ctx context.Context) {
+		var errRet error
+		for {
+			select {
+			case udpMsg, ok := <-pxy.sendCh:
+				if !ok {
+					return
+				}
+				if errRet = msg.WriteMsg(conn, udpMsg); errRet != nil {
+					pxy.Info("sender goroutine for udp work connection closed: %v", errRet)
+					return
+				} else {
+					continue
+				}
+			case <-ctx.Done():
+				pxy.Info("sender goroutine for udp work connection closed")
+				return
+			}
+		}
+	}
+
+	go func() {
+		for {
+			// Sleep a while for waiting control send the NewProxyResp to client.
+			time.Sleep(500 * time.Millisecond)
+			workConn, err := pxy.GetWorkConnFromPool()
+			if err != nil {
+				time.Sleep(5 * time.Second)
+				// check if proxy is closed
+				select {
+				case _, ok := <-pxy.checkCloseCh:
+					if !ok {
+						return
+					}
+				default:
+				}
+				continue
+			}
+			pxy.workConn = workConn
+			ctx, cancel := context.WithCancel(context.Background())
+			go workConnReaderFn(workConn)
+			go workConnSenderFn(workConn, ctx)
+			_, ok := <-pxy.checkCloseCh
+			cancel()
+			if !ok {
+				return
+			}
+		}
+	}()
+
+	// Read from user connections and send wrapped udp message to sendCh.
+	// Client will transfor udp message to local udp service and waiting for response for a while.
+	// Response will be wrapped to be transfored in work connection to server.
+	udp.ForwardUserConn(udpConn, pxy.readCh, pxy.sendCh)
+	return nil
 }
 
 func (pxy *UdpProxy) GetConf() config.ProxyConf {
@@ -238,42 +369,24 @@ func (pxy *UdpProxy) GetConf() config.ProxyConf {
 
 func (pxy *UdpProxy) Close() {
 	pxy.BaseProxy.Close()
+	pxy.workConn.Close()
+	pxy.udpConn.Close()
+	close(pxy.checkCloseCh)
+	close(pxy.readCh)
+	close(pxy.sendCh)
 }
 
 // HandleUserTcpConnection is used for incoming tcp user connections.
 // It can be used for tcp, http, https type.
-func HandleUserTcpConnection(pxy Proxy, userConn net.Conn) {
+func HandleUserTcpConnection(pxy Proxy, userConn frpNet.Conn) {
 	defer userConn.Close()
-	ctl := pxy.GetControl()
-	var (
-		workConn net.Conn
-		err      error
-	)
-	// try all connections from the pool
-	for i := 0; i < ctl.poolCount+1; i++ {
-		if workConn, err = ctl.GetWorkConn(); err != nil {
-			pxy.Warn("failed to get work connection: %v", err)
-			return
-		}
-		defer workConn.Close()
-		pxy.Info("get a new work connection: [%s]", workConn.RemoteAddr().String())
-		workConn.AddLogPrefix(pxy.GetName())
-
-		err := msg.WriteMsg(workConn, &msg.StartWorkConn{
-			ProxyName: pxy.GetName(),
-		})
-		if err != nil {
-			workConn.Warn("failed to send message to work connection from pool: %v, times: %d", err, i)
-			workConn.Close()
-		} else {
-			break
-		}
-	}
 
+	// try all connections from the pool
+	workConn, err := pxy.GetWorkConnFromPool()
 	if err != nil {
-		pxy.Error("try to get work connection failed in the end")
 		return
 	}
+	defer workConn.Close()
 
 	var local io.ReadWriteCloser = workConn
 	cfg := pxy.GetConf().GetBaseInfo()

+ 2 - 1
utils/errors/errors.go

@@ -20,7 +20,8 @@ import (
 )
 
 var (
-	ErrMsgType = errors.New("message type error")
+	ErrMsgType   = errors.New("message type error")
+	ErrCtlClosed = errors.New("control is closed")
 )
 
 func PanicToError(fn func()) (err error) {

+ 15 - 2
utils/net/udp.go

@@ -221,12 +221,25 @@ func ListenUDP(bindAddr string, bindPort int64) (l *UdpListener, err error) {
 	return
 }
 
-func (l *UdpListener) writeUdpPacket(packet *UdpPacket) {
+func (l *UdpListener) writeUdpPacket(packet *UdpPacket) (err error) {
 	defer func() {
-		if err := recover(); err != nil {
+		if errRet := recover(); errRet != nil {
+			err = fmt.Errorf("udp write closed listener")
+			l.Info("udp write closed listener")
 		}
 	}()
 	l.writeCh <- packet
+	return
+}
+
+func (l *UdpListener) WriteMsg(buf []byte, remoteAddr *net.UDPAddr) (err error) {
+	// only set remote addr here
+	packet := &UdpPacket{
+		Buf:        buf,
+		RemoteAddr: remoteAddr,
+	}
+	err = l.writeUdpPacket(packet)
+	return
 }
 
 func (l *UdpListener) Accept() (Conn, error) {