diff options
Diffstat (limited to 'src/netapp.rs')
-rw-r--r-- | src/netapp.rs | 216 |
1 files changed, 63 insertions, 153 deletions
diff --git a/src/netapp.rs b/src/netapp.rs index e49f599..8415c58 100644 --- a/src/netapp.rs +++ b/src/netapp.rs @@ -1,43 +1,36 @@ -use std::any::Any; use std::collections::HashMap; use std::net::{IpAddr, SocketAddr}; -use std::pin::Pin; use std::sync::{Arc, RwLock}; -use std::time::Instant; - -use std::future::Future; use log::{debug, info}; -use arc_swap::{ArcSwap, ArcSwapOption}; -use bytes::Bytes; +use arc_swap::ArcSwapOption; +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; use sodiumoxide::crypto::auth; use sodiumoxide::crypto::sign::ed25519; use tokio::net::{TcpListener, TcpStream}; use crate::conn::*; +use crate::endpoint::*; use crate::error::*; -use crate::message::*; use crate::proto::*; use crate::util::*; -type DynMsg = Box<dyn Any + Send + Sync + 'static>; +#[derive(Serialize, Deserialize)] +pub(crate) struct HelloMessage { + pub server_addr: Option<IpAddr>, + pub server_port: u16, +} + +impl Message for HelloMessage { + type Response = (); +} type OnConnectHandler = Box<dyn Fn(NodeID, SocketAddr, bool) + Send + Sync>; type OnDisconnectHandler = Box<dyn Fn(NodeID, bool) + Send + Sync>; -pub(crate) type LocalHandler = - Box<dyn Fn(DynMsg) -> Pin<Box<dyn Future<Output = DynMsg> + Sync + Send>> + Sync + Send>; -pub(crate) type NetHandler = Box< - dyn Fn(NodeID, Bytes) -> Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>> + Sync + Send, ->; - -pub(crate) struct Handler { - pub(crate) local_handler: LocalHandler, - pub(crate) net_handler: NetHandler, -} - /// NetApp is the main class that handles incoming and outgoing connections. /// /// The `request()` method can be used to send a message to any peer to which we have @@ -60,10 +53,12 @@ pub struct NetApp { /// Private key associated with our peer ID pub privkey: ed25519::SecretKey, - server_conns: RwLock<HashMap<NodeID, Arc<ServerConn>>>, - client_conns: RwLock<HashMap<NodeID, Arc<ClientConn>>>, + pub(crate) server_conns: RwLock<HashMap<NodeID, Arc<ServerConn>>>, + pub(crate) client_conns: RwLock<HashMap<NodeID, Arc<ClientConn>>>, + + pub(crate) endpoints: RwLock<HashMap<String, DynEndpoint>>, + hello_endpoint: ArcSwapOption<Endpoint<HelloMessage, NetApp>>, - pub(crate) msg_handlers: ArcSwap<HashMap<MessageKind, Arc<Handler>>>, on_connected_handler: ArcSwapOption<OnConnectHandler>, on_disconnected_handler: ArcSwapOption<OnDisconnectHandler>, } @@ -73,44 +68,6 @@ struct ListenParams { public_addr: Option<IpAddr>, } -async fn net_handler_aux<M, F, R>(handler: Arc<F>, remote: NodeID, bytes: Bytes) -> Vec<u8> -where - M: Message + 'static, - F: Fn(NodeID, M) -> R + Send + Sync + 'static, - R: Future<Output = <M as Message>::Response> + Send + Sync, -{ - debug!( - "Handling message of kind {:08x} from {}", - M::KIND, - hex::encode(remote) - ); - let begin_time = Instant::now(); - let res = match rmp_serde::decode::from_read_ref::<_, M>(&bytes[..]) { - Ok(msg) => Ok(handler(remote, msg).await), - Err(e) => Err(e.to_string()), - }; - let end_time = Instant::now(); - debug!( - "Request {:08x} from {} handled in {}msec", - M::KIND, - hex::encode(remote), - (end_time - begin_time).as_millis() - ); - rmp_to_vec_all_named(&res).unwrap_or_default() -} - -async fn local_handler_aux<M, F, R>(handler: Arc<F>, remote: NodeID, msg: DynMsg) -> DynMsg -where - M: Message + 'static, - F: Fn(NodeID, M) -> R + Send + Sync + 'static, - R: Future<Output = <M as Message>::Response> + Send + Sync, -{ - debug!("Handling message of kind {:08x} from ourself", M::KIND); - let msg = (msg as Box<dyn Any + 'static>).downcast::<M>().unwrap(); - let res = handler(remote, *msg).await; - Box::new(res) -} - impl NetApp { /// Creates a new instance of NetApp, which can serve either as a full p2p node, /// or just as a passive client. To upgrade to a full p2p node, spawn a listener @@ -126,16 +83,20 @@ impl NetApp { privkey, server_conns: RwLock::new(HashMap::new()), client_conns: RwLock::new(HashMap::new()), - msg_handlers: ArcSwap::new(Arc::new(HashMap::new())), + endpoints: RwLock::new(HashMap::new()), + hello_endpoint: ArcSwapOption::new(None), on_connected_handler: ArcSwapOption::new(None), on_disconnected_handler: ArcSwapOption::new(None), }); - let netapp2 = netapp.clone(); - netapp.add_msg_handler::<HelloMessage, _, _>(move |from: NodeID, msg: HelloMessage| { - netapp2.handle_hello_message(from, msg); - async {} - }); + netapp + .hello_endpoint + .swap(Some(netapp.endpoint("__netapp/netapp.rs/Hello".into()))); + netapp + .hello_endpoint + .load_full() + .unwrap() + .set_handler(netapp.clone()); netapp } @@ -162,40 +123,23 @@ impl NetApp { .store(Some(Arc::new(Box::new(handler)))); } - /// Add a handler for a certain message type. Note that only one handler - /// can be specified for each message type. - /// The handler is an asynchronous function, i.e. a function that returns - /// a future. - pub fn add_msg_handler<M, F, R>(&self, handler: F) + pub fn endpoint<M, H>(self: &Arc<Self>, name: String) -> Arc<Endpoint<M, H>> where M: Message + 'static, - F: Fn(NodeID, M) -> R + Send + Sync + 'static, - R: Future<Output = <M as Message>::Response> + Send + Sync + 'static, + H: EndpointHandler<M> + 'static, { - let handler = Arc::new(handler); - - let handler2 = handler.clone(); - let net_handler = Box::new(move |remote: NodeID, bytes: Bytes| { - let fun: Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>> = - Box::pin(net_handler_aux(handler2.clone(), remote, bytes)); - fun - }); - - let self_id = self.id; - let local_handler = Box::new(move |msg: DynMsg| { - let fun: Pin<Box<dyn Future<Output = DynMsg> + Sync + Send>> = - Box::pin(local_handler_aux(handler.clone(), self_id, msg)); - fun - }); - - let funs = Arc::new(Handler { - net_handler, - local_handler, - }); - - let mut handlers = self.msg_handlers.load().as_ref().clone(); - handlers.insert(M::KIND, funs); - self.msg_handlers.store(Arc::new(handlers)); + let endpoint = Arc::new(Endpoint::<M, H>::new(self.clone(), name.clone())); + let endpoint_arc = EndpointArc(endpoint.clone()); + if self + .endpoints + .write() + .unwrap() + .insert(name.clone(), Box::new(endpoint_arc)) + .is_some() + { + panic!("Redefining endpoint: {}", name); + }; + endpoint } /// Main listening process for our app. This future runs during the whole @@ -318,15 +262,6 @@ impl NetApp { // At this point we know they are a full network member, and not just a client, // and we call the on_connected handler so that the peering strategy knows // we have a new potential peer - fn handle_hello_message(&self, id: NodeID, msg: HelloMessage) { - if let Some(h) = self.on_connected_handler.load().as_ref() { - if let Some(c) = self.server_conns.read().unwrap().get(&id) { - let remote_ip = msg.server_addr.unwrap_or_else(|| c.remote_addr.ip()); - let remote_addr = SocketAddr::new(remote_ip, msg.server_port); - h(id, remote_addr, true); - } - } - } // Called from conn.rs when an incoming connection is closed. // We deregister the connection from server_conns and call the @@ -371,16 +306,19 @@ impl NetApp { if let Some(lp) = self.listen_params.load_full() { let server_addr = lp.public_addr; let server_port = lp.listen_addr.port(); + let hello_endpoint = self.hello_endpoint.load_full().unwrap(); tokio::spawn(async move { - conn.request( - HelloMessage { - server_addr, - server_port, - }, - PRIO_NORMAL, - ) - .await - .log_err("Sending hello message"); + hello_endpoint + .call( + &conn.peer_id, + HelloMessage { + server_addr, + server_port, + }, + PRIO_NORMAL, + ) + .await + .log_err("Sending hello message"); }); } } @@ -404,44 +342,16 @@ impl NetApp { // else case: happens if connection was removed in .disconnect() // in which case on_disconnected_handler was already called } +} - /// Send a message to a remote host to which a client connection is already - /// established, and await their response. The target is the id of the peer we - /// want to send the message to. - /// The priority is an `u8`, with lower numbers meaning highest priority. - pub async fn request<T>( - &self, - target: &NodeID, - rq: T, - prio: RequestPriority, - ) -> Result<<T as Message>::Response, Error> - where - T: Message + 'static, - { - if *target == self.id { - let handler = self.msg_handlers.load().get(&T::KIND).cloned(); - match handler { - None => Err(Error::Message(format!( - "No handler registered for message kind {:08x}", - T::KIND - ))), - Some(h) => { - let local_handler = &h.local_handler; - let res = local_handler(Box::new(rq)).await; - let res_t = (res as Box<dyn Any + 'static>) - .downcast::<<T as Message>::Response>() - .unwrap(); - Ok(*res_t) - } - } - } else { - let conn = self.client_conns.read().unwrap().get(target).cloned(); - match conn { - None => Err(Error::Message(format!( - "Not connected: {}", - hex::encode(target) - ))), - Some(c) => c.request(rq, prio).await, +#[async_trait] +impl EndpointHandler<HelloMessage> for NetApp { + async fn handle(self: &Arc<Self>, msg: HelloMessage, from: NodeID) { + if let Some(h) = self.on_connected_handler.load().as_ref() { + if let Some(c) = self.server_conns.read().unwrap().get(&from) { + let remote_ip = msg.server_addr.unwrap_or_else(|| c.remote_addr.ip()); + let remote_addr = SocketAddr::new(remote_ip, msg.server_port); + h(from, remote_addr, true); } } } |