From 2fe82be3bcb326af04c4c862431237c576ed1152 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Thu, 23 Apr 2020 14:40:59 +0000 Subject: RPC to ourself do not pass through serialization + HTTPS --- src/rpc_client.rs | 45 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 9 deletions(-) (limited to 'src/rpc_client.rs') diff --git a/src/rpc_client.rs b/src/rpc_client.rs index 8bc3fe50..e78079c2 100644 --- a/src/rpc_client.rs +++ b/src/rpc_client.rs @@ -1,10 +1,13 @@ use std::borrow::Borrow; use std::marker::PhantomData; use std::net::SocketAddr; +use std::pin::Pin; use std::sync::Arc; use std::time::Duration; +use arc_swap::ArcSwapOption; use bytes::IntoBuf; +use futures::future::Future; use futures::stream::futures_unordered::FuturesUnordered; use futures::stream::StreamExt; use futures_util::future::FutureExt; @@ -47,10 +50,15 @@ impl RequestStrategy { } } +pub type LocalHandlerFn = + Box) -> Pin> + Send>> + Send + Sync>; + pub struct RpcClient { status: watch::Receiver>, background: Arc, + local_handler: ArcSwapOption<(UUID, LocalHandlerFn)>, + pub rpc_addr_client: RpcAddrClient, } @@ -64,19 +72,38 @@ impl RpcClient { rpc_addr_client: rac, background, status, + local_handler: ArcSwapOption::new(None), }) } + pub fn set_local_handler(&self, my_id: UUID, handler: F) + where + F: Fn(Arc) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, + { + let handler_arc = Arc::new(handler); + let handler: LocalHandlerFn = Box::new(move |msg| { + let handler_arc2 = handler_arc.clone(); + Box::pin(async move { handler_arc2(msg).await }) + }); + self.local_handler.swap(Some(Arc::new((my_id, handler)))); + } + pub fn by_addr(&self) -> &RpcAddrClient { &self.rpc_addr_client } - pub async fn call, N: Borrow>( - &self, - to: N, - msg: MB, - timeout: Duration, - ) -> Result { + pub async fn call(&self, to: UUID, msg: M, timeout: Duration) -> Result { + self.call_arc(to, Arc::new(msg), timeout).await + } + + pub async fn call_arc(&self, to: UUID, msg: Arc, timeout: Duration) -> Result { + if let Some(lh) = self.local_handler.load_full() { + let (my_id, local_handler) = lh.as_ref(); + if to.borrow() == my_id { + return local_handler(msg).await; + } + } let addr = { let status = self.status.borrow().clone(); match status.nodes.get(to.borrow()) { @@ -96,7 +123,7 @@ impl RpcClient { let msg = Arc::new(msg); let mut resp_stream = to .iter() - .map(|to| self.call(to, msg.clone(), timeout)) + .map(|to| self.call_arc(*to, msg.clone(), timeout)) .collect::>(); let mut results = vec![]; @@ -121,7 +148,7 @@ impl RpcClient { .map(|to| { let self2 = self.clone(); let msg = msg.clone(); - async move { self2.call(to, msg, timeout).await } + async move { self2.call_arc(to, msg, timeout).await } }) .collect::>(); @@ -155,7 +182,7 @@ impl RpcClient { resp_stream.collect::>().await; Ok(()) }); - self.clone().background.spawn(wait_finished_fut.map(|x| { + self.background.spawn(wait_finished_fut.map(|x| { x.unwrap_or_else(|e| Err(Error::Message(format!("Await failed: {}", e)))) })); } -- cgit v1.2.3