Browse Source

add https proto for reverse proxy

Maodanping 8 years ago
parent
commit
f3876d69bb

+ 2 - 1
conf/frps.ini

@@ -4,6 +4,7 @@ bind_addr = 0.0.0.0
 bind_port = 7000
 # if you want to support virtual host, you must set the http port for listening (optional)
 vhost_http_port = 80
+vhost_https_port = 443
 # if you want to configure or reload frps by dashboard, dashboard_port must be set
 dashboard_port = 7500
 # console or real logFile path like ./frps.log
@@ -20,7 +21,7 @@ bind_addr = 0.0.0.0
 listen_port = 6000
 
 [web01]
-type = http
+type = https
 auth_token = 123
 # if proxy type equals http, custom_domains must be set separated by commas
 custom_domains = web01.yourdomain.com,web01.yourdomain2.com

+ 14 - 1
src/frp/cmd/frps/main.go

@@ -143,12 +143,25 @@ func main() {
 			log.Error("Create vhost http listener error, %v", err)
 			os.Exit(1)
 		}
-		server.VhostMuxer, err = vhost.NewHttpMuxer(vhostListener, 30*time.Second)
+		server.VhostHttpMuxer, err = vhost.NewHttpMuxer(vhostListener, 30*time.Second)
 		if err != nil {
 			log.Error("Create vhost httpMuxer error, %v", err)
 		}
 	}
 
