Jelajahi Sumber

utils/conn: add func joinMore

fatedier 9 tahun lalu
induk
melakukan
45c21b2705

+ 1 - 1
conf/frpc.ini

@@ -11,4 +11,4 @@ log_level = debug
 [test1]
 passwd = 123
 local_ip = 127.0.0.1
-local_port = 8000
+local_port = 22

+ 1 - 2
src/frp/models/client/client.go

@@ -82,8 +82,7 @@ func (p *ProxyClient) StartTunnel(serverAddr string, serverPort int64) (err erro
 	log.Debug("Join two conns, (l[%s] r[%s]) (l[%s] r[%s])", localConn.GetLocalAddr(), localConn.GetRemoteAddr(),
 		remoteConn.GetLocalAddr(), remoteConn.GetRemoteAddr())
 	// go conn.Join(localConn, remoteConn)
-	go conn.PipeEncryptoWriter(localConn.TcpConn, remoteConn.TcpConn, p.Passwd)
-	go conn.PipeDecryptoReader(remoteConn.TcpConn, localConn.TcpConn, p.Passwd)
+	go conn.JoinMore(localConn, remoteConn, p.Passwd)
 
 	return nil
 }

+ 1 - 2
src/frp/models/server/server.go

@@ -133,8 +133,7 @@ func (p *ProxyServer) Start() (err error) {
 			log.Debug("Join two conns, (l[%s] r[%s]) (l[%s] r[%s])", cliConn.GetLocalAddr(), cliConn.GetRemoteAddr(),
 				userConn.GetLocalAddr(), userConn.GetRemoteAddr())
 			// go conn.Join(cliConn, userConn)
-			go conn.PipeEncryptoWriter(userConn.TcpConn, cliConn.TcpConn, p.Passwd)
-			go conn.PipeDecryptoReader(cliConn.TcpConn, userConn.TcpConn, p.Passwd)
+			go conn.JoinMore(userConn, cliConn, p.Passwd)
 		}
 	}()
 

+ 47 - 25
src/frp/utils/conn/conn.go

@@ -164,49 +164,72 @@ func Join(c1 *Conn, c2 *Conn) {
 	return
 }
 
-// decrypto msg from reader, then write into writer
-func PipeDecryptoReader(r net.Conn, w net.Conn, key string) {
-	defer r.Close()
-	defer w.Close()
+func JoinMore(local *Conn, remote *Conn, cryptoKey string) {
+	var wait sync.WaitGroup
+	encrypPipe := func(from *Conn, to *Conn, key string) {
+		defer from.Close()
+		defer to.Close()
+		defer wait.Done()
 
+		err := PipeEncryptoWriter(from.TcpConn, to.TcpConn, key)
+		if err != nil {
+			log.Warn("join conns error, %v", err)
+		}
+	}
+
+	decryptoPipe := func(to *Conn, from *Conn, key string) {
+		defer from.Close()
+		defer to.Close()
+		defer wait.Done()
+
+		err := PipeDecryptoReader(to.TcpConn, from.TcpConn, key)
+		if err != nil {
+			log.Warn("join conns error, %v", err)
+		}
+	}
+
+	wait.Add(2)
+	go encrypPipe(local, remote, cryptoKey)
+	go decryptoPipe(remote, local, cryptoKey)
+	wait.Wait()
+	return
+}
+
+// decrypto msg from reader, then write into writer
+func PipeDecryptoReader(r net.Conn, w net.Conn, key string) error {
 	laes := new(pcrypto.Pcrypto)
 	if err := laes.Init([]byte(key)); err != nil {
-		log.Error("Pcrypto Init error, [%v]", err)
-		return
+		log.Error("Pcrypto Init error: %v", err)
+		return fmt.Errorf("Pcrypto Init error: %v", err)
 	}
 
 	nreader := bufio.NewReader(r)
-
 	for {
 		buf, err := nreader.ReadBytes('\n')
 		if err != nil {
-			log.Error("Conn ReadBytes error, [%v]", err)
-			return
+			return err
 		}
 
 		res, err := laes.Decrypto(buf)
 		if err != nil {
-			log.Error("Decrypto error, [%s] [%s]", err, string(buf))
-			return
+			log.Error("Decrypto [%s] error, %v", string(buf), err)
+			return fmt.Errorf("Decrypto [%s] error: %v", string(buf), err)
 		}
 
 		_, err = w.Write(res)
 		if err != nil {
-			log.Error("net.Conn Write error, [%v]", err)
-			return
+			return err
 		}
 	}
+	return nil
 }
 
 // recvive msg from reader, then encrypto msg into write
-func PipeEncryptoWriter(r net.Conn, w net.Conn, key string) {
-	defer r.Close()
-	defer w.Close()
-
+func PipeEncryptoWriter(r net.Conn, w net.Conn, key string) error {
 	laes := new(pcrypto.Pcrypto)
 	if err := laes.Init([]byte(key)); err != nil {
-		log.Error("Pcrypto Init error, [%v]", err)
-		return
+		log.Error("Pcrypto Init error: %v", err)
+		return fmt.Errorf("Pcrypto Init error: %v", err)
 	}
 
 	nreader := bufio.NewReader(r)
@@ -215,20 +238,19 @@ func PipeEncryptoWriter(r net.Conn, w net.Conn, key string) {
 	for {
 		n, err := nreader.Read(buf)
 		if err != nil {
-			log.Error("Conn ReadLine error, [%v]", err)
-			return
+			return err
 		}
 		res, err := laes.Encrypto(buf[:n])
 		if err != nil {
-			log.Error("Encrypto error, [%v]", err)
-			return
+			log.Error("Encrypto error: %v", err)
+			return fmt.Errorf("Encrypto error: %v", err)
 		}
 
 		res = append(res, '\n')
 		_, err = w.Write(res)
 		if err != nil {
-			log.Error("net.Conn Write error, [%v]", err)
-			return
+			return err
 		}
 	}
+	return nil
 }