aboutsummaryrefslogblamecommitdiff
path: root/src/membership.rs
blob: e468c9b05c86401554d41e192e95aa4b09b3276a (plain) (tree)
1
2
3
4
5
6
7
8
9


                              
                                   
 
                              

                          
                           




















                                                                          
                                              
                              
 
                                  

 
              
                                                                           

                                                                    
                                   
                                                   
                                                                          







                                                                                                  



                                                                                   
                                                                                             








                                                                                  


                                               


                                                                       
                                               
                                                               
                                           
                                                                         
                                                                                       
                 
                                    



                                                                       

                             
                                           

 


                                                      
                                          
                                                       
                                                       



                                                                






                                                      


                 
                                                  
                                                        


                                                       

                                                                 


                  

                                                                                  
                                                                                                       

                                                                                
         
 















                                                                                                               



                                                  
 
                                               







                                                                            
                                                                  


                                                          

                                                              

                              






                                                                                


                                          


















                                                                                                  

                                          












                                                                                         
                                                          
























                                                                                                                  
                                                                          















                                                                                                                                               

                                                    


                                                                                           


                                                                   










                                                                                                                                        
                                                                   






                                                                                                            


                                                                  





                                        








                                                                                                                     








                                                                    
                                                                                          
                 
         
 
use std::sync::Arc;
use std::collections::HashMap;
use std::time::Duration;
use std::net::{IpAddr, SocketAddr};

use futures::future::join_all;
use hyper::client::Client;
use tokio::sync::RwLock;
use sha2::{Sha256, Digest};

use crate::Config;
use crate::error::Error;
use crate::data::*;
use crate::proto::*;
use crate::rpc::*;

const PING_INTERVAL: Duration = Duration::from_secs(10);
const PING_TIMEOUT: Duration = Duration::from_secs(2);
const MAX_FAILED_PINGS: usize = 3;

pub struct System {
	pub config: Config,
	pub id: UUID,

	pub rpc_client: Client<hyper::client::HttpConnector, hyper::Body>,

	pub members: RwLock<Members>,
}

pub struct Members {
	pub status: HashMap<UUID, NodeStatus>,
	pub status_hash: Hash,

	pub config: NetworkConfig,
}

impl Members {
	fn handle_ping(&mut self, ip: IpAddr, info: &PingMessage) -> bool {
		let addr = SocketAddr::new(ip, info.rpc_port);
		let old_status = self.status.insert(info.id.clone(),
			NodeStatus{
				addr: addr.clone(),
				remaining_ping_attempts: MAX_FAILED_PINGS,
			});
		match old_status {
			None => {
				eprintln!("Discovered new node (ping): {}", hex::encode(info.id));
				true
			}
			Some(x) => x.addr != addr,
		}
	}

	fn handle_advertise_node(&mut self, id: &UUID, addr: &SocketAddr) -> bool {
		if !self.status.contains_key(id) {
			eprintln!("Discovered new node (advertisment): {}", hex::encode(id));
			self.status.insert(id.clone(),
				NodeStatus{
					addr: addr.clone(),
					remaining_ping_attempts: MAX_FAILED_PINGS,
				});
			true
		} else {
			false
		}
	}

	fn recalculate_status_hash(&mut self) {
		let mut nodes = self.status.iter().collect::<Vec<_>>();
		nodes.sort_by_key(|(id, _status)| *id);

		let mut hasher = Sha256::new();
		eprintln!("Current set of pingable nodes: --");
		for (id, status) in nodes {
			eprintln!("{} {}", hex::encode(id), status.addr);
			hasher.input(format!("{} {}\n", hex::encode(id), status.addr));
		}
		eprintln!("END --");
		self.status_hash.copy_from_slice(&hasher.result()[..]);
	}
}

pub struct NodeStatus {
	pub addr: SocketAddr,
	pub remaining_ping_attempts: usize,
}


impl System {
	pub fn new(config: Config, id: UUID) -> Self {
		let mut members = Members{
				status: HashMap::new(),
				status_hash: [0u8; 32],
				config: NetworkConfig{
					members: HashMap::new(),
					version: 0,
				},
			};
		members.recalculate_status_hash();
		System{
			config,
			id,
			rpc_client: Client::new(),
			members: RwLock::new(members),
		}
	}

	pub async fn make_ping(&self) -> Message {
		let members = self.members.read().await;
		Message::Ping(PingMessage{
			id: self.id,
			rpc_port: self.config.rpc_port,
			status_hash: members.status_hash.clone(),
			config_version: members.config.version,
		})
	}

	pub async fn broadcast(self: Arc<Self>, msg: Message, timeout: Duration) {
		let members = self.members.read().await;
		let to = members.status.keys().filter(|x| **x != self.id).cloned().collect::<Vec<_>>();
		drop(members);
		rpc_call_many(self.clone(), &to[..], &msg, None, timeout).await;
	}

	pub async fn bootstrap(self: Arc<Self>) {
		let ping_msg = self.make_ping().await;
		let ping_resps = join_all(
			self.config.bootstrap_peers.iter().cloned()
			.map(|to| {
				let sys = self.clone();
				let ping_msg_ref = &ping_msg;
				async move {
					(to.clone(), rpc_call_addr(sys, &to, ping_msg_ref, PING_TIMEOUT).await)
				}
			})).await;
		
		let mut members = self.members.write().await;
		for (addr, ping_resp) in ping_resps {
			if let Ok(Message::Ping(info)) = ping_resp {
				members.handle_ping(addr.ip(), &info);
			}
		}
		members.recalculate_status_hash();
		drop(members);

		tokio::spawn(self.ping_loop());
	}

