base.go 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. package proxy
  2. import (
  3. "errors"
  4. "github.com/cnlh/nps/bridge"
  5. "github.com/cnlh/nps/lib/common"
  6. "github.com/cnlh/nps/lib/conn"
  7. "github.com/cnlh/nps/lib/file"
  8. "github.com/cnlh/nps/lib/pool"
  9. "net"
  10. "net/http"
  11. "sync"
  12. )
  13. type Service interface {
  14. Start() error
  15. Close() error
  16. }
  17. //server base struct
  18. type server struct {
  19. id int
  20. bridge *bridge.Bridge
  21. task *file.Tunnel
  22. errorContent []byte
  23. sync.Mutex
  24. }
  25. func (s *server) FlowAdd(in, out int64) {
  26. s.Lock()
  27. defer s.Unlock()
  28. s.task.Flow.ExportFlow += out
  29. s.task.Flow.InletFlow += in
  30. }
  31. func (s *server) FlowAddHost(host *file.Host, in, out int64) {
  32. s.Lock()
  33. defer s.Unlock()
  34. host.Flow.ExportFlow += out
  35. host.Flow.InletFlow += in
  36. }
  37. func (s *server) linkCopy(link *conn.Link, c *conn.Conn, rb []byte, tunnel *conn.Conn, flow *file.Flow) {
  38. if rb != nil {
  39. if _, err := tunnel.SendMsg(rb, link); err != nil {
  40. c.Close()
  41. return
  42. }
  43. flow.Add(len(rb), 0)
  44. }
  45. buf := pool.BufPoolCopy.Get().([]byte)
  46. for {
  47. if err := s.checkFlow(); err != nil {
  48. c.Close()
  49. break
  50. }
  51. if n, err := c.Read(buf); err != nil {
  52. tunnel.SendMsg([]byte(common.IO_EOF), link)
  53. break
  54. } else {
  55. if _, err := tunnel.SendMsg(buf[:n], link); err != nil {
  56. c.Close()
  57. break
  58. }
  59. flow.Add(n, 0)
  60. }
  61. <-link.StatusCh
  62. }
  63. pool.PutBufPoolCopy(buf)
  64. }
  65. func (s *server) writeConnFail(c net.Conn) {
  66. c.Write([]byte(common.ConnectionFailBytes))
  67. c.Write(s.errorContent)
  68. }
  69. //权限认证
  70. func (s *server) auth(r *http.Request, c *conn.Conn, u, p string) error {
  71. if u != "" && p != "" && !common.CheckAuth(r, u, p) {
  72. c.Write([]byte(common.UnauthorizedBytes))
  73. c.Close()
  74. return errors.New("401 Unauthorized")
  75. }
  76. return nil
  77. }
  78. func (s *server) checkFlow() error {
  79. if s.task.Client.Flow.FlowLimit > 0 && (s.task.Client.Flow.FlowLimit<<20) < (s.task.Client.Flow.ExportFlow+s.task.Client.Flow.InletFlow) {
  80. return errors.New("Traffic exceeded")
  81. }
  82. return nil
  83. }