aboutsummaryrefslogblamecommitdiff
path: root/src/main.rs
blob: 051530596931bfd4d340fce1f56dc9cf2b7157f9 (plain) (tree)
1
2
3
4
5
6
7
8
9


                    
                               
                   
                       
               

                                       
                         
                         

                       
 

               
         
          
            
                 
                  
             
 
                            




                                        
 



                                    


                                           
                                                       
          

                                



                                                                            



                                                                                            







                                                                                    
                                                         




                                                




                                                                                            

                                        




                                              


                                         




                                               
                                        
 



                                                                                  
                                                                 
                                                                                  
                                      








                                                                                    
                                                                                                                                                                       

                                        





                                                       

 
                                                            
                 


                                                   
                                               
                                                            
         



                                                                                           
 





                                                     

                                   
                                 
                                    
 





                                                         

                                                                                 
                                            

                                                    
                                                            



                                                            
                                                                              
                                                        




                                                                    
 

                                                    
                                      

                                              
                                    
          





                                                                                                                                               
 






                                                                            








                                                                 





                                                           
                                   

                                               
                                            
          









                                                                  
          
 

                                                                                                  



                                                              





                                                               

                                                                             


                                                          
                                               




                                                            





                                                                             
                                                                                                   



                                                     



                                                                                                             


                                                                                               






                                                                




                                      
                 

         
 





















                                                                            
 
#[macro_use]
extern crate anyhow;

use std::collections::BTreeMap;
use std::sync::Arc;
use std::time::Instant;
use tracing::*;

use futures::{FutureExt, TryFutureExt};
use std::net::SocketAddr;
use structopt::StructOpt;
use tokio::select;
use tokio::sync::watch;

mod cert;
mod cert_store;
mod http;
mod https;
mod metrics;
mod proxy_config;
mod reverse_proxy;
mod tls_util;

pub use df_consul as consul;
use proxy_config::ProxyConfig;

#[cfg(feature = "dhat-heap")]
#[global_allocator]
static ALLOC: dhat::Alloc = dhat::Alloc;

#[derive(StructOpt, Debug)]
#[structopt(name = "tricot")]
struct Opt {
	/// Address of consul server
	#[structopt(
		long = "consul-addr",
		env = "TRICOT_CONSUL_HOST",
		default_value = "http://127.0.0.1:8500"
	)]
	pub consul_addr: String,

	/// CA certificate for Consul server with TLS
	#[structopt(long = "consul-ca-cert", env = "TRICOT_CONSUL_CA_CERT")]
	pub consul_ca_cert: Option<String>,

	/// Skip TLS verification for Consul
	#[structopt(long = "consul-tls-skip-verify", env = "TRICOT_CONSUL_TLS_SKIP_VERIFY")]
	pub consul_tls_skip_verify: bool,

	/// Client certificate for Consul server with TLS
	#[structopt(long = "consul-client-cert", env = "TRICOT_CONSUL_CLIENT_CERT")]
	pub consul_client_cert: Option<String>,

	/// Client key for Consul server with TLS
	#[structopt(long = "consul-client-key", env = "TRICOT_CONSUL_CLIENT_KEY")]
	pub consul_client_key: Option<String>,

	/// Prefix of Tricot's entries in Consul KV space
	#[structopt(
		long = "consul-kv-prefix",
		env = "TRICOT_CONSUL_KV_PREFIX",
		default_value = "tricot/"
	)]
	pub consul_kv_prefix: String,

	/// Node name
	#[structopt(long = "node-name", env = "TRICOT_NODE_NAME", default_value = "<none>")]
	pub node_name: String,

	/// Bind address for HTTP server
	#[structopt(
		long = "http-bind-addr",
		env = "TRICOT_HTTP_BIND_ADDR",
		default_value = "0.0.0.0:80"
	)]
	pub http_bind_addr: SocketAddr,

	/// Bind address for HTTPS server
	#[structopt(
		long = "https-bind-addr",
		env = "TRICOT_HTTPS_BIND_ADDR",
		default_value = "0.0.0.0:443"
	)]
	pub https_bind_addr: SocketAddr,

	/// Bind address for metrics server (Prometheus format over HTTP)
	#[structopt(long = "metrics-bind-addr", env = "TRICOT_METRICS_BIND_ADDR")]
	pub metrics_bind_addr: Option<SocketAddr>,

	/// E-mail address for Let's Encrypt certificate requests
	#[structopt(long = "letsencrypt-email", env = "TRICOT_LETSENCRYPT_EMAIL")]
	pub letsencrypt_email: String,

	/// Enable compression of responses
	#[structopt(long = "enable-compression", env = "TRICOT_ENABLE_COMPRESSION")]
	pub enable_compression: bool,

	/// Mime types for which to enable compression (comma-separated list)
	#[structopt(
		long = "compress-mime-types",
		env = "TRICOT_COMPRESS_MIME_TYPES",
		default_value = "text/html,text/plain,text/css,text/javascript,text/xml,application/javascript,application/json,application/xml,image/svg+xml,font/ttf"
	)]
	pub compress_mime_types: String,

	#[structopt(
		long = "warmup-cert-memory-store",
		env = "TRICOT_WARMUP_CERT_MEMORY_STORE"
	)]
	pub warmup_cert_memory_store: bool,
}

