|
@@ -1,127 +1,105 @@
|
|
|
package net
|
|
|
|
|
|
import (
|
|
|
+ "errors"
|
|
|
"fmt"
|
|
|
"net"
|
|
|
"net/http"
|
|
|
"net/url"
|
|
|
- "sync/atomic"
|
|
|
"time"
|
|
|
|
|
|
"github.com/fatedier/frp/utils/log"
|
|
|
+
|
|
|
"golang.org/x/net/websocket"
|
|
|
)
|
|
|
|
|
|
+var (
|
|
|
+ ErrWebsocketListenerClosed = errors.New("websocket listener closed")
|
|
|
+)
|
|
|
+
|
|
|
+const (
|
|
|
+ FrpWebsocketPath = "/#frp"
|
|
|
+)
|
|
|
+
|
|
|
type WebsocketListener struct {
|
|
|
+ net.Addr
|
|
|
+ ln net.Listener
|
|
|
+ accept chan Conn
|
|
|
log.Logger
|
|
|
+
|
|
|
server *http.Server
|
|
|
httpMutex *http.ServeMux
|
|
|
- connChan chan *WebsocketConn
|
|
|
- closeFlag bool
|
|
|
}
|
|
|
|
|
|
-func NewWebsocketListener(ln net.Listener,
|
|
|
- filter func(w http.ResponseWriter, r *http.Request) bool) (l *WebsocketListener, err error) {
|
|
|
- l = &WebsocketListener{
|
|
|
- httpMutex: http.NewServeMux(),
|
|
|
- connChan: make(chan *WebsocketConn),
|
|
|
- Logger: log.NewPrefixLogger(""),
|
|
|
+// ln: tcp listener for websocket connections
|
|
|
+func NewWebsocketListener(ln net.Listener) (wl *WebsocketListener) {
|
|
|
+ wl = &WebsocketListener{
|
|
|
+ Addr: ln.Addr(),
|
|
|
+ accept: make(chan Conn),
|
|
|
+ Logger: log.NewPrefixLogger(""),
|
|
|
}
|
|
|
- l.httpMutex.Handle("/", websocket.Handler(func(c *websocket.Conn) {
|
|
|
- conn := NewWebScoketConn(c)
|
|
|
- l.connChan <- conn
|
|
|
- conn.waitClose()
|
|
|
+
|
|
|
+ muxer := http.NewServeMux()
|
|
|
+ muxer.Handle(FrpWebsocketPath, websocket.Handler(func(c *websocket.Conn) {
|
|
|
+ notifyCh := make(chan struct{})
|
|
|
+ conn := WrapCloseNotifyConn(c, func() {
|
|
|
+ close(notifyCh)
|
|
|
+ })
|
|
|
+ wl.accept <- conn
|
|
|
+ <-notifyCh
|
|
|
}))
|
|
|
- l.server = &http.Server{
|
|
|
- Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
- if filter != nil && !filter(w, r) {
|
|
|
- return
|
|
|
- }
|
|
|
- l.httpMutex.ServeHTTP(w, r)
|
|
|
- }),
|
|
|
+
|
|
|
+ wl.server = &http.Server{
|
|
|
+ Addr: ln.Addr().String(),
|
|
|
+ Handler: muxer,
|
|
|
}
|
|
|
- ch := make(chan struct{})
|
|
|
- go func() {
|
|
|
- close(ch)
|
|
|
- err = l.server.Serve(ln)
|
|
|
- }()
|
|
|
- <-ch
|
|
|
- <-time.After(time.Millisecond)
|
|
|
+
|
|
|
+ go wl.server.Serve(ln)
|
|
|
return
|
|
|
}
|
|
|
|
|
|
-func ListenWebsocket(bindAddr string, bindPort int) (l *WebsocketListener, err error) {
|
|
|
- ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
|
|
|
+func ListenWebsocket(bindAddr string, bindPort int) (*WebsocketListener, error) {
|
|
|
+ tcpLn, err := net.Listen("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
|
|
|
if err != nil {
|
|
|
- return
|
|
|
+ return nil, err
|
|
|
}
|
|
|
- l, err = NewWebsocketListener(ln, nil)
|
|
|
- return
|
|
|
+ l := NewWebsocketListener(tcpLn)
|
|
|
+ return l, nil
|
|
|
}
|
|
|
|
|
|
func (p *WebsocketListener) Accept() (Conn, error) {
|
|
|
- c := <-p.connChan
|
|
|
+ c, ok := <-p.accept
|
|
|
+ if !ok {
|
|
|
+ return nil, ErrWebsocketListenerClosed
|
|
|
+ }
|
|
|
return c, nil
|
|
|
}
|
|
|
|
|
|
func (p *WebsocketListener) Close() error {
|
|
|
- if !p.closeFlag {
|
|
|
- p.closeFlag = true
|
|
|
- p.server.Close()
|
|
|
- }
|
|
|
- return nil
|
|
|
-}
|
|
|
-
|
|
|
-type WebsocketConn struct {
|
|
|
- net.Conn
|
|
|
- log.Logger
|
|
|
- closed int32
|
|
|
- wait chan struct{}
|
|
|
-}
|
|
|
-
|
|
|
-func NewWebScoketConn(conn net.Conn) (c *WebsocketConn) {
|
|
|
- c = &WebsocketConn{
|
|
|
- Conn: conn,
|
|
|
- Logger: log.NewPrefixLogger(""),
|
|
|
- wait: make(chan struct{}),
|
|
|
- }
|
|
|
- return
|
|
|
+ return p.server.Close()
|
|
|
}
|
|
|
|
|
|
-func (p *WebsocketConn) Close() error {
|
|
|
- if atomic.SwapInt32(&p.closed, 1) == 1 {
|
|
|
- return nil
|
|
|
- }
|
|
|
- close(p.wait)
|
|
|
- return p.Conn.Close()
|
|
|
-}
|
|
|
-
|
|
|
-func (p *WebsocketConn) waitClose() {
|
|
|
- <-p.wait
|
|
|
-}
|
|
|
-
|
|
|
-// ConnectWebsocketServer :
|
|
|
-// addr: ws://domain:port
|
|
|
-func ConnectWebsocketServer(addr string) (c Conn, err error) {
|
|
|
- addr = "ws://" + addr
|
|
|
+// addr: domain:port
|
|
|
+func ConnectWebsocketServer(addr string) (Conn, error) {
|
|
|
+ addr = "ws://" + addr + FrpWebsocketPath
|
|
|
uri, err := url.Parse(addr)
|
|
|
if err != nil {
|
|
|
- return
|
|
|
+ return nil, err
|
|
|
}
|
|
|
|
|
|
origin := "http://" + uri.Host
|
|
|
cfg, err := websocket.NewConfig(addr, origin)
|
|
|
if err != nil {
|
|
|
- return
|
|
|
+ return nil, err
|
|
|
}
|
|
|
cfg.Dialer = &net.Dialer{
|
|
|
- Timeout: time.Second * 10,
|
|
|
+ Timeout: 10 * time.Second,
|
|
|
}
|
|
|
|
|
|
conn, err := websocket.DialConfig(cfg)
|
|
|
if err != nil {
|
|
|
- return
|
|
|
+ return nil, err
|
|
|
}
|
|
|
- c = NewWebScoketConn(conn)
|
|
|
- return
|
|
|
+ c := WrapConn(conn)
|
|
|
+ return c, nil
|
|
|
}
|