Browse Source

refactor: reorganize security policy into dedicated packag

fatedier 2 days ago
parent
commit
33ab7eedd6

+ 3 - 2
client/service.go

@@ -31,6 +31,7 @@ import (
 	"github.com/fatedier/frp/pkg/auth"
 	v1 "github.com/fatedier/frp/pkg/config/v1"
 	"github.com/fatedier/frp/pkg/msg"
+	"github.com/fatedier/frp/pkg/policy/security"
 	httppkg "github.com/fatedier/frp/pkg/util/http"
 	"github.com/fatedier/frp/pkg/util/log"
 	netpkg "github.com/fatedier/frp/pkg/util/net"
@@ -64,7 +65,7 @@ type ServiceOptions struct {
 	ProxyCfgs   []v1.ProxyConfigurer
 	VisitorCfgs []v1.VisitorConfigurer
 
-	UnsafeFeatures v1.UnsafeFeatures
+	UnsafeFeatures *security.UnsafeFeatures
 
 	// ConfigFilePath is the path to the configuration file used to initialize.
 	// If it is empty, it means that the configuration file is not used for initialization.
@@ -124,7 +125,7 @@ type Service struct {
 	visitorCfgs []v1.VisitorConfigurer
 	clientSpec  *msg.ClientSpec
 
-	unsafeFeatures v1.UnsafeFeatures
+	unsafeFeatures *security.UnsafeFeatures
 
 	// The configuration file used to initialize this client, or an empty
 	// string if no configuration file was used.

+ 3 - 2
cmd/frpc/sub/proxy.go

@@ -24,6 +24,7 @@ import (
 	"github.com/fatedier/frp/pkg/config"
 	v1 "github.com/fatedier/frp/pkg/config/v1"
 	"github.com/fatedier/frp/pkg/config/v1/validation"
+	"github.com/fatedier/frp/pkg/policy/security"
 )
 
 var proxyTypes = []v1.ProxyType{
@@ -78,7 +79,7 @@ func NewProxyCommand(name string, c v1.ProxyConfigurer, clientCfg *v1.ClientComm
 				os.Exit(1)
 			}
 
-			unsafeFeatures := v1.NewUnsafeFeatures(allowUnsafe)
+			unsafeFeatures := security.NewUnsafeFeatures(allowUnsafe)
 			if _, err := validation.ValidateClientCommonConfig(clientCfg, unsafeFeatures); err != nil {
 				fmt.Println(err)
 				os.Exit(1)
@@ -108,7 +109,7 @@ func NewVisitorCommand(name string, c v1.VisitorConfigurer, clientCfg *v1.Client
 				fmt.Println(err)
 				os.Exit(1)
 			}
-			unsafeFeatures := v1.NewUnsafeFeatures(allowUnsafe)
+			unsafeFeatures := security.NewUnsafeFeatures(allowUnsafe)
 			if _, err := validation.ValidateClientCommonConfig(clientCfg, unsafeFeatures); err != nil {
 				fmt.Println(err)
 				os.Exit(1)

+ 10 - 6
cmd/frpc/sub/root.go

@@ -21,6 +21,7 @@ import (
 	"os"
 	"os/signal"
 	"path/filepath"
+	"strings"
 	"sync"
 	"syscall"
 	"time"
@@ -31,7 +32,8 @@ import (
 	"github.com/fatedier/frp/pkg/config"
 	v1 "github.com/fatedier/frp/pkg/config/v1"
 	"github.com/fatedier/frp/pkg/config/v1/validation"
-	"github.com/fatedier/frp/pkg/featuregate"
+	"github.com/fatedier/frp/pkg/policy/featuregate"
+	"github.com/fatedier/frp/pkg/policy/security"
 	"github.com/fatedier/frp/pkg/util/log"
 	"github.com/fatedier/frp/pkg/util/version"
 )
@@ -49,7 +51,9 @@ func init() {
 	rootCmd.PersistentFlags().StringVarP(&cfgDir, "config_dir", "", "", "config directory, run one frpc service for each file in config directory")
 	rootCmd.PersistentFlags().BoolVarP(&showVersion, "version", "v", false, "version of frpc")
 	rootCmd.PersistentFlags().BoolVarP(&strictConfigMode, "strict_config", "", true, "strict config parsing mode, unknown fields will cause an errors")
-	rootCmd.PersistentFlags().StringSliceVarP(&allowUnsafe, "allow-unsafe", "", []string{}, "allowed unsafe features, one or more of: TokenSourceExec")
+
+	rootCmd.PersistentFlags().StringSliceVarP(&allowUnsafe, "allow-unsafe", "", []string{},
+		fmt.Sprintf("allowed unsafe features, one or more of: %s", strings.Join(security.ClientUnsafeFeatures, ", ")))
 }
 
 var rootCmd = &cobra.Command{
@@ -61,7 +65,7 @@ var rootCmd = &cobra.Command{
 			return nil
 		}
 
-		unsafeFeatures := v1.NewUnsafeFeatures(allowUnsafe)
+		unsafeFeatures := security.NewUnsafeFeatures(allowUnsafe)
 
 		// If cfgDir is not empty, run multiple frpc service for each config file in cfgDir.
 		// Note that it's only designed for testing. It's not guaranteed to be stable.
@@ -80,7 +84,7 @@ var rootCmd = &cobra.Command{
 	},
 }
 
-func runMultipleClients(cfgDir string, unsafeFeatures v1.UnsafeFeatures) error {
+func runMultipleClients(cfgDir string, unsafeFeatures *security.UnsafeFeatures) error {
 	var wg sync.WaitGroup
 	err := filepath.WalkDir(cfgDir, func(path string, d fs.DirEntry, err error) error {
 		if err != nil || d.IsDir() {
@@ -115,7 +119,7 @@ func handleTermSignal(svr *client.Service) {
 	svr.GracefulClose(500 * time.Millisecond)
 }
 
-func runClient(cfgFilePath string, unsafeFeatures v1.UnsafeFeatures) error {
+func runClient(cfgFilePath string, unsafeFeatures *security.UnsafeFeatures) error {
 	cfg, proxyCfgs, visitorCfgs, isLegacyFormat, err := config.LoadClientConfig(cfgFilePath, strictConfigMode)
 	if err != nil {
 		return err
@@ -145,7 +149,7 @@ func startService(
 	cfg *v1.ClientCommonConfig,
 	proxyCfgs []v1.ProxyConfigurer,
 	visitorCfgs []v1.VisitorConfigurer,
-	unsafeFeatures v1.UnsafeFeatures,
+	unsafeFeatures *security.UnsafeFeatures,
 	cfgFile string,
 ) error {
 	log.InitLogger(cfg.Log.To, cfg.Log.Level, int(cfg.Log.MaxDays), cfg.Log.DisablePrintColor)

+ 2 - 2
cmd/frpc/sub/verify.go

@@ -21,8 +21,8 @@ import (
 	"github.com/spf13/cobra"
 
 	"github.com/fatedier/frp/pkg/config"
-	v1 "github.com/fatedier/frp/pkg/config/v1"
 	"github.com/fatedier/frp/pkg/config/v1/validation"
+	"github.com/fatedier/frp/pkg/policy/security"
 )
 
 func init() {
@@ -43,7 +43,7 @@ var verifyCmd = &cobra.Command{
 			fmt.Println(err)
 			os.Exit(1)
 		}
-		unsafeFeatures := v1.NewUnsafeFeatures(allowUnsafe)
+		unsafeFeatures := security.NewUnsafeFeatures(allowUnsafe)
 		warning, err := validation.ValidateAllClientConfig(cliCfg, proxyCfgs, visitorCfgs, unsafeFeatures)
 		if warning != nil {
 			fmt.Printf("WARNING: %v\n", warning)

+ 0 - 20
pkg/config/v1/client.go

@@ -248,23 +248,3 @@ type AuthOIDCClientConfig struct {
 type VirtualNetConfig struct {
 	Address string `json:"address,omitempty"`
 }
-
-const (
-	UnsafeFeatureTokenSourceExec = "TokenSourceExec"
-)
-
-type UnsafeFeatures struct {
-	features map[string]bool
-}
-
-func NewUnsafeFeatures(allowed []string) UnsafeFeatures {
-	features := make(map[string]bool)
-	for _, f := range allowed {
-		features[f] = true
-	}
-	return UnsafeFeatures{features: features}
-}
-
-func (u UnsafeFeatures) IsEnabled(feature string) bool {
-	return u.features[feature]
-}

+ 81 - 36
pkg/config/v1/validation/client.go

@@ -23,70 +23,111 @@ import (
 	"github.com/samber/lo"
 
 	v1 "github.com/fatedier/frp/pkg/config/v1"
-	"github.com/fatedier/frp/pkg/featuregate"
+	"github.com/fatedier/frp/pkg/policy/featuregate"
+	"github.com/fatedier/frp/pkg/policy/security"
 )
 
-func ValidateClientCommonConfig(c *v1.ClientCommonConfig, unsafeFeatures v1.UnsafeFeatures) (Warning, error) {
+func ValidateClientCommonConfig(c *v1.ClientCommonConfig, unsafeFeatures *security.UnsafeFeatures) (Warning, error) {
 	var (
 		warnings Warning
 		errs     error
 	)
-	// validate feature gates
+
+	validators := []func() (Warning, error){
+		func() (Warning, error) { return validateFeatureGates(c) },
+		func() (Warning, error) { return validateAuthConfig(&c.Auth, unsafeFeatures) },
+		func() (Warning, error) { return nil, validateLogConfig(&c.Log) },
+		func() (Warning, error) { return nil, validateWebServerConfig(&c.WebServer) },
+		func() (Warning, error) { return validateTransportConfig(&c.Transport) },
+		func() (Warning, error) { return validateIncludeFiles(c.IncludeConfigFiles) },
+	}
+
+	for _, v := range validators {
+		w, err := v()
+		warnings = AppendError(warnings, w)
+		errs = AppendError(errs, err)
+	}
+	return warnings, errs
+}
+
+func validateFeatureGates(c *v1.ClientCommonConfig) (Warning, error) {
 	if c.VirtualNet.Address != "" {
 		if !featuregate.Enabled(featuregate.VirtualNet) {
-			return warnings, fmt.Errorf("VirtualNet feature is not enabled; enable it by setting the appropriate feature gate flag")
+			return nil, fmt.Errorf("VirtualNet feature is not enabled; enable it by setting the appropriate feature gate flag")
 		}
 	}
+	return nil, nil
+}
 
-	if !slices.Contains(SupportedAuthMethods, c.Auth.Method) {
+func validateAuthConfig(c *v1.AuthClientConfig, unsafeFeatures *security.UnsafeFeatures) (Warning, error) {
+	var errs error
+	if !slices.Contains(SupportedAuthMethods, c.Method) {
 		errs = AppendError(errs, fmt.Errorf("invalid auth method, optional values are %v", SupportedAuthMethods))
 	}
-	if !lo.Every(SupportedAuthAdditionalScopes, c.Auth.AdditionalScopes) {
+	if !lo.Every(SupportedAuthAdditionalScopes, c.AdditionalScopes) {
 		errs = AppendError(errs, fmt.Errorf("invalid auth additional scopes, optional values are %v", SupportedAuthAdditionalScopes))
 	}
 
 	// Validate token/tokenSource mutual exclusivity
-	if c.Auth.Token != "" && c.Auth.TokenSource != nil {
+	if c.Token != "" && c.TokenSource != nil {
 		errs = AppendError(errs, fmt.Errorf("cannot specify both auth.token and auth.tokenSource"))
 	}
 
 	// Validate tokenSource if specified
-	if c.Auth.TokenSource != nil {
-		if c.Auth.TokenSource.Type == "exec" && !unsafeFeatures.IsEnabled(v1.UnsafeFeatureTokenSourceExec) {
-			errs = AppendError(errs, fmt.Errorf("unsafe 'exec' not allowed for auth.tokenSource.type"))
+	if c.TokenSource != nil {
+		if c.TokenSource.Type == "exec" {
+			if !unsafeFeatures.IsEnabled(security.TokenSourceExec) {
+				errs = AppendError(errs, fmt.Errorf("unsafe feature %q is not enabled. "+
+					"To enable it, start frpc with '--allow-unsafe %s'", security.TokenSourceExec, security.TokenSourceExec))
+			}
 		}
-		if err := c.Auth.TokenSource.Validate(); err != nil {
+		if err := c.TokenSource.Validate(); err != nil {
 			errs = AppendError(errs, fmt.Errorf("invalid auth.tokenSource: %v", err))
 		}
 	}
 
-	if c.Auth.OIDC.TokenSource != nil {
-		// Validate oidc.tokenSource mutual exclusivity with other fields of oidc
-		if c.Auth.OIDC.ClientID != "" || c.Auth.OIDC.ClientSecret != "" || c.Auth.OIDC.Audience != "" ||
-			c.Auth.OIDC.Scope != "" || c.Auth.OIDC.TokenEndpointURL != "" || len(c.Auth.OIDC.AdditionalEndpointParams) > 0 ||
-			c.Auth.OIDC.TrustedCaFile != "" || c.Auth.OIDC.InsecureSkipVerify || c.Auth.OIDC.ProxyURL != "" {
-			errs = AppendError(errs, fmt.Errorf("cannot specify both auth.oidc.tokenSource and any other field of auth.oidc"))
-		}
-		if c.Auth.OIDC.TokenSource.Type == "exec" && !unsafeFeatures.IsEnabled(v1.UnsafeFeatureTokenSourceExec) {
-			errs = AppendError(errs, fmt.Errorf("unsafe 'exec' not allowed for auth.oidc.tokenSource.type"))
-		}
-	}
-
-	if err := validateLogConfig(&c.Log); err != nil {
+	if err := validateOIDCConfig(&c.OIDC, unsafeFeatures); err != nil {
 		errs = AppendError(errs, err)
 	}
+	return nil, errs
+}
 
-	if err := validateWebServerConfig(&c.WebServer); err != nil {
-		errs = AppendError(errs, err)
+func validateOIDCConfig(c *v1.AuthOIDCClientConfig, unsafeFeatures *security.UnsafeFeatures) error {
+	if c.TokenSource == nil {
+		return nil
+	}
+	var errs error
+	// Validate oidc.tokenSource mutual exclusivity with other fields of oidc
+	if c.ClientID != "" || c.ClientSecret != "" || c.Audience != "" ||
+		c.Scope != "" || c.TokenEndpointURL != "" || len(c.AdditionalEndpointParams) > 0 ||
+		c.TrustedCaFile != "" || c.InsecureSkipVerify || c.ProxyURL != "" {
+		errs = AppendError(errs, fmt.Errorf("cannot specify both auth.oidc.tokenSource and any other field of auth.oidc"))
+	}
+	if c.TokenSource.Type == "exec" {
+		if !unsafeFeatures.IsEnabled(security.TokenSourceExec) {
+			errs = AppendError(errs, fmt.Errorf("unsafe feature %q is not enabled. "+
+				"To enable it, start frpc with '--allow-unsafe %s'", security.TokenSourceExec, security.TokenSourceExec))
+		}
 	}
+	if err := c.TokenSource.Validate(); err != nil {
+		errs = AppendError(errs, fmt.Errorf("invalid auth.oidc.tokenSource: %v", err))
+	}
+	return errs
+}
+
+func validateTransportConfig(c *v1.ClientTransportConfig) (Warning, error) {
+	var (
+		warnings Warning
+		errs     error
+	)
 
-	if c.Transport.HeartbeatTimeout > 0 && c.Transport.HeartbeatInterval > 0 {
-		if c.Transport.HeartbeatTimeout < c.Transport.HeartbeatInterval {
+	if c.HeartbeatTimeout > 0 && c.HeartbeatInterval > 0 {
+		if c.HeartbeatTimeout < c.HeartbeatInterval {
 			errs = AppendError(errs, fmt.Errorf("invalid transport.heartbeatTimeout, heartbeat timeout should not less than heartbeat interval"))
 		}
 	}
 
-	if !lo.FromPtr(c.Transport.TLS.Enable) {
+	if !lo.FromPtr(c.TLS.Enable) {
 		checkTLSConfig := func(name string, value string) Warning {
 			if value != "" {
 				return fmt.Errorf("%s is invalid when transport.tls.enable is false", name)
@@ -94,16 +135,20 @@ func ValidateClientCommonConfig(c *v1.ClientCommonConfig, unsafeFeatures v1.Unsa
 			return nil
 		}
 
-		warnings = AppendError(warnings, checkTLSConfig("transport.tls.certFile", c.Transport.TLS.CertFile))
-		warnings = AppendError(warnings, checkTLSConfig("transport.tls.keyFile", c.Transport.TLS.KeyFile))
-		warnings = AppendError(warnings, checkTLSConfig("transport.tls.trustedCaFile", c.Transport.TLS.TrustedCaFile))
+		warnings = AppendError(warnings, checkTLSConfig("transport.tls.certFile", c.TLS.CertFile))
+		warnings = AppendError(warnings, checkTLSConfig("transport.tls.keyFile", c.TLS.KeyFile))
+		warnings = AppendError(warnings, checkTLSConfig("transport.tls.trustedCaFile", c.TLS.TrustedCaFile))
 	}
 
-	if !slices.Contains(SupportedTransportProtocols, c.Transport.Protocol) {
+	if !slices.Contains(SupportedTransportProtocols, c.Protocol) {
 		errs = AppendError(errs, fmt.Errorf("invalid transport.protocol, optional values are %v", SupportedTransportProtocols))
 	}
+	return warnings, errs
+}
 
-	for _, f := range c.IncludeConfigFiles {
+func validateIncludeFiles(files []string) (Warning, error) {
+	var errs error
+	for _, f := range files {
 		absDir, err := filepath.Abs(filepath.Dir(f))
 		if err != nil {
 			errs = AppendError(errs, fmt.Errorf("include: parse directory of %s failed: %v", f, err))
@@ -113,14 +158,14 @@ func ValidateClientCommonConfig(c *v1.ClientCommonConfig, unsafeFeatures v1.Unsa
 			errs = AppendError(errs, fmt.Errorf("include: directory of %s not exist", f))
 		}
 	}
-	return warnings, errs
+	return nil, errs
 }
 
 func ValidateAllClientConfig(
 	c *v1.ClientCommonConfig,
 	proxyCfgs []v1.ProxyConfigurer,
 	visitorCfgs []v1.VisitorConfigurer,
-	unsafeFeatures v1.UnsafeFeatures,
+	unsafeFeatures *security.UnsafeFeatures,
 ) (Warning, error) {
 	var warnings Warning
 	if c != nil {

+ 0 - 0
pkg/featuregate/feature_gate.go → pkg/policy/featuregate/feature_gate.go


+ 34 - 0
pkg/policy/security/unsafe.go

@@ -0,0 +1,34 @@
+package security
+
+const (
+	TokenSourceExec = "TokenSourceExec"
+)
+
+var (
+	ClientUnsafeFeatures = []string{
+		TokenSourceExec,
+	}
+
+	ServerUnsafeFeatures = []string{
+		TokenSourceExec,
+	}
+)
+
+type UnsafeFeatures struct {
+	features map[string]bool
+}
+
+func NewUnsafeFeatures(allowed []string) *UnsafeFeatures {
+	features := make(map[string]bool)
+	for _, f := range allowed {
+		features[f] = true
+	}
+	return &UnsafeFeatures{features: features}
+}
+
+func (u *UnsafeFeatures) IsEnabled(feature string) bool {
+	if u == nil {
+		return false
+	}
+	return u.features[feature]
+}