aboutsummaryrefslogblamecommitdiff
path: root/src/cert_store.rs
blob: 3d137f93fed93c8250af81927e4a2767420cae06 (plain) (tree)
1
2
3
4
5
6
7
8
9

                                         
                                   


                   
                                       

                               
                                
               


                                          
                               

                                 
                                  
                           


                       
                          
                                  
 
                                                  
                                                              
 
                                                           
                                                    

 




                                                                
                

                               
                                  

                                                                   
                                                                     
                        


                                                         
                               
                                  
                                          
                                                           
                                                                       
                                        


                                         






                                                                                              

                          

         



                                                                  

                                                                       
                                                                                
                                                                      
 
                      
                                               
                                                                                                  


                                                                                
                                         
 


                                                                                                                 
                                                                                                              
 
                                                                                








                                                                                                                     
                                                          



                                                                                                                                         
                                         
 


                                                                                                                                            
                                 
                                                                     


                                                                    
                                                                                                             
 
                                                                                                   
                                                                               
                                                                                                      
                                                                                        
                                                         
 
                                                                                                                       

                                                                                     
                                         
                                 
                          
 

                                                                         
                                                   

                                                                                               
                                                             
                                                                                                             


                                                                                                  
                                                                                                


                                                                                            




                                 


                                            
                                                              


                                                                  








                                                                                                             
                                         
                                                                 

                                                                                      
                                                                                                                                                       



                                                                    
                                                                             





                                                                                          
                                                                                                 










                                                                                                                  
                                                                        




                                                                         

                                                                                         








                                                                                                                             
                                                                        






                                                                                                                                                                                       





                                                                              
                                                                                    
                                                     
                                                                            

                                                                            
                         
                                                

                 
                                                                         
                                                            


                                                                                        
                                                


                                                        

         











                                                                          
                                                 


                                                                                    
                                                                                              


                                                 
                                                              




















                                                                                                             








                                                                                    
                                                                               




                                                               
                                                      









                                                                             











                                                                                                                
                                 



                                                                                                             





                                                
                                                                         


                                                                               
                                         
                                                         

                                                                    

                                                                 
                                                                                      

                                  
                                                                          
                                                        
                                           

                                                               


                                                                

                                                    







                                                                                





                                                                                                    

                                                   
                                                                          
                                                                                 
 


                                                                                          

                                                                                              

                                 
                                                                            
                                                                                                    






                                                                                                

                         
 
                                                           







                                                                              
                                                                                       


                                                                                    
                                                                                          

                                        
                                                                   
                                                                                        
 
                                                                 

                                                                 
                                                              

                  
                                                        
                                                  


                                                                                                         
 
                                                                    




                                                              
                                                      



                                                            
                                                                 
 

                                                                             


                                                                            

                                
                                                                            
                                                         
 

                                                                           
         
 












                                                                                        







































                                                                                               





                                                                                          
                                                      












                                                                      
 





                                                                                                       






                                                                                         

         
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};

use anyhow::Result;
use chrono::Utc;
use futures::{FutureExt, TryFutureExt};
use tokio::select;
use tokio::sync::{mpsc, watch};
use tokio::task::block_in_place;
use tracing::*;

use acme_micro::create_p384_key;
use acme_micro::{Directory, DirectoryUrl};
use rustls::sign::CertifiedKey;

use crate::cert::{Cert, CertSer};
use crate::consul::{self, Consul};
use crate::proxy_config::*;

pub struct CertStore {
	consul: Consul,
	node_name: String,
	letsencrypt_email: String,

	certs: RwLock<HashMap<String, Arc<Cert>>>,
	self_signed_certs: RwLock<HashMap<String, Arc<Cert>>>,

	rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>,
	tx_need_cert: mpsc::UnboundedSender<String>,
}

struct ProcessedDomains {
	static_domains: HashSet<String>,
	on_demand_domains: Vec<(glob::Pattern, Option<String>)>,
}

