浏览代码

Revert http reverse proxy changes

snowie2000 5 年之前
父节点
当前提交
16be6d1b55
共有 2 个文件被更改,包括 147 次插入268 次删除
  1. 147 132
      server/proxy/http.go
  2. 0 136
      server/proxy/reverseproxy.go

+ 147 - 132
server/proxy/http.go

@@ -1,7 +1,7 @@
 package proxy
 
 import (
-	"context"
+	"bufio"
 	"crypto/tls"
 	"io"
 	"net"
@@ -10,8 +10,8 @@ import (
 	"os"
 	"path/filepath"
 	"strconv"
+	"strings"
 	"sync"
-	"time"
 
 	"ehang.io/nps/bridge"
 	"ehang.io/nps/lib/cache"
@@ -101,159 +101,174 @@ func (s *httpServer) Close() error {
 	return nil
 }
 
-func (s *httpServer) NewServer(port int, scheme string) *http.Server {
-	rProxy := NewHttpReverseProxy(s)
-	return &http.Server{
-		Addr: ":" + strconv.Itoa(port),
-		Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-			r.URL.Scheme = scheme
-			rProxy.ServeHTTP(w, r)
-		}),
-		// Disable HTTP/2.
-		TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)),
+func (s *httpServer) handleTunneling(w http.ResponseWriter, r *http.Request) {
+	hijacker, ok := w.(http.Hijacker)
+	if !ok {
+		http.Error(w, "Hijacking not supported", http.StatusInternalServerError)
+		return
 	}
+	c, _, err := hijacker.Hijack()
+	if err != nil {
+		http.Error(w, err.Error(), http.StatusServiceUnavailable)
+	}
+	s.handleHttp(conn.NewConn(c), r)
 }
 
