Browse Source

Add tls configuration to both client and server (#1974)

yuyulei 4 years ago
parent
commit
4fff3c7472
6 changed files with 247 additions and 34 deletions
  1. 10 2
      client/control.go
  2. 9 2
      client/service.go
  3. 43 1
      models/config/client_common.go
  4. 39 2
      models/config/server_common.go
  5. 136 0
      models/transport/tls.go
  6. 10 27
      server/service.go

+ 10 - 2
client/control.go

@@ -28,6 +28,7 @@ import (
 	"github.com/fatedier/frp/models/auth"
 	"github.com/fatedier/frp/models/config"
 	"github.com/fatedier/frp/models/msg"
+	"github.com/fatedier/frp/models/transport"
 	frpNet "github.com/fatedier/frp/utils/net"
 	"github.com/fatedier/frp/utils/xlog"
 
@@ -208,9 +209,16 @@ func (ctl *Control) connectServer() (conn net.Conn, err error) {
 		conn = stream
 	} else {
 		var tlsConfig *tls.Config
+
 		if ctl.clientCfg.TLSEnable {
-			tlsConfig = &tls.Config{
-				InsecureSkipVerify: true,
+			tlsConfig, err = transport.NewServerTLSConfig(
+				ctl.clientCfg.TLSCertFile,
+				ctl.clientCfg.TLSKeyFile,
+				ctl.clientCfg.TLSTrustedCaFile)
+
+			if err != nil {
+				xl.Warn("fail to build tls configuration when connecting to server, err: %v", err)
+				return
 			}
 		}
 		conn, err = frpNet.ConnectServerByProxyWithTLS(ctl.clientCfg.HTTPProxy, ctl.clientCfg.Protocol,

+ 9 - 2
client/service.go

@@ -18,6 +18,7 @@ import (
 	"context"
 	"crypto/tls"
 	"fmt"
+	"github.com/fatedier/frp/models/transport"
 	"io/ioutil"
 	"net"
 	"runtime"
@@ -204,8 +205,14 @@ func (svr *Service) login() (conn net.Conn, session *fmux.Session, err error) {
 	xl := xlog.FromContextSafe(svr.ctx)
 	var tlsConfig *tls.Config
 	if svr.cfg.TLSEnable {
-		tlsConfig = &tls.Config{
-			InsecureSkipVerify: true,
+		tlsConfig, err = transport.NewClientTLSConfig(
+			svr.cfg.TLSCertFile,
+			svr.cfg.TLSKeyFile,
+			svr.cfg.TLSTrustedCaFile,
+			svr.cfg.ServerAddr)
+		if err != nil {
+			xl.Warn("fail to build tls configuration when service login, err: %v", err)
+			return
 		}
 	}
 	conn, err = frpNet.ConnectServerByProxyWithTLS(svr.cfg.HTTPProxy, svr.cfg.Protocol,

+ 43 - 1
models/config/client_common.go

@@ -104,8 +104,20 @@ type ClientCommonConf struct {
 	// is "tcp".
 	Protocol string `json:"protocol"`
 	// TLSEnable specifies whether or not TLS should be used when communicating
-	// with the server.
+	// with the server. If "tls_cert_file" and "tls_key_file" are valid,
+	// client will load the supplied tls configuration.
 	TLSEnable bool `json:"tls_enable"`
+	// ClientTLSCertPath specifies the path of the cert file that client will
+	// load. It only works when "tls_enable" is true and "tls_key_file" is valid.
+	TLSCertFile string `json:"tls_cert_file"`
+	// ClientTLSKeyPath specifies the path of the secret key file that client
+	// will load. It only works when "tls_enable" is true and "tls_cert_file"
+	// are valid.
+	TLSKeyFile string `json:"tls_key_file"`
+	// TrustedCaFile specifies the path of the trusted ca file that will load.
+	// It only works when "tls_enable" is valid and tls configuration of server
+	// has been specified.
+	TLSTrustedCaFile string `json:"tls_trusted_ca_file"`
 	// HeartBeatInterval specifies at what interval heartbeats are sent to the
 	// server, in seconds. It is not recommended to change this value. By
 	// default, this value is 30.
@@ -145,6 +157,9 @@ func GetDefaultClientConf() ClientCommonConf {
 		Start:             make(map[string]struct{}),
 		Protocol:          "tcp",
 		TLSEnable:         false,
+		TLSCertFile:       "",
+		TLSKeyFile:        "",
+		TLSTrustedCaFile:  "",
 		HeartBeatInterval: 30,
 		HeartBeatTimeout:  90,
 		Metas:             make(map[string]string),
@@ -280,6 +295,18 @@ func UnmarshalClientConfFromIni(content string) (cfg ClientCommonConf, err error
 		cfg.TLSEnable = false
 	}
 
+	if tmpStr, ok = conf.Get("common", "tls_cert_file"); ok {
+		cfg.TLSCertFile = tmpStr
+	}
+
+	if tmpStr, ok := conf.Get("common", "tls_key_file"); ok {
+		cfg.TLSKeyFile = tmpStr
+	}
+
+	if tmpStr, ok := conf.Get("common", "tls_trusted_ca_file"); ok {
+		cfg.TLSTrustedCaFile = tmpStr
+	}
+
 	if tmpStr, ok = conf.Get("common", "heartbeat_timeout"); ok {
 		if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil {
 			err = fmt.Errorf("Parse conf error: invalid heartbeat_timeout")
@@ -320,5 +347,20 @@ func (cfg *ClientCommonConf) Check() (err error) {
 		err = fmt.Errorf("Parse conf error: invalid heartbeat_timeout, heartbeat_timeout is less than heartbeat_interval")
 		return
 	}
+
+	if cfg.TLSEnable == false {
+		if cfg.TLSCertFile != "" {
+			fmt.Println("WARNING! Because tls_enable is not true, so tls_cert_file will not make sense")
+		}
+
+		if cfg.TLSKeyFile != "" {
+			fmt.Println("WARNING! Because tls_enable is not true, so tls_key_file will not make sense")
+		}
+
+		if cfg.TLSTrustedCaFile != "" {
+			fmt.Println("WARNING! Because tls_enable is not true, so tls_trusted_ca_file will not make sense")
+		}
+	}
+
 	return
 }

+ 39 - 2
models/config/server_common.go

@@ -133,9 +133,24 @@ type ServerCommonConf struct {
 	// may proxy to. If this value is 0, no limit will be applied. By default,
 	// this value is 0.
 	MaxPortsPerClient int64 `json:"max_ports_per_client"`
-	// TLSOnly specifies whether to only accept TLS-encrypted connections. By
-	// default, the value is false.
+	// TLSOnly specifies whether to only accept TLS-encrypted connections.
+	// By default, the value is false.
 	TLSOnly bool `json:"tls_only"`
+	// TLSCertFile specifies the path of the cert file that the server will
+	// load. If "tls_cert_file", "tls_key_file" are valid, the server will use this
+	// supplied tls configuration. Otherwise, the server will use the tls
+	// configuration generated by itself.
+	TLSCertFile string `json:"tls_cert_file"`
+	// TLSKeyFile specifies the path of the secret key that the server will
+	// load. If "tls_cert_file", "tls_key_file" are valid, the server will use this
+	// supplied tls configuration. Otherwise, the server will use the tls
+	// configuration generated by itself.
+	TLSKeyFile string `json:"tls_key_file"`
+	// TLSTrustedCaFile specifies the paths of the client cert files that the
+	// server will load. It only works when "tls_only" is true. If
+	// "tls_trusted_ca_file" is valid, the server will verify each client's
+	// certificate.
+	TLSTrustedCaFile string `json:"tls_trusted_ca_file"`
 	// HeartBeatTimeout specifies the maximum time to wait for a heartbeat
 	// before terminating the connection. It is not recommended to change this
 	// value. By default, this value is 90.
@@ -181,6 +196,9 @@ func GetDefaultServerConf() ServerCommonConf {
 		MaxPoolCount:           5,
 		MaxPortsPerClient:      0,
 		TLSOnly:                false,
+		TLSCertFile:            "",
+		TLSKeyFile:             "",
+		TLSTrustedCaFile:       "",
 		HeartBeatTimeout:       90,
 		UserConnTimeout:        10,
 		Custom404Page:          "",
@@ -419,6 +437,19 @@ func UnmarshalServerConfFromIni(content string) (cfg ServerCommonConf, err error
 		}
 		cfg.UDPPacketSize = v
 	}
+
+	if tmpStr, ok := conf.Get("common", "tls_cert_file"); ok {
+		cfg.TLSCertFile = tmpStr
+	}
+
+	if tmpStr, ok := conf.Get("common", "tls_key_file"); ok {
+		cfg.TLSKeyFile = tmpStr
+	}
+
+	if tmpStr, ok := conf.Get("common", "tls_trusted_ca_file"); ok {
+		cfg.TLSTrustedCaFile = tmpStr
+	}
+
 	return
 }
 
@@ -441,5 +472,11 @@ func UnmarshalPluginsFromIni(sections ini.File, cfg *ServerCommonConf) {
 }
 
 func (cfg *ServerCommonConf) Check() (err error) {
+	if cfg.TLSOnly == false {
+		if cfg.TLSTrustedCaFile != "" {
+			err = fmt.Errorf("Parse conf error: forbidden tls_trusted_ca_file, it only works when tls_only is true")
+			return
+		}
+	}
 	return
 }

+ 136 - 0
models/transport/tls.go

@@ -0,0 +1,136 @@
+package transport
+
+import (
+	"crypto/rand"
+	"crypto/rsa"
+	"crypto/tls"
+	"crypto/x509"
+	"encoding/pem"
+	"io/ioutil"
+	"math/big"
+)
+
+/*
+	Example for self-signed certificates by openssl:
+
+	Self CA:
+	openssl genrsa -out ca.key 2048
+	openssl req -x509 -new -nodes -key ca.key -subj "/CN=example.ca.com" -days 5000 -out ca.crt
+
+	Server:
+	openssl genrsa -out server.key 2048
+	openssl req -new -key server.key -subj "/CN=example.server.com" -out server.csr
+	openssl x509 -req -in server.csr -CA ca.crt -CAkey ca.key -CAcreateserial -out server.crt -days 5000
+
+	Client:
+	openssl genrsa -out client.key 2048
+	openssl req -new -key client.key -subj "/CN=example.client.com" -out client.csr
+	openssl x509 -req -in client.csr -CA ca.crt -CAkey ca.key -CAcreateserial -out client.crt -days 5000
+
+*/
+
+func newCustomTLSKeyPair(certfile, keyfile string) (*tls.Certificate, error) {
+	tlsCert, err := tls.LoadX509KeyPair(certfile, keyfile)
+	if err != nil {
+		return nil, err
+	}
+	return &tlsCert, nil
+}
+
+func newRandomTLSKeyPair() *tls.Certificate {
+	key, err := rsa.GenerateKey(rand.Reader, 1024)
+	if err != nil {
+		panic(err)
+	}
+	template := x509.Certificate{SerialNumber: big.NewInt(1)}
+	certDER, err := x509.CreateCertificate(
+		rand.Reader,
+		&template,
+		&template,
+		&key.PublicKey,
+		key)
+	if err != nil {
+		panic(err)
+	}
+	keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
+	certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
+
+	tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
+	if err != nil {
+		panic(err)
+	}
+	return &tlsCert
+}
+
+// Only supprt one ca file to add
+func newCertPool(caPath string) (*x509.CertPool, error) {
+	pool := x509.NewCertPool()
+
+	caCrt, err := ioutil.ReadFile(caPath)
+	if err != nil {
+		return nil, err
+	}
+
+	pool.AppendCertsFromPEM(caCrt)
+
+	return pool, nil
+}
+
+func NewServerTLSConfig(certPath, keyPath, caPath string) (*tls.Config, error) {
+	var base = &tls.Config{}
+
+	if certPath == "" || keyPath == "" {
+		// server will generate tls conf by itself
+		cert := newRandomTLSKeyPair()
+		base.Certificates = []tls.Certificate{*cert}
+	} else {
+		cert, err := newCustomTLSKeyPair(certPath, keyPath)
+		if err != nil {
+			return nil, err
+		}
+
+		base.Certificates = []tls.Certificate{*cert}
+	}
+
+	if caPath != "" {
+		pool, err := newCertPool(caPath)
+		if err != nil {
+			return nil, err
+		}
+
+		base.ClientAuth = tls.RequireAndVerifyClientCert
+		base.ClientCAs = pool
+	}
+
+	return base, nil
+}
+
+func NewClientTLSConfig(certPath, keyPath, caPath, servearName string) (*tls.Config, error) {
+	var base = &tls.Config{}
+
+	if certPath == "" || keyPath == "" {
+		// client will not generate tls conf by itself
+	} else {
+		cert, err := newCustomTLSKeyPair(certPath, keyPath)
+		if err != nil {
+			return nil, err
+		}
+
+		base.Certificates = []tls.Certificate{*cert}
+	}
+
+	if caPath != "" {
+		pool, err := newCertPool(caPath)
+		if err != nil {
+			return nil, err
+		}
+
+		base.RootCAs = pool
+		base.ServerName = servearName
+		base.InsecureSkipVerify = false
+	} else {
+		base.InsecureSkipVerify = true
+	}
+
+	return base, nil
+}

+ 10 - 27
server/service.go

@@ -17,14 +17,9 @@ package server
 import (
 	"bytes"
 	"context"
-	"crypto/rand"
-	"crypto/rsa"
 	"crypto/tls"
-	"crypto/x509"
-	"encoding/pem"
 	"fmt"
 	"io/ioutil"
-	"math/big"
 	"net"
 	"net/http"
 	"sort"
@@ -37,6 +32,7 @@ import (
 	"github.com/fatedier/frp/models/msg"
 	"github.com/fatedier/frp/models/nathole"
 	plugin "github.com/fatedier/frp/models/plugin/server"
+	"github.com/fatedier/frp/models/transport"
 	"github.com/fatedier/frp/server/controller"
 	"github.com/fatedier/frp/server/group"
 	"github.com/fatedier/frp/server/metrics"
@@ -101,6 +97,14 @@ type Service struct {
 }
 
 func NewService(cfg config.ServerCommonConf) (svr *Service, err error) {
+	tlsConfig, err := transport.NewServerTLSConfig(
+		cfg.TLSCertFile,
+		cfg.TLSKeyFile,
+		cfg.TLSTrustedCaFile)
+	if err != nil {
+		return
+	}
+
 	svr = &Service{
 		ctlManager:    NewControlManager(),
 		pxyManager:    proxy.NewManager(),
@@ -112,7 +116,7 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) {
 		},
 		httpVhostRouter: vhost.NewRouters(),
 		authVerifier:    auth.NewAuthVerifier(cfg.ServerConfig),
-		tlsConfig:       generateTLSConfig(),
+		tlsConfig:       tlsConfig,
 		cfg:             cfg,
 	}
 
@@ -506,24 +510,3 @@ func (svr *Service) RegisterVisitorConn(visitorConn net.Conn, newMsg *msg.NewVis
 	return svr.rc.VisitorManager.NewConn(newMsg.ProxyName, visitorConn, newMsg.Timestamp, newMsg.SignKey,
 		newMsg.UseEncryption, newMsg.UseCompression)
 }
-
-// Setup a bare-bones TLS config for the server
-func generateTLSConfig() *tls.Config {
-	key, err := rsa.GenerateKey(rand.Reader, 1024)
-	if err != nil {
-		panic(err)
-	}
-	template := x509.Certificate{SerialNumber: big.NewInt(1)}
-	certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
-	if err != nil {
-		panic(err)
-	}
-	keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
-	certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
-
-	tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
-	if err != nil {
-		panic(err)
-	}
-	return &tls.Config{Certificates: []tls.Certificate{tlsCert}}
-}