mux.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. package mux
  2. import (
  3. "errors"
  4. "io"
  5. "math"
  6. "net"
  7. "sync/atomic"
  8. "time"
  9. "github.com/astaxie/beego/logs"
  10. "github.com/cnlh/nps/lib/common"
  11. )
  12. type Mux struct {
  13. latency uint64 // we store latency in bits, but it's float64
  14. net.Listener
  15. conn net.Conn
  16. connMap *connMap
  17. newConnCh chan *conn
  18. id int32
  19. closeChan chan struct{}
  20. IsClose bool
  21. pingOk int
  22. counter *latencyCounter
  23. bw *bandwidth
  24. pingCh chan []byte
  25. pingCheckTime uint32
  26. connType string
  27. writeQueue PriorityQueue
  28. newConnQueue ConnQueue
  29. }
  30. func NewMux(c net.Conn, connType string) *Mux {
  31. //c.(*net.TCPConn).SetReadBuffer(0)
  32. //c.(*net.TCPConn).SetWriteBuffer(0)
  33. m := &Mux{
  34. conn: c,
  35. connMap: NewConnMap(),
  36. id: 0,
  37. closeChan: make(chan struct{}, 1),
  38. newConnCh: make(chan *conn),
  39. bw: new(bandwidth),
  40. IsClose: false,
  41. connType: connType,
  42. pingCh: make(chan []byte),
  43. counter: newLatencyCounter(),
  44. }
  45. m.writeQueue.New()
  46. m.newConnQueue.New()
  47. //read session by flag
  48. m.readSession()
  49. //ping
  50. m.ping()
  51. m.pingReturn()
  52. m.writeSession()
  53. return m
  54. }
  55. func (s *Mux) NewConn() (*conn, error) {
  56. if s.IsClose {
  57. return nil, errors.New("the mux has closed")
  58. }
  59. conn := NewConn(s.getId(), s, "nps ")
  60. //it must be set before send
  61. s.connMap.Set(conn.connId, conn)
  62. s.sendInfo(common.MUX_NEW_CONN, conn.connId, nil)
  63. //set a timer timeout 30 second
  64. timer := time.NewTimer(time.Minute * 2)
  65. defer timer.Stop()
  66. select {
  67. case <-conn.connStatusOkCh:
  68. return conn, nil
  69. case <-conn.connStatusFailCh:
  70. case <-timer.C:
  71. }
  72. return nil, errors.New("create connection fail,the server refused the connection")
  73. }
  74. func (s *Mux) Accept() (net.Conn, error) {
  75. if s.IsClose {
  76. return nil, errors.New("accpet error,the mux has closed")
  77. }
  78. conn := <-s.newConnCh
  79. if conn == nil {
  80. return nil, errors.New("accpet error,the conn has closed")
  81. }
  82. return conn, nil
  83. }
  84. func (s *Mux) Addr() net.Addr {
  85. return s.conn.LocalAddr()
  86. }
  87. func (s *Mux) sendInfo(flag uint8, id int32, data ...interface{}) {
  88. if s.IsClose {
  89. return
  90. }
  91. var err error
  92. pack := common.MuxPack.Get()
  93. err = pack.NewPac(flag, id, data...)
  94. if err != nil {
  95. common.MuxPack.Put(pack)
  96. logs.Error("mux: new pack err")
  97. s.Close()
  98. return
  99. }
  100. s.writeQueue.Push(pack)
  101. return
  102. }
  103. func (s *Mux) writeSession() {
  104. go s.packBuf()
  105. //go s.writeBuf()
  106. }
  107. func (s *Mux) packBuf() {
  108. buffer := common.BuffPool.Get()
  109. for {
  110. if s.IsClose {
  111. break
  112. }
  113. buffer.Reset()
  114. pack := s.writeQueue.Pop()
  115. //buffer := common.BuffPool.Get()
  116. err := pack.Pack(buffer)
  117. common.MuxPack.Put(pack)
  118. if err != nil {
  119. logs.Error("mux: pack err", err)
  120. common.BuffPool.Put(buffer)
  121. break
  122. }
  123. //logs.Warn(buffer.String())
  124. //s.bufQueue.Push(buffer)
  125. l := buffer.Len()
  126. n, err := buffer.WriteTo(s.conn)
  127. //common.BuffPool.Put(buffer)
  128. if err != nil || int(n) != l {
  129. logs.Error("mux: close from write session fail ", err, n, l)
  130. s.Close()
  131. break
  132. }
  133. }
  134. }
  135. //func (s *Mux) writeBuf() {
  136. // for {
  137. // if s.IsClose {
  138. // break
  139. // }
  140. // buffer, err := s.bufQueue.Pop()
  141. // if err != nil {
  142. // break
  143. // }
  144. // l := buffer.Len()
  145. // n, err := buffer.WriteTo(s.conn)
  146. // common.BuffPool.Put(buffer)
  147. // if err != nil || int(n) != l {
  148. // logs.Warn("close from write session fail ", err, n, l)
  149. // s.Close()
  150. // break
  151. // }
  152. // }
  153. //}
  154. func (s *Mux) ping() {
  155. go func() {
  156. now, _ := time.Now().UTC().MarshalText()
  157. s.sendInfo(common.MUX_PING_FLAG, common.MUX_PING, now)
  158. // send the ping flag and get the latency first
  159. ticker := time.NewTicker(time.Second * 5)
  160. for {
  161. if s.IsClose {
  162. ticker.Stop()
  163. break
  164. }
  165. select {
  166. case <-ticker.C:
  167. }
  168. if atomic.LoadUint32(&s.pingCheckTime) >= 60 {
  169. logs.Error("mux: ping time out")
  170. s.Close()
  171. // more than 5 minutes not receive the ping return package,
  172. // mux conn is damaged, maybe a packet drop, close it
  173. break
  174. }
  175. now, _ := time.Now().UTC().MarshalText()
  176. s.sendInfo(common.MUX_PING_FLAG, common.MUX_PING, now)
  177. atomic.AddUint32(&s.pingCheckTime, 1)
  178. if s.pingOk > 10 && s.connType == "kcp" {
  179. logs.Error("mux: kcp ping err")
  180. s.Close()
  181. break
  182. }
  183. s.pingOk++
  184. }
  185. }()
  186. }
  187. func (s *Mux) pingReturn() {
  188. go func() {
  189. var now time.Time
  190. var data []byte
  191. for {
  192. if s.IsClose {
  193. break
  194. }
  195. select {
  196. case data = <-s.pingCh:
  197. atomic.StoreUint32(&s.pingCheckTime, 0)
  198. case <-s.closeChan:
  199. break
  200. }
  201. _ = now.UnmarshalText(data)
  202. latency := time.Now().UTC().Sub(now).Seconds() / 2
  203. if latency > 0 {
  204. atomic.StoreUint64(&s.latency, math.Float64bits(s.counter.Latency(latency)))
  205. // convert float64 to bits, store it atomic
  206. }
  207. //logs.Warn("latency", math.Float64frombits(atomic.LoadUint64(&s.latency)))
  208. common.WindowBuff.Put(data)
  209. }
  210. }()
  211. }
  212. func (s *Mux) readSession() {
  213. go func() {
  214. var connection *conn
  215. for {
  216. connection = s.newConnQueue.Pop()
  217. s.connMap.Set(connection.connId, connection) //it has been set before send ok
  218. s.newConnCh <- connection
  219. s.sendInfo(common.MUX_NEW_CONN_OK, connection.connId, nil)
  220. }
  221. }()
  222. go func() {
  223. pack := common.MuxPack.Get()
  224. var l uint16
  225. var err error
  226. for {
  227. if s.IsClose {
  228. break
  229. }
  230. pack = common.MuxPack.Get()
  231. s.bw.StartRead()
  232. if l, err = pack.UnPack(s.conn); err != nil {
  233. logs.Error("mux: read session unpack from connection err")
  234. s.Close()
  235. break
  236. }
  237. s.bw.SetCopySize(l)
  238. s.pingOk = 0
  239. switch pack.Flag {
  240. case common.MUX_NEW_CONN: //new connection
  241. connection := NewConn(pack.Id, s)
  242. s.newConnQueue.Push(connection)
  243. continue
  244. case common.MUX_PING_FLAG: //ping
  245. s.sendInfo(common.MUX_PING_RETURN, common.MUX_PING, pack.Content)
  246. common.WindowBuff.Put(pack.Content)
  247. continue
  248. case common.MUX_PING_RETURN:
  249. //go func(content []byte) {
  250. s.pingCh <- pack.Content
  251. //}(pack.Content)
  252. continue
  253. }
  254. if connection, ok := s.connMap.Get(pack.Id); ok && !connection.isClose {
  255. switch pack.Flag {
  256. case common.MUX_NEW_MSG, common.MUX_NEW_MSG_PART: //new msg from remote connection
  257. err = s.newMsg(connection, pack)
  258. if err != nil {
  259. logs.Error("mux: read session connection new msg err")
  260. connection.Close()
  261. }
  262. continue
  263. case common.MUX_NEW_CONN_OK: //connection ok
  264. connection.connStatusOkCh <- struct{}{}
  265. continue
  266. case common.MUX_NEW_CONN_Fail:
  267. connection.connStatusFailCh <- struct{}{}
  268. continue
  269. case common.MUX_MSG_SEND_OK:
  270. if connection.isClose {
  271. continue
  272. }
  273. connection.sendWindow.SetSize(pack.Window, pack.ReadLength)
  274. continue
  275. case common.MUX_CONN_CLOSE: //close the connection
  276. s.connMap.Delete(pack.Id)
  277. //go func(connection *conn) {
  278. connection.closeFlag = true
  279. connection.receiveWindow.Stop() // close signal to receive window
  280. //}(connection)
  281. continue
  282. }
  283. } else if pack.Flag == common.MUX_CONN_CLOSE {
  284. continue
  285. }
  286. common.MuxPack.Put(pack)
  287. }
  288. common.MuxPack.Put(pack)
  289. s.Close()
  290. }()
  291. }
  292. func (s *Mux) newMsg(connection *conn, pack *common.MuxPackager) (err error) {
  293. if connection.isClose {
  294. err = io.ErrClosedPipe
  295. return
  296. }
  297. //logs.Warn("read session receive new msg", pack.Length)
  298. //go func(connection *conn, pack *common.MuxPackager) { // do not block read session
  299. //insert into queue
  300. if pack.Flag == common.MUX_NEW_MSG_PART {
  301. err = connection.receiveWindow.Write(pack.Content, pack.Length, true, pack.Id)
  302. }
  303. if pack.Flag == common.MUX_NEW_MSG {
  304. err = connection.receiveWindow.Write(pack.Content, pack.Length, false, pack.Id)
  305. }
  306. //logs.Warn("read session write success", pack.Length)
  307. return
  308. }
  309. func (s *Mux) Close() error {
  310. logs.Warn("close mux")
  311. if s.IsClose {
  312. return errors.New("the mux has closed")
  313. }
  314. s.IsClose = true
  315. s.connMap.Close()
  316. //s.bufQueue.Stop()
  317. s.closeChan <- struct{}{}
  318. close(s.newConnCh)
  319. return s.conn.Close()
  320. }
  321. //get new connId as unique flag
  322. func (s *Mux) getId() (id int32) {
  323. //Avoid going beyond the scope
  324. if (math.MaxInt32 - s.id) < 10000 {
  325. atomic.StoreInt32(&s.id, 0)
  326. }
  327. id = atomic.AddInt32(&s.id, 1)
  328. if _, ok := s.connMap.Get(id); ok {
  329. return s.getId()
  330. }
  331. return
  332. }
  333. type bandwidth struct {
  334. readStart time.Time
  335. lastReadStart time.Time
  336. bufLength uint16
  337. readBandwidth float64
  338. }
  339. func (Self *bandwidth) StartRead() {
  340. if Self.readStart.IsZero() {
  341. Self.readStart = time.Now()
  342. }
  343. if Self.bufLength >= 16384 {
  344. Self.lastReadStart, Self.readStart = Self.readStart, time.Now()
  345. Self.calcBandWidth()
  346. }
  347. }
  348. func (Self *bandwidth) SetCopySize(n uint16) {
  349. Self.bufLength += n
  350. }
  351. func (Self *bandwidth) calcBandWidth() {
  352. t := Self.readStart.Sub(Self.lastReadStart)
  353. Self.readBandwidth = float64(Self.bufLength) / t.Seconds()
  354. Self.bufLength = 0
  355. }
  356. func (Self *bandwidth) Get() (bw float64) {
  357. // The zero value, 0 for numeric types
  358. if Self.readBandwidth <= 0 {
  359. Self.readBandwidth = 100
  360. }
  361. return Self.readBandwidth
  362. }
  363. const counterBits = 4
  364. const counterMask = 1<<counterBits - 1
  365. func newLatencyCounter() *latencyCounter {
  366. return &latencyCounter{
  367. buf: make([]float64, 1<<counterBits, 1<<counterBits),
  368. headMin: 0,
  369. }
  370. }
  371. type latencyCounter struct {
  372. buf []float64 //buf is a fixed length ring buffer,
  373. // if buffer is full, new value will replace the oldest one.
  374. headMin uint8 //head indicate the head in ring buffer,
  375. // in meaning, slot in list will be replaced;
  376. // min indicate this slot value is minimal in list.
  377. }
  378. func (Self *latencyCounter) unpack(idxs uint8) (head, min uint8) {
  379. head = uint8((idxs >> counterBits) & counterMask)
  380. // we set head is 4 bits
  381. min = uint8(idxs & counterMask)
  382. return
  383. }
  384. func (Self *latencyCounter) pack(head, min uint8) uint8 {
  385. return uint8(head<<counterBits) |
  386. uint8(min&counterMask)
  387. }
  388. func (Self *latencyCounter) add(value float64) {
  389. head, min := Self.unpack(Self.headMin)
  390. Self.buf[head] = value
  391. if head == min {
  392. min = Self.minimal()
  393. //if head equals min, means the min slot already be replaced,
  394. // so we need to find another minimal value in the list,
  395. // and change the min indicator
  396. }
  397. if Self.buf[min] > value {
  398. min = head
  399. }
  400. head++
  401. Self.headMin = Self.pack(head, min)
  402. }
  403. func (Self *latencyCounter) minimal() (min uint8) {
  404. var val float64
  405. var i uint8
  406. for i = 0; i < counterMask; i++ {
  407. if Self.buf[i] > 0 {
  408. if val > Self.buf[i] {
  409. val = Self.buf[i]
  410. min = i
  411. }
  412. }
  413. }
  414. return
  415. }
  416. func (Self *latencyCounter) Latency(value float64) (latency float64) {
  417. Self.add(value)
  418. _, min := Self.unpack(Self.headMin)
  419. latency = Self.buf[min] * Self.countSuccess()
  420. return
  421. }
  422. const lossRatio = 1.6
  423. func (Self *latencyCounter) countSuccess() (successRate float64) {
  424. var success, loss, i uint8
  425. _, min := Self.unpack(Self.headMin)
  426. for i = 0; i < counterMask; i++ {
  427. if Self.buf[i] > lossRatio*Self.buf[min] && Self.buf[i] > 0 {
  428. loss++
  429. }
  430. if Self.buf[i] <= lossRatio*Self.buf[min] && Self.buf[i] > 0 {
  431. success++
  432. }
  433. }
  434. // counting all the data in the ring buf, except zero
  435. successRate = float64(success) / float64(loss+success)
  436. return
  437. }