瀏覽代碼

Merge pull request #33 from fatedier/mydev

User can set use aes or gzip
Pan Hao 8 年之前
父節點
當前提交
1987a399c1

+ 2 - 6
src/frp/models/client/client.go

@@ -32,7 +32,7 @@ type ProxyClient struct {
 	LocalIp       string
 	LocalPort     int64
 	Type          string
-	UseEncryption bool
+	UseEncryption int
 }
 
 func (p *ProxyClient) GetLocalConn() (c *conn.Conn, err error) {
@@ -89,11 +89,7 @@ func (p *ProxyClient) StartTunnel(serverAddr string, serverPort int64) (err erro
 	// l means local, r means remote
 	log.Debug("Join two connections, (l[%s] r[%s]) (l[%s] r[%s])", localConn.GetLocalAddr(), localConn.GetRemoteAddr(),
 		remoteConn.GetLocalAddr(), remoteConn.GetRemoteAddr())
-	if p.UseEncryption {
-		go conn.JoinMore(localConn, remoteConn, p.AuthToken)
-	} else {
-		go conn.Join(localConn, remoteConn)
-	}
+	go conn.JoinMore(localConn, remoteConn, p.AuthToken, p.UseEncryption)
 
 	return nil
 }

+ 8 - 3
src/frp/models/client/config.go

@@ -122,10 +122,15 @@ func LoadConf(confFile string) (err error) {
 			}
 
 			// use_encryption
-			proxyClient.UseEncryption = false
+			proxyClient.UseEncryption = 0
 			useEncryptionStr, ok := section["use_encryption"]
-			if ok && useEncryptionStr == "true" {
-				proxyClient.UseEncryption = true
+			if ok {
+				tmpRes, err := strconv.Atoi(useEncryptionStr)
+				if err != nil {
+					proxyClient.UseEncryption = 0
+				}
+
+				proxyClient.UseEncryption = tmpRes
 			}
 
 			ProxyClients[proxyClient.Name] = proxyClient

+ 1 - 1
src/frp/models/msg/msg.go

@@ -24,7 +24,7 @@ type ControlReq struct {
 	Type          int64  `json:"type"`
 	ProxyName     string `json:"proxy_name,omitempty"`
 	AuthKey       string `json:"auth_key, omitempty"`
-	UseEncryption bool   `json:"use_encryption, omitempty"`
+	UseEncryption int    `json:"use_encryption, omitempty"`
 	Timestamp     int64  `json:"timestamp, omitempty"`
 }
 

+ 2 - 6
src/frp/models/server/server.go

@@ -38,7 +38,7 @@ type ProxyServer struct {
 	CustomDomains []string
 
 	// configure in frpc.ini
-	UseEncryption bool
+	UseEncryption int
 
 	Status       int64
 	CtlConn      *conn.Conn      // control connection with frpc
@@ -144,11 +144,7 @@ func (p *ProxyServer) Start(c *conn.Conn) (err error) {
 					log.Debug("Join two connections, (l[%s] r[%s]) (l[%s] r[%s])", workConn.GetLocalAddr(), workConn.GetRemoteAddr(),
 						userConn.GetLocalAddr(), userConn.GetRemoteAddr())
 
-					if p.UseEncryption {
-						go conn.JoinMore(userConn, workConn, p.AuthToken)
-					} else {
-						go conn.Join(userConn, workConn)
-					}
+					go conn.JoinMore(userConn, workConn, p.AuthToken, p.UseEncryption)
 				}()
 			}
 		}(listener)

+ 46 - 13
src/frp/utils/conn/conn.go

@@ -16,6 +16,8 @@ package conn
 
 import (
 	"bufio"
+	"bytes"
+	"encoding/binary"
 	"fmt"
 	"io"
 	"net"
@@ -192,48 +194,70 @@ func Join(c1 *Conn, c2 *Conn) {
 
 // messages from c1 to c2 will be encrypted
 // and from c2 to c1 will be decrypted
-func JoinMore(c1 *Conn, c2 *Conn, cryptKey string) {
+func JoinMore(c1 *Conn, c2 *Conn, cryptKey string, ptype int) {
 	var wait sync.WaitGroup
-	encryptPipe := func(from *Conn, to *Conn, key string) {
+	encryptPipe := func(from *Conn, to *Conn, key string, ttype int) {
 		defer from.Close()
 		defer to.Close()
 		defer wait.Done()
 
 		// we don't care about errors here
-		PipeEncrypt(from.TcpConn, to.TcpConn, key)
+		PipeEncrypt(from.TcpConn, to.TcpConn, key, ttype)
 	}
 
-	decryptPipe := func(to *Conn, from *Conn, key string) {
+	decryptPipe := func(to *Conn, from *Conn, key string, ttype int) {
 		defer from.Close()
 		defer to.Close()
 		defer wait.Done()
 
 		// we don't care about errors here
-		PipeDecrypt(to.TcpConn, from.TcpConn, key)
+		PipeDecrypt(to.TcpConn, from.TcpConn, key, ttype)
 	}
 
 	wait.Add(2)
-	go encryptPipe(c1, c2, cryptKey)
-	go decryptPipe(c2, c1, cryptKey)
+	go encryptPipe(c1, c2, cryptKey, ptype)
+
+	go decryptPipe(c2, c1, cryptKey, ptype)
 	wait.Wait()
 	log.Debug("One tunnel stopped")
 	return
 }
 
+func unpkgMsg(data []byte) (int, []byte, []byte) {
+	if len(data) < 4 {
+		return -1, nil, nil
+	}
+	llen := int(binary.BigEndian.Uint32(data[0:4]))
+	// no complete
+	if len(data) < llen+4 {
+		return -1, nil, nil
+	}
+
+	return 0, data[4 : llen+4], data[llen+4:]
+}
+
 // decrypt msg from reader, then write into writer
-func PipeDecrypt(r net.Conn, w net.Conn, key string) error {
+func PipeDecrypt(r net.Conn, w net.Conn, key string, ptype int) error {
 	laes := new(pcrypto.Pcrypto)
-	if err := laes.Init([]byte(key)); err != nil {
+	if err := laes.Init([]byte(key), ptype); err != nil {
 		log.Error("Pcrypto Init error: %v", err)
 		return fmt.Errorf("Pcrypto Init error: %v", err)
 	}
 
+	buf := make([]byte, 10*1024)
+	var left []byte
 	nreader := bufio.NewReader(r)
 	for {
-		buf, err := nreader.ReadBytes('\n')
+		n, err := nreader.Read(buf)
 		if err != nil {
 			return err
 		}
+		left := append(left, buf[:n]...)
+		cnt, buf, left := unpkgMsg(left)
+
+		if cnt < 0 {
+			continue
+		}
 
 		res, err := laes.Decrypt(buf)
 		if err != nil {
@@ -249,10 +273,18 @@ func PipeDecrypt(r net.Conn, w net.Conn, key string) error {
 	return nil
 }
 
+func pkgMsg(data []byte) []byte {
+	llen := uint32(len(data))
+	buf := new(bytes.Buffer)
+	binary.Write(buf, binary.BigEndian, llen)
+	buf.Write(data)
+	return buf.Bytes()
+}
+
 // recvive msg from reader, then encrypt msg into write
-func PipeEncrypt(r net.Conn, w net.Conn, key string) error {
+func PipeEncrypt(r net.Conn, w net.Conn, key string, ptype int) error {
 	laes := new(pcrypto.Pcrypto)
-	if err := laes.Init([]byte(key)); err != nil {
+	if err := laes.Init([]byte(key), ptype); err != nil {
 		log.Error("Pcrypto Init error: %v", err)
 		return fmt.Errorf("Pcrypto Init error: %v", err)
 	}
@@ -271,11 +303,12 @@ func PipeEncrypt(r net.Conn, w net.Conn, key string) error {
 			return fmt.Errorf("Encrypt error: %v", err)
 		}
 
-		res = append(res, '\n')
+		res = pkgMsg(res)
 		_, err = w.Write(res)
 		if err != nil {
 			return err
 		}
 	}
+
 	return nil
 }

+ 48 - 38
src/frp/utils/pcrypto/pcrypto.go

@@ -20,7 +20,6 @@ import (
 	"crypto/aes"
 	"crypto/cipher"
 	"crypto/md5"
-	"encoding/base64"
 	"encoding/hex"
 	"errors"
 	"fmt"
@@ -30,69 +29,80 @@ import (
 type Pcrypto struct {
 	pkey []byte
 	paes cipher.Block
+	// 0: nono; 1:compress; 2: encrypt; 3: compress and encrypt
+	ptyp int
 }
 
-func (pc *Pcrypto) Init(key []byte) error {
+func (pc *Pcrypto) Init(key []byte, ptyp int) error {
 	var err error
 	pc.pkey = pKCS7Padding(key, aes.BlockSize)
 	pc.paes, err = aes.NewCipher(pc.pkey)
+	if ptyp == 1 || ptyp == 2 || ptyp == 3 {
+		pc.ptyp = ptyp
+	} else {
+		pc.ptyp = 0
+	}
 
 	return err
 }
 
 func (pc *Pcrypto) Encrypt(src []byte) ([]byte, error) {
-	// gzip
 	var zbuf bytes.Buffer
-	zwr, err := gzip.NewWriterLevel(&zbuf, -1)
-	if err != nil {
-		return nil, err
+
+	// gzip
+	if pc.ptyp == 1 || pc.ptyp == 3 {
+		zwr, err := gzip.NewWriterLevel(&zbuf, gzip.DefaultCompression)
+		if err != nil {
+			return nil, err
+		}
+		defer zwr.Close()
+		zwr.Write(src)
+		zwr.Flush()
+		src = zbuf.Bytes()
 	}
-	defer zwr.Close()
-	zwr.Write(src)
-	zwr.Flush()
 
 	// aes
-	src = pKCS7Padding(zbuf.Bytes(), aes.BlockSize)
-	blockMode := cipher.NewCBCEncrypter(pc.paes, pc.pkey)
-	crypted := make([]byte, len(src))
-	blockMode.CryptBlocks(crypted, src)
+	if pc.ptyp == 2 || pc.ptyp == 3 {
+		src = pKCS7Padding(src, aes.BlockSize)
+		blockMode := cipher.NewCBCEncrypter(pc.paes, pc.pkey)
+		crypted := make([]byte, len(src))
+		blockMode.CryptBlocks(crypted, src)
+		src = crypted
+	}
 
-	// base64
-	return []byte(base64.StdEncoding.EncodeToString(crypted)), nil
+	return src, nil
 }
 
 func (pc *Pcrypto) Decrypt(str []byte) ([]byte, error) {
-	// base64
-	data, err := base64.StdEncoding.DecodeString(string(str))
-	if err != nil {
-		return nil, err
-	}
-
 	// aes
-	decryptText, err := hex.DecodeString(fmt.Sprintf("%x", data))
-	if err != nil {
-		return nil, err
-	}
+	if pc.ptyp == 2 || pc.ptyp == 3 {
+		decryptText, err := hex.DecodeString(fmt.Sprintf("%x", str))
+		if err != nil {
+			return nil, err
+		}
 
-	if len(decryptText)%aes.BlockSize != 0 {
-		return nil, errors.New("crypto/cipher: ciphertext is not a multiple of the block size")
-	}
+		if len(decryptText)%aes.BlockSize != 0 {
+			return nil, errors.New("crypto/cipher: ciphertext is not a multiple of the block size")
+		}
 
-	blockMode := cipher.NewCBCDecrypter(pc.paes, pc.pkey)
+		blockMode := cipher.NewCBCDecrypter(pc.paes, pc.pkey)
 
-	blockMode.CryptBlocks(decryptText, decryptText)
-	decryptText = pKCS7UnPadding(decryptText)
+		blockMode.CryptBlocks(decryptText, decryptText)
+		str = pKCS7UnPadding(decryptText)
+	}
 
 	// gunzip
-	zbuf := bytes.NewBuffer(decryptText)
-	zrd, err := gzip.NewReader(zbuf)
-	if err != nil {
-		return nil, err
+	if pc.ptyp == 1 || pc.ptyp == 3 {
+		zbuf := bytes.NewBuffer(str)
+		zrd, err := gzip.NewReader(zbuf)
+		if err != nil {
+			return nil, err
+		}
+		defer zrd.Close()
+		str, _ = ioutil.ReadAll(zrd)
 	}
-	defer zrd.Close()
-	data, _ = ioutil.ReadAll(zrd)
 
-	return data, nil
+	return str, nil
 }
 
 func pKCS7Padding(ciphertext []byte, blockSize int) []byte {

+ 62 - 12
src/frp/utils/pcrypto/pcrypto_test.go

@@ -20,28 +20,78 @@ import (
 )
 
 func TestEncrypt(t *testing.T) {
+	return
 	pp := new(Pcrypto)
-	pp.Init([]byte("Hana"))
-	res, err := pp.Encrypt([]byte("Just One Test!"))
+	pp.Init([]byte("Hana"), 1)
+	res, err := pp.Encrypt([]byte("Test Encrypt!"))
 	if err != nil {
 		t.Fatal(err)
 	}
 
-	fmt.Printf("[%x]\n", res)
+	fmt.Printf("Encrypt: len %d, [%x]\n", len(res), res)
 }
 
 func TestDecrypt(t *testing.T) {
-	pp := new(Pcrypto)
-	pp.Init([]byte("Hana"))
-	res, err := pp.Encrypt([]byte("Just One Test!"))
-	if err != nil {
-		t.Fatal(err)
+	fmt.Println("*****************************************************")
+	{
+		pp := new(Pcrypto)
+		pp.Init([]byte("Hana"), 0)
+		res, err := pp.Encrypt([]byte("Test Decrypt! 0"))
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		res, err = pp.Decrypt(res)
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		fmt.Printf("[%s]\n", string(res))
 	}
+	{
+		pp := new(Pcrypto)
+		pp.Init([]byte("Hana"), 1)
+		res, err := pp.Encrypt([]byte("Test Decrypt! 1"))
+		if err != nil {
+			t.Fatal(err)
+		}
 
-	res, err = pp.Decrypt(res)
-	if err != nil {
-		t.Fatal(err)
+		res, err = pp.Decrypt(res)
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		fmt.Printf("[%s]\n", string(res))
+	}
+	{
+		pp := new(Pcrypto)
+		pp.Init([]byte("Hana"), 2)
+		res, err := pp.Encrypt([]byte("Test Decrypt! 2"))
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		res, err = pp.Decrypt(res)
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		fmt.Printf("[%s]\n", string(res))
+	}
+	{
+		pp := new(Pcrypto)
+		pp.Init([]byte("Hana"), 3)
+		res, err := pp.Encrypt([]byte("Test Decrypt! 3"))
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		res, err = pp.Decrypt(res)
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		fmt.Printf("[%s]\n", string(res))
 	}
 
-	fmt.Printf("[%s]\n", string(res))
 }