Browse Source

vhost: use new readClientHello function (#2504)

fatedier 3 years ago
parent
commit
86b2e686a5
3 changed files with 80 additions and 155 deletions
  1. 38 151
      pkg/util/vhost/https.go
  2. 38 0
      pkg/util/vhost/https_test.go
  3. 4 4
      test/e2e/mock/server/httpserver/server.go

+ 38 - 151
pkg/util/vhost/https.go

@@ -15,32 +15,12 @@
 package vhost
 
 import (
-	"fmt"
+	"crypto/tls"
 	"io"
 	"net"
-	"strings"
 	"time"
 
 	gnet "github.com/fatedier/golib/net"
-	"github.com/fatedier/golib/pool"
-)
-
-const (
-	typeClientHello uint8 = 1 // Type client hello
-)
-
-// TLS extension numbers
-const (
-	extensionServerName          uint16 = 0
-	extensionStatusRequest       uint16 = 5
-	extensionSupportedCurves     uint16 = 10
-	extensionSupportedPoints     uint16 = 11
-	extensionSignatureAlgorithms uint16 = 13
-	extensionALPN                uint16 = 16
-	extensionSCT                 uint16 = 18
-	extensionSessionTicket       uint16 = 35
-	extensionNextProtoNeg        uint16 = 13172 // not IANA assigned
-	extensionRenegotiationInfo   uint16 = 0xff01
 )
 
 type HTTPSMuxer struct {
@@ -52,142 +32,49 @@ func NewHTTPSMuxer(listener net.Listener, timeout time.Duration) (*HTTPSMuxer, e
 	return &HTTPSMuxer{mux}, err
 }
 
-func readHandshake(rd io.Reader) (host string, err error) {
-	data := pool.GetBuf(1024)
-	origin := data
-	defer pool.PutBuf(origin)
-
-	_, err = io.ReadFull(rd, data[:47])
-	if err != nil {
-		return
-	}
-
-	length, err := rd.Read(data[47:])
-	if err != nil {
-		return
-	}
-	length += 47
-	data = data[:length]
-	if uint8(data[5]) != typeClientHello {
-		err = fmt.Errorf("readHandshake: type[%d] is not clientHello", uint16(data[5]))
-		return
-	}
-
-	// session
-	sessionIDLen := int(data[43])
-	if sessionIDLen > 32 || len(data) < 44+sessionIDLen {
-		err = fmt.Errorf("readHandshake: sessionIdLen[%d] is long", sessionIDLen)
-		return
-	}
-	data = data[44+sessionIDLen:]
-	if len(data) < 2 {
-		err = fmt.Errorf("readHandshake: dataLen[%d] after session is short", len(data))
-		return
-	}
-
-	// cipher suite numbers
-	cipherSuiteLen := int(data[0])<<8 | int(data[1])
-	if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
-		err = fmt.Errorf("readHandshake: dataLen[%d] after cipher suite is short", len(data))
-		return
-	}
-	data = data[2+cipherSuiteLen:]
-	if len(data) < 1 {
-		err = fmt.Errorf("readHandshake: cipherSuiteLen[%d] is long", cipherSuiteLen)
-		return
-	}
-
-	// compression method
-	compressionMethodsLen := int(data[0])
-	if len(data) < 1+compressionMethodsLen {
-		err = fmt.Errorf("readHandshake: compressionMethodsLen[%d] is long", compressionMethodsLen)
-		return
-	}
-
-	data = data[1+compressionMethodsLen:]
-	if len(data) == 0 {
-		// ClientHello is optionally followed by extension data
-		err = fmt.Errorf("readHandshake: there is no extension data to get servername")
-		return
-	}
-	if len(data) < 2 {
-		err = fmt.Errorf("readHandshake: extension dataLen[%d] is too short", len(data))
-		return
-	}
-
-	extensionsLength := int(data[0])<<8 | int(data[1])
-	data = data[2:]
-	if extensionsLength != len(data) {
-		err = fmt.Errorf("readHandshake: extensionsLen[%d] is not equal to dataLen[%d]", extensionsLength, len(data))
-		return
-	}
-	for len(data) != 0 {
-		if len(data) < 4 {
-			err = fmt.Errorf("readHandshake: extensionsDataLen[%d] is too short", len(data))
-			return
-		}
-		extension := uint16(data[0])<<8 | uint16(data[1])
-		length := int(data[2])<<8 | int(data[3])
-		data = data[4:]
-		if len(data) < length {
-			err = fmt.Errorf("readHandshake: extensionLen[%d] is long", length)
-			return
-		}
-
-		switch extension {
-		case extensionRenegotiationInfo:
-			if length != 1 || data[0] != 0 {
-				err = fmt.Errorf("readHandshake: extension reNegotiationInfoLen[%d] is short", length)
-				return
-			}
-		case extensionNextProtoNeg:
-		case extensionStatusRequest:
-		case extensionServerName:
-			d := data[:length]
-			if len(d) < 2 {
-				err = fmt.Errorf("readHandshake: remiaining dataLen[%d] is short", len(d))
-				return
-			}
-			namesLen := int(d[0])<<8 | int(d[1])
-			d = d[2:]
-			if len(d) != namesLen {
-				err = fmt.Errorf("readHandshake: nameListLen[%d] is not equal to dataLen[%d]", namesLen, len(d))
-				return
-			}
-			for len(d) > 0 {
-				if len(d) < 3 {
-					err = fmt.Errorf("readHandshake: extension serverNameLen[%d] is short", len(d))
-					return
-				}
-				nameType := d[0]
-				nameLen := int(d[1])<<8 | int(d[2])
-				d = d[3:]
-				if len(d) < nameLen {
-					err = fmt.Errorf("readHandshake: nameLen[%d] is not equal to dataLen[%d]", nameLen, len(d))
-					return
-				}
-				if nameType == 0 {
-					serverName := string(d[:nameLen])
-					host = strings.TrimSpace(serverName)
-					return host, nil
-				}
-				d = d[nameLen:]
-			}
-		}
-		data = data[length:]
-	}
-	err = fmt.Errorf("Unknown error")
-	return
-}
-
 func GetHTTPSHostname(c net.Conn) (_ net.Conn, _ map[string]string, err error) {
 	reqInfoMap := make(map[string]string, 0)
 	sc, rd := gnet.NewSharedConn(c)
-	host, err := readHandshake(rd)
+
+	clientHello, err := readClientHello(rd)
 	if err != nil {
 		return nil, reqInfoMap, err
 	}
-	reqInfoMap["Host"] = host
+
+	reqInfoMap["Host"] = clientHello.ServerName
 	reqInfoMap["Scheme"] = "https"
 	return sc, reqInfoMap, nil
 }
+
+func readClientHello(reader io.Reader) (*tls.ClientHelloInfo, error) {
+	var hello *tls.ClientHelloInfo
+
+	// Note that Handshake always fails because the readOnlyConn is not a real connection.
+	// As long as the Client Hello is successfully read, the failure should only happen after GetConfigForClient is called,
+	// so we only care about the error if hello was never set.
+	err := tls.Server(readOnlyConn{reader: reader}, &tls.Config{
+		GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
+			hello = &tls.ClientHelloInfo{}
+			*hello = *argHello
+			return nil, nil
+		},
+	}).Handshake()
+
+	if hello == nil {
+		return nil, err
+	}
+	return hello, nil
+}
+
+type readOnlyConn struct {
+	reader io.Reader
+}
+
+func (conn readOnlyConn) Read(p []byte) (int, error)         { return conn.reader.Read(p) }
+func (conn readOnlyConn) Write(p []byte) (int, error)        { return 0, io.ErrClosedPipe }
+func (conn readOnlyConn) Close() error                       { return nil }
+func (conn readOnlyConn) LocalAddr() net.Addr                { return nil }
+func (conn readOnlyConn) RemoteAddr() net.Addr               { return nil }
+func (conn readOnlyConn) SetDeadline(t time.Time) error      { return nil }
+func (conn readOnlyConn) SetReadDeadline(t time.Time) error  { return nil }
+func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }

