stcp.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. // Copyright 2017 fatedier, fatedier@gmail.com
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package visitor
  15. import (
  16. "io"
  17. "net"
  18. "strconv"
  19. "time"
  20. libio "github.com/fatedier/golib/io"
  21. v1 "github.com/fatedier/frp/pkg/config/v1"
  22. "github.com/fatedier/frp/pkg/msg"
  23. "github.com/fatedier/frp/pkg/util/util"
  24. "github.com/fatedier/frp/pkg/util/xlog"
  25. )
  26. type STCPVisitor struct {
  27. *BaseVisitor
  28. cfg *v1.STCPVisitorConfig
  29. }
  30. func (sv *STCPVisitor) Run() (err error) {
  31. if sv.cfg.BindPort > 0 {
  32. sv.l, err = net.Listen("tcp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort)))
  33. if err != nil {
  34. return
  35. }
  36. go sv.worker()
  37. }
  38. go sv.internalConnWorker()
  39. if sv.plugin != nil {
  40. sv.plugin.Start()
  41. }
  42. return
  43. }
  44. func (sv *STCPVisitor) Close() {
  45. sv.BaseVisitor.Close()
  46. }
  47. func (sv *STCPVisitor) worker() {
  48. xl := xlog.FromContextSafe(sv.ctx)
  49. for {
  50. conn, err := sv.l.Accept()
  51. if err != nil {
  52. xl.Warnf("stcp local listener closed")
  53. return
  54. }
  55. go sv.handleConn(conn)
  56. }
  57. }
  58. func (sv *STCPVisitor) internalConnWorker() {
  59. xl := xlog.FromContextSafe(sv.ctx)
  60. for {
  61. conn, err := sv.internalLn.Accept()
  62. if err != nil {
  63. xl.Warnf("stcp internal listener closed")
  64. return
  65. }
  66. go sv.handleConn(conn)
  67. }
  68. }
  69. func (sv *STCPVisitor) handleConn(userConn net.Conn) {
  70. xl := xlog.FromContextSafe(sv.ctx)
  71. defer userConn.Close()
  72. xl.Debugf("get a new stcp user connection")
  73. visitorConn, err := sv.helper.ConnectServer()
  74. if err != nil {
  75. return
  76. }
  77. defer visitorConn.Close()
  78. now := time.Now().Unix()
  79. newVisitorConnMsg := &msg.NewVisitorConn{
  80. RunID: sv.helper.RunID(),
  81. ProxyName: sv.cfg.ServerName,
  82. SignKey: util.GetAuthKey(sv.cfg.SecretKey, now),
  83. Timestamp: now,
  84. UseEncryption: sv.cfg.Transport.UseEncryption,
  85. UseCompression: sv.cfg.Transport.UseCompression,
  86. }
  87. err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
  88. if err != nil {
  89. xl.Warnf("send newVisitorConnMsg to server error: %v", err)
  90. return
  91. }
  92. var newVisitorConnRespMsg msg.NewVisitorConnResp
  93. _ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
  94. err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
  95. if err != nil {
  96. xl.Warnf("get newVisitorConnRespMsg error: %v", err)
  97. return
  98. }
  99. _ = visitorConn.SetReadDeadline(time.Time{})
  100. if newVisitorConnRespMsg.Error != "" {
  101. xl.Warnf("start new visitor connection error: %s", newVisitorConnRespMsg.Error)
  102. return
  103. }
  104. var remote io.ReadWriteCloser
  105. remote = visitorConn
  106. if sv.cfg.Transport.UseEncryption {
  107. remote, err = libio.WithEncryption(remote, []byte(sv.cfg.SecretKey))
  108. if err != nil {
  109. xl.Errorf("create encryption stream error: %v", err)
  110. return
  111. }
  112. }
  113. if sv.cfg.Transport.UseCompression {
  114. var recycleFn func()
  115. remote, recycleFn = libio.WithCompressionFromPool(remote)
  116. defer recycleFn()
  117. }
  118. libio.Join(userConn, remote)
  119. }