conn.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. package lib
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/binary"
  6. "errors"
  7. "github.com/golang/snappy"
  8. "io"
  9. "log"
  10. "net"
  11. "net/http"
  12. "net/url"
  13. "strconv"
  14. "strings"
  15. "time"
  16. )
  17. type CryptConn struct {
  18. conn net.Conn
  19. crypt bool
  20. }
  21. func NewCryptConn(conn net.Conn, crypt bool) *CryptConn {
  22. c := new(CryptConn)
  23. c.conn = conn
  24. c.crypt = crypt
  25. return c
  26. }
  27. //加密写
  28. func (s *CryptConn) Write(b []byte) (n int, err error) {
  29. n = len(b)
  30. if s.crypt {
  31. if b, err = AesEncrypt(b, []byte(cryptKey)); err != nil {
  32. return
  33. }
  34. if b, err = GetLenBytes(b); err != nil {
  35. return
  36. }
  37. }
  38. _, err = s.conn.Write(b)
  39. return
  40. }
  41. //解密读
  42. func (s *CryptConn) Read(b []byte) (n int, err error) {
  43. if s.crypt {
  44. var lens int
  45. var buf, bs []byte
  46. c := NewConn(s.conn)
  47. if lens, err = c.GetLen(); err != nil {
  48. return
  49. }
  50. if buf, err = c.ReadLen(lens); err != nil {
  51. return
  52. }
  53. if bs, err = AesDecrypt(buf, []byte(cryptKey)); err != nil {
  54. return
  55. }
  56. n = len(bs)
  57. copy(b, bs)
  58. return
  59. }
  60. return s.conn.Read(b)
  61. }
  62. type SnappyConn struct {
  63. w *snappy.Writer
  64. r *snappy.Reader
  65. crypt bool
  66. }
  67. func NewSnappyConn(conn net.Conn, crypt bool) *SnappyConn {
  68. c := new(SnappyConn)
  69. c.w = snappy.NewBufferedWriter(conn)
  70. c.r = snappy.NewReader(conn)
  71. c.crypt = crypt
  72. return c
  73. }
  74. //snappy压缩写 包含加密
  75. func (s *SnappyConn) Write(b []byte) (n int, err error) {
  76. n = len(b)
  77. if s.crypt {
  78. if b, err = AesEncrypt(b, []byte(cryptKey)); err != nil {
  79. log.Println("encode crypt error:", err)
  80. return
  81. }
  82. }
  83. if _, err = s.w.Write(b); err != nil {
  84. return
  85. }
  86. err = s.w.Flush()
  87. return
  88. }
  89. //snappy压缩读 包含解密
  90. func (s *SnappyConn) Read(b []byte) (n int, err error) {
  91. if n, err = s.r.Read(b); err != nil {
  92. return
  93. }
  94. if s.crypt {
  95. var bs []byte
  96. if bs, err = AesDecrypt(b[:n], []byte(cryptKey)); err != nil {
  97. log.Println("decode crypt error:", err)
  98. return
  99. }
  100. n = len(bs)
  101. copy(b, bs)
  102. }
  103. return
  104. }
  105. type Conn struct {
  106. conn net.Conn
  107. }
  108. //new conn
  109. func NewConn(conn net.Conn) *Conn {
  110. c := new(Conn)
  111. c.conn = conn
  112. return c
  113. }
  114. //读取指定长度内容
  115. func (s *Conn) ReadLen(len int) ([]byte, error) {
  116. buf := make([]byte, len)
  117. if n, err := io.ReadFull(s, buf); err != nil || n != len {
  118. return buf, errors.New("读取指定长度错误" + err.Error())
  119. }
  120. return buf, nil
  121. }
  122. //获取长度
  123. func (s *Conn) GetLen() (int, error) {
  124. val, err := s.ReadLen(4)
  125. if err != nil {
  126. return 0, err
  127. }
  128. return GetLenByBytes(val)
  129. }
  130. //写入长度+内容 粘包
  131. func (s *Conn) WriteLen(buf []byte) (int, error) {
  132. var b []byte
  133. if b, err = GetLenBytes(buf); err != nil {
  134. return 0, err
  135. }
  136. return s.Write(b)
  137. }
  138. //读取flag
  139. func (s *Conn) ReadFlag() (string, error) {
  140. val, err := s.ReadLen(4)
  141. if err != nil {
  142. return "", err
  143. }
  144. return string(val), err
  145. }
  146. //读取host 连接地址 压缩类型
  147. func (s *Conn) GetHostFromConn() (typeStr string, host string, en, de int, crypt bool, err error) {
  148. retry:
  149. lType, err := s.ReadLen(3)
  150. if err != nil {
  151. return
  152. }
  153. if typeStr = string(lType); typeStr == TEST_FLAG {
  154. en, de, crypt = s.GetConnInfoFromConn()
  155. goto retry
  156. }
  157. cLen, err := s.GetLen()
  158. if err != nil {
  159. return
  160. }
  161. hostByte, err := s.ReadLen(cLen)
  162. if err != nil {
  163. return
  164. }
  165. host = string(hostByte)
  166. return
  167. }
  168. //写连接类型 和 host地址
  169. func (s *Conn) WriteHost(ltype string, host string) (int, error) {
  170. raw := bytes.NewBuffer([]byte{})
  171. binary.Write(raw, binary.LittleEndian, []byte(ltype))
  172. binary.Write(raw, binary.LittleEndian, int32(len([]byte(host))))
  173. binary.Write(raw, binary.LittleEndian, []byte(host))
  174. return s.Write(raw.Bytes())
  175. }
  176. //设置连接为长连接
  177. func (s *Conn) SetAlive() {
  178. conn := s.conn.(*net.TCPConn)
  179. conn.SetReadDeadline(time.Time{})
  180. conn.SetKeepAlive(true)
  181. conn.SetKeepAlivePeriod(time.Duration(2 * time.Second))
  182. }
  183. //从tcp报文中解析出host,连接类型等
  184. func (s *Conn) GetHost() (method, address string, rb []byte, err error, r *http.Request) {
  185. var b [32 * 1024]byte
  186. var n int
  187. if n, err = s.Read(b[:]); err != nil {
  188. return
  189. }
  190. rb = b[:n]
  191. r, err = http.ReadRequest(bufio.NewReader(bytes.NewReader(rb)))
  192. if err != nil {
  193. log.Println("解析host出错:", err)
  194. return
  195. }
  196. hostPortURL, err := url.Parse(r.Host)
  197. if err != nil {
  198. return
  199. }
  200. if hostPortURL.Opaque == "443" { //https访问
  201. address = r.Host + ":443"
  202. } else { //http访问
  203. if strings.Index(hostPortURL.Host, ":") == -1 { //host不带端口, 默认80
  204. address = r.Host + ":80"
  205. } else {
  206. address = r.Host
  207. }
  208. }
  209. return
  210. }
  211. //单独读(加密|压缩)
  212. func (s *Conn) ReadFrom(b []byte, compress int, crypt bool) (int, error) {
  213. if COMPRESS_SNAPY_DECODE == compress {
  214. return NewSnappyConn(s.conn, crypt).Read(b)
  215. }
  216. return NewCryptConn(s.conn, crypt).Read(b)
  217. }
  218. //单独写(加密|压缩)
  219. func (s *Conn) WriteTo(b []byte, compress int, crypt bool) (n int, err error) {
  220. if COMPRESS_SNAPY_ENCODE == compress {
  221. return NewSnappyConn(s.conn, crypt).Write(b)
  222. }
  223. return NewCryptConn(s.conn, crypt).Write(b)
  224. }
  225. //写压缩方式,加密
  226. func (s *Conn) WriteConnInfo(en, de int, crypt bool) {
  227. s.Write([]byte(strconv.Itoa(en) + strconv.Itoa(de) + GetStrByBool(crypt)))
  228. }
  229. //获取压缩方式,是否加密
  230. func (s *Conn) GetConnInfoFromConn() (en, de int, crypt bool) {
  231. buf, err := s.ReadLen(3)
  232. //TODO:错误处理
  233. if err != nil {
  234. return
  235. }
  236. en, _ = strconv.Atoi(string(buf[0]))
  237. de, _ = strconv.Atoi(string(buf[1]))
  238. crypt = GetBoolByStr(string(buf[2]))
  239. return
  240. }
  241. //close
  242. func (s *Conn) Close() error {
  243. return s.conn.Close()
  244. }
  245. //write
  246. func (s *Conn) Write(b []byte) (int, error) {
  247. return s.conn.Write(b)
  248. }
  249. //read
  250. func (s *Conn) Read(b []byte) (int, error) {
  251. return s.conn.Read(b)
  252. }
  253. //write error
  254. func (s *Conn) wError() (int, error) {
  255. return s.Write([]byte(RES_MSG))
  256. }
  257. //write sign flag
  258. func (s *Conn) wSign() (int, error) {
  259. return s.Write([]byte(RES_SIGN))
  260. }
  261. //write main
  262. func (s *Conn) wMain() (int, error) {
  263. return s.Write([]byte(WORK_MAIN))
  264. }
  265. //write chan
  266. func (s *Conn) wChan() (int, error) {
  267. return s.Write([]byte(WORK_CHAN))
  268. }
  269. //write test
  270. func (s *Conn) wTest() (int, error) {
  271. return s.Write([]byte(TEST_FLAG))
  272. }
  273. //获取长度+内容
  274. func GetLenBytes(buf []byte) (b []byte, err error) {
  275. raw := bytes.NewBuffer([]byte{})
  276. if err = binary.Write(raw, binary.LittleEndian, int32(len(buf))); err != nil {
  277. return
  278. }
  279. if err = binary.Write(raw, binary.LittleEndian, buf); err != nil {
  280. return
  281. }
  282. b = raw.Bytes()
  283. return
  284. }
  285. //解析出长度
  286. func GetLenByBytes(buf []byte) (int, error) {
  287. nlen := binary.LittleEndian.Uint32(buf)
  288. if nlen <= 0 {
  289. return 0, errors.New("数据长度错误")
  290. }
  291. return int(nlen), nil
  292. }