From f41583e1b731574b4bb13a20d4b3fd9fe3a899f5 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Sat, 18 Apr 2020 19:21:34 +0200 Subject: Massive RPC refactoring --- src/rpc_client.rs | 247 +++++++++++++++++++++++++++++++++--------------------- 1 file changed, 153 insertions(+), 94 deletions(-) (limited to 'src/rpc_client.rs') diff --git a/src/rpc_client.rs b/src/rpc_client.rs index f8da778c..6d26d86a 100644 --- a/src/rpc_client.rs +++ b/src/rpc_client.rs @@ -1,4 +1,5 @@ use std::borrow::Borrow; +use std::marker::PhantomData; use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; @@ -9,110 +10,166 @@ use futures::stream::StreamExt; use futures_util::future::FutureExt; use hyper::client::{Client, HttpConnector}; use hyper::{Body, Method, Request, StatusCode}; +use tokio::sync::watch; +use crate::background::*; use crate::data::*; use crate::error::Error; -use crate::membership::System; -use crate::proto::Message; +use crate::membership::Status; +use crate::rpc_server::RpcMessage; use crate::server::*; use crate::tls_util; -pub async fn rpc_call_many( - sys: Arc, - to: &[UUID], - msg: Message, - timeout: Duration, -) -> Vec> { - let msg = Arc::new(msg); - let mut resp_stream = to - .iter() - .map(|to| rpc_call(sys.clone(), to, msg.clone(), timeout)) - .collect::>(); - - let mut results = vec![]; - while let Some(resp) = resp_stream.next().await { - results.push(resp); - } - results +pub struct RpcClient { + status: watch::Receiver>, + background: Arc, + + pub rpc_addr_client: RpcAddrClient, } -pub async fn rpc_try_call_many( - sys: Arc, - to: &[UUID], - msg: Message, - stop_after: usize, - timeout: Duration, -) -> Result, Error> { - let sys2 = sys.clone(); - let msg = Arc::new(msg); - let mut resp_stream = to - .to_vec() - .into_iter() - .map(move |to| rpc_call(sys2.clone(), to.clone(), msg.clone(), timeout)) - .collect::>(); - - let mut results = vec![]; - let mut errors = vec![]; - - while let Some(resp) = resp_stream.next().await { - match resp { - Ok(msg) => { - results.push(msg); - if results.len() >= stop_after { - break; +impl RpcClient { + pub fn new( + rac: RpcAddrClient, + background: Arc, + status: watch::Receiver>, + ) -> Arc { + Arc::new(Self { + rpc_addr_client: rac, + background, + status, + }) + } + + pub fn by_addr(&self) -> &RpcAddrClient { + &self.rpc_addr_client + } + + pub async fn call, N: Borrow>( + &self, + to: N, + msg: MB, + timeout: Duration, + ) -> Result { + let addr = { + let status = self.status.borrow().clone(); + match status.nodes.get(to.borrow()) { + Some(status) => status.addr.clone(), + None => { + return Err(Error::Message(format!( + "Peer ID not found: {:?}", + to.borrow() + ))) } } - Err(e) => { - errors.push(e); - } + }; + self.rpc_addr_client.call(&addr, msg, timeout).await + } + + pub async fn call_many(&self, to: &[UUID], msg: M, timeout: Duration) -> Vec> { + let msg = Arc::new(msg); + let mut resp_stream = to + .iter() + .map(|to| self.call(to, msg.clone(), timeout)) + .collect::>(); + + let mut results = vec![]; + while let Some(resp) = resp_stream.next().await { + results.push(resp); } + results } - if results.len() >= stop_after { - // Continue requests in background - // TODO: make this optionnal (only usefull for write requests) - sys.background.spawn(async move { - resp_stream.collect::>().await; - Ok(()) - }); - - Ok(results) - } else { - let mut msg = "Too many failures:".to_string(); - for e in errors { - msg += &format!("\n{}", e); + pub async fn try_call_many( + self: &Arc, + to: &[UUID], + msg: M, + stop_after: usize, + timeout: Duration, + ) -> Result, Error> { + let msg = Arc::new(msg); + let mut resp_stream = to + .to_vec() + .into_iter() + .map(|to| { + let self2 = self.clone(); + let msg = msg.clone(); + async move { self2.call(to.clone(), msg, timeout).await } + }) + .collect::>(); + + let mut results = vec![]; + let mut errors = vec![]; + + while let Some(resp) = resp_stream.next().await { + match resp { + Ok(msg) => { + results.push(msg); + if results.len() >= stop_after { + break; + } + } + Err(e) => { + errors.push(e); + } + } + } + + if results.len() >= stop_after { + // Continue requests in background + // TODO: make this optionnal (only usefull for write requests) + self.clone().background.spawn(async move { + resp_stream.collect::>().await; + Ok(()) + }); + + Ok(results) + } else { + let mut msg = "Too many failures:".to_string(); + for e in errors { + msg += &format!("\n{}", e); + } + Err(Error::Message(msg)) } - Err(Error::Message(msg)) } } -pub async fn rpc_call, N: Borrow>( - sys: Arc, - to: N, - msg: M, - timeout: Duration, -) -> Result { - let addr = { - let status = sys.status.borrow().clone(); - match status.nodes.get(to.borrow()) { - Some(status) => status.addr.clone(), - None => { - return Err(Error::Message(format!( - "Peer ID not found: {:?}", - to.borrow() - ))) - } +pub struct RpcAddrClient { + phantom: PhantomData, + + pub http_client: Arc, + pub path: String, +} + +impl RpcAddrClient { + pub fn new(http_client: Arc, path: String) -> Self { + Self { + phantom: PhantomData::default(), + http_client: http_client, + path, } - }; - sys.rpc_client.call(&addr, msg, timeout).await + } + + pub async fn call( + &self, + to_addr: &SocketAddr, + msg: MB, + timeout: Duration, + ) -> Result + where + MB: Borrow, + { + self.http_client + .call(&self.path, to_addr, msg, timeout) + .await + } } -pub enum RpcClient { +pub enum RpcHttpClient { HTTP(Client), HTTPS(Client, hyper::Body>), } -impl RpcClient { +impl RpcHttpClient { pub fn new(tls_config: &Option) -> Result { if let Some(cf) = tls_config { let ca_certs = tls_util::load_certs(&cf.ca_cert)?; @@ -130,21 +187,26 @@ impl RpcClient { let connector = tls_util::HttpsConnectorFixedDnsname::::new(config, "garage"); - Ok(RpcClient::HTTPS(Client::builder().build(connector))) + Ok(RpcHttpClient::HTTPS(Client::builder().build(connector))) } else { - Ok(RpcClient::HTTP(Client::new())) + Ok(RpcHttpClient::HTTP(Client::new())) } } - pub async fn call>( + async fn call( &self, + path: &str, to_addr: &SocketAddr, - msg: M, + msg: MB, timeout: Duration, - ) -> Result { + ) -> Result + where + MB: Borrow, + M: RpcMessage, + { let uri = match self { - RpcClient::HTTP(_) => format!("http://{}/rpc", to_addr), - RpcClient::HTTPS(_) => format!("https://{}/rpc", to_addr), + RpcHttpClient::HTTP(_) => format!("http://{}/{}", to_addr, path), + RpcHttpClient::HTTPS(_) => format!("https://{}/{}", to_addr, path), }; let req = Request::builder() @@ -153,8 +215,8 @@ impl RpcClient { .body(Body::from(rmp_to_vec_all_named(msg.borrow())?))?; let resp_fut = match self { - RpcClient::HTTP(client) => client.request(req).fuse(), - RpcClient::HTTPS(client) => client.request(req).fuse(), + RpcHttpClient::HTTP(client) => client.request(req).fuse(), + RpcHttpClient::HTTPS(client) => client.request(req).fuse(), }; let resp = tokio::time::timeout(timeout, resp_fut) .await? @@ -168,11 +230,8 @@ impl RpcClient { if resp.status() == StatusCode::OK { let body = hyper::body::to_bytes(resp.into_body()).await?; - let msg = rmp_serde::decode::from_read::<_, Message>(body.into_buf())?; - match msg { - Message::Error(e) => Err(Error::RPCError(e)), - x => Ok(x), - } + let msg = rmp_serde::decode::from_read::<_, Result>(body.into_buf())?; + msg.map_err(Error::RPCError) } else { Err(Error::RPCError(format!("Status code {}", resp.status()))) } -- cgit v1.2.3