1
0

util.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. package lib
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/base64"
  6. "encoding/binary"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "log"
  11. "net"
  12. "net/http"
  13. "net/http/httputil"
  14. "net/url"
  15. "regexp"
  16. "strconv"
  17. "strings"
  18. "sync"
  19. "time"
  20. )
  21. var (
  22. disabledRedirect = errors.New("disabled redirect.")
  23. )
  24. const (
  25. COMPRESS_NONE_ENCODE = iota
  26. COMPRESS_NONE_DECODE
  27. COMPRESS_SNAPY_ENCODE
  28. COMPRESS_SNAPY_DECODE
  29. IO_EOF = "PROXYEOF"
  30. )
  31. //error
  32. func BadRequest(w http.ResponseWriter) {
  33. http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
  34. }
  35. //发送请求并转为bytes
  36. func GetEncodeResponse(req *http.Request) ([]byte, error) {
  37. var respBytes []byte
  38. client := new(http.Client)
  39. client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
  40. return disabledRedirect
  41. }
  42. resp, err := client.Do(req)
  43. disRedirect := err != nil && strings.Contains(err.Error(), disabledRedirect.Error())
  44. if err != nil && !disRedirect {
  45. return respBytes, err
  46. }
  47. if !disRedirect {
  48. defer resp.Body.Close()
  49. } else {
  50. resp.Body = nil
  51. resp.ContentLength = 0
  52. }
  53. respBytes, err = EncodeResponse(resp)
  54. return respBytes, nil
  55. }
  56. // 将request转为bytes
  57. func EncodeRequest(r *http.Request) ([]byte, error) {
  58. raw := bytes.NewBuffer([]byte{})
  59. reqBytes, err := httputil.DumpRequest(r, true)
  60. if err != nil {
  61. return nil, err
  62. }
  63. binary.Write(raw, binary.LittleEndian, bool(r.URL.Scheme == "https"))
  64. binary.Write(raw, binary.LittleEndian, reqBytes)
  65. return raw.Bytes(), nil
  66. }
  67. // 将字节转为request
  68. func DecodeRequest(data []byte) (*http.Request, error) {
  69. req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(data[1:])))
  70. if err != nil {
  71. return nil, err
  72. }
  73. str := strings.Split(req.Host, ":")
  74. req.Host, err = getHost(str[0])
  75. if err != nil {
  76. return nil, err
  77. }
  78. scheme := "http"
  79. if data[0] == 1 {
  80. scheme = "https"
  81. }
  82. req.URL, _ = url.Parse(fmt.Sprintf("%s://%s%s", scheme, req.Host, req.RequestURI))
  83. req.RequestURI = ""
  84. return req, nil
  85. }
  86. // 将response转为字节
  87. func EncodeResponse(r *http.Response) ([]byte, error) {
  88. respBytes, err := httputil.DumpResponse(r, true)
  89. if err != nil {
  90. return nil, err
  91. }
  92. if config.Replace == 1 {
  93. respBytes = replaceHost(respBytes)
  94. }
  95. return respBytes, nil
  96. }
  97. // 将字节转为response
  98. func DecodeResponse(data []byte) (*http.Response, error) {
  99. resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(data)), nil)
  100. if err != nil {
  101. return nil, err
  102. }
  103. return resp, nil
  104. }
  105. // 根据host地址从配置是文件中查找对应目标
  106. func getHost(str string) (string, error) {
  107. for _, v := range config.SiteList {
  108. if v.Host == str {
  109. return v.Url + ":" + strconv.Itoa(v.Port), nil
  110. }
  111. }
  112. return "", errors.New("没有找到解析的的host!")
  113. }
  114. //替换
  115. func replaceHost(resp []byte) []byte {
  116. str := string(resp)
  117. for _, v := range config.SiteList {
  118. str = strings.Replace(str, v.Url+":"+strconv.Itoa(v.Port), v.Host, -1)
  119. str = strings.Replace(str, v.Url, v.Host, -1)
  120. }
  121. return []byte(str)
  122. }
  123. //copy
  124. func relay(in, out net.Conn, compressType int, crypt, mux bool) {
  125. switch compressType {
  126. case COMPRESS_SNAPY_ENCODE:
  127. copyBuffer(NewSnappyConn(in, crypt), out)
  128. if mux {
  129. out.Close()
  130. NewSnappyConn(in, crypt).Write([]byte(IO_EOF))
  131. }
  132. case COMPRESS_SNAPY_DECODE:
  133. copyBuffer(in, NewSnappyConn(out, crypt))
  134. if mux {
  135. in.Close()
  136. }
  137. case COMPRESS_NONE_ENCODE:
  138. copyBuffer(NewCryptConn(in, crypt), out)
  139. if mux {
  140. out.Close()
  141. NewCryptConn(in, crypt).Write([]byte(IO_EOF))
  142. }
  143. case COMPRESS_NONE_DECODE:
  144. copyBuffer(in, NewCryptConn(out, crypt))
  145. if mux {
  146. in.Close()
  147. }
  148. }
  149. if !mux {
  150. in.Close()
  151. out.Close()
  152. }
  153. }
  154. //判断压缩方式
  155. func getCompressType(compress string) (int, int) {
  156. switch compress {
  157. case "":
  158. return COMPRESS_NONE_DECODE, COMPRESS_NONE_ENCODE
  159. case "snappy":
  160. return COMPRESS_SNAPY_DECODE, COMPRESS_SNAPY_ENCODE
  161. default:
  162. log.Fatalln("数据压缩格式错误")
  163. }
  164. return COMPRESS_NONE_DECODE, COMPRESS_NONE_ENCODE
  165. }
  166. //简单的一个校验值
  167. func getverifyval(vkey string) string {
  168. //单客户端模式
  169. if *verifyKey != "" {
  170. return Md5(*verifyKey)
  171. }
  172. return Md5(vkey)
  173. }
  174. //验证
  175. func verify(verifyKeyMd5 string) bool {
  176. if *verifyKey != "" && getverifyval(*verifyKey) == verifyKeyMd5 {
  177. return true
  178. }
  179. if *verifyKey == "" {
  180. for k := range RunList {
  181. if getverifyval(k) == verifyKeyMd5 {
  182. return true
  183. }
  184. }
  185. }
  186. return false
  187. }
  188. //get key by host from x
  189. func getKeyByHost(host string) (h *HostList, t *ServerConfig, err error) {
  190. for _, v := range CsvDb.Hosts {
  191. if strings.Contains(host, v.Host) {
  192. h = v
  193. t, err = CsvDb.GetTask(v.Vkey)
  194. return
  195. }
  196. }
  197. err = errors.New("未找到host对应的内网目标")
  198. return
  199. }
  200. //通过host获取对应的ip地址
  201. func Gethostbyname(hostname string) string {
  202. if !DomainCheck(hostname) {
  203. return hostname
  204. }
  205. ips, _ := net.LookupIP(hostname)
  206. if ips != nil {
  207. for _, v := range ips {
  208. if v.To4() != nil {
  209. return v.String()
  210. }
  211. }
  212. }
  213. return ""
  214. }
  215. //检查是否是域名
  216. func DomainCheck(domain string) bool {
  217. var match bool
  218. IsLine := "^((http://)|(https://))?([a-zA-Z0-9]([a-zA-Z0-9\\-]{0,61}[a-zA-Z0-9])?\\.)+[a-zA-Z]{2,6}(/)"
  219. NotLine := "^((http://)|(https://))?([a-zA-Z0-9]([a-zA-Z0-9\\-]{0,61}[a-zA-Z0-9])?\\.)+[a-zA-Z]{2,6}"
  220. match, _ = regexp.MatchString(IsLine, domain)
  221. if !match {
  222. match, _ = regexp.MatchString(NotLine, domain)
  223. }
  224. return match
  225. }
  226. //检查basic认证
  227. func checkAuth(r *http.Request, user, passwd string) bool {
  228. s := strings.SplitN(r.Header.Get("Authorization"), " ", 2)
  229. if len(s) != 2 {
  230. return false
  231. }
  232. b, err := base64.StdEncoding.DecodeString(s[1])
  233. if err != nil {
  234. return false
  235. }
  236. pair := strings.SplitN(string(b), ":", 2)
  237. if len(pair) != 2 {
  238. return false
  239. }
  240. return pair[0] == user && pair[1] == passwd
  241. }
  242. //get bool by str
  243. func GetBoolByStr(s string) bool {
  244. switch s {
  245. case "1", "true":
  246. return true
  247. }
  248. return false
  249. }
  250. //get str by bool
  251. func GetStrByBool(b bool) string {
  252. if b {
  253. return "1"
  254. }
  255. return "0"
  256. }
  257. func GetIntNoerrByStr(str string) int {
  258. i, _ := strconv.Atoi(str)
  259. return i
  260. }
  261. var bufPool = sync.Pool{
  262. New: func() interface{} {
  263. return make([]byte, 65535)
  264. },
  265. }
  266. // io.copy的优化版,读取buffer长度原为32*1024,与snappy不同,导致读取出的内容存在差异,不利于解密,特此修改
  267. func copyBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
  268. //TODO 回收问题
  269. buf := bufPool.Get().([]byte)
  270. for {
  271. nr, er := src.Read(buf)
  272. if nr > 0 {
  273. nw, ew := dst.Write(buf[0:nr])
  274. if nw > 0 {
  275. written += int64(nw)
  276. }
  277. if ew != nil {
  278. err = ew
  279. break
  280. }
  281. if nr != nw {
  282. err = io.ErrShortWrite
  283. break
  284. }
  285. }
  286. if er != nil {
  287. if er != io.EOF {
  288. err = er
  289. }
  290. break
  291. }
  292. }
  293. return written, err
  294. }
  295. //连接重置 清空缓存区
  296. func FlushConn(c net.Conn) {
  297. c.SetReadDeadline(time.Now().Add(time.Second * 3))
  298. buf := bufPool.Get().([]byte)
  299. for {
  300. if _, err := c.Read(buf); err != nil {
  301. break
  302. }
  303. }
  304. c.SetReadDeadline(time.Time{})
  305. }