aboutsummaryrefslogblamecommitdiff
path: root/src/cert_store.rs
blob: 0ced1788eec087b95050f7c16118c0dd51fbdfff (plain) (tree)
1
2
3
4
5
6
7
8
9

                                         
                                   


                   
                                       

                               
                                
               


                                          
                               

                                 
                                  
                           


                       
                          
                                  
                                                  
                                                              
                                                           
                                                    


                

                               
                                  

                                                                   
                                                                     
                        


                                                         
                               
                                  
                                                           
                                                                       
                                        
                                          


                                         






                                                                                              

                          

         



                                                                  

                                                                       


                                                                                

                                                                          



                                                                                
                                         


                                                                                                              
                                                                                                                 

















                                                                                                             
                                 
                         

                                                   
                                                             
                                                                                                             





                                                                                                  




                                 
                                                                                    





                                                
                                                                       




                                                                                      
                                                                            

                                                                            
                         
                                                

                 
                                                                         
                                                            


                                                                                        
                                                


                                                        

         








































                                                                                                             
                                                                               




                                                               
                                                      









                                                                             











                                                                                                                
                                 



                                                                                                             





                                                







                                                                                                                              
 
                                         
                                                         

                                                                    

                                                                 
                                                                                      

                                  
                                                                          
                                                        
                                           

                                                               


                                                                

                                                    







                                                                                





                                                                                                    

                                                   
                                                                          
                                                                                 
 


                                                                                          

                                                                                              

                                 
                                                                            
                                                                                                    






                                                                                                

                         
 
                                                           







                                                                              
                                                                                       


                                                                                    
                                                                                          

                                        
                                                                   
                                                                                        
 
                                                                 

                                                                 
                                                              

                  
                                                        
                                                  


                                                                                                         
 
                                                                    




                                                              
                                                      



                                                            
                                                                 
 

                                                                             


                                                                            

                                
                                                                            
                                                         
 

                                                                           
         
 







































                                                                                               





                                                                                          
                                                      












                                                                      
 





                                                                                                       






                                                                                         

         
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};

use anyhow::Result;
use chrono::Utc;
use futures::{FutureExt, TryFutureExt};
use tokio::select;
use tokio::sync::{mpsc, watch};
use tokio::task::block_in_place;
use tracing::*;

use acme_micro::create_p384_key;
use acme_micro::{Directory, DirectoryUrl};
use rustls::sign::CertifiedKey;

use crate::cert::{Cert, CertSer};
use crate::consul::{self, Consul};
use crate::proxy_config::*;

pub struct CertStore {
	consul: Consul,
	node_name: String,
	letsencrypt_email: String,
	certs: RwLock<HashMap<String, Arc<Cert>>>,
	self_signed_certs: RwLock<HashMap<String, Arc<Cert>>>,
	rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>,
	tx_need_cert: mpsc::UnboundedSender<String>,
}

impl CertStore {
	pub fn new(
		consul: Consul,
		node_name: String,
		rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>,
		letsencrypt_email: String,
		exit_on_err: impl Fn(anyhow::Error) + Send + 'static,
	) -> Arc<Self> {
		let (tx, rx) = mpsc::unbounded_channel();

		let cert_store = Arc::new(Self {
			consul,
			node_name,
			certs: RwLock::new(HashMap::new()),
			self_signed_certs: RwLock::new(HashMap::new()),
			rx_proxy_config,
			letsencrypt_email,
			tx_need_cert: tx,
		});

		tokio::spawn(
			cert_store
				.clone()
				.certificate_loop(rx)
				.map_err(exit_on_err)
				.then(|_| async { info!("Certificate renewal task exited") }),
		);

		cert_store
	}

	async fn certificate_loop(
		self: Arc<Self>,
		mut rx_need_cert: mpsc::UnboundedReceiver<String>,
	) -> Result<()> {
		let mut rx_proxy_config = self.rx_proxy_config.clone();

		let mut t_last_check: HashMap<String, Instant> = HashMap::new();

		loop {
			let mut domains: HashSet<String> = HashSet::new();

			select! {
				res = rx_proxy_config.changed() => {
					if res.is_err() {
						bail!("rx_proxy_config closed");
					}

					let proxy_config: Arc<ProxyConfig> = rx_proxy_config.borrow().clone();
					for ent in proxy_config.entries.iter() {
						if let HostDescription::Hostname(domain) = &ent.url_prefix.host {
							if let Some((host, _port)) = domain.split_once(':') {
								domains.insert(host.to_string());
							} else {
								domains.insert(domain.clone());
							}
						}
					}
				}
				need_cert = rx_need_cert.recv() => {
					match need_cert {
						Some(dom) => {
							domains.insert(dom);
							while let Ok(dom2) = rx_need_cert.try_recv() {
								domains.insert(dom2);
							}
						}
						None => bail!("rx_need_cert closed"),
					};
				}
			}

			for dom in domains.iter() {
				match t_last_check.get(dom) {
					Some(t) if Instant::now() - *t < Duration::from_secs(60) => continue,
					_ => t_last_check.insert(dom.to_string(), Instant::now()),
				};

				debug!("Checking cert for domain: {}", dom);
				if let Err(e) = self.check_cert(dom).await {
					warn!("({}) Could not get certificate: {}", dom, e);
				}
			}
		}
	}

