1
0

websocket.go 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. package net
  2. import (
  3. "errors"
  4. "net"
  5. "net/http"
  6. "strconv"
  7. "golang.org/x/net/websocket"
  8. )
  9. var ErrWebsocketListenerClosed = errors.New("websocket listener closed")
  10. const (
  11. FrpWebsocketPath = "/~!frp"
  12. )
  13. type WebsocketListener struct {
  14. ln net.Listener
  15. acceptCh chan net.Conn
  16. server *http.Server
  17. }
  18. // NewWebsocketListener to handle websocket connections
  19. // ln: tcp listener for websocket connections
  20. func NewWebsocketListener(ln net.Listener) (wl *WebsocketListener) {
  21. wl = &WebsocketListener{
  22. acceptCh: make(chan net.Conn),
  23. }
  24. muxer := http.NewServeMux()
  25. muxer.Handle(FrpWebsocketPath, websocket.Handler(func(c *websocket.Conn) {
  26. notifyCh := make(chan struct{})
  27. conn := WrapCloseNotifyConn(c, func() {
  28. close(notifyCh)
  29. })
  30. wl.acceptCh <- conn
  31. <-notifyCh
  32. }))
  33. wl.server = &http.Server{
  34. Addr: ln.Addr().String(),
  35. Handler: muxer,
  36. }
  37. go func() {
  38. _ = wl.server.Serve(ln)
  39. }()
  40. return
  41. }
  42. func ListenWebsocket(bindAddr string, bindPort int) (*WebsocketListener, error) {
  43. tcpLn, err := net.Listen("tcp", net.JoinHostPort(bindAddr, strconv.Itoa(bindPort)))
  44. if err != nil {
  45. return nil, err
  46. }
  47. l := NewWebsocketListener(tcpLn)
  48. return l, nil
  49. }
  50. func (p *WebsocketListener) Accept() (net.Conn, error) {
  51. c, ok := <-p.acceptCh
  52. if !ok {
  53. return nil, ErrWebsocketListenerClosed
  54. }
  55. return c, nil
  56. }
  57. func (p *WebsocketListener) Close() error {
  58. return p.server.Close()
  59. }
  60. func (p *WebsocketListener) Addr() net.Addr {
  61. return p.ln.Addr()
  62. }