diff options
-rw-r--r-- | src/api/generic_server.rs | 155 | ||||
-rw-r--r-- | src/garage/server.rs | 5 | ||||
-rw-r--r-- | src/web/web_server.rs | 75 |
3 files changed, 132 insertions, 103 deletions
diff --git a/src/api/generic_server.rs b/src/api/generic_server.rs index 832f2da3..e3005f8a 100644 --- a/src/api/generic_server.rs +++ b/src/api/generic_server.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use async_trait::async_trait; use futures::future::Future; +use futures::stream::{futures_unordered::FuturesUnordered, StreamExt}; use http_body_util::BodyExt; use hyper::header::HeaderValue; @@ -15,7 +16,7 @@ use hyper::{HeaderMap, StatusCode}; use hyper_util::rt::TokioIo; use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::net::{TcpListener, UnixListener}; +use tokio::net::{TcpListener, TcpStream, UnixListener, UnixStream}; use opentelemetry::{ global, @@ -110,20 +111,12 @@ impl<A: ApiHandler> ApiServer<A> { bind_addr ); - tokio::pin!(shutdown_signal); - match bind_addr { UnixOrTCPSocketAddress::TCPSocket(addr) => { let listener = TcpListener::bind(addr).await?; - loop { - let (stream, client_addr) = tokio::select! { - acc = listener.accept() => acc?, - _ = &mut shutdown_signal => break, - }; - - self.launch_handler(stream, client_addr.to_string()); - } + let handler = move |request, socketaddr| self.clone().handler(request, socketaddr); + server_loop(listener, handler, shutdown_signal).await } UnixOrTCPSocketAddress::UnixSocket(ref path) => { if path.exists() { @@ -131,52 +124,24 @@ impl<A: ApiHandler> ApiServer<A> { } let listener = UnixListener::bind(path)?; + let listener = UnixListenerOn(listener, path.display().to_string()); fs::set_permissions( path, Permissions::from_mode(unix_bind_addr_mode.unwrap_or(0o222)), )?; - loop { - let (stream, _) = tokio::select! { - acc = listener.accept() => acc?, - _ = &mut shutdown_signal => break, - }; - - self.launch_handler(stream, path.display().to_string()); - } - } - }; - - Ok(()) - } - - fn launch_handler<S>(self: &Arc<Self>, stream: S, client_addr: String) - where - S: AsyncRead + AsyncWrite + Send + Sync + 'static, - { - let this = self.clone(); - let io = TokioIo::new(stream); - - let serve = - move |req: Request<IncomingBody>| this.clone().handler(req, client_addr.to_string()); - - tokio::task::spawn(async move { - let io = Box::pin(io); - if let Err(e) = http1::Builder::new() - .serve_connection(io, service_fn(serve)) - .await - { - debug!("Error handling HTTP connection: {}", e); + let handler = move |request, socketaddr| self.clone().handler(request, socketaddr); + server_loop(listener, handler, shutdown_signal).await } - }); + } } async fn handler( self: Arc<Self>, req: Request<IncomingBody>, addr: String, - ) -> Result<Response<BoxBody<A::Error>>, GarageError> { + ) -> Result<Response<BoxBody<A::Error>>, http::Error> { let uri = req.uri().clone(); if let Ok(forwarded_for_ip_addr) = @@ -278,3 +243,105 @@ impl<A: ApiHandler> ApiServer<A> { res } } + +// ==== helper functions ==== + +#[async_trait] +pub trait Accept: Send + Sync + 'static { + type Stream: AsyncRead + AsyncWrite + Send + Sync + 'static; + async fn accept(&self) -> std::io::Result<(Self::Stream, String)>; +} + +#[async_trait] +impl Accept for TcpListener { + type Stream = TcpStream; + async fn accept(&self) -> std::io::Result<(Self::Stream, String)> { + self.accept() + .await + .map(|(stream, addr)| (stream, addr.to_string())) + } +} + +pub struct UnixListenerOn(pub UnixListener, pub String); + +#[async_trait] +impl Accept for UnixListenerOn { + type Stream = UnixStream; + async fn accept(&self) -> std::io::Result<(Self::Stream, String)> { + self.0 + .accept() + .await + .map(|(stream, _addr)| (stream, self.1.clone())) + } +} + +pub async fn server_loop<A, H, F, E>( + listener: A, + handler: H, + shutdown_signal: impl Future<Output = ()>, +) -> Result<(), GarageError> +where + A: Accept, + H: Fn(Request<IncomingBody>, String) -> F + Send + Sync + Clone + 'static, + F: Future<Output = Result<Response<BoxBody<E>>, http::Error>> + Send + 'static, + E: Send + Sync + std::error::Error + 'static, +{ + tokio::pin!(shutdown_signal); + + let (conn_in, mut conn_out) = tokio::sync::mpsc::unbounded_channel(); + let connection_collector = tokio::spawn(async move { + let mut collection = FuturesUnordered::new(); + loop { + let collect_next = async { + if collection.is_empty() { + futures::future::pending().await + } else { + collection.next().await + } + }; + tokio::select! { + result = collect_next => { + trace!("HTTP connection finished: {:?}", result); + } + new_fut = conn_out.recv() => { + match new_fut { + Some(f) => collection.push(f), + None => break, + } + } + } + } + debug!("Collecting last open HTTP connections."); + while let Some(conn_res) = collection.next().await { + trace!("HTTP connection finished: {:?}", conn_res); + } + debug!("No more HTTP connections to collect"); + }); + + loop { + let (stream, client_addr) = tokio::select! { + acc = listener.accept() => acc?, + _ = &mut shutdown_signal => break, + }; + + let io = TokioIo::new(stream); + + let handler = handler.clone(); + let serve = move |req: Request<IncomingBody>| handler(req, client_addr.clone()); + + let fut = tokio::task::spawn(async move { + let io = Box::pin(io); + if let Err(e) = http1::Builder::new() + .serve_connection(io, service_fn(serve)) + .await + { + debug!("Error handling HTTP connection: {}", e); + } + }); + conn_in.send(fut)?; + } + + connection_collector.await?; + + Ok(()) +} diff --git a/src/garage/server.rs b/src/garage/server.rs index ac76a44d..de8ac9e2 100644 --- a/src/garage/server.rs +++ b/src/garage/server.rs @@ -113,12 +113,11 @@ pub async fn run_server(config_file: PathBuf, secrets: Secrets) -> Result<(), Er if let Some(web_config) = &config.s3_web { info!("Initializing web server..."); + let web_server = WebServer::new(garage.clone(), web_config.root_domain.clone()); servers.push(( "Web", - tokio::spawn(WebServer::run( - garage.clone(), + tokio::spawn(web_server.run( web_config.bind_addr.clone(), - web_config.root_domain.clone(), wait_from(watch_cancel.clone()), )), )); diff --git a/src/web/web_server.rs b/src/web/web_server.rs index 668a897a..766e3829 100644 --- a/src/web/web_server.rs +++ b/src/web/web_server.rs @@ -4,16 +4,12 @@ use std::{convert::Infallible, sync::Arc}; use futures::future::Future; -use hyper::server::conn::http1; use hyper::{ body::Incoming as IncomingBody, header::{HeaderValue, HOST}, - service::service_fn, Method, Request, Response, StatusCode, }; -use hyper_util::rt::TokioIo; -use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::{TcpListener, UnixListener}; use opentelemetry::{ @@ -25,6 +21,7 @@ use opentelemetry::{ use crate::error::*; +use garage_api::generic_server::{server_loop, UnixListenerOn}; use garage_api::helpers::*; use garage_api::s3::cors::{add_cors_headers, find_matching_cors_rule, handle_options_for_bucket}; use garage_api::s3::error::{ @@ -75,35 +72,29 @@ pub struct WebServer { impl WebServer { /// Run a web server - pub async fn run( - garage: Arc<Garage>, - bind_addr: UnixOrTCPSocketAddress, - root_domain: String, - shutdown_signal: impl Future<Output = ()>, - ) -> Result<(), GarageError> { + pub fn new(garage: Arc<Garage>, root_domain: String) -> Arc<Self> { let metrics = Arc::new(WebMetrics::new()); - let web_server = Arc::new(WebServer { + Arc::new(WebServer { garage, metrics, root_domain, - }); + }) + } + pub async fn run( + self: Arc<Self>, + bind_addr: UnixOrTCPSocketAddress, + shutdown_signal: impl Future<Output = ()>, + ) -> Result<(), GarageError> { info!("Web server listening on {}", bind_addr); - tokio::pin!(shutdown_signal); - match bind_addr { UnixOrTCPSocketAddress::TCPSocket(addr) => { let listener = TcpListener::bind(addr).await?; - loop { - let (stream, client_addr) = tokio::select! { - acc = listener.accept() => acc?, - _ = &mut shutdown_signal => break, - }; - - web_server.launch_handler(stream, client_addr.to_string()); - } + let handler = + move |stream, socketaddr| self.clone().handle_request(stream, socketaddr); + server_loop(listener, handler, shutdown_signal).await } UnixOrTCPSocketAddress::UnixSocket(ref path) => { if path.exists() { @@ -111,50 +102,22 @@ impl WebServer { } let listener = UnixListener::bind(path)?; + let listener = UnixListenerOn(listener, path.display().to_string()); fs::set_permissions(path, Permissions::from_mode(0o222))?; - loop { - let (stream, _) = tokio::select! { - acc = listener.accept() => acc?, - _ = &mut shutdown_signal => break, - }; - - web_server.launch_handler(stream, path.display().to_string()); - } + let handler = + move |stream, socketaddr| self.clone().handle_request(stream, socketaddr); + server_loop(listener, handler, shutdown_signal).await } - }; - - Ok(()) - } - - fn launch_handler<S>(self: &Arc<Self>, stream: S, client_addr: String) - where - S: AsyncRead + AsyncWrite + Send + Sync + 'static, - { - let this = self.clone(); - let io = TokioIo::new(stream); - - let serve = move |req: Request<IncomingBody>| { - this.clone().handle_request(req, client_addr.to_string()) - }; - - tokio::task::spawn(async move { - let io = Box::pin(io); - if let Err(e) = http1::Builder::new() - .serve_connection(io, service_fn(serve)) - .await - { - debug!("Error handling HTTP connection: {}", e); - } - }); + } } async fn handle_request( self: Arc<Self>, req: Request<IncomingBody>, addr: String, - ) -> Result<Response<BoxBody<Error>>, Infallible> { + ) -> Result<Response<BoxBody<Error>>, http::Error> { if let Ok(forwarded_for_ip_addr) = forwarded_headers::handle_forwarded_for_headers(req.headers()) { |