فهرست منبع

socks5 udp support

刘河 5 سال پیش
والد
کامیت
927038fd4c
3فایلهای تغییر یافته به همراه385 افزوده شده و 18 حذف شده
  1. 60 0
      client/client.go
  2. 208 1
      lib/common/netpackager.go
  3. 117 17
      server/proxy/socks5.go

+ 60 - 0
client/client.go

@@ -2,6 +2,7 @@ package client
 
 import (
 	"bufio"
+	"bytes"
 	"net"
 	"net/http"
 	"strconv"
@@ -189,6 +190,10 @@ func (s *TRPClient) handleChan(src net.Conn) {
 		}
 		return
 	}
+	if lk.ConnType == "udp" {
+		logs.Trace("new %s connection with the goal of %s, remote address:%s", lk.ConnType, lk.Host, lk.RemoteAddr)
+		s.handleUdp(src)
+	}
 	//connect to target if conn type is tcp or udp
 	if targetConn, err := net.DialTimeout(lk.ConnType, lk.Host, lk.Option.Timeout); err != nil {
 		logs.Warn("connect to %s error %s", lk.Host, err.Error())
@@ -199,6 +204,61 @@ func (s *TRPClient) handleChan(src net.Conn) {
 	}
 }
 
+func (s *TRPClient) handleUdp(serverConn net.Conn) {
+	// bind a local udp port
+	local, err := net.ListenUDP("udp", nil)
+	defer local.Close()
+	defer serverConn.Close()
+	if err != nil {
+		logs.Error("bind local udp port error ", err.Error())
+		return
+	}
+	go func() {
+		defer serverConn.Close()
+		b := common.BufPoolUdp.Get().([]byte)
+		defer common.BufPoolUdp.Put(b)
+		for {
+			n, raddr, err := local.ReadFrom(b)
+			if err != nil {
+				logs.Error("read data from remote server error", err.Error())
+			}
+			buf := bytes.Buffer{}
+			dgram := common.NewUDPDatagram(common.NewUDPHeader(0, 0, common.ToSocksAddr(raddr)), b[:n])
+			dgram.Write(&buf)
+			if _, err := serverConn.Write(buf.Bytes()); err != nil {
+				logs.Error("write data to remote  error", err.Error())
+				return
+			}
+		}
+	}()
+	b := common.BufPoolUdp.Get().([]byte)
+	defer common.BufPoolUdp.Put(b)
+	for {
+		n, err := serverConn.Read(b)
+		if err != nil {
+			logs.Error("read udp data from server error ", err.Error())
+			return
+		}
+
+		udpData, err := common.ReadUDPDatagram(bytes.NewReader(b[:n]))
+		if err != nil {
+			logs.Error("unpack data error", err.Error())
+			return
+		}
+
+		raddr, err := net.ResolveUDPAddr("udp", udpData.Header.Addr.String())
+		if err != nil {
+			logs.Error("build remote addr err", err.Error())
+			continue // drop silently
+		}
+		_, err = local.WriteTo(udpData.Data, raddr)
+		if err != nil {
+			logs.Error("write data to remote ", raddr.String(), "error", err.Error())
+			return
+		}
+	}
+}
+
 // Whether the monitor channel is closed
 func (s *TRPClient) ping() {
 	s.ticker = time.NewTicker(time.Second * 5)

+ 208 - 1
lib/common/netpackager.go

@@ -6,6 +6,9 @@ import (
 	"encoding/json"
 	"errors"
 	"io"
+	"io/ioutil"
+	"net"
+	"strconv"
 	"strings"
 )
 
@@ -119,7 +122,8 @@ func (Self *BasePackager) Split() (strList []string) {
 	return
 }
 
-type ConnPackager struct { // Todo
+type ConnPackager struct {
+	// Todo
 	ConnType uint8
 	BasePackager
 }
@@ -233,3 +237,206 @@ func (Self *MuxPackager) UnPack(reader io.Reader) (n uint16, err error) {
 	n += 5 //uint8 int32
 	return
 }
+
+const (
+	ipV4       = 1
+	domainName = 3
+	ipV6       = 4
+)
+
+type UDPHeader struct {
+	Rsv  uint16
+	Frag uint8
+	Addr *Addr
+}
+
+func NewUDPHeader(rsv uint16, frag uint8, addr *Addr) *UDPHeader {
+	return &UDPHeader{
+		Rsv:  rsv,
+		Frag: frag,
+		Addr: addr,
+	}
+}
+
+type Addr struct {
+	Type uint8
+	Host string
+	Port uint16
+}
+
+func (addr *Addr) String() string {
+	return net.JoinHostPort(addr.Host, strconv.Itoa(int(addr.Port)))
+}
+
+func (addr *Addr) Decode(b []byte) error {
+	addr.Type = b[0]
+	pos := 1
+	switch addr.Type {
+	case ipV4:
+		addr.Host = net.IP(b[pos:pos+net.IPv4len]).String()
+		pos += net.IPv4len
+	case ipV6:
+		addr.Host = net.IP(b[pos:pos+net.IPv6len]).String()
+		pos += net.IPv6len
+	case domainName:
+		addrlen := int(b[pos])
+		pos++
+		addr.Host = string(b[pos : pos+addrlen])
+		pos += addrlen
+	default:
+		return errors.New("decode error")
+	}
+
+	addr.Port = binary.BigEndian.Uint16(b[pos:])
+
+	return nil
+}
+
+func (addr *Addr) Encode(b []byte) (int, error) {
+	b[0] = addr.Type
+	pos := 1
+	switch addr.Type {
+	case ipV4:
+		ip4 := net.ParseIP(addr.Host).To4()
+		if ip4 == nil {
+			ip4 = net.IPv4zero.To4()
+		}
+		pos += copy(b[pos:], ip4)
+	case domainName:
+		b[pos] = byte(len(addr.Host))
+		pos++
+		pos += copy(b[pos:], []byte(addr.Host))
+	case ipV6:
+		ip16 := net.ParseIP(addr.Host).To16()
+		if ip16 == nil {
+			ip16 = net.IPv6zero.To16()
+		}
+		pos += copy(b[pos:], ip16)
+	default:
+		b[0] = ipV4
+		copy(b[pos:pos+4], net.IPv4zero.To4())
+		pos += 4
+	}
+	binary.BigEndian.PutUint16(b[pos:], addr.Port)
+	pos += 2
+
+	return pos, nil
+}
+
+func (h *UDPHeader) Write(w io.Writer) error {
+	b := BufPoolUdp.Get().([]byte)
+	defer BufPoolUdp.Put(b)
+
+	binary.BigEndian.PutUint16(b[:2], h.Rsv)
+	b[2] = h.Frag
+
+	addr := h.Addr
+	if addr == nil {
+		addr = &Addr{}
+	}
+	length, _ := addr.Encode(b[3:])
+
+	_, err := w.Write(b[:3+length])
+	return err
+}
+
+type UDPDatagram struct {
+	Header *UDPHeader
+	Data   []byte
+}
+
+func ReadUDPDatagram(r io.Reader) (*UDPDatagram, error) {
+	b := BufPoolUdp.Get().([]byte)
+	defer BufPoolUdp.Put(b)
+
+	// when r is a streaming (such as TCP connection), we may read more than the required data,
+	// but we don't know how to handle it. So we use io.ReadFull to instead of io.ReadAtLeast
+	// to make sure that no redundant data will be discarded.
+	n, err := io.ReadFull(r, b[:5])
+	if err != nil {
+		return nil, err
+	}
+
+	header := &UDPHeader{
+		Rsv:  binary.BigEndian.Uint16(b[:2]),
+		Frag: b[2],
+	}
+
+	atype := b[3]
+	hlen := 0
+	switch atype {
+	case ipV4:
+		hlen = 10
+	case ipV6:
+		hlen = 22
+	case domainName:
+		hlen = 7 + int(b[4])
+	default:
+		return nil, errors.New("addr not support")
+	}
+	dlen := int(header.Rsv)
+	if dlen == 0 { // standard SOCKS5 UDP datagram
+		extra, err := ioutil.ReadAll(r) // we assume no redundant data
+		if err != nil {
+			return nil, err
+		}
+		copy(b[n:], extra)
+		n += len(extra) // total length
+		dlen = n - hlen // data length
+	} else { // extended feature, for UDP over TCP, using reserved field as data length
+		if _, err := io.ReadFull(r, b[n:hlen+dlen]); err != nil {
+			return nil, err
+		}
+		n = hlen + dlen
+	}
+	header.Addr = new(Addr)
+	if err := header.Addr.Decode(b[3:hlen]); err != nil {
+		return nil, err
+	}
+	data := make([]byte, dlen)
+	copy(data, b[hlen:n])
+	d := &UDPDatagram{
+		Header: header,
+		Data:   data,
+	}
+	return d, nil
+}
+
+func NewUDPDatagram(header *UDPHeader, data []byte) *UDPDatagram {
+	return &UDPDatagram{
+		Header: header,
+		Data:   data,
+	}
+}
+
+func (d *UDPDatagram) Write(w io.Writer) error {
+	h := d.Header
+	if h == nil {
+		h = &UDPHeader{}
+	}
+	buf := bytes.Buffer{}
+	if err := h.Write(&buf); err != nil {
+		return err
+	}
+	if _, err := buf.Write(d.Data); err != nil {
+		return err
+	}
+
+	_, err := buf.WriteTo(w)
+	return err
+}
+
+func ToSocksAddr(addr net.Addr) *Addr {
+	host := "0.0.0.0"
+	port := 0
+	if addr != nil {
+		h, p, _ := net.SplitHostPort(addr.String())
+		host = h
+		port, _ = strconv.Atoi(p)
+	}
+	return &Addr{
+		Type: ipV4,
+		Host: host,
+		Port: uint16(port),
+	}
+}

+ 117 - 17
server/proxy/socks5.go

@@ -3,6 +3,7 @@ package proxy
 import (
 	"encoding/binary"
 	"errors"
+	"fmt"
 	"io"
 	"net"
 	"strconv"
@@ -154,27 +155,126 @@ func (s *Sock5ModeServer) handleConnect(c net.Conn) {
 // passive mode
 func (s *Sock5ModeServer) handleBind(c net.Conn) {
 }
+func (s *Sock5ModeServer) sendUdpReply(writeConn net.Conn, c net.Conn, rep uint8, serverIp string) {
+	reply := []byte{
+		5,
+		rep,
+		0,
+		1,
+	}
+	localHost, localPort, _ := net.SplitHostPort(c.LocalAddr().String())
+	localHost = serverIp
+	ipBytes := net.ParseIP(localHost).To4()
+	nPort, _ := strconv.Atoi(localPort)
+	reply = append(reply, ipBytes...)
+	portBytes := make([]byte, 2)
+	binary.BigEndian.PutUint16(portBytes, uint16(nPort))
+	reply = append(reply, portBytes...)
+	writeConn.Write(reply)
+
+}
 
-//udp
 func (s *Sock5ModeServer) handleUDP(c net.Conn) {
-	/*
-	   +----+------+------+----------+----------+----------+
-	   |RSV | FRAG | ATYP | DST.ADDR | DST.PORT |   DATA   |
-	   +----+------+------+----------+----------+----------+
-	   | 2  |  1   |  1   | Variable |    2     | Variable |
-	   +----+------+------+----------+----------+----------+
-	*/
-	buf := make([]byte, 3)
-	c.Read(buf)
-	// relay udp datagram silently, without any notification to the requesting client
-	if buf[2] != 0 {
-		// does not support fragmentation, drop it
-		logs.Warn("does not support fragmentation, drop")
-		dummy := make([]byte, maxUDPPacketSize)
-		c.Read(dummy)
+	defer c.Close()
+	addrType := make([]byte, 1)
+	c.Read(addrType)
+	var host string
+	switch addrType[0] {
+	case ipV4:
+		ipv4 := make(net.IP, net.IPv4len)
+		c.Read(ipv4)
+		host = ipv4.String()
+	case ipV6:
+		ipv6 := make(net.IP, net.IPv6len)
+		c.Read(ipv6)
+		host = ipv6.String()
+	case domainName:
+		var domainLen uint8
+		binary.Read(c, binary.BigEndian, &domainLen)
+		domain := make([]byte, domainLen)
+		c.Read(domain)
+		host = string(domain)
+	default:
+		s.sendReply(c, addrTypeNotSupported)
+		return
+	}
+	//读取端口
+	var port uint16
+	binary.Read(c, binary.BigEndian, &port)
+	fmt.Println(host, string(port))
+	replyAddr, err := net.ResolveUDPAddr("udp", s.task.ServerIp+":0")
+	if err != nil {
+		logs.Error("build local reply addr error", err)
+		return
+	}
+	reply, err := net.ListenUDP("udp", replyAddr)
+	if err != nil {
+		s.sendReply(c, addrTypeNotSupported)
+		logs.Error("listen local reply udp port error")
+		return
+	}
+
+	// reply the local addr
+	s.sendUdpReply(c, reply, succeeded, "106.12.146.199")
+	defer reply.Close()
+
+	// new a tunnel to client
+	link := conn.NewLink("udp", "", s.task.Client.Cnf.Crypt, s.task.Client.Cnf.Compress, c.RemoteAddr().String(), false)
+	target, err := s.bridge.SendLinkInfo(s.task.Client.Id, link, s.task)
+	if err != nil {
+		logs.Warn("get connection from client id %d  error %s", s.task.Client.Id, err.Error())
+		return
 	}
 
-	s.doConnect(c, associateMethod)
+	var clientAddr net.Addr
+	// copy buffer
+	go func() {
+		b := common.BufPoolUdp.Get().([]byte)
+		defer common.BufPoolUdp.Put(b)
+		defer c.Close()
+
+		for {
+			n, laddr, err := reply.ReadFrom(b)
+			if err != nil {
+				logs.Error("read data from %s err %s", reply.LocalAddr().String(), err.Error())
+				return
+			}
+			if clientAddr == nil {
+				clientAddr = laddr
+			}
+			if _, err := target.Write(b[:n]); err != nil {
+				logs.Error("write data to client error", err.Error())
+				return
+			}
+		}
+	}()
+
+	go func() {
+		b := common.BufPoolUdp.Get().([]byte)
+		defer common.BufPoolUdp.Put(b)
+		defer c.Close()
+		for {
+			n, err := target.Read(b)
+			if err != nil {
+				logs.Warn("read data form client error", err.Error())
+				return
+			}
+			if _, err := reply.WriteTo(b[:n], clientAddr); err != nil {
+				logs.Warn("write data to user ", err.Error())
+				return
+			}
+		}
+	}()
+
+	b := common.BufPoolUdp.Get().([]byte)
+	defer common.BufPoolUdp.Put(b)
+	for {
+		_, err := c.Read(b)
+		if err != nil {
+			c.Close()
+			return
+		}
+	}
 }
 
 //new conn