file.go 8.8 KB


  1. package lib
  2. import (
  3. "encoding/csv"
  4. "errors"
  5. "os"
  6. "path/filepath"
  7. "strconv"
  8. "strings"
  9. "sync"
  10. )
  11. var (
  12. CsvDb *Csv
  13. once sync.Once
  14. )
  15. func NewCsv() *Csv {
  16. return new(Csv)
  17. }
  18. type Csv struct {
  19. Tasks []*Tunnel
  20. Path string
  21. Hosts []*Host //域名列表
  22. Clients []*Client //客户端
  23. ClientIncreaseId int //客户端id
  24. TaskIncreaseId int //任务自增ID
  25. sync.Mutex
  26. }
  27. func (s *Csv) Init() {
  28. s.LoadClientFromCsv()
  29. s.LoadTaskFromCsv()
  30. s.LoadHostFromCsv()
  31. }
  32. func (s *Csv) StoreTasksToCsv() {
  33. // 创建文件
  34. csvFile, err := os.Create(filepath.Join(GetRunPath(), "conf", "tasks.csv"))
  35. if err != nil {
  36. Fatalf(err.Error())
  37. }
  38. defer csvFile.Close()
  39. writer := csv.NewWriter(csvFile)
  40. for _, task := range s.Tasks {
  41. record := []string{
  42. strconv.Itoa(task.TcpPort),
  43. task.Mode,
  44. task.Target,
  45. task.Config.U,
  46. task.Config.P,
  47. task.Config.Compress,
  48. GetStrByBool(task.Status),
  49. GetStrByBool(task.Config.Crypt),
  50. strconv.Itoa(task.Config.CompressEncode),
  51. strconv.Itoa(task.Config.CompressDecode),
  52. strconv.Itoa(task.Id),
  53. strconv.Itoa(task.Client.Id),
  54. strconv.FormatBool(task.UseClientCnf),
  55. task.Remark,
  56. }
  57. err := writer.Write(record)
  58. if err != nil {
  59. Fatalf(err.Error())
  60. }
  61. }
  62. writer.Flush()
  63. }
  64. func (s *Csv) openFile(path string) ([][]string, error) {
  65. // 打开文件
  66. file, err := os.Open(path)
  67. if err != nil {
  68. panic(err)
  69. }
  70. defer file.Close()
  71. // 获取csv的reader
  72. reader := csv.NewReader(file)
  73. // 设置FieldsPerRecord为-1
  74. reader.FieldsPerRecord = -1
  75. // 读取文件中所有行保存到slice中
  76. return reader.ReadAll()
  77. }
  78. func (s *Csv) LoadTaskFromCsv() {
  79. path := filepath.Join(GetRunPath(), "conf", "tasks.csv")
  80. records, err := s.openFile(path)
  81. if err != nil {
  82. Fatalln("配置文件打开错误:", path)
  83. }
  84. var tasks []*Tunnel
  85. // 将每一行数据保存到内存slice中
  86. for _, item := range records {
  87. post := &Tunnel{
  88. TcpPort: GetIntNoErrByStr(item[0]),
  89. Mode: item[1],
  90. Target: item[2],
  91. Config: &Config{
  92. U: item[3],
  93. P: item[4],
  94. Compress: item[5],
  95. Crypt: GetBoolByStr(item[7]),
  96. CompressEncode: GetIntNoErrByStr(item[8]),
  97. CompressDecode: GetIntNoErrByStr(item[9]),
  98. },
  99. Status: GetBoolByStr(item[6]),
  100. Id: GetIntNoErrByStr(item[10]),
  101. UseClientCnf: GetBoolByStr(item[12]),
  102. Remark: item[13],
  103. }
  104. post.Flow = new(Flow)
  105. if post.Client, err = s.GetClient(GetIntNoErrByStr(item[11])); err != nil {
  106. continue
  107. }
  108. tasks = append(tasks, post)
  109. if post.Id > s.TaskIncreaseId {
  110. s.TaskIncreaseId = post.Id
  111. }
  112. }
  113. s.Tasks = tasks
  114. }
  115. func (s *Csv) GetTaskId() int {
  116. s.Lock()
  117. defer s.Unlock()
  118. s.TaskIncreaseId++
  119. return s.TaskIncreaseId
  120. }
  121. func (s *Csv) GetIdByVerifyKey(vKey string, addr string) (int, error) {
  122. s.Lock()
  123. defer s.Unlock()
  124. for _, v := range s.Clients {
  125. if Getverifyval(v.VerifyKey) == vKey && v.Status {
  126. if arr := strings.Split(addr, ":"); len(arr) > 0 {
  127. v.Addr = arr[0]
  128. }
  129. return v.Id, nil
  130. }
  131. }
  132. return 0, errors.New("not found")
  133. }
  134. func (s *Csv) NewTask(t *Tunnel) {
  135. t.Flow = new(Flow)
  136. s.Tasks = append(s.Tasks, t)
  137. s.StoreTasksToCsv()
  138. }
  139. func (s *Csv) UpdateTask(t *Tunnel) error {
  140. for k, v := range s.Tasks {
  141. if v.Id == t.Id {
  142. s.Tasks = append(s.Tasks[:k], s.Tasks[k+1:]...)
  143. s.Tasks = append(s.Tasks, t)
  144. s.StoreTasksToCsv()
  145. return nil
  146. }
  147. }
  148. return errors.New("不存在")
  149. }
  150. func (s *Csv) DelTask(id int) error {
  151. for k, v := range s.Tasks {
  152. if v.Id == id {
  153. s.Tasks = append(s.Tasks[:k], s.Tasks[k+1:]...)
  154. s.StoreTasksToCsv()
  155. return nil
  156. }
  157. }
  158. return errors.New("不存在")
  159. }
  160. func (s *Csv) GetTask(id int) (v *Tunnel, err error) {
  161. for _, v = range s.Tasks {
  162. if v.Id == id {
  163. return
  164. }
  165. }
  166. err = errors.New("未找到")
  167. return
  168. }
  169. func (s *Csv) StoreHostToCsv() {
  170. // 创建文件
  171. csvFile, err := os.Create(filepath.Join(GetRunPath(), "conf", "hosts.csv"))
  172. if err != nil {
  173. panic(err)
  174. }
  175. defer csvFile.Close()
  176. // 获取csv的Writer
  177. writer := csv.NewWriter(csvFile)
  178. // 将map中的Post转换成slice,因为csv的Write需要slice参数
  179. // 并写入csv文件
  180. for _, host := range s.Hosts {
  181. record := []string{
  182. host.Host,
  183. host.Target,
  184. strconv.Itoa(host.Client.Id),
  185. host.HeaderChange,
  186. host.HostChange,
  187. host.Remark,
  188. }
  189. err1 := writer.Write(record)
  190. if err1 != nil {
  191. panic(err1)
  192. }
  193. }
  194. // 确保所有内存数据刷到csv文件
  195. writer.Flush()
  196. }
  197. func (s *Csv) LoadClientFromCsv() {
  198. path := filepath.Join(GetRunPath(), "conf", "clients.csv")
  199. records, err := s.openFile(path)
  200. if err != nil {
  201. Fatalln("配置文件打开错误:", path)
  202. }
  203. var clients []*Client
  204. // 将每一行数据保存到内存slice中
  205. for _, item := range records {
  206. post := &Client{
  207. Id: GetIntNoErrByStr(item[0]),
  208. VerifyKey: item[1],
  209. Remark: item[2],
  210. Status: GetBoolByStr(item[3]),
  211. RateLimit: GetIntNoErrByStr(item[8]),
  212. Cnf: &Config{
  213. U: item[4],
  214. P: item[5],
  215. Crypt: GetBoolByStr(item[6]),
  216. Compress: item[7],
  217. },
  218. }
  219. if post.Id > s.ClientIncreaseId {
  220. s.ClientIncreaseId = post.Id
  221. }
  222. if post.RateLimit > 0 {
  223. post.Rate = NewRate(int64(post.RateLimit * 1024))
  224. post.Rate.Start()
  225. }
  226. post.Flow = new(Flow)
  227. post.Flow.FlowLimit = int64(GetIntNoErrByStr(item[9]))
  228. clients = append(clients, post)
  229. }
  230. s.Clients = clients
  231. }
  232. func (s *Csv) LoadHostFromCsv() {
  233. path := filepath.Join(GetRunPath(), "conf", "hosts.csv")
  234. records, err := s.openFile(path)
  235. if err != nil {
  236. Fatalln("配置文件打开错误:", path)
  237. }
  238. var hosts []*Host
  239. // 将每一行数据保存到内存slice中
  240. for _, item := range records {
  241. post := &Host{
  242. Host: item[0],
  243. Target: item[1],
  244. HeaderChange: item[3],
  245. HostChange: item[4],
  246. Remark: item[5],
  247. }
  248. if post.Client, err = s.GetClient(GetIntNoErrByStr(item[2])); err != nil {
  249. continue
  250. }
  251. post.Flow = new(Flow)
  252. hosts = append(hosts, post)
  253. }
  254. s.Hosts = hosts
  255. }
  256. func (s *Csv) DelHost(host string) error {
  257. for k, v := range s.Hosts {
  258. if v.Host == host {
  259. s.Hosts = append(s.Hosts[:k], s.Hosts[k+1:]...)
  260. s.StoreHostToCsv()
  261. return nil
  262. }
  263. }
  264. return errors.New("不存在")
  265. }
  266. func (s *Csv) NewHost(t *Host) {
  267. t.Flow = new(Flow)
  268. s.Hosts = append(s.Hosts, t)
  269. s.StoreHostToCsv()
  270. }
  271. func (s *Csv) UpdateHost(t *Host) error {
  272. for k, v := range s.Hosts {
  273. if v.Host == t.Host {
  274. s.Hosts = append(s.Hosts[:k], s.Hosts[k+1:]...)
  275. s.Hosts = append(s.Hosts, t)
  276. s.StoreHostToCsv()
  277. return nil
  278. }
  279. }
  280. return errors.New("不存在")
  281. }
  282. func (s *Csv) GetHost(start, length int, id int) ([]*Host, int) {
  283. list := make([]*Host, 0)
  284. var cnt int
  285. for _, v := range s.Hosts {
  286. if id == 0 || v.Client.Id == id {
  287. cnt++
  288. if start--; start < 0 {
  289. if length--; length > 0 {
  290. list = append(list, v)
  291. }
  292. }
  293. }
  294. }
  295. return list, cnt
  296. }
  297. func (s *Csv) DelClient(id int) error {
  298. for k, v := range s.Clients {
  299. if v.Id == id {
  300. s.Clients = append(s.Clients[:k], s.Clients[k+1:]...)
  301. s.StoreClientsToCsv()
  302. return nil
  303. }
  304. }
  305. return errors.New("不存在")
  306. }
  307. func (s *Csv) NewClient(c *Client) {
  308. s.Lock()
  309. defer s.Unlock()
  310. c.Flow = new(Flow)
  311. s.Clients = append(s.Clients, c)
  312. s.StoreClientsToCsv()
  313. }
  314. func (s *Csv) GetClientId() int {
  315. s.Lock()
  316. defer s.Unlock()
  317. s.ClientIncreaseId++
  318. return s.ClientIncreaseId
  319. }
  320. func (s *Csv) UpdateClient(t *Client) error {
  321. s.Lock()
  322. defer s.Unlock()
  323. for _, v := range s.Clients {
  324. if v.Id == t.Id {
  325. v.Cnf = t.Cnf
  326. v.VerifyKey = t.VerifyKey
  327. v.Remark = t.Remark
  328. v.RateLimit = t.RateLimit
  329. v.Flow = t.Flow
  330. v.Rate = t.Rate
  331. s.StoreClientsToCsv()
  332. return nil
  333. }
  334. }
  335. return errors.New("不存在")
  336. }
  337. func (s *Csv) GetClientList(start, length int) ([]*Client, int) {
  338. list := make([]*Client, 0)
  339. var cnt int
  340. for _, v := range s.Clients {
  341. cnt++
  342. if start--; start < 0 {
  343. if length--; length > 0 {
  344. list = append(list, v)
  345. }
  346. }
  347. }
  348. return list, cnt
  349. }
  350. func (s *Csv) GetClient(id int) (v *Client, err error) {
  351. for _, v = range s.Clients {
  352. if v.Id == id {
  353. return
  354. }
  355. }
  356. err = errors.New("未找到")
  357. return
  358. }
  359. func (s *Csv) StoreClientsToCsv() {
  360. // 创建文件
  361. csvFile, err := os.Create(filepath.Join(GetRunPath(), "conf", "clients.csv"))
  362. if err != nil {
  363. Fatalln(err.Error())
  364. }
  365. defer csvFile.Close()
  366. writer := csv.NewWriter(csvFile)
  367. for _, client := range s.Clients {
  368. record := []string{
  369. strconv.Itoa(client.Id),
  370. client.VerifyKey,
  371. client.Remark,
  372. strconv.FormatBool(client.Status),
  373. client.Cnf.U,
  374. client.Cnf.P,
  375. GetStrByBool(client.Cnf.Crypt),
  376. client.Cnf.Compress,
  377. strconv.Itoa(client.RateLimit),
  378. strconv.Itoa(int(client.Flow.FlowLimit)),
  379. }
  380. err := writer.Write(record)
  381. if err != nil {
  382. Fatalln(err.Error())
  383. }
  384. }
  385. writer.Flush()
  386. }
  387. //init csv from file
  388. func GetCsvDb() *Csv {
  389. once.Do(func() {
  390. CsvDb = NewCsv()
  391. CsvDb.Init()
  392. })
  393. return CsvDb
  394. }