aboutsummaryrefslogblamecommitdiff
path: root/src/peering/basalt.rs
blob: 21506a78afc08e863ef67f652788048deccee895 (plain) (tree)
1
2
3
4
5
6
7

                              
                           


                             
                             






                                    

                       
                       
                      
                     







                                 








                                    



















                                                                                      
                                    






                                                                
                                                           







                                                                                                     
                                                           






                                                                                                     
                                            





































                                                                                                       
                                               













                                                                
                                         




                                                   
                                                                            

























































































                                                                                   
                                     






                                        

                                                        



























                                                              

                                                                                                 






                                                                      


                                                                 









                                                                                            















                                                                       
                                                                             



                                                                      



                                                                          

         

                                                                                       
                                                                               










                                                                                
                                      
                                                                 













                                                                           
                                                                                   















                                                                                  

                                                                                    
                                                                            



































































































                                                                                                    

                                              

                                                                                                 




                                              

                                                                                 


         




                                                                  
use std::collections::HashSet;
use std::net::SocketAddr;
use std::num::NonZeroUsize;
use std::sync::{Arc, RwLock};
use std::time::Duration;

use async_trait::async_trait;
use log::{debug, info, trace, warn};
use lru::LruCache;
use rand::{thread_rng, Rng};
use serde::{Deserialize, Serialize};

use sodiumoxide::crypto::hash;

use tokio::sync::watch;

use crate::endpoint::*;
use crate::message::*;
use crate::netapp::*;
use crate::NodeID;

// -- Protocol messages --

#[derive(Serialize, Deserialize)]
struct PullMessage {}

impl Message for PullMessage {
	type Response = PushMessage;
}

#[derive(Serialize, Deserialize)]
struct PushMessage {
	peers: Vec<Peer>,
}

impl Message for PushMessage {
	type Response = ();
}

// -- Algorithm data structures --

type Seed = [u8; 32];

#[derive(Hash, Clone, Copy, Debug, PartialOrd, PartialEq, Eq, Serialize, Deserialize)]
struct Peer {
	id: NodeID,
	addr: SocketAddr,
}

type Cost = [u8; 40];
const MAX_COST: Cost = [0xffu8; 40];

impl Peer {
	fn cost(&self, seed: &Seed) -> Cost {
		let mut hasher = hash::State::new();
		hasher.update(&seed[..]);
		let hasher = hasher;

		let mut cost = [0u8; 40];
		match self.addr {
			SocketAddr::V4(v4addr) => {
				let v4ip = v4addr.ip().octets();

				for i in 0..4 {
					let mut h = hasher;
					h.update(&v4ip[..i + 1]);
					cost[i * 8..(i + 1) * 8].copy_from_slice(&h.finalize()[..8]);
				}
			}
			SocketAddr::V6(v6addr) => {
				let v6ip = v6addr.ip().octets();

				for i in 0..4 {
					let mut h = hasher;
					h.update(&v6ip[..i + 2]);
					cost[i * 8..(i + 1) * 8].copy_from_slice(&h.finalize()[..8]);
				}
			}
		}

		{
			let mut h5 = hasher;
			h5.update(&format!("{} {}", self.addr, hex::encode(self.id)).into_bytes()[..]);
			cost[32..40].copy_from_slice(&h5.finalize()[..8]);
		}

		cost
	}
}

struct BasaltSlot {
	seed: Seed,
	peer: Option<Peer>,
}

impl BasaltSlot {
	fn cost(&self) -> Cost {
		self.peer.map(|p| p.cost(&self.seed)).unwrap_or(MAX_COST)
	}
}

struct BasaltView {
	i_reset: usize,
	slots: Vec<BasaltSlot>,
}

impl BasaltView {
	fn new(size: usize) -> Self {
		let slots = (0..size)
			.map(|_| BasaltSlot {
				seed: rand_seed(),
				peer: None,
			})
			.collect::<Vec<_>>();
		Self { i_reset: 0, slots }
	}

	fn current_peers(&self) -> HashSet<Peer> {
		self.slots
			.iter()
			.filter_map(|s| s.peer)
			.collect::<HashSet<_>>()
	}
	fn current_peers_vec(&self) -> Vec<Peer> {
		self.current_peers().drain().collect::<Vec<_>>()
	}

	fn sample(&self, count: usize) -> Vec<Peer> {
		let possibles = self
			.slots
			.iter()
			.enumerate()
			.filter(|(_i, s)| s.peer.is_some())
			.map(|(i, _s)| i)
			.collect::<Vec<_>>();
		if possibles.is_empty() {
			vec![]
		} else {
			let mut ret = vec![];
			let mut rng = thread_rng();
			for _i in 0..count {
				let idx = rng.gen_range(0..possibles.len());
				ret.push(self.slots[possibles[idx]].peer.unwrap());
			}
			ret
		}
	}

