server.go 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. package grace
  2. import (
  3. "crypto/tls"
  4. "crypto/x509"
  5. "fmt"
  6. "io/ioutil"
  7. "log"
  8. "net"
  9. "net/http"
  10. "os"
  11. "os/exec"
  12. "os/signal"
  13. "strings"
  14. "sync"
  15. "syscall"
  16. "time"
  17. )
  18. // Server embedded http.Server
  19. type Server struct {
  20. *http.Server
  21. GraceListener net.Listener
  22. SignalHooks map[int]map[os.Signal][]func()
  23. tlsInnerListener *graceListener
  24. wg sync.WaitGroup
  25. sigChan chan os.Signal
  26. isChild bool
  27. state uint8
  28. Network string
  29. }
  30. // Serve accepts incoming connections on the Listener l,
  31. // creating a new service goroutine for each.
  32. // The service goroutines read requests and then call srv.Handler to reply to them.
  33. func (srv *Server) Serve() (err error) {
  34. srv.state = StateRunning
  35. err = srv.Server.Serve(srv.GraceListener)
  36. log.Println(syscall.Getpid(), "Waiting for connections to finish...")
  37. srv.wg.Wait()
  38. srv.state = StateTerminate
  39. return
  40. }
  41. // ListenAndServe listens on the TCP network address srv.Addr and then calls Serve
  42. // to handle requests on incoming connections. If srv.Addr is blank, ":http" is
  43. // used.
  44. func (srv *Server) ListenAndServe() (err error) {
  45. addr := srv.Addr
  46. if addr == "" {
  47. addr = ":http"
  48. }
  49. go srv.handleSignals()
  50. l, err := srv.getListener(addr)
  51. if err != nil {
  52. log.Println(err)
  53. return err
  54. }
  55. srv.GraceListener = newGraceListener(l, srv)
  56. if srv.isChild {
  57. process, err := os.FindProcess(os.Getppid())
  58. if err != nil {
  59. log.Println(err)
  60. return err
  61. }
  62. err = process.Signal(syscall.SIGTERM)
  63. if err != nil {
  64. return err
  65. }
  66. }
  67. log.Println(os.Getpid(), srv.Addr)
  68. return srv.Serve()
  69. }
  70. // ListenAndServeTLS listens on the TCP network address srv.Addr and then calls
  71. // Serve to handle requests on incoming TLS connections.
  72. //
  73. // Filenames containing a certificate and matching private key for the server must
  74. // be provided. If the certificate is signed by a certificate authority, the
  75. // certFile should be the concatenation of the server's certificate followed by the
  76. // CA's certificate.
  77. //
  78. // If srv.Addr is blank, ":https" is used.
  79. func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) {
  80. addr := srv.Addr
  81. if addr == "" {
  82. addr = ":https"
  83. }
  84. if srv.TLSConfig == nil {
  85. srv.TLSConfig = &tls.Config{}
  86. }
  87. if srv.TLSConfig.NextProtos == nil {
  88. srv.TLSConfig.NextProtos = []string{"http/1.1"}
  89. }
  90. srv.TLSConfig.Certificates = make([]tls.Certificate, 1)
  91. srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
  92. if err != nil {
  93. return
  94. }
  95. go srv.handleSignals()
  96. l, err := srv.getListener(addr)
  97. if err != nil {
  98. log.Println(err)
  99. return err
  100. }
  101. srv.tlsInnerListener = newGraceListener(l, srv)
  102. srv.GraceListener = tls.NewListener(srv.tlsInnerListener, srv.TLSConfig)
  103. if srv.isChild {
  104. process, err := os.FindProcess(os.Getppid())
  105. if err != nil {
  106. log.Println(err)
  107. return err
  108. }
  109. err = process.Signal(syscall.SIGTERM)
  110. if err != nil {
  111. return err
  112. }
  113. }
  114. log.Println(os.Getpid(), srv.Addr)
  115. return srv.Serve()
  116. }
  117. // ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls
  118. // Serve to handle requests on incoming mutual TLS connections.
  119. func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) {
  120. addr := srv.Addr
  121. if addr == "" {
  122. addr = ":https"
  123. }
  124. if srv.TLSConfig == nil {
  125. srv.TLSConfig = &tls.Config{}
  126. }
  127. if srv.TLSConfig.NextProtos == nil {
  128. srv.TLSConfig.NextProtos = []string{"http/1.1"}
  129. }
  130. srv.TLSConfig.Certificates = make([]tls.Certificate, 1)
  131. srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
  132. if err != nil {
  133. return
  134. }
  135. srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
  136. pool := x509.NewCertPool()
  137. data, err := ioutil.ReadFile(trustFile)
  138. if err != nil {
  139. log.Println(err)
  140. return err
  141. }
  142. pool.AppendCertsFromPEM(data)
  143. srv.TLSConfig.ClientCAs = pool
  144. log.Println("Mutual HTTPS")
  145. go srv.handleSignals()
  146. l, err := srv.getListener(addr)
  147. if err != nil {
  148. log.Println(err)
  149. return err
  150. }
  151. srv.tlsInnerListener = newGraceListener(l, srv)
  152. srv.GraceListener = tls.NewListener(srv.tlsInnerListener, srv.TLSConfig)
  153. if srv.isChild {
  154. process, err := os.FindProcess(os.Getppid())
  155. if err != nil {
  156. log.Println(err)
  157. return err
  158. }
  159. err = process.Kill()
  160. if err != nil {
  161. return err
  162. }
  163. }
  164. log.Println(os.Getpid(), srv.Addr)
  165. return srv.Serve()
  166. }
  167. // getListener either opens a new socket to listen on, or takes the acceptor socket
  168. // it got passed when restarted.
  169. func (srv *Server) getListener(laddr string) (l net.Listener, err error) {
  170. if srv.isChild {
  171. var ptrOffset uint
  172. if len(socketPtrOffsetMap) > 0 {
  173. ptrOffset = socketPtrOffsetMap[laddr]
  174. log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr])
  175. }
  176. f := os.NewFile(uintptr(3+ptrOffset), "")
  177. l, err = net.FileListener(f)
  178. if err != nil {
  179. err = fmt.Errorf("net.FileListener error: %v", err)
  180. return
  181. }
  182. } else {
  183. l, err = net.Listen(srv.Network, laddr)
  184. if err != nil {
  185. err = fmt.Errorf("net.Listen error: %v", err)
  186. return
  187. }
  188. }
  189. return
  190. }
  191. // handleSignals listens for os Signals and calls any hooked in function that the
  192. // user had registered with the signal.
  193. func (srv *Server) handleSignals() {
  194. var sig os.Signal
  195. signal.Notify(
  196. srv.sigChan,
  197. hookableSignals...,
  198. )
  199. pid := syscall.Getpid()
  200. for {
  201. sig = <-srv.sigChan
  202. srv.signalHooks(PreSignal, sig)
  203. switch sig {
  204. case syscall.SIGHUP:
  205. log.Println(pid, "Received SIGHUP. forking.")
  206. err := srv.fork()
  207. if err != nil {
  208. log.Println("Fork err:", err)
  209. }
  210. case syscall.SIGINT:
  211. log.Println(pid, "Received SIGINT.")
  212. srv.shutdown()
  213. case syscall.SIGTERM:
  214. log.Println(pid, "Received SIGTERM.")
  215. srv.shutdown()
  216. default:
  217. log.Printf("Received %v: nothing i care about...\n", sig)
  218. }
  219. srv.signalHooks(PostSignal, sig)
  220. }
  221. }
  222. func (srv *Server) signalHooks(ppFlag int, sig os.Signal) {
  223. if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet {
  224. return
  225. }
  226. for _, f := range srv.SignalHooks[ppFlag][sig] {
  227. f()
  228. }
  229. }
  230. // shutdown closes the listener so that no new connections are accepted. it also
  231. // starts a goroutine that will serverTimeout (stop all running requests) the server
  232. // after DefaultTimeout.
  233. func (srv *Server) shutdown() {
  234. if srv.state != StateRunning {
  235. return
  236. }
  237. srv.state = StateShuttingDown
  238. if DefaultTimeout >= 0 {
  239. go srv.serverTimeout(DefaultTimeout)
  240. }
  241. err := srv.GraceListener.Close()
  242. if err != nil {
  243. log.Println(syscall.Getpid(), "Listener.Close() error:", err)
  244. } else {
  245. log.Println(syscall.Getpid(), srv.GraceListener.Addr(), "Listener closed.")
  246. }
  247. }
  248. // serverTimeout forces the server to shutdown in a given timeout - whether it
  249. // finished outstanding requests or not. if Read/WriteTimeout are not set or the
  250. // max header size is very big a connection could hang
  251. func (srv *Server) serverTimeout(d time.Duration) {
  252. defer func() {
  253. if r := recover(); r != nil {
  254. log.Println("WaitGroup at 0", r)
  255. }
  256. }()
  257. if srv.state != StateShuttingDown {
  258. return
  259. }
  260. time.Sleep(d)
  261. log.Println("[STOP - Hammer Time] Forcefully shutting down parent")
  262. for {
  263. if srv.state == StateTerminate {
  264. break
  265. }
  266. srv.wg.Done()
  267. }
  268. }
  269. func (srv *Server) fork() (err error) {
  270. regLock.Lock()
  271. defer regLock.Unlock()
  272. if runningServersForked {
  273. return
  274. }
  275. runningServersForked = true
  276. var files = make([]*os.File, len(runningServers))
  277. var orderArgs = make([]string, len(runningServers))
  278. for _, srvPtr := range runningServers {
  279. switch srvPtr.GraceListener.(type) {
  280. case *graceListener:
  281. files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.GraceListener.(*graceListener).File()
  282. default:
  283. files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.tlsInnerListener.File()
  284. }
  285. orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr
  286. }
  287. log.Println(files)
  288. path := os.Args[0]
  289. var args []string
  290. if len(os.Args) > 1 {
  291. for _, arg := range os.Args[1:] {
  292. if arg == "-graceful" {
  293. break
  294. }
  295. args = append(args, arg)
  296. }
  297. }
  298. args = append(args, "-graceful")
  299. if len(runningServers) > 1 {
  300. args = append(args, fmt.Sprintf(`-socketorder=%s`, strings.Join(orderArgs, ",")))
  301. log.Println(args)
  302. }
  303. cmd := exec.Command(path, args...)
  304. cmd.Stdout = os.Stdout
  305. cmd.Stderr = os.Stderr
  306. cmd.ExtraFiles = files
  307. err = cmd.Start()
  308. if err != nil {
  309. log.Fatalf("Restart: Failed to launch, error: %v", err)
  310. }
  311. return
  312. }
  313. // RegisterSignalHook registers a function to be run PreSignal or PostSignal for a given signal.
  314. func (srv *Server) RegisterSignalHook(ppFlag int, sig os.Signal, f func()) (err error) {
  315. if ppFlag != PreSignal && ppFlag != PostSignal {
  316. err = fmt.Errorf("Invalid ppFlag argument. Must be either grace.PreSignal or grace.PostSignal")
  317. return
  318. }
  319. for _, s := range hookableSignals {
  320. if s == sig {
  321. srv.SignalHooks[ppFlag][sig] = append(srv.SignalHooks[ppFlag][sig], f)
  322. return
  323. }
  324. }
  325. err = fmt.Errorf("Signal '%v' is not supported", sig)
  326. return
  327. }