base.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. package server
  2. import (
  3. "errors"
  4. "github.com/cnlh/nps/bridge"
  5. "github.com/cnlh/nps/lib/common"
  6. "github.com/cnlh/nps/lib/conn"
  7. "github.com/cnlh/nps/lib/file"
  8. "github.com/cnlh/nps/lib/pool"
  9. "net"
  10. "net/http"
  11. "sync"
  12. )
  13. //server base struct
  14. type server struct {
  15. id int
  16. bridge *bridge.Bridge
  17. task *file.Tunnel
  18. config *file.Config
  19. errorContent []byte
  20. sync.Mutex
  21. }
  22. func (s *server) FlowAdd(in, out int64) {
  23. s.Lock()
  24. defer s.Unlock()
  25. s.task.Flow.ExportFlow += out
  26. s.task.Flow.InletFlow += in
  27. }
  28. func (s *server) FlowAddHost(host *file.Host, in, out int64) {
  29. s.Lock()
  30. defer s.Unlock()
  31. host.Flow.ExportFlow += out
  32. host.Flow.InletFlow += in
  33. }
  34. //热更新配置
  35. func (s *server) ResetConfig() bool {
  36. //获取最新数据
  37. task, err := file.GetCsvDb().GetTask(s.task.Id)
  38. if err != nil {
  39. return false
  40. }
  41. if s.task.Client.Flow.FlowLimit > 0 && (s.task.Client.Flow.FlowLimit<<20) < (s.task.Client.Flow.ExportFlow+s.task.Client.Flow.InletFlow) {
  42. return false
  43. }
  44. s.task.UseClientCnf = task.UseClientCnf
  45. //使用客户端配置
  46. client, err := file.GetCsvDb().GetClient(s.task.Client.Id)
  47. if s.task.UseClientCnf {
  48. if err == nil {
  49. s.config.U = client.Cnf.U
  50. s.config.P = client.Cnf.P
  51. s.config.Compress = client.Cnf.Compress
  52. s.config.Crypt = client.Cnf.Crypt
  53. }
  54. } else {
  55. if err == nil {
  56. s.config.U = task.Config.U
  57. s.config.P = task.Config.P
  58. s.config.Compress = task.Config.Compress
  59. s.config.Crypt = task.Config.Crypt
  60. }
  61. }
  62. s.task.Client.Rate = client.Rate
  63. s.config.CompressDecode, s.config.CompressEncode = common.GetCompressType(s.config.Compress)
  64. return true
  65. }
  66. func (s *server) linkCopy(link *conn.Link, c *conn.Conn, rb []byte, tunnel *conn.Conn, flow *file.Flow) {
  67. if rb != nil {
  68. if _, err := tunnel.SendMsg(rb, link); err != nil {
  69. c.Close()
  70. return
  71. }
  72. flow.Add(len(rb), 0)
  73. }
  74. buf := pool.BufPoolCopy.Get().([]byte)
  75. for {
  76. if n, err := c.Read(buf); err != nil {
  77. tunnel.SendMsg([]byte(common.IO_EOF), link)
  78. break
  79. } else {
  80. if _, err := tunnel.SendMsg(buf[:n], link); err != nil {
  81. c.Close()
  82. break
  83. }
  84. flow.Add(n, 0)
  85. }
  86. }
  87. pool.PutBufPoolCopy(buf)
  88. }
  89. func (s *server) writeConnFail(c net.Conn) {
  90. c.Write([]byte(common.ConnectionFailBytes))
  91. c.Write(s.errorContent)
  92. }
  93. //权限认证
  94. func (s *server) auth(r *http.Request, c *conn.Conn, u, p string) error {
  95. if u != "" && p != "" && !common.CheckAuth(r, u, p) {
  96. c.Write([]byte(common.UnauthorizedBytes))
  97. c.Close()
  98. return errors.New("401 Unauthorized")
  99. }
  100. return nil
  101. }