	fn update_slot(&mut self, i: usize, peers: &[Peer]) {
		let mut slot_cost = self.slots[i].cost();

		for peer in peers.iter() {
			let peer_cost = peer.cost(&self.slots[i].seed);
			if self.slots[i].peer.is_none() || peer_cost < slot_cost {
				trace!(
					"Best match for slot {}: {}@{} (cost {})",
					i,
					hex::encode(peer.id),
					peer.addr,
					hex::encode(peer_cost)
				);
				self.slots[i].peer = Some(*peer);
				slot_cost = peer_cost;
			}
		}
	}
	fn update_all_slots(&mut self, peers: &[Peer]) {
		for i in 0..self.slots.len() {
			self.update_slot(i, peers);
		}
	}

	fn disconnected(&mut self, id: NodeID) {
		let mut cleared_slots = vec![];
		for i in 0..self.slots.len() {
			if let Some(p) = self.slots[i].peer {
				if p.id == id {
					self.slots[i].peer = None;
					cleared_slots.push(i);
				}
			}
		}

		let remaining_peers = self.current_peers_vec();

		for i in cleared_slots {
			self.update_slot(i, &remaining_peers[..]);
		}
	}

	fn should_try_list(&self, peers: &[Peer]) -> Vec<Peer> {
		// Select peers that have lower cost than any of our slots
		let mut ret = HashSet::new();

		for i in 0..self.slots.len() {
			if self.slots[i].peer.is_none() {
				return peers.to_vec();
			}
			let mut min_cost = self.slots[i].cost();
			let mut min_peer = None;
			for peer in peers.iter() {
				if ret.contains(peer) {
					continue;
				}
				let peer_cost = peer.cost(&self.slots[i].seed);
				if peer_cost < min_cost {
					min_cost = peer_cost;
					min_peer = Some(*peer);
				}
			}
			if let Some(p) = min_peer {
				ret.insert(p);
				if ret.len() == peers.len() {
					break;
				}
			}
		}

		ret.drain().collect::<Vec<_>>()
	}

	fn reset_some_slots(&mut self, count: usize) {
		for _i in 0..count {
			trace!("Reset slot {}", self.i_reset);
			self.slots[self.i_reset].seed = rand_seed();
			self.i_reset = (self.i_reset + 1) % self.slots.len();
		}
	}
}

pub struct BasaltParams {
	pub view_size: usize,
	pub cache_size: NonZeroUsize,
	pub exchange_interval: Duration,
	pub reset_interval: Duration,
	pub reset_count: usize,
}

pub struct Basalt {
	netapp: Arc<NetApp>,
	pull_endpoint: Arc<Endpoint<PullMessage, Self>>,
	push_endpoint: Arc<Endpoint<PushMessage, Self>>,

	param: BasaltParams,
	bootstrap_peers: Vec<Peer>,

	view: RwLock<BasaltView>,
	current_attempts: RwLock<HashSet<Peer>>,
	backlog: RwLock<LruCache<Peer, ()>>,
}

impl Basalt {
	pub fn new(
		netapp: Arc<NetApp>,
		bootstrap_list: Vec<(NodeID, SocketAddr)>,
		param: BasaltParams,
	) -> Arc<Self> {
		let bootstrap_peers = bootstrap_list
			.iter()
			.map(|(id, addr)| Peer {
				id: *id,
				addr: *addr,
			})
			.collect::<Vec<_>>();

		let view = BasaltView::new(param.view_size);
		let backlog = LruCache::new(param.cache_size);

		let basalt = Arc::new(Self {
			netapp: netapp.clone(),
			pull_endpoint: netapp.endpoint("__netapp/peering/basalt.rs/Pull".into()),
			push_endpoint: netapp.endpoint("__netapp/peering/basalt.rs/Push".into()),
			param,
			bootstrap_peers,
			view: RwLock::new(view),
			current_attempts: RwLock::new(HashSet::new()),
			backlog: RwLock::new(backlog),
		});

		basalt.pull_endpoint.set_handler(basalt.clone());
		basalt.push_endpoint.set_handler(basalt.clone());

		let basalt2 = basalt.clone();
		netapp.on_connected(move |id: NodeID, addr: SocketAddr, is_incoming: bool| {
			basalt2.on_connected(id, addr, is_incoming);
		});

		let basalt2 = basalt.clone();
		netapp.on_disconnected(move |id: NodeID, is_incoming: bool| {
			basalt2.on_disconnected(id, is_incoming);
		});

		basalt
	}

	pub fn sample(&self, count: usize) -> Vec<NodeID> {
		self.view
			.read()
			.unwrap()
			.sample(count)
			.iter()
			.map(|p| {
				debug!("KYEV S {}", hex::encode(p.id));
				p.id
			})
			.collect::<Vec<_>>()
	}

	pub async fn run(self: Arc<Self>, must_exit: watch::Receiver<bool>) {
		for peer in self.bootstrap_peers.iter() {
			tokio::spawn(self.clone().try_connect(*peer));
		}

		tokio::join!(
			self.clone().run_pushpull_loop(must_exit.clone()),
			self.clone().run_reset_loop(must_exit.clone()),
		);
	}

