mux.go 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. package mux
  2. import (
  3. "bytes"
  4. "errors"
  5. "io"
  6. "math"
  7. "net"
  8. "sync"
  9. "sync/atomic"
  10. "time"
  11. "github.com/astaxie/beego/logs"
  12. "github.com/cnlh/nps/lib/common"
  13. )
  14. type Mux struct {
  15. net.Listener
  16. conn net.Conn
  17. connMap *connMap
  18. newConnCh chan *conn
  19. id int32
  20. closeChan chan struct{}
  21. IsClose bool
  22. pingOk int
  23. latency float64
  24. bw *bandwidth
  25. pingCh chan []byte
  26. pingTimer *time.Timer
  27. connType string
  28. writeQueue PriorityQueue
  29. bufCh chan *bytes.Buffer
  30. sync.Mutex
  31. }
  32. func NewMux(c net.Conn, connType string) *Mux {
  33. m := &Mux{
  34. conn: c,
  35. connMap: NewConnMap(),
  36. id: 0,
  37. closeChan: make(chan struct{}, 3),
  38. newConnCh: make(chan *conn),
  39. bw: new(bandwidth),
  40. IsClose: false,
  41. connType: connType,
  42. bufCh: make(chan *bytes.Buffer),
  43. pingCh: make(chan []byte),
  44. pingTimer: time.NewTimer(15 * time.Second),
  45. }
  46. m.writeQueue.New()
  47. //read session by flag
  48. m.readSession()
  49. //ping
  50. m.ping()
  51. m.pingReturn()
  52. m.writeSession()
  53. return m
  54. }
  55. func (s *Mux) NewConn() (*conn, error) {
  56. if s.IsClose {
  57. return nil, errors.New("the mux has closed")
  58. }
  59. conn := NewConn(s.getId(), s, "nps ")
  60. //it must be set before send
  61. s.connMap.Set(conn.connId, conn)
  62. s.sendInfo(common.MUX_NEW_CONN, conn.connId, nil)
  63. //set a timer timeout 30 second
  64. timer := time.NewTimer(time.Minute * 2)
  65. defer timer.Stop()
  66. select {
  67. case <-conn.connStatusOkCh:
  68. return conn, nil
  69. case <-conn.connStatusFailCh:
  70. case <-timer.C:
  71. }
  72. return nil, errors.New("create connection fail,the server refused the connection")
  73. }
  74. func (s *Mux) Accept() (net.Conn, error) {
  75. if s.IsClose {
  76. return nil, errors.New("accpet error,the mux has closed")
  77. }
  78. conn := <-s.newConnCh
  79. if conn == nil {
  80. return nil, errors.New("accpet error,the conn has closed")
  81. }
  82. return conn, nil
  83. }
  84. func (s *Mux) Addr() net.Addr {
  85. return s.conn.LocalAddr()
  86. }
  87. func (s *Mux) sendInfo(flag uint8, id int32, data ...interface{}) {
  88. if s.IsClose {
  89. return
  90. }
  91. var err error
  92. pack := common.MuxPack.Get()
  93. err = pack.NewPac(flag, id, data...)
  94. if err != nil {
  95. common.MuxPack.Put(pack)
  96. return
  97. }
  98. s.writeQueue.Push(pack)
  99. return
  100. }
  101. func (s *Mux) writeSession() {
  102. go s.packBuf()
  103. go s.writeBuf()
  104. }
  105. func (s *Mux) packBuf() {
  106. for {
  107. if s.IsClose {
  108. break
  109. }
  110. pack := s.writeQueue.Pop()
  111. buffer := common.BuffPool.Get()
  112. err := pack.Pack(buffer)
  113. common.MuxPack.Put(pack)
  114. if err != nil {
  115. logs.Warn("pack err", err)
  116. common.BuffPool.Put(buffer)
  117. break
  118. }
  119. //logs.Warn(buffer.String())
  120. select {
  121. case s.bufCh <- buffer:
  122. case <-s.closeChan:
  123. break
  124. }
  125. }
  126. }
  127. func (s *Mux) writeBuf() {
  128. for {
  129. if s.IsClose {
  130. break
  131. }
  132. select {
  133. case buffer := <-s.bufCh:
  134. l := buffer.Len()
  135. n, err := buffer.WriteTo(s.conn)
  136. common.BuffPool.Put(buffer)
  137. if err != nil || int(n) != l {
  138. logs.Warn("close from write session fail ", err, n, l)
  139. s.Close()
  140. break
  141. }
  142. case <-s.closeChan:
  143. break
  144. }
  145. }
  146. }
  147. func (s *Mux) ping() {
  148. go func() {
  149. now, _ := time.Now().UTC().MarshalText()
  150. s.sendInfo(common.MUX_PING_FLAG, common.MUX_PING, now)
  151. // send the ping flag and get the latency first
  152. ticker := time.NewTicker(time.Second * 5)
  153. for {
  154. if s.IsClose {
  155. ticker.Stop()
  156. if !s.pingTimer.Stop() {
  157. <-s.pingTimer.C
  158. }
  159. break
  160. }
  161. select {
  162. case <-ticker.C:
  163. }
  164. //Avoid going beyond the scope
  165. if (math.MaxInt32 - s.id) < 10000 {
  166. s.id = 0
  167. }
  168. now, _ := time.Now().UTC().MarshalText()
  169. s.sendInfo(common.MUX_PING_FLAG, common.MUX_PING, now)
  170. if !s.pingTimer.Stop() {
  171. <-s.pingTimer.C
  172. }
  173. s.pingTimer.Reset(15 * time.Second)
  174. if s.pingOk > 10 && s.connType == "kcp" {
  175. s.Close()
  176. break
  177. }
  178. s.pingOk++
  179. }
  180. }()
  181. }
  182. func (s *Mux) pingReturn() {
  183. go func() {
  184. var now time.Time
  185. var data []byte
  186. for {
  187. if s.IsClose {
  188. break
  189. }
  190. select {
  191. case data = <-s.pingCh:
  192. case <-s.closeChan:
  193. break
  194. case <-s.pingTimer.C:
  195. logs.Error("mux: ping time out")
  196. s.Close()
  197. break
  198. }
  199. _ = now.UnmarshalText(data)
  200. latency := time.Now().UTC().Sub(now).Seconds() / 2
  201. if latency < 0.5 && latency > 0 {
  202. s.latency = latency
  203. }
  204. //logs.Warn("latency", s.latency)
  205. common.WindowBuff.Put(data)
  206. }
  207. }()
  208. }
  209. func (s *Mux) readSession() {
  210. go func() {
  211. pack := common.MuxPack.Get()
  212. var l uint16
  213. var err error
  214. for {
  215. if s.IsClose {
  216. break
  217. }
  218. pack = common.MuxPack.Get()
  219. s.bw.StartRead()
  220. if l, err = pack.UnPack(s.conn); err != nil {
  221. break
  222. }
  223. s.bw.SetCopySize(l)
  224. s.pingOk = 0
  225. switch pack.Flag {
  226. case common.MUX_NEW_CONN: //new connection
  227. connection := NewConn(pack.Id, s, "npc ")
  228. s.connMap.Set(pack.Id, connection) //it has been set before send ok
  229. s.newConnCh <- connection
  230. s.sendInfo(common.MUX_NEW_CONN_OK, connection.connId, nil)
  231. continue
  232. case common.MUX_PING_FLAG: //ping
  233. s.sendInfo(common.MUX_PING_RETURN, common.MUX_PING, pack.Content)
  234. common.WindowBuff.Put(pack.Content)
  235. continue
  236. case common.MUX_PING_RETURN:
  237. s.pingCh <- pack.Content
  238. continue
  239. }
  240. if connection, ok := s.connMap.Get(pack.Id); ok && !connection.isClose {
  241. switch pack.Flag {
  242. case common.MUX_NEW_MSG, common.MUX_NEW_MSG_PART: //new msg from remote connection
  243. err = s.newMsg(connection, pack)
  244. if err != nil {
  245. connection.Close()
  246. }
  247. continue
  248. case common.MUX_NEW_CONN_OK: //connection ok
  249. connection.connStatusOkCh <- struct{}{}
  250. continue
  251. case common.MUX_NEW_CONN_Fail:
  252. connection.connStatusFailCh <- struct{}{}
  253. continue
  254. case common.MUX_MSG_SEND_OK:
  255. if connection.isClose {
  256. continue
  257. }
  258. connection.sendWindow.SetSize(pack.Window, pack.ReadLength)
  259. continue
  260. case common.MUX_CONN_CLOSE: //close the connection
  261. s.connMap.Delete(pack.Id)
  262. connection.closeFlag = true
  263. connection.receiveWindow.Stop() // close signal to receive window
  264. continue
  265. }
  266. } else if pack.Flag == common.MUX_CONN_CLOSE {
  267. continue
  268. }
  269. common.MuxPack.Put(pack)
  270. }
  271. common.MuxPack.Put(pack)
  272. s.Close()
  273. }()
  274. }
  275. func (s *Mux) newMsg(connection *conn, pack *common.MuxPackager) (err error) {
  276. if connection.isClose {
  277. err = io.ErrClosedPipe
  278. return
  279. }
  280. //logs.Warn("read session receive new msg", pack.Length)
  281. //go func(connection *conn, pack *common.MuxPackager) { // do not block read session
  282. //insert into queue
  283. if pack.Flag == common.MUX_NEW_MSG_PART {
  284. err = connection.receiveWindow.Write(pack.Content, pack.Length, true, pack.Id)
  285. }
  286. if pack.Flag == common.MUX_NEW_MSG {
  287. err = connection.receiveWindow.Write(pack.Content, pack.Length, false, pack.Id)
  288. }
  289. //logs.Warn("read session write success", pack.Length)
  290. return
  291. }
  292. func (s *Mux) Close() error {
  293. logs.Warn("close mux")
  294. if s.IsClose {
  295. return errors.New("the mux has closed")
  296. }
  297. s.IsClose = true
  298. s.connMap.Close()
  299. s.closeChan <- struct{}{}
  300. s.closeChan <- struct{}{}
  301. s.closeChan <- struct{}{}
  302. close(s.newConnCh)
  303. return s.conn.Close()
  304. }
  305. //get new connId as unique flag
  306. func (s *Mux) getId() (id int32) {
  307. id = atomic.AddInt32(&s.id, 1)
  308. if _, ok := s.connMap.Get(id); ok {
  309. s.getId()
  310. }
  311. return
  312. }
  313. type bandwidth struct {
  314. readStart time.Time
  315. lastReadStart time.Time
  316. bufLength uint16
  317. readBandwidth float64
  318. }
  319. func (Self *bandwidth) StartRead() {
  320. if Self.readStart.IsZero() {
  321. Self.readStart = time.Now()
  322. }
  323. if Self.bufLength >= 16384 {
  324. Self.lastReadStart, Self.readStart = Self.readStart, time.Now()
  325. Self.calcBandWidth()
  326. }
  327. }
  328. func (Self *bandwidth) SetCopySize(n uint16) {
  329. Self.bufLength += n
  330. }
  331. func (Self *bandwidth) calcBandWidth() {
  332. t := Self.readStart.Sub(Self.lastReadStart)
  333. Self.readBandwidth = float64(Self.bufLength) / t.Seconds()
  334. Self.bufLength = 0
  335. }
  336. func (Self *bandwidth) Get() (bw float64) {
  337. // The zero value, 0 for numeric types
  338. if Self.readBandwidth <= 0 {
  339. Self.readBandwidth = 100
  340. }
  341. return Self.readBandwidth
  342. }