mux.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. package mux
  2. import (
  3. "bytes"
  4. "errors"
  5. "io"
  6. "math"
  7. "net"
  8. "sync"
  9. "sync/atomic"
  10. "time"
  11. "github.com/astaxie/beego/logs"
  12. "github.com/cnlh/nps/lib/common"
  13. )
  14. type Mux struct {
  15. net.Listener
  16. conn net.Conn
  17. connMap *connMap
  18. newConnCh chan *conn
  19. id int32
  20. closeChan chan struct{}
  21. IsClose bool
  22. pingOk int
  23. latency float64
  24. pingCh chan []byte
  25. connType string
  26. writeQueue PriorityQueue
  27. bufCh chan *bytes.Buffer
  28. sync.Mutex
  29. }
  30. func NewMux(c net.Conn, connType string) *Mux {
  31. m := &Mux{
  32. conn: c,
  33. connMap: NewConnMap(),
  34. id: 0,
  35. closeChan: make(chan struct{}),
  36. newConnCh: make(chan *conn),
  37. IsClose: false,
  38. connType: connType,
  39. bufCh: make(chan *bytes.Buffer),
  40. pingCh: make(chan []byte),
  41. }
  42. m.writeQueue.New()
  43. //read session by flag
  44. m.readSession()
  45. //ping
  46. m.ping()
  47. m.pingReturn()
  48. m.writeSession()
  49. return m
  50. }
  51. func (s *Mux) NewConn() (*conn, error) {
  52. if s.IsClose {
  53. return nil, errors.New("the mux has closed")
  54. }
  55. conn := NewConn(s.getId(), s)
  56. //it must be set before send
  57. s.connMap.Set(conn.connId, conn)
  58. s.sendInfo(common.MUX_NEW_CONN, conn.connId, nil)
  59. //set a timer timeout 30 second
  60. timer := time.NewTimer(time.Minute * 2)
  61. defer timer.Stop()
  62. select {
  63. case <-conn.connStatusOkCh:
  64. return conn, nil
  65. case <-conn.connStatusFailCh:
  66. case <-timer.C:
  67. }
  68. return nil, errors.New("create connection fail,the server refused the connection")
  69. }
  70. func (s *Mux) Accept() (net.Conn, error) {
  71. if s.IsClose {
  72. return nil, errors.New("accpet error,the mux has closed")
  73. }
  74. conn := <-s.newConnCh
  75. if conn == nil {
  76. return nil, errors.New("accpet error,the conn has closed")
  77. }
  78. return conn, nil
  79. }
  80. func (s *Mux) Addr() net.Addr {
  81. return s.conn.LocalAddr()
  82. }
  83. func (s *Mux) sendInfo(flag uint8, id int32, data ...interface{}) {
  84. var err error
  85. pack := common.MuxPack.Get()
  86. err = pack.NewPac(flag, id, data...)
  87. if err != nil {
  88. common.MuxPack.Put(pack)
  89. return
  90. }
  91. s.writeQueue.Push(pack)
  92. return
  93. }
  94. func (s *Mux) writeSession() {
  95. go s.packBuf()
  96. go s.writeBuf()
  97. }
  98. func (s *Mux) packBuf() {
  99. for {
  100. if s.IsClose {
  101. break
  102. }
  103. pack := s.writeQueue.Pop()
  104. buffer := common.BuffPool.Get()
  105. err := pack.Pack(buffer)
  106. common.MuxPack.Put(pack)
  107. if err != nil {
  108. logs.Warn("pack err", err)
  109. common.BuffPool.Put(buffer)
  110. break
  111. }
  112. select {
  113. case s.bufCh <- buffer:
  114. case <-s.closeChan:
  115. break
  116. }
  117. }
  118. }
  119. func (s *Mux) writeBuf() {
  120. for {
  121. if s.IsClose {
  122. break
  123. }
  124. select {
  125. case buffer := <-s.bufCh:
  126. l := buffer.Len()
  127. n, err := buffer.WriteTo(s.conn)
  128. common.BuffPool.Put(buffer)
  129. if err != nil || int(n) != l {
  130. logs.Warn("close from write session fail ", err, n, l)
  131. s.Close()
  132. break
  133. }
  134. case <-s.closeChan:
  135. break
  136. }
  137. }
  138. }
  139. func (s *Mux) ping() {
  140. go func() {
  141. now, _ := time.Now().MarshalText()
  142. s.sendInfo(common.MUX_PING_FLAG, common.MUX_PING, now)
  143. // send the ping flag and get the latency first
  144. ticker := time.NewTicker(time.Second * 15)
  145. for {
  146. if s.IsClose {
  147. ticker.Stop()
  148. break
  149. }
  150. select {
  151. case <-ticker.C:
  152. }
  153. //Avoid going beyond the scope
  154. if (math.MaxInt32 - s.id) < 10000 {
  155. s.id = 0
  156. }
  157. now, _ := time.Now().MarshalText()
  158. s.sendInfo(common.MUX_PING_FLAG, common.MUX_PING, now)
  159. if s.pingOk > 10 && s.connType == "kcp" {
  160. s.Close()
  161. break
  162. }
  163. s.pingOk++
  164. }
  165. }()
  166. }
  167. func (s *Mux) pingReturn() {
  168. go func() {
  169. var now time.Time
  170. var data []byte
  171. for {
  172. select {
  173. case data = <-s.pingCh:
  174. case <-s.closeChan:
  175. break
  176. }
  177. _ = now.UnmarshalText(data)
  178. s.latency = time.Since(now).Seconds()
  179. s.sendInfo(common.MUX_PING_RETURN, common.MUX_PING, nil)
  180. }
  181. }()
  182. }
  183. func (s *Mux) readSession() {
  184. go func() {
  185. pack := common.MuxPack.Get()
  186. for {
  187. if s.IsClose {
  188. break
  189. }
  190. pack = common.MuxPack.Get()
  191. if pack.UnPack(s.conn) != nil {
  192. break
  193. }
  194. s.pingOk = 0
  195. switch pack.Flag {
  196. case common.MUX_NEW_CONN: //new connection
  197. connection := NewConn(pack.Id, s)
  198. s.connMap.Set(pack.Id, connection) //it has been set before send ok
  199. s.newConnCh <- connection
  200. s.sendInfo(common.MUX_NEW_CONN_OK, connection.connId, nil)
  201. continue
  202. case common.MUX_PING_FLAG: //ping
  203. s.pingCh <- pack.Content
  204. continue
  205. case common.MUX_PING_RETURN:
  206. continue
  207. }
  208. if connection, ok := s.connMap.Get(pack.Id); ok && !connection.isClose {
  209. switch pack.Flag {
  210. case common.MUX_NEW_MSG, common.MUX_NEW_MSG_PART: //new msg from remote connection
  211. err := s.newMsg(connection, pack)
  212. if err != nil {
  213. connection.Close()
  214. }
  215. continue
  216. case common.MUX_NEW_CONN_OK: //connection ok
  217. connection.connStatusOkCh <- struct{}{}
  218. continue
  219. case common.MUX_NEW_CONN_Fail:
  220. connection.connStatusFailCh <- struct{}{}
  221. continue
  222. case common.MUX_MSG_SEND_OK:
  223. if connection.isClose {
  224. continue
  225. }
  226. connection.sendWindow.SetSize(pack.Window, pack.ReadLength)
  227. continue
  228. case common.MUX_CONN_CLOSE: //close the connection
  229. s.connMap.Delete(pack.Id)
  230. connection.closeFlag = true
  231. connection.receiveWindow.Stop() // close signal to receive window
  232. continue
  233. }
  234. } else if pack.Flag == common.MUX_CONN_CLOSE {
  235. continue
  236. }
  237. common.MuxPack.Put(pack)
  238. }
  239. common.MuxPack.Put(pack)
  240. s.Close()
  241. }()
  242. }
  243. func (s *Mux) newMsg(connection *conn, pack *common.MuxPackager) (err error) {
  244. if connection.isClose {
  245. err = io.ErrClosedPipe
  246. return
  247. }
  248. //logs.Warn("read session receive new msg", pack.Length)
  249. //go func(connection *conn, pack *common.MuxPackager) { // do not block read session
  250. //insert into queue
  251. if pack.Flag == common.MUX_NEW_MSG_PART {
  252. err = connection.receiveWindow.Write(pack.Content, pack.Length, true, pack.Id)
  253. }
  254. if pack.Flag == common.MUX_NEW_MSG {
  255. err = connection.receiveWindow.Write(pack.Content, pack.Length, false, pack.Id)
  256. }
  257. //logs.Warn("read session write success", pack.Length)
  258. return
  259. }
  260. func (s *Mux) Close() error {
  261. logs.Warn("close mux")
  262. if s.IsClose {
  263. return errors.New("the mux has closed")
  264. }
  265. s.IsClose = true
  266. s.connMap.Close()
  267. s.closeChan <- struct{}{}
  268. s.closeChan <- struct{}{}
  269. close(s.newConnCh)
  270. return s.conn.Close()
  271. }
  272. //get new connId as unique flag
  273. func (s *Mux) getId() (id int32) {
  274. id = atomic.AddInt32(&s.id, 1)
  275. if _, ok := s.connMap.Get(id); ok {
  276. s.getId()
  277. }
  278. return
  279. }