aboutsummaryrefslogtreecommitdiff
path: root/src/rpc_client.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/rpc_client.rs')
-rw-r--r--src/rpc_client.rs45
1 files changed, 36 insertions, 9 deletions
diff --git a/src/rpc_client.rs b/src/rpc_client.rs
index 8bc3fe50..e78079c2 100644
--- a/src/rpc_client.rs
+++ b/src/rpc_client.rs
@@ -1,10 +1,13 @@
use std::borrow::Borrow;
use std::marker::PhantomData;
use std::net::SocketAddr;
+use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
+use arc_swap::ArcSwapOption;
use bytes::IntoBuf;
+use futures::future::Future;
use futures::stream::futures_unordered::FuturesUnordered;
use futures::stream::StreamExt;
use futures_util::future::FutureExt;
@@ -47,10 +50,15 @@ impl RequestStrategy {
}
}
+pub type LocalHandlerFn<M> =
+ Box<dyn Fn(Arc<M>) -> Pin<Box<dyn Future<Output = Result<M, Error>> + Send>> + Send + Sync>;
+
pub struct RpcClient<M: RpcMessage> {
status: watch::Receiver<Arc<Status>>,
background: Arc<BackgroundRunner>,
+ local_handler: ArcSwapOption<(UUID, LocalHandlerFn<M>)>,
+
pub rpc_addr_client: RpcAddrClient<M>,
}
@@ -64,19 +72,38 @@ impl<M: RpcMessage + 'static> RpcClient<M> {
rpc_addr_client: rac,
background,
status,
+ local_handler: ArcSwapOption::new(None),
})
}
+ pub fn set_local_handler<F, Fut>(&self, my_id: UUID, handler: F)
+ where
+ F: Fn(Arc<M>) -> Fut + Send + Sync + 'static,
+ Fut: Future<Output = Result<M, Error>> + Send + 'static,
+ {
+ let handler_arc = Arc::new(handler);
+ let handler: LocalHandlerFn<M> = Box::new(move |msg| {
+ let handler_arc2 = handler_arc.clone();
+ Box::pin(async move { handler_arc2(msg).await })
+ });
+ self.local_handler.swap(Some(Arc::new((my_id, handler))));
+ }
+
pub fn by_addr(&self) -> &RpcAddrClient<M> {
&self.rpc_addr_client
}
- pub async fn call<MB: Borrow<M>, N: Borrow<UUID>>(
- &self,
- to: N,
- msg: MB,
- timeout: Duration,
- ) -> Result<M, Error> {
+ pub async fn call(&self, to: UUID, msg: M, timeout: Duration) -> Result<M, Error> {
+ self.call_arc(to, Arc::new(msg), timeout).await
+ }
+
+ pub async fn call_arc(&self, to: UUID, msg: Arc<M>, timeout: Duration) -> Result<M, Error> {
+ if let Some(lh) = self.local_handler.load_full() {
+ let (my_id, local_handler) = lh.as_ref();
+ if to.borrow() == my_id {
+ return local_handler(msg).await;
+ }
+ }
let addr = {
let status = self.status.borrow().clone();
match status.nodes.get(to.borrow()) {
@@ -96,7 +123,7 @@ impl<M: RpcMessage + 'static> RpcClient<M> {
let msg = Arc::new(msg);
let mut resp_stream = to
.iter()
- .map(|to| self.call(to, msg.clone(), timeout))
+ .map(|to| self.call_arc(*to, msg.clone(), timeout))
.collect::<FuturesUnordered<_>>();
let mut results = vec![];
@@ -121,7 +148,7 @@ impl<M: RpcMessage + 'static> RpcClient<M> {
.map(|to| {
let self2 = self.clone();
let msg = msg.clone();
- async move { self2.call(to, msg, timeout).await }
+ async move { self2.call_arc(to, msg, timeout).await }
})
.collect::<FuturesUnordered<_>>();
@@ -155,7 +182,7 @@ impl<M: RpcMessage + 'static> RpcClient<M> {
resp_stream.collect::<Vec<_>>().await;
Ok(())
});
- self.clone().background.spawn(wait_finished_fut.map(|x| {
+ self.background.spawn(wait_finished_fut.map(|x| {
x.unwrap_or_else(|e| Err(Error::Message(format!("Await failed: {}", e))))
}));
}