-type HttpReverseProxy struct {
-	proxy *ReverseProxy
-
-	responseHeaderTimeout time.Duration
-}
-
-func (rp *HttpReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
+func (s *httpServer) handleHttp(c *conn.Conn, r *http.Request) {
 	var (
 		host       *file.Host
-		targetAddr string
+		target     net.Conn
 		err        error
+		connClient io.ReadWriteCloser
+		scheme     = r.URL.Scheme
+		lk         *conn.Link
+		targetAddr string
+		lenConn    *conn.LenConn
+		isReset    bool
+		wg         sync.WaitGroup
 	)
-	if host, err = file.GetDb().GetInfoByHost(req.Host, req); err != nil {
-		rw.WriteHeader(http.StatusNotFound)
-		rw.Write([]byte(req.Host + " not found"))
+	defer func() {
+		if connClient != nil {
+			connClient.Close()
+		} else {
+			s.writeConnFail(c.Conn)
+		}
+		c.Close()
+	}()
+reset:
+	if isReset {
+		host.Client.AddConn()
+	}
+	if host, err = file.GetDb().GetInfoByHost(r.Host, r); err != nil {
+		logs.Notice("the url %s %s %s can't be parsed!", r.URL.Scheme, r.Host, r.RequestURI)
 		return
 	}
-	if host.Client.Cnf.U != "" && host.Client.Cnf.P != "" && !common.CheckAuth(req, host.Client.Cnf.U, host.Client.Cnf.P) {
-		rw.WriteHeader(http.StatusUnauthorized)
-		rw.Write([]byte("Unauthorized"))
+	if err := s.CheckFlowAndConnNum(host.Client); err != nil {
+		logs.Warn("client id %d, host id %d, error %s, when https connection", host.Client.Id, host.Id, err.Error())
+		return
+	}
+	if !isReset {
+		defer host.Client.AddConn()
+	}
+	if err = s.auth(r, c, host.Client.Cnf.U, host.Client.Cnf.P); err != nil {
+		logs.Warn("auth error", err, r.RemoteAddr)
 		return
 	}
 	if targetAddr, err = host.Target.GetRandomTarget(); err != nil {
-		rw.WriteHeader(http.StatusBadGateway)
-		rw.Write([]byte("502 Bad Gateway"))
+		logs.Warn(err.Error())
 		return
 	}
-	req = req.WithContext(context.WithValue(req.Context(), "host", host))
-	req = req.WithContext(context.WithValue(req.Context(), "target", targetAddr))
-	req = req.WithContext(context.WithValue(req.Context(), "req", req))
-
-	rp.proxy.ServeHTTP(rw, req)
-}
-
-func NewHttpReverseProxy(s *httpServer) *HttpReverseProxy {
-	rp := &HttpReverseProxy{
-		responseHeaderTimeout: 30 * time.Second,
+	lk = conn.NewLink("http", targetAddr, host.Client.Cnf.Crypt, host.Client.Cnf.Compress, r.RemoteAddr, host.Target.LocalProxy)
+	if target, err = s.bridge.SendLinkInfo(host.Client.Id, lk, nil); err != nil {
+		logs.Notice("connect to target %s error %s", lk.Host, err)
+		return
 	}
-	local, _ := net.ResolveTCPAddr("tcp", "127.0.0.1")
-	proxy := NewReverseProxy(&httputil.ReverseProxy{
-		Director: func(r *http.Request) {
-			host := r.Context().Value("host").(*file.Host)
-			common.ChangeHostAndHeader(r, host.HostChange, host.HeaderChange, "", false)
-		},
-		Transport: &http.Transport{
-			ResponseHeaderTimeout: rp.responseHeaderTimeout,
-			DisableKeepAlives:     true,
-			DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
-				var (
-					host       *file.Host
-					target     net.Conn
-					err        error
-					connClient io.ReadWriteCloser
-					targetAddr string
-					lk         *conn.Link
-				)
-
-				r := ctx.Value("req").(*http.Request)
-				host = ctx.Value("host").(*file.Host)
-				targetAddr = ctx.Value("target").(string)
+	connClient = conn.GetConn(target, lk.Crypt, lk.Compress, host.Client.Rate, true)
 
-				lk = conn.NewLink("http", targetAddr, host.Client.Cnf.Crypt, host.Client.Cnf.Compress, r.RemoteAddr, host.Target.LocalProxy)
-				if target, err = s.bridge.SendLinkInfo(host.Client.Id, lk, nil); err != nil {
-					logs.Notice("connect to target %s error %s", lk.Host, err)
-					return nil, NewHTTPError(http.StatusBadGateway, "Cannot connect to the server")
+	//read from inc-client
+	go func() {
+		wg.Add(1)
+		isReset = false
+		defer connClient.Close()
+		defer func() {
+			wg.Done()
+			if !isReset {
+				c.Close()
+			}
+		}()
+		for {
+			if resp, err := http.ReadResponse(bufio.NewReader(connClient), r); err != nil || resp == nil {
+				return
+			} else {
+				//if the cache is start and the response is in the extension,store the response to the cache list
+				if s.useCache && r.URL != nil && strings.Contains(r.URL.Path, ".") {
+					b, err := httputil.DumpResponse(resp, true)
+					if err != nil {
+						return
+					}
+					c.Write(b)
+					host.Flow.Add(0, int64(len(b)))
+					s.cache.Add(filepath.Join(host.Host, r.URL.Path), b)
+				} else {
+					lenConn := conn.NewLenConn(c)
+					if err := resp.Write(lenConn); err != nil {
+						logs.Error(err)
+						return
+					}
+					host.Flow.Add(0, int64(lenConn.Len))
 				}
-				connClient = conn.GetConn(target, lk.Crypt, lk.Compress, host.Client.Rate, true)
-				return &flowConn{
-					ReadWriteCloser: connClient,
-					fakeAddr:        local,
-					host:            host,
-				}, nil
-			},
-		},
-		ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
-			logs.Warn("do http proxy request error: %v", err)
-			rw.WriteHeader(http.StatusNotFound)
-		},
-	})
-	proxy.WebSocketDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
-		var (
-			host       *file.Host
-			target     net.Conn
-			err        error
-			connClient io.ReadWriteCloser
-			targetAddr string
-			lk         *conn.Link
-		)
-		r := ctx.Value("req").(*http.Request)
-		host = ctx.Value("host").(*file.Host)
-		targetAddr = ctx.Value("target").(string)
+			}
+		}
+	}()
 
-		lk = conn.NewLink("tcp", targetAddr, host.Client.Cnf.Crypt, host.Client.Cnf.Compress, r.RemoteAddr, host.Target.LocalProxy)
-		if target, err = s.bridge.SendLinkInfo(host.Client.Id, lk, nil); err != nil {
-			logs.Notice("connect to target %s error %s", lk.Host, err)
-			return nil, NewHTTPError(http.StatusBadGateway, "Cannot connect to the target")
+	for {
+		//if the cache start and the request is in the cache list, return the cache
+		if s.useCache {
+			if v, ok := s.cache.Get(filepath.Join(host.Host, r.URL.Path)); ok {
+				n, err := c.Write(v.([]byte))
+				if err != nil {
+					break
+				}
+				logs.Trace("%s request, method %s, host %s, url %s, remote address %s, return cache", r.URL.Scheme, r.Method, r.Host, r.URL.Path, c.RemoteAddr().String())
+				host.Flow.Add(0, int64(n))
+				//if return cache and does not create a new conn with client and Connection is not set or close, close the connection.
+				if strings.ToLower(r.Header.Get("Connection")) == "close" || strings.ToLower(r.Header.Get("Connection")) == "" {
+					break
+				}
+				goto readReq
+			}
 		}
-		connClient = conn.GetConn(target, lk.Crypt, lk.Compress, host.Client.Rate, true)
-		return &flowConn{
-			ReadWriteCloser: connClient,
-			fakeAddr:        local,
-			host:            host,
-		}, nil
-	}
-	rp.proxy = proxy
-	return rp
-}
 
-type flowConn struct {
-	io.ReadWriteCloser
-	fakeAddr net.Addr
-	host     *file.Host
-	flowIn   int64
-	flowOut  int64
-	once     sync.Once
-}
+		//change the host and header and set proxy setting
+		common.ChangeHostAndHeader(r, host.HostChange, host.HeaderChange, c.Conn.RemoteAddr().String(), s.addOrigin)
+		logs.Trace("%s request, method %s, host %s, url %s, remote address %s, target %s", r.URL.Scheme, r.Method, r.Host, r.URL.Path, c.RemoteAddr().String(), lk.Host)
+		//write
+		lenConn = conn.NewLenConn(connClient)
+		if err := r.Write(lenConn); err != nil {
+			logs.Error(err)
+			break
+		}
+		host.Flow.Add(int64(lenConn.Len), 0)
 
-func (c *flowConn) Read(p []byte) (n int, err error) {
-	n, err = c.ReadWriteCloser.Read(p)
-	c.flowIn += int64(n)
-	return n, err
+	readReq:
+		//read req from connection
+		if r, err = http.ReadRequest(bufio.NewReader(c)); err != nil {
+			break
+		}
+		r.URL.Scheme = scheme
+		//What happened ,Why one character less???
+		r.Method = resetReqMethod(r.Method)
+		if hostTmp, err := file.GetDb().GetInfoByHost(r.Host, r); err != nil {
+			logs.Notice("the url %s %s %s can't be parsed!", r.URL.Scheme, r.Host, r.RequestURI)
+			break
+		} else if host != hostTmp {
+			host = hostTmp
+			isReset = true
+			connClient.Close()
+			goto reset
+		}
+	}
+	wg.Wait()
 }
 
-func (c *flowConn) Write(p []byte) (n int, err error) {
-	n, err = c.ReadWriteCloser.Write(p)
-	c.flowOut += int64(n)
-	return n, err
+func resetReqMethod(method string) string {
+	if method == "ET" {
+		return "GET"
+	}
+	if method == "OST" {
+		return "POST"
+	}
+	return method
 }
 
-func (c *flowConn) Close() error {
-	c.once.Do(func() { c.host.Flow.Add(c.flowIn, c.flowOut) })
-	return c.ReadWriteCloser.Close()
+func (s *httpServer) NewServer(port int, scheme string) *http.Server {
+	return &http.Server{
+		Addr: ":" + strconv.Itoa(port),
+		Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+			r.URL.Scheme = scheme
+			s.handleTunneling(w, r)
+		}),
+		// Disable HTTP/2.
+		TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)),
+	}
 }
-
-func (c *flowConn) LocalAddr() net.Addr { return c.fakeAddr }
-
-func (c *flowConn) RemoteAddr() net.Addr { return c.fakeAddr }
-
-func (*flowConn) SetDeadline(t time.Time) error { return nil }
-
-func (*flowConn) SetReadDeadline(t time.Time) error { return nil }
-
-func (*flowConn) SetWriteDeadline(t time.Time) error { return nil }

+ 0 - 136
server/proxy/reverseproxy.go

@@ -1,136 +0,0 @@
-// Copyright 2011 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// HTTP reverse proxy handler
-
-package proxy
-
-import (
-	"context"
-	"errors"
-	"io"
-	"net"
-	"net/http"
-	"net/http/httputil"
-	"net/url"
-	"strings"
-	"sync"
-)
-
-type HTTPError struct {
-	error
-	HTTPCode int
-}
-
-func NewHTTPError(code int, errmsg string) error {
-	return &HTTPError{
-		error:    errors.New(errmsg),
-		HTTPCode: code,
-	}
-}
-
-type ReverseProxy struct {
-	*httputil.ReverseProxy
-	WebSocketDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
-}
-
-func IsWebsocketRequest(req *http.Request) bool {
-	containsHeader := func(name, value string) bool {
-		items := strings.Split(req.Header.Get(name), ",")
-		for _, item := range items {
-			if value == strings.ToLower(strings.TrimSpace(item)) {
-				return true
-			}
-		}
-		return false
-	}
-	return containsHeader("Connection", "upgrade") && containsHeader("Upgrade", "websocket")
-}
-
-func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
-	rp := &ReverseProxy{
-		ReverseProxy:         httputil.NewSingleHostReverseProxy(target),
-		WebSocketDialContext: nil,
-	}
-	rp.ErrorHandler = rp.errHandler
-	return rp
-}
-
-func NewReverseProxy(orp *httputil.ReverseProxy) *ReverseProxy {
-	rp := &ReverseProxy{
-		ReverseProxy:         orp,
-		WebSocketDialContext: nil,
-	}
-	rp.ErrorHandler = rp.errHandler
-	return rp
-}
-
-func (p *ReverseProxy) errHandler(rw http.ResponseWriter, r *http.Request, e error) {
-	if e == io.EOF {
-		rw.WriteHeader(521)
-		//rw.Write(getWaitingPageContent())
-	} else {
-		if httperr, ok := e.(*HTTPError); ok {
-			rw.WriteHeader(httperr.HTTPCode)
-		} else {
-			rw.WriteHeader(http.StatusNotFound)
-		}
-		rw.Write([]byte("error: " + e.Error()))
-	}
-}
-
-func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
-	if IsWebsocketRequest(req) {
-		p.serveWebSocket(rw, req)
-	} else {
-		p.ReverseProxy.ServeHTTP(rw, req)
-	}
-}
-
-func (p *ReverseProxy) serveWebSocket(rw http.ResponseWriter, req *http.Request) {
-	if p.WebSocketDialContext == nil {
-		rw.WriteHeader(500)
-		return
-	}
-	targetConn, err := p.WebSocketDialContext(req.Context(), "tcp", "")
-	if err != nil {
-		rw.WriteHeader(501)
-		return
-	}
-	defer targetConn.Close()
-
-	p.Director(req)
-
-	hijacker, ok := rw.(http.Hijacker)
-	if !ok {
-		rw.WriteHeader(500)
-		return
-	}
-	conn, _, errHijack := hijacker.Hijack()
-	if errHijack != nil {
-		rw.WriteHeader(500)
-		return
-	}
-	defer conn.Close()
-
-	req.Write(targetConn)
-	Join(conn, targetConn)
-}
-
-func Join(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser) (inCount int64, outCount int64) {
-	var wait sync.WaitGroup
-	pipe := func(to io.ReadWriteCloser, from io.ReadWriteCloser, count *int64) {
-		defer to.Close()
-		defer from.Close()
-		defer wait.Done()
-
-		*count, _ = io.Copy(to, from)
-	}
-
-	wait.Add(2)
-	go pipe(c1, c2, &inCount)
-	go pipe(c2, c1, &outCount)
-	wait.Wait()
-	return
-}