file.go 8.7 KB


  1. package utils
  2. import (
  3. "encoding/csv"
  4. "errors"
  5. "github.com/astaxie/beego"
  6. "log"
  7. "os"
  8. "strconv"
  9. "strings"
  10. "sync"
  11. )
  12. var (
  13. CsvDb *Csv
  14. once sync.Once
  15. )
  16. func NewCsv() *Csv {
  17. return new(Csv)
  18. }
  19. type Csv struct {
  20. Tasks []*Tunnel
  21. Path string
  22. Hosts []*Host //域名列表
  23. Clients []*Client //客户端
  24. ClientIncreaseId int //客户端id
  25. TaskIncreaseId int //任务自增ID
  26. sync.Mutex
  27. }
  28. func (s *Csv) Init() {
  29. s.LoadClientFromCsv()
  30. s.LoadTaskFromCsv()
  31. s.LoadHostFromCsv()
  32. }
  33. func (s *Csv) StoreTasksToCsv() {
  34. // 创建文件
  35. csvFile, err := os.Create(beego.AppPath + "/conf/tasks.csv")
  36. if err != nil {
  37. log.Fatalf(err.Error())
  38. }
  39. defer csvFile.Close()
  40. writer := csv.NewWriter(csvFile)
  41. for _, task := range s.Tasks {
  42. record := []string{
  43. strconv.Itoa(task.TcpPort),
  44. task.Mode,
  45. task.Target,
  46. task.Config.U,
  47. task.Config.P,
  48. task.Config.Compress,
  49. GetStrByBool(task.Status),
  50. GetStrByBool(task.Config.Crypt),
  51. strconv.Itoa(task.Config.CompressEncode),
  52. strconv.Itoa(task.Config.CompressDecode),
  53. strconv.Itoa(task.Id),
  54. strconv.Itoa(task.Client.Id),
  55. strconv.FormatBool(task.UseClientCnf),
  56. task.Remark,
  57. }
  58. err := writer.Write(record)
  59. if err != nil {
  60. log.Fatalf(err.Error())
  61. }
  62. }
  63. writer.Flush()
  64. }
  65. func (s *Csv) openFile(path string) ([][]string, error) {
  66. // 打开文件
  67. file, err := os.Open(path)
  68. if err != nil {
  69. panic(err)
  70. }
  71. defer file.Close()
  72. // 获取csv的reader
  73. reader := csv.NewReader(file)
  74. // 设置FieldsPerRecord为-1
  75. reader.FieldsPerRecord = -1
  76. // 读取文件中所有行保存到slice中
  77. return reader.ReadAll()
  78. }
  79. func (s *Csv) LoadTaskFromCsv() {
  80. path := beego.AppPath + "/conf/tasks.csv"
  81. records, err := s.openFile(path)
  82. if err != nil {
  83. log.Fatal("配置文件打开错误:", path)
  84. }
  85. var tasks []*Tunnel
  86. // 将每一行数据保存到内存slice中
  87. for _, item := range records {
  88. post := &Tunnel{
  89. TcpPort: GetIntNoErrByStr(item[0]),
  90. Mode: item[1],
  91. Target: item[2],
  92. Config: &Config{
  93. U: item[3],
  94. P: item[4],
  95. Compress: item[5],
  96. Crypt: GetBoolByStr(item[7]),
  97. CompressEncode: GetIntNoErrByStr(item[8]),
  98. CompressDecode: GetIntNoErrByStr(item[9]),
  99. },
  100. Status: GetBoolByStr(item[6]),
  101. Id: GetIntNoErrByStr(item[10]),
  102. UseClientCnf: GetBoolByStr(item[12]),
  103. Remark: item[13],
  104. }
  105. post.Flow = new(Flow)
  106. if post.Client, err = s.GetClient(GetIntNoErrByStr(item[11])); err != nil {
  107. continue
  108. }
  109. tasks = append(tasks, post)
  110. if post.Id > s.TaskIncreaseId {
  111. s.TaskIncreaseId = post.Id
  112. }
  113. }
  114. s.Tasks = tasks
  115. }
  116. func (s *Csv) GetTaskId() int {
  117. s.Lock()
  118. defer s.Unlock()
  119. s.TaskIncreaseId++
  120. return s.TaskIncreaseId
  121. }
  122. func (s *Csv) GetIdByVerifyKey(vKey string, addr string) (int, error) {
  123. s.Lock()
  124. defer s.Unlock()
  125. for _, v := range s.Clients {
  126. if Getverifyval(v.VerifyKey) == vKey && v.Status {
  127. if arr := strings.Split(addr, ":"); len(arr) > 0 {
  128. v.Addr = arr[0]
  129. }
  130. return v.Id, nil
  131. }
  132. }
  133. return 0, errors.New("not found")
  134. }
  135. func (s *Csv) NewTask(t *Tunnel) {
  136. t.Flow = new(Flow)
  137. s.Tasks = append(s.Tasks, t)
  138. s.StoreTasksToCsv()
  139. }
  140. func (s *Csv) UpdateTask(t *Tunnel) error {
  141. for k, v := range s.Tasks {
  142. if v.Id == t.Id {
  143. s.Tasks = append(s.Tasks[:k], s.Tasks[k+1:]...)
  144. s.Tasks = append(s.Tasks, t)
  145. s.StoreTasksToCsv()
  146. return nil
  147. }
  148. }
  149. return errors.New("不存在")
  150. }
  151. func (s *Csv) DelTask(id int) error {
  152. for k, v := range s.Tasks {
  153. if v.Id == id {
  154. s.Tasks = append(s.Tasks[:k], s.Tasks[k+1:]...)
  155. s.StoreTasksToCsv()
  156. return nil
  157. }
  158. }
  159. return errors.New("不存在")
  160. }
  161. func (s *Csv) GetTask(id int) (v *Tunnel, err error) {
  162. for _, v = range s.Tasks {
  163. if v.Id == id {
  164. return
  165. }
  166. }
  167. err = errors.New("未找到")
  168. return
  169. }
  170. func (s *Csv) StoreHostToCsv() {
  171. // 创建文件
  172. csvFile, err := os.Create(beego.AppPath + "/conf/hosts.csv")
  173. if err != nil {
  174. panic(err)
  175. }
  176. defer csvFile.Close()
  177. // 获取csv的Writer
  178. writer := csv.NewWriter(csvFile)
  179. // 将map中的Post转换成slice,因为csv的Write需要slice参数
  180. // 并写入csv文件
  181. for _, host := range s.Hosts {
  182. record := []string{
  183. host.Host,
  184. host.Target,
  185. strconv.Itoa(host.Client.Id),
  186. host.HeaderChange,
  187. host.HostChange,
  188. host.Remark,
  189. }
  190. err1 := writer.Write(record)
  191. if err1 != nil {
  192. panic(err1)
  193. }
  194. }
  195. // 确保所有内存数据刷到csv文件
  196. writer.Flush()
  197. }
  198. func (s *Csv) LoadClientFromCsv() {
  199. path := beego.AppPath + "/conf/clients.csv"
  200. records, err := s.openFile(path)
  201. if err != nil {
  202. log.Fatal("配置文件打开错误:", path)
  203. }
  204. var clients []*Client
  205. // 将每一行数据保存到内存slice中
  206. for _, item := range records {
  207. post := &Client{
  208. Id: GetIntNoErrByStr(item[0]),
  209. VerifyKey: item[1],
  210. Remark: item[2],
  211. Status: GetBoolByStr(item[3]),
  212. RateLimit: GetIntNoErrByStr(item[8]),
  213. Cnf: &Config{
  214. U: item[4],
  215. P: item[5],
  216. Crypt: GetBoolByStr(item[6]),
  217. Compress: item[7],
  218. },
  219. }
  220. if post.Id > s.ClientIncreaseId {
  221. s.ClientIncreaseId = post.Id
  222. }
  223. if post.RateLimit > 0 {
  224. post.Rate = NewRate(int64(post.RateLimit * 1024))
  225. post.Rate.Start()
  226. }
  227. post.Flow = new(Flow)
  228. post.Flow.FlowLimit = int64(GetIntNoErrByStr(item[9]))
  229. clients = append(clients, post)
  230. }
  231. s.Clients = clients
  232. }
  233. func (s *Csv) LoadHostFromCsv() {
  234. path := beego.AppPath + "/conf/hosts.csv"
  235. records, err := s.openFile(path)
  236. if err != nil {
  237. log.Fatal("配置文件打开错误:", path)
  238. }
  239. var hosts []*Host
  240. // 将每一行数据保存到内存slice中
  241. for _, item := range records {
  242. post := &Host{
  243. Host: item[0],
  244. Target: item[1],
  245. HeaderChange: item[3],
  246. HostChange: item[4],
  247. Remark: item[5],
  248. }
  249. if post.Client, err = s.GetClient(GetIntNoErrByStr(item[2])); err != nil {
  250. continue
  251. }
  252. post.Flow = new(Flow)
  253. hosts = append(hosts, post)
  254. }
  255. s.Hosts = hosts
  256. }
  257. func (s *Csv) DelHost(host string) error {
  258. for k, v := range s.Hosts {
  259. if v.Host == host {
  260. s.Hosts = append(s.Hosts[:k], s.Hosts[k+1:]...)
  261. s.StoreHostToCsv()
  262. return nil
  263. }
  264. }
  265. return errors.New("不存在")
  266. }
  267. func (s *Csv) NewHost(t *Host) {
  268. t.Flow = new(Flow)
  269. s.Hosts = append(s.Hosts, t)
  270. s.StoreHostToCsv()
  271. }
  272. func (s *Csv) UpdateHost(t *Host) error {
  273. for k, v := range s.Hosts {
  274. if v.Host == t.Host {
  275. s.Hosts = append(s.Hosts[:k], s.Hosts[k+1:]...)
  276. s.Hosts = append(s.Hosts, t)
  277. s.StoreHostToCsv()
  278. return nil
  279. }
  280. }
  281. return errors.New("不存在")
  282. }
  283. func (s *Csv) GetHost(start, length int, id int) ([]*Host, int) {
  284. list := make([]*Host, 0)
  285. var cnt int
  286. for _, v := range s.Hosts {
  287. if id == 0 || v.Client.Id == id {
  288. cnt++
  289. if start--; start < 0 {
  290. if length--; length > 0 {
  291. list = append(list, v)
  292. }
  293. }
  294. }
  295. }
  296. return list, cnt
  297. }
  298. func (s *Csv) DelClient(id int) error {
  299. for k, v := range s.Clients {
  300. if v.Id == id {
  301. s.Clients = append(s.Clients[:k], s.Clients[k+1:]...)
  302. s.StoreClientsToCsv()
  303. return nil
  304. }
  305. }
  306. return errors.New("不存在")
  307. }
  308. func (s *Csv) NewClient(c *Client) {
  309. s.Lock()
  310. defer s.Unlock()
  311. c.Flow = new(Flow)
  312. s.Clients = append(s.Clients, c)
  313. s.StoreClientsToCsv()
  314. }
  315. func (s *Csv) GetClientId() int {
  316. s.Lock()
  317. defer s.Unlock()
  318. s.ClientIncreaseId++
  319. return s.ClientIncreaseId
  320. }
  321. func (s *Csv) UpdateClient(t *Client) error {
  322. s.Lock()
  323. defer s.Unlock()
  324. for _, v := range s.Clients {
  325. if v.Id == t.Id {
  326. v.Cnf = t.Cnf
  327. v.VerifyKey = t.VerifyKey
  328. v.Remark = t.Remark
  329. v.RateLimit = t.RateLimit
  330. v.Flow = t.Flow
  331. v.Rate = t.Rate
  332. s.StoreClientsToCsv()
  333. return nil
  334. }
  335. }
  336. return errors.New("不存在")
  337. }
  338. func (s *Csv) GetClientList(start, length int) ([]*Client, int) {
  339. list := make([]*Client, 0)
  340. var cnt int
  341. for _, v := range s.Clients {
  342. cnt++
  343. if start--; start < 0 {
  344. if length--; length > 0 {
  345. list = append(list, v)
  346. }
  347. }
  348. }
  349. return list, cnt
  350. }
  351. func (s *Csv) GetClient(id int) (v *Client, err error) {
  352. for _, v = range s.Clients {
  353. if v.Id == id {
  354. return
  355. }
  356. }
  357. err = errors.New("未找到")
  358. return
  359. }
  360. func (s *Csv) StoreClientsToCsv() {
  361. // 创建文件
  362. csvFile, err := os.Create(beego.AppPath + "/conf/clients.csv")
  363. if err != nil {
  364. log.Fatalf(err.Error())
  365. }
  366. defer csvFile.Close()
  367. writer := csv.NewWriter(csvFile)
  368. for _, client := range s.Clients {
  369. record := []string{
  370. strconv.Itoa(client.Id),
  371. client.VerifyKey,
  372. client.Remark,
  373. strconv.FormatBool(client.Status),
  374. client.Cnf.U,
  375. client.Cnf.P,
  376. GetStrByBool(client.Cnf.Crypt),
  377. client.Cnf.Compress,
  378. strconv.Itoa(client.RateLimit),
  379. strconv.Itoa(int(client.Flow.FlowLimit)),
  380. }
  381. err := writer.Write(record)
  382. if err != nil {
  383. log.Fatalf(err.Error())
  384. }
  385. }
  386. writer.Flush()
  387. }
  388. //init csv from file
  389. func GetCsvDb() *Csv {
  390. once.Do(func() {
  391. CsvDb = NewCsv()
  392. CsvDb.Init()
  393. })
  394. return CsvDb
  395. }