base.go 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. package proxy
  2. import (
  3. "errors"
  4. "github.com/cnlh/nps/bridge"
  5. "github.com/cnlh/nps/lib/common"
  6. "github.com/cnlh/nps/lib/conn"
  7. "github.com/cnlh/nps/lib/file"
  8. "github.com/cnlh/nps/lib/pool"
  9. "net"
  10. "net/http"
  11. "sync"
  12. )
  13. //server base struct
  14. type server struct {
  15. id int
  16. bridge *bridge.Bridge
  17. task *file.Tunnel
  18. errorContent []byte
  19. sync.Mutex
  20. }
  21. func (s *server) FlowAdd(in, out int64) {
  22. s.Lock()
  23. defer s.Unlock()
  24. s.task.Flow.ExportFlow += out
  25. s.task.Flow.InletFlow += in
  26. }
  27. func (s *server) FlowAddHost(host *file.Host, in, out int64) {
  28. s.Lock()
  29. defer s.Unlock()
  30. host.Flow.ExportFlow += out
  31. host.Flow.InletFlow += in
  32. }
  33. func (s *server) linkCopy(link *conn.Link, c *conn.Conn, rb []byte, tunnel *conn.Conn, flow *file.Flow) {
  34. if rb != nil {
  35. if _, err := tunnel.SendMsg(rb, link); err != nil {
  36. c.Close()
  37. return
  38. }
  39. flow.Add(len(rb), 0)
  40. }
  41. buf := pool.BufPoolCopy.Get().([]byte)
  42. for {
  43. if err := s.checkFlow(); err != nil {
  44. c.Close()
  45. break
  46. }
  47. if n, err := c.Read(buf); err != nil {
  48. tunnel.SendMsg([]byte(common.IO_EOF), link)
  49. break
  50. } else {
  51. if _, err := tunnel.SendMsg(buf[:n], link); err != nil {
  52. c.Close()
  53. break
  54. }
  55. flow.Add(n, 0)
  56. }
  57. <-link.StatusCh
  58. }
  59. pool.PutBufPoolCopy(buf)
  60. }
  61. func (s *server) writeConnFail(c net.Conn) {
  62. c.Write([]byte(common.ConnectionFailBytes))
  63. c.Write(s.errorContent)
  64. }
  65. //权限认证
  66. func (s *server) auth(r *http.Request, c *conn.Conn, u, p string) error {
  67. if u != "" && p != "" && !common.CheckAuth(r, u, p) {
  68. c.Write([]byte(common.UnauthorizedBytes))
  69. c.Close()
  70. return errors.New("401 Unauthorized")
  71. }
  72. return nil
  73. }
  74. func (s *server) checkFlow() error {
  75. if s.task.Client.Flow.FlowLimit > 0 && (s.task.Client.Flow.FlowLimit<<20) < (s.task.Client.Flow.ExportFlow+s.task.Client.Flow.InletFlow) {
  76. return errors.New("Traffic exceeded")
  77. }
  78. return nil
  79. }