1
0

stcp.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. // Copyright 2017 fatedier, fatedier@gmail.com
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package visitor
  15. import (
  16. "fmt"
  17. "io"
  18. "net"
  19. "strconv"
  20. "time"
  21. libio "github.com/fatedier/golib/io"
  22. v1 "github.com/fatedier/frp/pkg/config/v1"
  23. "github.com/fatedier/frp/pkg/msg"
  24. "github.com/fatedier/frp/pkg/util/util"
  25. "github.com/fatedier/frp/pkg/util/xlog"
  26. )
  27. type STCPVisitor struct {
  28. *BaseVisitor
  29. cfg *v1.STCPVisitorConfig
  30. }
  31. func (sv *STCPVisitor) Run() (err error) {
  32. if sv.cfg.BindPort > 0 {
  33. sv.l, err = net.Listen("tcp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort)))
  34. if err != nil {
  35. return
  36. }
  37. go sv.worker()
  38. }
  39. go sv.internalConnWorker()
  40. if sv.plugin != nil {
  41. sv.plugin.Start()
  42. }
  43. return
  44. }
  45. func (sv *STCPVisitor) Close() {
  46. sv.BaseVisitor.Close()
  47. }
  48. func (sv *STCPVisitor) worker() {
  49. xl := xlog.FromContextSafe(sv.ctx)
  50. for {
  51. conn, err := sv.l.Accept()
  52. if err != nil {
  53. xl.Warnf("stcp local listener closed")
  54. return
  55. }
  56. go sv.handleConn(conn)
  57. }
  58. }
  59. func (sv *STCPVisitor) internalConnWorker() {
  60. xl := xlog.FromContextSafe(sv.ctx)
  61. for {
  62. conn, err := sv.internalLn.Accept()
  63. if err != nil {
  64. xl.Warnf("stcp internal listener closed")
  65. return
  66. }
  67. go sv.handleConn(conn)
  68. }
  69. }
  70. func (sv *STCPVisitor) handleConn(userConn net.Conn) {
  71. xl := xlog.FromContextSafe(sv.ctx)
  72. var tunnelErr error
  73. defer func() {
  74. // If there was an error and connection supports CloseWithError, use it
  75. if tunnelErr != nil {
  76. if eConn, ok := userConn.(interface{ CloseWithError(error) error }); ok {
  77. _ = eConn.CloseWithError(tunnelErr)
  78. return
  79. }
  80. }
  81. userConn.Close()
  82. }()
  83. xl.Debugf("get a new stcp user connection")
  84. visitorConn, err := sv.helper.ConnectServer()
  85. if err != nil {
  86. tunnelErr = err
  87. return
  88. }
  89. defer visitorConn.Close()
  90. now := time.Now().Unix()
  91. newVisitorConnMsg := &msg.NewVisitorConn{
  92. RunID: sv.helper.RunID(),
  93. ProxyName: sv.cfg.ServerName,
  94. SignKey: util.GetAuthKey(sv.cfg.SecretKey, now),
  95. Timestamp: now,
  96. UseEncryption: sv.cfg.Transport.UseEncryption,
  97. UseCompression: sv.cfg.Transport.UseCompression,
  98. }
  99. err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
  100. if err != nil {
  101. xl.Warnf("send newVisitorConnMsg to server error: %v", err)
  102. tunnelErr = err
  103. return
  104. }
  105. var newVisitorConnRespMsg msg.NewVisitorConnResp
  106. _ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
  107. err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
  108. if err != nil {
  109. xl.Warnf("get newVisitorConnRespMsg error: %v", err)
  110. tunnelErr = err
  111. return
  112. }
  113. _ = visitorConn.SetReadDeadline(time.Time{})
  114. if newVisitorConnRespMsg.Error != "" {
  115. xl.Warnf("start new visitor connection error: %s", newVisitorConnRespMsg.Error)
  116. tunnelErr = fmt.Errorf("%s", newVisitorConnRespMsg.Error)
  117. return
  118. }
  119. var remote io.ReadWriteCloser
  120. remote = visitorConn
  121. if sv.cfg.Transport.UseEncryption {
  122. remote, err = libio.WithEncryption(remote, []byte(sv.cfg.SecretKey))
  123. if err != nil {
  124. xl.Errorf("create encryption stream error: %v", err)
  125. tunnelErr = err
  126. return
  127. }
  128. }
  129. if sv.cfg.Transport.UseCompression {
  130. var recycleFn func()
  131. remote, recycleFn = libio.WithCompressionFromPool(remote)
  132. defer recycleFn()
  133. }
  134. libio.Join(userConn, remote)
  135. }