util.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. package lib
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/base64"
  6. "encoding/binary"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "log"
  11. "net"
  12. "net/http"
  13. "net/http/httputil"
  14. "net/url"
  15. "regexp"
  16. "strconv"
  17. "strings"
  18. "sync"
  19. )
  20. var (
  21. disabledRedirect = errors.New("disabled redirect.")
  22. bufPool = &sync.Pool{
  23. New: func() interface{} {
  24. return make([]byte, 32*1024)
  25. },
  26. }
  27. )
  28. //pool 实现
  29. type bufType [32 * 1024]byte
  30. const (
  31. COMPRESS_NONE_ENCODE = iota
  32. COMPRESS_NONE_DECODE
  33. COMPRESS_SNAPY_ENCODE
  34. COMPRESS_SNAPY_DECODE
  35. )
  36. func BadRequest(w http.ResponseWriter) {
  37. http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
  38. }
  39. //发送请求并转为bytes
  40. func GetEncodeResponse(req *http.Request) ([]byte, error) {
  41. var respBytes []byte
  42. client := new(http.Client)
  43. client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
  44. return disabledRedirect
  45. }
  46. resp, err := client.Do(req)
  47. disRedirect := err != nil && strings.Contains(err.Error(), disabledRedirect.Error())
  48. if err != nil && !disRedirect {
  49. return respBytes, err
  50. }
  51. if !disRedirect {
  52. defer resp.Body.Close()
  53. } else {
  54. resp.Body = nil
  55. resp.ContentLength = 0
  56. }
  57. respBytes, err = EncodeResponse(resp)
  58. return respBytes, nil
  59. }
  60. // 将request转为bytes
  61. func EncodeRequest(r *http.Request) ([]byte, error) {
  62. raw := bytes.NewBuffer([]byte{})
  63. reqBytes, err := httputil.DumpRequest(r, true)
  64. if err != nil {
  65. return nil, err
  66. }
  67. binary.Write(raw, binary.LittleEndian, bool(r.URL.Scheme == "https"))
  68. binary.Write(raw, binary.LittleEndian, reqBytes)
  69. return raw.Bytes(), nil
  70. }
  71. // 将字节转为request
  72. func DecodeRequest(data []byte) (*http.Request, error) {
  73. req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(data[1:])))
  74. if err != nil {
  75. return nil, err
  76. }
  77. str := strings.Split(req.Host, ":")
  78. req.Host, err = getHost(str[0])
  79. if err != nil {
  80. return nil, err
  81. }
  82. scheme := "http"
  83. if data[0] == 1 {
  84. scheme = "https"
  85. }
  86. req.URL, _ = url.Parse(fmt.Sprintf("%s://%s%s", scheme, req.Host, req.RequestURI))
  87. req.RequestURI = ""
  88. return req, nil
  89. }
  90. //// 将response转为字节
  91. func EncodeResponse(r *http.Response) ([]byte, error) {
  92. respBytes, err := httputil.DumpResponse(r, true)
  93. if err != nil {
  94. return nil, err
  95. }
  96. if config.Replace == 1 {
  97. respBytes = replaceHost(respBytes)
  98. }
  99. return respBytes, nil
  100. }
  101. // 将字节转为response
  102. func DecodeResponse(data []byte) (*http.Response, error) {
  103. resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(data)), nil)
  104. if err != nil {
  105. return nil, err
  106. }
  107. return resp, nil
  108. }
  109. func getHost(str string) (string, error) {
  110. for _, v := range config.SiteList {
  111. if v.Host == str {
  112. return v.Url + ":" + strconv.Itoa(v.Port), nil
  113. }
  114. }
  115. return "", errors.New("没有找到解析的的host!")
  116. }
  117. func replaceHost(resp []byte) []byte {
  118. str := string(resp)
  119. for _, v := range config.SiteList {
  120. str = strings.Replace(str, v.Url+":"+strconv.Itoa(v.Port), v.Host, -1)
  121. str = strings.Replace(str, v.Url, v.Host, -1)
  122. }
  123. return []byte(str)
  124. }
  125. func relay(in, out *Conn, compressType int, crypt bool) {
  126. fmt.Println(crypt)
  127. switch compressType {
  128. case COMPRESS_SNAPY_ENCODE:
  129. copyBuffer(NewSnappyConn(in.conn, crypt), out)
  130. case COMPRESS_SNAPY_DECODE:
  131. copyBuffer(in, NewSnappyConn(out.conn, crypt))
  132. case COMPRESS_NONE_ENCODE:
  133. copyBuffer(NewCryptConn(in.conn, crypt), out)
  134. case COMPRESS_NONE_DECODE:
  135. copyBuffer(in, NewCryptConn(out.conn, crypt))
  136. }
  137. out.Close()
  138. in.Close()
  139. }
  140. //判断压缩方式
  141. func getCompressType(compress string) (int, int) {
  142. switch compress {
  143. case "":
  144. return COMPRESS_NONE_DECODE, COMPRESS_NONE_ENCODE
  145. case "snappy":
  146. return COMPRESS_SNAPY_DECODE, COMPRESS_SNAPY_ENCODE
  147. default:
  148. log.Fatalln("数据压缩格式错误")
  149. }
  150. return COMPRESS_NONE_DECODE, COMPRESS_NONE_ENCODE
  151. }
  152. //简单的一个校验值
  153. func getverifyval(vkey string) string {
  154. //单客户端模式
  155. if *verifyKey != "" {
  156. return Md5(*verifyKey)
  157. }
  158. return Md5(vkey)
  159. }
  160. //验证
  161. func verify(verifyKeyMd5 string) bool {
  162. if *verifyKey != "" && getverifyval(*verifyKey) == verifyKeyMd5 {
  163. return true
  164. }
  165. if *verifyKey == "" {
  166. for k := range RunList {
  167. if getverifyval(k) == verifyKeyMd5 {
  168. return true
  169. }
  170. }
  171. }
  172. return false
  173. }
  174. //get key by host from x
  175. func getKeyByHost(host string) (h *HostList, t *TaskList, err error) {
  176. for _, v := range CsvDb.Hosts {
  177. if strings.Contains(host, v.Host) {
  178. h = v
  179. t, err = CsvDb.GetTask(v.Vkey)
  180. return
  181. }
  182. }
  183. err = errors.New("未找到host对应的内网目标")
  184. return
  185. }
  186. //通过host获取对应的ip地址
  187. func Gethostbyname(hostname string) string {
  188. if !DomainCheck(hostname) {
  189. return hostname
  190. }
  191. ips, _ := net.LookupIP(hostname)
  192. if ips != nil {
  193. for _, v := range ips {
  194. if v.To4() != nil {
  195. return v.String()
  196. }
  197. }
  198. }
  199. return ""
  200. }
  201. //检查是否是域名
  202. func DomainCheck(domain string) bool {
  203. var match bool
  204. IsLine := "^((http://)|(https://))?([a-zA-Z0-9]([a-zA-Z0-9\\-]{0,61}[a-zA-Z0-9])?\\.)+[a-zA-Z]{2,6}(/)"
  205. NotLine := "^((http://)|(https://))?([a-zA-Z0-9]([a-zA-Z0-9\\-]{0,61}[a-zA-Z0-9])?\\.)+[a-zA-Z]{2,6}"
  206. match, _ = regexp.MatchString(IsLine, domain)
  207. if !match {
  208. match, _ = regexp.MatchString(NotLine, domain)
  209. }
  210. return match
  211. }
  212. //检查basic认证
  213. func checkAuth(r *http.Request, user, passwd string) bool {
  214. s := strings.SplitN(r.Header.Get("Authorization"), " ", 2)
  215. if len(s) != 2 {
  216. return false
  217. }
  218. b, err := base64.StdEncoding.DecodeString(s[1])
  219. if err != nil {
  220. return false
  221. }
  222. pair := strings.SplitN(string(b), ":", 2)
  223. if len(pair) != 2 {
  224. return false
  225. }
  226. return pair[0] == user && pair[1] == passwd
  227. }
  228. //get bool by str
  229. func GetBoolByStr(s string) bool {
  230. switch s {
  231. case "1", "true":
  232. return true
  233. }
  234. return false
  235. }
  236. //get str by bool
  237. func GetStrByBool(b bool) string {
  238. if b {
  239. return "1"
  240. }
  241. return "0"
  242. }
  243. // io.copy的优化版,读取buffer长度原为32*1024,与snappy不同,导致读取出的内容存在差异,不利于解密
  244. func copyBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
  245. // If the reader has a WriteTo method, use it to do the copy.
  246. // Avoids an allocation and a copy.
  247. if wt, ok := src.(io.WriterTo); ok {
  248. return wt.WriteTo(dst)
  249. }
  250. // Similarly, if the writer has a ReadFrom method, use it to do the copy.
  251. if rt, ok := dst.(io.ReaderFrom); ok {
  252. return rt.ReadFrom(src)
  253. }
  254. buf := make([]byte, 65535)
  255. for {
  256. nr, er := src.Read(buf)
  257. if nr > 0 {
  258. nw, ew := dst.Write(buf[0:nr])
  259. if nw > 0 {
  260. written += int64(nw)
  261. }
  262. if ew != nil {
  263. err = ew
  264. break
  265. }
  266. if nr != nw {
  267. err = io.ErrShortWrite
  268. break
  269. }
  270. }
  271. if er != nil {
  272. if er != io.EOF {
  273. err = er
  274. }
  275. break
  276. }
  277. }
  278. return written, err
  279. }