1
0

conn.go 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438
  1. package lib
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/binary"
  6. "errors"
  7. "github.com/golang/snappy"
  8. "io"
  9. "net"
  10. "net/http"
  11. "net/url"
  12. "strconv"
  13. "strings"
  14. "sync"
  15. "time"
  16. )
  17. const cryptKey = "1234567812345678"
  18. type CryptConn struct {
  19. conn net.Conn
  20. crypt bool
  21. rate *Rate
  22. }
  23. func NewCryptConn(conn net.Conn, crypt bool, rate *Rate) *CryptConn {
  24. c := new(CryptConn)
  25. c.conn = conn
  26. c.crypt = crypt
  27. c.rate = rate
  28. return c
  29. }
  30. //加密写
  31. func (s *CryptConn) Write(b []byte) (n int, err error) {
  32. n = len(b)
  33. if s.crypt {
  34. if b, err = AesEncrypt(b, []byte(cryptKey)); err != nil {
  35. return
  36. }
  37. }
  38. if b, err = GetLenBytes(b); err != nil {
  39. return
  40. }
  41. _, err = s.conn.Write(b)
  42. if s.rate != nil {
  43. s.rate.Get(int64(n))
  44. }
  45. return
  46. }
  47. //解密读
  48. func (s *CryptConn) Read(b []byte) (n int, err error) {
  49. var lens int
  50. var buf []byte
  51. var rb []byte
  52. c := NewConn(s.conn)
  53. if lens, err = c.GetLen(); err != nil {
  54. return
  55. }
  56. if buf, err = c.ReadLen(lens); err != nil {
  57. return
  58. }
  59. if s.crypt {
  60. if rb, err = AesDecrypt(buf, []byte(cryptKey)); err != nil {
  61. return
  62. }
  63. } else {
  64. rb = buf
  65. }
  66. copy(b, rb)
  67. n = len(rb)
  68. if s.rate != nil {
  69. s.rate.Get(int64(n))
  70. }
  71. return
  72. }
  73. type SnappyConn struct {
  74. w *snappy.Writer
  75. r *snappy.Reader
  76. crypt bool
  77. rate *Rate
  78. }
  79. func NewSnappyConn(conn net.Conn, crypt bool, rate *Rate) *SnappyConn {
  80. c := new(SnappyConn)
  81. c.w = snappy.NewBufferedWriter(conn)
  82. c.r = snappy.NewReader(conn)
  83. c.crypt = crypt
  84. c.rate = rate
  85. return c
  86. }
  87. //snappy压缩写 包含加密
  88. func (s *SnappyConn) Write(b []byte) (n int, err error) {
  89. n = len(b)
  90. if s.crypt {
  91. if b, err = AesEncrypt(b, []byte(cryptKey)); err != nil {
  92. Println("encode crypt error:", err)
  93. return
  94. }
  95. }
  96. if _, err = s.w.Write(b); err != nil {
  97. return
  98. }
  99. if err = s.w.Flush(); err != nil {
  100. return
  101. }
  102. if s.rate != nil {
  103. s.rate.Get(int64(n))
  104. }
  105. return
  106. }
  107. //snappy压缩读 包含解密
  108. func (s *SnappyConn) Read(b []byte) (n int, err error) {
  109. buf := BufPool.Get().([]byte)
  110. defer BufPool.Put(buf)
  111. if n, err = s.r.Read(buf); err != nil {
  112. return
  113. }
  114. var bs []byte
  115. if s.crypt {
  116. if bs, err = AesDecrypt(buf[:n], []byte(cryptKey)); err != nil {
  117. Println("decode crypt error:", err)
  118. return
  119. }
  120. } else {
  121. bs = buf[:n]
  122. }
  123. n = len(bs)
  124. copy(b, bs)
  125. if s.rate != nil {
  126. s.rate.Get(int64(n))
  127. }
  128. return
  129. }
  130. type Conn struct {
  131. Conn net.Conn
  132. sync.Mutex
  133. }
  134. //new conn
  135. func NewConn(conn net.Conn) *Conn {
  136. c := new(Conn)
  137. c.Conn = conn
  138. return c
  139. }
  140. //从tcp报文中解析出host,连接类型等
  141. func (s *Conn) GetHost() (method, address string, rb []byte, err error, r *http.Request) {
  142. var b [32 * 1024]byte
  143. var n int
  144. if n, err = s.Read(b[:]); err != nil {
  145. return
  146. }
  147. rb = b[:n]
  148. r, err = http.ReadRequest(bufio.NewReader(bytes.NewReader(rb)))
  149. if err != nil {
  150. return
  151. }
  152. hostPortURL, err := url.Parse(r.Host)
  153. if err != nil {
  154. address = r.Host
  155. err = nil
  156. return
  157. }
  158. if hostPortURL.Opaque == "443" { //https访问
  159. if strings.Index(r.Host, ":") == -1 { //host不带端口, 默认80
  160. address = r.Host + ":443"
  161. } else {
  162. address = r.Host
  163. }
  164. } else { //http访问
  165. if strings.Index(r.Host, ":") == -1 { //host不带端口, 默认80
  166. address = r.Host + ":80"
  167. } else {
  168. address = r.Host
  169. }
  170. }
  171. return
  172. }
  173. //读取指定长度内容
  174. func (s *Conn) ReadLen(cLen int) ([]byte, error) {
  175. if cLen > poolSize {
  176. return nil, errors.New("长度错误" + strconv.Itoa(cLen))
  177. }
  178. var buf []byte
  179. if cLen <= poolSizeSmall {
  180. buf = BufPoolSmall.Get().([]byte)[:cLen]
  181. defer BufPoolSmall.Put(buf)
  182. } else {
  183. buf = BufPoolMax.Get().([]byte)[:cLen]
  184. defer BufPoolMax.Put(buf)
  185. }
  186. if n, err := io.ReadFull(s, buf); err != nil || n != cLen {
  187. return buf, errors.New("读取指定长度错误" + err.Error())
  188. }
  189. return buf, nil
  190. }
  191. //read length or id (content length=4)
  192. func (s *Conn) GetLen() (int, error) {
  193. val, err := s.ReadLen(4)
  194. if err != nil {
  195. return 0, err
  196. }
  197. return GetLenByBytes(val)
  198. }
  199. //read flag
  200. func (s *Conn) ReadFlag() (string, error) {
  201. val, err := s.ReadLen(4)
  202. if err != nil {
  203. return "", err
  204. }
  205. return string(val), err
  206. }
  207. //read connect status
  208. func (s *Conn) GetConnStatus() (id int, status bool, err error) {
  209. id, err = s.GetLen()
  210. if err != nil {
  211. return
  212. }
  213. var b []byte
  214. if b, err = s.ReadLen(1); err != nil {
  215. return
  216. } else {
  217. status = GetBoolByStr(string(b[0]))
  218. }
  219. return
  220. }
  221. //设置连接为长连接
  222. func (s *Conn) SetAlive() {
  223. conn := s.Conn.(*net.TCPConn)
  224. conn.SetReadDeadline(time.Time{})
  225. conn.SetKeepAlive(true)
  226. conn.SetKeepAlivePeriod(time.Duration(2 * time.Second))
  227. }
  228. //set read dead time
  229. func (s *Conn) SetReadDeadline(t time.Duration) {
  230. s.Conn.(*net.TCPConn).SetReadDeadline(time.Now().Add(time.Duration(t) * time.Second))
  231. }
  232. //单独读(加密|压缩)
  233. func (s *Conn) ReadFrom(b []byte, compress int, crypt bool, rate *Rate) (int, error) {
  234. if COMPRESS_SNAPY_DECODE == compress {
  235. return NewSnappyConn(s.Conn, crypt, rate).Read(b)
  236. }
  237. return NewCryptConn(s.Conn, crypt, rate).Read(b)
  238. }
  239. //单独写(加密|压缩)
  240. func (s *Conn) WriteTo(b []byte, compress int, crypt bool, rate *Rate) (n int, err error) {
  241. if COMPRESS_SNAPY_ENCODE == compress {
  242. return NewSnappyConn(s.Conn, crypt, rate).Write(b)
  243. }
  244. return NewCryptConn(s.Conn, crypt, rate).Write(b)
  245. }
  246. //send msg
  247. func (s *Conn) SendMsg(content []byte, link *Link) (n int, err error) {
  248. /*
  249. The msg info is formed as follows:
  250. +----+--------+
  251. |id | content |
  252. +----+--------+
  253. | 4 | ... |
  254. +----+--------+
  255. */
  256. s.Lock()
  257. defer s.Unlock()
  258. raw := bytes.NewBuffer([]byte{})
  259. binary.Write(raw, binary.LittleEndian, int32(link.Id))
  260. if n, err = s.Write(raw.Bytes()); err != nil {
  261. return
  262. }
  263. raw.Reset()
  264. binary.Write(raw, binary.LittleEndian, content)
  265. n, err = s.WriteTo(raw.Bytes(), link.En, link.Crypt, link.Rate)
  266. return
  267. }
  268. //get msg content from conn
  269. func (s *Conn) GetMsgContent(link *Link) (content []byte, err error) {
  270. s.Lock()
  271. defer s.Unlock()
  272. buf := BufPoolCopy.Get().([]byte)
  273. if n, err := s.ReadFrom(buf, link.De, link.Crypt, link.Rate); err == nil && n > 4 {
  274. content = buf[:n]
  275. }
  276. return
  277. }
  278. //send info for link
  279. func (s *Conn) SendLinkInfo(link *Link) (int, error) {
  280. /*
  281. The link info is formed as follows:
  282. +----------+------+----------+------+----------+-----+
  283. | id | len | type | hostlen | host | en | de |crypt |
  284. +----------+------+----------+------+---------+------+
  285. | 4 | 4 | 3 | 4 | host | 1 | 1 | 1 |
  286. +----------+------+----------+------+----+----+------+
  287. */
  288. raw := bytes.NewBuffer([]byte{})
  289. binary.Write(raw, binary.LittleEndian, []byte(NEW_CONN))
  290. binary.Write(raw, binary.LittleEndian, int32(14+len(link.Host)))
  291. binary.Write(raw, binary.LittleEndian, int32(link.Id))
  292. binary.Write(raw, binary.LittleEndian, []byte(link.ConnType))
  293. binary.Write(raw, binary.LittleEndian, int32(len(link.Host)))
  294. binary.Write(raw, binary.LittleEndian, []byte(link.Host))
  295. binary.Write(raw, binary.LittleEndian, []byte(strconv.Itoa(link.En)))
  296. binary.Write(raw, binary.LittleEndian, []byte(strconv.Itoa(link.De)))
  297. binary.Write(raw, binary.LittleEndian, []byte(GetStrByBool(link.Crypt)))
  298. s.Lock()
  299. defer s.Unlock()
  300. return s.Write(raw.Bytes())
  301. }
  302. func (s *Conn) GetLinkInfo() (link *Link, err error) {
  303. s.Lock()
  304. defer s.Unlock()
  305. var hostLen, n int
  306. var buf []byte
  307. if n, err = s.GetLen(); err != nil {
  308. return
  309. }
  310. link = new(Link)
  311. if buf, err = s.ReadLen(n); err != nil {
  312. return
  313. }
  314. if link.Id, err = GetLenByBytes(buf[:4]); err != nil {
  315. return
  316. }
  317. link.ConnType = string(buf[4:7])
  318. if hostLen, err = GetLenByBytes(buf[7:11]); err != nil {
  319. return
  320. } else {
  321. link.Host = string(buf[11 : 11+hostLen])
  322. link.En = GetIntNoErrByStr(string(buf[11+hostLen]))
  323. link.De = GetIntNoErrByStr(string(buf[12+hostLen]))
  324. link.Crypt = GetBoolByStr(string(buf[13+hostLen]))
  325. }
  326. return
  327. }
  328. //write connect success
  329. func (s *Conn) WriteSuccess(id int) (int, error) {
  330. raw := bytes.NewBuffer([]byte{})
  331. binary.Write(raw, binary.LittleEndian, int32(id))
  332. binary.Write(raw, binary.LittleEndian, []byte("1"))
  333. s.Lock()
  334. defer s.Unlock()
  335. return s.Write(raw.Bytes())
  336. }
  337. //write connect fail
  338. func (s *Conn) WriteFail(id int) (int, error) {
  339. raw := bytes.NewBuffer([]byte{})
  340. binary.Write(raw, binary.LittleEndian, int32(id))
  341. binary.Write(raw, binary.LittleEndian, []byte("0"))
  342. s.Lock()
  343. defer s.Unlock()
  344. return s.Write(raw.Bytes())
  345. }
  346. //close
  347. func (s *Conn) Close() error {
  348. return s.Conn.Close()
  349. }
  350. //write
  351. func (s *Conn) Write(b []byte) (int, error) {
  352. return s.Conn.Write(b)
  353. }
  354. //read
  355. func (s *Conn) Read(b []byte) (int, error) {
  356. return s.Conn.Read(b)
  357. }
  358. //write error
  359. func (s *Conn) WriteError() (int, error) {
  360. return s.Write([]byte(RES_MSG))
  361. }
  362. //write sign flag
  363. func (s *Conn) WriteSign() (int, error) {
  364. return s.Write([]byte(RES_SIGN))
  365. }
  366. //write sign flag
  367. func (s *Conn) WriteClose() (int, error) {
  368. return s.Write([]byte(RES_CLOSE))
  369. }
  370. //write main
  371. func (s *Conn) WriteMain() (int, error) {
  372. s.Lock()
  373. defer s.Unlock()
  374. return s.Write([]byte(WORK_MAIN))
  375. }
  376. //write chan
  377. func (s *Conn) WriteChan() (int, error) {
  378. s.Lock()
  379. defer s.Unlock()
  380. return s.Write([]byte(WORK_CHAN))
  381. }
  382. //获取长度+内容
  383. func GetLenBytes(buf []byte) (b []byte, err error) {
  384. raw := bytes.NewBuffer([]byte{})
  385. if err = binary.Write(raw, binary.LittleEndian, int32(len(buf))); err != nil {
  386. return
  387. }
  388. if err = binary.Write(raw, binary.LittleEndian, buf); err != nil {
  389. return
  390. }
  391. b = raw.Bytes()
  392. return
  393. }
  394. //解析出长度
  395. func GetLenByBytes(buf []byte) (int, error) {
  396. nlen := binary.LittleEndian.Uint32(buf)
  397. if nlen <= 0 {
  398. return 0, errors.New("数据长度错误")
  399. }
  400. return int(nlen), nil
  401. }