+	// create vhost if VhostHttpPort != 0
+	if server.VhostHttpsPort != 0 {
+		vhostListener, err := conn.Listen(server.BindAddr, server.VhostHttpsPort)
+		if err != nil {
+			log.Error("Create vhost https listener error, %v", err)
+			os.Exit(1)
+		}
+		server.VhostHttpsMuxer, err = vhost.NewHttpsMuxer(vhostListener, 30*time.Second)
+		if err != nil {
+			log.Error("Create vhost httpsMuxer error, %v", err)
+		}
+	}
+
 	// create dashboard web server if DashboardPort is set, so it won't be 0
 	if server.DashboardPort != 0 {
 		err := server.RunDashboardServer(server.BindAddr, server.DashboardPort)

+ 1 - 1
src/frp/models/client/config.go

@@ -115,7 +115,7 @@ func LoadConf(confFile string) (err error) {
 			proxyClient.Type = "tcp"
 			typeStr, ok := section["type"]
 			if ok {
-				if typeStr != "tcp" && typeStr != "http" {
+				if typeStr != "tcp" && typeStr != "http" && typeStr != "https" {
 					return fmt.Errorf("Parse ini file error: proxy [%s] type error", proxyClient.Name)
 				}
 				proxyClient.Type = typeStr

+ 29 - 2
src/frp/models/server/config.go

@@ -32,6 +32,7 @@ var (
 	BindAddr         string = "0.0.0.0"
 	BindPort         int64  = 7000
 	VhostHttpPort    int64  = 0 // if VhostHttpPort equals 0, don't listen a public port for http
+	VhostHttpsPort   int64  = 0 // if VhostHttpsPort equals 0, don't listen a public port for http
 	DashboardPort    int64  = 0 // if DashboardPort equals 0, dashboard is not available
 	LogFile          string = "console"
 	LogWay           string = "console" // console or file
@@ -40,7 +41,8 @@ var (
 	HeartBeatTimeout int64  = 90
 	UserConnTimeout  int64  = 10
 
-	VhostMuxer        *vhost.HttpMuxer
+	VhostHttpMuxer    *vhost.HttpMuxer
+	VhostHttpsMuxer   *vhost.HttpsMuxer
 	ProxyServers      map[string]*ProxyServer = make(map[string]*ProxyServer) // all proxy servers info and resources
 	ProxyServersMutex sync.RWMutex
 )
@@ -91,6 +93,14 @@ func loadCommonConf(confFile string) error {
 		VhostHttpPort = 0
 	}
 
+	tmpStr, ok = conf.Get("common", "vhost_https_port")
+	if ok {
+		VhostHttpsPort, _ = strconv.ParseInt(tmpStr, 10, 64)
+	} else {
+		VhostHttpsPort = 0
+	}
+	vhost.VhostHttpsPort = VhostHttpsPort
+
 	tmpStr, ok = conf.Get("common", "dashboard_port")
 	if ok {
 		DashboardPort, _ = strconv.ParseInt(tmpStr, 10, 64)
@@ -135,7 +145,7 @@ func loadProxyConf(confFile string) (proxyServers map[string]*ProxyServer, err e
 
 			proxyServer.Type, ok = section["type"]
 			if ok {
-				if proxyServer.Type != "tcp" && proxyServer.Type != "http" {
+				if proxyServer.Type != "tcp" && proxyServer.Type != "http" && proxyServer.Type != "https" {
 					return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] type error", proxyServer.Name)
 				}
 			} else {
@@ -179,6 +189,23 @@ func loadProxyConf(confFile string) (proxyServers map[string]*ProxyServer, err e
 						proxyServer.CustomDomains[i] = strings.ToLower(strings.TrimSpace(domain)) + suffix
 					}
 				}
+			} else if proxyServer.Type == "https" {
+				// for https
+				domainStr, ok := section["custom_domains"]
+				if ok {
+					var suffix string
+					if VhostHttpsPort != 443 {
+						suffix = fmt.Sprintf(":%d", VhostHttpsPort)
+					}
+					proxyServer.CustomDomains = strings.Split(domainStr, ",")
+					if len(proxyServer.CustomDomains) == 0 {
+						return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] custom_domains must be set when type equals http", proxyServer.Name)
+					}
+					for i, domain := range proxyServer.CustomDomains {
+						proxyServer.CustomDomains[i] = strings.ToLower(strings.TrimSpace(domain)) + suffix
+					}
+					log.Info("proxyServer: %+v", proxyServer.CustomDomains)
+				}
 			}
 			proxyServers[proxyServer.Name] = proxyServer
 		}

+ 9 - 1
src/frp/models/server/server.go

@@ -100,7 +100,15 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) {
 		p.listeners = append(p.listeners, l)
 	} else if p.Type == "http" {
 		for _, domain := range p.CustomDomains {
-			l, err := VhostMuxer.Listen(domain)
+			l, err := VhostHttpMuxer.Listen(domain)
+			if err != nil {
+				return err
+			}
+			p.listeners = append(p.listeners, l)
+		}
+	} else if p.Type == "https" {
+		for _, domain := range p.CustomDomains {
+			l, err := VhostHttpsMuxer.Listen(domain)
 			if err != nil {
 				return err
 			}

+ 217 - 0
src/frp/utils/vhost/vhost_https.go

@@ -0,0 +1,217 @@
+// Copyright 2016 fatedier, fatedier@gmail.com
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vhost
+
+import (
+	_ "bufio"
+	_ "bytes"
+	_ "crypto/tls"
+	"errors"
+	"fmt"
+	"frp/utils/conn"
+	"frp/utils/log"
+	"io"
+	_ "io/ioutil"
+	"net"
+	_ "net/http"
+	"strings"
+	_ "sync"
+	"time"
+)
+
+var (
+	maxHandshake   int64 = 65536 // maximum handshake we support (protocol max is 16 MB)
+	VhostHttpsPort int64 = 443
+)
+
+const (
+	typeClientHello uint8 = 1 // Type client hello
+)
+
+// TLS extension numbers
+const (
+	extensionServerName          uint16 = 0
+	extensionStatusRequest       uint16 = 5
+	extensionSupportedCurves     uint16 = 10
+	extensionSupportedPoints     uint16 = 11
+	extensionSignatureAlgorithms uint16 = 13
+	extensionALPN                uint16 = 16
+	extensionSCT                 uint16 = 18
+	extensionSessionTicket       uint16 = 35
+	extensionNextProtoNeg        uint16 = 13172 // not IANA assigned
+	extensionRenegotiationInfo   uint16 = 0xff01
+)
+
+type HttpsMuxer struct {
+	*VhostMuxer
+}
+
+/*
+   RFC document: http://tools.ietf.org/html/rfc5246
+*/
+
+func errMsgToLog(format string, a ...interface{}) error {
+	errMsg := fmt.Sprintf(format, a...)
+	log.Warn(errMsg)
+	return errors.New(errMsg)
+}
+
+func readHandshake(rd io.Reader) (string, error) {
+
+	data := make([]byte, 1024)
+	length, err := rd.Read(data)
+	if err != nil {
+		return "", errMsgToLog("read err:%v", err)
+	} else {
+		if length < 47 {
+			return "", errMsgToLog("readHandshake: proto length[%d] is too short", length)
+		}
+	}
+	data = data[:length]
+	//log.Warn("data: %+v", data)
+	if uint8(data[5]) != typeClientHello {
+		return "", errMsgToLog("readHandshake: type[%d] is not clientHello", uint16(data[5]))
+	}
+
+	//version and random
+	//tlsVersion := uint16(data[9])<<8 | uint16(data[10])
+	//random := data[11:43]
+
+	//session
+	sessionIdLen := int(data[43])
+	if sessionIdLen > 32 || len(data) < 44+sessionIdLen {
+		return "", errMsgToLog("readHandshake: sessionIdLen[%d] is long", sessionIdLen)
+	}
+	data = data[44+sessionIdLen:]
+	if len(data) < 2 {
+		return "", errMsgToLog("readHandshake: dataLen[%d] after session is short", len(data))
+	}
+
+	// cipher suite numbers
+	cipherSuiteLen := int(data[0])<<8 | int(data[1])
+	if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
+		//return "", errMsgToLog("readHandshake: cipherSuiteLen[%d] is long", sessionIdLen)
+		return "", errMsgToLog("readHandshake: dataLen[%d] after cipher suite is short", len(data))
+	}
+	data = data[2+cipherSuiteLen:]
+	if len(data) < 1 {
+		return "", errMsgToLog("readHandshake: cipherSuiteLen[%d] is long", cipherSuiteLen)
+	}
+
+	//compression method
+	compressionMethodsLen := int(data[0])
+	if len(data) < 1+compressionMethodsLen {
+		return "", errMsgToLog("readHandshake: compressionMethodsLen[%d] is long", compressionMethodsLen)
+		//return false
+	}
+
+	data = data[1+compressionMethodsLen:]
+
+	if len(data) == 0 {
+		// ClientHello is optionally followed by extension data
+		//return true
+		return "", errMsgToLog("readHandshake: there is no extension data to get servername")
+	}
+	if len(data) < 2 {
+		return "", errMsgToLog("readHandshake: extension dataLen[%d] is too short")
+	}
+
+	extensionsLength := int(data[0])<<8 | int(data[1])
+	data = data[2:]
+	if extensionsLength != len(data) {
+		return "", errMsgToLog("readHandshake: extensionsLen[%d] is not equal to dataLen[%d]", extensionsLength, len(data))
+	}
+	for len(data) != 0 {
+		if len(data) < 4 {
+			return "", errMsgToLog("readHandshake: extensionsDataLen[%d] is too short", len(data))
+		}
+		extension := uint16(data[0])<<8 | uint16(data[1])
+		length := int(data[2])<<8 | int(data[3])
+		data = data[4:]
+		if len(data) < length {
+			return "", errMsgToLog("readHandshake: extensionLen[%d] is long", length)
+			//return false
+		}
+
+		switch extension {
+		case extensionRenegotiationInfo:
+			if length != 1 || data[0] != 0 {
+				return "", errMsgToLog("readHandshake: extension reNegotiationInfoLen[%d] is short", length)
+			}
+		case extensionNextProtoNeg:
+		case extensionStatusRequest:
+		case extensionServerName:
+			d := data[:length]
+			if len(d) < 2 {
+				return "", errMsgToLog("readHandshake: remiaining dataLen[%d] is short", len(d))
+			}
+			namesLen := int(d[0])<<8 | int(d[1])
+			d = d[2:]
+			if len(d) != namesLen {
+				return "", errMsgToLog("readHandshake: nameListLen[%d] is not equal to dataLen[%d]", namesLen, len(d))
+			}
+			for len(d) > 0 {
+				if len(d) < 3 {
+					return "", errMsgToLog("readHandshake: extension serverNameLen[%d] is short", len(d))
+				}
+				nameType := d[0]
+				nameLen := int(d[1])<<8 | int(d[2])
+				d = d[3:]
+				if len(d) < nameLen {
+					return "", errMsgToLog("readHandshake: nameLen[%d] is not equal to dataLen[%d]", nameLen, len(d))
+				}
+				if nameType == 0 {
+					suffix := ""
+					if VhostHttpsPort != 443 {
+						suffix = fmt.Sprintf(":%d", VhostHttpsPort)
+					}
+					serverName := string(d[:nameLen])
+					domain := strings.ToLower(strings.TrimSpace(serverName)) + suffix
+					return domain, nil
+					break
+				}
+				d = d[nameLen:]
+			}
+		}
+		data = data[length:]
+	}
+	//return "test.codermao.com:8082", nil
+	return "", errMsgToLog("Unknow error")
+}
+
+func GetHttpsHostname(c *conn.Conn) (sc net.Conn, routerName string, err error) {
+	log.Info("GetHttpsHostname")
+	sc, rd := newShareConn(c.TcpConn)
+
+	host, err := readHandshake(rd)
+	if err != nil {
+		return sc, "", err
+	}
+	/*
+		if _, ok := c.TcpConn.(*tls.Conn); ok {
+			log.Warn("convert to tlsConn success")
+		} else {
+			log.Warn("convert to tlsConn error")
+		}*/
+	//tcpConn.
+	log.Debug("GetHttpsHostname: %s", host)
+
+	return sc, host, nil
+}
+
+func NewHttpsMuxer(listener *conn.Listener, timeout time.Duration) (*HttpsMuxer, error) {
+	mux, err := NewVhostMuxer(listener, GetHttpsHostname, timeout)
+	return &HttpsMuxer{mux}, err
+}