|
@@ -86,16 +86,14 @@ func NewService(
|
|
|
visitorCfgs map[string]config.VisitorConf,
|
|
|
cfgFile string,
|
|
|
) (svr *Service, err error) {
|
|
|
- ctx, cancel := context.WithCancel(context.Background())
|
|
|
svr = &Service{
|
|
|
authSetter: auth.NewAuthSetter(cfg.ClientConfig),
|
|
|
cfg: cfg,
|
|
|
cfgFile: cfgFile,
|
|
|
pxyCfgs: pxyCfgs,
|
|
|
visitorCfgs: visitorCfgs,
|
|
|
+ ctx: context.Background(),
|
|
|
exit: 0,
|
|
|
- ctx: xlog.NewContext(ctx, xlog.New()),
|
|
|
- cancel: cancel,
|
|
|
}
|
|
|
return
|
|
|
}
|
|
@@ -106,7 +104,11 @@ func (svr *Service) GetController() *Control {
|
|
|
return svr.ctl
|
|
|
}
|
|
|
|
|
|
-func (svr *Service) Run() error {
|
|
|
+func (svr *Service) Run(ctx context.Context) error {
|
|
|
+ ctx, cancel := context.WithCancel(ctx)
|
|
|
+ svr.ctx = xlog.NewContext(ctx, xlog.New())
|
|
|
+ svr.cancel = cancel
|
|
|
+
|
|
|
xl := xlog.FromContextSafe(svr.ctx)
|
|
|
|
|
|
// set custom DNSServer
|
|
@@ -135,7 +137,7 @@ func (svr *Service) Run() error {
|
|
|
if svr.cfg.LoginFailExit {
|
|
|
return err
|
|
|
}
|
|
|
- util.RandomSleep(10*time.Second, 0.9, 1.1)
|
|
|
+ util.RandomSleep(5*time.Second, 0.9, 1.1)
|
|
|
} else {
|
|
|
// login success
|
|
|
ctl := NewControl(svr.ctx, svr.runID, conn, cm, svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.authSetter)
|
|
@@ -161,6 +163,10 @@ func (svr *Service) Run() error {
|
|
|
log.Info("admin server listen on %s:%d", svr.cfg.AdminAddr, svr.cfg.AdminPort)
|
|
|
}
|
|
|
<-svr.ctx.Done()
|
|
|
+ // service context may not be canceled by svr.Close(), we should call it here to release resources
|
|
|
+ if atomic.LoadUint32(&svr.exit) == 0 {
|
|
|
+ svr.Close()
|
|
|
+ }
|
|
|
return nil
|
|
|
}
|
|
|
|
|
@@ -182,7 +188,7 @@ func (svr *Service) keepControllerWorking() {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- // the first three retry with no delay
|
|
|
+ // the first three attempts with a low delay
|
|
|
if reconnectCounts > 3 {
|
|
|
util.RandomSleep(reconnectDelay, 0.9, 1.1)
|
|
|
xl.Info("wait %v to reconnect", reconnectDelay)
|
|
@@ -322,10 +328,13 @@ func (svr *Service) GracefulClose(d time.Duration) {
|
|
|
svr.ctlMu.RLock()
|
|
|
if svr.ctl != nil {
|
|
|
svr.ctl.GracefulClose(d)
|
|
|
+ svr.ctl = nil
|
|
|
}
|
|
|
svr.ctlMu.RUnlock()
|
|
|
|
|
|
- svr.cancel()
|
|
|
+ if svr.cancel != nil {
|
|
|
+ svr.cancel()
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
type ConnectionManager struct {
|
|
@@ -427,7 +436,11 @@ func (cm *ConnectionManager) realConnect() (net.Conn, error) {
|
|
|
xl := xlog.FromContextSafe(cm.ctx)
|
|
|
var tlsConfig *tls.Config
|
|
|
var err error
|
|
|
- if cm.cfg.TLSEnable {
|
|
|
+ tlsEnable := cm.cfg.TLSEnable
|
|
|
+ if cm.cfg.Protocol == "wss" {
|
|
|
+ tlsEnable = true
|
|
|
+ }
|
|
|
+ if tlsEnable {
|
|
|
sn := cm.cfg.TLSServerName
|
|
|
if sn == "" {
|
|
|
sn = cm.cfg.ServerAddr
|
|
@@ -451,10 +464,23 @@ func (cm *ConnectionManager) realConnect() (net.Conn, error) {
|
|
|
}
|
|
|
dialOptions := []libdial.DialOption{}
|
|
|
protocol := cm.cfg.Protocol
|
|
|
- if protocol == "websocket" {
|
|
|
+ switch protocol {
|
|
|
+ case "websocket":
|
|
|
protocol = "tcp"
|
|
|
- dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: utilnet.DialHookWebsocket()}))
|
|
|
+ dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: utilnet.DialHookWebsocket(protocol, "")}))
|
|
|
+ dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{
|
|
|
+ Hook: utilnet.DialHookCustomTLSHeadByte(tlsConfig != nil, cm.cfg.DisableCustomTLSFirstByte),
|
|
|
+ }))
|
|
|
+ dialOptions = append(dialOptions, libdial.WithTLSConfig(tlsConfig))
|
|
|
+ case "wss":
|
|
|
+ protocol = "tcp"
|
|
|
+ dialOptions = append(dialOptions, libdial.WithTLSConfigAndPriority(100, tlsConfig))
|
|
|
+ // Make sure that if it is wss, the websocket hook is executed after the tls hook.
|
|
|
+ dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: utilnet.DialHookWebsocket(protocol, tlsConfig.ServerName), Priority: 110}))
|
|
|
+ default:
|
|
|
+ dialOptions = append(dialOptions, libdial.WithTLSConfig(tlsConfig))
|
|
|
}
|
|
|
+
|
|
|
if cm.cfg.ConnectServerLocalIP != "" {
|
|
|
dialOptions = append(dialOptions, libdial.WithLocalAddr(cm.cfg.ConnectServerLocalIP))
|
|
|
}
|
|
@@ -464,10 +490,6 @@ func (cm *ConnectionManager) realConnect() (net.Conn, error) {
|
|
|
libdial.WithKeepAlive(time.Duration(cm.cfg.DialServerKeepAlive)*time.Second),
|
|
|
libdial.WithProxy(proxyType, addr),
|
|
|
libdial.WithProxyAuth(auth),
|
|
|
- libdial.WithTLSConfig(tlsConfig),
|
|
|
- libdial.WithAfterHook(libdial.AfterHook{
|
|
|
- Hook: utilnet.DialHookCustomTLSHeadByte(tlsConfig != nil, cm.cfg.DisableCustomTLSFirstByte),
|
|
|
- }),
|
|
|
)
|
|
|
conn, err := libdial.DialContext(
|
|
|
cm.ctx,
|