	fn get_cert_for_https(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> {
		// Check if domain is authorized
		if !self
			.rx_proxy_config
			.borrow()
			.entries
			.iter()
			.any(|ent| ent.url_prefix.host.matches(domain))
		{
			bail!("Domain {} should not have a TLS certificate.", domain);
		}

		// Check in local memory if it exists
		if let Some(cert) = self.certs.read().unwrap().get(domain) {
			if cert.is_old() {
				self.tx_need_cert.send(domain.to_string())?;
			}
			return Ok(cert.clone());
		}

		// Not found in local memory, try to get it in background
		self.tx_need_cert.send(domain.to_string())?;

		// In the meantime, use a self-signed certificate
		if let Some(cert) = self.self_signed_certs.read().unwrap().get(domain) {
			return Ok(cert.clone());
		}

		self.gen_self_signed_certificate(domain)
	}

	pub async fn warmup_memory_store(self: &Arc<Self>) -> Result<()> {
		let consul_certs = self
			.consul
			.kv_get_prefix("certs/", None)
			.await?
			.into_inner();

		trace!(
			"Fetched {} certificate entries from Consul",
			consul_certs.len()
		);
		let mut loaded_certs: usize = 0;
		for (domain, cert) in consul_certs {
			let certser: CertSer = match serde_json::from_slice(&cert) {
				Ok(cs) => cs,
				Err(e) => {
					warn!("Could not deserialize CertSer for {domain}: {e}");
					continue;
				}
			};

			let cert = match Cert::new(certser) {
				Ok(c) => c,
				Err(e) => {
					warn!("Could not create Cert from CertSer for domain {domain}: {e}");
					continue;
				}
			};

			self.certs
				.write()
				.unwrap()
				.insert(domain.to_string(), Arc::new(cert));

			debug!("({domain}) Certificate loaded from Consul to the Memory Store");
			loaded_certs += 1;
		}
		info!("Memory store warmed up with {loaded_certs} certificates");
		Ok(())
	}

	pub async fn check_cert(self: &Arc<Self>, domain: &str) -> Result<()> {
		// First, try locally.
		{
			let certs = self.certs.read().unwrap();
			if let Some(cert) = certs.get(domain) {
				if !cert.is_old() {
					return Ok(());
				}
			}
		}

		// Second, try from Consul.
		if let Some(consul_cert) = self
			.consul
			.kv_get_json::<CertSer>(&format!("certs/{}", domain))
			.await?
		{
			match Cert::new(consul_cert) {
				Ok(cert) => {
					let cert = Arc::new(cert);
					self.certs
						.write()
						.unwrap()
						.insert(domain.to_string(), cert.clone());
					debug!("({domain}) Certificate loaded from Consul to the Memory Store");

					if !cert.is_old() {
						return Ok(());
					}
				}
				Err(e) => {
					warn!("Could not create Cert from CertSer for domain {domain}: {e}");
				}
			};
		}

		// Third, ask from Let's Encrypt
		self.renew_cert(domain).await
	}

	pub async fn renew_cert(self: &Arc<Self>, domain: &str) -> Result<()> {
		info!("({}) Renewing certificate", domain);

		// Basic sanity check (we could add more kinds of checks here)
		// This is just to help avoid getting rate-limited against ACME server
		if !domain.contains('.') || domain.ends_with(".local") {
			bail!("Probably not a publicly accessible domain, skipping (a self-signed certificate will be used)");
		}

		// ---- Acquire lock ----
		// the lock is acquired for half an hour,
		// so that in case of an error we won't retry before
		// that delay expires

		let lock_path = format!("renew_lock/{}", domain);
		let lock_name = format!("tricot/renew:{}@{}", domain, self.node_name);
		let session = self
			.consul
			.create_session(&consul::locking::SessionRequest {
				name: lock_name.clone(),
				node: None,
				lock_delay: Some("30m".into()),
				ttl: Some("45m".into()),
				behavior: Some("delete".into()),
			})
			.await?;
		debug!("Lock session: {}", session);

		if !self
			.consul
			.acquire(&lock_path, lock_name.clone().into(), &session)
			.await?
		{
			bail!("Lock is already taken, not renewing for now.");
		}

		// ---- Accessibility check ----
		// We don't want to ask Let's encrypt for a domain that
		// is not configured to point here. This can happen with wildcards: someone can send
		// a fake SNI to a domain that is not ours. We have to detect it here.
		self.check_domain_accessibility(domain, &session).await?;

		// ---- Do let's encrypt stuff ----

		let dir = Directory::from_url(DirectoryUrl::LetsEncrypt)?;
		let contact = vec![format!("mailto:{}", self.letsencrypt_email)];

		// Use existing Let's encrypt account or register new one if necessary
		let acc = match self.consul.kv_get("letsencrypt_account_key.pem").await? {
			Some(acc_privkey) => {
				info!("Using existing Let's encrypt account");
				dir.load_account(std::str::from_utf8(&acc_privkey)?, contact)?
			}
			None => {
				info!("Creating new Let's encrypt account");
				let acc = block_in_place(|| dir.register_account(contact.clone()))?;
				self.consul
					.kv_put(
						"letsencrypt_account_key.pem",
						acc.acme_private_key_pem()?.into_bytes().into(),
					)
					.await?;
				acc
			}
		};

		// Order certificate and perform validation
		let mut ord_new = acc.new_order(domain, &[])?;
		let ord_csr = loop {
			if let Some(ord_csr) = ord_new.confirm_validations() {
				break ord_csr;
			}

			let auths = ord_new.authorizations()?;

			info!("({}) Creating challenge and storing in Consul", domain);
			let chall = auths[0].http_challenge().unwrap();
			let chall_key = format!("challenge/{}", chall.http_token());
			self.consul
				.acquire(&chall_key, chall.http_proof()?.into(), &session)
				.await?;

			info!("({}) Validating challenge", domain);
			block_in_place(|| chall.validate(Duration::from_millis(5000)))?;

			info!("({}) Deleting challenge", domain);
			self.consul.kv_delete(&chall_key).await?;

			block_in_place(|| ord_new.refresh())?;
		};

		// Generate key and finalize certificate
		let pkey_pri = create_p384_key()?;
		let ord_cert =
			block_in_place(|| ord_csr.finalize_pkey(pkey_pri, Duration::from_millis(5000)))?;
		let cert = block_in_place(|| ord_cert.download_cert())?;

		info!("({}) Keys and certificate obtained", domain);
		let key_pem = cert.private_key().to_string();
		let cert_pem = cert.certificate().to_string();

		let certser = CertSer {
			hostname: domain.to_string(),
			date: Utc::now().date_naive(),
			valid_days: cert.valid_days_left()?,
			key_pem,
			cert_pem,
		};
		let cert = Arc::new(Cert::new(certser.clone())?);

		// Store certificate in Consul and local store
		self.certs.write().unwrap().insert(domain.to_string(), cert);
		self.consul
			.kv_put_json(&format!("certs/{}", domain), &certser)
			.await?;

		// Release locks
		self.consul.release(&lock_path, "".into(), &session).await?;
		self.consul.kv_delete(&lock_path).await?;

		info!("({}) Cert successfully renewed and stored", domain);
		Ok(())
	}

	async fn check_domain_accessibility(&self, domain: &str, session: &str) -> Result<()> {
		// Returns Ok(()) only if domain is a correct domain name that
		// redirects to this server
		let self_challenge_id = uuid::Uuid::new_v4().to_string();
		let self_challenge_key = format!("challenge/{}", self_challenge_id);
		let self_challenge_resp = uuid::Uuid::new_v4().to_string();

		self.consul
			.acquire(
				&self_challenge_key,
				self_challenge_resp.as_bytes().to_vec().into(),
				session,
			)
			.await?;

		let httpcli = reqwest::Client::new();
		let chall_url = format!(
			"http://{}/.well-known/acme-challenge/{}",
			domain, self_challenge_id
		);

		for i in 1..=4 {
			tokio::time::sleep(Duration::from_secs(2)).await;
			info!("({}) Accessibility check {}/4", domain, i);

			let httpresp = httpcli.get(&chall_url).send().await?;
			if httpresp.status() == reqwest::StatusCode::OK
				&& httpresp.bytes().await? == self_challenge_resp.as_bytes()
			{
				// Challenge successfully validated
				info!("({}) Accessibility check successfull", domain);
				return Ok(());
			}

			tokio::time::sleep(Duration::from_secs(2)).await;
		}

		bail!("Unable to validate self-challenge for domain accessibility check");
	}

	fn gen_self_signed_certificate(&self, domain: &str) -> Result<Arc<Cert>> {
		let subject_alt_names = vec![domain.to_string(), "localhost".to_string()];
		let cert = rcgen::generate_simple_self_signed(subject_alt_names)?;

		let certser = CertSer {
			hostname: domain.to_string(),
			date: Utc::now().date_naive(),
			valid_days: 1024,
			key_pem: cert.serialize_private_key_pem(),
			cert_pem: cert.serialize_pem()?,
		};
		let cert = Arc::new(Cert::new(certser)?);
		self.self_signed_certs
			.write()
			.unwrap()
			.insert(domain.to_string(), cert.clone());
		info!("Added self-signed certificate for {}", domain);

		Ok(cert)
	}
}

pub struct StoreResolver(pub Arc<CertStore>);

impl rustls::server::ResolvesServerCert for StoreResolver {
	fn resolve(&self, client_hello: rustls::server::ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
		let domain = client_hello.server_name()?;
		match self.0.get_cert_for_https(domain) {
			Ok(cert) => Some(cert.certkey.clone()),
			Err(e) => {
				warn!("Could not get certificate for {}: {}", domain, e);
				None
			}
		}
	}
}