aboutsummaryrefslogtreecommitdiff
path: root/src/rpc_client.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/rpc_client.rs')
-rw-r--r--src/rpc_client.rs247
1 files changed, 153 insertions, 94 deletions
diff --git a/src/rpc_client.rs b/src/rpc_client.rs
index f8da778c..6d26d86a 100644
--- a/src/rpc_client.rs
+++ b/src/rpc_client.rs
@@ -1,4 +1,5 @@
use std::borrow::Borrow;
+use std::marker::PhantomData;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
@@ -9,110 +10,166 @@ use futures::stream::StreamExt;
use futures_util::future::FutureExt;
use hyper::client::{Client, HttpConnector};
use hyper::{Body, Method, Request, StatusCode};
+use tokio::sync::watch;
+use crate::background::*;
use crate::data::*;
use crate::error::Error;
-use crate::membership::System;
-use crate::proto::Message;
+use crate::membership::Status;
+use crate::rpc_server::RpcMessage;
use crate::server::*;
use crate::tls_util;
-pub async fn rpc_call_many(
- sys: Arc<System>,
- to: &[UUID],
- msg: Message,
- timeout: Duration,
-) -> Vec<Result<Message, Error>> {
- let msg = Arc::new(msg);
- let mut resp_stream = to
- .iter()
- .map(|to| rpc_call(sys.clone(), to, msg.clone(), timeout))
- .collect::<FuturesUnordered<_>>();
-
- let mut results = vec![];
- while let Some(resp) = resp_stream.next().await {
- results.push(resp);
- }
- results
+pub struct RpcClient<M: RpcMessage> {
+ status: watch::Receiver<Arc<Status>>,
+ background: Arc<BackgroundRunner>,
+
+ pub rpc_addr_client: RpcAddrClient<M>,
}
-pub async fn rpc_try_call_many(
- sys: Arc<System>,
- to: &[UUID],
- msg: Message,
- stop_after: usize,
- timeout: Duration,
-) -> Result<Vec<Message>, Error> {
- let sys2 = sys.clone();
- let msg = Arc::new(msg);
- let mut resp_stream = to
- .to_vec()
- .into_iter()
- .map(move |to| rpc_call(sys2.clone(), to.clone(), msg.clone(), timeout))
- .collect::<FuturesUnordered<_>>();
-
- let mut results = vec![];
- let mut errors = vec![];
-
- while let Some(resp) = resp_stream.next().await {
- match resp {
- Ok(msg) => {
- results.push(msg);
- if results.len() >= stop_after {
- break;
+impl<M: RpcMessage + 'static> RpcClient<M> {
+ pub fn new(
+ rac: RpcAddrClient<M>,
+ background: Arc<BackgroundRunner>,
+ status: watch::Receiver<Arc<Status>>,
+ ) -> Arc<Self> {
+ Arc::new(Self {
+ rpc_addr_client: rac,
+ background,
+ status,
+ })
+ }
+
+ pub fn by_addr(&self) -> &RpcAddrClient<M> {
+ &self.rpc_addr_client
+ }
+
+ pub async fn call<MB: Borrow<M>, N: Borrow<UUID>>(
+ &self,
+ to: N,
+ msg: MB,
+ timeout: Duration,
+ ) -> Result<M, Error> {
+ let addr = {
+ let status = self.status.borrow().clone();
+ match status.nodes.get(to.borrow()) {
+ Some(status) => status.addr.clone(),
+ None => {
+ return Err(Error::Message(format!(
+ "Peer ID not found: {:?}",
+ to.borrow()
+ )))
}
}
- Err(e) => {
- errors.push(e);
- }
+ };
+ self.rpc_addr_client.call(&addr, msg, timeout).await
+ }
+
+ pub async fn call_many(&self, to: &[UUID], msg: M, timeout: Duration) -> Vec<Result<M, Error>> {
+ let msg = Arc::new(msg);
+ let mut resp_stream = to
+ .iter()
+ .map(|to| self.call(to, msg.clone(), timeout))
+ .collect::<FuturesUnordered<_>>();
+
+ let mut results = vec![];
+ while let Some(resp) = resp_stream.next().await {
+ results.push(resp);
}
+ results
}
- if results.len() >= stop_after {
- // Continue requests in background
- // TODO: make this optionnal (only usefull for write requests)
- sys.background.spawn(async move {
- resp_stream.collect::<Vec<_>>().await;
- Ok(())
- });
-
- Ok(results)
- } else {
- let mut msg = "Too many failures:".to_string();
- for e in errors {
- msg += &format!("\n{}", e);
+ pub async fn try_call_many(
+ self: &Arc<Self>,
+ to: &[UUID],
+ msg: M,
+ stop_after: usize,
+ timeout: Duration,
+ ) -> Result<Vec<M>, Error> {
+ let msg = Arc::new(msg);
+ let mut resp_stream = to
+ .to_vec()
+ .into_iter()
+ .map(|to| {
+ let self2 = self.clone();
+ let msg = msg.clone();
+ async move { self2.call(to.clone(), msg, timeout).await }
+ })
+ .collect::<FuturesUnordered<_>>();
+
+ let mut results = vec![];
+ let mut errors = vec![];
+
+ while let Some(resp) = resp_stream.next().await {
+ match resp {
+ Ok(msg) => {
+ results.push(msg);
+ if results.len() >= stop_after {
+ break;
+ }
+ }
+ Err(e) => {
+ errors.push(e);
+ }
+ }
+ }
+
+ if results.len() >= stop_after {
+ // Continue requests in background
+ // TODO: make this optionnal (only usefull for write requests)
+ self.clone().background.spawn(async move {
+ resp_stream.collect::<Vec<_>>().await;
+ Ok(())
+ });
+
+ Ok(results)
+ } else {
+ let mut msg = "Too many failures:".to_string();
+ for e in errors {
+ msg += &format!("\n{}", e);
+ }
+ Err(Error::Message(msg))
}
- Err(Error::Message(msg))
}
}
-pub async fn rpc_call<M: Borrow<Message>, N: Borrow<UUID>>(
- sys: Arc<System>,
- to: N,
- msg: M,
- timeout: Duration,
-) -> Result<Message, Error> {
- let addr = {
- let status = sys.status.borrow().clone();
- match status.nodes.get(to.borrow()) {
- Some(status) => status.addr.clone(),
- None => {
- return Err(Error::Message(format!(
- "Peer ID not found: {:?}",
- to.borrow()
- )))
- }
+pub struct RpcAddrClient<M: RpcMessage> {
+ phantom: PhantomData<M>,
+
+ pub http_client: Arc<RpcHttpClient>,
+ pub path: String,
+}
+
+impl<M: RpcMessage> RpcAddrClient<M> {
+ pub fn new(http_client: Arc<RpcHttpClient>, path: String) -> Self {
+ Self {
+ phantom: PhantomData::default(),
+ http_client: http_client,
+ path,
}
- };
- sys.rpc_client.call(&addr, msg, timeout).await
+ }
+
+ pub async fn call<MB>(
+ &self,
+ to_addr: &SocketAddr,
+ msg: MB,
+ timeout: Duration,
+ ) -> Result<M, Error>
+ where
+ MB: Borrow<M>,
+ {
+ self.http_client
+ .call(&self.path, to_addr, msg, timeout)
+ .await
+ }
}
-pub enum RpcClient {
+pub enum RpcHttpClient {
HTTP(Client<HttpConnector, hyper::Body>),
HTTPS(Client<tls_util::HttpsConnectorFixedDnsname<HttpConnector>, hyper::Body>),
}
-impl RpcClient {
+impl RpcHttpClient {
pub fn new(tls_config: &Option<TlsConfig>) -> Result<Self, Error> {
if let Some(cf) = tls_config {
let ca_certs = tls_util::load_certs(&cf.ca_cert)?;
@@ -130,21 +187,26 @@ impl RpcClient {
let connector =
tls_util::HttpsConnectorFixedDnsname::<HttpConnector>::new(config, "garage");
- Ok(RpcClient::HTTPS(Client::builder().build(connector)))
+ Ok(RpcHttpClient::HTTPS(Client::builder().build(connector)))
} else {
- Ok(RpcClient::HTTP(Client::new()))
+ Ok(RpcHttpClient::HTTP(Client::new()))
}
}
- pub async fn call<M: Borrow<Message>>(
+ async fn call<M, MB>(
&self,
+ path: &str,
to_addr: &SocketAddr,
- msg: M,
+ msg: MB,
timeout: Duration,
- ) -> Result<Message, Error> {
+ ) -> Result<M, Error>
+ where
+ MB: Borrow<M>,
+ M: RpcMessage,
+ {
let uri = match self {
- RpcClient::HTTP(_) => format!("http://{}/rpc", to_addr),
- RpcClient::HTTPS(_) => format!("https://{}/rpc", to_addr),
+ RpcHttpClient::HTTP(_) => format!("http://{}/{}", to_addr, path),
+ RpcHttpClient::HTTPS(_) => format!("https://{}/{}", to_addr, path),
};
let req = Request::builder()
@@ -153,8 +215,8 @@ impl RpcClient {
.body(Body::from(rmp_to_vec_all_named(msg.borrow())?))?;
let resp_fut = match self {
- RpcClient::HTTP(client) => client.request(req).fuse(),
- RpcClient::HTTPS(client) => client.request(req).fuse(),
+ RpcHttpClient::HTTP(client) => client.request(req).fuse(),
+ RpcHttpClient::HTTPS(client) => client.request(req).fuse(),
};
let resp = tokio::time::timeout(timeout, resp_fut)
.await?
@@ -168,11 +230,8 @@ impl RpcClient {
if resp.status() == StatusCode::OK {
let body = hyper::body::to_bytes(resp.into_body()).await?;
- let msg = rmp_serde::decode::from_read::<_, Message>(body.into_buf())?;
- match msg {
- Message::Error(e) => Err(Error::RPCError(e)),
- x => Ok(x),
- }
+ let msg = rmp_serde::decode::from_read::<_, Result<M, String>>(body.into_buf())?;
+ msg.map_err(Error::RPCError)
} else {
Err(Error::RPCError(format!("Status code {}", resp.status())))
}