impl CertStore {
	pub fn new(
		consul: Consul,
		node_name: String,
		rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>,
		letsencrypt_email: String,
		exit_on_err: impl Fn(anyhow::Error) + Send + 'static,
	) -> Arc<Self> {
		let (tx, rx) = mpsc::unbounded_channel();

		let cert_store = Arc::new(Self {
			consul,
			node_name,
			letsencrypt_email,
			certs: RwLock::new(HashMap::new()),
			self_signed_certs: RwLock::new(HashMap::new()),
			rx_proxy_config,
			tx_need_cert: tx,
		});

		tokio::spawn(
			cert_store
				.clone()
				.certificate_loop(rx)
				.map_err(exit_on_err)
				.then(|_| async { info!("Certificate renewal task exited") }),
		);

		cert_store
	}

	async fn certificate_loop(
		self: Arc<Self>,
		mut rx_need_cert: mpsc::UnboundedReceiver<String>,
	) -> Result<()> {
		let mut rx_proxy_config = self.rx_proxy_config.clone();

		let mut t_last_check: HashMap<String, Instant> = HashMap::new();
		let mut proc_domains: Option<ProcessedDomains> = None;

		loop {
			let domains = select! {
				// Refresh some internal states, schedule static_domains for renew
				res = rx_proxy_config.changed() => {
					if res.is_err() {
						bail!("rx_proxy_config closed");
					}

					let mut static_domains: HashSet<String> = HashSet::new();
					let mut on_demand_domains: Vec<(glob::Pattern, Option<String>)> = vec![];

					let proxy_config: Arc<ProxyConfig> = rx_proxy_config.borrow().clone();

					for ent in proxy_config.entries.iter() {
						// Eagerly generate certificates for domains that
						// are not patterns
						match &ent.url_prefix.host {
							HostDescription::Hostname(domain) => {
								if let Some((host, _port)) = domain.split_once(':') {
									static_domains.insert(host.to_string());
								} else {
									static_domains.insert(domain.clone());
								}
							},
							HostDescription::Pattern(pattern) => {
								on_demand_domains.push((pattern.clone(), ent.on_demand_tls_ask.clone()));
							},
						}
					}

					// only static_domains are refreshed
					proc_domains = Some(ProcessedDomains { static_domains: static_domains.clone(), on_demand_domains });
					self.domain_validation(static_domains, proc_domains.as_ref()).await
				}
				// renew static and on-demand domains
				need_cert = rx_need_cert.recv() => {
					match need_cert {
						Some(dom) => {
							let mut candidates: HashSet<String> = HashSet::new();

							// collect certificates as much as possible
							candidates.insert(dom);
							while let Ok(dom2) = rx_need_cert.try_recv() {
								candidates.insert(dom2);
							}

							self.domain_validation(candidates, proc_domains.as_ref()).await
						}
						None => bail!("rx_need_cert closed"),
					}
				}
			};

			// Now that we have our list of domains to check,
			// actually do something
			for dom in domains.iter() {
				// Exclude from the list domains that were checked less than 60
				// seconds ago
				match t_last_check.get(dom) {
					Some(t) if Instant::now() - *t < Duration::from_secs(60) => continue,
					_ => t_last_check.insert(dom.to_string(), Instant::now()),
				};

				// Actual Let's Encrypt calls are done here (in sister function)
				debug!("Checking cert for domain: {}", dom);
				if let Err(e) = self.check_cert(dom).await {
					warn!("({}) Could not get certificate: {}", dom, e);
				}
			}
		}
	}

