udp.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. // Copyright 2017 fatedier, fatedier@gmail.com
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package udp
  15. import (
  16. "encoding/base64"
  17. "net"
  18. "sync"
  19. "time"
  20. "github.com/fatedier/golib/errors"
  21. "github.com/fatedier/golib/pool"
  22. "github.com/fatedier/frp/pkg/msg"
  23. netpkg "github.com/fatedier/frp/pkg/util/net"
  24. )
  25. func NewUDPPacket(buf []byte, laddr, raddr *net.UDPAddr) *msg.UDPPacket {
  26. return &msg.UDPPacket{
  27. Content: base64.StdEncoding.EncodeToString(buf),
  28. LocalAddr: laddr,
  29. RemoteAddr: raddr,
  30. }
  31. }
  32. func GetContent(m *msg.UDPPacket) (buf []byte, err error) {
  33. buf, err = base64.StdEncoding.DecodeString(m.Content)
  34. return
  35. }
  36. func ForwardUserConn(udpConn *net.UDPConn, readCh <-chan *msg.UDPPacket, sendCh chan<- *msg.UDPPacket, bufSize int) {
  37. // read
  38. go func() {
  39. for udpMsg := range readCh {
  40. buf, err := GetContent(udpMsg)
  41. if err != nil {
  42. continue
  43. }
  44. _, _ = udpConn.WriteToUDP(buf, udpMsg.RemoteAddr)
  45. }
  46. }()
  47. // write
  48. buf := pool.GetBuf(bufSize)
  49. defer pool.PutBuf(buf)
  50. for {
  51. n, remoteAddr, err := udpConn.ReadFromUDP(buf)
  52. if err != nil {
  53. return
  54. }
  55. // buf[:n] will be encoded to string, so the bytes can be reused
  56. udpMsg := NewUDPPacket(buf[:n], nil, remoteAddr)
  57. select {
  58. case sendCh <- udpMsg:
  59. default:
  60. }
  61. }
  62. }
  63. func Forwarder(dstAddr *net.UDPAddr, readCh <-chan *msg.UDPPacket, sendCh chan<- msg.Message, bufSize int, proxyProtocolVersion string) {
  64. var mu sync.RWMutex
  65. udpConnMap := make(map[string]*net.UDPConn)
  66. // read from dstAddr and write to sendCh
  67. writerFn := func(raddr *net.UDPAddr, udpConn *net.UDPConn) {
  68. addr := raddr.String()
  69. defer func() {
  70. mu.Lock()
  71. delete(udpConnMap, addr)
  72. mu.Unlock()
  73. udpConn.Close()
  74. }()
  75. buf := pool.GetBuf(bufSize)
  76. for {
  77. _ = udpConn.SetReadDeadline(time.Now().Add(30 * time.Second))
  78. n, _, err := udpConn.ReadFromUDP(buf)
  79. if err != nil {
  80. return
  81. }
  82. udpMsg := NewUDPPacket(buf[:n], nil, raddr)
  83. if err = errors.PanicToError(func() {
  84. select {
  85. case sendCh <- udpMsg:
  86. default:
  87. }
  88. }); err != nil {
  89. return
  90. }
  91. }
  92. }
  93. // read from readCh
  94. go func() {
  95. for udpMsg := range readCh {
  96. buf, err := GetContent(udpMsg)
  97. if err != nil {
  98. continue
  99. }
  100. mu.Lock()
  101. udpConn, ok := udpConnMap[udpMsg.RemoteAddr.String()]
  102. if !ok {
  103. udpConn, err = net.DialUDP("udp", nil, dstAddr)
  104. if err != nil {
  105. mu.Unlock()
  106. continue
  107. }
  108. udpConnMap[udpMsg.RemoteAddr.String()] = udpConn
  109. }
  110. mu.Unlock()
  111. // Add proxy protocol header if configured
  112. if proxyProtocolVersion != "" && udpMsg.RemoteAddr != nil {
  113. ppBuf, err := netpkg.BuildProxyProtocolHeader(udpMsg.RemoteAddr, dstAddr, proxyProtocolVersion)
  114. if err == nil {
  115. // Prepend proxy protocol header to the UDP payload
  116. finalBuf := make([]byte, len(ppBuf)+len(buf))
  117. copy(finalBuf, ppBuf)
  118. copy(finalBuf[len(ppBuf):], buf)
  119. buf = finalBuf
  120. }
  121. }
  122. _, err = udpConn.Write(buf)
  123. if err != nil {
  124. udpConn.Close()
  125. }
  126. if !ok {
  127. go writerFn(udpMsg.RemoteAddr, udpConn)
  128. }
  129. }
  130. }()
  131. }