fec.go 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. package kcp
  2. import (
  3. "encoding/binary"
  4. "sync/atomic"
  5. "github.com/klauspost/reedsolomon"
  6. )
  7. const (
  8. fecHeaderSize = 6
  9. fecHeaderSizePlus2 = fecHeaderSize + 2 // plus 2B data size
  10. typeData = 0xf1
  11. typeParity = 0xf2
  12. )
  13. // fecPacket is a decoded FEC packet
  14. type fecPacket []byte
  15. func (bts fecPacket) seqid() uint32 { return binary.LittleEndian.Uint32(bts) }
  16. func (bts fecPacket) flag() uint16 { return binary.LittleEndian.Uint16(bts[4:]) }
  17. func (bts fecPacket) data() []byte { return bts[6:] }
  18. // fecDecoder for decoding incoming packets
  19. type fecDecoder struct {
  20. rxlimit int // queue size limit
  21. dataShards int
  22. parityShards int
  23. shardSize int
  24. rx []fecPacket // ordered receive queue
  25. // caches
  26. decodeCache [][]byte
  27. flagCache []bool
  28. // zeros
  29. zeros []byte
  30. // RS decoder
  31. codec reedsolomon.Encoder
  32. }
  33. func newFECDecoder(rxlimit, dataShards, parityShards int) *fecDecoder {
  34. if dataShards <= 0 || parityShards <= 0 {
  35. return nil
  36. }
  37. if rxlimit < dataShards+parityShards {
  38. return nil
  39. }
  40. dec := new(fecDecoder)
  41. dec.rxlimit = rxlimit
  42. dec.dataShards = dataShards
  43. dec.parityShards = parityShards
  44. dec.shardSize = dataShards + parityShards
  45. codec, err := reedsolomon.New(dataShards, parityShards)
  46. if err != nil {
  47. return nil
  48. }
  49. dec.codec = codec
  50. dec.decodeCache = make([][]byte, dec.shardSize)
  51. dec.flagCache = make([]bool, dec.shardSize)
  52. dec.zeros = make([]byte, mtuLimit)
  53. return dec
  54. }
  55. // decode a fec packet
  56. func (dec *fecDecoder) decode(in fecPacket) (recovered [][]byte) {
  57. // insertion
  58. n := len(dec.rx) - 1
  59. insertIdx := 0
  60. for i := n; i >= 0; i-- {
  61. if in.seqid() == dec.rx[i].seqid() { // de-duplicate
  62. return nil
  63. } else if _itimediff(in.seqid(), dec.rx[i].seqid()) > 0 { // insertion
  64. insertIdx = i + 1
  65. break
  66. }
  67. }
  68. // make a copy
  69. pkt := fecPacket(xmitBuf.Get().([]byte)[:len(in)])
  70. copy(pkt, in)
  71. // insert into ordered rx queue
  72. if insertIdx == n+1 {
  73. dec.rx = append(dec.rx, pkt)
  74. } else {
  75. dec.rx = append(dec.rx, fecPacket{})
  76. copy(dec.rx[insertIdx+1:], dec.rx[insertIdx:]) // shift right
  77. dec.rx[insertIdx] = pkt
  78. }
  79. // shard range for current packet
  80. shardBegin := pkt.seqid() - pkt.seqid()%uint32(dec.shardSize)
  81. shardEnd := shardBegin + uint32(dec.shardSize) - 1
  82. // max search range in ordered queue for current shard
  83. searchBegin := insertIdx - int(pkt.seqid()%uint32(dec.shardSize))
  84. if searchBegin < 0 {
  85. searchBegin = 0
  86. }
  87. searchEnd := searchBegin + dec.shardSize - 1
  88. if searchEnd >= len(dec.rx) {
  89. searchEnd = len(dec.rx) - 1
  90. }
  91. // re-construct datashards
  92. if searchEnd-searchBegin+1 >= dec.dataShards {
  93. var numshard, numDataShard, first, maxlen int
  94. // zero caches
  95. shards := dec.decodeCache
  96. shardsflag := dec.flagCache
  97. for k := range dec.decodeCache {
  98. shards[k] = nil
  99. shardsflag[k] = false
  100. }
  101. // shard assembly
  102. for i := searchBegin; i <= searchEnd; i++ {
  103. seqid := dec.rx[i].seqid()
  104. if _itimediff(seqid, shardEnd) > 0 {
  105. break
  106. } else if _itimediff(seqid, shardBegin) >= 0 {
  107. shards[seqid%uint32(dec.shardSize)] = dec.rx[i].data()
  108. shardsflag[seqid%uint32(dec.shardSize)] = true
  109. numshard++
  110. if dec.rx[i].flag() == typeData {
  111. numDataShard++
  112. }
  113. if numshard == 1 {
  114. first = i
  115. }
  116. if len(dec.rx[i].data()) > maxlen {
  117. maxlen = len(dec.rx[i].data())
  118. }
  119. }
  120. }
  121. if numDataShard == dec.dataShards {
  122. // case 1: no loss on data shards
  123. dec.rx = dec.freeRange(first, numshard, dec.rx)
  124. } else if numshard >= dec.dataShards {
  125. // case 2: loss on data shards, but it's recoverable from parity shards
  126. for k := range shards {
  127. if shards[k] != nil {
  128. dlen := len(shards[k])
  129. shards[k] = shards[k][:maxlen]
  130. copy(shards[k][dlen:], dec.zeros)
  131. } else {
  132. shards[k] = xmitBuf.Get().([]byte)[:0]
  133. }
  134. }
  135. if err := dec.codec.ReconstructData(shards); err == nil {
  136. for k := range shards[:dec.dataShards] {
  137. if !shardsflag[k] {
  138. // recovered data should be recycled
  139. recovered = append(recovered, shards[k])
  140. }
  141. }
  142. }
  143. dec.rx = dec.freeRange(first, numshard, dec.rx)
  144. }
  145. }
  146. // keep rxlimit
  147. if len(dec.rx) > dec.rxlimit {
  148. if dec.rx[0].flag() == typeData { // track the unrecoverable data
  149. atomic.AddUint64(&DefaultSnmp.FECShortShards, 1)
  150. }
  151. dec.rx = dec.freeRange(0, 1, dec.rx)
  152. }
  153. return
  154. }
  155. // free a range of fecPacket
  156. func (dec *fecDecoder) freeRange(first, n int, q []fecPacket) []fecPacket {
  157. for i := first; i < first+n; i++ { // recycle buffer
  158. xmitBuf.Put([]byte(q[i]))
  159. }
  160. if first == 0 && n < cap(q)/2 {
  161. return q[n:]
  162. }
  163. copy(q[first:], q[first+n:])
  164. return q[:len(q)-n]
  165. }
  166. type (
  167. // fecEncoder for encoding outgoing packets
  168. fecEncoder struct {
  169. dataShards int
  170. parityShards int
  171. shardSize int
  172. paws uint32 // Protect Against Wrapped Sequence numbers
  173. next uint32 // next seqid
  174. shardCount int // count the number of datashards collected
  175. maxSize int // track maximum data length in datashard
  176. headerOffset int // FEC header offset
  177. payloadOffset int // FEC payload offset
  178. // caches
  179. shardCache [][]byte
  180. encodeCache [][]byte
  181. // zeros
  182. zeros []byte
  183. // RS encoder
  184. codec reedsolomon.Encoder
  185. }
  186. )
  187. func newFECEncoder(dataShards, parityShards, offset int) *fecEncoder {
  188. if dataShards <= 0 || parityShards <= 0 {
  189. return nil
  190. }
  191. enc := new(fecEncoder)
  192. enc.dataShards = dataShards
  193. enc.parityShards = parityShards
  194. enc.shardSize = dataShards + parityShards
  195. enc.paws = 0xffffffff / uint32(enc.shardSize) * uint32(enc.shardSize)
  196. enc.headerOffset = offset
  197. enc.payloadOffset = enc.headerOffset + fecHeaderSize
  198. codec, err := reedsolomon.New(dataShards, parityShards)
  199. if err != nil {
  200. return nil
  201. }
  202. enc.codec = codec
  203. // caches
  204. enc.encodeCache = make([][]byte, enc.shardSize)
  205. enc.shardCache = make([][]byte, enc.shardSize)
  206. for k := range enc.shardCache {
  207. enc.shardCache[k] = make([]byte, mtuLimit)
  208. }
  209. enc.zeros = make([]byte, mtuLimit)
  210. return enc
  211. }
  212. // encodes the packet, outputs parity shards if we have collected quorum datashards
  213. // notice: the contents of 'ps' will be re-written in successive calling
  214. func (enc *fecEncoder) encode(b []byte) (ps [][]byte) {
  215. // The header format:
  216. // | FEC SEQID(4B) | FEC TYPE(2B) | SIZE (2B) | PAYLOAD(SIZE-2) |
  217. // |<-headerOffset |<-payloadOffset
  218. enc.markData(b[enc.headerOffset:])
  219. binary.LittleEndian.PutUint16(b[enc.payloadOffset:], uint16(len(b[enc.payloadOffset:])))
  220. // copy data from payloadOffset to fec shard cache
  221. sz := len(b)
  222. enc.shardCache[enc.shardCount] = enc.shardCache[enc.shardCount][:sz]
  223. copy(enc.shardCache[enc.shardCount][enc.payloadOffset:], b[enc.payloadOffset:])
  224. enc.shardCount++
  225. // track max datashard length
  226. if sz > enc.maxSize {
  227. enc.maxSize = sz
  228. }
  229. // Generation of Reed-Solomon Erasure Code
  230. if enc.shardCount == enc.dataShards {
  231. // fill '0' into the tail of each datashard
  232. for i := 0; i < enc.dataShards; i++ {
  233. shard := enc.shardCache[i]
  234. slen := len(shard)
  235. copy(shard[slen:enc.maxSize], enc.zeros)
  236. }
  237. // construct equal-sized slice with stripped header
  238. cache := enc.encodeCache
  239. for k := range cache {
  240. cache[k] = enc.shardCache[k][enc.payloadOffset:enc.maxSize]
  241. }
  242. // encoding
  243. if err := enc.codec.Encode(cache); err == nil {
  244. ps = enc.shardCache[enc.dataShards:]
  245. for k := range ps {
  246. enc.markParity(ps[k][enc.headerOffset:])
  247. ps[k] = ps[k][:enc.maxSize]
  248. }
  249. }
  250. // counters resetting
  251. enc.shardCount = 0
  252. enc.maxSize = 0
  253. }
  254. return
  255. }
  256. func (enc *fecEncoder) markData(data []byte) {
  257. binary.LittleEndian.PutUint32(data, enc.next)
  258. binary.LittleEndian.PutUint16(data[4:], typeData)
  259. enc.next++
  260. }
  261. func (enc *fecEncoder) markParity(data []byte) {
  262. binary.LittleEndian.PutUint32(data, enc.next)
  263. binary.LittleEndian.PutUint16(data[4:], typeParity)
  264. // sequence wrap will only happen at parity shard
  265. enc.next = (enc.next + 1) % enc.paws
  266. }