|
@@ -13,6 +13,7 @@ import (
|
|
|
"net/url"
|
|
|
"strconv"
|
|
|
"strings"
|
|
|
+ "sync"
|
|
|
"time"
|
|
|
)
|
|
|
|
|
@@ -52,12 +53,6 @@ func (s *CryptConn) Write(b []byte) (n int, err error) {
|
|
|
|
|
|
//解密读
|
|
|
func (s *CryptConn) Read(b []byte) (n int, err error) {
|
|
|
- defer func() {
|
|
|
- if err == nil && n == len(IO_EOF) && string(b[:n]) == IO_EOF {
|
|
|
- err = io.EOF
|
|
|
- n = 0
|
|
|
- }
|
|
|
- }()
|
|
|
var lens int
|
|
|
var buf []byte
|
|
|
var rb []byte
|
|
@@ -122,14 +117,8 @@ func (s *SnappyConn) Write(b []byte) (n int, err error) {
|
|
|
|
|
|
//snappy压缩读 包含解密
|
|
|
func (s *SnappyConn) Read(b []byte) (n int, err error) {
|
|
|
- buf := bufPool.Get().([]byte)
|
|
|
- defer func() {
|
|
|
- if err == nil && n == len(IO_EOF) && string(b[:n]) == IO_EOF {
|
|
|
- err = io.EOF
|
|
|
- n = 0
|
|
|
- }
|
|
|
- bufPool.Put(buf)
|
|
|
- }()
|
|
|
+ buf := BufPool.Get().([]byte)
|
|
|
+ defer BufPool.Put(buf)
|
|
|
if n, err = s.r.Read(buf); err != nil {
|
|
|
return
|
|
|
}
|
|
@@ -152,6 +141,7 @@ func (s *SnappyConn) Read(b []byte) (n int, err error) {
|
|
|
|
|
|
type Conn struct {
|
|
|
Conn net.Conn
|
|
|
+ sync.Mutex
|
|
|
}
|
|
|
|
|
|
//new conn
|
|
@@ -161,6 +151,36 @@ func NewConn(conn net.Conn) *Conn {
|
|
|
return c
|
|
|
}
|
|
|
|
|
|
+//从tcp报文中解析出host,连接类型等
|
|
|
+func (s *Conn) GetHost() (method, address string, rb []byte, err error, r *http.Request) {
|
|
|
+ var b [32 * 1024]byte
|
|
|
+ var n int
|
|
|
+ if n, err = s.Read(b[:]); err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ rb = b[:n]
|
|
|
+ r, err = http.ReadRequest(bufio.NewReader(bytes.NewReader(rb)))
|
|
|
+ if err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ hostPortURL, err := url.Parse(r.Host)
|
|
|
+ if err != nil {
|
|
|
+ address = r.Host
|
|
|
+ err = nil
|
|
|
+ return
|
|
|
+ }
|
|
|
+ if hostPortURL.Opaque == "443" { //https访问
|
|
|
+ address = r.Host + ":443"
|
|
|
+ } else { //http访问
|
|
|
+ if strings.Index(hostPortURL.Host, ":") == -1 { //host不带端口, 默认80
|
|
|
+ address = r.Host + ":80"
|
|
|
+ } else {
|
|
|
+ address = r.Host
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return
|
|
|
+}
|
|
|
+
|
|
|
//读取指定长度内容
|
|
|
func (s *Conn) ReadLen(cLen int) ([]byte, error) {
|
|
|
if cLen > poolSize {
|
|
@@ -168,11 +188,11 @@ func (s *Conn) ReadLen(cLen int) ([]byte, error) {
|
|
|
}
|
|
|
var buf []byte
|
|
|
if cLen <= poolSizeSmall {
|
|
|
- buf = bufPoolSmall.Get().([]byte)[:cLen]
|
|
|
- defer bufPoolSmall.Put(buf)
|
|
|
+ buf = BufPoolSmall.Get().([]byte)[:cLen]
|
|
|
+ defer BufPoolSmall.Put(buf)
|
|
|
} else {
|
|
|
- buf = bufPoolMax.Get().([]byte)[:cLen]
|
|
|
- defer bufPoolMax.Put(buf)
|
|
|
+ buf = BufPoolMax.Get().([]byte)[:cLen]
|
|
|
+ defer BufPoolMax.Put(buf)
|
|
|
}
|
|
|
if n, err := io.ReadFull(s, buf); err != nil || n != cLen {
|
|
|
return buf, errors.New("读取指定长度错误" + err.Error())
|
|
@@ -180,7 +200,7 @@ func (s *Conn) ReadLen(cLen int) ([]byte, error) {
|
|
|
return buf, nil
|
|
|
}
|
|
|
|
|
|
-//获取长度
|
|
|
+//read length or id (content length=4)
|
|
|
func (s *Conn) GetLen() (int, error) {
|
|
|
val, err := s.ReadLen(4)
|
|
|
if err != nil {
|
|
@@ -189,17 +209,7 @@ func (s *Conn) GetLen() (int, error) {
|
|
|
return GetLenByBytes(val)
|
|
|
}
|
|
|
|
|
|
-//写入长度+内容 粘包
|
|
|
-func (s *Conn) WriteLen(buf []byte) (int, error) {
|
|
|
- var b []byte
|
|
|
- var err error
|
|
|
- if b, err = GetLenBytes(buf); err != nil {
|
|
|
- return 0, err
|
|
|
- }
|
|
|
- return s.Write(b)
|
|
|
-}
|
|
|
-
|
|
|
-//读取flag
|
|
|
+//read flag
|
|
|
func (s *Conn) ReadFlag() (string, error) {
|
|
|
val, err := s.ReadLen(4)
|
|
|
if err != nil {
|
|
@@ -208,41 +218,21 @@ func (s *Conn) ReadFlag() (string, error) {
|
|
|
return string(val), err
|
|
|
}
|
|
|
|
|
|
-//读取host 连接地址 压缩类型
|
|
|
-func (s *Conn) GetHostFromConn() (typeStr string, host string, en, de int, crypt, mux bool, err error) {
|
|
|
-retry:
|
|
|
- lType, err := s.ReadLen(3)
|
|
|
+//read connect status
|
|
|
+func (s *Conn) GetConnStatus() (id int, status bool, err error) {
|
|
|
+ id, err = s.GetLen()
|
|
|
if err != nil {
|
|
|
return
|
|
|
}
|
|
|
- if typeStr = string(lType); typeStr == TEST_FLAG {
|
|
|
- en, de, crypt, mux = s.GetConnInfoFromConn()
|
|
|
- goto retry
|
|
|
- } else if typeStr != CONN_TCP && typeStr != CONN_UDP {
|
|
|
- err = errors.New("unknown conn type")
|
|
|
- return
|
|
|
- }
|
|
|
- cLen, err := s.GetLen()
|
|
|
- if err != nil || cLen > poolSize {
|
|
|
- return
|
|
|
- }
|
|
|
- hostByte, err := s.ReadLen(cLen)
|
|
|
- if err != nil {
|
|
|
+ var b []byte
|
|
|
+ if b, err = s.ReadLen(1); err != nil {
|
|
|
return
|
|
|
+ } else {
|
|
|
+ status = GetBoolByStr(string(b[0]))
|
|
|
}
|
|
|
- host = string(hostByte)
|
|
|
return
|
|
|
}
|
|
|
|
|
|
-//写连接类型 和 host地址
|
|
|
-func (s *Conn) WriteHost(ltype string, host string) (int, error) {
|
|
|
- raw := bytes.NewBuffer([]byte{})
|
|
|
- binary.Write(raw, binary.LittleEndian, []byte(ltype))
|
|
|
- binary.Write(raw, binary.LittleEndian, int32(len([]byte(host))))
|
|
|
- binary.Write(raw, binary.LittleEndian, []byte(host))
|
|
|
- return s.Write(raw.Bytes())
|
|
|
-}
|
|
|
-
|
|
|
//设置连接为长连接
|
|
|
func (s *Conn) SetAlive() {
|
|
|
conn := s.Conn.(*net.TCPConn)
|
|
@@ -251,40 +241,11 @@ func (s *Conn) SetAlive() {
|
|
|
conn.SetKeepAlivePeriod(time.Duration(2 * time.Second))
|
|
|
}
|
|
|
|
|
|
+//set read dead time
|
|
|
func (s *Conn) SetReadDeadline(t time.Duration) {
|
|
|
s.Conn.(*net.TCPConn).SetReadDeadline(time.Now().Add(time.Duration(t) * time.Second))
|
|
|
}
|
|
|
|
|
|
-//从tcp报文中解析出host,连接类型等 TODO 多种情况
|
|
|
-func (s *Conn) GetHost() (method, address string, rb []byte, err error, r *http.Request) {
|
|
|
- var b [32 * 1024]byte
|
|
|
- var n int
|
|
|
- if n, err = s.Read(b[:]); err != nil {
|
|
|
- return
|
|
|
- }
|
|
|
- rb = b[:n]
|
|
|
- r, err = http.ReadRequest(bufio.NewReader(bytes.NewReader(rb)))
|
|
|
- if err != nil {
|
|
|
- return
|
|
|
- }
|
|
|
- hostPortURL, err := url.Parse(r.Host)
|
|
|
- if err != nil {
|
|
|
- address = r.Host
|
|
|
- err = nil
|
|
|
- return
|
|
|
- }
|
|
|
- if hostPortURL.Opaque == "443" { //https访问
|
|
|
- address = r.Host + ":443"
|
|
|
- } else { //http访问
|
|
|
- if strings.Index(hostPortURL.Host, ":") == -1 { //host不带端口, 默认80
|
|
|
- address = r.Host + ":80"
|
|
|
- } else {
|
|
|
- address = r.Host
|
|
|
- }
|
|
|
- }
|
|
|
- return
|
|
|
-}
|
|
|
-
|
|
|
//单独读(加密|压缩)
|
|
|
func (s *Conn) ReadFrom(b []byte, compress int, crypt bool, rate *Rate) (int, error) {
|
|
|
if COMPRESS_SNAPY_DECODE == compress {
|
|
@@ -301,24 +262,112 @@ func (s *Conn) WriteTo(b []byte, compress int, crypt bool, rate *Rate) (n int, e
|
|
|
return NewCryptConn(s.Conn, crypt, rate).Write(b)
|
|
|
}
|
|
|
|
|
|
-//写压缩方式,加密
|
|
|
-func (s *Conn) WriteConnInfo(en, de int, crypt, mux bool) {
|
|
|
- s.Write([]byte(strconv.Itoa(en) + strconv.Itoa(de) + GetStrByBool(crypt) + GetStrByBool(mux)))
|
|
|
+//send msg
|
|
|
+func (s *Conn) SendMsg(content []byte, link *Link) (n int, err error) {
|
|
|
+ /*
|
|
|
+ The msg info is formed as follows:
|
|
|
+ +----+--------+
|
|
|
+ |id | content |
|
|
|
+ +----+--------+
|
|
|
+ | 4 | ... |
|
|
|
+ +----+--------+
|
|
|
+*/
|
|
|
+ s.Lock()
|
|
|
+ defer s.Unlock()
|
|
|
+ raw := bytes.NewBuffer([]byte{})
|
|
|
+ binary.Write(raw, binary.LittleEndian, int32(link.Id))
|
|
|
+ if n, err = s.Write(raw.Bytes()); err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ raw.Reset()
|
|
|
+ binary.Write(raw, binary.LittleEndian, content)
|
|
|
+ n, err = s.WriteTo(raw.Bytes(), link.En, link.Crypt, link.Rate)
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
-//获取压缩方式,是否加密
|
|
|
-func (s *Conn) GetConnInfoFromConn() (en, de int, crypt, mux bool) {
|
|
|
- buf, err := s.ReadLen(4)
|
|
|
- if err != nil {
|
|
|
+//get msg content from conn
|
|
|
+func (s *Conn) GetMsgContent(link *Link) (content []byte, err error) {
|
|
|
+ s.Lock()
|
|
|
+ defer s.Unlock()
|
|
|
+ buf := BufPoolCopy.Get().([]byte)
|
|
|
+ if n, err := s.ReadFrom(buf, link.De, link.Crypt, link.Rate); err == nil && n > 4 {
|
|
|
+ content = buf[:n]
|
|
|
+ }
|
|
|
+ return
|
|
|
+}
|
|
|
+
|
|
|
+//send info for link
|
|
|
+func (s *Conn) SendLinkInfo(link *Link) (int, error) {
|
|
|
+ /*
|
|
|
+ The link info is formed as follows:
|
|
|
+ +----------+------+----------+------+----------+-----+
|
|
|
+ | id | len | type | hostlen | host | en | de |crypt |
|
|
|
+ +----------+------+----------+------+---------+------+
|
|
|
+ | 4 | 4 | 3 | 4 | host | 1 | 1 | 1 |
|
|
|
+ +----------+------+----------+------+----+----+------+
|
|
|
+ */
|
|
|
+ raw := bytes.NewBuffer([]byte{})
|
|
|
+ binary.Write(raw, binary.LittleEndian, []byte(NEW_CONN))
|
|
|
+ binary.Write(raw, binary.LittleEndian, int32(14+len(link.Host)))
|
|
|
+ binary.Write(raw, binary.LittleEndian, int32(link.Id))
|
|
|
+ binary.Write(raw, binary.LittleEndian, []byte(link.ConnType))
|
|
|
+ binary.Write(raw, binary.LittleEndian, int32(len(link.Host)))
|
|
|
+ binary.Write(raw, binary.LittleEndian, []byte(link.Host))
|
|
|
+ binary.Write(raw, binary.LittleEndian, []byte(strconv.Itoa(link.En)))
|
|
|
+ binary.Write(raw, binary.LittleEndian, []byte(strconv.Itoa(link.De)))
|
|
|
+ binary.Write(raw, binary.LittleEndian, []byte(GetStrByBool(link.Crypt)))
|
|
|
+ s.Lock()
|
|
|
+ defer s.Unlock()
|
|
|
+ return s.Write(raw.Bytes())
|
|
|
+}
|
|
|
+
|
|
|
+func (s *Conn) GetLinkInfo() (link *Link, err error) {
|
|
|
+ s.Lock()
|
|
|
+ defer s.Unlock()
|
|
|
+ var hostLen, n int
|
|
|
+ var buf []byte
|
|
|
+ if n, err = s.GetLen(); err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ link = new(Link)
|
|
|
+ if buf, err = s.ReadLen(n); err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ if link.Id, err = GetLenByBytes(buf[:4]); err != nil {
|
|
|
return
|
|
|
}
|
|
|
- en, _ = strconv.Atoi(string(buf[0]))
|
|
|
- de, _ = strconv.Atoi(string(buf[1]))
|
|
|
- crypt = GetBoolByStr(string(buf[2]))
|
|
|
- mux = GetBoolByStr(string(buf[3]))
|
|
|
+ link.ConnType = string(buf[4:7])
|
|
|
+ if hostLen, err = GetLenByBytes(buf[7:11]); err != nil {
|
|
|
+ return
|
|
|
+ } else {
|
|
|
+ link.Host = string(buf[11 : 11+hostLen])
|
|
|
+ link.En = GetIntNoErrByStr(string(buf[11+hostLen]))
|
|
|
+ link.De = GetIntNoErrByStr(string(buf[12+hostLen]))
|
|
|
+ link.Crypt = GetBoolByStr(string(buf[13+hostLen]))
|
|
|
+ }
|
|
|
return
|
|
|
}
|
|
|
|
|
|
+//write connect success
|
|
|
+func (s *Conn) WriteSuccess(id int) (int, error) {
|
|
|
+ raw := bytes.NewBuffer([]byte{})
|
|
|
+ binary.Write(raw, binary.LittleEndian, int32(id))
|
|
|
+ binary.Write(raw, binary.LittleEndian, []byte("1"))
|
|
|
+ s.Lock()
|
|
|
+ defer s.Unlock()
|
|
|
+ return s.Write(raw.Bytes())
|
|
|
+}
|
|
|
+
|
|
|
+//write connect fail
|
|
|
+func (s *Conn) WriteFail(id int) (int, error) {
|
|
|
+ raw := bytes.NewBuffer([]byte{})
|
|
|
+ binary.Write(raw, binary.LittleEndian, int32(id))
|
|
|
+ binary.Write(raw, binary.LittleEndian, []byte("0"))
|
|
|
+ s.Lock()
|
|
|
+ defer s.Unlock()
|
|
|
+ return s.Write(raw.Bytes())
|
|
|
+}
|
|
|
+
|
|
|
//close
|
|
|
func (s *Conn) Close() error {
|
|
|
return s.Conn.Close()
|
|
@@ -351,29 +400,18 @@ func (s *Conn) WriteClose() (int, error) {
|
|
|
|
|
|
//write main
|
|
|
func (s *Conn) WriteMain() (int, error) {
|
|
|
+ s.Lock()
|
|
|
+ defer s.Unlock()
|
|
|
return s.Write([]byte(WORK_MAIN))
|
|
|
}
|
|
|
|
|
|
//write chan
|
|
|
func (s *Conn) WriteChan() (int, error) {
|
|
|
+ s.Lock()
|
|
|
+ defer s.Unlock()
|
|
|
return s.Write([]byte(WORK_CHAN))
|
|
|
}
|
|
|
|
|
|
-//write test
|
|
|
-func (s *Conn) WriteTest() (int, error) {
|
|
|
- return s.Write([]byte(TEST_FLAG))
|
|
|
-}
|
|
|
-
|
|
|
-//write test
|
|
|
-func (s *Conn) WriteSuccess() (int, error) {
|
|
|
- return s.Write([]byte(CONN_SUCCESS))
|
|
|
-}
|
|
|
-
|
|
|
-//write test
|
|
|
-func (s *Conn) WriteFail() (int, error) {
|
|
|
- return s.Write([]byte(CONN_ERROR))
|
|
|
-}
|
|
|
-
|
|
|
//获取长度+内容
|
|
|
func GetLenBytes(buf []byte) (b []byte, err error) {
|
|
|
raw := bytes.NewBuffer([]byte{})
|