aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/api/generic_server.rs155
-rw-r--r--src/garage/server.rs5
-rw-r--r--src/web/web_server.rs75
3 files changed, 132 insertions, 103 deletions
diff --git a/src/api/generic_server.rs b/src/api/generic_server.rs
index 832f2da3..e3005f8a 100644
--- a/src/api/generic_server.rs
+++ b/src/api/generic_server.rs
@@ -5,6 +5,7 @@ use std::sync::Arc;
use async_trait::async_trait;
use futures::future::Future;
+use futures::stream::{futures_unordered::FuturesUnordered, StreamExt};
use http_body_util::BodyExt;
use hyper::header::HeaderValue;
@@ -15,7 +16,7 @@ use hyper::{HeaderMap, StatusCode};
use hyper_util::rt::TokioIo;
use tokio::io::{AsyncRead, AsyncWrite};
-use tokio::net::{TcpListener, UnixListener};
+use tokio::net::{TcpListener, TcpStream, UnixListener, UnixStream};
use opentelemetry::{
global,
@@ -110,20 +111,12 @@ impl<A: ApiHandler> ApiServer<A> {
bind_addr
);
- tokio::pin!(shutdown_signal);
-
match bind_addr {
UnixOrTCPSocketAddress::TCPSocket(addr) => {
let listener = TcpListener::bind(addr).await?;
- loop {
- let (stream, client_addr) = tokio::select! {
- acc = listener.accept() => acc?,
- _ = &mut shutdown_signal => break,
- };
-
- self.launch_handler(stream, client_addr.to_string());
- }
+ let handler = move |request, socketaddr| self.clone().handler(request, socketaddr);
+ server_loop(listener, handler, shutdown_signal).await
}
UnixOrTCPSocketAddress::UnixSocket(ref path) => {
if path.exists() {
@@ -131,52 +124,24 @@ impl<A: ApiHandler> ApiServer<A> {
}
let listener = UnixListener::bind(path)?;
+ let listener = UnixListenerOn(listener, path.display().to_string());
fs::set_permissions(
path,
Permissions::from_mode(unix_bind_addr_mode.unwrap_or(0o222)),
)?;
- loop {
- let (stream, _) = tokio::select! {
- acc = listener.accept() => acc?,
- _ = &mut shutdown_signal => break,
- };
-
- self.launch_handler(stream, path.display().to_string());
- }
- }
- };
-
- Ok(())
- }
-
- fn launch_handler<S>(self: &Arc<Self>, stream: S, client_addr: String)
- where
- S: AsyncRead + AsyncWrite + Send + Sync + 'static,
- {
- let this = self.clone();
- let io = TokioIo::new(stream);
-
- let serve =
- move |req: Request<IncomingBody>| this.clone().handler(req, client_addr.to_string());
-
- tokio::task::spawn(async move {
- let io = Box::pin(io);
- if let Err(e) = http1::Builder::new()
- .serve_connection(io, service_fn(serve))
- .await
- {
- debug!("Error handling HTTP connection: {}", e);
+ let handler = move |request, socketaddr| self.clone().handler(request, socketaddr);
+ server_loop(listener, handler, shutdown_signal).await
}
- });
+ }
}
async fn handler(
self: Arc<Self>,
req: Request<IncomingBody>,
addr: String,
- ) -> Result<Response<BoxBody<A::Error>>, GarageError> {
+ ) -> Result<Response<BoxBody<A::Error>>, http::Error> {
let uri = req.uri().clone();
if let Ok(forwarded_for_ip_addr) =
@@ -278,3 +243,105 @@ impl<A: ApiHandler> ApiServer<A> {
res
}
}
+
+// ==== helper functions ====
+
+#[async_trait]
+pub trait Accept: Send + Sync + 'static {
+ type Stream: AsyncRead + AsyncWrite + Send + Sync + 'static;
+ async fn accept(&self) -> std::io::Result<(Self::Stream, String)>;
+}
+
+#[async_trait]
+impl Accept for TcpListener {
+ type Stream = TcpStream;
+ async fn accept(&self) -> std::io::Result<(Self::Stream, String)> {
+ self.accept()
+ .await
+ .map(|(stream, addr)| (stream, addr.to_string()))
+ }
+}
+
+pub struct UnixListenerOn(pub UnixListener, pub String);
+
+#[async_trait]
+impl Accept for UnixListenerOn {
+ type Stream = UnixStream;
+ async fn accept(&self) -> std::io::Result<(Self::Stream, String)> {
+ self.0
+ .accept()
+ .await
+ .map(|(stream, _addr)| (stream, self.1.clone()))
+ }
+}
+
+pub async fn server_loop<A, H, F, E>(
+ listener: A,
+ handler: H,
+ shutdown_signal: impl Future<Output = ()>,
+) -> Result<(), GarageError>
+where
+ A: Accept,
+ H: Fn(Request<IncomingBody>, String) -> F + Send + Sync + Clone + 'static,
+ F: Future<Output = Result<Response<BoxBody<E>>, http::Error>> + Send + 'static,
+ E: Send + Sync + std::error::Error + 'static,
+{
+ tokio::pin!(shutdown_signal);
+
+ let (conn_in, mut conn_out) = tokio::sync::mpsc::unbounded_channel();
+ let connection_collector = tokio::spawn(async move {
+ let mut collection = FuturesUnordered::new();
+ loop {
+ let collect_next = async {
+ if collection.is_empty() {
+ futures::future::pending().await
+ } else {
+ collection.next().await
+ }
+ };
+ tokio::select! {
+ result = collect_next => {
+ trace!("HTTP connection finished: {:?}", result);
+ }
+ new_fut = conn_out.recv() => {
+ match new_fut {
+ Some(f) => collection.push(f),
+ None => break,
+ }
+ }
+ }
+ }
+ debug!("Collecting last open HTTP connections.");
+ while let Some(conn_res) = collection.next().await {
+ trace!("HTTP connection finished: {:?}", conn_res);
+ }
+ debug!("No more HTTP connections to collect");
+ });
+
+ loop {
+ let (stream, client_addr) = tokio::select! {
+ acc = listener.accept() => acc?,
+ _ = &mut shutdown_signal => break,
+ };
+
+ let io = TokioIo::new(stream);
+
+ let handler = handler.clone();
+ let serve = move |req: Request<IncomingBody>| handler(req, client_addr.clone());
+
+ let fut = tokio::task::spawn(async move {
+ let io = Box::pin(io);
+ if let Err(e) = http1::Builder::new()
+ .serve_connection(io, service_fn(serve))
+ .await
+ {
+ debug!("Error handling HTTP connection: {}", e);
+ }
+ });
+ conn_in.send(fut)?;
+ }
+
+ connection_collector.await?;
+
+ Ok(())
+}
diff --git a/src/garage/server.rs b/src/garage/server.rs
index ac76a44d..de8ac9e2 100644
--- a/src/garage/server.rs
+++ b/src/garage/server.rs
@@ -113,12 +113,11 @@ pub async fn run_server(config_file: PathBuf, secrets: Secrets) -> Result<(), Er
if let Some(web_config) = &config.s3_web {
info!("Initializing web server...");
+ let web_server = WebServer::new(garage.clone(), web_config.root_domain.clone());
servers.push((
"Web",
- tokio::spawn(WebServer::run(
- garage.clone(),
+ tokio::spawn(web_server.run(
web_config.bind_addr.clone(),
- web_config.root_domain.clone(),
wait_from(watch_cancel.clone()),
)),
));
diff --git a/src/web/web_server.rs b/src/web/web_server.rs
index 668a897a..766e3829 100644
--- a/src/web/web_server.rs
+++ b/src/web/web_server.rs
@@ -4,16 +4,12 @@ use std::{convert::Infallible, sync::Arc};
use futures::future::Future;
-use hyper::server::conn::http1;
use hyper::{
body::Incoming as IncomingBody,
header::{HeaderValue, HOST},
- service::service_fn,
Method, Request, Response, StatusCode,
};
-use hyper_util::rt::TokioIo;
-use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, UnixListener};
use opentelemetry::{
@@ -25,6 +21,7 @@ use opentelemetry::{
use crate::error::*;
+use garage_api::generic_server::{server_loop, UnixListenerOn};
use garage_api::helpers::*;
use garage_api::s3::cors::{add_cors_headers, find_matching_cors_rule, handle_options_for_bucket};
use garage_api::s3::error::{
@@ -75,35 +72,29 @@ pub struct WebServer {
impl WebServer {
/// Run a web server
- pub async fn run(
- garage: Arc<Garage>,
- bind_addr: UnixOrTCPSocketAddress,
- root_domain: String,
- shutdown_signal: impl Future<Output = ()>,
- ) -> Result<(), GarageError> {
+ pub fn new(garage: Arc<Garage>, root_domain: String) -> Arc<Self> {
let metrics = Arc::new(WebMetrics::new());
- let web_server = Arc::new(WebServer {
+ Arc::new(WebServer {
garage,
metrics,
root_domain,
- });
+ })
+ }
+ pub async fn run(
+ self: Arc<Self>,
+ bind_addr: UnixOrTCPSocketAddress,
+ shutdown_signal: impl Future<Output = ()>,
+ ) -> Result<(), GarageError> {
info!("Web server listening on {}", bind_addr);
- tokio::pin!(shutdown_signal);
-
match bind_addr {
UnixOrTCPSocketAddress::TCPSocket(addr) => {
let listener = TcpListener::bind(addr).await?;
- loop {
- let (stream, client_addr) = tokio::select! {
- acc = listener.accept() => acc?,
- _ = &mut shutdown_signal => break,
- };
-
- web_server.launch_handler(stream, client_addr.to_string());
- }
+ let handler =
+ move |stream, socketaddr| self.clone().handle_request(stream, socketaddr);
+ server_loop(listener, handler, shutdown_signal).await
}
UnixOrTCPSocketAddress::UnixSocket(ref path) => {
if path.exists() {
@@ -111,50 +102,22 @@ impl WebServer {
}
let listener = UnixListener::bind(path)?;
+ let listener = UnixListenerOn(listener, path.display().to_string());
fs::set_permissions(path, Permissions::from_mode(0o222))?;
- loop {
- let (stream, _) = tokio::select! {
- acc = listener.accept() => acc?,
- _ = &mut shutdown_signal => break,
- };
-
- web_server.launch_handler(stream, path.display().to_string());
- }
+ let handler =
+ move |stream, socketaddr| self.clone().handle_request(stream, socketaddr);
+ server_loop(listener, handler, shutdown_signal).await
}
- };
-
- Ok(())
- }
-
- fn launch_handler<S>(self: &Arc<Self>, stream: S, client_addr: String)
- where
- S: AsyncRead + AsyncWrite + Send + Sync + 'static,
- {
- let this = self.clone();
- let io = TokioIo::new(stream);
-
- let serve = move |req: Request<IncomingBody>| {
- this.clone().handle_request(req, client_addr.to_string())
- };
-
- tokio::task::spawn(async move {
- let io = Box::pin(io);
- if let Err(e) = http1::Builder::new()
- .serve_connection(io, service_fn(serve))
- .await
- {
- debug!("Error handling HTTP connection: {}", e);
- }
- });
+ }
}
async fn handle_request(
self: Arc<Self>,
req: Request<IncomingBody>,
addr: String,
- ) -> Result<Response<BoxBody<Error>>, Infallible> {
+ ) -> Result<Response<BoxBody<Error>>, http::Error> {
if let Ok(forwarded_for_ip_addr) =
forwarded_headers::handle_forwarded_for_headers(req.headers())
{