	async fn run_pushpull_loop(self: Arc<Self>, must_exit: watch::Receiver<bool>) {
		while !*must_exit.borrow() {
			tokio::time::sleep(self.param.exchange_interval).await;

			let peers = self.view.read().unwrap().sample(2);
			if peers.len() == 2 {
				tokio::spawn(self.clone().do_pull(peers[0].id));
				tokio::spawn(self.clone().do_push(peers[1].id));
			}
		}
	}

	async fn do_pull(self: Arc<Self>, peer: NodeID) {
		match self
			.pull_endpoint
			.call(&peer, PullMessage {}, PRIO_NORMAL)
			.await
		{
			Ok(resp) => {
				self.handle_peer_list(&resp.peers[..]);
				trace!("KYEV PEXi {}", hex::encode(peer));
			}
			Err(e) => {
				warn!("Error during pull exchange: {}", e);
			}
		};
	}

	async fn do_push(self: Arc<Self>, peer: NodeID) {
		let push_msg = self.make_push_message();
		match self.push_endpoint.call(&peer, push_msg, PRIO_NORMAL).await {
			Ok(_) => {
				trace!("KYEV PEXo {}", hex::encode(peer));
			}
			Err(e) => {
				warn!("Error during push exchange: {}", e);
			}
		}
	}

	fn make_push_message(&self) -> PushMessage {
		let current_peers = self.view.read().unwrap().current_peers_vec();
		PushMessage {
			peers: current_peers,
		}
	}

	async fn run_reset_loop(self: Arc<Self>, must_exit: watch::Receiver<bool>) {
		while !*must_exit.borrow() {
			tokio::time::sleep(self.param.reset_interval).await;

			{
				debug!("KYEV R {}", self.param.reset_count);

				let mut view = self.view.write().unwrap();
				let prev_peers = view.current_peers();
				let prev_peers_vec = prev_peers.iter().cloned().collect::<Vec<_>>();

				view.reset_some_slots(self.param.reset_count);
				view.update_all_slots(&prev_peers_vec[..]);

				let new_peers = view.current_peers();
				drop(view);

				self.close_all_diff(&prev_peers, &new_peers);
			}

			let mut to_retry_maybe = self.bootstrap_peers.clone();
			for (peer, _) in self.backlog.read().unwrap().iter() {
				if !self.bootstrap_peers.contains(peer) {
					to_retry_maybe.push(*peer);
				}
			}
			self.handle_peer_list(&to_retry_maybe[..]);
		}
	}

	fn handle_peer_list(self: &Arc<Self>, peers: &[Peer]) {
		let to_connect = self.view.read().unwrap().should_try_list(peers);

		for peer in to_connect.iter() {
			tokio::spawn(self.clone().try_connect(*peer));
		}
	}

	async fn try_connect(self: Arc<Self>, peer: Peer) {
		{
			let view = self.view.read().unwrap();
			let mut attempts = self.current_attempts.write().unwrap();

			if view.slots.iter().any(|x| x.peer == Some(peer)) {
				return;
			}
			if attempts.contains(&peer) {
				return;
			}

			attempts.insert(peer);
		}
		let res = self.netapp.clone().try_connect(peer.addr, peer.id).await;
		trace!("Connection attempt to {}: {:?}", peer.addr, res);

		self.current_attempts.write().unwrap().remove(&peer);

		if res.is_err() {
			self.backlog.write().unwrap().pop(&peer);
		}
	}

	fn on_connected(self: &Arc<Self>, id: NodeID, addr: SocketAddr, is_incoming: bool) {
		if is_incoming {
			self.handle_peer_list(&[Peer { id, addr }][..]);
		} else {
			info!("KYEV C {} {}", hex::encode(id), addr);
			let peer = Peer { id, addr };

			let mut backlog = self.backlog.write().unwrap();
			if backlog.get(&peer).is_none() {
				backlog.put(peer, ());
			}
			drop(backlog);

			let mut view = self.view.write().unwrap();
			let prev_peers = view.current_peers();

			view.update_all_slots(&[peer][..]);

			let new_peers = view.current_peers();
			drop(view);

			self.close_all_diff(&prev_peers, &new_peers);
		}
	}

	fn on_disconnected(&self, id: NodeID, is_incoming: bool) {
		if !is_incoming {
			info!("KYEV D {}", hex::encode(id));
			self.view.write().unwrap().disconnected(id);
		}
	}

	fn close_all_diff(&self, prev_peers: &HashSet<Peer>, new_peers: &HashSet<Peer>) {
		for peer in prev_peers.iter() {
			if !new_peers.contains(peer) {
				self.netapp.disconnect(&peer.id);
			}
		}
	}
}

#[async_trait]
impl EndpointHandler<PullMessage> for Basalt {
	async fn handle(self: &Arc<Self>, _pullmsg: &PullMessage, _from: NodeID) -> PushMessage {
		self.make_push_message()
	}
}

#[async_trait]
impl EndpointHandler<PushMessage> for Basalt {
	async fn handle(self: &Arc<Self>, pushmsg: &PushMessage, _from: NodeID) {
		self.handle_peer_list(&pushmsg.peers[..]);
	}
}

fn rand_seed() -> Seed {
	let mut seed = [0u8; 32];
	sodiumoxide::randombytes::randombytes_into(&mut seed[..]);
	seed
}