|
@@ -17,6 +17,7 @@ package net
|
|
|
import (
|
|
|
"crypto/tls"
|
|
|
"net"
|
|
|
+ "time"
|
|
|
|
|
|
gnet "github.com/fatedier/golib/net"
|
|
|
)
|
|
@@ -31,10 +32,17 @@ func WrapTLSClientConn(c net.Conn, tlsConfig *tls.Config) (out Conn) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
-func CheckAndEnableTLSServerConn(c net.Conn, tlsConfig *tls.Config) (out Conn) {
|
|
|
- sc, r := gnet.NewSharedConnSize(c, 1)
|
|
|
+func CheckAndEnableTLSServerConnWithTimeout(c net.Conn, tlsConfig *tls.Config, timeout time.Duration) (out Conn, err error) {
|
|
|
+ sc, r := gnet.NewSharedConnSize(c, 2)
|
|
|
buf := make([]byte, 1)
|
|
|
- n, _ := r.Read(buf)
|
|
|
+ var n int
|
|
|
+ c.SetReadDeadline(time.Now().Add(timeout))
|
|
|
+ n, err = r.Read(buf)
|
|
|
+ c.SetReadDeadline(time.Time{})
|
|
|
+ if err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
if n == 1 && int(buf[0]) == FRP_TLS_HEAD_BYTE {
|
|
|
out = WrapConn(tls.Server(c, tlsConfig))
|
|
|
} else {
|