diff options
Diffstat (limited to 'src/rpc_client.rs')
-rw-r--r-- | src/rpc_client.rs | 247 |
1 files changed, 153 insertions, 94 deletions
diff --git a/src/rpc_client.rs b/src/rpc_client.rs index f8da778c..6d26d86a 100644 --- a/src/rpc_client.rs +++ b/src/rpc_client.rs @@ -1,4 +1,5 @@ use std::borrow::Borrow; +use std::marker::PhantomData; use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; @@ -9,110 +10,166 @@ use futures::stream::StreamExt; use futures_util::future::FutureExt; use hyper::client::{Client, HttpConnector}; use hyper::{Body, Method, Request, StatusCode}; +use tokio::sync::watch; +use crate::background::*; use crate::data::*; use crate::error::Error; -use crate::membership::System; -use crate::proto::Message; +use crate::membership::Status; +use crate::rpc_server::RpcMessage; use crate::server::*; use crate::tls_util; -pub async fn rpc_call_many( - sys: Arc<System>, - to: &[UUID], - msg: Message, - timeout: Duration, -) -> Vec<Result<Message, Error>> { - let msg = Arc::new(msg); - let mut resp_stream = to - .iter() - .map(|to| rpc_call(sys.clone(), to, msg.clone(), timeout)) - .collect::<FuturesUnordered<_>>(); - - let mut results = vec![]; - while let Some(resp) = resp_stream.next().await { - results.push(resp); - } - results +pub struct RpcClient<M: RpcMessage> { + status: watch::Receiver<Arc<Status>>, + background: Arc<BackgroundRunner>, + + pub rpc_addr_client: RpcAddrClient<M>, } -pub async fn rpc_try_call_many( - sys: Arc<System>, - to: &[UUID], - msg: Message, - stop_after: usize, - timeout: Duration, -) -> Result<Vec<Message>, Error> { - let sys2 = sys.clone(); - let msg = Arc::new(msg); - let mut resp_stream = to - .to_vec() - .into_iter() - .map(move |to| rpc_call(sys2.clone(), to.clone(), msg.clone(), timeout)) - .collect::<FuturesUnordered<_>>(); - - let mut results = vec![]; - let mut errors = vec![]; - - while let Some(resp) = resp_stream.next().await { - match resp { - Ok(msg) => { - results.push(msg); - if results.len() >= stop_after { - break; +impl<M: RpcMessage + 'static> RpcClient<M> { + pub fn new( + rac: RpcAddrClient<M>, + background: Arc<BackgroundRunner>, + status: watch::Receiver<Arc<Status>>, + ) -> Arc<Self> { + Arc::new(Self { + rpc_addr_client: rac, + background, + status, + }) + } + + pub fn by_addr(&self) -> &RpcAddrClient<M> { + &self.rpc_addr_client + } + + pub async fn call<MB: Borrow<M>, N: Borrow<UUID>>( + &self, + to: N, + msg: MB, + timeout: Duration, + ) -> Result<M, Error> { + let addr = { + let status = self.status.borrow().clone(); + match status.nodes.get(to.borrow()) { + Some(status) => status.addr.clone(), + None => { + return Err(Error::Message(format!( + "Peer ID not found: {:?}", + to.borrow() + ))) } } - Err(e) => { - errors.push(e); - } + }; + self.rpc_addr_client.call(&addr, msg, timeout).await + } + + pub async fn call_many(&self, to: &[UUID], msg: M, timeout: Duration) -> Vec<Result<M, Error>> { + let msg = Arc::new(msg); + let mut resp_stream = to + .iter() + .map(|to| self.call(to, msg.clone(), timeout)) + .collect::<FuturesUnordered<_>>(); + + let mut results = vec![]; + while let Some(resp) = resp_stream.next().await { + results.push(resp); } + results } - if results.len() >= stop_after { - // Continue requests in background - // TODO: make this optionnal (only usefull for write requests) - sys.background.spawn(async move { - resp_stream.collect::<Vec<_>>().await; - Ok(()) - }); - - Ok(results) - } else { - let mut msg = "Too many failures:".to_string(); - for e in errors { - msg += &format!("\n{}", e); + pub async fn try_call_many( + self: &Arc<Self>, + to: &[UUID], + msg: M, + stop_after: usize, + timeout: Duration, + ) -> Result<Vec<M>, Error> { + let msg = Arc::new(msg); + let mut resp_stream = to + .to_vec() + .into_iter() + .map(|to| { + let self2 = self.clone(); + let msg = msg.clone(); + async move { self2.call(to.clone(), msg, timeout).await } + }) + .collect::<FuturesUnordered<_>>(); + + let mut results = vec![]; + let mut errors = vec![]; + + while let Some(resp) = resp_stream.next().await { + match resp { + Ok(msg) => { + results.push(msg); + if results.len() >= stop_after { + break; + } + } + Err(e) => { + errors.push(e); + } + } + } + + if results.len() >= stop_after { + // Continue requests in background + // TODO: make this optionnal (only usefull for write requests) + self.clone().background.spawn(async move { + resp_stream.collect::<Vec<_>>().await; + Ok(()) + }); + + Ok(results) + } else { + let mut msg = "Too many failures:".to_string(); + for e in errors { + msg += &format!("\n{}", e); + } + Err(Error::Message(msg)) } - Err(Error::Message(msg)) } } -pub async fn rpc_call<M: Borrow<Message>, N: Borrow<UUID>>( - sys: Arc<System>, - to: N, - msg: M, - timeout: Duration, -) -> Result<Message, Error> { - let addr = { - let status = sys.status.borrow().clone(); - match status.nodes.get(to.borrow()) { - Some(status) => status.addr.clone(), - None => { - return Err(Error::Message(format!( - "Peer ID not found: {:?}", - to.borrow() - ))) - } +pub struct RpcAddrClient<M: RpcMessage> { + phantom: PhantomData<M>, + + pub http_client: Arc<RpcHttpClient>, + pub path: String, +} + +impl<M: RpcMessage> RpcAddrClient<M> { + pub fn new(http_client: Arc<RpcHttpClient>, path: String) -> Self { + Self { + phantom: PhantomData::default(), + http_client: http_client, + path, } - }; - sys.rpc_client.call(&addr, msg, timeout).await + } + + pub async fn call<MB>( + &self, + to_addr: &SocketAddr, + msg: MB, + timeout: Duration, + ) -> Result<M, Error> + where + MB: Borrow<M>, + { + self.http_client + .call(&self.path, to_addr, msg, timeout) + .await + } } -pub enum RpcClient { +pub enum RpcHttpClient { HTTP(Client<HttpConnector, hyper::Body>), HTTPS(Client<tls_util::HttpsConnectorFixedDnsname<HttpConnector>, hyper::Body>), } -impl RpcClient { +impl RpcHttpClient { pub fn new(tls_config: &Option<TlsConfig>) -> Result<Self, Error> { if let Some(cf) = tls_config { let ca_certs = tls_util::load_certs(&cf.ca_cert)?; @@ -130,21 +187,26 @@ impl RpcClient { let connector = tls_util::HttpsConnectorFixedDnsname::<HttpConnector>::new(config, "garage"); - Ok(RpcClient::HTTPS(Client::builder().build(connector))) + Ok(RpcHttpClient::HTTPS(Client::builder().build(connector))) } else { - Ok(RpcClient::HTTP(Client::new())) + Ok(RpcHttpClient::HTTP(Client::new())) } } - pub async fn call<M: Borrow<Message>>( + async fn call<M, MB>( &self, + path: &str, to_addr: &SocketAddr, - msg: M, + msg: MB, timeout: Duration, - ) -> Result<Message, Error> { + ) -> Result<M, Error> + where + MB: Borrow<M>, + M: RpcMessage, + { let uri = match self { - RpcClient::HTTP(_) => format!("http://{}/rpc", to_addr), - RpcClient::HTTPS(_) => format!("https://{}/rpc", to_addr), + RpcHttpClient::HTTP(_) => format!("http://{}/{}", to_addr, path), + RpcHttpClient::HTTPS(_) => format!("https://{}/{}", to_addr, path), }; let req = Request::builder() @@ -153,8 +215,8 @@ impl RpcClient { .body(Body::from(rmp_to_vec_all_named(msg.borrow())?))?; let resp_fut = match self { - RpcClient::HTTP(client) => client.request(req).fuse(), - RpcClient::HTTPS(client) => client.request(req).fuse(), + RpcHttpClient::HTTP(client) => client.request(req).fuse(), + RpcHttpClient::HTTPS(client) => client.request(req).fuse(), }; let resp = tokio::time::timeout(timeout, resp_fut) .await? @@ -168,11 +230,8 @@ impl RpcClient { if resp.status() == StatusCode::OK { let body = hyper::body::to_bytes(resp.into_body()).await?; - let msg = rmp_serde::decode::from_read::<_, Message>(body.into_buf())?; - match msg { - Message::Error(e) => Err(Error::RPCError(e)), - x => Ok(x), - } + let msg = rmp_serde::decode::from_read::<_, Result<M, String>>(body.into_buf())?; + msg.map_err(Error::RPCError) } else { Err(Error::RPCError(format!("Status code {}", resp.status()))) } |