소스 검색

fixed typo in test.go
replaced self-made http reverseproxy with a more robust and versatile one.
dynamically generate cert for client-server tls encryption

snowie2000 5 년 전
부모
커밋
a732febf3b
5개의 변경된 파일342개의 추가작업 그리고 161개의 파일을 삭제
  1. 11 9
      cmd/nps/nps.go
  2. 58 5
      lib/crypt/tls.go
  3. 135 145
      server/proxy/http.go
  4. 136 0
      server/proxy/reverseproxy.go
  5. 2 2
      server/test/test.go

+ 11 - 9
cmd/nps/nps.go

@@ -1,14 +1,6 @@
 package main
 
 import (
-	"ehang.io/nps/lib/crypt"
-	"ehang.io/nps/lib/file"
-	"ehang.io/nps/lib/install"
-	"ehang.io/nps/lib/version"
-	"ehang.io/nps/server"
-	"ehang.io/nps/server/connection"
-	"ehang.io/nps/server/tool"
-	"ehang.io/nps/web/routers"
 	"flag"
 	"log"
 	"os"
@@ -18,7 +10,16 @@ import (
 	"strings"
 	"sync"
 
+	"ehang.io/nps/lib/file"
+	"ehang.io/nps/lib/install"
+	"ehang.io/nps/lib/version"
+	"ehang.io/nps/server"
+	"ehang.io/nps/server/connection"
+	"ehang.io/nps/server/tool"
+	"ehang.io/nps/web/routers"
+
 	"ehang.io/nps/lib/common"
+	"ehang.io/nps/lib/crypt"
 	"ehang.io/nps/lib/daemon"
 	"github.com/astaxie/beego"
 	"github.com/astaxie/beego/logs"
@@ -200,7 +201,8 @@ func run() {
 	}
 	logs.Info("the version of server is %s ,allow client core version to be %s", version.VERSION, version.GetVersion())
 	connection.InitConnectionService()
-	crypt.InitTls(filepath.Join(common.GetRunPath(), "conf", "server.pem"), filepath.Join(common.GetRunPath(), "conf", "server.key"))
+	//crypt.InitTls(filepath.Join(common.GetRunPath(), "conf", "server.pem"), filepath.Join(common.GetRunPath(), "conf", "server.key"))
+	crypt.InitTls()
 	tool.InitAllowPort()
 	tool.StartSystemInfo()
 	go server.StartNewServer(bridgePort, task, beego.AppConfig.String("bridge_type"))

+ 58 - 5
lib/crypt/tls.go

@@ -1,22 +1,37 @@
 package crypt
 
 import (
+	"crypto/rand"
+	"crypto/rsa"
 	"crypto/tls"
+	"crypto/x509"
+	"crypto/x509/pkix"
+	"encoding/pem"
+	"log"
+	"math/big"
 	"net"
 	"os"
+	"time"
 
 	"github.com/astaxie/beego/logs"
 )
 
-var pemPath, keyPath string
+var (
+	cert tls.Certificate
+)
 
-func InitTls(pem, key string) {
-	pemPath = pem
-	keyPath = key
+func InitTls() {
+	c, k, err := generateKeyPair("NPS Corp,.Inc")
+	if err == nil {
+		cert, err = tls.X509KeyPair(c, k)
+	}
+	if err != nil {
+		log.Fatalln("Error initializing crypto certs", err)
+	}
 }
 
 func NewTlsServerConn(conn net.Conn) net.Conn {
-	cert, err := tls.LoadX509KeyPair(pemPath, keyPath)
+	var err error
 	if err != nil {
 		logs.Error(err)
 		os.Exit(0)
@@ -32,3 +47,41 @@ func NewTlsClientConn(conn net.Conn) net.Conn {
 	}
 	return tls.Client(conn, conf)
 }
+
+func generateKeyPair(CommonName string) (rawCert, rawKey []byte, err error) {
+	// Create private key and self-signed certificate
+	// Adapted from https://golang.org/src/crypto/tls/generate_cert.go
+
+	priv, err := rsa.GenerateKey(rand.Reader, 2048)
+	if err != nil {
+		return
+	}
+	validFor := time.Hour * 24 * 365 * 10 // ten years
+	notBefore := time.Now()
+	notAfter := notBefore.Add(validFor)
+	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
+	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
+	template := x509.Certificate{
+		SerialNumber: serialNumber,
+		Subject: pkix.Name{
+			Organization: []string{"My Company Name LTD."},
+			CommonName:   CommonName,
+			Country:      []string{"US"},
+		},
+		NotBefore: notBefore,
+		NotAfter:  notAfter,
+
+		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
+		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+		BasicConstraintsValid: true,
+	}
+	derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
+	if err != nil {
+		return
+	}
+
+	rawCert = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
+	rawKey = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
+
+	return
+}

+ 135 - 145
server/proxy/http.go

@@ -1,7 +1,7 @@
 package proxy
 
 import (
-	"bufio"
+	"context"
 	"crypto/tls"
 	"io"
 	"net"
@@ -10,8 +10,8 @@ import (
 	"os"
 	"path/filepath"
 	"strconv"
-	"strings"
 	"sync"
+	"time"
 
 	"ehang.io/nps/bridge"
 	"ehang.io/nps/lib/cache"
@@ -101,174 +101,164 @@ func (s *httpServer) Close() error {
 	return nil
 }
 
-func (s *httpServer) handleTunneling(w http.ResponseWriter, r *http.Request) {
-	hijacker, ok := w.(http.Hijacker)
-	if !ok {
-		http.Error(w, "Hijacking not supported", http.StatusInternalServerError)
-		return
-	}
-	c, _, err := hijacker.Hijack()
-	if err != nil {
-		http.Error(w, err.Error(), http.StatusServiceUnavailable)
+func (s *httpServer) NewServer(port int, scheme string) *http.Server {
+	rProxy := NewHttpReverseProxy(s)
+	return &http.Server{
+		Addr: ":" + strconv.Itoa(port),
+		Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+			r.URL.Scheme = scheme
+			rProxy.ServeHTTP(w, r)
+		}),
+		// Disable HTTP/2.
+		TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)),
 	}
-	s.handleHttp(conn.NewConn(c), r)
 }
 
-func (s *httpServer) handleHttp(c *conn.Conn, r *http.Request) {
+type HttpReverseProxy struct {
+	proxy *ReverseProxy
+
+	responseHeaderTimeout time.Duration
+}
+
+func (rp *HttpReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
 	var (
 		host       *file.Host
-		target     net.Conn
-		err        error
-		connClient io.ReadWriteCloser
-		scheme     = r.URL.Scheme
-		lk         *conn.Link
 		targetAddr string
-		lenConn    *conn.LenConn
-		isReset    bool
-		wg         sync.WaitGroup
+		err        error
 	)
-	defer func() {
-		if connClient != nil {
-			connClient.Close()
-		}else {
-			s.writeConnFail(c.Conn)
-		}
-		c.Close()
-	}()
-reset:
-	if isReset {
-		host.Client.AddConn()
-	}
-	if host, err = file.GetDb().GetInfoByHost(r.Host, r); err != nil {
-		logs.Notice("the url %s %s %s can't be parsed!", r.URL.Scheme, r.Host, r.RequestURI)
-		return
-	}
-	if err := s.CheckFlowAndConnNum(host.Client); err != nil {
-		logs.Warn("client id %d, host id %d, error %s, when https connection", host.Client.Id, host.Id, err.Error())
+	if host, err = file.GetDb().GetInfoByHost(req.Host, req); err != nil {
+		rw.WriteHeader(http.StatusNotFound)
+		rw.Write([]byte(req.Host + " not found"))
 		return
 	}
-	if !isReset {
-		defer host.Client.AddConn()
-	}
-	if err = s.auth(r, c, host.Client.Cnf.U, host.Client.Cnf.P); err != nil {
-		logs.Warn("auth error", err, r.RemoteAddr)
+	if host.Client.Cnf.U != "" && host.Client.Cnf.P != "" && !common.CheckAuth(req, host.Client.Cnf.U, host.Client.Cnf.P) {
+		rw.WriteHeader(http.StatusUnauthorized)
+		rw.Write([]byte("Unauthorized"))
 		return
 	}
 	if targetAddr, err = host.Target.GetRandomTarget(); err != nil {
-		logs.Warn(err.Error())
-		return
-	}
-	lk = conn.NewLink("http", targetAddr, host.Client.Cnf.Crypt, host.Client.Cnf.Compress, r.RemoteAddr, host.Target.LocalProxy)
-	if target, err = s.bridge.SendLinkInfo(host.Client.Id, lk, nil); err != nil {
-		logs.Notice("connect to target %s error %s", lk.Host, err)
+		rw.WriteHeader(http.StatusBadGateway)
+		rw.Write([]byte("502 Bad Gateway"))
 		return
 	}
-	connClient = conn.GetConn(target, lk.Crypt, lk.Compress, host.Client.Rate, true)
+	req = req.WithContext(context.WithValue(req.Context(), "host", host))
+	req = req.WithContext(context.WithValue(req.Context(), "target", targetAddr))
+	req = req.WithContext(context.WithValue(req.Context(), "req", req))
 
-	//read from inc-client
-	go func() {
-		wg.Add(1)
-		isReset = false
-		defer connClient.Close()
-		defer func() {
-			wg.Done()
-			if !isReset {
-				c.Close()
-			}
-		}()
-		for {
-			if resp, err := http.ReadResponse(bufio.NewReader(connClient), r); err != nil || resp == nil {
+	rp.proxy.ServeHTTP(rw, req)
+}
+
+func NewHttpReverseProxy(s *httpServer) *HttpReverseProxy {
+	rp := &HttpReverseProxy{
+		responseHeaderTimeout: 30 * time.Second,
+	}
+	local, _ := net.ResolveTCPAddr("tcp", "127.0.0.1")
+	proxy := NewReverseProxy(&httputil.ReverseProxy{
+		Director: func(r *http.Request) {
+			r.URL.Host = r.Host
+			if host, err := file.GetDb().GetInfoByHost(r.Host, r); err != nil {
+				logs.Notice("the url %s %s %s can't be parsed!", r.URL.Scheme, r.Host, r.RequestURI)
 				return
 			} else {
-				//if the cache is start and the response is in the extension,store the response to the cache list
-				if s.useCache && r.URL != nil && strings.Contains(r.URL.Path, ".") {
-					b, err := httputil.DumpResponse(resp, true)
-					if err != nil {
-						return
-					}
-					c.Write(b)
-					host.Flow.Add(0, int64(len(b)))
-					s.cache.Add(filepath.Join(host.Host, r.URL.Path), b)
-				} else {
-					lenConn := conn.NewLenConn(c)
-					if err := resp.Write(lenConn); err != nil {
-						logs.Error(err)
-						return
-					}
-					host.Flow.Add(0, int64(lenConn.Len))
-				}
+				common.ChangeHostAndHeader(r, host.HostChange, host.HeaderChange, "", false)
 			}
-		}
-	}()
+		},
+		Transport: &http.Transport{
+			ResponseHeaderTimeout: rp.responseHeaderTimeout,
+			DisableKeepAlives:     true,
+			DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+				var (
+					host       *file.Host
+					target     net.Conn
+					err        error
+					connClient io.ReadWriteCloser
+					targetAddr string
+					lk         *conn.Link
+				)
 
-	for {
-		//if the cache start and the request is in the cache list, return the cache
-		if s.useCache {
-			if v, ok := s.cache.Get(filepath.Join(host.Host, r.URL.Path)); ok {
-				n, err := c.Write(v.([]byte))
-				if err != nil {
-					break
-				}
-				logs.Trace("%s request, method %s, host %s, url %s, remote address %s, return cache", r.URL.Scheme, r.Method, r.Host, r.URL.Path, c.RemoteAddr().String())
-				host.Flow.Add(0, int64(n))
-				//if return cache and does not create a new conn with client and Connection is not set or close, close the connection.
-				if strings.ToLower(r.Header.Get("Connection")) == "close" || strings.ToLower(r.Header.Get("Connection")) == "" {
-					break
-				}
-				goto readReq
-			}
-		}
+				r := ctx.Value("req").(*http.Request)
+				host = ctx.Value("host").(*file.Host)
+				targetAddr = ctx.Value("target").(string)
 
-		//change the host and header and set proxy setting
-		common.ChangeHostAndHeader(r, host.HostChange, host.HeaderChange, c.Conn.RemoteAddr().String(), s.addOrigin)
-		logs.Trace("%s request, method %s, host %s, url %s, remote address %s, target %s", r.URL.Scheme, r.Method, r.Host, r.URL.Path, c.RemoteAddr().String(), lk.Host)
-		//write
-		lenConn = conn.NewLenConn(connClient)
-		if err := r.Write(lenConn); err != nil {
-			logs.Error(err)
-			break
-		}
-		host.Flow.Add(int64(lenConn.Len), 0)
+				lk = conn.NewLink("http", targetAddr, host.Client.Cnf.Crypt, host.Client.Cnf.Compress, r.RemoteAddr, host.Target.LocalProxy)
+				if target, err = s.bridge.SendLinkInfo(host.Client.Id, lk, nil); err != nil {
+					logs.Notice("connect to target %s error %s", lk.Host, err)
+					return nil, NewHTTPError(http.StatusBadGateway, "Cannot connect to the server")
+				}
+				connClient = conn.GetConn(target, lk.Crypt, lk.Compress, host.Client.Rate, true)
+				return &flowConn{
+					ReadWriteCloser: connClient,
+					fakeAddr:        local,
+					host:            host,
+				}, nil
+			},
+		},
+		ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
+			logs.Warn("do http proxy request error: %v", err)
+			rw.WriteHeader(http.StatusNotFound)
+		},
+	})
+	proxy.WebSocketDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
+		var (
+			host       *file.Host
+			target     net.Conn
+			err        error
+			connClient io.ReadWriteCloser
+			targetAddr string
+			lk         *conn.Link
+		)
+		r := ctx.Value("req").(*http.Request)
+		host = ctx.Value("host").(*file.Host)
+		targetAddr = ctx.Value("target").(string)
 
-	readReq:
-		//read req from connection
-		if r, err = http.ReadRequest(bufio.NewReader(c)); err != nil {
-			break
-		}
-		r.URL.Scheme = scheme
-		//What happened ,Why one character less???
-		r.Method = resetReqMethod(r.Method)
-		if hostTmp, err := file.GetDb().GetInfoByHost(r.Host, r); err != nil {
-			logs.Notice("the url %s %s %s can't be parsed!", r.URL.Scheme, r.Host, r.RequestURI)
-			break
-		} else if host != hostTmp {
-			host = hostTmp
-			isReset = true
-			connClient.Close()
-			goto reset
+		lk = conn.NewLink("tcp", targetAddr, host.Client.Cnf.Crypt, host.Client.Cnf.Compress, r.RemoteAddr, host.Target.LocalProxy)
+		if target, err = s.bridge.SendLinkInfo(host.Client.Id, lk, nil); err != nil {
+			logs.Notice("connect to target %s error %s", lk.Host, err)
+			return nil, NewHTTPError(http.StatusBadGateway, "Cannot connect to the target")
 		}
+		connClient = conn.GetConn(target, lk.Crypt, lk.Compress, host.Client.Rate, true)
+		return &flowConn{
+			ReadWriteCloser: connClient,
+			fakeAddr:        local,
+			host:            host,
+		}, nil
 	}
-	wg.Wait()
+	rp.proxy = proxy
+	return rp
 }
 
-func resetReqMethod(method string) string {
-	if method == "ET" {
-		return "GET"
-	}
-	if method == "OST" {
-		return "POST"
-	}
-	return method
+type flowConn struct {
+	io.ReadWriteCloser
+	fakeAddr net.Addr
+	host     *file.Host
+	flowIn   int64
+	flowOut  int64
+	once     sync.Once
 }
 
-func (s *httpServer) NewServer(port int, scheme string) *http.Server {
-	return &http.Server{
-		Addr: ":" + strconv.Itoa(port),
-		Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-			r.URL.Scheme = scheme
-			s.handleTunneling(w, r)
-		}),
-		// Disable HTTP/2.
-		TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)),
-	}
+func (c *flowConn) Read(p []byte) (n int, err error) {
+	n, err = c.ReadWriteCloser.Read(p)
+	c.flowIn += int64(n)
+	return n, err
+}
+
+func (c *flowConn) Write(p []byte) (n int, err error) {
+	n, err = c.ReadWriteCloser.Write(p)
+	c.flowOut += int64(n)
+	return n, err
 }
+
+func (c *flowConn) Close() error {
+	c.once.Do(func() { c.host.Flow.Add(c.flowIn, c.flowOut) })
+	return c.ReadWriteCloser.Close()
+}
+
+func (c *flowConn) LocalAddr() net.Addr { return c.fakeAddr }
+
+func (c *flowConn) RemoteAddr() net.Addr { return c.fakeAddr }
+
+func (*flowConn) SetDeadline(t time.Time) error { return nil }
+
+func (*flowConn) SetReadDeadline(t time.Time) error { return nil }
+
+func (*flowConn) SetWriteDeadline(t time.Time) error { return nil }

+ 136 - 0
server/proxy/reverseproxy.go

@@ -0,0 +1,136 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// HTTP reverse proxy handler
+
+package proxy
+
+import (
+	"context"
+	"errors"
+	"io"
+	"net"
+	"net/http"
+	"net/http/httputil"
+	"net/url"
+	"strings"
+	"sync"
+)
+
+type HTTPError struct {
+	error
+	HTTPCode int
+}
+
+func NewHTTPError(code int, errmsg string) error {
+	return &HTTPError{
+		error:    errors.New(errmsg),
+		HTTPCode: code,
+	}
+}
+
+type ReverseProxy struct {
+	*httputil.ReverseProxy
+	WebSocketDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
+}
+
+func IsWebsocketRequest(req *http.Request) bool {
+	containsHeader := func(name, value string) bool {
+		items := strings.Split(req.Header.Get(name), ",")
+		for _, item := range items {
+			if value == strings.ToLower(strings.TrimSpace(item)) {
+				return true
+			}
+		}
+		return false
+	}
+	return containsHeader("Connection", "upgrade") && containsHeader("Upgrade", "websocket")
+}
+
+func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
+	rp := &ReverseProxy{
+		ReverseProxy:         httputil.NewSingleHostReverseProxy(target),
+		WebSocketDialContext: nil,
+	}
+	rp.ErrorHandler = rp.errHandler
+	return rp
+}
+
+func NewReverseProxy(orp *httputil.ReverseProxy) *ReverseProxy {
+	rp := &ReverseProxy{
+		ReverseProxy:         orp,
+		WebSocketDialContext: nil,
+	}
+	rp.ErrorHandler = rp.errHandler
+	return rp
+}
+
+func (p *ReverseProxy) errHandler(rw http.ResponseWriter, r *http.Request, e error) {
+	if e == io.EOF {
+		rw.WriteHeader(521)
+		//rw.Write(getWaitingPageContent())
+	} else {
+		if httperr, ok := e.(*HTTPError); ok {
+			rw.WriteHeader(httperr.HTTPCode)
+		} else {
+			rw.WriteHeader(http.StatusNotFound)
+		}
+		rw.Write([]byte("error: " + e.Error()))
+	}
+}
+
+func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
+	if IsWebsocketRequest(req) {
+		p.serveWebSocket(rw, req)
+	} else {
+		p.ReverseProxy.ServeHTTP(rw, req)
+	}
+}
+
+func (p *ReverseProxy) serveWebSocket(rw http.ResponseWriter, req *http.Request) {
+	if p.WebSocketDialContext == nil {
+		rw.WriteHeader(500)
+		return
+	}
+	targetConn, err := p.WebSocketDialContext(req.Context(), "tcp", "")
+	if err != nil {
+		rw.WriteHeader(501)
+		return
+	}
+	defer targetConn.Close()
+
+	p.Director(req)
+
+	hijacker, ok := rw.(http.Hijacker)
+	if !ok {
+		rw.WriteHeader(500)
+		return
+	}
+	conn, _, errHijack := hijacker.Hijack()
+	if errHijack != nil {
+		rw.WriteHeader(500)
+		return
+	}
+	defer conn.Close()
+
+	req.Write(targetConn)
+	Join(conn, targetConn)
+}
+
+func Join(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser) (inCount int64, outCount int64) {
+	var wait sync.WaitGroup
+	pipe := func(to io.ReadWriteCloser, from io.ReadWriteCloser, count *int64) {
+		defer to.Close()
+		defer from.Close()
+		defer wait.Done()
+
+		*count, _ = io.Copy(to, from)
+	}
+
+	wait.Add(2)
+	go pipe(c1, c2, &inCount)
+	go pipe(c2, c1, &outCount)
+	wait.Wait()
+	return
+}

+ 2 - 2
server/test/test.go

@@ -52,10 +52,10 @@ func TestServerConfig() {
 			if port, err := strconv.Atoi(p); err != nil {
 				log.Fatalln("get https port error", err)
 			} else {
-				if !common.FileExists(filepath.Join(common.GetRunPath(), beego.AppConfig.String("pemPath"))) {
+				if beego.AppConfig.String("pemPath") != "" && !common.FileExists(filepath.Join(common.GetRunPath(), beego.AppConfig.String("pemPath"))) {
 					log.Fatalf("ssl certFile %s is not exist", beego.AppConfig.String("pemPath"))
 				}
-				if !common.FileExists(filepath.Join(common.GetRunPath(), beego.AppConfig.String("ketPath"))) {
+				if beego.AppConfig.String("keyPath") != "" && !common.FileExists(filepath.Join(common.GetRunPath(), beego.AppConfig.String("keyPath"))) {
 					log.Fatalf("ssl keyFile %s is not exist", beego.AppConfig.String("pemPath"))
 				}
 				isInArr(&postTcpArr, port, "http port", "tcp")