file.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600
  1. package file
  2. import (
  3. "encoding/csv"
  4. "errors"
  5. "fmt"
  6. "github.com/cnlh/nps/lib/common"
  7. "github.com/cnlh/nps/lib/crypt"
  8. "github.com/cnlh/nps/lib/rate"
  9. "github.com/cnlh/nps/vender/github.com/astaxie/beego"
  10. "github.com/cnlh/nps/vender/github.com/astaxie/beego/logs"
  11. "net/http"
  12. "os"
  13. "path/filepath"
  14. "regexp"
  15. "strconv"
  16. "strings"
  17. "sync"
  18. "sync/atomic"
  19. )
  20. func NewCsv(runPath string) *Csv {
  21. return &Csv{
  22. RunPath: runPath,
  23. }
  24. }
  25. type Csv struct {
  26. Tasks sync.Map
  27. Hosts sync.Map //域名列表
  28. HostsTmp sync.Map
  29. Clients sync.Map //客户端
  30. RunPath string //存储根目录
  31. ClientIncreaseId int32 //客户端id
  32. TaskIncreaseId int32 //任务自增ID
  33. HostIncreaseId int32 //host increased id
  34. }
  35. func (s *Csv) StoreTasksToCsv() {
  36. // 创建文件
  37. csvFile, err := os.Create(filepath.Join(s.RunPath, "conf", "tasks.csv"))
  38. if err != nil {
  39. logs.Error(err.Error())
  40. }
  41. defer csvFile.Close()
  42. writer := csv.NewWriter(csvFile)
  43. s.Tasks.Range(func(key, value interface{}) bool {
  44. task := value.(*Tunnel)
  45. if task.NoStore {
  46. return true
  47. }
  48. record := []string{
  49. strconv.Itoa(task.Port),
  50. task.Mode,
  51. task.Target,
  52. common.GetStrByBool(task.Status),
  53. strconv.Itoa(task.Id),
  54. strconv.Itoa(task.Client.Id),
  55. task.Remark,
  56. strconv.Itoa(int(task.Flow.ExportFlow)),
  57. strconv.Itoa(int(task.Flow.InletFlow)),
  58. task.Password,
  59. task.ServerIp,
  60. }
  61. err := writer.Write(record)
  62. if err != nil {
  63. logs.Error(err.Error())
  64. }
  65. return true
  66. })
  67. writer.Flush()
  68. }
  69. func (s *Csv) openFile(path string) ([][]string, error) {
  70. // 打开文件
  71. file, err := os.Open(path)
  72. if err != nil {
  73. panic(err)
  74. }
  75. defer file.Close()
  76. // 获取csv的reader
  77. reader := csv.NewReader(file)
  78. // 设置FieldsPerRecord为-1
  79. reader.FieldsPerRecord = -1
  80. // 读取文件中所有行保存到slice中
  81. return reader.ReadAll()
  82. }
  83. func (s *Csv) LoadTaskFromCsv() {
  84. path := filepath.Join(s.RunPath, "conf", "tasks.csv")
  85. records, err := s.openFile(path)
  86. if err != nil {
  87. logs.Error("Profile Opening Error:", path)
  88. os.Exit(0)
  89. }
  90. // 将每一行数据保存到内存slice中
  91. for _, item := range records {
  92. post := &Tunnel{
  93. Port: common.GetIntNoErrByStr(item[0]),
  94. Mode: item[1],
  95. Target: item[2],
  96. Status: common.GetBoolByStr(item[3]),
  97. Id: common.GetIntNoErrByStr(item[4]),
  98. Remark: item[6],
  99. Password: item[9],
  100. }
  101. post.Flow = new(Flow)
  102. post.Flow.ExportFlow = int64(common.GetIntNoErrByStr(item[7]))
  103. post.Flow.InletFlow = int64(common.GetIntNoErrByStr(item[8]))
  104. if post.Client, err = s.GetClient(common.GetIntNoErrByStr(item[5])); err != nil {
  105. continue
  106. }
  107. if len(item) > 10 {
  108. post.ServerIp = item[10]
  109. } else {
  110. post.ServerIp = "0.0.0.0"
  111. }
  112. s.Tasks.Store(post.Id, post)
  113. if post.Id > int(s.TaskIncreaseId) {
  114. s.TaskIncreaseId = int32(s.TaskIncreaseId)
  115. }
  116. }
  117. }
  118. func (s *Csv) GetIdByVerifyKey(vKey string, addr string) (id int, err error) {
  119. var exist bool
  120. s.Clients.Range(func(key, value interface{}) bool {
  121. v := value.(*Client)
  122. if common.Getverifyval(v.VerifyKey) == vKey && v.Status {
  123. v.Addr = common.GetIpByAddr(addr)
  124. id = v.Id
  125. exist = true
  126. return false
  127. }
  128. return true
  129. })
  130. if exist {
  131. return
  132. }
  133. return 0, errors.New("not found")
  134. }
  135. func (s *Csv) NewTask(t *Tunnel) (err error) {
  136. s.Tasks.Range(func(key, value interface{}) bool {
  137. v := value.(*Tunnel)
  138. if (v.Mode == "secret" || v.Mode == "p2p") && v.Password == t.Password {
  139. err = errors.New(fmt.Sprintf("Secret mode keys %s must be unique", t.Password))
  140. return false
  141. }
  142. return true
  143. })
  144. if err != nil {
  145. return
  146. }
  147. t.Flow = new(Flow)
  148. s.Tasks.Store(t.Id, t)
  149. s.StoreTasksToCsv()
  150. return
  151. }
  152. func (s *Csv) UpdateTask(t *Tunnel) error {
  153. s.Tasks.Store(t.Id, t)
  154. s.StoreTasksToCsv()
  155. return nil
  156. }
  157. func (s *Csv) DelTask(id int) error {
  158. s.Tasks.Delete(id)
  159. s.StoreTasksToCsv()
  160. return nil
  161. }
  162. //md5 password
  163. func (s *Csv) GetTaskByMd5Password(p string) (t *Tunnel) {
  164. s.Tasks.Range(func(key, value interface{}) bool {
  165. if crypt.Md5(value.(*Tunnel).Password) == p {
  166. t = value.(*Tunnel)
  167. return false
  168. }
  169. return true
  170. })
  171. return
  172. }
  173. func (s *Csv) GetTask(id int) (t *Tunnel, err error) {
  174. if v, ok := s.Tasks.Load(id); ok {
  175. t = v.(*Tunnel)
  176. return
  177. }
  178. err = errors.New("not found")
  179. return
  180. }
  181. func (s *Csv) StoreHostToCsv() {
  182. // 创建文件
  183. csvFile, err := os.Create(filepath.Join(s.RunPath, "conf", "hosts.csv"))
  184. if err != nil {
  185. panic(err)
  186. }
  187. defer csvFile.Close()
  188. // 获取csv的Writer
  189. writer := csv.NewWriter(csvFile)
  190. // 将map中的Post转换成slice,因为csv的Write需要slice参数
  191. // 并写入csv文件
  192. s.Hosts.Range(func(key, value interface{}) bool {
  193. host := value.(*Host)
  194. if host.NoStore {
  195. return true
  196. }
  197. record := []string{
  198. host.Host,
  199. host.Target,
  200. strconv.Itoa(host.Client.Id),
  201. host.HeaderChange,
  202. host.HostChange,
  203. host.Remark,
  204. host.Location,
  205. strconv.Itoa(host.Id),
  206. strconv.Itoa(int(host.Flow.ExportFlow)),
  207. strconv.Itoa(int(host.Flow.InletFlow)),
  208. host.Scheme,
  209. }
  210. err1 := writer.Write(record)
  211. if err1 != nil {
  212. panic(err1)
  213. }
  214. return true
  215. })
  216. // 确保所有内存数据刷到csv文件
  217. writer.Flush()
  218. }
  219. func (s *Csv) LoadClientFromCsv() {
  220. path := filepath.Join(s.RunPath, "conf", "clients.csv")
  221. records, err := s.openFile(path)
  222. if err != nil {
  223. logs.Error("Profile Opening Error:", path)
  224. os.Exit(0)
  225. }
  226. // 将每一行数据保存到内存slice中
  227. for _, item := range records {
  228. post := &Client{
  229. Id: common.GetIntNoErrByStr(item[0]),
  230. VerifyKey: item[1],
  231. Remark: item[2],
  232. Status: common.GetBoolByStr(item[3]),
  233. RateLimit: common.GetIntNoErrByStr(item[8]),
  234. Cnf: &Config{
  235. U: item[4],
  236. P: item[5],
  237. Crypt: common.GetBoolByStr(item[6]),
  238. Compress: common.GetBoolByStr(item[7]),
  239. },
  240. MaxConn: common.GetIntNoErrByStr(item[10]),
  241. }
  242. if post.Id > int(s.ClientIncreaseId) {
  243. s.ClientIncreaseId = int32(post.Id)
  244. }
  245. if post.RateLimit > 0 {
  246. post.Rate = rate.NewRate(int64(post.RateLimit * 1024))
  247. post.Rate.Start()
  248. } else {
  249. post.Rate = rate.NewRate(int64(2 << 23))
  250. post.Rate.Start()
  251. }
  252. post.Flow = new(Flow)
  253. post.Flow.FlowLimit = int64(common.GetIntNoErrByStr(item[9]))
  254. if len(item) >= 12 {
  255. post.ConfigConnAllow = common.GetBoolByStr(item[11])
  256. } else {
  257. post.ConfigConnAllow = true
  258. }
  259. s.Clients.Store(post.Id, post)
  260. }
  261. }
  262. func (s *Csv) LoadHostFromCsv() {
  263. path := filepath.Join(s.RunPath, "conf", "hosts.csv")
  264. records, err := s.openFile(path)
  265. if err != nil {
  266. logs.Error("Profile Opening Error:", path)
  267. os.Exit(0)
  268. }
  269. // 将每一行数据保存到内存slice中
  270. for _, item := range records {
  271. post := &Host{
  272. Host: item[0],
  273. Target: item[1],
  274. HeaderChange: item[3],
  275. HostChange: item[4],
  276. Remark: item[5],
  277. Location: item[6],
  278. Id: common.GetIntNoErrByStr(item[7]),
  279. }
  280. if post.Client, err = s.GetClient(common.GetIntNoErrByStr(item[2])); err != nil {
  281. continue
  282. }
  283. post.Flow = new(Flow)
  284. post.Flow.ExportFlow = int64(common.GetIntNoErrByStr(item[8]))
  285. post.Flow.InletFlow = int64(common.GetIntNoErrByStr(item[9]))
  286. if len(item) > 10 {
  287. post.Scheme = item[10]
  288. } else {
  289. post.Scheme = "all"
  290. }
  291. s.Hosts.Store(post.Id, post)
  292. if post.Id > int(s.HostIncreaseId) {
  293. s.HostIncreaseId = int32(post.Id)
  294. }
  295. //store host to hostMap if the host url is none
  296. }
  297. }
  298. func (s *Csv) DelHost(id int) error {
  299. s.Hosts.Delete(id)
  300. s.StoreHostToCsv()
  301. return nil
  302. }
  303. func (s *Csv) GetMapLen(m sync.Map) int {
  304. var c int
  305. m.Range(func(key, value interface{}) bool {
  306. c++
  307. return true
  308. })
  309. return c
  310. }
  311. func (s *Csv) IsHostExist(h *Host) bool {
  312. var exist bool
  313. s.Hosts.Range(func(key, value interface{}) bool {
  314. v := value.(*Host)
  315. if v.Host == h.Host && h.Location == v.Location && (v.Scheme == "all" || v.Scheme == h.Scheme) {
  316. exist = true
  317. return false
  318. }
  319. return true
  320. })
  321. return exist
  322. }
  323. func (s *Csv) NewHost(t *Host) error {
  324. if s.IsHostExist(t) {
  325. return errors.New("host has exist")
  326. }
  327. if t.Location == "" {
  328. t.Location = "/"
  329. }
  330. t.Flow = new(Flow)
  331. s.Hosts.Store(t.Id, t)
  332. s.StoreHostToCsv()
  333. return nil
  334. }
  335. func (s *Csv) GetHost(start, length int, id int, search string) ([]*Host, int) {
  336. list := make([]*Host, 0)
  337. var cnt int
  338. keys := common.GetMapKeys(s.Hosts)
  339. for _, key := range keys {
  340. if value, ok := s.Hosts.Load(key); ok {
  341. v := value.(*Host)
  342. if search != "" && !(v.Id == common.GetIntNoErrByStr(search) || strings.Contains(v.Host, search) || strings.Contains(v.Remark, search)) {
  343. continue
  344. }
  345. if id == 0 || v.Client.Id == id {
  346. cnt++
  347. if start--; start < 0 {
  348. if length--; length > 0 {
  349. list = append(list, v)
  350. }
  351. }
  352. }
  353. }
  354. }
  355. return list, cnt
  356. }
  357. func (s *Csv) DelClient(id int) error {
  358. s.Clients.Delete(id)
  359. s.StoreClientsToCsv()
  360. return nil
  361. }
  362. func (s *Csv) NewClient(c *Client) error {
  363. var isNotSet bool
  364. reset:
  365. if c.VerifyKey == "" || isNotSet {
  366. isNotSet = true
  367. c.VerifyKey = crypt.GetRandomString(16)
  368. }
  369. if c.RateLimit == 0 {
  370. c.Rate = rate.NewRate(int64(2 << 23))
  371. c.Rate.Start()
  372. }
  373. if !s.VerifyVkey(c.VerifyKey, c.id) {
  374. if isNotSet {
  375. goto reset
  376. }
  377. return errors.New("Vkey duplicate, please reset")
  378. }
  379. if c.Id == 0 {
  380. c.Id = int(s.GetClientId())
  381. }
  382. if c.Flow == nil {
  383. c.Flow = new(Flow)
  384. }
  385. s.Clients.Store(c.Id, c)
  386. s.StoreClientsToCsv()
  387. return nil
  388. }
  389. func (s *Csv) VerifyVkey(vkey string, id int) (res bool) {
  390. res = true
  391. s.Clients.Range(func(key, value interface{}) bool {
  392. v := value.(*Client)
  393. if v.VerifyKey == vkey && v.Id != id {
  394. res = false
  395. return false
  396. }
  397. return true
  398. })
  399. return res
  400. }
  401. func (s *Csv) GetClientId() int32 {
  402. return atomic.AddInt32(&s.ClientIncreaseId, 1)
  403. }
  404. func (s *Csv) GetTaskId() int32 {
  405. return atomic.AddInt32(&s.TaskIncreaseId, 1)
  406. }
  407. func (s *Csv) GetHostId() int32 {
  408. return atomic.AddInt32(&s.HostIncreaseId, 1)
  409. }
  410. func (s *Csv) UpdateClient(t *Client) error {
  411. s.Clients.Store(t.Id, t)
  412. if t.RateLimit == 0 {
  413. t.Rate = rate.NewRate(int64(2 << 23))
  414. t.Rate.Start()
  415. }
  416. return nil
  417. }
  418. func (s *Csv) GetClientList(start, length int, search string, clientId int) ([]*Client, int) {
  419. list := make([]*Client, 0)
  420. var cnt int
  421. keys := common.GetMapKeys(s.Clients)
  422. for _, key := range keys {
  423. if value, ok := s.Clients.Load(key); ok {
  424. v := value.(*Client)
  425. if v.NoDisplay {
  426. continue
  427. }
  428. if clientId != 0 && clientId != v.Id {
  429. continue
  430. }
  431. if search != "" && !(v.Id == common.GetIntNoErrByStr(search) || strings.Contains(v.VerifyKey, search) || strings.Contains(v.Remark, search)) {
  432. continue
  433. }
  434. cnt++
  435. if start--; start < 0 {
  436. if length--; length > 0 {
  437. list = append(list, v)
  438. }
  439. }
  440. }
  441. }
  442. return list, cnt
  443. }
  444. func (s *Csv) IsPubClient(id int) bool {
  445. client, err := s.GetClient(id)
  446. if err == nil {
  447. if client.VerifyKey == beego.AppConfig.String("public_vkey") {
  448. return true
  449. } else {
  450. return false
  451. }
  452. }
  453. return false
  454. }
  455. func (s *Csv) GetClient(id int) (c *Client, err error) {
  456. if v, ok := s.Clients.Load(id); ok {
  457. c = v.(*Client)
  458. return
  459. }
  460. err = errors.New("未找到客户端")
  461. return
  462. }
  463. func (s *Csv) GetClientIdByVkey(vkey string) (id int, err error) {
  464. var exist bool
  465. s.Clients.Range(func(key, value interface{}) bool {
  466. v := value.(*Client)
  467. if crypt.Md5(v.VerifyKey) == vkey {
  468. exist = true
  469. id = v.Id
  470. return false
  471. }
  472. return true
  473. })
  474. if exist {
  475. return
  476. }
  477. err = errors.New("未找到客户端")
  478. return
  479. }
  480. func (s *Csv) GetHostById(id int) (h *Host, err error) {
  481. if v, ok := s.Hosts.Load(id); ok {
  482. h = v.(*Host)
  483. return
  484. }
  485. err = errors.New("The host could not be parsed")
  486. return
  487. }
  488. //get key by host from x
  489. func (s *Csv) GetInfoByHost(host string, r *http.Request) (h *Host, err error) {
  490. var hosts []*Host
  491. //Handling Ported Access
  492. host = common.GetIpByAddr(host)
  493. s.Hosts.Range(func(key, value interface{}) bool {
  494. v := value.(*Host)
  495. if v.IsClose {
  496. return true
  497. }
  498. //Remove http(s) http(s)://a.proxy.com
  499. //*.proxy.com *.a.proxy.com Do some pan-parsing
  500. tmp := strings.Replace(v.Host, "*", `\w+?`, -1)
  501. var re *regexp.Regexp
  502. if re, err = regexp.Compile(tmp); err != nil {
  503. return true
  504. }
  505. if len(re.FindAllString(host, -1)) > 0 && (v.Scheme == "all" || v.Scheme == r.URL.Scheme) {
  506. //URL routing
  507. hosts = append(hosts, v)
  508. }
  509. return true
  510. })
  511. for _, v := range hosts {
  512. //If not set, default matches all
  513. if v.Location == "" {
  514. v.Location = "/"
  515. }
  516. if strings.Index(r.RequestURI, v.Location) == 0 {
  517. if h == nil || (len(v.Location) > len(h.Location)) {
  518. h = v
  519. }
  520. }
  521. }
  522. if h != nil {
  523. return
  524. }
  525. err = errors.New("The host could not be parsed")
  526. return
  527. }
  528. func (s *Csv) StoreClientsToCsv() {
  529. // 创建文件
  530. csvFile, err := os.Create(filepath.Join(s.RunPath, "conf", "clients.csv"))
  531. if err != nil {
  532. logs.Error(err.Error())
  533. }
  534. defer csvFile.Close()
  535. writer := csv.NewWriter(csvFile)
  536. s.Clients.Range(func(key, value interface{}) bool {
  537. client := value.(*Client)
  538. if client.NoStore {
  539. return true
  540. }
  541. record := []string{
  542. strconv.Itoa(client.Id),
  543. client.VerifyKey,
  544. client.Remark,
  545. strconv.FormatBool(client.Status),
  546. client.Cnf.U,
  547. client.Cnf.P,
  548. common.GetStrByBool(client.Cnf.Crypt),
  549. strconv.FormatBool(client.Cnf.Compress),
  550. strconv.Itoa(client.RateLimit),
  551. strconv.Itoa(int(client.Flow.FlowLimit)),
  552. strconv.Itoa(int(client.MaxConn)),
  553. common.GetStrByBool(client.ConfigConnAllow),
  554. }
  555. err := writer.Write(record)
  556. if err != nil {
  557. logs.Error(err.Error())
  558. }
  559. return true
  560. })
  561. writer.Flush()
  562. }