websocket.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. package net
  2. import (
  3. "fmt"
  4. "net"
  5. "net/http"
  6. "net/url"
  7. "sync/atomic"
  8. "time"
  9. "github.com/fatedier/frp/utils/log"
  10. "golang.org/x/net/websocket"
  11. )
  12. type WebsocketListener struct {
  13. log.Logger
  14. server *http.Server
  15. httpMutex *http.ServeMux
  16. connChan chan *WebsocketConn
  17. closeFlag bool
  18. }
  19. func NewWebsocketListener(ln net.Listener,
  20. filter func(w http.ResponseWriter, r *http.Request) bool) (l *WebsocketListener, err error) {
  21. l = &WebsocketListener{
  22. httpMutex: http.NewServeMux(),
  23. connChan: make(chan *WebsocketConn),
  24. Logger: log.NewPrefixLogger(""),
  25. }
  26. l.httpMutex.Handle("/", websocket.Handler(func(c *websocket.Conn) {
  27. conn := NewWebScoketConn(c)
  28. l.connChan <- conn
  29. conn.waitClose()
  30. }))
  31. l.server = &http.Server{
  32. Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  33. if filter != nil && !filter(w, r) {
  34. return
  35. }
  36. l.httpMutex.ServeHTTP(w, r)
  37. }),
  38. }
  39. ch := make(chan struct{})
  40. go func() {
  41. close(ch)
  42. err = l.server.Serve(ln)
  43. }()
  44. <-ch
  45. <-time.After(time.Millisecond)
  46. return
  47. }
  48. func ListenWebsocket(bindAddr string, bindPort int) (l *WebsocketListener, err error) {
  49. ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
  50. if err != nil {
  51. return
  52. }
  53. l, err = NewWebsocketListener(ln, nil)
  54. return
  55. }
  56. func (p *WebsocketListener) Accept() (Conn, error) {
  57. c := <-p.connChan
  58. return c, nil
  59. }
  60. func (p *WebsocketListener) Close() error {
  61. if !p.closeFlag {
  62. p.closeFlag = true
  63. p.server.Close()
  64. }
  65. return nil
  66. }
  67. type WebsocketConn struct {
  68. net.Conn
  69. log.Logger
  70. closed int32
  71. wait chan struct{}
  72. }
  73. func NewWebScoketConn(conn net.Conn) (c *WebsocketConn) {
  74. c = &WebsocketConn{
  75. Conn: conn,
  76. Logger: log.NewPrefixLogger(""),
  77. wait: make(chan struct{}),
  78. }
  79. return
  80. }
  81. func (p *WebsocketConn) Close() error {
  82. if atomic.SwapInt32(&p.closed, 1) == 1 {
  83. return nil
  84. }
  85. close(p.wait)
  86. return p.Conn.Close()
  87. }
  88. func (p *WebsocketConn) waitClose() {
  89. <-p.wait
  90. }
  91. // ConnectWebsocketServer :
  92. // addr: ws://domain:port
  93. func ConnectWebsocketServer(addr string) (c Conn, err error) {
  94. addr = "ws://" + addr
  95. uri, err := url.Parse(addr)
  96. if err != nil {
  97. return
  98. }
  99. origin := "http://" + uri.Host
  100. cfg, err := websocket.NewConfig(addr, origin)
  101. if err != nil {
  102. return
  103. }
  104. cfg.Dialer = &net.Dialer{
  105. Timeout: time.Second * 10,
  106. }
  107. conn, err := websocket.DialConfig(cfg)
  108. if err != nil {
  109. return
  110. }
  111. c = NewWebScoketConn(conn)
  112. return
  113. }