	async fn domain_validation(
		&self,
		candidates: HashSet<String>,
		maybe_proc_domains: Option<&ProcessedDomains>,
	) -> HashSet<String> {
		let mut domains: HashSet<String> = HashSet::new();

		// Handle initialization
		let proc_domains = match maybe_proc_domains {
			None => {
				warn!("Proxy config is not yet loaded, refusing all certificate generation");
				return domains;
			}
			Some(proc) => proc,
		};

		// Filter certificates...
		'outer: for candidate in candidates.into_iter() {
			// Disallow obvious wrong domains...
			if !candidate.contains('.') || candidate.ends_with(".local") {
				warn!("{} is probably not a publicly accessible domain, skipping (a self-signed certificate will be used)", candidate);
				continue;
			}

			// Try to register domain as a static domain
			if proc_domains.static_domains.contains(&candidate) {
				trace!("domain {} validated as static domain", candidate);
				domains.insert(candidate);
				continue;
			}

			// It's not a static domain, maybe an on-demand domain?
			for (pattern, maybe_check_url) in proc_domains.on_demand_domains.iter() {
				// check glob pattern
				if pattern.matches(&candidate) {
					// if no check url is set, accept domain as long as it matches the pattern
					let check_url = match maybe_check_url {
						None => {
							trace!(
								"domain {} validated on glob pattern {} only",
								candidate,
								pattern
							);
							domains.insert(candidate);
							continue 'outer;
						}
						Some(url) => url,
					};

					// if a check url is set, call it
					// -- avoid DDoSing a backend
					tokio::time::sleep(Duration::from_secs(2)).await;
					match self.on_demand_tls_ask(check_url, &candidate).await {
						Ok(()) => {
							trace!(
								"domain {} validated on glob pattern {} and on check url {}",
								candidate,
								pattern,
								check_url
							);
							domains.insert(candidate);
							continue 'outer;
						}
						Err(e) => {
							warn!("domain {} validation refused  on glob pattern {} and on check url {} with error: {}", candidate, pattern, check_url, e);
						}
					}
				}
			}
		}

