aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlex Auvolat <alex@adnab.me>2021-12-09 12:18:23 +0100
committerAlex Auvolat <alex@adnab.me>2021-12-09 12:18:23 +0100
commitfdb83162ce9979184a9d680c5ec4f64235497485 (patch)
tree543d1d0b7c8f63d7a10c38d1d95c8b2e2c8a180f /src
parent8153bdca4618eed76665eeb4c5a005378701df1f (diff)
downloadtricot-fdb83162ce9979184a9d680c5ec4f64235497485.tar.gz
tricot-fdb83162ce9979184a9d680c5ec4f64235497485.zip
Improved management of ACME orders and certificate pre-expiration period
Diffstat (limited to 'src')
-rw-r--r--src/cert.rs2
-rw-r--r--src/cert_store.rs151
-rw-r--r--src/https.rs1
-rw-r--r--src/main.rs3
-rw-r--r--src/proxy_config.rs3
-rw-r--r--src/reverse_proxy.rs26
-rw-r--r--src/tls_util.rs9
7 files changed, 116 insertions, 79 deletions
diff --git a/src/cert.rs b/src/cert.rs
index 0be43f3..12b9218 100644
--- a/src/cert.rs
+++ b/src/cert.rs
@@ -6,7 +6,7 @@ use chrono::{Date, NaiveDate, Utc};
use rustls::sign::CertifiedKey;
use serde::{Deserialize, Serialize};
-#[derive(Serialize, Deserialize, Debug)]
+#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct CertSer {
pub hostname: String,
pub date: NaiveDate,
diff --git a/src/cert_store.rs b/src/cert_store.rs
index fe2f8b0..2095660 100644
--- a/src/cert_store.rs
+++ b/src/cert_store.rs
@@ -1,11 +1,13 @@
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, RwLock};
-use std::time::Duration;
+use std::time::{Duration, Instant};
use anyhow::Result;
use chrono::Utc;
+use futures::TryFutureExt;
use log::*;
-use tokio::sync::watch;
+use tokio::select;
+use tokio::sync::{mpsc, watch};
use tokio::task::block_in_place;
use acme_micro::create_p384_key;
@@ -14,6 +16,7 @@ use rustls::sign::CertifiedKey;
use crate::cert::{Cert, CertSer};
use crate::consul::*;
+use crate::exit_on_err;
use crate::proxy_config::*;
pub struct CertStore {
@@ -22,6 +25,7 @@ pub struct CertStore {
certs: RwLock<HashMap<String, Arc<Cert>>>,
self_signed_certs: RwLock<HashMap<String, Arc<Cert>>>,
rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>,
+ tx_need_cert: mpsc::UnboundedSender<String>,
}
impl CertStore {
@@ -30,44 +34,78 @@ impl CertStore {
rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>,
letsencrypt_email: String,
) -> Arc<Self> {
- Arc::new(Self {
+ let (tx, rx) = mpsc::unbounded_channel();
+
+ let cert_store = Arc::new(Self {
consul,
certs: RwLock::new(HashMap::new()),
self_signed_certs: RwLock::new(HashMap::new()),
rx_proxy_config,
letsencrypt_email,
- })
+ tx_need_cert: tx,
+ });
+
+ tokio::spawn(cert_store.clone().certificate_loop(rx).map_err(exit_on_err));
+
+ cert_store
}
- pub async fn watch_proxy_config(self: Arc<Self>) -> Result<()> {
+ async fn certificate_loop(
+ self: Arc<Self>,
+ mut rx_need_cert: mpsc::UnboundedReceiver<String>,
+ ) -> Result<()> {
let mut rx_proxy_config = self.rx_proxy_config.clone();
- while rx_proxy_config.changed().await.is_ok() {
+ let mut t_last_check: HashMap<String, Instant> = HashMap::new();
+
+ loop {
let mut domains: HashSet<String> = HashSet::new();
- let proxy_config: Arc<ProxyConfig> = rx_proxy_config.borrow().clone();
- for ent in proxy_config.entries.iter() {
- if let HostDescription::Hostname(domain) = &ent.host {
- if let Some((host, _port)) = domain.split_once(':') {
- domains.insert(host.to_string());
- } else {
- domains.insert(domain.clone());
+ select! {
+ res = rx_proxy_config.changed() => {
+ if res.is_err() {
+ bail!("rx_proxy_config closed");
+ }
+
+ let proxy_config: Arc<ProxyConfig> = rx_proxy_config.borrow().clone();
+ for ent in proxy_config.entries.iter() {
+ if let HostDescription::Hostname(domain) = &ent.host {
+ if let Some((host, _port)) = domain.split_once(':') {
+ domains.insert(host.to_string());
+ } else {
+ domains.insert(domain.clone());
+ }
+ }
}
}
+ need_cert = rx_need_cert.recv() => {
+ match need_cert {
+ Some(dom) => {
+ domains.insert(dom);
+ while let Ok(dom2) = rx_need_cert.try_recv() {
+ domains.insert(dom2);
+ }
+ }
+ None => bail!("rx_need_cert closed"),
+ };
+ }
}
- debug!("Ensuring we have certs for domains: {:#?}", domains);
for dom in domains.iter() {
- if let Err(e) = self.get_cert(dom).await {
- warn!("Error get_cert {}: {}", dom, e);
+ match t_last_check.get(dom) {
+ Some(t) if Instant::now() - *t < Duration::from_secs(3600) => continue,
+ _ => t_last_check.insert(dom.to_string(), Instant::now()),
+ };
+
+ debug!("Checking cert for domain: {}", dom);
+ if let Err(e) = self.check_cert(dom).await {
+ warn!("({}) Could not get certificate: {}", dom, e);
}
}
}
-
- bail!("rx_proxy_config closed");
}
- pub fn get_cert_for_https(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> {
+ fn get_cert_for_https(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> {
// Check if domain is authorized
if !self
.rx_proxy_config
@@ -81,35 +119,30 @@ impl CertStore {
// Check in local memory if it exists
if let Some(cert) = self.certs.read().unwrap().get(domain) {
- if !cert.is_old() {
- return Ok(cert.clone());
+ if cert.is_old() {
+ self.tx_need_cert.send(domain.to_string())?;
}
+ return Ok(cert.clone());
}
// Not found in local memory, try to get it in background
- tokio::spawn(self.clone().get_cert_task(domain.to_string()));
+ self.tx_need_cert.send(domain.to_string())?;
// In the meantime, use a self-signed certificate
if let Some(cert) = self.self_signed_certs.read().unwrap().get(domain) {
- if !cert.is_old() {
- return Ok(cert.clone());
- }
+ return Ok(cert.clone());
}
self.gen_self_signed_certificate(domain)
}
- pub async fn get_cert_task(self: Arc<Self>, domain: String) -> Result<Arc<Cert>> {
- self.get_cert(domain.as_str()).await
- }
-
- pub async fn get_cert(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> {
+ pub async fn check_cert(self: &Arc<Self>, domain: &str) -> Result<()> {
// First, try locally.
{
let certs = self.certs.read().unwrap();
if let Some(cert) = certs.get(domain) {
if !cert.is_old() {
- return Ok(cert.clone());
+ return Ok(());
}
}
}
@@ -122,12 +155,12 @@ impl CertStore {
{
if let Ok(cert) = Cert::new(consul_cert) {
let cert = Arc::new(cert);
+ self.certs
+ .write()
+ .unwrap()
+ .insert(domain.to_string(), cert.clone());
if !cert.is_old() {
- self.certs
- .write()
- .unwrap()
- .insert(domain.to_string(), cert.clone());
- return Ok(cert);
+ return Ok(());
}
}
}
@@ -136,8 +169,14 @@ impl CertStore {
self.renew_cert(domain).await
}
- pub async fn renew_cert(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> {
- info!("Renewing certificate for {}", domain);
+ pub async fn renew_cert(self: &Arc<Self>, domain: &str) -> Result<()> {
+ info!("({}) Renewing certificate", domain);
+
+ // Basic sanity check (we could add more kinds of checks here)
+ // This is just to help avoid getting rate-limited against ACME server
+ if !domain.contains('.') || domain.ends_with(".local") {
+ bail!("Probably not a publicly accessible domain, skipping (a self-signed certificate will be used)");
+ }
// ---- Acquire lock ----
// the lock is acquired for fifteen minutes,
@@ -171,11 +210,13 @@ impl CertStore {
let dir = Directory::from_url(DirectoryUrl::LetsEncrypt)?;
let contact = vec![format!("mailto:{}", self.letsencrypt_email)];
- let acc =
- if let Some(acc_privkey) = self.consul.kv_get("letsencrypt_account_key.pem").await? {
+ // Use existing Let's encrypt account or register new one if necessary
+ let acc = match self.consul.kv_get("letsencrypt_account_key.pem").await? {
+ Some(acc_privkey) => {
info!("Using existing Let's encrypt account");
dir.load_account(std::str::from_utf8(&acc_privkey)?, contact)?
- } else {
+ }
+ None => {
info!("Creating new Let's encrypt account");
let acc = block_in_place(|| dir.register_account(contact.clone()))?;
self.consul
@@ -185,8 +226,10 @@ impl CertStore {
)
.await?;
acc
- };
+ }
+ };
+ // Order certificate and perform validation
let mut ord_new = acc.new_order(domain, &[])?;
let ord_csr = loop {
if let Some(ord_csr) = ord_new.confirm_validations() {
@@ -195,28 +238,29 @@ impl CertStore {
let auths = ord_new.authorizations()?;
- info!("Creating challenge and storing in Consul");
+ info!("({}) Creating challenge and storing in Consul", domain);
let chall = auths[0].http_challenge().unwrap();
let chall_key = format!("challenge/{}", chall.http_token());
self.consul
.acquire(&chall_key, chall.http_proof()?.into(), &session)
.await?;
- info!("Validating challenge");
+ info!("({}) Validating challenge", domain);
block_in_place(|| chall.validate(Duration::from_millis(5000)))?;
- info!("Deleting challenge");
+ info!("({}) Deleting challenge", domain);
self.consul.kv_delete(&chall_key).await?;
block_in_place(|| ord_new.refresh())?;
};
+ // Generate key and finalize certificate
let pkey_pri = create_p384_key()?;
let ord_cert =
block_in_place(|| ord_csr.finalize_pkey(pkey_pri, Duration::from_millis(5000)))?;
let cert = block_in_place(|| ord_cert.download_cert())?;
- info!("Keys and certificate obtained");
+ info!("({}) Keys and certificate obtained", domain);
let key_pem = cert.private_key().to_string();
let cert_pem = cert.certificate().to_string();
@@ -227,21 +271,20 @@ impl CertStore {
key_pem,
cert_pem,
};
+ let cert = Arc::new(Cert::new(certser.clone())?);
+ // Store certificate in Consul and local store
+ self.certs.write().unwrap().insert(domain.to_string(), cert);
self.consul
.kv_put_json(&format!("certs/{}", domain), &certser)
.await?;
+
+ // Release locks
self.consul.release(&lock_path, "".into(), &session).await?;
self.consul.kv_delete(&lock_path).await?;
- let cert = Arc::new(Cert::new(certser)?);
- self.certs
- .write()
- .unwrap()
- .insert(domain.to_string(), cert.clone());
-
- info!("Cert successfully renewed: {}", domain);
- Ok(cert)
+ info!("({}) Cert successfully renewed and stored", domain);
+ Ok(())
}
fn gen_self_signed_certificate(&self, domain: &str) -> Result<Arc<Cert>> {
diff --git a/src/https.rs b/src/https.rs
index b0d452b..a389e72 100644
--- a/src/https.rs
+++ b/src/https.rs
@@ -114,7 +114,6 @@ async fn handle(
)
});
-
if let Some(proxy_to) = best_match {
proxy_to.calls.fetch_add(1, Ordering::SeqCst);
diff --git a/src/main.rs b/src/main.rs
index 1fffcbc..faffac6 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -5,7 +5,6 @@ use futures::TryFutureExt;
use std::net::SocketAddr;
use structopt::StructOpt;
-mod tls_util;
mod cert;
mod cert_store;
mod consul;
@@ -13,6 +12,7 @@ mod http;
mod https;
mod proxy_config;
mod reverse_proxy;
+mod tls_util;
use log::*;
@@ -85,7 +85,6 @@ async fn main() {
rx_proxy_config.clone(),
opt.letsencrypt_email.clone(),
);
- tokio::spawn(cert_store.clone().watch_proxy_config().map_err(exit_on_err));
tokio::spawn(http::serve_http(opt.http_bind_addr, consul.clone()).map_err(exit_on_err));
tokio::spawn(
diff --git a/src/proxy_config.rs b/src/proxy_config.rs
index 2c55eb5..820d40a 100644
--- a/src/proxy_config.rs
+++ b/src/proxy_config.rs
@@ -99,7 +99,8 @@ fn parse_tricot_tag(
) -> Option<ProxyEntry> {
let splits = tag.split(' ').collect::<Vec<_>>();
if (splits.len() != 2 && splits.len() != 3)
- || (splits[0] != "tricot" && splits[0] != "tricot-https") {
+ || (splits[0] != "tricot" && splits[0] != "tricot-https")
+ {
return None;
}
diff --git a/src/reverse_proxy.rs b/src/reverse_proxy.rs
index 10f463c..7b0f261 100644
--- a/src/reverse_proxy.rs
+++ b/src/reverse_proxy.rs
@@ -1,11 +1,11 @@
//! Copied from https://github.com/felipenoris/hyper-reverse-proxy
//! See there for original Copyright notice
-use std::sync::Arc;
use std::convert::TryInto;
-use std::time::SystemTime;
use std::net::IpAddr;
use std::str::FromStr;
+use std::sync::Arc;
+use std::time::SystemTime;
use anyhow::Result;
use log::*;
@@ -13,9 +13,9 @@ use log::*;
use http::header::HeaderName;
use hyper::header::{HeaderMap, HeaderValue};
use hyper::{Body, Client, Request, Response, Uri};
-use rustls::{Certificate, ServerName};
-use rustls::client::{ServerCertVerifier, ServerCertVerified};
use lazy_static::lazy_static;
+use rustls::client::{ServerCertVerified, ServerCertVerifier};
+use rustls::{Certificate, ServerName};
use crate::tls_util::HttpsConnectorFixedDnsname;
@@ -175,16 +175,14 @@ struct DontVerifyServerCert;
impl ServerCertVerifier for DontVerifyServerCert {
fn verify_server_cert(
- &self,
- _end_entity: &Certificate,
- _intermediates: &[Certificate],
- _server_name: &ServerName,
- _scts: &mut dyn Iterator<Item = &[u8]>,
- _ocsp_response: &[u8],
- _now: SystemTime
- ) -> Result<ServerCertVerified, rustls::Error> {
+ &self,
+ _end_entity: &Certificate,
+ _intermediates: &[Certificate],
+ _server_name: &ServerName,
+ _scts: &mut dyn Iterator<Item = &[u8]>,
+ _ocsp_response: &[u8],
+ _now: SystemTime,
+ ) -> Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
}
-
-
diff --git a/src/tls_util.rs b/src/tls_util.rs
index 054c35a..91ad31c 100644
--- a/src/tls_util.rs
+++ b/src/tls_util.rs
@@ -1,21 +1,20 @@
use core::future::Future;
use core::task::{Context, Poll};
use std::convert::TryFrom;
+use std::io;
use std::pin::Pin;
use std::sync::Arc;
-use std::io;
use futures_util::future::*;
-use rustls::ServerName;
use hyper::client::connect::Connection;
use hyper::client::HttpConnector;
use hyper::service::Service;
use hyper::Uri;
use hyper_rustls::MaybeHttpsStream;
+use rustls::ServerName;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::TlsConnector;
-
#[derive(Clone)]
pub struct HttpsConnectorFixedDnsname<T> {
http: T,
@@ -62,8 +61,7 @@ where
let cfg = self.tls_config.clone();
let connecting_future = self.http.call(dst);
- let dnsname =
- ServerName::try_from(self.fixed_dnsname).expect("Invalid fixed dnsname");
+ let dnsname = ServerName::try_from(self.fixed_dnsname).expect("Invalid fixed dnsname");
let f = async move {
let tcp = connecting_future.await.map_err(Into::into)?;
let connector = TlsConnector::from(cfg);
@@ -76,4 +74,3 @@ where
f.boxed()
}
}
-