123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107 |
- package server
- import (
- "errors"
- "github.com/cnlh/nps/bridge"
- "github.com/cnlh/nps/lib"
- "net"
- "net/http"
- "sync"
- )
- //server base struct
- type server struct {
- id int
- bridge *bridge.Bridge
- task *lib.Tunnel
- config *lib.Config
- errorContent []byte
- sync.Mutex
- }
- func (s *server) FlowAdd(in, out int64) {
- s.Lock()
- defer s.Unlock()
- s.task.Flow.ExportFlow += out
- s.task.Flow.InletFlow += in
- }
- func (s *server) FlowAddHost(host *lib.Host, in, out int64) {
- s.Lock()
- defer s.Unlock()
- host.Flow.ExportFlow += out
- host.Flow.InletFlow += in
- }
- //热更新配置
- func (s *server) ResetConfig() bool {
- //获取最新数据
- task, err := lib.GetCsvDb().GetTask(s.task.Id)
- if err != nil {
- return false
- }
- if s.task.Client.Flow.FlowLimit > 0 && (s.task.Client.Flow.FlowLimit<<20) < (s.task.Client.Flow.ExportFlow+s.task.Client.Flow.InletFlow) {
- return false
- }
- s.task.UseClientCnf = task.UseClientCnf
- //使用客户端配置
- client, err := lib.GetCsvDb().GetClient(s.task.Client.Id)
- if s.task.UseClientCnf {
- if err == nil {
- s.config.U = client.Cnf.U
- s.config.P = client.Cnf.P
- s.config.Compress = client.Cnf.Compress
- s.config.Crypt = client.Cnf.Crypt
- }
- } else {
- if err == nil {
- s.config.U = task.Config.U
- s.config.P = task.Config.P
- s.config.Compress = task.Config.Compress
- s.config.Crypt = task.Config.Crypt
- }
- }
- s.task.Client.Rate = client.Rate
- s.config.CompressDecode, s.config.CompressEncode = lib.GetCompressType(s.config.Compress)
- return true
- }
- func (s *server) linkCopy(link *lib.Link, c *lib.Conn, rb []byte, tunnel *lib.Conn, flow *lib.Flow) {
- if rb != nil {
- if _, err := tunnel.SendMsg(rb, link); err != nil {
- c.Close()
- return
- }
- flow.Add(len(rb), 0)
- }
- for {
- buf := lib.BufPoolCopy.Get().([]byte)
- if n, err := c.Read(buf); err != nil {
- tunnel.SendMsg([]byte(lib.IO_EOF), link)
- break
- } else {
- if _, err := tunnel.SendMsg(buf[:n], link); err != nil {
- lib.PutBufPoolCopy(buf)
- c.Close()
- break
- }
- lib.PutBufPoolCopy(buf)
- flow.Add(n, 0)
- }
- }
- }
- func (s *server) writeConnFail(c net.Conn) {
- c.Write([]byte(lib.ConnectionFailBytes))
- c.Write(s.errorContent)
- }
- //权限认证
- func (s *server) auth(r *http.Request, c *lib.Conn, u, p string) error {
- if u != "" && p != "" && !lib.CheckAuth(r, u, p) {
- c.Write([]byte(lib.UnauthorizedBytes))
- c.Close()
- return errors.New("401 Unauthorized")
- }
- return nil
- }
|