diff options
-rw-r--r-- | Cargo.toml | 2 | ||||
-rw-r--r-- | src/conn.rs | 90 | ||||
-rw-r--r-- | src/endpoint.rs | 125 | ||||
-rw-r--r-- | src/error.rs | 23 | ||||
-rw-r--r-- | src/lib.rs | 2 | ||||
-rw-r--r-- | src/message.rs | 36 | ||||
-rw-r--r-- | src/netapp.rs | 216 | ||||
-rw-r--r-- | src/peering/fullmesh.rs | 64 | ||||
-rw-r--r-- | src/proto.rs | 36 |
9 files changed, 344 insertions, 250 deletions
@@ -1,6 +1,6 @@ [package] name = "netapp" -version = "0.2.0" +version = "0.3.0" authors = ["Alex Auvolat <alex@adnab.me>"] edition = "2018" license-file = "LICENSE" diff --git a/src/conn.rs b/src/conn.rs index c2c9c8b..64318dc 100644 --- a/src/conn.rs +++ b/src/conn.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; use std::net::SocketAddr; -use std::sync::atomic::{self, AtomicBool, AtomicU16}; +use std::sync::atomic::{self, AtomicBool, AtomicU32}; use std::sync::{Arc, Mutex}; use bytes::Bytes; @@ -16,12 +16,22 @@ use async_trait::async_trait; use kuska_handshake::async_std::{handshake_client, handshake_server, BoxStream}; +use crate::endpoint::*; use crate::error::*; -use crate::message::*; use crate::netapp::*; use crate::proto::*; use crate::util::*; +// Request message format (client -> server): +// - u8 priority +// - u8 path length +// - [u8; path length] path +// - [u8; *] data + +// Response message format (server -> client): +// - u8 response code +// - [u8; *] response + pub(crate) struct ServerConn { pub(crate) remote_addr: SocketAddr, pub(crate) peer_id: NodeID, @@ -99,30 +109,60 @@ impl ServerConn { pub fn close(&self) { self.close_send.send(true).unwrap(); } + + async fn recv_handler_aux(self: &Arc<Self>, bytes: &[u8]) -> Result<Vec<u8>, Error> { + if bytes.len() < 2 { + return Err(Error::Message("Invalid protocol message".into())); + } + + // byte 0 is the request priority, we don't care here + let path_length = bytes[1] as usize; + if bytes.len() < 2 + path_length { + return Err(Error::Message("Invalid protocol message".into())); + } + + let path = &bytes[2..2 + path_length]; + let path = String::from_utf8(path.to_vec())?; + let data = &bytes[2 + path_length..]; + + let handler_opt = { + let endpoints = self.netapp.endpoints.read().unwrap(); + endpoints.get(&path).map(|e| e.clone_endpoint()) + }; + + if let Some(handler) = handler_opt { + handler.handle(data, self.peer_id).await + } else { + Err(Error::NoHandler) + } + } } impl SendLoop for ServerConn {} #[async_trait] impl RecvLoop for ServerConn { - async fn recv_handler(self: Arc<Self>, id: u16, bytes: Vec<u8>) { + async fn recv_handler(self: Arc<Self>, id: RequestID, bytes: Vec<u8>) { trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len()); - let bytes: Bytes = bytes.into(); + let resp = self.recv_handler_aux(&bytes[..]).await; let prio = bytes[0]; - let mut kind_bytes = [0u8; 4]; - kind_bytes.copy_from_slice(&bytes[1..5]); - let kind = u32::from_be_bytes(kind_bytes); - - if let Some(handler) = self.netapp.msg_handlers.load().get(&kind) { - let net_handler = &handler.net_handler; - let resp = net_handler(self.peer_id, bytes.slice(5..)).await; - self.resp_send - .send(Some((id, prio, resp))) - .log_err("ServerConn recv_handler send resp"); + let mut resp_bytes = vec![]; + match resp { + Ok(rb) => { + resp_bytes.push(0u8); + resp_bytes.extend(&rb[..]); + } + Err(e) => { + resp_bytes.push(e.code()); + } } + + self.resp_send + .send(Some((id, prio, resp_bytes))) + .log_err("ServerConn recv_handler send resp"); } } pub(crate) struct ClientConn { @@ -131,7 +171,7 @@ pub(crate) struct ClientConn { query_send: mpsc::UnboundedSender<Option<(RequestID, RequestPriority, Vec<u8>)>>, - next_query_number: AtomicU16, + next_query_number: AtomicU32, inflight: Mutex<HashMap<RequestID, oneshot::Sender<Vec<u8>>>>, must_exit: AtomicBool, stop_recv_loop: watch::Sender<bool>, @@ -173,7 +213,7 @@ impl ClientConn { let conn = Arc::new(ClientConn { remote_addr, peer_id, - next_query_number: AtomicU16::from(0u16), + next_query_number: AtomicU32::from(RequestID::default()), query_send, inflight: Mutex::new(HashMap::new()), must_exit: AtomicBool::new(false), @@ -212,9 +252,10 @@ impl ClientConn { } } - pub(crate) async fn request<T>( + pub(crate) async fn call<T>( self: Arc<Self>, rq: T, + path: &str, prio: RequestPriority, ) -> Result<<T as Message>::Response, Error> where @@ -222,9 +263,9 @@ impl ClientConn { { let id = self .next_query_number - .fetch_add(1u16, atomic::Ordering::Relaxed); - let mut bytes = vec![prio]; - bytes.extend_from_slice(&u32::to_be_bytes(T::KIND)[..]); + .fetch_add(1, atomic::Ordering::Relaxed); + let mut bytes = vec![prio, path.as_bytes().len() as u8]; + bytes.extend_from_slice(path.as_bytes()); bytes.extend_from_slice(&rmp_to_vec_all_named(&rq)?[..]); let (resp_send, resp_recv) = oneshot::channel(); @@ -243,8 +284,15 @@ impl ClientConn { let resp = resp_recv.await?; - rmp_serde::decode::from_read_ref::<_, Result<<T as Message>::Response, String>>(&resp[..])? + let code = resp[0]; + if code == 0 { + rmp_serde::decode::from_read_ref::<_, Result<<T as Message>::Response, String>>( + &resp[1..], + )? .map_err(Error::Remote) + } else { + Err(Error::Remote(format!("Remote error code {}", code))) + } } } diff --git a/src/endpoint.rs b/src/endpoint.rs new file mode 100644 index 0000000..0e1f5c8 --- /dev/null +++ b/src/endpoint.rs @@ -0,0 +1,125 @@ +use std::marker::PhantomData; +use std::sync::Arc; + +use arc_swap::ArcSwapOption; +use async_trait::async_trait; + +use serde::{Deserialize, Serialize}; + +use crate::error::Error; +use crate::netapp::*; +use crate::proto::*; +use crate::util::*; + +/// This trait should be implemented by all messages your application +/// wants to handle (click to read more). +pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync { + type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync; +} + +pub(crate) type DynEndpoint = Box<dyn GenericEndpoint + Send + Sync>; + +#[async_trait] +pub trait EndpointHandler<M>: Send + Sync +where + M: Message, +{ + async fn handle(self: &Arc<Self>, m: M, from: NodeID) -> M::Response; +} + +pub struct Endpoint<M, H> +where + M: Message, + H: EndpointHandler<M>, +{ + phantom: PhantomData<M>, + netapp: Arc<NetApp>, + path: String, + handler: ArcSwapOption<H>, +} + +impl<M, H> Endpoint<M, H> +where + M: Message, + H: EndpointHandler<M>, +{ + pub(crate) fn new(netapp: Arc<NetApp>, path: String) -> Self { + Self { + phantom: PhantomData::default(), + netapp, + path, + handler: ArcSwapOption::from(None), + } + } + pub fn set_handler(&self, h: Arc<H>) { + self.handler.swap(Some(h)); + } + pub async fn call( + &self, + target: &NodeID, + req: M, + prio: RequestPriority, + ) -> Result<<M as Message>::Response, Error> { + if *target == self.netapp.id { + match self.handler.load_full() { + None => Err(Error::NoHandler), + Some(h) => Ok(h.handle(req, self.netapp.id).await), + } + } else { + let conn = self + .netapp + .client_conns + .read() + .unwrap() + .get(target) + .cloned(); + match conn { + None => Err(Error::Message(format!( + "Not connected: {}", + hex::encode(target) + ))), + Some(c) => c.call(req, self.path.as_str(), prio).await, + } + } + } +} + +#[async_trait] +pub(crate) trait GenericEndpoint { + async fn handle(&self, buf: &[u8], from: NodeID) -> Result<Vec<u8>, Error>; + fn clear_handler(&self); + fn clone_endpoint(&self) -> DynEndpoint; +} + +#[derive(Clone)] +pub(crate) struct EndpointArc<M, H>(pub(crate) Arc<Endpoint<M, H>>) +where + M: Message, + H: EndpointHandler<M>; + +#[async_trait] +impl<M, H> GenericEndpoint for EndpointArc<M, H> +where + M: Message + 'static, + H: EndpointHandler<M> + 'static, +{ + async fn handle(&self, buf: &[u8], from: NodeID) -> Result<Vec<u8>, Error> { + match self.0.handler.load_full() { + None => Err(Error::NoHandler), + Some(h) => { + let req = rmp_serde::decode::from_read_ref::<_, M>(buf)?; + let res = h.handle(req, from).await; + let res_bytes = rmp_to_vec_all_named(&res)?; + Ok(res_bytes) + } + } + } + + fn clear_handler(&self) { + self.0.handler.swap(None); + } + + fn clone_endpoint(&self) -> DynEndpoint { + Box::new(Self(self.0.clone())) + } +} diff --git a/src/error.rs b/src/error.rs index 469670a..14c6187 100644 --- a/src/error.rs +++ b/src/error.rs @@ -22,13 +22,36 @@ pub enum Error { #[error(display = "Handshake error: {}", _0)] Handshake(#[error(source)] kuska_handshake::async_std::Error), + #[error(display = "UTF8 error: {}", _0)] + UTF8(#[error(source)] std::string::FromUtf8Error), + #[error(display = "{}", _0)] Message(String), + #[error(display = "No handler / shutting down")] + NoHandler, + #[error(display = "Remote error: {}", _0)] Remote(String), } +impl Error { + pub fn code(&self) -> u8 { + match self { + Self::Io(_) => 100, + Self::TokioJoin(_) => 110, + Self::OneshotRecv(_) => 111, + Self::RMPEncode(_) => 10, + Self::RMPDecode(_) => 11, + Self::UTF8(_) => 12, + Self::NoHandler => 20, + Self::Handshake(_) => 30, + Self::Remote(_) => 40, + Self::Message(_) => 99, + } + } +} + impl<T> From<tokio::sync::watch::error::SendError<T>> for Error { fn from(_e: tokio::sync::watch::error::SendError<T>) -> Error { Error::Message("Watch send error".into()) @@ -18,7 +18,7 @@ pub mod error; pub mod util; -pub mod message; +pub mod endpoint; pub mod proto; mod conn; diff --git a/src/message.rs b/src/message.rs deleted file mode 100644 index 9ab14f9..0000000 --- a/src/message.rs +++ /dev/null @@ -1,36 +0,0 @@ -use std::net::IpAddr; - -use serde::{Deserialize, Serialize}; - -pub type MessageKind = u32; - -/// This trait should be implemented by all messages your application -/// wants to handle (click to read more). -/// -/// It defines a `KIND`, which should be a **unique** -/// `u32` that distinguishes these messages from other types of messages -/// (it is used by our communication protocol), as well as an associated -/// `Response` type that defines the type of the response that is given -/// to the message. It is your responsibility to ensure that `KIND` is a -/// unique `u32` that is not used by any other protocol messages. -/// All `KIND` values of the form `0x42xxxxxx` are reserved by the netapp -/// crate for internal purposes. -/// -/// A handler for this message has type `Self -> Self::Response`. -/// If you need to return an error, the `Response` type should be -/// a `Result<_, _>`. -pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync { - const KIND: MessageKind; - type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync; -} - -#[derive(Serialize, Deserialize)] -pub(crate) struct HelloMessage { - pub server_addr: Option<IpAddr>, - pub server_port: u16, -} - -impl Message for HelloMessage { - const KIND: MessageKind = 0x42000001; - type Response = (); -} diff --git a/src/netapp.rs b/src/netapp.rs index e49f599..8415c58 100644 --- a/src/netapp.rs +++ b/src/netapp.rs @@ -1,43 +1,36 @@ -use std::any::Any; use std::collections::HashMap; use std::net::{IpAddr, SocketAddr}; -use std::pin::Pin; use std::sync::{Arc, RwLock}; -use std::time::Instant; - -use std::future::Future; use log::{debug, info}; -use arc_swap::{ArcSwap, ArcSwapOption}; -use bytes::Bytes; +use arc_swap::ArcSwapOption; +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; use sodiumoxide::crypto::auth; use sodiumoxide::crypto::sign::ed25519; use tokio::net::{TcpListener, TcpStream}; use crate::conn::*; +use crate::endpoint::*; use crate::error::*; -use crate::message::*; use crate::proto::*; use crate::util::*; -type DynMsg = Box<dyn Any + Send + Sync + 'static>; +#[derive(Serialize, Deserialize)] +pub(crate) struct HelloMessage { + pub server_addr: Option<IpAddr>, + pub server_port: u16, +} + +impl Message for HelloMessage { + type Response = (); +} type OnConnectHandler = Box<dyn Fn(NodeID, SocketAddr, bool) + Send + Sync>; type OnDisconnectHandler = Box<dyn Fn(NodeID, bool) + Send + Sync>; -pub(crate) type LocalHandler = - Box<dyn Fn(DynMsg) -> Pin<Box<dyn Future<Output = DynMsg> + Sync + Send>> + Sync + Send>; -pub(crate) type NetHandler = Box< - dyn Fn(NodeID, Bytes) -> Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>> + Sync + Send, ->; - -pub(crate) struct Handler { - pub(crate) local_handler: LocalHandler, - pub(crate) net_handler: NetHandler, -} - /// NetApp is the main class that handles incoming and outgoing connections. /// /// The `request()` method can be used to send a message to any peer to which we have @@ -60,10 +53,12 @@ pub struct NetApp { /// Private key associated with our peer ID pub privkey: ed25519::SecretKey, - server_conns: RwLock<HashMap<NodeID, Arc<ServerConn>>>, - client_conns: RwLock<HashMap<NodeID, Arc<ClientConn>>>, + pub(crate) server_conns: RwLock<HashMap<NodeID, Arc<ServerConn>>>, + pub(crate) client_conns: RwLock<HashMap<NodeID, Arc<ClientConn>>>, + + pub(crate) endpoints: RwLock<HashMap<String, DynEndpoint>>, + hello_endpoint: ArcSwapOption<Endpoint<HelloMessage, NetApp>>, - pub(crate) msg_handlers: ArcSwap<HashMap<MessageKind, Arc<Handler>>>, on_connected_handler: ArcSwapOption<OnConnectHandler>, on_disconnected_handler: ArcSwapOption<OnDisconnectHandler>, } @@ -73,44 +68,6 @@ struct ListenParams { public_addr: Option<IpAddr>, } -async fn net_handler_aux<M, F, R>(handler: Arc<F>, remote: NodeID, bytes: Bytes) -> Vec<u8> -where - M: Message + 'static, - F: Fn(NodeID, M) -> R + Send + Sync + 'static, - R: Future<Output = <M as Message>::Response> + Send + Sync, -{ - debug!( - "Handling message of kind {:08x} from {}", - M::KIND, - hex::encode(remote) - ); - let begin_time = Instant::now(); - let res = match rmp_serde::decode::from_read_ref::<_, M>(&bytes[..]) { - Ok(msg) => Ok(handler(remote, msg).await), - Err(e) => Err(e.to_string()), - }; - let end_time = Instant::now(); - debug!( - "Request {:08x} from {} handled in {}msec", - M::KIND, - hex::encode(remote), - (end_time - begin_time).as_millis() - ); - rmp_to_vec_all_named(&res).unwrap_or_default() -} - -async fn local_handler_aux<M, F, R>(handler: Arc<F>, remote: NodeID, msg: DynMsg) -> DynMsg -where - M: Message + 'static, - F: Fn(NodeID, M) -> R + Send + Sync + 'static, - R: Future<Output = <M as Message>::Response> + Send + Sync, -{ - debug!("Handling message of kind {:08x} from ourself", M::KIND); - let msg = (msg as Box<dyn Any + 'static>).downcast::<M>().unwrap(); - let res = handler(remote, *msg).await; - Box::new(res) -} - impl NetApp { /// Creates a new instance of NetApp, which can serve either as a full p2p node, /// or just as a passive client. To upgrade to a full p2p node, spawn a listener @@ -126,16 +83,20 @@ impl NetApp { privkey, server_conns: RwLock::new(HashMap::new()), client_conns: RwLock::new(HashMap::new()), - msg_handlers: ArcSwap::new(Arc::new(HashMap::new())), + endpoints: RwLock::new(HashMap::new()), + hello_endpoint: ArcSwapOption::new(None), on_connected_handler: ArcSwapOption::new(None), on_disconnected_handler: ArcSwapOption::new(None), }); - let netapp2 = netapp.clone(); - netapp.add_msg_handler::<HelloMessage, _, _>(move |from: NodeID, msg: HelloMessage| { - netapp2.handle_hello_message(from, msg); - async {} - }); + netapp + .hello_endpoint + .swap(Some(netapp.endpoint("__netapp/netapp.rs/Hello".into()))); + netapp + .hello_endpoint + .load_full() + .unwrap() + .set_handler(netapp.clone()); netapp } @@ -162,40 +123,23 @@ impl NetApp { .store(Some(Arc::new(Box::new(handler)))); } - /// Add a handler for a certain message type. Note that only one handler - /// can be specified for each message type. - /// The handler is an asynchronous function, i.e. a function that returns - /// a future. - pub fn add_msg_handler<M, F, R>(&self, handler: F) + pub fn endpoint<M, H>(self: &Arc<Self>, name: String) -> Arc<Endpoint<M, H>> where M: Message + 'static, - F: Fn(NodeID, M) -> R + Send + Sync + 'static, - R: Future<Output = <M as Message>::Response> + Send + Sync + 'static, + H: EndpointHandler<M> + 'static, { - let handler = Arc::new(handler); - - let handler2 = handler.clone(); - let net_handler = Box::new(move |remote: NodeID, bytes: Bytes| { - let fun: Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>> = - Box::pin(net_handler_aux(handler2.clone(), remote, bytes)); - fun - }); - - let self_id = self.id; - let local_handler = Box::new(move |msg: DynMsg| { - let fun: Pin<Box<dyn Future<Output = DynMsg> + Sync + Send>> = - Box::pin(local_handler_aux(handler.clone(), self_id, msg)); - fun - }); - - let funs = Arc::new(Handler { - net_handler, - local_handler, - }); - - let mut handlers = self.msg_handlers.load().as_ref().clone(); - handlers.insert(M::KIND, funs); - self.msg_handlers.store(Arc::new(handlers)); + let endpoint = Arc::new(Endpoint::<M, H>::new(self.clone(), name.clone())); + let endpoint_arc = EndpointArc(endpoint.clone()); + if self + .endpoints + .write() + .unwrap() + .insert(name.clone(), Box::new(endpoint_arc)) + .is_some() + { + panic!("Redefining endpoint: {}", name); + }; + endpoint } /// Main listening process for our app. This future runs during the whole @@ -318,15 +262,6 @@ impl NetApp { // At this point we know they are a full network member, and not just a client, // and we call the on_connected handler so that the peering strategy knows // we have a new potential peer - fn handle_hello_message(&self, id: NodeID, msg: HelloMessage) { - if let Some(h) = self.on_connected_handler.load().as_ref() { - if let Some(c) = self.server_conns.read().unwrap().get(&id) { - let remote_ip = msg.server_addr.unwrap_or_else(|| c.remote_addr.ip()); - let remote_addr = SocketAddr::new(remote_ip, msg.server_port); - h(id, remote_addr, true); - } - } - } // Called from conn.rs when an incoming connection is closed. // We deregister the connection from server_conns and call the @@ -371,16 +306,19 @@ impl NetApp { if let Some(lp) = self.listen_params.load_full() { let server_addr = lp.public_addr; let server_port = lp.listen_addr.port(); + let hello_endpoint = self.hello_endpoint.load_full().unwrap(); tokio::spawn(async move { - conn.request( - HelloMessage { - server_addr, - server_port, - }, - PRIO_NORMAL, - ) - .await - .log_err("Sending hello message"); + hello_endpoint + .call( + &conn.peer_id, + HelloMessage { + server_addr, + server_port, + }, + PRIO_NORMAL, + ) + .await + .log_err("Sending hello message"); }); } } @@ -404,44 +342,16 @@ impl NetApp { // else case: happens if connection was removed in .disconnect() // in which case on_disconnected_handler was already called } +} - /// Send a message to a remote host to which a client connection is already - /// established, and await their response. The target is the id of the peer we - /// want to send the message to. - /// The priority is an `u8`, with lower numbers meaning highest priority. - pub async fn request<T>( - &self, - target: &NodeID, - rq: T, - prio: RequestPriority, - ) -> Result<<T as Message>::Response, Error> - where - T: Message + 'static, - { - if *target == self.id { - let handler = self.msg_handlers.load().get(&T::KIND).cloned(); - match handler { - None => Err(Error::Message(format!( - "No handler registered for message kind {:08x}", - T::KIND - ))), - Some(h) => { - let local_handler = &h.local_handler; - let res = local_handler(Box::new(rq)).await; - let res_t = (res as Box<dyn Any + 'static>) - .downcast::<<T as Message>::Response>() - .unwrap(); - Ok(*res_t) - } - } - } else { - let conn = self.client_conns.read().unwrap().get(target).cloned(); - match conn { - None => Err(Error::Message(format!( - "Not connected: {}", - hex::encode(target) - ))), - Some(c) => c.request(rq, prio).await, +#[async_trait] +impl EndpointHandler<HelloMessage> for NetApp { + async fn handle(self: &Arc<Self>, msg: HelloMessage, from: NodeID) { + if let Some(h) = self.on_connected_handler.load().as_ref() { + if let Some(c) = self.server_conns.read().unwrap().get(&from) { + let remote_ip = msg.server_addr.unwrap_or_else(|| c.remote_addr.ip()); + let remote_addr = SocketAddr::new(remote_ip, msg.server_port); + h(from, remote_addr, true); } } } diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs index 8b1c802..b579654 100644 --- a/src/peering/fullmesh.rs +++ b/src/peering/fullmesh.rs @@ -4,12 +4,13 @@ use std::sync::atomic::{self, AtomicU64}; use std::sync::{Arc, RwLock}; use std::time::{Duration, Instant}; +use async_trait::async_trait; use log::{debug, info, trace, warn}; use serde::{Deserialize, Serialize}; use sodiumoxide::crypto::hash; -use crate::message::*; +use crate::endpoint::*; use crate::netapp::*; use crate::proto::*; use crate::NodeID; @@ -28,7 +29,6 @@ struct PingMessage { } impl Message for PingMessage { - const KIND: MessageKind = 0x42001000; type Response = PingMessage; } @@ -38,7 +38,6 @@ struct PeerListMessage { } impl Message for PeerListMessage { - const KIND: MessageKind = 0x42001001; type Response = PeerListMessage; } @@ -124,6 +123,9 @@ pub struct FullMeshPeeringStrategy { netapp: Arc<NetApp>, known_hosts: RwLock<KnownHosts>, next_ping_id: AtomicU64, + + ping_endpoint: Arc<Endpoint<PingMessage, Self>>, + peer_list_endpoint: Arc<Endpoint<PeerListMessage, Self>>, } impl FullMeshPeeringStrategy { @@ -147,27 +149,12 @@ impl FullMeshPeeringStrategy { netapp: netapp.clone(), known_hosts: RwLock::new(known_hosts), next_ping_id: AtomicU64::new(42), + ping_endpoint: netapp.endpoint("__netapp/peering/fullmesh.rs/Ping".into()), + peer_list_endpoint: netapp.endpoint("__netapp/peering/fullmesh.rs/PeerList".into()), }); - let strat2 = strat.clone(); - netapp.add_msg_handler::<PingMessage, _, _>(move |from: NodeID, ping: PingMessage| { - let ping_resp = PingMessage { - id: ping.id, - peer_list_hash: strat2.known_hosts.read().unwrap().hash, - }; - debug!("Ping from {}", hex::encode(&from)); - async move { ping_resp } - }); - - let strat2 = strat.clone(); - netapp.add_msg_handler::<PeerListMessage, _, _>( - move |_from: NodeID, peer_list: PeerListMessage| { - strat2.handle_peer_list(&peer_list.list[..]); - let peer_list = KnownHosts::map_into_vec(&strat2.known_hosts.read().unwrap().list); - let resp = PeerListMessage { list: peer_list }; - async move { resp } - }, - ); + strat.ping_endpoint.set_handler(strat.clone()); + strat.peer_list_endpoint.set_handler(strat.clone()); let strat2 = strat.clone(); netapp.on_connected(move |id: NodeID, addr: SocketAddr, is_incoming: bool| { @@ -262,7 +249,7 @@ impl FullMeshPeeringStrategy { hex::encode(id), ping_time ); - match self.netapp.request(&id, ping_msg, PRIO_HIGH).await { + match self.ping_endpoint.call(&id, ping_msg, PRIO_HIGH).await { Err(e) => warn!("Error pinging {}: {}", hex::encode(id), e), Ok(ping_resp) => { let resp_time = Instant::now(); @@ -291,7 +278,11 @@ impl FullMeshPeeringStrategy { async fn exchange_peers(self: Arc<Self>, id: &NodeID) { let peer_list = KnownHosts::map_into_vec(&self.known_hosts.read().unwrap().list); let pex_message = PeerListMessage { list: peer_list }; - match self.netapp.request(id, pex_message, PRIO_BACKGROUND).await { + match self + .peer_list_endpoint + .call(id, pex_message, PRIO_BACKGROUND) + .await + { Err(e) => warn!("Error doing peer exchange: {}", e), Ok(resp) => { self.handle_peer_list(&resp.list[..]); @@ -408,3 +399,28 @@ impl FullMeshPeeringStrategy { } } } + +#[async_trait] +impl EndpointHandler<PingMessage> for FullMeshPeeringStrategy { + async fn handle(self: &Arc<Self>, ping: PingMessage, from: NodeID) -> PingMessage { + let ping_resp = PingMessage { + id: ping.id, + peer_list_hash: self.known_hosts.read().unwrap().hash, + }; + debug!("Ping from {}", hex::encode(&from)); + ping_resp + } +} + +#[async_trait] +impl EndpointHandler<PeerListMessage> for FullMeshPeeringStrategy { + async fn handle( + self: &Arc<Self>, + peer_list: PeerListMessage, + _from: NodeID, + ) -> PeerListMessage { + self.handle_peer_list(&peer_list.list[..]); + let peer_list = KnownHosts::map_into_vec(&self.known_hosts.read().unwrap().list); + PeerListMessage { list: peer_list } + } +} diff --git a/src/proto.rs b/src/proto.rs index ef3b31c..5b71ba5 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -38,9 +38,10 @@ pub const PRIO_PRIMARY: RequestPriority = 0x00; /// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`) pub const PRIO_SECONDARY: RequestPriority = 0x01; -const MAX_CHUNK_SIZE: usize = 0x4000; - -pub(crate) type RequestID = u16; +pub(crate) type RequestID = u32; +type ChunkLength = u16; +const MAX_CHUNK_LENGTH: ChunkLength = 0x4000; +const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; struct SendQueueItem { id: RequestID, @@ -85,6 +86,12 @@ impl SendQueue { } } +// Messages are sent by chunks +// Chunk format: +// - u32 BE: request id (same for request and response) +// - u16 BE: chunk length +// - [u8; chunk_length] chunk data + #[async_trait] pub(crate) trait SendLoop: Sync { async fn send_loop<W>( @@ -117,22 +124,23 @@ pub(crate) trait SendLoop: Sync { item.data.len(), item.cursor ); - let header_id = u16::to_be_bytes(item.id); + let header_id = RequestID::to_be_bytes(item.id); write.write_all(&header_id[..]).await?; - if item.data.len() - item.cursor > MAX_CHUNK_SIZE { - let header_size = u16::to_be_bytes(MAX_CHUNK_SIZE as u16 | 0x8000); + if item.data.len() - item.cursor > MAX_CHUNK_LENGTH as usize { + let header_size = + ChunkLength::to_be_bytes(MAX_CHUNK_LENGTH | CHUNK_HAS_CONTINUATION); write.write_all(&header_size[..]).await?; - let new_cursor = item.cursor + MAX_CHUNK_SIZE as usize; + let new_cursor = item.cursor + MAX_CHUNK_LENGTH as usize; write.write_all(&item.data[item.cursor..new_cursor]).await?; item.cursor = new_cursor; sending.push(item); } else { - let send_len = (item.data.len() - item.cursor) as u16; + let send_len = (item.data.len() - item.cursor) as ChunkLength; - let header_size = u16::to_be_bytes(send_len); + let header_size = ChunkLength::to_be_bytes(send_len); write.write_all(&header_size[..]).await?; write.write_all(&item.data[item.cursor..]).await?; @@ -172,18 +180,18 @@ pub(crate) trait RecvLoop: Sync + 'static { let mut receiving = HashMap::new(); loop { trace!("recv_loop: reading packet"); - let mut header_id = [0u8; 2]; + let mut header_id = [0u8; RequestID::BITS as usize / 8]; read.read_exact(&mut header_id[..]).await?; let id = RequestID::from_be_bytes(header_id); trace!("recv_loop: got header id: {:04x}", id); - let mut header_size = [0u8; 2]; + let mut header_size = [0u8; ChunkLength::BITS as usize / 8]; read.read_exact(&mut header_size[..]).await?; - let size = RequestID::from_be_bytes(header_size); + let size = ChunkLength::from_be_bytes(header_size); trace!("recv_loop: got header size: {:04x}", size); - let has_cont = (size & 0x8000) != 0; - let size = size & !0x8000; + let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; + let size = size & !CHUNK_HAS_CONTINUATION; let mut next_slice = vec![0; size as usize]; read.read_exact(&mut next_slice[..]).await?; |