|
@@ -1,18 +1,16 @@
|
|
|
package file
|
|
|
|
|
|
import (
|
|
|
- "encoding/csv"
|
|
|
+ "encoding/json"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
"github.com/cnlh/nps/lib/common"
|
|
|
"github.com/cnlh/nps/lib/crypt"
|
|
|
"github.com/cnlh/nps/lib/rate"
|
|
|
- "github.com/cnlh/nps/vender/github.com/astaxie/beego/logs"
|
|
|
"net/http"
|
|
|
"os"
|
|
|
"path/filepath"
|
|
|
"regexp"
|
|
|
- "strconv"
|
|
|
"strings"
|
|
|
"sync"
|
|
|
"sync/atomic"
|
|
@@ -20,7 +18,10 @@ import (
|
|
|
|
|
|
func NewCsv(runPath string) *Csv {
|
|
|
return &Csv{
|
|
|
- RunPath: runPath,
|
|
|
+ RunPath: runPath,
|
|
|
+ TaskFilePath: filepath.Join(runPath, "conf", "tasks.json"),
|
|
|
+ HostFilePath: filepath.Join(runPath, "conf", "hosts.json"),
|
|
|
+ ClientFilePath: filepath.Join(runPath, "conf", "clients.json"),
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -33,96 +34,62 @@ type Csv struct {
|
|
|
ClientIncreaseId int32 //客户端id
|
|
|
TaskIncreaseId int32 //任务自增ID
|
|
|
HostIncreaseId int32 //host increased id
|
|
|
+ TaskFilePath string
|
|
|
+ HostFilePath string
|
|
|
+ ClientFilePath string
|
|
|
}
|
|
|
|
|
|
-func (s *Csv) StoreTasksToCsv() {
|
|
|
- // 创建文件
|
|
|
- csvFile, err := os.Create(filepath.Join(s.RunPath, "conf", "tasks.csv"))
|
|
|
- if err != nil {
|
|
|
- logs.Error(err.Error())
|
|
|
- }
|
|
|
- defer csvFile.Close()
|
|
|
- writer := csv.NewWriter(csvFile)
|
|
|
- s.Tasks.Range(func(key, value interface{}) bool {
|
|
|
- task := value.(*Tunnel)
|
|
|
- if task.NoStore {
|
|
|
- return true
|
|
|
+func (s *Csv) LoadTaskFromCsv() {
|
|
|
+ loadSyncMapFromFile(s.TaskFilePath, func(v string) {
|
|
|
+ var err error
|
|
|
+ post := new(Tunnel)
|
|
|
+ if json.Unmarshal([]byte(v), &post) != nil {
|
|
|
+ return
|
|
|
}
|
|
|
- record := []string{
|
|
|
- strconv.Itoa(task.Port),
|
|
|
- task.Mode,
|
|
|
- task.Target.TargetStr,
|
|
|
- common.GetStrByBool(task.Status),
|
|
|
- strconv.Itoa(task.Id),
|
|
|
- strconv.Itoa(task.Client.Id),
|
|
|
- task.Remark,
|
|
|
- strconv.Itoa(int(task.Flow.ExportFlow)),
|
|
|
- strconv.Itoa(int(task.Flow.InletFlow)),
|
|
|
- task.Password,
|
|
|
- task.ServerIp,
|
|
|
+ if post.Client, err = s.GetClient(post.Client.Id); err != nil {
|
|
|
+ return
|
|
|
}
|
|
|
- err := writer.Write(record)
|
|
|
- if err != nil {
|
|
|
- logs.Error(err.Error())
|
|
|
+ s.Tasks.Store(post.Id, post)
|
|
|
+ if post.Id > int(s.TaskIncreaseId) {
|
|
|
+ s.TaskIncreaseId = int32(post.Id)
|
|
|
}
|
|
|
- return true
|
|
|
})
|
|
|
- writer.Flush()
|
|
|
}
|
|
|
|
|
|
-func (s *Csv) openFile(path string) ([][]string, error) {
|
|
|
- // 打开文件
|
|
|
- file, err := os.Open(path)
|
|
|
- if err != nil {
|
|
|
- panic(err)
|
|
|
- }
|
|
|
- defer file.Close()
|
|
|
-
|
|
|
- // 获取csv的reader
|
|
|
- reader := csv.NewReader(file)
|
|
|
-
|
|
|
- // 设置FieldsPerRecord为-1
|
|
|
- reader.FieldsPerRecord = -1
|
|
|
-
|
|
|
- // 读取文件中所有行保存到slice中
|
|
|
- return reader.ReadAll()
|
|
|
+func (s *Csv) LoadClientFromCsv() {
|
|
|
+ loadSyncMapFromFile(s.ClientFilePath, func(v string) {
|
|
|
+ post := new(Client)
|
|
|
+ if json.Unmarshal([]byte(v), &post) != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ if post.RateLimit > 0 {
|
|
|
+ post.Rate = rate.NewRate(int64(post.RateLimit * 1024))
|
|
|
+ } else {
|
|
|
+ post.Rate = rate.NewRate(int64(2 << 23))
|
|
|
+ }
|
|
|
+ post.Rate.Start()
|
|
|
+ s.Clients.Store(post.Id, post)
|
|
|
+ if post.Id > int(s.ClientIncreaseId) {
|
|
|
+ s.ClientIncreaseId = int32(post.Id)
|
|
|
+ }
|
|
|
+ })
|
|
|
}
|
|
|
|
|
|
-func (s *Csv) LoadTaskFromCsv() {
|
|
|
- path := filepath.Join(s.RunPath, "conf", "tasks.csv")
|
|
|
- records, err := s.openFile(path)
|
|
|
- if err != nil {
|
|
|
- logs.Error("Profile Opening Error:", path)
|
|
|
- os.Exit(0)
|
|
|
- }
|
|
|
- // 将每一行数据保存到内存slice中
|
|
|
- for _, item := range records {
|
|
|
- post := &Tunnel{
|
|
|
- Port: common.GetIntNoErrByStr(item[0]),
|
|
|
- Mode: item[1],
|
|
|
- Status: common.GetBoolByStr(item[3]),
|
|
|
- Id: common.GetIntNoErrByStr(item[4]),
|
|
|
- Remark: item[6],
|
|
|
- Password: item[9],
|
|
|
- }
|
|
|
- post.Target = new(Target)
|
|
|
- post.Target.TargetStr = item[2]
|
|
|
- post.Flow = new(Flow)
|
|
|
- post.Flow.ExportFlow = int64(common.GetIntNoErrByStr(item[7]))
|
|
|
- post.Flow.InletFlow = int64(common.GetIntNoErrByStr(item[8]))
|
|
|
- if post.Client, err = s.GetClient(common.GetIntNoErrByStr(item[5])); err != nil {
|
|
|
- continue
|
|
|
+func (s *Csv) LoadHostFromCsv() {
|
|
|
+ loadSyncMapFromFile(s.HostFilePath, func(v string) {
|
|
|
+ var err error
|
|
|
+ post := new(Host)
|
|
|
+ if json.Unmarshal([]byte(v), &post) != nil {
|
|
|
+ return
|
|
|
}
|
|
|
- if len(item) > 10 {
|
|
|
- post.ServerIp = item[10]
|
|
|
- } else {
|
|
|
- post.ServerIp = "0.0.0.0"
|
|
|
+ if post.Client, err = s.GetClient(post.Client.Id); err != nil {
|
|
|
+ return
|
|
|
}
|
|
|
- s.Tasks.Store(post.Id, post)
|
|
|
- if post.Id > int(s.TaskIncreaseId) {
|
|
|
- s.TaskIncreaseId = int32(s.TaskIncreaseId)
|
|
|
+ s.Hosts.Store(post.Id, post)
|
|
|
+ if post.Id > int(s.HostIncreaseId) {
|
|
|
+ s.HostIncreaseId = int32(post.Id)
|
|
|
}
|
|
|
- }
|
|
|
+ })
|
|
|
}
|
|
|
|
|
|
func (s *Csv) GetIdByVerifyKey(vKey string, addr string) (id int, err error) {
|
|
@@ -195,135 +162,15 @@ func (s *Csv) GetTask(id int) (t *Tunnel, err error) {
|
|
|
}
|
|
|
|
|
|
func (s *Csv) StoreHostToCsv() {
|
|
|
- // 创建文件
|
|
|
- csvFile, err := os.Create(filepath.Join(s.RunPath, "conf", "hosts.csv"))
|
|
|
- if err != nil {
|
|
|
- panic(err)
|
|
|
- }
|
|
|
- defer csvFile.Close()
|
|
|
- // 获取csv的Writer
|
|
|
- writer := csv.NewWriter(csvFile)
|
|
|
- // 将map中的Post转换成slice,因为csv的Write需要slice参数
|
|
|
- // 并写入csv文件
|
|
|
- s.Hosts.Range(func(key, value interface{}) bool {
|
|
|
- host := value.(*Host)
|
|
|
- if host.NoStore {
|
|
|
- return true
|
|
|
- }
|
|
|
- record := []string{
|
|
|
- host.Host,
|
|
|
- host.Target.TargetStr,
|
|
|
- strconv.Itoa(host.Client.Id),
|
|
|
- host.HeaderChange,
|
|
|
- host.HostChange,
|
|
|
- host.Remark,
|
|
|
- host.Location,
|
|
|
- strconv.Itoa(host.Id),
|
|
|
- strconv.Itoa(int(host.Flow.ExportFlow)),
|
|
|
- strconv.Itoa(int(host.Flow.InletFlow)),
|
|
|
- host.Scheme,
|
|
|
- }
|
|
|
- err1 := writer.Write(record)
|
|
|
- if err1 != nil {
|
|
|
- panic(err1)
|
|
|
- }
|
|
|
- return true
|
|
|
- })
|
|
|
-
|
|
|
- // 确保所有内存数据刷到csv文件
|
|
|
- writer.Flush()
|
|
|
+ storeSyncMapToFile(s.Hosts, s.HostFilePath)
|
|
|
}
|
|
|
|
|
|
-func (s *Csv) LoadClientFromCsv() {
|
|
|
- path := filepath.Join(s.RunPath, "conf", "clients.csv")
|
|
|
- records, err := s.openFile(path)
|
|
|
- if err != nil {
|
|
|
- logs.Error("Profile Opening Error:", path)
|
|
|
- os.Exit(0)
|
|
|
- }
|
|
|
- // 将每一行数据保存到内存slice中
|
|
|
- for _, item := range records {
|
|
|
- post := &Client{
|
|
|
- Id: common.GetIntNoErrByStr(item[0]),
|
|
|
- VerifyKey: item[1],
|
|
|
- Remark: item[2],
|
|
|
- Status: common.GetBoolByStr(item[3]),
|
|
|
- RateLimit: common.GetIntNoErrByStr(item[8]),
|
|
|
- Cnf: &Config{
|
|
|
- U: item[4],
|
|
|
- P: item[5],
|
|
|
- Crypt: common.GetBoolByStr(item[6]),
|
|
|
- Compress: common.GetBoolByStr(item[7]),
|
|
|
- },
|
|
|
- MaxConn: common.GetIntNoErrByStr(item[10]),
|
|
|
- }
|
|
|
- if post.Id > int(s.ClientIncreaseId) {
|
|
|
- s.ClientIncreaseId = int32(post.Id)
|
|
|
- }
|
|
|
- if post.RateLimit > 0 {
|
|
|
- post.Rate = rate.NewRate(int64(post.RateLimit * 1024))
|
|
|
- post.Rate.Start()
|
|
|
- } else {
|
|
|
- post.Rate = rate.NewRate(int64(2 << 23))
|
|
|
- post.Rate.Start()
|
|
|
- }
|
|
|
- post.Flow = new(Flow)
|
|
|
- post.Flow.FlowLimit = int64(common.GetIntNoErrByStr(item[9]))
|
|
|
- if len(item) >= 12 {
|
|
|
- post.ConfigConnAllow = common.GetBoolByStr(item[11])
|
|
|
- } else {
|
|
|
- post.ConfigConnAllow = true
|
|
|
- }
|
|
|
- if len(item) >= 13 {
|
|
|
- post.WebUserName = item[12]
|
|
|
- } else {
|
|
|
- post.WebUserName = ""
|
|
|
- }
|
|
|
- if len(item) >= 14 {
|
|
|
- post.WebPassword = item[13]
|
|
|
- } else {
|
|
|
- post.WebPassword = ""
|
|
|
- }
|
|
|
- s.Clients.Store(post.Id, post)
|
|
|
- }
|
|
|
+func (s *Csv) StoreTasksToCsv() {
|
|
|
+ storeSyncMapToFile(s.Tasks, s.TaskFilePath)
|
|
|
}
|
|
|
|
|
|
-func (s *Csv) LoadHostFromCsv() {
|
|
|
- path := filepath.Join(s.RunPath, "conf", "hosts.csv")
|
|
|
- records, err := s.openFile(path)
|
|
|
- if err != nil {
|
|
|
- logs.Error("Profile Opening Error:", path)
|
|
|
- os.Exit(0)
|
|
|
- }
|
|
|
- // 将每一行数据保存到内存slice中
|
|
|
- for _, item := range records {
|
|
|
- post := &Host{
|
|
|
- Host: item[0],
|
|
|
- HeaderChange: item[3],
|
|
|
- HostChange: item[4],
|
|
|
- Remark: item[5],
|
|
|
- Location: item[6],
|
|
|
- Id: common.GetIntNoErrByStr(item[7]),
|
|
|
- }
|
|
|
- if post.Client, err = s.GetClient(common.GetIntNoErrByStr(item[2])); err != nil {
|
|
|
- continue
|
|
|
- }
|
|
|
- post.Target = new(Target)
|
|
|
- post.Target.TargetStr = item[1]
|
|
|
- post.Flow = new(Flow)
|
|
|
- post.Flow.ExportFlow = int64(common.GetIntNoErrByStr(item[8]))
|
|
|
- post.Flow.InletFlow = int64(common.GetIntNoErrByStr(item[9]))
|
|
|
- if len(item) > 10 {
|
|
|
- post.Scheme = item[10]
|
|
|
- } else {
|
|
|
- post.Scheme = "all"
|
|
|
- }
|
|
|
- s.Hosts.Store(post.Id, post)
|
|
|
- if post.Id > int(s.HostIncreaseId) {
|
|
|
- s.HostIncreaseId = int32(post.Id)
|
|
|
- }
|
|
|
- //store host to hostMap if the host url is none
|
|
|
- }
|
|
|
+func (s *Csv) StoreClientsToCsv() {
|
|
|
+ storeSyncMapToFile(s.Clients, s.ClientFilePath)
|
|
|
}
|
|
|
|
|
|
func (s *Csv) DelHost(id int) error {
|
|
@@ -439,6 +286,7 @@ func (s *Csv) VerifyVkey(vkey string, id int) (res bool) {
|
|
|
})
|
|
|
return res
|
|
|
}
|
|
|
+
|
|
|
func (s *Csv) VerifyUserName(username string, id int) (res bool) {
|
|
|
res = true
|
|
|
s.Clients.Range(func(key, value interface{}) bool {
|
|
@@ -452,18 +300,6 @@ func (s *Csv) VerifyUserName(username string, id int) (res bool) {
|
|
|
return res
|
|
|
}
|
|
|
|
|
|
-func (s *Csv) GetClientId() int32 {
|
|
|
- return atomic.AddInt32(&s.ClientIncreaseId, 1)
|
|
|
-}
|
|
|
-
|
|
|
-func (s *Csv) GetTaskId() int32 {
|
|
|
- return atomic.AddInt32(&s.TaskIncreaseId, 1)
|
|
|
-}
|
|
|
-
|
|
|
-func (s *Csv) GetHostId() int32 {
|
|
|
- return atomic.AddInt32(&s.HostIncreaseId, 1)
|
|
|
-}
|
|
|
-
|
|
|
func (s *Csv) UpdateClient(t *Client) error {
|
|
|
s.Clients.Store(t.Id, t)
|
|
|
if t.RateLimit == 0 {
|
|
@@ -516,6 +352,7 @@ func (s *Csv) GetClient(id int) (c *Client, err error) {
|
|
|
err = errors.New("未找到客户端")
|
|
|
return
|
|
|
}
|
|
|
+
|
|
|
func (s *Csv) GetClientIdByVkey(vkey string) (id int, err error) {
|
|
|
var exist bool
|
|
|
s.Clients.Range(func(key, value interface{}) bool {
|
|
@@ -585,40 +422,70 @@ func (s *Csv) GetInfoByHost(host string, r *http.Request) (h *Host, err error) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
-func (s *Csv) StoreClientsToCsv() {
|
|
|
- // 创建文件
|
|
|
- csvFile, err := os.Create(filepath.Join(s.RunPath, "conf", "clients.csv"))
|
|
|
+func (s *Csv) GetClientId() int32 {
|
|
|
+ return atomic.AddInt32(&s.ClientIncreaseId, 1)
|
|
|
+}
|
|
|
+
|
|
|
+func (s *Csv) GetTaskId() int32 {
|
|
|
+ return atomic.AddInt32(&s.TaskIncreaseId, 1)
|
|
|
+}
|
|
|
+
|
|
|
+func (s *Csv) GetHostId() int32 {
|
|
|
+ return atomic.AddInt32(&s.HostIncreaseId, 1)
|
|
|
+}
|
|
|
+
|
|
|
+func loadSyncMapFromFile(filePath string, f func(value string)) {
|
|
|
+ b, err := common.ReadAllFromFile(filePath)
|
|
|
+ if err != nil {
|
|
|
+ panic(err)
|
|
|
+ }
|
|
|
+ for _, v := range strings.Split(string(b), "\n"+common.CONN_DATA_SEQ) {
|
|
|
+ f(v)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func storeSyncMapToFile(m sync.Map, filePath string) {
|
|
|
+ file, err := os.Create(filePath)
|
|
|
if err != nil {
|
|
|
- logs.Error(err.Error())
|
|
|
+ panic(err)
|
|
|
}
|
|
|
- defer csvFile.Close()
|
|
|
- writer := csv.NewWriter(csvFile)
|
|
|
- s.Clients.Range(func(key, value interface{}) bool {
|
|
|
- client := value.(*Client)
|
|
|
- if client.NoStore {
|
|
|
+ defer file.Close()
|
|
|
+ m.Range(func(key, value interface{}) bool {
|
|
|
+ var b []byte
|
|
|
+ var err error
|
|
|
+ switch value.(type) {
|
|
|
+ case *Tunnel:
|
|
|
+ obj := value.(*Tunnel)
|
|
|
+ if obj.NoStore {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ b, err = json.Marshal(obj)
|
|
|
+ case *Host:
|
|
|
+ obj := value.(*Host)
|
|
|
+ if obj.NoStore {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ b, err = json.Marshal(obj)
|
|
|
+ case *Client:
|
|
|
+ obj := value.(*Client)
|
|
|
+ if obj.NoStore {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ b, err = json.Marshal(obj)
|
|
|
+ default:
|
|
|
return true
|
|
|
}
|
|
|
- record := []string{
|
|
|
- strconv.Itoa(client.Id),
|
|
|
- client.VerifyKey,
|
|
|
- client.Remark,
|
|
|
- strconv.FormatBool(client.Status),
|
|
|
- client.Cnf.U,
|
|
|
- client.Cnf.P,
|
|
|
- common.GetStrByBool(client.Cnf.Crypt),
|
|
|
- strconv.FormatBool(client.Cnf.Compress),
|
|
|
- strconv.Itoa(client.RateLimit),
|
|
|
- strconv.Itoa(int(client.Flow.FlowLimit)),
|
|
|
- strconv.Itoa(int(client.MaxConn)),
|
|
|
- common.GetStrByBool(client.ConfigConnAllow),
|
|
|
- client.WebUserName,
|
|
|
- client.WebPassword,
|
|
|
+ if err != nil {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ _, err = file.Write(b)
|
|
|
+ if err != nil {
|
|
|
+ panic(err)
|
|
|
}
|
|
|
- err := writer.Write(record)
|
|
|
+ _, err = file.Write([]byte("\n" + common.CONN_DATA_SEQ))
|
|
|
if err != nil {
|
|
|
- logs.Error(err.Error())
|
|
|
+ panic(err)
|
|
|
}
|
|
|
return true
|
|
|
})
|
|
|
- writer.Flush()
|
|
|
}
|