websocket.go 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. package net
  2. import (
  3. "errors"
  4. "fmt"
  5. "net"
  6. "net/http"
  7. "net/url"
  8. "time"
  9. "github.com/fatedier/frp/utils/log"
  10. "golang.org/x/net/websocket"
  11. )
  12. var (
  13. ErrWebsocketListenerClosed = errors.New("websocket listener closed")
  14. )
  15. const (
  16. FrpWebsocketPath = "/~!frp"
  17. )
  18. type WebsocketListener struct {
  19. net.Addr
  20. ln net.Listener
  21. accept chan Conn
  22. log.Logger
  23. server *http.Server
  24. httpMutex *http.ServeMux
  25. }
  26. // NewWebsocketListener to handle websocket connections
  27. // ln: tcp listener for websocket connections
  28. func NewWebsocketListener(ln net.Listener) (wl *WebsocketListener) {
  29. wl = &WebsocketListener{
  30. Addr: ln.Addr(),
  31. accept: make(chan Conn),
  32. Logger: log.NewPrefixLogger(""),
  33. }
  34. muxer := http.NewServeMux()
  35. muxer.Handle(FrpWebsocketPath, websocket.Handler(func(c *websocket.Conn) {
  36. notifyCh := make(chan struct{})
  37. conn := WrapCloseNotifyConn(c, func() {
  38. close(notifyCh)
  39. })
  40. wl.accept <- conn
  41. <-notifyCh
  42. }))
  43. wl.server = &http.Server{
  44. Addr: ln.Addr().String(),
  45. Handler: muxer,
  46. }
  47. go wl.server.Serve(ln)
  48. return
  49. }
  50. func ListenWebsocket(bindAddr string, bindPort int) (*WebsocketListener, error) {
  51. tcpLn, err := net.Listen("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
  52. if err != nil {
  53. return nil, err
  54. }
  55. l := NewWebsocketListener(tcpLn)
  56. return l, nil
  57. }
  58. func (p *WebsocketListener) Accept() (Conn, error) {
  59. c, ok := <-p.accept
  60. if !ok {
  61. return nil, ErrWebsocketListenerClosed
  62. }
  63. return c, nil
  64. }
  65. func (p *WebsocketListener) Close() error {
  66. return p.server.Close()
  67. }
  68. // addr: domain:port
  69. func ConnectWebsocketServer(addr string) (Conn, error) {
  70. addr = "ws://" + addr + FrpWebsocketPath
  71. uri, err := url.Parse(addr)
  72. if err != nil {
  73. return nil, err
  74. }
  75. origin := "http://" + uri.Host
  76. cfg, err := websocket.NewConfig(addr, origin)
  77. if err != nil {
  78. return nil, err
  79. }
  80. cfg.Dialer = &net.Dialer{
  81. Timeout: 10 * time.Second,
  82. }
  83. conn, err := websocket.DialConfig(cfg)
  84. if err != nil {
  85. return nil, err
  86. }
  87. c := WrapConn(conn)
  88. return c, nil
  89. }