file.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. package file
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "os"
  6. "path/filepath"
  7. "strings"
  8. "sync"
  9. "sync/atomic"
  10. "ehang.io/nps/lib/common"
  11. "ehang.io/nps/lib/rate"
  12. )
  13. func NewJsonDb(runPath string) *JsonDb {
  14. return &JsonDb{
  15. RunPath: runPath,
  16. TaskFilePath: filepath.Join(runPath, "conf", "tasks.json"),
  17. HostFilePath: filepath.Join(runPath, "conf", "hosts.json"),
  18. ClientFilePath: filepath.Join(runPath, "conf", "clients.json"),
  19. }
  20. }
  21. type JsonDb struct {
  22. Tasks sync.Map
  23. Hosts sync.Map
  24. HostsTmp sync.Map
  25. Clients sync.Map
  26. RunPath string
  27. ClientIncreaseId int32 //client increased id
  28. TaskIncreaseId int32 //task increased id
  29. HostIncreaseId int32 //host increased id
  30. TaskFilePath string //task file path
  31. HostFilePath string //host file path
  32. ClientFilePath string //client file path
  33. }
  34. func (s *JsonDb) LoadTaskFromJsonFile() {
  35. loadSyncMapFromFile(s.TaskFilePath, func(v string) {
  36. var err error
  37. post := new(Tunnel)
  38. if json.Unmarshal([]byte(v), &post) != nil {
  39. return
  40. }
  41. if post.Client, err = s.GetClient(post.Client.Id); err != nil {
  42. return
  43. }
  44. s.Tasks.Store(post.Id, post)
  45. if post.Id > int(s.TaskIncreaseId) {
  46. s.TaskIncreaseId = int32(post.Id)
  47. }
  48. })
  49. }
  50. func (s *JsonDb) LoadClientFromJsonFile() {
  51. loadSyncMapFromFile(s.ClientFilePath, func(v string) {
  52. post := new(Client)
  53. if json.Unmarshal([]byte(v), &post) != nil {
  54. return
  55. }
  56. if post.RateLimit > 0 {
  57. post.Rate = rate.NewRate(int64(post.RateLimit * 1024))
  58. } else {
  59. post.Rate = rate.NewRate(int64(2 << 23))
  60. }
  61. post.Rate.Start()
  62. post.NowConn = 0
  63. s.Clients.Store(post.Id, post)
  64. if post.Id > int(s.ClientIncreaseId) {
  65. s.ClientIncreaseId = int32(post.Id)
  66. }
  67. })
  68. }
  69. func (s *JsonDb) LoadHostFromJsonFile() {
  70. loadSyncMapFromFile(s.HostFilePath, func(v string) {
  71. var err error
  72. post := new(Host)
  73. if json.Unmarshal([]byte(v), &post) != nil {
  74. return
  75. }
  76. if post.Client, err = s.GetClient(post.Client.Id); err != nil {
  77. return
  78. }
  79. s.Hosts.Store(post.Id, post)
  80. if post.Id > int(s.HostIncreaseId) {
  81. s.HostIncreaseId = int32(post.Id)
  82. }
  83. })
  84. }
  85. func (s *JsonDb) GetClient(id int) (c *Client, err error) {
  86. if v, ok := s.Clients.Load(id); ok {
  87. c = v.(*Client)
  88. return
  89. }
  90. err = errors.New("未找到客户端")
  91. return
  92. }
  93. func (s *JsonDb) StoreHostToJsonFile() {
  94. storeSyncMapToFile(s.Hosts, s.HostFilePath)
  95. }
  96. func (s *JsonDb) StoreTasksToJsonFile() {
  97. storeSyncMapToFile(s.Tasks, s.TaskFilePath)
  98. }
  99. func (s *JsonDb) StoreClientsToJsonFile() {
  100. storeSyncMapToFile(s.Clients, s.ClientFilePath)
  101. }
  102. func (s *JsonDb) GetClientId() int32 {
  103. return atomic.AddInt32(&s.ClientIncreaseId, 1)
  104. }
  105. func (s *JsonDb) GetTaskId() int32 {
  106. return atomic.AddInt32(&s.TaskIncreaseId, 1)
  107. }
  108. func (s *JsonDb) GetHostId() int32 {
  109. return atomic.AddInt32(&s.HostIncreaseId, 1)
  110. }
  111. func loadSyncMapFromFile(filePath string, f func(value string)) {
  112. b, err := common.ReadAllFromFile(filePath)
  113. if err != nil {
  114. panic(err)
  115. }
  116. for _, v := range strings.Split(string(b), "\n"+common.CONN_DATA_SEQ) {
  117. f(v)
  118. }
  119. }
  120. func storeSyncMapToFile(m sync.Map, filePath string) {
  121. file, err := os.Create(filePath)
  122. if err != nil {
  123. panic(err)
  124. }
  125. defer file.Close()
  126. m.Range(func(key, value interface{}) bool {
  127. var b []byte
  128. var err error
  129. switch value.(type) {
  130. case *Tunnel:
  131. obj := value.(*Tunnel)
  132. if obj.NoStore {
  133. return true
  134. }
  135. b, err = json.Marshal(obj)
  136. case *Host:
  137. obj := value.(*Host)
  138. if obj.NoStore {
  139. return true
  140. }
  141. b, err = json.Marshal(obj)
  142. case *Client:
  143. obj := value.(*Client)
  144. if obj.NoStore {
  145. return true
  146. }
  147. b, err = json.Marshal(obj)
  148. default:
  149. return true
  150. }
  151. if err != nil {
  152. return true
  153. }
  154. _, err = file.Write(b)
  155. if err != nil {
  156. panic(err)
  157. }
  158. _, err = file.Write([]byte("\n" + common.CONN_DATA_SEQ))
  159. if err != nil {
  160. panic(err)
  161. }
  162. return true
  163. })
  164. }