aboutsummaryrefslogblamecommitdiff
path: root/src/cert_store.rs
blob: 8d45df4e73caa8b098ce282a96144648edc2e07a (plain) (tree)
1
2
3
4
5
6
7
8
9







                                         
                                


                                          
                               

                                 
                     
                           



                                                  
                                                           


                
                                                                                                     


                                                           
                                        


                  


                                                                       




                                                                                              


                                                                                      
                         

                                                   
                                                                                           






                                                                               






                                                                                        
                                                            




















                                                                                          



































                                                                                          
                                         


                                                                    






                                                                                                      


                                                               


                                                                

                                                    









                                                                                








                                                                                                             
                                                                                                    




















                                                                                                
                                                                                          


                                                      
                                                                                        



                                                                 
                                                              


                                                  


                                                                                                         















                                                                            
                                                                            










                                                                  





                                                                                                       
                                                                   


                                          
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, RwLock};
use std::time::Duration;

use anyhow::Result;
use chrono::Utc;
use log::*;
use tokio::sync::watch;
use tokio::task::block_in_place;

use acme_micro::create_p384_key;
use acme_micro::{Directory, DirectoryUrl};
use rustls::sign::CertifiedKey;

use crate::cert::{Cert, CertSer};
use crate::consul::*;
use crate::proxy_config::*;

pub struct CertStore {
	consul: Consul,
	certs: RwLock<HashMap<String, Arc<Cert>>>,
	rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>,
}

impl CertStore {
	pub fn new(consul: Consul, rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>) -> Arc<Self> {
		Arc::new(Self {
			consul,
			certs: RwLock::new(HashMap::new()),
			rx_proxy_config,
		})
	}

	pub async fn watch_proxy_config(self: Arc<Self>) {
		let mut rx_proxy_config = self.rx_proxy_config.clone();

		while rx_proxy_config.changed().await.is_ok() {
			let mut domains: HashSet<String> = HashSet::new();

			let proxy_config: Arc<ProxyConfig> = rx_proxy_config.borrow().clone();
			for ent in proxy_config.entries.iter() {
				if let HostDescription::Hostname(domain) = &ent.host {
					domains.insert(domain.clone());
				}
			}

			for dom in domains.iter() {
				info!("Ensuring we have certs for domains: {:?}", domains);
				if let Err(e) = self.get_cert(dom).await {
					warn!("Error get_cert {}: {}", dom, e);
				}
			}
		}
	}

