use std::collections::{HashMap, HashSet}; use std::sync::{Arc, RwLock}; use std::time::{Duration, Instant}; use anyhow::Result; use chrono::Utc; use futures::{FutureExt, TryFutureExt}; use log::*; use tokio::select; use tokio::sync::{mpsc, watch}; use tokio::task::block_in_place; use acme_micro::create_p384_key; use acme_micro::{Directory, DirectoryUrl}; use rustls::sign::CertifiedKey; use crate::cert::{Cert, CertSer}; use crate::consul::{self, Consul}; use crate::proxy_config::*; pub struct CertStore { consul: Consul, node_name: String, letsencrypt_email: String, certs: RwLock>>, self_signed_certs: RwLock>>, rx_proxy_config: watch::Receiver>, tx_need_cert: mpsc::UnboundedSender, } impl CertStore { pub fn new( consul: Consul, node_name: String, rx_proxy_config: watch::Receiver>, letsencrypt_email: String, exit_on_err: impl Fn(anyhow::Error) + Send + 'static, ) -> Arc { let (tx, rx) = mpsc::unbounded_channel(); let cert_store = Arc::new(Self { consul, node_name, certs: RwLock::new(HashMap::new()), self_signed_certs: RwLock::new(HashMap::new()), rx_proxy_config, letsencrypt_email, tx_need_cert: tx, }); tokio::spawn( cert_store .clone() .certificate_loop(rx) .map_err(exit_on_err) .then(|_| async { info!("Certificate renewal task exited") }), ); cert_store } async fn certificate_loop( self: Arc, mut rx_need_cert: mpsc::UnboundedReceiver, ) -> Result<()> { let mut rx_proxy_config = self.rx_proxy_config.clone(); let mut t_last_check: HashMap = HashMap::new(); loop { let mut domains: HashSet = HashSet::new(); select! { res = rx_proxy_config.changed() => { if res.is_err() { bail!("rx_proxy_config closed"); } let proxy_config: Arc = rx_proxy_config.borrow().clone(); for ent in proxy_config.entries.iter() { if let HostDescription::Hostname(domain) = &ent.host { if let Some((host, _port)) = domain.split_once(':') { domains.insert(host.to_string()); } else { domains.insert(domain.clone()); } } } } need_cert = rx_need_cert.recv() => { match need_cert { Some(dom) => { domains.insert(dom); while let Ok(dom2) = rx_need_cert.try_recv() { domains.insert(dom2); } } None => bail!("rx_need_cert closed"), }; } } for dom in domains.iter() { match t_last_check.get(dom) { Some(t) if Instant::now() - *t < Duration::from_secs(60) => continue, _ => t_last_check.insert(dom.to_string(), Instant::now()), }; debug!("Checking cert for domain: {}", dom); if let Err(e) = self.check_cert(dom).await { warn!("({}) Could not get certificate: {}", dom, e); } } } } fn get_cert_for_https(self: &Arc, domain: &str) -> Result> { // Check if domain is authorized if !self .rx_proxy_config .borrow() .entries .iter() .any(|ent| ent.host.matches(domain)) { bail!("Domain {} should not have a TLS certificate.", domain); } // Check in local memory if it exists if let Some(cert) = self.certs.read().unwrap().get(domain) { if cert.is_old() { self.tx_need_cert.send(domain.to_string())?; } return Ok(cert.clone()); } // Not found in local memory, try to get it in background self.tx_need_cert.send(domain.to_string())?; // In the meantime, use a self-signed certificate if let Some(cert) = self.self_signed_certs.read().unwrap().get(domain) { return Ok(cert.clone()); } self.gen_self_signed_certificate(domain) } pub async fn check_cert(self: &Arc, domain: &str) -> Result<()> { // First, try locally. { let certs = self.certs.read().unwrap(); if let Some(cert) = certs.get(domain) { if !cert.is_old() { return Ok(()); } } } // Second, try from Consul. if let Some(consul_cert) = self .consul .kv_get_json::(&format!("certs/{}", domain)) .await? { if let Ok(cert) = Cert::new(consul_cert) { let cert = Arc::new(cert); self.certs .write() .unwrap() .insert(domain.to_string(), cert.clone()); if !cert.is_old() { return Ok(()); } } } // Third, ask from Let's Encrypt self.renew_cert(domain).await } pub async fn renew_cert(self: &Arc, domain: &str) -> Result<()> { info!("({}) Renewing certificate", domain); // Basic sanity check (we could add more kinds of checks here) // This is just to help avoid getting rate-limited against ACME server if !domain.contains('.') || domain.ends_with(".local") { bail!("Probably not a publicly accessible domain, skipping (a self-signed certificate will be used)"); } // ---- Acquire lock ---- // the lock is acquired for half an hour, // so that in case of an error we won't retry before // that delay expires let lock_path = format!("renew_lock/{}", domain); let lock_name = format!("tricot/renew:{}@{}", domain, self.node_name); let session = self .consul .create_session(&consul::locking::SessionRequest { name: lock_name.clone(), node: None, lock_delay: Some("30m".into()), ttl: Some("45m".into()), behavior: Some("delete".into()), }) .await?; debug!("Lock session: {}", session); if !self .consul .acquire(&lock_path, lock_name.clone().into(), &session) .await? { bail!("Lock is already taken, not renewing for now."); } // ---- Accessibility check ---- // We don't want to ask Let's encrypt for a domain that // is not configured to point here. This can happen with wildcards: someone can send // a fake SNI to a domain that is not ours. We have to detect it here. self.check_domain_accessibility(domain, &session).await?; // ---- Do let's encrypt stuff ---- let dir = Directory::from_url(DirectoryUrl::LetsEncrypt)?; let contact = vec![format!("mailto:{}", self.letsencrypt_email)]; // Use existing Let's encrypt account or register new one if necessary let acc = match self.consul.kv_get("letsencrypt_account_key.pem").await? { Some(acc_privkey) => { info!("Using existing Let's encrypt account"); dir.load_account(std::str::from_utf8(&acc_privkey)?, contact)? } None => { info!("Creating new Let's encrypt account"); let acc = block_in_place(|| dir.register_account(contact.clone()))?; self.consul .kv_put( "letsencrypt_account_key.pem", acc.acme_private_key_pem()?.into_bytes().into(), ) .await?; acc } }; // Order certificate and perform validation let mut ord_new = acc.new_order(domain, &[])?; let ord_csr = loop { if let Some(ord_csr) = ord_new.confirm_validations() { break ord_csr; } let auths = ord_new.authorizations()?; info!("({}) Creating challenge and storing in Consul", domain); let chall = auths[0].http_challenge().unwrap(); let chall_key = format!("challenge/{}", chall.http_token()); self.consul .acquire(&chall_key, chall.http_proof()?.into(), &session) .await?; info!("({}) Validating challenge", domain); block_in_place(|| chall.validate(Duration::from_millis(5000)))?; info!("({}) Deleting challenge", domain); self.consul.kv_delete(&chall_key).await?; block_in_place(|| ord_new.refresh())?; }; // Generate key and finalize certificate let pkey_pri = create_p384_key()?; let ord_cert = block_in_place(|| ord_csr.finalize_pkey(pkey_pri, Duration::from_millis(5000)))?; let cert = block_in_place(|| ord_cert.download_cert())?; info!("({}) Keys and certificate obtained", domain); let key_pem = cert.private_key().to_string(); let cert_pem = cert.certificate().to_string(); let certser = CertSer { hostname: domain.to_string(), date: Utc::now().date_naive(), valid_days: cert.valid_days_left()?, key_pem, cert_pem, }; let cert = Arc::new(Cert::new(certser.clone())?); // Store certificate in Consul and local store self.certs.write().unwrap().insert(domain.to_string(), cert); self.consul .kv_put_json(&format!("certs/{}", domain), &certser) .await?; // Release locks self.consul.release(&lock_path, "".into(), &session).await?; self.consul.kv_delete(&lock_path).await?; info!("({}) Cert successfully renewed and stored", domain); Ok(()) } async fn check_domain_accessibility(&self, domain: &str, session: &str) -> Result<()> { // Returns Ok(()) only if domain is a correct domain name that // redirects to this server let self_challenge_id = uuid::Uuid::new_v4().to_string(); let self_challenge_key = format!("challenge/{}", self_challenge_id); let self_challenge_resp = uuid::Uuid::new_v4().to_string(); self.consul .acquire( &self_challenge_key, self_challenge_resp.as_bytes().to_vec().into(), session, ) .await?; let httpcli = reqwest::Client::new(); let chall_url = format!( "http://{}/.well-known/acme-challenge/{}", domain, self_challenge_id ); for i in 1..=4 { tokio::time::sleep(Duration::from_secs(2)).await; info!("({}) Accessibility check {}/4", domain, i); let httpresp = httpcli.get(&chall_url).send().await?; if httpresp.status() == reqwest::StatusCode::OK && httpresp.bytes().await? == self_challenge_resp.as_bytes() { // Challenge successfully validated info!("({}) Accessibility check successfull", domain); return Ok(()); } tokio::time::sleep(Duration::from_secs(2)).await; } bail!("Unable to validate self-challenge for domain accessibility check"); } fn gen_self_signed_certificate(&self, domain: &str) -> Result> { let subject_alt_names = vec![domain.to_string(), "localhost".to_string()]; let cert = rcgen::generate_simple_self_signed(subject_alt_names)?; let certser = CertSer { hostname: domain.to_string(), date: Utc::now().date_naive(), valid_days: 1024, key_pem: cert.serialize_private_key_pem(), cert_pem: cert.serialize_pem()?, }; let cert = Arc::new(Cert::new(certser)?); self.self_signed_certs .write() .unwrap() .insert(domain.to_string(), cert.clone()); info!("Added self-signed certificate for {}", domain); Ok(cert) } } pub struct StoreResolver(pub Arc); impl rustls::server::ResolvesServerCert for StoreResolver { fn resolve(&self, client_hello: rustls::server::ClientHello<'_>) -> Option> { let domain = client_hello.server_name()?; match self.0.get_cert_for_https(domain) { Ok(cert) => Some(cert.certkey.clone()), Err(e) => { warn!("Could not get certificate for {}: {}", domain, e); None } } } }