diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/conn.rs | 5 | ||||
-rw-r--r-- | src/netapp.rs | 152 | ||||
-rw-r--r-- | src/peering/basalt.rs | 40 | ||||
-rw-r--r-- | src/peering/fullmesh.rs | 27 |
4 files changed, 155 insertions, 69 deletions
diff --git a/src/conn.rs b/src/conn.rs index 9b60d2a..df3b7cf 100644 --- a/src/conn.rs +++ b/src/conn.rs @@ -107,7 +107,8 @@ impl RecvLoop for ServerConn { let kind = u32::from_be_bytes(kind_bytes); if let Some(handler) = self.netapp.msg_handlers.load().get(&kind) { - let resp = handler(self.peer_pk.clone(), bytes.slice(5..)).await; + let net_handler = &handler.net_handler; + let resp = net_handler(self.peer_pk.clone(), bytes.slice(5..)).await; self.resp_send .send((id, prio, resp)) .log_err("ServerConn recv_handler send resp"); @@ -240,7 +241,7 @@ impl ClientConn { Ok(()) } - pub async fn request<T>( + pub(crate) async fn request<T>( self: Arc<Self>, rq: T, prio: RequestPriority, diff --git a/src/netapp.rs b/src/netapp.rs index 6f174b4..25c3b5a 100644 --- a/src/netapp.rs +++ b/src/netapp.rs @@ -1,3 +1,4 @@ +use std::any::Any; use std::collections::HashMap; use std::net::SocketAddr; use std::pin::Pin; @@ -20,6 +21,18 @@ use crate::message::*; use crate::proto::*; use crate::util::*; +type DynMsg = Box<dyn Any + Send + Sync + 'static>; + +pub(crate) struct Handler { + pub(crate) local_handler: + Box<dyn Fn(DynMsg) -> Pin<Box<dyn Future<Output = DynMsg> + Sync + Send>> + Sync + Send>, + pub(crate) net_handler: Box< + dyn Fn(ed25519::PublicKey, Bytes) -> Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>> + + Sync + + Send, + >, +} + pub struct NetApp { pub listen_addr: SocketAddr, pub netid: auth::Key, @@ -27,29 +40,21 @@ pub struct NetApp { pub privkey: ed25519::SecretKey, pub server_conns: RwLock<HashMap<ed25519::PublicKey, Arc<ServerConn>>>, pub client_conns: RwLock<HashMap<ed25519::PublicKey, Arc<ClientConn>>>, - pub(crate) msg_handlers: ArcSwap< - HashMap< - MessageKind, - Arc< - dyn Fn( - ed25519::PublicKey, - Bytes, - ) -> Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>> - + Sync - + Send, - >, - >, - >, + pub(crate) msg_handlers: ArcSwap<HashMap<MessageKind, Arc<Handler>>>, pub(crate) on_connected: ArcSwapOption<Box<dyn Fn(ed25519::PublicKey, SocketAddr, bool) + Send + Sync>>, pub(crate) on_disconnected: ArcSwapOption<Box<dyn Fn(ed25519::PublicKey, bool) + Send + Sync>>, } -async fn handler_aux<M, F, R>(handler: Arc<F>, remote: ed25519::PublicKey, bytes: Bytes) -> Vec<u8> +async fn net_handler_aux<M, F, R>( + handler: Arc<F>, + remote: ed25519::PublicKey, + bytes: Bytes, +) -> Vec<u8> where M: Message + 'static, F: Fn(ed25519::PublicKey, M) -> R + Send + Sync + 'static, - R: Future<Output = Result<<M as Message>::Response, Error>> + Send + Sync, + R: Future<Output = <M as Message>::Response> + Send + Sync, { debug!( "Handling message of kind {:08x} from {}", @@ -57,13 +62,28 @@ where hex::encode(remote) ); let res = match rmp_serde::decode::from_read_ref::<_, M>(&bytes[..]) { - Ok(msg) => handler(remote.clone(), msg).await, - Err(e) => Err(e.into()), + Ok(msg) => Ok(handler(remote, msg).await), + Err(e) => Err(e.to_string()), }; - let res = res.map_err(|e| format!("{}", e)); rmp_to_vec_all_named(&res).unwrap_or(vec![]) } +async fn local_handler_aux<M, F, R>( + handler: Arc<F>, + remote: ed25519::PublicKey, + msg: DynMsg, +) -> DynMsg +where + M: Message + 'static, + F: Fn(ed25519::PublicKey, M) -> R + Send + Sync + 'static, + R: Future<Output = <M as Message>::Response> + Send + Sync, +{ + debug!("Handling message of kind {:08x} from ourself", M::KIND,); + let msg = (msg as Box<dyn Any + 'static>).downcast::<M>().unwrap(); + let res = handler(remote, *msg).await; + Box::new(res) +} + impl NetApp { pub fn new( listen_addr: SocketAddr, @@ -87,7 +107,7 @@ impl NetApp { netapp.add_msg_handler::<HelloMessage, _, _>( move |from: ed25519::PublicKey, msg: HelloMessage| { netapp2.handle_hello_message(from, msg); - async { Ok(()) } + async { () } }, ); @@ -98,16 +118,31 @@ impl NetApp { where M: Message + 'static, F: Fn(ed25519::PublicKey, M) -> R + Send + Sync + 'static, - R: Future<Output = Result<<M as Message>::Response, Error>> + Send + Sync + 'static, + R: Future<Output = <M as Message>::Response> + Send + Sync + 'static, { let handler = Arc::new(handler); - let fun = Arc::new(move |remote: ed25519::PublicKey, bytes: Bytes| { + + let handler1 = handler.clone(); + let net_handler = Box::new(move |remote: ed25519::PublicKey, bytes: Bytes| { let fun: Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>> = - Box::pin(handler_aux(handler.clone(), remote, bytes)); + Box::pin(net_handler_aux(handler1.clone(), remote, bytes)); + fun + }); + + let self_id = self.pubkey.clone(); + let local_handler = Box::new(move |msg: DynMsg| { + let fun: Pin<Box<dyn Future<Output = DynMsg> + Sync + Send>> = + Box::pin(local_handler_aux(handler.clone(), self_id, msg)); fun }); + + let funs = Arc::new(Handler { + net_handler, + local_handler, + }); + let mut handlers = self.msg_handlers.load().as_ref().clone(); - handlers.insert(M::KIND, fun); + handlers.insert(M::KIND, funs); self.msg_handlers.store(Arc::new(handlers)); } @@ -136,23 +171,48 @@ impl NetApp { ip: SocketAddr, pk: ed25519::PublicKey, ) -> Result<(), Error> { + if pk == self.pubkey { + // Don't connect to ourself, we don't care + // but pretend we did + tokio::spawn(async move { + if let Some(h) = self.on_connected.load().as_ref() { + h(pk, ip, false); + } + }); + return Ok(()); + } + + // Don't connect if already connected if self.client_conns.read().unwrap().contains_key(&pk) { return Ok(()); } + let socket = TcpStream::connect(ip).await?; info!("Connected to {}, negotiating handshake...", ip); ClientConn::init(self, socket, pk.clone()).await?; Ok(()) } - pub fn disconnect(self: Arc<Self>, id: &ed25519::PublicKey) { - let conn = self.client_conns.read().unwrap().get(id).cloned(); + pub fn disconnect(self: Arc<Self>, pk: &ed25519::PublicKey) { + if *pk == self.pubkey { + let pk = *pk; + tokio::spawn(async move { + if let Some(h) = self.on_disconnected.load().as_ref() { + h(pk, false); + } + }); + return; + } + + let conn = self.client_conns.read().unwrap().get(pk).cloned(); if let Some(c) = conn { c.close(); } } pub(crate) fn connected_as_server(&self, id: ed25519::PublicKey, conn: Arc<ServerConn>) { + info!("Accepted connection from {}", hex::encode(id)); + let mut conn_list = self.server_conns.write().unwrap(); conn_list.insert(id.clone(), conn); } @@ -167,6 +227,8 @@ impl NetApp { } pub(crate) fn disconnected_as_server(&self, id: &ed25519::PublicKey, conn: Arc<ServerConn>) { + info!("Connection from {} closed", hex::encode(id)); + let mut conn_list = self.server_conns.write().unwrap(); if let Some(c) = conn_list.get(id) { if Arc::ptr_eq(c, &conn) { @@ -180,6 +242,8 @@ impl NetApp { } pub(crate) fn connected_as_client(&self, id: ed25519::PublicKey, conn: Arc<ClientConn>) { + info!("Connection established to {}", hex::encode(id)); + { let mut conn_list = self.client_conns.write().unwrap(); if let Some(old_c) = conn_list.insert(id.clone(), conn.clone()) { @@ -200,6 +264,7 @@ impl NetApp { } pub(crate) fn disconnected_as_client(&self, id: &ed25519::PublicKey, conn: Arc<ClientConn>) { + info!("Connection to {} closed", hex::encode(id)); let mut conn_list = self.client_conns.write().unwrap(); if let Some(c) = conn_list.get(id) { if Arc::ptr_eq(c, &conn) { @@ -211,4 +276,41 @@ impl NetApp { } } } + + pub async fn request<T>( + &self, + target: &ed25519::PublicKey, + rq: T, + prio: RequestPriority, + ) -> Result<<T as Message>::Response, Error> + where + T: Message + 'static, + { + if *target == self.pubkey { + let handler = self.msg_handlers.load().get(&T::KIND).cloned(); + match handler { + None => Err(Error::Message(format!( + "No handler registered for message kind {:08x}", + T::KIND + ))), + Some(h) => { + let local_handler = &h.local_handler; + let res = local_handler(Box::new(rq)).await; + let res_t = (res as Box<dyn Any + 'static>) + .downcast::<<T as Message>::Response>() + .unwrap(); + Ok(*res_t) + } + } + } else { + let conn = self.client_conns.read().unwrap().get(target).cloned(); + match conn { + None => Err(Error::Message(format!( + "Not connected: {}", + hex::encode(target) + ))), + Some(c) => c.request(rq, prio).await, + } + } + } } diff --git a/src/peering/basalt.rs b/src/peering/basalt.rs index be807a8..27461ab 100644 --- a/src/peering/basalt.rs +++ b/src/peering/basalt.rs @@ -11,7 +11,6 @@ use serde::{Deserialize, Serialize}; use sodiumoxide::crypto::hash; use sodiumoxide::crypto::sign::ed25519; -use crate::conn::*; use crate::message::*; use crate::netapp::*; use crate::proto::*; @@ -282,7 +281,7 @@ impl Basalt { netapp.add_msg_handler::<PullMessage, _, _>( move |_from: ed25519::PublicKey, _pullmsg: PullMessage| { let push_msg = basalt2.make_push_message(); - async move { Ok(push_msg) } + async move { push_msg } }, ); @@ -290,7 +289,7 @@ impl Basalt { netapp.add_msg_handler::<PushMessage, _, _>( move |_from: ed25519::PublicKey, push_msg: PushMessage| { basalt2.handle_peer_list(&push_msg.peers[..]); - async move { Ok(()) } + async move { () } }, ); @@ -323,25 +322,18 @@ impl Basalt { let peers = self.view.read().unwrap().sample(2); if peers.len() == 2 { - let (c1, c2) = { - let client_conns = self.netapp.client_conns.read().unwrap(); - ( - client_conns.get(&peers[0].id).cloned(), - client_conns.get(&peers[1].id).cloned(), - ) - }; - if let Some(c) = c1 { - tokio::spawn(self.clone().do_pull(c)); - } - if let Some(c) = c2 { - tokio::spawn(self.clone().do_push(c)); - } + tokio::spawn(self.clone().do_pull(peers[0].id)); + tokio::spawn(self.clone().do_push(peers[1].id)); } } } - async fn do_pull(self: Arc<Self>, peer: Arc<ClientConn>) { - match peer.request(PullMessage {}, prio::NORMAL).await { + async fn do_pull(self: Arc<Self>, peer: ed25519::PublicKey) { + match self + .netapp + .request(&peer, PullMessage {}, prio::NORMAL) + .await + { Ok(resp) => { self.handle_peer_list(&resp.peers[..]); } @@ -351,9 +343,9 @@ impl Basalt { }; } - async fn do_push(self: Arc<Self>, peer: Arc<ClientConn>) { + async fn do_push(self: Arc<Self>, peer: ed25519::PublicKey) { let push_msg = self.make_push_message(); - if let Err(e) = peer.request(push_msg, prio::NORMAL).await { + if let Err(e) = self.netapp.request(&peer, push_msg, prio::NORMAL).await { warn!("Error during push exchange: {}", e); } } @@ -427,7 +419,7 @@ impl Basalt { fn on_connected(self: &Arc<Self>, pk: ed25519::PublicKey, addr: SocketAddr, is_incoming: bool) { if is_incoming { - self.handle_peer_list(&[Peer{id: pk, addr}][..]); + self.handle_peer_list(&[Peer { id: pk, addr }][..]); } else { let peer = Peer { id: pk, addr }; @@ -460,7 +452,11 @@ impl Basalt { for peer in prev_peers.iter() { if !new_peers.contains(peer) { if let Some(c) = client_conns.get(&peer.id) { - debug!("Closing connection to {} ({})", hex::encode(peer.id), peer.addr); + debug!( + "Closing connection to {} ({})", + hex::encode(peer.id), + peer.addr + ); c.close(); } } diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs index e04beb6..4e9a78d 100644 --- a/src/peering/fullmesh.rs +++ b/src/peering/fullmesh.rs @@ -10,7 +10,6 @@ use serde::{Deserialize, Serialize}; use sodiumoxide::crypto::hash; use sodiumoxide::crypto::sign::ed25519; -use crate::conn::*; use crate::message::*; use crate::netapp::*; use crate::proto::*; @@ -162,10 +161,8 @@ impl FullMeshPeeringStrategy { id: ping.id, peer_list_hash: strat2.known_hosts.read().unwrap().hash, }; - async move { - debug!("Ping from {}", hex::encode(&from)); - Ok(ping_resp) - } + debug!("Ping from {}", hex::encode(&from)); + async move { ping_resp } }, ); @@ -175,7 +172,7 @@ impl FullMeshPeeringStrategy { strat2.handle_peer_list(&peer_list.list[..]); let peer_list = KnownHosts::map_into_vec(&strat2.known_hosts.read().unwrap().list); let resp = PeerListMessage { list: peer_list }; - async move { Ok(resp) } + async move { resp } }, ); @@ -260,16 +257,6 @@ impl FullMeshPeeringStrategy { } async fn ping(self: Arc<Self>, id: ed25519::PublicKey) { - let peer = { - match self.netapp.client_conns.read().unwrap().get(&id) { - None => { - warn!("Should ping {}, but no connection", hex::encode(id)); - return; - } - Some(peer) => peer.clone(), - } - }; - let peer_list_hash = self.known_hosts.read().unwrap().hash; let ping_id = self.next_ping_id.fetch_add(1u64, atomic::Ordering::Relaxed); let ping_time = Instant::now(); @@ -284,7 +271,7 @@ impl FullMeshPeeringStrategy { hex::encode(id), ping_time ); - match peer.clone().request(ping_msg, prio::HIGH).await { + match self.netapp.request(&id, ping_msg, prio::HIGH).await { Err(e) => warn!("Error pinging {}: {}", hex::encode(id), e), Ok(ping_resp) => { let resp_time = Instant::now(); @@ -304,16 +291,16 @@ impl FullMeshPeeringStrategy { } } if ping_resp.peer_list_hash != peer_list_hash { - self.exchange_peers(peer).await; + self.exchange_peers(&id).await; } } } } - async fn exchange_peers(self: Arc<Self>, peer: Arc<ClientConn>) { + async fn exchange_peers(self: Arc<Self>, id: &ed25519::PublicKey) { let peer_list = KnownHosts::map_into_vec(&self.known_hosts.read().unwrap().list); let pex_message = PeerListMessage { list: peer_list }; - match peer.request(pex_message, prio::BACKGROUND).await { + match self.netapp.request(id, pex_message, prio::BACKGROUND).await { Err(e) => warn!("Error doing peer exchange: {}", e), Ok(resp) => { self.handle_peer_list(&resp.list[..]); |