aboutsummaryrefslogtreecommitdiff
path: root/src/rpc/membership.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/rpc/membership.rs')
-rw-r--r--src/rpc/membership.rs692
1 files changed, 692 insertions, 0 deletions
diff --git a/src/rpc/membership.rs b/src/rpc/membership.rs
new file mode 100644
index 00000000..e0509536
--- /dev/null
+++ b/src/rpc/membership.rs
@@ -0,0 +1,692 @@
+use std::collections::HashMap;
+use std::hash::Hash as StdHash;
+use std::hash::Hasher;
+use std::io::{Read, Write};
+use std::net::{IpAddr, SocketAddr};
+use std::path::PathBuf;
+use std::sync::atomic::{AtomicUsize, Ordering};
+use std::sync::Arc;
+use std::time::Duration;
+
+use futures::future::join_all;
+use futures::select;
+use futures_util::future::*;
+use serde::{Deserialize, Serialize};
+use sha2::{Digest, Sha256};
+use tokio::prelude::*;
+use tokio::sync::watch;
+use tokio::sync::Mutex;
+
+use crate::background::BackgroundRunner;
+use crate::data::*;
+use crate::error::Error;
+
+use crate::rpc::rpc_client::*;
+use crate::rpc::rpc_server::*;
+
+const PING_INTERVAL: Duration = Duration::from_secs(10);
+const PING_TIMEOUT: Duration = Duration::from_secs(2);
+const MAX_FAILURES_BEFORE_CONSIDERED_DOWN: usize = 5;
+
+pub const MEMBERSHIP_RPC_PATH: &str = "_membership";
+
+#[derive(Debug, Serialize, Deserialize)]
+pub enum Message {
+ Ok,
+ Ping(PingMessage),
+ PullStatus,
+ PullConfig,
+ AdvertiseNodesUp(Vec<AdvertisedNode>),
+ AdvertiseConfig(NetworkConfig),
+}
+
+impl RpcMessage for Message {}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct PingMessage {
+ pub id: UUID,
+ pub rpc_port: u16,
+
+ pub status_hash: Hash,
+ pub config_version: u64,
+
+ pub state_info: StateInfo,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct AdvertisedNode {
+ pub id: UUID,
+ pub addr: SocketAddr,
+
+ pub is_up: bool,
+ pub last_seen: u64,
+
+ pub state_info: StateInfo,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct NetworkConfig {
+ pub members: HashMap<UUID, NetworkConfigEntry>,
+ pub version: u64,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct NetworkConfigEntry {
+ pub datacenter: String,
+ pub n_tokens: u32,
+ pub tag: String,
+}
+
+pub struct System {
+ pub id: UUID,
+ pub data_dir: PathBuf,
+ pub rpc_local_port: u16,
+
+ pub state_info: StateInfo,
+
+ pub rpc_http_client: Arc<RpcHttpClient>,
+ rpc_client: Arc<RpcClient<Message>>,
+
+ pub status: watch::Receiver<Arc<Status>>,
+ pub ring: watch::Receiver<Arc<Ring>>,
+
+ update_lock: Mutex<(watch::Sender<Arc<Status>>, watch::Sender<Arc<Ring>>)>,
+
+ pub background: Arc<BackgroundRunner>,
+}
+
+#[derive(Debug, Clone)]
+pub struct Status {
+ pub nodes: HashMap<UUID, Arc<StatusEntry>>,
+ pub hash: Hash,
+}
+
+#[derive(Debug)]
+pub struct StatusEntry {
+ pub addr: SocketAddr,
+ pub last_seen: u64,
+ pub num_failures: AtomicUsize,
+ pub state_info: StateInfo,
+}
+
+impl StatusEntry {
+ pub fn is_up(&self) -> bool {
+ self.num_failures.load(Ordering::SeqCst) < MAX_FAILURES_BEFORE_CONSIDERED_DOWN
+ }
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct StateInfo {
+ pub hostname: String,
+}
+
+#[derive(Clone)]
+pub struct Ring {
+ pub config: NetworkConfig,
+ pub ring: Vec<RingEntry>,
+ pub n_datacenters: usize,
+}
+
+#[derive(Clone, Debug)]
+pub struct RingEntry {
+ pub location: Hash,
+ pub node: UUID,
+ pub datacenter: u64,
+}
+
+impl Status {
+ fn handle_ping(&mut self, ip: IpAddr, info: &PingMessage) -> bool {
+ let addr = SocketAddr::new(ip, info.rpc_port);
+ let old_status = self.nodes.insert(
+ info.id,
+ Arc::new(StatusEntry {
+ addr,
+ last_seen: now_msec(),
+ num_failures: AtomicUsize::from(0),
+ state_info: info.state_info.clone(),
+ }),
+ );
+ match old_status {
+ None => {
+ info!("Newly pingable node: {}", hex::encode(&info.id));
+ true
+ }
+ Some(x) => x.addr != addr,
+ }
+ }
+
+ fn recalculate_hash(&mut self) {
+ let mut nodes = self.nodes.iter().collect::<Vec<_>>();
+ nodes.sort_unstable_by_key(|(id, _status)| *id);
+
+ let mut hasher = Sha256::new();
+ debug!("Current set of pingable nodes: --");
+ for (id, status) in nodes {
+ debug!("{} {}", hex::encode(&id), status.addr);
+ hasher.input(format!("{} {}\n", hex::encode(&id), status.addr));
+ }
+ debug!("END --");
+ self.hash
+ .as_slice_mut()
+ .copy_from_slice(&hasher.result()[..]);
+ }
+}
+
+impl Ring {
+ fn rebuild_ring(&mut self) {
+ let mut new_ring = vec![];
+ let mut datacenters = vec![];
+
+ for (id, config) in self.config.members.iter() {
+ let mut dc_hasher = std::collections::hash_map::DefaultHasher::new();
+ config.datacenter.hash(&mut dc_hasher);
+ let datacenter = dc_hasher.finish();
+
+ if !datacenters.contains(&datacenter) {
+ datacenters.push(datacenter);
+ }
+
+ for i in 0..config.n_tokens {
+ let location = hash(format!("{} {}", hex::encode(&id), i).as_bytes());
+
+ new_ring.push(RingEntry {
+ location: location.into(),
+ node: *id,
+ datacenter,
+ })
+ }
+ }
+
+ new_ring.sort_unstable_by(|x, y| x.location.cmp(&y.location));
+ self.ring = new_ring;
+ self.n_datacenters = datacenters.len();
+
+ // eprintln!("RING: --");
+ // for e in self.ring.iter() {
+ // eprintln!("{:?}", e);
+ // }
+ // eprintln!("END --");
+ }
+
+ pub fn walk_ring(&self, from: &Hash, n: usize) -> Vec<UUID> {
+ if n >= self.config.members.len() {
+ return self.config.members.keys().cloned().collect::<Vec<_>>();
+ }
+
+ let start = match self.ring.binary_search_by(|x| x.location.cmp(from)) {
+ Ok(i) => i,
+ Err(i) => {
+ if i == 0 {
+ self.ring.len() - 1
+ } else {
+ i - 1
+ }
+ }
+ };
+
+ self.walk_ring_from_pos(start, n)
+ }
+
+ fn walk_ring_from_pos(&self, start: usize, n: usize) -> Vec<UUID> {
+ if n >= self.config.members.len() {
+ return self.config.members.keys().cloned().collect::<Vec<_>>();
+ }
+
+ let mut ret = vec![];
+ let mut datacenters = vec![];
+
+ let mut delta = 0;
+ while ret.len() < n {
+ let i = (start + delta) % self.ring.len();
+ delta += 1;
+
+ if !datacenters.contains(&self.ring[i].datacenter) {
+ ret.push(self.ring[i].node);
+ datacenters.push(self.ring[i].datacenter);
+ } else if datacenters.len() == self.n_datacenters && !ret.contains(&self.ring[i].node) {
+ ret.push(self.ring[i].node);
+ }
+ }
+
+ ret
+ }
+}
+
+fn gen_node_id(metadata_dir: &PathBuf) -> Result<UUID, Error> {
+ let mut id_file = metadata_dir.clone();
+ id_file.push("node_id");
+ if id_file.as_path().exists() {
+ let mut f = std::fs::File::open(id_file.as_path())?;
+ let mut d = vec![];
+ f.read_to_end(&mut d)?;
+ if d.len() != 32 {
+ return Err(Error::Message(format!("Corrupt node_id file")));
+ }
+
+ let mut id = [0u8; 32];
+ id.copy_from_slice(&d[..]);
+ Ok(id.into())
+ } else {
+ let id = gen_uuid();
+
+ let mut f = std::fs::File::create(id_file.as_path())?;
+ f.write_all(id.as_slice())?;
+ Ok(id)
+ }
+}
+
+fn read_network_config(metadata_dir: &PathBuf) -> Result<NetworkConfig, Error> {
+ let mut path = metadata_dir.clone();
+ path.push("network_config");
+
+ let mut file = std::fs::OpenOptions::new()
+ .read(true)
+ .open(path.as_path())?;
+
+ let mut net_config_bytes = vec![];
+ file.read_to_end(&mut net_config_bytes)?;
+
+ let net_config = rmp_serde::decode::from_read_ref(&net_config_bytes[..])
+ .expect("Unable to parse network configuration file (has version format changed?).");
+
+ Ok(net_config)
+}
+
+impl System {
+ pub fn new(
+ data_dir: PathBuf,
+ rpc_http_client: Arc<RpcHttpClient>,
+ background: Arc<BackgroundRunner>,
+ rpc_server: &mut RpcServer,
+ ) -> Arc<Self> {
+ let id = gen_node_id(&data_dir).expect("Unable to read or generate node ID");
+ info!("Node ID: {}", hex::encode(&id));
+
+ let net_config = match read_network_config(&data_dir) {
+ Ok(x) => x,
+ Err(e) => {
+ info!(
+ "No valid previous network configuration stored ({}), starting fresh.",
+ e
+ );
+ NetworkConfig {
+ members: HashMap::new(),
+ version: 0,
+ }
+ }
+ };
+ let mut status = Status {
+ nodes: HashMap::new(),
+ hash: Hash::default(),
+ };
+ status.recalculate_hash();
+ let (update_status, status) = watch::channel(Arc::new(status));
+
+ let state_info = StateInfo {
+ hostname: gethostname::gethostname()
+ .into_string()
+ .unwrap_or("<invalid utf-8>".to_string()),
+ };
+
+ let mut ring = Ring {
+ config: net_config,
+ ring: Vec::new(),
+ n_datacenters: 0,
+ };
+ ring.rebuild_ring();
+ let (update_ring, ring) = watch::channel(Arc::new(ring));
+
+ let rpc_path = MEMBERSHIP_RPC_PATH.to_string();
+ let rpc_client = RpcClient::new(
+ RpcAddrClient::<Message>::new(rpc_http_client.clone(), rpc_path.clone()),
+ background.clone(),
+ status.clone(),
+ );
+
+ let sys = Arc::new(System {
+ id,
+ data_dir,
+ rpc_local_port: rpc_server.bind_addr.port(),
+ state_info,
+ rpc_http_client,
+ rpc_client,
+ status,
+ ring,
+ update_lock: Mutex::new((update_status, update_ring)),
+ background,
+ });
+ sys.clone().register_handler(rpc_server, rpc_path);
+ sys
+ }
+
+ fn register_handler(self: Arc<Self>, rpc_server: &mut RpcServer, path: String) {
+ rpc_server.add_handler::<Message, _, _>(path, move |msg, addr| {
+ let self2 = self.clone();
+ async move {
+ match msg {
+ Message::Ping(ping) => self2.handle_ping(&addr, &ping).await,
+
+ Message::PullStatus => self2.handle_pull_status(),
+ Message::PullConfig => self2.handle_pull_config(),
+ Message::AdvertiseNodesUp(adv) => self2.handle_advertise_nodes_up(&adv).await,
+ Message::AdvertiseConfig(adv) => self2.handle_advertise_config(&adv).await,
+
+ _ => Err(Error::BadRequest(format!("Unexpected RPC message"))),
+ }
+ }
+ });
+ }
+
+ pub fn rpc_client<M: RpcMessage + 'static>(self: &Arc<Self>, path: &str) -> Arc<RpcClient<M>> {
+ RpcClient::new(
+ RpcAddrClient::new(self.rpc_http_client.clone(), path.to_string()),
+ self.background.clone(),
+ self.status.clone(),
+ )
+ }
+
+ async fn save_network_config(self: Arc<Self>) -> Result<(), Error> {
+ let mut path = self.data_dir.clone();
+ path.push("network_config");
+
+ let ring = self.ring.borrow().clone();
+ let data = rmp_to_vec_all_named(&ring.config)?;
+
+ let mut f = tokio::fs::File::create(path.as_path()).await?;
+ f.write_all(&data[..]).await?;
+ Ok(())
+ }
+
+ pub fn make_ping(&self) -> Message {
+ let status = self.status.borrow().clone();
+ let ring = self.ring.borrow().clone();
+ Message::Ping(PingMessage {
+ id: self.id,
+ rpc_port: self.rpc_local_port,
+ status_hash: status.hash,
+ config_version: ring.config.version,
+ state_info: self.state_info.clone(),
+ })
+ }
+
+ pub async fn broadcast(self: Arc<Self>, msg: Message, timeout: Duration) {
+ let status = self.status.borrow().clone();
+ let to = status
+ .nodes
+ .keys()
+ .filter(|x| **x != self.id)
+ .cloned()
+ .collect::<Vec<_>>();
+ self.rpc_client.call_many(&to[..], msg, timeout).await;
+ }
+
+ pub async fn bootstrap(self: Arc<Self>, peers: &[SocketAddr]) {
+ let bootstrap_peers = peers.iter().map(|ip| (*ip, None)).collect::<Vec<_>>();
+ self.clone().ping_nodes(bootstrap_peers).await;
+
+ self.clone()
+ .background
+ .spawn_worker(format!("ping loop"), |stop_signal| {
+ self.ping_loop(stop_signal).map(Ok)
+ })
+ .await;
+ }
+
+ async fn ping_nodes(self: Arc<Self>, peers: Vec<(SocketAddr, Option<UUID>)>) {
+ let ping_msg = self.make_ping();
+ let ping_resps = join_all(peers.iter().map(|(addr, id_option)| {
+ let sys = self.clone();
+ let ping_msg_ref = &ping_msg;
+ async move {
+ (
+ id_option,
+ addr,
+ sys.rpc_client
+ .by_addr()
+ .call(&addr, ping_msg_ref, PING_TIMEOUT)
+ .await,
+ )
+ }
+ }))
+ .await;
+
+ let update_locked = self.update_lock.lock().await;
+ let mut status: Status = self.status.borrow().as_ref().clone();
+ let ring = self.ring.borrow().clone();
+
+ let mut has_changes = false;
+ let mut to_advertise = vec![];
+
+ for (id_option, addr, ping_resp) in ping_resps {
+ if let Ok(Ok(Message::Ping(info))) = ping_resp {
+ let is_new = status.handle_ping(addr.ip(), &info);
+ if is_new {
+ has_changes = true;
+ to_advertise.push(AdvertisedNode {
+ id: info.id,
+ addr: *addr,
+ is_up: true,
+ last_seen: now_msec(),
+ state_info: info.state_info.clone(),
+ });
+ }
+ if is_new || status.hash != info.status_hash {
+ self.background
+ .spawn_cancellable(self.clone().pull_status(info.id).map(Ok));
+ }
+ if is_new || ring.config.version < info.config_version {
+ self.background
+ .spawn_cancellable(self.clone().pull_config(info.id).map(Ok));
+ }
+ } else if let Some(id) = id_option {
+ if let Some(st) = status.nodes.get_mut(id) {
+ st.num_failures.fetch_add(1, Ordering::SeqCst);
+ if !st.is_up() {
+ warn!("Node {:?} seems to be down.", id);
+ if !ring.config.members.contains_key(id) {
+ info!("Removing node {:?} from status (not in config and not responding to pings anymore)", id);
+ drop(st);
+ status.nodes.remove(&id);
+ has_changes = true;
+ }
+ }
+ }
+ }
+ }
+ if has_changes {
+ status.recalculate_hash();
+ }
+ if let Err(e) = update_locked.0.broadcast(Arc::new(status)) {
+ error!("In ping_nodes: could not save status update ({})", e);
+ }
+ drop(update_locked);
+
+ if to_advertise.len() > 0 {
+ self.broadcast(Message::AdvertiseNodesUp(to_advertise), PING_TIMEOUT)
+ .await;
+ }
+ }
+
+ pub async fn handle_ping(
+ self: Arc<Self>,
+ from: &SocketAddr,
+ ping: &PingMessage,
+ ) -> Result<Message, Error> {
+ let update_locked = self.update_lock.lock().await;
+ let mut status: Status = self.status.borrow().as_ref().clone();
+
+ let is_new = status.handle_ping(from.ip(), ping);
+ if is_new {
+ status.recalculate_hash();
+ }
+ let status_hash = status.hash;
+ let config_version = self.ring.borrow().config.version;
+
+ update_locked.0.broadcast(Arc::new(status))?;
+ drop(update_locked);
+
+ if is_new || status_hash != ping.status_hash {
+ self.background
+ .spawn_cancellable(self.clone().pull_status(ping.id).map(Ok));
+ }
+ if is_new || config_version < ping.config_version {
+ self.background
+ .spawn_cancellable(self.clone().pull_config(ping.id).map(Ok));
+ }
+
+ Ok(self.make_ping())
+ }
+
+ pub fn handle_pull_status(&self) -> Result<Message, Error> {
+ let status = self.status.borrow().clone();
+ let mut mem = vec![];
+ for (node, status) in status.nodes.iter() {
+ let state_info = if *node == self.id {
+ self.state_info.clone()
+ } else {
+ status.state_info.clone()
+ };
+ mem.push(AdvertisedNode {
+ id: *node,
+ addr: status.addr,
+ is_up: status.is_up(),
+ last_seen: status.last_seen,
+ state_info,
+ });
+ }
+ Ok(Message::AdvertiseNodesUp(mem))
+ }
+
+ pub fn handle_pull_config(&self) -> Result<Message, Error> {
+ let ring = self.ring.borrow().clone();
+ Ok(Message::AdvertiseConfig(ring.config.clone()))
+ }
+
+ pub async fn handle_advertise_nodes_up(
+ self: Arc<Self>,
+ adv: &[AdvertisedNode],
+ ) -> Result<Message, Error> {
+ let mut to_ping = vec![];
+
+ let update_lock = self.update_lock.lock().await;
+ let mut status: Status = self.status.borrow().as_ref().clone();
+ let mut has_changed = false;
+
+ for node in adv.iter() {
+ if node.id == self.id {
+ // learn our own ip address
+ let self_addr = SocketAddr::new(node.addr.ip(), self.rpc_local_port);
+ let old_self = status.nodes.insert(
+ node.id,
+ Arc::new(StatusEntry {
+ addr: self_addr,
+ last_seen: now_msec(),
+ num_failures: AtomicUsize::from(0),
+ state_info: self.state_info.clone(),
+ }),
+ );
+ has_changed = match old_self {
+ None => true,
+ Some(x) => x.addr != self_addr,
+ };
+ } else {
+ let ping_them = match status.nodes.get(&node.id) {
+ // Case 1: new node
+ None => true,
+ // Case 2: the node might have changed address
+ Some(our_node) => node.is_up && !our_node.is_up() && our_node.addr != node.addr,
+ };
+ if ping_them {
+ to_ping.push((node.addr, Some(node.id)));
+ }
+ }
+ }
+ if has_changed {
+ status.recalculate_hash();
+ }
+ update_lock.0.broadcast(Arc::new(status))?;
+ drop(update_lock);
+
+ if to_ping.len() > 0 {
+ self.background
+ .spawn_cancellable(self.clone().ping_nodes(to_ping).map(Ok));
+ }
+
+ Ok(Message::Ok)
+ }
+
+ pub async fn handle_advertise_config(
+ self: Arc<Self>,
+ adv: &NetworkConfig,
+ ) -> Result<Message, Error> {
+ let update_lock = self.update_lock.lock().await;
+ let mut ring: Ring = self.ring.borrow().as_ref().clone();
+
+ if adv.version > ring.config.version {
+ ring.config = adv.clone();
+ ring.rebuild_ring();
+ update_lock.1.broadcast(Arc::new(ring))?;
+ drop(update_lock);
+
+ self.background.spawn_cancellable(
+ self.clone()
+ .broadcast(Message::AdvertiseConfig(adv.clone()), PING_TIMEOUT)
+ .map(Ok),
+ );
+ self.background.spawn(self.clone().save_network_config());
+ }
+
+ Ok(Message::Ok)
+ }
+
+ pub async fn ping_loop(self: Arc<Self>, mut stop_signal: watch::Receiver<bool>) {
+ loop {
+ let restart_at = tokio::time::delay_for(PING_INTERVAL);
+
+ let status = self.status.borrow().clone();
+ let ping_addrs = status
+ .nodes
+ .iter()
+ .filter(|(id, _)| **id != self.id)
+ .map(|(id, status)| (status.addr, Some(*id)))
+ .collect::<Vec<_>>();
+
+ self.clone().ping_nodes(ping_addrs).await;
+
+ select! {
+ _ = restart_at.fuse() => (),
+ must_exit = stop_signal.recv().fuse() => {
+ match must_exit {
+ None | Some(true) => return,
+ _ => (),
+ }
+ }
+ }
+ }
+ }
+
+ pub fn pull_status(
+ self: Arc<Self>,
+ peer: UUID,
+ ) -> impl futures::future::Future<Output = ()> + Send + 'static {
+ async move {
+ let resp = self
+ .rpc_client
+ .call(peer, Message::PullStatus, PING_TIMEOUT)
+ .await;
+ if let Ok(Message::AdvertiseNodesUp(nodes)) = resp {
+ let _: Result<_, _> = self.handle_advertise_nodes_up(&nodes).await;
+ }
+ }
+ }
+
+ pub async fn pull_config(self: Arc<Self>, peer: UUID) {
+ let resp = self
+ .rpc_client
+ .call(peer, Message::PullConfig, PING_TIMEOUT)
+ .await;
+ if let Ok(Message::AdvertiseConfig(config)) = resp {
+ let _: Result<_, _> = self.handle_advertise_config(&config).await;
+ }
+ }
+}