conn.go 7.0 KB


  1. package utils
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/binary"
  6. "errors"
  7. "github.com/golang/snappy"
  8. "io"
  9. "log"
  10. "net"
  11. "net/http"
  12. "net/http/httputil"
  13. "net/url"
  14. "strconv"
  15. "strings"
  16. "time"
  17. )
  18. const cryptKey = "1234567812345678"
  19. type CryptConn struct {
  20. conn net.Conn
  21. crypt bool
  22. }
  23. func NewCryptConn(conn net.Conn, crypt bool) *CryptConn {
  24. c := new(CryptConn)
  25. c.conn = conn
  26. c.crypt = crypt
  27. return c
  28. }
  29. //加密写
  30. func (s *CryptConn) Write(b []byte) (n int, err error) {
  31. n = len(b)
  32. if s.crypt {
  33. if b, err = AesEncrypt(b, []byte(cryptKey)); err != nil {
  34. return
  35. }
  36. }
  37. if b, err = GetLenBytes(b); err != nil {
  38. return
  39. }
  40. _, err = s.conn.Write(b)
  41. return
  42. }
  43. //解密读
  44. func (s *CryptConn) Read(b []byte) (n int, err error) {
  45. defer func() {
  46. if err == nil && n == len(IO_EOF) && string(b[:n]) == IO_EOF {
  47. err = io.EOF
  48. n = 0
  49. }
  50. }()
  51. var lens int
  52. var buf, bs []byte
  53. c := NewConn(s.conn)
  54. if lens, err = c.GetLen(); err != nil {
  55. return
  56. }
  57. if buf, err = c.ReadLen(lens); err != nil {
  58. return
  59. }
  60. if s.crypt {
  61. if bs, err = AesDecrypt(buf, []byte(cryptKey)); err != nil {
  62. return
  63. }
  64. } else {
  65. bs = buf
  66. }
  67. n = len(bs)
  68. copy(b, bs)
  69. return
  70. }
  71. type SnappyConn struct {
  72. w *snappy.Writer
  73. r *snappy.Reader
  74. crypt bool
  75. }
  76. func NewSnappyConn(conn net.Conn, crypt bool) *SnappyConn {
  77. c := new(SnappyConn)
  78. c.w = snappy.NewBufferedWriter(conn)
  79. c.r = snappy.NewReader(conn)
  80. c.crypt = crypt
  81. return c
  82. }
  83. //snappy压缩写 包含加密
  84. func (s *SnappyConn) Write(b []byte) (n int, err error) {
  85. n = len(b)
  86. if s.crypt {
  87. if b, err = AesEncrypt(b, []byte(cryptKey)); err != nil {
  88. log.Println("encode crypt error:", err)
  89. return
  90. }
  91. }
  92. if _, err = s.w.Write(b); err != nil {
  93. return
  94. }
  95. err = s.w.Flush()
  96. return
  97. }
  98. //snappy压缩读 包含解密
  99. func (s *SnappyConn) Read(b []byte) (n int, err error) {
  100. defer func() {
  101. if err == nil && n == len(IO_EOF) && string(b[:n]) == IO_EOF {
  102. err = io.EOF
  103. n = 0
  104. }
  105. }()
  106. if n, err = s.r.Read(b); err != nil {
  107. return
  108. }
  109. if s.crypt {
  110. var bs []byte
  111. if bs, err = AesDecrypt(b[:n], []byte(cryptKey)); err != nil {
  112. log.Println("decode crypt error:", err)
  113. return
  114. }
  115. n = len(bs)
  116. copy(b, bs)
  117. }
  118. return
  119. }
  120. type Conn struct {
  121. Conn net.Conn
  122. }
  123. //new conn
  124. func NewConn(conn net.Conn) *Conn {
  125. c := new(Conn)
  126. c.Conn = conn
  127. return c
  128. }
  129. //读取指定长度内容
  130. func (s *Conn) ReadLen(cLen int) ([]byte, error) {
  131. if cLen > 65535 {
  132. return nil, errors.New("长度错误")
  133. }
  134. buf := bufPool.Get().([]byte)[:cLen]
  135. if n, err := io.ReadFull(s, buf); err != nil || n != cLen {
  136. return buf, errors.New("读取指定长度错误" + err.Error())
  137. }
  138. return buf, nil
  139. }
  140. //获取长度
  141. func (s *Conn) GetLen() (int, error) {
  142. val, err := s.ReadLen(4)
  143. if err != nil {
  144. return 0, err
  145. }
  146. return GetLenByBytes(val)
  147. }
  148. //写入长度+内容 粘包
  149. func (s *Conn) WriteLen(buf []byte) (int, error) {
  150. var b []byte
  151. var err error
  152. if b, err = GetLenBytes(buf); err != nil {
  153. return 0, err
  154. }
  155. return s.Write(b)
  156. }
  157. //读取flag
  158. func (s *Conn) ReadFlag() (string, error) {
  159. val, err := s.ReadLen(4)
  160. if err != nil {
  161. return "", err
  162. }
  163. return string(val), err
  164. }
  165. //读取host 连接地址 压缩类型
  166. func (s *Conn) GetHostFromConn() (typeStr string, host string, en, de int, crypt, mux bool, err error) {
  167. retry:
  168. lType, err := s.ReadLen(3)
  169. if err != nil {
  170. return
  171. }
  172. if typeStr = string(lType); typeStr == TEST_FLAG {
  173. en, de, crypt, mux = s.GetConnInfoFromConn()
  174. goto retry
  175. }
  176. cLen, err := s.GetLen()
  177. if err != nil {
  178. return
  179. }
  180. hostByte, err := s.ReadLen(cLen)
  181. if err != nil {
  182. return
  183. }
  184. host = string(hostByte)
  185. return
  186. }
  187. //写连接类型 和 host地址
  188. func (s *Conn) WriteHost(ltype string, host string) (int, error) {
  189. raw := bytes.NewBuffer([]byte{})
  190. binary.Write(raw, binary.LittleEndian, []byte(ltype))
  191. binary.Write(raw, binary.LittleEndian, int32(len([]byte(host))))
  192. binary.Write(raw, binary.LittleEndian, []byte(host))
  193. return s.Write(raw.Bytes())
  194. }
  195. //设置连接为长连接
  196. func (s *Conn) SetAlive() {
  197. conn := s.Conn.(*net.TCPConn)
  198. conn.SetReadDeadline(time.Time{})
  199. conn.SetKeepAlive(true)
  200. conn.SetKeepAlivePeriod(time.Duration(2 * time.Second))
  201. }
  202. //从tcp报文中解析出host,连接类型等 TODO 多种情况
  203. func (s *Conn) GetHost() (method, address string, rb []byte, err error, r *http.Request) {
  204. r, err = http.ReadRequest(bufio.NewReader(s))
  205. if err != nil {
  206. return
  207. }
  208. rb, err = httputil.DumpRequest(r, true)
  209. if err != nil {
  210. return
  211. }
  212. hostPortURL, err := url.Parse(r.Host)
  213. if err != nil {
  214. return
  215. }
  216. if hostPortURL.Opaque == "443" { //https访问
  217. address = r.Host + ":443"
  218. } else { //http访问
  219. if strings.Index(hostPortURL.Host, ":") == -1 { //host不带端口, 默认80
  220. address = r.Host + ":80"
  221. } else {
  222. address = r.Host
  223. }
  224. }
  225. return
  226. }
  227. //单独读(加密|压缩)
  228. func (s *Conn) ReadFrom(b []byte, compress int, crypt bool) (int, error) {
  229. if COMPRESS_SNAPY_DECODE == compress {
  230. return NewSnappyConn(s.Conn, crypt).Read(b)
  231. }
  232. return NewCryptConn(s.Conn, crypt).Read(b)
  233. }
  234. //单独写(加密|压缩)
  235. func (s *Conn) WriteTo(b []byte, compress int, crypt bool) (n int, err error) {
  236. if COMPRESS_SNAPY_ENCODE == compress {
  237. return NewSnappyConn(s.Conn, crypt).Write(b)
  238. }
  239. return NewCryptConn(s.Conn, crypt).Write(b)
  240. }
  241. //写压缩方式,加密
  242. func (s *Conn) WriteConnInfo(en, de int, crypt, mux bool) {
  243. s.Write([]byte(strconv.Itoa(en) + strconv.Itoa(de) + GetStrByBool(crypt) + GetStrByBool(mux)))
  244. }
  245. //获取压缩方式,是否加密
  246. func (s *Conn) GetConnInfoFromConn() (en, de int, crypt, mux bool) {
  247. buf, err := s.ReadLen(4)
  248. if err != nil {
  249. return
  250. }
  251. en, _ = strconv.Atoi(string(buf[0]))
  252. de, _ = strconv.Atoi(string(buf[1]))
  253. crypt = GetBoolByStr(string(buf[2]))
  254. mux = GetBoolByStr(string(buf[3]))
  255. return
  256. }
  257. //close
  258. func (s *Conn) Close() error {
  259. return s.Conn.Close()
  260. }
  261. //write
  262. func (s *Conn) Write(b []byte) (int, error) {
  263. return s.Conn.Write(b)
  264. }
  265. //read
  266. func (s *Conn) Read(b []byte) (int, error) {
  267. return s.Conn.Read(b)
  268. }
  269. //write error
  270. func (s *Conn) WriteError() (int, error) {
  271. return s.Write([]byte(RES_MSG))
  272. }
  273. //write sign flag
  274. func (s *Conn) WriteSign() (int, error) {
  275. return s.Write([]byte(RES_SIGN))
  276. }
  277. //write main
  278. func (s *Conn) WriteMain() (int, error) {
  279. return s.Write([]byte(WORK_MAIN))
  280. }
  281. //write chan
  282. func (s *Conn) WriteChan() (int, error) {
  283. return s.Write([]byte(WORK_CHAN))
  284. }
  285. //write test
  286. func (s *Conn) WriteTest() (int, error) {
  287. return s.Write([]byte(TEST_FLAG))
  288. }
  289. //write test
  290. func (s *Conn) WriteSuccess() (int, error) {
  291. return s.Write([]byte(CONN_SUCCESS))
  292. }
  293. //write test
  294. func (s *Conn) WriteFail() (int, error) {
  295. return s.Write([]byte(CONN_ERROR))
  296. }
  297. //获取长度+内容
  298. func GetLenBytes(buf []byte) (b []byte, err error) {
  299. raw := bytes.NewBuffer([]byte{})
  300. if err = binary.Write(raw, binary.LittleEndian, int32(len(buf))); err != nil {
  301. return
  302. }
  303. if err = binary.Write(raw, binary.LittleEndian, buf); err != nil {
  304. return
  305. }
  306. b = raw.Bytes()
  307. return
  308. }
  309. //解析出长度
  310. func GetLenByBytes(buf []byte) (int, error) {
  311. nlen := binary.LittleEndian.Uint32(buf)
  312. if nlen <= 0 {
  313. return 0, errors.New("数据长度错误")
  314. }
  315. return int(nlen), nil
  316. }