diff options
Diffstat (limited to 'src/rpc_client.rs')
-rw-r--r-- | src/rpc_client.rs | 67 |
1 files changed, 48 insertions, 19 deletions
diff --git a/src/rpc_client.rs b/src/rpc_client.rs index e78079c2..c083fcfd 100644 --- a/src/rpc_client.rs +++ b/src/rpc_client.rs @@ -2,11 +2,13 @@ use std::borrow::Borrow; use std::marker::PhantomData; use std::net::SocketAddr; use std::pin::Pin; +use std::sync::atomic::Ordering; use std::sync::Arc; use std::time::Duration; use arc_swap::ArcSwapOption; use bytes::IntoBuf; +use err_derive::Error; use futures::future::Future; use futures::stream::futures_unordered::FuturesUnordered; use futures::stream::StreamExt; @@ -25,6 +27,22 @@ use crate::tls_util; const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10); +#[derive(Debug, Error)] +pub enum RPCError { + #[error(display = "Node is down: {:?}.", _0)] + NodeDown(UUID), + #[error(display = "Timeout: {}", _0)] + Timeout(#[error(source)] tokio::time::Elapsed), + #[error(display = "HTTP error: {}", _0)] + HTTP(#[error(source)] http::Error), + #[error(display = "Hyper error: {}", _0)] + Hyper(#[error(source)] hyper::Error), + #[error(display = "Messagepack encode error: {}", _0)] + RMPEncode(#[error(source)] rmp_serde::encode::Error), + #[error(display = "Messagepack decode error: {}", _0)] + RMPDecode(#[error(source)] rmp_serde::decode::Error), +} + #[derive(Copy, Clone)] pub struct RequestStrategy { pub rs_timeout: Duration, @@ -104,19 +122,34 @@ impl<M: RpcMessage + 'static> RpcClient<M> { return local_handler(msg).await; } } - let addr = { - let status = self.status.borrow().clone(); - match status.nodes.get(to.borrow()) { - Some(status) => status.addr, - None => { - return Err(Error::Message(format!( - "Peer ID not found: {:?}", - to.borrow() - ))) + let status = self.status.borrow().clone(); + let node_status = match status.nodes.get(&to) { + Some(node_status) => { + if node_status.is_up() { + node_status + } else { + return Err(Error::from(RPCError::NodeDown(to))); } } + None => { + return Err(Error::Message(format!( + "Peer ID not found: {:?}", + to.borrow() + ))) + } }; - self.rpc_addr_client.call(&addr, msg, timeout).await + match self + .rpc_addr_client + .call(&node_status.addr, msg, timeout) + .await + { + Err(rpc_error) => { + node_status.num_failures.fetch_add(1, Ordering::SeqCst); + // TODO: Save failure info somewhere + Err(Error::from(rpc_error)) + } + Ok(x) => x, + } } pub async fn call_many(&self, to: &[UUID], msg: M, timeout: Duration) -> Vec<Result<M, Error>> { @@ -219,7 +252,7 @@ impl<M: RpcMessage> RpcAddrClient<M> { to_addr: &SocketAddr, msg: MB, timeout: Duration, - ) -> Result<M, Error> + ) -> Result<Result<M, Error>, RPCError> where MB: Borrow<M>, { @@ -276,7 +309,7 @@ impl RpcHttpClient { to_addr: &SocketAddr, msg: MB, timeout: Duration, - ) -> Result<M, Error> + ) -> Result<Result<M, Error>, RPCError> where MB: Borrow<M>, M: RpcMessage, @@ -318,13 +351,9 @@ impl RpcHttpClient { let status = resp.status(); let body = hyper::body::to_bytes(resp.into_body()).await?; - match rmp_serde::decode::from_read::<_, Result<M, String>>(body.into_buf()) { - Err(e) => Err(Error::RPCError( - format!("Invalid reply (deserialize error: {})", e), - status, - )), - Ok(Err(e)) => Err(Error::RPCError(e, status)), - Ok(Ok(x)) => Ok(x), + match rmp_serde::decode::from_read::<_, Result<M, String>>(body.into_buf())? { + Err(e) => Ok(Err(Error::RemoteError(e, status))), + Ok(x) => Ok(Ok(x)), } } } |