base.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. package server
  2. import (
  3. "errors"
  4. "github.com/cnlh/nps/bridge"
  5. "github.com/cnlh/nps/lib"
  6. "net"
  7. "net/http"
  8. "sync"
  9. )
  10. //server base struct
  11. type server struct {
  12. id int
  13. bridge *bridge.Bridge
  14. task *lib.Tunnel
  15. config *lib.Config
  16. errorContent []byte
  17. sync.Mutex
  18. }
  19. func (s *server) FlowAdd(in, out int64) {
  20. s.Lock()
  21. defer s.Unlock()
  22. s.task.Flow.ExportFlow += out
  23. s.task.Flow.InletFlow += in
  24. }
  25. func (s *server) FlowAddHost(host *lib.Host, in, out int64) {
  26. s.Lock()
  27. defer s.Unlock()
  28. host.Flow.ExportFlow += out
  29. host.Flow.InletFlow += in
  30. }
  31. //热更新配置
  32. func (s *server) ResetConfig() bool {
  33. //获取最新数据
  34. task, err := lib.GetCsvDb().GetTask(s.task.Id)
  35. if err != nil {
  36. return false
  37. }
  38. if s.task.Client.Flow.FlowLimit > 0 && (s.task.Client.Flow.FlowLimit<<20) < (s.task.Client.Flow.ExportFlow+s.task.Client.Flow.InletFlow) {
  39. return false
  40. }
  41. s.task.UseClientCnf = task.UseClientCnf
  42. //使用客户端配置
  43. client, err := lib.GetCsvDb().GetClient(s.task.Client.Id)
  44. if s.task.UseClientCnf {
  45. if err == nil {
  46. s.config.U = client.Cnf.U
  47. s.config.P = client.Cnf.P
  48. s.config.Compress = client.Cnf.Compress
  49. s.config.Crypt = client.Cnf.Crypt
  50. }
  51. } else {
  52. if err == nil {
  53. s.config.U = task.Config.U
  54. s.config.P = task.Config.P
  55. s.config.Compress = task.Config.Compress
  56. s.config.Crypt = task.Config.Crypt
  57. }
  58. }
  59. s.task.Client.Rate = client.Rate
  60. s.config.CompressDecode, s.config.CompressEncode = lib.GetCompressType(s.config.Compress)
  61. return true
  62. }
  63. func (s *server) linkCopy(link *lib.Link, c *lib.Conn, rb []byte, tunnel *lib.Conn, flow *lib.Flow) {
  64. if rb != nil {
  65. if _, err := tunnel.SendMsg(rb, link); err != nil {
  66. c.Close()
  67. return
  68. }
  69. flow.Add(len(rb), 0)
  70. }
  71. for {
  72. buf := lib.BufPoolCopy.Get().([]byte)
  73. if n, err := c.Read(buf); err != nil {
  74. tunnel.SendMsg([]byte(lib.IO_EOF), link)
  75. break
  76. } else {
  77. if _, err := tunnel.SendMsg(buf[:n], link); err != nil {
  78. lib.PutBufPoolCopy(buf)
  79. c.Close()
  80. break
  81. }
  82. lib.PutBufPoolCopy(buf)
  83. flow.Add(n, 0)
  84. }
  85. }
  86. }
  87. func (s *server) writeConnFail(c net.Conn) {
  88. c.Write([]byte(lib.ConnectionFailBytes))
  89. c.Write(s.errorContent)
  90. }
  91. //权限认证
  92. func (s *server) auth(r *http.Request, c *lib.Conn, u, p string) error {
  93. if u != "" && p != "" && !lib.CheckAuth(r, u, p) {
  94. c.Write([]byte(lib.UnauthorizedBytes))
  95. c.Close()
  96. return errors.New("401 Unauthorized")
  97. }
  98. return nil
  99. }