aboutsummaryrefslogblamecommitdiff
path: root/src/dns_updater.rs
blob: d781671dc1f08d6e7f5a050bf99f44ce58d2a779 (plain) (tree)
1
2
3
4
5
6
7
8

                                   
                        

                                   

                       
               









                                                       



                                                     
















































































                                                                                                                                                                         
use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::Arc;
use std::time::Duration;

use anyhow::{anyhow, bail, Result};
use tokio::select;
use tokio::sync::watch;
use tracing::*;

use crate::dns_config::*;
use crate::provider::DnsProvider;

pub async fn dns_updater_task(
    mut rx_dns_config: watch::Receiver<Arc<DnsConfig>>,
    provider: Box<dyn DnsProvider>,
    allowed_domains: Vec<String>,
    mut must_exit: watch::Receiver<bool>,
) {
    info!("DNS updater will start in 5 seconds");
    tokio::time::sleep(Duration::from_secs(5)).await;
    info!("DNS updater starting");

    let mut config = Arc::new(DnsConfig::new());
    while !*must_exit.borrow() {
        select!(
            c = rx_dns_config.changed() => {
                if c.is_err() {
                    break;
                }
            }
            _ = must_exit.changed() => continue,
        );
        let new_config: Arc<DnsConfig> = rx_dns_config.borrow().clone();

        for (k, v) in new_config.entries.iter() {
            if config.entries.get(k) != Some(v) {
                let fulldomain = format!("{}.{}", k.subdomain, k.domain);
                if !allowed_domains.iter().any(|d| fulldomain.ends_with(d)) {
                    error!(
                        "Got an entry for domain {} which is not in allowed list",
                        k.domain
                    );
                    continue;
                }

                info!("Updating {} {}", k, v);
                if let Err(e) = update_dns_entry(k, v, provider.as_ref()).await {
                    error!("Unable to update entry {} {}: {}", k, v, e);
                }
            }
        }

        config = new_config;
    }
}

async fn update_dns_entry(
    key: &DnsEntryKey,
    value: &DnsEntryValue,
    provider: &dyn DnsProvider,
) -> Result<()> {
    if value.targets.is_empty() {
        bail!("zero targets (internal error)");
    }

    match key.record_type {
        DnsRecordType::A => {
            let mut targets = vec![];
            for tgt in value.targets.iter() {
                targets.push(
                    tgt.parse::<Ipv4Addr>()
                        .map_err(|_| anyhow!("Invalid ipv4 address: {}", tgt))?,
                );
            }
            provider
                .update_a(&key.domain, &key.subdomain, &targets)
                .await?;
        }
        DnsRecordType::AAAA => {
            let mut targets = vec![];
            for tgt in value.targets.iter() {
                targets.push(
                    tgt.parse::<Ipv6Addr>()
                        .map_err(|_| anyhow!("Invalid ipv6 address: {}", tgt))?,
                );
            }
            provider
                .update_aaaa(&key.domain, &key.subdomain, &targets)
                .await?;
        }
        DnsRecordType::CNAME => {
            let mut targets = value.targets.iter().cloned().collect::<Vec<_>>();
            if targets.len() > 1 {
                targets.sort();
                warn!("Several CNAME targets for {}: {:?}. Taking first one in alphabetical order. Consider switching to a single global target instead.", key, targets);
            }
            provider
                .update_cname(&key.domain, &key.subdomain, &targets[0])
                .await?;
        }
    }
    Ok(())
}