	pub fn get_cert_for_https(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> {
		// Check if domain is authorized
		if !self
			.rx_proxy_config
			.borrow()
			.entries
			.iter()
			.any(|ent| ent.host.matches(domain))
		{
			bail!("Domain {} should not have a TLS certificate.", domain);
		}

		// Check in local memory if it exists
		let certs = self.certs.read().unwrap();
		if let Some(cert) = certs.get(domain) {
			if !cert.is_old() {
				return Ok(cert.clone());
			}
		}

		// Not found in local memory
		tokio::spawn(self.clone().get_cert_task(domain.to_string()));
		bail!("Certificate not found (will try to get it in background)");
	}

	pub async fn get_cert_task(self: Arc<Self>, domain: String) -> Result<Arc<Cert>> {
		self.get_cert(domain.as_str()).await
	}

	pub async fn get_cert(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> {
		// First, try locally.
		{
			let certs = self.certs.read().unwrap();
			if let Some(cert) = certs.get(domain) {
				if !cert.is_old() {
					return Ok(cert.clone());
				}
			}
		}

		// Second, try from Consul.
		if let Some(consul_cert) = self
			.consul
			.kv_get_json::<CertSer>(&format!("certs/{}", domain))
			.await?
		{
			if let Ok(cert) = Cert::new(consul_cert) {
				let cert = Arc::new(cert);
				if !cert.is_old() {
					self.certs
						.write()
						.unwrap()
						.insert(domain.to_string(), cert.clone());
					return Ok(cert);
				}
			}
		}

		// Third, ask from Let's Encrypt
		self.renew_cert(domain).await
	}

	pub async fn renew_cert(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> {
		info!("Renewing certificate for {}", domain);

		// ---- Acquire lock ----
		// the lock is acquired for fifteen minutes,
		// so that in case of an error we won't retry before
		// that delay expires

		let lock_path = format!("renew_lock/{}", domain);
		let lock_name = format!("tricot/renew:{}@{}", domain, self.consul.local_node.clone());
		let session = self
			.consul
			.create_session(&ConsulSessionRequest {
				name: lock_name.clone(),
				node: None,
				lock_delay: Some("15m".into()),
				ttl: Some("30m".into()),
				behavior: Some("delete".into()),
			})
			.await?;
		debug!("Lock session: {}", session);

		if !self
			.consul
			.acquire(&lock_path, lock_name.clone().into(), &session)
			.await?
		{
			bail!("Lock is already taken, not renewing for now.");
		}

		// ---- Do let's encrypt stuff ----

		let dir = Directory::from_url(DirectoryUrl::LetsEncrypt)?;
		let contact = vec!["mailto:alex@adnab.me".to_string()];

		let acc =
			if let Some(acc_privkey) = self.consul.kv_get("letsencrypt_account_key.pem").await? {
				info!("Using existing Let's encrypt account");
				dir.load_account(std::str::from_utf8(&acc_privkey)?, contact)?
			} else {
				info!("Creating new Let's encrypt account");
				let acc = block_in_place(|| dir.register_account(contact.clone()))?;
				self.consul
					.kv_put(
						"letsencrypt_account_key.pem",
						acc.acme_private_key_pem()?.into_bytes().into(),
					)
					.await?;
				acc
			};

		let mut ord_new = acc.new_order(domain, &[])?;
		let ord_csr = loop {
			if let Some(ord_csr) = ord_new.confirm_validations() {
				break ord_csr;
			}

			let auths = ord_new.authorizations()?;

			info!("Creating challenge and storing in Consul");
			let chall = auths[0].http_challenge().unwrap();
			let chall_key = format!("challenge/{}", chall.http_token());
			self.consul
				.acquire(&chall_key, chall.http_proof()?.into(), &session)
				.await?;

			info!("Validating challenge");
			block_in_place(|| chall.validate(Duration::from_millis(5000)))?;

			info!("Deleting challenge");
			self.consul.kv_delete(&chall_key).await?;

			block_in_place(|| ord_new.refresh())?;
		};

		let pkey_pri = create_p384_key()?;
		let ord_cert =
			block_in_place(|| ord_csr.finalize_pkey(pkey_pri, Duration::from_millis(5000)))?;
		let cert = block_in_place(|| ord_cert.download_cert())?;

		info!("Keys and certificate obtained");
		let key_pem = cert.private_key().to_string();
		let cert_pem = cert.certificate().to_string();

		let certser = CertSer {
			hostname: domain.to_string(),
			date: Utc::today().naive_utc(),
			valid_days: cert.valid_days_left()?,
			key_pem,
			cert_pem,
		};

		self.consul
			.kv_put_json(&format!("certs/{}", domain), &certser)
			.await?;
		self.consul.release(&lock_path, "".into(), &session).await?;

		let cert = Arc::new(Cert::new(certser)?);
		self.certs
			.write()
			.unwrap()
			.insert(domain.to_string(), cert.clone());

		info!("Cert successfully renewed: {}", domain);
		Ok(cert)
	}
}

pub struct StoreResolver(pub Arc<CertStore>);

impl rustls::server::ResolvesServerCert for StoreResolver {
	fn resolve(&self, client_hello: rustls::server::ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
		let domain = client_hello.server_name()?;
		let cert = self.0.get_cert_for_https(domain).ok()?;
		Some(cert.certkey.clone())
	}
}