aboutsummaryrefslogblamecommitdiff
path: root/src/cert_store.rs
blob: 283479570fda6e2ba4fe2883564feac810c869f4 (plain) (tree)
1
2
3
4
5
6
7
8
9







                                         
                                


                                          
                               

                                 
                     
                           


                       
                                  
                                                  
                                                              
                                                           


                
                                                                                                                                


                                                           
                                                                       
                                        
                                          


                  


                                                                       




                                                                                              


                                                                                      
                         

                                                   
                                                                                           






                                                                               






                                                                                        
                                                            




                                                                                      
                                                                            




                                                        
                                                                         
                                                                             








                                                                                        





                                                                                          



































                                                                                          
                                         


                                                                    






                                                                                                      


                                                               


                                                                

                                                    









                                                                                
                                                                          
                                                                                 






                                                                                                             
                                                                                                    




















                                                                                                
                                                                                          


                                                      
                                                                                        



                                                                 
                                                              


                                                  


                                                                                                         















                                                                            
                                                                            









                                                                  




















                                                                                          
 





                                                                                                       






                                                                                         

         
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, RwLock};
use std::time::Duration;

use anyhow::Result;
use chrono::Utc;
use log::*;
use tokio::sync::watch;
use tokio::task::block_in_place;

use acme_micro::create_p384_key;
use acme_micro::{Directory, DirectoryUrl};
use rustls::sign::CertifiedKey;

use crate::cert::{Cert, CertSer};
use crate::consul::*;
use crate::proxy_config::*;

pub struct CertStore {
	consul: Consul,
	letsencrypt_email: String,
	certs: RwLock<HashMap<String, Arc<Cert>>>,
	self_signed_certs: RwLock<HashMap<String, Arc<Cert>>>,
	rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>,
}

impl CertStore {
	pub fn new(consul: Consul, rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>, letsencrypt_email: String) -> Arc<Self> {
		Arc::new(Self {
			consul,
			certs: RwLock::new(HashMap::new()),
			self_signed_certs: RwLock::new(HashMap::new()),
			rx_proxy_config,
			letsencrypt_email,
		})
	}

	pub async fn watch_proxy_config(self: Arc<Self>) {
		let mut rx_proxy_config = self.rx_proxy_config.clone();

		while rx_proxy_config.changed().await.is_ok() {
			let mut domains: HashSet<String> = HashSet::new();

			let proxy_config: Arc<ProxyConfig> = rx_proxy_config.borrow().clone();
			for ent in proxy_config.entries.iter() {
				if let HostDescription::Hostname(domain) = &ent.host {
					domains.insert(domain.clone());
				}
			}

			for dom in domains.iter() {
				info!("Ensuring we have certs for domains: {:?}", domains);
				if let Err(e) = self.get_cert(dom).await {
					warn!("Error get_cert {}: {}", dom, e);
				}
			}
		}
	}