	pub async fn handle_ping(self: Arc<Self>,
							 from: &SocketAddr,
							 ping: &PingMessage)
		-> Result<Message, Error> 
	{
		let mut members = self.members.write().await;
		let is_new = members.handle_ping(from.ip(), ping);
		if is_new {
			members.recalculate_status_hash();
		}
		let status_hash = members.status_hash.clone();
		let config_version = members.config.version;
		drop(members);

		if is_new || status_hash != ping.status_hash {
			tokio::spawn(self.clone().pull_status(ping.id.clone()));
		}
		if is_new || config_version < ping.config_version {
			tokio::spawn(self.clone().pull_config(ping.id.clone()));
		}

		Ok(self.make_ping().await)
	}

	pub async fn handle_pull_status(&self) -> Result<Message, Error> {
		let members = self.members.read().await;
		let mut mem = vec![];
		for (node, status) in members.status.iter() {
			mem.push(AdvertisedNode{
				id: node.clone(),
				addr: status.addr.clone(),
			});
		}
		Ok(Message::AdvertiseNodesUp(mem))
	}

	pub async fn handle_pull_config(&self) -> Result<Message, Error> {
		let members = self.members.read().await;
		Ok(Message::AdvertiseConfig(members.config.clone()))
	}

	pub async fn handle_advertise_nodes_up(self: Arc<Self>,
									   adv: &[AdvertisedNode])
		 -> Result<Message, Error>
	{
		let mut propagate = vec![];

		let mut members = self.members.write().await;
		for node in adv.iter() {
			let is_new = members.handle_advertise_node(&node.id, &node.addr);
			if is_new {
				tokio::spawn(self.clone().pull_status(node.id.clone()));
				tokio::spawn(self.clone().pull_config(node.id.clone()));
				propagate.push(node.clone());
			}
		}
		
		if propagate.len() > 0 {
			members.recalculate_status_hash();
			tokio::spawn(self.clone().broadcast(Message::AdvertiseNodesUp(propagate), PING_TIMEOUT));
		}

		Ok(Message::Ok)
	}

	pub async fn handle_advertise_config(self: Arc<Self>,
										 adv: &NetworkConfig)
		-> Result<Message, Error>
	{
		let mut members = self.members.write().await;
		if adv.version > members.config.version {
			members.config = adv.clone();
			tokio::spawn(self.clone().broadcast(Message::AdvertiseConfig(adv.clone()), PING_TIMEOUT));
		}

		Ok(Message::Ok)
	}

	pub async fn ping_loop(self: Arc<Self>) {
		loop {
			let restart_at = tokio::time::delay_for(PING_INTERVAL);
			
			let members = self.members.read().await;
			let ping_addrs = members.status.iter()
					.filter(|(id, _)| **id != self.id)
					.map(|(id, status)| (id.clone(), status.addr.clone()))
					.collect::<Vec<_>>();
			drop(members);

			let ping_msg = self.make_ping().await;
			let ping_resps = join_all(
					ping_addrs.iter()
						.map(|(id, addr)| {
							let sys = self.clone();
							let ping_msg_ref = &ping_msg;
							async move {
								(id, addr.clone(), rpc_call_addr(sys, &addr, ping_msg_ref, PING_TIMEOUT).await)
							}
					})).await;

			let mut members = self.members.write().await;
			let mut has_changes = false;

			for (id, addr, ping_resp) in ping_resps {
				if let Ok(Message::Ping(ping)) = ping_resp {
					let is_new = members.handle_ping(addr.ip(), &ping);
					if is_new {
						has_changes = true;
					}
					if is_new || members.status_hash != ping.status_hash {
						tokio::spawn(self.clone().pull_status(ping.id.clone()));
					}
					if is_new || members.config.version < ping.config_version {
						tokio::spawn(self.clone().pull_config(ping.id.clone()));
					}
				} else {
					let remaining_attempts = members.status.get(id).map(|x| x.remaining_ping_attempts).unwrap_or(0);
					if remaining_attempts == 0 {
						eprintln!("Removing node {} after too many failed pings", hex::encode(id));
						members.status.remove(id);
						has_changes = true;
					} else {
						if let Some(st) = members.status.get_mut(id) {
							st.remaining_ping_attempts = remaining_attempts - 1;
						}
					}
				}
			}
			if has_changes {
				members.recalculate_status_hash();
			}
			drop(members);

			restart_at.await
		}
	}

	pub fn pull_status(self: Arc<Self>, peer: UUID) -> impl futures::future::Future<Output=()> + Send + 'static {
		async move {
			let resp = rpc_call(self.clone(),
						 &peer,
						 &Message::PullStatus,
						 PING_TIMEOUT).await;
			if let Ok(Message::AdvertiseNodesUp(nodes)) = resp {
				let _: Result<_, _> = self.handle_advertise_nodes_up(&nodes).await;
			}
		}
	}

	pub async fn pull_config(self: Arc<Self>, peer: UUID) {
		let resp = rpc_call(self.clone(),
					 &peer,
					 &Message::PullConfig,
					 PING_TIMEOUT).await;
		if let Ok(Message::AdvertiseConfig(config)) = resp {
			let _: Result<_, _> = self.handle_advertise_config(&config).await;
		}
	}
}