|
@@ -31,6 +31,7 @@ import (
|
|
|
"github.com/fatedier/golib/crypto"
|
|
|
libdial "github.com/fatedier/golib/net/dial"
|
|
|
fmux "github.com/hashicorp/yamux"
|
|
|
+ quic "github.com/lucas-clemente/quic-go"
|
|
|
|
|
|
"github.com/fatedier/frp/assets"
|
|
|
"github.com/fatedier/frp/pkg/auth"
|
|
@@ -127,7 +128,7 @@ func (svr *Service) Run() error {
|
|
|
|
|
|
// login to frps
|
|
|
for {
|
|
|
- conn, session, err := svr.login()
|
|
|
+ conn, cm, err := svr.login()
|
|
|
if err != nil {
|
|
|
xl.Warn("login to server failed: %v", err)
|
|
|
|
|
@@ -139,7 +140,7 @@ func (svr *Service) Run() error {
|
|
|
util.RandomSleep(10*time.Second, 0.9, 1.1)
|
|
|
} else {
|
|
|
// login success
|
|
|
- ctl := NewControl(svr.ctx, svr.runID, conn, session, svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.serverUDPPort, svr.authSetter)
|
|
|
+ ctl := NewControl(svr.ctx, svr.runID, conn, cm, svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.serverUDPPort, svr.authSetter)
|
|
|
ctl.Run()
|
|
|
svr.ctlMu.Lock()
|
|
|
svr.ctl = ctl
|
|
@@ -207,7 +208,7 @@ func (svr *Service) keepControllerWorking() {
|
|
|
}
|
|
|
|
|
|
xl.Info("try to reconnect to server...")
|
|
|
- conn, session, err := svr.login()
|
|
|
+ conn, cm, err := svr.login()
|
|
|
if err != nil {
|
|
|
xl.Warn("reconnect to server error: %v, wait %v for another retry", err, delayTime)
|
|
|
util.RandomSleep(delayTime, 0.9, 1.1)
|
|
@@ -221,7 +222,7 @@ func (svr *Service) keepControllerWorking() {
|
|
|
// reconnect success, init delayTime
|
|
|
delayTime = time.Second
|
|
|
|
|
|
- ctl := NewControl(svr.ctx, svr.runID, conn, session, svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.serverUDPPort, svr.authSetter)
|
|
|
+ ctl := NewControl(svr.ctx, svr.runID, conn, cm, svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.serverUDPPort, svr.authSetter)
|
|
|
ctl.Run()
|
|
|
svr.ctlMu.Lock()
|
|
|
if svr.ctl != nil {
|
|
@@ -237,83 +238,23 @@ func (svr *Service) keepControllerWorking() {
|
|
|
// login creates a connection to frps and registers it self as a client
|
|
|
// conn: control connection
|
|
|
// session: if it's not nil, using tcp mux
|
|
|
-func (svr *Service) login() (conn net.Conn, session *fmux.Session, err error) {
|
|
|
+func (svr *Service) login() (conn net.Conn, cm *ConnectionManager, err error) {
|
|
|
xl := xlog.FromContextSafe(svr.ctx)
|
|
|
- var tlsConfig *tls.Config
|
|
|
- if svr.cfg.TLSEnable {
|
|
|
- sn := svr.cfg.TLSServerName
|
|
|
- if sn == "" {
|
|
|
- sn = svr.cfg.ServerAddr
|
|
|
- }
|
|
|
+ cm = NewConnectionManager(svr.ctx, &svr.cfg)
|
|
|
|
|
|
- tlsConfig, err = transport.NewClientTLSConfig(
|
|
|
- svr.cfg.TLSCertFile,
|
|
|
- svr.cfg.TLSKeyFile,
|
|
|
- svr.cfg.TLSTrustedCaFile,
|
|
|
- sn)
|
|
|
- if err != nil {
|
|
|
- xl.Warn("fail to build tls configuration when service login, err: %v", err)
|
|
|
- return
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- proxyType, addr, auth, err := libdial.ParseProxyURL(svr.cfg.HTTPProxy)
|
|
|
- if err != nil {
|
|
|
- xl.Error("fail to parse proxy url")
|
|
|
- return
|
|
|
- }
|
|
|
- dialOptions := []libdial.DialOption{}
|
|
|
- protocol := svr.cfg.Protocol
|
|
|
- if protocol == "websocket" {
|
|
|
- protocol = "tcp"
|
|
|
- dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: frpNet.DialHookWebsocket()}))
|
|
|
- }
|
|
|
- if svr.cfg.ConnectServerLocalIP != "" {
|
|
|
- dialOptions = append(dialOptions, libdial.WithLocalAddr(svr.cfg.ConnectServerLocalIP))
|
|
|
- }
|
|
|
- dialOptions = append(dialOptions,
|
|
|
- libdial.WithProtocol(protocol),
|
|
|
- libdial.WithTimeout(time.Duration(svr.cfg.DialServerTimeout)*time.Second),
|
|
|
- libdial.WithKeepAlive(time.Duration(svr.cfg.DialServerKeepAlive)*time.Second),
|
|
|
- libdial.WithProxy(proxyType, addr),
|
|
|
- libdial.WithProxyAuth(auth),
|
|
|
- libdial.WithTLSConfig(tlsConfig),
|
|
|
- libdial.WithAfterHook(libdial.AfterHook{
|
|
|
- Hook: frpNet.DialHookCustomTLSHeadByte(tlsConfig != nil, svr.cfg.DisableCustomTLSFirstByte),
|
|
|
- }),
|
|
|
- )
|
|
|
- conn, err = libdial.Dial(
|
|
|
- net.JoinHostPort(svr.cfg.ServerAddr, strconv.Itoa(svr.cfg.ServerPort)),
|
|
|
- dialOptions...,
|
|
|
- )
|
|
|
- if err != nil {
|
|
|
- return
|
|
|
+ if err = cm.OpenConnection(); err != nil {
|
|
|
+ return nil, nil, err
|
|
|
}
|
|
|
|
|
|
defer func() {
|
|
|
if err != nil {
|
|
|
- conn.Close()
|
|
|
- if session != nil {
|
|
|
- session.Close()
|
|
|
- }
|
|
|
+ cm.Close()
|
|
|
}
|
|
|
}()
|
|
|
|
|
|
- if svr.cfg.TCPMux {
|
|
|
- fmuxCfg := fmux.DefaultConfig()
|
|
|
- fmuxCfg.KeepAliveInterval = time.Duration(svr.cfg.TCPMuxKeepaliveInterval) * time.Second
|
|
|
- fmuxCfg.LogOutput = io.Discard
|
|
|
- session, err = fmux.Client(conn, fmuxCfg)
|
|
|
- if err != nil {
|
|
|
- return
|
|
|
- }
|
|
|
- stream, errRet := session.OpenStream()
|
|
|
- if errRet != nil {
|
|
|
- session.Close()
|
|
|
- err = errRet
|
|
|
- return
|
|
|
- }
|
|
|
- conn = stream
|
|
|
+ conn, err = cm.Connect()
|
|
|
+ if err != nil {
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
loginMsg := &msg.Login{
|
|
@@ -389,3 +330,155 @@ func (svr *Service) GracefulClose(d time.Duration) {
|
|
|
|
|
|
svr.cancel()
|
|
|
}
|
|
|
+
|
|
|
+type ConnectionManager struct {
|
|
|
+ ctx context.Context
|
|
|
+ cfg *config.ClientCommonConf
|
|
|
+
|
|
|
+ muxSession *fmux.Session
|
|
|
+ quicConn quic.Connection
|
|
|
+}
|
|
|
+
|
|
|
+func NewConnectionManager(ctx context.Context, cfg *config.ClientCommonConf) *ConnectionManager {
|
|
|
+ return &ConnectionManager{
|
|
|
+ ctx: ctx,
|
|
|
+ cfg: cfg,
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (cm *ConnectionManager) OpenConnection() error {
|
|
|
+ xl := xlog.FromContextSafe(cm.ctx)
|
|
|
+
|
|
|
+ // special for quic
|
|
|
+ if strings.EqualFold(cm.cfg.Protocol, "quic") {
|
|
|
+ var tlsConfig *tls.Config
|
|
|
+ var err error
|
|
|
+ sn := cm.cfg.TLSServerName
|
|
|
+ if sn == "" {
|
|
|
+ sn = cm.cfg.ServerAddr
|
|
|
+ }
|
|
|
+ if cm.cfg.TLSEnable {
|
|
|
+ tlsConfig, err = transport.NewClientTLSConfig(
|
|
|
+ cm.cfg.TLSCertFile,
|
|
|
+ cm.cfg.TLSKeyFile,
|
|
|
+ cm.cfg.TLSTrustedCaFile,
|
|
|
+ sn)
|
|
|
+ } else {
|
|
|
+ tlsConfig, err = transport.NewClientTLSConfig("", "", "", sn)
|
|
|
+ }
|
|
|
+ if err != nil {
|
|
|
+ xl.Warn("fail to build tls configuration, err: %v", err)
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ tlsConfig.NextProtos = []string{"frp"}
|
|
|
+
|
|
|
+ conn, err := quic.DialAddr(
|
|
|
+ net.JoinHostPort(cm.cfg.ServerAddr, strconv.Itoa(cm.cfg.ServerPort)),
|
|
|
+ tlsConfig, nil)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ cm.quicConn = conn
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ if !cm.cfg.TCPMux {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ conn, err := cm.realConnect()
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ fmuxCfg := fmux.DefaultConfig()
|
|
|
+ fmuxCfg.KeepAliveInterval = time.Duration(cm.cfg.TCPMuxKeepaliveInterval) * time.Second
|
|
|
+ fmuxCfg.LogOutput = io.Discard
|
|
|
+ session, err := fmux.Client(conn, fmuxCfg)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ cm.muxSession = session
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (cm *ConnectionManager) Connect() (net.Conn, error) {
|
|
|
+ if cm.quicConn != nil {
|
|
|
+ stream, err := cm.quicConn.OpenStreamSync(context.Background())
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ return frpNet.QuicStreamToNetConn(stream, cm.quicConn), nil
|
|
|
+ } else if cm.muxSession != nil {
|
|
|
+ stream, err := cm.muxSession.OpenStream()
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ return stream, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ return cm.realConnect()
|
|
|
+}
|
|
|
+
|
|
|
+func (cm *ConnectionManager) realConnect() (net.Conn, error) {
|
|
|
+ xl := xlog.FromContextSafe(cm.ctx)
|
|
|
+ var tlsConfig *tls.Config
|
|
|
+ var err error
|
|
|
+ if cm.cfg.TLSEnable {
|
|
|
+ sn := cm.cfg.TLSServerName
|
|
|
+ if sn == "" {
|
|
|
+ sn = cm.cfg.ServerAddr
|
|
|
+ }
|
|
|
+
|
|
|
+ tlsConfig, err = transport.NewClientTLSConfig(
|
|
|
+ cm.cfg.TLSCertFile,
|
|
|
+ cm.cfg.TLSKeyFile,
|
|
|
+ cm.cfg.TLSTrustedCaFile,
|
|
|
+ sn)
|
|
|
+ if err != nil {
|
|
|
+ xl.Warn("fail to build tls configuration, err: %v", err)
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ proxyType, addr, auth, err := libdial.ParseProxyURL(cm.cfg.HTTPProxy)
|
|
|
+ if err != nil {
|
|
|
+ xl.Error("fail to parse proxy url")
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ dialOptions := []libdial.DialOption{}
|
|
|
+ protocol := cm.cfg.Protocol
|
|
|
+ if protocol == "websocket" {
|
|
|
+ protocol = "tcp"
|
|
|
+ dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: frpNet.DialHookWebsocket()}))
|
|
|
+ }
|
|
|
+ if cm.cfg.ConnectServerLocalIP != "" {
|
|
|
+ dialOptions = append(dialOptions, libdial.WithLocalAddr(cm.cfg.ConnectServerLocalIP))
|
|
|
+ }
|
|
|
+ dialOptions = append(dialOptions,
|
|
|
+ libdial.WithProtocol(protocol),
|
|
|
+ libdial.WithTimeout(time.Duration(cm.cfg.DialServerTimeout)*time.Second),
|
|
|
+ libdial.WithKeepAlive(time.Duration(cm.cfg.DialServerKeepAlive)*time.Second),
|
|
|
+ libdial.WithProxy(proxyType, addr),
|
|
|
+ libdial.WithProxyAuth(auth),
|
|
|
+ libdial.WithTLSConfig(tlsConfig),
|
|
|
+ libdial.WithAfterHook(libdial.AfterHook{
|
|
|
+ Hook: frpNet.DialHookCustomTLSHeadByte(tlsConfig != nil, cm.cfg.DisableCustomTLSFirstByte),
|
|
|
+ }),
|
|
|
+ )
|
|
|
+ conn, err := libdial.Dial(
|
|
|
+ net.JoinHostPort(cm.cfg.ServerAddr, strconv.Itoa(cm.cfg.ServerPort)),
|
|
|
+ dialOptions...,
|
|
|
+ )
|
|
|
+ return conn, err
|
|
|
+}
|
|
|
+
|
|
|
+func (cm *ConnectionManager) Close() error {
|
|
|
+ if cm.quicConn != nil {
|
|
|
+ _ = cm.quicConn.CloseWithError(0, "")
|
|
|
+ }
|
|
|
+ if cm.muxSession != nil {
|
|
|
+ _ = cm.muxSession.Close()
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|