	pub fn get_cert_for_https(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> {
		// Check if domain is authorized
		if !self
			.rx_proxy_config
			.borrow()
			.entries
			.iter()
			.any(|ent| ent.host.matches(domain))
		{
			bail!("Domain {} should not have a TLS certificate.", domain);
		}

		// Check in local memory if it exists
		if let Some(cert) = self.certs.read().unwrap().get(domain) {
			if !cert.is_old() {
				return Ok(cert.clone());
			}
		}

		// Not found in local memory, try to get it in background
		tokio::spawn(self.clone().get_cert_task(domain.to_string()));

		// In the meantime, use a self-signed certificate
		if let Some(cert) = self.self_signed_certs.read().unwrap().get(domain) {
			if !cert.is_old() {
				return Ok(cert.clone());
			}
		}

		self.gen_self_signed_certificate(domain)
	}

	pub async fn get_cert_task(self: Arc<Self>, domain: String) -> Result<Arc<Cert>> {
		self.get_cert(domain.as_str()).await
	}

	pub async fn get_cert(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> {
		// First, try locally.
		{
			let certs = self.certs.read().unwrap();
			if let Some(cert) = certs.get(domain) {
				if !cert.is_old() {
					return Ok(cert.clone());
				}
			}
		}

		// Second, try from Consul.
		if let Some(consul_cert) = self
			.consul
			.kv_get_json::<CertSer>(&format!("certs/{}", domain))
			.await?
		{
			if let Ok(cert) = Cert::new(consul_cert) {
				let cert = Arc::new(cert);
				if !cert.is_old() {
					self.certs
						.write()
						.unwrap()
						.insert(domain.to_string(), cert.clone());
					return Ok(cert);
				}
			}
		}

		// Third, ask from Let's Encrypt
		self.renew_cert(domain).await
	}

	pub async fn renew_cert(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> {
		info!("Renewing certificate for {}", domain);

		// ---- Acquire lock ----
		// the lock is acquired for fifteen minutes,
		// so that in case of an error we won't retry before
		// that delay expires

		let lock_path = format!("renew_lock/{}", domain);
		let lock_name = format!("tricot/renew:{}@{}", domain, self.consul.local_node.clone());
		let session = self
			.consul
			.create_session(&ConsulSessionRequest {
				name: lock_name.clone(),
				node: None,
				lock_delay: Some("15m".into()),
				ttl: Some("30m".into()),
				behavior: Some("delete".into()),
			})
			.await?;
		debug!("Lock session: {}", session);

		if !self
			.consul
			.acquire(&lock_path, lock_name.clone().into(), &session)
			.await?
		{
			bail!("Lock is already taken, not renewing for now.");
		}

		// ---- Do let's encrypt stuff ----

		let dir = Directory::from_url(DirectoryUrl::LetsEncrypt)?;
		let contact = vec![format!("mailto:{}", self.letsencrypt_email)];

		let acc =
			if let Some(acc_privkey) = self.consul.kv_get("letsencrypt_account_key.pem").await? {
				info!("Using existing Let's encrypt account");
				dir.load_account(std::str::from_utf8(&acc_privkey)?, contact)?
			} else {
				info!("Creating new Let's encrypt account");
				let acc = block_in_place(|| dir.register_account(contact.clone()))?;
				self.consul
					.kv_put(
						"letsencrypt_account_key.pem",
						acc.acme_private_key_pem()?.into_bytes().into(),
					)
					.await?;
				acc
			};

		let mut ord_new = acc.new_order(domain, &[])?;
		let ord_csr = loop {
			if let Some(ord_csr) = ord_new.confirm_validations() {
				break ord_csr;
			}

			let auths = ord_new.authorizations()?;

			info!("Creating challenge and storing in Consul");
			let chall = auths[0].http_challenge().unwrap();
			let chall_key = format!("challenge/{}", chall.http_token());
			self.consul
				.acquire(&chall_key, chall.http_proof()?.into(), &session)
				.await?;

			info!("Validating challenge");
			block_in_place(|| chall.validate(Duration::from_millis(5000)))?;

			info!("Deleting challenge");
			self.consul.kv_delete(&chall_key).await?;

			block_in_place(|| ord_new.refresh())?;
		};

		let pkey_pri = create_p384_key()?;
		let ord_cert =
			block_in_place(|| ord_csr.finalize_pkey(pkey_pri, Duration::from_millis(5000)))?;
		let cert = block_in_place(|| ord_cert.download_cert())?;

		info!("Keys and certificate obtained");
		let key_pem = cert.private_key().to_string();
		let cert_pem = cert.certificate().to_string();

		let certser = CertSer {
			hostname: domain.to_string(),
			date: Utc::today().naive_utc(),
			valid_days: cert.valid_days_left()?,
			key_pem,
			cert_pem,
		};

		self.consul
			.kv_put_json(&format!("certs/{}", domain), &certser)
			.await?;
		self.consul.release(&lock_path, "".into(), &session).await?;

		let cert = Arc::new(Cert::new(certser)?);
		self.certs
			.write()
			.unwrap()
			.insert(domain.to_string(), cert.clone());

		info!("Cert successfully renewed: {}", domain);
		Ok(cert)
	}

	fn gen_self_signed_certificate(&self, domain: &str) -> Result<Arc<Cert>> {
		let subject_alt_names = vec![domain.to_string(), "localhost".to_string()];
		let cert = rcgen::generate_simple_self_signed(subject_alt_names)?;

		let certser = CertSer {
			hostname: domain.to_string(),
			date: Utc::today().naive_utc(),
			valid_days: 1024,
			key_pem: cert.serialize_private_key_pem(),
			cert_pem: cert.serialize_pem()?,
		};
		let cert = Arc::new(Cert::new(certser)?);
		self.self_signed_certs
			.write()
			.unwrap()
			.insert(domain.to_string(), cert.clone());
		info!("Added self-signed certificate for {}", domain);

		Ok(cert)
	}
}

pub struct StoreResolver(pub Arc<CertStore>);

impl rustls::server::ResolvesServerCert for StoreResolver {
	fn resolve(&self, client_hello: rustls::server::ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
		let domain = client_hello.server_name()?;
		match self.0.get_cert_for_https(domain) {
			Ok(cert) => Some(cert.certkey.clone()),
			Err(e) => {
				warn!("Could not get certificate for {}: {}", domain, e);
				None
			}
		}
	}
}