1
0

bridge.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. package lib
  2. import (
  3. "errors"
  4. "log"
  5. "net"
  6. "sync"
  7. "time"
  8. )
  9. type list struct {
  10. connList chan *Conn
  11. }
  12. func (l *list) Add(c *Conn) {
  13. l.connList <- c
  14. }
  15. func (l *list) Pop() *Conn {
  16. return <-l.connList
  17. }
  18. func (l *list) Len() int {
  19. return len(l.connList)
  20. }
  21. func newList() *list {
  22. l := new(list)
  23. l.connList = make(chan *Conn, 1000)
  24. return l
  25. }
  26. type Tunnel struct {
  27. tunnelPort int //通信隧道端口
  28. listener *net.TCPListener //server端监听
  29. signalList map[string]*list //通信
  30. tunnelList map[string]*list //隧道
  31. lock sync.Mutex
  32. tunnelLock sync.Mutex
  33. }
  34. func newTunnel(tunnelPort int) *Tunnel {
  35. t := new(Tunnel)
  36. t.tunnelPort = tunnelPort
  37. t.signalList = make(map[string]*list)
  38. t.tunnelList = make(map[string]*list)
  39. return t
  40. }
  41. func (s *Tunnel) StartTunnel() error {
  42. var err error
  43. s.listener, err = net.ListenTCP("tcp", &net.TCPAddr{net.ParseIP("0.0.0.0"), s.tunnelPort, ""})
  44. if err != nil {
  45. return err
  46. }
  47. go s.tunnelProcess()
  48. return nil
  49. }
  50. //tcp server
  51. func (s *Tunnel) tunnelProcess() error {
  52. var err error
  53. for {
  54. conn, err := s.listener.Accept()
  55. if err != nil {
  56. log.Println(err)
  57. continue
  58. }
  59. go s.cliProcess(NewConn(conn))
  60. }
  61. return err
  62. }
  63. //验证失败,返回错误验证flag,并且关闭连接
  64. func (s *Tunnel) verifyError(c *Conn) {
  65. c.conn.Write([]byte(VERIFY_EER))
  66. c.conn.Close()
  67. }
  68. func (s *Tunnel) cliProcess(c *Conn) error {
  69. c.conn.(*net.TCPConn).SetReadDeadline(time.Now().Add(time.Duration(5) * time.Second))
  70. vval := make([]byte, 32)
  71. if _, err := c.conn.Read(vval); err != nil {
  72. log.Println("客户端读超时。客户端地址为::", c.conn.RemoteAddr())
  73. c.conn.Close()
  74. return err
  75. }
  76. if !verify(string(vval)) {
  77. log.Println("当前客户端连接校验错误,关闭此客户端:", c.conn.RemoteAddr())
  78. s.verifyError(c)
  79. return err
  80. }
  81. c.conn.(*net.TCPConn).SetReadDeadline(time.Time{})
  82. //做一个判断 添加到对应的channel里面以供使用
  83. if flag, err := c.ReadFlag(); err != nil {
  84. return err
  85. } else {
  86. return s.typeDeal(flag, c, string(vval))
  87. }
  88. }
  89. //tcp连接类型区分
  90. func (s *Tunnel) typeDeal(typeVal string, c *Conn, cFlag string) error {
  91. switch typeVal {
  92. case WORK_MAIN:
  93. s.addList(s.signalList, c, cFlag)
  94. case WORK_CHAN:
  95. s.addList(s.tunnelList, c, cFlag)
  96. default:
  97. return errors.New("无法识别")
  98. }
  99. c.SetAlive()
  100. return nil
  101. }
  102. //加到对应的list中
  103. func (s *Tunnel) addList(m map[string]*list, c *Conn, cFlag string) {
  104. s.lock.Lock()
  105. if v, ok := m[cFlag]; ok {
  106. v.Add(c)
  107. } else {
  108. l := newList()
  109. l.Add(c)
  110. m[cFlag] = l
  111. }
  112. s.lock.Unlock()
  113. }
  114. //新建隧道
  115. func (s *Tunnel) newChan(cFlag string) error {
  116. if err := s.wait(s.signalList, cFlag); err != nil {
  117. return err
  118. }
  119. retry:
  120. connPass := s.signalList[cFlag].Pop()
  121. _, err := connPass.conn.Write([]byte("chan"))
  122. if err != nil {
  123. log.Println(err)
  124. goto retry
  125. }
  126. s.signalList[cFlag].Add(connPass)
  127. return nil
  128. }
  129. //得到一个tcp隧道
  130. func (s *Tunnel) GetTunnel(cFlag string, en, de int, crypt, mux bool) (c *Conn, err error) {
  131. s.tunnelLock.Lock()
  132. if v, ok := s.tunnelList[cFlag]; !ok || v.Len() < 3 { //新建通道
  133. go s.newChan(cFlag)
  134. }
  135. retry:
  136. if err = s.wait(s.tunnelList, cFlag); err != nil {
  137. return
  138. }
  139. c = s.tunnelList[cFlag].Pop()
  140. if _, err = c.wTest(); err != nil {
  141. c.Close()
  142. goto retry
  143. }
  144. c.WriteConnInfo(en, de, crypt, mux)
  145. s.tunnelLock.Unlock()
  146. return
  147. }
  148. //得到一个通信通道
  149. func (s *Tunnel) GetSignal(cFlag string) (err error, conn *Conn) {
  150. if v, ok := s.signalList[cFlag]; !ok || v.Len() == 0 {
  151. err = errors.New("客户端未连接")
  152. return
  153. }
  154. conn = s.signalList[cFlag].Pop()
  155. return
  156. }
  157. //重回slice 复用
  158. func (s *Tunnel) ReturnSignal(conn *Conn, cFlag string) {
  159. if v, ok := s.signalList[cFlag]; ok {
  160. v.Add(conn)
  161. }
  162. }
  163. //重回slice 复用
  164. func (s *Tunnel) ReturnTunnel(conn *Conn, cFlag string) {
  165. if v, ok := s.tunnelList[cFlag]; ok {
  166. FlushConn(conn.conn)
  167. v.Add(conn)
  168. }
  169. }
  170. //删除通信通道
  171. func (s *Tunnel) DelClientSignal(cFlag string) {
  172. s.delClient(cFlag, s.signalList)
  173. }
  174. //删除隧道
  175. func (s *Tunnel) DelClientTunnel(cFlag string) {
  176. s.delClient(cFlag, s.tunnelList)
  177. }
  178. func (s *Tunnel) delClient(cFlag string, l map[string]*list) {
  179. if t := l[getverifyval(cFlag)]; t != nil {
  180. for {
  181. if t.Len() <= 0 {
  182. break
  183. }
  184. t.Pop().Close()
  185. }
  186. delete(l, getverifyval(cFlag))
  187. }
  188. }
  189. //等待
  190. func (s *Tunnel) wait(m map[string]*list, cFlag string) error {
  191. ticker := time.NewTicker(time.Millisecond * 100)
  192. stop := time.After(time.Second * 10)
  193. loop:
  194. for {
  195. select {
  196. case <-ticker.C:
  197. if _, ok := m[cFlag]; ok {
  198. ticker.Stop()
  199. break loop
  200. }
  201. case <-stop:
  202. return errors.New("client key: " + cFlag + ",err: get client conn timeout")
  203. }
  204. }
  205. return nil
  206. }