aboutsummaryrefslogtreecommitdiff
path: root/src/net
diff options
context:
space:
mode:
Diffstat (limited to 'src/net')
-rw-r--r--src/net/Cargo.toml45
-rw-r--r--src/net/bytes_buf.rs186
-rw-r--r--src/net/client.rs292
-rw-r--r--src/net/endpoint.rs201
-rw-r--r--src/net/error.rs126
-rw-r--r--src/net/lib.rs35
-rw-r--r--src/net/message.rs522
-rw-r--r--src/net/netapp.rs452
-rw-r--r--src/net/peering.rs614
-rw-r--r--src/net/recv.rs153
-rw-r--r--src/net/send.rs356
-rw-r--r--src/net/server.rs222
-rw-r--r--src/net/stream.rs202
-rw-r--r--src/net/test.rs118
-rw-r--r--src/net/util.rs96
15 files changed, 3620 insertions, 0 deletions
diff --git a/src/net/Cargo.toml b/src/net/Cargo.toml
new file mode 100644
index 00000000..a2674498
--- /dev/null
+++ b/src/net/Cargo.toml
@@ -0,0 +1,45 @@
+[package]
+name = "garage_net"
+version = "0.9.1"
+authors = ["Alex Auvolat <alex@adnab.me>"]
+edition = "2018"
+license-file = "AGPL-3.0"
+description = "Networking library for Garage RPC communication, forked from Netapp"
+repository = "https://git.deuxfleurs.fr/Deuxfleurs/garage"
+readme = "../../README.md"
+
+[lib]
+path = "lib.rs"
+
+[features]
+default = []
+telemetry = ["opentelemetry", "opentelemetry-contrib"]
+
+[dependencies]
+futures.workspace = true
+pin-project.workspace = true
+tokio.workspace = true
+tokio-util.workspace = true
+tokio-stream.workspace = true
+
+serde.workspace = true
+rmp-serde.workspace = true
+hex.workspace = true
+
+rand.workspace = true
+
+log.workspace = true
+arc-swap.workspace = true
+async-trait.workspace = true
+err-derive.workspace = true
+bytes.workspace = true
+cfg-if.workspace = true
+
+sodiumoxide.workspace = true
+kuska-handshake.workspace = true
+
+opentelemetry = { workspace = true, optional = true }
+opentelemetry-contrib = { workspace = true, optional = true }
+
+[dev-dependencies]
+pretty_env_logger.workspace = true
diff --git a/src/net/bytes_buf.rs b/src/net/bytes_buf.rs
new file mode 100644
index 00000000..3929a860
--- /dev/null
+++ b/src/net/bytes_buf.rs
@@ -0,0 +1,186 @@
+use std::cmp::Ordering;
+use std::collections::VecDeque;
+
+use bytes::BytesMut;
+
+pub use bytes::Bytes;
+
+/// A circular buffer of bytes, internally represented as a list of Bytes
+/// for optimization, but that for all intent and purposes acts just like
+/// a big byte slice which can be extended on the right and from which
+/// stuff can be taken on the left.
+pub struct BytesBuf {
+ buf: VecDeque<Bytes>,
+ buf_len: usize,
+}
+
+impl BytesBuf {
+ /// Creates a new empty BytesBuf
+ pub fn new() -> Self {
+ Self {
+ buf: VecDeque::new(),
+ buf_len: 0,
+ }
+ }
+
+ /// Returns the number of bytes stored in the BytesBuf
+ #[inline]
+ pub fn len(&self) -> usize {
+ self.buf_len
+ }
+
+ /// Returns true iff the BytesBuf contains zero bytes
+ #[inline]
+ pub fn is_empty(&self) -> bool {
+ self.buf_len == 0
+ }
+
+ /// Adds some bytes to the right of the buffer
+ pub fn extend(&mut self, b: Bytes) {
+ if !b.is_empty() {
+ self.buf_len += b.len();
+ self.buf.push_back(b);
+ }
+ }
+
+ /// Takes the whole content of the buffer and returns it as a single Bytes unit
+ pub fn take_all(&mut self) -> Bytes {
+ if self.buf.is_empty() {
+ Bytes::new()
+ } else if self.buf.len() == 1 {
+ self.buf_len = 0;
+ self.buf.pop_back().unwrap()
+ } else {
+ let mut ret = BytesMut::with_capacity(self.buf_len);
+ for b in self.buf.iter() {
+ ret.extend_from_slice(&b[..]);
+ }
+ self.buf.clear();
+ self.buf_len = 0;
+ ret.freeze()
+ }
+ }
+
+ /// Takes at most max_len bytes from the left of the buffer
+ pub fn take_max(&mut self, max_len: usize) -> Bytes {
+ if self.buf_len <= max_len {
+ self.take_all()
+ } else {
+ self.take_exact_ok(max_len)
+ }
+ }
+
+ /// Take exactly len bytes from the left of the buffer, returns None if
+ /// the BytesBuf doesn't contain enough data
+ pub fn take_exact(&mut self, len: usize) -> Option<Bytes> {
+ if self.buf_len < len {
+ None
+ } else {
+ Some(self.take_exact_ok(len))
+ }
+ }
+
+ fn take_exact_ok(&mut self, len: usize) -> Bytes {
+ assert!(len <= self.buf_len);
+ let front = self.buf.pop_front().unwrap();
+ match front.len().cmp(&len) {
+ Ordering::Greater => {
+ self.buf.push_front(front.slice(len..));
+ self.buf_len -= len;
+ front.slice(..len)
+ }
+ Ordering::Equal => {
+ self.buf_len -= len;
+ front
+ }
+ Ordering::Less => {
+ let mut ret = BytesMut::with_capacity(len);
+ ret.extend_from_slice(&front[..]);
+ self.buf_len -= front.len();
+ while ret.len() < len {
+ let front = self.buf.pop_front().unwrap();
+ if front.len() > len - ret.len() {
+ let take = len - ret.len();
+ ret.extend_from_slice(&front[..take]);
+ self.buf.push_front(front.slice(take..));
+ self.buf_len -= take;
+ break;
+ } else {
+ ret.extend_from_slice(&front[..]);
+ self.buf_len -= front.len();
+ }
+ }
+ ret.freeze()
+ }
+ }
+ }
+
+ /// Return the internal sequence of Bytes slices that make up the buffer
+ pub fn into_slices(self) -> VecDeque<Bytes> {
+ self.buf
+ }
+}
+
+impl Default for BytesBuf {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl From<Bytes> for BytesBuf {
+ fn from(b: Bytes) -> BytesBuf {
+ let mut ret = BytesBuf::new();
+ ret.extend(b);
+ ret
+ }
+}
+
+impl From<BytesBuf> for Bytes {
+ fn from(mut b: BytesBuf) -> Bytes {
+ b.take_all()
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ #[test]
+ fn test_bytes_buf() {
+ let mut buf = BytesBuf::new();
+ assert!(buf.len() == 0);
+ assert!(buf.is_empty());
+
+ buf.extend(Bytes::from(b"Hello, world!".to_vec()));
+ assert!(buf.len() == 13);
+ assert!(!buf.is_empty());
+
+ buf.extend(Bytes::from(b"1234567890".to_vec()));
+ assert!(buf.len() == 23);
+ assert!(!buf.is_empty());
+
+ assert_eq!(
+ buf.take_all(),
+ Bytes::from(b"Hello, world!1234567890".to_vec())
+ );
+ assert!(buf.len() == 0);
+ assert!(buf.is_empty());
+
+ buf.extend(Bytes::from(b"1234567890".to_vec()));
+ buf.extend(Bytes::from(b"Hello, world!".to_vec()));
+ assert!(buf.len() == 23);
+ assert!(!buf.is_empty());
+
+ assert_eq!(buf.take_max(12), Bytes::from(b"1234567890He".to_vec()));
+ assert!(buf.len() == 11);
+
+ assert_eq!(buf.take_exact(12), None);
+ assert!(buf.len() == 11);
+ assert_eq!(
+ buf.take_exact(11),
+ Some(Bytes::from(b"llo, world!".to_vec()))
+ );
+ assert!(buf.len() == 0);
+ assert!(buf.is_empty());
+ }
+}
diff --git a/src/net/client.rs b/src/net/client.rs
new file mode 100644
index 00000000..607dd173
--- /dev/null
+++ b/src/net/client.rs
@@ -0,0 +1,292 @@
+use std::collections::HashMap;
+use std::net::SocketAddr;
+use std::pin::Pin;
+use std::sync::atomic::{self, AtomicU32};
+use std::sync::{Arc, Mutex};
+use std::task::Poll;
+
+use arc_swap::ArcSwapOption;
+use async_trait::async_trait;
+use bytes::Bytes;
+use log::{debug, error, trace};
+
+use futures::io::AsyncReadExt;
+use futures::Stream;
+use kuska_handshake::async_std::{handshake_client, BoxStream};
+use tokio::net::TcpStream;
+use tokio::select;
+use tokio::sync::{mpsc, oneshot, watch};
+use tokio_util::compat::*;
+
+#[cfg(feature = "telemetry")]
+use opentelemetry::{
+ trace::{FutureExt, Span, SpanKind, TraceContextExt, Tracer},
+ Context, KeyValue,
+};
+#[cfg(feature = "telemetry")]
+use opentelemetry_contrib::trace::propagator::binary::*;
+
+use crate::error::*;
+use crate::message::*;
+use crate::netapp::*;
+use crate::recv::*;
+use crate::send::*;
+use crate::stream::*;
+use crate::util::*;
+
+pub(crate) struct ClientConn {
+ pub(crate) remote_addr: SocketAddr,
+ pub(crate) peer_id: NodeID,
+
+ query_send: ArcSwapOption<mpsc::UnboundedSender<SendItem>>,
+
+ next_query_number: AtomicU32,
+ inflight: Mutex<HashMap<RequestID, oneshot::Sender<ByteStream>>>,
+}
+
+impl ClientConn {
+ pub(crate) async fn init(
+ netapp: Arc<NetApp>,
+ socket: TcpStream,
+ peer_id: NodeID,
+ ) -> Result<(), Error> {
+ let remote_addr = socket.peer_addr()?;
+ let mut socket = socket.compat();
+
+ // Do handshake to authenticate and prove our identity to server
+ let handshake = handshake_client(
+ &mut socket,
+ netapp.netid.clone(),
+ netapp.id,
+ netapp.privkey.clone(),
+ peer_id,
+ )
+ .await?;
+
+ debug!(
+ "Handshake complete (client) with {}@{}",
+ hex::encode(peer_id),
+ remote_addr
+ );
+
+ // Create BoxStream layer that encodes content
+ let (read, write) = socket.split();
+ let (mut read, write) =
+ BoxStream::from_handshake(read, write, handshake, 0x8000).split_read_write();
+
+ // Before doing anything, receive version tag and
+ // check they are running the same version as us
+ let mut their_version_tag = VersionTag::default();
+ read.read_exact(&mut their_version_tag[..]).await?;
+ if their_version_tag != netapp.version_tag {
+ let msg = format!(
+ "different version tags: {} (theirs) vs. {} (ours)",
+ hex::encode(their_version_tag),
+ hex::encode(netapp.version_tag)
+ );
+ error!("Cannot connect to {}: {}", hex::encode(&peer_id[..8]), msg);
+ return Err(Error::VersionMismatch(msg));
+ }
+
+ // Build and launch stuff that manages sending requests client-side
+ let (query_send, query_recv) = mpsc::unbounded_channel();
+
+ let (stop_recv_loop, stop_recv_loop_recv) = watch::channel(false);
+
+ let conn = Arc::new(ClientConn {
+ remote_addr,
+ peer_id,
+ next_query_number: AtomicU32::from(RequestID::default()),
+ query_send: ArcSwapOption::new(Some(Arc::new(query_send))),
+ inflight: Mutex::new(HashMap::new()),
+ });
+
+ netapp.connected_as_client(peer_id, conn.clone());
+
+ let debug_name = format!("CLI {}", hex::encode(&peer_id[..8]));
+
+ tokio::spawn(async move {
+ let debug_name_2 = debug_name.clone();
+ let send_future = tokio::spawn(conn.clone().send_loop(query_recv, write, debug_name_2));
+
+ let conn2 = conn.clone();
+ let recv_future = tokio::spawn(async move {
+ select! {
+ r = conn2.recv_loop(read, debug_name) => r,
+ _ = await_exit(stop_recv_loop_recv) => Ok(())
+ }
+ });
+
+ send_future.await.log_err("ClientConn send_loop");
+
+ // FIXME: should do here: wait for inflight requests to all have their response
+ stop_recv_loop
+ .send(true)
+ .log_err("ClientConn send true to stop_recv_loop");
+
+ recv_future.await.log_err("ClientConn recv_loop");
+
+ // Make sure we don't wait on any more requests that won't
+ // have a response
+ conn.inflight.lock().unwrap().clear();
+
+ netapp.disconnected_as_client(&peer_id, conn);
+ });
+
+ Ok(())
+ }
+
+ pub fn close(&self) {
+ self.query_send.store(None);
+ }
+
+ pub(crate) async fn call<T>(
+ self: Arc<Self>,
+ req: Req<T>,
+ path: &str,
+ prio: RequestPriority,
+ ) -> Result<Resp<T>, Error>
+ where
+ T: Message,
+ {
+ let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?;
+
+ let id = self
+ .next_query_number
+ .fetch_add(1, atomic::Ordering::Relaxed);
+
+ cfg_if::cfg_if! {
+ if #[cfg(feature = "telemetry")] {
+ let tracer = opentelemetry::global::tracer("netapp");
+ let mut span = tracer.span_builder(format!("RPC >> {}", path))
+ .with_kind(SpanKind::Client)
+ .start(&tracer);
+ let propagator = BinaryPropagator::new();
+ let telemetry_id: Bytes = propagator.to_bytes(span.span_context()).to_vec().into();
+ } else {
+ let telemetry_id: Bytes = Bytes::new();
+ }
+ };
+
+ // Encode request
+ let req_enc = req.into_enc(prio, path.as_bytes().to_vec().into(), telemetry_id);
+ let req_msg_len = req_enc.msg.len();
+ let (req_stream, req_order) = req_enc.encode();
+
+ // Send request through
+ let (resp_send, resp_recv) = oneshot::channel();
+ let old = self.inflight.lock().unwrap().insert(id, resp_send);
+ if let Some(old_ch) = old {
+ error!(
+ "Too many inflight requests! RequestID collision. Interrupting previous request."
+ );
+ let _ = old_ch.send(Box::pin(futures::stream::once(async move {
+ Err(std::io::Error::new(
+ std::io::ErrorKind::Other,
+ "RequestID collision, too many inflight requests",
+ ))
+ })));
+ }
+
+ debug!(
+ "request: query_send {}, path {}, prio {} (serialized message: {} bytes)",
+ id, path, prio, req_msg_len
+ );
+
+ #[cfg(feature = "telemetry")]
+ span.set_attribute(KeyValue::new("len_query_msg", req_msg_len as i64));
+
+ query_send.send(SendItem::Stream(id, prio, req_order, req_stream))?;
+
+ let canceller = CancelOnDrop::new(id, query_send.as_ref().clone());
+
+ cfg_if::cfg_if! {
+ if #[cfg(feature = "telemetry")] {
+ let stream = resp_recv
+ .with_context(Context::current_with_span(span))
+ .await?;
+ } else {
+ let stream = resp_recv.await?;
+ }
+ }
+
+ let stream = Box::pin(canceller.for_stream(stream));
+
+ let resp_enc = RespEnc::decode(stream).await?;
+ debug!("client: got response to request {} (path {})", id, path);
+ Resp::from_enc(resp_enc)
+ }
+}
+
+impl SendLoop for ClientConn {}
+
+#[async_trait]
+impl RecvLoop for ClientConn {
+ fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream) {
+ trace!("ClientConn recv_handler {}", id);
+
+ let mut inflight = self.inflight.lock().unwrap();
+ if let Some(ch) = inflight.remove(&id) {
+ if ch.send(stream).is_err() {
+ debug!("Could not send request response, probably because request was interrupted. Dropping response.");
+ }
+ } else {
+ debug!("Got unexpected response to request {}, dropping it", id);
+ }
+ }
+}
+
+// ----
+
+struct CancelOnDrop {
+ id: RequestID,
+ query_send: mpsc::UnboundedSender<SendItem>,
+}
+
+impl CancelOnDrop {
+ fn new(id: RequestID, query_send: mpsc::UnboundedSender<SendItem>) -> Self {
+ Self { id, query_send }
+ }
+ fn for_stream(self, stream: ByteStream) -> CancelOnDropStream {
+ CancelOnDropStream {
+ cancel: Some(self),
+ stream,
+ }
+ }
+}
+
+impl Drop for CancelOnDrop {
+ fn drop(&mut self) {
+ trace!("cancelling request {}", self.id);
+ let _ = self.query_send.send(SendItem::Cancel(self.id));
+ }
+}
+
+#[pin_project::pin_project]
+struct CancelOnDropStream {
+ cancel: Option<CancelOnDrop>,
+ #[pin]
+ stream: ByteStream,
+}
+
+impl Stream for CancelOnDropStream {
+ type Item = Packet;
+
+ fn poll_next(
+ self: Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
+ let this = self.project();
+ let res = this.stream.poll_next(cx);
+ if matches!(res, Poll::Ready(None)) {
+ if let Some(c) = this.cancel.take() {
+ std::mem::forget(c)
+ }
+ }
+ res
+ }
+
+ fn size_hint(&self) -> (usize, Option<usize>) {
+ self.stream.size_hint()
+ }
+}
diff --git a/src/net/endpoint.rs b/src/net/endpoint.rs
new file mode 100644
index 00000000..3cafafeb
--- /dev/null
+++ b/src/net/endpoint.rs
@@ -0,0 +1,201 @@
+use std::marker::PhantomData;
+use std::sync::Arc;
+
+use arc_swap::ArcSwapOption;
+use async_trait::async_trait;
+
+use crate::error::Error;
+use crate::message::*;
+use crate::netapp::*;
+
+/// This trait should be implemented by an object of your application
+/// that can handle a message of type `M`, if it wishes to handle
+/// streams attached to the request and/or to send back streams
+/// attached to the response..
+///
+/// The handler object should be in an Arc, see `Endpoint::set_handler`
+#[async_trait]
+pub trait StreamingEndpointHandler<M>: Send + Sync
+where
+ M: Message,
+{
+ async fn handle(self: &Arc<Self>, m: Req<M>, from: NodeID) -> Resp<M>;
+}
+
+/// If one simply wants to use an endpoint in a client fashion,
+/// without locally serving requests to that endpoint,
+/// use the unit type `()` as the handler type:
+/// it will panic if it is ever made to handle request.
+#[async_trait]
+impl<M: Message> EndpointHandler<M> for () {
+ async fn handle(self: &Arc<()>, _m: &M, _from: NodeID) -> M::Response {
+ panic!("This endpoint should not have a local handler.");
+ }
+}
+
+// ----
+
+/// This trait should be implemented by an object of your application
+/// that can handle a message of type `M`, in the cases where it doesn't
+/// care about attached stream in the request nor in the response.
+#[async_trait]
+pub trait EndpointHandler<M>: Send + Sync
+where
+ M: Message,
+{
+ async fn handle(self: &Arc<Self>, m: &M, from: NodeID) -> M::Response;
+}
+
+#[async_trait]
+impl<T, M> StreamingEndpointHandler<M> for T
+where
+ T: EndpointHandler<M>,
+ M: Message,
+{
+ async fn handle(self: &Arc<Self>, mut m: Req<M>, from: NodeID) -> Resp<M> {
+ // Immediately drop stream to ignore all data that comes in,
+ // instead of buffering it indefinitely
+ drop(m.take_stream());
+ Resp::new(EndpointHandler::handle(self, m.msg(), from).await)
+ }
+}
+
+// ----
+
+/// This struct represents an endpoint for message of type `M`.
+///
+/// Creating a new endpoint is done by calling `NetApp::endpoint`.
+/// An endpoint is identified primarily by its path, which is specified
+/// at creation time.
+///
+/// An `Endpoint` is used both to send requests to remote nodes,
+/// and to specify the handler for such requests on the local node.
+/// The type `H` represents the type of the handler object for
+/// endpoint messages (see `StreamingEndpointHandler`).
+pub struct Endpoint<M, H>
+where
+ M: Message,
+ H: StreamingEndpointHandler<M>,
+{
+ _phantom: PhantomData<M>,
+ netapp: Arc<NetApp>,
+ path: String,
+ handler: ArcSwapOption<H>,
+}
+
+impl<M, H> Endpoint<M, H>
+where
+ M: Message,
+ H: StreamingEndpointHandler<M>,
+{
+ pub(crate) fn new(netapp: Arc<NetApp>, path: String) -> Self {
+ Self {
+ _phantom: PhantomData::default(),
+ netapp,
+ path,
+ handler: ArcSwapOption::from(None),
+ }
+ }
+
+ /// Get the path of this endpoint
+ pub fn path(&self) -> &str {
+ &self.path
+ }
+
+ /// Set the object that is responsible of handling requests to
+ /// this endpoint on the local node.
+ pub fn set_handler(&self, h: Arc<H>) {
+ self.handler.swap(Some(h));
+ }
+
+ /// Call this endpoint on a remote node (or on the local node,
+ /// for that matter). This function invokes the full version that
+ /// allows to attach a stream to the request and to
+ /// receive such a stream attached to the response.
+ pub async fn call_streaming<T>(
+ &self,
+ target: &NodeID,
+ req: T,
+ prio: RequestPriority,
+ ) -> Result<Resp<M>, Error>
+ where
+ T: IntoReq<M>,
+ {
+ if *target == self.netapp.id {
+ match self.handler.load_full() {
+ None => Err(Error::NoHandler),
+ Some(h) => Ok(h.handle(req.into_req_local(), self.netapp.id).await),
+ }
+ } else {
+ let conn = self
+ .netapp
+ .client_conns
+ .read()
+ .unwrap()
+ .get(target)
+ .cloned();
+ match conn {
+ None => Err(Error::Message(format!(
+ "Not connected: {}",
+ hex::encode(&target[..8])
+ ))),
+ Some(c) => c.call(req.into_req()?, self.path.as_str(), prio).await,
+ }
+ }
+ }
+
+ /// Call this endpoint on a remote node. This function is the simplified
+ /// version that doesn't allow to have streams attached to the request
+ /// or the response; see `call_streaming` for the full version.
+ pub async fn call(
+ &self,
+ target: &NodeID,
+ req: M,
+ prio: RequestPriority,
+ ) -> Result<<M as Message>::Response, Error> {
+ Ok(self.call_streaming(target, req, prio).await?.into_msg())
+ }
+}
+
+// ---- Internal stuff ----
+
+pub(crate) type DynEndpoint = Box<dyn GenericEndpoint + Send + Sync>;
+
+#[async_trait]
+pub(crate) trait GenericEndpoint {
+ async fn handle(&self, req_enc: ReqEnc, from: NodeID) -> Result<RespEnc, Error>;
+ fn drop_handler(&self);
+ fn clone_endpoint(&self) -> DynEndpoint;
+}
+
+#[derive(Clone)]
+pub(crate) struct EndpointArc<M, H>(pub(crate) Arc<Endpoint<M, H>>)
+where
+ M: Message,
+ H: StreamingEndpointHandler<M>;
+
+#[async_trait]
+impl<M, H> GenericEndpoint for EndpointArc<M, H>
+where
+ M: Message,
+ H: StreamingEndpointHandler<M> + 'static,
+{
+ async fn handle(&self, req_enc: ReqEnc, from: NodeID) -> Result<RespEnc, Error> {
+ match self.0.handler.load_full() {
+ None => Err(Error::NoHandler),
+ Some(h) => {
+ let req = Req::from_enc(req_enc)?;
+ let res = h.handle(req, from).await;
+ Ok(res.into_enc()?)
+ }
+ }
+ }
+
+ fn drop_handler(&self) {
+ self.0.handler.swap(None);
+ }
+
+ fn clone_endpoint(&self) -> DynEndpoint {
+ Box::new(Self(self.0.clone()))
+ }
+}
diff --git a/src/net/error.rs b/src/net/error.rs
new file mode 100644
index 00000000..c0aeeacc
--- /dev/null
+++ b/src/net/error.rs
@@ -0,0 +1,126 @@
+use std::io;
+
+use err_derive::Error;
+use log::error;
+
+#[derive(Debug, Error)]
+pub enum Error {
+ #[error(display = "IO error: {}", _0)]
+ Io(#[error(source)] io::Error),
+
+ #[error(display = "Messagepack encode error: {}", _0)]
+ RMPEncode(#[error(source)] rmp_serde::encode::Error),
+ #[error(display = "Messagepack decode error: {}", _0)]
+ RMPDecode(#[error(source)] rmp_serde::decode::Error),
+
+ #[error(display = "Tokio join error: {}", _0)]
+ TokioJoin(#[error(source)] tokio::task::JoinError),
+
+ #[error(display = "oneshot receive error: {}", _0)]
+ OneshotRecv(#[error(source)] tokio::sync::oneshot::error::RecvError),
+
+ #[error(display = "Handshake error: {}", _0)]
+ Handshake(#[error(source)] kuska_handshake::async_std::Error),
+
+ #[error(display = "UTF8 error: {}", _0)]
+ UTF8(#[error(source)] std::string::FromUtf8Error),
+
+ #[error(display = "Framing protocol error")]
+ Framing,
+
+ #[error(display = "Remote error ({:?}): {}", _0, _1)]
+ Remote(io::ErrorKind, String),
+
+ #[error(display = "Request ID collision")]
+ IdCollision,
+
+ #[error(display = "{}", _0)]
+ Message(String),
+
+ #[error(display = "No handler / shutting down")]
+ NoHandler,
+
+ #[error(display = "Connection closed")]
+ ConnectionClosed,
+
+ #[error(display = "Version mismatch: {}", _0)]
+ VersionMismatch(String),
+}
+
+impl<T> From<tokio::sync::watch::error::SendError<T>> for Error {
+ fn from(_e: tokio::sync::watch::error::SendError<T>) -> Error {
+ Error::Message("Watch send error".into())
+ }
+}
+
+impl<T> From<tokio::sync::mpsc::error::SendError<T>> for Error {
+ fn from(_e: tokio::sync::mpsc::error::SendError<T>) -> Error {
+ Error::Message("MPSC send error".into())
+ }
+}
+
+/// Ths trait adds a `.log_err()` method on `Result<(), E>` types,
+/// which dismisses the error by logging it to stderr.
+pub trait LogError {
+ fn log_err(self, msg: &'static str);
+}
+
+impl<E> LogError for Result<(), E>
+where
+ E: Into<Error>,
+{
+ fn log_err(self, msg: &'static str) {
+ if let Err(e) = self {
+ error!("Error: {}: {}", msg, Into::<Error>::into(e));
+ };
+ }
+}
+
+impl<E, T> LogError for Result<T, E>
+where
+ T: LogError,
+ E: Into<Error>,
+{
+ fn log_err(self, msg: &'static str) {
+ match self {
+ Err(e) => error!("Error: {}: {}", msg, Into::<Error>::into(e)),
+ Ok(x) => x.log_err(msg),
+ }
+ }
+}
+
+// ---- Helpers for serializing I/O Errors
+
+pub(crate) fn u8_to_io_errorkind(v: u8) -> std::io::ErrorKind {
+ use std::io::ErrorKind;
+ match v {
+ 101 => ErrorKind::ConnectionAborted,
+ 102 => ErrorKind::BrokenPipe,
+ 103 => ErrorKind::WouldBlock,
+ 104 => ErrorKind::InvalidInput,
+ 105 => ErrorKind::InvalidData,
+ 106 => ErrorKind::TimedOut,
+ 107 => ErrorKind::Interrupted,
+ 108 => ErrorKind::UnexpectedEof,
+ 109 => ErrorKind::OutOfMemory,
+ 110 => ErrorKind::ConnectionReset,
+ _ => ErrorKind::Other,
+ }
+}
+
+pub(crate) fn io_errorkind_to_u8(kind: std::io::ErrorKind) -> u8 {
+ use std::io::ErrorKind;
+ match kind {
+ ErrorKind::ConnectionAborted => 101,
+ ErrorKind::BrokenPipe => 102,
+ ErrorKind::WouldBlock => 103,
+ ErrorKind::InvalidInput => 104,
+ ErrorKind::InvalidData => 105,
+ ErrorKind::TimedOut => 106,
+ ErrorKind::Interrupted => 107,
+ ErrorKind::UnexpectedEof => 108,
+ ErrorKind::OutOfMemory => 109,
+ ErrorKind::ConnectionReset => 110,
+ _ => 100,
+ }
+}
diff --git a/src/net/lib.rs b/src/net/lib.rs
new file mode 100644
index 00000000..8e30e40f
--- /dev/null
+++ b/src/net/lib.rs
@@ -0,0 +1,35 @@
+//! Netapp is a Rust library that takes care of a few common tasks in distributed software:
+//!
+//! - establishing secure connections
+//! - managing connection lifetime, reconnecting on failure
+//! - checking peer's state
+//! - peer discovery
+//! - query/response message passing model for communications
+//! - multiplexing transfers over a connection
+//! - overlay networks: full mesh, and soon other methods
+//!
+//! Of particular interest, read the documentation for the `netapp::NetApp` type,
+//! the `message::Message` trait, and `proto::RequestPriority` to learn more
+//! about message priorization.
+//! Also check out the examples to learn how to use this crate.
+
+pub mod bytes_buf;
+pub mod error;
+pub mod stream;
+pub mod util;
+
+pub mod endpoint;
+pub mod message;
+
+mod client;
+mod recv;
+mod send;
+mod server;
+
+pub mod netapp;
+pub mod peering;
+
+pub use crate::netapp::*;
+
+#[cfg(test)]
+mod test;
diff --git a/src/net/message.rs b/src/net/message.rs
new file mode 100644
index 00000000..b0d255c6
--- /dev/null
+++ b/src/net/message.rs
@@ -0,0 +1,522 @@
+use std::fmt;
+use std::marker::PhantomData;
+use std::sync::Arc;
+
+use bytes::{BufMut, Bytes, BytesMut};
+use rand::prelude::*;
+use serde::{Deserialize, Serialize};
+
+use futures::stream::StreamExt;
+
+use crate::error::*;
+use crate::stream::*;
+use crate::util::*;
+
+/// Priority of a request (click to read more about priorities).
+///
+/// This priority value is used to priorize messages
+/// in the send queue of the client, and their responses in the send queue of the
+/// server. Lower values mean higher priority.
+///
+/// This mechanism is usefull for messages bigger than the maximum chunk size
+/// (set at `0x4000` bytes), such as large file transfers.
+/// In such case, all of the messages in the send queue with the highest priority
+/// will take turns to send individual chunks, in a round-robin fashion.
+/// Once all highest priority messages are sent successfully, the messages with
+/// the next highest priority will begin being sent in the same way.
+///
+/// The same priority value is given to a request and to its associated response.
+pub type RequestPriority = u8;
+
+/// Priority class: high
+pub const PRIO_HIGH: RequestPriority = 0x20;
+/// Priority class: normal
+pub const PRIO_NORMAL: RequestPriority = 0x40;
+/// Priority class: background
+pub const PRIO_BACKGROUND: RequestPriority = 0x80;
+/// Priority: primary among given class
+pub const PRIO_PRIMARY: RequestPriority = 0x00;
+/// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`)
+pub const PRIO_SECONDARY: RequestPriority = 0x01;
+
+// ----
+
+/// An order tag can be added to a message or a response to indicate
+/// whether it should be sent after or before other messages with order tags
+/// referencing a same stream
+#[derive(Clone, Copy, Serialize, Deserialize, Debug)]
+pub struct OrderTag(pub(crate) u64, pub(crate) u64);
+
+/// A stream is an opaque identifier that defines a set of messages
+/// or responses that are ordered wrt one another using to order tags.
+#[derive(Clone, Copy)]
+pub struct OrderTagStream(u64);
+
+impl OrderTag {
+ /// Create a new stream from which to generate order tags. Example:
+ /// ```ignore
+ /// let stream = OrderTag.stream();
+ /// let tag_1 = stream.order(1);
+ /// let tag_2 = stream.order(2);
+ /// ```
+ pub fn stream() -> OrderTagStream {
+ OrderTagStream(thread_rng().gen())
+ }
+}
+impl OrderTagStream {
+ /// Create the order tag for message `order` in this stream
+ pub fn order(&self, order: u64) -> OrderTag {
+ OrderTag(self.0, order)
+ }
+}
+
+// ----
+
+/// This trait should be implemented by all messages your application
+/// wants to handle. It specifies which data type should be sent
+/// as a response to this message in the RPC protocol.
+pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static {
+ /// The type of the response that is sent in response to this message
+ type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static;
+}
+
+// ----
+
+/// The Req<M> is a helper object used to create requests and attach them
+/// a stream of data. If the stream is a fixed Bytes and not a ByteStream,
+/// Req<M> is cheaply clonable to allow the request to be sent to different
+/// peers (Clone will panic if the stream is a ByteStream).
+pub struct Req<M: Message> {
+ pub(crate) msg: Arc<M>,
+ pub(crate) msg_ser: Option<Bytes>,
+ pub(crate) stream: AttachedStream,
+ pub(crate) order_tag: Option<OrderTag>,
+}
+
+impl<M: Message> Req<M> {
+ /// Creates a new request from a base message `M`
+ pub fn new(v: M) -> Result<Self, Error> {
+ Ok(v.into_req()?)
+ }
+
+ /// Attach a stream to message in request, where the stream is streamed
+ /// from a fixed `Bytes` buffer
+ pub fn with_stream_from_buffer(self, b: Bytes) -> Self {
+ Self {
+ stream: AttachedStream::Fixed(b),
+ ..self
+ }
+ }
+
+ /// Attach a stream to message in request, where the stream is
+ /// an instance of `ByteStream`. Note than when a `Req<M>` has an attached
+ /// stream which is a `ByteStream` instance, it can no longer be cloned
+ /// to be sent to different nodes (`.clone()` will panic)
+ pub fn with_stream(self, b: ByteStream) -> Self {
+ Self {
+ stream: AttachedStream::Stream(b),
+ ..self
+ }
+ }
+
+ /// Add an order tag to this request to indicate in which order it should
+ /// be sent.
+ pub fn with_order_tag(self, order_tag: OrderTag) -> Self {
+ Self {
+ order_tag: Some(order_tag),
+ ..self
+ }
+ }
+
+ /// Get a reference to the message `M` contained in this request
+ pub fn msg(&self) -> &M {
+ &self.msg
+ }
+
+ /// Takes out the stream attached to this request, if any
+ pub fn take_stream(&mut self) -> Option<ByteStream> {
+ std::mem::replace(&mut self.stream, AttachedStream::None).into_stream()
+ }
+
+ pub(crate) fn into_enc(
+ self,
+ prio: RequestPriority,
+ path: Bytes,
+ telemetry_id: Bytes,
+ ) -> ReqEnc {
+ ReqEnc {
+ prio,
+ path,
+ telemetry_id,
+ msg: self.msg_ser.unwrap(),
+ stream: self.stream.into_stream(),
+ order_tag: self.order_tag,
+ }
+ }
+
+ pub(crate) fn from_enc(enc: ReqEnc) -> Result<Self, rmp_serde::decode::Error> {
+ let msg = rmp_serde::decode::from_slice(&enc.msg)?;
+ Ok(Req {
+ msg: Arc::new(msg),
+ msg_ser: Some(enc.msg),
+ stream: enc
+ .stream
+ .map(AttachedStream::Stream)
+ .unwrap_or(AttachedStream::None),
+ order_tag: enc.order_tag,
+ })
+ }
+}
+
+/// `IntoReq<M>` represents any object that can be transformed into `Req<M>`
+pub trait IntoReq<M: Message> {
+ /// Transform the object into a `Req<M>`, serializing the message M
+ /// to be sent to remote nodes
+ fn into_req(self) -> Result<Req<M>, rmp_serde::encode::Error>;
+ /// Transform the object into a `Req<M>`, skipping the serialization
+ /// of message M, in the case we are not sending this RPC message to
+ /// a remote node
+ fn into_req_local(self) -> Req<M>;
+}
+
+impl<M: Message> IntoReq<M> for M {
+ fn into_req(self) -> Result<Req<M>, rmp_serde::encode::Error> {
+ let msg_ser = rmp_to_vec_all_named(&self)?;
+ Ok(Req {
+ msg: Arc::new(self),
+ msg_ser: Some(Bytes::from(msg_ser)),
+ stream: AttachedStream::None,
+ order_tag: None,
+ })
+ }
+ fn into_req_local(self) -> Req<M> {
+ Req {
+ msg: Arc::new(self),
+ msg_ser: None,
+ stream: AttachedStream::None,
+ order_tag: None,
+ }
+ }
+}
+
+impl<M: Message> IntoReq<M> for Req<M> {
+ fn into_req(self) -> Result<Req<M>, rmp_serde::encode::Error> {
+ Ok(self)
+ }
+ fn into_req_local(self) -> Req<M> {
+ self
+ }
+}
+
+impl<M: Message> Clone for Req<M> {
+ fn clone(&self) -> Self {
+ let stream = match &self.stream {
+ AttachedStream::None => AttachedStream::None,
+ AttachedStream::Fixed(b) => AttachedStream::Fixed(b.clone()),
+ AttachedStream::Stream(_) => {
+ panic!("Cannot clone a Req<_> with a non-buffer attached stream")
+ }
+ };
+ Self {
+ msg: self.msg.clone(),
+ msg_ser: self.msg_ser.clone(),
+ stream,
+ order_tag: self.order_tag,
+ }
+ }
+}
+
+impl<M> fmt::Debug for Req<M>
+where
+ M: Message + fmt::Debug,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
+ write!(f, "Req[{:?}", self.msg)?;
+ match &self.stream {
+ AttachedStream::None => write!(f, "]"),
+ AttachedStream::Fixed(b) => write!(f, "; stream=buf:{}]", b.len()),
+ AttachedStream::Stream(_) => write!(f, "; stream]"),
+ }
+ }
+}
+
+// ----
+
+/// The Resp<M> represents a full response from a RPC that may have
+/// an attached stream.
+pub struct Resp<M: Message> {
+ pub(crate) _phantom: PhantomData<M>,
+ pub(crate) msg: M::Response,
+ pub(crate) stream: AttachedStream,
+ pub(crate) order_tag: Option<OrderTag>,
+}
+
+impl<M: Message> Resp<M> {
+ /// Creates a new response from a base response message
+ pub fn new(v: M::Response) -> Self {
+ Resp {
+ _phantom: Default::default(),
+ msg: v,
+ stream: AttachedStream::None,
+ order_tag: None,
+ }
+ }
+
+ /// Attach a stream to message in response, where the stream is streamed
+ /// from a fixed `Bytes` buffer
+ pub fn with_stream_from_buffer(self, b: Bytes) -> Self {
+ Self {
+ stream: AttachedStream::Fixed(b),
+ ..self
+ }
+ }
+
+ /// Attach a stream to message in response, where the stream is
+ /// an instance of `ByteStream`.
+ pub fn with_stream(self, b: ByteStream) -> Self {
+ Self {
+ stream: AttachedStream::Stream(b),
+ ..self
+ }
+ }
+
+ /// Add an order tag to this response to indicate in which order it should
+ /// be sent.
+ pub fn with_order_tag(self, order_tag: OrderTag) -> Self {
+ Self {
+ order_tag: Some(order_tag),
+ ..self
+ }
+ }
+
+ /// Get a reference to the response message contained in this request
+ pub fn msg(&self) -> &M::Response {
+ &self.msg
+ }
+
+ /// Transforms the `Resp<M>` into the response message it contains,
+ /// dropping everything else (including attached data stream)
+ pub fn into_msg(self) -> M::Response {
+ self.msg
+ }
+
+ /// Transforms the `Resp<M>` into, on the one side, the response message
+ /// it contains, and on the other side, the associated data stream
+ /// if it exists
+ pub fn into_parts(self) -> (M::Response, Option<ByteStream>) {
+ (self.msg, self.stream.into_stream())
+ }
+
+ pub(crate) fn into_enc(self) -> Result<RespEnc, rmp_serde::encode::Error> {
+ Ok(RespEnc {
+ msg: rmp_to_vec_all_named(&self.msg)?.into(),
+ stream: self.stream.into_stream(),
+ order_tag: self.order_tag,
+ })
+ }
+
+ pub(crate) fn from_enc(enc: RespEnc) -> Result<Self, Error> {
+ let msg = rmp_serde::decode::from_slice(&enc.msg)?;
+ Ok(Self {
+ _phantom: Default::default(),
+ msg,
+ stream: enc
+ .stream
+ .map(AttachedStream::Stream)
+ .unwrap_or(AttachedStream::None),
+ order_tag: enc.order_tag,
+ })
+ }
+}
+
+impl<M> fmt::Debug for Resp<M>
+where
+ M: Message,
+ <M as Message>::Response: fmt::Debug,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
+ write!(f, "Resp[{:?}", self.msg)?;
+ match &self.stream {
+ AttachedStream::None => write!(f, "]"),
+ AttachedStream::Fixed(b) => write!(f, "; stream=buf:{}]", b.len()),
+ AttachedStream::Stream(_) => write!(f, "; stream]"),
+ }
+ }
+}
+
+// ----
+
+pub(crate) enum AttachedStream {
+ None,
+ Fixed(Bytes),
+ Stream(ByteStream),
+}
+
+impl AttachedStream {
+ pub fn into_stream(self) -> Option<ByteStream> {
+ match self {
+ AttachedStream::None => None,
+ AttachedStream::Fixed(b) => Some(Box::pin(futures::stream::once(async move { Ok(b) }))),
+ AttachedStream::Stream(s) => Some(s),
+ }
+ }
+}
+
+// ---- ----
+
+/// Encoding for requests into a ByteStream:
+/// - priority: u8
+/// - path length: u8
+/// - path: [u8; path length]
+/// - telemetry id length: u8
+/// - telemetry id: [u8; telemetry id length]
+/// - msg len: u32
+/// - msg [u8; ..]
+/// - the attached stream as the rest of the encoded stream
+pub(crate) struct ReqEnc {
+ pub(crate) prio: RequestPriority,
+ pub(crate) path: Bytes,
+ pub(crate) telemetry_id: Bytes,
+ pub(crate) msg: Bytes,
+ pub(crate) stream: Option<ByteStream>,
+ pub(crate) order_tag: Option<OrderTag>,
+}
+
+impl ReqEnc {
+ pub(crate) fn encode(self) -> (ByteStream, Option<OrderTag>) {
+ let mut buf = BytesMut::with_capacity(
+ self.path.len() + self.telemetry_id.len() + self.msg.len() + 16,
+ );
+
+ buf.put_u8(self.prio);
+
+ buf.put_u8(self.path.len() as u8);
+ buf.put(self.path);
+
+ buf.put_u8(self.telemetry_id.len() as u8);
+ buf.put(&self.telemetry_id[..]);
+
+ buf.put_u32(self.msg.len() as u32);
+
+ let header = buf.freeze();
+
+ let res_stream: ByteStream = if let Some(stream) = self.stream {
+ Box::pin(futures::stream::iter([Ok(header), Ok(self.msg)]).chain(stream))
+ } else {
+ Box::pin(futures::stream::iter([Ok(header), Ok(self.msg)]))
+ };
+ (res_stream, self.order_tag)
+ }
+
+ pub(crate) async fn decode(stream: ByteStream) -> Result<Self, Error> {
+ Self::decode_aux(stream)
+ .await
+ .map_err(read_exact_error_to_error)
+ }
+
+ async fn decode_aux(stream: ByteStream) -> Result<Self, ReadExactError> {
+ let mut reader = ByteStreamReader::new(stream);
+
+ let prio = reader.read_u8().await?;
+
+ let path_len = reader.read_u8().await?;
+ let path = reader.read_exact(path_len as usize).await?;
+
+ let telemetry_id_len = reader.read_u8().await?;
+ let telemetry_id = reader.read_exact(telemetry_id_len as usize).await?;
+
+ let msg_len = reader.read_u32().await?;
+ let msg = reader.read_exact(msg_len as usize).await?;
+
+ Ok(Self {
+ prio,
+ path,
+ telemetry_id,
+ msg,
+ stream: Some(reader.into_stream()),
+ order_tag: None,
+ })
+ }
+}
+
+/// Encoding for responses into a ByteStream:
+/// IF SUCCESS:
+/// - 0: u8
+/// - msg len: u32
+/// - msg [u8; ..]
+/// - the attached stream as the rest of the encoded stream
+/// IF ERROR:
+/// - message length + 1: u8
+/// - error code: u8
+/// - message: [u8; message_length]
+pub(crate) struct RespEnc {
+ msg: Bytes,
+ stream: Option<ByteStream>,
+ order_tag: Option<OrderTag>,
+}
+
+impl RespEnc {
+ pub(crate) fn encode(resp: Result<Self, Error>) -> (ByteStream, Option<OrderTag>) {
+ match resp {
+ Ok(Self {
+ msg,
+ stream,
+ order_tag,
+ }) => {
+ let mut buf = BytesMut::with_capacity(4);
+ buf.put_u32(msg.len() as u32);
+ let header = buf.freeze();
+
+ let res_stream: ByteStream = if let Some(stream) = stream {
+ Box::pin(futures::stream::iter([Ok(header), Ok(msg)]).chain(stream))
+ } else {
+ Box::pin(futures::stream::iter([Ok(header), Ok(msg)]))
+ };
+ (res_stream, order_tag)
+ }
+ Err(err) => {
+ let err = std::io::Error::new(
+ std::io::ErrorKind::Other,
+ format!("netapp error: {}", err),
+ );
+ (
+ Box::pin(futures::stream::once(async move { Err(err) })),
+ None,
+ )
+ }
+ }
+ }
+
+ pub(crate) async fn decode(stream: ByteStream) -> Result<Self, Error> {
+ Self::decode_aux(stream)
+ .await
+ .map_err(read_exact_error_to_error)
+ }
+
+ async fn decode_aux(stream: ByteStream) -> Result<Self, ReadExactError> {
+ let mut reader = ByteStreamReader::new(stream);
+
+ let msg_len = reader.read_u32().await?;
+ let msg = reader.read_exact(msg_len as usize).await?;
+
+ // Check whether the response stream still has data or not.
+ // If no more data is coming, this will defuse the request canceller.
+ // If we didn't do this, and the client doesn't try to read from the stream,
+ // the request canceller doesn't know that we read everything and
+ // sends a cancellation message to the server (which they don't care about).
+ reader.fill_buffer().await;
+
+ Ok(Self {
+ msg,
+ stream: Some(reader.into_stream()),
+ order_tag: None,
+ })
+ }
+}
+
+fn read_exact_error_to_error(e: ReadExactError) -> Error {
+ match e {
+ ReadExactError::Stream(err) => Error::Remote(err.kind(), err.to_string()),
+ ReadExactError::UnexpectedEos => Error::Framing,
+ }
+}
diff --git a/src/net/netapp.rs b/src/net/netapp.rs
new file mode 100644
index 00000000..b1ad9db8
--- /dev/null
+++ b/src/net/netapp.rs
@@ -0,0 +1,452 @@
+use std::collections::HashMap;
+use std::net::{IpAddr, SocketAddr};
+use std::sync::{Arc, RwLock};
+
+use log::{debug, error, info, trace, warn};
+
+use arc_swap::ArcSwapOption;
+use async_trait::async_trait;
+
+use serde::{Deserialize, Serialize};
+use sodiumoxide::crypto::auth;
+use sodiumoxide::crypto::sign::ed25519;
+
+use futures::stream::futures_unordered::FuturesUnordered;
+use futures::stream::StreamExt;
+use tokio::net::{TcpListener, TcpStream};
+use tokio::select;
+use tokio::sync::{mpsc, watch};
+
+use crate::client::*;
+use crate::endpoint::*;
+use crate::error::*;
+use crate::message::*;
+use crate::server::*;
+
+/// A node's identifier, which is also its public cryptographic key
+pub type NodeID = sodiumoxide::crypto::sign::ed25519::PublicKey;
+/// A node's secret key
+pub type NodeKey = sodiumoxide::crypto::sign::ed25519::SecretKey;
+/// A network key
+pub type NetworkKey = sodiumoxide::crypto::auth::Key;
+
+/// Tag which is exchanged between client and server upon connection establishment
+/// to check that they are running compatible versions of Netapp,
+/// composed of 8 bytes for Netapp version and 8 bytes for client version
+pub(crate) type VersionTag = [u8; 16];
+
+/// Value of the Netapp version used in the version tag
+pub(crate) const NETAPP_VERSION_TAG: u64 = 0x6e65746170700005; // netapp 0x0005
+
+#[derive(Serialize, Deserialize, Debug)]
+pub(crate) struct HelloMessage {
+ pub server_addr: Option<IpAddr>,
+ pub server_port: u16,
+}
+
+impl Message for HelloMessage {
+ type Response = ();
+}
+
+type OnConnectHandler = Box<dyn Fn(NodeID, SocketAddr, bool) + Send + Sync>;
+type OnDisconnectHandler = Box<dyn Fn(NodeID, bool) + Send + Sync>;
+
+/// NetApp is the main class that handles incoming and outgoing connections.
+///
+/// NetApp can be used in a stand-alone fashion or together with a peering strategy.
+/// If using it alone, you will want to set `on_connect` and `on_disconnect` events
+/// in order to manage information about the current peer list.
+///
+/// It is generally not necessary to use NetApp stand-alone, as the provided full mesh
+/// and RPS peering strategies take care of the most common use cases.
+pub struct NetApp {
+ listen_params: ArcSwapOption<ListenParams>,
+
+ /// Version tag, 8 bytes for netapp version, 8 bytes for app version
+ pub version_tag: VersionTag,
+ /// Network secret key
+ pub netid: auth::Key,
+ /// Our peer ID
+ pub id: NodeID,
+ /// Private key associated with our peer ID
+ pub privkey: ed25519::SecretKey,
+
+ pub(crate) server_conns: RwLock<HashMap<NodeID, Arc<ServerConn>>>,
+ pub(crate) client_conns: RwLock<HashMap<NodeID, Arc<ClientConn>>>,
+
+ pub(crate) endpoints: RwLock<HashMap<String, DynEndpoint>>,
+ hello_endpoint: ArcSwapOption<Endpoint<HelloMessage, NetApp>>,
+
+ on_connected_handler: ArcSwapOption<OnConnectHandler>,
+ on_disconnected_handler: ArcSwapOption<OnDisconnectHandler>,
+}
+
+struct ListenParams {
+ listen_addr: SocketAddr,
+ public_addr: Option<IpAddr>,
+}
+
+impl NetApp {
+ /// Creates a new instance of NetApp, which can serve either as a full p2p node,
+ /// or just as a passive client. To upgrade to a full p2p node, spawn a listener
+ /// using `.listen()`
+ ///
+ /// Our Peer ID is the public key associated to the secret key given here.
+ pub fn new(app_version_tag: u64, netid: auth::Key, privkey: ed25519::SecretKey) -> Arc<Self> {
+ let mut version_tag = [0u8; 16];
+ version_tag[0..8].copy_from_slice(&u64::to_be_bytes(NETAPP_VERSION_TAG)[..]);
+ version_tag[8..16].copy_from_slice(&u64::to_be_bytes(app_version_tag)[..]);
+
+ let id = privkey.public_key();
+ let netapp = Arc::new(Self {
+ listen_params: ArcSwapOption::new(None),
+ version_tag,
+ netid,
+ id,
+ privkey,
+ server_conns: RwLock::new(HashMap::new()),
+ client_conns: RwLock::new(HashMap::new()),
+ endpoints: RwLock::new(HashMap::new()),
+ hello_endpoint: ArcSwapOption::new(None),
+ on_connected_handler: ArcSwapOption::new(None),
+ on_disconnected_handler: ArcSwapOption::new(None),
+ });
+
+ netapp
+ .hello_endpoint
+ .swap(Some(netapp.endpoint("__netapp/netapp.rs/Hello".into())));
+ netapp
+ .hello_endpoint
+ .load_full()
+ .unwrap()
+ .set_handler(netapp.clone());
+
+ netapp
+ }
+
+ /// Set the handler to be called when a new connection (incoming or outgoing) has
+ /// been successfully established. Do not set this if using a peering strategy,
+ /// as the peering strategy will need to set this itself.
+ pub fn on_connected<F>(&self, handler: F)
+ where
+ F: Fn(NodeID, SocketAddr, bool) + Sized + Send + Sync + 'static,
+ {
+ self.on_connected_handler
+ .store(Some(Arc::new(Box::new(handler))));
+ }
+
+ /// Set the handler to be called when an existing connection (incoming or outgoing) has
+ /// been closed by either party. Do not set this if using a peering strategy,
+ /// as the peering strategy will need to set this itself.
+ pub fn on_disconnected<F>(&self, handler: F)
+ where
+ F: Fn(NodeID, bool) + Sized + Send + Sync + 'static,
+ {
+ self.on_disconnected_handler
+ .store(Some(Arc::new(Box::new(handler))));
+ }
+
+ /// Create a new endpoint with path `path`,
+ /// that handles messages of type `M`.
+ /// `H` is the type of the object that should handle requests
+ /// to this endpoint on the local node. If you don't want
+ /// to handle request on the local node (e.g. if this node
+ /// is only a client in the network), define the type `H`
+ /// to be `()`.
+ /// This function will panic if the endpoint has already been
+ /// created.
+ pub fn endpoint<M, H>(self: &Arc<Self>, path: String) -> Arc<Endpoint<M, H>>
+ where
+ M: Message + 'static,
+ H: StreamingEndpointHandler<M> + 'static,
+ {
+ let endpoint = Arc::new(Endpoint::<M, H>::new(self.clone(), path.clone()));
+ let endpoint_arc = EndpointArc(endpoint.clone());
+ if self
+ .endpoints
+ .write()
+ .unwrap()
+ .insert(path.clone(), Box::new(endpoint_arc))
+ .is_some()
+ {
+ panic!("Redefining endpoint: {}", path);
+ };
+ endpoint
+ }
+
+ /// Main listening process for our app. This future runs during the whole
+ /// run time of our application.
+ /// If this is not called, the NetApp instance remains a passive client.
+ pub async fn listen(
+ self: Arc<Self>,
+ listen_addr: SocketAddr,
+ public_addr: Option<IpAddr>,
+ mut must_exit: watch::Receiver<bool>,
+ ) {
+ let listen_params = ListenParams {
+ listen_addr,
+ public_addr,
+ };
+ if self
+ .listen_params
+ .swap(Some(Arc::new(listen_params)))
+ .is_some()
+ {
+ error!("Trying to listen on NetApp but we're already listening!");
+ }
+
+ let listener = TcpListener::bind(listen_addr).await.unwrap();
+ info!("Listening on {}", listen_addr);
+
+ let (conn_in, mut conn_out) = mpsc::unbounded_channel();
+ let connection_collector = tokio::spawn(async move {
+ let mut collection = FuturesUnordered::new();
+ loop {
+ if collection.is_empty() {
+ match conn_out.recv().await {
+ Some(f) => collection.push(f),
+ None => break,
+ }
+ } else {
+ select! {
+ new_fut = conn_out.recv() => {
+ match new_fut {
+ Some(f) => collection.push(f),
+ None => break,
+ }
+ }
+ result = collection.next() => {
+ trace!("Collected connection: {:?}", result);
+ }
+ }
+ }
+ }
+ debug!("Collecting last open server connections.");
+ while let Some(conn_res) = collection.next().await {
+ trace!("Collected connection: {:?}", conn_res);
+ }
+ debug!("No more server connections to collect");
+ });
+
+ while !*must_exit.borrow_and_update() {
+ let (socket, peer_addr) = select! {
+ sockres = listener.accept() => {
+ match sockres {
+ Ok(x) => x,
+ Err(e) => {
+ warn!("Error in listener.accept: {}", e);
+ continue;
+ }
+ }
+ },
+ _ = must_exit.changed() => continue,
+ };
+
+ info!(
+ "Incoming connection from {}, negotiating handshake...",
+ peer_addr
+ );
+ let self2 = self.clone();
+ let must_exit2 = must_exit.clone();
+ conn_in
+ .send(tokio::spawn(async move {
+ ServerConn::run(self2, socket, must_exit2)
+ .await
+ .log_err("ServerConn::run");
+ }))
+ .log_err("Failed to send connection to connection collector");
+ }
+
+ drop(conn_in);
+
+ connection_collector
+ .await
+ .log_err("Failed to await for connection collector");
+ }
+
+ /// Drop all endpoint handlers, as well as handlers for connection/disconnection
+ /// events. (This disables the peering strategy)
+ ///
+ /// Use this when terminating to break reference cycles
+ pub fn drop_all_handlers(&self) {
+ for (_, endpoint) in self.endpoints.read().unwrap().iter() {
+ endpoint.drop_handler();
+ }
+ self.on_connected_handler.store(None);
+ self.on_disconnected_handler.store(None);
+ }
+
+ /// Attempt to connect to a peer, given by its ip:port and its public key.
+ /// The public key will be checked during the secret handshake process.
+ /// This function returns once the connection has been established and a
+ /// successfull handshake was made. At this point we can send messages to
+ /// the other node with `Netapp::request`
+ pub async fn try_connect(self: Arc<Self>, ip: SocketAddr, id: NodeID) -> Result<(), Error> {
+ // Don't connect to ourself, we don't care
+ // but pretend we did
+ if id == self.id {
+ tokio::spawn(async move {
+ if let Some(h) = self.on_connected_handler.load().as_ref() {
+ h(id, ip, false);
+ }
+ });
+ return Ok(());
+ }
+
+ // Don't connect if already connected
+ if self.client_conns.read().unwrap().contains_key(&id) {
+ return Ok(());
+ }
+
+ let socket = TcpStream::connect(ip).await?;
+ info!("Connected to {}, negotiating handshake...", ip);
+ ClientConn::init(self, socket, id).await?;
+ Ok(())
+ }
+
+ /// Close the outgoing connection we have to a node specified by its public key,
+ /// if such a connection is currently open.
+ pub fn disconnect(self: &Arc<Self>, id: &NodeID) {
+ // If id is ourself, we're not supposed to have a connection open
+ if *id != self.id {
+ let conn = self.client_conns.write().unwrap().remove(id);
+ if let Some(c) = conn {
+ debug!(
+ "Closing connection to {} ({})",
+ hex::encode(&c.peer_id[..8]),
+ c.remote_addr
+ );
+ c.close();
+ } else {
+ return;
+ }
+ }
+
+ // call on_disconnected_handler immediately, since the connection
+ // was removed
+ // (if id == self.id, we pretend we disconnected)
+ let id = *id;
+ let self2 = self.clone();
+ tokio::spawn(async move {
+ if let Some(h) = self2.on_disconnected_handler.load().as_ref() {
+ h(id, false);
+ }
+ });
+ }
+
+ // Called from conn.rs when an incoming connection is successfully established
+ // Registers the connection in our list of connections
+ // Do not yet call the on_connected handler, because we don't know if the remote
+ // has an actual IP address and port we can call them back on.
+ // We will know this when they send a Hello message, which is handled below.
+ pub(crate) fn connected_as_server(&self, id: NodeID, conn: Arc<ServerConn>) {
+ info!(
+ "Accepted connection from {} at {}",
+ hex::encode(&id[..8]),
+ conn.remote_addr
+ );
+
+ self.server_conns.write().unwrap().insert(id, conn);
+ }
+
+ // Handle hello message from a client. This message is used for them to tell us
+ // that they are listening on a certain port number on which we can call them back.
+ // At this point we know they are a full network member, and not just a client,
+ // and we call the on_connected handler so that the peering strategy knows
+ // we have a new potential peer
+
+ // Called from conn.rs when an incoming connection is closed.
+ // We deregister the connection from server_conns and call the
+ // handler registered by on_disconnected
+ pub(crate) fn disconnected_as_server(&self, id: &NodeID, conn: Arc<ServerConn>) {
+ info!("Connection from {} closed", hex::encode(&id[..8]));
+
+ let mut conn_list = self.server_conns.write().unwrap();
+ if let Some(c) = conn_list.get(id) {
+ if Arc::ptr_eq(c, &conn) {
+ conn_list.remove(id);
+ drop(conn_list);
+
+ if let Some(h) = self.on_disconnected_handler.load().as_ref() {
+ h(conn.peer_id, true);
+ }
+ }
+ }
+ }
+
+ // Called from conn.rs when an outgoinc connection is successfully established.
+ // The connection is registered in self.client_conns, and the
+ // on_connected handler is called.
+ //
+ // Since we are ourself listening, we send them a Hello message so that
+ // they know on which port to call us back. (TODO: don't do this if we are
+ // just a simple client and not a full p2p node)
+ pub(crate) fn connected_as_client(&self, id: NodeID, conn: Arc<ClientConn>) {
+ info!("Connection established to {}", hex::encode(&id[..8]));
+
+ {
+ let old_c_opt = self.client_conns.write().unwrap().insert(id, conn.clone());
+ if let Some(old_c) = old_c_opt {
+ tokio::spawn(async move { old_c.close() });
+ }
+ }
+
+ if let Some(h) = self.on_connected_handler.load().as_ref() {
+ h(conn.peer_id, conn.remote_addr, false);
+ }
+
+ if let Some(lp) = self.listen_params.load_full() {
+ let server_addr = lp.public_addr;
+ let server_port = lp.listen_addr.port();
+ let hello_endpoint = self.hello_endpoint.load_full().unwrap();
+ tokio::spawn(async move {
+ hello_endpoint
+ .call(
+ &conn.peer_id,
+ HelloMessage {
+ server_addr,
+ server_port,
+ },
+ PRIO_NORMAL,
+ )
+ .await
+ .map(|_| ())
+ .log_err("Sending hello message");
+ });
+ }
+ }
+
+ // Called from conn.rs when an outgoinc connection is closed.
+ // The connection is removed from conn_list, and the on_disconnected handler
+ // is called.
+ pub(crate) fn disconnected_as_client(&self, id: &NodeID, conn: Arc<ClientConn>) {
+ info!("Connection to {} closed", hex::encode(&id[..8]));
+ let mut conn_list = self.client_conns.write().unwrap();
+ if let Some(c) = conn_list.get(id) {
+ if Arc::ptr_eq(c, &conn) {
+ conn_list.remove(id);
+ drop(conn_list);
+
+ if let Some(h) = self.on_disconnected_handler.load().as_ref() {
+ h(conn.peer_id, false);
+ }
+ }
+ }
+ // else case: happens if connection was removed in .disconnect()
+ // in which case on_disconnected_handler was already called
+ }
+}
+
+#[async_trait]
+impl EndpointHandler<HelloMessage> for NetApp {
+ async fn handle(self: &Arc<Self>, msg: &HelloMessage, from: NodeID) {
+ debug!("Hello from {:?}: {:?}", hex::encode(&from[..8]), msg);
+ if let Some(h) = self.on_connected_handler.load().as_ref() {
+ if let Some(c) = self.server_conns.read().unwrap().get(&from) {
+ let remote_ip = msg.server_addr.unwrap_or_else(|| c.remote_addr.ip());
+ let remote_addr = SocketAddr::new(remote_ip, msg.server_port);
+ h(from, remote_addr, true);
+ }
+ }
+ }
+}
diff --git a/src/net/peering.rs b/src/net/peering.rs
new file mode 100644
index 00000000..32199cf8
--- /dev/null
+++ b/src/net/peering.rs
@@ -0,0 +1,614 @@
+use std::collections::{HashMap, VecDeque};
+use std::net::SocketAddr;
+use std::sync::atomic::{self, AtomicU64};
+use std::sync::{Arc, RwLock};
+use std::time::{Duration, Instant};
+
+use arc_swap::ArcSwap;
+use async_trait::async_trait;
+use log::{debug, info, trace, warn};
+use serde::{Deserialize, Serialize};
+
+use tokio::select;
+use tokio::sync::watch;
+
+use sodiumoxide::crypto::hash;
+
+use crate::endpoint::*;
+use crate::error::*;
+use crate::netapp::*;
+
+use crate::message::*;
+use crate::NodeID;
+
+const CONN_RETRY_INTERVAL: Duration = Duration::from_secs(30);
+const CONN_MAX_RETRIES: usize = 10;
+const PING_INTERVAL: Duration = Duration::from_secs(15);
+const LOOP_DELAY: Duration = Duration::from_secs(1);
+const FAILED_PING_THRESHOLD: usize = 4;
+
+const DEFAULT_PING_TIMEOUT_MILLIS: u64 = 10_000;
+
+// -- Protocol messages --
+
+#[derive(Serialize, Deserialize)]
+struct PingMessage {
+ pub id: u64,
+ pub peer_list_hash: hash::Digest,
+}
+
+impl Message for PingMessage {
+ type Response = PingMessage;
+}
+
+#[derive(Serialize, Deserialize)]
+struct PeerListMessage {
+ pub list: Vec<(NodeID, SocketAddr)>,
+}
+
+impl Message for PeerListMessage {
+ type Response = PeerListMessage;
+}
+
+// -- Algorithm data structures --
+
+#[derive(Debug)]
+struct PeerInfoInternal {
+ // addr is the currently connected address,
+ // or the last address we were connected to,
+ // or an arbitrary address some other peer gave us
+ addr: SocketAddr,
+ // all_addrs contains all of the addresses everyone gave us
+ all_addrs: Vec<SocketAddr>,
+
+ state: PeerConnState,
+ last_send_ping: Option<Instant>,
+ last_seen: Option<Instant>,
+ ping: VecDeque<Duration>,
+ failed_pings: usize,
+}
+
+impl PeerInfoInternal {
+ fn new(addr: SocketAddr, state: PeerConnState) -> Self {
+ Self {
+ addr,
+ all_addrs: vec![addr],
+ state,
+ last_send_ping: None,
+ last_seen: None,
+ ping: VecDeque::new(),
+ failed_pings: 0,
+ }
+ }
+}
+
+/// Information that the full mesh peering strategy can return about the peers it knows of
+#[derive(Copy, Clone, Debug)]
+pub struct PeerInfo {
+ /// The node's identifier (its public key)
+ pub id: NodeID,
+ /// The node's network address
+ pub addr: SocketAddr,
+ /// The current status of our connection to this node
+ pub state: PeerConnState,
+ /// The last time at which the node was seen
+ pub last_seen: Option<Instant>,
+ /// The average ping to this node on recent observations (if at least one ping value is known)
+ pub avg_ping: Option<Duration>,
+ /// The maximum observed ping to this node on recent observations (if at least one
+ /// ping value is known)
+ pub max_ping: Option<Duration>,
+ /// The median ping to this node on recent observations (if at least one ping value
+ /// is known)
+ pub med_ping: Option<Duration>,
+}
+
+impl PeerInfo {
+ /// Returns true if we can currently send requests to this peer
+ pub fn is_up(&self) -> bool {
+ self.state.is_up()
+ }
+}
+
+/// PeerConnState: possible states for our tentative connections to given peer
+/// This structure is only interested in recording connection info for outgoing
+/// TCP connections
+#[derive(Copy, Clone, Debug, PartialEq, Eq)]
+pub enum PeerConnState {
+ /// This entry represents ourself (the local node)
+ Ourself,
+
+ /// We currently have a connection to this peer
+ Connected,
+
+ /// Our next connection tentative (the nth, where n is the first value of the tuple)
+ /// will be at given Instant
+ Waiting(usize, Instant),
+
+ /// A connection tentative is in progress (the nth, where n is the value stored)
+ Trying(usize),
+
+ /// We abandonned trying to connect to this peer (too many failed attempts)
+ Abandonned,
+}
+
+impl PeerConnState {
+ /// Returns true if we can currently send requests to this peer
+ pub fn is_up(&self) -> bool {
+ matches!(self, Self::Ourself | Self::Connected)
+ }
+}
+
+struct KnownHosts {
+ list: HashMap<NodeID, PeerInfoInternal>,
+ hash: hash::Digest,
+}
+
+impl KnownHosts {
+ fn new() -> Self {
+ let list = HashMap::new();
+ let hash = Self::calculate_hash(&list);
+ Self { list, hash }
+ }
+ fn update_hash(&mut self) {
+ self.hash = Self::calculate_hash(&self.list);
+ }
+ fn map_into_vec(input: &HashMap<NodeID, PeerInfoInternal>) -> Vec<(NodeID, SocketAddr)> {
+ let mut list = Vec::with_capacity(input.len());
+ for (id, peer) in input.iter() {
+ if peer.state == PeerConnState::Connected || peer.state == PeerConnState::Ourself {
+ list.push((*id, peer.addr));
+ }
+ }
+ list
+ }
+ fn calculate_hash(input: &HashMap<NodeID, PeerInfoInternal>) -> hash::Digest {
+ let mut list = Self::map_into_vec(input);
+ list.sort();
+ let mut hash_state = hash::State::new();
+ for (id, addr) in list {
+ hash_state.update(&id[..]);
+ hash_state.update(&format!("{}\n", addr).into_bytes()[..]);
+ }
+ hash_state.finalize()
+ }
+}
+
+/// A "Full Mesh" peering strategy is a peering strategy that tries
+/// to establish and maintain a direct connection with all of the
+/// known nodes in the network.
+pub struct PeeringManager {
+ netapp: Arc<NetApp>,
+ known_hosts: RwLock<KnownHosts>,
+ public_peer_list: ArcSwap<Vec<PeerInfo>>,
+
+ next_ping_id: AtomicU64,
+ ping_endpoint: Arc<Endpoint<PingMessage, Self>>,
+ peer_list_endpoint: Arc<Endpoint<PeerListMessage, Self>>,
+
+ ping_timeout_millis: AtomicU64,
+}
+
+impl PeeringManager {
+ /// Create a new Full Mesh peering strategy.
+ /// The strategy will not be run until `.run()` is called and awaited.
+ /// Once that happens, the peering strategy will try to connect
+ /// to all of the nodes specified in the bootstrap list.
+ pub fn new(
+ netapp: Arc<NetApp>,
+ bootstrap_list: Vec<(NodeID, SocketAddr)>,
+ our_addr: Option<SocketAddr>,
+ ) -> Arc<Self> {
+ let mut known_hosts = KnownHosts::new();
+ for (id, addr) in bootstrap_list {
+ if id != netapp.id {
+ known_hosts.list.insert(
+ id,
+ PeerInfoInternal::new(addr, PeerConnState::Waiting(0, Instant::now())),
+ );
+ }
+ }
+
+ if let Some(addr) = our_addr {
+ known_hosts.list.insert(
+ netapp.id,
+ PeerInfoInternal::new(addr, PeerConnState::Ourself),
+ );
+ }
+
+ // TODO for v0.10 / v1.0 : rename the endpoint (it will break compatibility)
+ let strat = Arc::new(Self {
+ netapp: netapp.clone(),
+ known_hosts: RwLock::new(known_hosts),
+ public_peer_list: ArcSwap::new(Arc::new(Vec::new())),
+ next_ping_id: AtomicU64::new(42),
+ ping_endpoint: netapp.endpoint("__netapp/peering/fullmesh.rs/Ping".into()),
+ peer_list_endpoint: netapp.endpoint("__netapp/peering/fullmesh.rs/PeerList".into()),
+ ping_timeout_millis: DEFAULT_PING_TIMEOUT_MILLIS.into(),
+ });
+
+ strat.update_public_peer_list(&strat.known_hosts.read().unwrap());
+
+ strat.ping_endpoint.set_handler(strat.clone());
+ strat.peer_list_endpoint.set_handler(strat.clone());
+
+ let strat2 = strat.clone();
+ netapp.on_connected(move |id: NodeID, addr: SocketAddr, is_incoming: bool| {
+ let strat2 = strat2.clone();
+ strat2.on_connected(id, addr, is_incoming);
+ });
+
+ let strat2 = strat.clone();
+ netapp.on_disconnected(move |id: NodeID, is_incoming: bool| {
+ let strat2 = strat2.clone();
+ strat2.on_disconnected(id, is_incoming);
+ });
+
+ strat
+ }
+
+ /// Run the full mesh peering strategy.
+ /// This future exits when the `must_exit` watch becomes true.
+ pub async fn run(self: Arc<Self>, must_exit: watch::Receiver<bool>) {
+ while !*must_exit.borrow() {
+ // 1. Read current state: get list of connected peers (ping them)
+ let (to_ping, to_retry) = {
+ let known_hosts = self.known_hosts.read().unwrap();
+ trace!("known_hosts: {} peers", known_hosts.list.len());
+
+ let mut to_ping = vec![];
+ let mut to_retry = vec![];
+ for (id, info) in known_hosts.list.iter() {
+ trace!("{}, {:?}", hex::encode(&id[..8]), info);
+ match info.state {
+ PeerConnState::Connected => {
+ let must_ping = match info.last_send_ping {
+ None => true,
+ Some(t) => Instant::now() - t > PING_INTERVAL,
+ };
+ if must_ping {
+ to_ping.push(*id);
+ }
+ }
+ PeerConnState::Waiting(_, t) => {
+ if Instant::now() >= t {
+ to_retry.push(*id);
+ }
+ }
+ _ => (),
+ }
+ }
+ (to_ping, to_retry)
+ };
+
+ // 2. Dispatch ping to hosts
+ trace!("to_ping: {} peers", to_ping.len());
+ if !to_ping.is_empty() {
+ let mut known_hosts = self.known_hosts.write().unwrap();
+ for id in to_ping.iter() {
+ known_hosts.list.get_mut(id).unwrap().last_send_ping = Some(Instant::now());
+ }
+ drop(known_hosts);
+ for id in to_ping {
+ tokio::spawn(self.clone().ping(id));
+ }
+ }
+
+ // 3. Try reconnects
+ trace!("to_retry: {} peers", to_retry.len());
+ if !to_retry.is_empty() {
+ let mut known_hosts = self.known_hosts.write().unwrap();
+ for id in to_retry {
+ if let Some(h) = known_hosts.list.get_mut(&id) {
+ if let PeerConnState::Waiting(i, _) = h.state {
+ info!(
+ "Retrying connection to {} at {} ({})",
+ hex::encode(&id[..8]),
+ h.all_addrs
+ .iter()
+ .map(|x| format!("{}", x))
+ .collect::<Vec<_>>()
+ .join(", "),
+ i + 1
+ );
+ h.state = PeerConnState::Trying(i);
+
+ let alternate_addrs = h
+ .all_addrs
+ .iter()
+ .filter(|x| **x != h.addr)
+ .cloned()
+ .collect::<Vec<_>>();
+ tokio::spawn(self.clone().try_connect(id, h.addr, alternate_addrs));
+ }
+ }
+ }
+ self.update_public_peer_list(&known_hosts);
+ }
+
+ // 4. Sleep before next loop iteration
+ tokio::time::sleep(LOOP_DELAY).await;
+ }
+ }
+
+ /// Returns a list of currently known peers in the network.
+ pub fn get_peer_list(&self) -> Arc<Vec<PeerInfo>> {
+ self.public_peer_list.load_full()
+ }
+
+ /// Set the timeout for ping messages, in milliseconds
+ pub fn set_ping_timeout_millis(&self, timeout: u64) {
+ self.ping_timeout_millis
+ .store(timeout, atomic::Ordering::Relaxed);
+ }
+
+ // -- internal stuff --
+
+ fn update_public_peer_list(&self, known_hosts: &KnownHosts) {
+ let mut pub_peer_list = Vec::with_capacity(known_hosts.list.len());
+ for (id, info) in known_hosts.list.iter() {
+ let mut pings = info.ping.iter().cloned().collect::<Vec<_>>();
+ pings.sort();
+ if !pings.is_empty() {
+ pub_peer_list.push(PeerInfo {
+ id: *id,
+ addr: info.addr,
+ state: info.state,
+ last_seen: info.last_seen,
+ avg_ping: Some(
+ pings
+ .iter()
+ .fold(Duration::from_secs(0), |x, y| x + *y)
+ .div_f64(pings.len() as f64),
+ ),
+ max_ping: pings.last().cloned(),
+ med_ping: Some(pings[pings.len() / 2]),
+ });
+ } else {
+ pub_peer_list.push(PeerInfo {
+ id: *id,
+ addr: info.addr,
+ state: info.state,
+ last_seen: info.last_seen,
+ avg_ping: None,
+ max_ping: None,
+ med_ping: None,
+ });
+ }
+ }
+ self.public_peer_list.store(Arc::new(pub_peer_list));
+ }
+
+ async fn ping(self: Arc<Self>, id: NodeID) {
+ let peer_list_hash = self.known_hosts.read().unwrap().hash;
+ let ping_id = self.next_ping_id.fetch_add(1u64, atomic::Ordering::Relaxed);
+ let ping_time = Instant::now();
+ let ping_timeout =
+ Duration::from_millis(self.ping_timeout_millis.load(atomic::Ordering::Relaxed));
+ let ping_msg = PingMessage {
+ id: ping_id,
+ peer_list_hash,
+ };
+
+ debug!(
+ "Sending ping {} to {} at {:?}",
+ ping_id,
+ hex::encode(&id[..8]),
+ ping_time
+ );
+ let ping_response = select! {
+ r = self.ping_endpoint.call(&id, ping_msg, PRIO_HIGH) => r,
+ _ = tokio::time::sleep(ping_timeout) => Err(Error::Message("Ping timeout".into())),
+ };
+
+ match ping_response {
+ Err(e) => {
+ warn!("Error pinging {}: {}", hex::encode(&id[..8]), e);
+ let mut known_hosts = self.known_hosts.write().unwrap();
+ if let Some(host) = known_hosts.list.get_mut(&id) {
+ host.failed_pings += 1;
+ if host.failed_pings > FAILED_PING_THRESHOLD {
+ warn!(
+ "Too many failed pings from {}, closing connection.",
+ hex::encode(&id[..8])
+ );
+ // this will later update info in known_hosts
+ // through the disconnection handler
+ self.netapp.disconnect(&id);
+ }
+ }
+ }
+ Ok(ping_resp) => {
+ let resp_time = Instant::now();
+ debug!(
+ "Got ping response from {} at {:?}",
+ hex::encode(&id[..8]),
+ resp_time
+ );
+ {
+ let mut known_hosts = self.known_hosts.write().unwrap();
+ if let Some(host) = known_hosts.list.get_mut(&id) {
+ host.failed_pings = 0;
+ host.last_seen = Some(resp_time);
+ host.ping.push_back(resp_time - ping_time);
+ while host.ping.len() > 10 {
+ host.ping.pop_front();
+ }
+ self.update_public_peer_list(&known_hosts);
+ }
+ }
+ if ping_resp.peer_list_hash != peer_list_hash {
+ self.exchange_peers(&id).await;
+ }
+ }
+ }
+ }
+
+ async fn exchange_peers(self: Arc<Self>, id: &NodeID) {
+ let peer_list = KnownHosts::map_into_vec(&self.known_hosts.read().unwrap().list);
+ let pex_message = PeerListMessage { list: peer_list };
+ match self
+ .peer_list_endpoint
+ .call(id, pex_message, PRIO_BACKGROUND)
+ .await
+ {
+ Err(e) => warn!("Error doing peer exchange: {}", e),
+ Ok(resp) => {
+ self.handle_peer_list(&resp.list[..]);
+ }
+ }
+ }
+
+ fn handle_peer_list(&self, list: &[(NodeID, SocketAddr)]) {
+ let mut known_hosts = self.known_hosts.write().unwrap();
+
+ let mut changed = false;
+ for (id, addr) in list.iter() {
+ if let Some(kh) = known_hosts.list.get_mut(id) {
+ if !kh.all_addrs.contains(addr) {
+ kh.all_addrs.push(*addr);
+ changed = true;
+ }
+ } else {
+ known_hosts.list.insert(*id, self.new_peer(id, *addr));
+ changed = true;
+ }
+ }
+
+ if changed {
+ known_hosts.update_hash();
+ self.update_public_peer_list(&known_hosts);
+ }
+ }
+
+ async fn try_connect(
+ self: Arc<Self>,
+ id: NodeID,
+ default_addr: SocketAddr,
+ alternate_addrs: Vec<SocketAddr>,
+ ) {
+ let conn_addr = {
+ let mut ret = None;
+ for addr in [default_addr].iter().chain(alternate_addrs.iter()) {
+ debug!("Trying address {} for peer {}", addr, hex::encode(&id[..8]));
+ match self.netapp.clone().try_connect(*addr, id).await {
+ Ok(()) => {
+ ret = Some(*addr);
+ break;
+ }
+ Err(e) => {
+ debug!(
+ "Error connecting to {} at {}: {}",
+ hex::encode(&id[..8]),
+ addr,
+ e
+ );
+ }
+ }
+ }
+ ret
+ };
+
+ if let Some(ok_addr) = conn_addr {
+ self.on_connected(id, ok_addr, false);
+ } else {
+ warn!(
+ "Could not connect to peer {} ({} addresses tried)",
+ hex::encode(&id[..8]),
+ 1 + alternate_addrs.len()
+ );
+ let mut known_hosts = self.known_hosts.write().unwrap();
+ if let Some(host) = known_hosts.list.get_mut(&id) {
+ host.state = match host.state {
+ PeerConnState::Trying(i) => {
+ if i >= CONN_MAX_RETRIES {
+ PeerConnState::Abandonned
+ } else {
+ PeerConnState::Waiting(i + 1, Instant::now() + CONN_RETRY_INTERVAL)
+ }
+ }
+ _ => PeerConnState::Waiting(0, Instant::now() + CONN_RETRY_INTERVAL),
+ };
+ self.update_public_peer_list(&known_hosts);
+ }
+ }
+ }
+
+ fn on_connected(self: Arc<Self>, id: NodeID, addr: SocketAddr, is_incoming: bool) {
+ let mut known_hosts = self.known_hosts.write().unwrap();
+ if is_incoming {
+ if let Some(host) = known_hosts.list.get_mut(&id) {
+ if !host.all_addrs.contains(&addr) {
+ host.all_addrs.push(addr);
+ }
+ } else {
+ known_hosts.list.insert(id, self.new_peer(&id, addr));
+ }
+ } else {
+ info!(
+ "Successfully connected to {} at {}",
+ hex::encode(&id[..8]),
+ addr
+ );
+ if let Some(host) = known_hosts.list.get_mut(&id) {
+ host.state = PeerConnState::Connected;
+ host.addr = addr;
+ if !host.all_addrs.contains(&addr) {
+ host.all_addrs.push(addr);
+ }
+ } else {
+ known_hosts
+ .list
+ .insert(id, PeerInfoInternal::new(addr, PeerConnState::Connected));
+ }
+ }
+ known_hosts.update_hash();
+ self.update_public_peer_list(&known_hosts);
+ }
+
+ fn on_disconnected(self: Arc<Self>, id: NodeID, is_incoming: bool) {
+ if !is_incoming {
+ info!("Connection to {} was closed", hex::encode(&id[..8]));
+ let mut known_hosts = self.known_hosts.write().unwrap();
+ if let Some(host) = known_hosts.list.get_mut(&id) {
+ host.state = PeerConnState::Waiting(0, Instant::now());
+ known_hosts.update_hash();
+ self.update_public_peer_list(&known_hosts);
+ }
+ }
+ }
+
+ fn new_peer(&self, id: &NodeID, addr: SocketAddr) -> PeerInfoInternal {
+ let state = if *id == self.netapp.id {
+ PeerConnState::Ourself
+ } else {
+ PeerConnState::Waiting(0, Instant::now())
+ };
+ PeerInfoInternal::new(addr, state)
+ }
+}
+
+#[async_trait]
+impl EndpointHandler<PingMessage> for PeeringManager {
+ async fn handle(self: &Arc<Self>, ping: &PingMessage, from: NodeID) -> PingMessage {
+ let ping_resp = PingMessage {
+ id: ping.id,
+ peer_list_hash: self.known_hosts.read().unwrap().hash,
+ };
+ debug!("Ping from {}", hex::encode(&from[..8]));
+ ping_resp
+ }
+}
+
+#[async_trait]
+impl EndpointHandler<PeerListMessage> for PeeringManager {
+ async fn handle(
+ self: &Arc<Self>,
+ peer_list: &PeerListMessage,
+ _from: NodeID,
+ ) -> PeerListMessage {
+ self.handle_peer_list(&peer_list.list[..]);
+ let peer_list = KnownHosts::map_into_vec(&self.known_hosts.read().unwrap().list);
+ PeerListMessage { list: peer_list }
+ }
+}
diff --git a/src/net/recv.rs b/src/net/recv.rs
new file mode 100644
index 00000000..0de7bef2
--- /dev/null
+++ b/src/net/recv.rs
@@ -0,0 +1,153 @@
+use std::collections::HashMap;
+use std::sync::Arc;
+
+use async_trait::async_trait;
+use bytes::Bytes;
+use log::*;
+
+use futures::AsyncReadExt;
+use tokio::sync::mpsc;
+
+use crate::error::*;
+use crate::send::*;
+use crate::stream::*;
+
+/// Structure to warn when the sender is dropped before end of stream was reached, like when
+/// connection to some remote drops while transmitting data
+struct Sender {
+ inner: Option<mpsc::UnboundedSender<Packet>>,
+}
+
+impl Sender {
+ fn new(inner: mpsc::UnboundedSender<Packet>) -> Self {
+ Sender { inner: Some(inner) }
+ }
+
+ fn send(&self, packet: Packet) {
+ let _ = self.inner.as_ref().unwrap().send(packet);
+ }
+
+ fn end(&mut self) {
+ self.inner = None;
+ }
+}
+
+impl Drop for Sender {
+ fn drop(&mut self) {
+ if let Some(inner) = self.inner.take() {
+ let _ = inner.send(Err(std::io::Error::new(
+ std::io::ErrorKind::BrokenPipe,
+ "Netapp connection dropped before end of stream",
+ )));
+ }
+ }
+}
+
+/// The RecvLoop trait, which is implemented both by the client and the server
+/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()`
+/// and a prototype of a handler for received messages `.recv_handler()` that
+/// must be filled by implementors. `.recv_loop()` receives messages in a loop
+/// according to the protocol defined above: chunks of message in progress of being
+/// received are stored in a buffer, and when the last chunk of a message is received,
+/// the full message is passed to the receive handler.
+#[async_trait]
+pub(crate) trait RecvLoop: Sync + 'static {
+ fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream);
+ fn cancel_handler(self: &Arc<Self>, _id: RequestID) {}
+
+ async fn recv_loop<R>(self: Arc<Self>, mut read: R, debug_name: String) -> Result<(), Error>
+ where
+ R: AsyncReadExt + Unpin + Send + Sync,
+ {
+ let mut streams: HashMap<RequestID, Sender> = HashMap::new();
+ loop {
+ trace!(
+ "recv_loop({}): in_progress = {:?}",
+ debug_name,
+ streams.iter().map(|(id, _)| id).collect::<Vec<_>>()
+ );
+
+ let mut header_id = [0u8; RequestID::BITS as usize / 8];
+ match read.read_exact(&mut header_id[..]).await {
+ Ok(_) => (),
+ Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
+ Err(e) => return Err(e.into()),
+ };
+ let id = RequestID::from_be_bytes(header_id);
+
+ let mut header_size = [0u8; ChunkLength::BITS as usize / 8];
+ read.read_exact(&mut header_size[..]).await?;
+ let size = ChunkLength::from_be_bytes(header_size);
+
+ if size == CANCEL_REQUEST {
+ if let Some(mut stream) = streams.remove(&id) {
+ let _ = stream.send(Err(std::io::Error::new(
+ std::io::ErrorKind::Other,
+ "netapp: cancel requested",
+ )));
+ stream.end();
+ }
+ self.cancel_handler(id);
+ continue;
+ }
+
+ let has_cont = (size & CHUNK_FLAG_HAS_CONTINUATION) != 0;
+ let is_error = (size & CHUNK_FLAG_ERROR) != 0;
+ let size = (size & CHUNK_LENGTH_MASK) as usize;
+ let mut next_slice = vec![0; size as usize];
+ read.read_exact(&mut next_slice[..]).await?;
+
+ let packet = if is_error {
+ let kind = u8_to_io_errorkind(next_slice[0]);
+ let msg =
+ std::str::from_utf8(&next_slice[1..]).unwrap_or("<invalid utf8 error message>");
+ debug!(
+ "recv_loop({}): got id {}, error {:?}: {}",
+ debug_name, id, kind, msg
+ );
+ Some(Err(std::io::Error::new(kind, msg.to_string())))
+ } else {
+ trace!(
+ "recv_loop({}): got id {}, size {}, has_cont {}",
+ debug_name,
+ id,
+ size,
+ has_cont
+ );
+ if !next_slice.is_empty() {
+ Some(Ok(Bytes::from(next_slice)))
+ } else {
+ None
+ }
+ };
+
+ let mut sender = if let Some(send) = streams.remove(&(id)) {
+ send
+ } else {
+ let (send, recv) = mpsc::unbounded_channel();
+ trace!("recv_loop({}): id {} is new channel", debug_name, id);
+ self.recv_handler(
+ id,
+ Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(recv)),
+ );
+ Sender::new(send)
+ };
+
+ if let Some(packet) = packet {
+ // If we cannot put packet in channel, it means that the
+ // receiving end of the channel is disconnected.
+ // We still need to reach eos before dropping this sender
+ let _ = sender.send(packet);
+ }
+
+ if has_cont {
+ assert!(!is_error);
+ streams.insert(id, sender);
+ } else {
+ trace!("recv_loop({}): close channel id {}", debug_name, id);
+ sender.end();
+ }
+ }
+ Ok(())
+ }
+}
diff --git a/src/net/send.rs b/src/net/send.rs
new file mode 100644
index 00000000..0db0ba77
--- /dev/null
+++ b/src/net/send.rs
@@ -0,0 +1,356 @@
+use std::collections::{HashMap, VecDeque};
+use std::pin::Pin;
+use std::sync::Arc;
+use std::task::{Context, Poll};
+
+use async_trait::async_trait;
+use bytes::{BufMut, Bytes, BytesMut};
+use log::*;
+
+use futures::{AsyncWriteExt, Future};
+use kuska_handshake::async_std::BoxStreamWrite;
+use tokio::sync::mpsc;
+
+use crate::error::*;
+use crate::message::*;
+use crate::stream::*;
+
+// Messages are sent by chunks
+// Chunk format:
+// - u32 BE: request id (same for request and response)
+// - u16 BE: chunk length + flags:
+// CHUNK_FLAG_HAS_CONTINUATION when this is not the last chunk of the stream
+// CHUNK_FLAG_ERROR if this chunk denotes an error
+// (these two flags are exclusive, an error denotes the end of the stream)
+// **special value** 0xFFFF indicates a CANCEL message
+// - [u8; chunk_length], either
+// - if not error: chunk data
+// - if error:
+// - u8: error kind, encoded using error::io_errorkind_to_u8
+// - rest: error message
+// - absent for cancel messag
+
+pub(crate) type RequestID = u32;
+pub(crate) type ChunkLength = u16;
+
+pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0;
+pub(crate) const CHUNK_FLAG_ERROR: ChunkLength = 0x4000;
+pub(crate) const CHUNK_FLAG_HAS_CONTINUATION: ChunkLength = 0x8000;
+pub(crate) const CHUNK_LENGTH_MASK: ChunkLength = 0x3FFF;
+pub(crate) const CANCEL_REQUEST: ChunkLength = 0xFFFF;
+
+pub(crate) enum SendItem {
+ Stream(RequestID, RequestPriority, Option<OrderTag>, ByteStream),
+ Cancel(RequestID),
+}
+
+// ----
+
+struct SendQueue {
+ items: Vec<(u8, SendQueuePriority)>,
+}
+
+struct SendQueuePriority {
+ items: VecDeque<SendQueueItem>,
+ order: HashMap<u64, VecDeque<u64>>,
+}
+
+struct SendQueueItem {
+ id: RequestID,
+ prio: RequestPriority,
+ order_tag: Option<OrderTag>,
+ data: ByteStreamReader,
+ sent: usize,
+}
+
+impl SendQueue {
+ fn new() -> Self {
+ Self {
+ items: Vec::with_capacity(64),
+ }
+ }
+ fn push(&mut self, item: SendQueueItem) {
+ let prio = item.prio;
+ let pos_prio = match self.items.binary_search_by(|(p, _)| p.cmp(&prio)) {
+ Ok(i) => i,
+ Err(i) => {
+ self.items.insert(i, (prio, SendQueuePriority::new()));
+ i
+ }
+ };
+ self.items[pos_prio].1.push(item);
+ }
+ fn remove(&mut self, id: RequestID) {
+ for (_, prioq) in self.items.iter_mut() {
+ prioq.remove(id);
+ }
+ self.items.retain(|(_prio, q)| !q.is_empty());
+ }
+ fn is_empty(&self) -> bool {
+ self.items.iter().all(|(_k, v)| v.is_empty())
+ }
+
+ // this is like an async fn, but hand implemented
+ fn next_ready(&mut self) -> SendQueuePollNextReady<'_> {
+ SendQueuePollNextReady { queue: self }
+ }
+}
+
+impl SendQueuePriority {
+ fn new() -> Self {
+ Self {
+ items: VecDeque::new(),
+ order: HashMap::new(),
+ }
+ }
+ fn push(&mut self, item: SendQueueItem) {
+ if let Some(OrderTag(stream, order)) = item.order_tag {
+ let order_vec = self.order.entry(stream).or_default();
+ let i = order_vec.iter().take_while(|o2| **o2 < order).count();
+ order_vec.insert(i, order);
+ }
+ self.items.push_front(item);
+ }
+ fn remove(&mut self, id: RequestID) {
+ if let Some(i) = self.items.iter().position(|x| x.id == id) {
+ let item = self.items.remove(i).unwrap();
+ if let Some(OrderTag(stream, order)) = item.order_tag {
+ let order_vec = self.order.get_mut(&stream).unwrap();
+ let j = order_vec.iter().position(|x| *x == order).unwrap();
+ order_vec.remove(j).unwrap();
+ if order_vec.is_empty() {
+ self.order.remove(&stream);
+ }
+ }
+ }
+ }
+ fn is_empty(&self) -> bool {
+ self.items.is_empty()
+ }
+ fn poll_next_ready(&mut self, ctx: &mut Context<'_>) -> Poll<(RequestID, DataFrame)> {
+ for (j, item) in self.items.iter_mut().enumerate() {
+ if let Some(OrderTag(stream, order)) = item.order_tag {
+ if order > *self.order.get(&stream).unwrap().front().unwrap() {
+ continue;
+ }
+ }
+
+ let mut item_reader = item.data.read_exact_or_eos(MAX_CHUNK_LENGTH as usize);
+ if let Poll::Ready(bytes_or_err) = Pin::new(&mut item_reader).poll(ctx) {
+ let id = item.id;
+ let eos = item.data.eos();
+
+ let packet = bytes_or_err.map_err(|e| match e {
+ ReadExactError::Stream(err) => err,
+ _ => unreachable!(),
+ });
+
+ let is_err = packet.is_err();
+ let data_frame = DataFrame::from_packet(packet, !eos);
+ item.sent += data_frame.data().len();
+
+ if eos || is_err {
+ // If item had an order tag, remove it from the corresponding ordering list
+ if let Some(OrderTag(stream, order)) = item.order_tag {
+ let order_stream = self.order.get_mut(&stream).unwrap();
+ assert_eq!(order_stream.pop_front(), Some(order));
+ if order_stream.is_empty() {
+ self.order.remove(&stream);
+ }
+ }
+ // Remove item from sending queue
+ self.items.remove(j);
+ } else {
+ // Move item later in send queue to implement LAS scheduling
+ // (LAS = Least Attained Service)
+ for k in j..self.items.len() - 1 {
+ if self.items[k].sent >= self.items[k + 1].sent {
+ self.items.swap(k, k + 1);
+ } else {
+ break;
+ }
+ }
+ }
+
+ return Poll::Ready((id, data_frame));
+ }
+ }
+
+ Poll::Pending
+ }
+ fn dump(&self, prio: u8) -> String {
+ self.items
+ .iter()
+ .map(|i| format!("[{} {} {:?} @{}]", prio, i.id, i.order_tag, i.sent))
+ .collect::<Vec<_>>()
+ .join(" ")
+ }
+}
+
+struct SendQueuePollNextReady<'a> {
+ queue: &'a mut SendQueue,
+}
+
+impl<'a> futures::Future for SendQueuePollNextReady<'a> {
+ type Output = (RequestID, DataFrame);
+
+ fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
+ for (i, (_prio, items_at_prio)) in self.queue.items.iter_mut().enumerate() {
+ if let Poll::Ready(res) = items_at_prio.poll_next_ready(ctx) {
+ if items_at_prio.is_empty() {
+ self.queue.items.remove(i);
+ }
+ return Poll::Ready(res);
+ }
+ }
+ // If the queue is empty, this futures is eternally pending.
+ // This is ok because we use it in a select with another future
+ // that can interrupt it.
+ Poll::Pending
+ }
+}
+
+enum DataFrame {
+ /// a fixed size buffer containing some data + a boolean indicating whether
+ /// there may be more data comming from this stream. Can be used for some
+ /// optimization. It's an error to set it to false if there is more data, but it is correct
+ /// (albeit sub-optimal) to set it to true if there is nothing coming after
+ Data(Bytes, bool),
+ /// An error code automatically signals the end of the stream
+ Error(Bytes),
+}
+
+impl DataFrame {
+ fn from_packet(p: Packet, has_cont: bool) -> Self {
+ match p {
+ Ok(bytes) => {
+ assert!(bytes.len() <= MAX_CHUNK_LENGTH as usize);
+ Self::Data(bytes, has_cont)
+ }
+ Err(e) => {
+ let mut buf = BytesMut::new();
+ buf.put_u8(io_errorkind_to_u8(e.kind()));
+
+ let msg = format!("{}", e).into_bytes();
+ if msg.len() > (MAX_CHUNK_LENGTH - 1) as usize {
+ buf.put(&msg[..(MAX_CHUNK_LENGTH - 1) as usize]);
+ } else {
+ buf.put(&msg[..]);
+ }
+
+ Self::Error(buf.freeze())
+ }
+ }
+ }
+
+ fn header(&self) -> [u8; 2] {
+ let header_u16 = match self {
+ DataFrame::Data(data, false) => data.len() as u16,
+ DataFrame::Data(data, true) => data.len() as u16 | CHUNK_FLAG_HAS_CONTINUATION,
+ DataFrame::Error(msg) => msg.len() as u16 | CHUNK_FLAG_ERROR,
+ };
+ ChunkLength::to_be_bytes(header_u16)
+ }
+
+ fn data(&self) -> &[u8] {
+ match self {
+ DataFrame::Data(ref data, _) => &data[..],
+ DataFrame::Error(ref msg) => &msg[..],
+ }
+ }
+}
+
+/// The SendLoop trait, which is implemented both by the client and the server
+/// connection objects (ServerConna and ClientConn) adds a method `.send_loop()`
+/// that takes a channel of messages to send and an asynchronous writer,
+/// and sends messages from the channel to the async writer, putting them in a queue
+/// before being sent and doing the round-robin sending strategy.
+///
+/// The `.send_loop()` exits when the sending end of the channel is closed,
+/// or if there is an error at any time writing to the async writer.
+#[async_trait]
+pub(crate) trait SendLoop: Sync {
+ async fn send_loop<W>(
+ self: Arc<Self>,
+ msg_recv: mpsc::UnboundedReceiver<SendItem>,
+ mut write: BoxStreamWrite<W>,
+ debug_name: String,
+ ) -> Result<(), Error>
+ where
+ W: AsyncWriteExt + Unpin + Send + Sync,
+ {
+ let mut sending = SendQueue::new();
+ let mut msg_recv = Some(msg_recv);
+ while msg_recv.is_some() || !sending.is_empty() {
+ trace!(
+ "send_loop({}): queue = {:?}",
+ debug_name,
+ sending
+ .items
+ .iter()
+ .map(|(prio, i)| i.dump(*prio))
+ .collect::<Vec<_>>()
+ .join(" ; ")
+ );
+
+ let recv_fut = async {
+ if let Some(chan) = &mut msg_recv {
+ chan.recv().await
+ } else {
+ futures::future::pending().await
+ }
+ };
+ let send_fut = sending.next_ready();
+
+ // recv_fut is cancellation-safe according to tokio doc,
+ // send_fut is cancellation-safe as implemented above?
+ tokio::select! {
+ biased; // always read incomming channel first if it has data
+ sth = recv_fut => {
+ match sth {
+ Some(SendItem::Stream(id, prio, order_tag, data)) => {
+ trace!("send_loop({}): add stream {} to send", debug_name, id);
+ sending.push(SendQueueItem {
+ id,
+ prio,
+ order_tag,
+ data: ByteStreamReader::new(data),
+ sent: 0,
+ })
+ }
+ Some(SendItem::Cancel(id)) => {
+ trace!("send_loop({}): cancelling {}", debug_name, id);
+ sending.remove(id);
+ let header_id = RequestID::to_be_bytes(id);
+ write.write_all(&header_id[..]).await?;
+ write.write_all(&ChunkLength::to_be_bytes(CANCEL_REQUEST)).await?;
+ write.flush().await?;
+ }
+ None => {
+ msg_recv = None;
+ }
+ };
+ }
+ (id, data) = send_fut => {
+ trace!(
+ "send_loop({}): id {}, send {} bytes, header_size {}",
+ debug_name,
+ id,
+ data.data().len(),
+ hex::encode(data.header())
+ );
+
+ let header_id = RequestID::to_be_bytes(id);
+ write.write_all(&header_id[..]).await?;
+
+ write.write_all(&data.header()).await?;
+ write.write_all(data.data()).await?;
+ write.flush().await?;
+ }
+ }
+ }
+
+ let _ = write.goodbye().await;
+ Ok(())
+ }
+}
diff --git a/src/net/server.rs b/src/net/server.rs
new file mode 100644
index 00000000..55b9e678
--- /dev/null
+++ b/src/net/server.rs
@@ -0,0 +1,222 @@
+use std::collections::HashMap;
+use std::net::SocketAddr;
+use std::sync::{Arc, Mutex};
+
+use arc_swap::ArcSwapOption;
+use async_trait::async_trait;
+use log::*;
+
+use futures::io::{AsyncReadExt, AsyncWriteExt};
+use kuska_handshake::async_std::{handshake_server, BoxStream};
+use tokio::net::TcpStream;
+use tokio::select;
+use tokio::sync::{mpsc, watch};
+use tokio_util::compat::*;
+
+#[cfg(feature = "telemetry")]
+use opentelemetry::{
+ trace::{FutureExt, Span, SpanKind, TraceContextExt, TraceId, Tracer},
+ Context, KeyValue,
+};
+#[cfg(feature = "telemetry")]
+use opentelemetry_contrib::trace::propagator::binary::*;
+#[cfg(feature = "telemetry")]
+use rand::{thread_rng, Rng};
+
+use crate::error::*;
+use crate::message::*;
+use crate::netapp::*;
+use crate::recv::*;
+use crate::send::*;
+use crate::stream::*;
+use crate::util::*;
+
+// The client and server connection structs (client.rs and server.rs)
+// build upon the chunking mechanism which is exclusively contained
+// in proto.rs.
+// Here, we just care about sending big messages without size limit.
+// The format of these messages is described below.
+// Chunking happens independently.
+
+// Request message format (client -> server):
+// - u8 priority
+// - u8 path length
+// - [u8; path length] path
+// - [u8; *] data
+
+// Response message format (server -> client):
+// - u8 response code
+// - [u8; *] response
+
+pub(crate) struct ServerConn {
+ pub(crate) remote_addr: SocketAddr,
+ pub(crate) peer_id: NodeID,
+
+ netapp: Arc<NetApp>,
+
+ resp_send: ArcSwapOption<mpsc::UnboundedSender<SendItem>>,
+ running_handlers: Mutex<HashMap<RequestID, tokio::task::JoinHandle<()>>>,
+}
+
+impl ServerConn {
+ pub(crate) async fn run(
+ netapp: Arc<NetApp>,
+ socket: TcpStream,
+ must_exit: watch::Receiver<bool>,
+ ) -> Result<(), Error> {
+ let remote_addr = socket.peer_addr()?;
+ let mut socket = socket.compat();
+
+ // Do handshake to authenticate client
+ let handshake = handshake_server(
+ &mut socket,
+ netapp.netid.clone(),
+ netapp.id,
+ netapp.privkey.clone(),
+ )
+ .await?;
+ let peer_id = handshake.peer_pk;
+
+ debug!(
+ "Handshake complete (server) with {}@{}",
+ hex::encode(peer_id),
+ remote_addr
+ );
+
+ // Create BoxStream layer that encodes content
+ let (read, write) = socket.split();
+ let (read, mut write) =
+ BoxStream::from_handshake(read, write, handshake, 0x8000).split_read_write();
+
+ // Before doing anything, send version tag, so that client
+ // can check and disconnect if version is wrong
+ write.write_all(&netapp.version_tag[..]).await?;
+ write.flush().await?;
+
+ // Build and launch stuff that handles requests server-side
+ let (resp_send, resp_recv) = mpsc::unbounded_channel();
+
+ let conn = Arc::new(ServerConn {
+ netapp: netapp.clone(),
+ remote_addr,
+ peer_id,
+ resp_send: ArcSwapOption::new(Some(Arc::new(resp_send))),
+ running_handlers: Mutex::new(HashMap::new()),
+ });
+
+ netapp.connected_as_server(peer_id, conn.clone());
+
+ let debug_name = format!("SRV {}", hex::encode(&peer_id[..8]));
+ let debug_name_2 = debug_name.clone();
+
+ let conn2 = conn.clone();
+ let recv_future = tokio::spawn(async move {
+ select! {
+ r = conn2.recv_loop(read, debug_name_2) => r,
+ _ = await_exit(must_exit) => Ok(())
+ }
+ });
+ let send_future = tokio::spawn(conn.clone().send_loop(resp_recv, write, debug_name));
+
+ recv_future.await.log_err("ServerConn recv_loop");
+ conn.resp_send.store(None);
+ send_future.await.log_err("ServerConn send_loop");
+
+ netapp.disconnected_as_server(&peer_id, conn);
+
+ Ok(())
+ }
+
+ async fn recv_handler_aux(self: &Arc<Self>, req_enc: ReqEnc) -> Result<RespEnc, Error> {
+ let path = String::from_utf8(req_enc.path.to_vec())?;
+
+ let handler_opt = {
+ let endpoints = self.netapp.endpoints.read().unwrap();
+ endpoints.get(&path).map(|e| e.clone_endpoint())
+ };
+
+ if let Some(handler) = handler_opt {
+ cfg_if::cfg_if! {
+ if #[cfg(feature = "telemetry")] {
+ let tracer = opentelemetry::global::tracer("netapp");
+
+ let mut span = if !req_enc.telemetry_id.is_empty() {
+ let propagator = BinaryPropagator::new();
+ let context = propagator.from_bytes(req_enc.telemetry_id.to_vec());
+ let context = Context::new().with_remote_span_context(context);
+ tracer.span_builder(format!(">> RPC {}", path))
+ .with_kind(SpanKind::Server)
+ .start_with_context(&tracer, &context)
+ } else {
+ let mut rng = thread_rng();
+ let trace_id = TraceId::from_bytes(rng.gen());
+ tracer
+ .span_builder(format!(">> RPC {}", path))
+ .with_kind(SpanKind::Server)
+ .with_trace_id(trace_id)
+ .start(&tracer)
+ };
+ span.set_attribute(KeyValue::new("path", path.to_string()));
+ span.set_attribute(KeyValue::new("len_query_msg", req_enc.msg.len() as i64));
+
+ handler.handle(req_enc, self.peer_id)
+ .with_context(Context::current_with_span(span))
+ .await
+ } else {
+ handler.handle(req_enc, self.peer_id).await
+ }
+ }
+ } else {
+ Err(Error::NoHandler)
+ }
+ }
+}
+
+impl SendLoop for ServerConn {}
+
+#[async_trait]
+impl RecvLoop for ServerConn {
+ fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream) {
+ let resp_send = match self.resp_send.load_full() {
+ Some(c) => c,
+ None => return,
+ };
+
+ let mut rh = self.running_handlers.lock().unwrap();
+
+ let self2 = self.clone();
+ let jh = tokio::spawn(async move {
+ debug!("server: recv_handler got {}", id);
+
+ let (prio, resp_enc_result) = match ReqEnc::decode(stream).await {
+ Ok(req_enc) => (req_enc.prio, self2.recv_handler_aux(req_enc).await),
+ Err(e) => (PRIO_HIGH, Err(e)),
+ };
+
+ debug!("server: sending response to {}", id);
+
+ let (resp_stream, resp_order) = RespEnc::encode(resp_enc_result);
+ resp_send
+ .send(SendItem::Stream(id, prio, resp_order, resp_stream))
+ .log_err("ServerConn recv_handler send resp bytes");
+
+ self2.running_handlers.lock().unwrap().remove(&id);
+ });
+
+ rh.insert(id, jh);
+ }
+
+ fn cancel_handler(self: &Arc<Self>, id: RequestID) {
+ trace!("received cancel for request {}", id);
+
+ // If the handler is still running, abort it now
+ if let Some(jh) = self.running_handlers.lock().unwrap().remove(&id) {
+ jh.abort();
+ }
+
+ // Inform the response sender that we don't need to send the response
+ if let Some(resp_send) = self.resp_send.load_full() {
+ let _ = resp_send.send(SendItem::Cancel(id));
+ }
+ }
+}
diff --git a/src/net/stream.rs b/src/net/stream.rs
new file mode 100644
index 00000000..88c3fed4
--- /dev/null
+++ b/src/net/stream.rs
@@ -0,0 +1,202 @@
+use std::pin::Pin;
+use std::task::{Context, Poll};
+
+use bytes::Bytes;
+
+use futures::Future;
+use futures::{Stream, StreamExt};
+use tokio::io::AsyncRead;
+
+use crate::bytes_buf::BytesBuf;
+
+/// A stream of bytes (click to read more).
+///
+/// When sent through Netapp, the Vec may be split in smaller chunk in such a way
+/// consecutive Vec may get merged, but Vec and error code may not be reordered
+///
+/// Items sent in the ByteStream may be errors of type `std::io::Error`.
+/// An error indicates the end of the ByteStream: a reader should no longer read
+/// after recieving an error, and a writer should stop writing after sending an error.
+pub type ByteStream = Pin<Box<dyn Stream<Item = Packet> + Send + Sync>>;
+
+/// A packet sent in a ByteStream, which may contain either
+/// a Bytes object or an error
+pub type Packet = Result<Bytes, std::io::Error>;
+
+// ----
+
+/// A helper struct to read defined lengths of data from a BytesStream
+pub struct ByteStreamReader {
+ stream: ByteStream,
+ buf: BytesBuf,
+ eos: bool,
+ err: Option<std::io::Error>,
+}
+
+impl ByteStreamReader {
+ /// Creates a new `ByteStreamReader` from a `ByteStream`
+ pub fn new(stream: ByteStream) -> Self {
+ ByteStreamReader {
+ stream,
+ buf: BytesBuf::new(),
+ eos: false,
+ err: None,
+ }
+ }
+
+ /// Read exactly `read_len` bytes from the underlying stream
+ /// (returns a future)
+ pub fn read_exact(&mut self, read_len: usize) -> ByteStreamReadExact<'_> {
+ ByteStreamReadExact {
+ reader: self,
+ read_len,
+ fail_on_eos: true,
+ }
+ }
+
+ /// Read at most `read_len` bytes from the underlying stream, or less
+ /// if the end of the stream is reached (returns a future)
+ pub fn read_exact_or_eos(&mut self, read_len: usize) -> ByteStreamReadExact<'_> {
+ ByteStreamReadExact {
+ reader: self,
+ read_len,
+ fail_on_eos: false,
+ }
+ }
+
+ /// Read exactly one byte from the underlying stream and returns it
+ /// as an u8
+ pub async fn read_u8(&mut self) -> Result<u8, ReadExactError> {
+ Ok(self.read_exact(1).await?[0])
+ }
+
+ /// Read exactly two bytes from the underlying stream and returns them as an u16 (using
+ /// big-endian decoding)
+ pub async fn read_u16(&mut self) -> Result<u16, ReadExactError> {
+ let bytes = self.read_exact(2).await?;
+ let mut b = [0u8; 2];
+ b.copy_from_slice(&bytes[..]);
+ Ok(u16::from_be_bytes(b))
+ }
+
+ /// Read exactly four bytes from the underlying stream and returns them as an u32 (using
+ /// big-endian decoding)
+ pub async fn read_u32(&mut self) -> Result<u32, ReadExactError> {
+ let bytes = self.read_exact(4).await?;
+ let mut b = [0u8; 4];
+ b.copy_from_slice(&bytes[..]);
+ Ok(u32::from_be_bytes(b))
+ }
+
+ /// Transforms the stream reader back into the underlying stream (starting
+ /// after everything that the reader has read)
+ pub fn into_stream(self) -> ByteStream {
+ let buf_stream = futures::stream::iter(self.buf.into_slices().into_iter().map(Ok));
+ if let Some(err) = self.err {
+ Box::pin(buf_stream.chain(futures::stream::once(async move { Err(err) })))
+ } else if self.eos {
+ Box::pin(buf_stream)
+ } else {
+ Box::pin(buf_stream.chain(self.stream))
+ }
+ }
+
+ /// Tries to fill the internal read buffer from the underlying stream if it is empty.
+ /// Calling this might be necessary to ensure that `.eos()` returns a correct
+ /// result, otherwise the reader might not be aware that the underlying
+ /// stream has nothing left to return.
+ pub async fn fill_buffer(&mut self) {
+ if self.buf.is_empty() {
+ let packet = self.stream.next().await;
+ self.add_stream_next(packet);
+ }
+ }
+
+ /// Clears the internal read buffer and returns its content
+ pub fn take_buffer(&mut self) -> Bytes {
+ self.buf.take_all()
+ }
+
+ /// Returns true if the end of the underlying stream has been reached
+ pub fn eos(&self) -> bool {
+ self.buf.is_empty() && self.eos
+ }
+
+ fn try_get(&mut self, read_len: usize) -> Option<Bytes> {
+ self.buf.take_exact(read_len)
+ }
+
+ fn add_stream_next(&mut self, packet: Option<Packet>) {
+ match packet {
+ Some(Ok(slice)) => {
+ self.buf.extend(slice);
+ }
+ Some(Err(e)) => {
+ self.err = Some(e);
+ self.eos = true;
+ }
+ None => {
+ self.eos = true;
+ }
+ }
+ }
+}
+
+/// The error kind that can be returned by `ByteStreamReader::read_exact` and
+/// `ByteStreamReader::read_exact_or_eos`
+pub enum ReadExactError {
+ /// The end of the stream was reached before the requested number of bytes could be read
+ UnexpectedEos,
+ /// The underlying data stream returned an IO error when trying to read
+ Stream(std::io::Error),
+}
+
+/// The future returned by `ByteStreamReader::read_exact` and
+/// `ByteStreamReader::read_exact_or_eos`
+#[pin_project::pin_project]
+pub struct ByteStreamReadExact<'a> {
+ #[pin]
+ reader: &'a mut ByteStreamReader,
+ read_len: usize,
+ fail_on_eos: bool,
+}
+
+impl<'a> Future for ByteStreamReadExact<'a> {
+ type Output = Result<Bytes, ReadExactError>;
+
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<Bytes, ReadExactError>> {
+ let mut this = self.project();
+
+ loop {
+ if let Some(bytes) = this.reader.try_get(*this.read_len) {
+ return Poll::Ready(Ok(bytes));
+ }
+ if let Some(err) = &this.reader.err {
+ let err = std::io::Error::new(err.kind(), format!("{}", err));
+ return Poll::Ready(Err(ReadExactError::Stream(err)));
+ }
+ if this.reader.eos {
+ if *this.fail_on_eos {
+ return Poll::Ready(Err(ReadExactError::UnexpectedEos));
+ } else {
+ return Poll::Ready(Ok(this.reader.take_buffer()));
+ }
+ }
+
+ let next_packet = futures::ready!(this.reader.stream.as_mut().poll_next(cx));
+ this.reader.add_stream_next(next_packet);
+ }
+ }
+}
+
+// ----
+
+/// Turns a `tokio::io::AsyncRead` asynchronous reader into a `ByteStream`
+pub fn asyncread_stream<R: AsyncRead + Send + Sync + 'static>(reader: R) -> ByteStream {
+ Box::pin(tokio_util::io::ReaderStream::new(reader))
+}
+
+/// Turns a `ByteStream` into a `tokio::io::AsyncRead` asynchronous reader
+pub fn stream_asyncread(stream: ByteStream) -> impl AsyncRead + Send + Sync + 'static {
+ tokio_util::io::StreamReader::new(stream)
+}
diff --git a/src/net/test.rs b/src/net/test.rs
new file mode 100644
index 00000000..c6259752
--- /dev/null
+++ b/src/net/test.rs
@@ -0,0 +1,118 @@
+use std::net::SocketAddr;
+use std::sync::Arc;
+use std::time::Duration;
+
+use tokio::select;
+use tokio::sync::watch;
+
+use sodiumoxide::crypto::auth;
+use sodiumoxide::crypto::sign::ed25519;
+
+use crate::netapp::*;
+use crate::peering::*;
+use crate::NodeID;
+
+#[tokio::test(flavor = "current_thread")]
+async fn test_with_basic_scheduler() {
+ pretty_env_logger::init();
+ run_test(19980).await
+}
+
+#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
+async fn test_with_threaded_scheduler() {
+ run_test(19990).await
+}
+
+async fn run_test(port_base: u16) {
+ select! {
+ _ = run_test_inner(port_base) => (),
+ _ = tokio::time::sleep(Duration::from_secs(20)) => panic!("timeout"),
+ }
+}
+
+async fn run_test_inner(port_base: u16) {
+ let netid = auth::gen_key();
+
+ let (pk1, sk1) = ed25519::gen_keypair();
+ let (pk2, sk2) = ed25519::gen_keypair();
+ let (pk3, sk3) = ed25519::gen_keypair();
+
+ let addr1: SocketAddr = SocketAddr::new("127.0.0.1".parse().unwrap(), port_base);
+ let addr2: SocketAddr = SocketAddr::new("127.0.0.1".parse().unwrap(), port_base + 1);
+ let addr3: SocketAddr = SocketAddr::new("127.0.0.1".parse().unwrap(), port_base + 2);
+
+ let (stop_tx, stop_rx) = watch::channel(false);
+
+ let (thread1, _netapp1, peering1) =
+ run_netapp(netid.clone(), pk1, sk1, addr1, vec![], stop_rx.clone());
+ tokio::time::sleep(Duration::from_secs(2)).await;
+
+ // Connect second node and check it peers with everyone
+ let (thread2, _netapp2, peering2) = run_netapp(
+ netid.clone(),
+ pk2,
+ sk2,
+ addr2,
+ vec![(pk1, addr1)],
+ stop_rx.clone(),
+ );
+ tokio::time::sleep(Duration::from_secs(3)).await;
+
+ let pl1 = peering1.get_peer_list();
+ println!("A pl1: {:?}", pl1);
+ assert_eq!(pl1.len(), 2);
+
+ let pl2 = peering2.get_peer_list();
+ println!("A pl2: {:?}", pl2);
+ assert_eq!(pl2.len(), 2);
+
+ // Connect third ndoe and check it peers with everyone
+ let (thread3, _netapp3, peering3) =
+ run_netapp(netid, pk3, sk3, addr3, vec![(pk2, addr2)], stop_rx.clone());
+ tokio::time::sleep(Duration::from_secs(3)).await;
+
+ let pl1 = peering1.get_peer_list();
+ println!("B pl1: {:?}", pl1);
+ assert_eq!(pl1.len(), 3);
+
+ let pl2 = peering2.get_peer_list();
+ println!("B pl2: {:?}", pl2);
+ assert_eq!(pl2.len(), 3);
+
+ let pl3 = peering3.get_peer_list();
+ println!("B pl3: {:?}", pl3);
+ assert_eq!(pl3.len(), 3);
+
+ // Send stop signal and wait for everyone to finish
+ stop_tx.send(true).unwrap();
+ thread1.await.unwrap();
+ thread2.await.unwrap();
+ thread3.await.unwrap();
+}
+
+fn run_netapp(
+ netid: auth::Key,
+ _pk: NodeID,
+ sk: ed25519::SecretKey,
+ listen_addr: SocketAddr,
+ bootstrap_peers: Vec<(NodeID, SocketAddr)>,
+ must_exit: watch::Receiver<bool>,
+) -> (
+ tokio::task::JoinHandle<()>,
+ Arc<NetApp>,
+ Arc<PeeringManager>,
+) {
+ let netapp = NetApp::new(0u64, netid, sk);
+ let peering = PeeringManager::new(netapp.clone(), bootstrap_peers, None);
+
+ let peering2 = peering.clone();
+ let netapp2 = netapp.clone();
+ let fut = tokio::spawn(async move {
+ tokio::join!(
+ netapp2.listen(listen_addr, None, must_exit.clone()),
+ peering2.run(must_exit.clone()),
+ );
+ });
+
+ (fut, netapp, peering)
+}
diff --git a/src/net/util.rs b/src/net/util.rs
new file mode 100644
index 00000000..56230b73
--- /dev/null
+++ b/src/net/util.rs
@@ -0,0 +1,96 @@
+use std::net::SocketAddr;
+
+use log::info;
+use serde::Serialize;
+
+use tokio::sync::watch;
+
+use crate::netapp::*;
+
+/// Utility function: encodes any serializable value in MessagePack binary format
+/// using the RMP library.
+///
+/// Field names and variant names are included in the serialization.
+/// This is used internally by the netapp communication protocol.
+pub fn rmp_to_vec_all_named<T>(val: &T) -> Result<Vec<u8>, rmp_serde::encode::Error>
+where
+ T: Serialize + ?Sized,
+{
+ let mut wr = Vec::with_capacity(128);
+ let mut se = rmp_serde::Serializer::new(&mut wr).with_struct_map();
+ val.serialize(&mut se)?;
+ Ok(wr)
+}
+
+/// This async function returns only when a true signal was received
+/// from a watcher that tells us when to exit.
+///
+/// Usefull in a select statement to interrupt another
+/// future:
+/// ```ignore
+/// select!(
+/// _ = a_long_task() => Success,
+/// _ = await_exit(must_exit) => Interrupted,
+/// )
+/// ```
+pub async fn await_exit(mut must_exit: watch::Receiver<bool>) {
+ while !*must_exit.borrow_and_update() {
+ if must_exit.changed().await.is_err() {
+ break;
+ }
+ }
+}
+
+/// Creates a watch that contains `false`, and that changes
+/// to `true` when a Ctrl+C signal is received.
+pub fn watch_ctrl_c() -> watch::Receiver<bool> {
+ let (send_cancel, watch_cancel) = watch::channel(false);
+ tokio::spawn(async move {
+ tokio::signal::ctrl_c()
+ .await
+ .expect("failed to install CTRL+C signal handler");
+ info!("Received CTRL+C, shutting down.");
+ send_cancel.send(true).unwrap();
+ });
+ watch_cancel
+}
+
+/// Parse a peer's address including public key, written in the format:
+/// `<public key hex>@<ip>:<port>`
+pub fn parse_peer_addr(peer: &str) -> Option<(NodeID, SocketAddr)> {
+ let delim = peer.find('@')?;
+ let (key, ip) = peer.split_at(delim);
+ let pubkey = NodeID::from_slice(&hex::decode(key).ok()?)?;
+ let ip = ip[1..].parse::<SocketAddr>().ok()?;
+ Some((pubkey, ip))
+}
+
+/// Parse and resolve a peer's address including public key, written in the format:
+/// `<public key hex>@<ip or hostname>:<port>`
+pub fn parse_and_resolve_peer_addr(peer: &str) -> Option<(NodeID, Vec<SocketAddr>)> {
+ use std::net::ToSocketAddrs;
+
+ let delim = peer.find('@')?;
+ let (key, host) = peer.split_at(delim);
+ let pubkey = NodeID::from_slice(&hex::decode(key).ok()?)?;
+ let hosts = host[1..].to_socket_addrs().ok()?.collect::<Vec<_>>();
+ if hosts.is_empty() {
+ return None;
+ }
+ Some((pubkey, hosts))
+}
+
+/// async version of parse_and_resolve_peer_addr
+pub async fn parse_and_resolve_peer_addr_async(peer: &str) -> Option<(NodeID, Vec<SocketAddr>)> {
+ let delim = peer.find('@')?;
+ let (key, host) = peer.split_at(delim);
+ let pubkey = NodeID::from_slice(&hex::decode(key).ok()?)?;
+ let hosts = tokio::net::lookup_host(&host[1..])
+ .await
+ .ok()?
+ .collect::<Vec<_>>();
+ if hosts.is_empty() {
+ return None;
+ }
+ Some((pubkey, hosts))
+}