Browse Source

return ssl alert unrecognized_name when https domain not registered (#3620)

Zeyu Dong 1 year ago
parent
commit
5c8ea51eb5

+ 11 - 1
pkg/util/tcpmux/httpconnect.go

@@ -40,7 +40,8 @@ func NewHTTPConnectTCPMuxer(listener net.Listener, passthrough bool, timeout tim
 	ret := &HTTPConnectTCPMuxer{passthrough: passthrough}
 	mux, err := vhost.NewMuxer(listener, ret.getHostFromHTTPConnect, timeout)
 	mux.SetCheckAuthFunc(ret.auth).
-		SetSuccessHookFunc(ret.sendConnectResponse)
+		SetSuccessHookFunc(ret.sendConnectResponse).
+		SetFailHookFunc(vhostFailed)
 	ret.Muxer = mux
 	return ret, err
 }
@@ -92,6 +93,15 @@ func (muxer *HTTPConnectTCPMuxer) auth(c net.Conn, username, password string, re
 	return false, nil
 }
 
+func vhostFailed(c net.Conn) {
+	res := vhost.NotFoundResponse()
+	if res.Body != nil {
+		defer res.Body.Close()
+	}
+	_ = res.Write(c)
+	_ = c.Close()
+}
+
 func (muxer *HTTPConnectTCPMuxer) getHostFromHTTPConnect(c net.Conn) (net.Conn, map[string]string, error) {
 	reqInfoMap := make(map[string]string, 0)
 	sc, rd := libnet.NewSharedConn(c)

+ 1 - 1
pkg/util/vhost/http.go

@@ -251,7 +251,7 @@ func (rp *HTTPReverseProxy) connectHandler(rw http.ResponseWriter, req *http.Req
 
 	remote, err := rp.CreateConnection(req.Context().Value(RouteInfoKey).(*RequestRouteInfo), false)
 	if err != nil {
-		_ = notFoundResponse().Write(client)
+		_ = NotFoundResponse().Write(client)
 		client.Close()
 		return
 	}

+ 7 - 0
pkg/util/vhost/https.go

@@ -29,6 +29,7 @@ type HTTPSMuxer struct {
 
 func NewHTTPSMuxer(listener net.Listener, timeout time.Duration) (*HTTPSMuxer, error) {
 	mux, err := NewMuxer(listener, GetHTTPSHostname, timeout)
+	mux.SetFailHookFunc(vhostFailed)
 	if err != nil {
 		return nil, err
 	}
@@ -69,6 +70,12 @@ func readClientHello(reader io.Reader) (*tls.ClientHelloInfo, error) {
 	return hello, nil
 }
 
+func vhostFailed(c net.Conn) {
+	// Alert with alertUnrecognizedName
+	_ = tls.Server(c, &tls.Config{}).Handshake()
+	c.Close()
+}
+
 type readOnlyConn struct {
 	reader io.Reader
 }

+ 1 - 1
pkg/util/vhost/resource.go

@@ -67,7 +67,7 @@ func getNotFoundPageContent() []byte {
 	return buf
 }
 
-func notFoundResponse() *http.Response {
+func NotFoundResponse() *http.Response {
 	header := make(http.Header)
 	header.Set("server", "frp/"+version.Full())
 	header.Set("Content-Type", "text/html")

+ 8 - 6
pkg/util/vhost/vhost.go

@@ -46,6 +46,7 @@ type (
 	authFunc        func(conn net.Conn, username, password string, reqInfoMap map[string]string) (bool, error)
 	hostRewriteFunc func(net.Conn, string) (net.Conn, error)
 	successHookFunc func(net.Conn, map[string]string) error
+	failHookFunc    func(net.Conn)
 )
 
 // Muxer is a functional component used for https and tcpmux proxies.
@@ -58,6 +59,7 @@ type Muxer struct {
 	vhostFunc      muxFunc
 	checkAuth      authFunc
 	successHook    successHookFunc
+	failHook       failHookFunc
 	rewriteHost    hostRewriteFunc
 	registryRouter *Routers
 }
@@ -87,6 +89,11 @@ func (v *Muxer) SetSuccessHookFunc(f successHookFunc) *Muxer {
 	return v
 }
 
+func (v *Muxer) SetFailHookFunc(f failHookFunc) *Muxer {
+	v.failHook = f
+	return v
+}
+
 func (v *Muxer) SetRewriteHostFunc(f hostRewriteFunc) *Muxer {
 	v.rewriteHost = f
 	return v
@@ -206,13 +213,8 @@ func (v *Muxer) handle(c net.Conn) {
 	httpUser := reqInfoMap["HTTPUser"]
 	l, ok := v.getListener(name, path, httpUser)
 	if !ok {
-		res := notFoundResponse()
-		if res.Body != nil {
-			defer res.Body.Close()
-		}
-		_ = res.Write(c)
 		log.Debug("http request for host [%s] path [%s] httpUser [%s] not found", name, path, httpUser)
-		_ = c.Close()
+		v.failHook(sConn)
 		return
 	}