diff options
Diffstat (limited to 'src/rpc_client.rs')
-rw-r--r-- | src/rpc_client.rs | 45 |
1 files changed, 36 insertions, 9 deletions
diff --git a/src/rpc_client.rs b/src/rpc_client.rs index 8bc3fe50..e78079c2 100644 --- a/src/rpc_client.rs +++ b/src/rpc_client.rs @@ -1,10 +1,13 @@ use std::borrow::Borrow; use std::marker::PhantomData; use std::net::SocketAddr; +use std::pin::Pin; use std::sync::Arc; use std::time::Duration; +use arc_swap::ArcSwapOption; use bytes::IntoBuf; +use futures::future::Future; use futures::stream::futures_unordered::FuturesUnordered; use futures::stream::StreamExt; use futures_util::future::FutureExt; @@ -47,10 +50,15 @@ impl RequestStrategy { } } +pub type LocalHandlerFn<M> = + Box<dyn Fn(Arc<M>) -> Pin<Box<dyn Future<Output = Result<M, Error>> + Send>> + Send + Sync>; + pub struct RpcClient<M: RpcMessage> { status: watch::Receiver<Arc<Status>>, background: Arc<BackgroundRunner>, + local_handler: ArcSwapOption<(UUID, LocalHandlerFn<M>)>, + pub rpc_addr_client: RpcAddrClient<M>, } @@ -64,19 +72,38 @@ impl<M: RpcMessage + 'static> RpcClient<M> { rpc_addr_client: rac, background, status, + local_handler: ArcSwapOption::new(None), }) } + pub fn set_local_handler<F, Fut>(&self, my_id: UUID, handler: F) + where + F: Fn(Arc<M>) -> Fut + Send + Sync + 'static, + Fut: Future<Output = Result<M, Error>> + Send + 'static, + { + let handler_arc = Arc::new(handler); + let handler: LocalHandlerFn<M> = Box::new(move |msg| { + let handler_arc2 = handler_arc.clone(); + Box::pin(async move { handler_arc2(msg).await }) + }); + self.local_handler.swap(Some(Arc::new((my_id, handler)))); + } + pub fn by_addr(&self) -> &RpcAddrClient<M> { &self.rpc_addr_client } - pub async fn call<MB: Borrow<M>, N: Borrow<UUID>>( - &self, - to: N, - msg: MB, - timeout: Duration, - ) -> Result<M, Error> { + pub async fn call(&self, to: UUID, msg: M, timeout: Duration) -> Result<M, Error> { + self.call_arc(to, Arc::new(msg), timeout).await + } + + pub async fn call_arc(&self, to: UUID, msg: Arc<M>, timeout: Duration) -> Result<M, Error> { + if let Some(lh) = self.local_handler.load_full() { + let (my_id, local_handler) = lh.as_ref(); + if to.borrow() == my_id { + return local_handler(msg).await; + } + } let addr = { let status = self.status.borrow().clone(); match status.nodes.get(to.borrow()) { @@ -96,7 +123,7 @@ impl<M: RpcMessage + 'static> RpcClient<M> { let msg = Arc::new(msg); let mut resp_stream = to .iter() - .map(|to| self.call(to, msg.clone(), timeout)) + .map(|to| self.call_arc(*to, msg.clone(), timeout)) .collect::<FuturesUnordered<_>>(); let mut results = vec![]; @@ -121,7 +148,7 @@ impl<M: RpcMessage + 'static> RpcClient<M> { .map(|to| { let self2 = self.clone(); let msg = msg.clone(); - async move { self2.call(to, msg, timeout).await } + async move { self2.call_arc(to, msg, timeout).await } }) .collect::<FuturesUnordered<_>>(); @@ -155,7 +182,7 @@ impl<M: RpcMessage + 'static> RpcClient<M> { resp_stream.collect::<Vec<_>>().await; Ok(()) }); - self.clone().background.spawn(wait_finished_fut.map(|x| { + self.background.spawn(wait_finished_fut.map(|x| { x.unwrap_or_else(|e| Err(Error::Message(format!("Await failed: {}", e)))) })); } |