aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlex Auvolat <alex@adnab.me>2021-10-12 17:59:46 +0200
committerAlex Auvolat <alex@adnab.me>2021-10-12 17:59:46 +0200
commitf87dbe73dc12f2d6eb13850a3bc4b012aadd3c9b (patch)
tree5407c8eab331d066e66f5193d51f6fd66bedb9bb /src
parent040231d554b74e981644e606c096ced6fc36a2ad (diff)
downloadnetapp-f87dbe73dc12f2d6eb13850a3bc4b012aadd3c9b.tar.gz
netapp-f87dbe73dc12f2d6eb13850a3bc4b012aadd3c9b.zip
WIP v0.3.0 with changed API
Diffstat (limited to 'src')
-rw-r--r--src/conn.rs90
-rw-r--r--src/endpoint.rs125
-rw-r--r--src/error.rs23
-rw-r--r--src/lib.rs2
-rw-r--r--src/message.rs36
-rw-r--r--src/netapp.rs216
-rw-r--r--src/peering/fullmesh.rs64
-rw-r--r--src/proto.rs36
8 files changed, 343 insertions, 249 deletions
diff --git a/src/conn.rs b/src/conn.rs
index c2c9c8b..64318dc 100644
--- a/src/conn.rs
+++ b/src/conn.rs
@@ -1,6 +1,6 @@
use std::collections::HashMap;
use std::net::SocketAddr;
-use std::sync::atomic::{self, AtomicBool, AtomicU16};
+use std::sync::atomic::{self, AtomicBool, AtomicU32};
use std::sync::{Arc, Mutex};
use bytes::Bytes;
@@ -16,12 +16,22 @@ use async_trait::async_trait;
use kuska_handshake::async_std::{handshake_client, handshake_server, BoxStream};
+use crate::endpoint::*;
use crate::error::*;
-use crate::message::*;
use crate::netapp::*;
use crate::proto::*;
use crate::util::*;
+// Request message format (client -> server):
+// - u8 priority
+// - u8 path length
+// - [u8; path length] path
+// - [u8; *] data
+
+// Response message format (server -> client):
+// - u8 response code
+// - [u8; *] response
+
pub(crate) struct ServerConn {
pub(crate) remote_addr: SocketAddr,
pub(crate) peer_id: NodeID,
@@ -99,30 +109,60 @@ impl ServerConn {
pub fn close(&self) {
self.close_send.send(true).unwrap();
}
+
+ async fn recv_handler_aux(self: &Arc<Self>, bytes: &[u8]) -> Result<Vec<u8>, Error> {
+ if bytes.len() < 2 {
+ return Err(Error::Message("Invalid protocol message".into()));
+ }
+
+ // byte 0 is the request priority, we don't care here
+ let path_length = bytes[1] as usize;
+ if bytes.len() < 2 + path_length {
+ return Err(Error::Message("Invalid protocol message".into()));
+ }
+
+ let path = &bytes[2..2 + path_length];
+ let path = String::from_utf8(path.to_vec())?;
+ let data = &bytes[2 + path_length..];
+
+ let handler_opt = {
+ let endpoints = self.netapp.endpoints.read().unwrap();
+ endpoints.get(&path).map(|e| e.clone_endpoint())
+ };
+
+ if let Some(handler) = handler_opt {
+ handler.handle(data, self.peer_id).await
+ } else {
+ Err(Error::NoHandler)
+ }
+ }
}
impl SendLoop for ServerConn {}
#[async_trait]
impl RecvLoop for ServerConn {
- async fn recv_handler(self: Arc<Self>, id: u16, bytes: Vec<u8>) {
+ async fn recv_handler(self: Arc<Self>, id: RequestID, bytes: Vec<u8>) {
trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len());
-
let bytes: Bytes = bytes.into();
+ let resp = self.recv_handler_aux(&bytes[..]).await;
let prio = bytes[0];
- let mut kind_bytes = [0u8; 4];
- kind_bytes.copy_from_slice(&bytes[1..5]);
- let kind = u32::from_be_bytes(kind_bytes);
-
- if let Some(handler) = self.netapp.msg_handlers.load().get(&kind) {
- let net_handler = &handler.net_handler;
- let resp = net_handler(self.peer_id, bytes.slice(5..)).await;
- self.resp_send
- .send(Some((id, prio, resp)))
- .log_err("ServerConn recv_handler send resp");
+ let mut resp_bytes = vec![];
+ match resp {
+ Ok(rb) => {
+ resp_bytes.push(0u8);
+ resp_bytes.extend(&rb[..]);
+ }
+ Err(e) => {
+ resp_bytes.push(e.code());
+ }
}
+
+ self.resp_send
+ .send(Some((id, prio, resp_bytes)))
+ .log_err("ServerConn recv_handler send resp");
}
}
pub(crate) struct ClientConn {
@@ -131,7 +171,7 @@ pub(crate) struct ClientConn {
query_send: mpsc::UnboundedSender<Option<(RequestID, RequestPriority, Vec<u8>)>>,
- next_query_number: AtomicU16,
+ next_query_number: AtomicU32,
inflight: Mutex<HashMap<RequestID, oneshot::Sender<Vec<u8>>>>,
must_exit: AtomicBool,
stop_recv_loop: watch::Sender<bool>,
@@ -173,7 +213,7 @@ impl ClientConn {
let conn = Arc::new(ClientConn {
remote_addr,
peer_id,
- next_query_number: AtomicU16::from(0u16),
+ next_query_number: AtomicU32::from(RequestID::default()),
query_send,
inflight: Mutex::new(HashMap::new()),
must_exit: AtomicBool::new(false),
@@ -212,9 +252,10 @@ impl ClientConn {
}
}
- pub(crate) async fn request<T>(
+ pub(crate) async fn call<T>(
self: Arc<Self>,
rq: T,
+ path: &str,
prio: RequestPriority,
) -> Result<<T as Message>::Response, Error>
where
@@ -222,9 +263,9 @@ impl ClientConn {
{
let id = self
.next_query_number
- .fetch_add(1u16, atomic::Ordering::Relaxed);
- let mut bytes = vec![prio];
- bytes.extend_from_slice(&u32::to_be_bytes(T::KIND)[..]);
+ .fetch_add(1, atomic::Ordering::Relaxed);
+ let mut bytes = vec![prio, path.as_bytes().len() as u8];
+ bytes.extend_from_slice(path.as_bytes());
bytes.extend_from_slice(&rmp_to_vec_all_named(&rq)?[..]);
let (resp_send, resp_recv) = oneshot::channel();
@@ -243,8 +284,15 @@ impl ClientConn {
let resp = resp_recv.await?;
- rmp_serde::decode::from_read_ref::<_, Result<<T as Message>::Response, String>>(&resp[..])?
+ let code = resp[0];
+ if code == 0 {
+ rmp_serde::decode::from_read_ref::<_, Result<<T as Message>::Response, String>>(
+ &resp[1..],
+ )?
.map_err(Error::Remote)
+ } else {
+ Err(Error::Remote(format!("Remote error code {}", code)))
+ }
}
}
diff --git a/src/endpoint.rs b/src/endpoint.rs
new file mode 100644
index 0000000..0e1f5c8
--- /dev/null
+++ b/src/endpoint.rs
@@ -0,0 +1,125 @@
+use std::marker::PhantomData;
+use std::sync::Arc;
+
+use arc_swap::ArcSwapOption;
+use async_trait::async_trait;
+
+use serde::{Deserialize, Serialize};
+
+use crate::error::Error;
+use crate::netapp::*;
+use crate::proto::*;
+use crate::util::*;
+
+/// This trait should be implemented by all messages your application
+/// wants to handle (click to read more).
+pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync {
+ type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync;
+}
+
+pub(crate) type DynEndpoint = Box<dyn GenericEndpoint + Send + Sync>;
+
+#[async_trait]
+pub trait EndpointHandler<M>: Send + Sync
+where
+ M: Message,
+{
+ async fn handle(self: &Arc<Self>, m: M, from: NodeID) -> M::Response;
+}
+
+pub struct Endpoint<M, H>
+where
+ M: Message,
+ H: EndpointHandler<M>,
+{
+ phantom: PhantomData<M>,
+ netapp: Arc<NetApp>,
+ path: String,
+ handler: ArcSwapOption<H>,
+}
+
+impl<M, H> Endpoint<M, H>
+where
+ M: Message,
+ H: EndpointHandler<M>,
+{
+ pub(crate) fn new(netapp: Arc<NetApp>, path: String) -> Self {
+ Self {
+ phantom: PhantomData::default(),
+ netapp,
+ path,
+ handler: ArcSwapOption::from(None),
+ }
+ }
+ pub fn set_handler(&self, h: Arc<H>) {
+ self.handler.swap(Some(h));
+ }
+ pub async fn call(
+ &self,
+ target: &NodeID,
+ req: M,
+ prio: RequestPriority,
+ ) -> Result<<M as Message>::Response, Error> {
+ if *target == self.netapp.id {
+ match self.handler.load_full() {
+ None => Err(Error::NoHandler),
+ Some(h) => Ok(h.handle(req, self.netapp.id).await),
+ }
+ } else {
+ let conn = self
+ .netapp
+ .client_conns
+ .read()
+ .unwrap()
+ .get(target)
+ .cloned();
+ match conn {
+ None => Err(Error::Message(format!(
+ "Not connected: {}",
+ hex::encode(target)
+ ))),
+ Some(c) => c.call(req, self.path.as_str(), prio).await,
+ }
+ }
+ }
+}
+
+#[async_trait]
+pub(crate) trait GenericEndpoint {
+ async fn handle(&self, buf: &[u8], from: NodeID) -> Result<Vec<u8>, Error>;
+ fn clear_handler(&self);
+ fn clone_endpoint(&self) -> DynEndpoint;
+}
+
+#[derive(Clone)]
+pub(crate) struct EndpointArc<M, H>(pub(crate) Arc<Endpoint<M, H>>)
+where
+ M: Message,
+ H: EndpointHandler<M>;
+
+#[async_trait]
+impl<M, H> GenericEndpoint for EndpointArc<M, H>
+where
+ M: Message + 'static,
+ H: EndpointHandler<M> + 'static,
+{
+ async fn handle(&self, buf: &[u8], from: NodeID) -> Result<Vec<u8>, Error> {
+ match self.0.handler.load_full() {
+ None => Err(Error::NoHandler),
+ Some(h) => {
+ let req = rmp_serde::decode::from_read_ref::<_, M>(buf)?;
+ let res = h.handle(req, from).await;
+ let res_bytes = rmp_to_vec_all_named(&res)?;
+ Ok(res_bytes)
+ }
+ }
+ }
+
+ fn clear_handler(&self) {
+ self.0.handler.swap(None);
+ }
+
+ fn clone_endpoint(&self) -> DynEndpoint {
+ Box::new(Self(self.0.clone()))
+ }
+}
diff --git a/src/error.rs b/src/error.rs
index 469670a..14c6187 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -22,13 +22,36 @@ pub enum Error {
#[error(display = "Handshake error: {}", _0)]
Handshake(#[error(source)] kuska_handshake::async_std::Error),
+ #[error(display = "UTF8 error: {}", _0)]
+ UTF8(#[error(source)] std::string::FromUtf8Error),
+
#[error(display = "{}", _0)]
Message(String),
+ #[error(display = "No handler / shutting down")]
+ NoHandler,
+
#[error(display = "Remote error: {}", _0)]
Remote(String),
}
+impl Error {
+ pub fn code(&self) -> u8 {
+ match self {
+ Self::Io(_) => 100,
+ Self::TokioJoin(_) => 110,
+ Self::OneshotRecv(_) => 111,
+ Self::RMPEncode(_) => 10,
+ Self::RMPDecode(_) => 11,
+ Self::UTF8(_) => 12,
+ Self::NoHandler => 20,
+ Self::Handshake(_) => 30,
+ Self::Remote(_) => 40,
+ Self::Message(_) => 99,
+ }
+ }
+}
+
impl<T> From<tokio::sync::watch::error::SendError<T>> for Error {
fn from(_e: tokio::sync::watch::error::SendError<T>) -> Error {
Error::Message("Watch send error".into())
diff --git a/src/lib.rs b/src/lib.rs
index a0bec32..e5251c5 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -18,7 +18,7 @@
pub mod error;
pub mod util;
-pub mod message;
+pub mod endpoint;
pub mod proto;
mod conn;
diff --git a/src/message.rs b/src/message.rs
deleted file mode 100644
index 9ab14f9..0000000
--- a/src/message.rs
+++ /dev/null
@@ -1,36 +0,0 @@
-use std::net::IpAddr;
-
-use serde::{Deserialize, Serialize};
-
-pub type MessageKind = u32;
-
-/// This trait should be implemented by all messages your application
-/// wants to handle (click to read more).
-///
-/// It defines a `KIND`, which should be a **unique**
-/// `u32` that distinguishes these messages from other types of messages
-/// (it is used by our communication protocol), as well as an associated
-/// `Response` type that defines the type of the response that is given
-/// to the message. It is your responsibility to ensure that `KIND` is a
-/// unique `u32` that is not used by any other protocol messages.
-/// All `KIND` values of the form `0x42xxxxxx` are reserved by the netapp
-/// crate for internal purposes.
-///
-/// A handler for this message has type `Self -> Self::Response`.
-/// If you need to return an error, the `Response` type should be
-/// a `Result<_, _>`.
-pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync {
- const KIND: MessageKind;
- type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync;
-}
-
-#[derive(Serialize, Deserialize)]
-pub(crate) struct HelloMessage {
- pub server_addr: Option<IpAddr>,
- pub server_port: u16,
-}
-
-impl Message for HelloMessage {
- const KIND: MessageKind = 0x42000001;
- type Response = ();
-}
diff --git a/src/netapp.rs b/src/netapp.rs
index e49f599..8415c58 100644
--- a/src/netapp.rs
+++ b/src/netapp.rs
@@ -1,43 +1,36 @@
-use std::any::Any;
use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
-use std::pin::Pin;
use std::sync::{Arc, RwLock};
-use std::time::Instant;
-
-use std::future::Future;
use log::{debug, info};
-use arc_swap::{ArcSwap, ArcSwapOption};
-use bytes::Bytes;
+use arc_swap::ArcSwapOption;
+use async_trait::async_trait;
+use serde::{Deserialize, Serialize};
use sodiumoxide::crypto::auth;
use sodiumoxide::crypto::sign::ed25519;
use tokio::net::{TcpListener, TcpStream};
use crate::conn::*;
+use crate::endpoint::*;
use crate::error::*;
-use crate::message::*;
use crate::proto::*;
use crate::util::*;
-type DynMsg = Box<dyn Any + Send + Sync + 'static>;
+#[derive(Serialize, Deserialize)]
+pub(crate) struct HelloMessage {
+ pub server_addr: Option<IpAddr>,
+ pub server_port: u16,
+}
+
+impl Message for HelloMessage {
+ type Response = ();
+}
type OnConnectHandler = Box<dyn Fn(NodeID, SocketAddr, bool) + Send + Sync>;
type OnDisconnectHandler = Box<dyn Fn(NodeID, bool) + Send + Sync>;
-pub(crate) type LocalHandler =
- Box<dyn Fn(DynMsg) -> Pin<Box<dyn Future<Output = DynMsg> + Sync + Send>> + Sync + Send>;
-pub(crate) type NetHandler = Box<
- dyn Fn(NodeID, Bytes) -> Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>> + Sync + Send,
->;
-
-pub(crate) struct Handler {
- pub(crate) local_handler: LocalHandler,
- pub(crate) net_handler: NetHandler,
-}
-
/// NetApp is the main class that handles incoming and outgoing connections.
///
/// The `request()` method can be used to send a message to any peer to which we have
@@ -60,10 +53,12 @@ pub struct NetApp {
/// Private key associated with our peer ID
pub privkey: ed25519::SecretKey,
- server_conns: RwLock<HashMap<NodeID, Arc<ServerConn>>>,
- client_conns: RwLock<HashMap<NodeID, Arc<ClientConn>>>,
+ pub(crate) server_conns: RwLock<HashMap<NodeID, Arc<ServerConn>>>,
+ pub(crate) client_conns: RwLock<HashMap<NodeID, Arc<ClientConn>>>,
+
+ pub(crate) endpoints: RwLock<HashMap<String, DynEndpoint>>,
+ hello_endpoint: ArcSwapOption<Endpoint<HelloMessage, NetApp>>,
- pub(crate) msg_handlers: ArcSwap<HashMap<MessageKind, Arc<Handler>>>,
on_connected_handler: ArcSwapOption<OnConnectHandler>,
on_disconnected_handler: ArcSwapOption<OnDisconnectHandler>,
}
@@ -73,44 +68,6 @@ struct ListenParams {
public_addr: Option<IpAddr>,
}
-async fn net_handler_aux<M, F, R>(handler: Arc<F>, remote: NodeID, bytes: Bytes) -> Vec<u8>
-where
- M: Message + 'static,
- F: Fn(NodeID, M) -> R + Send + Sync + 'static,
- R: Future<Output = <M as Message>::Response> + Send + Sync,
-{
- debug!(
- "Handling message of kind {:08x} from {}",
- M::KIND,
- hex::encode(remote)
- );
- let begin_time = Instant::now();
- let res = match rmp_serde::decode::from_read_ref::<_, M>(&bytes[..]) {
- Ok(msg) => Ok(handler(remote, msg).await),
- Err(e) => Err(e.to_string()),
- };
- let end_time = Instant::now();
- debug!(
- "Request {:08x} from {} handled in {}msec",
- M::KIND,
- hex::encode(remote),
- (end_time - begin_time).as_millis()
- );
- rmp_to_vec_all_named(&res).unwrap_or_default()
-}
-
-async fn local_handler_aux<M, F, R>(handler: Arc<F>, remote: NodeID, msg: DynMsg) -> DynMsg
-where
- M: Message + 'static,
- F: Fn(NodeID, M) -> R + Send + Sync + 'static,
- R: Future<Output = <M as Message>::Response> + Send + Sync,
-{
- debug!("Handling message of kind {:08x} from ourself", M::KIND);
- let msg = (msg as Box<dyn Any + 'static>).downcast::<M>().unwrap();
- let res = handler(remote, *msg).await;
- Box::new(res)
-}
-
impl NetApp {
/// Creates a new instance of NetApp, which can serve either as a full p2p node,
/// or just as a passive client. To upgrade to a full p2p node, spawn a listener
@@ -126,16 +83,20 @@ impl NetApp {
privkey,
server_conns: RwLock::new(HashMap::new()),
client_conns: RwLock::new(HashMap::new()),
- msg_handlers: ArcSwap::new(Arc::new(HashMap::new())),
+ endpoints: RwLock::new(HashMap::new()),
+ hello_endpoint: ArcSwapOption::new(None),
on_connected_handler: ArcSwapOption::new(None),
on_disconnected_handler: ArcSwapOption::new(None),
});
- let netapp2 = netapp.clone();
- netapp.add_msg_handler::<HelloMessage, _, _>(move |from: NodeID, msg: HelloMessage| {
- netapp2.handle_hello_message(from, msg);
- async {}
- });
+ netapp
+ .hello_endpoint
+ .swap(Some(netapp.endpoint("__netapp/netapp.rs/Hello".into())));
+ netapp
+ .hello_endpoint
+ .load_full()
+ .unwrap()
+ .set_handler(netapp.clone());
netapp
}
@@ -162,40 +123,23 @@ impl NetApp {
.store(Some(Arc::new(Box::new(handler))));
}
- /// Add a handler for a certain message type. Note that only one handler
- /// can be specified for each message type.
- /// The handler is an asynchronous function, i.e. a function that returns
- /// a future.
- pub fn add_msg_handler<M, F, R>(&self, handler: F)
+ pub fn endpoint<M, H>(self: &Arc<Self>, name: String) -> Arc<Endpoint<M, H>>
where
M: Message + 'static,
- F: Fn(NodeID, M) -> R + Send + Sync + 'static,
- R: Future<Output = <M as Message>::Response> + Send + Sync + 'static,
+ H: EndpointHandler<M> + 'static,
{
- let handler = Arc::new(handler);
-
- let handler2 = handler.clone();
- let net_handler = Box::new(move |remote: NodeID, bytes: Bytes| {
- let fun: Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>> =
- Box::pin(net_handler_aux(handler2.clone(), remote, bytes));
- fun
- });
-
- let self_id = self.id;
- let local_handler = Box::new(move |msg: DynMsg| {
- let fun: Pin<Box<dyn Future<Output = DynMsg> + Sync + Send>> =
- Box::pin(local_handler_aux(handler.clone(), self_id, msg));
- fun
- });
-
- let funs = Arc::new(Handler {
- net_handler,
- local_handler,
- });
-
- let mut handlers = self.msg_handlers.load().as_ref().clone();
- handlers.insert(M::KIND, funs);
- self.msg_handlers.store(Arc::new(handlers));
+ let endpoint = Arc::new(Endpoint::<M, H>::new(self.clone(), name.clone()));
+ let endpoint_arc = EndpointArc(endpoint.clone());
+ if self
+ .endpoints
+ .write()
+ .unwrap()
+ .insert(name.clone(), Box::new(endpoint_arc))
+ .is_some()
+ {
+ panic!("Redefining endpoint: {}", name);
+ };
+ endpoint
}
/// Main listening process for our app. This future runs during the whole
@@ -318,15 +262,6 @@ impl NetApp {
// At this point we know they are a full network member, and not just a client,
// and we call the on_connected handler so that the peering strategy knows
// we have a new potential peer
- fn handle_hello_message(&self, id: NodeID, msg: HelloMessage) {
- if let Some(h) = self.on_connected_handler.load().as_ref() {
- if let Some(c) = self.server_conns.read().unwrap().get(&id) {
- let remote_ip = msg.server_addr.unwrap_or_else(|| c.remote_addr.ip());
- let remote_addr = SocketAddr::new(remote_ip, msg.server_port);
- h(id, remote_addr, true);
- }
- }
- }
// Called from conn.rs when an incoming connection is closed.
// We deregister the connection from server_conns and call the
@@ -371,16 +306,19 @@ impl NetApp {
if let Some(lp) = self.listen_params.load_full() {
let server_addr = lp.public_addr;
let server_port = lp.listen_addr.port();
+ let hello_endpoint = self.hello_endpoint.load_full().unwrap();
tokio::spawn(async move {
- conn.request(
- HelloMessage {
- server_addr,
- server_port,
- },
- PRIO_NORMAL,
- )
- .await
- .log_err("Sending hello message");
+ hello_endpoint
+ .call(
+ &conn.peer_id,
+ HelloMessage {
+ server_addr,
+ server_port,
+ },
+ PRIO_NORMAL,
+ )
+ .await
+ .log_err("Sending hello message");
});
}
}
@@ -404,44 +342,16 @@ impl NetApp {
// else case: happens if connection was removed in .disconnect()
// in which case on_disconnected_handler was already called
}
+}
- /// Send a message to a remote host to which a client connection is already
- /// established, and await their response. The target is the id of the peer we
- /// want to send the message to.
- /// The priority is an `u8`, with lower numbers meaning highest priority.
- pub async fn request<T>(
- &self,
- target: &NodeID,
- rq: T,
- prio: RequestPriority,
- ) -> Result<<T as Message>::Response, Error>
- where
- T: Message + 'static,
- {
- if *target == self.id {
- let handler = self.msg_handlers.load().get(&T::KIND).cloned();
- match handler {
- None => Err(Error::Message(format!(
- "No handler registered for message kind {:08x}",
- T::KIND
- ))),
- Some(h) => {
- let local_handler = &h.local_handler;
- let res = local_handler(Box::new(rq)).await;
- let res_t = (res as Box<dyn Any + 'static>)
- .downcast::<<T as Message>::Response>()
- .unwrap();
- Ok(*res_t)
- }
- }
- } else {
- let conn = self.client_conns.read().unwrap().get(target).cloned();
- match conn {
- None => Err(Error::Message(format!(
- "Not connected: {}",
- hex::encode(target)
- ))),
- Some(c) => c.request(rq, prio).await,
+#[async_trait]
+impl EndpointHandler<HelloMessage> for NetApp {
+ async fn handle(self: &Arc<Self>, msg: HelloMessage, from: NodeID) {
+ if let Some(h) = self.on_connected_handler.load().as_ref() {
+ if let Some(c) = self.server_conns.read().unwrap().get(&from) {
+ let remote_ip = msg.server_addr.unwrap_or_else(|| c.remote_addr.ip());
+ let remote_addr = SocketAddr::new(remote_ip, msg.server_port);
+ h(from, remote_addr, true);
}
}
}
diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs
index 8b1c802..b579654 100644
--- a/src/peering/fullmesh.rs
+++ b/src/peering/fullmesh.rs
@@ -4,12 +4,13 @@ use std::sync::atomic::{self, AtomicU64};
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
+use async_trait::async_trait;
use log::{debug, info, trace, warn};
use serde::{Deserialize, Serialize};
use sodiumoxide::crypto::hash;
-use crate::message::*;
+use crate::endpoint::*;
use crate::netapp::*;
use crate::proto::*;
use crate::NodeID;
@@ -28,7 +29,6 @@ struct PingMessage {
}
impl Message for PingMessage {
- const KIND: MessageKind = 0x42001000;
type Response = PingMessage;
}
@@ -38,7 +38,6 @@ struct PeerListMessage {
}
impl Message for PeerListMessage {
- const KIND: MessageKind = 0x42001001;
type Response = PeerListMessage;
}
@@ -124,6 +123,9 @@ pub struct FullMeshPeeringStrategy {
netapp: Arc<NetApp>,
known_hosts: RwLock<KnownHosts>,
next_ping_id: AtomicU64,
+
+ ping_endpoint: Arc<Endpoint<PingMessage, Self>>,
+ peer_list_endpoint: Arc<Endpoint<PeerListMessage, Self>>,
}
impl FullMeshPeeringStrategy {
@@ -147,27 +149,12 @@ impl FullMeshPeeringStrategy {
netapp: netapp.clone(),
known_hosts: RwLock::new(known_hosts),
next_ping_id: AtomicU64::new(42),
+ ping_endpoint: netapp.endpoint("__netapp/peering/fullmesh.rs/Ping".into()),
+ peer_list_endpoint: netapp.endpoint("__netapp/peering/fullmesh.rs/PeerList".into()),
});
- let strat2 = strat.clone();
- netapp.add_msg_handler::<PingMessage, _, _>(move |from: NodeID, ping: PingMessage| {
- let ping_resp = PingMessage {
- id: ping.id,
- peer_list_hash: strat2.known_hosts.read().unwrap().hash,
- };
- debug!("Ping from {}", hex::encode(&from));
- async move { ping_resp }
- });
-
- let strat2 = strat.clone();
- netapp.add_msg_handler::<PeerListMessage, _, _>(
- move |_from: NodeID, peer_list: PeerListMessage| {
- strat2.handle_peer_list(&peer_list.list[..]);
- let peer_list = KnownHosts::map_into_vec(&strat2.known_hosts.read().unwrap().list);
- let resp = PeerListMessage { list: peer_list };
- async move { resp }
- },
- );
+ strat.ping_endpoint.set_handler(strat.clone());
+ strat.peer_list_endpoint.set_handler(strat.clone());
let strat2 = strat.clone();
netapp.on_connected(move |id: NodeID, addr: SocketAddr, is_incoming: bool| {
@@ -262,7 +249,7 @@ impl FullMeshPeeringStrategy {
hex::encode(id),
ping_time
);
- match self.netapp.request(&id, ping_msg, PRIO_HIGH).await {
+ match self.ping_endpoint.call(&id, ping_msg, PRIO_HIGH).await {
Err(e) => warn!("Error pinging {}: {}", hex::encode(id), e),
Ok(ping_resp) => {
let resp_time = Instant::now();
@@ -291,7 +278,11 @@ impl FullMeshPeeringStrategy {
async fn exchange_peers(self: Arc<Self>, id: &NodeID) {
let peer_list = KnownHosts::map_into_vec(&self.known_hosts.read().unwrap().list);
let pex_message = PeerListMessage { list: peer_list };
- match self.netapp.request(id, pex_message, PRIO_BACKGROUND).await {
+ match self
+ .peer_list_endpoint
+ .call(id, pex_message, PRIO_BACKGROUND)
+ .await
+ {
Err(e) => warn!("Error doing peer exchange: {}", e),
Ok(resp) => {
self.handle_peer_list(&resp.list[..]);
@@ -408,3 +399,28 @@ impl FullMeshPeeringStrategy {
}
}
}
+
+#[async_trait]
+impl EndpointHandler<PingMessage> for FullMeshPeeringStrategy {
+ async fn handle(self: &Arc<Self>, ping: PingMessage, from: NodeID) -> PingMessage {
+ let ping_resp = PingMessage {
+ id: ping.id,
+ peer_list_hash: self.known_hosts.read().unwrap().hash,
+ };
+ debug!("Ping from {}", hex::encode(&from));
+ ping_resp
+ }
+}
+
+#[async_trait]
+impl EndpointHandler<PeerListMessage> for FullMeshPeeringStrategy {
+ async fn handle(
+ self: &Arc<Self>,
+ peer_list: PeerListMessage,
+ _from: NodeID,
+ ) -> PeerListMessage {
+ self.handle_peer_list(&peer_list.list[..]);
+ let peer_list = KnownHosts::map_into_vec(&self.known_hosts.read().unwrap().list);
+ PeerListMessage { list: peer_list }
+ }
+}
diff --git a/src/proto.rs b/src/proto.rs
index ef3b31c..5b71ba5 100644
--- a/src/proto.rs
+++ b/src/proto.rs
@@ -38,9 +38,10 @@ pub const PRIO_PRIMARY: RequestPriority = 0x00;
/// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`)
pub const PRIO_SECONDARY: RequestPriority = 0x01;
-const MAX_CHUNK_SIZE: usize = 0x4000;
-
-pub(crate) type RequestID = u16;
+pub(crate) type RequestID = u32;
+type ChunkLength = u16;
+const MAX_CHUNK_LENGTH: ChunkLength = 0x4000;
+const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000;
struct SendQueueItem {
id: RequestID,
@@ -85,6 +86,12 @@ impl SendQueue {
}
}
+// Messages are sent by chunks
+// Chunk format:
+// - u32 BE: request id (same for request and response)
+// - u16 BE: chunk length
+// - [u8; chunk_length] chunk data
+
#[async_trait]
pub(crate) trait SendLoop: Sync {
async fn send_loop<W>(
@@ -117,22 +124,23 @@ pub(crate) trait SendLoop: Sync {
item.data.len(),
item.cursor
);
- let header_id = u16::to_be_bytes(item.id);
+ let header_id = RequestID::to_be_bytes(item.id);
write.write_all(&header_id[..]).await?;
- if item.data.len() - item.cursor > MAX_CHUNK_SIZE {
- let header_size = u16::to_be_bytes(MAX_CHUNK_SIZE as u16 | 0x8000);
+ if item.data.len() - item.cursor > MAX_CHUNK_LENGTH as usize {
+ let header_size =
+ ChunkLength::to_be_bytes(MAX_CHUNK_LENGTH | CHUNK_HAS_CONTINUATION);
write.write_all(&header_size[..]).await?;
- let new_cursor = item.cursor + MAX_CHUNK_SIZE as usize;
+ let new_cursor = item.cursor + MAX_CHUNK_LENGTH as usize;
write.write_all(&item.data[item.cursor..new_cursor]).await?;
item.cursor = new_cursor;
sending.push(item);
} else {
- let send_len = (item.data.len() - item.cursor) as u16;
+ let send_len = (item.data.len() - item.cursor) as ChunkLength;
- let header_size = u16::to_be_bytes(send_len);
+ let header_size = ChunkLength::to_be_bytes(send_len);
write.write_all(&header_size[..]).await?;
write.write_all(&item.data[item.cursor..]).await?;
@@ -172,18 +180,18 @@ pub(crate) trait RecvLoop: Sync + 'static {
let mut receiving = HashMap::new();
loop {
trace!("recv_loop: reading packet");
- let mut header_id = [0u8; 2];
+ let mut header_id = [0u8; RequestID::BITS as usize / 8];
read.read_exact(&mut header_id[..]).await?;
let id = RequestID::from_be_bytes(header_id);
trace!("recv_loop: got header id: {:04x}", id);
- let mut header_size = [0u8; 2];
+ let mut header_size = [0u8; ChunkLength::BITS as usize / 8];
read.read_exact(&mut header_size[..]).await?;
- let size = RequestID::from_be_bytes(header_size);
+ let size = ChunkLength::from_be_bytes(header_size);
trace!("recv_loop: got header size: {:04x}", size);
- let has_cont = (size & 0x8000) != 0;
- let size = size & !0x8000;
+ let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0;
+ let size = size & !CHUNK_HAS_CONTINUATION;
let mut next_slice = vec![0; size as usize];
read.read_exact(&mut next_slice[..]).await?;