|
@@ -16,30 +16,22 @@ package client
|
|
|
|
|
|
import (
|
|
import (
|
|
"context"
|
|
"context"
|
|
- "crypto/tls"
|
|
|
|
"errors"
|
|
"errors"
|
|
"fmt"
|
|
"fmt"
|
|
- "io"
|
|
|
|
"net"
|
|
"net"
|
|
"runtime"
|
|
"runtime"
|
|
"strconv"
|
|
"strconv"
|
|
- "strings"
|
|
|
|
"sync"
|
|
"sync"
|
|
"time"
|
|
"time"
|
|
|
|
|
|
"github.com/fatedier/golib/crypto"
|
|
"github.com/fatedier/golib/crypto"
|
|
- libdial "github.com/fatedier/golib/net/dial"
|
|
|
|
- fmux "github.com/hashicorp/yamux"
|
|
|
|
- quic "github.com/quic-go/quic-go"
|
|
|
|
"github.com/samber/lo"
|
|
"github.com/samber/lo"
|
|
|
|
|
|
"github.com/fatedier/frp/assets"
|
|
"github.com/fatedier/frp/assets"
|
|
"github.com/fatedier/frp/pkg/auth"
|
|
"github.com/fatedier/frp/pkg/auth"
|
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
|
"github.com/fatedier/frp/pkg/msg"
|
|
"github.com/fatedier/frp/pkg/msg"
|
|
- "github.com/fatedier/frp/pkg/transport"
|
|
|
|
"github.com/fatedier/frp/pkg/util/log"
|
|
"github.com/fatedier/frp/pkg/util/log"
|
|
- utilnet "github.com/fatedier/frp/pkg/util/net"
|
|
|
|
"github.com/fatedier/frp/pkg/util/version"
|
|
"github.com/fatedier/frp/pkg/util/version"
|
|
"github.com/fatedier/frp/pkg/util/wait"
|
|
"github.com/fatedier/frp/pkg/util/wait"
|
|
"github.com/fatedier/frp/pkg/util/xlog"
|
|
"github.com/fatedier/frp/pkg/util/xlog"
|
|
@@ -75,6 +67,9 @@ type Service struct {
|
|
// call cancel to stop service
|
|
// call cancel to stop service
|
|
cancel context.CancelFunc
|
|
cancel context.CancelFunc
|
|
gracefulDuration time.Duration
|
|
gracefulDuration time.Duration
|
|
|
|
+
|
|
|
|
+ connectorCreator func(context.Context, *v1.ClientCommonConfig) Connector
|
|
|
|
+ inWorkConnCallback func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool
|
|
}
|
|
}
|
|
|
|
|
|
func NewService(
|
|
func NewService(
|
|
@@ -84,15 +79,24 @@ func NewService(
|
|
cfgFile string,
|
|
cfgFile string,
|
|
) *Service {
|
|
) *Service {
|
|
return &Service{
|
|
return &Service{
|
|
- authSetter: auth.NewAuthSetter(cfg.Auth),
|
|
|
|
- cfg: cfg,
|
|
|
|
- cfgFile: cfgFile,
|
|
|
|
- pxyCfgs: pxyCfgs,
|
|
|
|
- visitorCfgs: visitorCfgs,
|
|
|
|
- ctx: context.Background(),
|
|
|
|
|
|
+ authSetter: auth.NewAuthSetter(cfg.Auth),
|
|
|
|
+ cfg: cfg,
|
|
|
|
+ cfgFile: cfgFile,
|
|
|
|
+ pxyCfgs: pxyCfgs,
|
|
|
|
+ visitorCfgs: visitorCfgs,
|
|
|
|
+ ctx: context.Background(),
|
|
|
|
+ connectorCreator: NewConnector,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+func (svr *Service) SetConnectorCreator(h func(context.Context, *v1.ClientCommonConfig) Connector) {
|
|
|
|
+ svr.connectorCreator = h
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (svr *Service) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) {
|
|
|
|
+ svr.inWorkConnCallback = cb
|
|
|
|
+}
|
|
|
|
+
|
|
func (svr *Service) GetController() *Control {
|
|
func (svr *Service) GetController() *Control {
|
|
svr.ctlMu.RLock()
|
|
svr.ctlMu.RLock()
|
|
defer svr.ctlMu.RUnlock()
|
|
defer svr.ctlMu.RUnlock()
|
|
@@ -101,7 +105,7 @@ func (svr *Service) GetController() *Control {
|
|
|
|
|
|
func (svr *Service) Run(ctx context.Context) error {
|
|
func (svr *Service) Run(ctx context.Context) error {
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
- svr.ctx = xlog.NewContext(ctx, xlog.New())
|
|
|
|
|
|
+ svr.ctx = xlog.NewContext(ctx, xlog.FromContextSafe(ctx))
|
|
svr.cancel = cancel
|
|
svr.cancel = cancel
|
|
|
|
|
|
// set custom DNSServer
|
|
// set custom DNSServer
|
|
@@ -173,21 +177,20 @@ func (svr *Service) keepControllerWorking() {
|
|
// login creates a connection to frps and registers it self as a client
|
|
// login creates a connection to frps and registers it self as a client
|
|
// conn: control connection
|
|
// conn: control connection
|
|
// session: if it's not nil, using tcp mux
|
|
// session: if it's not nil, using tcp mux
|
|
-func (svr *Service) login() (conn net.Conn, cm *ConnectionManager, err error) {
|
|
|
|
|
|
+func (svr *Service) login() (conn net.Conn, connector Connector, err error) {
|
|
xl := xlog.FromContextSafe(svr.ctx)
|
|
xl := xlog.FromContextSafe(svr.ctx)
|
|
- cm = NewConnectionManager(svr.ctx, svr.cfg)
|
|
|
|
-
|
|
|
|
- if err = cm.OpenConnection(); err != nil {
|
|
|
|
|
|
+ connector = svr.connectorCreator(svr.ctx, svr.cfg)
|
|
|
|
+ if err = connector.Open(); err != nil {
|
|
return nil, nil, err
|
|
return nil, nil, err
|
|
}
|
|
}
|
|
|
|
|
|
defer func() {
|
|
defer func() {
|
|
if err != nil {
|
|
if err != nil {
|
|
- cm.Close()
|
|
|
|
|
|
+ connector.Close()
|
|
}
|
|
}
|
|
}()
|
|
}()
|
|
|
|
|
|
- conn, err = cm.Connect()
|
|
|
|
|
|
+ conn, err = connector.Connect()
|
|
if err != nil {
|
|
if err != nil {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
@@ -226,8 +229,7 @@ func (svr *Service) login() (conn net.Conn, cm *ConnectionManager, err error) {
|
|
}
|
|
}
|
|
|
|
|
|
svr.runID = loginRespMsg.RunID
|
|
svr.runID = loginRespMsg.RunID
|
|
- xl.ResetPrefixes()
|
|
|
|
- xl.AppendPrefix(svr.runID)
|
|
|
|
|
|
+ xl.AddPrefix(xlog.LogPrefix{Name: "runID", Value: svr.runID})
|
|
|
|
|
|
xl.Info("login to server success, get run id [%s]", loginRespMsg.RunID)
|
|
xl.Info("login to server success, get run id [%s]", loginRespMsg.RunID)
|
|
return
|
|
return
|
|
@@ -239,7 +241,7 @@ func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginE
|
|
|
|
|
|
loginFunc := func() error {
|
|
loginFunc := func() error {
|
|
xl.Info("try to connect to server...")
|
|
xl.Info("try to connect to server...")
|
|
- conn, cm, err := svr.login()
|
|
|
|
|
|
+ conn, connector, err := svr.login()
|
|
if err != nil {
|
|
if err != nil {
|
|
xl.Warn("connect to server error: %v", err)
|
|
xl.Warn("connect to server error: %v", err)
|
|
if firstLoginExit {
|
|
if firstLoginExit {
|
|
@@ -248,13 +250,14 @@ func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginE
|
|
return err
|
|
return err
|
|
}
|
|
}
|
|
|
|
|
|
- ctl, err := NewControl(svr.ctx, svr.runID, conn, cm,
|
|
|
|
|
|
+ ctl, err := NewControl(svr.ctx, svr.runID, conn, connector,
|
|
svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.authSetter)
|
|
svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.authSetter)
|
|
if err != nil {
|
|
if err != nil {
|
|
conn.Close()
|
|
conn.Close()
|
|
xl.Error("NewControl error: %v", err)
|
|
xl.Error("NewControl error: %v", err)
|
|
return err
|
|
return err
|
|
}
|
|
}
|
|
|
|
+ ctl.SetInWorkConnCallback(svr.inWorkConnCallback)
|
|
|
|
|
|
ctl.Run()
|
|
ctl.Run()
|
|
// close and replace previous control
|
|
// close and replace previous control
|
|
@@ -314,184 +317,3 @@ func (svr *Service) stop() {
|
|
svr.ctl = nil
|
|
svr.ctl = nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
-
|
|
|
|
-// ConnectionManager is a wrapper for establishing connections to the server.
|
|
|
|
-type ConnectionManager struct {
|
|
|
|
- ctx context.Context
|
|
|
|
- cfg *v1.ClientCommonConfig
|
|
|
|
-
|
|
|
|
- muxSession *fmux.Session
|
|
|
|
- quicConn quic.Connection
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
-func NewConnectionManager(ctx context.Context, cfg *v1.ClientCommonConfig) *ConnectionManager {
|
|
|
|
- return &ConnectionManager{
|
|
|
|
- ctx: ctx,
|
|
|
|
- cfg: cfg,
|
|
|
|
- }
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
-// OpenConnection opens a underlying connection to the server.
|
|
|
|
-// The underlying connection is either a TCP connection or a QUIC connection.
|
|
|
|
-// After the underlying connection is established, you can call Connect() to get a stream.
|
|
|
|
-// If TCPMux isn't enabled, the underlying connection is nil, you will get a new real TCP connection every time you call Connect().
|
|
|
|
-func (cm *ConnectionManager) OpenConnection() error {
|
|
|
|
- xl := xlog.FromContextSafe(cm.ctx)
|
|
|
|
-
|
|
|
|
- // special for quic
|
|
|
|
- if strings.EqualFold(cm.cfg.Transport.Protocol, "quic") {
|
|
|
|
- var tlsConfig *tls.Config
|
|
|
|
- var err error
|
|
|
|
- sn := cm.cfg.Transport.TLS.ServerName
|
|
|
|
- if sn == "" {
|
|
|
|
- sn = cm.cfg.ServerAddr
|
|
|
|
- }
|
|
|
|
- if lo.FromPtr(cm.cfg.Transport.TLS.Enable) {
|
|
|
|
- tlsConfig, err = transport.NewClientTLSConfig(
|
|
|
|
- cm.cfg.Transport.TLS.CertFile,
|
|
|
|
- cm.cfg.Transport.TLS.KeyFile,
|
|
|
|
- cm.cfg.Transport.TLS.TrustedCaFile,
|
|
|
|
- sn)
|
|
|
|
- } else {
|
|
|
|
- tlsConfig, err = transport.NewClientTLSConfig("", "", "", sn)
|
|
|
|
- }
|
|
|
|
- if err != nil {
|
|
|
|
- xl.Warn("fail to build tls configuration, err: %v", err)
|
|
|
|
- return err
|
|
|
|
- }
|
|
|
|
- tlsConfig.NextProtos = []string{"frp"}
|
|
|
|
-
|
|
|
|
- conn, err := quic.DialAddr(
|
|
|
|
- cm.ctx,
|
|
|
|
- net.JoinHostPort(cm.cfg.ServerAddr, strconv.Itoa(cm.cfg.ServerPort)),
|
|
|
|
- tlsConfig, &quic.Config{
|
|
|
|
- MaxIdleTimeout: time.Duration(cm.cfg.Transport.QUIC.MaxIdleTimeout) * time.Second,
|
|
|
|
- MaxIncomingStreams: int64(cm.cfg.Transport.QUIC.MaxIncomingStreams),
|
|
|
|
- KeepAlivePeriod: time.Duration(cm.cfg.Transport.QUIC.KeepalivePeriod) * time.Second,
|
|
|
|
- })
|
|
|
|
- if err != nil {
|
|
|
|
- return err
|
|
|
|
- }
|
|
|
|
- cm.quicConn = conn
|
|
|
|
- return nil
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- if !lo.FromPtr(cm.cfg.Transport.TCPMux) {
|
|
|
|
- return nil
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- conn, err := cm.realConnect()
|
|
|
|
- if err != nil {
|
|
|
|
- return err
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- fmuxCfg := fmux.DefaultConfig()
|
|
|
|
- fmuxCfg.KeepAliveInterval = time.Duration(cm.cfg.Transport.TCPMuxKeepaliveInterval) * time.Second
|
|
|
|
- fmuxCfg.LogOutput = io.Discard
|
|
|
|
- fmuxCfg.MaxStreamWindowSize = 6 * 1024 * 1024
|
|
|
|
- session, err := fmux.Client(conn, fmuxCfg)
|
|
|
|
- if err != nil {
|
|
|
|
- return err
|
|
|
|
- }
|
|
|
|
- cm.muxSession = session
|
|
|
|
- return nil
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
-// Connect returns a stream from the underlying connection, or a new TCP connection if TCPMux isn't enabled.
|
|
|
|
-func (cm *ConnectionManager) Connect() (net.Conn, error) {
|
|
|
|
- if cm.quicConn != nil {
|
|
|
|
- stream, err := cm.quicConn.OpenStreamSync(context.Background())
|
|
|
|
- if err != nil {
|
|
|
|
- return nil, err
|
|
|
|
- }
|
|
|
|
- return utilnet.QuicStreamToNetConn(stream, cm.quicConn), nil
|
|
|
|
- } else if cm.muxSession != nil {
|
|
|
|
- stream, err := cm.muxSession.OpenStream()
|
|
|
|
- if err != nil {
|
|
|
|
- return nil, err
|
|
|
|
- }
|
|
|
|
- return stream, nil
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- return cm.realConnect()
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
-func (cm *ConnectionManager) realConnect() (net.Conn, error) {
|
|
|
|
- xl := xlog.FromContextSafe(cm.ctx)
|
|
|
|
- var tlsConfig *tls.Config
|
|
|
|
- var err error
|
|
|
|
- tlsEnable := lo.FromPtr(cm.cfg.Transport.TLS.Enable)
|
|
|
|
- if cm.cfg.Transport.Protocol == "wss" {
|
|
|
|
- tlsEnable = true
|
|
|
|
- }
|
|
|
|
- if tlsEnable {
|
|
|
|
- sn := cm.cfg.Transport.TLS.ServerName
|
|
|
|
- if sn == "" {
|
|
|
|
- sn = cm.cfg.ServerAddr
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- tlsConfig, err = transport.NewClientTLSConfig(
|
|
|
|
- cm.cfg.Transport.TLS.CertFile,
|
|
|
|
- cm.cfg.Transport.TLS.KeyFile,
|
|
|
|
- cm.cfg.Transport.TLS.TrustedCaFile,
|
|
|
|
- sn)
|
|
|
|
- if err != nil {
|
|
|
|
- xl.Warn("fail to build tls configuration, err: %v", err)
|
|
|
|
- return nil, err
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- proxyType, addr, auth, err := libdial.ParseProxyURL(cm.cfg.Transport.ProxyURL)
|
|
|
|
- if err != nil {
|
|
|
|
- xl.Error("fail to parse proxy url")
|
|
|
|
- return nil, err
|
|
|
|
- }
|
|
|
|
- dialOptions := []libdial.DialOption{}
|
|
|
|
- protocol := cm.cfg.Transport.Protocol
|
|
|
|
- switch protocol {
|
|
|
|
- case "websocket":
|
|
|
|
- protocol = "tcp"
|
|
|
|
- dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: utilnet.DialHookWebsocket(protocol, "")}))
|
|
|
|
- dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{
|
|
|
|
- Hook: utilnet.DialHookCustomTLSHeadByte(tlsConfig != nil, lo.FromPtr(cm.cfg.Transport.TLS.DisableCustomTLSFirstByte)),
|
|
|
|
- }))
|
|
|
|
- dialOptions = append(dialOptions, libdial.WithTLSConfig(tlsConfig))
|
|
|
|
- case "wss":
|
|
|
|
- protocol = "tcp"
|
|
|
|
- dialOptions = append(dialOptions, libdial.WithTLSConfigAndPriority(100, tlsConfig))
|
|
|
|
- // Make sure that if it is wss, the websocket hook is executed after the tls hook.
|
|
|
|
- dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: utilnet.DialHookWebsocket(protocol, tlsConfig.ServerName), Priority: 110}))
|
|
|
|
- default:
|
|
|
|
- dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{
|
|
|
|
- Hook: utilnet.DialHookCustomTLSHeadByte(tlsConfig != nil, lo.FromPtr(cm.cfg.Transport.TLS.DisableCustomTLSFirstByte)),
|
|
|
|
- }))
|
|
|
|
- dialOptions = append(dialOptions, libdial.WithTLSConfig(tlsConfig))
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- if cm.cfg.Transport.ConnectServerLocalIP != "" {
|
|
|
|
- dialOptions = append(dialOptions, libdial.WithLocalAddr(cm.cfg.Transport.ConnectServerLocalIP))
|
|
|
|
- }
|
|
|
|
- dialOptions = append(dialOptions,
|
|
|
|
- libdial.WithProtocol(protocol),
|
|
|
|
- libdial.WithTimeout(time.Duration(cm.cfg.Transport.DialServerTimeout)*time.Second),
|
|
|
|
- libdial.WithKeepAlive(time.Duration(cm.cfg.Transport.DialServerKeepAlive)*time.Second),
|
|
|
|
- libdial.WithProxy(proxyType, addr),
|
|
|
|
- libdial.WithProxyAuth(auth),
|
|
|
|
- )
|
|
|
|
- conn, err := libdial.DialContext(
|
|
|
|
- cm.ctx,
|
|
|
|
- net.JoinHostPort(cm.cfg.ServerAddr, strconv.Itoa(cm.cfg.ServerPort)),
|
|
|
|
- dialOptions...,
|
|
|
|
- )
|
|
|
|
- return conn, err
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
-func (cm *ConnectionManager) Close() error {
|
|
|
|
- if cm.quicConn != nil {
|
|
|
|
- _ = cm.quicConn.CloseWithError(0, "")
|
|
|
|
- }
|
|
|
|
- if cm.muxSession != nil {
|
|
|
|
- _ = cm.muxSession.Close()
|
|
|
|
- }
|
|
|
|
- return nil
|
|
|
|
-}
|
|
|