1
0
刘河 6 жил өмнө
parent
commit
cd7f99063c

+ 0 - 2
conf/clients.csv

@@ -1,2 +0,0 @@
-12,ao0yd0jx6ty0ht69,,true,,,0,false,0,0,0,1
-11,mxg22qa06dc137of,,true,,,0,false,0,0,0,1

+ 0 - 0
conf/clients.json


+ 0 - 1
conf/hosts.csv

@@ -1 +0,0 @@
-a.o.com,123.206.77.88:8080,11,,,,/,1,0,0,all

+ 0 - 0
conf/hosts.json


+ 0 - 1
conf/tasks.csv

@@ -1 +0,0 @@
-9999,tcp,,1,3,11,,0,0,,0.0.0.0

+ 0 - 0
conf/tasks.json


+ 112 - 245
lib/file/file.go

@@ -1,18 +1,16 @@
 package file
 
 import (
-	"encoding/csv"
+	"encoding/json"
 	"errors"
 	"fmt"
 	"github.com/cnlh/nps/lib/common"
 	"github.com/cnlh/nps/lib/crypt"
 	"github.com/cnlh/nps/lib/rate"
-	"github.com/cnlh/nps/vender/github.com/astaxie/beego/logs"
 	"net/http"
 	"os"
 	"path/filepath"
 	"regexp"
-	"strconv"
 	"strings"
 	"sync"
 	"sync/atomic"
@@ -20,7 +18,10 @@ import (
 
 func NewCsv(runPath string) *Csv {
 	return &Csv{
-		RunPath: runPath,
+		RunPath:        runPath,
+		TaskFilePath:   filepath.Join(runPath, "conf", "tasks.json"),
+		HostFilePath:   filepath.Join(runPath, "conf", "hosts.json"),
+		ClientFilePath: filepath.Join(runPath, "conf", "clients.json"),
 	}
 }
 
@@ -33,96 +34,62 @@ type Csv struct {
 	ClientIncreaseId int32    //客户端id
 	TaskIncreaseId   int32    //任务自增ID
 	HostIncreaseId   int32    //host increased id
+	TaskFilePath     string
+	HostFilePath     string
+	ClientFilePath   string
 }
 
-func (s *Csv) StoreTasksToCsv() {
-	// 创建文件
-	csvFile, err := os.Create(filepath.Join(s.RunPath, "conf", "tasks.csv"))
-	if err != nil {
-		logs.Error(err.Error())
-	}
-	defer csvFile.Close()
-	writer := csv.NewWriter(csvFile)
-	s.Tasks.Range(func(key, value interface{}) bool {
-		task := value.(*Tunnel)
-		if task.NoStore {
-			return true
+func (s *Csv) LoadTaskFromCsv() {
+	loadSyncMapFromFile(s.TaskFilePath, func(v string) {
+		var err error
+		post := new(Tunnel)
+		if json.Unmarshal([]byte(v), &post) != nil {
+			return
 		}
-		record := []string{
-			strconv.Itoa(task.Port),
-			task.Mode,
-			task.Target.TargetStr,
-			common.GetStrByBool(task.Status),
-			strconv.Itoa(task.Id),
-			strconv.Itoa(task.Client.Id),
-			task.Remark,
-			strconv.Itoa(int(task.Flow.ExportFlow)),
-			strconv.Itoa(int(task.Flow.InletFlow)),
-			task.Password,
-			task.ServerIp,
+		if post.Client, err = s.GetClient(post.Client.Id); err != nil {
+			return
 		}
-		err := writer.Write(record)
-		if err != nil {
-			logs.Error(err.Error())
+		s.Tasks.Store(post.Id, post)
+		if post.Id > int(s.TaskIncreaseId) {
+			s.TaskIncreaseId = int32(post.Id)
 		}
-		return true
 	})
-	writer.Flush()
 }
 
-func (s *Csv) openFile(path string) ([][]string, error) {
-	// 打开文件
-	file, err := os.Open(path)
-	if err != nil {
-		panic(err)
-	}
-	defer file.Close()
-
-	// 获取csv的reader
-	reader := csv.NewReader(file)
-
-	// 设置FieldsPerRecord为-1
-	reader.FieldsPerRecord = -1
-
-	// 读取文件中所有行保存到slice中
-	return reader.ReadAll()
+func (s *Csv) LoadClientFromCsv() {
+	loadSyncMapFromFile(s.ClientFilePath, func(v string) {
+		post := new(Client)
+		if json.Unmarshal([]byte(v), &post) != nil {
+			return
+		}
+		if post.RateLimit > 0 {
+			post.Rate = rate.NewRate(int64(post.RateLimit * 1024))
+		} else {
+			post.Rate = rate.NewRate(int64(2 << 23))
+		}
+		post.Rate.Start()
+		s.Clients.Store(post.Id, post)
+		if post.Id > int(s.ClientIncreaseId) {
+			s.ClientIncreaseId = int32(post.Id)
+		}
+	})
 }
 
-func (s *Csv) LoadTaskFromCsv() {
-	path := filepath.Join(s.RunPath, "conf", "tasks.csv")
-	records, err := s.openFile(path)
-	if err != nil {
-		logs.Error("Profile Opening Error:", path)
-		os.Exit(0)
-	}
-	// 将每一行数据保存到内存slice中
-	for _, item := range records {
-		post := &Tunnel{
-			Port:     common.GetIntNoErrByStr(item[0]),
-			Mode:     item[1],
-			Status:   common.GetBoolByStr(item[3]),
-			Id:       common.GetIntNoErrByStr(item[4]),
-			Remark:   item[6],
-			Password: item[9],
-		}
-		post.Target = new(Target)
-		post.Target.TargetStr = item[2]
-		post.Flow = new(Flow)
-		post.Flow.ExportFlow = int64(common.GetIntNoErrByStr(item[7]))
-		post.Flow.InletFlow = int64(common.GetIntNoErrByStr(item[8]))
-		if post.Client, err = s.GetClient(common.GetIntNoErrByStr(item[5])); err != nil {
-			continue
+func (s *Csv) LoadHostFromCsv() {
+	loadSyncMapFromFile(s.HostFilePath, func(v string) {
+		var err error
+		post := new(Host)
+		if json.Unmarshal([]byte(v), &post) != nil {
+			return
 		}
-		if len(item) > 10 {
-			post.ServerIp = item[10]
-		} else {
-			post.ServerIp = "0.0.0.0"
+		if post.Client, err = s.GetClient(post.Client.Id); err != nil {
+			return
 		}
-		s.Tasks.Store(post.Id, post)
-		if post.Id > int(s.TaskIncreaseId) {
-			s.TaskIncreaseId = int32(s.TaskIncreaseId)
+		s.Hosts.Store(post.Id, post)
+		if post.Id > int(s.HostIncreaseId) {
+			s.HostIncreaseId = int32(post.Id)
 		}
-	}
+	})
 }
 
 func (s *Csv) GetIdByVerifyKey(vKey string, addr string) (id int, err error) {
@@ -195,135 +162,15 @@ func (s *Csv) GetTask(id int) (t *Tunnel, err error) {
 }
 
 func (s *Csv) StoreHostToCsv() {
-	// 创建文件
-	csvFile, err := os.Create(filepath.Join(s.RunPath, "conf", "hosts.csv"))
-	if err != nil {
-		panic(err)
-	}
-	defer csvFile.Close()
-	// 获取csv的Writer
-	writer := csv.NewWriter(csvFile)
-	// 将map中的Post转换成slice,因为csv的Write需要slice参数
-	// 并写入csv文件
-	s.Hosts.Range(func(key, value interface{}) bool {
-		host := value.(*Host)
-		if host.NoStore {
-			return true
-		}
-		record := []string{
-			host.Host,
-			host.Target.TargetStr,
-			strconv.Itoa(host.Client.Id),
-			host.HeaderChange,
-			host.HostChange,
-			host.Remark,
-			host.Location,
-			strconv.Itoa(host.Id),
-			strconv.Itoa(int(host.Flow.ExportFlow)),
-			strconv.Itoa(int(host.Flow.InletFlow)),
-			host.Scheme,
-		}
-		err1 := writer.Write(record)
-		if err1 != nil {
-			panic(err1)
-		}
-		return true
-	})
-
-	// 确保所有内存数据刷到csv文件
-	writer.Flush()
+	storeSyncMapToFile(s.Hosts, s.HostFilePath)
 }
 
-func (s *Csv) LoadClientFromCsv() {
-	path := filepath.Join(s.RunPath, "conf", "clients.csv")
-	records, err := s.openFile(path)
-	if err != nil {
-		logs.Error("Profile Opening Error:", path)
-		os.Exit(0)
-	}
-	// 将每一行数据保存到内存slice中
-	for _, item := range records {
-		post := &Client{
-			Id:        common.GetIntNoErrByStr(item[0]),
-			VerifyKey: item[1],
-			Remark:    item[2],
-			Status:    common.GetBoolByStr(item[3]),
-			RateLimit: common.GetIntNoErrByStr(item[8]),
-			Cnf: &Config{
-				U:        item[4],
-				P:        item[5],
-				Crypt:    common.GetBoolByStr(item[6]),
-				Compress: common.GetBoolByStr(item[7]),
-			},
-			MaxConn: common.GetIntNoErrByStr(item[10]),
-		}
-		if post.Id > int(s.ClientIncreaseId) {
-			s.ClientIncreaseId = int32(post.Id)
-		}
-		if post.RateLimit > 0 {
-			post.Rate = rate.NewRate(int64(post.RateLimit * 1024))
-			post.Rate.Start()
-		} else {
-			post.Rate = rate.NewRate(int64(2 << 23))
-			post.Rate.Start()
-		}
-		post.Flow = new(Flow)
-		post.Flow.FlowLimit = int64(common.GetIntNoErrByStr(item[9]))
-		if len(item) >= 12 {
-			post.ConfigConnAllow = common.GetBoolByStr(item[11])
-		} else {
-			post.ConfigConnAllow = true
-		}
-		if len(item) >= 13 {
-			post.WebUserName = item[12]
-		} else {
-			post.WebUserName = ""
-		}
-		if len(item) >= 14 {
-			post.WebPassword = item[13]
-		} else {
-			post.WebPassword = ""
-		}
-		s.Clients.Store(post.Id, post)
-	}
+func (s *Csv) StoreTasksToCsv() {
+	storeSyncMapToFile(s.Tasks, s.TaskFilePath)
 }
 
-func (s *Csv) LoadHostFromCsv() {
-	path := filepath.Join(s.RunPath, "conf", "hosts.csv")
-	records, err := s.openFile(path)
-	if err != nil {
-		logs.Error("Profile Opening Error:", path)
-		os.Exit(0)
-	}
-	// 将每一行数据保存到内存slice中
-	for _, item := range records {
-		post := &Host{
-			Host:         item[0],
-			HeaderChange: item[3],
-			HostChange:   item[4],
-			Remark:       item[5],
-			Location:     item[6],
-			Id:           common.GetIntNoErrByStr(item[7]),
-		}
-		if post.Client, err = s.GetClient(common.GetIntNoErrByStr(item[2])); err != nil {
-			continue
-		}
-		post.Target = new(Target)
-		post.Target.TargetStr = item[1]
-		post.Flow = new(Flow)
-		post.Flow.ExportFlow = int64(common.GetIntNoErrByStr(item[8]))
-		post.Flow.InletFlow = int64(common.GetIntNoErrByStr(item[9]))
-		if len(item) > 10 {
-			post.Scheme = item[10]
-		} else {
-			post.Scheme = "all"
-		}
-		s.Hosts.Store(post.Id, post)
-		if post.Id > int(s.HostIncreaseId) {
-			s.HostIncreaseId = int32(post.Id)
-		}
-		//store host to hostMap if the host url is none
-	}
+func (s *Csv) StoreClientsToCsv() {
+	storeSyncMapToFile(s.Clients, s.ClientFilePath)
 }
 
 func (s *Csv) DelHost(id int) error {
@@ -439,6 +286,7 @@ func (s *Csv) VerifyVkey(vkey string, id int) (res bool) {
 	})
 	return res
 }
+
 func (s *Csv) VerifyUserName(username string, id int) (res bool) {
 	res = true
 	s.Clients.Range(func(key, value interface{}) bool {
@@ -452,18 +300,6 @@ func (s *Csv) VerifyUserName(username string, id int) (res bool) {
 	return res
 }
 
-func (s *Csv) GetClientId() int32 {
-	return atomic.AddInt32(&s.ClientIncreaseId, 1)
-}
-
-func (s *Csv) GetTaskId() int32 {
-	return atomic.AddInt32(&s.TaskIncreaseId, 1)
-}
-
-func (s *Csv) GetHostId() int32 {
-	return atomic.AddInt32(&s.HostIncreaseId, 1)
-}
-
 func (s *Csv) UpdateClient(t *Client) error {
 	s.Clients.Store(t.Id, t)
 	if t.RateLimit == 0 {
@@ -516,6 +352,7 @@ func (s *Csv) GetClient(id int) (c *Client, err error) {
 	err = errors.New("未找到客户端")
 	return
 }
+
 func (s *Csv) GetClientIdByVkey(vkey string) (id int, err error) {
 	var exist bool
 	s.Clients.Range(func(key, value interface{}) bool {
@@ -585,40 +422,70 @@ func (s *Csv) GetInfoByHost(host string, r *http.Request) (h *Host, err error) {
 	return
 }
 
-func (s *Csv) StoreClientsToCsv() {
-	// 创建文件
-	csvFile, err := os.Create(filepath.Join(s.RunPath, "conf", "clients.csv"))
+func (s *Csv) GetClientId() int32 {
+	return atomic.AddInt32(&s.ClientIncreaseId, 1)
+}
+
+func (s *Csv) GetTaskId() int32 {
+	return atomic.AddInt32(&s.TaskIncreaseId, 1)
+}
+
+func (s *Csv) GetHostId() int32 {
+	return atomic.AddInt32(&s.HostIncreaseId, 1)
+}
+
+func loadSyncMapFromFile(filePath string, f func(value string)) {
+	b, err := common.ReadAllFromFile(filePath)
+	if err != nil {
+		panic(err)
+	}
+	for _, v := range strings.Split(string(b), "\n"+common.CONN_DATA_SEQ) {
+		f(v)
+	}
+}
+
+func storeSyncMapToFile(m sync.Map, filePath string) {
+	file, err := os.Create(filePath)
 	if err != nil {
-		logs.Error(err.Error())
+		panic(err)
 	}
-	defer csvFile.Close()
-	writer := csv.NewWriter(csvFile)
-	s.Clients.Range(func(key, value interface{}) bool {
-		client := value.(*Client)
-		if client.NoStore {
+	defer file.Close()
+	m.Range(func(key, value interface{}) bool {
+		var b []byte
+		var err error
+		switch value.(type) {
+		case *Tunnel:
+			obj := value.(*Tunnel)
+			if obj.NoStore {
+				return true
+			}
+			b, err = json.Marshal(obj)
+		case *Host:
+			obj := value.(*Host)
+			if obj.NoStore {
+				return true
+			}
+			b, err = json.Marshal(obj)
+		case *Client:
+			obj := value.(*Client)
+			if obj.NoStore {
+				return true
+			}
+			b, err = json.Marshal(obj)
+		default:
 			return true
 		}
-		record := []string{
-			strconv.Itoa(client.Id),
-			client.VerifyKey,
-			client.Remark,
-			strconv.FormatBool(client.Status),
-			client.Cnf.U,
-			client.Cnf.P,
-			common.GetStrByBool(client.Cnf.Crypt),
-			strconv.FormatBool(client.Cnf.Compress),
-			strconv.Itoa(client.RateLimit),
-			strconv.Itoa(int(client.Flow.FlowLimit)),
-			strconv.Itoa(int(client.MaxConn)),
-			common.GetStrByBool(client.ConfigConnAllow),
-			client.WebUserName,
-			client.WebPassword,
+		if err != nil {
+			return true
+		}
+		_, err = file.Write(b)
+		if err != nil {
+			panic(err)
 		}
-		err := writer.Write(record)
+		_, err = file.Write([]byte("\n" + common.CONN_DATA_SEQ))
 		if err != nil {
-			logs.Error(err.Error())
+			panic(err)
 		}
 		return true
 	})
-	writer.Flush()
 }

+ 1 - 1
lib/file/obj.go

@@ -183,6 +183,6 @@ type Host struct {
 	Flow         *Flow
 	Client       *Client
 	Target       *Target //目标
-	Health
+	Health       `json:"-"`
 	sync.RWMutex
 }

+ 1 - 1
lib/mux/mux.go

@@ -65,7 +65,7 @@ func (s *Mux) NewConn() (*conn, error) {
 		return nil, err
 	}
 	//set a timer timeout 30 second
-	timer := time.NewTimer(time.Second * 30)
+	timer := time.NewTimer(time.Minute * 2)
 	defer timer.Stop()
 	select {
 	case <-conn.connStatusOkCh: