Selaa lähdekoodia

Udp 多路复用 优化

刘河 6 vuotta sitten
vanhempi
commit
05e66af647
6 muutettua tiedostoa jossa 43 lisäystä ja 25 poistoa
  1. 2 1
      lib/bridge.go
  2. 1 0
      lib/client.go
  3. 8 4
      lib/crypt.go
  4. 2 2
      lib/tcp.go
  5. 17 18
      lib/udp.go
  6. 13 0
      lib/util.go

+ 2 - 1
lib/bridge.go

@@ -35,7 +35,7 @@ type Tunnel struct {
 	signalList map[string]*list //通信
 	tunnelList map[string]*list //隧道
 	lock       sync.Mutex
-	tunnelLock      sync.Mutex
+	tunnelLock sync.Mutex
 }
 
 func newTunnel(tunnelPort int) *Tunnel {
@@ -181,6 +181,7 @@ func (s *Tunnel) ReturnSignal(conn *Conn, cFlag string) {
 //重回slice 复用
 func (s *Tunnel) ReturnTunnel(conn *Conn, cFlag string) {
 	if v, ok := s.tunnelList[cFlag]; ok {
+		FlushConn(conn.conn)
 		v.Add(conn)
 	}
 }

+ 1 - 0
lib/client.go

@@ -123,6 +123,7 @@ re:
 	relay(c.conn, server, en, crypt, mux)
 end:
 	if mux {
+		FlushConn(conn)
 		goto re
 	} else {
 		c.Close()

+ 8 - 4
lib/crypt.go

@@ -6,6 +6,7 @@ import (
 	"crypto/cipher"
 	"crypto/md5"
 	"encoding/hex"
+	"github.com/pkg/errors"
 	"math/rand"
 	"time"
 )
@@ -38,9 +39,9 @@ func AesDecrypt(crypted, key []byte) ([]byte, error) {
 	origData := make([]byte, len(crypted))
 	// origData := crypted
 	blockMode.CryptBlocks(origData, crypted)
-	origData = PKCS5UnPadding(origData)
+	err, origData = PKCS5UnPadding(origData)
 	// origData = ZeroUnPadding(origData)
-	return origData, nil
+	return origData, err
 }
 
 //补全
@@ -51,11 +52,14 @@ func PKCS5Padding(ciphertext []byte, blockSize int) []byte {
 }
 
 //去补
-func PKCS5UnPadding(origData []byte) []byte {
+func PKCS5UnPadding(origData []byte) (error, []byte) {
 	length := len(origData)
 	// 去掉最后一个字节 unpadding 次
 	unpadding := int(origData[length-1])
-	return origData[:(length - unpadding)]
+	if (length - unpadding) < 0 {
+		return errors.New("len error"), nil
+	}
+	return nil, origData[:(length - unpadding)]
 }
 
 //生成32位md5字串

+ 2 - 2
lib/tcp.go

@@ -190,6 +190,8 @@ func (s *TunnelModeServer) dealClient(c *Conn, cnf *ServerConfig, addr string, m
 	defer func() {
 		if cnf.Mux {
 			s.bridge.ReturnTunnel(link, getverifyval(cnf.VerifyKey))
+		} else {
+			c.Close()
 		}
 	}()
 	if err != nil {
@@ -212,8 +214,6 @@ func (s *TunnelModeServer) dealClient(c *Conn, cnf *ServerConfig, addr string, m
 			}
 			go relay(link.conn, c.conn, cnf.CompressEncode, cnf.Crypt, cnf.Mux)
 			relay(c.conn, link.conn, cnf.CompressDecode, cnf.Crypt, cnf.Mux)
-		} else {
-			c.Close()
 		}
 	}
 	return nil

+ 17 - 18
lib/udp.go

@@ -5,7 +5,6 @@ import (
 	"log"
 	"net"
 	"strings"
-	"time"
 )
 
 type UdpModeServer struct {
@@ -54,25 +53,25 @@ func (s *UdpModeServer) process(addr *net.UDPAddr, data []byte) {
 		conn.Close()
 		return
 	}
-	conn.WriteTo(data, s.config.CompressEncode, s.config.Crypt)
 	if flag, err := conn.ReadFlag(); err == nil {
-		if flag == CONN_SUCCESS {
-			go func(addr *net.UDPAddr, conn *Conn) {
-				defer func() {
-					if s.config.Mux {
-						s.bridge.ReturnTunnel(conn, getverifyval(s.config.VerifyKey))
-					}
-				}()
-				buf := make([]byte, 1024)
-				conn.conn.SetReadDeadline(time.Now().Add(time.Duration(time.Second * 3)))
-				n, err := conn.ReadFrom(buf, s.config.CompressDecode, s.config.Crypt)
-				if err != nil || err == io.EOF {
-					conn.Close()
-					return
-				}
-				s.listener.WriteToUDP(buf[:n], addr)
+		defer func() {
+			if s.config.Mux {
+				s.bridge.ReturnTunnel(conn, getverifyval(s.config.VerifyKey))
+			} else {
 				conn.Close()
-			}(addr, conn)
+			}
+		}()
+		if flag == CONN_SUCCESS {
+			conn.WriteTo(data, s.config.CompressEncode, s.config.Crypt)
+			buf := make([]byte, 1024)
+			//conn.conn.SetReadDeadline(time.Now().Add(time.Duration(time.Second * 3)))
+			n, err := conn.ReadFrom(buf, s.config.CompressDecode, s.config.Crypt)
+			if err != nil || err == io.EOF {
+				log.Println("revieve error:", err)
+				return
+			}
+			s.listener.WriteToUDP(buf[:n], addr)
+			conn.WriteTo([]byte(IO_EOF), s.config.CompressEncode, s.config.Crypt)
 		}
 	}
 }

+ 13 - 0
lib/util.go

@@ -17,6 +17,7 @@ import (
 	"strconv"
 	"strings"
 	"sync"
+	"time"
 )
 
 var (
@@ -315,3 +316,15 @@ func copyBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
 	}
 	return written, err
 }
+
+//连接重置 清空缓存区
+func FlushConn(c net.Conn) {
+	c.SetReadDeadline(time.Now().Add(time.Second * 3))
+	buf := bufPool.Get().([]byte)
+	for {
+		if _, err := c.Read(buf); err != nil {
+			break
+		}
+	}
+	c.SetReadDeadline(time.Time{})
+}