|
@@ -21,6 +21,7 @@ import (
|
|
"io"
|
|
"io"
|
|
"net"
|
|
"net"
|
|
"sync"
|
|
"sync"
|
|
|
|
+ "sync/atomic"
|
|
"time"
|
|
"time"
|
|
|
|
|
|
"github.com/fatedier/frp/utils/log"
|
|
"github.com/fatedier/frp/utils/log"
|
|
@@ -178,6 +179,7 @@ func (sc *SharedConn) WriteBuff(buffer []byte) (err error) {
|
|
type StatsConn struct {
|
|
type StatsConn struct {
|
|
Conn
|
|
Conn
|
|
|
|
|
|
|
|
+ closed int64 // 1 means closed
|
|
totalRead int64
|
|
totalRead int64
|
|
totalWrite int64
|
|
totalWrite int64
|
|
statsFunc func(totalRead, totalWrite int64)
|
|
statsFunc func(totalRead, totalWrite int64)
|
|
@@ -203,9 +205,12 @@ func (statsConn *StatsConn) Write(p []byte) (n int, err error) {
|
|
}
|
|
}
|
|
|
|
|
|
func (statsConn *StatsConn) Close() (err error) {
|
|
func (statsConn *StatsConn) Close() (err error) {
|
|
- err = statsConn.Conn.Close()
|
|
|
|
- if statsConn.statsFunc != nil {
|
|
|
|
- statsConn.statsFunc(statsConn.totalRead, statsConn.totalWrite)
|
|
|
|
|
|
+ old := atomic.SwapInt64(&statsConn.closed, 1)
|
|
|
|
+ if old != 1 {
|
|
|
|
+ err = statsConn.Conn.Close()
|
|
|
|
+ if statsConn.statsFunc != nil {
|
|
|
|
+ statsConn.statsFunc(statsConn.totalRead, statsConn.totalWrite)
|
|
|
|
+ }
|
|
}
|
|
}
|
|
return
|
|
return
|
|
}
|
|
}
|