controller.go 8.9 KB


  1. // Copyright 2025 The frp Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package vnet
  15. import (
  16. "context"
  17. "encoding/base64"
  18. "fmt"
  19. "io"
  20. "net"
  21. "sync"
  22. "github.com/fatedier/golib/pool"
  23. "github.com/songgao/water/waterutil"
  24. "golang.org/x/net/ipv4"
  25. "golang.org/x/net/ipv6"
  26. v1 "github.com/fatedier/frp/pkg/config/v1"
  27. "github.com/fatedier/frp/pkg/util/log"
  28. "github.com/fatedier/frp/pkg/util/xlog"
  29. )
  30. const (
  31. maxPacketSize = 1420
  32. )
  33. type Controller struct {
  34. addr string
  35. tun io.ReadWriteCloser
  36. clientRouter *clientRouter // Route based on destination IP (client mode)
  37. serverRouter *serverRouter // Route based on source IP (server mode)
  38. }
  39. func NewController(cfg v1.VirtualNetConfig) *Controller {
  40. return &Controller{
  41. addr: cfg.Address,
  42. clientRouter: newClientRouter(),
  43. serverRouter: newServerRouter(),
  44. }
  45. }
  46. func (c *Controller) Init() error {
  47. tunDevice, err := OpenTun(context.Background(), c.addr)
  48. if err != nil {
  49. return err
  50. }
  51. c.tun = tunDevice
  52. return nil
  53. }
  54. func (c *Controller) Run() error {
  55. conn := c.tun
  56. for {
  57. buf := pool.GetBuf(maxPacketSize)
  58. n, err := conn.Read(buf)
  59. if err != nil {
  60. pool.PutBuf(buf)
  61. log.Warnf("vnet read from tun error: %v", err)
  62. return err
  63. }
  64. c.handlePacket(buf[:n])
  65. pool.PutBuf(buf)
  66. }
  67. }
  68. // handlePacket processes a single packet. The caller is responsible for managing the buffer.
  69. func (c *Controller) handlePacket(buf []byte) {
  70. log.Tracef("vnet read from tun [%d]: %s", len(buf), base64.StdEncoding.EncodeToString(buf))
  71. var src, dst net.IP
  72. switch {
  73. case waterutil.IsIPv4(buf):
  74. header, err := ipv4.ParseHeader(buf)
  75. if err != nil {
  76. log.Warnf("parse ipv4 header error:", err)
  77. return
  78. }
  79. src = header.Src
  80. dst = header.Dst
  81. log.Tracef("%s >> %s %d/%-4d %-4x %d",
  82. header.Src, header.Dst,
  83. header.Len, header.TotalLen, header.ID, header.Flags)
  84. case waterutil.IsIPv6(buf):
  85. header, err := ipv6.ParseHeader(buf)
  86. if err != nil {
  87. log.Warnf("parse ipv6 header error:", err)
  88. return
  89. }
  90. src = header.Src
  91. dst = header.Dst
  92. log.Tracef("%s >> %s %d %d",
  93. header.Src, header.Dst,
  94. header.PayloadLen, header.TrafficClass)
  95. default:
  96. log.Tracef("unknown packet, discarded(%d)", len(buf))
  97. return
  98. }
  99. targetConn, err := c.clientRouter.findConn(dst)
  100. if err == nil {
  101. if err := WriteMessage(targetConn, buf); err != nil {
  102. log.Warnf("write to client target conn error: %v", err)
  103. }
  104. return
  105. }
  106. targetConn, err = c.serverRouter.findConnBySrc(dst)
  107. if err == nil {
  108. if err := WriteMessage(targetConn, buf); err != nil {
  109. log.Warnf("write to server target conn error: %v", err)
  110. }
  111. return
  112. }
  113. log.Tracef("no route found for packet from %s to %s", src, dst)
  114. }
  115. func (c *Controller) Stop() error {
  116. return c.tun.Close()
  117. }
  118. // Client connection read loop
  119. func (c *Controller) readLoopClient(ctx context.Context, conn io.ReadWriteCloser) {
  120. xl := xlog.FromContextSafe(ctx)
  121. for {
  122. data, err := ReadMessage(conn)
  123. if err != nil {
  124. xl.Warnf("client read error: %v", err)
  125. return
  126. }
  127. if len(data) == 0 {
  128. continue
  129. }
  130. switch {
  131. case waterutil.IsIPv4(data):
  132. header, err := ipv4.ParseHeader(data)
  133. if err != nil {
  134. xl.Warnf("parse ipv4 header error: %v", err)
  135. continue
  136. }
  137. xl.Tracef("%s >> %s %d/%-4d %-4x %d",
  138. header.Src, header.Dst,
  139. header.Len, header.TotalLen, header.ID, header.Flags)
  140. case waterutil.IsIPv6(data):
  141. header, err := ipv6.ParseHeader(data)
  142. if err != nil {
  143. xl.Warnf("parse ipv6 header error: %v", err)
  144. continue
  145. }
  146. xl.Tracef("%s >> %s %d %d",
  147. header.Src, header.Dst,
  148. header.PayloadLen, header.TrafficClass)
  149. default:
  150. xl.Tracef("unknown packet, discarded(%d)", len(data))
  151. continue
  152. }
  153. xl.Tracef("vnet write to tun (client) [%d]: %s", len(data), base64.StdEncoding.EncodeToString(data))
  154. _, err = c.tun.Write(data)
  155. if err != nil {
  156. xl.Warnf("client write tun error: %v", err)
  157. }
  158. }
  159. }
  160. // Server connection read loop
  161. func (c *Controller) readLoopServer(ctx context.Context, conn io.ReadWriteCloser) {
  162. xl := xlog.FromContextSafe(ctx)
  163. for {
  164. data, err := ReadMessage(conn)
  165. if err != nil {
  166. xl.Warnf("server read error: %v", err)
  167. return
  168. }
  169. if len(data) == 0 {
  170. continue
  171. }
  172. // Register source IP to connection mapping
  173. if waterutil.IsIPv4(data) || waterutil.IsIPv6(data) {
  174. var src net.IP
  175. if waterutil.IsIPv4(data) {
  176. header, err := ipv4.ParseHeader(data)
  177. if err == nil {
  178. src = header.Src
  179. c.serverRouter.registerSrcIP(src, conn)
  180. }
  181. } else {
  182. header, err := ipv6.ParseHeader(data)
  183. if err == nil {
  184. src = header.Src
  185. c.serverRouter.registerSrcIP(src, conn)
  186. }
  187. }
  188. }
  189. xl.Tracef("vnet write to tun (server) [%d]: %s", len(data), base64.StdEncoding.EncodeToString(data))
  190. _, err = c.tun.Write(data)
  191. if err != nil {
  192. xl.Warnf("server write tun error: %v", err)
  193. }
  194. }
  195. }
  196. // RegisterClientRoute Register client route (based on destination IP CIDR)
  197. func (c *Controller) RegisterClientRoute(ctx context.Context, name string, routes []net.IPNet, conn io.ReadWriteCloser) error {
  198. if err := c.clientRouter.addRoute(name, routes, conn); err != nil {
  199. return err
  200. }
  201. go c.readLoopClient(ctx, conn)
  202. return nil
  203. }
  204. // RegisterServerConn Register server connection (dynamically associates with source IPs)
  205. func (c *Controller) RegisterServerConn(ctx context.Context, name string, conn io.ReadWriteCloser) error {
  206. if err := c.serverRouter.addConn(name, conn); err != nil {
  207. return err
  208. }
  209. go c.readLoopServer(ctx, conn)
  210. return nil
  211. }
  212. // UnregisterServerConn Remove server connection from routing table
  213. func (c *Controller) UnregisterServerConn(name string) {
  214. c.serverRouter.delConn(name)
  215. }
  216. // UnregisterClientRoute Remove client route from routing table
  217. func (c *Controller) UnregisterClientRoute(name string) {
  218. c.clientRouter.delRoute(name)
  219. }
  220. // ParseRoutes Convert route strings to IPNet objects
  221. func ParseRoutes(routeStrings []string) ([]net.IPNet, error) {
  222. routes := make([]net.IPNet, 0, len(routeStrings))
  223. for _, r := range routeStrings {
  224. _, ipNet, err := net.ParseCIDR(r)
  225. if err != nil {
  226. return nil, fmt.Errorf("parse route %s error: %v", r, err)
  227. }
  228. routes = append(routes, *ipNet)
  229. }
  230. return routes, nil
  231. }
  232. // Client router (based on destination IP routing)
  233. type clientRouter struct {
  234. routes map[string]*routeElement
  235. mu sync.RWMutex
  236. }
  237. func newClientRouter() *clientRouter {
  238. return &clientRouter{
  239. routes: make(map[string]*routeElement),
  240. }
  241. }
  242. func (r *clientRouter) addRoute(name string, routes []net.IPNet, conn io.ReadWriteCloser) error {
  243. r.mu.Lock()
  244. defer r.mu.Unlock()
  245. r.routes[name] = &routeElement{
  246. name: name,
  247. routes: routes,
  248. conn: conn,
  249. }
  250. return nil
  251. }
  252. func (r *clientRouter) findConn(dst net.IP) (io.Writer, error) {
  253. r.mu.RLock()
  254. defer r.mu.RUnlock()
  255. for _, re := range r.routes {
  256. for _, route := range re.routes {
  257. if route.Contains(dst) {
  258. return re.conn, nil
  259. }
  260. }
  261. }
  262. return nil, fmt.Errorf("no route found for destination %s", dst)
  263. }
  264. func (r *clientRouter) delRoute(name string) {
  265. r.mu.Lock()
  266. defer r.mu.Unlock()
  267. delete(r.routes, name)
  268. }
  269. // Server router (based on source IP routing)
  270. type serverRouter struct {
  271. namedConns map[string]io.ReadWriteCloser // Name to connection mapping
  272. srcIPConns map[string]io.Writer // Source IP string to connection mapping
  273. mu sync.RWMutex
  274. }
  275. func newServerRouter() *serverRouter {
  276. return &serverRouter{
  277. namedConns: make(map[string]io.ReadWriteCloser),
  278. srcIPConns: make(map[string]io.Writer),
  279. }
  280. }
  281. func (r *serverRouter) addConn(name string, conn io.ReadWriteCloser) error {
  282. r.mu.Lock()
  283. original, ok := r.namedConns[name]
  284. r.namedConns[name] = conn
  285. r.mu.Unlock()
  286. if ok {
  287. // Close the original connection if it exists
  288. _ = original.Close()
  289. }
  290. return nil
  291. }
  292. func (r *serverRouter) findConnBySrc(src net.IP) (io.Writer, error) {
  293. r.mu.RLock()
  294. defer r.mu.RUnlock()
  295. conn, exists := r.srcIPConns[src.String()]
  296. if !exists {
  297. return nil, fmt.Errorf("no route found for source %s", src)
  298. }
  299. return conn, nil
  300. }
  301. func (r *serverRouter) registerSrcIP(src net.IP, conn io.Writer) {
  302. r.mu.Lock()
  303. defer r.mu.Unlock()
  304. r.srcIPConns[src.String()] = conn
  305. }
  306. func (r *serverRouter) delConn(name string) {
  307. r.mu.Lock()
  308. defer r.mu.Unlock()
  309. delete(r.namedConns, name)
  310. // Note: We don't delete mappings from srcIPConns because we don't know which source IPs are associated with this connection
  311. // This might cause dangling references, but they will be overwritten on new connections or restart
  312. }
  313. type routeElement struct {
  314. name string
  315. routes []net.IPNet
  316. conn io.ReadWriteCloser
  317. }