aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/conn.rs5
-rw-r--r--src/netapp.rs152
-rw-r--r--src/peering/basalt.rs40
-rw-r--r--src/peering/fullmesh.rs27
4 files changed, 155 insertions, 69 deletions
diff --git a/src/conn.rs b/src/conn.rs
index 9b60d2a..df3b7cf 100644
--- a/src/conn.rs
+++ b/src/conn.rs
@@ -107,7 +107,8 @@ impl RecvLoop for ServerConn {
let kind = u32::from_be_bytes(kind_bytes);
if let Some(handler) = self.netapp.msg_handlers.load().get(&kind) {
- let resp = handler(self.peer_pk.clone(), bytes.slice(5..)).await;
+ let net_handler = &handler.net_handler;
+ let resp = net_handler(self.peer_pk.clone(), bytes.slice(5..)).await;
self.resp_send
.send((id, prio, resp))
.log_err("ServerConn recv_handler send resp");
@@ -240,7 +241,7 @@ impl ClientConn {
Ok(())
}
- pub async fn request<T>(
+ pub(crate) async fn request<T>(
self: Arc<Self>,
rq: T,
prio: RequestPriority,
diff --git a/src/netapp.rs b/src/netapp.rs
index 6f174b4..25c3b5a 100644
--- a/src/netapp.rs
+++ b/src/netapp.rs
@@ -1,3 +1,4 @@
+use std::any::Any;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::pin::Pin;
@@ -20,6 +21,18 @@ use crate::message::*;
use crate::proto::*;
use crate::util::*;
+type DynMsg = Box<dyn Any + Send + Sync + 'static>;
+
+pub(crate) struct Handler {
+ pub(crate) local_handler:
+ Box<dyn Fn(DynMsg) -> Pin<Box<dyn Future<Output = DynMsg> + Sync + Send>> + Sync + Send>,
+ pub(crate) net_handler: Box<
+ dyn Fn(ed25519::PublicKey, Bytes) -> Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>>
+ + Sync
+ + Send,
+ >,
+}
+
pub struct NetApp {
pub listen_addr: SocketAddr,
pub netid: auth::Key,
@@ -27,29 +40,21 @@ pub struct NetApp {
pub privkey: ed25519::SecretKey,
pub server_conns: RwLock<HashMap<ed25519::PublicKey, Arc<ServerConn>>>,
pub client_conns: RwLock<HashMap<ed25519::PublicKey, Arc<ClientConn>>>,
- pub(crate) msg_handlers: ArcSwap<
- HashMap<
- MessageKind,
- Arc<
- dyn Fn(
- ed25519::PublicKey,
- Bytes,
- ) -> Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>>
- + Sync
- + Send,
- >,
- >,
- >,
+ pub(crate) msg_handlers: ArcSwap<HashMap<MessageKind, Arc<Handler>>>,
pub(crate) on_connected:
ArcSwapOption<Box<dyn Fn(ed25519::PublicKey, SocketAddr, bool) + Send + Sync>>,
pub(crate) on_disconnected: ArcSwapOption<Box<dyn Fn(ed25519::PublicKey, bool) + Send + Sync>>,
}
-async fn handler_aux<M, F, R>(handler: Arc<F>, remote: ed25519::PublicKey, bytes: Bytes) -> Vec<u8>
+async fn net_handler_aux<M, F, R>(
+ handler: Arc<F>,
+ remote: ed25519::PublicKey,
+ bytes: Bytes,
+) -> Vec<u8>
where
M: Message + 'static,
F: Fn(ed25519::PublicKey, M) -> R + Send + Sync + 'static,
- R: Future<Output = Result<<M as Message>::Response, Error>> + Send + Sync,
+ R: Future<Output = <M as Message>::Response> + Send + Sync,
{
debug!(
"Handling message of kind {:08x} from {}",
@@ -57,13 +62,28 @@ where
hex::encode(remote)
);
let res = match rmp_serde::decode::from_read_ref::<_, M>(&bytes[..]) {
- Ok(msg) => handler(remote.clone(), msg).await,
- Err(e) => Err(e.into()),
+ Ok(msg) => Ok(handler(remote, msg).await),
+ Err(e) => Err(e.to_string()),
};
- let res = res.map_err(|e| format!("{}", e));
rmp_to_vec_all_named(&res).unwrap_or(vec![])
}
+async fn local_handler_aux<M, F, R>(
+ handler: Arc<F>,
+ remote: ed25519::PublicKey,
+ msg: DynMsg,
+) -> DynMsg
+where
+ M: Message + 'static,
+ F: Fn(ed25519::PublicKey, M) -> R + Send + Sync + 'static,
+ R: Future<Output = <M as Message>::Response> + Send + Sync,
+{
+ debug!("Handling message of kind {:08x} from ourself", M::KIND,);
+ let msg = (msg as Box<dyn Any + 'static>).downcast::<M>().unwrap();
+ let res = handler(remote, *msg).await;
+ Box::new(res)
+}
+
impl NetApp {
pub fn new(
listen_addr: SocketAddr,
@@ -87,7 +107,7 @@ impl NetApp {
netapp.add_msg_handler::<HelloMessage, _, _>(
move |from: ed25519::PublicKey, msg: HelloMessage| {
netapp2.handle_hello_message(from, msg);
- async { Ok(()) }
+ async { () }
},
);
@@ -98,16 +118,31 @@ impl NetApp {
where
M: Message + 'static,
F: Fn(ed25519::PublicKey, M) -> R + Send + Sync + 'static,
- R: Future<Output = Result<<M as Message>::Response, Error>> + Send + Sync + 'static,
+ R: Future<Output = <M as Message>::Response> + Send + Sync + 'static,
{
let handler = Arc::new(handler);
- let fun = Arc::new(move |remote: ed25519::PublicKey, bytes: Bytes| {
+
+ let handler1 = handler.clone();
+ let net_handler = Box::new(move |remote: ed25519::PublicKey, bytes: Bytes| {
let fun: Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>> =
- Box::pin(handler_aux(handler.clone(), remote, bytes));
+ Box::pin(net_handler_aux(handler1.clone(), remote, bytes));
+ fun
+ });
+
+ let self_id = self.pubkey.clone();
+ let local_handler = Box::new(move |msg: DynMsg| {
+ let fun: Pin<Box<dyn Future<Output = DynMsg> + Sync + Send>> =
+ Box::pin(local_handler_aux(handler.clone(), self_id, msg));
fun
});
+
+ let funs = Arc::new(Handler {
+ net_handler,
+ local_handler,
+ });
+
let mut handlers = self.msg_handlers.load().as_ref().clone();
- handlers.insert(M::KIND, fun);
+ handlers.insert(M::KIND, funs);
self.msg_handlers.store(Arc::new(handlers));
}
@@ -136,23 +171,48 @@ impl NetApp {
ip: SocketAddr,
pk: ed25519::PublicKey,
) -> Result<(), Error> {
+ if pk == self.pubkey {
+ // Don't connect to ourself, we don't care
+ // but pretend we did
+ tokio::spawn(async move {
+ if let Some(h) = self.on_connected.load().as_ref() {
+ h(pk, ip, false);
+ }
+ });
+ return Ok(());
+ }
+
+ // Don't connect if already connected
if self.client_conns.read().unwrap().contains_key(&pk) {
return Ok(());
}
+
let socket = TcpStream::connect(ip).await?;
info!("Connected to {}, negotiating handshake...", ip);
ClientConn::init(self, socket, pk.clone()).await?;
Ok(())
}
- pub fn disconnect(self: Arc<Self>, id: &ed25519::PublicKey) {
- let conn = self.client_conns.read().unwrap().get(id).cloned();
+ pub fn disconnect(self: Arc<Self>, pk: &ed25519::PublicKey) {
+ if *pk == self.pubkey {
+ let pk = *pk;
+ tokio::spawn(async move {
+ if let Some(h) = self.on_disconnected.load().as_ref() {
+ h(pk, false);
+ }
+ });
+ return;
+ }
+
+ let conn = self.client_conns.read().unwrap().get(pk).cloned();
if let Some(c) = conn {
c.close();
}
}
pub(crate) fn connected_as_server(&self, id: ed25519::PublicKey, conn: Arc<ServerConn>) {
+ info!("Accepted connection from {}", hex::encode(id));
+
let mut conn_list = self.server_conns.write().unwrap();
conn_list.insert(id.clone(), conn);
}
@@ -167,6 +227,8 @@ impl NetApp {
}
pub(crate) fn disconnected_as_server(&self, id: &ed25519::PublicKey, conn: Arc<ServerConn>) {
+ info!("Connection from {} closed", hex::encode(id));
+
let mut conn_list = self.server_conns.write().unwrap();
if let Some(c) = conn_list.get(id) {
if Arc::ptr_eq(c, &conn) {
@@ -180,6 +242,8 @@ impl NetApp {
}
pub(crate) fn connected_as_client(&self, id: ed25519::PublicKey, conn: Arc<ClientConn>) {
+ info!("Connection established to {}", hex::encode(id));
+
{
let mut conn_list = self.client_conns.write().unwrap();
if let Some(old_c) = conn_list.insert(id.clone(), conn.clone()) {
@@ -200,6 +264,7 @@ impl NetApp {
}
pub(crate) fn disconnected_as_client(&self, id: &ed25519::PublicKey, conn: Arc<ClientConn>) {
+ info!("Connection to {} closed", hex::encode(id));
let mut conn_list = self.client_conns.write().unwrap();
if let Some(c) = conn_list.get(id) {
if Arc::ptr_eq(c, &conn) {
@@ -211,4 +276,41 @@ impl NetApp {
}
}
}
+
+ pub async fn request<T>(
+ &self,
+ target: &ed25519::PublicKey,
+ rq: T,
+ prio: RequestPriority,
+ ) -> Result<<T as Message>::Response, Error>
+ where
+ T: Message + 'static,
+ {
+ if *target == self.pubkey {
+ let handler = self.msg_handlers.load().get(&T::KIND).cloned();
+ match handler {
+ None => Err(Error::Message(format!(
+ "No handler registered for message kind {:08x}",
+ T::KIND
+ ))),
+ Some(h) => {
+ let local_handler = &h.local_handler;
+ let res = local_handler(Box::new(rq)).await;
+ let res_t = (res as Box<dyn Any + 'static>)
+ .downcast::<<T as Message>::Response>()
+ .unwrap();
+ Ok(*res_t)
+ }
+ }
+ } else {
+ let conn = self.client_conns.read().unwrap().get(target).cloned();
+ match conn {
+ None => Err(Error::Message(format!(
+ "Not connected: {}",
+ hex::encode(target)
+ ))),
+ Some(c) => c.request(rq, prio).await,
+ }
+ }
+ }
}
diff --git a/src/peering/basalt.rs b/src/peering/basalt.rs
index be807a8..27461ab 100644
--- a/src/peering/basalt.rs
+++ b/src/peering/basalt.rs
@@ -11,7 +11,6 @@ use serde::{Deserialize, Serialize};
use sodiumoxide::crypto::hash;
use sodiumoxide::crypto::sign::ed25519;
-use crate::conn::*;
use crate::message::*;
use crate::netapp::*;
use crate::proto::*;
@@ -282,7 +281,7 @@ impl Basalt {
netapp.add_msg_handler::<PullMessage, _, _>(
move |_from: ed25519::PublicKey, _pullmsg: PullMessage| {
let push_msg = basalt2.make_push_message();
- async move { Ok(push_msg) }
+ async move { push_msg }
},
);
@@ -290,7 +289,7 @@ impl Basalt {
netapp.add_msg_handler::<PushMessage, _, _>(
move |_from: ed25519::PublicKey, push_msg: PushMessage| {
basalt2.handle_peer_list(&push_msg.peers[..]);
- async move { Ok(()) }
+ async move { () }
},
);
@@ -323,25 +322,18 @@ impl Basalt {
let peers = self.view.read().unwrap().sample(2);
if peers.len() == 2 {
- let (c1, c2) = {
- let client_conns = self.netapp.client_conns.read().unwrap();
- (
- client_conns.get(&peers[0].id).cloned(),
- client_conns.get(&peers[1].id).cloned(),
- )
- };
- if let Some(c) = c1 {
- tokio::spawn(self.clone().do_pull(c));
- }
- if let Some(c) = c2 {
- tokio::spawn(self.clone().do_push(c));
- }
+ tokio::spawn(self.clone().do_pull(peers[0].id));
+ tokio::spawn(self.clone().do_push(peers[1].id));
}
}
}
- async fn do_pull(self: Arc<Self>, peer: Arc<ClientConn>) {
- match peer.request(PullMessage {}, prio::NORMAL).await {
+ async fn do_pull(self: Arc<Self>, peer: ed25519::PublicKey) {
+ match self
+ .netapp
+ .request(&peer, PullMessage {}, prio::NORMAL)
+ .await
+ {
Ok(resp) => {
self.handle_peer_list(&resp.peers[..]);
}
@@ -351,9 +343,9 @@ impl Basalt {
};
}
- async fn do_push(self: Arc<Self>, peer: Arc<ClientConn>) {
+ async fn do_push(self: Arc<Self>, peer: ed25519::PublicKey) {
let push_msg = self.make_push_message();
- if let Err(e) = peer.request(push_msg, prio::NORMAL).await {
+ if let Err(e) = self.netapp.request(&peer, push_msg, prio::NORMAL).await {
warn!("Error during push exchange: {}", e);
}
}
@@ -427,7 +419,7 @@ impl Basalt {
fn on_connected(self: &Arc<Self>, pk: ed25519::PublicKey, addr: SocketAddr, is_incoming: bool) {
if is_incoming {
- self.handle_peer_list(&[Peer{id: pk, addr}][..]);
+ self.handle_peer_list(&[Peer { id: pk, addr }][..]);
} else {
let peer = Peer { id: pk, addr };
@@ -460,7 +452,11 @@ impl Basalt {
for peer in prev_peers.iter() {
if !new_peers.contains(peer) {
if let Some(c) = client_conns.get(&peer.id) {
- debug!("Closing connection to {} ({})", hex::encode(peer.id), peer.addr);
+ debug!(
+ "Closing connection to {} ({})",
+ hex::encode(peer.id),
+ peer.addr
+ );
c.close();
}
}
diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs
index e04beb6..4e9a78d 100644
--- a/src/peering/fullmesh.rs
+++ b/src/peering/fullmesh.rs
@@ -10,7 +10,6 @@ use serde::{Deserialize, Serialize};
use sodiumoxide::crypto::hash;
use sodiumoxide::crypto::sign::ed25519;
-use crate::conn::*;
use crate::message::*;
use crate::netapp::*;
use crate::proto::*;
@@ -162,10 +161,8 @@ impl FullMeshPeeringStrategy {
id: ping.id,
peer_list_hash: strat2.known_hosts.read().unwrap().hash,
};
- async move {
- debug!("Ping from {}", hex::encode(&from));
- Ok(ping_resp)
- }
+ debug!("Ping from {}", hex::encode(&from));
+ async move { ping_resp }
},
);
@@ -175,7 +172,7 @@ impl FullMeshPeeringStrategy {
strat2.handle_peer_list(&peer_list.list[..]);
let peer_list = KnownHosts::map_into_vec(&strat2.known_hosts.read().unwrap().list);
let resp = PeerListMessage { list: peer_list };
- async move { Ok(resp) }
+ async move { resp }
},
);
@@ -260,16 +257,6 @@ impl FullMeshPeeringStrategy {
}
async fn ping(self: Arc<Self>, id: ed25519::PublicKey) {
- let peer = {
- match self.netapp.client_conns.read().unwrap().get(&id) {
- None => {
- warn!("Should ping {}, but no connection", hex::encode(id));
- return;
- }
- Some(peer) => peer.clone(),
- }
- };
-
let peer_list_hash = self.known_hosts.read().unwrap().hash;
let ping_id = self.next_ping_id.fetch_add(1u64, atomic::Ordering::Relaxed);
let ping_time = Instant::now();
@@ -284,7 +271,7 @@ impl FullMeshPeeringStrategy {
hex::encode(id),
ping_time
);
- match peer.clone().request(ping_msg, prio::HIGH).await {
+ match self.netapp.request(&id, ping_msg, prio::HIGH).await {
Err(e) => warn!("Error pinging {}: {}", hex::encode(id), e),
Ok(ping_resp) => {
let resp_time = Instant::now();
@@ -304,16 +291,16 @@ impl FullMeshPeeringStrategy {
}
}
if ping_resp.peer_list_hash != peer_list_hash {
- self.exchange_peers(peer).await;
+ self.exchange_peers(&id).await;
}
}
}
}
- async fn exchange_peers(self: Arc<Self>, peer: Arc<ClientConn>) {
+ async fn exchange_peers(self: Arc<Self>, id: &ed25519::PublicKey) {
let peer_list = KnownHosts::map_into_vec(&self.known_hosts.read().unwrap().list);
let pex_message = PeerListMessage { list: peer_list };
- match peer.request(pex_message, prio::BACKGROUND).await {
+ match self.netapp.request(id, pex_message, prio::BACKGROUND).await {
Err(e) => warn!("Error doing peer exchange: {}", e),
Ok(resp) => {
self.handle_peer_list(&resp.list[..]);