mux.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. package mux
  2. import (
  3. "errors"
  4. "io"
  5. "math"
  6. "net"
  7. "sync"
  8. "sync/atomic"
  9. "time"
  10. "github.com/astaxie/beego/logs"
  11. "github.com/cnlh/nps/lib/common"
  12. )
  13. type Mux struct {
  14. net.Listener
  15. conn net.Conn
  16. connMap *connMap
  17. newConnCh chan *conn
  18. id int32
  19. closeChan chan struct{}
  20. IsClose bool
  21. pingOk int
  22. latency float64
  23. bw *bandwidth
  24. pingCh chan []byte
  25. pingTimer *time.Timer
  26. connType string
  27. writeQueue PriorityQueue
  28. //bufQueue BytesQueue
  29. sync.Mutex
  30. }
  31. func NewMux(c net.Conn, connType string) *Mux {
  32. //c.(*net.TCPConn).SetReadBuffer(0)
  33. //c.(*net.TCPConn).SetWriteBuffer(0)
  34. m := &Mux{
  35. conn: c,
  36. connMap: NewConnMap(),
  37. id: 0,
  38. closeChan: make(chan struct{}, 1),
  39. newConnCh: make(chan *conn, 10),
  40. bw: new(bandwidth),
  41. IsClose: false,
  42. connType: connType,
  43. pingCh: make(chan []byte),
  44. pingTimer: time.NewTimer(15 * time.Second),
  45. }
  46. m.writeQueue.New()
  47. //m.bufQueue.New()
  48. //read session by flag
  49. m.readSession()
  50. //ping
  51. m.ping()
  52. m.pingReturn()
  53. m.writeSession()
  54. return m
  55. }
  56. func (s *Mux) NewConn() (*conn, error) {
  57. if s.IsClose {
  58. return nil, errors.New("the mux has closed")
  59. }
  60. conn := NewConn(s.getId(), s, "nps ")
  61. //it must be set before send
  62. s.connMap.Set(conn.connId, conn)
  63. s.sendInfo(common.MUX_NEW_CONN, conn.connId, nil)
  64. //set a timer timeout 30 second
  65. timer := time.NewTimer(time.Minute * 2)
  66. defer timer.Stop()
  67. select {
  68. case <-conn.connStatusOkCh:
  69. return conn, nil
  70. case <-conn.connStatusFailCh:
  71. case <-timer.C:
  72. }
  73. return nil, errors.New("create connection fail,the server refused the connection")
  74. }
  75. func (s *Mux) Accept() (net.Conn, error) {
  76. if s.IsClose {
  77. return nil, errors.New("accpet error,the mux has closed")
  78. }
  79. conn := <-s.newConnCh
  80. if conn == nil {
  81. return nil, errors.New("accpet error,the conn has closed")
  82. }
  83. return conn, nil
  84. }
  85. func (s *Mux) Addr() net.Addr {
  86. return s.conn.LocalAddr()
  87. }
  88. func (s *Mux) sendInfo(flag uint8, id int32, data ...interface{}) {
  89. if s.IsClose {
  90. return
  91. }
  92. var err error
  93. pack := common.MuxPack.Get()
  94. err = pack.NewPac(flag, id, data...)
  95. if err != nil {
  96. common.MuxPack.Put(pack)
  97. return
  98. }
  99. s.writeQueue.Push(pack)
  100. return
  101. }
  102. func (s *Mux) writeSession() {
  103. go s.packBuf()
  104. //go s.writeBuf()
  105. }
  106. func (s *Mux) packBuf() {
  107. buffer := common.BuffPool.Get()
  108. for {
  109. if s.IsClose {
  110. break
  111. }
  112. buffer.Reset()
  113. pack := s.writeQueue.Pop()
  114. //buffer := common.BuffPool.Get()
  115. err := pack.Pack(buffer)
  116. common.MuxPack.Put(pack)
  117. if err != nil {
  118. logs.Warn("pack err", err)
  119. common.BuffPool.Put(buffer)
  120. break
  121. }
  122. //logs.Warn(buffer.String())
  123. //s.bufQueue.Push(buffer)
  124. l := buffer.Len()
  125. n, err := buffer.WriteTo(s.conn)
  126. //common.BuffPool.Put(buffer)
  127. if err != nil || int(n) != l {
  128. logs.Warn("close from write session fail ", err, n, l)
  129. s.Close()
  130. break
  131. }
  132. }
  133. }
  134. //func (s *Mux) writeBuf() {
  135. // for {
  136. // if s.IsClose {
  137. // break
  138. // }
  139. // buffer, err := s.bufQueue.Pop()
  140. // if err != nil {
  141. // break
  142. // }
  143. // l := buffer.Len()
  144. // n, err := buffer.WriteTo(s.conn)
  145. // common.BuffPool.Put(buffer)
  146. // if err != nil || int(n) != l {
  147. // logs.Warn("close from write session fail ", err, n, l)
  148. // s.Close()
  149. // break
  150. // }
  151. // }
  152. //}
  153. func (s *Mux) ping() {
  154. go func() {
  155. now, _ := time.Now().UTC().MarshalText()
  156. s.sendInfo(common.MUX_PING_FLAG, common.MUX_PING, now)
  157. // send the ping flag and get the latency first
  158. ticker := time.NewTicker(time.Second * 5)
  159. for {
  160. if s.IsClose {
  161. ticker.Stop()
  162. if !s.pingTimer.Stop() {
  163. <-s.pingTimer.C
  164. }
  165. break
  166. }
  167. select {
  168. case <-ticker.C:
  169. }
  170. now, _ := time.Now().UTC().MarshalText()
  171. s.sendInfo(common.MUX_PING_FLAG, common.MUX_PING, now)
  172. if !s.pingTimer.Stop() {
  173. <-s.pingTimer.C
  174. }
  175. s.pingTimer.Reset(15 * time.Second)
  176. if s.pingOk > 10 && s.connType == "kcp" {
  177. s.Close()
  178. break
  179. }
  180. s.pingOk++
  181. }
  182. }()
  183. }
  184. func (s *Mux) pingReturn() {
  185. go func() {
  186. var now time.Time
  187. var data []byte
  188. for {
  189. if s.IsClose {
  190. break
  191. }
  192. select {
  193. case data = <-s.pingCh:
  194. case <-s.closeChan:
  195. break
  196. case <-s.pingTimer.C:
  197. logs.Error("mux: ping time out")
  198. s.Close()
  199. break
  200. }
  201. _ = now.UnmarshalText(data)
  202. latency := time.Now().UTC().Sub(now).Seconds() / 2
  203. if latency < 0.5 && latency > 0 {
  204. s.latency = latency
  205. }
  206. //logs.Warn("latency", s.latency)
  207. common.WindowBuff.Put(data)
  208. }
  209. }()
  210. }
  211. func (s *Mux) readSession() {
  212. go func() {
  213. pack := common.MuxPack.Get()
  214. var l uint16
  215. var err error
  216. for {
  217. if s.IsClose {
  218. break
  219. }
  220. pack = common.MuxPack.Get()
  221. s.bw.StartRead()
  222. if l, err = pack.UnPack(s.conn); err != nil {
  223. break
  224. }
  225. s.bw.SetCopySize(l)
  226. s.pingOk = 0
  227. switch pack.Flag {
  228. case common.MUX_NEW_CONN: //new connection
  229. connection := NewConn(pack.Id, s, "npc ")
  230. s.connMap.Set(pack.Id, connection) //it has been set before send ok
  231. s.newConnCh <- connection
  232. s.sendInfo(common.MUX_NEW_CONN_OK, connection.connId, nil)
  233. continue
  234. case common.MUX_PING_FLAG: //ping
  235. s.sendInfo(common.MUX_PING_RETURN, common.MUX_PING, pack.Content)
  236. common.WindowBuff.Put(pack.Content)
  237. continue
  238. case common.MUX_PING_RETURN:
  239. s.pingCh <- pack.Content
  240. continue
  241. }
  242. if connection, ok := s.connMap.Get(pack.Id); ok && !connection.isClose {
  243. switch pack.Flag {
  244. case common.MUX_NEW_MSG, common.MUX_NEW_MSG_PART: //new msg from remote connection
  245. err = s.newMsg(connection, pack)
  246. if err != nil {
  247. connection.Close()
  248. }
  249. continue
  250. case common.MUX_NEW_CONN_OK: //connection ok
  251. connection.connStatusOkCh <- struct{}{}
  252. continue
  253. case common.MUX_NEW_CONN_Fail:
  254. connection.connStatusFailCh <- struct{}{}
  255. continue
  256. case common.MUX_MSG_SEND_OK:
  257. if connection.isClose {
  258. continue
  259. }
  260. connection.sendWindow.SetSize(pack.Window, pack.ReadLength)
  261. continue
  262. case common.MUX_CONN_CLOSE: //close the connection
  263. s.connMap.Delete(pack.Id)
  264. connection.closeFlag = true
  265. connection.receiveWindow.Stop() // close signal to receive window
  266. continue
  267. }
  268. } else if pack.Flag == common.MUX_CONN_CLOSE {
  269. continue
  270. }
  271. common.MuxPack.Put(pack)
  272. }
  273. common.MuxPack.Put(pack)
  274. s.Close()
  275. }()
  276. }
  277. func (s *Mux) newMsg(connection *conn, pack *common.MuxPackager) (err error) {
  278. if connection.isClose {
  279. err = io.ErrClosedPipe
  280. return
  281. }
  282. //logs.Warn("read session receive new msg", pack.Length)
  283. //go func(connection *conn, pack *common.MuxPackager) { // do not block read session
  284. //insert into queue
  285. if pack.Flag == common.MUX_NEW_MSG_PART {
  286. err = connection.receiveWindow.Write(pack.Content, pack.Length, true, pack.Id)
  287. }
  288. if pack.Flag == common.MUX_NEW_MSG {
  289. err = connection.receiveWindow.Write(pack.Content, pack.Length, false, pack.Id)
  290. }
  291. //logs.Warn("read session write success", pack.Length)
  292. return
  293. }
  294. func (s *Mux) Close() error {
  295. logs.Warn("close mux")
  296. if s.IsClose {
  297. return errors.New("the mux has closed")
  298. }
  299. s.IsClose = true
  300. s.connMap.Close()
  301. //s.bufQueue.Stop()
  302. s.closeChan <- struct{}{}
  303. close(s.newConnCh)
  304. return s.conn.Close()
  305. }
  306. //get new connId as unique flag
  307. func (s *Mux) getId() (id int32) {
  308. //Avoid going beyond the scope
  309. if (math.MaxInt32 - s.id) < 10000 {
  310. atomic.StoreInt32(&s.id, 0)
  311. }
  312. id = atomic.AddInt32(&s.id, 1)
  313. if _, ok := s.connMap.Get(id); ok {
  314. return s.getId()
  315. }
  316. return
  317. }
  318. type bandwidth struct {
  319. readStart time.Time
  320. lastReadStart time.Time
  321. bufLength uint16
  322. readBandwidth float64
  323. }
  324. func (Self *bandwidth) StartRead() {
  325. if Self.readStart.IsZero() {
  326. Self.readStart = time.Now()
  327. }
  328. if Self.bufLength >= 16384 {
  329. Self.lastReadStart, Self.readStart = Self.readStart, time.Now()
  330. Self.calcBandWidth()
  331. }
  332. }
  333. func (Self *bandwidth) SetCopySize(n uint16) {
  334. Self.bufLength += n
  335. }
  336. func (Self *bandwidth) calcBandWidth() {
  337. t := Self.readStart.Sub(Self.lastReadStart)
  338. Self.readBandwidth = float64(Self.bufLength) / t.Seconds()
  339. Self.bufLength = 0
  340. }
  341. func (Self *bandwidth) Get() (bw float64) {
  342. // The zero value, 0 for numeric types
  343. if Self.readBandwidth <= 0 {
  344. Self.readBandwidth = 100
  345. }
  346. return Self.readBandwidth
  347. }