#[tokio::main(flavor = "multi_thread", worker_threads = 10)]
async fn main() {
	#[cfg(feature = "dhat-heap")]
	let _profiler = dhat::Profiler::new_heap();

	if std::env::var("RUST_LOG").is_err() {
		std::env::set_var("RUST_LOG", "tricot=info")
	}
	tracing_subscriber::fmt()
		.with_writer(std::io::stderr)
		.with_env_filter(tracing_subscriber::filter::EnvFilter::from_default_env())
		.init();

	// Abort on panic (same behavior as in Go)
	std::panic::set_hook(Box::new(|panic_info| {
		error!("{}", panic_info.to_string());
		std::process::abort();
	}));

	let opt = Opt::from_args();

	info!("Starting Tricot");
	println!("Starting Tricot");

	let (exit_signal, provoke_exit) = watch_ctrl_c();
	let exit_on_err = move |err: anyhow::Error| {
		error!("Error: {}", err);
		let _ = provoke_exit.send(true);
	};

	let metrics_server = metrics::MetricsServer::init(opt.metrics_bind_addr);

	let consul_config = consul::Config {
		addr: opt.consul_addr.clone(),
		ca_cert: opt.consul_ca_cert.clone(),
		tls_skip_verify: opt.consul_tls_skip_verify,
		client_cert: opt.consul_client_cert.clone(),
		client_key: opt.consul_client_key.clone(),
	};

	let consul = consul::Consul::new(consul_config, &opt.consul_kv_prefix)
		.expect("Error creating Consul client");
	let rx_proxy_config = proxy_config::spawn_proxy_config_task(
		consul.clone(),
		opt.node_name.clone(),
		exit_signal.clone(),
	);

	let cert_store = cert_store::CertStore::new(
		consul.clone(),
		opt.node_name.clone(),
		rx_proxy_config.clone(),
		opt.letsencrypt_email.clone(),
		exit_on_err.clone(),
	);
	if opt.warmup_cert_memory_store {
		match cert_store.warmup_memory_store().await {
            Err(e) => error!("An error occured while warming up the certificate memory store with Consul data, continue without caching: {e}"),
            _ => (),
        };
	}

	let metrics_task = tokio::spawn(
		metrics_server
			.run(wait_from(exit_signal.clone()))
			.map_err(exit_on_err.clone())
			.then(|_| async { info!("Metrics server exited") }),
	);

	let http_task = tokio::spawn(
		http::serve_http(
			opt.http_bind_addr,
			consul.clone(),
			wait_from(exit_signal.clone()),
		)
		.map_err(exit_on_err.clone())
		.then(|_| async { info!("HTTP server exited") }),
	);

	let https_config = https::HttpsConfig {
		bind_addr: opt.https_bind_addr,
		enable_compression: opt.enable_compression,
		compress_mime_types: opt
			.compress_mime_types
			.split(',')
			.map(|x| x.to_string())
			.collect(),
		time_origin: Instant::now(),
	};

	let https_task = tokio::spawn(
		https::serve_https(
			https_config,
			cert_store.clone(),
			rx_proxy_config.clone(),
			exit_signal.clone(),
		)
		.map_err(exit_on_err.clone())
		.then(|_| async { info!("HTTPS server exited") }),
	);

	let dump_task = tokio::spawn(dump_config_on_change(rx_proxy_config, exit_signal.clone()));

	metrics_task.await.expect("Tokio task await failure");
	http_task.await.expect("Tokio task await failure");
	https_task.await.expect("Tokio task await failure");
	dump_task.await.expect("Tokio task await failure");
}

async fn dump_config_on_change(
	mut rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>,
	mut must_exit: watch::Receiver<bool>,
) {
	let mut old_cfg: Arc<ProxyConfig> = rx_proxy_config.borrow().clone();

	while !*must_exit.borrow() {
		select!(
			c = rx_proxy_config.changed() => {
				if c.is_err() {
					break;
				}
			}
			_ = must_exit.changed() => continue,
		);

		let cfg: Arc<ProxyConfig> = rx_proxy_config.borrow().clone();
		if cfg != old_cfg {
			let mut cfg_map = BTreeMap::<_, Vec<_>>::new();
			for ent in cfg.entries.iter() {
				cfg_map
					.entry((&ent.url_prefix.host, &ent.url_prefix.path_prefix))
					.or_default()
					.push(ent);
			}

			println!(
				"---- PROXY CONFIGURATION at {} ----",
				chrono::offset::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Secs, true)
			);
			for ((host, prefix), ents) in cfg_map.iter_mut() {
				println!("{}{}:", host, prefix.as_deref().unwrap_or_default());
				for ent in ents.iter() {
					print!("    ");
					if !ent.flags.healthy {
						print!("/!\\ ");
					} else {
						print!("    ");
					}
					println!("{}", ent);
				}
			}
			println!();

			old_cfg = cfg;
		}
	}
}

/// Creates a watch that contains `false`, and that changes
/// to `true` when a Ctrl+C signal is received.
pub fn watch_ctrl_c() -> (watch::Receiver<bool>, Arc<watch::Sender<bool>>) {
	let (send_cancel, watch_cancel) = watch::channel(false);
	let send_cancel = Arc::new(send_cancel);
	let send_cancel_2 = send_cancel.clone();
	tokio::spawn(async move {
		tokio::signal::ctrl_c()
			.await
			.expect("failed to install CTRL+C signal handler");
		info!("Received CTRL+C, shutting down.");
		send_cancel.send(true).unwrap();
	});
	(watch_cancel, send_cancel_2)
}

async fn wait_from(mut chan: watch::Receiver<bool>) {
	while !*chan.borrow() {
		if chan.changed().await.is_err() {
			return;
		}
	}
}