From 489d364676003fa08130689a9f509de7d4df1602 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Tue, 7 Dec 2021 18:19:51 +0100 Subject: Add support for custom headers --- .gitignore | 1 + src/cert_store.rs | 2 +- src/http.rs | 57 +++++++++++++++++++++++++++-------------------------- src/https.rs | 16 ++++++++++----- src/main.rs | 17 ++++++++++++++-- src/proxy_config.rs | 24 +++++++++++++++++++--- 6 files changed, 78 insertions(+), 39 deletions(-) diff --git a/.gitignore b/.gitignore index ea8c4bf..1c2b9da 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ /target +run_local.sh diff --git a/src/cert_store.rs b/src/cert_store.rs index a58288c..eca39b9 100644 --- a/src/cert_store.rs +++ b/src/cert_store.rs @@ -40,7 +40,7 @@ impl CertStore { for ent in proxy_config.entries.iter() { domains.insert(ent.host.clone()); } - info!("Ensuring we have certs for domains: {:#?}", domains); + info!("Ensuring we have certs for domains: {:?}", domains); for dom in domains.iter() { if let Err(e) = self.get_cert(dom).await { diff --git a/src/http.rs b/src/http.rs index 385456a..4731645 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,4 +1,5 @@ use std::sync::Arc; +use std::net::SocketAddr; use anyhow::Result; use log::*; @@ -11,6 +12,34 @@ use crate::consul::Consul; const CHALLENGE_PREFIX: &str = "/.well-known/acme-challenge/"; +pub async fn serve_http( + bind_addr: SocketAddr, + consul: Consul, +) -> Result<(), Box> { + let consul = Arc::new(consul); + // For every connection, we must make a `Service` to handle all + // incoming HTTP requests on said connection. + let make_svc = make_service_fn(|_conn| { + let consul = consul.clone(); + // This is the `Service` that will handle the connection. + // `service_fn` is a helper to convert a function that + // returns a Response into a `Service`. + async move { + Ok::<_, anyhow::Error>(service_fn(move |req: Request| { + let consul = consul.clone(); + handle(req, consul) + })) + } + }); + + info!("Listening on http://{}", bind_addr); + let server = Server::bind(&bind_addr).serve(make_svc); + + server.await?; + + Ok(()) +} + async fn handle(req: Request, consul: Arc) -> Result> { let path = req.uri().path(); info!("HTTP request {}", path); @@ -45,31 +74,3 @@ async fn handle(req: Request, consul: Arc) -> Result Result<(), Box> { - let consul = Arc::new(consul); - // For every connection, we must make a `Service` to handle all - // incoming HTTP requests on said connection. - let make_svc = make_service_fn(|_conn| { - let consul = consul.clone(); - // This is the `Service` that will handle the connection. - // `service_fn` is a helper to convert a function that - // returns a Response into a `Service`. - async move { - Ok::<_, anyhow::Error>(service_fn(move |req: Request| { - let consul = consul.clone(); - handle(req, consul) - })) - } - }); - - let addr = ([0, 0, 0, 0], 1080).into(); - - let server = Server::bind(&addr).serve(make_svc); - - println!("Listening on http://{}", addr); - - server.await?; - - Ok(()) -} diff --git a/src/https.rs b/src/https.rs index a62ebea..43a93e2 100644 --- a/src/https.rs +++ b/src/https.rs @@ -8,6 +8,7 @@ use futures::FutureExt; use hyper::server::conn::Http; use hyper::service::service_fn; use hyper::{Body, Request, Response, StatusCode}; +use http::header::{HeaderName, HeaderValue}; use tokio::net::TcpListener; use tokio::sync::watch; use tokio_rustls::TlsAcceptor; @@ -17,11 +18,10 @@ use crate::proxy_config::ProxyConfig; use crate::reverse_proxy; pub async fn serve_https( + bind_addr: SocketAddr, cert_store: Arc, proxy_config: watch::Receiver>, ) -> Result<()> { - let addr = format!("0.0.0.0:1443"); - let mut cfg = rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() @@ -31,9 +31,9 @@ pub async fn serve_https( let tls_cfg = Arc::new(cfg); let tls_acceptor = Arc::new(TlsAcceptor::from(tls_cfg)); - println!("Starting to serve on https://{}.", addr); + info!("Starting to serve on https://{}.", bind_addr); - let tcp = TcpListener::bind(&addr).await?; + let tcp = TcpListener::bind(bind_addr).await?; loop { let (socket, remote_addr) = tcp.accept().await?; @@ -118,7 +118,13 @@ async fn handle( let to_addr = format!("http://{}", proxy_to.target_addr); info!("Proxying {} {} -> {}", host, path, to_addr); - reverse_proxy::call(remote_addr.ip(), &to_addr, req).await + let mut response = reverse_proxy::call(remote_addr.ip(), &to_addr, req).await?; + + for (header, value) in proxy_to.add_headers.iter() { + response.headers_mut().insert(HeaderName::from_bytes(header.as_bytes())?, HeaderValue::from_str(value)?); + } + + Ok(response) } else { info!("Proxying {} {} -> NOT FOUND", host, path); diff --git a/src/main.rs b/src/main.rs index 4a0c0ec..3a51702 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ #[macro_use] extern crate anyhow; +use std::net::SocketAddr; use structopt::StructOpt; mod cert; @@ -27,6 +28,14 @@ struct Opt { /// Node name #[structopt(long = "node-name", env = "TRICOT_NODE_NAME", default_value = "")] pub node_name: String, + + /// Bind address for HTTP server + #[structopt(long = "http-bind-addr", env = "TRICOT_HTTP_BIND_ADDR", default_value = "0.0.0.0:80")] + pub http_bind_addr: SocketAddr, + + /// Bind address for HTTPS server + #[structopt(long = "https-bind-addr", env = "TRICOT_HTTPS_BIND_ADDR", default_value = "0.0.0.0:443")] + pub https_bind_addr: SocketAddr, } #[tokio::main(flavor = "multi_thread", worker_threads = 10)] @@ -50,13 +59,17 @@ async fn main() { .watch_proxy_config(rx_proxy_config.clone()), ); - tokio::spawn(http::serve_http(consul.clone())); + tokio::spawn(http::serve_http(opt.http_bind_addr, consul.clone())); tokio::spawn(https::serve_https( + opt.https_bind_addr, cert_store.clone(), rx_proxy_config.clone(), )); while rx_proxy_config.changed().await.is_ok() { - info!("Proxy config: {:#?}", *rx_proxy_config.borrow()); + info!("Proxy config:"); + for ent in rx_proxy_config.borrow().entries.iter() { + info!(" {:?}", ent); + } } } diff --git a/src/proxy_config.rs b/src/proxy_config.rs index 31a2659..9d07604 100644 --- a/src/proxy_config.rs +++ b/src/proxy_config.rs @@ -22,6 +22,7 @@ pub struct ProxyEntry { pub host: String, pub path_prefix: Option, pub priority: u32, + pub add_headers: Vec<(String, String)>, // Counts the number of times this proxy server has been called to // This implements a round-robin load balancer if there are multiple @@ -44,7 +45,7 @@ fn retry_to_time(retries: u32, max_time: Duration) -> Duration { )); } -fn parse_tricot_tag(target_addr: SocketAddr, tag: &str) -> Option { +fn parse_tricot_tag(tag: &str, target_addr: SocketAddr, add_headers: &[(String, String)]) -> Option { let splits = tag.split(' ').collect::>(); if (splits.len() != 2 && splits.len() != 3) || splits[0] != "tricot" { return None; @@ -65,10 +66,20 @@ fn parse_tricot_tag(target_addr: SocketAddr, tag: &str) -> Option { host: host.to_string(), path_prefix, priority, + add_headers: add_headers.to_vec(), calls: atomic::AtomicU64::from(0), }) } +fn parse_tricot_add_header_tag(tag: &str) -> Option<(String, String)> { + let splits = tag.split(' ').collect::>(); + if splits.len() == 3 && splits[0] == "tricot-add-header" { + Some((splits[1].to_string(), splits[2].to_string())) + } else { + None + } +} + fn parse_consul_catalog(catalog: &ConsulNodeCatalog) -> Vec { let mut entries = vec![]; @@ -78,8 +89,16 @@ fn parse_consul_catalog(catalog: &ConsulNodeCatalog) -> Vec { _ => continue, }; let addr = SocketAddr::new(ip_addr, svc.port); + + let mut add_headers = vec![]; + for tag in svc.tags.iter() { + if let Some(pair) = parse_tricot_add_header_tag(tag) { + add_headers.push(pair); + } + } + for tag in svc.tags.iter() { - if let Some(ent) = parse_tricot_tag(addr, tag) { + if let Some(ent) = parse_tricot_tag(tag, addr, &add_headers[..]) { entries.push(ent); } } @@ -181,7 +200,6 @@ pub fn spawn_proxy_config_task(consul: Consul) -> watch::Receiver