From 5e5299a6d0addc6498be12a24451860b9e4c3445 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Mon, 24 Jan 2022 19:28:18 +0100 Subject: Add graceful shutdown and memory tracing --- src/cert_store.rs | 12 +++++-- src/http.rs | 11 +++++-- src/https.rs | 52 +++++++++++++++++++++-------- src/main.rs | 95 ++++++++++++++++++++++++++++++++++++++++++++++------- src/proxy_config.rs | 23 ++++++++++--- 5 files changed, 159 insertions(+), 34 deletions(-) (limited to 'src') diff --git a/src/cert_store.rs b/src/cert_store.rs index d561605..c1381db 100644 --- a/src/cert_store.rs +++ b/src/cert_store.rs @@ -4,7 +4,7 @@ use std::time::{Duration, Instant}; use anyhow::Result; use chrono::Utc; -use futures::TryFutureExt; +use futures::{FutureExt, TryFutureExt}; use log::*; use tokio::select; use tokio::sync::{mpsc, watch}; @@ -16,7 +16,6 @@ use rustls::sign::CertifiedKey; use crate::cert::{Cert, CertSer}; use crate::consul::*; -use crate::exit_on_err; use crate::proxy_config::*; pub struct CertStore { @@ -33,6 +32,7 @@ impl CertStore { consul: Consul, rx_proxy_config: watch::Receiver>, letsencrypt_email: String, + exit_on_err: impl Fn(anyhow::Error) + Send + 'static, ) -> Arc { let (tx, rx) = mpsc::unbounded_channel(); @@ -45,7 +45,13 @@ impl CertStore { tx_need_cert: tx, }); - tokio::spawn(cert_store.clone().certificate_loop(rx).map_err(exit_on_err)); + tokio::spawn( + cert_store + .clone() + .certificate_loop(rx) + .map_err(exit_on_err) + .then(|_| async { info!("Certificate renewal task exited") }), + ); cert_store } diff --git a/src/http.rs b/src/http.rs index 05d7440..973e77f 100644 --- a/src/http.rs +++ b/src/http.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use anyhow::Result; use log::*; +use futures::future::Future; use http::uri::Authority; use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Request, Response, Server, StatusCode, Uri}; @@ -12,7 +13,11 @@ use crate::consul::Consul; const CHALLENGE_PREFIX: &str = "/.well-known/acme-challenge/"; -pub async fn serve_http(bind_addr: SocketAddr, consul: Consul) -> Result<()> { +pub async fn serve_http( + bind_addr: SocketAddr, + consul: Consul, + shutdown_signal: impl Future, +) -> Result<()> { let consul = Arc::new(consul); // For every connection, we must make a `Service` to handle all // incoming HTTP requests on said connection. @@ -30,7 +35,9 @@ pub async fn serve_http(bind_addr: SocketAddr, consul: Consul) -> Result<()> { }); info!("Listening on http://{}", bind_addr); - let server = Server::bind(&bind_addr).serve(make_svc); + let server = Server::bind(&bind_addr) + .serve(make_svc) + .with_graceful_shutdown(shutdown_signal); server.await?; diff --git a/src/https.rs b/src/https.rs index 34e3f85..6b1f5e7 100644 --- a/src/https.rs +++ b/src/https.rs @@ -7,6 +7,7 @@ use log::*; use accept_encoding_fork::Encoding; use async_compression::tokio::bufread::*; +use futures::stream::FuturesUnordered; use futures::StreamExt; use futures::TryStreamExt; use http::header::{HeaderName, HeaderValue}; @@ -15,6 +16,7 @@ use hyper::server::conn::Http; use hyper::service::service_fn; use hyper::{header, Body, Request, Response, StatusCode}; use tokio::net::TcpListener; +use tokio::select; use tokio::sync::watch; use tokio_rustls::TlsAcceptor; use tokio_util::io::{ReaderStream, StreamReader}; @@ -33,6 +35,7 @@ pub async fn serve_https( config: HttpsConfig, cert_store: Arc, rx_proxy_config: watch::Receiver>, + mut must_exit: watch::Receiver, ) -> Result<()> { let config = Arc::new(config); @@ -47,28 +50,43 @@ pub async fn serve_https( info!("Starting to serve on https://{}.", config.bind_addr); let tcp = TcpListener::bind(config.bind_addr).await?; - loop { - let (socket, remote_addr) = tcp.accept().await?; + let mut connections = FuturesUnordered::new(); + + while !*must_exit.borrow() { + let (socket, remote_addr) = select! { + a = tcp.accept() => a?, + _ = connections.next() => continue, + _ = must_exit.changed() => continue, + }; let rx_proxy_config = rx_proxy_config.clone(); let tls_acceptor = tls_acceptor.clone(); let config = config.clone(); - tokio::spawn(async move { + let mut must_exit_2 = must_exit.clone(); + let conn = tokio::spawn(async move { match tls_acceptor.accept(socket).await { Ok(stream) => { debug!("TLS handshake was successfull"); - let http_result = Http::new() - .serve_connection( - stream, - service_fn(move |req: Request| { - let https_config = config.clone(); - let proxy_config: Arc = - rx_proxy_config.borrow().clone(); - handle_outer(remote_addr, req, https_config, proxy_config) - }), + let http_conn = Http::new().serve_connection( + stream, + service_fn(move |req: Request| { + let https_config = config.clone(); + let proxy_config: Arc = rx_proxy_config.borrow().clone(); + handle_outer(remote_addr, req, https_config, proxy_config) + }), + ); + tokio::pin!(http_conn); + let http_result = loop { + select! ( + r = &mut http_conn => break r, + _ = must_exit_2.changed() => { + if *must_exit_2.borrow() { + http_conn.as_mut().graceful_shutdown(); + } + } ) - .await; + }; if let Err(http_err) = http_result { warn!("HTTP error: {}", http_err); } @@ -76,7 +94,15 @@ pub async fn serve_https( Err(e) => warn!("Error in TLS connection: {}", e), } }); + connections.push(conn); } + + info!("HTTPS server shutting down, draining remaining connections..."); + while !connections.is_empty() { + let _ = connections.next().await; + } + + Ok(()) } async fn handle_outer( diff --git a/src/main.rs b/src/main.rs index dcbd187..dada7e7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,14 @@ #[macro_use] extern crate anyhow; -use futures::TryFutureExt; +use log::*; +use std::sync::Arc; + +use futures::{FutureExt, TryFutureExt}; use std::net::SocketAddr; use structopt::StructOpt; +use tokio::select; +use tokio::sync::watch; mod cert; mod cert_store; @@ -14,7 +19,11 @@ mod proxy_config; mod reverse_proxy; mod tls_util; -use log::*; +use proxy_config::ProxyConfig; + +#[cfg(feature = "dhat-heap")] +#[global_allocator] +static ALLOC: dhat::Alloc = dhat::Alloc; #[derive(StructOpt, Debug)] #[structopt(name = "tricot")] @@ -86,6 +95,9 @@ struct Opt { #[tokio::main(flavor = "multi_thread", worker_threads = 10)] async fn main() { + #[cfg(feature = "dhat-heap")] + let _profiler = dhat::Profiler::new_heap(); + if std::env::var("RUST_LOG").is_err() { std::env::set_var("RUST_LOG", "tricot=info") } @@ -101,6 +113,12 @@ async fn main() { info!("Starting Tricot"); + let (exit_signal, provoke_exit) = watch_ctrl_c(); + let exit_on_err = move |err: anyhow::Error| { + error!("Error: {}", err); + let _ = provoke_exit.send(true); + }; + let consul_config = consul::ConsulConfig { addr: opt.consul_addr.clone(), ca_cert: opt.consul_ca_cert.clone(), @@ -110,15 +128,25 @@ async fn main() { let consul = consul::Consul::new(consul_config, &opt.consul_kv_prefix, &opt.node_name) .expect("Error creating Consul client"); - let mut rx_proxy_config = proxy_config::spawn_proxy_config_task(consul.clone()); + let rx_proxy_config = + proxy_config::spawn_proxy_config_task(consul.clone(), exit_signal.clone()); let cert_store = cert_store::CertStore::new( consul.clone(), rx_proxy_config.clone(), opt.letsencrypt_email.clone(), + exit_on_err.clone(), ); - tokio::spawn(http::serve_http(opt.http_bind_addr, consul.clone()).map_err(exit_on_err)); + let http_task = tokio::spawn( + http::serve_http( + opt.http_bind_addr, + consul.clone(), + wait_from(exit_signal.clone()), + ) + .map_err(exit_on_err.clone()) + .then(|_| async { info!("HTTP server exited") }), + ); let https_config = https::HttpsConfig { bind_addr: opt.https_bind_addr, @@ -129,12 +157,38 @@ async fn main() { .map(|x| x.to_string()) .collect(), }; - tokio::spawn( - https::serve_https(https_config, cert_store.clone(), rx_proxy_config.clone()) - .map_err(exit_on_err), + + let https_task = tokio::spawn( + https::serve_https( + https_config, + cert_store.clone(), + rx_proxy_config.clone(), + exit_signal.clone(), + ) + .map_err(exit_on_err.clone()) + .then(|_| async { info!("HTTPS server exited") }), ); - while rx_proxy_config.changed().await.is_ok() { + let dump_task = tokio::spawn(dump_config_on_change(rx_proxy_config, exit_signal.clone())); + + let _ = http_task.await.expect("Tokio task await failure"); + let _ = https_task.await.expect("Tokio task await failure"); + let _ = dump_task.await.expect("Tokio task await failure"); +} + +async fn dump_config_on_change( + mut rx_proxy_config: watch::Receiver>, + mut must_exit: watch::Receiver, +) { + while !*must_exit.borrow() { + select!( + c = rx_proxy_config.changed() => { + if !c.is_ok() { + break; + } + } + _ = must_exit.changed() => continue, + ); println!("---- PROXY CONFIGURATION ----"); for ent in rx_proxy_config.borrow().entries.iter() { println!(" {}", ent); @@ -143,7 +197,26 @@ async fn main() { } } -fn exit_on_err(e: anyhow::Error) { - error!("{}", e); - std::process::exit(1); +/// Creates a watch that contains `false`, and that changes +/// to `true` when a Ctrl+C signal is received. +pub fn watch_ctrl_c() -> (watch::Receiver, Arc>) { + let (send_cancel, watch_cancel) = watch::channel(false); + let send_cancel = Arc::new(send_cancel); + let send_cancel_2 = send_cancel.clone(); + tokio::spawn(async move { + tokio::signal::ctrl_c() + .await + .expect("failed to install CTRL+C signal handler"); + info!("Received CTRL+C, shutting down."); + send_cancel.send(true).unwrap(); + }); + (watch_cancel, send_cancel_2) +} + +async fn wait_from(mut chan: watch::Receiver) { + while !*chan.borrow() { + if chan.changed().await.is_err() { + return; + } + } } diff --git a/src/proxy_config.rs b/src/proxy_config.rs index 4add98c..e380885 100644 --- a/src/proxy_config.rs +++ b/src/proxy_config.rs @@ -9,7 +9,7 @@ use futures::future::BoxFuture; use futures::stream::{FuturesUnordered, StreamExt}; use log::*; -use tokio::{sync::watch, time::sleep}; +use tokio::{select, sync::watch, time::sleep}; use crate::consul::*; @@ -231,7 +231,10 @@ struct NodeWatchState { retries: u32, } -pub fn spawn_proxy_config_task(consul: Consul) -> watch::Receiver> { +pub fn spawn_proxy_config_task( + consul: Consul, + mut must_exit: watch::Receiver, +) -> watch::Receiver> { let (tx, rx) = watch::channel(Arc::new(ProxyConfig { entries: Vec::new(), })); @@ -244,8 +247,13 @@ pub fn spawn_proxy_config_task(consul: Consul) -> watch::Receiver ln, + _ = must_exit.changed() => continue, + }; + + match list_nodes { Ok(consul_nodes) => { info!("Watched consul nodes: {:?}", consul_nodes); for consul_node in consul_nodes { @@ -271,7 +279,12 @@ pub fn spawn_proxy_config_task(consul: Consul) -> watch::Receiver) = match watches.next().await { + let next_watch = select! { + nw = watches.next() => nw, + _ = must_exit.changed() => continue, + }; + + let (node, res): (String, Result<_>) = match next_watch { Some(v) => v, None => { warn!("No nodes currently watched in proxy_config.rs"); -- cgit v1.2.3