Browse Source

support udp type

fatedier 8 years ago
parent
commit
f2999e3317

+ 2 - 2
.travis.yml

@@ -3,8 +3,8 @@ language: go
 
 go:
     - 1.5.4
-    - 1.6.3
-    - 1.7
+    - 1.6.4
+    - 1.7.4
 
 install:
     - make

+ 7 - 0
conf/frpc.ini

@@ -29,6 +29,13 @@ use_gzip = false
 # connections will be established in advance, default value is zero
 pool_count = 10
 
+[dns]
+type = udp
+local_ip = 127.0.0.1
+local_port = 53
+use_encryption = true
+use_gzip = true
+
 # Resolve your domain names to [server_addr] so you can use http://web01.yourdomain.com to browse web01 and http://web02.yourdomain.com to browse web02, the domains are set in frps.ini
 [web01]
 type = http

+ 6 - 0
conf/frps.ini

@@ -34,6 +34,12 @@ auth_token = 123
 bind_addr = 0.0.0.0
 listen_port = 6000
 
+[dns]
+type = udp
+auth_token = 123
+bind_addr = 0.0.0.0
+listen_port = 53
+
 [web01]
 # if type equals http, vhost_http_port must be set
 type = http

+ 8 - 2
src/cmd/frpc/control.go

@@ -120,7 +120,7 @@ func msgSender(cli *client.ProxyClient, c *conn.Conn, msgSendChan chan interface
 		}
 
 		buf, _ := json.Marshal(msg)