+ 38 - 0
pkg/util/vhost/https_test.go

@@ -0,0 +1,38 @@
+package vhost
+
+import (
+	"crypto/tls"
+	"net"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/require"
+)
+
+func TestGetHTTPSHostname(t *testing.T) {
+	require := require.New(t)
+
+	l, err := net.Listen("tcp", ":")
+	require.NoError(err)
+	defer l.Close()
+
+	var conn net.Conn
+	go func() {
+		conn, _ = l.Accept()
+		require.NotNil(conn)
+	}()
+
+	go func() {
+		time.Sleep(100 * time.Millisecond)
+		tls.Dial("tcp", l.Addr().String(), &tls.Config{
+			InsecureSkipVerify: true,
+			ServerName:         "example.com",
+		})
+	}()
+
+	time.Sleep(200 * time.Millisecond)
+	_, infos, err := GetHTTPSHostname(conn)
+	require.NoError(err)
+	require.Equal("example.com", infos["Host"])
+	require.Equal("https", infos["Scheme"])
+}

+ 4 - 4
test/e2e/mock/server/httpserver/server.go

@@ -11,7 +11,7 @@ import (
 type Server struct {
 	bindAddr string
 	bindPort int
-	hanlder  http.Handler
+	handler  http.Handler
 
 	l         net.Listener
 	tlsConfig *tls.Config
@@ -54,14 +54,14 @@ func WithTlsConfig(tlsConfig *tls.Config) Option {
 
 func WithHandler(h http.Handler) Option {
 	return func(s *Server) *Server {
-		s.hanlder = h
+		s.handler = h
 		return s
 	}
 }
 
 func WithResponse(resp []byte) Option {
 	return func(s *Server) *Server {
-		s.hanlder = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		s.handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 			w.Write(resp)
 		})
 		return s
@@ -76,7 +76,7 @@ func (s *Server) Run() error {
 	addr := net.JoinHostPort(s.bindAddr, strconv.Itoa(s.bindPort))
 	hs := &http.Server{
 		Addr:      addr,
-		Handler:   s.hanlder,
+		Handler:   s.handler,
 		TLSConfig: s.tlsConfig,
 	}