diff options
Diffstat (limited to 'src/rpc')
-rw-r--r-- | src/rpc/membership.rs | 692 | ||||
-rw-r--r-- | src/rpc/mod.rs | 4 | ||||
-rw-r--r-- | src/rpc/rpc_client.rs | 360 | ||||
-rw-r--r-- | src/rpc/rpc_server.rs | 219 | ||||
-rw-r--r-- | src/rpc/tls_util.rs | 139 |
5 files changed, 1414 insertions, 0 deletions
diff --git a/src/rpc/membership.rs b/src/rpc/membership.rs new file mode 100644 index 00000000..e0509536 --- /dev/null +++ b/src/rpc/membership.rs @@ -0,0 +1,692 @@ +use std::collections::HashMap; +use std::hash::Hash as StdHash; +use std::hash::Hasher; +use std::io::{Read, Write}; +use std::net::{IpAddr, SocketAddr}; +use std::path::PathBuf; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +use futures::future::join_all; +use futures::select; +use futures_util::future::*; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use tokio::prelude::*; +use tokio::sync::watch; +use tokio::sync::Mutex; + +use crate::background::BackgroundRunner; +use crate::data::*; +use crate::error::Error; + +use crate::rpc::rpc_client::*; +use crate::rpc::rpc_server::*; + +const PING_INTERVAL: Duration = Duration::from_secs(10); +const PING_TIMEOUT: Duration = Duration::from_secs(2); +const MAX_FAILURES_BEFORE_CONSIDERED_DOWN: usize = 5; + +pub const MEMBERSHIP_RPC_PATH: &str = "_membership"; + +#[derive(Debug, Serialize, Deserialize)] +pub enum Message { + Ok, + Ping(PingMessage), + PullStatus, + PullConfig, + AdvertiseNodesUp(Vec<AdvertisedNode>), + AdvertiseConfig(NetworkConfig), +} + +impl RpcMessage for Message {} + +#[derive(Debug, Serialize, Deserialize)] +pub struct PingMessage { + pub id: UUID, + pub rpc_port: u16, + + pub status_hash: Hash, + pub config_version: u64, + + pub state_info: StateInfo, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AdvertisedNode { + pub id: UUID, + pub addr: SocketAddr, + + pub is_up: bool, + pub last_seen: u64, + + pub state_info: StateInfo, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct NetworkConfig { + pub members: HashMap<UUID, NetworkConfigEntry>, + pub version: u64, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct NetworkConfigEntry { + pub datacenter: String, + pub n_tokens: u32, + pub tag: String, +} + +pub struct System { + pub id: UUID, + pub data_dir: PathBuf, + pub rpc_local_port: u16, + + pub state_info: StateInfo, + + pub rpc_http_client: Arc<RpcHttpClient>, + rpc_client: Arc<RpcClient<Message>>, + + pub status: watch::Receiver<Arc<Status>>, + pub ring: watch::Receiver<Arc<Ring>>, + + update_lock: Mutex<(watch::Sender<Arc<Status>>, watch::Sender<Arc<Ring>>)>, + + pub background: Arc<BackgroundRunner>, +} + +#[derive(Debug, Clone)] +pub struct Status { + pub nodes: HashMap<UUID, Arc<StatusEntry>>, + pub hash: Hash, +} + +#[derive(Debug)] +pub struct StatusEntry { + pub addr: SocketAddr, + pub last_seen: u64, + pub num_failures: AtomicUsize, + pub state_info: StateInfo, +} + +impl StatusEntry { + pub fn is_up(&self) -> bool { + self.num_failures.load(Ordering::SeqCst) < MAX_FAILURES_BEFORE_CONSIDERED_DOWN + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StateInfo { + pub hostname: String, +} + +#[derive(Clone)] +pub struct Ring { + pub config: NetworkConfig, + pub ring: Vec<RingEntry>, + pub n_datacenters: usize, +} + +#[derive(Clone, Debug)] +pub struct RingEntry { + pub location: Hash, + pub node: UUID, + pub datacenter: u64, +} + +impl Status { + fn handle_ping(&mut self, ip: IpAddr, info: &PingMessage) -> bool { + let addr = SocketAddr::new(ip, info.rpc_port); + let old_status = self.nodes.insert( + info.id, + Arc::new(StatusEntry { + addr, + last_seen: now_msec(), + num_failures: AtomicUsize::from(0), + state_info: info.state_info.clone(), + }), + ); + match old_status { + None => { + info!("Newly pingable node: {}", hex::encode(&info.id)); + true + } + Some(x) => x.addr != addr, + } + } + + fn recalculate_hash(&mut self) { + let mut nodes = self.nodes.iter().collect::<Vec<_>>(); + nodes.sort_unstable_by_key(|(id, _status)| *id); + + let mut hasher = Sha256::new(); + debug!("Current set of pingable nodes: --"); + for (id, status) in nodes { + debug!("{} {}", hex::encode(&id), status.addr); + hasher.input(format!("{} {}\n", hex::encode(&id), status.addr)); + } + debug!("END --"); + self.hash + .as_slice_mut() + .copy_from_slice(&hasher.result()[..]); + } +} + +impl Ring { + fn rebuild_ring(&mut self) { + let mut new_ring = vec![]; + let mut datacenters = vec![]; + + for (id, config) in self.config.members.iter() { + let mut dc_hasher = std::collections::hash_map::DefaultHasher::new(); + config.datacenter.hash(&mut dc_hasher); + let datacenter = dc_hasher.finish(); + + if !datacenters.contains(&datacenter) { + datacenters.push(datacenter); + } + + for i in 0..config.n_tokens { + let location = hash(format!("{} {}", hex::encode(&id), i).as_bytes()); + + new_ring.push(RingEntry { + location: location.into(), + node: *id, + datacenter, + }) + } + } + + new_ring.sort_unstable_by(|x, y| x.location.cmp(&y.location)); + self.ring = new_ring; + self.n_datacenters = datacenters.len(); + + // eprintln!("RING: --"); + // for e in self.ring.iter() { + // eprintln!("{:?}", e); + // } + // eprintln!("END --"); + } + + pub fn walk_ring(&self, from: &Hash, n: usize) -> Vec<UUID> { + if n >= self.config.members.len() { + return self.config.members.keys().cloned().collect::<Vec<_>>(); + } + + let start = match self.ring.binary_search_by(|x| x.location.cmp(from)) { + Ok(i) => i, + Err(i) => { + if i == 0 { + self.ring.len() - 1 + } else { + i - 1 + } + } + }; + + self.walk_ring_from_pos(start, n) + } + + fn walk_ring_from_pos(&self, start: usize, n: usize) -> Vec<UUID> { + if n >= self.config.members.len() { + return self.config.members.keys().cloned().collect::<Vec<_>>(); + } + + let mut ret = vec![]; + let mut datacenters = vec![]; + + let mut delta = 0; + while ret.len() < n { + let i = (start + delta) % self.ring.len(); + delta += 1; + + if !datacenters.contains(&self.ring[i].datacenter) { + ret.push(self.ring[i].node); + datacenters.push(self.ring[i].datacenter); + } else if datacenters.len() == self.n_datacenters && !ret.contains(&self.ring[i].node) { + ret.push(self.ring[i].node); + } + } + + ret + } +} + +fn gen_node_id(metadata_dir: &PathBuf) -> Result<UUID, Error> { + let mut id_file = metadata_dir.clone(); + id_file.push("node_id"); + if id_file.as_path().exists() { + let mut f = std::fs::File::open(id_file.as_path())?; + let mut d = vec![]; + f.read_to_end(&mut d)?; + if d.len() != 32 { + return Err(Error::Message(format!("Corrupt node_id file"))); + } + + let mut id = [0u8; 32]; + id.copy_from_slice(&d[..]); + Ok(id.into()) + } else { + let id = gen_uuid(); + + let mut f = std::fs::File::create(id_file.as_path())?; + f.write_all(id.as_slice())?; + Ok(id) + } +} + +fn read_network_config(metadata_dir: &PathBuf) -> Result<NetworkConfig, Error> { + let mut path = metadata_dir.clone(); + path.push("network_config"); + + let mut file = std::fs::OpenOptions::new() + .read(true) + .open(path.as_path())?; + + let mut net_config_bytes = vec![]; + file.read_to_end(&mut net_config_bytes)?; + + let net_config = rmp_serde::decode::from_read_ref(&net_config_bytes[..]) + .expect("Unable to parse network configuration file (has version format changed?)."); + + Ok(net_config) +} + +impl System { + pub fn new( + data_dir: PathBuf, + rpc_http_client: Arc<RpcHttpClient>, + background: Arc<BackgroundRunner>, + rpc_server: &mut RpcServer, + ) -> Arc<Self> { + let id = gen_node_id(&data_dir).expect("Unable to read or generate node ID"); + info!("Node ID: {}", hex::encode(&id)); + + let net_config = match read_network_config(&data_dir) { + Ok(x) => x, + Err(e) => { + info!( + "No valid previous network configuration stored ({}), starting fresh.", + e + ); + NetworkConfig { + members: HashMap::new(), + version: 0, + } + } + }; + let mut status = Status { + nodes: HashMap::new(), + hash: Hash::default(), + }; + status.recalculate_hash(); + let (update_status, status) = watch::channel(Arc::new(status)); + + let state_info = StateInfo { + hostname: gethostname::gethostname() + .into_string() + .unwrap_or("<invalid utf-8>".to_string()), + }; + + let mut ring = Ring { + config: net_config, + ring: Vec::new(), + n_datacenters: 0, + }; + ring.rebuild_ring(); + let (update_ring, ring) = watch::channel(Arc::new(ring)); + + let rpc_path = MEMBERSHIP_RPC_PATH.to_string(); + let rpc_client = RpcClient::new( + RpcAddrClient::<Message>::new(rpc_http_client.clone(), rpc_path.clone()), + background.clone(), + status.clone(), + ); + + let sys = Arc::new(System { + id, + data_dir, + rpc_local_port: rpc_server.bind_addr.port(), + state_info, + rpc_http_client, + rpc_client, + status, + ring, + update_lock: Mutex::new((update_status, update_ring)), + background, + }); + sys.clone().register_handler(rpc_server, rpc_path); + sys + } + + fn register_handler(self: Arc<Self>, rpc_server: &mut RpcServer, path: String) { + rpc_server.add_handler::<Message, _, _>(path, move |msg, addr| { + let self2 = self.clone(); + async move { + match msg { + Message::Ping(ping) => self2.handle_ping(&addr, &ping).await, + + Message::PullStatus => self2.handle_pull_status(), + Message::PullConfig => self2.handle_pull_config(), + Message::AdvertiseNodesUp(adv) => self2.handle_advertise_nodes_up(&adv).await, + Message::AdvertiseConfig(adv) => self2.handle_advertise_config(&adv).await, + + _ => Err(Error::BadRequest(format!("Unexpected RPC message"))), + } + } + }); + } + + pub fn rpc_client<M: RpcMessage + 'static>(self: &Arc<Self>, path: &str) -> Arc<RpcClient<M>> { + RpcClient::new( + RpcAddrClient::new(self.rpc_http_client.clone(), path.to_string()), + self.background.clone(), + self.status.clone(), + ) + } + + async fn save_network_config(self: Arc<Self>) -> Result<(), Error> { + let mut path = self.data_dir.clone(); + path.push("network_config"); + + let ring = self.ring.borrow().clone(); + let data = rmp_to_vec_all_named(&ring.config)?; + + let mut f = tokio::fs::File::create(path.as_path()).await?; + f.write_all(&data[..]).await?; + Ok(()) + } + + pub fn make_ping(&self) -> Message { + let status = self.status.borrow().clone(); + let ring = self.ring.borrow().clone(); + Message::Ping(PingMessage { + id: self.id, + rpc_port: self.rpc_local_port, + status_hash: status.hash, + config_version: ring.config.version, + state_info: self.state_info.clone(), + }) + } + + pub async fn broadcast(self: Arc<Self>, msg: Message, timeout: Duration) { + let status = self.status.borrow().clone(); + let to = status + .nodes + .keys() + .filter(|x| **x != self.id) + .cloned() + .collect::<Vec<_>>(); + self.rpc_client.call_many(&to[..], msg, timeout).await; + } + + pub async fn bootstrap(self: Arc<Self>, peers: &[SocketAddr]) { + let bootstrap_peers = peers.iter().map(|ip| (*ip, None)).collect::<Vec<_>>(); + self.clone().ping_nodes(bootstrap_peers).await; + + self.clone() + .background + .spawn_worker(format!("ping loop"), |stop_signal| { + self.ping_loop(stop_signal).map(Ok) + }) + .await; + } + + async fn ping_nodes(self: Arc<Self>, peers: Vec<(SocketAddr, Option<UUID>)>) { + let ping_msg = self.make_ping(); + let ping_resps = join_all(peers.iter().map(|(addr, id_option)| { + let sys = self.clone(); + let ping_msg_ref = &ping_msg; + async move { + ( + id_option, + addr, + sys.rpc_client + .by_addr() + .call(&addr, ping_msg_ref, PING_TIMEOUT) + .await, + ) + } + })) + .await; + + let update_locked = self.update_lock.lock().await; + let mut status: Status = self.status.borrow().as_ref().clone(); + let ring = self.ring.borrow().clone(); + + let mut has_changes = false; + let mut to_advertise = vec![]; + + for (id_option, addr, ping_resp) in ping_resps { + if let Ok(Ok(Message::Ping(info))) = ping_resp { + let is_new = status.handle_ping(addr.ip(), &info); + if is_new { + has_changes = true; + to_advertise.push(AdvertisedNode { + id: info.id, + addr: *addr, + is_up: true, + last_seen: now_msec(), + state_info: info.state_info.clone(), + }); + } + if is_new || status.hash != info.status_hash { + self.background + .spawn_cancellable(self.clone().pull_status(info.id).map(Ok)); + } + if is_new || ring.config.version < info.config_version { + self.background + .spawn_cancellable(self.clone().pull_config(info.id).map(Ok)); + } + } else if let Some(id) = id_option { + if let Some(st) = status.nodes.get_mut(id) { + st.num_failures.fetch_add(1, Ordering::SeqCst); + if !st.is_up() { + warn!("Node {:?} seems to be down.", id); + if !ring.config.members.contains_key(id) { + info!("Removing node {:?} from status (not in config and not responding to pings anymore)", id); + drop(st); + status.nodes.remove(&id); + has_changes = true; + } + } + } + } + } + if has_changes { + status.recalculate_hash(); + } + if let Err(e) = update_locked.0.broadcast(Arc::new(status)) { + error!("In ping_nodes: could not save status update ({})", e); + } + drop(update_locked); + + if to_advertise.len() > 0 { + self.broadcast(Message::AdvertiseNodesUp(to_advertise), PING_TIMEOUT) + .await; + } + } + + pub async fn handle_ping( + self: Arc<Self>, + from: &SocketAddr, + ping: &PingMessage, + ) -> Result<Message, Error> { + let update_locked = self.update_lock.lock().await; + let mut status: Status = self.status.borrow().as_ref().clone(); + + let is_new = status.handle_ping(from.ip(), ping); + if is_new { + status.recalculate_hash(); + } + let status_hash = status.hash; + let config_version = self.ring.borrow().config.version; + + update_locked.0.broadcast(Arc::new(status))?; + drop(update_locked); + + if is_new || status_hash != ping.status_hash { + self.background + .spawn_cancellable(self.clone().pull_status(ping.id).map(Ok)); + } + if is_new || config_version < ping.config_version { + self.background + .spawn_cancellable(self.clone().pull_config(ping.id).map(Ok)); + } + + Ok(self.make_ping()) + } + + pub fn handle_pull_status(&self) -> Result<Message, Error> { + let status = self.status.borrow().clone(); + let mut mem = vec![]; + for (node, status) in status.nodes.iter() { + let state_info = if *node == self.id { + self.state_info.clone() + } else { + status.state_info.clone() + }; + mem.push(AdvertisedNode { + id: *node, + addr: status.addr, + is_up: status.is_up(), + last_seen: status.last_seen, + state_info, + }); + } + Ok(Message::AdvertiseNodesUp(mem)) + } + + pub fn handle_pull_config(&self) -> Result<Message, Error> { + let ring = self.ring.borrow().clone(); + Ok(Message::AdvertiseConfig(ring.config.clone())) + } + + pub async fn handle_advertise_nodes_up( + self: Arc<Self>, + adv: &[AdvertisedNode], + ) -> Result<Message, Error> { + let mut to_ping = vec![]; + + let update_lock = self.update_lock.lock().await; + let mut status: Status = self.status.borrow().as_ref().clone(); + let mut has_changed = false; + + for node in adv.iter() { + if node.id == self.id { + // learn our own ip address + let self_addr = SocketAddr::new(node.addr.ip(), self.rpc_local_port); + let old_self = status.nodes.insert( + node.id, + Arc::new(StatusEntry { + addr: self_addr, + last_seen: now_msec(), + num_failures: AtomicUsize::from(0), + state_info: self.state_info.clone(), + }), + ); + has_changed = match old_self { + None => true, + Some(x) => x.addr != self_addr, + }; + } else { + let ping_them = match status.nodes.get(&node.id) { + // Case 1: new node + None => true, + // Case 2: the node might have changed address + Some(our_node) => node.is_up && !our_node.is_up() && our_node.addr != node.addr, + }; + if ping_them { + to_ping.push((node.addr, Some(node.id))); + } + } + } + if has_changed { + status.recalculate_hash(); + } + update_lock.0.broadcast(Arc::new(status))?; + drop(update_lock); + + if to_ping.len() > 0 { + self.background + .spawn_cancellable(self.clone().ping_nodes(to_ping).map(Ok)); + } + + Ok(Message::Ok) + } + + pub async fn handle_advertise_config( + self: Arc<Self>, + adv: &NetworkConfig, + ) -> Result<Message, Error> { + let update_lock = self.update_lock.lock().await; + let mut ring: Ring = self.ring.borrow().as_ref().clone(); + + if adv.version > ring.config.version { + ring.config = adv.clone(); + ring.rebuild_ring(); + update_lock.1.broadcast(Arc::new(ring))?; + drop(update_lock); + + self.background.spawn_cancellable( + self.clone() + .broadcast(Message::AdvertiseConfig(adv.clone()), PING_TIMEOUT) + .map(Ok), + ); + self.background.spawn(self.clone().save_network_config()); + } + + Ok(Message::Ok) + } + + pub async fn ping_loop(self: Arc<Self>, mut stop_signal: watch::Receiver<bool>) { + loop { + let restart_at = tokio::time::delay_for(PING_INTERVAL); + + let status = self.status.borrow().clone(); + let ping_addrs = status + .nodes + .iter() + .filter(|(id, _)| **id != self.id) + .map(|(id, status)| (status.addr, Some(*id))) + .collect::<Vec<_>>(); + + self.clone().ping_nodes(ping_addrs).await; + + select! { + _ = restart_at.fuse() => (), + must_exit = stop_signal.recv().fuse() => { + match must_exit { + None | Some(true) => return, + _ => (), + } + } + } + } + } + + pub fn pull_status( + self: Arc<Self>, + peer: UUID, + ) -> impl futures::future::Future<Output = ()> + Send + 'static { + async move { + let resp = self + .rpc_client + .call(peer, Message::PullStatus, PING_TIMEOUT) + .await; + if let Ok(Message::AdvertiseNodesUp(nodes)) = resp { + let _: Result<_, _> = self.handle_advertise_nodes_up(&nodes).await; + } + } + } + + pub async fn pull_config(self: Arc<Self>, peer: UUID) { + let resp = self + .rpc_client + .call(peer, Message::PullConfig, PING_TIMEOUT) + .await; + if let Ok(Message::AdvertiseConfig(config)) = resp { + let _: Result<_, _> = self.handle_advertise_config(&config).await; + } + } +} diff --git a/src/rpc/mod.rs b/src/rpc/mod.rs new file mode 100644 index 00000000..83fd0aac --- /dev/null +++ b/src/rpc/mod.rs @@ -0,0 +1,4 @@ +pub mod membership; +pub mod rpc_client; +pub mod rpc_server; +pub mod tls_util; diff --git a/src/rpc/rpc_client.rs b/src/rpc/rpc_client.rs new file mode 100644 index 00000000..027a3cde --- /dev/null +++ b/src/rpc/rpc_client.rs @@ -0,0 +1,360 @@ +use std::borrow::Borrow; +use std::marker::PhantomData; +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::atomic::Ordering; +use std::sync::Arc; +use std::time::Duration; + +use arc_swap::ArcSwapOption; +use bytes::IntoBuf; +use err_derive::Error; +use futures::future::Future; +use futures::stream::futures_unordered::FuturesUnordered; +use futures::stream::StreamExt; +use futures_util::future::FutureExt; +use hyper::client::{Client, HttpConnector}; +use hyper::{Body, Method, Request}; +use tokio::sync::{watch, Semaphore}; + +use crate::background::BackgroundRunner; +use crate::data::*; +use crate::error::Error; + +use crate::rpc::membership::Status; +use crate::rpc::rpc_server::RpcMessage; +use crate::rpc::tls_util; + +use crate::config::TlsConfig; + +const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10); + +#[derive(Debug, Error)] +pub enum RPCError { + #[error(display = "Node is down: {:?}.", _0)] + NodeDown(UUID), + #[error(display = "Timeout: {}", _0)] + Timeout(#[error(source)] tokio::time::Elapsed), + #[error(display = "HTTP error: {}", _0)] + HTTP(#[error(source)] http::Error), + #[error(display = "Hyper error: {}", _0)] + Hyper(#[error(source)] hyper::Error), + #[error(display = "Messagepack encode error: {}", _0)] + RMPEncode(#[error(source)] rmp_serde::encode::Error), + #[error(display = "Messagepack decode error: {}", _0)] + RMPDecode(#[error(source)] rmp_serde::decode::Error), + #[error(display = "Too many errors: {:?}", _0)] + TooManyErrors(Vec<String>), +} + +#[derive(Copy, Clone)] +pub struct RequestStrategy { + pub rs_timeout: Duration, + pub rs_quorum: usize, + pub rs_interrupt_after_quorum: bool, +} + +impl RequestStrategy { + pub fn with_quorum(quorum: usize) -> Self { + RequestStrategy { + rs_timeout: DEFAULT_TIMEOUT, + rs_quorum: quorum, + rs_interrupt_after_quorum: false, + } + } + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.rs_timeout = timeout; + self + } + pub fn interrupt_after_quorum(mut self, interrupt: bool) -> Self { + self.rs_interrupt_after_quorum = interrupt; + self + } +} + +pub type LocalHandlerFn<M> = + Box<dyn Fn(Arc<M>) -> Pin<Box<dyn Future<Output = Result<M, Error>> + Send>> + Send + Sync>; + +pub struct RpcClient<M: RpcMessage> { + status: watch::Receiver<Arc<Status>>, + background: Arc<BackgroundRunner>, + + local_handler: ArcSwapOption<(UUID, LocalHandlerFn<M>)>, + + pub rpc_addr_client: RpcAddrClient<M>, +} + +impl<M: RpcMessage + 'static> RpcClient<M> { + pub fn new( + rac: RpcAddrClient<M>, + background: Arc<BackgroundRunner>, + status: watch::Receiver<Arc<Status>>, + ) -> Arc<Self> { + Arc::new(Self { + rpc_addr_client: rac, + background, + status, + local_handler: ArcSwapOption::new(None), + }) + } + + pub fn set_local_handler<F, Fut>(&self, my_id: UUID, handler: F) + where + F: Fn(Arc<M>) -> Fut + Send + Sync + 'static, + Fut: Future<Output = Result<M, Error>> + Send + 'static, + { + let handler_arc = Arc::new(handler); + let handler: LocalHandlerFn<M> = Box::new(move |msg| { + let handler_arc2 = handler_arc.clone(); + Box::pin(async move { handler_arc2(msg).await }) + }); + self.local_handler.swap(Some(Arc::new((my_id, handler)))); + } + + pub fn by_addr(&self) -> &RpcAddrClient<M> { + &self.rpc_addr_client + } + + pub async fn call(&self, to: UUID, msg: M, timeout: Duration) -> Result<M, Error> { + self.call_arc(to, Arc::new(msg), timeout).await + } + + pub async fn call_arc(&self, to: UUID, msg: Arc<M>, timeout: Duration) -> Result<M, Error> { + if let Some(lh) = self.local_handler.load_full() { + let (my_id, local_handler) = lh.as_ref(); + if to.borrow() == my_id { + return local_handler(msg).await; + } + } + let status = self.status.borrow().clone(); + let node_status = match status.nodes.get(&to) { + Some(node_status) => { + if node_status.is_up() { + node_status + } else { + return Err(Error::from(RPCError::NodeDown(to))); + } + } + None => { + return Err(Error::Message(format!( + "Peer ID not found: {:?}", + to.borrow() + ))) + } + }; + match self + .rpc_addr_client + .call(&node_status.addr, msg, timeout) + .await + { + Err(rpc_error) => { + node_status.num_failures.fetch_add(1, Ordering::SeqCst); + // TODO: Save failure info somewhere + Err(Error::from(rpc_error)) + } + Ok(x) => x, + } + } + + pub async fn call_many(&self, to: &[UUID], msg: M, timeout: Duration) -> Vec<Result<M, Error>> { + let msg = Arc::new(msg); + let mut resp_stream = to + .iter() + .map(|to| self.call_arc(*to, msg.clone(), timeout)) + .collect::<FuturesUnordered<_>>(); + + let mut results = vec![]; + while let Some(resp) = resp_stream.next().await { + results.push(resp); + } + results + } + + pub async fn try_call_many( + self: &Arc<Self>, + to: &[UUID], + msg: M, + strategy: RequestStrategy, + ) -> Result<Vec<M>, Error> { + let timeout = strategy.rs_timeout; + + let msg = Arc::new(msg); + let mut resp_stream = to + .to_vec() + .into_iter() + .map(|to| { + let self2 = self.clone(); + let msg = msg.clone(); + async move { self2.call_arc(to, msg, timeout).await } + }) + .collect::<FuturesUnordered<_>>(); + + let mut results = vec![]; + let mut errors = vec![]; + + while let Some(resp) = resp_stream.next().await { + match resp { + Ok(msg) => { + results.push(msg); + if results.len() >= strategy.rs_quorum { + break; + } + } + Err(e) => { + errors.push(e); + } + } + } + + if results.len() >= strategy.rs_quorum { + // Continue requests in background. + // Continue the remaining requests immediately using tokio::spawn + // but enqueue a task in the background runner + // to ensure that the process won't exit until the requests are done + // (if we had just enqueued the resp_stream.collect directly in the background runner, + // the requests might have been put on hold in the background runner's queue, + // in which case they might timeout or otherwise fail) + if !strategy.rs_interrupt_after_quorum { + let wait_finished_fut = tokio::spawn(async move { + resp_stream.collect::<Vec<_>>().await; + Ok(()) + }); + self.background.spawn(wait_finished_fut.map(|x| { + x.unwrap_or_else(|e| Err(Error::Message(format!("Await failed: {}", e)))) + })); + } + + Ok(results) + } else { + let errors = errors.iter().map(|e| format!("{}", e)).collect::<Vec<_>>(); + Err(Error::from(RPCError::TooManyErrors(errors))) + } + } +} + +pub struct RpcAddrClient<M: RpcMessage> { + phantom: PhantomData<M>, + + pub http_client: Arc<RpcHttpClient>, + pub path: String, +} + +impl<M: RpcMessage> RpcAddrClient<M> { + pub fn new(http_client: Arc<RpcHttpClient>, path: String) -> Self { + Self { + phantom: PhantomData::default(), + http_client: http_client, + path, + } + } + + pub async fn call<MB>( + &self, + to_addr: &SocketAddr, + msg: MB, + timeout: Duration, + ) -> Result<Result<M, Error>, RPCError> + where + MB: Borrow<M>, + { + self.http_client + .call(&self.path, to_addr, msg, timeout) + .await + } +} + +pub struct RpcHttpClient { + request_limiter: Semaphore, + method: ClientMethod, +} + +enum ClientMethod { + HTTP(Client<HttpConnector, hyper::Body>), + HTTPS(Client<tls_util::HttpsConnectorFixedDnsname<HttpConnector>, hyper::Body>), +} + +impl RpcHttpClient { + pub fn new( + max_concurrent_requests: usize, + tls_config: &Option<TlsConfig>, + ) -> Result<Self, Error> { + let method = if let Some(cf) = tls_config { + let ca_certs = tls_util::load_certs(&cf.ca_cert)?; + let node_certs = tls_util::load_certs(&cf.node_cert)?; + let node_key = tls_util::load_private_key(&cf.node_key)?; + + let mut config = rustls::ClientConfig::new(); + + for crt in ca_certs.iter() { + config.root_store.add(crt)?; + } + + config.set_single_client_cert([&node_certs[..], &ca_certs[..]].concat(), node_key)?; + + let connector = + tls_util::HttpsConnectorFixedDnsname::<HttpConnector>::new(config, "garage"); + + ClientMethod::HTTPS(Client::builder().build(connector)) + } else { + ClientMethod::HTTP(Client::new()) + }; + Ok(RpcHttpClient { + method, + request_limiter: Semaphore::new(max_concurrent_requests), + }) + } + + async fn call<M, MB>( + &self, + path: &str, + to_addr: &SocketAddr, + msg: MB, + timeout: Duration, + ) -> Result<Result<M, Error>, RPCError> + where + MB: Borrow<M>, + M: RpcMessage, + { + let uri = match self.method { + ClientMethod::HTTP(_) => format!("http://{}/{}", to_addr, path), + ClientMethod::HTTPS(_) => format!("https://{}/{}", to_addr, path), + }; + + let req = Request::builder() + .method(Method::POST) + .uri(uri) + .body(Body::from(rmp_to_vec_all_named(msg.borrow())?))?; + + let resp_fut = match &self.method { + ClientMethod::HTTP(client) => client.request(req).fuse(), + ClientMethod::HTTPS(client) => client.request(req).fuse(), + }; + + let slot = self.request_limiter.acquire().await; + let resp = tokio::time::timeout(timeout, resp_fut) + .await + .map_err(|e| { + debug!( + "RPC timeout to {}: {}", + to_addr, + debug_serialize(msg.borrow()) + ); + e + })? + .map_err(|e| { + warn!( + "RPC HTTP client error when connecting to {}: {}", + to_addr, e + ); + e + })?; + drop(slot); + + let status = resp.status(); + let body = hyper::body::to_bytes(resp.into_body()).await?; + match rmp_serde::decode::from_read::<_, Result<M, String>>(body.into_buf())? { + Err(e) => Ok(Err(Error::RemoteError(e, status))), + Ok(x) => Ok(Ok(x)), + } + } +} diff --git a/src/rpc/rpc_server.rs b/src/rpc/rpc_server.rs new file mode 100644 index 00000000..4ee53909 --- /dev/null +++ b/src/rpc/rpc_server.rs @@ -0,0 +1,219 @@ +use std::collections::HashMap; +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::Arc; +use std::time::Instant; + +use bytes::IntoBuf; +use futures::future::Future; +use futures_util::future::*; +use futures_util::stream::*; +use hyper::server::conn::AddrStream; +use hyper::service::{make_service_fn, service_fn}; +use hyper::{Body, Method, Request, Response, Server, StatusCode}; +use serde::{Deserialize, Serialize}; +use tokio::net::{TcpListener, TcpStream}; +use tokio_rustls::server::TlsStream; +use tokio_rustls::TlsAcceptor; + +use crate::config::TlsConfig; +use crate::data::*; +use crate::error::Error; + +use crate::rpc::tls_util; + +pub trait RpcMessage: Serialize + for<'de> Deserialize<'de> + Send + Sync {} + +type ResponseFuture = Pin<Box<dyn Future<Output = Result<Response<Body>, Error>> + Send>>; +type Handler = Box<dyn Fn(Request<Body>, SocketAddr) -> ResponseFuture + Send + Sync>; + +pub struct RpcServer { + pub bind_addr: SocketAddr, + pub tls_config: Option<TlsConfig>, + + handlers: HashMap<String, Handler>, +} + +async fn handle_func<M, F, Fut>( + handler: Arc<F>, + req: Request<Body>, + sockaddr: SocketAddr, + name: Arc<String>, +) -> Result<Response<Body>, Error> +where + M: RpcMessage + 'static, + F: Fn(M, SocketAddr) -> Fut + Send + Sync + 'static, + Fut: Future<Output = Result<M, Error>> + Send + 'static, +{ + let begin_time = Instant::now(); + let whole_body = hyper::body::to_bytes(req.into_body()).await?; + let msg = rmp_serde::decode::from_read::<_, M>(whole_body.into_buf())?; + match handler(msg, sockaddr).await { + Ok(resp) => { + let resp_bytes = rmp_to_vec_all_named::<Result<M, String>>(&Ok(resp))?; + let rpc_duration = (Instant::now() - begin_time).as_millis(); + if rpc_duration > 100 { + debug!("RPC {} ok, took long: {} ms", name, rpc_duration,); + } + Ok(Response::new(Body::from(resp_bytes))) + } + Err(e) => { + let err_str = format!("{}", e); + let rep_bytes = rmp_to_vec_all_named::<Result<M, String>>(&Err(err_str))?; + let mut err_response = Response::new(Body::from(rep_bytes)); + *err_response.status_mut() = e.http_status_code(); + warn!( + "RPC error ({}): {} ({} ms)", + name, + e, + (Instant::now() - begin_time).as_millis(), + ); + Ok(err_response) + } + } +} + +impl RpcServer { + pub fn new(bind_addr: SocketAddr, tls_config: Option<TlsConfig>) -> Self { + Self { + bind_addr, + tls_config, + handlers: HashMap::new(), + } + } + + pub fn add_handler<M, F, Fut>(&mut self, name: String, handler: F) + where + M: RpcMessage + 'static, + F: Fn(M, SocketAddr) -> Fut + Send + Sync + 'static, + Fut: Future<Output = Result<M, Error>> + Send + 'static, + { + let name2 = Arc::new(name.clone()); + let handler_arc = Arc::new(handler); + let handler = Box::new(move |req: Request<Body>, sockaddr: SocketAddr| { + let handler2 = handler_arc.clone(); + let b: ResponseFuture = Box::pin(handle_func(handler2, req, sockaddr, name2.clone())); + b + }); + self.handlers.insert(name, handler); + } + + async fn handler( + self: Arc<Self>, + req: Request<Body>, + addr: SocketAddr, + ) -> Result<Response<Body>, Error> { + if req.method() != &Method::POST { + let mut bad_request = Response::default(); + *bad_request.status_mut() = StatusCode::BAD_REQUEST; + return Ok(bad_request); + } + + let path = &req.uri().path()[1..]; + let handler = match self.handlers.get(path) { + Some(h) => h, + None => { + let mut not_found = Response::default(); + *not_found.status_mut() = StatusCode::NOT_FOUND; + return Ok(not_found); + } + }; + + let resp_waiter = tokio::spawn(handler(req, addr)); + match resp_waiter.await { + Err(err) => { + warn!("Handler await error: {}", err); + let mut ise = Response::default(); + *ise.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + Ok(ise) + } + Ok(Err(err)) => { + let mut bad_request = Response::new(Body::from(format!("{}", err))); + *bad_request.status_mut() = StatusCode::BAD_REQUEST; + Ok(bad_request) + } + Ok(Ok(resp)) => Ok(resp), + } + } + + pub async fn run( + self: Arc<Self>, + shutdown_signal: impl Future<Output = ()>, + ) -> Result<(), Error> { + if let Some(tls_config) = self.tls_config.as_ref() { + let ca_certs = tls_util::load_certs(&tls_config.ca_cert)?; + let node_certs = tls_util::load_certs(&tls_config.node_cert)?; + let node_key = tls_util::load_private_key(&tls_config.node_key)?; + + let mut ca_store = rustls::RootCertStore::empty(); + for crt in ca_certs.iter() { + ca_store.add(crt)?; + } + + let mut config = + rustls::ServerConfig::new(rustls::AllowAnyAuthenticatedClient::new(ca_store)); + config.set_single_cert([&node_certs[..], &ca_certs[..]].concat(), node_key)?; + let tls_acceptor = Arc::new(TlsAcceptor::from(Arc::new(config))); + + let mut listener = TcpListener::bind(&self.bind_addr).await?; + let incoming = listener.incoming().filter_map(|socket| async { + match socket { + Ok(stream) => match tls_acceptor.clone().accept(stream).await { + Ok(x) => Some(Ok::<_, hyper::Error>(x)), + Err(_e) => None, + }, + Err(_) => None, + } + }); + let incoming = hyper::server::accept::from_stream(incoming); + + let self_arc = self.clone(); + let service = make_service_fn(|conn: &TlsStream<TcpStream>| { + let client_addr = conn + .get_ref() + .0 + .peer_addr() + .unwrap_or(([0, 0, 0, 0], 0).into()); + let self_arc = self_arc.clone(); + async move { + Ok::<_, Error>(service_fn(move |req: Request<Body>| { + self_arc.clone().handler(req, client_addr).map_err(|e| { + warn!("RPC handler error: {}", e); + e + }) + })) + } + }); + + let server = Server::builder(incoming).serve(service); + + let graceful = server.with_graceful_shutdown(shutdown_signal); + info!("RPC server listening on http://{}", self.bind_addr); + + graceful.await?; + } else { + let self_arc = self.clone(); + let service = make_service_fn(move |conn: &AddrStream| { + let client_addr = conn.remote_addr(); + let self_arc = self_arc.clone(); + async move { + Ok::<_, Error>(service_fn(move |req: Request<Body>| { + self_arc.clone().handler(req, client_addr).map_err(|e| { + warn!("RPC handler error: {}", e); + e + }) + })) + } + }); + + let server = Server::bind(&self.bind_addr).serve(service); + + let graceful = server.with_graceful_shutdown(shutdown_signal); + info!("RPC server listening on http://{}", self.bind_addr); + + graceful.await?; + } + + Ok(()) + } +} diff --git a/src/rpc/tls_util.rs b/src/rpc/tls_util.rs new file mode 100644 index 00000000..52c52110 --- /dev/null +++ b/src/rpc/tls_util.rs @@ -0,0 +1,139 @@ +use core::future::Future; +use core::task::{Context, Poll}; +use std::pin::Pin; +use std::sync::Arc; +use std::{fs, io}; + +use futures_util::future::*; +use hyper::client::connect::Connection; +use hyper::client::HttpConnector; +use hyper::service::Service; +use hyper::Uri; +use hyper_rustls::MaybeHttpsStream; +use rustls::internal::pemfile; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_rustls::TlsConnector; +use webpki::DNSNameRef; + +use crate::error::Error; + +pub fn load_certs(filename: &str) -> Result<Vec<rustls::Certificate>, Error> { + let certfile = fs::File::open(&filename)?; + let mut reader = io::BufReader::new(certfile); + + let certs = pemfile::certs(&mut reader).map_err(|_| { + Error::Message(format!( + "Could not deecode certificates from file: {}", + filename + )) + })?; + + if certs.is_empty() { + return Err(Error::Message(format!( + "Invalid certificate file: {}", + filename + ))); + } + Ok(certs) +} + +pub fn load_private_key(filename: &str) -> Result<rustls::PrivateKey, Error> { + let keyfile = fs::File::open(&filename)?; + let mut reader = io::BufReader::new(keyfile); + + let keys = pemfile::rsa_private_keys(&mut reader).map_err(|_| { + Error::Message(format!( + "Could not decode private key from file: {}", + filename + )) + })?; + + if keys.len() != 1 { + return Err(Error::Message(format!( + "Invalid private key file: {} ({} private keys)", + filename, + keys.len() + ))); + } + Ok(keys[0].clone()) +} + +// ---- AWFUL COPYPASTA FROM HYPER-RUSTLS connector.rs +// ---- ALWAYS USE `garage` AS HOSTNAME FOR TLS VERIFICATION + +#[derive(Clone)] +pub struct HttpsConnectorFixedDnsname<T> { + http: T, + tls_config: Arc<rustls::ClientConfig>, + fixed_dnsname: &'static str, +} + +type BoxError = Box<dyn std::error::Error + Send + Sync>; + +impl HttpsConnectorFixedDnsname<HttpConnector> { + pub fn new(mut tls_config: rustls::ClientConfig, fixed_dnsname: &'static str) -> Self { + let mut http = HttpConnector::new(); + http.enforce_http(false); + tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + Self { + http, + tls_config: Arc::new(tls_config), + fixed_dnsname, + } + } +} + +impl<T> Service<Uri> for HttpsConnectorFixedDnsname<T> +where + T: Service<Uri>, + T::Response: Connection + AsyncRead + AsyncWrite + Send + Unpin + 'static, + T::Future: Send + 'static, + T::Error: Into<BoxError>, +{ + type Response = MaybeHttpsStream<T::Response>; + type Error = BoxError; + + #[allow(clippy::type_complexity)] + type Future = + Pin<Box<dyn Future<Output = Result<MaybeHttpsStream<T::Response>, BoxError>> + Send>>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + match self.http.poll_ready(cx) { + Poll::Ready(Ok(())) => Poll::Ready(Ok(())), + Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), + Poll::Pending => Poll::Pending, + } + } + + fn call(&mut self, dst: Uri) -> Self::Future { + let is_https = dst.scheme_str() == Some("https"); + + if !is_https { + let connecting_future = self.http.call(dst); + + let f = async move { + let tcp = connecting_future.await.map_err(Into::into)?; + + Ok(MaybeHttpsStream::Http(tcp)) + }; + f.boxed() + } else { + let cfg = self.tls_config.clone(); + let connecting_future = self.http.call(dst); + + let dnsname = + DNSNameRef::try_from_ascii_str(self.fixed_dnsname).expect("Invalid fixed dnsname"); + + let f = async move { + let tcp = connecting_future.await.map_err(Into::into)?; + let connector = TlsConnector::from(cfg); + let tls = connector + .connect(dnsname, tcp) + .await + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + Ok(MaybeHttpsStream::Https(tls)) + }; + f.boxed() + } + } +} |