		return domains;
	}

	/// This function is also in charge of the refresh of the domain names
	fn get_cert_for_https(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> {
		// Check in local memory if it exists
		if let Some(cert) = self.certs.read().unwrap().get(domain) {
			if cert.is_old() {
				self.tx_need_cert.send(domain.to_string())?;
			}
			return Ok(cert.clone());
		}

		// Not found in local memory, try to get it in background
		self.tx_need_cert.send(domain.to_string())?;

		// In the meantime, use a self-signed certificate
		if let Some(cert) = self.self_signed_certs.read().unwrap().get(domain) {
			return Ok(cert.clone());
		}

		self.gen_self_signed_certificate(domain)
	}

	pub async fn warmup_memory_store(self: &Arc<Self>) -> Result<()> {
		let consul_certs = self
			.consul
			.kv_get_prefix("certs/", None)
			.await?
			.into_inner();

		trace!(
			"Fetched {} certificate entries from Consul",
			consul_certs.len()
		);
		let mut loaded_certs: usize = 0;
		for (key, cert) in consul_certs {
			let certser: CertSer = match serde_json::from_slice(&cert) {
				Ok(cs) => cs,
				Err(e) => {
					warn!("Could not deserialize CertSer for {key}: {e}");
					continue;
				}
			};
			let domain = certser.hostname.clone();

			let cert = match Cert::new(certser) {
				Ok(c) => c,
				Err(e) => {
					warn!("Could not create Cert from CertSer for domain {domain}: {e}");
					continue;
				}
			};

			self.certs
				.write()
				.unwrap()
				.insert(domain.to_string(), Arc::new(cert));

			debug!("({domain}) Certificate loaded from Consul to the Memory Store");
			loaded_certs += 1;
		}
		info!("Memory store warmed up with {loaded_certs} certificates");
		Ok(())
	}

	/// Check certificate ensure that the certificate is in the memory store
	/// and that it does not need to be renewed.
	///
	/// If it's not in the memory store, it tries to load it from Consul,
	/// if it's not in Consul, it calls Let's Encrypt.
	///
	/// If the certificate is outdated in the memory store, it tries to load
	/// a more recent version in Consul, if the Consul version is also outdated,
	/// it tries to renew it
	pub async fn check_cert(self: &Arc<Self>, domain: &str) -> Result<()> {
		// First, try locally.
		{
			let certs = self.certs.read().unwrap();
			if let Some(cert) = certs.get(domain) {
				if !cert.is_old() {
					return Ok(());
				}
			}
		}

		// Second, try from Consul.
		if let Some(consul_cert) = self
			.consul
			.kv_get_json::<CertSer>(&format!("certs/{}", domain))
			.await?
		{
			match Cert::new(consul_cert) {
				Ok(cert) => {
					let cert = Arc::new(cert);
					self.certs
						.write()
						.unwrap()
						.insert(domain.to_string(), cert.clone());
					debug!("({domain}) Certificate loaded from Consul to the Memory Store");

					if !cert.is_old() {
						return Ok(());
					}
				}
				Err(e) => {
					warn!("Could not create Cert from CertSer for domain {domain}: {e}");
				}
			};
		}

		// Third, ask from Let's Encrypt
		self.renew_cert(domain).await
	}

	/// This is the place where certificates are generated or renewed
	pub async fn renew_cert(self: &Arc<Self>, domain: &str) -> Result<()> {
		info!("({}) Renewing certificate", domain);

		// ---- Acquire lock ----
		// the lock is acquired for half an hour,
		// so that in case of an error we won't retry before
		// that delay expires

		let lock_path = format!("renew_lock/{}", domain);
		let lock_name = format!("tricot/renew:{}@{}", domain, self.node_name);
		let session = self
			.consul
			.create_session(&consul::locking::SessionRequest {
				name: lock_name.clone(),
				node: None,
				lock_delay: Some("30m".into()),
				ttl: Some("45m".into()),
				behavior: Some("delete".into()),
			})
			.await?;
		debug!("Lock session: {}", session);

		if !self
			.consul
			.acquire(&lock_path, lock_name.clone().into(), &session)
			.await?
		{
			bail!("Lock is already taken, not renewing for now.");
		}

		// ---- Accessibility check ----
		// We don't want to ask Let's encrypt for a domain that
		// is not configured to point here. This can happen with wildcards: someone can send
		// a fake SNI to a domain that is not ours. We have to detect it here.
		self.check_domain_accessibility(domain, &session).await?;

		// ---- Do let's encrypt stuff ----

		let dir = Directory::from_url(DirectoryUrl::LetsEncrypt)?;
		let contact = vec![format!("mailto:{}", self.letsencrypt_email)];

		// Use existing Let's encrypt account or register new one if necessary
		let acc = match self.consul.kv_get("letsencrypt_account_key.pem").await? {
			Some(acc_privkey) => {
				info!("Using existing Let's encrypt account");
				dir.load_account(std::str::from_utf8(&acc_privkey)?, contact)?
			}
			None => {
				info!("Creating new Let's encrypt account");
				let acc = block_in_place(|| dir.register_account(contact.clone()))?;
				self.consul
					.kv_put(
						"letsencrypt_account_key.pem",
						acc.acme_private_key_pem()?.into_bytes().into(),
					)
					.await?;
				acc
			}
		};

		// Order certificate and perform validation
		let mut ord_new = acc.new_order(domain, &[])?;
		let ord_csr = loop {
			if let Some(ord_csr) = ord_new.confirm_validations() {
				break ord_csr;
			}

			let auths = ord_new.authorizations()?;

			info!("({}) Creating challenge and storing in Consul", domain);
			let chall = auths[0].http_challenge().unwrap();
			let chall_key = format!("challenge/{}", chall.http_token());
			self.consul
				.acquire(&chall_key, chall.http_proof()?.into(), &session)
				.await?;

			info!("({}) Validating challenge", domain);
			block_in_place(|| chall.validate(Duration::from_millis(5000)))?;

			info!("({}) Deleting challenge", domain);
			self.consul.kv_delete(&chall_key).await?;

			block_in_place(|| ord_new.refresh())?;
		};

		// Generate key and finalize certificate
		let pkey_pri = create_p384_key()?;
		let ord_cert =
			block_in_place(|| ord_csr.finalize_pkey(pkey_pri, Duration::from_millis(5000)))?;
		let cert = block_in_place(|| ord_cert.download_cert())?;

		info!("({}) Keys and certificate obtained", domain);
		let key_pem = cert.private_key().to_string();
		let cert_pem = cert.certificate().to_string();

		let certser = CertSer {
			hostname: domain.to_string(),
			date: Utc::now().date_naive(),
			valid_days: cert.valid_days_left()?,
			key_pem,
			cert_pem,
		};
		let cert = Arc::new(Cert::new(certser.clone())?);

		// Store certificate in Consul and local store
		self.certs.write().unwrap().insert(domain.to_string(), cert);
		self.consul
			.kv_put_json(&format!("certs/{}", domain), &certser)
			.await?;

		// Release locks
		self.consul.release(&lock_path, "".into(), &session).await?;
		self.consul.kv_delete(&lock_path).await?;

		info!("({}) Cert successfully renewed and stored", domain);
		Ok(())
	}

	async fn on_demand_tls_ask(&self, check_url: &str, domain: &str) -> Result<()> {
		let httpcli = reqwest::Client::new();
		let chall_url = format!("{}?domain={}", check_url, domain);
		info!("({}) On-demand TLS check", domain);

		let httpresp = httpcli.get(&chall_url).send().await?;
		if httpresp.status() != reqwest::StatusCode::OK {
			bail!("{} is not authorized for on-demand TLS", domain);
		}

		Ok(())
	}

	async fn check_domain_accessibility(&self, domain: &str, session: &str) -> Result<()> {
		// Returns Ok(()) only if domain is a correct domain name that
		// redirects to this server
		let self_challenge_id = uuid::Uuid::new_v4().to_string();
		let self_challenge_key = format!("challenge/{}", self_challenge_id);
		let self_challenge_resp = uuid::Uuid::new_v4().to_string();

		self.consul
			.acquire(
				&self_challenge_key,
				self_challenge_resp.as_bytes().to_vec().into(),
				session,
			)
			.await?;

		let httpcli = reqwest::Client::new();
		let chall_url = format!(
			"http://{}/.well-known/acme-challenge/{}",
			domain, self_challenge_id
		);

		for i in 1..=4 {
			tokio::time::sleep(Duration::from_secs(2)).await;
			info!("({}) Accessibility check {}/4", domain, i);

			let httpresp = httpcli.get(&chall_url).send().await?;
			if httpresp.status() == reqwest::StatusCode::OK
				&& httpresp.bytes().await? == self_challenge_resp.as_bytes()
			{
				// Challenge successfully validated
				info!("({}) Accessibility check successfull", domain);
				return Ok(());
			}

			tokio::time::sleep(Duration::from_secs(2)).await;
		}

		bail!("Unable to validate self-challenge for domain accessibility check");
	}

	fn gen_self_signed_certificate(&self, domain: &str) -> Result<Arc<Cert>> {
		let subject_alt_names = vec![domain.to_string(), "localhost".to_string()];
		let cert = rcgen::generate_simple_self_signed(subject_alt_names)?;

		let certser = CertSer {
			hostname: domain.to_string(),
			date: Utc::now().date_naive(),
			valid_days: 1024,
			key_pem: cert.serialize_private_key_pem(),
			cert_pem: cert.serialize_pem()?,
		};
		let cert = Arc::new(Cert::new(certser)?);
		self.self_signed_certs
			.write()
			.unwrap()
			.insert(domain.to_string(), cert.clone());
		info!("Added self-signed certificate for {}", domain);

		Ok(cert)
	}
}

pub struct StoreResolver(pub Arc<CertStore>);

impl rustls::server::ResolvesServerCert for StoreResolver {
	fn resolve(&self, client_hello: rustls::server::ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
		let domain = client_hello.server_name()?;
		match self.0.get_cert_for_https(domain) {
			Ok(cert) => Some(cert.certkey.clone()),
			Err(e) => {
				warn!("Could not get certificate for {}: {}", domain, e);
				None
			}
		}
	}
}