use std::collections::{HashMap, HashSet};
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use anyhow::Result;
use chrono::Utc;
use futures::{FutureExt, TryFutureExt};
use log::*;
use tokio::select;
use tokio::sync::{mpsc, watch};
use tokio::task::block_in_place;
use acme_micro::create_p384_key;
use acme_micro::{Directory, DirectoryUrl};
use rustls::sign::CertifiedKey;
use crate::cert::{Cert, CertSer};
use crate::consul::{self, Consul};
use crate::proxy_config::*;
pub struct CertStore {
consul: Consul,
node_name: String,
letsencrypt_email: String,
certs: RwLock<HashMap<String, Arc<Cert>>>,
self_signed_certs: RwLock<HashMap<String, Arc<Cert>>>,
rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>,
tx_need_cert: mpsc::UnboundedSender<String>,
}
impl CertStore {
pub fn new(
consul: Consul,
node_name: String,
rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>,
letsencrypt_email: String,
exit_on_err: impl Fn(anyhow::Error) + Send + 'static,
) -> Arc<Self> {
let (tx, rx) = mpsc::unbounded_channel();
let cert_store = Arc::new(Self {
consul,
node_name,
certs: RwLock::new(HashMap::new()),
self_signed_certs: RwLock::new(HashMap::new()),
rx_proxy_config,
letsencrypt_email,
tx_need_cert: tx,
});
tokio::spawn(
cert_store
.clone()
.certificate_loop(rx)
.map_err(exit_on_err)
.then(|_| async { info!("Certificate renewal task exited") }),
);
cert_store
}
async fn certificate_loop(
self: Arc<Self>,
mut rx_need_cert: mpsc::UnboundedReceiver<String>,
) -> Result<()> {
let mut rx_proxy_config = self.rx_proxy_config.clone();
let mut t_last_check: HashMap<String, Instant> = HashMap::new();
loop {
let mut domains: HashSet<String> = HashSet::new();
select! {
res = rx_proxy_config.changed() => {
if res.is_err() {
bail!("rx_proxy_config closed");
}
let proxy_config: Arc<ProxyConfig> = rx_proxy_config.borrow().clone();
for ent in proxy_config.entries.iter() {
if let HostDescription::Hostname(domain) = &ent.host {
if let Some((host, _port)) = domain.split_once(':') {
domains.insert(host.to_string());
} else {
domains.insert(domain.clone());
}
}
}
}
need_cert = rx_need_cert.recv() => {
match need_cert {
Some(dom) => {
domains.insert(dom);
while let Ok(dom2) = rx_need_cert.try_recv() {
domains.insert(dom2);
}
}
None => bail!("rx_need_cert closed"),
};
}
}
for dom in domains.iter() {
match t_last_check.get(dom) {
Some(t) if Instant::now() - *t < Duration::from_secs(60) => continue,
_ => t_last_check.insert(dom.to_string(), Instant::now()),
};
debug!("Checking cert for domain: {}", dom);
if let Err(e) = self.check_cert(dom).await {
warn!("({}) Could not get certificate: {}", dom, e);
}
}
}
}
fn get_cert_for_https(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> {
// Check if domain is authorized
if !self
.rx_proxy_config
.borrow()
.entries
.iter()
.any(|ent| ent.host.matches(domain))
{
bail!("Domain {} should not have a TLS certificate.", domain);
}
// Check in local memory if it exists
if let Some(cert) = self.certs.read().unwrap().get(domain) {
if cert.is_old() {
self.tx_need_cert.send(domain.to_string())?;
}
return Ok(cert.clone());
}
// Not found in local memory, try to get it in background
self.tx_need_cert.send(domain.to_string())?;
// In the meantime, use a self-signed certificate
if let Some(cert) = self.self_signed_certs.read().unwrap().get(domain) {
return Ok(cert.clone());
}
self.gen_self_signed_certificate(domain)
}
pub async fn check_cert(self: &Arc<Self>, domain: &str) -> Result<()> {
// First, try locally.
{
let certs = self.certs.read().unwrap();
if let Some(cert) = certs.get(domain) {
if !cert.is_old() {
return Ok(());
}
}
}
// Second, try from Consul.
if let Some(consul_cert) = self
.consul
.kv_get_json::<CertSer>(&format!("certs/{}", domain))
.await?
{
if let Ok(cert) = Cert::new(consul_cert) {
let cert = Arc::new(cert);
self.certs
.write()
.unwrap()
.insert(domain.to_string(), cert.clone());
if !cert.is_old() {
return Ok(());
}
}
}
// Third, ask from Let's Encrypt
self.renew_cert(domain).await
}
pub async fn renew_cert(self: &Arc<Self>, domain: &str) -> Result<()> {
info!("({}) Renewing certificate", domain);
// Basic sanity check (we could add more kinds of checks here)
// This is just to help avoid getting rate-limited against ACME server
if !domain.contains('.') || domain.ends_with(".local") {
bail!("Probably not a publicly accessible domain, skipping (a self-signed certificate will be used)");
}
// ---- Acquire lock ----
// the lock is acquired for half an hour,
// so that in case of an error we won't retry before
// that delay expires
let lock_path = format!("renew_lock/{}", domain);
let lock_name = format!("tricot/renew:{}@{}", domain, self.node_name);
let session = self
.consul
.create_session(&consul::locking::SessionRequest {
name: lock_name.clone(),
node: None,
lock_delay: Some("30m".into()),
ttl: Some("45m".into()),
behavior: Some("delete".into()),
})
.await?;
debug!("Lock session: {}", session);
if !self
.consul
.acquire(&lock_path, lock_name.clone().into(), &session)
.await?
{
bail!("Lock is already taken, not renewing for now.");
}
// ---- Accessibility check ----
// We don't want to ask Let's encrypt for a domain that
// is not configured to point here. This can happen with wildcards: someone can send
// a fake SNI to a domain that is not ours. We have to detect it here.
self.check_domain_accessibility(domain, &session).await?;
// ---- Do let's encrypt stuff ----
let dir = Directory::from_url(DirectoryUrl::LetsEncrypt)?;
let contact = vec![format!("mailto:{}", self.letsencrypt_email)];
// Use existing Let's encrypt account or register new one if necessary
let acc = match self.consul.kv_get("letsencrypt_account_key.pem").await? {
Some(acc_privkey) => {
info!("Using existing Let's encrypt account");
dir.load_account(std::str::from_utf8(&acc_privkey)?, contact)?
}
None => {
info!("Creating new Let's encrypt account");
let acc = block_in_place(|| dir.register_account(contact.clone()))?;
self.consul
.kv_put(
"letsencrypt_account_key.pem",
acc.acme_private_key_pem()?.into_bytes().into(),
)
.await?;
acc
}
};
// Order certificate and perform validation
let mut ord_new = acc.new_order(domain, &[])?;
let ord_csr = loop {
if let Some(ord_csr) = ord_new.confirm_validations() {
break ord_csr;
}
let auths = ord_new.authorizations()?;
info!("({}) Creating challenge and storing in Consul", domain);
let chall = auths[0].http_challenge().unwrap();
let chall_key = format!("challenge/{}", chall.http_token());
self.consul
.acquire(&chall_key, chall.http_proof()?.into(), &session)
.await?;
info!("({}) Validating challenge", domain);
block_in_place(|| chall.validate(Duration::from_millis(5000)))?;
info!("({}) Deleting challenge", domain);
self.consul.kv_delete(&chall_key).await?;
block_in_place(|| ord_new.refresh())?;
};
// Generate key and finalize certificate
let pkey_pri = create_p384_key()?;
let ord_cert =
block_in_place(|| ord_csr.finalize_pkey(pkey_pri, Duration::from_millis(5000)))?;
let cert = block_in_place(|| ord_cert.download_cert())?;
info!("({}) Keys and certificate obtained", domain);
let key_pem = cert.private_key().to_string();
let cert_pem = cert.certificate().to_string();
let certser = CertSer {
hostname: domain.to_string(),
date: Utc::now().date_naive(),
valid_days: cert.valid_days_left()?,
key_pem,
cert_pem,
};
let cert = Arc::new(Cert::new(certser.clone())?);
// Store certificate in Consul and local store
self.certs.write().unwrap().insert(domain.to_string(), cert);
self.consul
.kv_put_json(&format!("certs/{}", domain), &certser)
.await?;
// Release locks
self.consul.release(&lock_path, "".into(), &session).await?;
self.consul.kv_delete(&lock_path).await?;
info!("({}) Cert successfully renewed and stored", domain);
Ok(())
}
async fn check_domain_accessibility(&self, domain: &str, session: &str) -> Result<()> {
// Returns Ok(()) only if domain is a correct domain name that
// redirects to this server
let self_challenge_id = uuid::Uuid::new_v4().to_string();
let self_challenge_key = format!("challenge/{}", self_challenge_id);
let self_challenge_resp = uuid::Uuid::new_v4().to_string();
self.consul
.acquire(
&self_challenge_key,
self_challenge_resp.as_bytes().to_vec().into(),
session,
)
.await?;
let httpcli = reqwest::Client::new();
let chall_url = format!(
"http://{}/.well-known/acme-challenge/{}",
domain, self_challenge_id
);
for i in 1..=4 {
tokio::time::sleep(Duration::from_secs(2)).await;
info!("({}) Accessibility check {}/4", domain, i);
let httpresp = httpcli.get(&chall_url).send().await?;
if httpresp.status() == reqwest::StatusCode::OK
&& httpresp.bytes().await? == self_challenge_resp.as_bytes()
{
// Challenge successfully validated
info!("({}) Accessibility check successfull", domain);
return Ok(());
}
tokio::time::sleep(Duration::from_secs(2)).await;
}
bail!("Unable to validate self-challenge for domain accessibility check");
}
fn gen_self_signed_certificate(&self, domain: &str) -> Result<Arc<Cert>> {
let subject_alt_names = vec![domain.to_string(), "localhost".to_string()];
let cert = rcgen::generate_simple_self_signed(subject_alt_names)?;
let certser = CertSer {
hostname: domain.to_string(),
date: Utc::now().date_naive(),
valid_days: 1024,
key_pem: cert.serialize_private_key_pem(),
cert_pem: cert.serialize_pem()?,
};
let cert = Arc::new(Cert::new(certser)?);
self.self_signed_certs
.write()
.unwrap()
.insert(domain.to_string(), cert.clone());
info!("Added self-signed certificate for {}", domain);
Ok(cert)
}
}
pub struct StoreResolver(pub Arc<CertStore>);
impl rustls::server::ResolvesServerCert for StoreResolver {
fn resolve(&self, client_hello: rustls::server::ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
let domain = client_hello.server_name()?;
match self.0.get_cert_for_https(domain) {
Ok(cert) => Some(cert.certkey.clone()),
Err(e) => {
warn!("Could not get certificate for {}: {}", domain, e);
None
}
}
}
}