service.go 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497
  1. package ssh
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "flag"
  6. "fmt"
  7. "io"
  8. "net"
  9. "strconv"
  10. "strings"
  11. "sync"
  12. "sync/atomic"
  13. "time"
  14. gerror "github.com/fatedier/golib/errors"
  15. "golang.org/x/crypto/ssh"
  16. v1 "github.com/fatedier/frp/pkg/config/v1"
  17. "github.com/fatedier/frp/pkg/util/log"
  18. )
  19. const (
  20. // ssh protocol define
  21. // https://datatracker.ietf.org/doc/html/rfc4254#page-16
  22. ChannelTypeServerOpenChannel = "forwarded-tcpip"
  23. RequestTypeForward = "tcpip-forward"
  24. // golang ssh package define.
  25. // https://pkg.go.dev/golang.org/x/crypto/ssh
  26. RequestTypeHeartbeat = "keepalive@openssh.com"
  27. )
  28. // 当 proxy 失败会返回该错误
  29. type VProxyError struct{}
  30. // ssh protocol define
  31. // https://datatracker.ietf.org/doc/html/rfc4254#page-16
  32. // parse ssh client cmds input
  33. type forwardedTCPPayload struct {
  34. Addr string
  35. Port uint32
  36. // can be default empty value but do not delete it
  37. // because ssh protocol shoule be reserved
  38. OriginAddr string
  39. OriginPort uint32
  40. }
  41. // custom define
  42. // parse ssh client cmds input
  43. type CmdPayload struct {
  44. Address string
  45. Port uint32
  46. }
  47. // custom define
  48. // with frp control cmds
  49. type ExtraPayload struct {
  50. Type string
  51. // TODO port can be set by extra message and priority to ssh raw cmd
  52. Address string
  53. Port uint32
  54. }
  55. type Service struct {
  56. tcpConn net.Conn
  57. cfg *ssh.ServerConfig
  58. sshConn *ssh.ServerConn
  59. gChannel <-chan ssh.NewChannel
  60. gReq <-chan *ssh.Request
  61. addrPayloadCh chan CmdPayload
  62. extraPayloadCh chan ExtraPayload
  63. proxyPayloadCh chan v1.ProxyConfigurer
  64. replyCh chan interface{}
  65. closeCh chan struct{}
  66. exit int32
  67. }
  68. func NewSSHService(
  69. tcpConn net.Conn,
  70. cfg *ssh.ServerConfig,
  71. proxyPayloadCh chan v1.ProxyConfigurer,
  72. replyCh chan interface{},
  73. ) (ss *Service, err error) {
  74. ss = &Service{
  75. tcpConn: tcpConn,
  76. cfg: cfg,
  77. addrPayloadCh: make(chan CmdPayload),
  78. extraPayloadCh: make(chan ExtraPayload),
  79. proxyPayloadCh: proxyPayloadCh,
  80. replyCh: replyCh,
  81. closeCh: make(chan struct{}),
  82. exit: 0,
  83. }
  84. ss.sshConn, ss.gChannel, ss.gReq, err = ssh.NewServerConn(tcpConn, cfg)
  85. if err != nil {
  86. log.Error("ssh handshake error: %v", err)
  87. return nil, err
  88. }
  89. log.Info("ssh connection success")
  90. return ss, nil
  91. }
  92. func (ss *Service) Run() {
  93. go ss.loopGenerateProxy()
  94. go ss.loopParseCmdPayload()
  95. go ss.loopParseExtraPayload()
  96. go ss.loopReply()
  97. }
  98. func (ss *Service) Exit() <-chan struct{} {
  99. return ss.closeCh
  100. }
  101. func (ss *Service) Close() {
  102. if atomic.LoadInt32(&ss.exit) == 1 {
  103. return
  104. }
  105. select {
  106. case <-ss.closeCh:
  107. return
  108. default:
  109. }
  110. close(ss.closeCh)
  111. close(ss.addrPayloadCh)
  112. close(ss.extraPayloadCh)
  113. _ = ss.sshConn.Wait()
  114. ss.sshConn.Close()
  115. ss.tcpConn.Close()
  116. atomic.StoreInt32(&ss.exit, 1)
  117. log.Info("ssh service close")
  118. }
  119. func (ss *Service) loopParseCmdPayload() {
  120. for {
  121. select {
  122. case req, ok := <-ss.gReq:
  123. if !ok {
  124. log.Info("global request is close")
  125. ss.Close()
  126. return
  127. }
  128. switch req.Type {
  129. case RequestTypeForward:
  130. var addrPayload CmdPayload
  131. if err := ssh.Unmarshal(req.Payload, &addrPayload); err != nil {
  132. log.Error("ssh unmarshal error: %v", err)
  133. return
  134. }
  135. _ = gerror.PanicToError(func() {
  136. ss.addrPayloadCh <- addrPayload
  137. })
  138. default:
  139. if req.Type == RequestTypeHeartbeat {
  140. log.Debug("ssh heartbeat data")
  141. } else {
  142. log.Info("default req, data: %v", req)
  143. }
  144. }
  145. if req.WantReply {
  146. err := req.Reply(true, nil)
  147. if err != nil {
  148. log.Error("reply to ssh client error: %v", err)
  149. }
  150. }
  151. case <-ss.closeCh:
  152. log.Info("loop parse cmd payload close")
  153. return
  154. }
  155. }
  156. }
  157. func (ss *Service) loopSendHeartbeat(ch ssh.Channel) {
  158. tk := time.NewTicker(time.Second * 60)
  159. defer tk.Stop()
  160. for {
  161. select {
  162. case <-tk.C:
  163. ok, err := ch.SendRequest("heartbeat", false, nil)
  164. if err != nil {
  165. log.Error("channel send req error: %v", err)
  166. if err == io.EOF {
  167. ss.Close()
  168. return
  169. }
  170. continue
  171. }
  172. log.Debug("heartbeat send success, ok: %v", ok)
  173. case <-ss.closeCh:
  174. return
  175. }
  176. }
  177. }
  178. func (ss *Service) loopParseExtraPayload() {
  179. log.Info("loop parse extra payload start")
  180. for newChannel := range ss.gChannel {
  181. ch, req, err := newChannel.Accept()
  182. if err != nil {
  183. log.Error("channel accept error: %v", err)
  184. return
  185. }
  186. go ss.loopSendHeartbeat(ch)
  187. go func(req <-chan *ssh.Request) {
  188. for r := range req {
  189. if len(r.Payload) <= 4 {
  190. log.Info("r.payload is less than 4")
  191. continue
  192. }
  193. if !strings.Contains(string(r.Payload), "tcp") && !strings.Contains(string(r.Payload), "http") {
  194. log.Info("ssh protocol exchange data")
  195. continue
  196. }
  197. // [4byte data_len|data]
  198. end := 4 + binary.BigEndian.Uint32(r.Payload[:4])
  199. if end > uint32(len(r.Payload)) {
  200. end = uint32(len(r.Payload))
  201. }
  202. p := string(r.Payload[4:end])
  203. msg, err := parseSSHExtraMessage(p)
  204. if err != nil {
  205. log.Error("parse ssh extra message error: %v, payload: %v", err, r.Payload)
  206. continue
  207. }
  208. _ = gerror.PanicToError(func() {
  209. ss.extraPayloadCh <- msg
  210. })
  211. return
  212. }
  213. }(req)
  214. }
  215. }
  216. func (ss *Service) SSHConn() *ssh.ServerConn {
  217. return ss.sshConn
  218. }
  219. func (ss *Service) TCPConn() net.Conn {
  220. return ss.tcpConn
  221. }
  222. func (ss *Service) loopReply() {
  223. for {
  224. select {
  225. case <-ss.closeCh:
  226. log.Info("loop reply close")
  227. return
  228. case req := <-ss.replyCh:
  229. switch req.(type) {
  230. case *VProxyError:
  231. log.Error("run frp proxy error, close ssh service")
  232. ss.Close()
  233. default:
  234. // TODO
  235. }
  236. }
  237. }
  238. }
  239. func (ss *Service) loopGenerateProxy() {
  240. log.Info("loop generate proxy start")
  241. for {
  242. if atomic.LoadInt32(&ss.exit) == 1 {
  243. return
  244. }
  245. wg := new(sync.WaitGroup)
  246. wg.Add(2)
  247. var p1 CmdPayload
  248. var p2 ExtraPayload
  249. go func() {
  250. defer wg.Done()
  251. for {
  252. select {
  253. case <-ss.closeCh:
  254. return
  255. case p1 = <-ss.addrPayloadCh:
  256. return
  257. }
  258. }
  259. }()
  260. go func() {
  261. defer wg.Done()
  262. for {
  263. select {
  264. case <-ss.closeCh:
  265. return
  266. case p2 = <-ss.extraPayloadCh:
  267. return
  268. }
  269. }
  270. }()
  271. wg.Wait()
  272. if atomic.LoadInt32(&ss.exit) == 1 {
  273. return
  274. }
  275. switch p2.Type {
  276. case "http":
  277. case "tcp":
  278. ss.proxyPayloadCh <- &v1.TCPProxyConfig{
  279. ProxyBaseConfig: v1.ProxyBaseConfig{
  280. Name: fmt.Sprintf("ssh-proxy-%v-%v", ss.tcpConn.RemoteAddr().String(), time.Now().UnixNano()),
  281. Type: p2.Type,
  282. ProxyBackend: v1.ProxyBackend{
  283. LocalIP: p1.Address,
  284. },
  285. },
  286. RemotePort: int(p1.Port),
  287. }
  288. default:
  289. log.Warn("invalid frp proxy type: %v", p2.Type)
  290. }
  291. }
  292. }
  293. func parseSSHExtraMessage(s string) (p ExtraPayload, err error) {
  294. sn := len(s)
  295. log.Info("parse ssh extra message: %v", s)
  296. ss := strings.Fields(s)
  297. if len(ss) == 0 {
  298. if sn != 0 {
  299. ss = append(ss, s)
  300. } else {
  301. return p, fmt.Errorf("invalid ssh input, args: %v", ss)
  302. }
  303. }
  304. for i, v := range ss {
  305. ss[i] = strings.TrimSpace(v)
  306. }
  307. if ss[0] != "tcp" && ss[0] != "http" {
  308. return p, fmt.Errorf("only support tcp/http now")
  309. }
  310. switch ss[0] {
  311. case "tcp":
  312. tcpCmd, err := ParseTCPCommand(ss)
  313. if err != nil {
  314. return ExtraPayload{}, fmt.Errorf("invalid ssh input: %v", err)
  315. }
  316. port, _ := strconv.Atoi(tcpCmd.Port)
  317. p = ExtraPayload{
  318. Type: "tcp",
  319. Address: tcpCmd.Address,
  320. Port: uint32(port),
  321. }
  322. case "http":
  323. httpCmd, err := ParseHTTPCommand(ss)
  324. if err != nil {
  325. return ExtraPayload{}, fmt.Errorf("invalid ssh input: %v", err)
  326. }
  327. _ = httpCmd
  328. p = ExtraPayload{
  329. Type: "http",
  330. }
  331. }
  332. return p, nil
  333. }
  334. type HTTPCommand struct {
  335. Domain string
  336. BasicAuthUser string
  337. BasicAuthPass string
  338. }
  339. func ParseHTTPCommand(params []string) (*HTTPCommand, error) {
  340. if len(params) < 2 {
  341. return nil, errors.New("invalid HTTP command")
  342. }
  343. var (
  344. basicAuth string
  345. domainURL string
  346. basicAuthUser string
  347. basicAuthPass string
  348. )
  349. fs := flag.NewFlagSet("http", flag.ContinueOnError)
  350. fs.StringVar(&basicAuth, "basic-auth", "", "")
  351. fs.StringVar(&domainURL, "domain", "", "")
  352. fs.SetOutput(&nullWriter{}) // Disables usage output
  353. err := fs.Parse(params[2:])
  354. if err != nil {
  355. if !errors.Is(err, flag.ErrHelp) {
  356. return nil, err
  357. }
  358. }
  359. if basicAuth != "" {
  360. authParts := strings.SplitN(basicAuth, ":", 2)
  361. basicAuthUser = authParts[0]
  362. if len(authParts) > 1 {
  363. basicAuthPass = authParts[1]
  364. }
  365. }
  366. httpCmd := &HTTPCommand{
  367. Domain: domainURL,
  368. BasicAuthUser: basicAuthUser,
  369. BasicAuthPass: basicAuthPass,
  370. }
  371. return httpCmd, nil
  372. }
  373. type TCPCommand struct {
  374. Address string
  375. Port string
  376. }
  377. func ParseTCPCommand(params []string) (*TCPCommand, error) {
  378. if len(params) == 0 || params[0] != "tcp" {
  379. return nil, errors.New("invalid TCP command")
  380. }
  381. if len(params) == 1 {
  382. return &TCPCommand{}, nil
  383. }
  384. var (
  385. address string
  386. port string
  387. )
  388. fs := flag.NewFlagSet("tcp", flag.ContinueOnError)
  389. fs.StringVar(&address, "address", "", "The IP address to listen on")
  390. fs.StringVar(&port, "port", "", "The port to listen on")
  391. fs.SetOutput(&nullWriter{}) // Disables usage output
  392. args := params[1:]
  393. err := fs.Parse(args)
  394. if err != nil {
  395. if !errors.Is(err, flag.ErrHelp) {
  396. return nil, err
  397. }
  398. }
  399. parsedAddr, err := net.ResolveIPAddr("ip", address)
  400. if err != nil {
  401. return nil, err
  402. }
  403. if _, err := net.LookupPort("tcp", port); err != nil {
  404. return nil, err
  405. }
  406. tcpCmd := &TCPCommand{
  407. Address: parsedAddr.String(),
  408. Port: port,
  409. }
  410. return tcpCmd, nil
  411. }
  412. type nullWriter struct{}
  413. func (w *nullWriter) Write(p []byte) (n int, err error) { return len(p), nil }