file.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. package file
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "github.com/astaxie/beego/logs"
  6. "os"
  7. "path/filepath"
  8. "strings"
  9. "sync"
  10. "sync/atomic"
  11. "ehang.io/nps/lib/common"
  12. "ehang.io/nps/lib/rate"
  13. )
  14. func NewJsonDb(runPath string) *JsonDb {
  15. return &JsonDb{
  16. RunPath: runPath,
  17. TaskFilePath: filepath.Join(runPath, "conf", "tasks.json"),
  18. HostFilePath: filepath.Join(runPath, "conf", "hosts.json"),
  19. ClientFilePath: filepath.Join(runPath, "conf", "clients.json"),
  20. }
  21. }
  22. type JsonDb struct {
  23. Tasks sync.Map
  24. Hosts sync.Map
  25. HostsTmp sync.Map
  26. Clients sync.Map
  27. RunPath string
  28. ClientIncreaseId int32 //client increased id
  29. TaskIncreaseId int32 //task increased id
  30. HostIncreaseId int32 //host increased id
  31. TaskFilePath string //task file path
  32. HostFilePath string //host file path
  33. ClientFilePath string //client file path
  34. }
  35. func (s *JsonDb) LoadTaskFromJsonFile() {
  36. loadSyncMapFromFile(s.TaskFilePath, func(v string) {
  37. var err error
  38. post := new(Tunnel)
  39. if json.Unmarshal([]byte(v), &post) != nil {
  40. return
  41. }
  42. if post.Client, err = s.GetClient(post.Client.Id); err != nil {
  43. return
  44. }
  45. s.Tasks.Store(post.Id, post)
  46. if post.Id > int(s.TaskIncreaseId) {
  47. s.TaskIncreaseId = int32(post.Id)
  48. }
  49. })
  50. }
  51. func (s *JsonDb) LoadClientFromJsonFile() {
  52. loadSyncMapFromFile(s.ClientFilePath, func(v string) {
  53. post := new(Client)
  54. if json.Unmarshal([]byte(v), &post) != nil {
  55. return
  56. }
  57. if post.RateLimit > 0 {
  58. post.Rate = rate.NewRate(int64(post.RateLimit * 1024))
  59. } else {
  60. post.Rate = rate.NewRate(int64(2 << 23))
  61. }
  62. post.Rate.Start()
  63. post.NowConn = 0
  64. s.Clients.Store(post.Id, post)
  65. if post.Id > int(s.ClientIncreaseId) {
  66. s.ClientIncreaseId = int32(post.Id)
  67. }
  68. })
  69. }
  70. func (s *JsonDb) LoadHostFromJsonFile() {
  71. loadSyncMapFromFile(s.HostFilePath, func(v string) {
  72. var err error
  73. post := new(Host)
  74. if json.Unmarshal([]byte(v), &post) != nil {
  75. return
  76. }
  77. if post.Client, err = s.GetClient(post.Client.Id); err != nil {
  78. return
  79. }
  80. s.Hosts.Store(post.Id, post)
  81. if post.Id > int(s.HostIncreaseId) {
  82. s.HostIncreaseId = int32(post.Id)
  83. }
  84. })
  85. }
  86. func (s *JsonDb) GetClient(id int) (c *Client, err error) {
  87. if v, ok := s.Clients.Load(id); ok {
  88. c = v.(*Client)
  89. return
  90. }
  91. err = errors.New("未找到客户端")
  92. return
  93. }
  94. var hostLock sync.Mutex
  95. func (s *JsonDb) StoreHostToJsonFile() {
  96. hostLock.Lock()
  97. storeSyncMapToFile(s.Hosts, s.HostFilePath)
  98. hostLock.Unlock()
  99. }
  100. var taskLock sync.Mutex
  101. func (s *JsonDb) StoreTasksToJsonFile() {
  102. taskLock.Lock()
  103. storeSyncMapToFile(s.Tasks, s.TaskFilePath)
  104. taskLock.Unlock()
  105. }
  106. var clientLock sync.Mutex
  107. func (s *JsonDb) StoreClientsToJsonFile() {
  108. clientLock.Lock()
  109. storeSyncMapToFile(s.Clients, s.ClientFilePath)
  110. clientLock.Unlock()
  111. }
  112. func (s *JsonDb) GetClientId() int32 {
  113. return atomic.AddInt32(&s.ClientIncreaseId, 1)
  114. }
  115. func (s *JsonDb) GetTaskId() int32 {
  116. return atomic.AddInt32(&s.TaskIncreaseId, 1)
  117. }
  118. func (s *JsonDb) GetHostId() int32 {
  119. return atomic.AddInt32(&s.HostIncreaseId, 1)
  120. }
  121. func loadSyncMapFromFile(filePath string, f func(value string)) {
  122. b, err := common.ReadAllFromFile(filePath)
  123. if err != nil {
  124. panic(err)
  125. }
  126. for _, v := range strings.Split(string(b), "\n"+common.CONN_DATA_SEQ) {
  127. f(v)
  128. }
  129. }
  130. func storeSyncMapToFile(m sync.Map, filePath string) {
  131. file, err := os.Create(filePath + ".tmp")
  132. // first create a temporary file to store
  133. if err != nil {
  134. panic(err)
  135. }
  136. m.Range(func(key, value interface{}) bool {
  137. var b []byte
  138. var err error
  139. switch value.(type) {
  140. case *Tunnel:
  141. obj := value.(*Tunnel)
  142. if obj.NoStore {
  143. return true
  144. }
  145. b, err = json.Marshal(obj)
  146. case *Host:
  147. obj := value.(*Host)
  148. if obj.NoStore {
  149. return true
  150. }
  151. b, err = json.Marshal(obj)
  152. case *Client:
  153. obj := value.(*Client)
  154. if obj.NoStore {
  155. return true
  156. }
  157. b, err = json.Marshal(obj)
  158. default:
  159. return true
  160. }
  161. if err != nil {
  162. return true
  163. }
  164. _, err = file.Write(b)
  165. if err != nil {
  166. panic(err)
  167. }
  168. _, err = file.Write([]byte("\n" + common.CONN_DATA_SEQ))
  169. if err != nil {
  170. panic(err)
  171. }
  172. return true
  173. })
  174. _ = file.Sync()
  175. _ = file.Close()
  176. // must close file first, then rename it
  177. err = os.Rename(filePath+".tmp", filePath)
  178. if err != nil {
  179. logs.Error(err, "store to file err, data will lost")
  180. }
  181. // replace the file, maybe provides atomic operation
  182. }