Browse Source

utils/vhost: update for vhost_https

fatedier 8 years ago
parent
commit
b14441d5cd

+ 4 - 2
conf/frps.ini

@@ -21,12 +21,14 @@ bind_addr = 0.0.0.0
 listen_port = 6000
 
 [web01]
-type = https
+# if type equals http, vhost_http_port must be set
+type = http
 auth_token = 123
 # if proxy type equals http, custom_domains must be set separated by commas
 custom_domains = web01.yourdomain.com,web01.yourdomain2.com
 
 [web02]
-type = http
+# if type equals https, vhost_https_port must be set
+type = https
 auth_token = 123
 custom_domains = web02.yourdomain.com

+ 6 - 1
src/frp/cmd/frps/control.go

@@ -225,11 +225,16 @@ func doLogin(req *msg.ControlReq, c *conn.Conn) (ret int64, info string) {
 		}
 
 		// check if vhost_port is set
-		if s.Type == "http" && server.VhostMuxer == nil {
+		if s.Type == "http" && server.VhostHttpMuxer == nil {
 			info = fmt.Sprintf("ProxyName [%s], type [http] not support when vhost_http_port is not set", req.ProxyName)
 			log.Warn(info)
 			return
 		}
+		if s.Type == "https" && server.VhostHttpsMuxer == nil {
+			info = fmt.Sprintf("ProxyName [%s], type [https] not support when vhost_https_port is not set", req.ProxyName)
+			log.Warn(info)
+			return
+		}
 
 		// set infomations from frpc
 		s.UseEncryption = req.UseEncryption

+ 5 - 15
src/frp/models/server/config.go

@@ -31,8 +31,8 @@ var (
 	ConfigFile       string = "./frps.ini"
 	BindAddr         string = "0.0.0.0"
 	BindPort         int64  = 7000
-	VhostHttpPort    int64  = 0 // if VhostHttpPort equals 0, don't listen a public port for http
-	VhostHttpsPort   int64  = 0 // if VhostHttpsPort equals 0, don't listen a public port for http
+	VhostHttpPort    int64  = 0 // if VhostHttpPort equals 0, don't listen a public port for http protocol
+	VhostHttpsPort   int64  = 0 // if VhostHttpsPort equals 0, don't listen a public port for https protocol
 	DashboardPort    int64  = 0 // if DashboardPort equals 0, dashboard is not available
 	LogFile          string = "console"
 	LogWay           string = "console" // console or file
@@ -102,7 +102,6 @@ func loadCommonConf(confFile string) error {
 	} else {
 		VhostHttpsPort = 0
 	}
-	vhost.VhostHttpsPort = VhostHttpsPort
 
 	tmpStr, ok = conf.Get("common", "dashboard_port")
 	if ok {
@@ -183,34 +182,25 @@ func loadProxyConf(confFile string) (proxyServers map[string]*ProxyServer, err e
 				// for http
 				domainStr, ok := section["custom_domains"]
 				if ok {
-					var suffix string
-					if VhostHttpPort != 80 {
-						suffix = fmt.Sprintf(":%d", VhostHttpPort)
-					}
 					proxyServer.CustomDomains = strings.Split(domainStr, ",")
 					if len(proxyServer.CustomDomains) == 0 {
 						return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] custom_domains must be set when type equals http", proxyServer.Name)
 					}
 					for i, domain := range proxyServer.CustomDomains {
-						proxyServer.CustomDomains[i] = strings.ToLower(strings.TrimSpace(domain)) + suffix
+						proxyServer.CustomDomains[i] = strings.ToLower(strings.TrimSpace(domain))
 					}
 				}
 			} else if proxyServer.Type == "https" {
 				// for https
 				domainStr, ok := section["custom_domains"]
 				if ok {
-					var suffix string
-					if VhostHttpsPort != 443 {
-						suffix = fmt.Sprintf(":%d", VhostHttpsPort)
-					}
 					proxyServer.CustomDomains = strings.Split(domainStr, ",")
 					if len(proxyServer.CustomDomains) == 0 {
-						return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] custom_domains must be set when type equals http", proxyServer.Name)
+						return proxyServers, fmt.Errorf("Parse conf error: proxy [%s] custom_domains must be set when type equals https", proxyServer.Name)
 					}
 					for i, domain := range proxyServer.CustomDomains {
-						proxyServer.CustomDomains[i] = strings.ToLower(strings.TrimSpace(domain)) + suffix
+						proxyServer.CustomDomains[i] = strings.ToLower(strings.TrimSpace(domain))
 					}
-					log.Info("proxyServer: %+v", proxyServer.CustomDomains)
 				}
 			}
 			proxyServers[proxyServer.Name] = proxyServer

+ 47 - 0
src/frp/utils/vhost/http.go

@@ -0,0 +1,47 @@
+// Copyright 2016 fatedier, fatedier@gmail.com
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vhost
+
+import (
+	"bufio"
+	"net"
+	"net/http"
+	"strings"
+	"time"
+
+	"frp/utils/conn"
+)
+
+type HttpMuxer struct {
+	*VhostMuxer
+}
+
+func GetHttpHostname(c *conn.Conn) (_ net.Conn, routerName string, err error) {
+	sc, rd := newShareConn(c.TcpConn)
+
+	request, err := http.ReadRequest(bufio.NewReader(rd))
+	if err != nil {
+		return sc, "", err
+	}
+	tmpArr := strings.Split(request.Host, ":")
+	routerName = tmpArr[0]
+	request.Body.Close()
+	return sc, routerName, nil
+}
+
+func NewHttpMuxer(listener *conn.Listener, timeout time.Duration) (*HttpMuxer, error) {
+	mux, err := NewVhostMuxer(listener, GetHttpHostname, timeout)
+	return &HttpMuxer{mux}, err
+}

+ 46 - 78
src/frp/utils/vhost/vhost_https.go → src/frp/utils/vhost/https.go

@@ -15,25 +15,13 @@
 package vhost
 
 import (
-	_ "bufio"
-	_ "bytes"
-	_ "crypto/tls"
-	"errors"
 	"fmt"
-	"frp/utils/conn"
-	"frp/utils/log"
 	"io"
-	_ "io/ioutil"
 	"net"
-	_ "net/http"
 	"strings"
-	_ "sync"
 	"time"
-)
 
-var (
-	maxHandshake   int64 = 65536 // maximum handshake we support (protocol max is 16 MB)
-	VhostHttpsPort int64 = 443
+	"frp/utils/conn"
 )
 
 const (
@@ -58,160 +46,140 @@ type HttpsMuxer struct {
 	*VhostMuxer
 }
 
-/*
-   RFC document: http://tools.ietf.org/html/rfc5246
-*/
-
-func errMsgToLog(format string, a ...interface{}) error {
-	errMsg := fmt.Sprintf(format, a...)
-	log.Warn(errMsg)
-	return errors.New(errMsg)
+func NewHttpsMuxer(listener *conn.Listener, timeout time.Duration) (*HttpsMuxer, error) {
+	mux, err := NewVhostMuxer(listener, GetHttpsHostname, timeout)
+	return &HttpsMuxer{mux}, err
 }
 
-func readHandshake(rd io.Reader) (string, error) {
-
+func readHandshake(rd io.Reader) (host string, err error) {
 	data := make([]byte, 1024)
 	length, err := rd.Read(data)
 	if err != nil {
-		return "", errMsgToLog("read err:%v", err)
+		return
 	} else {
 		if length < 47 {
-			return "", errMsgToLog("readHandshake: proto length[%d] is too short", length)
+			err = fmt.Errorf("readHandshake: proto length[%d] is too short", length)
+			return
 		}
 	}
 	data = data[:length]
-	//log.Warn("data: %+v", data)
 	if uint8(data[5]) != typeClientHello {
-		return "", errMsgToLog("readHandshake: type[%d] is not clientHello", uint16(data[5]))
+		err = fmt.Errorf("readHandshake: type[%d] is not clientHello", uint16(data[5]))
+		return
 	}
 
-	//version and random
-	//tlsVersion := uint16(data[9])<<8 | uint16(data[10])
-	//random := data[11:43]
-
-	//session
+	// session
 	sessionIdLen := int(data[43])
 	if sessionIdLen > 32 || len(data) < 44+sessionIdLen {
-		return "", errMsgToLog("readHandshake: sessionIdLen[%d] is long", sessionIdLen)
+		err = fmt.Errorf("readHandshake: sessionIdLen[%d] is long", sessionIdLen)
+		return
 	}
 	data = data[44+sessionIdLen:]
 	if len(data) < 2 {
-		return "", errMsgToLog("readHandshake: dataLen[%d] after session is short", len(data))
+		err = fmt.Errorf("readHandshake: dataLen[%d] after session is short", len(data))
+		return
 	}
 
 	// cipher suite numbers
 	cipherSuiteLen := int(data[0])<<8 | int(data[1])
 	if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
-		//return "", errMsgToLog("readHandshake: cipherSuiteLen[%d] is long", sessionIdLen)
-		return "", errMsgToLog("readHandshake: dataLen[%d] after cipher suite is short", len(data))
+		err = fmt.Errorf("readHandshake: dataLen[%d] after cipher suite is short", len(data))
+		return
 	}
 	data = data[2+cipherSuiteLen:]
 	if len(data) < 1 {
-		return "", errMsgToLog("readHandshake: cipherSuiteLen[%d] is long", cipherSuiteLen)
+		err = fmt.Errorf("readHandshake: cipherSuiteLen[%d] is long", cipherSuiteLen)
+		return
 	}
 
-	//compression method
+	// compression method
 	compressionMethodsLen := int(data[0])
 	if len(data) < 1+compressionMethodsLen {
-		return "", errMsgToLog("readHandshake: compressionMethodsLen[%d] is long", compressionMethodsLen)
-		//return false
+		err = fmt.Errorf("readHandshake: compressionMethodsLen[%d] is long", compressionMethodsLen)
+		return
 	}
 
 	data = data[1+compressionMethodsLen:]
-
 	if len(data) == 0 {
 		// ClientHello is optionally followed by extension data
-		//return true
-		return "", errMsgToLog("readHandshake: there is no extension data to get servername")
+		err = fmt.Errorf("readHandshake: there is no extension data to get servername")
+		return
 	}
 	if len(data) < 2 {
-		return "", errMsgToLog("readHandshake: extension dataLen[%d] is too short")
+		err = fmt.Errorf("readHandshake: extension dataLen[%d] is too short")
+		return
 	}
 
 	extensionsLength := int(data[0])<<8 | int(data[1])
 	data = data[2:]
 	if extensionsLength != len(data) {
-		return "", errMsgToLog("readHandshake: extensionsLen[%d] is not equal to dataLen[%d]", extensionsLength, len(data))
+		err = fmt.Errorf("readHandshake: extensionsLen[%d] is not equal to dataLen[%d]", extensionsLength, len(data))
+		return
 	}
 	for len(data) != 0 {
 		if len(data) < 4 {
-			return "", errMsgToLog("readHandshake: extensionsDataLen[%d] is too short", len(data))
+			err = fmt.Errorf("readHandshake: extensionsDataLen[%d] is too short", len(data))
+			return
 		}
 		extension := uint16(data[0])<<8 | uint16(data[1])
 		length := int(data[2])<<8 | int(data[3])
 		data = data[4:]
 		if len(data) < length {
-			return "", errMsgToLog("readHandshake: extensionLen[%d] is long", length)
-			//return false
+			err = fmt.Errorf("readHandshake: extensionLen[%d] is long", length)
+			return
 		}
 
 		switch extension {
 		case extensionRenegotiationInfo:
 			if length != 1 || data[0] != 0 {
-				return "", errMsgToLog("readHandshake: extension reNegotiationInfoLen[%d] is short", length)
+				err = fmt.Errorf("readHandshake: extension reNegotiationInfoLen[%d] is short", length)
+				return
 			}
 		case extensionNextProtoNeg:
 		case extensionStatusRequest:
 		case extensionServerName:
 			d := data[:length]
 			if len(d) < 2 {
-				return "", errMsgToLog("readHandshake: remiaining dataLen[%d] is short", len(d))
+				err = fmt.Errorf("readHandshake: remiaining dataLen[%d] is short", len(d))
+				return
 			}
 			namesLen := int(d[0])<<8 | int(d[1])
 			d = d[2:]
 			if len(d) != namesLen {
-				return "", errMsgToLog("readHandshake: nameListLen[%d] is not equal to dataLen[%d]", namesLen, len(d))
+				err = fmt.Errorf("readHandshake: nameListLen[%d] is not equal to dataLen[%d]", namesLen, len(d))
+				return
 			}
 			for len(d) > 0 {
 				if len(d) < 3 {
-					return "", errMsgToLog("readHandshake: extension serverNameLen[%d] is short", len(d))
+					err = fmt.Errorf("readHandshake: extension serverNameLen[%d] is short", len(d))
+					return
 				}
 				nameType := d[0]
 				nameLen := int(d[1])<<8 | int(d[2])
 				d = d[3:]
 				if len(d) < nameLen {
-					return "", errMsgToLog("readHandshake: nameLen[%d] is not equal to dataLen[%d]", nameLen, len(d))
+					err = fmt.Errorf("readHandshake: nameLen[%d] is not equal to dataLen[%d]", nameLen, len(d))
+					return
 				}
 				if nameType == 0 {
-					suffix := ""
-					if VhostHttpsPort != 443 {
-						suffix = fmt.Sprintf(":%d", VhostHttpsPort)
-					}
 					serverName := string(d[:nameLen])
-					domain := strings.ToLower(strings.TrimSpace(serverName)) + suffix
-					return domain, nil
-					break
+					host = strings.TrimSpace(serverName)
+					return host, nil
 				}
 				d = d[nameLen:]
 			}
 		}
 		data = data[length:]
 	}
-	//return "test.codermao.com:8082", nil
-	return "", errMsgToLog("Unknow error")
+	err = fmt.Errorf("Unknow error")
+	return
 }
 
 func GetHttpsHostname(c *conn.Conn) (sc net.Conn, routerName string, err error) {
-	log.Info("GetHttpsHostname")
 	sc, rd := newShareConn(c.TcpConn)
-
 	host, err := readHandshake(rd)
 	if err != nil {
 		return sc, "", err
 	}
-	/*
-		if _, ok := c.TcpConn.(*tls.Conn); ok {
-			log.Warn("convert to tlsConn success")
-		} else {
-			log.Warn("convert to tlsConn error")
-		}*/
-	//tcpConn.
-	log.Debug("GetHttpsHostname: %s", host)
-
 	return sc, host, nil
 }
-
-func NewHttpsMuxer(listener *conn.Listener, timeout time.Duration) (*HttpsMuxer, error) {
-	mux, err := NewVhostMuxer(listener, GetHttpsHostname, timeout)
-	return &HttpsMuxer{mux}, err
-}

+ 0 - 25
src/frp/utils/vhost/vhost.go

@@ -15,12 +15,10 @@
 package vhost
 
 import (
-	"bufio"
 	"bytes"
 	"fmt"
 	"io"
 	"net"
-	"net/http"
 	"strings"
 	"sync"
 	"time"
@@ -99,7 +97,6 @@ func (v *VhostMuxer) handle(c *conn.Conn) {
 	}
 
 	name = strings.ToLower(name)
-
 	l, ok := v.getListener(name)
 	if !ok {
 		return
@@ -113,28 +110,6 @@ func (v *VhostMuxer) handle(c *conn.Conn) {
 	l.accept <- c
 }
 
-type HttpMuxer struct {
-	*VhostMuxer
-}
-
-func GetHttpHostname(c *conn.Conn) (_ net.Conn, routerName string, err error) {
-	sc, rd := newShareConn(c.TcpConn)
-
-	request, err := http.ReadRequest(bufio.NewReader(rd))
-	if err != nil {
-		return sc, "", err
-	}
-	routerName = request.Host
-	request.Body.Close()
-
-	return sc, routerName, nil
-}
-
-func NewHttpMuxer(listener *conn.Listener, timeout time.Duration) (*HttpMuxer, error) {
-	mux, err := NewVhostMuxer(listener, GetHttpHostname, timeout)
-	return &HttpMuxer{mux}, err
-}
-
 type Listener struct {
 	name   string
 	mux    *VhostMuxer // for closing VhostMuxer