use std::collections::HashSet;
use std::net::SocketAddr;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use log::{debug, warn};
use lru::LruCache;
use rand::{thread_rng, Rng};
use serde::{Deserialize, Serialize};
use sodiumoxide::crypto::hash;
use sodiumoxide::crypto::sign::ed25519;
use crate::message::*;
use crate::netapp::*;
use crate::proto::*;
// -- Protocol messages --
#[derive(Serialize, Deserialize)]
struct PullMessage {}
impl Message for PullMessage {
const KIND: MessageKind = 0x42001100;
type Response = PushMessage;
}
#[derive(Serialize, Deserialize)]
struct PushMessage {
peers: Vec<Peer>,
}
impl Message for PushMessage {
const KIND: MessageKind = 0x42001101;
type Response = ();
}
// -- Algorithm data structures --
type Seed = [u8; 32];
#[derive(Hash, Clone, Copy, Debug, PartialOrd, PartialEq, Eq, Serialize, Deserialize)]
struct Peer {
id: ed25519::PublicKey,
addr: SocketAddr,
}
type Cost = [u8; 40];
const MAX_COST: Cost = [0xffu8; 40];
impl Peer {
fn cost(&self, seed: &Seed) -> Cost {
let mut hasher = hash::State::new();
hasher.update(&seed[..]);
let mut cost = [0u8; 40];
match self.addr {
SocketAddr::V4(v4addr) => {
let v4ip = v4addr.ip().octets();
for i in 0..4 {
let mut h = hasher.clone();
h.update(&v4ip[..i + 1]);
cost[i * 8..(i + 1) * 8].copy_from_slice(&h.finalize()[..8]);
}
}
SocketAddr::V6(v6addr) => {
let v6ip = v6addr.ip().octets();
for i in 0..4 {
let mut h = hasher.clone();
h.update(&v6ip[..i + 2]);
cost[i * 8..(i + 1) * 8].copy_from_slice(&h.finalize()[..8]);
}
}
}
{
let mut h5 = hasher.clone();
h5.update(&format!("{}", self.addr).into_bytes()[..]);
cost[32..40].copy_from_slice(&h5.finalize()[..8]);
}
cost
}
}
struct BasaltSlot {
seed: Seed,
peer: Option<Peer>,
}
impl BasaltSlot {
fn cost(&self) -> Cost {
self.peer.map(|p| p.cost(&self.seed)).unwrap_or(MAX_COST)
}
}
struct BasaltView {
i_reset: usize,
slots: Vec<BasaltSlot>,
}
impl BasaltView {
fn new(size: usize) -> Self {
let slots = (0..size)
.map(|_| BasaltSlot {
seed: rand_seed(),
peer: None,
})
.collect::<Vec<_>>();
Self { i_reset: 0, slots }
}
fn current_peers(&self) -> HashSet<Peer> {
self.slots
.iter()
.filter(|s| s.peer.is_some())
.map(|s| s.peer.unwrap().clone())
.collect::<HashSet<_>>()
}
fn current_peers_vec(&self) -> Vec<Peer> {
self.current_peers().drain().collect::<Vec<_>>()
}
fn sample(&self, count: usize) -> Vec<Peer> {
let possibles = self
.slots
.iter()
.enumerate()
.filter(|(_i, s)| s.peer.is_some())
.map(|(i, _s)| i)
.collect::<Vec<_>>();
if possibles.len() == 0 {
vec![]
} else {
let mut ret = vec![];
let mut rng = thread_rng();
for _i in 0..count {
let idx = rng.gen_range(0, possibles.len());
ret.push(self.slots[possibles[idx]].peer.unwrap());
}
ret
}
}
fn update_slot(&mut self, i: usize, peers: &[Peer]) {
let mut slot_cost = self.slots[i].cost();
for peer in peers.iter() {
let peer_cost = peer.cost(&self.slots[i].seed);
if self.slots[i].peer.is_none() || peer_cost < slot_cost {
self.slots[i].peer = Some(*peer);
slot_cost = peer_cost;
}
}
}
fn update_all_slots(&mut self, peers: &[Peer]) {
for i in 0..self.slots.len() {
self.update_slot(i, peers);
}
}
fn disconnected(&mut self, id: ed25519::PublicKey) {
let mut cleared_slots = vec![];
for i in 0..self.slots.len() {
if let Some(p) = self.slots[i].peer {
if p.id == id {
self.slots[i].peer = None;
cleared_slots.push(i);
}
}
}
let remaining_peers = self.current_peers_vec();
for i in cleared_slots {
self.update_slot(i, &remaining_peers[..]);
}
}
fn should_try_list(&self, peers: &[Peer]) -> Vec<Peer> {
// Select peers that have lower cost than any of our slots
let mut ret = HashSet::new();
for i in 0..self.slots.len() {
if self.slots[i].peer.is_none() {
return peers.to_vec();
}
let mut min_cost = self.slots[i].cost();
let mut min_peer = None;
for peer in peers.iter() {
if ret.contains(peer) {
continue;
}
let peer_cost = peer.cost(&self.slots[i].seed);
if peer_cost < min_cost {
min_cost = peer_cost;
min_peer = Some(*peer);
}
}
if let Some(p) = min_peer {
ret.insert(p);
if ret.len() == peers.len() {
break;
}
}
}
ret.drain().collect::<Vec<_>>()
}
fn reset_some_slots(&mut self, count: usize) {
for _i in 0..count {
self.slots[self.i_reset].seed = rand_seed();
self.i_reset = (self.i_reset + 1) % self.slots.len();
}
}
}
pub struct BasaltParams {
pub view_size: usize,
pub cache_size: usize,
pub exchange_interval: Duration,
pub reset_interval: Duration,
pub reset_count: usize,
}
pub struct Basalt {
netapp: Arc<NetApp>,
param: BasaltParams,
bootstrap_peers: Vec<Peer>,
view: RwLock<BasaltView>,
current_attempts: RwLock<HashSet<Peer>>,
backlog: RwLock<LruCache<Peer, ()>>,
}
impl Basalt {
pub fn new(
netapp: Arc<NetApp>,
bootstrap_list: Vec<(ed25519::PublicKey, SocketAddr)>,
param: BasaltParams,
) -> Arc<Self> {
let bootstrap_peers = bootstrap_list
.iter()
.map(|(id, addr)| Peer {
id: *id,
addr: *addr,
})
.collect::<Vec<_>>();
let view = BasaltView::new(param.view_size);
let backlog = LruCache::new(param.cache_size);
let basalt = Arc::new(Self {
netapp: netapp.clone(),
param,
bootstrap_peers,
view: RwLock::new(view),
current_attempts: RwLock::new(HashSet::new()),
backlog: RwLock::new(backlog),
});
let basalt2 = basalt.clone();
netapp.on_connected.store(Some(Arc::new(Box::new(
move |pk: ed25519::PublicKey, addr: SocketAddr, is_incoming: bool| {
basalt2.on_connected(pk, addr, is_incoming);
},
))));
let basalt2 = basalt.clone();
netapp.on_disconnected.store(Some(Arc::new(Box::new(
move |pk: ed25519::PublicKey, is_incoming: bool| {
basalt2.on_disconnected(pk, is_incoming);
},
))));
let basalt2 = basalt.clone();
netapp.add_msg_handler::<PullMessage, _, _>(
move |_from: ed25519::PublicKey, _pullmsg: PullMessage| {
let push_msg = basalt2.make_push_message();
async move { push_msg }
},
);
let basalt2 = basalt.clone();
netapp.add_msg_handler::<PushMessage, _, _>(
move |_from: ed25519::PublicKey, push_msg: PushMessage| {
basalt2.handle_peer_list(&push_msg.peers[..]);
async move { () }
},
);
basalt
}
pub fn sample(&self, count: usize) -> Vec<ed25519::PublicKey> {
self.view
.read()
.unwrap()
.sample(count)
.iter()
.map(|p| p.id)
.collect::<Vec<_>>()
}
pub async fn run(self: Arc<Self>) {
for peer in self.bootstrap_peers.iter() {
tokio::spawn(self.clone().try_connect(*peer));
}
let pushpull_loop = self.clone().run_pushpull_loop();
let reset_loop = self.run_reset_loop();
tokio::join!(pushpull_loop, reset_loop);
}
async fn run_pushpull_loop(self: Arc<Self>) {
loop {
tokio::time::delay_for(self.param.exchange_interval).await;
let peers = self.view.read().unwrap().sample(2);
if peers.len() == 2 {
tokio::spawn(self.clone().do_pull(peers[0].id));
tokio::spawn(self.clone().do_push(peers[1].id));
}
}
}
async fn do_pull(self: Arc<Self>, peer: ed25519::PublicKey) {
match self
.netapp
.request(&peer, PullMessage {}, prio::NORMAL)
.await
{
Ok(resp) => {
self.handle_peer_list(&resp.peers[..]);
}
Err(e) => {
warn!("Error during pull exchange: {}", e);
}
};
}
async fn do_push(self: Arc<Self>, peer: ed25519::PublicKey) {
let push_msg = self.make_push_message();
if let Err(e) = self.netapp.request(&peer, push_msg, prio::NORMAL).await {
warn!("Error during push exchange: {}", e);
}
}
fn make_push_message(&self) -> PushMessage {
let current_peers = self.view.read().unwrap().current_peers_vec();
PushMessage {
peers: current_peers,
}
}
async fn run_reset_loop(self: Arc<Self>) {
loop {
tokio::time::delay_for(self.param.reset_interval).await;
{
let mut view = self.view.write().unwrap();
let prev_peers = view.current_peers();
let prev_peers_vec = prev_peers.iter().cloned().collect::<Vec<_>>();
view.reset_some_slots(self.param.reset_count);
view.update_all_slots(&prev_peers_vec[..]);
let new_peers = view.current_peers();
drop(view);
self.close_all_diff(&prev_peers, &new_peers);
}
let mut to_retry_maybe = self.bootstrap_peers.clone();
for (peer, _) in self.backlog.read().unwrap().iter() {
if !self.bootstrap_peers.contains(peer) {
to_retry_maybe.push(*peer);
}
}
self.handle_peer_list(&to_retry_maybe[..]);
}
}
fn handle_peer_list(self: &Arc<Self>, peers: &[Peer]) {
let to_connect = self.view.read().unwrap().should_try_list(peers);
for peer in to_connect.iter() {
tokio::spawn(self.clone().try_connect(*peer));
}
}
async fn try_connect(self: Arc<Self>, peer: Peer) {
{
let view = self.view.read().unwrap();
let mut attempts = self.current_attempts.write().unwrap();
if view.slots.iter().any(|x| x.peer == Some(peer)) {
return;
}
if attempts.contains(&peer) {
return;
}
attempts.insert(peer);
}
let res = self.netapp.clone().try_connect(peer.addr, peer.id).await;
debug!("Connection attempt to {}: {:?}", peer.addr, res);
self.current_attempts.write().unwrap().remove(&peer);
if res.is_err() {
self.backlog.write().unwrap().pop(&peer);
}
}
fn on_connected(self: &Arc<Self>, pk: ed25519::PublicKey, addr: SocketAddr, is_incoming: bool) {
if is_incoming {
self.handle_peer_list(&[Peer { id: pk, addr }][..]);
} else {
let peer = Peer { id: pk, addr };
let mut backlog = self.backlog.write().unwrap();
if backlog.get(&peer).is_none() {
backlog.put(peer, ());
}
drop(backlog);
let mut view = self.view.write().unwrap();
let prev_peers = view.current_peers();
view.update_all_slots(&[peer][..]);
let new_peers = view.current_peers();
drop(view);
self.close_all_diff(&prev_peers, &new_peers);
}
}
fn on_disconnected(&self, pk: ed25519::PublicKey, is_incoming: bool) {
if !is_incoming {
self.view.write().unwrap().disconnected(pk);
}
}
fn close_all_diff(&self, prev_peers: &HashSet<Peer>, new_peers: &HashSet<Peer>) {
let client_conns = self.netapp.client_conns.read().unwrap();
for peer in prev_peers.iter() {
if !new_peers.contains(peer) {
if let Some(c) = client_conns.get(&peer.id) {
debug!(
"Closing connection to {} ({})",
hex::encode(peer.id),
peer.addr
);
c.close();
}
}
}
}
}
fn rand_seed() -> Seed {
let mut seed = [0u8; 32];
sodiumoxide::randombytes::randombytes_into(&mut seed[..]);
seed
}