1
0

util.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. package utils
  2. import (
  3. "encoding/base64"
  4. "io"
  5. "io/ioutil"
  6. "log"
  7. "net"
  8. "net/http"
  9. "os"
  10. "regexp"
  11. "strconv"
  12. "strings"
  13. "sync"
  14. "time"
  15. )
  16. const (
  17. COMPRESS_NONE_ENCODE = iota
  18. COMPRESS_NONE_DECODE
  19. COMPRESS_SNAPY_ENCODE
  20. COMPRESS_SNAPY_DECODE
  21. VERIFY_EER = "vkey"
  22. WORK_MAIN = "main"
  23. WORK_CHAN = "chan"
  24. RES_SIGN = "sign"
  25. RES_MSG = "msg0"
  26. CONN_SUCCESS = "sucs"
  27. CONN_ERROR = "fail"
  28. TEST_FLAG = "tst"
  29. CONN_TCP = "tcp"
  30. CONN_UDP = "udp"
  31. UnauthorizedBytes = `HTTP/1.1 401 Unauthorized
  32. Content-Type: text/plain; charset=utf-8
  33. WWW-Authenticate: Basic realm="easyProxy"
  34. 401 Unauthorized`
  35. IO_EOF = "PROXYEOF"
  36. ConnectionFailBytes = `HTTP/1.1 404 Not Found
  37. `
  38. )
  39. //copy
  40. func Relay(in, out net.Conn, compressType int, crypt, mux bool) (n int64, err error) {
  41. switch compressType {
  42. case COMPRESS_SNAPY_ENCODE:
  43. n, err = copyBuffer(NewSnappyConn(in, crypt), out)
  44. out.Close()
  45. NewSnappyConn(in, crypt).Write([]byte(IO_EOF))
  46. case COMPRESS_SNAPY_DECODE:
  47. n, err = copyBuffer(in, NewSnappyConn(out, crypt))
  48. in.Close()
  49. if !mux {
  50. out.Close()
  51. }
  52. case COMPRESS_NONE_ENCODE:
  53. n, err = copyBuffer(NewCryptConn(in, crypt), out)
  54. out.Close()
  55. NewCryptConn(in, crypt).Write([]byte(IO_EOF))
  56. case COMPRESS_NONE_DECODE:
  57. n, err = copyBuffer(in, NewCryptConn(out, crypt))
  58. in.Close()
  59. if !mux {
  60. out.Close()
  61. }
  62. }
  63. return
  64. }
  65. //判断压缩方式
  66. func GetCompressType(compress string) (int, int) {
  67. switch compress {
  68. case "":
  69. return COMPRESS_NONE_DECODE, COMPRESS_NONE_ENCODE
  70. case "snappy":
  71. return COMPRESS_SNAPY_DECODE, COMPRESS_SNAPY_ENCODE
  72. default:
  73. log.Fatalln("数据压缩格式错误")
  74. }
  75. return COMPRESS_NONE_DECODE, COMPRESS_NONE_ENCODE
  76. }
  77. //通过host获取对应的ip地址
  78. func GetHostByName(hostname string) string {
  79. if !DomainCheck(hostname) {
  80. return hostname
  81. }
  82. ips, _ := net.LookupIP(hostname)
  83. if ips != nil {
  84. for _, v := range ips {
  85. if v.To4() != nil {
  86. return v.String()
  87. }
  88. }
  89. }
  90. return ""
  91. }
  92. //检查是否是域名
  93. func DomainCheck(domain string) bool {
  94. var match bool
  95. IsLine := "^((http://)|(https://))?([a-zA-Z0-9]([a-zA-Z0-9\\-]{0,61}[a-zA-Z0-9])?\\.)+[a-zA-Z]{2,6}(/)"
  96. NotLine := "^((http://)|(https://))?([a-zA-Z0-9]([a-zA-Z0-9\\-]{0,61}[a-zA-Z0-9])?\\.)+[a-zA-Z]{2,6}"
  97. match, _ = regexp.MatchString(IsLine, domain)
  98. if !match {
  99. match, _ = regexp.MatchString(NotLine, domain)
  100. }
  101. return match
  102. }
  103. //检查basic认证
  104. func CheckAuth(r *http.Request, user, passwd string) bool {
  105. s := strings.SplitN(r.Header.Get("Authorization"), " ", 2)
  106. if len(s) != 2 {
  107. return false
  108. }
  109. b, err := base64.StdEncoding.DecodeString(s[1])
  110. if err != nil {
  111. return false
  112. }
  113. pair := strings.SplitN(string(b), ":", 2)
  114. if len(pair) != 2 {
  115. return false
  116. }
  117. return pair[0] == user && pair[1] == passwd
  118. }
  119. //get bool by str
  120. func GetBoolByStr(s string) bool {
  121. switch s {
  122. case "1", "true":
  123. return true
  124. }
  125. return false
  126. }
  127. //get str by bool
  128. func GetStrByBool(b bool) string {
  129. if b {
  130. return "1"
  131. }
  132. return "0"
  133. }
  134. //int
  135. func GetIntNoErrByStr(str string) int {
  136. i, _ := strconv.Atoi(str)
  137. return i
  138. }
  139. // io.copy的优化版,读取buffer长度原为32*1024,与snappy不同,导致读取出的内容存在差异,不利于解密
  140. //内存优化 用到pool,快速回收
  141. func copyBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
  142. for {
  143. //放在里面是为了加快回收和重利用
  144. buf := bufPoolCopy.Get().([]byte)
  145. nr, er := src.Read(buf)
  146. if nr > 0 {
  147. nw, ew := dst.Write(buf[0:nr])
  148. bufPoolCopy.Put(buf)
  149. if nw > 0 {
  150. written += int64(nw)
  151. }
  152. if ew != nil {
  153. err = ew
  154. break
  155. }
  156. if nr != nw {
  157. err = io.ErrShortWrite
  158. break
  159. }
  160. } else {
  161. bufPoolCopy.Put(buf)
  162. }
  163. if er != nil {
  164. if er != io.EOF {
  165. err = er
  166. }
  167. break
  168. }
  169. }
  170. return written, err
  171. }
  172. //连接重置 清空缓存区
  173. func FlushConn(c net.Conn) {
  174. c.SetReadDeadline(time.Now().Add(time.Second * 3))
  175. buf := bufPool.Get().([]byte)
  176. defer bufPool.Put(buf)
  177. for {
  178. if _, err := c.Read(buf); err != nil {
  179. break
  180. }
  181. }
  182. c.SetReadDeadline(time.Time{})
  183. }
  184. //简单的一个校验值
  185. func Getverifyval(vkey string) string {
  186. return Md5(vkey)
  187. }
  188. //wait replay group
  189. //conn1 网桥 conn2
  190. func ReplayWaitGroup(conn1 net.Conn, conn2 net.Conn, compressEncode, compressDecode int, crypt, mux bool) (out int64, in int64) {
  191. var wg sync.WaitGroup
  192. wg.Add(1)
  193. go func() {
  194. in, _ = Relay(conn1, conn2, compressEncode, crypt, mux)
  195. wg.Done()
  196. }()
  197. out, _ = Relay(conn2, conn1, compressDecode, crypt, mux)
  198. wg.Wait()
  199. return
  200. }
  201. func ChangeHostAndHeader(r *http.Request, host string, header string, addr string) {
  202. if host != "" {
  203. r.Host = host
  204. }
  205. if header != "" {
  206. h := strings.Split(header, "\n")
  207. for _, v := range h {
  208. hd := strings.Split(v, ":")
  209. if len(hd) == 2 {
  210. r.Header.Set(hd[0], hd[1])
  211. }
  212. }
  213. }
  214. addr = strings.Split(addr, ":")[0]
  215. r.Header.Set("X-Forwarded-For", addr)
  216. r.Header.Set("X-Real-IP", addr)
  217. }
  218. func ReadAllFromFile(filePth string) ([]byte, error) {
  219. f, err := os.Open(filePth)
  220. if err != nil {
  221. return nil, err
  222. }
  223. return ioutil.ReadAll(f)
  224. }