conn.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. package utils
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/binary"
  6. "errors"
  7. "github.com/golang/snappy"
  8. "io"
  9. "log"
  10. "net"
  11. "net/http"
  12. "net/url"
  13. "strconv"
  14. "strings"
  15. "time"
  16. )
  17. const cryptKey = "1234567812345678"
  18. const poolSize = 64 * 1024
  19. type CryptConn struct {
  20. conn net.Conn
  21. crypt bool
  22. rb []byte
  23. rn int
  24. }
  25. func NewCryptConn(conn net.Conn, crypt bool) *CryptConn {
  26. c := new(CryptConn)
  27. c.conn = conn
  28. c.crypt = crypt
  29. c.rb = make([]byte, poolSize)
  30. return c
  31. }
  32. //加密写
  33. func (s *CryptConn) Write(b []byte) (n int, err error) {
  34. n = len(b)
  35. if s.crypt {
  36. if b, err = AesEncrypt(b, []byte(cryptKey)); err != nil {
  37. return
  38. }
  39. }
  40. if b, err = GetLenBytes(b); err != nil {
  41. return
  42. }
  43. _, err = s.conn.Write(b)
  44. return
  45. }
  46. //解密读
  47. func (s *CryptConn) Read(b []byte) (n int, err error) {
  48. read:
  49. if len(s.rb) > 0 {
  50. if len(b) >= s.rn {
  51. n = s.rn
  52. copy(b, s.rb[:s.rn])
  53. s.rn = 0
  54. s.rb = s.rb[:0]
  55. } else {
  56. n = len(b)
  57. copy(b, s.rb[:len(b)])
  58. s.rn = n - len(b)
  59. s.rb = s.rb[len(b):]
  60. }
  61. return
  62. }
  63. defer func() {
  64. if err == nil && n == len(IO_EOF) && string(b[:n]) == IO_EOF {
  65. err = io.EOF
  66. n = 0
  67. }
  68. }()
  69. var lens int
  70. var buf []byte
  71. c := NewConn(s.conn)
  72. if lens, err = c.GetLen(); err != nil {
  73. return
  74. }
  75. if buf, err = c.ReadLen(lens); err != nil {
  76. return
  77. }
  78. if s.crypt {
  79. if s.rb, err = AesDecrypt(buf, []byte(cryptKey)); err != nil {
  80. return
  81. }
  82. } else {
  83. s.rb = buf
  84. }
  85. s.rn = len(s.rb)
  86. goto read
  87. }
  88. type SnappyConn struct {
  89. w *snappy.Writer
  90. r *snappy.Reader
  91. crypt bool
  92. }
  93. func NewSnappyConn(conn net.Conn, crypt bool) *SnappyConn {
  94. c := new(SnappyConn)
  95. c.w = snappy.NewBufferedWriter(conn)
  96. c.r = snappy.NewReader(conn)
  97. c.crypt = crypt
  98. return c
  99. }
  100. //snappy压缩写 包含加密
  101. func (s *SnappyConn) Write(b []byte) (n int, err error) {
  102. n = len(b)
  103. if s.crypt {
  104. if b, err = AesEncrypt(b, []byte(cryptKey)); err != nil {
  105. log.Println("encode crypt error:", err)
  106. return
  107. }
  108. }
  109. if _, err = s.w.Write(b); err != nil {
  110. return
  111. }
  112. err = s.w.Flush()
  113. return
  114. }
  115. //snappy压缩读 包含解密
  116. func (s *SnappyConn) Read(b []byte) (n int, err error) {
  117. defer func() {
  118. if err == nil && n == len(IO_EOF) && string(b[:n]) == IO_EOF {
  119. err = io.EOF
  120. n = 0
  121. }
  122. }()
  123. buf := bufPool.Get().([]byte)
  124. if n, err = s.r.Read(buf); err != nil {
  125. return
  126. }
  127. var bs []byte
  128. if s.crypt {
  129. if bs, err = AesDecrypt(buf[:n], []byte(cryptKey)); err != nil {
  130. log.Println("decode crypt error:", err)
  131. return
  132. }
  133. } else {
  134. bs = buf[:n]
  135. }
  136. n = len(bs)
  137. copy(b, bs)
  138. return
  139. }
  140. type Conn struct {
  141. Conn net.Conn
  142. }
  143. //new conn
  144. func NewConn(conn net.Conn) *Conn {
  145. c := new(Conn)
  146. c.Conn = conn
  147. return c
  148. }
  149. //读取指定长度内容
  150. func (s *Conn) ReadLen(cLen int) ([]byte, error) {
  151. if cLen > poolSize {
  152. return nil, errors.New("长度错误" + strconv.Itoa(cLen))
  153. }
  154. buf := bufPool.Get().([]byte)[:cLen]
  155. if n, err := io.ReadFull(s, buf); err != nil || n != cLen {
  156. return buf, errors.New("读取指定长度错误" + err.Error())
  157. }
  158. return buf, nil
  159. }
  160. //获取长度
  161. func (s *Conn) GetLen() (int, error) {
  162. val, err := s.ReadLen(4)
  163. if err != nil {
  164. return 0, err
  165. }
  166. return GetLenByBytes(val)
  167. }
  168. //写入长度+内容 粘包
  169. func (s *Conn) WriteLen(buf []byte) (int, error) {
  170. var b []byte
  171. var err error
  172. if b, err = GetLenBytes(buf); err != nil {
  173. return 0, err
  174. }
  175. return s.Write(b)
  176. }
  177. //读取flag
  178. func (s *Conn) ReadFlag() (string, error) {
  179. val, err := s.ReadLen(4)
  180. if err != nil {
  181. return "", err
  182. }
  183. return string(val), err
  184. }
  185. //读取host 连接地址 压缩类型
  186. func (s *Conn) GetHostFromConn() (typeStr string, host string, en, de int, crypt, mux bool, err error) {
  187. retry:
  188. lType, err := s.ReadLen(3)
  189. if err != nil {
  190. return
  191. }
  192. if typeStr = string(lType); typeStr == TEST_FLAG {
  193. en, de, crypt, mux = s.GetConnInfoFromConn()
  194. goto retry
  195. } else if typeStr != CONN_TCP && typeStr != CONN_UDP {
  196. err = errors.New("unknown conn type")
  197. return
  198. }
  199. cLen, err := s.GetLen()
  200. if err != nil || cLen > poolSize {
  201. return
  202. }
  203. hostByte, err := s.ReadLen(cLen)
  204. if err != nil {
  205. return
  206. }
  207. host = string(hostByte)
  208. return
  209. }
  210. //写连接类型 和 host地址
  211. func (s *Conn) WriteHost(ltype string, host string) (int, error) {
  212. raw := bytes.NewBuffer([]byte{})
  213. binary.Write(raw, binary.LittleEndian, []byte(ltype))
  214. binary.Write(raw, binary.LittleEndian, int32(len([]byte(host))))
  215. binary.Write(raw, binary.LittleEndian, []byte(host))
  216. return s.Write(raw.Bytes())
  217. }
  218. //设置连接为长连接
  219. func (s *Conn) SetAlive() {
  220. conn := s.Conn.(*net.TCPConn)
  221. conn.SetReadDeadline(time.Time{})
  222. conn.SetKeepAlive(true)
  223. conn.SetKeepAlivePeriod(time.Duration(2 * time.Second))
  224. }
  225. //从tcp报文中解析出host,连接类型等 TODO 多种情况
  226. func (s *Conn) GetHost() (method, address string, rb []byte, err error, r *http.Request) {
  227. var b [32 * 1024]byte
  228. var n int
  229. if n, err = s.Read(b[:]); err != nil {
  230. return
  231. }
  232. rb = b[:n]
  233. r, err = http.ReadRequest(bufio.NewReader(bytes.NewReader(rb)))
  234. if err != nil {
  235. return
  236. }
  237. hostPortURL, err := url.Parse(r.Host)
  238. if err != nil {
  239. address = r.Host
  240. err = nil
  241. return
  242. }
  243. if hostPortURL.Opaque == "443" { //https访问
  244. address = r.Host + ":443"
  245. } else { //http访问
  246. if strings.Index(hostPortURL.Host, ":") == -1 { //host不带端口, 默认80
  247. address = r.Host + ":80"
  248. } else {
  249. address = r.Host
  250. }
  251. }
  252. return
  253. }
  254. //单独读(加密|压缩)
  255. func (s *Conn) ReadFrom(b []byte, compress int, crypt bool) (int, error) {
  256. if COMPRESS_SNAPY_DECODE == compress {
  257. return NewSnappyConn(s.Conn, crypt).Read(b)
  258. }
  259. return NewCryptConn(s.Conn, crypt).Read(b)
  260. }
  261. //单独写(加密|压缩)
  262. func (s *Conn) WriteTo(b []byte, compress int, crypt bool) (n int, err error) {
  263. if COMPRESS_SNAPY_ENCODE == compress {
  264. return NewSnappyConn(s.Conn, crypt).Write(b)
  265. }
  266. return NewCryptConn(s.Conn, crypt).Write(b)
  267. }
  268. //写压缩方式,加密
  269. func (s *Conn) WriteConnInfo(en, de int, crypt, mux bool) {
  270. s.Write([]byte(strconv.Itoa(en) + strconv.Itoa(de) + GetStrByBool(crypt) + GetStrByBool(mux)))
  271. }
  272. //获取压缩方式,是否加密
  273. func (s *Conn) GetConnInfoFromConn() (en, de int, crypt, mux bool) {
  274. buf, err := s.ReadLen(4)
  275. if err != nil {
  276. return
  277. }
  278. en, _ = strconv.Atoi(string(buf[0]))
  279. de, _ = strconv.Atoi(string(buf[1]))
  280. crypt = GetBoolByStr(string(buf[2]))
  281. mux = GetBoolByStr(string(buf[3]))
  282. return
  283. }
  284. //close
  285. func (s *Conn) Close() error {
  286. return s.Conn.Close()
  287. }
  288. //write
  289. func (s *Conn) Write(b []byte) (int, error) {
  290. return s.Conn.Write(b)
  291. }
  292. //read
  293. func (s *Conn) Read(b []byte) (int, error) {
  294. return s.Conn.Read(b)
  295. }
  296. //write error
  297. func (s *Conn) WriteError() (int, error) {
  298. return s.Write([]byte(RES_MSG))
  299. }
  300. //write sign flag
  301. func (s *Conn) WriteSign() (int, error) {
  302. return s.Write([]byte(RES_SIGN))
  303. }
  304. //write main
  305. func (s *Conn) WriteMain() (int, error) {
  306. return s.Write([]byte(WORK_MAIN))
  307. }
  308. //write chan
  309. func (s *Conn) WriteChan() (int, error) {
  310. return s.Write([]byte(WORK_CHAN))
  311. }
  312. //write test
  313. func (s *Conn) WriteTest() (int, error) {
  314. return s.Write([]byte(TEST_FLAG))
  315. }
  316. //write test
  317. func (s *Conn) WriteSuccess() (int, error) {
  318. return s.Write([]byte(CONN_SUCCESS))
  319. }
  320. //write test
  321. func (s *Conn) WriteFail() (int, error) {
  322. return s.Write([]byte(CONN_ERROR))
  323. }
  324. //获取长度+内容
  325. func GetLenBytes(buf []byte) (b []byte, err error) {
  326. raw := bytes.NewBuffer([]byte{})
  327. if err = binary.Write(raw, binary.LittleEndian, int32(len(buf))); err != nil {
  328. return
  329. }
  330. if err = binary.Write(raw, binary.LittleEndian, buf); err != nil {
  331. return
  332. }
  333. b = raw.Bytes()
  334. return
  335. }
  336. //解析出长度
  337. func GetLenByBytes(buf []byte) (int, error) {
  338. nlen := binary.LittleEndian.Uint32(buf)
  339. if nlen <= 0 {
  340. return 0, errors.New("数据长度错误")
  341. }
  342. return int(nlen), nil
  343. }