diff options
-rw-r--r-- | src/cert_store.rs | 174 | ||||
-rw-r--r-- | src/proxy_config.rs | 23 |
2 files changed, 158 insertions, 39 deletions
diff --git a/src/cert_store.rs b/src/cert_store.rs index 0ced178..edbd0a1 100644 --- a/src/cert_store.rs +++ b/src/cert_store.rs @@ -22,12 +22,19 @@ pub struct CertStore { consul: Consul, node_name: String, letsencrypt_email: String, + certs: RwLock<HashMap<String, Arc<Cert>>>, self_signed_certs: RwLock<HashMap<String, Arc<Cert>>>, + rx_proxy_config: watch::Receiver<Arc<ProxyConfig>>, tx_need_cert: mpsc::UnboundedSender<String>, } +struct ProcessedDomains { + static_domains: HashSet<String>, + on_demand_domains: Vec<(glob::Pattern, Option<String>)>, +} + impl CertStore { pub fn new( consul: Consul, @@ -41,10 +48,10 @@ impl CertStore { let cert_store = Arc::new(Self { consul, node_name, + letsencrypt_email, certs: RwLock::new(HashMap::new()), self_signed_certs: RwLock::new(HashMap::new()), rx_proxy_config, - letsencrypt_email, tx_need_cert: tx, }); @@ -66,46 +73,72 @@ impl CertStore { let mut rx_proxy_config = self.rx_proxy_config.clone(); let mut t_last_check: HashMap<String, Instant> = HashMap::new(); + let mut proc_domains: Option<ProcessedDomains> = None; loop { - let mut domains: HashSet<String> = HashSet::new(); - - select! { + let domains = select! { + // Refresh some internal states, schedule static_domains for renew res = rx_proxy_config.changed() => { if res.is_err() { bail!("rx_proxy_config closed"); } + let mut static_domains: HashSet<String> = HashSet::new(); + let mut on_demand_domains: Vec<(glob::Pattern, Option<String>)> = vec![]; + let proxy_config: Arc<ProxyConfig> = rx_proxy_config.borrow().clone(); + for ent in proxy_config.entries.iter() { - if let HostDescription::Hostname(domain) = &ent.url_prefix.host { - if let Some((host, _port)) = domain.split_once(':') { - domains.insert(host.to_string()); - } else { - domains.insert(domain.clone()); - } + // Eagerly generate certificates for domains that + // are not patterns + match &ent.url_prefix.host { + HostDescription::Hostname(domain) => { + if let Some((host, _port)) = domain.split_once(':') { + static_domains.insert(host.to_string()); + } else { + static_domains.insert(domain.clone()); + } + }, + HostDescription::Pattern(pattern) => { + on_demand_domains.push((pattern.clone(), ent.on_demand_tls_ask.clone())); + }, } } + + // only static_domains are refreshed + proc_domains = Some(ProcessedDomains { static_domains: static_domains.clone(), on_demand_domains }); + self.domain_validation(static_domains, proc_domains.as_ref()).await } + // renew static and on-demand domains need_cert = rx_need_cert.recv() => { match need_cert { Some(dom) => { - domains.insert(dom); + let mut candidates: HashSet<String> = HashSet::new(); + + // collect certificates as much as possible + candidates.insert(dom); while let Ok(dom2) = rx_need_cert.try_recv() { - domains.insert(dom2); + candidates.insert(dom2); } + + self.domain_validation(candidates, proc_domains.as_ref()).await } None => bail!("rx_need_cert closed"), - }; + } } - } + }; + // Now that we have our list of domains to check, + // actually do something for dom in domains.iter() { + // Exclude from the list domains that were checked less than 60 + // seconds ago match t_last_check.get(dom) { Some(t) if Instant::now() - *t < Duration::from_secs(60) => continue, _ => t_last_check.insert(dom.to_string(), Instant::now()), }; + // Actual Let's Encrypt calls are done here (in sister function) debug!("Checking cert for domain: {}", dom); if let Err(e) = self.check_cert(dom).await { warn!("({}) Could not get certificate: {}", dom, e); @@ -114,18 +147,82 @@ impl CertStore { } } - fn get_cert_for_https(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> { - // Check if domain is authorized - if !self - .rx_proxy_config - .borrow() - .entries - .iter() - .any(|ent| ent.url_prefix.host.matches(domain)) - { - bail!("Domain {} should not have a TLS certificate.", domain); + async fn domain_validation( + &self, + candidates: HashSet<String>, + maybe_proc_domains: Option<&ProcessedDomains>, + ) -> HashSet<String> { + let mut domains: HashSet<String> = HashSet::new(); + + // Handle initialization + let proc_domains = match maybe_proc_domains { + None => { + warn!("Proxy config is not yet loaded, refusing all certificate generation"); + return domains; + } + Some(proc) => proc, + }; + + // Filter certificates... + 'outer: for candidate in candidates.into_iter() { + // Disallow obvious wrong domains... + if !candidate.contains('.') || candidate.ends_with(".local") { + warn!("{} is probably not a publicly accessible domain, skipping (a self-signed certificate will be used)", candidate); + continue; + } + + // Try to register domain as a static domain + if proc_domains.static_domains.contains(&candidate) { + trace!("domain {} validated as static domain", candidate); + domains.insert(candidate); + continue; + } + + // It's not a static domain, maybe an on-demand domain? + for (pattern, maybe_check_url) in proc_domains.on_demand_domains.iter() { + // check glob pattern + if pattern.matches(&candidate) { + // if no check url is set, accept domain as long as it matches the pattern + let check_url = match maybe_check_url { + None => { + trace!( + "domain {} validated on glob pattern {} only", + candidate, + pattern + ); + domains.insert(candidate); + continue 'outer; + } + Some(url) => url, + }; + + // if a check url is set, call it + // -- avoid DDoSing a backend + tokio::time::sleep(Duration::from_secs(2)).await; + match self.on_demand_tls_ask(check_url, &candidate).await { + Ok(()) => { + trace!( + "domain {} validated on glob pattern {} and on check url {}", + candidate, + pattern, + check_url + ); + domains.insert(candidate); + continue 'outer; + } + Err(e) => { + warn!("domain {} validation refused on glob pattern {} and on check url {} with error: {}", candidate, pattern, check_url, e); + } + } + } + } } + return domains; + } + + /// This function is also in charge of the refresh of the domain names + fn get_cert_for_https(self: &Arc<Self>, domain: &str) -> Result<Arc<Cert>> { // Check in local memory if it exists if let Some(cert) = self.certs.read().unwrap().get(domain) { if cert.is_old() { @@ -186,6 +283,15 @@ impl CertStore { Ok(()) } + /// Check certificate ensure that the certificate is in the memory store + /// and that it does not need to be renewed. + /// + /// If it's not in the memory store, it tries to load it from Consul, + /// if it's not in Consul, it calls Let's Encrypt. + /// + /// If the certificate is outdated in the memory store, it tries to load + /// a more recent version in Consul, if the Consul version is also outdated, + /// it tries to renew it pub async fn check_cert(self: &Arc<Self>, domain: &str) -> Result<()> { // First, try locally. { @@ -226,15 +332,10 @@ impl CertStore { self.renew_cert(domain).await } + /// This is the place where certificates are generated or renewed pub async fn renew_cert(self: &Arc<Self>, domain: &str) -> Result<()> { info!("({}) Renewing certificate", domain); - // Basic sanity check (we could add more kinds of checks here) - // This is just to help avoid getting rate-limited against ACME server - if !domain.contains('.') || domain.ends_with(".local") { - bail!("Probably not a publicly accessible domain, skipping (a self-signed certificate will be used)"); - } - // ---- Acquire lock ---- // the lock is acquired for half an hour, // so that in case of an error we won't retry before @@ -350,6 +451,19 @@ impl CertStore { Ok(()) } + async fn on_demand_tls_ask(&self, check_url: &str, domain: &str) -> Result<()> { + let httpcli = reqwest::Client::new(); + let chall_url = format!("{}?domain={}", check_url, domain); + info!("({}) On-demand TLS check", domain); + + let httpresp = httpcli.get(&chall_url).send().await?; + if httpresp.status() != reqwest::StatusCode::OK { + bail!("{} is not authorized for on-demand TLS", domain); + } + + Ok(()) + } + async fn check_domain_accessibility(&self, domain: &str, session: &str) -> Result<()> { // Returns Ok(()) only if domain is a correct domain name that // redirects to this server diff --git a/src/proxy_config.rs b/src/proxy_config.rs index 8381de2..7690f8a 100644 --- a/src/proxy_config.rs +++ b/src/proxy_config.rs @@ -108,6 +108,10 @@ pub struct ProxyEntry { /// when matching this rule pub redirects: Vec<(UrlPrefix, UrlPrefix, u16)>, + /// Wether or not the domain must be validated before asking a certificate + /// to let's encrypt (only for Glob patterns) + pub on_demand_tls_ask: Option<String>, + /// Number of calls in progress, used to deprioritize slow back-ends pub calls_in_progress: atomic::AtomicI64, /// Time of last call, used for round-robin selection @@ -142,14 +146,14 @@ impl ProxyEntry { let mut add_headers = vec![]; let mut redirects = vec![]; + let mut on_demand_tls_ask: Option<String> = None; for mid in middleware.into_iter() { + // LocalLb and GlobalLb are handled in the parent function match mid { ConfigTag::AddHeader(k, v) => add_headers.push((k.to_string(), v.clone())), ConfigTag::AddRedirect(m, r, c) => redirects.push(((*m).clone(), (*r).clone(), *c)), - ConfigTag::LocalLb | ConfigTag::GlobalLb => { - /* handled in parent fx */ - () - } + ConfigTag::OnDemandTlsAsk(url) => on_demand_tls_ask = Some(url.to_string()), + ConfigTag::LocalLb | ConfigTag::GlobalLb => (), }; } @@ -166,6 +170,7 @@ impl ProxyEntry { flags, add_headers, redirects, + on_demand_tls_ask, // internal last_call: atomic::AtomicI64::from(0), calls_in_progress: atomic::AtomicI64::from(0), @@ -247,6 +252,7 @@ enum MatchTag { enum ConfigTag<'a> { AddHeader(&'a str, String), AddRedirect(UrlPrefix, UrlPrefix, u16), + OnDemandTlsAsk(&'a str), GlobalLb, LocalLb, } @@ -321,6 +327,9 @@ fn parse_tricot_tags(tag: &str) -> Option<ParsedTag> { p_match, p_replace, http_code, ))) } + ["tricot-on-demand-tls-ask", url, ..] => { + Some(ParsedTag::Middleware(ConfigTag::OnDemandTlsAsk(url))) + } ["tricot-global-lb", ..] => Some(ParsedTag::Middleware(ConfigTag::GlobalLb)), ["tricot-local-lb", ..] => Some(ParsedTag::Middleware(ConfigTag::LocalLb)), _ => None, @@ -369,13 +378,9 @@ fn parse_consul_service( // some legacy processing that would need a refactor later for mid in collected_middleware.iter() { match mid { - ConfigTag::AddHeader(_, _) | ConfigTag::AddRedirect(_, _, _) => - /* not handled here */ - { - () - } ConfigTag::GlobalLb => flags.global_lb = true, ConfigTag::LocalLb => flags.site_lb = true, + _ => (), }; } |