aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Auvolat <alex@adnab.me>2021-12-07 18:19:51 +0100
committerAlex Auvolat <alex@adnab.me>2021-12-07 18:19:51 +0100
commit489d364676003fa08130689a9f509de7d4df1602 (patch)
treec3595a10ef94eead74da41101be7bd42ed292c58
parent0682c74e9d5083b43b3f83f8bb1ca747658d1455 (diff)
downloadtricot-489d364676003fa08130689a9f509de7d4df1602.tar.gz
tricot-489d364676003fa08130689a9f509de7d4df1602.zip
Add support for custom headers
-rw-r--r--.gitignore1
-rw-r--r--src/cert_store.rs2
-rw-r--r--src/http.rs57
-rw-r--r--src/https.rs16
-rw-r--r--src/main.rs17
-rw-r--r--src/proxy_config.rs24
6 files changed, 78 insertions, 39 deletions
diff --git a/.gitignore b/.gitignore
index ea8c4bf..1c2b9da 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1,2 @@
/target
+run_local.sh
diff --git a/src/cert_store.rs b/src/cert_store.rs
index a58288c..eca39b9 100644
--- a/src/cert_store.rs
+++ b/src/cert_store.rs
@@ -40,7 +40,7 @@ impl CertStore {
for ent in proxy_config.entries.iter() {
domains.insert(ent.host.clone());
}
- info!("Ensuring we have certs for domains: {:#?}", domains);
+ info!("Ensuring we have certs for domains: {:?}", domains);
for dom in domains.iter() {
if let Err(e) = self.get_cert(dom).await {
diff --git a/src/http.rs b/src/http.rs
index 385456a..4731645 100644
--- a/src/http.rs
+++ b/src/http.rs
@@ -1,4 +1,5 @@
use std::sync::Arc;
+use std::net::SocketAddr;
use anyhow::Result;
use log::*;
@@ -11,6 +12,34 @@ use crate::consul::Consul;
const CHALLENGE_PREFIX: &str = "/.well-known/acme-challenge/";
+pub async fn serve_http(
+ bind_addr: SocketAddr,
+ consul: Consul,
+) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
+ let consul = Arc::new(consul);
+ // For every connection, we must make a `Service` to handle all
+ // incoming HTTP requests on said connection.
+ let make_svc = make_service_fn(|_conn| {
+ let consul = consul.clone();
+ // This is the `Service` that will handle the connection.
+ // `service_fn` is a helper to convert a function that
+ // returns a Response into a `Service`.
+ async move {
+ Ok::<_, anyhow::Error>(service_fn(move |req: Request<Body>| {
+ let consul = consul.clone();
+ handle(req, consul)
+ }))
+ }
+ });
+
+ info!("Listening on http://{}", bind_addr);
+ let server = Server::bind(&bind_addr).serve(make_svc);
+
+ server.await?;
+
+ Ok(())
+}
+
async fn handle(req: Request<Body>, consul: Arc<Consul>) -> Result<Response<Body>> {
let path = req.uri().path();
info!("HTTP request {}", path);
@@ -45,31 +74,3 @@ async fn handle(req: Request<Body>, consul: Arc<Consul>) -> Result<Response<Body
.body(Body::from(""))?)
}
}
-
-pub async fn serve_http(consul: Consul) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
- let consul = Arc::new(consul);
- // For every connection, we must make a `Service` to handle all
- // incoming HTTP requests on said connection.
- let make_svc = make_service_fn(|_conn| {
- let consul = consul.clone();
- // This is the `Service` that will handle the connection.
- // `service_fn` is a helper to convert a function that
- // returns a Response into a `Service`.
- async move {
- Ok::<_, anyhow::Error>(service_fn(move |req: Request<Body>| {
- let consul = consul.clone();
- handle(req, consul)
- }))
- }
- });
-
- let addr = ([0, 0, 0, 0], 1080).into();
-
- let server = Server::bind(&addr).serve(make_svc);
-
- println!("Listening on http://{}", addr);
-
- server.await?;
-
- Ok(())
-}
diff --git a/src/https.rs b/src/https.rs
index a62ebea..43a93e2 100644
--- a/src/https.rs
+++ b/src/https.rs
@@ -8,6 +8,7 @@ use futures::FutureExt;
use hyper::server::conn::Http;
use hyper::service::service_fn;
use hyper::{Body, Request, Response, StatusCode};
+use http::header::{HeaderName, HeaderValue};
use tokio::net::TcpListener;
use tokio::sync::watch;
use tokio_rustls::TlsAcceptor;
@@ -17,11 +18,10 @@ use crate::proxy_config::ProxyConfig;
use crate::reverse_proxy;
pub async fn serve_https(
+ bind_addr: SocketAddr,
cert_store: Arc<CertStore>,
proxy_config: watch::Receiver<Arc<ProxyConfig>>,
) -> Result<()> {
- let addr = format!("0.0.0.0:1443");
-
let mut cfg = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
@@ -31,9 +31,9 @@ pub async fn serve_https(
let tls_cfg = Arc::new(cfg);
let tls_acceptor = Arc::new(TlsAcceptor::from(tls_cfg));
- println!("Starting to serve on https://{}.", addr);
+ info!("Starting to serve on https://{}.", bind_addr);
- let tcp = TcpListener::bind(&addr).await?;
+ let tcp = TcpListener::bind(bind_addr).await?;
loop {
let (socket, remote_addr) = tcp.accept().await?;
@@ -118,7 +118,13 @@ async fn handle(
let to_addr = format!("http://{}", proxy_to.target_addr);
info!("Proxying {} {} -> {}", host, path, to_addr);
- reverse_proxy::call(remote_addr.ip(), &to_addr, req).await
+ let mut response = reverse_proxy::call(remote_addr.ip(), &to_addr, req).await?;
+
+ for (header, value) in proxy_to.add_headers.iter() {
+ response.headers_mut().insert(HeaderName::from_bytes(header.as_bytes())?, HeaderValue::from_str(value)?);
+ }
+
+ Ok(response)
} else {
info!("Proxying {} {} -> NOT FOUND", host, path);
diff --git a/src/main.rs b/src/main.rs
index 4a0c0ec..3a51702 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,6 +1,7 @@
#[macro_use]
extern crate anyhow;
+use std::net::SocketAddr;
use structopt::StructOpt;
mod cert;
@@ -27,6 +28,14 @@ struct Opt {
/// Node name
#[structopt(long = "node-name", env = "TRICOT_NODE_NAME", default_value = "<none>")]
pub node_name: String,
+
+ /// Bind address for HTTP server
+ #[structopt(long = "http-bind-addr", env = "TRICOT_HTTP_BIND_ADDR", default_value = "0.0.0.0:80")]
+ pub http_bind_addr: SocketAddr,
+
+ /// Bind address for HTTPS server
+ #[structopt(long = "https-bind-addr", env = "TRICOT_HTTPS_BIND_ADDR", default_value = "0.0.0.0:443")]
+ pub https_bind_addr: SocketAddr,
}
#[tokio::main(flavor = "multi_thread", worker_threads = 10)]
@@ -50,13 +59,17 @@ async fn main() {
.watch_proxy_config(rx_proxy_config.clone()),
);
- tokio::spawn(http::serve_http(consul.clone()));
+ tokio::spawn(http::serve_http(opt.http_bind_addr, consul.clone()));
tokio::spawn(https::serve_https(
+ opt.https_bind_addr,
cert_store.clone(),
rx_proxy_config.clone(),
));
while rx_proxy_config.changed().await.is_ok() {
- info!("Proxy config: {:#?}", *rx_proxy_config.borrow());
+ info!("Proxy config:");
+ for ent in rx_proxy_config.borrow().entries.iter() {
+ info!(" {:?}", ent);
+ }
}
}
diff --git a/src/proxy_config.rs b/src/proxy_config.rs
index 31a2659..9d07604 100644
--- a/src/proxy_config.rs
+++ b/src/proxy_config.rs
@@ -22,6 +22,7 @@ pub struct ProxyEntry {
pub host: String,
pub path_prefix: Option<String>,
pub priority: u32,
+ pub add_headers: Vec<(String, String)>,
// Counts the number of times this proxy server has been called to
// This implements a round-robin load balancer if there are multiple
@@ -44,7 +45,7 @@ fn retry_to_time(retries: u32, max_time: Duration) -> Duration {
));
}
-fn parse_tricot_tag(target_addr: SocketAddr, tag: &str) -> Option<ProxyEntry> {
+fn parse_tricot_tag(tag: &str, target_addr: SocketAddr, add_headers: &[(String, String)]) -> Option<ProxyEntry> {
let splits = tag.split(' ').collect::<Vec<_>>();
if (splits.len() != 2 && splits.len() != 3) || splits[0] != "tricot" {
return None;
@@ -65,10 +66,20 @@ fn parse_tricot_tag(target_addr: SocketAddr, tag: &str) -> Option<ProxyEntry> {
host: host.to_string(),
path_prefix,
priority,
+ add_headers: add_headers.to_vec(),
calls: atomic::AtomicU64::from(0),
})
}
+fn parse_tricot_add_header_tag(tag: &str) -> Option<(String, String)> {
+ let splits = tag.split(' ').collect::<Vec<_>>();
+ if splits.len() == 3 && splits[0] == "tricot-add-header" {
+ Some((splits[1].to_string(), splits[2].to_string()))
+ } else {
+ None
+ }
+}
+
fn parse_consul_catalog(catalog: &ConsulNodeCatalog) -> Vec<ProxyEntry> {
let mut entries = vec![];
@@ -78,8 +89,16 @@ fn parse_consul_catalog(catalog: &ConsulNodeCatalog) -> Vec<ProxyEntry> {
_ => continue,
};
let addr = SocketAddr::new(ip_addr, svc.port);
+
+ let mut add_headers = vec![];
+ for tag in svc.tags.iter() {
+ if let Some(pair) = parse_tricot_add_header_tag(tag) {
+ add_headers.push(pair);
+ }
+ }
+
for tag in svc.tags.iter() {
- if let Some(ent) = parse_tricot_tag(addr, tag) {
+ if let Some(ent) = parse_tricot_tag(tag, addr, &add_headers[..]) {
entries.push(ent);
}
}
@@ -181,7 +200,6 @@ pub fn spawn_proxy_config_task(consul: Consul) -> watch::Receiver<Arc<ProxyConfi
}
}
let config = ProxyConfig { entries };
- debug!("Extracted configuration: {:#?}", config);
tx.send(Arc::new(config)).expect("Internal error");
}