Bladeren bron

Bug修复+流量限制+带宽限制

刘河 6 jaren geleden
bovenliggende
commit
eccf1dbfb8

+ 20 - 26
README.md

@@ -21,26 +21,6 @@ easyProxy是一款轻量级、高性能、功能最为强大的**内网穿透**
 
 5. 搭建一个内网穿透ss,在外网如同使用内网vpn一样访问内网资源或者设备----> [socks5代理模式](#socks5代理模式)
 
-## 特点
-- [x] 支持snappy压缩,减小传输过程流量消耗
-- [x] 断线自动重连
-- [x] 支持多路传输,提高并发
-- [x] 跨站自动匹配替换
-- [x] 支持tcp隧道,提升访问效率
-- [x] 支持udp隧道
-- [x] 支持http代理
-- [x] 支持内网穿透sock5代理,配合proxifier可达到vpn的效果,在外网访问内网资源或者设备,同时可以设置用户名和密码验证
-- [x] 强大的web管理界面,可方便的设置的和管理隧道
-- [x] 支持站点密码保护
-- [x] 支持加密传输
-- [x] 支持TCP多路复用
-- [x] 支持同时开多条tcp、udp隧道等等,且只需要开一个客户端和服务端
-- [x] 支持一个服务端,多个客户端模式
-- [x] host修改支持
-- [x] 自定义header支持
-- [x] 流量统计
-- [x] 自定义404页面
-- [x] 热更新支持
 
 ## 目录
 
@@ -63,12 +43,15 @@ easyProxy是一款轻量级、高性能、功能最为强大的**内网穿透**
    * [TCP多路复用](#多路复用)
    * [host修改](#host修改)
    * [自定义header](#自定义header)
-   * [获取用户真实ip](#获取用户真实ip)
-   * [热更新支持](#热更新支持)
-   * [客户端地址显示](#客户端地址显示)
    * [自定义404页面](#404页面配置)
+   * [流量限制](#流量限制)
+   * [带宽限制](#带宽限制)
+* [相关说明](#相关说明)
    * [流量统计](#流量统计)
    * [连接池](#连接池)
+   * [热更新支持](#热更新支持)
+   * [获取用户真实ip](#获取用户真实ip)
+   * [客户端地址显示](#客户端地址显示)
 
 ## 安装
 
@@ -362,6 +345,20 @@ easyProxy支持通过 HTTP Basic Auth 来保护你的 web 服务,使用户需
 
 支持对header进行新增或者修改,以配合服务的需要
 
+### 404页面配置
+支持域名解析模式的自定义404页面,修改/web/static/page/error.html中内容即可,暂不支持静态文件等内容
+
+### 流量限制
+
+支持客户端级流量限制,当该客户端入口流量与出口流量达到设定的总量后会拒绝服务
+,域名代理会返回404页面,其他代理会拒绝连接
+
+### 带宽限制
+
+支持客户端级带宽限制,带宽计算方式为入口和出口总和,权重均衡
+
+## 相关说明
+
 ### 获取用户真实ip
 
 目前只有域名模式的代理支持这一功能,可以通过用户请求的 header 中的 X-Forwarded-For 和 X-Real-IP 来获取用户真实 IP。
@@ -374,9 +371,6 @@ easyProxy支持通过 HTTP Basic Auth 来保护你的 web 服务,使用户需
 ### 客户端地址显示
 在web管理中将显示客户端的连接地址
 
-### 404页面配置
-支持域名解析模式的自定义404页面,修改/web/static/page/error.html中内容即可,暂不支持静态文件等内容
-
 ### 流量统计
 可统计显示每个代理使用的流量,由于压缩和加密等原因,会和实际环境中的略有差异
 

+ 28 - 35
bridge/bridge.go

@@ -76,46 +76,55 @@ func (s *Bridge) tunnelProcess() error {
 
 //验证失败,返回错误验证flag,并且关闭连接
 func (s *Bridge) verifyError(c *utils.Conn) {
-	c.Conn.Write([]byte(utils.VERIFY_EER))
+	c.Write([]byte(utils.VERIFY_EER))
 	c.Conn.Close()
 }
 
-func (s *Bridge) cliProcess(c *utils.Conn) error {
-	c.Conn.(*net.TCPConn).SetReadDeadline(time.Now().Add(time.Duration(5) * time.Second))
-	vval := make([]byte, 32)
-	if _, err := c.Conn.Read(vval); err != nil {
-		log.Println("客户端读超时。客户端地址为::", c.Conn.RemoteAddr())
-		c.Conn.Close()
-		return err
+func (s *Bridge) cliProcess(c *utils.Conn) {
+	c.SetReadDeadline(5)
+	var buf []byte
+	var err error
+	if buf, err = c.ReadLen(32); err != nil {
+		c.Close()
+		return
 	}
-	id, err := utils.GetCsvDb().GetIdByVerifyKey(string(vval),c.Conn.RemoteAddr().String())
+	//验证
+	id, err := utils.GetCsvDb().GetIdByVerifyKey(string(buf), c.Conn.RemoteAddr().String())
 	if err != nil {
 		log.Println("当前客户端连接校验错误,关闭此客户端:", c.Conn.RemoteAddr())
 		s.verifyError(c)
-		return errors.New("验证错误")
+		return
 	}
-	c.Conn.(*net.TCPConn).SetReadDeadline(time.Time{})
 	//做一个判断 添加到对应的channel里面以供使用
-	if flag, err := c.ReadFlag(); err != nil {
-		return err
-	} else {
-		return s.typeDeal(flag, c, id)
+	if flag, err := c.ReadFlag(); err == nil {
+		s.typeDeal(flag, c, id)
+	}
+	return
+}
+
+func (s *Bridge) closeClient(id int) {
+	if len(s.SignalList) > 0 {
+		s.SignalList[id].Pop().WriteClose()
 	}
+	s.DelClientSignal(id)
+	s.DelClientTunnel(id)
 }
 
 //tcp连接类型区分
-func (s *Bridge) typeDeal(typeVal string, c *utils.Conn, id int) error {
+func (s *Bridge) typeDeal(typeVal string, c *utils.Conn, id int) {
 	switch typeVal {
 	case utils.WORK_MAIN:
+		//客户端已经存在,下线
+		if _, ok := s.SignalList[id]; ok {
+			s.closeClient(id)
+		}
 		log.Println("客户端连接成功", c.Conn.RemoteAddr())
 		s.addList(s.SignalList, c, id)
 	case utils.WORK_CHAN:
 		s.addList(s.TunnelList, c, id)
-	default:
-		return errors.New("无法识别")
 	}
 	c.SetAlive()
-	return nil
+	return
 }
 
 //加到对应的list中
@@ -131,23 +140,7 @@ func (s *Bridge) addList(m map[int]*list, c *utils.Conn, id int) {
 	s.lock.Unlock()
 }
 
-//新建隧道
-func (s *Bridge) newChan(id int) error {
-	var connPass *utils.Conn
-	var err error
-retry:
-	if connPass, err = s.waitAndPop(s.SignalList, id); err != nil {
-		return err
-	}
-	if _, err = connPass.Conn.Write([]byte("chan")); err != nil {
-		goto retry
-	}
-	s.SignalList[id].Add(connPass)
-	return nil
-}
-
 //得到一个tcp隧道
-//TODO 超时问题 锁机制问题 对单个客户端加锁
 func (s *Bridge) GetTunnel(id int, en, de int, crypt, mux bool) (c *utils.Conn, err error) {
 retry:
 	if c, err = s.waitAndPop(s.TunnelList, id); err != nil {

+ 5 - 2
client/client.go

@@ -12,6 +12,7 @@ import (
 type TRPClient struct {
 	svrAddr      string
 	tcpNum       int
+	connPoolSize int
 	tunnelNum    int64
 	tunnel       chan bool
 	serverStatus bool
@@ -26,6 +27,7 @@ func NewRPClient(svraddr string, tcpNum int, vKey string) *TRPClient {
 	c.tcpNum = tcpNum
 	c.vKey = vKey
 	c.tunnel = make(chan bool)
+	c.connPoolSize = 5
 	return c
 }
 
@@ -56,7 +58,6 @@ func (s *TRPClient) NewConn() error {
 	s.Unlock()
 	return s.processor(utils.NewConn(conn))
 }
-
 //处理
 func (s *TRPClient) processor(c *utils.Conn) error {
 	s.serverStatus = true
@@ -76,6 +77,8 @@ func (s *TRPClient) processor(c *utils.Conn) error {
 		case utils.VERIFY_EER:
 			log.Fatalln("vkey:", s.vKey, "不正确,服务端拒绝连接,请检查")
 		case utils.WORK_CHAN: //隧道模式,每次开启10个,加快连接速度
+		case utils.RES_CLOSE:
+			log.Fatal("该vkey被另一客户连接")
 		case utils.RES_MSG:
 			log.Println("服务端返回错误。")
 		default:
@@ -145,5 +148,5 @@ func (s *TRPClient) ConnectAndCopy(c *utils.Conn, typeStr, host string, en, de i
 		return
 	}
 	c.WriteSuccess()
-	utils.ReplayWaitGroup(c.Conn, server, en, de, crypt, mux)
+	utils.ReplayWaitGroup(c.Conn, server, en, de, crypt, mux, nil)
 }

+ 1 - 2
conf/clients.csv

@@ -1,2 +1 @@
-1,rfd0tl1anega0d0g,127.0.0.1:53603,测试,true,1,1,1,1,snappy
-2,zl4p3da659qa9rh3,127.0.0.1:52096,测试2,true,1,1,1,1,snappy
+1,wuz1nozs9dhtxic6,,true,,,0,0,,0,1

+ 1 - 2
conf/hosts.csv

@@ -1,2 +1 @@
-b.o.com,127.0.0.1:8082,2,,,测试
-a.o.com,127.0.0.1:8080,1,Connection: close,,测试2
+a.o.com,127.0.0.1:8082,1,Connection:close,,

+ 1 - 5
conf/tasks.csv

@@ -1,5 +1 @@
-53,udpServer,114.114.114.114:53,,,,1,0,0,0,1,2,2,true,udp测试
-9001,tunnelServer,127.0.0.1:8080,1,1,snappy,1,1,1,0,0,1,1,false,test
-9009,tunnelServer,127.0.0.1:5900,,,,1,0,0,0,0,5,2,true,vnc
-8025,httpProxyServer,,2,2,snappy,1,1,1,0,0,4,2,false,http测试
-8024,socks5Server,,,,,1,0,0,0,0,3,2,false,socks5测试
+9001,tunnelServer,127.0.0.1:8082,,,,1,0,0,0,0,1,1,true,

+ 8 - 3
server/base.go

@@ -39,16 +39,19 @@ func (s *server) FlowAddHost(host *utils.Host, in, out int64) {
 }
 
 //热更新配置
-func (s *server) ResetConfig() {
+func (s *server) ResetConfig() bool {
 	//获取最新数据
 	task, err := CsvDb.GetTask(s.task.Id)
 	if err != nil {
-		return
+		return false
+	}
+	if s.task.Client.Flow.FlowLimit > 0 && (s.task.Client.Flow.FlowLimit<<20) < (s.task.Client.Flow.ExportFlow+s.task.Client.Flow.InletFlow) {
+		return false
 	}
 	s.task.UseClientCnf = task.UseClientCnf
 	//使用客户端配置
+	client, err := CsvDb.GetClient(s.task.Client.Id)
 	if s.task.UseClientCnf {
-		client, err := CsvDb.GetClient(s.task.Client.Id)
 		if err == nil {
 			s.config.U = client.Cnf.U
 			s.config.P = client.Cnf.P
@@ -65,5 +68,7 @@ func (s *server) ResetConfig() {
 			s.config.Crypt = task.Config.Crypt
 		}
 	}
+	s.task.Client.Rate = client.Rate
 	s.config.CompressDecode, s.config.CompressEncode = utils.GetCompressType(s.config.Compress)
+	return true
 }

+ 17 - 5
server/process.go

@@ -3,6 +3,7 @@ package server
 import (
 	"bufio"
 	"github.com/cnlh/easyProxy/utils"
+	"github.com/pkg/errors"
 	"log"
 	"net/http"
 	"net/http/httputil"
@@ -13,11 +14,19 @@ type process func(c *utils.Conn, s *TunnelModeServer) error
 
 //tcp隧道模式
 func ProcessTunnel(c *utils.Conn, s *TunnelModeServer) error {
+	if !s.ResetConfig() {
+		c.Close()
+		return errors.New("流量超出")
+	}
 	return s.dealClient(c, s.config, s.task.Target, "", nil)
 }
 
 //http代理模式
 func ProcessHttp(c *utils.Conn, s *TunnelModeServer) error {
+	if !s.ResetConfig() {
+		c.Close()
+		return errors.New("流量超出")
+	}
 	method, addr, rb, err, r := c.GetHost()
 	if err != nil {
 		log.Println(err)
@@ -49,9 +58,12 @@ func ProcessHost(c *utils.Conn, s *TunnelModeServer) error {
 				log.Printf("the host %s is not found !", r.Host)
 				break
 			}
-
+			//流量限制
+			if host.Client.Flow.FlowLimit > 0 && (host.Client.Flow.FlowLimit<<20) < (host.Client.Flow.ExportFlow+host.Client.Flow.InletFlow) {
+				break
+			}
 			host.Client.Cnf.CompressDecode, host.Client.Cnf.CompressEncode = utils.GetCompressType(host.Client.Cnf.Compress)
-
+			//权限控制
 			if err = s.auth(r, c, host.Client.Cnf.U, host.Client.Cnf.P); err != nil {
 				break
 			}
@@ -65,7 +77,7 @@ func ProcessHost(c *utils.Conn, s *TunnelModeServer) error {
 			} else {
 				wg.Add(1)
 				go func() {
-					out, _ := utils.Relay(c.Conn, link.Conn, host.Client.Cnf.CompressDecode, host.Client.Cnf.Crypt, host.Client.Cnf.Mux)
+					out, _ := utils.Relay(c.Conn, link.Conn, host.Client.Cnf.CompressDecode, host.Client.Cnf.Crypt, host.Client.Cnf.Mux, host.Client.Rate)
 					wg.Done()
 					s.FlowAddHost(host, 0, out)
 				}()
@@ -79,13 +91,13 @@ func ProcessHost(c *utils.Conn, s *TunnelModeServer) error {
 			break
 		}
 		s.FlowAddHost(host, int64(len(b)), 0)
-		if _, err := link.WriteTo(b, host.Client.Cnf.CompressEncode, host.Client.Cnf.Crypt); err != nil {
+		if _, err := link.WriteTo(b, host.Client.Cnf.CompressEncode, host.Client.Cnf.Crypt, host.Client.Rate); err != nil {
 			break
 		}
 	}
 	wg.Wait()
 	if host != nil && host.Client.Cnf != nil && host.Client.Cnf.Mux && link != nil {
-		link.WriteTo([]byte(utils.IO_EOF), host.Client.Cnf.CompressEncode, host.Client.Cnf.Crypt)
+		link.WriteTo([]byte(utils.IO_EOF), host.Client.Cnf.CompressEncode, host.Client.Cnf.Crypt, host.Client.Rate)
 		s.bridge.ReturnTunnel(link, host.Client.Id)
 	} else if link != nil {
 		link.Close()

+ 6 - 3
server/socks5.go

@@ -166,7 +166,7 @@ func (s *Sock5ModeServer) handleConnect(c net.Conn) {
 	if err != nil {
 		c.Close()
 	} else {
-		out, in := utils.ReplayWaitGroup(proxyConn.Conn, c, s.config.CompressEncode, s.config.CompressDecode, s.config.Crypt, s.config.Mux)
+		out, in := utils.ReplayWaitGroup(proxyConn.Conn, c, s.config.CompressEncode, s.config.CompressDecode, s.config.Crypt, s.config.Mux, s.task.Client.Rate)
 		s.FlowAdd(in, out)
 	}
 }
@@ -204,7 +204,7 @@ func (s *Sock5ModeServer) handleUDP(c net.Conn) {
 	if err != nil {
 		c.Close()
 	} else {
-		out, in := utils.ReplayWaitGroup(proxyConn.Conn, c, s.config.CompressEncode, s.config.CompressDecode, s.config.Crypt, s.config.Mux)
+		out, in := utils.ReplayWaitGroup(proxyConn.Conn, c, s.config.CompressEncode, s.config.CompressDecode, s.config.Crypt, s.config.Mux, s.task.Client.Rate)
 		s.FlowAdd(in, out)
 	}
 }
@@ -297,7 +297,10 @@ func (s *Sock5ModeServer) Start() error {
 			}
 			log.Fatal("accept error: ", err)
 		}
-		s.ResetConfig()
+		if !s.ResetConfig() {
+			conn.Close()
+			continue
+		}
 		go s.handleConn(conn)
 	}
 	return nil

+ 2 - 3
server/tcp.go

@@ -48,7 +48,6 @@ func (s *TunnelModeServer) Start() error {
 			log.Println(err)
 			continue
 		}
-		s.ResetConfig()
 		go s.process(utils.NewConn(conn), s)
 	}
 	return nil
@@ -87,9 +86,9 @@ func (s *TunnelModeServer) dealClient(c *utils.Conn, cnf *utils.Config, addr str
 			if method == "CONNECT" {
 				fmt.Fprint(c, "HTTP/1.1 200 Connection established\r\n")
 			} else if rb != nil {
-				link.WriteTo(rb, cnf.CompressEncode, cnf.Crypt)
+				link.WriteTo(rb, cnf.CompressEncode, cnf.Crypt, s.task.Client.Rate)
 			}
-			out, in := utils.ReplayWaitGroup(link.Conn, c.Conn, cnf.CompressEncode, cnf.CompressDecode, cnf.Crypt, cnf.Mux)
+			out, in := utils.ReplayWaitGroup(link.Conn, c.Conn, cnf.CompressEncode, cnf.CompressDecode, cnf.Crypt, cnf.Mux, s.task.Client.Rate)
 			s.FlowAdd(in, out)
 		}
 	}

+ 6 - 4
server/udp.go

@@ -40,7 +40,9 @@ func (s *UdpModeServer) Start() error {
 			}
 			continue
 		}
-		s.ResetConfig()
+		if !s.ResetConfig() {
+			continue
+		}
 		go s.process(addr, data[:n])
 	}
 	return nil
@@ -60,16 +62,16 @@ func (s *UdpModeServer) process(addr *net.UDPAddr, data []byte) {
 	if flag, err := conn.ReadFlag(); err == nil {
 		defer func() {
 			if conn != nil && s.config.Mux {
-				conn.WriteTo([]byte(utils.IO_EOF), s.config.CompressEncode, s.config.Crypt)
+				conn.WriteTo([]byte(utils.IO_EOF), s.config.CompressEncode, s.config.Crypt, s.task.Client.Rate)
 				s.bridge.ReturnTunnel(conn, s.task.Client.Id)
 			} else {
 				conn.Close()
 			}
 		}()
 		if flag == utils.CONN_SUCCESS {
-			in, _ := conn.WriteTo(data, s.config.CompressEncode, s.config.Crypt)
+			in, _ := conn.WriteTo(data, s.config.CompressEncode, s.config.Crypt, s.task.Client.Rate)
 			buf := utils.BufPoolUdp.Get().([]byte)
-			out, err := conn.ReadFrom(buf, s.config.CompressDecode, s.config.Crypt)
+			out, err := conn.ReadFrom(buf, s.config.CompressDecode, s.config.Crypt, s.task.Client.Rate)
 			if err != nil || err == io.EOF {
 				return
 			}

+ 36 - 9
utils/conn.go

@@ -21,12 +21,14 @@ const cryptKey = "1234567812345678"
 type CryptConn struct {
 	conn  net.Conn
 	crypt bool
+	rate  *Rate
 }
 
-func NewCryptConn(conn net.Conn, crypt bool) *CryptConn {
+func NewCryptConn(conn net.Conn, crypt bool, rate *Rate) *CryptConn {
 	c := new(CryptConn)
 	c.conn = conn
 	c.crypt = crypt
+	c.rate = rate
 	return c
 }
 
@@ -42,6 +44,9 @@ func (s *CryptConn) Write(b []byte) (n int, err error) {
 		return
 	}
 	_, err = s.conn.Write(b)
+	if s.rate != nil {
+		s.rate.Get(int64(n))
+	}
 	return
 }
 
@@ -72,6 +77,9 @@ func (s *CryptConn) Read(b []byte) (n int, err error) {
 	}
 	copy(b, rb)
 	n = len(rb)
+	if s.rate != nil {
+		s.rate.Get(int64(n))
+	}
 	return
 }
 
@@ -79,13 +87,15 @@ type SnappyConn struct {
 	w     *snappy.Writer
 	r     *snappy.Reader
 	crypt bool
+	rate  *Rate
 }
 
-func NewSnappyConn(conn net.Conn, crypt bool) *SnappyConn {
+func NewSnappyConn(conn net.Conn, crypt bool, rate *Rate) *SnappyConn {
 	c := new(SnappyConn)
 	c.w = snappy.NewBufferedWriter(conn)
 	c.r = snappy.NewReader(conn)
 	c.crypt = crypt
+	c.rate = rate
 	return c
 }
 
@@ -101,7 +111,12 @@ func (s *SnappyConn) Write(b []byte) (n int, err error) {
 	if _, err = s.w.Write(b); err != nil {
 		return
 	}
-	err = s.w.Flush()
+	if err = s.w.Flush(); err != nil {
+		return
+	}
+	if s.rate != nil {
+		s.rate.Get(int64(n))
+	}
 	return
 }
 
@@ -129,6 +144,9 @@ func (s *SnappyConn) Read(b []byte) (n int, err error) {
 	}
 	n = len(bs)
 	copy(b, bs)
+	if s.rate != nil {
+		s.rate.Get(int64(n))
+	}
 	return
 }
 
@@ -233,6 +251,10 @@ func (s *Conn) SetAlive() {
 	conn.SetKeepAlivePeriod(time.Duration(2 * time.Second))
 }
 
+func (s *Conn) SetReadDeadline(t time.Duration) {
+	s.Conn.(*net.TCPConn).SetReadDeadline(time.Now().Add(time.Duration(t) * time.Second))
+}
+
 //从tcp报文中解析出host,连接类型等 TODO 多种情况
 func (s *Conn) GetHost() (method, address string, rb []byte, err error, r *http.Request) {
 	var b [32 * 1024]byte
@@ -264,19 +286,19 @@ func (s *Conn) GetHost() (method, address string, rb []byte, err error, r *http.
 }
 
 //单独读(加密|压缩)
-func (s *Conn) ReadFrom(b []byte, compress int, crypt bool) (int, error) {
+func (s *Conn) ReadFrom(b []byte, compress int, crypt bool, rate *Rate) (int, error) {
 	if COMPRESS_SNAPY_DECODE == compress {
-		return NewSnappyConn(s.Conn, crypt).Read(b)
+		return NewSnappyConn(s.Conn, crypt, rate).Read(b)
 	}
-	return NewCryptConn(s.Conn, crypt).Read(b)
+	return NewCryptConn(s.Conn, crypt, rate).Read(b)
 }
 
 //单独写(加密|压缩)
-func (s *Conn) WriteTo(b []byte, compress int, crypt bool) (n int, err error) {
+func (s *Conn) WriteTo(b []byte, compress int, crypt bool, rate *Rate) (n int, err error) {
 	if COMPRESS_SNAPY_ENCODE == compress {
-		return NewSnappyConn(s.Conn, crypt).Write(b)
+		return NewSnappyConn(s.Conn, crypt, rate).Write(b)
 	}
-	return NewCryptConn(s.Conn, crypt).Write(b)
+	return NewCryptConn(s.Conn, crypt, rate).Write(b)
 }
 
 //写压缩方式,加密
@@ -322,6 +344,11 @@ func (s *Conn) WriteSign() (int, error) {
 	return s.Write([]byte(RES_SIGN))
 }
 
+//write sign flag
+func (s *Conn) WriteClose() (int, error) {
+	return s.Write([]byte(RES_CLOSE))
+}
+
 //write main
 func (s *Conn) WriteMain() (int, error) {
 	return s.Write([]byte(WORK_MAIN))

+ 23 - 11
utils/file.go

@@ -8,6 +8,7 @@ import (
 	"log"
 	"os"
 	"strconv"
+	"strings"
 	"sync"
 )
 
@@ -19,6 +20,7 @@ var (
 type Flow struct {
 	ExportFlow int64 //出口流量
 	InletFlow  int64 //入口流量
+	FlowLimit  int64 //流量限制,出口+入口 /M
 }
 
 type Client struct {
@@ -29,7 +31,9 @@ type Client struct {
 	Remark    string //备注
 	Status    bool   //是否开启
 	IsConnect bool   //是否连接
-	Flow      *Flow
+	RateLimit int    //速度限制 /kb
+	Flow      *Flow  //流量
+	Rate      *Rate  //速度控制
 }
 
 type Tunnel struct {
@@ -189,7 +193,9 @@ func (s *Csv) GetIdByVerifyKey(vKey string, addr string) (int, error) {
 	defer s.Unlock()
 	for _, v := range s.Clients {
 		if utils.Getverifyval(v.VerifyKey) == vKey && v.Status {
-			v.Addr = addr
+			if arr := strings.Split(addr, ":"); len(arr) > 0 {
+				v.Addr = arr[0]
+			}
 			return v.Id, nil
 		}
 	}
@@ -276,21 +282,26 @@ func (s *Csv) LoadClientFromCsv() {
 		post := &Client{
 			Id:        GetIntNoErrByStr(item[0]),
 			VerifyKey: item[1],
-			Addr:      item[2],
-			Remark:    item[3],
-			Status:    GetBoolByStr(item[4]),
+			Remark:    item[2],
+			Status:    GetBoolByStr(item[3]),
+			RateLimit: GetIntNoErrByStr(item[9]),
 			Cnf: &Config{
-				U:        item[5],
-				P:        item[6],
-				Crypt:    GetBoolByStr(item[7]),
-				Mux:      GetBoolByStr(item[8]),
-				Compress: item[9],
+				U:        item[4],
+				P:        item[5],
+				Crypt:    GetBoolByStr(item[6]),
+				Mux:      GetBoolByStr(item[7]),
+				Compress: item[8],
 			},
 		}
 		if post.Id > s.ClientIncreaseId {
 			s.ClientIncreaseId = post.Id
 		}
+		if post.RateLimit > 0 {
+			post.Rate = NewRate(int64(post.RateLimit * 1024))
+			post.Rate.Start()
+		}
 		post.Flow = new(Flow)
+		post.Flow.FlowLimit = int64(utils.GetIntNoerrByStr(item[10]))
 		clients = append(clients, post)
 	}
 	s.Clients = clients
@@ -442,7 +453,6 @@ func (s *Csv) StoreClientsToCsv() {
 		record := []string{
 			strconv.Itoa(client.Id),
 			client.VerifyKey,
-			client.Addr,
 			client.Remark,
 			strconv.FormatBool(client.Status),
 			client.Cnf.U,
@@ -450,6 +460,8 @@ func (s *Csv) StoreClientsToCsv() {
 			utils.GetStrByBool(client.Cnf.Crypt),
 			utils.GetStrByBool(client.Cnf.Mux),
 			client.Cnf.Compress,
+			strconv.Itoa(client.RateLimit),
+			strconv.Itoa(int(client.Flow.FlowLimit)),
 		}
 		err := writer.Write(record)
 		if err != nil {

+ 1 - 0
utils/pool.go

@@ -12,6 +12,7 @@ var bufPool = sync.Pool{
 		return make([]byte, poolSize)
 	},
 }
+
 var BufPoolUdp = sync.Pool{
 	New: func() interface{} {
 		return make([]byte, poolSizeUdp)

+ 74 - 0
utils/rate.go

@@ -0,0 +1,74 @@
+package utils
+
+import (
+	"sync/atomic"
+	"time"
+)
+
+type Rate struct {
+	bucketSize        int64     //木桶容量
+	bucketSurplusSize int64     //当前桶中体积
+	bucketAddSize     int64     //每次加水大小
+	stopChan          chan bool //停止
+}
+
+func NewRate(addSize int64) *Rate {
+	return &Rate{
+		bucketSize:        addSize * 2,
+		bucketSurplusSize: 0,
+		bucketAddSize:     addSize,
+		stopChan:          make(chan bool),
+	}
+}
+
+func (s *Rate) Start() {
+	go s.session()
+}
+
+func (s *Rate) add(size int64) {
+	if (s.bucketSize - s.bucketSurplusSize) < s.bucketAddSize {
+		return
+	}
+	atomic.AddInt64(&s.bucketSurplusSize, size)
+}
+
+//回桶
+func (s *Rate) ReturnBucket(size int64) {
+	s.add(size)
+}
+
+//停止
+func (s *Rate) Stop() {
+	s.stopChan <- true
+}
+
+func (s *Rate) Get(size int64) {
+	if s.bucketSurplusSize >= size {
+		atomic.AddInt64(&s.bucketSurplusSize, -size)
+		return
+	}
+	ticker := time.NewTicker(time.Millisecond * 100)
+	for {
+		select {
+		case <-ticker.C:
+			if s.bucketSurplusSize >= size {
+				atomic.AddInt64(&s.bucketSurplusSize, -size)
+				ticker.Stop()
+				return
+			}
+		}
+	}
+}
+
+func (s *Rate) session() {
+	ticker := time.NewTicker(time.Second * 1)
+	for {
+		select {
+		case <-ticker.C:
+			s.add(s.bucketAddSize)
+		case <-s.stopChan:
+			ticker.Stop()
+			return
+		}
+	}
+}

+ 23 - 0
utils/rate_test.go

@@ -0,0 +1,23 @@
+package utils
+
+import (
+	"log"
+	"testing"
+)
+
+var rate = NewRate(100 * 1024)
+
+func TestRate_Get(t *testing.T) {
+	rate.Start()
+	for i := 0; i < 5; i++ {
+		go test(i)
+	}
+	test(5)
+}
+
+func test(i int) {
+	for {
+		rate.Get(64 * 1024)
+		log.Println("get ok", i)
+	}
+}

+ 11 - 10
utils/util.go

@@ -25,6 +25,7 @@ const (
 	WORK_CHAN         = "chan"
 	RES_SIGN          = "sign"
 	RES_MSG           = "msg0"
+	RES_CLOSE         = "clse"
 	CONN_SUCCESS      = "sucs"
 	CONN_ERROR        = "fail"
 	TEST_FLAG         = "tst"
@@ -42,24 +43,24 @@ WWW-Authenticate: Basic realm="easyProxy"
 )
 
 //copy
-func Relay(in, out net.Conn, compressType int, crypt, mux bool) (n int64, err error) {
+func Relay(in, out net.Conn, compressType int, crypt, mux bool, rate *Rate) (n int64, err error) {
 	switch compressType {
 	case COMPRESS_SNAPY_ENCODE:
-		n, err = copyBuffer(NewSnappyConn(in, crypt), out)
+		n, err = copyBuffer(NewSnappyConn(in, crypt, rate), out)
 		out.Close()
-		NewSnappyConn(in, crypt).Write([]byte(IO_EOF))
+		NewSnappyConn(in, crypt, rate).Write([]byte(IO_EOF))
 	case COMPRESS_SNAPY_DECODE:
-		n, err = copyBuffer(in, NewSnappyConn(out, crypt))
+		n, err = copyBuffer(in, NewSnappyConn(out, crypt, rate))
 		in.Close()
 		if !mux {
 			out.Close()
 		}
 	case COMPRESS_NONE_ENCODE:
-		n, err = copyBuffer(NewCryptConn(in, crypt), out)
+		n, err = copyBuffer(NewCryptConn(in, crypt, rate), out)
 		out.Close()
-		NewCryptConn(in, crypt).Write([]byte(IO_EOF))
+		NewCryptConn(in, crypt, rate).Write([]byte(IO_EOF))
 	case COMPRESS_NONE_DECODE:
-		n, err = copyBuffer(in, NewCryptConn(out, crypt))
+		n, err = copyBuffer(in, NewCryptConn(out, crypt, rate))
 		in.Close()
 		if !mux {
 			out.Close()
@@ -205,14 +206,14 @@ func Getverifyval(vkey string) string {
 
 //wait replay group
 //conn1 网桥 conn2
-func ReplayWaitGroup(conn1 net.Conn, conn2 net.Conn, compressEncode, compressDecode int, crypt, mux bool) (out int64, in int64) {
+func ReplayWaitGroup(conn1 net.Conn, conn2 net.Conn, compressEncode, compressDecode int, crypt, mux bool, rate *Rate) (out int64, in int64) {
 	var wg sync.WaitGroup
 	wg.Add(1)
 	go func() {
-		in, _ = Relay(conn1, conn2, compressEncode, crypt, mux)
+		in, _ = Relay(conn1, conn2, compressEncode, crypt, mux, rate)
 		wg.Done()
 	}()
-	out, _ = Relay(conn2, conn1, compressDecode, crypt, mux)
+	out, _ = Relay(conn2, conn1, compressDecode, crypt, mux, rate)
 	wg.Wait()
 	return
 }

+ 21 - 0
web/controllers/client.go

@@ -40,6 +40,16 @@ func (s *ClientController) Add() {
 				Crypt:    s.GetBoolNoErr("crypt"),
 				Mux:      s.GetBoolNoErr("mux"),
 			},
+			RateLimit: s.GetIntNoErr("rate_limit"),
+			Flow: &utils.Flow{
+				ExportFlow: 0,
+				InletFlow:  0,
+				FlowLimit:  int64(s.GetIntNoErr("flow_limit")),
+			},
+		}
+		if t.RateLimit > 0 {
+			t.Rate = utils.NewRate(int64(t.RateLimit * 1024))
+			t.Rate.Start()
 		}
 		server.CsvDb.NewClient(t)
 		s.AjaxOk("添加成功")
@@ -69,6 +79,17 @@ func (s *ClientController) Edit() {
 			c.Cnf.Compress = s.GetString("compress")
 			c.Cnf.Crypt = s.GetBoolNoErr("crypt")
 			c.Cnf.Mux = s.GetBoolNoErr("mux")
+			c.Flow.FlowLimit = int64(s.GetIntNoErr("flow_limit"))
+			c.RateLimit = s.GetIntNoErr("rate_limit")
+			if c.Rate != nil {
+				c.Rate.Stop()
+			}
+			if c.RateLimit > 0 {
+				c.Rate = utils.NewRate(int64(c.RateLimit * 1024))
+				c.Rate.Start()
+			} else {
+				c.Rate = nil
+			}
 			server.CsvDb.UpdateClient(c)
 		}
 		s.AjaxOk("修改成功")

+ 13 - 4
web/controllers/index.go

@@ -87,6 +87,7 @@ func (s *IndexController) Add() {
 			UseClientCnf: s.GetBoolNoErr("use_client"),
 			Status:       true,
 			Remark:       s.GetString("remark"),
+			Flow:         &utils.Flow{},
 		}
 		var err error
 		if t.Client, err = server.CsvDb.GetClient(s.GetIntNoErr("client_id")); err != nil {
@@ -127,6 +128,9 @@ func (s *IndexController) Edit() {
 			t.Config.Mux = s.GetBoolNoErr("mux")
 			t.UseClientCnf = s.GetBoolNoErr("use_client")
 			t.Remark = s.GetString("remark")
+			if t.Client, err = server.CsvDb.GetClient(s.GetIntNoErr("client_id")); err != nil {
+				s.AjaxErr("修改失败")
+			}
 			server.CsvDb.UpdateTask(t)
 		}
 		s.AjaxOk("修改成功")
@@ -187,14 +191,16 @@ func (s *IndexController) AddHost() {
 		s.display("index/hadd")
 	} else {
 		h := &utils.Host{
-			Client: &utils.Client{
-				Id: s.GetIntNoErr("client_id"),
-			},
 			Host:         s.GetString("host"),
 			Target:       s.GetString("target"),
 			HeaderChange: s.GetString("header"),
 			HostChange:   s.GetString("hostchange"),
 			Remark:       s.GetString("remark"),
+			Flow:         &utils.Flow{},
+		}
+		var err error
+		if h.Client, err = server.CsvDb.GetClient(s.GetIntNoErr("client_id")); err != nil {
+			s.AjaxErr("添加失败")
 		}
 		server.CsvDb.NewHost(h)
 		s.AjaxOk("添加成功")
@@ -216,13 +222,16 @@ func (s *IndexController) EditHost() {
 		if h, err := server.GetInfoByHost(host); err != nil {
 			s.error()
 		} else {
-			h.Client.Id = s.GetIntNoErr("client_id")
 			h.Host = s.GetString("nhost")
 			h.Target = s.GetString("target")
 			h.HeaderChange = s.GetString("header")
 			h.HostChange = s.GetString("hostchange")
 			h.Remark = s.GetString("remark")
 			server.CsvDb.UpdateHost(h)
+			var err error
+			if h.Client, err = server.CsvDb.GetClient(s.GetIntNoErr("client_id")); err != nil {
+				s.AjaxErr("修改失败")
+			}
 		}
 		s.AjaxOk("修改成功")
 	}

+ 8 - 0
web/views/client/add.html

@@ -8,6 +8,14 @@
                         <label class="control-label">备注</label>
                         <input class="form-control" type="text" name="Remark" placeholder="客户端备注">
                     </div>
+                    <div class="form-group" id="flow_limit">
+                        <label class="control-label">流量限制(单位:M,为空不限制)</label>
+                        <input class="form-control" type="text" name="flow_limit" placeholder="为空不限制">
+                    </div>
+                    <div class="form-group" id="rate_limit">
+                        <label class="control-label">速度限制(单位:KB,为空不限制)</label>
+                        <input class="form-control" type="text" name="rate_limit" placeholder="为空不限制">
+                    </div>
                     <div class="form-group" id="u">
                         <label class="control-label">验证用户名(仅socks5,web穿透支持)</label>
                         <input class="form-control" type="text" name="u" placeholder="不填则无需验证">

+ 10 - 0
web/views/client/edit.html

@@ -9,6 +9,16 @@
                         <label class="control-label">备注</label>
                         <input class="form-control" value="{{.c.Remark}}" type="text" name="Remark" placeholder="客户端备注">
                     </div>
+                    <div class="form-group" id="flow_limit">
+                        <label class="control-label">流量限制(单位:M,为空不限制)</label>
+                        <input class="form-control" value="{{.c.Flow.FlowLimit}}" type="text" name="flow_limit"
+                               placeholder="为空不限制">
+                    </div>
+                    <div class="form-group" id="rate_limit">
+                        <label class="control-label">速度限制(单位:KB,为空不限制)</label>
+                        <input class="form-control" value="{{.c.RateLimit}}" type="text" name="rate_limit"
+                               placeholder="为空不限制">
+                    </div>
                     <div class="form-group" id="u">
                         <label class="control-label">验证用户名(仅socks5,web穿透支持)</label>
                         <input class="form-control" value="{{.c.Cnf.U}}" type="text" name="u"

+ 1 - 1
web/views/index/hlist.html

@@ -99,7 +99,7 @@
                 }
             },
                 {
-                    targets: 1,
+                    targets: 0,
                     render: function (data, type, row, meta) {
                         return row.Client.Id
                     }