|
@@ -16,6 +16,7 @@ package net
|
|
|
|
|
|
import (
|
|
|
"crypto/tls"
|
|
|
+ "fmt"
|
|
|
"net"
|
|
|
"time"
|
|
|
|
|
@@ -32,7 +33,7 @@ func WrapTLSClientConn(c net.Conn, tlsConfig *tls.Config) (out net.Conn) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
-func CheckAndEnableTLSServerConnWithTimeout(c net.Conn, tlsConfig *tls.Config, timeout time.Duration) (out net.Conn, err error) {
|
|
|
+func CheckAndEnableTLSServerConnWithTimeout(c net.Conn, tlsConfig *tls.Config, tlsOnly bool, timeout time.Duration) (out net.Conn, err error) {
|
|
|
sc, r := gnet.NewSharedConnSize(c, 2)
|
|
|
buf := make([]byte, 1)
|
|
|
var n int
|
|
@@ -46,6 +47,10 @@ func CheckAndEnableTLSServerConnWithTimeout(c net.Conn, tlsConfig *tls.Config, t
|
|
|
if n == 1 && int(buf[0]) == FRP_TLS_HEAD_BYTE {
|
|
|
out = tls.Server(c, tlsConfig)
|
|
|
} else {
|
|
|
+ if tlsOnly {
|
|
|
+ err = fmt.Errorf("non-TLS connection received on a TlsOnly server")
|
|
|
+ return
|
|
|
+ }
|
|
|
out = sc
|
|
|
}
|
|
|
return
|