aboutsummaryrefslogblamecommitdiff
path: root/src/main.rs
blob: 9461596a071a4ec4d42be063398a0272e8b852b3 (plain) (tree)
1
2
3
4
5
6
7
8

                   


                         
               
 
                  































                                                                                     

                                                           









                                                                       




                                                 




                                                    



                                                                                   












                                                
                                           








                                                                                         

















                                                                                                                                






                                 
                                                                                       


                                                                  
                         












                                                                                    

                                                                         








                                                







                                                                       
         
                                         

















                                                                            
use std::sync::Arc;

use structopt::StructOpt;
use tokio::select;
use tokio::sync::watch;
use tracing::*;

mod autodiscovery;
mod dns_config;
mod dns_updater;
mod provider;

#[derive(StructOpt, Debug)]
#[structopt(name = "d53")]
pub struct Opt {
    /// Address of consul server
    #[structopt(
        long = "consul-addr",
        env = "D53_CONSUL_HOST",
        default_value = "http://127.0.0.1:8500"
    )]
    pub consul_addr: String,

    /// CA certificate for Consul server with TLS
    #[structopt(long = "consul-ca-cert", env = "D53_CONSUL_CA_CERT")]
    pub consul_ca_cert: Option<String>,

    /// Skip TLS verification for Consul
    #[structopt(long = "consul-tls-skip-verify", env = "D53_CONSUL_TLS_SKIP_VERIFY")]
    pub consul_tls_skip_verify: bool,

    /// Client certificate for Consul server with TLS
    #[structopt(long = "consul-client-cert", env = "D53_CONSUL_CLIENT_CERT")]
    pub consul_client_cert: Option<String>,

    /// Client key for Consul server with TLS
    #[structopt(long = "consul-client-key", env = "D53_CONSUL_CLIENT_KEY")]
    pub consul_client_key: Option<String>,

    /// DNS provider
    #[structopt(long = "providers", env = "D53_PROVIDERS")]
    pub providers: String,

    /// Allowed domains
    #[structopt(long = "allowed-domains", env = "D53_ALLOWED_DOMAINS")]
    pub allowed_domains: String,

    /// API key for Gandi DNS provider
    #[structopt(long = "gandi-api-key", env = "D53_GANDI_API_KEY")]
    pub gandi_api_key: Option<String>,
}

pub struct DomainProvider {
    pub domain: String,
    pub provider: Box<dyn provider::DnsProvider>,
}

#[tokio::main]
async fn main() {
    if std::env::var("RUST_LOG").is_err() {
        std::env::set_var("RUST_LOG", "tricot=info")
    }
    tracing_subscriber::fmt()
        .with_writer(std::io::stderr)
        .with_env_filter(tracing_subscriber::filter::EnvFilter::from_default_env())
        .init();

    // Abort on panic (same behavior as in Go)
    std::panic::set_hook(Box::new(|panic_info| {
        error!("{}", panic_info.to_string());
        std::process::abort();
    }));

    let opt = Opt::from_args();

    info!("Starting D53");

    let (exit_signal, _) = watch_ctrl_c();

    let consul_config = df_consul::Config {
        addr: opt.consul_addr.clone(),
        ca_cert: opt.consul_ca_cert.clone(),
        tls_skip_verify: opt.consul_tls_skip_verify,
        client_cert: opt.consul_client_cert.clone(),
        client_key: opt.consul_client_key.clone(),
    };

    let consul = df_consul::Consul::new(consul_config, "").expect("Cannot build Consul");

    let mut domain_providers = vec![];
    for pstr in opt.providers.as_str().split(',') {
        let (domain, provider) = pstr.split_once(':')
			.expect("Invalid provider syntax, expected: <domain_name>:<provider>[,<domain_name>:<provider>[,...]]");
        let provider: Box<dyn provider::DnsProvider> = match provider {
            "gandi" => Box::new(
                provider::gandi::GandiProvider::new(&opt).expect("Cannot initialize Gandi provier"),
            ),
            p => panic!("Unsupported DNS provider: {}", p),
        };
        domain_providers.push(DomainProvider {
            domain: domain.to_string(),
            provider,
        });
    }
    if domain_providers.is_empty() {
        panic!("No domain providers were specified.");
    }

    let allowed_domains = opt
        .allowed_domains
        .split(',')
        .map(ToString::to_string)
        .collect::<Vec<_>>();

    let rx_dns_config = dns_config::spawn_dns_config_task(consul, exit_signal.clone());

    let updater_task = tokio::spawn(dns_updater::dns_updater_task(
        rx_dns_config.clone(),
        domain_providers,
        allowed_domains,
        exit_signal.clone(),
    ));
    let dump_task = tokio::spawn(dump_config_on_change(rx_dns_config, exit_signal));

    updater_task.await.expect("Tokio task await failure");
    dump_task.await.expect("Tokio task await failure");
}

async fn dump_config_on_change(
    mut rx_dns_config: watch::Receiver<Arc<dns_config::DnsConfig>>,
    mut must_exit: watch::Receiver<bool>,
) {
    let mut prev_dns_config = Arc::new(dns_config::DnsConfig::default());

    while !*must_exit.borrow() {
        select!(
            c = rx_dns_config.changed() => {
                if c.is_err() {
                    break;
                }
            }
            _ = must_exit.changed() => continue,
        );

        let new_dns_config = rx_dns_config.borrow_and_update().clone();
        if new_dns_config != prev_dns_config {
            println!("---- DNS CONFIGURATION ----");
            for (k, v) in rx_dns_config.borrow().entries.iter() {
                println!("   {} {}", k, v);
            }
            println!();
        }
        prev_dns_config = new_dns_config;
    }
}

/// Creates a watch that contains `false`, and that changes
/// to `true` when a Ctrl+C signal is received.
pub fn watch_ctrl_c() -> (watch::Receiver<bool>, Arc<watch::Sender<bool>>) {
    let (send_cancel, watch_cancel) = watch::channel(false);
    let send_cancel = Arc::new(send_cancel);
    let send_cancel_2 = send_cancel.clone();
    tokio::spawn(async move {
        tokio::signal::ctrl_c()
            .await
            .expect("failed to install CTRL+C signal handler");
        info!("Received CTRL+C, shutting down.");
        send_cancel.send(true).unwrap();
    });
    (watch_cancel, send_cancel_2)
}