diff options
Diffstat (limited to 'src/rpc')
-rw-r--r-- | src/rpc/Cargo.toml | 12 | ||||
-rw-r--r-- | src/rpc/lib.rs | 8 | ||||
-rw-r--r-- | src/rpc/membership.rs | 722 | ||||
-rw-r--r-- | src/rpc/ring.rs | 11 | ||||
-rw-r--r-- | src/rpc/rpc_client.rs | 369 | ||||
-rw-r--r-- | src/rpc/rpc_helper.rs | 206 | ||||
-rw-r--r-- | src/rpc/rpc_server.rs | 247 | ||||
-rw-r--r-- | src/rpc/system.rs | 363 | ||||
-rw-r--r-- | src/rpc/tls_util.rs | 140 |
9 files changed, 587 insertions, 1491 deletions
diff --git a/src/rpc/Cargo.toml b/src/rpc/Cargo.toml index f1204cdf..1100c737 100644 --- a/src/rpc/Cargo.toml +++ b/src/rpc/Cargo.toml @@ -22,7 +22,10 @@ bytes = "1.0" gethostname = "0.2" hex = "0.4" log = "0.4" +rand = "0.8" +sodiumoxide = { version = "0.2.5-0", package = "kuska-sodiumoxide" } +async-trait = "0.1.7" rmp-serde = "0.15" serde = { version = "1.0", default-features = false, features = ["derive", "rc"] } serde_json = "1.0" @@ -32,11 +35,6 @@ futures-util = "0.3" tokio = { version = "1.0", default-features = false, features = ["rt", "rt-multi-thread", "io-util", "net", "time", "macros", "sync", "signal", "fs"] } tokio-stream = { version = "0.1", features = ["net"] } -http = "0.2" -hyper = { version = "0.14", features = ["full"] } -hyper-rustls = { version = "0.22", default-features = false } -rustls = "0.19" -tokio-rustls = "0.22" -webpki = "0.21" - +netapp = { version = "0.3.0", git = "https://git.deuxfleurs.fr/lx/netapp" } +hyper = "0.14" diff --git a/src/rpc/lib.rs b/src/rpc/lib.rs index 96561d0e..ea3f1139 100644 --- a/src/rpc/lib.rs +++ b/src/rpc/lib.rs @@ -4,10 +4,10 @@ extern crate log; mod consul; -pub(crate) mod tls_util; -pub mod membership; pub mod ring; +pub mod system; -pub mod rpc_client; -pub mod rpc_server; +pub mod rpc_helper; + +pub use rpc_helper::*; diff --git a/src/rpc/membership.rs b/src/rpc/membership.rs deleted file mode 100644 index a77eeed3..00000000 --- a/src/rpc/membership.rs +++ /dev/null @@ -1,722 +0,0 @@ -//! Module containing structs related to membership management -use std::collections::HashMap; -use std::fmt::Write as FmtWrite; -use std::io::{Read, Write}; -use std::net::{IpAddr, SocketAddr}; -use std::path::{Path, PathBuf}; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; -use std::time::Duration; - -use futures::future::join_all; -use futures::select; -use futures_util::future::*; -use serde::{Deserialize, Serialize}; -use tokio::sync::watch; -use tokio::sync::Mutex; - -use garage_util::background::BackgroundRunner; -use garage_util::data::*; -use garage_util::error::Error; -use garage_util::persister::Persister; -use garage_util::time::*; - -use crate::consul::get_consul_nodes; -use crate::ring::*; -use crate::rpc_client::*; -use crate::rpc_server::*; - -const PING_INTERVAL: Duration = Duration::from_secs(10); -const DISCOVERY_INTERVAL: Duration = Duration::from_secs(60); -const PING_TIMEOUT: Duration = Duration::from_secs(2); -const MAX_FAILURES_BEFORE_CONSIDERED_DOWN: usize = 5; - -/// RPC endpoint used for calls related to membership -pub const MEMBERSHIP_RPC_PATH: &str = "_membership"; - -/// RPC messages related to membership -#[derive(Debug, Serialize, Deserialize)] -pub enum Message { - /// Response to successfull advertisements - Ok, - /// Message sent to detect other nodes status - Ping(PingMessage), - /// Ask other node for the nodes it knows. Answered with AdvertiseNodesUp - PullStatus, - /// Ask other node its config. Answered with AdvertiseConfig - PullConfig, - /// Advertisement of nodes the host knows up. Sent spontanously or in response to PullStatus - AdvertiseNodesUp(Vec<AdvertisedNode>), - /// Advertisement of nodes config. Sent spontanously or in response to PullConfig - AdvertiseConfig(NetworkConfig), -} - -impl RpcMessage for Message {} - -/// A ping, containing informations about status and config -#[derive(Debug, Serialize, Deserialize)] -pub struct PingMessage { - id: Uuid, - rpc_port: u16, - - status_hash: Hash, - config_version: u64, - - state_info: StateInfo, -} - -/// A node advertisement -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct AdvertisedNode { - /// Id of the node this advertisement relates to - pub id: Uuid, - /// IP and port of the node - pub addr: SocketAddr, - - /// Is the node considered up - pub is_up: bool, - /// When was the node last seen up, in milliseconds since UNIX epoch - pub last_seen: u64, - - pub state_info: StateInfo, -} - -/// This node's membership manager -pub struct System { - /// The id of this node - pub id: Uuid, - - persist_config: Persister<NetworkConfig>, - persist_status: Persister<Vec<AdvertisedNode>>, - rpc_local_port: u16, - - state_info: StateInfo, - - rpc_http_client: Arc<RpcHttpClient>, - rpc_client: Arc<RpcClient<Message>>, - - replication_factor: usize, - pub(crate) status: watch::Receiver<Arc<Status>>, - /// The ring - pub ring: watch::Receiver<Arc<Ring>>, - - update_lock: Mutex<Updaters>, - - /// The job runner of this node - pub background: Arc<BackgroundRunner>, -} - -struct Updaters { - update_status: watch::Sender<Arc<Status>>, - update_ring: watch::Sender<Arc<Ring>>, -} - -/// The status of each nodes, viewed by this node -#[derive(Debug, Clone)] -pub struct Status { - /// Mapping of each node id to its known status - pub nodes: HashMap<Uuid, Arc<StatusEntry>>, - /// Hash of `nodes`, used to detect when nodes have different views of the cluster - pub hash: Hash, -} - -/// The status of a single node -#[derive(Debug)] -pub struct StatusEntry { - /// The IP and port used to connect to this node - pub addr: SocketAddr, - /// Last time this node was seen - pub last_seen: u64, - /// Number of consecutive pings sent without reply to this node - pub num_failures: AtomicUsize, - pub state_info: StateInfo, -} - -impl StatusEntry { - /// is the node associated to this entry considered up - pub fn is_up(&self) -> bool { - self.num_failures.load(Ordering::SeqCst) < MAX_FAILURES_BEFORE_CONSIDERED_DOWN - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct StateInfo { - /// Hostname of the node - pub hostname: String, - /// Replication factor configured on the node - pub replication_factor: Option<usize>, // TODO Option is just for retrocompatibility. It should become a simple usize at some point -} - -impl Status { - fn handle_ping(&mut self, ip: IpAddr, info: &PingMessage) -> bool { - let addr = SocketAddr::new(ip, info.rpc_port); - let old_status = self.nodes.insert( - info.id, - Arc::new(StatusEntry { - addr, - last_seen: now_msec(), - num_failures: AtomicUsize::from(0), - state_info: info.state_info.clone(), - }), - ); - match old_status { - None => { - info!("Newly pingable node: {}", hex::encode(&info.id)); - true - } - Some(x) => x.addr != addr, - } - } - - fn recalculate_hash(&mut self) { - let mut nodes = self.nodes.iter().collect::<Vec<_>>(); - nodes.sort_unstable_by_key(|(id, _status)| *id); - - let mut nodes_txt = String::new(); - debug!("Current set of pingable nodes: --"); - for (id, status) in nodes { - debug!("{} {}", hex::encode(&id), status.addr); - writeln!(&mut nodes_txt, "{} {}", hex::encode(&id), status.addr).unwrap(); - } - debug!("END --"); - self.hash = blake2sum(nodes_txt.as_bytes()); - } - - fn to_serializable_membership(&self, system: &System) -> Vec<AdvertisedNode> { - let mut mem = vec![]; - for (node, status) in self.nodes.iter() { - let state_info = if *node == system.id { - system.state_info.clone() - } else { - status.state_info.clone() - }; - mem.push(AdvertisedNode { - id: *node, - addr: status.addr, - is_up: status.is_up(), - last_seen: status.last_seen, - state_info, - }); - } - mem - } -} - -fn gen_node_id(metadata_dir: &Path) -> Result<Uuid, Error> { - let mut id_file = metadata_dir.to_path_buf(); - id_file.push("node_id"); - if id_file.as_path().exists() { - let mut f = std::fs::File::open(id_file.as_path())?; - let mut d = vec![]; - f.read_to_end(&mut d)?; - if d.len() != 32 { - return Err(Error::Message("Corrupt node_id file".to_string())); - } - - let mut id = [0u8; 32]; - id.copy_from_slice(&d[..]); - Ok(id.into()) - } else { - let id = gen_uuid(); - - let mut f = std::fs::File::create(id_file.as_path())?; - f.write_all(id.as_slice())?; - Ok(id) - } -} - -impl System { - /// Create this node's membership manager - pub fn new( - metadata_dir: PathBuf, - rpc_http_client: Arc<RpcHttpClient>, - background: Arc<BackgroundRunner>, - rpc_server: &mut RpcServer, - replication_factor: usize, - ) -> Arc<Self> { - let id = gen_node_id(&metadata_dir).expect("Unable to read or generate node ID"); - info!("Node ID: {}", hex::encode(&id)); - - let persist_config = Persister::new(&metadata_dir, "network_config"); - let persist_status = Persister::new(&metadata_dir, "peer_info"); - - let net_config = match persist_config.load() { - Ok(x) => x, - Err(e) => { - match Persister::<garage_rpc_021::ring::NetworkConfig>::new( - &metadata_dir, - "network_config", - ) - .load() - { - Ok(old_config) => NetworkConfig::migrate_from_021(old_config), - Err(e2) => { - info!( - "No valid previous network configuration stored ({}, {}), starting fresh.", - e, e2 - ); - NetworkConfig::new() - } - } - } - }; - - let mut status = Status { - nodes: HashMap::new(), - hash: Hash::default(), - }; - status.recalculate_hash(); - let (update_status, status) = watch::channel(Arc::new(status)); - - let state_info = StateInfo { - hostname: gethostname::gethostname() - .into_string() - .unwrap_or_else(|_| "<invalid utf-8>".to_string()), - replication_factor: Some(replication_factor), - }; - - let ring = Ring::new(net_config, replication_factor); - let (update_ring, ring) = watch::channel(Arc::new(ring)); - - let rpc_path = MEMBERSHIP_RPC_PATH.to_string(); - let rpc_client = RpcClient::new( - RpcAddrClient::<Message>::new(rpc_http_client.clone(), rpc_path.clone()), - background.clone(), - status.clone(), - ); - - let sys = Arc::new(System { - id, - persist_config, - persist_status, - rpc_local_port: rpc_server.bind_addr.port(), - state_info, - rpc_http_client, - rpc_client, - replication_factor, - status, - ring, - update_lock: Mutex::new(Updaters { - update_status, - update_ring, - }), - background, - }); - sys.clone().register_handler(rpc_server, rpc_path); - sys - } - - fn register_handler(self: Arc<Self>, rpc_server: &mut RpcServer, path: String) { - rpc_server.add_handler::<Message, _, _>(path, move |msg, addr| { - let self2 = self.clone(); - async move { - match msg { - Message::Ping(ping) => self2.handle_ping(&addr, &ping).await, - - Message::PullStatus => Ok(self2.handle_pull_status()), - Message::PullConfig => Ok(self2.handle_pull_config()), - Message::AdvertiseNodesUp(adv) => self2.handle_advertise_nodes_up(&adv).await, - Message::AdvertiseConfig(adv) => self2.handle_advertise_config(&adv).await, - - _ => Err(Error::BadRpc("Unexpected RPC message".to_string())), - } - } - }); - } - - /// Get an RPC client - pub fn rpc_client<M: RpcMessage + 'static>(self: &Arc<Self>, path: &str) -> Arc<RpcClient<M>> { - RpcClient::new( - RpcAddrClient::new(self.rpc_http_client.clone(), path.to_string()), - self.background.clone(), - self.status.clone(), - ) - } - - /// Save network configuration to disc - async fn save_network_config(self: Arc<Self>) -> Result<(), Error> { - let ring = self.ring.borrow().clone(); - self.persist_config - .save_async(&ring.config) - .await - .expect("Cannot save current cluster configuration"); - Ok(()) - } - - fn make_ping(&self) -> Message { - let status = self.status.borrow().clone(); - let ring = self.ring.borrow().clone(); - Message::Ping(PingMessage { - id: self.id, - rpc_port: self.rpc_local_port, - status_hash: status.hash, - config_version: ring.config.version, - state_info: self.state_info.clone(), - }) - } - - async fn broadcast(self: Arc<Self>, msg: Message, timeout: Duration) { - let status = self.status.borrow().clone(); - let to = status - .nodes - .keys() - .filter(|x| **x != self.id) - .cloned() - .collect::<Vec<_>>(); - self.rpc_client.call_many(&to[..], msg, timeout).await; - } - - /// Perform bootstraping, starting the ping loop - pub async fn bootstrap( - self: Arc<Self>, - peers: Vec<SocketAddr>, - consul_host: Option<String>, - consul_service_name: Option<String>, - ) { - let self2 = self.clone(); - self.background - .spawn_worker("discovery loop".to_string(), |stop_signal| { - self2.discovery_loop(peers, consul_host, consul_service_name, stop_signal) - }); - - let self2 = self.clone(); - self.background - .spawn_worker("ping loop".to_string(), |stop_signal| { - self2.ping_loop(stop_signal) - }); - } - - async fn ping_nodes(self: Arc<Self>, peers: Vec<(SocketAddr, Option<Uuid>)>) { - let ping_msg = self.make_ping(); - let ping_resps = join_all(peers.iter().map(|(addr, id_option)| { - let sys = self.clone(); - let ping_msg_ref = &ping_msg; - async move { - ( - id_option, - addr, - sys.rpc_client - .by_addr() - .call(&addr, ping_msg_ref, PING_TIMEOUT) - .await, - ) - } - })) - .await; - - let update_locked = self.update_lock.lock().await; - let mut status: Status = self.status.borrow().as_ref().clone(); - let ring = self.ring.borrow().clone(); - - let mut has_changes = false; - let mut to_advertise = vec![]; - - for (id_option, addr, ping_resp) in ping_resps { - if let Ok(Ok(Message::Ping(info))) = ping_resp { - let is_new = status.handle_ping(addr.ip(), &info); - if is_new { - has_changes = true; - to_advertise.push(AdvertisedNode { - id: info.id, - addr: *addr, - is_up: true, - last_seen: now_msec(), - state_info: info.state_info.clone(), - }); - } - if is_new || status.hash != info.status_hash { - self.background - .spawn_cancellable(self.clone().pull_status(info.id).map(Ok)); - } - if is_new || ring.config.version < info.config_version { - self.background - .spawn_cancellable(self.clone().pull_config(info.id).map(Ok)); - } - } else if let Some(id) = id_option { - if let Some(st) = status.nodes.get_mut(id) { - // we need to increment failure counter as call was done using by_addr so the - // counter was not auto-incremented - st.num_failures.fetch_add(1, Ordering::SeqCst); - if !st.is_up() { - warn!("Node {:?} seems to be down.", id); - if !ring.config.members.contains_key(id) { - info!("Removing node {:?} from status (not in config and not responding to pings anymore)", id); - status.nodes.remove(&id); - has_changes = true; - } - } - } - } - } - if has_changes { - status.recalculate_hash(); - } - self.update_status(&update_locked, status).await; - drop(update_locked); - - if !to_advertise.is_empty() { - self.broadcast(Message::AdvertiseNodesUp(to_advertise), PING_TIMEOUT) - .await; - } - } - - async fn handle_ping( - self: Arc<Self>, - from: &SocketAddr, - ping: &PingMessage, - ) -> Result<Message, Error> { - let update_locked = self.update_lock.lock().await; - let mut status: Status = self.status.borrow().as_ref().clone(); - - let is_new = status.handle_ping(from.ip(), ping); - if is_new { - status.recalculate_hash(); - } - let status_hash = status.hash; - let config_version = self.ring.borrow().config.version; - - self.update_status(&update_locked, status).await; - drop(update_locked); - - if is_new || status_hash != ping.status_hash { - self.background - .spawn_cancellable(self.clone().pull_status(ping.id).map(Ok)); - } - if is_new || config_version < ping.config_version { - self.background - .spawn_cancellable(self.clone().pull_config(ping.id).map(Ok)); - } - - Ok(self.make_ping()) - } - - fn handle_pull_status(&self) -> Message { - Message::AdvertiseNodesUp(self.status.borrow().to_serializable_membership(self)) - } - - fn handle_pull_config(&self) -> Message { - let ring = self.ring.borrow().clone(); - Message::AdvertiseConfig(ring.config.clone()) - } - - async fn handle_advertise_nodes_up( - self: Arc<Self>, - adv: &[AdvertisedNode], - ) -> Result<Message, Error> { - let mut to_ping = vec![]; - - let update_lock = self.update_lock.lock().await; - let mut status: Status = self.status.borrow().as_ref().clone(); - let mut has_changed = false; - let mut max_replication_factor = 0; - - for node in adv.iter() { - if node.id == self.id { - // learn our own ip address - let self_addr = SocketAddr::new(node.addr.ip(), self.rpc_local_port); - let old_self = status.nodes.insert( - node.id, - Arc::new(StatusEntry { - addr: self_addr, - last_seen: now_msec(), - num_failures: AtomicUsize::from(0), - state_info: self.state_info.clone(), - }), - ); - has_changed = match old_self { - None => true, - Some(x) => x.addr != self_addr, - }; - } else { - let ping_them = match status.nodes.get(&node.id) { - // Case 1: new node - None => true, - // Case 2: the node might have changed address - Some(our_node) => node.is_up && !our_node.is_up() && our_node.addr != node.addr, - }; - max_replication_factor = std::cmp::max( - max_replication_factor, - node.state_info.replication_factor.unwrap_or_default(), - ); - if ping_them { - to_ping.push((node.addr, Some(node.id))); - } - } - } - - if self.replication_factor < max_replication_factor { - error!("Some node have a higher replication factor ({}) than this one ({}). This is not supported and might lead to bugs", - max_replication_factor, - self.replication_factor); - std::process::exit(1); - } - if has_changed { - status.recalculate_hash(); - } - self.update_status(&update_lock, status).await; - drop(update_lock); - - if !to_ping.is_empty() { - self.background - .spawn_cancellable(self.clone().ping_nodes(to_ping).map(Ok)); - } - - Ok(Message::Ok) - } - - async fn handle_advertise_config( - self: Arc<Self>, - adv: &NetworkConfig, - ) -> Result<Message, Error> { - let update_lock = self.update_lock.lock().await; - let ring: Arc<Ring> = self.ring.borrow().clone(); - - if adv.version > ring.config.version { - let ring = Ring::new(adv.clone(), self.replication_factor); - update_lock.update_ring.send(Arc::new(ring))?; - drop(update_lock); - - self.background.spawn_cancellable( - self.clone() - .broadcast(Message::AdvertiseConfig(adv.clone()), PING_TIMEOUT) - .map(Ok), - ); - self.background.spawn(self.clone().save_network_config()); - } - - Ok(Message::Ok) - } - - async fn ping_loop(self: Arc<Self>, mut stop_signal: watch::Receiver<bool>) { - while !*stop_signal.borrow() { - let restart_at = tokio::time::sleep(PING_INTERVAL); - - let status = self.status.borrow().clone(); - let ping_addrs = status - .nodes - .iter() - .filter(|(id, _)| **id != self.id) - .map(|(id, status)| (status.addr, Some(*id))) - .collect::<Vec<_>>(); - - self.clone().ping_nodes(ping_addrs).await; - - select! { - _ = restart_at.fuse() => {}, - _ = stop_signal.changed().fuse() => {}, - } - } - } - - async fn discovery_loop( - self: Arc<Self>, - bootstrap_peers: Vec<SocketAddr>, - consul_host: Option<String>, - consul_service_name: Option<String>, - mut stop_signal: watch::Receiver<bool>, - ) { - let consul_config = match (consul_host, consul_service_name) { - (Some(ch), Some(csn)) => Some((ch, csn)), - _ => None, - }; - - while !*stop_signal.borrow() { - let not_configured = self.ring.borrow().config.members.is_empty(); - let no_peers = self.status.borrow().nodes.len() < 3; - let bad_peers = self - .status - .borrow() - .nodes - .iter() - .filter(|(_, v)| v.is_up()) - .count() != self.ring.borrow().config.members.len(); - - if not_configured || no_peers || bad_peers { - info!("Doing a bootstrap/discovery step (not_configured: {}, no_peers: {}, bad_peers: {})", not_configured, no_peers, bad_peers); - - let mut ping_list = bootstrap_peers - .iter() - .map(|ip| (*ip, None)) - .collect::<Vec<_>>(); - - if let Ok(peers) = self.persist_status.load_async().await { - ping_list.extend(peers.iter().map(|x| (x.addr, Some(x.id)))); - } - - if let Some((consul_host, consul_service_name)) = &consul_config { - match get_consul_nodes(consul_host, consul_service_name).await { - Ok(node_list) => { - ping_list.extend(node_list.iter().map(|a| (*a, None))); - } - Err(e) => { - warn!("Could not retrieve node list from Consul: {}", e); - } - } - } - - self.clone().ping_nodes(ping_list).await; - } - - let restart_at = tokio::time::sleep(DISCOVERY_INTERVAL); - select! { - _ = restart_at.fuse() => {}, - _ = stop_signal.changed().fuse() => {}, - } - } - } - - // for some reason fixing this is causing compilation error, see https://github.com/rust-lang/rust-clippy/issues/7052 - #[allow(clippy::manual_async_fn)] - fn pull_status( - self: Arc<Self>, - peer: Uuid, - ) -> impl futures::future::Future<Output = ()> + Send + 'static { - async move { - let resp = self - .rpc_client - .call(peer, Message::PullStatus, PING_TIMEOUT) - .await; - if let Ok(Message::AdvertiseNodesUp(nodes)) = resp { - let _: Result<_, _> = self.handle_advertise_nodes_up(&nodes).await; - } - } - } - - async fn pull_config(self: Arc<Self>, peer: Uuid) { - let resp = self - .rpc_client - .call(peer, Message::PullConfig, PING_TIMEOUT) - .await; - if let Ok(Message::AdvertiseConfig(config)) = resp { - let _: Result<_, _> = self.handle_advertise_config(&config).await; - } - } - - async fn update_status(self: &Arc<Self>, updaters: &Updaters, status: Status) { - if status.hash != self.status.borrow().hash { - let mut list = status.to_serializable_membership(&self); - - // Combine with old peer list to make sure no peer is lost - if let Ok(old_list) = self.persist_status.load_async().await { - for pp in old_list { - if !list.iter().any(|np| pp.id == np.id) { - list.push(pp); - } - } - } - - if !list.is_empty() { - info!("Persisting new peer list ({} peers)", list.len()); - self.persist_status - .save_async(&list) - .await - .expect("Unable to persist peer list"); - } - } - - updaters - .update_status - .send(Arc::new(status)) - .expect("Could not update internal membership status"); - } -} diff --git a/src/rpc/ring.rs b/src/rpc/ring.rs index 90db8fd2..7cbab762 100644 --- a/src/rpc/ring.rs +++ b/src/rpc/ring.rs @@ -3,6 +3,8 @@ use std::collections::{HashMap, HashSet}; use std::convert::TryInto; +use netapp::NodeID; + use serde::{Deserialize, Serialize}; use garage_util::data::*; @@ -98,7 +100,7 @@ pub struct Ring { pub config: NetworkConfig, // Internal order of nodes used to make a more compact representation of the ring - nodes: Vec<Uuid>, + nodes: Vec<NodeID>, // The list of entries in the ring ring: Vec<RingEntry>, @@ -260,6 +262,11 @@ impl Ring { }) .collect::<Vec<_>>(); + let nodes = nodes + .iter() + .map(|id| NodeID::from_slice(id.as_slice()).unwrap()) + .collect::<Vec<_>>(); + Self { replication_factor, config, @@ -291,7 +298,7 @@ impl Ring { } /// Walk the ring to find the n servers in which data should be replicated - pub fn get_nodes(&self, position: &Hash, n: usize) -> Vec<Uuid> { + pub fn get_nodes(&self, position: &Hash, n: usize) -> Vec<NodeID> { if self.ring.len() != 1 << PARTITION_BITS { warn!("Ring not yet ready, read/writes will be lost!"); return vec![]; diff --git a/src/rpc/rpc_client.rs b/src/rpc/rpc_client.rs deleted file mode 100644 index 806c7e69..00000000 --- a/src/rpc/rpc_client.rs +++ /dev/null @@ -1,369 +0,0 @@ -//! Contain structs related to making RPCs -use std::borrow::Borrow; -use std::marker::PhantomData; -use std::net::SocketAddr; -use std::pin::Pin; -use std::sync::atomic::Ordering; -use std::sync::Arc; -use std::time::Duration; - -use arc_swap::ArcSwapOption; -use futures::future::Future; -use futures::stream::futures_unordered::FuturesUnordered; -use futures::stream::StreamExt; -use futures_util::future::FutureExt; -use hyper::client::{Client, HttpConnector}; -use hyper::{Body, Method, Request}; -use tokio::sync::{watch, Semaphore}; - -use garage_util::background::BackgroundRunner; -use garage_util::config::TlsConfig; -use garage_util::data::*; -use garage_util::error::{Error, RpcError}; - -use crate::membership::Status; -use crate::rpc_server::RpcMessage; -use crate::tls_util; - -const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10); - -/// Strategy to apply when making RPC -#[derive(Copy, Clone)] -pub struct RequestStrategy { - /// Max time to wait for reponse - pub rs_timeout: Duration, - /// Min number of response to consider the request successful - pub rs_quorum: usize, - /// Should requests be dropped after enough response are received - pub rs_interrupt_after_quorum: bool, -} - -impl RequestStrategy { - /// Create a RequestStrategy with default timeout and not interrupting when quorum reached - pub fn with_quorum(quorum: usize) -> Self { - RequestStrategy { - rs_timeout: DEFAULT_TIMEOUT, - rs_quorum: quorum, - rs_interrupt_after_quorum: false, - } - } - /// Set timeout of the strategy - pub fn with_timeout(mut self, timeout: Duration) -> Self { - self.rs_timeout = timeout; - self - } - /// Set if requests can be dropped after quorum has been reached - /// In general true for read requests, and false for write - pub fn interrupt_after_quorum(mut self, interrupt: bool) -> Self { - self.rs_interrupt_after_quorum = interrupt; - self - } -} - -/// Shortcut for a boxed async function taking a message, and resolving to another message or an -/// error -pub type LocalHandlerFn<M> = - Box<dyn Fn(Arc<M>) -> Pin<Box<dyn Future<Output = Result<M, Error>> + Send>> + Send + Sync>; - -/// Client used to send RPC -pub struct RpcClient<M: RpcMessage> { - status: watch::Receiver<Arc<Status>>, - background: Arc<BackgroundRunner>, - - local_handler: ArcSwapOption<(Uuid, LocalHandlerFn<M>)>, - - rpc_addr_client: RpcAddrClient<M>, -} - -impl<M: RpcMessage + 'static> RpcClient<M> { - /// Create a new RpcClient from an address, a job runner, and the status of all RPC servers - pub fn new( - rac: RpcAddrClient<M>, - background: Arc<BackgroundRunner>, - status: watch::Receiver<Arc<Status>>, - ) -> Arc<Self> { - Arc::new(Self { - rpc_addr_client: rac, - background, - status, - local_handler: ArcSwapOption::new(None), - }) - } - - /// Set the local handler, to process RPC to this node without network usage - pub fn set_local_handler<F, Fut>(&self, my_id: Uuid, handler: F) - where - F: Fn(Arc<M>) -> Fut + Send + Sync + 'static, - Fut: Future<Output = Result<M, Error>> + Send + 'static, - { - let handler_arc = Arc::new(handler); - let handler: LocalHandlerFn<M> = Box::new(move |msg| { - let handler_arc2 = handler_arc.clone(); - Box::pin(async move { handler_arc2(msg).await }) - }); - self.local_handler.swap(Some(Arc::new((my_id, handler)))); - } - - /// Get a RPC client to make calls using node's SocketAddr instead of its ID - pub fn by_addr(&self) -> &RpcAddrClient<M> { - &self.rpc_addr_client - } - - /// Make a RPC call - pub async fn call(&self, to: Uuid, msg: M, timeout: Duration) -> Result<M, Error> { - self.call_arc(to, Arc::new(msg), timeout).await - } - - /// Make a RPC call from a message stored in an Arc - pub async fn call_arc(&self, to: Uuid, msg: Arc<M>, timeout: Duration) -> Result<M, Error> { - if let Some(lh) = self.local_handler.load_full() { - let (my_id, local_handler) = lh.as_ref(); - if to.borrow() == my_id { - return local_handler(msg).await; - } - } - let status = self.status.borrow().clone(); - let node_status = match status.nodes.get(&to) { - Some(node_status) => { - if node_status.is_up() { - node_status - } else { - return Err(Error::from(RpcError::NodeDown(to))); - } - } - None => { - return Err(Error::Message(format!( - "Peer ID not found: {:?}", - to.borrow() - ))) - } - }; - match self - .rpc_addr_client - .call(&node_status.addr, msg, timeout) - .await - { - Err(rpc_error) => { - node_status.num_failures.fetch_add(1, Ordering::SeqCst); - Err(Error::from(rpc_error)) - } - Ok(x) => x, - } - } - - /// Make a RPC call to multiple servers, returning a Vec containing each result - pub async fn call_many(&self, to: &[Uuid], msg: M, timeout: Duration) -> Vec<Result<M, Error>> { - let msg = Arc::new(msg); - let mut resp_stream = to - .iter() - .map(|to| self.call_arc(*to, msg.clone(), timeout)) - .collect::<FuturesUnordered<_>>(); - - let mut results = vec![]; - while let Some(resp) = resp_stream.next().await { - results.push(resp); - } - results - } - - /// Make a RPC call to multiple servers, returning either a Vec of responses, or an error if - /// strategy could not be respected due to too many errors - pub async fn try_call_many( - self: &Arc<Self>, - to: &[Uuid], - msg: M, - strategy: RequestStrategy, - ) -> Result<Vec<M>, Error> { - let timeout = strategy.rs_timeout; - - let msg = Arc::new(msg); - let mut resp_stream = to - .to_vec() - .into_iter() - .map(|to| { - let self2 = self.clone(); - let msg = msg.clone(); - async move { self2.call_arc(to, msg, timeout).await } - }) - .collect::<FuturesUnordered<_>>(); - - let mut results = vec![]; - let mut errors = vec![]; - - while let Some(resp) = resp_stream.next().await { - match resp { - Ok(msg) => { - results.push(msg); - if results.len() >= strategy.rs_quorum { - break; - } - } - Err(e) => { - errors.push(e); - } - } - } - - if results.len() >= strategy.rs_quorum { - // Continue requests in background. - // Continue the remaining requests immediately using tokio::spawn - // but enqueue a task in the background runner - // to ensure that the process won't exit until the requests are done - // (if we had just enqueued the resp_stream.collect directly in the background runner, - // the requests might have been put on hold in the background runner's queue, - // in which case they might timeout or otherwise fail) - if !strategy.rs_interrupt_after_quorum { - let wait_finished_fut = tokio::spawn(async move { - resp_stream.collect::<Vec<_>>().await; - }); - self.background.spawn(wait_finished_fut.map(|_| Ok(()))); - } - - Ok(results) - } else { - let errors = errors.iter().map(|e| format!("{}", e)).collect::<Vec<_>>(); - Err(Error::from(RpcError::TooManyErrors(errors))) - } - } -} - -/// Thin wrapper arround an `RpcHttpClient` specifying the path of the request -pub struct RpcAddrClient<M: RpcMessage> { - phantom: PhantomData<M>, - - http_client: Arc<RpcHttpClient>, - path: String, -} - -impl<M: RpcMessage> RpcAddrClient<M> { - /// Create an RpcAddrClient from an HTTP client and the endpoint to reach for RPCs - pub fn new(http_client: Arc<RpcHttpClient>, path: String) -> Self { - Self { - phantom: PhantomData::default(), - http_client, - path, - } - } - - /// Make a RPC - pub async fn call<MB>( - &self, - to_addr: &SocketAddr, - msg: MB, - timeout: Duration, - ) -> Result<Result<M, Error>, RpcError> - where - MB: Borrow<M>, - { - self.http_client - .call(&self.path, to_addr, msg, timeout) - .await - } -} - -/// HTTP client used to make RPCs -pub struct RpcHttpClient { - request_limiter: Semaphore, - method: ClientMethod, -} - -enum ClientMethod { - Http(Client<HttpConnector, hyper::Body>), - Https(Client<tls_util::HttpsConnectorFixedDnsname<HttpConnector>, hyper::Body>), -} - -impl RpcHttpClient { - /// Create a new RpcHttpClient - pub fn new( - max_concurrent_requests: usize, - tls_config: &Option<TlsConfig>, - ) -> Result<Self, Error> { - let method = if let Some(cf) = tls_config { - let ca_certs = tls_util::load_certs(&cf.ca_cert).map_err(|e| { - Error::Message(format!("Failed to open CA certificate file: {:?}", e)) - })?; - let node_certs = tls_util::load_certs(&cf.node_cert) - .map_err(|e| Error::Message(format!("Failed to open certificate file: {:?}", e)))?; - let node_key = tls_util::load_private_key(&cf.node_key) - .map_err(|e| Error::Message(format!("Failed to open private key file: {:?}", e)))?; - - let mut config = rustls::ClientConfig::new(); - - for crt in ca_certs.iter() { - config.root_store.add(crt)?; - } - - config.set_single_client_cert([&node_certs[..], &ca_certs[..]].concat(), node_key)?; - - let connector = - tls_util::HttpsConnectorFixedDnsname::<HttpConnector>::new(config, "garage"); - - ClientMethod::Https(Client::builder().build(connector)) - } else { - ClientMethod::Http(Client::new()) - }; - Ok(RpcHttpClient { - method, - request_limiter: Semaphore::new(max_concurrent_requests), - }) - } - - /// Make a RPC - async fn call<M, MB>( - &self, - path: &str, - to_addr: &SocketAddr, - msg: MB, - timeout: Duration, - ) -> Result<Result<M, Error>, RpcError> - where - MB: Borrow<M>, - M: RpcMessage, - { - let uri = match self.method { - ClientMethod::Http(_) => format!("http://{}/{}", to_addr, path), - ClientMethod::Https(_) => format!("https://{}/{}", to_addr, path), - }; - - let req = Request::builder() - .method(Method::POST) - .uri(uri) - .body(Body::from(rmp_to_vec_all_named(msg.borrow())?))?; - - let resp_fut = match &self.method { - ClientMethod::Http(client) => client.request(req).fuse(), - ClientMethod::Https(client) => client.request(req).fuse(), - }; - - trace!("({}) Acquiring request_limiter slot...", path); - let slot = self.request_limiter.acquire().await; - trace!("({}) Got slot, doing request to {}...", path, to_addr); - let resp = tokio::time::timeout(timeout, resp_fut) - .await - .map_err(|e| { - debug!( - "RPC timeout to {}: {}", - to_addr, - debug_serialize(msg.borrow()) - ); - e - })? - .map_err(|e| { - warn!( - "RPC HTTP client error when connecting to {}: {}", - to_addr, e - ); - e - })?; - - let status = resp.status(); - trace!("({}) Request returned, got status {}", path, status); - let body = hyper::body::to_bytes(resp.into_body()).await?; - drop(slot); - - match rmp_serde::decode::from_read::<_, Result<M, String>>(&body[..])? { - Err(e) => Ok(Err(Error::RemoteError(e, status))), - Ok(x) => Ok(Ok(x)), - } - } -} diff --git a/src/rpc/rpc_helper.rs b/src/rpc/rpc_helper.rs new file mode 100644 index 00000000..c9458ee6 --- /dev/null +++ b/src/rpc/rpc_helper.rs @@ -0,0 +1,206 @@ +//! Contain structs related to making RPCs +use std::sync::Arc; +use std::time::Duration; + +use futures::future::join_all; +use futures::stream::futures_unordered::FuturesUnordered; +use futures::stream::StreamExt; +use futures_util::future::FutureExt; +use tokio::select; + +pub use netapp::endpoint::{Endpoint, EndpointHandler, Message}; +use netapp::peering::fullmesh::FullMeshPeeringStrategy; +pub use netapp::proto::*; +pub use netapp::{NetApp, NodeID}; + +use garage_util::background::BackgroundRunner; +use garage_util::error::{Error, RpcError}; + +const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10); + +/// Strategy to apply when making RPC +#[derive(Copy, Clone)] +pub struct RequestStrategy { + /// Max time to wait for reponse + pub rs_timeout: Duration, + /// Min number of response to consider the request successful + pub rs_quorum: Option<usize>, + /// Should requests be dropped after enough response are received + pub rs_interrupt_after_quorum: bool, + /// Request priority + pub rs_priority: RequestPriority, +} + +impl RequestStrategy { + /// Create a RequestStrategy with default timeout and not interrupting when quorum reached + pub fn with_priority(prio: RequestPriority) -> Self { + RequestStrategy { + rs_timeout: DEFAULT_TIMEOUT, + rs_quorum: None, + rs_interrupt_after_quorum: false, + rs_priority: prio, + } + } + /// Set quorum to be reached for request + pub fn with_quorum(mut self, quorum: usize) -> Self { + self.rs_quorum = Some(quorum); + self + } + /// Set timeout of the strategy + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.rs_timeout = timeout; + self + } + /// Set if requests can be dropped after quorum has been reached + /// In general true for read requests, and false for write + pub fn interrupt_after_quorum(mut self, interrupt: bool) -> Self { + self.rs_interrupt_after_quorum = interrupt; + self + } +} + +#[derive(Clone)] +pub struct RpcHelper { + pub(crate) fullmesh: Arc<FullMeshPeeringStrategy>, + pub(crate) background: Arc<BackgroundRunner>, +} + +impl RpcHelper { + pub async fn call<M, H>( + &self, + endpoint: &Endpoint<M, H>, + to: NodeID, + msg: M, + strat: RequestStrategy, + ) -> Result<M::Response, Error> + where + M: Message, + H: EndpointHandler<M>, + { + self.call_arc(endpoint, to, Arc::new(msg), strat).await + } + + pub async fn call_arc<M, H>( + &self, + endpoint: &Endpoint<M, H>, + to: NodeID, + msg: Arc<M>, + strat: RequestStrategy, + ) -> Result<M::Response, Error> + where + M: Message, + H: EndpointHandler<M>, + { + select! { + res = endpoint.call(&to, &msg, strat.rs_priority) => Ok(res?), + _ = tokio::time::sleep(strat.rs_timeout) => Err(Error::Rpc(RpcError::Timeout)), + } + } + + pub async fn call_many<M, H>( + &self, + endpoint: &Endpoint<M, H>, + to: &[NodeID], + msg: M, + strat: RequestStrategy, + ) -> Vec<(NodeID, Result<M::Response, Error>)> + where + M: Message, + H: EndpointHandler<M>, + { + let msg = Arc::new(msg); + let resps = join_all( + to.iter() + .map(|to| self.call_arc(endpoint, *to, msg.clone(), strat)), + ) + .await; + to.iter() + .cloned() + .zip(resps.into_iter()) + .collect::<Vec<_>>() + } + + pub async fn broadcast<M, H>( + &self, + endpoint: &Endpoint<M, H>, + msg: M, + strat: RequestStrategy, + ) -> Vec<(NodeID, Result<M::Response, Error>)> + where + M: Message, + H: EndpointHandler<M>, + { + let to = self + .fullmesh + .get_peer_list() + .iter() + .map(|p| p.id) + .collect::<Vec<_>>(); + self.call_many(endpoint, &to[..], msg, strat).await + } + + /// Make a RPC call to multiple servers, returning either a Vec of responses, or an error if + /// strategy could not be respected due to too many errors + pub async fn try_call_many<M, H>( + &self, + endpoint: &Arc<Endpoint<M, H>>, + to: &[NodeID], + msg: M, + strategy: RequestStrategy, + ) -> Result<Vec<M::Response>, Error> + where + M: Message + 'static, + H: EndpointHandler<M> + 'static, + { + let msg = Arc::new(msg); + let mut resp_stream = to + .to_vec() + .into_iter() + .map(|to| { + let self2 = self.clone(); + let msg = msg.clone(); + let endpoint2 = endpoint.clone(); + async move { self2.call_arc(&endpoint2, to, msg, strategy).await } + }) + .collect::<FuturesUnordered<_>>(); + + let mut results = vec![]; + let mut errors = vec![]; + let quorum = strategy.rs_quorum.unwrap_or(to.len()); + + while let Some(resp) = resp_stream.next().await { + match resp { + Ok(msg) => { + results.push(msg); + if results.len() >= quorum { + break; + } + } + Err(e) => { + errors.push(e); + } + } + } + + if results.len() >= quorum { + // Continue requests in background. + // Continue the remaining requests immediately using tokio::spawn + // but enqueue a task in the background runner + // to ensure that the process won't exit until the requests are done + // (if we had just enqueued the resp_stream.collect directly in the background runner, + // the requests might have been put on hold in the background runner's queue, + // in which case they might timeout or otherwise fail) + if !strategy.rs_interrupt_after_quorum { + let wait_finished_fut = tokio::spawn(async move { + resp_stream.collect::<Vec<_>>().await; + }); + self.background.spawn(wait_finished_fut.map(|_| Ok(()))); + } + + Ok(results) + } else { + let errors = errors.iter().map(|e| format!("{}", e)).collect::<Vec<_>>(); + Err(Error::from(RpcError::TooManyErrors(errors))) + } + } +} diff --git a/src/rpc/rpc_server.rs b/src/rpc/rpc_server.rs deleted file mode 100644 index 81361ab9..00000000 --- a/src/rpc/rpc_server.rs +++ /dev/null @@ -1,247 +0,0 @@ -//! Contains structs related to receiving RPCs -use std::collections::HashMap; -use std::net::SocketAddr; -use std::pin::Pin; -use std::sync::Arc; -use std::time::Instant; - -use futures::future::Future; -use futures_util::future::*; -use futures_util::stream::*; -use hyper::server::conn::AddrStream; -use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Method, Request, Response, Server, StatusCode}; -use serde::{Deserialize, Serialize}; -use tokio::net::{TcpListener, TcpStream}; -use tokio_rustls::server::TlsStream; -use tokio_rustls::TlsAcceptor; -use tokio_stream::wrappers::TcpListenerStream; - -use garage_util::config::TlsConfig; -use garage_util::data::*; -use garage_util::error::Error; - -use crate::tls_util; - -/// Trait for messages that can be sent as RPC -pub trait RpcMessage: Serialize + for<'de> Deserialize<'de> + Send + Sync {} - -type ResponseFuture = Pin<Box<dyn Future<Output = Result<Response<Body>, Error>> + Send>>; -type Handler = Box<dyn Fn(Request<Body>, SocketAddr) -> ResponseFuture + Send + Sync>; - -/// Structure handling RPCs -pub struct RpcServer { - /// The address the RpcServer will bind - pub bind_addr: SocketAddr, - /// The tls configuration used for RPC - pub tls_config: Option<TlsConfig>, - - handlers: HashMap<String, Handler>, -} - -async fn handle_func<M, F, Fut>( - handler: Arc<F>, - req: Request<Body>, - sockaddr: SocketAddr, - name: Arc<String>, -) -> Result<Response<Body>, Error> -where - M: RpcMessage + 'static, - F: Fn(M, SocketAddr) -> Fut + Send + Sync + 'static, - Fut: Future<Output = Result<M, Error>> + Send + 'static, -{ - let begin_time = Instant::now(); - let whole_body = hyper::body::to_bytes(req.into_body()).await?; - let msg = rmp_serde::decode::from_read::<_, M>(&whole_body[..])?; - - trace!( - "Request message: {}", - serde_json::to_string(&msg) - .unwrap_or_else(|_| "<json error>".into()) - .chars() - .take(100) - .collect::<String>() - ); - - match handler(msg, sockaddr).await { - Ok(resp) => { - let resp_bytes = rmp_to_vec_all_named::<Result<M, String>>(&Ok(resp))?; - let rpc_duration = (Instant::now() - begin_time).as_millis(); - if rpc_duration > 100 { - debug!("RPC {} ok, took long: {} ms", name, rpc_duration,); - } - Ok(Response::new(Body::from(resp_bytes))) - } - Err(e) => { - let err_str = format!("{}", e); - let rep_bytes = rmp_to_vec_all_named::<Result<M, String>>(&Err(err_str))?; - let mut err_response = Response::new(Body::from(rep_bytes)); - *err_response.status_mut() = match e { - Error::BadRpc(_) => StatusCode::BAD_REQUEST, - _ => StatusCode::INTERNAL_SERVER_ERROR, - }; - warn!( - "RPC error ({}): {} ({} ms)", - name, - e, - (Instant::now() - begin_time).as_millis(), - ); - Ok(err_response) - } - } -} - -impl RpcServer { - /// Create a new RpcServer - pub fn new(bind_addr: SocketAddr, tls_config: Option<TlsConfig>) -> Self { - Self { - bind_addr, - tls_config, - handlers: HashMap::new(), - } - } - - /// Add handler handling request made to `name` - pub fn add_handler<M, F, Fut>(&mut self, name: String, handler: F) - where - M: RpcMessage + 'static, - F: Fn(M, SocketAddr) -> Fut + Send + Sync + 'static, - Fut: Future<Output = Result<M, Error>> + Send + 'static, - { - let name2 = Arc::new(name.clone()); - let handler_arc = Arc::new(handler); - let handler = Box::new(move |req: Request<Body>, sockaddr: SocketAddr| { - let handler2 = handler_arc.clone(); - let b: ResponseFuture = Box::pin(handle_func(handler2, req, sockaddr, name2.clone())); - b - }); - self.handlers.insert(name, handler); - } - - async fn handler( - self: Arc<Self>, - req: Request<Body>, - addr: SocketAddr, - ) -> Result<Response<Body>, Error> { - if req.method() != Method::POST { - let mut bad_request = Response::default(); - *bad_request.status_mut() = StatusCode::BAD_REQUEST; - return Ok(bad_request); - } - - let path = &req.uri().path()[1..].to_string(); - - let handler = match self.handlers.get(path) { - Some(h) => h, - None => { - let mut not_found = Response::default(); - *not_found.status_mut() = StatusCode::NOT_FOUND; - return Ok(not_found); - } - }; - - trace!("({}) Handling request", path); - - let resp_waiter = tokio::spawn(handler(req, addr)); - match resp_waiter.await { - Err(err) => { - warn!("Handler await error: {}", err); - let mut ise = Response::default(); - *ise.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - Ok(ise) - } - Ok(Err(err)) => { - trace!("({}) Request handler failed: {}", path, err); - let mut bad_request = Response::new(Body::from(format!("{}", err))); - *bad_request.status_mut() = StatusCode::BAD_REQUEST; - Ok(bad_request) - } - Ok(Ok(resp)) => { - trace!("({}) Request handler succeeded", path); - Ok(resp) - } - } - } - - /// Run the RpcServer - pub async fn run( - self: Arc<Self>, - shutdown_signal: impl Future<Output = ()>, - ) -> Result<(), Error> { - if let Some(tls_config) = self.tls_config.as_ref() { - let ca_certs = tls_util::load_certs(&tls_config.ca_cert)?; - let node_certs = tls_util::load_certs(&tls_config.node_cert)?; - let node_key = tls_util::load_private_key(&tls_config.node_key)?; - - let mut ca_store = rustls::RootCertStore::empty(); - for crt in ca_certs.iter() { - ca_store.add(crt)?; - } - - let mut config = - rustls::ServerConfig::new(rustls::AllowAnyAuthenticatedClient::new(ca_store)); - config.set_single_cert([&node_certs[..], &ca_certs[..]].concat(), node_key)?; - let tls_acceptor = Arc::new(TlsAcceptor::from(Arc::new(config))); - - let listener = TcpListener::bind(&self.bind_addr).await?; - let incoming = TcpListenerStream::new(listener).filter_map(|socket| async { - match socket { - Ok(stream) => match tls_acceptor.clone().accept(stream).await { - Ok(x) => Some(Ok::<_, hyper::Error>(x)), - Err(_e) => None, - }, - Err(_) => None, - } - }); - let incoming = hyper::server::accept::from_stream(incoming); - - let self_arc = self.clone(); - let service = make_service_fn(|conn: &TlsStream<TcpStream>| { - let client_addr = conn - .get_ref() - .0 - .peer_addr() - .unwrap_or_else(|_| ([0, 0, 0, 0], 0).into()); - let self_arc = self_arc.clone(); - async move { - Ok::<_, Error>(service_fn(move |req: Request<Body>| { - self_arc.clone().handler(req, client_addr).map_err(|e| { - warn!("RPC handler error: {}", e); - e - }) - })) - } - }); - - let server = Server::builder(incoming).serve(service); - - let graceful = server.with_graceful_shutdown(shutdown_signal); - info!("RPC server listening on http://{}", self.bind_addr); - - graceful.await?; - } else { - let self_arc = self.clone(); - let service = make_service_fn(move |conn: &AddrStream| { - let client_addr = conn.remote_addr(); - let self_arc = self_arc.clone(); - async move { - Ok::<_, Error>(service_fn(move |req: Request<Body>| { - self_arc.clone().handler(req, client_addr).map_err(|e| { - warn!("RPC handler error: {}", e); - e - }) - })) - } - }); - - let server = Server::bind(&self.bind_addr).serve(service); - - let graceful = server.with_graceful_shutdown(shutdown_signal); - info!("RPC server listening on http://{}", self.bind_addr); - - graceful.await?; - } - - Ok(()) - } -} diff --git a/src/rpc/system.rs b/src/rpc/system.rs new file mode 100644 index 00000000..7ccec945 --- /dev/null +++ b/src/rpc/system.rs @@ -0,0 +1,363 @@ +//! Module containing structs related to membership management +use std::io::{Read, Write}; +use std::net::SocketAddr; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::time::Duration; + +use arc_swap::ArcSwap; +use async_trait::async_trait; +use futures::{join, select}; +use futures_util::future::*; +use serde::{Deserialize, Serialize}; +use sodiumoxide::crypto::sign::ed25519; +use tokio::sync::watch; +use tokio::sync::Mutex; + +use netapp::endpoint::{Endpoint, EndpointHandler, Message}; +use netapp::peering::fullmesh::FullMeshPeeringStrategy; +use netapp::proto::*; +use netapp::{NetApp, NetworkKey, NodeID, NodeKey}; + +use garage_util::background::BackgroundRunner; +use garage_util::error::Error; +use garage_util::persister::Persister; +//use garage_util::time::*; + +//use crate::consul::get_consul_nodes; +use crate::ring::*; +use crate::rpc_helper::{RequestStrategy, RpcHelper}; + +const DISCOVERY_INTERVAL: Duration = Duration::from_secs(60); +const PING_TIMEOUT: Duration = Duration::from_secs(2); + +/// RPC endpoint used for calls related to membership +pub const SYSTEM_RPC_PATH: &str = "garage_rpc/membership.rs/SystemRpc"; + +/// RPC messages related to membership +#[derive(Debug, Serialize, Deserialize, Clone)] +pub enum SystemRpc { + /// Response to successfull advertisements + Ok, + /// Error response + Error(String), + /// Ask other node its config. Answered with AdvertiseConfig + PullConfig, + /// Advertise Garage status. Answered with another AdvertiseStatus. + /// Exchanged with every node on a regular basis. + AdvertiseStatus(StateInfo), + /// Advertisement of nodes config. Sent spontanously or in response to PullConfig + AdvertiseConfig(NetworkConfig), + /// Get known nodes states + GetKnownNodes, + /// Return known nodes + ReturnKnownNodes(Vec<(NodeID, SocketAddr, bool)>), +} + +impl Message for SystemRpc { + type Response = SystemRpc; +} + +/// This node's membership manager +pub struct System { + /// The id of this node + pub id: NodeID, + + persist_config: Persister<NetworkConfig>, + + state_info: ArcSwap<StateInfo>, + + pub netapp: Arc<NetApp>, + fullmesh: Arc<FullMeshPeeringStrategy>, + pub rpc: RpcHelper, + + system_endpoint: Arc<Endpoint<SystemRpc, System>>, + + rpc_listen_addr: SocketAddr, + bootstrap_peers: Vec<(NodeID, SocketAddr)>, + consul_host: Option<String>, + consul_service_name: Option<String>, + replication_factor: usize, + + /// The ring + pub ring: watch::Receiver<Arc<Ring>>, + update_ring: Mutex<watch::Sender<Arc<Ring>>>, + + /// The job runner of this node + pub background: Arc<BackgroundRunner>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StateInfo { + /// Hostname of the node + pub hostname: String, + /// Replication factor configured on the node + pub replication_factor: usize, + /// Configuration version + pub config_version: u64, +} + +fn gen_node_key(metadata_dir: &Path) -> Result<NodeKey, Error> { + let mut id_file = metadata_dir.to_path_buf(); + id_file.push("node_id"); + if id_file.as_path().exists() { + let mut f = std::fs::File::open(id_file.as_path())?; + let mut d = vec![]; + f.read_to_end(&mut d)?; + if d.len() != 64 { + return Err(Error::Message("Corrupt node_id file".to_string())); + } + + let mut key = [0u8; 64]; + key.copy_from_slice(&d[..]); + Ok(NodeKey::from_slice(&key[..]).unwrap()) + } else { + let (key, _) = ed25519::gen_keypair(); + + let mut f = std::fs::File::create(id_file.as_path())?; + f.write_all(&key[..])?; + Ok(NodeKey::from_slice(&key[..]).unwrap()) + } +} + +impl System { + /// Create this node's membership manager + pub fn new( + network_key: NetworkKey, + metadata_dir: PathBuf, + background: Arc<BackgroundRunner>, + replication_factor: usize, + rpc_listen_addr: SocketAddr, + bootstrap_peers: Vec<(NodeID, SocketAddr)>, + consul_host: Option<String>, + consul_service_name: Option<String>, + ) -> Arc<Self> { + let node_key = gen_node_key(&metadata_dir).expect("Unable to read or generate node ID"); + info!("Node public key: {}", hex::encode(&node_key.public_key())); + + let persist_config = Persister::new(&metadata_dir, "network_config"); + + let net_config = match persist_config.load() { + Ok(x) => x, + Err(e) => { + match Persister::<garage_rpc_021::ring::NetworkConfig>::new( + &metadata_dir, + "network_config", + ) + .load() + { + Ok(old_config) => NetworkConfig::migrate_from_021(old_config), + Err(e2) => { + info!( + "No valid previous network configuration stored ({}, {}), starting fresh.", + e, e2 + ); + NetworkConfig::new() + } + } + } + }; + + let state_info = StateInfo { + hostname: gethostname::gethostname() + .into_string() + .unwrap_or_else(|_| "<invalid utf-8>".to_string()), + replication_factor: replication_factor, + config_version: net_config.version, + }; + + let ring = Ring::new(net_config, replication_factor); + let (update_ring, ring) = watch::channel(Arc::new(ring)); + + let netapp = NetApp::new(network_key, node_key); + let fullmesh = FullMeshPeeringStrategy::new(netapp.clone(), bootstrap_peers.clone()); + + let system_endpoint = netapp.endpoint(SYSTEM_RPC_PATH.into()); + + let sys = Arc::new(System { + id: netapp.id.clone(), + persist_config, + state_info: ArcSwap::new(Arc::new(state_info)), + netapp: netapp.clone(), + fullmesh: fullmesh.clone(), + rpc: RpcHelper { + fullmesh: fullmesh.clone(), + background: background.clone(), + }, + system_endpoint, + replication_factor, + rpc_listen_addr, + bootstrap_peers, + consul_host, + consul_service_name, + ring, + update_ring: Mutex::new(update_ring), + background: background.clone(), + }); + sys.system_endpoint.set_handler(sys.clone()); + sys + } + + /// Perform bootstraping, starting the ping loop + pub async fn run(self: Arc<Self>, must_exit: watch::Receiver<bool>) { + join!( + self.netapp + .clone() + .listen(self.rpc_listen_addr, None, must_exit.clone()), + self.fullmesh.clone().run(must_exit.clone()), + self.discovery_loop(must_exit.clone()), + ); + } + + // ---- INTERNALS ---- + + /// Save network configuration to disc + async fn save_network_config(self: Arc<Self>) -> Result<(), Error> { + let ring: Arc<Ring> = self.ring.borrow().clone(); + self.persist_config + .save_async(&ring.config) + .await + .expect("Cannot save current cluster configuration"); + Ok(()) + } + + fn update_state_info(&self) { + let mut new_si: StateInfo = self.state_info.load().as_ref().clone(); + + let ring = self.ring.borrow(); + new_si.config_version = ring.config.version; + self.state_info.swap(Arc::new(new_si)); + } + + fn handle_pull_config(&self) -> SystemRpc { + let ring = self.ring.borrow().clone(); + SystemRpc::AdvertiseConfig(ring.config.clone()) + } + + async fn handle_advertise_config( + self: Arc<Self>, + adv: &NetworkConfig, + ) -> Result<SystemRpc, Error> { + let update_ring = self.update_ring.lock().await; + let ring: Arc<Ring> = self.ring.borrow().clone(); + + if adv.version > ring.config.version { + let ring = Ring::new(adv.clone(), self.replication_factor); + update_ring.send(Arc::new(ring))?; + drop(update_ring); + + let self2 = self.clone(); + let adv2 = adv.clone(); + self.background.spawn_cancellable(async move { + self2 + .rpc + .broadcast( + &self2.system_endpoint, + SystemRpc::AdvertiseConfig(adv2), + RequestStrategy::with_priority(PRIO_NORMAL), + ) + .await; + Ok(()) + }); + self.background.spawn(self.clone().save_network_config()); + } + + Ok(SystemRpc::Ok) + } + + async fn discovery_loop(&self, mut stop_signal: watch::Receiver<bool>) { + /* TODO + let consul_config = match (&self.consul_host, &self.consul_service_name) { + (Some(ch), Some(csn)) => Some((ch.clone(), csn.clone())), + _ => None, + }; + */ + + while !*stop_signal.borrow() { + let not_configured = self.ring.borrow().config.members.is_empty(); + let no_peers = self.fullmesh.get_peer_list().len() < self.replication_factor; + let bad_peers = self + .fullmesh + .get_peer_list() + .iter() + .filter(|p| p.is_up()) + .count() != self.ring.borrow().config.members.len(); + + if not_configured || no_peers || bad_peers { + info!("Doing a bootstrap/discovery step (not_configured: {}, no_peers: {}, bad_peers: {})", not_configured, no_peers, bad_peers); + + let ping_list = self.bootstrap_peers.clone(); + + /* + *TODO bring this back: persisted list of peers + if let Ok(peers) = self.persist_status.load_async().await { + ping_list.extend(peers.iter().map(|x| (x.addr, Some(x.id)))); + } + */ + + /* + * TODO bring this back: get peers from consul + if let Some((consul_host, consul_service_name)) = &consul_config { + match get_consul_nodes(consul_host, consul_service_name).await { + Ok(node_list) => { + ping_list.extend(node_list.iter().map(|a| (*a, None))); + } + Err(e) => { + warn!("Could not retrieve node list from Consul: {}", e); + } + } + } + */ + + for (node_id, node_addr) in ping_list { + tokio::spawn(self.netapp.clone().try_connect(node_addr, node_id)); + } + } + + let restart_at = tokio::time::sleep(DISCOVERY_INTERVAL); + select! { + _ = restart_at.fuse() => {}, + _ = stop_signal.changed().fuse() => {}, + } + } + } + + async fn pull_config(self: Arc<Self>, peer: NodeID) { + let resp = self + .rpc + .call( + &self.system_endpoint, + peer, + SystemRpc::PullConfig, + RequestStrategy::with_priority(PRIO_HIGH).with_timeout(PING_TIMEOUT), + ) + .await; + if let Ok(SystemRpc::AdvertiseConfig(config)) = resp { + let _: Result<_, _> = self.handle_advertise_config(&config).await; + } + } +} + +#[async_trait] +impl EndpointHandler<SystemRpc> for System { + async fn handle(self: &Arc<Self>, msg: &SystemRpc, _from: NodeID) -> SystemRpc { + let resp = match msg { + SystemRpc::PullConfig => Ok(self.handle_pull_config()), + SystemRpc::AdvertiseConfig(adv) => self.clone().handle_advertise_config(&adv).await, + SystemRpc::GetKnownNodes => { + let known_nodes = self + .fullmesh + .get_peer_list() + .iter() + .map(|n| (n.id, n.addr, n.is_up())) + .collect::<Vec<_>>(); + Ok(SystemRpc::ReturnKnownNodes(known_nodes)) + } + _ => Err(Error::BadRpc("Unexpected RPC message".to_string())), + }; + match resp { + Ok(r) => r, + Err(e) => SystemRpc::Error(format!("{}", e)), + } + } +} diff --git a/src/rpc/tls_util.rs b/src/rpc/tls_util.rs deleted file mode 100644 index 8189f93b..00000000 --- a/src/rpc/tls_util.rs +++ /dev/null @@ -1,140 +0,0 @@ -use core::future::Future; -use core::task::{Context, Poll}; -use std::pin::Pin; -use std::sync::Arc; -use std::{fs, io}; - -use futures_util::future::*; -use hyper::client::connect::Connection; -use hyper::client::HttpConnector; -use hyper::service::Service; -use hyper::Uri; -use hyper_rustls::MaybeHttpsStream; -use rustls::internal::pemfile; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_rustls::TlsConnector; -use webpki::DNSNameRef; - -use garage_util::error::Error; - -pub fn load_certs(filename: &str) -> Result<Vec<rustls::Certificate>, Error> { - let certfile = fs::File::open(&filename)?; - let mut reader = io::BufReader::new(certfile); - - let certs = pemfile::certs(&mut reader).map_err(|_| { - Error::Message(format!( - "Could not deecode certificates from file: {}", - filename - )) - })?; - - if certs.is_empty() { - return Err(Error::Message(format!( - "Invalid certificate file: {}", - filename - ))); - } - Ok(certs) -} - -pub fn load_private_key(filename: &str) -> Result<rustls::PrivateKey, Error> { - let keydata = fs::read_to_string(filename)?; - - let mut buf1 = keydata.as_bytes(); - let rsa_keys = pemfile::rsa_private_keys(&mut buf1).unwrap_or_default(); - - let mut buf2 = keydata.as_bytes(); - let pkcs8_keys = pemfile::pkcs8_private_keys(&mut buf2).unwrap_or_default(); - - let mut keys = rsa_keys; - keys.extend(pkcs8_keys.into_iter()); - - if keys.len() != 1 { - return Err(Error::Message(format!( - "Invalid private key file: {} ({} private keys)", - filename, - keys.len() - ))); - } - Ok(keys[0].clone()) -} - -// ---- AWFUL COPYPASTA FROM HYPER-RUSTLS connector.rs -// ---- ALWAYS USE `garage` AS HOSTNAME FOR TLS VERIFICATION - -#[derive(Clone)] -pub struct HttpsConnectorFixedDnsname<T> { - http: T, - tls_config: Arc<rustls::ClientConfig>, - fixed_dnsname: &'static str, -} - -type BoxError = Box<dyn std::error::Error + Send + Sync>; - -impl HttpsConnectorFixedDnsname<HttpConnector> { - pub fn new(mut tls_config: rustls::ClientConfig, fixed_dnsname: &'static str) -> Self { - let mut http = HttpConnector::new(); - http.enforce_http(false); - tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - Self { - http, - tls_config: Arc::new(tls_config), - fixed_dnsname, - } - } -} - -impl<T> Service<Uri> for HttpsConnectorFixedDnsname<T> -where - T: Service<Uri>, - T::Response: Connection + AsyncRead + AsyncWrite + Send + Unpin + 'static, - T::Future: Send + 'static, - T::Error: Into<BoxError>, -{ - type Response = MaybeHttpsStream<T::Response>; - type Error = BoxError; - - #[allow(clippy::type_complexity)] - type Future = - Pin<Box<dyn Future<Output = Result<MaybeHttpsStream<T::Response>, BoxError>> + Send>>; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { - match self.http.poll_ready(cx) { - Poll::Ready(Ok(())) => Poll::Ready(Ok(())), - Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), - Poll::Pending => Poll::Pending, - } - } - - fn call(&mut self, dst: Uri) -> Self::Future { - let is_https = dst.scheme_str() == Some("https"); - - if !is_https { - let connecting_future = self.http.call(dst); - - let f = async move { - let tcp = connecting_future.await.map_err(Into::into)?; - - Ok(MaybeHttpsStream::Http(tcp)) - }; - f.boxed() - } else { - let cfg = self.tls_config.clone(); - let connecting_future = self.http.call(dst); - - let dnsname = - DNSNameRef::try_from_ascii_str(self.fixed_dnsname).expect("Invalid fixed dnsname"); - - let f = async move { - let tcp = connecting_future.await.map_err(Into::into)?; - let connector = TlsConnector::from(cfg); - let tls = connector - .connect(dnsname, tcp) - .await - .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; - Ok(MaybeHttpsStream::Https(tls)) - }; - f.boxed() - } - } -} |