diff options
author | Alex Auvolat <alex@adnab.me> | 2021-12-07 18:19:51 +0100 |
---|---|---|
committer | Alex Auvolat <alex@adnab.me> | 2021-12-07 18:19:51 +0100 |
commit | 489d364676003fa08130689a9f509de7d4df1602 (patch) | |
tree | c3595a10ef94eead74da41101be7bd42ed292c58 | |
parent | 0682c74e9d5083b43b3f83f8bb1ca747658d1455 (diff) | |
download | tricot-489d364676003fa08130689a9f509de7d4df1602.tar.gz tricot-489d364676003fa08130689a9f509de7d4df1602.zip |
Add support for custom headers
-rw-r--r-- | .gitignore | 1 | ||||
-rw-r--r-- | src/cert_store.rs | 2 | ||||
-rw-r--r-- | src/http.rs | 57 | ||||
-rw-r--r-- | src/https.rs | 16 | ||||
-rw-r--r-- | src/main.rs | 17 | ||||
-rw-r--r-- | src/proxy_config.rs | 24 |
6 files changed, 78 insertions, 39 deletions
@@ -1 +1,2 @@ /target +run_local.sh diff --git a/src/cert_store.rs b/src/cert_store.rs index a58288c..eca39b9 100644 --- a/src/cert_store.rs +++ b/src/cert_store.rs @@ -40,7 +40,7 @@ impl CertStore { for ent in proxy_config.entries.iter() { domains.insert(ent.host.clone()); } - info!("Ensuring we have certs for domains: {:#?}", domains); + info!("Ensuring we have certs for domains: {:?}", domains); for dom in domains.iter() { if let Err(e) = self.get_cert(dom).await { diff --git a/src/http.rs b/src/http.rs index 385456a..4731645 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,4 +1,5 @@ use std::sync::Arc; +use std::net::SocketAddr; use anyhow::Result; use log::*; @@ -11,6 +12,34 @@ use crate::consul::Consul; const CHALLENGE_PREFIX: &str = "/.well-known/acme-challenge/"; +pub async fn serve_http( + bind_addr: SocketAddr, + consul: Consul, +) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { + let consul = Arc::new(consul); + // For every connection, we must make a `Service` to handle all + // incoming HTTP requests on said connection. + let make_svc = make_service_fn(|_conn| { + let consul = consul.clone(); + // This is the `Service` that will handle the connection. + // `service_fn` is a helper to convert a function that + // returns a Response into a `Service`. + async move { + Ok::<_, anyhow::Error>(service_fn(move |req: Request<Body>| { + let consul = consul.clone(); + handle(req, consul) + })) + } + }); + + info!("Listening on http://{}", bind_addr); + let server = Server::bind(&bind_addr).serve(make_svc); + + server.await?; + + Ok(()) +} + async fn handle(req: Request<Body>, consul: Arc<Consul>) -> Result<Response<Body>> { let path = req.uri().path(); info!("HTTP request {}", path); @@ -45,31 +74,3 @@ async fn handle(req: Request<Body>, consul: Arc<Consul>) -> Result<Response<Body .body(Body::from(""))?) } } - -pub async fn serve_http(consul: Consul) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { - let consul = Arc::new(consul); - // For every connection, we must make a `Service` to handle all - // incoming HTTP requests on said connection. - let make_svc = make_service_fn(|_conn| { - let consul = consul.clone(); - // This is the `Service` that will handle the connection. - // `service_fn` is a helper to convert a function that - // returns a Response into a `Service`. - async move { - Ok::<_, anyhow::Error>(service_fn(move |req: Request<Body>| { - let consul = consul.clone(); - handle(req, consul) - })) - } - }); - - let addr = ([0, 0, 0, 0], 1080).into(); - - let server = Server::bind(&addr).serve(make_svc); - - println!("Listening on http://{}", addr); - - server.await?; - - Ok(()) -} diff --git a/src/https.rs b/src/https.rs index a62ebea..43a93e2 100644 --- a/src/https.rs +++ b/src/https.rs @@ -8,6 +8,7 @@ use futures::FutureExt; use hyper::server::conn::Http; use hyper::service::service_fn; use hyper::{Body, Request, Response, StatusCode}; +use http::header::{HeaderName, HeaderValue}; use tokio::net::TcpListener; use tokio::sync::watch; use tokio_rustls::TlsAcceptor; @@ -17,11 +18,10 @@ use crate::proxy_config::ProxyConfig; use crate::reverse_proxy; pub async fn serve_https( + bind_addr: SocketAddr, cert_store: Arc<CertStore>, proxy_config: watch::Receiver<Arc<ProxyConfig>>, ) -> Result<()> { - let addr = format!("0.0.0.0:1443"); - let mut cfg = rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() @@ -31,9 +31,9 @@ pub async fn serve_https( let tls_cfg = Arc::new(cfg); let tls_acceptor = Arc::new(TlsAcceptor::from(tls_cfg)); - println!("Starting to serve on https://{}.", addr); + info!("Starting to serve on https://{}.", bind_addr); - let tcp = TcpListener::bind(&addr).await?; + let tcp = TcpListener::bind(bind_addr).await?; loop { let (socket, remote_addr) = tcp.accept().await?; @@ -118,7 +118,13 @@ async fn handle( let to_addr = format!("http://{}", proxy_to.target_addr); info!("Proxying {} {} -> {}", host, path, to_addr); - reverse_proxy::call(remote_addr.ip(), &to_addr, req).await + let mut response = reverse_proxy::call(remote_addr.ip(), &to_addr, req).await?; + + for (header, value) in proxy_to.add_headers.iter() { + response.headers_mut().insert(HeaderName::from_bytes(header.as_bytes())?, HeaderValue::from_str(value)?); + } + + Ok(response) } else { info!("Proxying {} {} -> NOT FOUND", host, path); diff --git a/src/main.rs b/src/main.rs index 4a0c0ec..3a51702 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ #[macro_use] extern crate anyhow; +use std::net::SocketAddr; use structopt::StructOpt; mod cert; @@ -27,6 +28,14 @@ struct Opt { /// Node name #[structopt(long = "node-name", env = "TRICOT_NODE_NAME", default_value = "<none>")] pub node_name: String, + + /// Bind address for HTTP server + #[structopt(long = "http-bind-addr", env = "TRICOT_HTTP_BIND_ADDR", default_value = "0.0.0.0:80")] + pub http_bind_addr: SocketAddr, + + /// Bind address for HTTPS server + #[structopt(long = "https-bind-addr", env = "TRICOT_HTTPS_BIND_ADDR", default_value = "0.0.0.0:443")] + pub https_bind_addr: SocketAddr, } #[tokio::main(flavor = "multi_thread", worker_threads = 10)] @@ -50,13 +59,17 @@ async fn main() { .watch_proxy_config(rx_proxy_config.clone()), ); - tokio::spawn(http::serve_http(consul.clone())); + tokio::spawn(http::serve_http(opt.http_bind_addr, consul.clone())); tokio::spawn(https::serve_https( + opt.https_bind_addr, cert_store.clone(), rx_proxy_config.clone(), )); while rx_proxy_config.changed().await.is_ok() { - info!("Proxy config: {:#?}", *rx_proxy_config.borrow()); + info!("Proxy config:"); + for ent in rx_proxy_config.borrow().entries.iter() { + info!(" {:?}", ent); + } } } diff --git a/src/proxy_config.rs b/src/proxy_config.rs index 31a2659..9d07604 100644 --- a/src/proxy_config.rs +++ b/src/proxy_config.rs @@ -22,6 +22,7 @@ pub struct ProxyEntry { pub host: String, pub path_prefix: Option<String>, pub priority: u32, + pub add_headers: Vec<(String, String)>, // Counts the number of times this proxy server has been called to // This implements a round-robin load balancer if there are multiple @@ -44,7 +45,7 @@ fn retry_to_time(retries: u32, max_time: Duration) -> Duration { )); } -fn parse_tricot_tag(target_addr: SocketAddr, tag: &str) -> Option<ProxyEntry> { +fn parse_tricot_tag(tag: &str, target_addr: SocketAddr, add_headers: &[(String, String)]) -> Option<ProxyEntry> { let splits = tag.split(' ').collect::<Vec<_>>(); if (splits.len() != 2 && splits.len() != 3) || splits[0] != "tricot" { return None; @@ -65,10 +66,20 @@ fn parse_tricot_tag(target_addr: SocketAddr, tag: &str) -> Option<ProxyEntry> { host: host.to_string(), path_prefix, priority, + add_headers: add_headers.to_vec(), calls: atomic::AtomicU64::from(0), }) } +fn parse_tricot_add_header_tag(tag: &str) -> Option<(String, String)> { + let splits = tag.split(' ').collect::<Vec<_>>(); + if splits.len() == 3 && splits[0] == "tricot-add-header" { + Some((splits[1].to_string(), splits[2].to_string())) + } else { + None + } +} + fn parse_consul_catalog(catalog: &ConsulNodeCatalog) -> Vec<ProxyEntry> { let mut entries = vec![]; @@ -78,8 +89,16 @@ fn parse_consul_catalog(catalog: &ConsulNodeCatalog) -> Vec<ProxyEntry> { _ => continue, }; let addr = SocketAddr::new(ip_addr, svc.port); + + let mut add_headers = vec![]; + for tag in svc.tags.iter() { + if let Some(pair) = parse_tricot_add_header_tag(tag) { + add_headers.push(pair); + } + } + for tag in svc.tags.iter() { - if let Some(ent) = parse_tricot_tag(addr, tag) { + if let Some(ent) = parse_tricot_tag(tag, addr, &add_headers[..]) { entries.push(ent); } } @@ -181,7 +200,6 @@ pub fn spawn_proxy_config_task(consul: Consul) -> watch::Receiver<Arc<ProxyConfi } } let config = ProxyConfig { entries }; - debug!("Extracted configuration: {:#?}", config); tx.send(Arc::new(config)).expect("Internal error"); } |