base.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. package proxy
  2. import (
  3. "errors"
  4. "github.com/cnlh/nps/bridge"
  5. "github.com/cnlh/nps/lib/common"
  6. "github.com/cnlh/nps/lib/conn"
  7. "github.com/cnlh/nps/lib/file"
  8. "github.com/cnlh/nps/lib/pool"
  9. "net"
  10. "net/http"
  11. "sync"
  12. )
  13. type Service interface {
  14. Start() error
  15. Close() error
  16. }
  17. //Server BaseServer struct
  18. type BaseServer struct {
  19. id int
  20. bridge *bridge.Bridge
  21. task *file.Tunnel
  22. errorContent []byte
  23. sync.Mutex
  24. }
  25. func NewBaseServer(bridge *bridge.Bridge, task *file.Tunnel) *BaseServer {
  26. return &BaseServer{
  27. bridge: bridge,
  28. task: task,
  29. errorContent: nil,
  30. Mutex: sync.Mutex{},
  31. }
  32. }
  33. func (s *BaseServer) FlowAdd(in, out int64) {
  34. s.Lock()
  35. defer s.Unlock()
  36. s.task.Flow.ExportFlow += out
  37. s.task.Flow.InletFlow += in
  38. }
  39. func (s *BaseServer) FlowAddHost(host *file.Host, in, out int64) {
  40. s.Lock()
  41. defer s.Unlock()
  42. host.Flow.ExportFlow += out
  43. host.Flow.InletFlow += in
  44. }
  45. func (s *BaseServer) linkCopy(link *conn.Link, c *conn.Conn, rb []byte, tunnel *conn.Conn, flow *file.Flow) {
  46. if rb != nil {
  47. if _, err := tunnel.SendMsg(rb, link); err != nil {
  48. c.Close()
  49. return
  50. }
  51. flow.Add(len(rb), 0)
  52. <-link.StatusCh
  53. }
  54. buf := pool.BufPoolCopy.Get().([]byte)
  55. for {
  56. if err := s.checkFlow(); err != nil {
  57. c.Close()
  58. break
  59. }
  60. if n, err := c.Read(buf); err != nil {
  61. tunnel.SendMsg([]byte(common.IO_EOF), link)
  62. break
  63. } else {
  64. if _, err := tunnel.SendMsg(buf[:n], link); err != nil {
  65. c.Close()
  66. break
  67. }
  68. flow.Add(n, 0)
  69. }
  70. <-link.StatusCh
  71. }
  72. s.task.Client.AddConn()
  73. pool.PutBufPoolCopy(buf)
  74. }
  75. func (s *BaseServer) writeConnFail(c net.Conn) {
  76. c.Write([]byte(common.ConnectionFailBytes))
  77. c.Write(s.errorContent)
  78. }
  79. //权限认证
  80. func (s *BaseServer) auth(r *http.Request, c *conn.Conn, u, p string) error {
  81. if u != "" && p != "" && !common.CheckAuth(r, u, p) {
  82. c.Write([]byte(common.UnauthorizedBytes))
  83. c.Close()
  84. return errors.New("401 Unauthorized")
  85. }
  86. return nil
  87. }
  88. func (s *BaseServer) checkFlow() error {
  89. if s.task.Client.Flow.FlowLimit > 0 && (s.task.Client.Flow.FlowLimit<<20) < (s.task.Client.Flow.ExportFlow+s.task.Client.Flow.InletFlow) {
  90. return errors.New("Traffic exceeded")
  91. }
  92. return nil
  93. }
  94. //与客户端建立通道
  95. func (s *BaseServer) DealClient(c *conn.Conn, addr string, rb []byte) error {
  96. link := conn.NewLink(s.task.Client.GetId(), common.CONN_TCP, addr, s.task.Client.Cnf.CompressEncode, s.task.Client.Cnf.CompressDecode, s.task.Client.Cnf.Crypt, c, s.task.Flow, nil, s.task.Client.Rate, nil)
  97. if tunnel, err := s.bridge.SendLinkInfo(s.task.Client.Id, link, c.Conn.RemoteAddr().String()); err != nil {
  98. c.Close()
  99. return err
  100. } else {
  101. link.Run(true)
  102. s.linkCopy(link, c, rb, tunnel, s.task.Flow)
  103. }
  104. return nil
  105. }