websocket.go 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. package net
  2. import (
  3. "errors"
  4. "fmt"
  5. "net"
  6. "net/http"
  7. "net/url"
  8. "time"
  9. "golang.org/x/net/websocket"
  10. )
  11. var (
  12. ErrWebsocketListenerClosed = errors.New("websocket listener closed")
  13. )
  14. const (
  15. FrpWebsocketPath = "/~!frp"
  16. )
  17. type WebsocketListener struct {
  18. ln net.Listener
  19. acceptCh chan net.Conn
  20. server *http.Server
  21. httpMutex *http.ServeMux
  22. }
  23. // NewWebsocketListener to handle websocket connections
  24. // ln: tcp listener for websocket connections
  25. func NewWebsocketListener(ln net.Listener) (wl *WebsocketListener) {
  26. wl = &WebsocketListener{
  27. acceptCh: make(chan net.Conn),
  28. }
  29. muxer := http.NewServeMux()
  30. muxer.Handle(FrpWebsocketPath, websocket.Handler(func(c *websocket.Conn) {
  31. notifyCh := make(chan struct{})
  32. conn := WrapCloseNotifyConn(c, func() {
  33. close(notifyCh)
  34. })
  35. wl.acceptCh <- conn
  36. <-notifyCh
  37. }))
  38. wl.server = &http.Server{
  39. Addr: ln.Addr().String(),
  40. Handler: muxer,
  41. }
  42. go wl.server.Serve(ln)
  43. return
  44. }
  45. func ListenWebsocket(bindAddr string, bindPort int) (*WebsocketListener, error) {
  46. tcpLn, err := net.Listen("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
  47. if err != nil {
  48. return nil, err
  49. }
  50. l := NewWebsocketListener(tcpLn)
  51. return l, nil
  52. }
  53. func (p *WebsocketListener) Accept() (net.Conn, error) {
  54. c, ok := <-p.acceptCh
  55. if !ok {
  56. return nil, ErrWebsocketListenerClosed
  57. }
  58. return c, nil
  59. }
  60. func (p *WebsocketListener) Close() error {
  61. return p.server.Close()
  62. }
  63. func (p *WebsocketListener) Addr() net.Addr {
  64. return p.ln.Addr()
  65. }
  66. // addr: domain:port
  67. func ConnectWebsocketServer(addr string) (net.Conn, error) {
  68. addr = "ws://" + addr + FrpWebsocketPath
  69. uri, err := url.Parse(addr)
  70. if err != nil {
  71. return nil, err
  72. }
  73. origin := "http://" + uri.Host
  74. cfg, err := websocket.NewConfig(addr, origin)
  75. if err != nil {
  76. return nil, err
  77. }
  78. cfg.Dialer = &net.Dialer{
  79. Timeout: 10 * time.Second,
  80. }
  81. conn, err := websocket.DialConfig(cfg)
  82. if err != nil {
  83. return nil, err
  84. }
  85. return conn, nil
  86. }