From 6f13d083ab188060d2a2dc5f619070a445fe61ba Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Wed, 3 Nov 2021 17:00:40 +0100 Subject: Add semaphore to limit RAM used by buffered outgoing requests --- src/rpc/rpc_helper.rs | 33 ++++++++++++++++++++++++++++++--- src/rpc/system.rs | 5 +---- src/util/error.rs | 3 +++ 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/src/rpc/rpc_helper.rs b/src/rpc/rpc_helper.rs index 8c7cc681..cdac6f14 100644 --- a/src/rpc/rpc_helper.rs +++ b/src/rpc/rpc_helper.rs @@ -7,6 +7,7 @@ use futures::stream::futures_unordered::FuturesUnordered; use futures::stream::StreamExt; use futures_util::future::FutureExt; use tokio::select; +use tokio::sync::Semaphore; pub use netapp::endpoint::{Endpoint, EndpointHandler, Message as Rpc}; use netapp::peering::fullmesh::FullMeshPeeringStrategy; @@ -14,11 +15,16 @@ pub use netapp::proto::*; pub use netapp::{NetApp, NodeID}; use garage_util::background::BackgroundRunner; -use garage_util::data::Uuid; +use garage_util::data::*; use garage_util::error::Error; const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10); +// Try to never have more than 200MB of outgoing requests +// buffered at the same time. Other requests are queued until +// space is freed. +const REQUEST_BUFFER_SIZE: usize = 200 * 1024 * 1024; + /// Strategy to apply when making RPC #[derive(Copy, Clone)] pub struct RequestStrategy { @@ -64,9 +70,21 @@ impl RequestStrategy { pub struct RpcHelper { pub(crate) fullmesh: Arc, pub(crate) background: Arc, + request_buffer_semaphore: Arc, } impl RpcHelper { + pub(crate) fn new( + fullmesh: Arc, + background: Arc, + ) -> Self { + Self { + fullmesh, + background, + request_buffer_semaphore: Arc::new(Semaphore::new(REQUEST_BUFFER_SIZE)), + } + } + pub async fn call( &self, endpoint: &Endpoint, @@ -92,10 +110,19 @@ impl RpcHelper { M: Rpc>, H: EndpointHandler, { + let msg_size = rmp_to_vec_all_named(&msg)?.len() as u32; + let permit = self.request_buffer_semaphore.acquire_many(msg_size).await?; + let node_id = to.into(); select! { - res = endpoint.call(&node_id, &msg, strat.rs_priority) => Ok(res??), - _ = tokio::time::sleep(strat.rs_timeout) => Err(Error::Timeout), + res = endpoint.call(&node_id, &msg, strat.rs_priority) => { + drop(permit); + Ok(res??) + } + _ = tokio::time::sleep(strat.rs_timeout) => { + drop(permit); + Err(Error::Timeout) + } } } diff --git a/src/rpc/system.rs b/src/rpc/system.rs index 8f5a1ec5..a518ef21 100644 --- a/src/rpc/system.rs +++ b/src/rpc/system.rs @@ -235,10 +235,7 @@ impl System { node_status: RwLock::new(HashMap::new()), netapp: netapp.clone(), fullmesh: fullmesh.clone(), - rpc: RpcHelper { - fullmesh, - background: background.clone(), - }, + rpc: RpcHelper::new(fullmesh, background.clone()), system_endpoint, replication_factor, rpc_listen_addr: config.rpc_bind_addr, diff --git a/src/util/error.rs b/src/util/error.rs index 626958da..ff03d05b 100644 --- a/src/util/error.rs +++ b/src/util/error.rs @@ -41,6 +41,9 @@ pub enum Error { #[error(display = "Tokio join error: {}", _0)] TokioJoin(#[error(source)] tokio::task::JoinError), + #[error(display = "Tokio semaphore acquire error: {}", _0)] + TokioSemAcquire(#[error(source)] tokio::sync::AcquireError), + #[error(display = "Remote error: {}", _0)] RemoteError(String), -- cgit v1.2.3