conn.go 7.5 KB


  1. package utils
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/binary"
  6. "errors"
  7. "github.com/golang/snappy"
  8. "io"
  9. "log"
  10. "net"
  11. "net/http"
  12. "net/http/httputil"
  13. "net/url"
  14. "strconv"
  15. "strings"
  16. "time"
  17. )
  18. const cryptKey = "1234567812345678"
  19. const poolSize = 64 * 1024
  20. type CryptConn struct {
  21. conn net.Conn
  22. crypt bool
  23. rb []byte
  24. rn int
  25. }
  26. func NewCryptConn(conn net.Conn, crypt bool) *CryptConn {
  27. c := new(CryptConn)
  28. c.conn = conn
  29. c.crypt = crypt
  30. c.rb = make([]byte, poolSize)
  31. return c
  32. }
  33. //加密写
  34. func (s *CryptConn) Write(b []byte) (n int, err error) {
  35. n = len(b)
  36. if s.crypt {
  37. if b, err = AesEncrypt(b, []byte(cryptKey)); err != nil {
  38. return
  39. }
  40. }
  41. if b, err = GetLenBytes(b); err != nil {
  42. return
  43. }
  44. _, err = s.conn.Write(b)
  45. return
  46. }
  47. //解密读
  48. func (s *CryptConn) Read(b []byte) (n int, err error) {
  49. read:
  50. if len(s.rb) > 0 {
  51. if len(b) >= s.rn {
  52. n = s.rn
  53. copy(b, s.rb[:s.rn])
  54. s.rn = 0
  55. s.rb = s.rb[:0]
  56. } else {
  57. n = len(b)
  58. copy(b, s.rb[:len(b)])
  59. s.rn = n - len(b)
  60. s.rb = s.rb[len(b):]
  61. }
  62. return
  63. }
  64. defer func() {
  65. if err == nil && n == len(IO_EOF) && string(b[:n]) == IO_EOF {
  66. err = io.EOF
  67. n = 0
  68. }
  69. }()
  70. var lens int
  71. var buf []byte
  72. c := NewConn(s.conn)
  73. if lens, err = c.GetLen(); err != nil {
  74. return
  75. }
  76. if buf, err = c.ReadLen(lens); err != nil {
  77. return
  78. }
  79. if s.crypt {
  80. if s.rb, err = AesDecrypt(buf, []byte(cryptKey)); err != nil {
  81. return
  82. }
  83. } else {
  84. s.rb = buf
  85. }
  86. s.rn = len(s.rb)
  87. goto read
  88. }
  89. type SnappyConn struct {
  90. w *snappy.Writer
  91. r *snappy.Reader
  92. crypt bool
  93. }
  94. func NewSnappyConn(conn net.Conn, crypt bool) *SnappyConn {
  95. c := new(SnappyConn)
  96. c.w = snappy.NewBufferedWriter(conn)
  97. c.r = snappy.NewReader(conn)
  98. c.crypt = crypt
  99. return c
  100. }
  101. //snappy压缩写 包含加密
  102. func (s *SnappyConn) Write(b []byte) (n int, err error) {
  103. n = len(b)
  104. if s.crypt {
  105. if b, err = AesEncrypt(b, []byte(cryptKey)); err != nil {
  106. log.Println("encode crypt error:", err)
  107. return
  108. }
  109. }
  110. if _, err = s.w.Write(b); err != nil {
  111. return
  112. }
  113. err = s.w.Flush()
  114. return
  115. }
  116. //snappy压缩读 包含解密
  117. func (s *SnappyConn) Read(b []byte) (n int, err error) {
  118. defer func() {
  119. if err == nil && n == len(IO_EOF) && string(b[:n]) == IO_EOF {
  120. err = io.EOF
  121. n = 0
  122. }
  123. }()
  124. buf := bufPool.Get().([]byte)
  125. if n, err = s.r.Read(buf); err != nil {
  126. return
  127. }
  128. var bs []byte
  129. if s.crypt {
  130. if bs, err = AesDecrypt(buf[:n], []byte(cryptKey)); err != nil {
  131. log.Println("decode crypt error:", err)
  132. return
  133. }
  134. } else {
  135. bs = buf[:n]
  136. }
  137. n = len(bs)
  138. copy(b, bs)
  139. return
  140. }
  141. type Conn struct {
  142. Conn net.Conn
  143. }
  144. //new conn
  145. func NewConn(conn net.Conn) *Conn {
  146. c := new(Conn)
  147. c.Conn = conn
  148. return c
  149. }
  150. //读取指定长度内容
  151. func (s *Conn) ReadLen(cLen int) ([]byte, error) {
  152. if cLen > poolSize {
  153. return nil, errors.New("长度错误" + strconv.Itoa(cLen))
  154. }
  155. buf := bufPool.Get().([]byte)[:cLen]
  156. if n, err := io.ReadFull(s, buf); err != nil || n != cLen {
  157. return buf, errors.New("读取指定长度错误" + err.Error())
  158. }
  159. return buf, nil
  160. }
  161. //获取长度
  162. func (s *Conn) GetLen() (int, error) {
  163. val, err := s.ReadLen(4)
  164. if err != nil {
  165. return 0, err
  166. }
  167. return GetLenByBytes(val)
  168. }
  169. //写入长度+内容 粘包
  170. func (s *Conn) WriteLen(buf []byte) (int, error) {
  171. var b []byte
  172. var err error
  173. if b, err = GetLenBytes(buf); err != nil {
  174. return 0, err
  175. }
  176. return s.Write(b)
  177. }
  178. //读取flag
  179. func (s *Conn) ReadFlag() (string, error) {
  180. val, err := s.ReadLen(4)
  181. if err != nil {
  182. return "", err
  183. }
  184. return string(val), err
  185. }
  186. //读取host 连接地址 压缩类型
  187. func (s *Conn) GetHostFromConn() (typeStr string, host string, en, de int, crypt, mux bool, err error) {
  188. retry:
  189. lType, err := s.ReadLen(3)
  190. if err != nil {
  191. return
  192. }
  193. if typeStr = string(lType); typeStr == TEST_FLAG {
  194. en, de, crypt, mux = s.GetConnInfoFromConn()
  195. goto retry
  196. } else if typeStr != CONN_TCP && typeStr != CONN_UDP {
  197. err = errors.New("unknown conn type")
  198. return
  199. }
  200. cLen, err := s.GetLen()
  201. if err != nil || cLen > poolSize {
  202. return
  203. }
  204. hostByte, err := s.ReadLen(cLen)
  205. if err != nil {
  206. return
  207. }
  208. host = string(hostByte)
  209. return
  210. }
  211. //写连接类型 和 host地址
  212. func (s *Conn) WriteHost(ltype string, host string) (int, error) {
  213. raw := bytes.NewBuffer([]byte{})
  214. binary.Write(raw, binary.LittleEndian, []byte(ltype))
  215. binary.Write(raw, binary.LittleEndian, int32(len([]byte(host))))
  216. binary.Write(raw, binary.LittleEndian, []byte(host))
  217. return s.Write(raw.Bytes())
  218. }
  219. //设置连接为长连接
  220. func (s *Conn) SetAlive() {
  221. conn := s.Conn.(*net.TCPConn)
  222. conn.SetReadDeadline(time.Time{})
  223. conn.SetKeepAlive(true)
  224. conn.SetKeepAlivePeriod(time.Duration(2 * time.Second))
  225. }
  226. //从tcp报文中解析出host,连接类型等 TODO 多种情况
  227. func (s *Conn) GetHost() (method, address string, rb []byte, err error, r *http.Request) {
  228. r, err = http.ReadRequest(bufio.NewReader(s))
  229. if err != nil {
  230. return
  231. }
  232. rb, err = httputil.DumpRequest(r, true)
  233. if err != nil {
  234. return
  235. }
  236. hostPortURL, err := url.Parse(r.Host)
  237. if err != nil {
  238. return
  239. }
  240. if hostPortURL.Opaque == "443" { //https访问
  241. address = r.Host + ":443"
  242. } else { //http访问
  243. if strings.Index(hostPortURL.Host, ":") == -1 { //host不带端口, 默认80
  244. address = r.Host + ":80"
  245. } else {
  246. address = r.Host
  247. }
  248. }
  249. return
  250. }
  251. //单独读(加密|压缩)
  252. func (s *Conn) ReadFrom(b []byte, compress int, crypt bool) (int, error) {
  253. if COMPRESS_SNAPY_DECODE == compress {
  254. return NewSnappyConn(s.Conn, crypt).Read(b)
  255. }
  256. return NewCryptConn(s.Conn, crypt).Read(b)
  257. }
  258. //单独写(加密|压缩)
  259. func (s *Conn) WriteTo(b []byte, compress int, crypt bool) (n int, err error) {
  260. if COMPRESS_SNAPY_ENCODE == compress {
  261. return NewSnappyConn(s.Conn, crypt).Write(b)
  262. }
  263. return NewCryptConn(s.Conn, crypt).Write(b)
  264. }
  265. //写压缩方式,加密
  266. func (s *Conn) WriteConnInfo(en, de int, crypt, mux bool) {
  267. s.Write([]byte(strconv.Itoa(en) + strconv.Itoa(de) + GetStrByBool(crypt) + GetStrByBool(mux)))
  268. }
  269. //获取压缩方式,是否加密
  270. func (s *Conn) GetConnInfoFromConn() (en, de int, crypt, mux bool) {
  271. buf, err := s.ReadLen(4)
  272. if err != nil {
  273. return
  274. }
  275. en, _ = strconv.Atoi(string(buf[0]))
  276. de, _ = strconv.Atoi(string(buf[1]))
  277. crypt = GetBoolByStr(string(buf[2]))
  278. mux = GetBoolByStr(string(buf[3]))
  279. return
  280. }
  281. //close
  282. func (s *Conn) Close() error {
  283. return s.Conn.Close()
  284. }
  285. //write
  286. func (s *Conn) Write(b []byte) (int, error) {
  287. return s.Conn.Write(b)
  288. }
  289. //read
  290. func (s *Conn) Read(b []byte) (int, error) {
  291. return s.Conn.Read(b)
  292. }
  293. //write error
  294. func (s *Conn) WriteError() (int, error) {
  295. return s.Write([]byte(RES_MSG))
  296. }
  297. //write sign flag
  298. func (s *Conn) WriteSign() (int, error) {
  299. return s.Write([]byte(RES_SIGN))
  300. }
  301. //write main
  302. func (s *Conn) WriteMain() (int, error) {
  303. return s.Write([]byte(WORK_MAIN))
  304. }
  305. //write chan
  306. func (s *Conn) WriteChan() (int, error) {
  307. return s.Write([]byte(WORK_CHAN))
  308. }
  309. //write test
  310. func (s *Conn) WriteTest() (int, error) {
  311. return s.Write([]byte(TEST_FLAG))
  312. }
  313. //write test
  314. func (s *Conn) WriteSuccess() (int, error) {
  315. return s.Write([]byte(CONN_SUCCESS))
  316. }
  317. //write test
  318. func (s *Conn) WriteFail() (int, error) {
  319. return s.Write([]byte(CONN_ERROR))
  320. }
  321. //获取长度+内容
  322. func GetLenBytes(buf []byte) (b []byte, err error) {
  323. raw := bytes.NewBuffer([]byte{})
  324. if err = binary.Write(raw, binary.LittleEndian, int32(len(buf))); err != nil {
  325. return
  326. }
  327. if err = binary.Write(raw, binary.LittleEndian, buf); err != nil {
  328. return
  329. }
  330. b = raw.Bytes()
  331. return
  332. }
  333. //解析出长度
  334. func GetLenByBytes(buf []byte) (int, error) {
  335. nlen := binary.LittleEndian.Uint32(buf)
  336. if nlen <= 0 {
  337. return 0, errors.New("数据长度错误")
  338. }
  339. return int(nlen), nil
  340. }