123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127 |
- package net
- import (
- "fmt"
- "net"
- "net/http"
- "net/url"
- "sync/atomic"
- "time"
- "github.com/fatedier/frp/utils/log"
- "golang.org/x/net/websocket"
- )
- type WebsocketListener struct {
- log.Logger
- server *http.Server
- httpMutex *http.ServeMux
- connChan chan *WebsocketConn
- closeFlag bool
- }
- func NewWebsocketListener(ln net.Listener,
- filter func(w http.ResponseWriter, r *http.Request) bool) (l *WebsocketListener, err error) {
- l = &WebsocketListener{
- httpMutex: http.NewServeMux(),
- connChan: make(chan *WebsocketConn),
- Logger: log.NewPrefixLogger(""),
- }
- l.httpMutex.Handle("/", websocket.Handler(func(c *websocket.Conn) {
- conn := NewWebScoketConn(c)
- l.connChan <- conn
- conn.waitClose()
- }))
- l.server = &http.Server{
- Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if filter != nil && !filter(w, r) {
- return
- }
- l.httpMutex.ServeHTTP(w, r)
- }),
- }
- ch := make(chan struct{})
- go func() {
- close(ch)
- err = l.server.Serve(ln)
- }()
- <-ch
- <-time.After(time.Millisecond)
- return
- }
- func ListenWebsocket(bindAddr string, bindPort int) (l *WebsocketListener, err error) {
- ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
- if err != nil {
- return
- }
- l, err = NewWebsocketListener(ln, nil)
- return
- }
- func (p *WebsocketListener) Accept() (Conn, error) {
- c := <-p.connChan
- return c, nil
- }
- func (p *WebsocketListener) Close() error {
- if !p.closeFlag {
- p.closeFlag = true
- p.server.Close()
- }
- return nil
- }
- type WebsocketConn struct {
- net.Conn
- log.Logger
- closed int32
- wait chan struct{}
- }
- func NewWebScoketConn(conn net.Conn) (c *WebsocketConn) {
- c = &WebsocketConn{
- Conn: conn,
- Logger: log.NewPrefixLogger(""),
- wait: make(chan struct{}),
- }
- return
- }
- func (p *WebsocketConn) Close() error {
- if atomic.SwapInt32(&p.closed, 1) == 1 {
- return nil
- }
- close(p.wait)
- return p.Conn.Close()
- }
- func (p *WebsocketConn) waitClose() {
- <-p.wait
- }
- // ConnectWebsocketServer :
- // addr: ws://domain:port
- func ConnectWebsocketServer(addr string) (c Conn, err error) {
- addr = "ws://" + addr
- uri, err := url.Parse(addr)
- if err != nil {
- return
- }
- origin := "http://" + uri.Host
- cfg, err := websocket.NewConfig(addr, origin)
- if err != nil {
- return
- }
- cfg.Dialer = &net.Dialer{
- Timeout: time.Second * 10,
- }
- conn, err := websocket.DialConfig(cfg)
- if err != nil {
- return
- }
- c = NewWebScoketConn(conn)
- return
- }
|