conn.go 7.4 KB


  1. package utils
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/binary"
  6. "errors"
  7. "github.com/golang/snappy"
  8. "io"
  9. "log"
  10. "net"
  11. "net/http"
  12. "net/url"
  13. "strconv"
  14. "strings"
  15. "time"
  16. )
  17. const cryptKey = "1234567812345678"
  18. const poolSize = 64 * 1024
  19. const poolSizeSmall = 10
  20. type CryptConn struct {
  21. conn net.Conn
  22. crypt bool
  23. }
  24. func NewCryptConn(conn net.Conn, crypt bool) *CryptConn {
  25. c := new(CryptConn)
  26. c.conn = conn
  27. c.crypt = crypt
  28. return c
  29. }
  30. //加密写
  31. func (s *CryptConn) Write(b []byte) (n int, err error) {
  32. n = len(b)
  33. if s.crypt {
  34. if b, err = AesEncrypt(b, []byte(cryptKey)); err != nil {
  35. return
  36. }
  37. }
  38. if b, err = GetLenBytes(b); err != nil {
  39. return
  40. }
  41. _, err = s.conn.Write(b)
  42. return
  43. }
  44. //解密读
  45. func (s *CryptConn) Read(b []byte) (n int, err error) {
  46. defer func() {
  47. if err == nil && n == len(IO_EOF) && string(b[:n]) == IO_EOF {
  48. err = io.EOF
  49. n = 0
  50. }
  51. }()
  52. var lens int
  53. var buf []byte
  54. var rb []byte
  55. c := NewConn(s.conn)
  56. if lens, err = c.GetLen(); err != nil {
  57. return
  58. }
  59. if buf, err = c.ReadLen(lens); err != nil {
  60. return
  61. }
  62. if s.crypt {
  63. if rb, err = AesDecrypt(buf, []byte(cryptKey)); err != nil {
  64. return
  65. }
  66. } else {
  67. rb = buf
  68. }
  69. copy(b, rb)
  70. n = len(rb)
  71. return
  72. }
  73. type SnappyConn struct {
  74. w *snappy.Writer
  75. r *snappy.Reader
  76. crypt bool
  77. }
  78. func NewSnappyConn(conn net.Conn, crypt bool) *SnappyConn {
  79. c := new(SnappyConn)
  80. c.w = snappy.NewBufferedWriter(conn)
  81. c.r = snappy.NewReader(conn)
  82. c.crypt = crypt
  83. return c
  84. }
  85. //snappy压缩写 包含加密
  86. func (s *SnappyConn) Write(b []byte) (n int, err error) {
  87. n = len(b)
  88. if s.crypt {
  89. if b, err = AesEncrypt(b, []byte(cryptKey)); err != nil {
  90. log.Println("encode crypt error:", err)
  91. return
  92. }
  93. }
  94. if _, err = s.w.Write(b); err != nil {
  95. return
  96. }
  97. err = s.w.Flush()
  98. return
  99. }
  100. //snappy压缩读 包含解密
  101. func (s *SnappyConn) Read(b []byte) (n int, err error) {
  102. defer func() {
  103. if err == nil && n == len(IO_EOF) && string(b[:n]) == IO_EOF {
  104. err = io.EOF
  105. n = 0
  106. }
  107. }()
  108. buf := bufPool.Get().([]byte)
  109. if n, err = s.r.Read(buf); err != nil {
  110. return
  111. }
  112. var bs []byte
  113. if s.crypt {
  114. if bs, err = AesDecrypt(buf[:n], []byte(cryptKey)); err != nil {
  115. log.Println("decode crypt error:", err)
  116. return
  117. }
  118. } else {
  119. bs = buf[:n]
  120. }
  121. n = len(bs)
  122. copy(b, bs)
  123. return
  124. }
  125. type Conn struct {
  126. Conn net.Conn
  127. }
  128. //new conn
  129. func NewConn(conn net.Conn) *Conn {
  130. c := new(Conn)
  131. c.Conn = conn
  132. return c
  133. }
  134. //读取指定长度内容
  135. func (s *Conn) ReadLen(cLen int) ([]byte, error) {
  136. if cLen > poolSize {
  137. return nil, errors.New("长度错误" + strconv.Itoa(cLen))
  138. }
  139. var buf []byte
  140. if cLen <= poolSizeSmall {
  141. buf = bufPoolSmall.Get().([]byte)[:cLen]
  142. } else {
  143. buf = bufPool.Get().([]byte)[:cLen]
  144. }
  145. if n, err := io.ReadFull(s, buf); err != nil || n != cLen {
  146. return buf, errors.New("读取指定长度错误" + err.Error())
  147. }
  148. return buf, nil
  149. }
  150. //获取长度
  151. func (s *Conn) GetLen() (int, error) {
  152. val, err := s.ReadLen(4)
  153. if err != nil {
  154. return 0, err
  155. }
  156. return GetLenByBytes(val)
  157. }
  158. //写入长度+内容 粘包
  159. func (s *Conn) WriteLen(buf []byte) (int, error) {
  160. var b []byte
  161. var err error
  162. if b, err = GetLenBytes(buf); err != nil {
  163. return 0, err
  164. }
  165. return s.Write(b)
  166. }
  167. //读取flag
  168. func (s *Conn) ReadFlag() (string, error) {
  169. val, err := s.ReadLen(4)
  170. if err != nil {
  171. return "", err
  172. }
  173. return string(val), err
  174. }
  175. //读取host 连接地址 压缩类型
  176. func (s *Conn) GetHostFromConn() (typeStr string, host string, en, de int, crypt, mux bool, err error) {
  177. retry:
  178. lType, err := s.ReadLen(3)
  179. if err != nil {
  180. return
  181. }
  182. if typeStr = string(lType); typeStr == TEST_FLAG {
  183. en, de, crypt, mux = s.GetConnInfoFromConn()
  184. goto retry
  185. } else if typeStr != CONN_TCP && typeStr != CONN_UDP {
  186. err = errors.New("unknown conn type")
  187. return
  188. }
  189. cLen, err := s.GetLen()
  190. if err != nil || cLen > poolSize {
  191. return
  192. }
  193. hostByte, err := s.ReadLen(cLen)
  194. if err != nil {
  195. return
  196. }
  197. host = string(hostByte)
  198. return
  199. }
  200. //写连接类型 和 host地址
  201. func (s *Conn) WriteHost(ltype string, host string) (int, error) {
  202. raw := bytes.NewBuffer([]byte{})
  203. binary.Write(raw, binary.LittleEndian, []byte(ltype))
  204. binary.Write(raw, binary.LittleEndian, int32(len([]byte(host))))
  205. binary.Write(raw, binary.LittleEndian, []byte(host))
  206. return s.Write(raw.Bytes())
  207. }
  208. //设置连接为长连接
  209. func (s *Conn) SetAlive() {
  210. conn := s.Conn.(*net.TCPConn)
  211. conn.SetReadDeadline(time.Time{})
  212. conn.SetKeepAlive(true)
  213. conn.SetKeepAlivePeriod(time.Duration(2 * time.Second))
  214. }
  215. //从tcp报文中解析出host,连接类型等 TODO 多种情况
  216. func (s *Conn) GetHost() (method, address string, rb []byte, err error, r *http.Request) {
  217. var b [32 * 1024]byte
  218. var n int
  219. if n, err = s.Read(b[:]); err != nil {
  220. return
  221. }
  222. rb = b[:n]
  223. r, err = http.ReadRequest(bufio.NewReader(bytes.NewReader(rb)))
  224. if err != nil {
  225. return
  226. }
  227. hostPortURL, err := url.Parse(r.Host)
  228. if err != nil {
  229. address = r.Host
  230. err = nil
  231. return
  232. }
  233. if hostPortURL.Opaque == "443" { //https访问
  234. address = r.Host + ":443"
  235. } else { //http访问
  236. if strings.Index(hostPortURL.Host, ":") == -1 { //host不带端口, 默认80
  237. address = r.Host + ":80"
  238. } else {
  239. address = r.Host
  240. }
  241. }
  242. return
  243. }
  244. //单独读(加密|压缩)
  245. func (s *Conn) ReadFrom(b []byte, compress int, crypt bool) (int, error) {
  246. if COMPRESS_SNAPY_DECODE == compress {
  247. return NewSnappyConn(s.Conn, crypt).Read(b)
  248. }
  249. return NewCryptConn(s.Conn, crypt).Read(b)
  250. }
  251. //单独写(加密|压缩)
  252. func (s *Conn) WriteTo(b []byte, compress int, crypt bool) (n int, err error) {
  253. if COMPRESS_SNAPY_ENCODE == compress {
  254. return NewSnappyConn(s.Conn, crypt).Write(b)
  255. }
  256. return NewCryptConn(s.Conn, crypt).Write(b)
  257. }
  258. //写压缩方式,加密
  259. func (s *Conn) WriteConnInfo(en, de int, crypt, mux bool) {
  260. s.Write([]byte(strconv.Itoa(en) + strconv.Itoa(de) + GetStrByBool(crypt) + GetStrByBool(mux)))
  261. }
  262. //获取压缩方式,是否加密
  263. func (s *Conn) GetConnInfoFromConn() (en, de int, crypt, mux bool) {
  264. buf, err := s.ReadLen(4)
  265. if err != nil {
  266. return
  267. }
  268. en, _ = strconv.Atoi(string(buf[0]))
  269. de, _ = strconv.Atoi(string(buf[1]))
  270. crypt = GetBoolByStr(string(buf[2]))
  271. mux = GetBoolByStr(string(buf[3]))
  272. return
  273. }
  274. //close
  275. func (s *Conn) Close() error {
  276. return s.Conn.Close()
  277. }
  278. //write
  279. func (s *Conn) Write(b []byte) (int, error) {
  280. return s.Conn.Write(b)
  281. }
  282. //read
  283. func (s *Conn) Read(b []byte) (int, error) {
  284. return s.Conn.Read(b)
  285. }
  286. //write error
  287. func (s *Conn) WriteError() (int, error) {
  288. return s.Write([]byte(RES_MSG))
  289. }
  290. //write sign flag
  291. func (s *Conn) WriteSign() (int, error) {
  292. return s.Write([]byte(RES_SIGN))
  293. }
  294. //write main
  295. func (s *Conn) WriteMain() (int, error) {
  296. return s.Write([]byte(WORK_MAIN))
  297. }
  298. //write chan
  299. func (s *Conn) WriteChan() (int, error) {
  300. return s.Write([]byte(WORK_CHAN))
  301. }
  302. //write test
  303. func (s *Conn) WriteTest() (int, error) {
  304. return s.Write([]byte(TEST_FLAG))
  305. }
  306. //write test
  307. func (s *Conn) WriteSuccess() (int, error) {
  308. return s.Write([]byte(CONN_SUCCESS))
  309. }
  310. //write test
  311. func (s *Conn) WriteFail() (int, error) {
  312. return s.Write([]byte(CONN_ERROR))
  313. }
  314. //获取长度+内容
  315. func GetLenBytes(buf []byte) (b []byte, err error) {
  316. raw := bytes.NewBuffer([]byte{})
  317. if err = binary.Write(raw, binary.LittleEndian, int32(len(buf))); err != nil {
  318. return
  319. }
  320. if err = binary.Write(raw, binary.LittleEndian, buf); err != nil {
  321. return
  322. }
  323. b = raw.Bytes()
  324. return
  325. }
  326. //解析出长度
  327. func GetLenByBytes(buf []byte) (int, error) {
  328. nlen := binary.LittleEndian.Uint32(buf)
  329. if nlen <= 0 {
  330. return 0, errors.New("数据长度错误")
  331. }
  332. return int(nlen), nil
  333. }