file.go 8.9 KB


  1. package file
  2. import (
  3. "encoding/csv"
  4. "errors"
  5. "github.com/cnlh/nps/lib/common"
  6. "github.com/cnlh/nps/lib/lg"
  7. "github.com/cnlh/nps/lib/rate"
  8. "os"
  9. "path/filepath"
  10. "strconv"
  11. "strings"
  12. "sync"
  13. )
  14. func NewCsv(runPath string) *Csv {
  15. return &Csv{
  16. RunPath: runPath,
  17. }
  18. }
  19. type Csv struct {
  20. Tasks []*Tunnel
  21. Path string
  22. Hosts []*Host //域名列表
  23. Clients []*Client //客户端
  24. RunPath string //存储根目录
  25. ClientIncreaseId int //客户端id
  26. TaskIncreaseId int //任务自增ID
  27. sync.Mutex
  28. }
  29. func (s *Csv) Init() {
  30. s.LoadClientFromCsv()
  31. s.LoadTaskFromCsv()
  32. s.LoadHostFromCsv()
  33. }
  34. func (s *Csv) StoreTasksToCsv() {
  35. // 创建文件
  36. csvFile, err := os.Create(filepath.Join(s.RunPath, "conf", "tasks.csv"))
  37. if err != nil {
  38. lg.Fatalf(err.Error())
  39. }
  40. defer csvFile.Close()
  41. writer := csv.NewWriter(csvFile)
  42. for _, task := range s.Tasks {
  43. record := []string{
  44. strconv.Itoa(task.TcpPort),
  45. task.Mode,
  46. task.Target,
  47. task.Config.U,
  48. task.Config.P,
  49. task.Config.Compress,
  50. common.GetStrByBool(task.Status),
  51. common.GetStrByBool(task.Config.Crypt),
  52. strconv.Itoa(task.Config.CompressEncode),
  53. strconv.Itoa(task.Config.CompressDecode),
  54. strconv.Itoa(task.Id),
  55. strconv.Itoa(task.Client.Id),
  56. strconv.FormatBool(task.UseClientCnf),
  57. task.Remark,
  58. }
  59. err := writer.Write(record)
  60. if err != nil {
  61. lg.Fatalf(err.Error())
  62. }
  63. }
  64. writer.Flush()
  65. }
  66. func (s *Csv) openFile(path string) ([][]string, error) {
  67. // 打开文件
  68. file, err := os.Open(path)
  69. if err != nil {
  70. panic(err)
  71. }
  72. defer file.Close()
  73. // 获取csv的reader
  74. reader := csv.NewReader(file)
  75. // 设置FieldsPerRecord为-1
  76. reader.FieldsPerRecord = -1
  77. // 读取文件中所有行保存到slice中
  78. return reader.ReadAll()
  79. }
  80. func (s *Csv) LoadTaskFromCsv() {
  81. path := filepath.Join(s.RunPath, "conf", "tasks.csv")
  82. records, err := s.openFile(path)
  83. if err != nil {
  84. lg.Fatalln("配置文件打开错误:", path)
  85. }
  86. var tasks []*Tunnel
  87. // 将每一行数据保存到内存slice中
  88. for _, item := range records {
  89. post := &Tunnel{
  90. TcpPort: common.GetIntNoErrByStr(item[0]),
  91. Mode: item[1],
  92. Target: item[2],
  93. Config: &Config{
  94. U: item[3],
  95. P: item[4],
  96. Compress: item[5],
  97. Crypt: common.GetBoolByStr(item[7]),
  98. CompressEncode: common.GetIntNoErrByStr(item[8]),
  99. CompressDecode: common.GetIntNoErrByStr(item[9]),
  100. },
  101. Status: common.GetBoolByStr(item[6]),
  102. Id: common.GetIntNoErrByStr(item[10]),
  103. UseClientCnf: common.GetBoolByStr(item[12]),
  104. Remark: item[13],
  105. }
  106. post.Flow = new(Flow)
  107. if post.Client, err = s.GetClient(common.GetIntNoErrByStr(item[11])); err != nil {
  108. continue
  109. }
  110. tasks = append(tasks, post)
  111. if post.Id > s.TaskIncreaseId {
  112. s.TaskIncreaseId = post.Id
  113. }
  114. }
  115. s.Tasks = tasks
  116. }
  117. func (s *Csv) GetTaskId() int {
  118. s.Lock()
  119. defer s.Unlock()
  120. s.TaskIncreaseId++
  121. return s.TaskIncreaseId
  122. }
  123. func (s *Csv) GetIdByVerifyKey(vKey string, addr string) (int, error) {
  124. s.Lock()
  125. defer s.Unlock()
  126. for _, v := range s.Clients {
  127. if common.Getverifyval(v.VerifyKey) == vKey && v.Status {
  128. if arr := strings.Split(addr, ":"); len(arr) > 0 {
  129. v.Addr = arr[0]
  130. }
  131. return v.Id, nil
  132. }
  133. }
  134. return 0, errors.New("not found")
  135. }
  136. func (s *Csv) NewTask(t *Tunnel) {
  137. t.Flow = new(Flow)
  138. s.Tasks = append(s.Tasks, t)
  139. s.StoreTasksToCsv()
  140. }
  141. func (s *Csv) UpdateTask(t *Tunnel) error {
  142. for k, v := range s.Tasks {
  143. if v.Id == t.Id {
  144. s.Tasks = append(s.Tasks[:k], s.Tasks[k+1:]...)
  145. s.Tasks = append(s.Tasks, t)
  146. s.StoreTasksToCsv()
  147. return nil
  148. }
  149. }
  150. return errors.New("不存在")
  151. }
  152. func (s *Csv) DelTask(id int) error {
  153. for k, v := range s.Tasks {
  154. if v.Id == id {
  155. s.Tasks = append(s.Tasks[:k], s.Tasks[k+1:]...)
  156. s.StoreTasksToCsv()
  157. return nil
  158. }
  159. }
  160. return errors.New("不存在")
  161. }
  162. func (s *Csv) GetTask(id int) (v *Tunnel, err error) {
  163. for _, v = range s.Tasks {
  164. if v.Id == id {
  165. return
  166. }
  167. }
  168. err = errors.New("未找到")
  169. return
  170. }
  171. func (s *Csv) StoreHostToCsv() {
  172. // 创建文件
  173. csvFile, err := os.Create(filepath.Join(s.RunPath, "conf", "hosts.csv"))
  174. if err != nil {
  175. panic(err)
  176. }
  177. defer csvFile.Close()
  178. // 获取csv的Writer
  179. writer := csv.NewWriter(csvFile)
  180. // 将map中的Post转换成slice,因为csv的Write需要slice参数
  181. // 并写入csv文件
  182. for _, host := range s.Hosts {
  183. record := []string{
  184. host.Host,
  185. host.Target,
  186. strconv.Itoa(host.Client.Id),
  187. host.HeaderChange,
  188. host.HostChange,
  189. host.Remark,
  190. }
  191. err1 := writer.Write(record)
  192. if err1 != nil {
  193. panic(err1)
  194. }
  195. }
  196. // 确保所有内存数据刷到csv文件
  197. writer.Flush()
  198. }
  199. func (s *Csv) LoadClientFromCsv() {
  200. path := filepath.Join(s.RunPath, "conf", "clients.csv")
  201. records, err := s.openFile(path)
  202. if err != nil {
  203. lg.Fatalln("配置文件打开错误:", path)
  204. }
  205. var clients []*Client
  206. // 将每一行数据保存到内存slice中
  207. for _, item := range records {
  208. post := &Client{
  209. Id: common.GetIntNoErrByStr(item[0]),
  210. VerifyKey: item[1],
  211. Remark: item[2],
  212. Status: common.GetBoolByStr(item[3]),
  213. RateLimit: common.GetIntNoErrByStr(item[8]),
  214. Cnf: &Config{
  215. U: item[4],
  216. P: item[5],
  217. Crypt: common.GetBoolByStr(item[6]),
  218. Compress: item[7],
  219. },
  220. }
  221. if post.Id > s.ClientIncreaseId {
  222. s.ClientIncreaseId = post.Id
  223. }
  224. if post.RateLimit > 0 {
  225. post.Rate = rate.NewRate(int64(post.RateLimit * 1024))
  226. post.Rate.Start()
  227. }
  228. post.Flow = new(Flow)
  229. post.Flow.FlowLimit = int64(common.GetIntNoErrByStr(item[9]))
  230. clients = append(clients, post)
  231. }
  232. s.Clients = clients
  233. }
  234. func (s *Csv) LoadHostFromCsv() {
  235. path := filepath.Join(s.RunPath, "conf", "hosts.csv")
  236. records, err := s.openFile(path)
  237. if err != nil {
  238. lg.Fatalln("配置文件打开错误:", path)
  239. }
  240. var hosts []*Host
  241. // 将每一行数据保存到内存slice中
  242. for _, item := range records {
  243. post := &Host{
  244. Host: item[0],
  245. Target: item[1],
  246. HeaderChange: item[3],
  247. HostChange: item[4],
  248. Remark: item[5],
  249. }
  250. if post.Client, err = s.GetClient(common.GetIntNoErrByStr(item[2])); err != nil {
  251. continue
  252. }
  253. post.Flow = new(Flow)
  254. hosts = append(hosts, post)
  255. }
  256. s.Hosts = hosts
  257. }
  258. func (s *Csv) DelHost(host string) error {
  259. for k, v := range s.Hosts {
  260. if v.Host == host {
  261. s.Hosts = append(s.Hosts[:k], s.Hosts[k+1:]...)
  262. s.StoreHostToCsv()
  263. return nil
  264. }
  265. }
  266. return errors.New("不存在")
  267. }
  268. func (s *Csv) NewHost(t *Host) {
  269. t.Flow = new(Flow)
  270. s.Hosts = append(s.Hosts, t)
  271. s.StoreHostToCsv()
  272. }
  273. func (s *Csv) UpdateHost(t *Host) error {
  274. for k, v := range s.Hosts {
  275. if v.Host == t.Host {
  276. s.Hosts = append(s.Hosts[:k], s.Hosts[k+1:]...)
  277. s.Hosts = append(s.Hosts, t)
  278. s.StoreHostToCsv()
  279. return nil
  280. }
  281. }
  282. return errors.New("不存在")
  283. }
  284. func (s *Csv) GetHost(start, length int, id int) ([]*Host, int) {
  285. list := make([]*Host, 0)
  286. var cnt int
  287. for _, v := range s.Hosts {
  288. if id == 0 || v.Client.Id == id {
  289. cnt++
  290. if start--; start < 0 {
  291. if length--; length > 0 {
  292. list = append(list, v)
  293. }
  294. }
  295. }
  296. }
  297. return list, cnt
  298. }
  299. func (s *Csv) DelClient(id int) error {
  300. for k, v := range s.Clients {
  301. if v.Id == id {
  302. s.Clients = append(s.Clients[:k], s.Clients[k+1:]...)
  303. s.StoreClientsToCsv()
  304. return nil
  305. }
  306. }
  307. return errors.New("不存在")
  308. }
  309. func (s *Csv) NewClient(c *Client) {
  310. s.Lock()
  311. defer s.Unlock()
  312. c.Flow = new(Flow)
  313. s.Clients = append(s.Clients, c)
  314. s.StoreClientsToCsv()
  315. }
  316. func (s *Csv) GetClientId() int {
  317. s.Lock()
  318. defer s.Unlock()
  319. s.ClientIncreaseId++
  320. return s.ClientIncreaseId
  321. }
  322. func (s *Csv) UpdateClient(t *Client) error {
  323. s.Lock()
  324. defer s.Unlock()
  325. for _, v := range s.Clients {
  326. if v.Id == t.Id {
  327. v.Cnf = t.Cnf
  328. v.VerifyKey = t.VerifyKey
  329. v.Remark = t.Remark
  330. v.RateLimit = t.RateLimit
  331. v.Flow = t.Flow
  332. v.Rate = t.Rate
  333. s.StoreClientsToCsv()
  334. return nil
  335. }
  336. }
  337. return errors.New("不存在")
  338. }
  339. func (s *Csv) GetClientList(start, length int) ([]*Client, int) {
  340. list := make([]*Client, 0)
  341. var cnt int
  342. for _, v := range s.Clients {
  343. cnt++
  344. if start--; start < 0 {
  345. if length--; length > 0 {
  346. list = append(list, v)
  347. }
  348. }
  349. }
  350. return list, cnt
  351. }
  352. func (s *Csv) GetClient(id int) (v *Client, err error) {
  353. for _, v = range s.Clients {
  354. if v.Id == id {
  355. return
  356. }
  357. }
  358. err = errors.New("未找到")
  359. return
  360. }
  361. func (s *Csv) StoreClientsToCsv() {
  362. // 创建文件
  363. csvFile, err := os.Create(filepath.Join(s.RunPath, "conf", "clients.csv"))
  364. if err != nil {
  365. lg.Fatalln(err.Error())
  366. }
  367. defer csvFile.Close()
  368. writer := csv.NewWriter(csvFile)
  369. for _, client := range s.Clients {
  370. record := []string{
  371. strconv.Itoa(client.Id),
  372. client.VerifyKey,
  373. client.Remark,
  374. strconv.FormatBool(client.Status),
  375. client.Cnf.U,
  376. client.Cnf.P,
  377. common.GetStrByBool(client.Cnf.Crypt),
  378. client.Cnf.Compress,
  379. strconv.Itoa(client.RateLimit),
  380. strconv.Itoa(int(client.Flow.FlowLimit)),
  381. }
  382. err := writer.Write(record)
  383. if err != nil {
  384. lg.Fatalln(err.Error())
  385. }
  386. }
  387. writer.Flush()
  388. }