1
0

websocket.go 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. package net
  2. import (
  3. "errors"
  4. "fmt"
  5. "net"
  6. "net/http"
  7. "net/url"
  8. "time"
  9. "github.com/fatedier/frp/utils/log"
  10. "golang.org/x/net/websocket"
  11. )
  12. var (
  13. ErrWebsocketListenerClosed = errors.New("websocket listener closed")
  14. )
  15. const (
  16. FrpWebsocketPath = "/#frp"
  17. )
  18. type WebsocketListener struct {
  19. net.Addr
  20. ln net.Listener
  21. accept chan Conn
  22. log.Logger
  23. server *http.Server
  24. httpMutex *http.ServeMux
  25. }
  26. // ln: tcp listener for websocket connections
  27. func NewWebsocketListener(ln net.Listener) (wl *WebsocketListener) {
  28. wl = &WebsocketListener{
  29. Addr: ln.Addr(),
  30. accept: make(chan Conn),
  31. Logger: log.NewPrefixLogger(""),
  32. }
  33. muxer := http.NewServeMux()
  34. muxer.Handle(FrpWebsocketPath, websocket.Handler(func(c *websocket.Conn) {
  35. notifyCh := make(chan struct{})
  36. conn := WrapCloseNotifyConn(c, func() {
  37. close(notifyCh)
  38. })
  39. wl.accept <- conn
  40. <-notifyCh
  41. }))
  42. wl.server = &http.Server{
  43. Addr: ln.Addr().String(),
  44. Handler: muxer,
  45. }
  46. go wl.server.Serve(ln)
  47. return
  48. }
  49. func ListenWebsocket(bindAddr string, bindPort int) (*WebsocketListener, error) {
  50. tcpLn, err := net.Listen("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
  51. if err != nil {
  52. return nil, err
  53. }
  54. l := NewWebsocketListener(tcpLn)
  55. return l, nil
  56. }
  57. func (p *WebsocketListener) Accept() (Conn, error) {
  58. c, ok := <-p.accept
  59. if !ok {
  60. return nil, ErrWebsocketListenerClosed
  61. }
  62. return c, nil
  63. }
  64. func (p *WebsocketListener) Close() error {
  65. return p.server.Close()
  66. }
  67. // addr: domain:port
  68. func ConnectWebsocketServer(addr string) (Conn, error) {
  69. addr = "ws://" + addr + FrpWebsocketPath
  70. uri, err := url.Parse(addr)
  71. if err != nil {
  72. return nil, err
  73. }
  74. origin := "http://" + uri.Host
  75. cfg, err := websocket.NewConfig(addr, origin)
  76. if err != nil {
  77. return nil, err
  78. }
  79. cfg.Dialer = &net.Dialer{
  80. Timeout: 10 * time.Second,
  81. }
  82. conn, err := websocket.DialConfig(cfg)
  83. if err != nil {
  84. return nil, err
  85. }
  86. c := WrapConn(conn)
  87. return c, nil
  88. }