Parcourir la source

添加多种模式

刘河 il y a 6 ans
Parent
commit
3ea895feb5
8 fichiers modifiés avec 824 ajouts et 208 suppressions
  1. 29 6
      README.md
  2. 72 80
      client.go
  3. 118 0
      conn.go
  4. 23 13
      main.go
  5. 101 109
      server.go
  6. 236 0
      sock5.go
  7. 97 0
      tunnel.go
  8. 148 0
      util.go

+ 29 - 6
README.md

@@ -1,3 +1,4 @@
+<<<<<<< Updated upstream
 # easyProxy
 轻量级、较高性能http代理服务器,主要应用与内网穿透。支持多站点配置、客户端与服务端连接中断自动重连,多路传输,大大的提高请求处理速度,go语言编写,无第三方依赖,经过测试内存占用小,普通场景下,仅占用10m内存。
 
@@ -135,12 +136,34 @@ server {
 
 如需开启,请加配置文件Replace值设置为1
 >注意:开启可能导致不应该被替换的内容被替换,请谨慎开启
+=======
+# rproxy
+简单的反向代理用于内网穿透  
+
+**特别注意,此工具只适合小文件类的访问测试,用来做做数据调试。当初也只是用于微信公众号开发,所以定位也是如此** 
+
+## 前言	  
+最近周末闲来无事,想起了做下微信公共号的开发,但微信限制只能80端口的,自己用的城中村的那种宽带,共用一个公网,没办法自己用路由做端口映射。自己的服务器在腾讯云上,每次都要编译完后用ftp上传再进行调试,非常的浪费时间。 一时间又不知道上哪找一个符合我的这种要求的工具,就索性自己构思了下,整个工作流程大致为:   
+
+## 工作原理  
+> 外部请求自己服务器上的HTTP服务端 -> 将数据传递给Socket服务器 -> Socket服务器将数据发送至已连接的Socket客户端 -> Socket客户端收到数据 -> 使用http请求本地http服务端 -> 本地http服务端处理相关后返回 -> Socket客户端将返回的数据发送至Socket服务端 -> Socket服务端解析出数据后原路返回至外部请求的HTTP  
+ 
+## 使用方法  
+> 1、go get github.com/ying32/rproxy  
+> 2、go build   
+> 3、服务端运行runsvr.bat或者runsvr.sh    
+> 4、客户端运行runcli.bat或者runcli.sh    
+
+## 命令行说明    
+>  --tcpport    Socket连接或者监听的端口   
+>  --httpport   当mode为server时为服务端监听端口,当为mode为client时为转发至本地客户端的端口  
+>  --mode       启动模式,可选为client、server,默认为client  
+>  --svraddr    当mode为client时有效,为连接服务器的地址,不需要填写端口    
+>  --vkey       客户端与服务端建立连接时校验的加密key,简单的。  
+>>>>>>> Stashed changes
 
 ## 操作系统支持  
-支持Windows、Linux、MacOSX等,无第三方依赖库。
-
-## 二级域名泛解析配置详细教程
-
-[详细教程](https://github.com/cnlh/easyProxy/wiki/%E4%BD%BF%E7%94%A8%E6%95%99%E7%A8%8B)
-
+支持Windows、Linux、MacOSX等,无第三方依赖库。  
 
+## 二进制下载
+https://github.com/ying32/rproxy/releases/tag/v0.4  

+ 72 - 80
client.go

@@ -1,20 +1,15 @@
 package main
 
 import (
-	"encoding/binary"
 	"errors"
+	"fmt"
+	"io"
 	"log"
 	"net"
-	"net/http"
-	"strings"
 	"sync"
 	"time"
 )
 
-var (
-	disabledRedirect = errors.New("disabled redirect.")
-)
-
 type TRPClient struct {
 	svrAddr string
 	tcpNum  int
@@ -28,56 +23,58 @@ func NewRPClient(svraddr string, tcpNum int) *TRPClient {
 	return c
 }
 
-func (c *TRPClient) Start() error {
-	for i := 0; i < c.tcpNum; i++ {
-		go c.newConn()
+func (s *TRPClient) Start() error {
+	for i := 0; i < s.tcpNum; i++ {
+		go s.newConn()
 	}
 	for {
-		time.Sleep(5 * time.Second)
+		time.Sleep(time.Second * 5)
 	}
 	return nil
 }
 
-func (c *TRPClient) newConn() error {
-	c.Lock()
-	conn, err := net.Dial("tcp", c.svrAddr)
+//新建
+func (s *TRPClient) newConn() error {
+	s.Lock()
+	conn, err := net.Dial("tcp", s.svrAddr)
 	if err != nil {
 		log.Println("连接服务端失败,五秒后将重连")
 		time.Sleep(time.Second * 5)
-		c.Unlock()
-		c.newConn()
+		s.Unlock()
+		go s.newConn()
 		return err
 	}
-	c.Unlock()
-	conn.(*net.TCPConn).SetKeepAlive(true)
-	conn.(*net.TCPConn).SetKeepAlivePeriod(time.Duration(2 * time.Second))
-	return c.process(conn)
+	s.Unlock()
+	return s.process(NewConn(conn))
 }
 
-func (c *TRPClient) werror(conn net.Conn) {
-	conn.Write([]byte("msg0"))
-}
-
-func (c *TRPClient) process(conn net.Conn) error {
-	if _, err := conn.Write(getverifyval()); err != nil {
+func (s *TRPClient) process(c *Conn) error {
+	c.SetAlive()
+	if _, err := c.Write(getverifyval()); err != nil {
 		return err
 	}
-	val := make([]byte, 4)
+	c.wMain()
 	for {
-		_, err := conn.Read(val)
+		flags, err := c.ReadFlag()
 		if err != nil {
 			log.Println("服务端断开,五秒后将重连", err)
 			time.Sleep(5 * time.Second)
-			go c.newConn()
-			return err
+			go s.newConn()
+			break
 		}
-		flags := string(val)
 		switch flags {
-		case "vkey":
+		case VERIFY_EER:
 			log.Fatal("vkey不正确,请检查配置文件")
-		case "sign":
-			c.deal(conn)
-		case "msg0":
+		case RES_SIGN: //代理请求模式
+			if err := s.dealHttp(c); err != nil {
+				log.Println(err)
+				return err
+			}
+		case WORK_CHAN: //隧道模式,每次开启10个,加快连接速度
+			for i := 0; i < 10; i++ {
+				go s.dealChan()
+			}
+		case RES_MSG:
 			log.Println("服务端返回错误。")
 		default:
 			log.Println("无法解析该错误。")
@@ -85,69 +82,64 @@ func (c *TRPClient) process(conn net.Conn) error {
 	}
 	return nil
 }
-func (c *TRPClient) deal(conn net.Conn) error {
-	val := make([]byte, 4)
-	_, err := conn.Read(val)
-	nlen := binary.LittleEndian.Uint32(val)
-	log.Println("收到服务端数据,长度:", nlen)
-	if nlen <= 0 {
-		log.Println("数据长度错误。")
-		c.werror(conn)
-		return errors.New("数据长度错误")
+
+//隧道模式处理
+func (s *TRPClient) dealChan() error {
+	//创建一个tcp连接
+	conn, err := net.Dial("tcp", s.svrAddr)
+	//验证
+	if _, err := conn.Write(getverifyval()); err != nil {
+		return err
 	}
-	raw := make([]byte, nlen)
-	n, err := conn.Read(raw)
+	//默认长连接保持
+	c := NewConn(conn)
+	c.SetAlive()
+	//写标志
+	c.wChan()
+	//获取连接的host
+	host, err := c.GetHostFromConn()
 	if err != nil {
 		return err
 	}
-	if n != int(nlen) {
-		log.Printf("读取服务端数据长度错误,已经读取%dbyte,总长度%d字节\n", n, nlen)
-		c.werror(conn)
-		return errors.New("读取服务端数据长度错误")
-	}
-	req, err := DecodeRequest(raw)
+	//与目标建立连接
+	server, err := net.Dial("tcp", host)
 	if err != nil {
-		log.Println("DecodeRequest错误:", err)
-		c.werror(conn)
 		return err
 	}
-	rawQuery := ""
-	if req.URL.RawQuery != "" {
-		rawQuery = "?" + req.URL.RawQuery
-	}
-	log.Println(req.URL.Path + rawQuery)
-	client := new(http.Client)
-	client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
-		return disabledRedirect
+	//创建成功后io.copy
+	go io.Copy(server, c)
+	io.Copy(c, server)
+	return nil
+}
+
+//http模式处理
+func (s *TRPClient) dealHttp(c *Conn) error {
+	nlen, err := c.GetLen()
+	if err != nil {
+		c.wError()
+		return err
 	}
-	resp, err := client.Do(req)
-	disRedirect := err != nil && strings.Contains(err.Error(), disabledRedirect.Error())
-	if err != nil && !disRedirect {
-		log.Println("请求本地客户端错误:", err)
-		c.werror(conn)
+	raw, err := c.ReadLen(int(nlen))
+	if err != nil {
+		c.wError()
 		return err
 	}
-	if !disRedirect {
-		defer resp.Body.Close()
-	} else {
-		resp.Body = nil
-		resp.ContentLength = 0
+	req, err := DecodeRequest(raw)
+	if err != nil {
+		c.wError()
+		return err
 	}
-	respBytes, err := EncodeResponse(resp)
+	respBytes, err := GetEncodeResponse(req)
 	if err != nil {
-		log.Println("EncodeResponse错误:", err)
-		c.werror(conn)
+		c.wError()
 		return err
 	}
-	n, err = conn.Write(respBytes)
+	n, err := c.Write(respBytes)
 	if err != nil {
-		log.Println("发送数据错误,错误:", err)
 		return err
 	}
 	if n != len(respBytes) {
-		log.Printf("发送数据长度错误,已经发送:%dbyte,总字节长:%dbyte\n", n, len(respBytes))
-	} else {
-		log.Printf("本次请求成功完成,共发送:%dbyte\n", n)
+		return errors.New(fmt.Sprintf("发送数据长度错误,已经发送:%dbyte,总字节长:%dbyte\n", n, len(respBytes)))
 	}
 	return nil
 }

+ 118 - 0
conn.go

@@ -0,0 +1,118 @@
+package main
+
+import (
+	"bytes"
+	"encoding/binary"
+	"errors"
+	"fmt"
+	"io"
+	"net"
+	"time"
+)
+
+type Conn struct {
+	conn net.Conn
+}
+
+func NewConn(conn net.Conn) *Conn {
+	c := new(Conn)
+	c.conn = conn
+	return c
+}
+
+//读取指定内容长度
+func (s *Conn) ReadLen(len int) ([]byte, error) {
+	raw := make([]byte, 0)
+	buff := make([]byte, 1024)
+	c := 0
+	for {
+		clen, err := s.conn.Read(buff)
+		if err != nil && err != io.EOF {
+			return raw, err
+		}
+		raw = append(raw, buff[:clen]...)
+		if c += clen; c >= len {
+			break
+		}
+	}
+	if c != len {
+		return raw, errors.New(fmt.Sprintf("已读取长度错误,已读取%dbyte,需要读取%dbyte。", c, len))
+	}
+	return raw, nil
+}
+
+//获取长度
+func (s *Conn) GetLen() (int, error) {
+	val := make([]byte, 4)
+	_, err := s.conn.Read(val)
+	if err != nil {
+		return 0, err
+	}
+	nlen := binary.LittleEndian.Uint32(val)
+	if nlen <= 0 {
+		return 0, errors.New("数据长度错误")
+	}
+	return int(nlen), nil
+}
+
+//读取flag
+func (s *Conn) ReadFlag() (string, error) {
+	val := make([]byte, 4)
+	_, err := s.conn.Read(val)
+	if err != nil {
+		return "", err
+	}
+	return string(val), err
+}
+
+//读取host
+func (s *Conn) GetHostFromConn() (string, error) {
+	len, err := s.GetLen()
+	if err != nil {
+		return "", err
+	}
+	hostByte := make([]byte, len)
+	_, err = s.conn.Read(hostByte)
+	if err != nil {
+		return "", err
+	}
+	return string(hostByte), nil
+}
+
+//获取host
+func (s *Conn) WriteHost(host string) (int, error) {
+	raw := bytes.NewBuffer([]byte{})
+	binary.Write(raw, binary.LittleEndian, int32(len([]byte(host))))
+	binary.Write(raw, binary.LittleEndian, []byte(host))
+	return s.Write(raw.Bytes())
+}
+
+//设置连接为长连接
+func (s *Conn) SetAlive() {
+	conn := s.conn.(*net.TCPConn)
+	conn.SetReadDeadline(time.Time{})
+	conn.SetKeepAlive(true)
+	conn.SetKeepAlivePeriod(time.Duration(2 * time.Second))
+}
+
+func (s *Conn) Close() error {
+	return s.conn.Close()
+}
+func (s *Conn) Write(b []byte) (int, error) {
+	return s.conn.Write(b)
+}
+func (s *Conn) Read(b []byte) (int, error) {
+	return s.conn.Read(b)
+}
+
+func (s *Conn) wError() {
+	s.conn.Write([]byte(RES_MSG))
+}
+
+func (s *Conn) wMain() {
+	s.conn.Write([]byte(WORK_MAIN))
+}
+
+func (s *Conn) wChan() {
+	s.conn.Write([]byte(WORK_CHAN))
+}

+ 23 - 13
main.go

@@ -7,13 +7,14 @@ import (
 )
 
 var (
-	configPath = flag.String("config", "config.json", "配置文件路径")
-	tcpPort    = flag.Int("tcpport", 8284, "Socket连接或者监听的端口")
-	httpPort   = flag.Int("httpport", 8024, "当mode为server时为服务端监听端口,当为mode为client时为转发至本地客户端的端口")
-	rpMode     = flag.String("mode", "client", "启动模式,可选为client、server")
-	verifyKey  = flag.String("vkey", "", "验证密钥")
-	config     Config
-	err        error
+	configPath   = flag.String("config", "config.json", "配置文件路径")
+	tcpPort      = flag.Int("tcpport", 8284, "Socket连接或者监听的端口")
+	httpPort     = flag.Int("httpport", 8024, "当mode为server时为服务端监听端口,当为mode为client时为转发至本地客户端的端口")
+	rpMode       = flag.String("mode", "client", "启动模式,可选为client、server")
+	tunnelTarget = flag.String("target", "10.1.50.203:80", "tunnel模式远程目标")
+	verifyKey    = flag.String("vkey", "", "验证密钥")
+	config       Config
+	err          error
 )
 
 func main() {
@@ -29,7 +30,7 @@ func main() {
 		log.Println("客户端启动,连接:", config.Server.Ip, ", 端口:", config.Server.Tcp)
 		cli := NewRPClient(fmt.Sprintf("%s:%d", config.Server.Ip, config.Server.Tcp), config.Server.Num)
 		cli.Start()
-	} else if *rpMode == "server" {
+	} else {
 		if *verifyKey == "" {
 			log.Fatalln("必须输入一个验证的key")
 		}
@@ -39,11 +40,20 @@ func main() {
 		if *httpPort <= 0 || *httpPort >= 65536 {
 			log.Fatalln("请输入正确的http端口。")
 		}
-		log.Println("服务端启动,监听tcp服务端端口:", *tcpPort, ", http服务端端口:", *httpPort)
-		svr := NewRPServer(*tcpPort, *httpPort)
-		if err := svr.Start(); err != nil {
-			log.Fatalln(err)
+		log.Println("服务端启动,监听tcp服务端端口:", *tcpPort, ", 外部服务端端口:", *httpPort)
+		if *rpMode == "httpServer" {
+			svr := NewHttpModeServer(*tcpPort, *httpPort)
+			if err := svr.Start(); err != nil {
+				log.Fatalln(err)
+			}
+		} else if *rpMode == "tunnelServer" {
+			svr := NewTunnelModeServer(*tcpPort, *httpPort, *tunnelTarget)
+			if err := svr.Start(); err != nil {
+				log.Fatalln(err)
+			}
+		} else if *rpMode == "sock5Server" {
+			svr := NewSock5ModeServer(*tcpPort, *httpPort)
+			svr.Start()
 		}
-		defer svr.Close()
 	}
 }

+ 101 - 109
server.go

@@ -1,8 +1,6 @@
 package main
 
 import (
-	"bytes"
-	"encoding/binary"
 	"errors"
 	"fmt"
 	"io"
@@ -10,115 +8,69 @@ import (
 	"log"
 	"net"
 	"net/http"
-	"sync"
-	"time"
 )
 
-type TRPServer struct {
-	tcpPort  int
+const (
+	VERIFY_EER = "vkey"
+	WORK_MAIN  = "main"
+	WORK_CHAN  = "chan"
+	RES_SIGN   = "sign"
+	RES_MSG    = "msg0"
+)
+
+type HttpModeServer struct {
+	Tunnel
 	httpPort int
-	listener *net.TCPListener
-	connList chan net.Conn
-	sync.RWMutex
 }
 
-func NewRPServer(tcpPort, httpPort int) *TRPServer {
-	s := new(TRPServer)
-	s.tcpPort = tcpPort
+func NewHttpModeServer(tcpPort, httpPort int) *HttpModeServer {
+	s := new(HttpModeServer)
+	s.tunnelPort = tcpPort
 	s.httpPort = httpPort
-	s.connList = make(chan net.Conn, 1000)
+	s.signalList = make(chan *Conn, 1000)
 	return s
 }
 
-func (s *TRPServer) Start() error {
-	var err error
-	s.listener, err = net.ListenTCP("tcp", &net.TCPAddr{net.ParseIP("0.0.0.0"), s.tcpPort, ""})
+//开始
+func (s *HttpModeServer) Start() (error) {
+	err := s.StartTunnel()
 	if err != nil {
+		log.Fatalln("开启客户端失败!", err)
 		return err
 	}
-	go s.httpserver()
-	return s.tcpserver()
-}
-
-func (s *TRPServer) Close() error {
-	if s.listener != nil {
-		err := s.listener.Close()
-		s.listener = nil
-		return err
-	}
-	return errors.New("TCP实例未创建!")
-}
-
-func (s *TRPServer) tcpserver() error {
-	var err error
-	for {
-		conn, err := s.listener.AcceptTCP()
-		if err != nil {
-			log.Println(err)
-			continue
-		}
-		go s.cliProcess(conn)
-	}
-	return err
-}
-
-func badRequest(w http.ResponseWriter) {
-	http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
+	s.startHttpServer()
+	return nil
 }
 
-func (s *TRPServer) httpserver() {
+//开启http端口监听
+func (s *HttpModeServer) startHttpServer() {
 	http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
 	retry:
-		if len(s.connList) == 0 {
-			badRequest(w)
+		if len(s.signalList) == 0 {
+			BadRequest(w)
 			return
 		}
-		conn := <-s.connList
-		log.Println(r.RequestURI)
-		err := s.write(r, conn)
-		if err != nil {
+		conn := <-s.signalList
+		if err := s.writeRequest(r, conn); err != nil {
 			log.Println(err)
 			conn.Close()
 			goto retry
 			return
 		}
-		err = s.read(w, conn)
+		err = s.writeResponse(w, conn)
 		if err != nil {
 			log.Println(err)
 			conn.Close()
 			goto retry
 			return
 		}
-		s.connList <- conn
-		conn = nil
+		s.signalList <- conn
 	})
 	log.Fatalln(http.ListenAndServe(fmt.Sprintf(":%d", s.httpPort), nil))
 }
 
-func (s *TRPServer) cliProcess(conn *net.TCPConn) error {
-	conn.SetReadDeadline(time.Now().Add(time.Duration(5) * time.Second))
-	vval := make([]byte, 20)
-	_, err := conn.Read(vval)
-	if err != nil {
-		log.Println("客户端读超时。客户端地址为::", conn.RemoteAddr())
-		conn.Close()
-		return err
-	}
-	if bytes.Compare(vval, getverifyval()[:]) != 0 {
-		log.Println("当前客户端连接校验错误,关闭此客户端:", conn.RemoteAddr())
-		conn.Write([]byte("vkey"))
-		conn.Close()
-		return err
-	}
-	conn.SetReadDeadline(time.Time{})
-	log.Println("连接新的客户端:", conn.RemoteAddr())
-	conn.SetKeepAlive(true)
-	conn.SetKeepAlivePeriod(time.Duration(2 * time.Second))
-	s.connList <- conn
-	return nil
-}
-
-func (s *TRPServer) write(r *http.Request, conn net.Conn) error {
+//req转为bytes发送给client端
+func (s *HttpModeServer) writeRequest(r *http.Request, conn *Conn) error {
 	raw, err := EncodeRequest(r)
 	if err != nil {
 		return err
@@ -133,41 +85,21 @@ func (s *TRPServer) write(r *http.Request, conn net.Conn) error {
 	return nil
 }
 
-func (s *TRPServer) read(w http.ResponseWriter, conn net.Conn) (error) {
-	val := make([]byte, 4)
-	_, err := conn.Read(val)
+//从client读取出Response
+func (s *HttpModeServer) writeResponse(w http.ResponseWriter, c *Conn) error {
+	flags, err := c.ReadFlag()
 	if err != nil {
 		return err
 	}
-	flags := string(val)
 	switch flags {
-	case "sign":
-		_, err = conn.Read(val)
+	case RES_SIGN:
+		nlen, err := c.GetLen()
 		if err != nil {
 			return err
 		}
-		nlen := int(binary.LittleEndian.Uint32(val))
-		if nlen == 0 {
-			return errors.New("读取客户端长度错误。")
-		}
-		log.Println("收到客户端数据,需要读取长度:", nlen)
-		raw := make([]byte, 0)
-		buff := make([]byte, 1024)
-		c := 0
-		for {
-			clen, err := conn.Read(buff)
-			if err != nil && err != io.EOF {
-				return err
-			}
-			raw = append(raw, buff[:clen]...)
-			c += clen
-			if c >= nlen {
-				break
-			}
-		}
-		log.Println("读取完成,长度:", c, "实际raw长度:", len(raw))
-		if c != nlen {
-			return fmt.Errorf("已读取长度错误,已读取%dbyte,需要读取%dbyte。", c, nlen)
+		raw, err := c.ReadLen(nlen)
+		if err != nil {
+			return err
 		}
 		resp, err := DecodeResponse(raw)
 		if err != nil {
@@ -184,10 +116,70 @@ func (s *TRPServer) read(w http.ResponseWriter, conn net.Conn) (error) {
 		}
 		w.WriteHeader(resp.StatusCode)
 		w.Write(bodyBytes)
-	case "msg0":
-		return nil
+	case RES_MSG:
+		BadRequest(w)
+		return errors.New("客户端请求出错")
 	default:
-		log.Println("无法解析此错误", string(val))
+		BadRequest(w)
+		return errors.New("无法解析此错误")
+	}
+	return nil
+}
+
+type TunnelModeServer struct {
+	Tunnel
+	httpPort     int
+	tunnelTarget string
+}
+
+func NewTunnelModeServer(tcpPort, httpPort int, tunnelTarget string) *TunnelModeServer {
+	s := new(TunnelModeServer)
+	s.tunnelPort = tcpPort
+	s.httpPort = httpPort
+	s.tunnelTarget = tunnelTarget
+	s.tunnelList = make(chan *Conn, 1000)
+	s.signalList = make(chan *Conn, 10)
+	return s
+}
+
+//开始
+func (s *TunnelModeServer) Start() (error) {
+	err := s.StartTunnel()
+	if err != nil {
+		log.Fatalln("开启客户端失败!", err)
+		return err
+	}
+	s.startTunnelServer()
+	return nil
+}
+
+//隧道模式server
+func (s *TunnelModeServer) startTunnelServer() {
+	listener, err := net.ListenTCP("tcp", &net.TCPAddr{net.ParseIP("0.0.0.0"), s.httpPort, ""})
+	if err != nil {
+		log.Fatalln(err)
+	}
+	for {
+		conn, err := listener.AcceptTCP()
+		if err != nil {
+			log.Println(err)
+			continue
+		}
+		go s.process(NewConn(conn))
+	}
+}
+
+//监听连接处理
+func (s *TunnelModeServer) process(c *Conn) error {
+retry:
+	if len(s.tunnelList) < 10 { //新建通道
+		go s.newChan()
+	}
+	link := <-s.tunnelList
+	if _, err := link.WriteHost(s.tunnelTarget); err != nil {
+		goto retry
 	}
+	go io.Copy(link, c)
+	io.Copy(c, link.conn)
 	return nil
 }

+ 236 - 0
sock5.go

@@ -0,0 +1,236 @@
+package main
+
+import (
+	"encoding/binary"
+	"errors"
+	"io"
+	"log"
+	"net"
+	"strconv"
+)
+
+const (
+	ipV4       = 1
+	domainName = 3
+	ipV6       = 4
+	connectMethod   = 1
+	bindMethod      = 2
+	associateMethod = 3
+	// The maximum packet size of any udp Associate packet, based on ethernet's max size,
+	// minus the IP and UDP headers. IPv4 has a 20 byte header, UDP adds an
+	// additional 4 bytes.  This is a total overhead of 24 bytes.  Ethernet's
+	// max packet size is 1500 bytes,  1500 - 24 = 1476.
+	maxUDPPacketSize = 1476
+)
+
+const (
+	succeeded uint8 = iota
+	serverFailure
+	notAllowed
+	networkUnreachable
+	hostUnreachable
+	connectionRefused
+	ttlExpired
+	commandNotSupported
+	addrTypeNotSupported
+)
+
+type Sock5ModeServer struct {
+	Tunnel
+	httpPort int
+}
+
+func (s *Sock5ModeServer) handleRequest(c net.Conn) {
+	/*
+		The SOCKS request is formed as follows:
+		+----+-----+-------+------+----------+----------+
+		|VER | CMD |  RSV  | ATYP | DST.ADDR | DST.PORT |
+		+----+-----+-------+------+----------+----------+
+		| 1  |  1  | X'00' |  1   | Variable |    2     |
+		+----+-----+-------+------+----------+----------+
+	*/
+	header := make([]byte, 3)
+
+	_, err := io.ReadFull(c, header)
+
+	if err != nil {
+		log.Println("illegal request", err)
+		c.Close()
+		return
+	}
+
+	switch header[1] {
+	case connectMethod:
+		s.handleConnect(c)
+	case bindMethod:
+		s.handleBind(c)
+	case associateMethod:
+		s.handleUDP(c)
+	default:
+		s.sendReply(c, commandNotSupported)
+		c.Close()
+	}
+}
+
+func (s *Sock5ModeServer) sendReply(c net.Conn, rep uint8) {
+	reply := []byte{
+		5,
+		rep,
+		0,
+		1,
+	}
+
+	localAddr := c.LocalAddr().String()
+	localHost, localPort, _ := net.SplitHostPort(localAddr)
+	ipBytes := net.ParseIP(localHost).To4()
+	nPort, _ := strconv.Atoi(localPort)
+	reply = append(reply, ipBytes...)
+	portBytes := make([]byte, 2)
+	binary.BigEndian.PutUint16(portBytes, uint16(nPort))
+	reply = append(reply, portBytes...)
+
+	c.Write(reply)
+}
+
+func (s *Sock5ModeServer) doConnect(c net.Conn, command uint8) (proxyConn *Conn, err error) {
+	addrType := make([]byte, 1)
+	c.Read(addrType)
+	var host string
+	switch addrType[0] {
+	case ipV4:
+		ipv4 := make(net.IP, net.IPv4len)
+		c.Read(ipv4)
+		host = ipv4.String()
+	case ipV6:
+		ipv6 := make(net.IP, net.IPv6len)
+		c.Read(ipv6)
+		host = ipv6.String()
+	case domainName:
+		var domainLen uint8
+		binary.Read(c, binary.BigEndian, &domainLen)
+		domain := make([]byte, domainLen)
+		c.Read(domain)
+		host = string(domain)
+	default:
+		s.sendReply(c, addrTypeNotSupported)
+		err = errors.New("Address type not supported")
+		return nil, err
+	}
+
+	var port uint16
+	binary.Read(c, binary.BigEndian, &port)
+
+	// connect to host
+	addr := net.JoinHostPort(host, strconv.Itoa(int(port)))
+	//取出一个连接
+	if len(s.tunnelList) < 10 { //新建通道
+		go s.newChan()
+	}
+	client := <-s.tunnelList
+	s.sendReply(c, succeeded)
+	_, err = client.WriteHost(addr)
+	return client, nil
+}
+
+func (s *Sock5ModeServer) handleConnect(c net.Conn) {
+	proxyConn, err := s.doConnect(c, connectMethod)
+	if err != nil {
+		c.Close()
+	} else {
+		go io.Copy(c, proxyConn)
+		go io.Copy(proxyConn, c)
+	}
+
+}
+
+func (s *Sock5ModeServer) relay(in, out net.Conn) {
+	if _, err := io.Copy(in, out); err != nil {
+		log.Println("copy error", err)
+	}
+	in.Close() // will trigger an error in the other relay, then call out.Close()
+}
+
+// passive mode
+func (s *Sock5ModeServer) handleBind(c net.Conn) {
+}
+
+func (s *Sock5ModeServer) handleUDP(c net.Conn) {
+	log.Println("UDP Associate")
+	/*
+	   +----+------+------+----------+----------+----------+
+	   |RSV | FRAG | ATYP | DST.ADDR | DST.PORT |   DATA   |
+	   +----+------+------+----------+----------+----------+
+	   | 2  |  1   |  1   | Variable |    2     | Variable |
+	   +----+------+------+----------+----------+----------+
+	*/
+	buf := make([]byte, 3)
+	c.Read(buf)
+	// relay udp datagram silently, without any notification to the requesting client
+	if buf[2] != 0 {
+		// does not support fragmentation, drop it
+		log.Println("does not support fragmentation, drop")
+		dummy := make([]byte, maxUDPPacketSize)
+		c.Read(dummy)
+	}
+
+	proxyConn, err := s.doConnect(c, associateMethod)
+	if err != nil {
+		c.Close()
+	} else {
+		go io.Copy(c, proxyConn)
+		go io.Copy(proxyConn, c)
+	}
+}
+
+func (s *Sock5ModeServer) handleNewConn(c net.Conn) {
+	buf := make([]byte, 2)
+	if _, err := io.ReadFull(c, buf); err != nil {
+		log.Println("negotiation err", err)
+		c.Close()
+		return
+	}
+
+	if version := buf[0]; version != 5 {
+		log.Println("only support socks5, request from: ", c.RemoteAddr())
+		c.Close()
+		return
+	}
+	nMethods := buf[1]
+
+	methods := make([]byte, nMethods)
+	if len, err := c.Read(methods); len != int(nMethods) || err != nil {
+		log.Println("wrong method")
+		c.Close()
+		return
+	}
+	// no authentication required for now
+	buf[1] = 0
+	// send a METHOD selection message
+	c.Write(buf)
+
+	s.handleRequest(c)
+}
+
+func (s *Sock5ModeServer) Start() {
+	l, err := net.Listen("tcp", ":"+strconv.Itoa(s.httpPort))
+	if err != nil {
+		log.Fatal("listen error: ", err)
+	}
+	s.StartTunnel()
+	for {
+		conn, err := l.Accept()
+		if err != nil {
+			log.Fatal("accept error: ", err)
+		}
+		go s.handleNewConn(conn)
+	}
+}
+
+func NewSock5ModeServer(tcpPort, httpPort int) *Sock5ModeServer {
+	s := new(Sock5ModeServer)
+	s.tunnelPort = tcpPort
+	s.httpPort = httpPort
+	s.tunnelList = make(chan *Conn, 1000)
+	s.signalList = make(chan *Conn, 10)
+	return s
+}

+ 97 - 0
tunnel.go

@@ -0,0 +1,97 @@
+package main
+
+import (
+	"bytes"
+	"errors"
+	"fmt"
+	"log"
+	"net"
+	"sync"
+	"time"
+)
+
+type Tunnel struct {
+	tunnelPort int              //通信隧道端口
+	listener   *net.TCPListener //server端监听
+	signalList chan *Conn       //通信
+	tunnelList chan *Conn       //隧道
+	sync.RWMutex
+}
+
+func (s *Tunnel) StartTunnel() error {
+	var err error
+	s.listener, err = net.ListenTCP("tcp", &net.TCPAddr{net.ParseIP("0.0.0.0"), s.tunnelPort, ""})
+	if err != nil {
+		return err
+	}
+	go s.tunnelProcess()
+	return nil
+}
+
+//tcp server
+func (s *Tunnel) tunnelProcess() error {
+	var err error
+	for {
+		conn, err := s.listener.Accept()
+		if err != nil {
+			log.Println(err)
+			continue
+		}
+		go s.cliProcess(NewConn(conn))
+	}
+	return err
+}
+
+//验证失败,返回错误验证flag,并且关闭连接
+func (s *Tunnel) verifyError(c *Conn) {
+	c.conn.Write([]byte(VERIFY_EER))
+	c.conn.Close()
+}
+
+func (s *Tunnel) cliProcess(c *Conn) error {
+	c.conn.(*net.TCPConn).SetReadDeadline(time.Now().Add(time.Duration(5) * time.Second))
+	vval := make([]byte, 20)
+	_, err := c.conn.Read(vval)
+	if err != nil {
+		log.Println("客户端读超时。客户端地址为::", c.conn.RemoteAddr())
+		c.conn.Close()
+		return err
+	}
+	if bytes.Compare(vval, getverifyval()[:]) != 0 {
+		log.Println("当前客户端连接校验错误,关闭此客户端:", c.conn.RemoteAddr())
+		s.verifyError(c)
+		return err
+	}
+	//做一个判断 添加到对应的channel里面以供使用
+	flag, err := c.ReadFlag()
+	if err != nil {
+		return err
+	}
+	return s.typeDeal(flag, c)
+}
+
+//tcp连接类型区分
+func (s *Tunnel) typeDeal(typeVal string, c *Conn) error {
+	switch typeVal {
+	case WORK_MAIN:
+		s.signalList <- c
+	case WORK_CHAN:
+		s.tunnelList <- c
+	default:
+		return errors.New("无法识别")
+	}
+	c.SetAlive()
+	return nil
+}
+
+//新建隧道
+func (s *Tunnel) newChan() {
+retry:
+	connPass := <-s.signalList
+	_, err := connPass.conn.Write([]byte("chan"))
+	if err != nil {
+		fmt.Println(err)
+		goto retry
+	}
+	s.signalList <- connPass
+}

+ 148 - 0
util.go

@@ -0,0 +1,148 @@
+package main
+
+import (
+	"bufio"
+	"bytes"
+	"compress/gzip"
+	"encoding/binary"
+	"errors"
+	"fmt"
+	"net/http"
+	"net/http/httputil"
+	"net/url"
+	"strconv"
+	"strings"
+)
+
+var (
+	disabledRedirect = errors.New("disabled redirect.")
+)
+
+
+
+
+func BadRequest(w http.ResponseWriter) {
+	http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
+}
+
+
+
+//发送请求并转为bytes
+func GetEncodeResponse(req *http.Request) ([]byte, error) {
+	var respBytes []byte
+	client := new(http.Client)
+	client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
+		return disabledRedirect
+	}
+	resp, err := client.Do(req)
+	disRedirect := err != nil && strings.Contains(err.Error(), disabledRedirect.Error())
+	if err != nil && !disRedirect {
+		return respBytes, err
+	}
+	if !disRedirect {
+		defer resp.Body.Close()
+	} else {
+		resp.Body = nil
+		resp.ContentLength = 0
+	}
+	respBytes, err = EncodeResponse(resp)
+	return respBytes, nil
+}
+
+
+// 将request 的处理
+func EncodeRequest(r *http.Request) ([]byte, error) {
+	raw := bytes.NewBuffer([]byte{})
+	// 写签名
+	binary.Write(raw, binary.LittleEndian, []byte("sign"))
+	reqBytes, err := httputil.DumpRequest(r, true)
+	if err != nil {
+		return nil, err
+	}
+	// 写body数据长度 + 1
+	binary.Write(raw, binary.LittleEndian, int32(len(reqBytes)+1))
+	// 判断是否为http或者https的标识1字节
+	binary.Write(raw, binary.LittleEndian, bool(r.URL.Scheme == "https"))
+	if err := binary.Write(raw, binary.LittleEndian, reqBytes); err != nil {
+		return nil, err
+	}
+	return raw.Bytes(), nil
+}
+
+// 将字节转为request
+func DecodeRequest(data []byte) (*http.Request, error) {
+	if len(data) <= 100 {
+		return nil, errors.New("待解码的字节长度太小")
+	}
+	req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(data[1:])))
+	if err != nil {
+		return nil, err
+	}
+	str := strings.Split(req.Host, ":")
+	req.Host, err = getHost(str[0])
+	if err != nil {
+		return nil, err
+	}
+	scheme := "http"
+	if data[0] == 1 {
+		scheme = "https"
+	}
+	req.URL, _ = url.Parse(fmt.Sprintf("%s://%s%s", scheme, req.Host, req.RequestURI))
+	req.RequestURI = ""
+	return req, nil
+}
+
+//// 将response转为字节
+func EncodeResponse(r *http.Response) ([]byte, error) {
+	raw := bytes.NewBuffer([]byte{})
+	binary.Write(raw, binary.LittleEndian, []byte(RES_SIGN))
+	respBytes, err := httputil.DumpResponse(r, true)
+	if config.Replace == 1 {
+		respBytes = replaceHost(respBytes)
+	}
+	if err != nil {
+		return nil, err
+	}
+	var buf bytes.Buffer
+	zw := gzip.NewWriter(&buf)
+	zw.Write(respBytes)
+	zw.Close()
+	binary.Write(raw, binary.LittleEndian, int32(len(buf.Bytes())))
+	if err := binary.Write(raw, binary.LittleEndian, buf.Bytes()); err != nil {
+		fmt.Println(err)
+		return nil, err
+	}
+	return raw.Bytes(), nil
+}
+
+// 将字节转为response
+func DecodeResponse(data []byte) (*http.Response, error) {
+	zr, err := gzip.NewReader(bytes.NewReader(data))
+	if err != nil {
+		return nil, err
+	}
+	defer zr.Close()
+	resp, err := http.ReadResponse(bufio.NewReader(zr), nil)
+	if err != nil {
+		return nil, err
+	}
+	return resp, nil
+}
+
+func getHost(str string) (string, error) {
+	for _, v := range config.SiteList {
+		if v.Host == str {
+			return v.Url + ":" + strconv.Itoa(v.Port), nil
+		}
+	}
+	return "", errors.New("没有找到解析的的host!")
+}
+
+func replaceHost(resp []byte) []byte {
+	str := string(resp)
+	for _, v := range config.SiteList {
+		str = strings.Replace(str, v.Url+":"+strconv.Itoa(v.Port), v.Host, -1)
+		str = strings.Replace(str, v.Url, v.Host, -1)
+	}
+	return []byte(str)
+}