aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlex Auvolat <alex@adnab.me>2021-12-07 15:20:45 +0100
committerAlex Auvolat <alex@adnab.me>2021-12-07 15:20:45 +0100
commitcd7e5ad034b75d659d4d87a752ab7b11cf75de12 (patch)
tree32773f9758b33188402e137d435bdd61ce01b280 /src
parent5535c4951a832d65755afa53822a36e96681320f (diff)
downloadtricot-cd7e5ad034b75d659d4d87a752ab7b11cf75de12.tar.gz
tricot-cd7e5ad034b75d659d4d87a752ab7b11cf75de12.zip
Got a reverse proxy
Diffstat (limited to 'src')
-rw-r--r--src/cert.rs6
-rw-r--r--src/cert_store.rs23
-rw-r--r--src/https.rs126
-rw-r--r--src/main.rs8
-rw-r--r--src/reverse_proxy.rs114
5 files changed, 269 insertions, 8 deletions
diff --git a/src/cert.rs b/src/cert.rs
index de0d821..0be43f3 100644
--- a/src/cert.rs
+++ b/src/cert.rs
@@ -1,3 +1,5 @@
+use std::sync::Arc;
+
use anyhow::Result;
use chrono::{Date, NaiveDate, Utc};
@@ -17,7 +19,7 @@ pub struct CertSer {
pub struct Cert {
pub ser: CertSer,
- pub certkey: CertifiedKey,
+ pub certkey: Arc<CertifiedKey>,
}
impl Cert {
@@ -46,7 +48,7 @@ impl Cert {
bail!("{} keys present in pem file", keys.len());
}
- let certkey = CertifiedKey::new(certs, keys.into_iter().next().unwrap());
+ let certkey = Arc::new(CertifiedKey::new(certs, keys.into_iter().next().unwrap()));
Ok(Cert { ser, certkey })
}
diff --git a/src/cert_store.rs b/src/cert_store.rs
index 6529395..1b1a478 100644
--- a/src/cert_store.rs
+++ b/src/cert_store.rs
@@ -6,9 +6,11 @@ use anyhow::Result;
use chrono::Utc;
use log::*;
use tokio::sync::watch;
+use tokio::task::block_in_place;
use acme_micro::create_p384_key;
use acme_micro::{Directory, DirectoryUrl};
+use rustls::sign::CertifiedKey;
use crate::cert::{Cert, CertSer};
use crate::consul::Consul;
@@ -93,7 +95,7 @@ impl CertStore {
dir.load_account(std::str::from_utf8(&acc_privkey)?, contact)?
} else {
info!("Creating new Let's encrypt account");
- let acc = dir.register_account(contact.clone())?;
+ let acc = block_in_place(|| dir.register_account(contact.clone()))?;
self.consul
.kv_put(
"letsencrypt_account_key.pem",
@@ -119,17 +121,18 @@ impl CertStore {
.await?;
info!("Validating challenge");
- chall.validate(Duration::from_millis(5000))?;
+ block_in_place(|| chall.validate(Duration::from_millis(5000)))?;
info!("Deleting challenge");
self.consul.kv_delete(&chall_key).await?;
- ord_new.refresh()?;
+ block_in_place(|| ord_new.refresh())?;
};
let pkey_pri = create_p384_key()?;
- let ord_cert = ord_csr.finalize_pkey(pkey_pri, Duration::from_millis(5000))?;
- let cert = ord_cert.download_cert()?;
+ let ord_cert =
+ block_in_place(|| ord_csr.finalize_pkey(pkey_pri, Duration::from_millis(5000)))?;
+ let cert = block_in_place(|| ord_cert.download_cert())?;
info!("Keys and certificate obtained");
let key_pem = cert.private_key().to_string();
@@ -157,3 +160,13 @@ impl CertStore {
Ok(cert)
}
}
+
+pub struct StoreResolver(pub Arc<CertStore>);
+
+impl rustls::server::ResolvesServerCert for StoreResolver {
+ fn resolve(&self, client_hello: rustls::server::ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
+ let domain = client_hello.server_name()?;
+ let cert = futures::executor::block_on(self.0.get_cert(domain)).ok()?;
+ Some(cert.certkey.clone())
+ }
+}
diff --git a/src/https.rs b/src/https.rs
new file mode 100644
index 0000000..c80d51c
--- /dev/null
+++ b/src/https.rs
@@ -0,0 +1,126 @@
+use std::net::SocketAddr;
+use std::sync::Arc;
+
+use anyhow::Result;
+use log::*;
+
+use futures::FutureExt;
+use hyper::server::conn::Http;
+use hyper::service::service_fn;
+use hyper::{Body, Request, Response, StatusCode};
+use tokio::net::TcpListener;
+use tokio::sync::watch;
+use tokio_rustls::TlsAcceptor;
+
+use crate::cert_store::{CertStore, StoreResolver};
+use crate::proxy_config::ProxyConfig;
+use crate::reverse_proxy;
+
+pub async fn serve_https(
+ cert_store: Arc<CertStore>,
+ proxy_config: watch::Receiver<Arc<ProxyConfig>>,
+) -> Result<()> {
+ let addr = format!("0.0.0.0:1443");
+
+ let mut cfg = rustls::ServerConfig::builder()
+ .with_safe_defaults()
+ .with_no_client_auth()
+ .with_cert_resolver(Arc::new(StoreResolver(cert_store)));
+
+ cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
+ let tls_cfg = Arc::new(cfg);
+ let tls_acceptor = Arc::new(TlsAcceptor::from(tls_cfg));
+
+ println!("Starting to serve on https://{}.", addr);
+
+ let tcp = TcpListener::bind(&addr).await?;
+ loop {
+ let (socket, remote_addr) = tcp.accept().await?;
+
+ let proxy_config = proxy_config.clone();
+ let tls_acceptor = tls_acceptor.clone();
+
+ tokio::spawn(async move {
+ match tls_acceptor.accept(socket).await {
+ Ok(stream) => {
+ debug!("TLS handshake was successfull");
+ let http_result = Http::new()
+ .serve_connection(
+ stream,
+ service_fn(move |req: Request<Body>| {
+ let proxy_config: Arc<ProxyConfig> = proxy_config.borrow().clone();
+ handle(remote_addr, req, proxy_config).map(|res| match res {
+ Err(e) => {
+ warn!("Handler error: {}", e);
+ Response::builder()
+ .status(StatusCode::INTERNAL_SERVER_ERROR)
+ .body(Body::from(format!("{}", e)))
+ .map_err(Into::into)
+ }
+ x => x,
+ })
+ }),
+ )
+ .await;
+ if let Err(http_err) = http_result {
+ debug!("HTTP error: {}", http_err);
+ }
+ }
+ Err(e) => debug!("Error in TLS connection: {}", e),
+ }
+ });
+ }
+}
+
+// Custom echo service, handling two different routes and a
+// catch-all 404 responder.
+async fn handle(
+ remote_addr: SocketAddr,
+ req: Request<Body>,
+ proxy_config: Arc<ProxyConfig>,
+) -> Result<Response<Body>, anyhow::Error> {
+ let host = if let Some(auth) = req.uri().authority() {
+ auth.as_str()
+ } else {
+ req.headers()
+ .get("host")
+ .ok_or_else(|| anyhow!("Missing host header"))?
+ .to_str()?
+ };
+ let path = req.uri().path();
+
+ let ent = proxy_config
+ .entries
+ .iter()
+ .filter(|ent| {
+ ent.host == host
+ && ent
+ .path_prefix
+ .as_ref()
+ .map(|prefix| path.starts_with(prefix))
+ .unwrap_or(true)
+ })
+ .min_by_key(|ent| {
+ (
+ ent.priority,
+ -(ent
+ .path_prefix
+ .as_ref()
+ .map(|x| x.len() as i32)
+ .unwrap_or(0)),
+ )
+ });
+
+ if let Some(proxy_to) = ent {
+ let to_addr = format!("http://{}", proxy_to.target_addr);
+ info!("Proxying {} {} -> {}", host, path, to_addr);
+
+ reverse_proxy::call(remote_addr.ip(), &to_addr, req).await
+ } else {
+ info!("Proxying {} {} -> NOT FOUND", host, path);
+
+ Ok(Response::builder()
+ .status(StatusCode::NOT_FOUND)
+ .body(Body::from("No matching proxy entry"))?)
+ }
+}
diff --git a/src/main.rs b/src/main.rs
index d7f1e24..df0845d 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -5,11 +5,13 @@ mod cert;
mod cert_store;
mod consul;
mod http;
+mod https;
mod proxy_config;
+mod reverse_proxy;
use log::*;
-#[tokio::main(flavor = "multi_thread")]
+#[tokio::main(flavor = "multi_thread", worker_threads = 10)]
async fn main() {
if std::env::var("RUST_LOG").is_err() {
std::env::set_var("RUST_LOG", "tricot=debug")
@@ -28,6 +30,10 @@ async fn main() {
);
tokio::spawn(http::serve_http(consul.clone()));
+ tokio::spawn(https::serve_https(
+ cert_store.clone(),
+ rx_proxy_config.clone(),
+ ));
while rx_proxy_config.changed().await.is_ok() {
info!("Proxy config: {:#?}", *rx_proxy_config.borrow());
diff --git a/src/reverse_proxy.rs b/src/reverse_proxy.rs
new file mode 100644
index 0000000..82533d8
--- /dev/null
+++ b/src/reverse_proxy.rs
@@ -0,0 +1,114 @@
+//! Copied from https://github.com/felipenoris/hyper-reverse-proxy
+//! See there for original Copyright notice
+
+use anyhow::Result;
+
+use hyper::header::{HeaderMap, HeaderValue};
+use hyper::{Body, Client, Request, Response, Uri};
+use lazy_static::lazy_static;
+use std::net::IpAddr;
+use std::str::FromStr;
+
+fn is_hop_header(name: &str) -> bool {
+ use unicase::Ascii;
+
+ // A list of the headers, using `unicase` to help us compare without
+ // worrying about the case, and `lazy_static!` to prevent reallocation
+ // of the vector.
+ lazy_static! {
+ static ref HOP_HEADERS: Vec<Ascii<&'static str>> = vec![
+ Ascii::new("Connection"),
+ Ascii::new("Keep-Alive"),
+ Ascii::new("Proxy-Authenticate"),
+ Ascii::new("Proxy-Authorization"),
+ Ascii::new("Te"),
+ Ascii::new("Trailers"),
+ Ascii::new("Transfer-Encoding"),
+ Ascii::new("Upgrade"),
+ ];
+ }
+
+ HOP_HEADERS.iter().any(|h| h == &name)
+}
+
+/// Returns a clone of the headers without the [hop-by-hop headers].
+///
+/// [hop-by-hop headers]: http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
+fn remove_hop_headers(headers: &HeaderMap<HeaderValue>) -> HeaderMap<HeaderValue> {
+ let mut result = HeaderMap::new();
+ for (k, v) in headers.iter() {
+ if !is_hop_header(k.as_str()) {
+ result.insert(k.clone(), v.clone());
+ }
+ }
+ result
+}
+
+fn create_proxied_response<B>(mut response: Response<B>) -> Response<B> {
+ *response.headers_mut() = remove_hop_headers(response.headers());
+ response
+}
+
+fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> Result<Uri> {
+ let forward_uri = match req.uri().query() {
+ Some(query) => format!("{}{}?{}", forward_url, req.uri().path(), query),
+ None => format!("{}{}", forward_url, req.uri().path()),
+ };
+
+ Ok(Uri::from_str(forward_uri.as_str())?)
+}
+
+fn create_proxied_request<B>(
+ client_ip: IpAddr,
+ forward_url: &str,
+ request: Request<B>,
+) -> Result<Request<B>> {
+ let mut builder = Request::builder().uri(forward_uri(forward_url, &request)?);
+
+ *builder.headers_mut().unwrap() = remove_hop_headers(request.headers());
+
+ let host_header_name = "host";
+ let x_forwarded_for_header_name = "x-forwarded-for";
+
+ // If request does not have host header, add it from original URI authority
+ if let Some(authority) = request.uri().authority() {
+ if let hyper::header::Entry::Vacant(entry) = builder
+ .headers_mut()
+ .unwrap()
+ .entry(host_header_name)
+ {
+ entry.insert(authority.as_str().parse()?);
+ }
+ }
+
+ // Add forwarding information in the headers
+ match builder
+ .headers_mut()
+ .unwrap()
+ .entry(x_forwarded_for_header_name)
+ {
+ hyper::header::Entry::Vacant(entry) => {
+ entry.insert(client_ip.to_string().parse()?);
+ }
+
+ hyper::header::Entry::Occupied(mut entry) => {
+ let addr = format!("{}, {}", entry.get().to_str()?, client_ip);
+ entry.insert(addr.parse()?);
+ }
+ }
+
+ Ok(builder.body(request.into_body())?)
+}
+
+pub async fn call(
+ client_ip: IpAddr,
+ forward_uri: &str,
+ request: Request<Body>,
+) -> Result<Response<Body>> {
+ let proxied_request = create_proxied_request(client_ip, &forward_uri, request)?;
+
+ let client = Client::new();
+ let response = client.request(proxied_request).await?;
+ let proxied_response = create_proxied_response(response);
+ Ok(proxied_response)
+}