|
@@ -15,25 +15,13 @@
|
|
|
package vhost
|
|
|
|
|
|
import (
|
|
|
- _ "bufio"
|
|
|
- _ "bytes"
|
|
|
- _ "crypto/tls"
|
|
|
- "errors"
|
|
|
"fmt"
|
|
|
- "frp/utils/conn"
|
|
|
- "frp/utils/log"
|
|
|
"io"
|
|
|
- _ "io/ioutil"
|
|
|
"net"
|
|
|
- _ "net/http"
|
|
|
"strings"
|
|
|
- _ "sync"
|
|
|
"time"
|
|
|
-)
|
|
|
|
|
|
-var (
|
|
|
- maxHandshake int64 = 65536 // maximum handshake we support (protocol max is 16 MB)
|
|
|
- VhostHttpsPort int64 = 443
|
|
|
+ "frp/utils/conn"
|
|
|
)
|
|
|
|
|
|
const (
|
|
@@ -58,160 +46,140 @@ type HttpsMuxer struct {
|
|
|
*VhostMuxer
|
|
|
}
|
|
|
|
|
|
-/*
|
|
|
- RFC document: http://tools.ietf.org/html/rfc5246
|
|
|
-*/
|
|
|
-
|
|
|
-func errMsgToLog(format string, a ...interface{}) error {
|
|
|
- errMsg := fmt.Sprintf(format, a...)
|
|
|
- log.Warn(errMsg)
|
|
|
- return errors.New(errMsg)
|
|
|
+func NewHttpsMuxer(listener *conn.Listener, timeout time.Duration) (*HttpsMuxer, error) {
|
|
|
+ mux, err := NewVhostMuxer(listener, GetHttpsHostname, timeout)
|
|
|
+ return &HttpsMuxer{mux}, err
|
|
|
}
|
|
|
|
|
|
-func readHandshake(rd io.Reader) (string, error) {
|
|
|
-
|
|
|
+func readHandshake(rd io.Reader) (host string, err error) {
|
|
|
data := make([]byte, 1024)
|
|
|
length, err := rd.Read(data)
|
|
|
if err != nil {
|
|
|
- return "", errMsgToLog("read err:%v", err)
|
|
|
+ return
|
|
|
} else {
|
|
|
if length < 47 {
|
|
|
- return "", errMsgToLog("readHandshake: proto length[%d] is too short", length)
|
|
|
+ err = fmt.Errorf("readHandshake: proto length[%d] is too short", length)
|
|
|
+ return
|
|
|
}
|
|
|
}
|
|
|
data = data[:length]
|
|
|
- //log.Warn("data: %+v", data)
|
|
|
if uint8(data[5]) != typeClientHello {
|
|
|
- return "", errMsgToLog("readHandshake: type[%d] is not clientHello", uint16(data[5]))
|
|
|
+ err = fmt.Errorf("readHandshake: type[%d] is not clientHello", uint16(data[5]))
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
- //version and random
|
|
|
- //tlsVersion := uint16(data[9])<<8 | uint16(data[10])
|
|
|
- //random := data[11:43]
|
|
|
-
|
|
|
- //session
|
|
|
+ // session
|
|
|
sessionIdLen := int(data[43])
|
|
|
if sessionIdLen > 32 || len(data) < 44+sessionIdLen {
|
|
|
- return "", errMsgToLog("readHandshake: sessionIdLen[%d] is long", sessionIdLen)
|
|
|
+ err = fmt.Errorf("readHandshake: sessionIdLen[%d] is long", sessionIdLen)
|
|
|
+ return
|
|
|
}
|
|
|
data = data[44+sessionIdLen:]
|
|
|
if len(data) < 2 {
|
|
|
- return "", errMsgToLog("readHandshake: dataLen[%d] after session is short", len(data))
|
|
|
+ err = fmt.Errorf("readHandshake: dataLen[%d] after session is short", len(data))
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
// cipher suite numbers
|
|
|
cipherSuiteLen := int(data[0])<<8 | int(data[1])
|
|
|
if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
|
|
|
- //return "", errMsgToLog("readHandshake: cipherSuiteLen[%d] is long", sessionIdLen)
|
|
|
- return "", errMsgToLog("readHandshake: dataLen[%d] after cipher suite is short", len(data))
|
|
|
+ err = fmt.Errorf("readHandshake: dataLen[%d] after cipher suite is short", len(data))
|
|
|
+ return
|
|
|
}
|
|
|
data = data[2+cipherSuiteLen:]
|
|
|
if len(data) < 1 {
|
|
|
- return "", errMsgToLog("readHandshake: cipherSuiteLen[%d] is long", cipherSuiteLen)
|
|
|
+ err = fmt.Errorf("readHandshake: cipherSuiteLen[%d] is long", cipherSuiteLen)
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
- //compression method
|
|
|
+ // compression method
|
|
|
compressionMethodsLen := int(data[0])
|
|
|
if len(data) < 1+compressionMethodsLen {
|
|
|
- return "", errMsgToLog("readHandshake: compressionMethodsLen[%d] is long", compressionMethodsLen)
|
|
|
- //return false
|
|
|
+ err = fmt.Errorf("readHandshake: compressionMethodsLen[%d] is long", compressionMethodsLen)
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
data = data[1+compressionMethodsLen:]
|
|
|
-
|
|
|
if len(data) == 0 {
|
|
|
// ClientHello is optionally followed by extension data
|
|
|
- //return true
|
|
|
- return "", errMsgToLog("readHandshake: there is no extension data to get servername")
|
|
|
+ err = fmt.Errorf("readHandshake: there is no extension data to get servername")
|
|
|
+ return
|
|
|
}
|
|
|
if len(data) < 2 {
|
|
|
- return "", errMsgToLog("readHandshake: extension dataLen[%d] is too short")
|
|
|
+ err = fmt.Errorf("readHandshake: extension dataLen[%d] is too short")
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
extensionsLength := int(data[0])<<8 | int(data[1])
|
|
|
data = data[2:]
|
|
|
if extensionsLength != len(data) {
|
|
|
- return "", errMsgToLog("readHandshake: extensionsLen[%d] is not equal to dataLen[%d]", extensionsLength, len(data))
|
|
|
+ err = fmt.Errorf("readHandshake: extensionsLen[%d] is not equal to dataLen[%d]", extensionsLength, len(data))
|
|
|
+ return
|
|
|
}
|
|
|
for len(data) != 0 {
|
|
|
if len(data) < 4 {
|
|
|
- return "", errMsgToLog("readHandshake: extensionsDataLen[%d] is too short", len(data))
|
|
|
+ err = fmt.Errorf("readHandshake: extensionsDataLen[%d] is too short", len(data))
|
|
|
+ return
|
|
|
}
|
|
|
extension := uint16(data[0])<<8 | uint16(data[1])
|
|
|
length := int(data[2])<<8 | int(data[3])
|
|
|
data = data[4:]
|
|
|
if len(data) < length {
|
|
|
- return "", errMsgToLog("readHandshake: extensionLen[%d] is long", length)
|
|
|
- //return false
|
|
|
+ err = fmt.Errorf("readHandshake: extensionLen[%d] is long", length)
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
switch extension {
|
|
|
case extensionRenegotiationInfo:
|
|
|
if length != 1 || data[0] != 0 {
|
|
|
- return "", errMsgToLog("readHandshake: extension reNegotiationInfoLen[%d] is short", length)
|
|
|
+ err = fmt.Errorf("readHandshake: extension reNegotiationInfoLen[%d] is short", length)
|
|
|
+ return
|
|
|
}
|
|
|
case extensionNextProtoNeg:
|
|
|
case extensionStatusRequest:
|
|
|
case extensionServerName:
|
|
|
d := data[:length]
|
|
|
if len(d) < 2 {
|
|
|
- return "", errMsgToLog("readHandshake: remiaining dataLen[%d] is short", len(d))
|
|
|
+ err = fmt.Errorf("readHandshake: remiaining dataLen[%d] is short", len(d))
|
|
|
+ return
|
|
|
}
|
|
|
namesLen := int(d[0])<<8 | int(d[1])
|
|
|
d = d[2:]
|
|
|
if len(d) != namesLen {
|
|
|
- return "", errMsgToLog("readHandshake: nameListLen[%d] is not equal to dataLen[%d]", namesLen, len(d))
|
|
|
+ err = fmt.Errorf("readHandshake: nameListLen[%d] is not equal to dataLen[%d]", namesLen, len(d))
|
|
|
+ return
|
|
|
}
|
|
|
for len(d) > 0 {
|
|
|
if len(d) < 3 {
|
|
|
- return "", errMsgToLog("readHandshake: extension serverNameLen[%d] is short", len(d))
|
|
|
+ err = fmt.Errorf("readHandshake: extension serverNameLen[%d] is short", len(d))
|
|
|
+ return
|
|
|
}
|
|
|
nameType := d[0]
|
|
|
nameLen := int(d[1])<<8 | int(d[2])
|
|
|
d = d[3:]
|
|
|
if len(d) < nameLen {
|
|
|
- return "", errMsgToLog("readHandshake: nameLen[%d] is not equal to dataLen[%d]", nameLen, len(d))
|
|
|
+ err = fmt.Errorf("readHandshake: nameLen[%d] is not equal to dataLen[%d]", nameLen, len(d))
|
|
|
+ return
|
|
|
}
|
|
|
if nameType == 0 {
|
|
|
- suffix := ""
|
|
|
- if VhostHttpsPort != 443 {
|
|
|
- suffix = fmt.Sprintf(":%d", VhostHttpsPort)
|
|
|
- }
|
|
|
serverName := string(d[:nameLen])
|
|
|
- domain := strings.ToLower(strings.TrimSpace(serverName)) + suffix
|
|
|
- return domain, nil
|
|
|
- break
|
|
|
+ host = strings.TrimSpace(serverName)
|
|
|
+ return host, nil
|
|
|
}
|
|
|
d = d[nameLen:]
|
|
|
}
|
|
|
}
|
|
|
data = data[length:]
|
|
|
}
|
|
|
- //return "test.codermao.com:8082", nil
|
|
|
- return "", errMsgToLog("Unknow error")
|
|
|
+ err = fmt.Errorf("Unknow error")
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
func GetHttpsHostname(c *conn.Conn) (sc net.Conn, routerName string, err error) {
|
|
|
- log.Info("GetHttpsHostname")
|
|
|
sc, rd := newShareConn(c.TcpConn)
|
|
|
-
|
|
|
host, err := readHandshake(rd)
|
|
|
if err != nil {
|
|
|
return sc, "", err
|
|
|
}
|
|
|
- /*
|
|
|
- if _, ok := c.TcpConn.(*tls.Conn); ok {
|
|
|
- log.Warn("convert to tlsConn success")
|
|
|
- } else {
|
|
|
- log.Warn("convert to tlsConn error")
|
|
|
- }*/
|
|
|
- //tcpConn.
|
|
|
- log.Debug("GetHttpsHostname: %s", host)
|
|
|
-
|
|
|
return sc, host, nil
|
|
|
}
|
|
|
-
|
|
|
-func NewHttpsMuxer(listener *conn.Listener, timeout time.Duration) (*HttpsMuxer, error) {
|
|
|
- mux, err := NewVhostMuxer(listener, GetHttpsHostname, timeout)
|
|
|
- return &HttpsMuxer{mux}, err
|
|
|
-}
|