Kaynağa Gözat

use constant time comparison (#3452)

fatedier 1 yıl önce
ebeveyn
işleme
4915852b9c

+ 1 - 1
client/admin.go

@@ -48,7 +48,7 @@ func (svr *Service) RunAdminServer(address string) (err error) {
 
 	subRouter := router.NewRoute().Subrouter()
 	user, passwd := svr.cfg.AdminUser, svr.cfg.AdminPwd
-	subRouter.Use(frpNet.NewHTTPAuthMiddleware(user, passwd).Middleware)
+	subRouter.Use(frpNet.NewHTTPAuthMiddleware(user, passwd).SetAuthFailDelay(200 * time.Millisecond).Middleware)
 
 	// api, see admin_api.go
 	subRouter.HandleFunc("/api/reload", svr.apiReload).Methods("GET")

+ 6 - 6
pkg/auth/token.go

@@ -73,30 +73,30 @@ func (auth *TokenAuthSetterVerifier) SetNewWorkConn(newWorkConnMsg *msg.NewWorkC
 	return nil
 }
 
-func (auth *TokenAuthSetterVerifier) VerifyLogin(loginMsg *msg.Login) error {
-	if util.GetAuthKey(auth.token, loginMsg.Timestamp) != loginMsg.PrivilegeKey {
+func (auth *TokenAuthSetterVerifier) VerifyLogin(m *msg.Login) error {
+	if !util.ConstantTimeEqString(util.GetAuthKey(auth.token, m.Timestamp), m.PrivilegeKey) {
 		return fmt.Errorf("token in login doesn't match token from configuration")
 	}
 	return nil
 }
 
-func (auth *TokenAuthSetterVerifier) VerifyPing(pingMsg *msg.Ping) error {
+func (auth *TokenAuthSetterVerifier) VerifyPing(m *msg.Ping) error {
 	if !auth.AuthenticateHeartBeats {
 		return nil
 	}
 
-	if util.GetAuthKey(auth.token, pingMsg.Timestamp) != pingMsg.PrivilegeKey {
+	if !util.ConstantTimeEqString(util.GetAuthKey(auth.token, m.Timestamp), m.PrivilegeKey) {
 		return fmt.Errorf("token in heartbeat doesn't match token from configuration")
 	}
 	return nil
 }
 
-func (auth *TokenAuthSetterVerifier) VerifyNewWorkConn(newWorkConnMsg *msg.NewWorkConn) error {
+func (auth *TokenAuthSetterVerifier) VerifyNewWorkConn(m *msg.NewWorkConn) error {
 	if !auth.AuthenticateNewWorkConns {
 		return nil
 	}
 
-	if util.GetAuthKey(auth.token, newWorkConnMsg.Timestamp) != newWorkConnMsg.PrivilegeKey {
+	if !util.ConstantTimeEqString(util.GetAuthKey(auth.token, m.Timestamp), m.PrivilegeKey) {
 		return fmt.Errorf("token in NewWorkConn doesn't match token from configuration")
 	}
 	return nil

+ 1 - 1
pkg/nathole/controller.go

@@ -174,7 +174,7 @@ func (c *Controller) HandleVisitor(m *msg.NatHoleVisitor, transporter transport.
 		if !ok {
 			return fmt.Errorf("xtcp server for [%s] doesn't exist", m.ProxyName)
 		}
-		if m.SignKey != util.GetAuthKey(clientCfg.sk, m.Timestamp) {
+		if !util.ConstantTimeEqString(m.SignKey, util.GetAuthKey(clientCfg.sk, m.Timestamp)) {
 			return fmt.Errorf("xtcp connection of [%s] auth failed", m.ProxyName)
 		}
 		c.sessions[sid] = session

+ 5 - 1
pkg/plugin/client/http_proxy.go

@@ -21,11 +21,13 @@ import (
 	"net"
 	"net/http"
 	"strings"
+	"time"
 
 	frpIo "github.com/fatedier/golib/io"
 	gnet "github.com/fatedier/golib/net"
 
 	frpNet "github.com/fatedier/frp/pkg/util/net"
+	"github.com/fatedier/frp/pkg/util/util"
 )
 
 const PluginHTTPProxy = "http_proxy"
@@ -179,7 +181,9 @@ func (hp *HTTPProxy) Auth(req *http.Request) bool {
 		return false
 	}
 
-	if pair[0] != hp.AuthUser || pair[1] != hp.AuthPasswd {
+	if !util.ConstantTimeEqString(pair[0], hp.AuthUser) ||
+		!util.ConstantTimeEqString(pair[1], hp.AuthPasswd) {
+		time.Sleep(200 * time.Millisecond)
 		return false
 	}
 	return true

+ 2 - 1
pkg/plugin/client/static_file.go

@@ -18,6 +18,7 @@ import (
 	"io"
 	"net"
 	"net/http"
+	"time"
 
 	"github.com/gorilla/mux"
 
@@ -64,7 +65,7 @@ func NewStaticFilePlugin(params map[string]string) (Plugin, error) {
 	}
 
 	router := mux.NewRouter()
-	router.Use(frpNet.NewHTTPAuthMiddleware(httpUser, httpPasswd).Middleware)
+	router.Use(frpNet.NewHTTPAuthMiddleware(httpUser, httpPasswd).SetAuthFailDelay(200 * time.Millisecond).Middleware)
 	router.PathPrefix(prefix).Handler(frpNet.MakeHTTPGzipHandler(http.StripPrefix(prefix, http.FileServer(http.Dir(localPath))))).Methods("GET")
 	sp.s = &http.Server{
 		Handler: router,

+ 16 - 16
pkg/util/net/http.go

@@ -19,6 +19,9 @@ import (
 	"io"
 	"net/http"
 	"strings"
+	"time"
+
+	"github.com/fatedier/frp/pkg/util/util"
 )
 
 type HTTPAuthWraper struct {
@@ -46,8 +49,9 @@ func (aw *HTTPAuthWraper) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 }
 
 type HTTPAuthMiddleware struct {
-	user   string
-	passwd string
+	user          string
+	passwd        string
+	authFailDelay time.Duration
 }
 
 func NewHTTPAuthMiddleware(user, passwd string) *HTTPAuthMiddleware {
@@ -57,32 +61,28 @@ func NewHTTPAuthMiddleware(user, passwd string) *HTTPAuthMiddleware {
 	}
 }
 
+func (authMid *HTTPAuthMiddleware) SetAuthFailDelay(delay time.Duration) *HTTPAuthMiddleware {
+	authMid.authFailDelay = delay
+	return authMid
+}
+
 func (authMid *HTTPAuthMiddleware) Middleware(next http.Handler) http.Handler {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		reqUser, reqPasswd, hasAuth := r.BasicAuth()
 		if (authMid.user == "" && authMid.passwd == "") ||
-			(hasAuth && reqUser == authMid.user && reqPasswd == authMid.passwd) {
+			(hasAuth && util.ConstantTimeEqString(reqUser, authMid.user) &&
+				util.ConstantTimeEqString(reqPasswd, authMid.passwd)) {
 			next.ServeHTTP(w, r)
 		} else {
+			if authMid.authFailDelay > 0 {
+				time.Sleep(authMid.authFailDelay)
+			}
 			w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
 			http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
 		}
 	})
 }
 
-func HTTPBasicAuth(h http.HandlerFunc, user, passwd string) http.HandlerFunc {
-	return func(w http.ResponseWriter, r *http.Request) {
-		reqUser, reqPasswd, hasAuth := r.BasicAuth()
-		if (user == "" && passwd == "") ||
-			(hasAuth && reqUser == user && reqPasswd == passwd) {
-			h.ServeHTTP(w, r)
-		} else {
-			w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
-			http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
-		}
-	}
-}
-
 type HTTPGzipWraper struct {
 	h http.Handler
 }

+ 5 - 0
pkg/util/util/util.go

@@ -17,6 +17,7 @@ package util
 import (
 	"crypto/md5"
 	"crypto/rand"
+	"crypto/subtle"
 	"encoding/hex"
 	"fmt"
 	mathrand "math/rand"
@@ -139,3 +140,7 @@ func RandomSleep(duration time.Duration, minRatio, maxRatio float64) time.Durati
 	time.Sleep(d)
 	return d
 }
+
+func ConstantTimeEqString(a, b string) bool {
+	return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
+}

+ 1 - 1
server/dashboard.go

@@ -50,7 +50,7 @@ func (svr *Service) RunDashboardServer(address string) (err error) {
 	subRouter := router.NewRoute().Subrouter()
 
 	user, passwd := svr.cfg.DashboardUser, svr.cfg.DashboardPwd
-	subRouter.Use(frpNet.NewHTTPAuthMiddleware(user, passwd).Middleware)
+	subRouter.Use(frpNet.NewHTTPAuthMiddleware(user, passwd).SetAuthFailDelay(200 * time.Millisecond).Middleware)
 
 	// metrics
 	if svr.cfg.EnablePrometheus {

+ 4 - 4
test/e2e/framework/framework.go

@@ -66,8 +66,8 @@ func NewDefaultFramework() *Framework {
 	options := Options{
 		TotalParallelNode: suiteConfig.ParallelTotal,
 		CurrentNodeIndex:  suiteConfig.ParallelProcess,
-		FromPortIndex:     20000,
-		ToPortIndex:       50000,
+		FromPortIndex:     10000,
+		ToPortIndex:       60000,
 	}
 	return NewFramework(options)
 }
@@ -118,14 +118,14 @@ func (f *Framework) AfterEach() {
 	// stop processor
 	for _, p := range f.serverProcesses {
 		_ = p.Stop()
-		if TestContext.Debug {
+		if TestContext.Debug || ginkgo.CurrentSpecReport().Failed() {
 			fmt.Println(p.ErrorOutput())
 			fmt.Println(p.StdOutput())
 		}
 	}
 	for _, p := range f.clientProcesses {
 		_ = p.Stop()
-		if TestContext.Debug {
+		if TestContext.Debug || ginkgo.CurrentSpecReport().Failed() {
 			fmt.Println(p.ErrorOutput())
 			fmt.Println(p.StdOutput())
 		}

+ 2 - 2
test/e2e/framework/process.go

@@ -38,7 +38,7 @@ func (f *Framework) RunProcesses(serverTemplates []string, clientTemplates []str
 		err = p.Start()
 		ExpectNoError(err)
 	}
-	time.Sleep(2 * time.Second)
+	time.Sleep(1 * time.Second)
 
 	currentClientProcesses := make([]*process.Process, 0, len(clientTemplates))
 	for i := range clientTemplates {
@@ -56,7 +56,7 @@ func (f *Framework) RunProcesses(serverTemplates []string, clientTemplates []str
 		ExpectNoError(err)
 		time.Sleep(500 * time.Millisecond)
 	}
-	time.Sleep(5 * time.Second)
+	time.Sleep(2 * time.Second)
 
 	return currentServerProcesses, currentClientProcesses
 }

+ 2 - 2
test/e2e/pkg/port/port.go

@@ -58,7 +58,7 @@ func (pa *Allocator) GetByName(portName string) int {
 			return 0
 		}
 
-		l, err := net.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(port)))
+		l, err := net.Listen("tcp", net.JoinHostPort("0.0.0.0", strconv.Itoa(port)))
 		if err != nil {
 			// Maybe not controlled by us, mark it used.
 			pa.used.Insert(port)
@@ -66,7 +66,7 @@ func (pa *Allocator) GetByName(portName string) int {
 		}
 		l.Close()
 
-		udpAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort("127.0.0.1", strconv.Itoa(port)))
+		udpAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort("0.0.0.0", strconv.Itoa(port)))
 		if err != nil {
 			continue
 		}