-		err := c.Write(string(buf) + "\n")
+		err := c.WriteString(string(buf) + "\n")
 		if err != nil {
 			log.Warn("ProxyName [%s], write to server error, proxy exit", cli.Name)
 			c.Close()
@@ -165,7 +165,7 @@ func loginToServer(cli *client.ProxyClient) (c *conn.Conn, err error) {
 	}
 
 	buf, _ := json.Marshal(req)
-	err = c.Write(string(buf) + "\n")
+	err = c.WriteString(string(buf) + "\n")
 	if err != nil {
 		log.Error("ProxyName [%s], write to server error, %v", cli.Name, err)
 		return
@@ -190,6 +190,12 @@ func loginToServer(cli *client.ProxyClient) (c *conn.Conn, err error) {
 	}
 
 	log.Info("ProxyName [%s], connect to server [%s:%d] success!", cli.Name, client.ServerAddr, client.ServerPort)
+
+	if cli.Type == "udp" {
+		// we only need one udp work connection
+		// all udp messages will be forwarded throngh this connection
+		go cli.StartUdpTunnelOnce(client.ServerAddr, client.ServerPort)
+	}
 	return
 }
 

+ 15 - 3
src/cmd/frps/control.go

@@ -71,14 +71,14 @@ func controlWorker(c *conn.Conn) {
 	// login when type is NewCtlConn or NewWorkConn
 	ret, info := doLogin(cliReq, c)
 	// if login type is NewWorkConn, nothing will be send to frpc
-	if cliReq.Type != consts.NewWorkConn {
+	if cliReq.Type == consts.NewCtlConn {
 		cliRes := &msg.ControlRes{
 			Type: consts.NewCtlConnRes,
 			Code: ret,
 			Msg:  info,
 		}
 		byteBuf, _ := json.Marshal(cliRes)
-		err = c.Write(string(byteBuf) + "\n")
+		err = c.WriteString(string(byteBuf) + "\n")
 		if err != nil {
 			log.Warn("ProxyName [%s], write to client error, proxy exit", cliReq.ProxyName)
 			return
@@ -144,9 +144,11 @@ func msgReader(s *server.ProxyServer, c *conn.Conn, msgSendChan chan interface{}
 		if err != nil {
 			if err == io.EOF {
 				log.Warn("ProxyName [%s], client is dead!", s.Name)
+				s.Close()
 				return err
 			} else if c == nil || c.IsClosed() {
 				log.Warn("ProxyName [%s], client connection is closed", s.Name)
+				s.Close()
 				return err
 			}
 			log.Warn("ProxyName [%s], read error: %v", s.Name, err)
@@ -183,7 +185,7 @@ func msgSender(s *server.ProxyServer, c *conn.Conn, msgSendChan chan interface{}
 		}
 
 		buf, _ := json.Marshal(msg)
-		err := c.Write(string(buf) + "\n")
+		err := c.WriteString(string(buf) + "\n")
 		if err != nil {
 			log.Warn("ProxyName [%s], write to client error, proxy exit", s.Name)
 			s.Close()
@@ -193,6 +195,9 @@ func msgSender(s *server.ProxyServer, c *conn.Conn, msgSendChan chan interface{}
 }
 
 // if success, ret equals 0, otherwise greater than 0
+// NewCtlConn
+// NewWorkConn
+// NewWorkConnUdp
 func doLogin(req *msg.ControlReq, c *conn.Conn) (ret int64, info string) {
 	ret = 1
 	// check if PrivilegeMode is enabled
@@ -325,6 +330,13 @@ func doLogin(req *msg.ControlReq, c *conn.Conn) (ret int64, info string) {
 		}
 		// the connection will close after join over
 		s.RegisterNewWorkConn(c)
+	} else if req.Type == consts.NewWorkConnUdp {
+		// work conn for udp
+		if s.Status != consts.Working {
+			log.Warn("ProxyName [%s], is not working when it gets one new work connnection for udp", req.ProxyName)
+			return
+		}
+		s.RegisterNewWorkConnUdp(c)
 	} else {
 		info = fmt.Sprintf("Unsupport login message type [%d]", req.Type)
 		log.Warn("Unsupport login message type [%d]", req.Type)

+ 70 - 19
src/models/client/client.go

@@ -17,6 +17,7 @@ package client
 import (
 	"encoding/json"
 	"fmt"
+	"sync"
 	"time"
 
 	"github.com/fatedier/frp/src/models/config"
@@ -34,19 +35,71 @@ type ProxyClient struct {
 
 	RemotePort    int64
 	CustomDomains []string
+
+	udpTunnel *conn.Conn
+	once      sync.Once
+}
+
+// if proxy type is udp, keep a tcp connection for transferring udp packages
+func (pc *ProxyClient) StartUdpTunnelOnce(addr string, port int64) {
+	pc.once.Do(func() {
+		var err error
+		var c *conn.Conn
+		udpProcessor := NewUdpProcesser(nil, pc.LocalIp, pc.LocalPort)
+		for {
+			if pc.udpTunnel == nil || pc.udpTunnel.IsClosed() {
+				if HttpProxy == "" {
+					c, err = conn.ConnectServer(fmt.Sprintf("%s:%d", addr, port))
+				} else {
+					c, err = conn.ConnectServerByHttpProxy(HttpProxy, fmt.Sprintf("%s:%d", addr, port))
+				}
+				if err != nil {
+					log.Error("ProxyName [%s], udp tunnel connect to server [%s:%d] error, %v", pc.Name, addr, port, err)
+					time.Sleep(5 * time.Second)
+					continue
+				}
+
+				nowTime := time.Now().Unix()
+				req := &msg.ControlReq{
+					Type:          consts.NewWorkConnUdp,
+					ProxyName:     pc.Name,
+					PrivilegeMode: pc.PrivilegeMode,
+					Timestamp:     nowTime,
+				}
+				if pc.PrivilegeMode == true {
+					req.PrivilegeKey = pcrypto.GetAuthKey(pc.Name + PrivilegeToken + fmt.Sprintf("%d", nowTime))
+				} else {
+					req.AuthKey = pcrypto.GetAuthKey(pc.Name + pc.AuthToken + fmt.Sprintf("%d", nowTime))
+				}
+
+				buf, _ := json.Marshal(req)
+				err = c.WriteString(string(buf) + "\n")
+				if err != nil {
+					log.Error("ProxyName [%s], udp tunnel write to server error, %v", pc.Name, err)
+					c.Close()
+					time.Sleep(1 * time.Second)
+					continue
+				}
+				pc.udpTunnel = c
+				udpProcessor.UpdateTcpConn(pc.udpTunnel)
+				udpProcessor.Run()
+			}
+			time.Sleep(1 * time.Second)
+		}
+	})
 }
 
-func (p *ProxyClient) GetLocalConn() (c *conn.Conn, err error) {
-	c, err = conn.ConnectServer(fmt.Sprintf("%s:%d", p.LocalIp, p.LocalPort))
+func (pc *ProxyClient) GetLocalConn() (c *conn.Conn, err error) {
+	c, err = conn.ConnectServer(fmt.Sprintf("%s:%d", pc.LocalIp, pc.LocalPort))
 	if err != nil {
-		log.Error("ProxyName [%s], connect to local port error, %v", p.Name, err)
+		log.Error("ProxyName [%s], connect to local port error, %v", pc.Name, err)
 	}
 	return
 }
 
-func (p *ProxyClient) GetRemoteConn(addr string, port int64) (c *conn.Conn, err error) {
+func (pc *ProxyClient) GetRemoteConn(addr string, port int64) (c *conn.Conn, err error) {
 	defer func() {
-		if err != nil {
+		if err != nil && c != nil {
 			c.Close()
 		}
 	}()
@@ -57,29 +110,27 @@ func (p *ProxyClient) GetRemoteConn(addr string, port int64) (c *conn.Conn, err
 		c, err = conn.ConnectServerByHttpProxy(HttpProxy, fmt.Sprintf("%s:%d", addr, port))
 	}
 	if err != nil {
-		log.Error("ProxyName [%s], connect to server [%s:%d] error, %v", p.Name, addr, port, err)
+		log.Error("ProxyName [%s], connect to server [%s:%d] error, %v", pc.Name, addr, port, err)
 		return
 	}
 
 	nowTime := time.Now().Unix()
 	req := &msg.ControlReq{
 		Type:          consts.NewWorkConn,
-		ProxyName:     p.Name,
-		PrivilegeMode: p.PrivilegeMode,
+		ProxyName:     pc.Name,
+		PrivilegeMode: pc.PrivilegeMode,
 		Timestamp:     nowTime,
 	}
-	if p.PrivilegeMode == true {
-		privilegeKey := pcrypto.GetAuthKey(p.Name + PrivilegeToken + fmt.Sprintf("%d", nowTime))
-		req.PrivilegeKey = privilegeKey
+	if pc.PrivilegeMode == true {
+		req.PrivilegeKey = pcrypto.GetAuthKey(pc.Name + PrivilegeToken + fmt.Sprintf("%d", nowTime))
 	} else {
-		authKey := pcrypto.GetAuthKey(p.Name + p.AuthToken + fmt.Sprintf("%d", nowTime))
-		req.AuthKey = authKey
+		req.AuthKey = pcrypto.GetAuthKey(pc.Name + pc.AuthToken + fmt.Sprintf("%d", nowTime))
 	}
 
 	buf, _ := json.Marshal(req)
-	err = c.Write(string(buf) + "\n")
+	err = c.WriteString(string(buf) + "\n")
 	if err != nil {
-		log.Error("ProxyName [%s], write to server error, %v", p.Name, err)
+		log.Error("ProxyName [%s], write to server error, %v", pc.Name, err)
 		return
 	}
 
@@ -87,12 +138,12 @@ func (p *ProxyClient) GetRemoteConn(addr string, port int64) (c *conn.Conn, err
 	return
 }
 
-func (p *ProxyClient) StartTunnel(serverAddr string, serverPort int64) (err error) {
-	localConn, err := p.GetLocalConn()
+func (pc *ProxyClient) StartTunnel(serverAddr string, serverPort int64) (err error) {
+	localConn, err := pc.GetLocalConn()
 	if err != nil {
 		return
 	}
-	remoteConn, err := p.GetRemoteConn(serverAddr, serverPort)
+	remoteConn, err := pc.GetRemoteConn(serverAddr, serverPort)
 	if err != nil {
 		return
 	}
@@ -101,7 +152,7 @@ func (p *ProxyClient) StartTunnel(serverAddr string, serverPort int64) (err erro
 	log.Debug("Join two connections, (l[%s] r[%s]) (l[%s] r[%s])", localConn.GetLocalAddr(), localConn.GetRemoteAddr(),
 		remoteConn.GetLocalAddr(), remoteConn.GetRemoteAddr())
 	needRecord := false
-	go msg.JoinMore(localConn, remoteConn, p.BaseConf, needRecord)
+	go msg.JoinMore(localConn, remoteConn, pc.BaseConf, needRecord)
 
 	return nil
 }

+ 1 - 1
src/models/client/config.go

@@ -130,7 +130,7 @@ func LoadConf(confFile string) (err error) {
 			proxyClient.Type = "tcp"
 			tmpStr, ok = section["type"]
 			if ok {
-				if tmpStr != "tcp" && tmpStr != "http" && tmpStr != "https" {
+				if tmpStr != "tcp" && tmpStr != "http" && tmpStr != "https" && tmpStr != "udp" {
 					return fmt.Errorf("Parse conf error: proxy [%s] type error", proxyClient.Name)
 				}
 				proxyClient.Type = tmpStr

+ 153 - 0
src/models/client/process_udp.go

@@ -0,0 +1,153 @@
+// Copyright 2016 fatedier, fatedier@gmail.com
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package client
+
+import (
+	"fmt"
+	"io"
+	"net"
+	"sync"
+	"time"
+
+	"github.com/fatedier/frp/src/models/msg"
+	"github.com/fatedier/frp/src/utils/conn"
+	"github.com/fatedier/frp/src/utils/pool"
+)
+
+type UdpProcesser struct {
+	tcpConn *conn.Conn
+	closeCh chan struct{}
+
+	localAddr string
+
+	// cache local udp connections
+	// key is remoteAddr
+	localUdpConns map[string]*net.UDPConn
+	mutex         sync.RWMutex
+	tcpConnMutex  sync.RWMutex
+}
+
+func NewUdpProcesser(c *conn.Conn, localIp string, localPort int64) *UdpProcesser {
+	return &UdpProcesser{
+		tcpConn:       c,
+		closeCh:       make(chan struct{}),
+		localAddr:     fmt.Sprintf("%s:%d", localIp, localPort),
+		localUdpConns: make(map[string]*net.UDPConn),
+	}
+}
+
+func (up *UdpProcesser) UpdateTcpConn(c *conn.Conn) {
+	up.tcpConnMutex.Lock()
+	defer up.tcpConnMutex.Unlock()
+	up.tcpConn = c
+}
+
+func (up *UdpProcesser) Run() {
+	go up.ReadLoop()
+}
+
+func (up *UdpProcesser) ReadLoop() {
+	var (
+		buf string
+		err error
+	)
+	for {
+		udpPacket := &msg.UdpPacket{}
+
+		// read udp package from frps
+		buf, err = up.tcpConn.ReadLine()
+		if err != nil {
+			if err == io.EOF {
+				return
+			} else {
+				continue
+			}
+		}
+		err = udpPacket.UnPack([]byte(buf))
+		if err != nil {
+			continue
+		}
+
+		// write to local udp port
+		sendConn, ok := up.GetUdpConn(udpPacket.SrcStr)
+		if !ok {
+			dstAddr, err := net.ResolveUDPAddr("udp", up.localAddr)
+			if err != nil {
+				continue
+			}
+			sendConn, err = net.DialUDP("udp", nil, dstAddr)
+			if err != nil {
+				continue
+			}
+
+			up.SetUdpConn(udpPacket.SrcStr, sendConn)
+		}
+
+		_, err = sendConn.Write(udpPacket.Content)
+		if err != nil {
+			sendConn.Close()
+			continue
+		}
+
+		if !ok {
+			go up.Forward(udpPacket, sendConn)
+		}
+	}
+}
+
+func (up *UdpProcesser) Forward(udpPacket *msg.UdpPacket, singleConn *net.UDPConn) {
+	addr := udpPacket.SrcStr
+	defer up.RemoveUdpConn(addr)
+
+	buf := pool.GetBuf(2048)
+	for {
+		singleConn.SetReadDeadline(time.Now().Add(120 * time.Second))
+		n, remoteAddr, err := singleConn.ReadFromUDP(buf)
+		if err != nil {
+			return
+		}
+
+		// forward to frps
+		forwardPacket := msg.NewUdpPacket(buf[0:n], remoteAddr, udpPacket.Src)
+		up.tcpConnMutex.RLock()
+		err = up.tcpConn.WriteString(string(forwardPacket.Pack()) + "\n")
+		up.tcpConnMutex.RUnlock()
+		if err != nil {
+			return
+		}
+	}
+}
+
+func (up *UdpProcesser) GetUdpConn(addr string) (singleConn *net.UDPConn, ok bool) {
+	up.mutex.RLock()
+	defer up.mutex.RUnlock()
+	singleConn, ok = up.localUdpConns[addr]
+	return
+}
+
+func (up *UdpProcesser) SetUdpConn(addr string, conn *net.UDPConn) {
+	up.mutex.Lock()
+	defer up.mutex.Unlock()
+	up.localUdpConns[addr] = conn
+}
+
+func (up *UdpProcesser) RemoveUdpConn(addr string) {
+	up.mutex.Lock()
+	defer up.mutex.Unlock()
+	if c, ok := up.localUdpConns[addr]; ok {
+		c.Close()
+	}
+	delete(up.localUdpConns, addr)
+}

+ 1 - 0
src/models/consts/consts.go

@@ -37,4 +37,5 @@ const (
 	NewCtlConnRes
 	HeartbeatReq
 	HeartbeatRes
+	NewWorkConnUdp
 )

+ 7 - 7
src/models/msg/process.go

@@ -53,9 +53,9 @@ func Join(c1 *conn.Conn, c2 *conn.Conn) {
 }
 
 // join two connections and do some operations
-func JoinMore(c1 *conn.Conn, c2 *conn.Conn, conf config.BaseConf, needRecord bool) {
+func JoinMore(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser, conf config.BaseConf, needRecord bool) {
 	var wait sync.WaitGroup
-	encryptPipe := func(from *conn.Conn, to *conn.Conn) {
+	encryptPipe := func(from io.ReadCloser, to io.WriteCloser) {
 		defer from.Close()
 		defer to.Close()
 		defer wait.Done()
@@ -64,7 +64,7 @@ func JoinMore(c1 *conn.Conn, c2 *conn.Conn, conf config.BaseConf, needRecord boo
 		pipeEncrypt(from, to, conf, needRecord)
 	}
 
-	decryptPipe := func(to *conn.Conn, from *conn.Conn) {
+	decryptPipe := func(to io.ReadCloser, from io.WriteCloser) {
 		defer from.Close()
 		defer to.Close()
 		defer wait.Done()
@@ -109,7 +109,7 @@ func unpkgMsg(data []byte) (int, []byte, []byte) {
 }
 
 // decrypt msg from reader, then write into writer
-func pipeDecrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bool) (err error) {
+func pipeDecrypt(r io.Reader, w io.Writer, conf config.BaseConf, needRecord bool) (err error) {
 	laes := new(pcrypto.Pcrypto)
 	key := conf.AuthToken
 	if conf.PrivilegeMode {
@@ -175,7 +175,7 @@ func pipeDecrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bo
 			}
 		}
 
-		_, err = w.WriteBytes(res)
+		_, err = w.Write(res)
 		if err != nil {
 			return err
 		}
@@ -192,7 +192,7 @@ func pipeDecrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bo
 }
 
 // recvive msg from reader, then encrypt msg into writer
-func pipeEncrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bool) (err error) {
+func pipeEncrypt(r io.Reader, w io.Writer, conf config.BaseConf, needRecord bool) (err error) {
 	laes := new(pcrypto.Pcrypto)
 	key := conf.AuthToken
 	if conf.PrivilegeMode {
@@ -247,7 +247,7 @@ func pipeEncrypt(r *conn.Conn, w *conn.Conn, conf config.BaseConf, needRecord bo
 		}
 
 		res = pkgMsg(res)
-		_, err = w.WriteBytes(res)
+		_, err = w.Write(res)
 		if err != nil {
 			return err
 		}

+ 72 - 0
src/models/msg/udp.go

@@ -0,0 +1,72 @@
+// Copyright 2016 fatedier, fatedier@gmail.com
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package msg
+
+import (
+	"encoding/base64"
+	"encoding/json"
+	"net"
+)
+
+type UdpPacket struct {
+	Content []byte       `json:"-"`
+	Src     *net.UDPAddr `json:"-"`
+	Dst     *net.UDPAddr `json:"-"`
+
+	EncodeContent string `json:"content"`
+	SrcStr        string `json:"src"`
+	DstStr        string `json:"dst"`
+}
+
+func NewUdpPacket(content []byte, src, dst *net.UDPAddr) *UdpPacket {
+	up := &UdpPacket{
+		Src:           src,
+		Dst:           dst,
+		EncodeContent: base64.StdEncoding.EncodeToString(content),
+		SrcStr:        src.String(),
+		DstStr:        dst.String(),
+	}
+	return up
+}
+
+// parse one udp packet struct to bytes
+func (up *UdpPacket) Pack() []byte {
+	b, _ := json.Marshal(up)
+	return b
+}
+
+// parse from bytes to UdpPacket struct
+func (up *UdpPacket) UnPack(packet []byte) error {
+	err := json.Unmarshal(packet, &up)
+	if err != nil {
+		return err
+	}
+
+	up.Content, err = base64.StdEncoding.DecodeString(up.EncodeContent)
+	if err != nil {
+		return err
+	}
+
+	up.Src, err = net.ResolveUDPAddr("udp", up.SrcStr)
+	if err != nil {
+		return err
+	}
+
+	up.Dst, err = net.ResolveUDPAddr("udp", up.DstStr)
+	if err != nil {
+		return err
+	}
+	return nil
+}

+ 50 - 0
src/models/msg/udp_test.go

@@ -0,0 +1,50 @@
+// Copyright 2016 fatedier, fatedier@gmail.com
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package msg
+
+import (
+	"net"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+var (
+	content string = "udp packet test"
+	src     string = "1.1.1.1:1000"
+	dst     string = "2.2.2.2:2000"
+
+	udpMsg *UdpPacket
+)
+
+func init() {
+	srcAddr, _ := net.ResolveUDPAddr("udp", src)
+	dstAddr, _ := net.ResolveUDPAddr("udp", dst)
+	udpMsg = NewUdpPacket([]byte(content), srcAddr, dstAddr)
+}
+
+func TestPack(t *testing.T) {
+	assert := assert.New(t)
+	msg := udpMsg.Pack()
+	assert.Equal(string(msg), `{"content":"dWRwIHBhY2tldCB0ZXN0","src":"1.1.1.1:1000","dst":"2.2.2.2:2000"}`)
+}
+
+func TestUnpack(t *testing.T) {
+	assert := assert.New(t)
+	udpMsg.UnPack([]byte(`{"content":"dWRwIHBhY2tldCB0ZXN0","src":"1.1.1.1:1000","dst":"2.2.2.2:2000"}`))
+	assert.Equal(content, string(udpMsg.Content))
+	assert.Equal(src, udpMsg.Src.String())
+	assert.Equal(dst, udpMsg.Dst.String())
+}

+ 3 - 3
src/models/server/config.go

@@ -240,7 +240,7 @@ func loadProxyConf(confFile string) (proxyServers map[string]*ProxyServer, err e
 
 			proxyServer.Type, ok = section["type"]
 			if ok {
-				if proxyServer.Type != "tcp" && proxyServer.Type != "http" && proxyServer.Type != "https" {
+				if proxyServer.Type != "tcp" && proxyServer.Type != "http" && proxyServer.Type != "https" && proxyServer.Type != "udp" {
 					return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] type error", proxyServer.Name)
 				}
 			} else {
@@ -252,8 +252,8 @@ func loadProxyConf(confFile string) (proxyServers map[string]*ProxyServer, err e
 				return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] no auth_token found", proxyServer.Name)
 			}
 
-			// for tcp
-			if proxyServer.Type == "tcp" {
+			// for tcp and udp
+			if proxyServer.Type == "tcp" || proxyServer.Type == "udp" {
 				proxyServer.BindAddr, ok = section["bind_addr"]
 				if !ok {
 					proxyServer.BindAddr = "0.0.0.0"

+ 125 - 32
src/models/server/server.go

@@ -16,6 +16,7 @@ package server
 
 import (
 	"fmt"
+	"net"
 	"sync"
 	"time"
 
@@ -25,6 +26,7 @@ import (
 	"github.com/fatedier/frp/src/models/msg"
 	"github.com/fatedier/frp/src/utils/conn"
 	"github.com/fatedier/frp/src/utils/log"
+	"github.com/fatedier/frp/src/utils/pool"
 )
 
 type Listener interface {
@@ -38,13 +40,17 @@ type ProxyServer struct {
 	ListenPort    int64
 	CustomDomains []string
 
-	Status       int64
-	CtlConn      *conn.Conn      // control connection with frpc
-	listeners    []Listener      // accept new connection from remote users
-	ctlMsgChan   chan int64      // every time accept a new user conn, put "1" to the channel
-	workConnChan chan *conn.Conn // get new work conns from control goroutine
-	mutex        sync.RWMutex
-	closeChan    chan struct{} // for notify other goroutines that the proxy is closed by close this channel
+	Status      int64
+	CtlConn     *conn.Conn // control connection with frpc
+	WorkConnUdp *conn.Conn // work connection for udp
+
+	udpConn       *net.UDPConn
+	listeners     []Listener      // accept new connection from remote users
+	ctlMsgChan    chan int64      // every time accept a new user conn, put "1" to the channel
+	workConnChan  chan *conn.Conn // get new work conns from control goroutine
+	udpSenderChan chan *msg.UdpPacket
+	mutex         sync.RWMutex
+	closeChan     chan struct{} // close this channel for notifying other goroutines that the proxy is closed
 }
 
 func NewProxyServer() (p *ProxyServer) {
@@ -83,6 +89,7 @@ func (p *ProxyServer) Init() {
 	metric.SetStatus(p.Name, p.Status)
 	p.workConnChan = make(chan *conn.Conn, p.PoolCount+10)
 	p.ctlMsgChan = make(chan int64, p.PoolCount+10)
+	p.udpSenderChan = make(chan *msg.UdpPacket, 1024)
 	p.listeners = make([]Listener, 0)
 	p.closeChan = make(chan struct{})
 	p.Unlock()
@@ -150,41 +157,68 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) {
 		go p.connectionPoolManager(p.closeChan)
 	}
 
-	// start a goroutine for every listener to accept user connection
-	for _, listener := range p.listeners {
-		go func(l Listener) {
+	if p.Type == "udp" {
+		// udp is special
+		p.udpConn, err = conn.ListenUDP(p.BindAddr, p.ListenPort)
+		if err != nil {
+			log.Warn("ProxyName [%s], listen udp port error: %v", p.Name, err)
+			return err
+		}
+		go func() {
 			for {
-				// block
-				// if listener is closed, err returned
-				c, err := l.Accept()
+				buf := pool.GetBuf(2048)
+				n, remoteAddr, err := p.udpConn.ReadFromUDP(buf)
 				if err != nil {
-					log.Info("ProxyName [%s], listener is closed", p.Name)
+					log.Info("ProxyName [%s], udp listener is closed", p.Name)
 					return
 				}
-				log.Debug("ProxyName [%s], get one new user conn [%s]", p.Name, c.GetRemoteAddr())
-
-				if p.Status != consts.Working {
-					log.Debug("ProxyName [%s] is not working, new user conn close", p.Name)
-					c.Close()
-					return
+				localAddr, _ := net.ResolveUDPAddr("udp", p.udpConn.LocalAddr().String())
+				udpPacket := msg.NewUdpPacket(buf[0:n], remoteAddr, localAddr)
+				select {
+				case p.udpSenderChan <- udpPacket:
+				default:
+					log.Warn("ProxyName [%s], udp sender channel is full", p.Name)
 				}
-
-				go func(userConn *conn.Conn) {
-					workConn, err := p.getWorkConn()
+				pool.PutBuf(buf)
+			}
+		}()
+	} else {
+		// start a goroutine for every listener to accept user connection
+		for _, listener := range p.listeners {
+			go func(l Listener) {
+				for {
+					// block
+					// if listener is closed, err returned
+					c, err := l.Accept()
 					if err != nil {
+						log.Info("ProxyName [%s], listener is closed", p.Name)
 						return
 					}
+					log.Debug("ProxyName [%s], get one new user conn [%s]", p.Name, c.GetRemoteAddr())
 
-					// message will be transferred to another without modifying
-					// l means local, r means remote
-					log.Debug("Join two connections, (l[%s] r[%s]) (l[%s] r[%s])", workConn.GetLocalAddr(), workConn.GetRemoteAddr(),
-						userConn.GetLocalAddr(), userConn.GetRemoteAddr())
+					if p.Status != consts.Working {
+						log.Debug("ProxyName [%s] is not working, new user conn close", p.Name)
+						c.Close()
+						return
+					}
 
-					needRecord := true
-					go msg.JoinMore(userConn, workConn, p.BaseConf, needRecord)
-				}(c)
-			}
-		}(listener)
+					go func(userConn *conn.Conn) {
+						workConn, err := p.getWorkConn()
+						if err != nil {
+							return
+						}
+
+						// message will be transferred to another without modifying
+						// l means local, r means remote
+						log.Debug("Join two connections, (l[%s] r[%s]) (l[%s] r[%s])", workConn.GetLocalAddr(), workConn.GetRemoteAddr(),
+							userConn.GetLocalAddr(), userConn.GetRemoteAddr())
+
+						needRecord := true
+						go msg.JoinMore(userConn, workConn, p.BaseConf, needRecord)
+					}(c)
+				}
+			}(listener)
+		}
 	}
 	return nil
 }
@@ -200,10 +234,18 @@ func (p *ProxyServer) Close() {
 		}
 		close(p.ctlMsgChan)
 		close(p.workConnChan)
+		close(p.udpSenderChan)
 		close(p.closeChan)
 		if p.CtlConn != nil {
 			p.CtlConn.Close()
 		}
+		if p.WorkConnUdp != nil {
+			p.WorkConnUdp.Close()
+		}
+		if p.udpConn != nil {
+			p.udpConn.Close()
+			p.udpConn = nil
+		}
 	}
 	metric.SetStatus(p.Name, p.Status)
 	// if the proxy created by PrivilegeMode, delete it when closed
@@ -228,7 +270,58 @@ func (p *ProxyServer) RegisterNewWorkConn(c *conn.Conn) {
 	case p.workConnChan <- c:
 	default:
 		log.Debug("ProxyName [%s], workConnChan is full, so close this work connection", p.Name)
+		c.Close()
+	}
+}
+
+// create a tcp connection for forwarding udp packages
+func (p *ProxyServer) RegisterNewWorkConnUdp(c *conn.Conn) {
+	if p.WorkConnUdp != nil && !p.WorkConnUdp.IsClosed() {
+		p.WorkConnUdp.Close()
 	}
+	p.WorkConnUdp = c
+
+	// read
+	go func() {
+		var (
+			buf string
+			err error
+		)
+		for {
+			buf, err = c.ReadLine()
+			if err != nil {
+				log.Warn("ProxyName [%s], work connection for udp closed", p.Name)
+				return
+			}
+			udpPacket := &msg.UdpPacket{}
+			err = udpPacket.UnPack([]byte(buf))
+			if err != nil {
+				log.Warn("ProxyName [%s], unpack udp packet error: %v", p.Name, err)
+				continue
+			}
+
+			// send to user
+			_, err = p.udpConn.WriteToUDP(udpPacket.Content, udpPacket.Dst)
+			if err != nil {
+				continue
+			}
+		}
+	}()
+
+	// write
+	go func() {
+		for {
+			udpPacket, ok := <-p.udpSenderChan
+			if !ok {
+				return
+			}
+			err := c.WriteString(string(udpPacket.Pack()) + "\n")
+			if err != nil {
+				log.Debug("ProxyName [%s], write to work connection for udp error: %v", p.Name, err)
+				return
+			}
+		}
+	}()
 }
 
 // When frps get one user connection, we get one work connection from the pool and return it.

+ 5 - 5
src/utils/conn/conn.go

@@ -202,12 +202,12 @@ func (c *Conn) ReadLine() (buff string, err error) {
 	return buff, err
 }
 
-func (c *Conn) WriteBytes(content []byte) (n int, err error) {
+func (c *Conn) Write(content []byte) (n int, err error) {
 	n, err = c.TcpConn.Write(content)
 	return
 }
 
-func (c *Conn) Write(content string) (err error) {
+func (c *Conn) WriteString(content string) (err error) {
 	_, err = c.TcpConn.Write([]byte(content))
 	return err
 }
@@ -220,13 +220,14 @@ func (c *Conn) SetReadDeadline(t time.Time) error {
 	return c.TcpConn.SetReadDeadline(t)
 }
 
-func (c *Conn) Close() {
+func (c *Conn) Close() error {
 	c.mutex.Lock()
+	defer c.mutex.Unlock()
 	if c.TcpConn != nil && c.closeFlag == false {
 		c.closeFlag = true
 		c.TcpConn.Close()
 	}
-	c.mutex.Unlock()
+	return nil
 }
 
 func (c *Conn) IsClosed() (closeFlag bool) {
@@ -245,7 +246,6 @@ func (c *Conn) CheckClosed() bool {
 	}
 	c.mutex.RUnlock()
 
-	// err := c.TcpConn.SetReadDeadline(time.Now().Add(100 * time.Microsecond))
 	err := c.TcpConn.SetReadDeadline(time.Now().Add(time.Millisecond))
 	if err != nil {
 		c.Close()

+ 29 - 0
src/utils/conn/udp_conn.go

@@ -0,0 +1,29 @@
+// Copyright 2016 fatedier, fatedier@gmail.com
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package conn
+
+import (
+	"fmt"
+	"net"
+)
+
+func ListenUDP(bindAddr string, bindPort int64) (conn *net.UDPConn, err error) {
+	udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
+	if err != nil {
+		return conn, err
+	}
+	conn, err = net.ListenUDP("udp", udpAddr)
+	return
+}

+ 1 - 1
test/echo_server.go

@@ -40,6 +40,6 @@ func echoWorker(c *conn.Conn) {
 			return
 		}
 
-		c.Write(buff)
+		c.WriteString(buff)
 	}
 }

+ 1 - 1
test/func_test.go

@@ -26,7 +26,7 @@ func TestEchoServer(t *testing.T) {
 	timer := time.Now().Add(time.Duration(5) * time.Second)
 	c.SetDeadline(timer)
 
-	c.Write(ECHO_TEST_STR)
+	c.WriteString(ECHO_TEST_STR)
 
 	buff, err := c.ReadLine()
 	if err != nil {