diff options
Diffstat (limited to 'src/net')
-rw-r--r-- | src/net/Cargo.toml | 45 | ||||
-rw-r--r-- | src/net/bytes_buf.rs | 186 | ||||
-rw-r--r-- | src/net/client.rs | 292 | ||||
-rw-r--r-- | src/net/endpoint.rs | 201 | ||||
-rw-r--r-- | src/net/error.rs | 126 | ||||
-rw-r--r-- | src/net/lib.rs | 35 | ||||
-rw-r--r-- | src/net/message.rs | 522 | ||||
-rw-r--r-- | src/net/netapp.rs | 452 | ||||
-rw-r--r-- | src/net/peering.rs | 614 | ||||
-rw-r--r-- | src/net/recv.rs | 153 | ||||
-rw-r--r-- | src/net/send.rs | 356 | ||||
-rw-r--r-- | src/net/server.rs | 222 | ||||
-rw-r--r-- | src/net/stream.rs | 202 | ||||
-rw-r--r-- | src/net/test.rs | 118 | ||||
-rw-r--r-- | src/net/util.rs | 96 |
15 files changed, 3620 insertions, 0 deletions
diff --git a/src/net/Cargo.toml b/src/net/Cargo.toml new file mode 100644 index 00000000..a2674498 --- /dev/null +++ b/src/net/Cargo.toml @@ -0,0 +1,45 @@ +[package] +name = "garage_net" +version = "0.9.1" +authors = ["Alex Auvolat <alex@adnab.me>"] +edition = "2018" +license-file = "AGPL-3.0" +description = "Networking library for Garage RPC communication, forked from Netapp" +repository = "https://git.deuxfleurs.fr/Deuxfleurs/garage" +readme = "../../README.md" + +[lib] +path = "lib.rs" + +[features] +default = [] +telemetry = ["opentelemetry", "opentelemetry-contrib"] + +[dependencies] +futures.workspace = true +pin-project.workspace = true +tokio.workspace = true +tokio-util.workspace = true +tokio-stream.workspace = true + +serde.workspace = true +rmp-serde.workspace = true +hex.workspace = true + +rand.workspace = true + +log.workspace = true +arc-swap.workspace = true +async-trait.workspace = true +err-derive.workspace = true +bytes.workspace = true +cfg-if.workspace = true + +sodiumoxide.workspace = true +kuska-handshake.workspace = true + +opentelemetry = { workspace = true, optional = true } +opentelemetry-contrib = { workspace = true, optional = true } + +[dev-dependencies] +pretty_env_logger.workspace = true diff --git a/src/net/bytes_buf.rs b/src/net/bytes_buf.rs new file mode 100644 index 00000000..3929a860 --- /dev/null +++ b/src/net/bytes_buf.rs @@ -0,0 +1,186 @@ +use std::cmp::Ordering; +use std::collections::VecDeque; + +use bytes::BytesMut; + +pub use bytes::Bytes; + +/// A circular buffer of bytes, internally represented as a list of Bytes +/// for optimization, but that for all intent and purposes acts just like +/// a big byte slice which can be extended on the right and from which +/// stuff can be taken on the left. +pub struct BytesBuf { + buf: VecDeque<Bytes>, + buf_len: usize, +} + +impl BytesBuf { + /// Creates a new empty BytesBuf + pub fn new() -> Self { + Self { + buf: VecDeque::new(), + buf_len: 0, + } + } + + /// Returns the number of bytes stored in the BytesBuf + #[inline] + pub fn len(&self) -> usize { + self.buf_len + } + + /// Returns true iff the BytesBuf contains zero bytes + #[inline] + pub fn is_empty(&self) -> bool { + self.buf_len == 0 + } + + /// Adds some bytes to the right of the buffer + pub fn extend(&mut self, b: Bytes) { + if !b.is_empty() { + self.buf_len += b.len(); + self.buf.push_back(b); + } + } + + /// Takes the whole content of the buffer and returns it as a single Bytes unit + pub fn take_all(&mut self) -> Bytes { + if self.buf.is_empty() { + Bytes::new() + } else if self.buf.len() == 1 { + self.buf_len = 0; + self.buf.pop_back().unwrap() + } else { + let mut ret = BytesMut::with_capacity(self.buf_len); + for b in self.buf.iter() { + ret.extend_from_slice(&b[..]); + } + self.buf.clear(); + self.buf_len = 0; + ret.freeze() + } + } + + /// Takes at most max_len bytes from the left of the buffer + pub fn take_max(&mut self, max_len: usize) -> Bytes { + if self.buf_len <= max_len { + self.take_all() + } else { + self.take_exact_ok(max_len) + } + } + + /// Take exactly len bytes from the left of the buffer, returns None if + /// the BytesBuf doesn't contain enough data + pub fn take_exact(&mut self, len: usize) -> Option<Bytes> { + if self.buf_len < len { + None + } else { + Some(self.take_exact_ok(len)) + } + } + + fn take_exact_ok(&mut self, len: usize) -> Bytes { + assert!(len <= self.buf_len); + let front = self.buf.pop_front().unwrap(); + match front.len().cmp(&len) { + Ordering::Greater => { + self.buf.push_front(front.slice(len..)); + self.buf_len -= len; + front.slice(..len) + } + Ordering::Equal => { + self.buf_len -= len; + front + } + Ordering::Less => { + let mut ret = BytesMut::with_capacity(len); + ret.extend_from_slice(&front[..]); + self.buf_len -= front.len(); + while ret.len() < len { + let front = self.buf.pop_front().unwrap(); + if front.len() > len - ret.len() { + let take = len - ret.len(); + ret.extend_from_slice(&front[..take]); + self.buf.push_front(front.slice(take..)); + self.buf_len -= take; + break; + } else { + ret.extend_from_slice(&front[..]); + self.buf_len -= front.len(); + } + } + ret.freeze() + } + } + } + + /// Return the internal sequence of Bytes slices that make up the buffer + pub fn into_slices(self) -> VecDeque<Bytes> { + self.buf + } +} + +impl Default for BytesBuf { + fn default() -> Self { + Self::new() + } +} + +impl From<Bytes> for BytesBuf { + fn from(b: Bytes) -> BytesBuf { + let mut ret = BytesBuf::new(); + ret.extend(b); + ret + } +} + +impl From<BytesBuf> for Bytes { + fn from(mut b: BytesBuf) -> Bytes { + b.take_all() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_bytes_buf() { + let mut buf = BytesBuf::new(); + assert!(buf.len() == 0); + assert!(buf.is_empty()); + + buf.extend(Bytes::from(b"Hello, world!".to_vec())); + assert!(buf.len() == 13); + assert!(!buf.is_empty()); + + buf.extend(Bytes::from(b"1234567890".to_vec())); + assert!(buf.len() == 23); + assert!(!buf.is_empty()); + + assert_eq!( + buf.take_all(), + Bytes::from(b"Hello, world!1234567890".to_vec()) + ); + assert!(buf.len() == 0); + assert!(buf.is_empty()); + + buf.extend(Bytes::from(b"1234567890".to_vec())); + buf.extend(Bytes::from(b"Hello, world!".to_vec())); + assert!(buf.len() == 23); + assert!(!buf.is_empty()); + + assert_eq!(buf.take_max(12), Bytes::from(b"1234567890He".to_vec())); + assert!(buf.len() == 11); + + assert_eq!(buf.take_exact(12), None); + assert!(buf.len() == 11); + assert_eq!( + buf.take_exact(11), + Some(Bytes::from(b"llo, world!".to_vec())) + ); + assert!(buf.len() == 0); + assert!(buf.is_empty()); + } +} diff --git a/src/net/client.rs b/src/net/client.rs new file mode 100644 index 00000000..607dd173 --- /dev/null +++ b/src/net/client.rs @@ -0,0 +1,292 @@ +use std::collections::HashMap; +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::atomic::{self, AtomicU32}; +use std::sync::{Arc, Mutex}; +use std::task::Poll; + +use arc_swap::ArcSwapOption; +use async_trait::async_trait; +use bytes::Bytes; +use log::{debug, error, trace}; + +use futures::io::AsyncReadExt; +use futures::Stream; +use kuska_handshake::async_std::{handshake_client, BoxStream}; +use tokio::net::TcpStream; +use tokio::select; +use tokio::sync::{mpsc, oneshot, watch}; +use tokio_util::compat::*; + +#[cfg(feature = "telemetry")] +use opentelemetry::{ + trace::{FutureExt, Span, SpanKind, TraceContextExt, Tracer}, + Context, KeyValue, +}; +#[cfg(feature = "telemetry")] +use opentelemetry_contrib::trace::propagator::binary::*; + +use crate::error::*; +use crate::message::*; +use crate::netapp::*; +use crate::recv::*; +use crate::send::*; +use crate::stream::*; +use crate::util::*; + +pub(crate) struct ClientConn { + pub(crate) remote_addr: SocketAddr, + pub(crate) peer_id: NodeID, + + query_send: ArcSwapOption<mpsc::UnboundedSender<SendItem>>, + + next_query_number: AtomicU32, + inflight: Mutex<HashMap<RequestID, oneshot::Sender<ByteStream>>>, +} + +impl ClientConn { + pub(crate) async fn init( + netapp: Arc<NetApp>, + socket: TcpStream, + peer_id: NodeID, + ) -> Result<(), Error> { + let remote_addr = socket.peer_addr()?; + let mut socket = socket.compat(); + + // Do handshake to authenticate and prove our identity to server + let handshake = handshake_client( + &mut socket, + netapp.netid.clone(), + netapp.id, + netapp.privkey.clone(), + peer_id, + ) + .await?; + + debug!( + "Handshake complete (client) with {}@{}", + hex::encode(peer_id), + remote_addr + ); + + // Create BoxStream layer that encodes content + let (read, write) = socket.split(); + let (mut read, write) = + BoxStream::from_handshake(read, write, handshake, 0x8000).split_read_write(); + + // Before doing anything, receive version tag and + // check they are running the same version as us + let mut their_version_tag = VersionTag::default(); + read.read_exact(&mut their_version_tag[..]).await?; + if their_version_tag != netapp.version_tag { + let msg = format!( + "different version tags: {} (theirs) vs. {} (ours)", + hex::encode(their_version_tag), + hex::encode(netapp.version_tag) + ); + error!("Cannot connect to {}: {}", hex::encode(&peer_id[..8]), msg); + return Err(Error::VersionMismatch(msg)); + } + + // Build and launch stuff that manages sending requests client-side + let (query_send, query_recv) = mpsc::unbounded_channel(); + + let (stop_recv_loop, stop_recv_loop_recv) = watch::channel(false); + + let conn = Arc::new(ClientConn { + remote_addr, + peer_id, + next_query_number: AtomicU32::from(RequestID::default()), + query_send: ArcSwapOption::new(Some(Arc::new(query_send))), + inflight: Mutex::new(HashMap::new()), + }); + + netapp.connected_as_client(peer_id, conn.clone()); + + let debug_name = format!("CLI {}", hex::encode(&peer_id[..8])); + + tokio::spawn(async move { + let debug_name_2 = debug_name.clone(); + let send_future = tokio::spawn(conn.clone().send_loop(query_recv, write, debug_name_2)); + + let conn2 = conn.clone(); + let recv_future = tokio::spawn(async move { + select! { + r = conn2.recv_loop(read, debug_name) => r, + _ = await_exit(stop_recv_loop_recv) => Ok(()) + } + }); + + send_future.await.log_err("ClientConn send_loop"); + + // FIXME: should do here: wait for inflight requests to all have their response + stop_recv_loop + .send(true) + .log_err("ClientConn send true to stop_recv_loop"); + + recv_future.await.log_err("ClientConn recv_loop"); + + // Make sure we don't wait on any more requests that won't + // have a response + conn.inflight.lock().unwrap().clear(); + + netapp.disconnected_as_client(&peer_id, conn); + }); + + Ok(()) + } + + pub fn close(&self) { + self.query_send.store(None); + } + + pub(crate) async fn call<T>( + self: Arc<Self>, + req: Req<T>, + path: &str, + prio: RequestPriority, + ) -> Result<Resp<T>, Error> + where + T: Message, + { + let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?; + + let id = self + .next_query_number + .fetch_add(1, atomic::Ordering::Relaxed); + + cfg_if::cfg_if! { + if #[cfg(feature = "telemetry")] { + let tracer = opentelemetry::global::tracer("netapp"); + let mut span = tracer.span_builder(format!("RPC >> {}", path)) + .with_kind(SpanKind::Client) + .start(&tracer); + let propagator = BinaryPropagator::new(); + let telemetry_id: Bytes = propagator.to_bytes(span.span_context()).to_vec().into(); + } else { + let telemetry_id: Bytes = Bytes::new(); + } + }; + + // Encode request + let req_enc = req.into_enc(prio, path.as_bytes().to_vec().into(), telemetry_id); + let req_msg_len = req_enc.msg.len(); + let (req_stream, req_order) = req_enc.encode(); + + // Send request through + let (resp_send, resp_recv) = oneshot::channel(); + let old = self.inflight.lock().unwrap().insert(id, resp_send); + if let Some(old_ch) = old { + error!( + "Too many inflight requests! RequestID collision. Interrupting previous request." + ); + let _ = old_ch.send(Box::pin(futures::stream::once(async move { + Err(std::io::Error::new( + std::io::ErrorKind::Other, + "RequestID collision, too many inflight requests", + )) + }))); + } + + debug!( + "request: query_send {}, path {}, prio {} (serialized message: {} bytes)", + id, path, prio, req_msg_len + ); + + #[cfg(feature = "telemetry")] + span.set_attribute(KeyValue::new("len_query_msg", req_msg_len as i64)); + + query_send.send(SendItem::Stream(id, prio, req_order, req_stream))?; + + let canceller = CancelOnDrop::new(id, query_send.as_ref().clone()); + + cfg_if::cfg_if! { + if #[cfg(feature = "telemetry")] { + let stream = resp_recv + .with_context(Context::current_with_span(span)) + .await?; + } else { + let stream = resp_recv.await?; + } + } + + let stream = Box::pin(canceller.for_stream(stream)); + + let resp_enc = RespEnc::decode(stream).await?; + debug!("client: got response to request {} (path {})", id, path); + Resp::from_enc(resp_enc) + } +} + +impl SendLoop for ClientConn {} + +#[async_trait] +impl RecvLoop for ClientConn { + fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream) { + trace!("ClientConn recv_handler {}", id); + + let mut inflight = self.inflight.lock().unwrap(); + if let Some(ch) = inflight.remove(&id) { + if ch.send(stream).is_err() { + debug!("Could not send request response, probably because request was interrupted. Dropping response."); + } + } else { + debug!("Got unexpected response to request {}, dropping it", id); + } + } +} + +// ---- + +struct CancelOnDrop { + id: RequestID, + query_send: mpsc::UnboundedSender<SendItem>, +} + +impl CancelOnDrop { + fn new(id: RequestID, query_send: mpsc::UnboundedSender<SendItem>) -> Self { + Self { id, query_send } + } + fn for_stream(self, stream: ByteStream) -> CancelOnDropStream { + CancelOnDropStream { + cancel: Some(self), + stream, + } + } +} + +impl Drop for CancelOnDrop { + fn drop(&mut self) { + trace!("cancelling request {}", self.id); + let _ = self.query_send.send(SendItem::Cancel(self.id)); + } +} + +#[pin_project::pin_project] +struct CancelOnDropStream { + cancel: Option<CancelOnDrop>, + #[pin] + stream: ByteStream, +} + +impl Stream for CancelOnDropStream { + type Item = Packet; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<Option<Self::Item>> { + let this = self.project(); + let res = this.stream.poll_next(cx); + if matches!(res, Poll::Ready(None)) { + if let Some(c) = this.cancel.take() { + std::mem::forget(c) + } + } + res + } + + fn size_hint(&self) -> (usize, Option<usize>) { + self.stream.size_hint() + } +} diff --git a/src/net/endpoint.rs b/src/net/endpoint.rs new file mode 100644 index 00000000..3cafafeb --- /dev/null +++ b/src/net/endpoint.rs @@ -0,0 +1,201 @@ +use std::marker::PhantomData; +use std::sync::Arc; + +use arc_swap::ArcSwapOption; +use async_trait::async_trait; + +use crate::error::Error; +use crate::message::*; +use crate::netapp::*; + +/// This trait should be implemented by an object of your application +/// that can handle a message of type `M`, if it wishes to handle +/// streams attached to the request and/or to send back streams +/// attached to the response.. +/// +/// The handler object should be in an Arc, see `Endpoint::set_handler` +#[async_trait] +pub trait StreamingEndpointHandler<M>: Send + Sync +where + M: Message, +{ + async fn handle(self: &Arc<Self>, m: Req<M>, from: NodeID) -> Resp<M>; +} + +/// If one simply wants to use an endpoint in a client fashion, +/// without locally serving requests to that endpoint, +/// use the unit type `()` as the handler type: +/// it will panic if it is ever made to handle request. +#[async_trait] +impl<M: Message> EndpointHandler<M> for () { + async fn handle(self: &Arc<()>, _m: &M, _from: NodeID) -> M::Response { + panic!("This endpoint should not have a local handler."); + } +} + +// ---- + +/// This trait should be implemented by an object of your application +/// that can handle a message of type `M`, in the cases where it doesn't +/// care about attached stream in the request nor in the response. +#[async_trait] +pub trait EndpointHandler<M>: Send + Sync +where + M: Message, +{ + async fn handle(self: &Arc<Self>, m: &M, from: NodeID) -> M::Response; +} + +#[async_trait] +impl<T, M> StreamingEndpointHandler<M> for T +where + T: EndpointHandler<M>, + M: Message, +{ + async fn handle(self: &Arc<Self>, mut m: Req<M>, from: NodeID) -> Resp<M> { + // Immediately drop stream to ignore all data that comes in, + // instead of buffering it indefinitely + drop(m.take_stream()); + Resp::new(EndpointHandler::handle(self, m.msg(), from).await) + } +} + +// ---- + +/// This struct represents an endpoint for message of type `M`. +/// +/// Creating a new endpoint is done by calling `NetApp::endpoint`. +/// An endpoint is identified primarily by its path, which is specified +/// at creation time. +/// +/// An `Endpoint` is used both to send requests to remote nodes, +/// and to specify the handler for such requests on the local node. +/// The type `H` represents the type of the handler object for +/// endpoint messages (see `StreamingEndpointHandler`). +pub struct Endpoint<M, H> +where + M: Message, + H: StreamingEndpointHandler<M>, +{ + _phantom: PhantomData<M>, + netapp: Arc<NetApp>, + path: String, + handler: ArcSwapOption<H>, +} + +impl<M, H> Endpoint<M, H> +where + M: Message, + H: StreamingEndpointHandler<M>, +{ + pub(crate) fn new(netapp: Arc<NetApp>, path: String) -> Self { + Self { + _phantom: PhantomData::default(), + netapp, + path, + handler: ArcSwapOption::from(None), + } + } + + /// Get the path of this endpoint + pub fn path(&self) -> &str { + &self.path + } + + /// Set the object that is responsible of handling requests to + /// this endpoint on the local node. + pub fn set_handler(&self, h: Arc<H>) { + self.handler.swap(Some(h)); + } + + /// Call this endpoint on a remote node (or on the local node, + /// for that matter). This function invokes the full version that + /// allows to attach a stream to the request and to + /// receive such a stream attached to the response. + pub async fn call_streaming<T>( + &self, + target: &NodeID, + req: T, + prio: RequestPriority, + ) -> Result<Resp<M>, Error> + where + T: IntoReq<M>, + { + if *target == self.netapp.id { + match self.handler.load_full() { + None => Err(Error::NoHandler), + Some(h) => Ok(h.handle(req.into_req_local(), self.netapp.id).await), + } + } else { + let conn = self + .netapp + .client_conns + .read() + .unwrap() + .get(target) + .cloned(); + match conn { + None => Err(Error::Message(format!( + "Not connected: {}", + hex::encode(&target[..8]) + ))), + Some(c) => c.call(req.into_req()?, self.path.as_str(), prio).await, + } + } + } + + /// Call this endpoint on a remote node. This function is the simplified + /// version that doesn't allow to have streams attached to the request + /// or the response; see `call_streaming` for the full version. + pub async fn call( + &self, + target: &NodeID, + req: M, + prio: RequestPriority, + ) -> Result<<M as Message>::Response, Error> { + Ok(self.call_streaming(target, req, prio).await?.into_msg()) + } +} + +// ---- Internal stuff ---- + +pub(crate) type DynEndpoint = Box<dyn GenericEndpoint + Send + Sync>; + +#[async_trait] +pub(crate) trait GenericEndpoint { + async fn handle(&self, req_enc: ReqEnc, from: NodeID) -> Result<RespEnc, Error>; + fn drop_handler(&self); + fn clone_endpoint(&self) -> DynEndpoint; +} + +#[derive(Clone)] +pub(crate) struct EndpointArc<M, H>(pub(crate) Arc<Endpoint<M, H>>) +where + M: Message, + H: StreamingEndpointHandler<M>; + +#[async_trait] +impl<M, H> GenericEndpoint for EndpointArc<M, H> +where + M: Message, + H: StreamingEndpointHandler<M> + 'static, +{ + async fn handle(&self, req_enc: ReqEnc, from: NodeID) -> Result<RespEnc, Error> { + match self.0.handler.load_full() { + None => Err(Error::NoHandler), + Some(h) => { + let req = Req::from_enc(req_enc)?; + let res = h.handle(req, from).await; + Ok(res.into_enc()?) + } + } + } + + fn drop_handler(&self) { + self.0.handler.swap(None); + } + + fn clone_endpoint(&self) -> DynEndpoint { + Box::new(Self(self.0.clone())) + } +} diff --git a/src/net/error.rs b/src/net/error.rs new file mode 100644 index 00000000..c0aeeacc --- /dev/null +++ b/src/net/error.rs @@ -0,0 +1,126 @@ +use std::io; + +use err_derive::Error; +use log::error; + +#[derive(Debug, Error)] +pub enum Error { + #[error(display = "IO error: {}", _0)] + Io(#[error(source)] io::Error), + + #[error(display = "Messagepack encode error: {}", _0)] + RMPEncode(#[error(source)] rmp_serde::encode::Error), + #[error(display = "Messagepack decode error: {}", _0)] + RMPDecode(#[error(source)] rmp_serde::decode::Error), + + #[error(display = "Tokio join error: {}", _0)] + TokioJoin(#[error(source)] tokio::task::JoinError), + + #[error(display = "oneshot receive error: {}", _0)] + OneshotRecv(#[error(source)] tokio::sync::oneshot::error::RecvError), + + #[error(display = "Handshake error: {}", _0)] + Handshake(#[error(source)] kuska_handshake::async_std::Error), + + #[error(display = "UTF8 error: {}", _0)] + UTF8(#[error(source)] std::string::FromUtf8Error), + + #[error(display = "Framing protocol error")] + Framing, + + #[error(display = "Remote error ({:?}): {}", _0, _1)] + Remote(io::ErrorKind, String), + + #[error(display = "Request ID collision")] + IdCollision, + + #[error(display = "{}", _0)] + Message(String), + + #[error(display = "No handler / shutting down")] + NoHandler, + + #[error(display = "Connection closed")] + ConnectionClosed, + + #[error(display = "Version mismatch: {}", _0)] + VersionMismatch(String), +} + +impl<T> From<tokio::sync::watch::error::SendError<T>> for Error { + fn from(_e: tokio::sync::watch::error::SendError<T>) -> Error { + Error::Message("Watch send error".into()) + } +} + +impl<T> From<tokio::sync::mpsc::error::SendError<T>> for Error { + fn from(_e: tokio::sync::mpsc::error::SendError<T>) -> Error { + Error::Message("MPSC send error".into()) + } +} + +/// Ths trait adds a `.log_err()` method on `Result<(), E>` types, +/// which dismisses the error by logging it to stderr. +pub trait LogError { + fn log_err(self, msg: &'static str); +} + +impl<E> LogError for Result<(), E> +where + E: Into<Error>, +{ + fn log_err(self, msg: &'static str) { + if let Err(e) = self { + error!("Error: {}: {}", msg, Into::<Error>::into(e)); + }; + } +} + +impl<E, T> LogError for Result<T, E> +where + T: LogError, + E: Into<Error>, +{ + fn log_err(self, msg: &'static str) { + match self { + Err(e) => error!("Error: {}: {}", msg, Into::<Error>::into(e)), + Ok(x) => x.log_err(msg), + } + } +} + +// ---- Helpers for serializing I/O Errors + +pub(crate) fn u8_to_io_errorkind(v: u8) -> std::io::ErrorKind { + use std::io::ErrorKind; + match v { + 101 => ErrorKind::ConnectionAborted, + 102 => ErrorKind::BrokenPipe, + 103 => ErrorKind::WouldBlock, + 104 => ErrorKind::InvalidInput, + 105 => ErrorKind::InvalidData, + 106 => ErrorKind::TimedOut, + 107 => ErrorKind::Interrupted, + 108 => ErrorKind::UnexpectedEof, + 109 => ErrorKind::OutOfMemory, + 110 => ErrorKind::ConnectionReset, + _ => ErrorKind::Other, + } +} + +pub(crate) fn io_errorkind_to_u8(kind: std::io::ErrorKind) -> u8 { + use std::io::ErrorKind; + match kind { + ErrorKind::ConnectionAborted => 101, + ErrorKind::BrokenPipe => 102, + ErrorKind::WouldBlock => 103, + ErrorKind::InvalidInput => 104, + ErrorKind::InvalidData => 105, + ErrorKind::TimedOut => 106, + ErrorKind::Interrupted => 107, + ErrorKind::UnexpectedEof => 108, + ErrorKind::OutOfMemory => 109, + ErrorKind::ConnectionReset => 110, + _ => 100, + } +} diff --git a/src/net/lib.rs b/src/net/lib.rs new file mode 100644 index 00000000..8e30e40f --- /dev/null +++ b/src/net/lib.rs @@ -0,0 +1,35 @@ +//! Netapp is a Rust library that takes care of a few common tasks in distributed software: +//! +//! - establishing secure connections +//! - managing connection lifetime, reconnecting on failure +//! - checking peer's state +//! - peer discovery +//! - query/response message passing model for communications +//! - multiplexing transfers over a connection +//! - overlay networks: full mesh, and soon other methods +//! +//! Of particular interest, read the documentation for the `netapp::NetApp` type, +//! the `message::Message` trait, and `proto::RequestPriority` to learn more +//! about message priorization. +//! Also check out the examples to learn how to use this crate. + +pub mod bytes_buf; +pub mod error; +pub mod stream; +pub mod util; + +pub mod endpoint; +pub mod message; + +mod client; +mod recv; +mod send; +mod server; + +pub mod netapp; +pub mod peering; + +pub use crate::netapp::*; + +#[cfg(test)] +mod test; diff --git a/src/net/message.rs b/src/net/message.rs new file mode 100644 index 00000000..b0d255c6 --- /dev/null +++ b/src/net/message.rs @@ -0,0 +1,522 @@ +use std::fmt; +use std::marker::PhantomData; +use std::sync::Arc; + +use bytes::{BufMut, Bytes, BytesMut}; +use rand::prelude::*; +use serde::{Deserialize, Serialize}; + +use futures::stream::StreamExt; + +use crate::error::*; +use crate::stream::*; +use crate::util::*; + +/// Priority of a request (click to read more about priorities). +/// +/// This priority value is used to priorize messages +/// in the send queue of the client, and their responses in the send queue of the +/// server. Lower values mean higher priority. +/// +/// This mechanism is usefull for messages bigger than the maximum chunk size +/// (set at `0x4000` bytes), such as large file transfers. +/// In such case, all of the messages in the send queue with the highest priority +/// will take turns to send individual chunks, in a round-robin fashion. +/// Once all highest priority messages are sent successfully, the messages with +/// the next highest priority will begin being sent in the same way. +/// +/// The same priority value is given to a request and to its associated response. +pub type RequestPriority = u8; + +/// Priority class: high +pub const PRIO_HIGH: RequestPriority = 0x20; +/// Priority class: normal +pub const PRIO_NORMAL: RequestPriority = 0x40; +/// Priority class: background +pub const PRIO_BACKGROUND: RequestPriority = 0x80; +/// Priority: primary among given class +pub const PRIO_PRIMARY: RequestPriority = 0x00; +/// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`) +pub const PRIO_SECONDARY: RequestPriority = 0x01; + +// ---- + +/// An order tag can be added to a message or a response to indicate +/// whether it should be sent after or before other messages with order tags +/// referencing a same stream +#[derive(Clone, Copy, Serialize, Deserialize, Debug)] +pub struct OrderTag(pub(crate) u64, pub(crate) u64); + +/// A stream is an opaque identifier that defines a set of messages +/// or responses that are ordered wrt one another using to order tags. +#[derive(Clone, Copy)] +pub struct OrderTagStream(u64); + +impl OrderTag { + /// Create a new stream from which to generate order tags. Example: + /// ```ignore + /// let stream = OrderTag.stream(); + /// let tag_1 = stream.order(1); + /// let tag_2 = stream.order(2); + /// ``` + pub fn stream() -> OrderTagStream { + OrderTagStream(thread_rng().gen()) + } +} +impl OrderTagStream { + /// Create the order tag for message `order` in this stream + pub fn order(&self, order: u64) -> OrderTag { + OrderTag(self.0, order) + } +} + +// ---- + +/// This trait should be implemented by all messages your application +/// wants to handle. It specifies which data type should be sent +/// as a response to this message in the RPC protocol. +pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static { + /// The type of the response that is sent in response to this message + type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static; +} + +// ---- + +/// The Req<M> is a helper object used to create requests and attach them +/// a stream of data. If the stream is a fixed Bytes and not a ByteStream, +/// Req<M> is cheaply clonable to allow the request to be sent to different +/// peers (Clone will panic if the stream is a ByteStream). +pub struct Req<M: Message> { + pub(crate) msg: Arc<M>, + pub(crate) msg_ser: Option<Bytes>, + pub(crate) stream: AttachedStream, + pub(crate) order_tag: Option<OrderTag>, +} + +impl<M: Message> Req<M> { + /// Creates a new request from a base message `M` + pub fn new(v: M) -> Result<Self, Error> { + Ok(v.into_req()?) + } + + /// Attach a stream to message in request, where the stream is streamed + /// from a fixed `Bytes` buffer + pub fn with_stream_from_buffer(self, b: Bytes) -> Self { + Self { + stream: AttachedStream::Fixed(b), + ..self + } + } + + /// Attach a stream to message in request, where the stream is + /// an instance of `ByteStream`. Note than when a `Req<M>` has an attached + /// stream which is a `ByteStream` instance, it can no longer be cloned + /// to be sent to different nodes (`.clone()` will panic) + pub fn with_stream(self, b: ByteStream) -> Self { + Self { + stream: AttachedStream::Stream(b), + ..self + } + } + + /// Add an order tag to this request to indicate in which order it should + /// be sent. + pub fn with_order_tag(self, order_tag: OrderTag) -> Self { + Self { + order_tag: Some(order_tag), + ..self + } + } + + /// Get a reference to the message `M` contained in this request + pub fn msg(&self) -> &M { + &self.msg + } + + /// Takes out the stream attached to this request, if any + pub fn take_stream(&mut self) -> Option<ByteStream> { + std::mem::replace(&mut self.stream, AttachedStream::None).into_stream() + } + + pub(crate) fn into_enc( + self, + prio: RequestPriority, + path: Bytes, + telemetry_id: Bytes, + ) -> ReqEnc { + ReqEnc { + prio, + path, + telemetry_id, + msg: self.msg_ser.unwrap(), + stream: self.stream.into_stream(), + order_tag: self.order_tag, + } + } + + pub(crate) fn from_enc(enc: ReqEnc) -> Result<Self, rmp_serde::decode::Error> { + let msg = rmp_serde::decode::from_slice(&enc.msg)?; + Ok(Req { + msg: Arc::new(msg), + msg_ser: Some(enc.msg), + stream: enc + .stream + .map(AttachedStream::Stream) + .unwrap_or(AttachedStream::None), + order_tag: enc.order_tag, + }) + } +} + +/// `IntoReq<M>` represents any object that can be transformed into `Req<M>` +pub trait IntoReq<M: Message> { + /// Transform the object into a `Req<M>`, serializing the message M + /// to be sent to remote nodes + fn into_req(self) -> Result<Req<M>, rmp_serde::encode::Error>; + /// Transform the object into a `Req<M>`, skipping the serialization + /// of message M, in the case we are not sending this RPC message to + /// a remote node + fn into_req_local(self) -> Req<M>; +} + +impl<M: Message> IntoReq<M> for M { + fn into_req(self) -> Result<Req<M>, rmp_serde::encode::Error> { + let msg_ser = rmp_to_vec_all_named(&self)?; + Ok(Req { + msg: Arc::new(self), + msg_ser: Some(Bytes::from(msg_ser)), + stream: AttachedStream::None, + order_tag: None, + }) + } + fn into_req_local(self) -> Req<M> { + Req { + msg: Arc::new(self), + msg_ser: None, + stream: AttachedStream::None, + order_tag: None, + } + } +} + +impl<M: Message> IntoReq<M> for Req<M> { + fn into_req(self) -> Result<Req<M>, rmp_serde::encode::Error> { + Ok(self) + } + fn into_req_local(self) -> Req<M> { + self + } +} + +impl<M: Message> Clone for Req<M> { + fn clone(&self) -> Self { + let stream = match &self.stream { + AttachedStream::None => AttachedStream::None, + AttachedStream::Fixed(b) => AttachedStream::Fixed(b.clone()), + AttachedStream::Stream(_) => { + panic!("Cannot clone a Req<_> with a non-buffer attached stream") + } + }; + Self { + msg: self.msg.clone(), + msg_ser: self.msg_ser.clone(), + stream, + order_tag: self.order_tag, + } + } +} + +impl<M> fmt::Debug for Req<M> +where + M: Message + fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!(f, "Req[{:?}", self.msg)?; + match &self.stream { + AttachedStream::None => write!(f, "]"), + AttachedStream::Fixed(b) => write!(f, "; stream=buf:{}]", b.len()), + AttachedStream::Stream(_) => write!(f, "; stream]"), + } + } +} + +// ---- + +/// The Resp<M> represents a full response from a RPC that may have +/// an attached stream. +pub struct Resp<M: Message> { + pub(crate) _phantom: PhantomData<M>, + pub(crate) msg: M::Response, + pub(crate) stream: AttachedStream, + pub(crate) order_tag: Option<OrderTag>, +} + +impl<M: Message> Resp<M> { + /// Creates a new response from a base response message + pub fn new(v: M::Response) -> Self { + Resp { + _phantom: Default::default(), + msg: v, + stream: AttachedStream::None, + order_tag: None, + } + } + + /// Attach a stream to message in response, where the stream is streamed + /// from a fixed `Bytes` buffer + pub fn with_stream_from_buffer(self, b: Bytes) -> Self { + Self { + stream: AttachedStream::Fixed(b), + ..self + } + } + + /// Attach a stream to message in response, where the stream is + /// an instance of `ByteStream`. + pub fn with_stream(self, b: ByteStream) -> Self { + Self { + stream: AttachedStream::Stream(b), + ..self + } + } + + /// Add an order tag to this response to indicate in which order it should + /// be sent. + pub fn with_order_tag(self, order_tag: OrderTag) -> Self { + Self { + order_tag: Some(order_tag), + ..self + } + } + + /// Get a reference to the response message contained in this request + pub fn msg(&self) -> &M::Response { + &self.msg + } + + /// Transforms the `Resp<M>` into the response message it contains, + /// dropping everything else (including attached data stream) + pub fn into_msg(self) -> M::Response { + self.msg + } + + /// Transforms the `Resp<M>` into, on the one side, the response message + /// it contains, and on the other side, the associated data stream + /// if it exists + pub fn into_parts(self) -> (M::Response, Option<ByteStream>) { + (self.msg, self.stream.into_stream()) + } + + pub(crate) fn into_enc(self) -> Result<RespEnc, rmp_serde::encode::Error> { + Ok(RespEnc { + msg: rmp_to_vec_all_named(&self.msg)?.into(), + stream: self.stream.into_stream(), + order_tag: self.order_tag, + }) + } + + pub(crate) fn from_enc(enc: RespEnc) -> Result<Self, Error> { + let msg = rmp_serde::decode::from_slice(&enc.msg)?; + Ok(Self { + _phantom: Default::default(), + msg, + stream: enc + .stream + .map(AttachedStream::Stream) + .unwrap_or(AttachedStream::None), + order_tag: enc.order_tag, + }) + } +} + +impl<M> fmt::Debug for Resp<M> +where + M: Message, + <M as Message>::Response: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!(f, "Resp[{:?}", self.msg)?; + match &self.stream { + AttachedStream::None => write!(f, "]"), + AttachedStream::Fixed(b) => write!(f, "; stream=buf:{}]", b.len()), + AttachedStream::Stream(_) => write!(f, "; stream]"), + } + } +} + +// ---- + +pub(crate) enum AttachedStream { + None, + Fixed(Bytes), + Stream(ByteStream), +} + +impl AttachedStream { + pub fn into_stream(self) -> Option<ByteStream> { + match self { + AttachedStream::None => None, + AttachedStream::Fixed(b) => Some(Box::pin(futures::stream::once(async move { Ok(b) }))), + AttachedStream::Stream(s) => Some(s), + } + } +} + +// ---- ---- + +/// Encoding for requests into a ByteStream: +/// - priority: u8 +/// - path length: u8 +/// - path: [u8; path length] +/// - telemetry id length: u8 +/// - telemetry id: [u8; telemetry id length] +/// - msg len: u32 +/// - msg [u8; ..] +/// - the attached stream as the rest of the encoded stream +pub(crate) struct ReqEnc { + pub(crate) prio: RequestPriority, + pub(crate) path: Bytes, + pub(crate) telemetry_id: Bytes, + pub(crate) msg: Bytes, + pub(crate) stream: Option<ByteStream>, + pub(crate) order_tag: Option<OrderTag>, +} + +impl ReqEnc { + pub(crate) fn encode(self) -> (ByteStream, Option<OrderTag>) { + let mut buf = BytesMut::with_capacity( + self.path.len() + self.telemetry_id.len() + self.msg.len() + 16, + ); + + buf.put_u8(self.prio); + + buf.put_u8(self.path.len() as u8); + buf.put(self.path); + + buf.put_u8(self.telemetry_id.len() as u8); + buf.put(&self.telemetry_id[..]); + + buf.put_u32(self.msg.len() as u32); + + let header = buf.freeze(); + + let res_stream: ByteStream = if let Some(stream) = self.stream { + Box::pin(futures::stream::iter([Ok(header), Ok(self.msg)]).chain(stream)) + } else { + Box::pin(futures::stream::iter([Ok(header), Ok(self.msg)])) + }; + (res_stream, self.order_tag) + } + + pub(crate) async fn decode(stream: ByteStream) -> Result<Self, Error> { + Self::decode_aux(stream) + .await + .map_err(read_exact_error_to_error) + } + + async fn decode_aux(stream: ByteStream) -> Result<Self, ReadExactError> { + let mut reader = ByteStreamReader::new(stream); + + let prio = reader.read_u8().await?; + + let path_len = reader.read_u8().await?; + let path = reader.read_exact(path_len as usize).await?; + + let telemetry_id_len = reader.read_u8().await?; + let telemetry_id = reader.read_exact(telemetry_id_len as usize).await?; + + let msg_len = reader.read_u32().await?; + let msg = reader.read_exact(msg_len as usize).await?; + + Ok(Self { + prio, + path, + telemetry_id, + msg, + stream: Some(reader.into_stream()), + order_tag: None, + }) + } +} + +/// Encoding for responses into a ByteStream: +/// IF SUCCESS: +/// - 0: u8 +/// - msg len: u32 +/// - msg [u8; ..] +/// - the attached stream as the rest of the encoded stream +/// IF ERROR: +/// - message length + 1: u8 +/// - error code: u8 +/// - message: [u8; message_length] +pub(crate) struct RespEnc { + msg: Bytes, + stream: Option<ByteStream>, + order_tag: Option<OrderTag>, +} + +impl RespEnc { + pub(crate) fn encode(resp: Result<Self, Error>) -> (ByteStream, Option<OrderTag>) { + match resp { + Ok(Self { + msg, + stream, + order_tag, + }) => { + let mut buf = BytesMut::with_capacity(4); + buf.put_u32(msg.len() as u32); + let header = buf.freeze(); + + let res_stream: ByteStream = if let Some(stream) = stream { + Box::pin(futures::stream::iter([Ok(header), Ok(msg)]).chain(stream)) + } else { + Box::pin(futures::stream::iter([Ok(header), Ok(msg)])) + }; + (res_stream, order_tag) + } + Err(err) => { + let err = std::io::Error::new( + std::io::ErrorKind::Other, + format!("netapp error: {}", err), + ); + ( + Box::pin(futures::stream::once(async move { Err(err) })), + None, + ) + } + } + } + + pub(crate) async fn decode(stream: ByteStream) -> Result<Self, Error> { + Self::decode_aux(stream) + .await + .map_err(read_exact_error_to_error) + } + + async fn decode_aux(stream: ByteStream) -> Result<Self, ReadExactError> { + let mut reader = ByteStreamReader::new(stream); + + let msg_len = reader.read_u32().await?; + let msg = reader.read_exact(msg_len as usize).await?; + + // Check whether the response stream still has data or not. + // If no more data is coming, this will defuse the request canceller. + // If we didn't do this, and the client doesn't try to read from the stream, + // the request canceller doesn't know that we read everything and + // sends a cancellation message to the server (which they don't care about). + reader.fill_buffer().await; + + Ok(Self { + msg, + stream: Some(reader.into_stream()), + order_tag: None, + }) + } +} + +fn read_exact_error_to_error(e: ReadExactError) -> Error { + match e { + ReadExactError::Stream(err) => Error::Remote(err.kind(), err.to_string()), + ReadExactError::UnexpectedEos => Error::Framing, + } +} diff --git a/src/net/netapp.rs b/src/net/netapp.rs new file mode 100644 index 00000000..b1ad9db8 --- /dev/null +++ b/src/net/netapp.rs @@ -0,0 +1,452 @@ +use std::collections::HashMap; +use std::net::{IpAddr, SocketAddr}; +use std::sync::{Arc, RwLock}; + +use log::{debug, error, info, trace, warn}; + +use arc_swap::ArcSwapOption; +use async_trait::async_trait; + +use serde::{Deserialize, Serialize}; +use sodiumoxide::crypto::auth; +use sodiumoxide::crypto::sign::ed25519; + +use futures::stream::futures_unordered::FuturesUnordered; +use futures::stream::StreamExt; +use tokio::net::{TcpListener, TcpStream}; +use tokio::select; +use tokio::sync::{mpsc, watch}; + +use crate::client::*; +use crate::endpoint::*; +use crate::error::*; +use crate::message::*; +use crate::server::*; + +/// A node's identifier, which is also its public cryptographic key +pub type NodeID = sodiumoxide::crypto::sign::ed25519::PublicKey; +/// A node's secret key +pub type NodeKey = sodiumoxide::crypto::sign::ed25519::SecretKey; +/// A network key +pub type NetworkKey = sodiumoxide::crypto::auth::Key; + +/// Tag which is exchanged between client and server upon connection establishment +/// to check that they are running compatible versions of Netapp, +/// composed of 8 bytes for Netapp version and 8 bytes for client version +pub(crate) type VersionTag = [u8; 16]; + +/// Value of the Netapp version used in the version tag +pub(crate) const NETAPP_VERSION_TAG: u64 = 0x6e65746170700005; // netapp 0x0005 + +#[derive(Serialize, Deserialize, Debug)] +pub(crate) struct HelloMessage { + pub server_addr: Option<IpAddr>, + pub server_port: u16, +} + +impl Message for HelloMessage { + type Response = (); +} + +type OnConnectHandler = Box<dyn Fn(NodeID, SocketAddr, bool) + Send + Sync>; +type OnDisconnectHandler = Box<dyn Fn(NodeID, bool) + Send + Sync>; + +/// NetApp is the main class that handles incoming and outgoing connections. +/// +/// NetApp can be used in a stand-alone fashion or together with a peering strategy. +/// If using it alone, you will want to set `on_connect` and `on_disconnect` events +/// in order to manage information about the current peer list. +/// +/// It is generally not necessary to use NetApp stand-alone, as the provided full mesh +/// and RPS peering strategies take care of the most common use cases. +pub struct NetApp { + listen_params: ArcSwapOption<ListenParams>, + + /// Version tag, 8 bytes for netapp version, 8 bytes for app version + pub version_tag: VersionTag, + /// Network secret key + pub netid: auth::Key, + /// Our peer ID + pub id: NodeID, + /// Private key associated with our peer ID + pub privkey: ed25519::SecretKey, + + pub(crate) server_conns: RwLock<HashMap<NodeID, Arc<ServerConn>>>, + pub(crate) client_conns: RwLock<HashMap<NodeID, Arc<ClientConn>>>, + + pub(crate) endpoints: RwLock<HashMap<String, DynEndpoint>>, + hello_endpoint: ArcSwapOption<Endpoint<HelloMessage, NetApp>>, + + on_connected_handler: ArcSwapOption<OnConnectHandler>, + on_disconnected_handler: ArcSwapOption<OnDisconnectHandler>, +} + +struct ListenParams { + listen_addr: SocketAddr, + public_addr: Option<IpAddr>, +} + +impl NetApp { + /// Creates a new instance of NetApp, which can serve either as a full p2p node, + /// or just as a passive client. To upgrade to a full p2p node, spawn a listener + /// using `.listen()` + /// + /// Our Peer ID is the public key associated to the secret key given here. + pub fn new(app_version_tag: u64, netid: auth::Key, privkey: ed25519::SecretKey) -> Arc<Self> { + let mut version_tag = [0u8; 16]; + version_tag[0..8].copy_from_slice(&u64::to_be_bytes(NETAPP_VERSION_TAG)[..]); + version_tag[8..16].copy_from_slice(&u64::to_be_bytes(app_version_tag)[..]); + + let id = privkey.public_key(); + let netapp = Arc::new(Self { + listen_params: ArcSwapOption::new(None), + version_tag, + netid, + id, + privkey, + server_conns: RwLock::new(HashMap::new()), + client_conns: RwLock::new(HashMap::new()), + endpoints: RwLock::new(HashMap::new()), + hello_endpoint: ArcSwapOption::new(None), + on_connected_handler: ArcSwapOption::new(None), + on_disconnected_handler: ArcSwapOption::new(None), + }); + + netapp + .hello_endpoint + .swap(Some(netapp.endpoint("__netapp/netapp.rs/Hello".into()))); + netapp + .hello_endpoint + .load_full() + .unwrap() + .set_handler(netapp.clone()); + + netapp + } + + /// Set the handler to be called when a new connection (incoming or outgoing) has + /// been successfully established. Do not set this if using a peering strategy, + /// as the peering strategy will need to set this itself. + pub fn on_connected<F>(&self, handler: F) + where + F: Fn(NodeID, SocketAddr, bool) + Sized + Send + Sync + 'static, + { + self.on_connected_handler + .store(Some(Arc::new(Box::new(handler)))); + } + + /// Set the handler to be called when an existing connection (incoming or outgoing) has + /// been closed by either party. Do not set this if using a peering strategy, + /// as the peering strategy will need to set this itself. + pub fn on_disconnected<F>(&self, handler: F) + where + F: Fn(NodeID, bool) + Sized + Send + Sync + 'static, + { + self.on_disconnected_handler + .store(Some(Arc::new(Box::new(handler)))); + } + + /// Create a new endpoint with path `path`, + /// that handles messages of type `M`. + /// `H` is the type of the object that should handle requests + /// to this endpoint on the local node. If you don't want + /// to handle request on the local node (e.g. if this node + /// is only a client in the network), define the type `H` + /// to be `()`. + /// This function will panic if the endpoint has already been + /// created. + pub fn endpoint<M, H>(self: &Arc<Self>, path: String) -> Arc<Endpoint<M, H>> + where + M: Message + 'static, + H: StreamingEndpointHandler<M> + 'static, + { + let endpoint = Arc::new(Endpoint::<M, H>::new(self.clone(), path.clone())); + let endpoint_arc = EndpointArc(endpoint.clone()); + if self + .endpoints + .write() + .unwrap() + .insert(path.clone(), Box::new(endpoint_arc)) + .is_some() + { + panic!("Redefining endpoint: {}", path); + }; + endpoint + } + + /// Main listening process for our app. This future runs during the whole + /// run time of our application. + /// If this is not called, the NetApp instance remains a passive client. + pub async fn listen( + self: Arc<Self>, + listen_addr: SocketAddr, + public_addr: Option<IpAddr>, + mut must_exit: watch::Receiver<bool>, + ) { + let listen_params = ListenParams { + listen_addr, + public_addr, + }; + if self + .listen_params + .swap(Some(Arc::new(listen_params))) + .is_some() + { + error!("Trying to listen on NetApp but we're already listening!"); + } + + let listener = TcpListener::bind(listen_addr).await.unwrap(); + info!("Listening on {}", listen_addr); + + let (conn_in, mut conn_out) = mpsc::unbounded_channel(); + let connection_collector = tokio::spawn(async move { + let mut collection = FuturesUnordered::new(); + loop { + if collection.is_empty() { + match conn_out.recv().await { + Some(f) => collection.push(f), + None => break, + } + } else { + select! { + new_fut = conn_out.recv() => { + match new_fut { + Some(f) => collection.push(f), + None => break, + } + } + result = collection.next() => { + trace!("Collected connection: {:?}", result); + } + } + } + } + debug!("Collecting last open server connections."); + while let Some(conn_res) = collection.next().await { + trace!("Collected connection: {:?}", conn_res); + } + debug!("No more server connections to collect"); + }); + + while !*must_exit.borrow_and_update() { + let (socket, peer_addr) = select! { + sockres = listener.accept() => { + match sockres { + Ok(x) => x, + Err(e) => { + warn!("Error in listener.accept: {}", e); + continue; + } + } + }, + _ = must_exit.changed() => continue, + }; + + info!( + "Incoming connection from {}, negotiating handshake...", + peer_addr + ); + let self2 = self.clone(); + let must_exit2 = must_exit.clone(); + conn_in + .send(tokio::spawn(async move { + ServerConn::run(self2, socket, must_exit2) + .await + .log_err("ServerConn::run"); + })) + .log_err("Failed to send connection to connection collector"); + } + + drop(conn_in); + + connection_collector + .await + .log_err("Failed to await for connection collector"); + } + + /// Drop all endpoint handlers, as well as handlers for connection/disconnection + /// events. (This disables the peering strategy) + /// + /// Use this when terminating to break reference cycles + pub fn drop_all_handlers(&self) { + for (_, endpoint) in self.endpoints.read().unwrap().iter() { + endpoint.drop_handler(); + } + self.on_connected_handler.store(None); + self.on_disconnected_handler.store(None); + } + + /// Attempt to connect to a peer, given by its ip:port and its public key. + /// The public key will be checked during the secret handshake process. + /// This function returns once the connection has been established and a + /// successfull handshake was made. At this point we can send messages to + /// the other node with `Netapp::request` + pub async fn try_connect(self: Arc<Self>, ip: SocketAddr, id: NodeID) -> Result<(), Error> { + // Don't connect to ourself, we don't care + // but pretend we did + if id == self.id { + tokio::spawn(async move { + if let Some(h) = self.on_connected_handler.load().as_ref() { + h(id, ip, false); + } + }); + return Ok(()); + } + + // Don't connect if already connected + if self.client_conns.read().unwrap().contains_key(&id) { + return Ok(()); + } + + let socket = TcpStream::connect(ip).await?; + info!("Connected to {}, negotiating handshake...", ip); + ClientConn::init(self, socket, id).await?; + Ok(()) + } + + /// Close the outgoing connection we have to a node specified by its public key, + /// if such a connection is currently open. + pub fn disconnect(self: &Arc<Self>, id: &NodeID) { + // If id is ourself, we're not supposed to have a connection open + if *id != self.id { + let conn = self.client_conns.write().unwrap().remove(id); + if let Some(c) = conn { + debug!( + "Closing connection to {} ({})", + hex::encode(&c.peer_id[..8]), + c.remote_addr + ); + c.close(); + } else { + return; + } + } + + // call on_disconnected_handler immediately, since the connection + // was removed + // (if id == self.id, we pretend we disconnected) + let id = *id; + let self2 = self.clone(); + tokio::spawn(async move { + if let Some(h) = self2.on_disconnected_handler.load().as_ref() { + h(id, false); + } + }); + } + + // Called from conn.rs when an incoming connection is successfully established + // Registers the connection in our list of connections + // Do not yet call the on_connected handler, because we don't know if the remote + // has an actual IP address and port we can call them back on. + // We will know this when they send a Hello message, which is handled below. + pub(crate) fn connected_as_server(&self, id: NodeID, conn: Arc<ServerConn>) { + info!( + "Accepted connection from {} at {}", + hex::encode(&id[..8]), + conn.remote_addr + ); + + self.server_conns.write().unwrap().insert(id, conn); + } + + // Handle hello message from a client. This message is used for them to tell us + // that they are listening on a certain port number on which we can call them back. + // At this point we know they are a full network member, and not just a client, + // and we call the on_connected handler so that the peering strategy knows + // we have a new potential peer + + // Called from conn.rs when an incoming connection is closed. + // We deregister the connection from server_conns and call the + // handler registered by on_disconnected + pub(crate) fn disconnected_as_server(&self, id: &NodeID, conn: Arc<ServerConn>) { + info!("Connection from {} closed", hex::encode(&id[..8])); + + let mut conn_list = self.server_conns.write().unwrap(); + if let Some(c) = conn_list.get(id) { + if Arc::ptr_eq(c, &conn) { + conn_list.remove(id); + drop(conn_list); + + if let Some(h) = self.on_disconnected_handler.load().as_ref() { + h(conn.peer_id, true); + } + } + } + } + + // Called from conn.rs when an outgoinc connection is successfully established. + // The connection is registered in self.client_conns, and the + // on_connected handler is called. + // + // Since we are ourself listening, we send them a Hello message so that + // they know on which port to call us back. (TODO: don't do this if we are + // just a simple client and not a full p2p node) + pub(crate) fn connected_as_client(&self, id: NodeID, conn: Arc<ClientConn>) { + info!("Connection established to {}", hex::encode(&id[..8])); + + { + let old_c_opt = self.client_conns.write().unwrap().insert(id, conn.clone()); + if let Some(old_c) = old_c_opt { + tokio::spawn(async move { old_c.close() }); + } + } + + if let Some(h) = self.on_connected_handler.load().as_ref() { + h(conn.peer_id, conn.remote_addr, false); + } + + if let Some(lp) = self.listen_params.load_full() { + let server_addr = lp.public_addr; + let server_port = lp.listen_addr.port(); + let hello_endpoint = self.hello_endpoint.load_full().unwrap(); + tokio::spawn(async move { + hello_endpoint + .call( + &conn.peer_id, + HelloMessage { + server_addr, + server_port, + }, + PRIO_NORMAL, + ) + .await + .map(|_| ()) + .log_err("Sending hello message"); + }); + } + } + + // Called from conn.rs when an outgoinc connection is closed. + // The connection is removed from conn_list, and the on_disconnected handler + // is called. + pub(crate) fn disconnected_as_client(&self, id: &NodeID, conn: Arc<ClientConn>) { + info!("Connection to {} closed", hex::encode(&id[..8])); + let mut conn_list = self.client_conns.write().unwrap(); + if let Some(c) = conn_list.get(id) { + if Arc::ptr_eq(c, &conn) { + conn_list.remove(id); + drop(conn_list); + + if let Some(h) = self.on_disconnected_handler.load().as_ref() { + h(conn.peer_id, false); + } + } + } + // else case: happens if connection was removed in .disconnect() + // in which case on_disconnected_handler was already called + } +} + +#[async_trait] +impl EndpointHandler<HelloMessage> for NetApp { + async fn handle(self: &Arc<Self>, msg: &HelloMessage, from: NodeID) { + debug!("Hello from {:?}: {:?}", hex::encode(&from[..8]), msg); + if let Some(h) = self.on_connected_handler.load().as_ref() { + if let Some(c) = self.server_conns.read().unwrap().get(&from) { + let remote_ip = msg.server_addr.unwrap_or_else(|| c.remote_addr.ip()); + let remote_addr = SocketAddr::new(remote_ip, msg.server_port); + h(from, remote_addr, true); + } + } + } +} diff --git a/src/net/peering.rs b/src/net/peering.rs new file mode 100644 index 00000000..32199cf8 --- /dev/null +++ b/src/net/peering.rs @@ -0,0 +1,614 @@ +use std::collections::{HashMap, VecDeque}; +use std::net::SocketAddr; +use std::sync::atomic::{self, AtomicU64}; +use std::sync::{Arc, RwLock}; +use std::time::{Duration, Instant}; + +use arc_swap::ArcSwap; +use async_trait::async_trait; +use log::{debug, info, trace, warn}; +use serde::{Deserialize, Serialize}; + +use tokio::select; +use tokio::sync::watch; + +use sodiumoxide::crypto::hash; + +use crate::endpoint::*; +use crate::error::*; +use crate::netapp::*; + +use crate::message::*; +use crate::NodeID; + +const CONN_RETRY_INTERVAL: Duration = Duration::from_secs(30); +const CONN_MAX_RETRIES: usize = 10; +const PING_INTERVAL: Duration = Duration::from_secs(15); +const LOOP_DELAY: Duration = Duration::from_secs(1); +const FAILED_PING_THRESHOLD: usize = 4; + +const DEFAULT_PING_TIMEOUT_MILLIS: u64 = 10_000; + +// -- Protocol messages -- + +#[derive(Serialize, Deserialize)] +struct PingMessage { + pub id: u64, + pub peer_list_hash: hash::Digest, +} + +impl Message for PingMessage { + type Response = PingMessage; +} + +#[derive(Serialize, Deserialize)] +struct PeerListMessage { + pub list: Vec<(NodeID, SocketAddr)>, +} + +impl Message for PeerListMessage { + type Response = PeerListMessage; +} + +// -- Algorithm data structures -- + +#[derive(Debug)] +struct PeerInfoInternal { + // addr is the currently connected address, + // or the last address we were connected to, + // or an arbitrary address some other peer gave us + addr: SocketAddr, + // all_addrs contains all of the addresses everyone gave us + all_addrs: Vec<SocketAddr>, + + state: PeerConnState, + last_send_ping: Option<Instant>, + last_seen: Option<Instant>, + ping: VecDeque<Duration>, + failed_pings: usize, +} + +impl PeerInfoInternal { + fn new(addr: SocketAddr, state: PeerConnState) -> Self { + Self { + addr, + all_addrs: vec![addr], + state, + last_send_ping: None, + last_seen: None, + ping: VecDeque::new(), + failed_pings: 0, + } + } +} + +/// Information that the full mesh peering strategy can return about the peers it knows of +#[derive(Copy, Clone, Debug)] +pub struct PeerInfo { + /// The node's identifier (its public key) + pub id: NodeID, + /// The node's network address + pub addr: SocketAddr, + /// The current status of our connection to this node + pub state: PeerConnState, + /// The last time at which the node was seen + pub last_seen: Option<Instant>, + /// The average ping to this node on recent observations (if at least one ping value is known) + pub avg_ping: Option<Duration>, + /// The maximum observed ping to this node on recent observations (if at least one + /// ping value is known) + pub max_ping: Option<Duration>, + /// The median ping to this node on recent observations (if at least one ping value + /// is known) + pub med_ping: Option<Duration>, +} + +impl PeerInfo { + /// Returns true if we can currently send requests to this peer + pub fn is_up(&self) -> bool { + self.state.is_up() + } +} + +/// PeerConnState: possible states for our tentative connections to given peer +/// This structure is only interested in recording connection info for outgoing +/// TCP connections +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum PeerConnState { + /// This entry represents ourself (the local node) + Ourself, + + /// We currently have a connection to this peer + Connected, + + /// Our next connection tentative (the nth, where n is the first value of the tuple) + /// will be at given Instant + Waiting(usize, Instant), + + /// A connection tentative is in progress (the nth, where n is the value stored) + Trying(usize), + + /// We abandonned trying to connect to this peer (too many failed attempts) + Abandonned, +} + +impl PeerConnState { + /// Returns true if we can currently send requests to this peer + pub fn is_up(&self) -> bool { + matches!(self, Self::Ourself | Self::Connected) + } +} + +struct KnownHosts { + list: HashMap<NodeID, PeerInfoInternal>, + hash: hash::Digest, +} + +impl KnownHosts { + fn new() -> Self { + let list = HashMap::new(); + let hash = Self::calculate_hash(&list); + Self { list, hash } + } + fn update_hash(&mut self) { + self.hash = Self::calculate_hash(&self.list); + } + fn map_into_vec(input: &HashMap<NodeID, PeerInfoInternal>) -> Vec<(NodeID, SocketAddr)> { + let mut list = Vec::with_capacity(input.len()); + for (id, peer) in input.iter() { + if peer.state == PeerConnState::Connected || peer.state == PeerConnState::Ourself { + list.push((*id, peer.addr)); + } + } + list + } + fn calculate_hash(input: &HashMap<NodeID, PeerInfoInternal>) -> hash::Digest { + let mut list = Self::map_into_vec(input); + list.sort(); + let mut hash_state = hash::State::new(); + for (id, addr) in list { + hash_state.update(&id[..]); + hash_state.update(&format!("{}\n", addr).into_bytes()[..]); + } + hash_state.finalize() + } +} + +/// A "Full Mesh" peering strategy is a peering strategy that tries +/// to establish and maintain a direct connection with all of the +/// known nodes in the network. +pub struct PeeringManager { + netapp: Arc<NetApp>, + known_hosts: RwLock<KnownHosts>, + public_peer_list: ArcSwap<Vec<PeerInfo>>, + + next_ping_id: AtomicU64, + ping_endpoint: Arc<Endpoint<PingMessage, Self>>, + peer_list_endpoint: Arc<Endpoint<PeerListMessage, Self>>, + + ping_timeout_millis: AtomicU64, +} + +impl PeeringManager { + /// Create a new Full Mesh peering strategy. + /// The strategy will not be run until `.run()` is called and awaited. + /// Once that happens, the peering strategy will try to connect + /// to all of the nodes specified in the bootstrap list. + pub fn new( + netapp: Arc<NetApp>, + bootstrap_list: Vec<(NodeID, SocketAddr)>, + our_addr: Option<SocketAddr>, + ) -> Arc<Self> { + let mut known_hosts = KnownHosts::new(); + for (id, addr) in bootstrap_list { + if id != netapp.id { + known_hosts.list.insert( + id, + PeerInfoInternal::new(addr, PeerConnState::Waiting(0, Instant::now())), + ); + } + } + + if let Some(addr) = our_addr { + known_hosts.list.insert( + netapp.id, + PeerInfoInternal::new(addr, PeerConnState::Ourself), + ); + } + + // TODO for v0.10 / v1.0 : rename the endpoint (it will break compatibility) + let strat = Arc::new(Self { + netapp: netapp.clone(), + known_hosts: RwLock::new(known_hosts), + public_peer_list: ArcSwap::new(Arc::new(Vec::new())), + next_ping_id: AtomicU64::new(42), + ping_endpoint: netapp.endpoint("__netapp/peering/fullmesh.rs/Ping".into()), + peer_list_endpoint: netapp.endpoint("__netapp/peering/fullmesh.rs/PeerList".into()), + ping_timeout_millis: DEFAULT_PING_TIMEOUT_MILLIS.into(), + }); + + strat.update_public_peer_list(&strat.known_hosts.read().unwrap()); + + strat.ping_endpoint.set_handler(strat.clone()); + strat.peer_list_endpoint.set_handler(strat.clone()); + + let strat2 = strat.clone(); + netapp.on_connected(move |id: NodeID, addr: SocketAddr, is_incoming: bool| { + let strat2 = strat2.clone(); + strat2.on_connected(id, addr, is_incoming); + }); + + let strat2 = strat.clone(); + netapp.on_disconnected(move |id: NodeID, is_incoming: bool| { + let strat2 = strat2.clone(); + strat2.on_disconnected(id, is_incoming); + }); + + strat + } + + /// Run the full mesh peering strategy. + /// This future exits when the `must_exit` watch becomes true. + pub async fn run(self: Arc<Self>, must_exit: watch::Receiver<bool>) { + while !*must_exit.borrow() { + // 1. Read current state: get list of connected peers (ping them) + let (to_ping, to_retry) = { + let known_hosts = self.known_hosts.read().unwrap(); + trace!("known_hosts: {} peers", known_hosts.list.len()); + + let mut to_ping = vec![]; + let mut to_retry = vec![]; + for (id, info) in known_hosts.list.iter() { + trace!("{}, {:?}", hex::encode(&id[..8]), info); + match info.state { + PeerConnState::Connected => { + let must_ping = match info.last_send_ping { + None => true, + Some(t) => Instant::now() - t > PING_INTERVAL, + }; + if must_ping { + to_ping.push(*id); + } + } + PeerConnState::Waiting(_, t) => { + if Instant::now() >= t { + to_retry.push(*id); + } + } + _ => (), + } + } + (to_ping, to_retry) + }; + + // 2. Dispatch ping to hosts + trace!("to_ping: {} peers", to_ping.len()); + if !to_ping.is_empty() { + let mut known_hosts = self.known_hosts.write().unwrap(); + for id in to_ping.iter() { + known_hosts.list.get_mut(id).unwrap().last_send_ping = Some(Instant::now()); + } + drop(known_hosts); + for id in to_ping { + tokio::spawn(self.clone().ping(id)); + } + } + + // 3. Try reconnects + trace!("to_retry: {} peers", to_retry.len()); + if !to_retry.is_empty() { + let mut known_hosts = self.known_hosts.write().unwrap(); + for id in to_retry { + if let Some(h) = known_hosts.list.get_mut(&id) { + if let PeerConnState::Waiting(i, _) = h.state { + info!( + "Retrying connection to {} at {} ({})", + hex::encode(&id[..8]), + h.all_addrs + .iter() + .map(|x| format!("{}", x)) + .collect::<Vec<_>>() + .join(", "), + i + 1 + ); + h.state = PeerConnState::Trying(i); + + let alternate_addrs = h + .all_addrs + .iter() + .filter(|x| **x != h.addr) + .cloned() + .collect::<Vec<_>>(); + tokio::spawn(self.clone().try_connect(id, h.addr, alternate_addrs)); + } + } + } + self.update_public_peer_list(&known_hosts); + } + + // 4. Sleep before next loop iteration + tokio::time::sleep(LOOP_DELAY).await; + } + } + + /// Returns a list of currently known peers in the network. + pub fn get_peer_list(&self) -> Arc<Vec<PeerInfo>> { + self.public_peer_list.load_full() + } + + /// Set the timeout for ping messages, in milliseconds + pub fn set_ping_timeout_millis(&self, timeout: u64) { + self.ping_timeout_millis + .store(timeout, atomic::Ordering::Relaxed); + } + + // -- internal stuff -- + + fn update_public_peer_list(&self, known_hosts: &KnownHosts) { + let mut pub_peer_list = Vec::with_capacity(known_hosts.list.len()); + for (id, info) in known_hosts.list.iter() { + let mut pings = info.ping.iter().cloned().collect::<Vec<_>>(); + pings.sort(); + if !pings.is_empty() { + pub_peer_list.push(PeerInfo { + id: *id, + addr: info.addr, + state: info.state, + last_seen: info.last_seen, + avg_ping: Some( + pings + .iter() + .fold(Duration::from_secs(0), |x, y| x + *y) + .div_f64(pings.len() as f64), + ), + max_ping: pings.last().cloned(), + med_ping: Some(pings[pings.len() / 2]), + }); + } else { + pub_peer_list.push(PeerInfo { + id: *id, + addr: info.addr, + state: info.state, + last_seen: info.last_seen, + avg_ping: None, + max_ping: None, + med_ping: None, + }); + } + } + self.public_peer_list.store(Arc::new(pub_peer_list)); + } + + async fn ping(self: Arc<Self>, id: NodeID) { + let peer_list_hash = self.known_hosts.read().unwrap().hash; + let ping_id = self.next_ping_id.fetch_add(1u64, atomic::Ordering::Relaxed); + let ping_time = Instant::now(); + let ping_timeout = + Duration::from_millis(self.ping_timeout_millis.load(atomic::Ordering::Relaxed)); + let ping_msg = PingMessage { + id: ping_id, + peer_list_hash, + }; + + debug!( + "Sending ping {} to {} at {:?}", + ping_id, + hex::encode(&id[..8]), + ping_time + ); + let ping_response = select! { + r = self.ping_endpoint.call(&id, ping_msg, PRIO_HIGH) => r, + _ = tokio::time::sleep(ping_timeout) => Err(Error::Message("Ping timeout".into())), + }; + + match ping_response { + Err(e) => { + warn!("Error pinging {}: {}", hex::encode(&id[..8]), e); + let mut known_hosts = self.known_hosts.write().unwrap(); + if let Some(host) = known_hosts.list.get_mut(&id) { + host.failed_pings += 1; + if host.failed_pings > FAILED_PING_THRESHOLD { + warn!( + "Too many failed pings from {}, closing connection.", + hex::encode(&id[..8]) + ); + // this will later update info in known_hosts + // through the disconnection handler + self.netapp.disconnect(&id); + } + } + } + Ok(ping_resp) => { + let resp_time = Instant::now(); + debug!( + "Got ping response from {} at {:?}", + hex::encode(&id[..8]), + resp_time + ); + { + let mut known_hosts = self.known_hosts.write().unwrap(); + if let Some(host) = known_hosts.list.get_mut(&id) { + host.failed_pings = 0; + host.last_seen = Some(resp_time); + host.ping.push_back(resp_time - ping_time); + while host.ping.len() > 10 { + host.ping.pop_front(); + } + self.update_public_peer_list(&known_hosts); + } + } + if ping_resp.peer_list_hash != peer_list_hash { + self.exchange_peers(&id).await; + } + } + } + } + + async fn exchange_peers(self: Arc<Self>, id: &NodeID) { + let peer_list = KnownHosts::map_into_vec(&self.known_hosts.read().unwrap().list); + let pex_message = PeerListMessage { list: peer_list }; + match self + .peer_list_endpoint + .call(id, pex_message, PRIO_BACKGROUND) + .await + { + Err(e) => warn!("Error doing peer exchange: {}", e), + Ok(resp) => { + self.handle_peer_list(&resp.list[..]); + } + } + } + + fn handle_peer_list(&self, list: &[(NodeID, SocketAddr)]) { + let mut known_hosts = self.known_hosts.write().unwrap(); + + let mut changed = false; + for (id, addr) in list.iter() { + if let Some(kh) = known_hosts.list.get_mut(id) { + if !kh.all_addrs.contains(addr) { + kh.all_addrs.push(*addr); + changed = true; + } + } else { + known_hosts.list.insert(*id, self.new_peer(id, *addr)); + changed = true; + } + } + + if changed { + known_hosts.update_hash(); + self.update_public_peer_list(&known_hosts); + } + } + + async fn try_connect( + self: Arc<Self>, + id: NodeID, + default_addr: SocketAddr, + alternate_addrs: Vec<SocketAddr>, + ) { + let conn_addr = { + let mut ret = None; + for addr in [default_addr].iter().chain(alternate_addrs.iter()) { + debug!("Trying address {} for peer {}", addr, hex::encode(&id[..8])); + match self.netapp.clone().try_connect(*addr, id).await { + Ok(()) => { + ret = Some(*addr); + break; + } + Err(e) => { + debug!( + "Error connecting to {} at {}: {}", + hex::encode(&id[..8]), + addr, + e + ); + } + } + } + ret + }; + + if let Some(ok_addr) = conn_addr { + self.on_connected(id, ok_addr, false); + } else { + warn!( + "Could not connect to peer {} ({} addresses tried)", + hex::encode(&id[..8]), + 1 + alternate_addrs.len() + ); + let mut known_hosts = self.known_hosts.write().unwrap(); + if let Some(host) = known_hosts.list.get_mut(&id) { + host.state = match host.state { + PeerConnState::Trying(i) => { + if i >= CONN_MAX_RETRIES { + PeerConnState::Abandonned + } else { + PeerConnState::Waiting(i + 1, Instant::now() + CONN_RETRY_INTERVAL) + } + } + _ => PeerConnState::Waiting(0, Instant::now() + CONN_RETRY_INTERVAL), + }; + self.update_public_peer_list(&known_hosts); + } + } + } + + fn on_connected(self: Arc<Self>, id: NodeID, addr: SocketAddr, is_incoming: bool) { + let mut known_hosts = self.known_hosts.write().unwrap(); + if is_incoming { + if let Some(host) = known_hosts.list.get_mut(&id) { + if !host.all_addrs.contains(&addr) { + host.all_addrs.push(addr); + } + } else { + known_hosts.list.insert(id, self.new_peer(&id, addr)); + } + } else { + info!( + "Successfully connected to {} at {}", + hex::encode(&id[..8]), + addr + ); + if let Some(host) = known_hosts.list.get_mut(&id) { + host.state = PeerConnState::Connected; + host.addr = addr; + if !host.all_addrs.contains(&addr) { + host.all_addrs.push(addr); + } + } else { + known_hosts + .list + .insert(id, PeerInfoInternal::new(addr, PeerConnState::Connected)); + } + } + known_hosts.update_hash(); + self.update_public_peer_list(&known_hosts); + } + + fn on_disconnected(self: Arc<Self>, id: NodeID, is_incoming: bool) { + if !is_incoming { + info!("Connection to {} was closed", hex::encode(&id[..8])); + let mut known_hosts = self.known_hosts.write().unwrap(); + if let Some(host) = known_hosts.list.get_mut(&id) { + host.state = PeerConnState::Waiting(0, Instant::now()); + known_hosts.update_hash(); + self.update_public_peer_list(&known_hosts); + } + } + } + + fn new_peer(&self, id: &NodeID, addr: SocketAddr) -> PeerInfoInternal { + let state = if *id == self.netapp.id { + PeerConnState::Ourself + } else { + PeerConnState::Waiting(0, Instant::now()) + }; + PeerInfoInternal::new(addr, state) + } +} + +#[async_trait] +impl EndpointHandler<PingMessage> for PeeringManager { + async fn handle(self: &Arc<Self>, ping: &PingMessage, from: NodeID) -> PingMessage { + let ping_resp = PingMessage { + id: ping.id, + peer_list_hash: self.known_hosts.read().unwrap().hash, + }; + debug!("Ping from {}", hex::encode(&from[..8])); + ping_resp + } +} + +#[async_trait] +impl EndpointHandler<PeerListMessage> for PeeringManager { + async fn handle( + self: &Arc<Self>, + peer_list: &PeerListMessage, + _from: NodeID, + ) -> PeerListMessage { + self.handle_peer_list(&peer_list.list[..]); + let peer_list = KnownHosts::map_into_vec(&self.known_hosts.read().unwrap().list); + PeerListMessage { list: peer_list } + } +} diff --git a/src/net/recv.rs b/src/net/recv.rs new file mode 100644 index 00000000..0de7bef2 --- /dev/null +++ b/src/net/recv.rs @@ -0,0 +1,153 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use async_trait::async_trait; +use bytes::Bytes; +use log::*; + +use futures::AsyncReadExt; +use tokio::sync::mpsc; + +use crate::error::*; +use crate::send::*; +use crate::stream::*; + +/// Structure to warn when the sender is dropped before end of stream was reached, like when +/// connection to some remote drops while transmitting data +struct Sender { + inner: Option<mpsc::UnboundedSender<Packet>>, +} + +impl Sender { + fn new(inner: mpsc::UnboundedSender<Packet>) -> Self { + Sender { inner: Some(inner) } + } + + fn send(&self, packet: Packet) { + let _ = self.inner.as_ref().unwrap().send(packet); + } + + fn end(&mut self) { + self.inner = None; + } +} + +impl Drop for Sender { + fn drop(&mut self) { + if let Some(inner) = self.inner.take() { + let _ = inner.send(Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "Netapp connection dropped before end of stream", + ))); + } + } +} + +/// The RecvLoop trait, which is implemented both by the client and the server +/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` +/// and a prototype of a handler for received messages `.recv_handler()` that +/// must be filled by implementors. `.recv_loop()` receives messages in a loop +/// according to the protocol defined above: chunks of message in progress of being +/// received are stored in a buffer, and when the last chunk of a message is received, +/// the full message is passed to the receive handler. +#[async_trait] +pub(crate) trait RecvLoop: Sync + 'static { + fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream); + fn cancel_handler(self: &Arc<Self>, _id: RequestID) {} + + async fn recv_loop<R>(self: Arc<Self>, mut read: R, debug_name: String) -> Result<(), Error> + where + R: AsyncReadExt + Unpin + Send + Sync, + { + let mut streams: HashMap<RequestID, Sender> = HashMap::new(); + loop { + trace!( + "recv_loop({}): in_progress = {:?}", + debug_name, + streams.iter().map(|(id, _)| id).collect::<Vec<_>>() + ); + + let mut header_id = [0u8; RequestID::BITS as usize / 8]; + match read.read_exact(&mut header_id[..]).await { + Ok(_) => (), + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break, + Err(e) => return Err(e.into()), + }; + let id = RequestID::from_be_bytes(header_id); + + let mut header_size = [0u8; ChunkLength::BITS as usize / 8]; + read.read_exact(&mut header_size[..]).await?; + let size = ChunkLength::from_be_bytes(header_size); + + if size == CANCEL_REQUEST { + if let Some(mut stream) = streams.remove(&id) { + let _ = stream.send(Err(std::io::Error::new( + std::io::ErrorKind::Other, + "netapp: cancel requested", + ))); + stream.end(); + } + self.cancel_handler(id); + continue; + } + + let has_cont = (size & CHUNK_FLAG_HAS_CONTINUATION) != 0; + let is_error = (size & CHUNK_FLAG_ERROR) != 0; + let size = (size & CHUNK_LENGTH_MASK) as usize; + let mut next_slice = vec![0; size as usize]; + read.read_exact(&mut next_slice[..]).await?; + + let packet = if is_error { + let kind = u8_to_io_errorkind(next_slice[0]); + let msg = + std::str::from_utf8(&next_slice[1..]).unwrap_or("<invalid utf8 error message>"); + debug!( + "recv_loop({}): got id {}, error {:?}: {}", + debug_name, id, kind, msg + ); + Some(Err(std::io::Error::new(kind, msg.to_string()))) + } else { + trace!( + "recv_loop({}): got id {}, size {}, has_cont {}", + debug_name, + id, + size, + has_cont + ); + if !next_slice.is_empty() { + Some(Ok(Bytes::from(next_slice))) + } else { + None + } + }; + + let mut sender = if let Some(send) = streams.remove(&(id)) { + send + } else { + let (send, recv) = mpsc::unbounded_channel(); + trace!("recv_loop({}): id {} is new channel", debug_name, id); + self.recv_handler( + id, + Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(recv)), + ); + Sender::new(send) + }; + + if let Some(packet) = packet { + // If we cannot put packet in channel, it means that the + // receiving end of the channel is disconnected. + // We still need to reach eos before dropping this sender + let _ = sender.send(packet); + } + + if has_cont { + assert!(!is_error); + streams.insert(id, sender); + } else { + trace!("recv_loop({}): close channel id {}", debug_name, id); + sender.end(); + } + } + Ok(()) + } +} diff --git a/src/net/send.rs b/src/net/send.rs new file mode 100644 index 00000000..0db0ba77 --- /dev/null +++ b/src/net/send.rs @@ -0,0 +1,356 @@ +use std::collections::{HashMap, VecDeque}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use async_trait::async_trait; +use bytes::{BufMut, Bytes, BytesMut}; +use log::*; + +use futures::{AsyncWriteExt, Future}; +use kuska_handshake::async_std::BoxStreamWrite; +use tokio::sync::mpsc; + +use crate::error::*; +use crate::message::*; +use crate::stream::*; + +// Messages are sent by chunks +// Chunk format: +// - u32 BE: request id (same for request and response) +// - u16 BE: chunk length + flags: +// CHUNK_FLAG_HAS_CONTINUATION when this is not the last chunk of the stream +// CHUNK_FLAG_ERROR if this chunk denotes an error +// (these two flags are exclusive, an error denotes the end of the stream) +// **special value** 0xFFFF indicates a CANCEL message +// - [u8; chunk_length], either +// - if not error: chunk data +// - if error: +// - u8: error kind, encoded using error::io_errorkind_to_u8 +// - rest: error message +// - absent for cancel messag + +pub(crate) type RequestID = u32; +pub(crate) type ChunkLength = u16; + +pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0; +pub(crate) const CHUNK_FLAG_ERROR: ChunkLength = 0x4000; +pub(crate) const CHUNK_FLAG_HAS_CONTINUATION: ChunkLength = 0x8000; +pub(crate) const CHUNK_LENGTH_MASK: ChunkLength = 0x3FFF; +pub(crate) const CANCEL_REQUEST: ChunkLength = 0xFFFF; + +pub(crate) enum SendItem { + Stream(RequestID, RequestPriority, Option<OrderTag>, ByteStream), + Cancel(RequestID), +} + +// ---- + +struct SendQueue { + items: Vec<(u8, SendQueuePriority)>, +} + +struct SendQueuePriority { + items: VecDeque<SendQueueItem>, + order: HashMap<u64, VecDeque<u64>>, +} + +struct SendQueueItem { + id: RequestID, + prio: RequestPriority, + order_tag: Option<OrderTag>, + data: ByteStreamReader, + sent: usize, +} + +impl SendQueue { + fn new() -> Self { + Self { + items: Vec::with_capacity(64), + } + } + fn push(&mut self, item: SendQueueItem) { + let prio = item.prio; + let pos_prio = match self.items.binary_search_by(|(p, _)| p.cmp(&prio)) { + Ok(i) => i, + Err(i) => { + self.items.insert(i, (prio, SendQueuePriority::new())); + i + } + }; + self.items[pos_prio].1.push(item); + } + fn remove(&mut self, id: RequestID) { + for (_, prioq) in self.items.iter_mut() { + prioq.remove(id); + } + self.items.retain(|(_prio, q)| !q.is_empty()); + } + fn is_empty(&self) -> bool { + self.items.iter().all(|(_k, v)| v.is_empty()) + } + + // this is like an async fn, but hand implemented + fn next_ready(&mut self) -> SendQueuePollNextReady<'_> { + SendQueuePollNextReady { queue: self } + } +} + +impl SendQueuePriority { + fn new() -> Self { + Self { + items: VecDeque::new(), + order: HashMap::new(), + } + } + fn push(&mut self, item: SendQueueItem) { + if let Some(OrderTag(stream, order)) = item.order_tag { + let order_vec = self.order.entry(stream).or_default(); + let i = order_vec.iter().take_while(|o2| **o2 < order).count(); + order_vec.insert(i, order); + } + self.items.push_front(item); + } + fn remove(&mut self, id: RequestID) { + if let Some(i) = self.items.iter().position(|x| x.id == id) { + let item = self.items.remove(i).unwrap(); + if let Some(OrderTag(stream, order)) = item.order_tag { + let order_vec = self.order.get_mut(&stream).unwrap(); + let j = order_vec.iter().position(|x| *x == order).unwrap(); + order_vec.remove(j).unwrap(); + if order_vec.is_empty() { + self.order.remove(&stream); + } + } + } + } + fn is_empty(&self) -> bool { + self.items.is_empty() + } + fn poll_next_ready(&mut self, ctx: &mut Context<'_>) -> Poll<(RequestID, DataFrame)> { + for (j, item) in self.items.iter_mut().enumerate() { + if let Some(OrderTag(stream, order)) = item.order_tag { + if order > *self.order.get(&stream).unwrap().front().unwrap() { + continue; + } + } + + let mut item_reader = item.data.read_exact_or_eos(MAX_CHUNK_LENGTH as usize); + if let Poll::Ready(bytes_or_err) = Pin::new(&mut item_reader).poll(ctx) { + let id = item.id; + let eos = item.data.eos(); + + let packet = bytes_or_err.map_err(|e| match e { + ReadExactError::Stream(err) => err, + _ => unreachable!(), + }); + + let is_err = packet.is_err(); + let data_frame = DataFrame::from_packet(packet, !eos); + item.sent += data_frame.data().len(); + + if eos || is_err { + // If item had an order tag, remove it from the corresponding ordering list + if let Some(OrderTag(stream, order)) = item.order_tag { + let order_stream = self.order.get_mut(&stream).unwrap(); + assert_eq!(order_stream.pop_front(), Some(order)); + if order_stream.is_empty() { + self.order.remove(&stream); + } + } + // Remove item from sending queue + self.items.remove(j); + } else { + // Move item later in send queue to implement LAS scheduling + // (LAS = Least Attained Service) + for k in j..self.items.len() - 1 { + if self.items[k].sent >= self.items[k + 1].sent { + self.items.swap(k, k + 1); + } else { + break; + } + } + } + + return Poll::Ready((id, data_frame)); + } + } + + Poll::Pending + } + fn dump(&self, prio: u8) -> String { + self.items + .iter() + .map(|i| format!("[{} {} {:?} @{}]", prio, i.id, i.order_tag, i.sent)) + .collect::<Vec<_>>() + .join(" ") + } +} + +struct SendQueuePollNextReady<'a> { + queue: &'a mut SendQueue, +} + +impl<'a> futures::Future for SendQueuePollNextReady<'a> { + type Output = (RequestID, DataFrame); + + fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> { + for (i, (_prio, items_at_prio)) in self.queue.items.iter_mut().enumerate() { + if let Poll::Ready(res) = items_at_prio.poll_next_ready(ctx) { + if items_at_prio.is_empty() { + self.queue.items.remove(i); + } + return Poll::Ready(res); + } + } + // If the queue is empty, this futures is eternally pending. + // This is ok because we use it in a select with another future + // that can interrupt it. + Poll::Pending + } +} + +enum DataFrame { + /// a fixed size buffer containing some data + a boolean indicating whether + /// there may be more data comming from this stream. Can be used for some + /// optimization. It's an error to set it to false if there is more data, but it is correct + /// (albeit sub-optimal) to set it to true if there is nothing coming after + Data(Bytes, bool), + /// An error code automatically signals the end of the stream + Error(Bytes), +} + +impl DataFrame { + fn from_packet(p: Packet, has_cont: bool) -> Self { + match p { + Ok(bytes) => { + assert!(bytes.len() <= MAX_CHUNK_LENGTH as usize); + Self::Data(bytes, has_cont) + } + Err(e) => { + let mut buf = BytesMut::new(); + buf.put_u8(io_errorkind_to_u8(e.kind())); + + let msg = format!("{}", e).into_bytes(); + if msg.len() > (MAX_CHUNK_LENGTH - 1) as usize { + buf.put(&msg[..(MAX_CHUNK_LENGTH - 1) as usize]); + } else { + buf.put(&msg[..]); + } + + Self::Error(buf.freeze()) + } + } + } + + fn header(&self) -> [u8; 2] { + let header_u16 = match self { + DataFrame::Data(data, false) => data.len() as u16, + DataFrame::Data(data, true) => data.len() as u16 | CHUNK_FLAG_HAS_CONTINUATION, + DataFrame::Error(msg) => msg.len() as u16 | CHUNK_FLAG_ERROR, + }; + ChunkLength::to_be_bytes(header_u16) + } + + fn data(&self) -> &[u8] { + match self { + DataFrame::Data(ref data, _) => &data[..], + DataFrame::Error(ref msg) => &msg[..], + } + } +} + +/// The SendLoop trait, which is implemented both by the client and the server +/// connection objects (ServerConna and ClientConn) adds a method `.send_loop()` +/// that takes a channel of messages to send and an asynchronous writer, +/// and sends messages from the channel to the async writer, putting them in a queue +/// before being sent and doing the round-robin sending strategy. +/// +/// The `.send_loop()` exits when the sending end of the channel is closed, +/// or if there is an error at any time writing to the async writer. +#[async_trait] +pub(crate) trait SendLoop: Sync { + async fn send_loop<W>( + self: Arc<Self>, + msg_recv: mpsc::UnboundedReceiver<SendItem>, + mut write: BoxStreamWrite<W>, + debug_name: String, + ) -> Result<(), Error> + where + W: AsyncWriteExt + Unpin + Send + Sync, + { + let mut sending = SendQueue::new(); + let mut msg_recv = Some(msg_recv); + while msg_recv.is_some() || !sending.is_empty() { + trace!( + "send_loop({}): queue = {:?}", + debug_name, + sending + .items + .iter() + .map(|(prio, i)| i.dump(*prio)) + .collect::<Vec<_>>() + .join(" ; ") + ); + + let recv_fut = async { + if let Some(chan) = &mut msg_recv { + chan.recv().await + } else { + futures::future::pending().await + } + }; + let send_fut = sending.next_ready(); + + // recv_fut is cancellation-safe according to tokio doc, + // send_fut is cancellation-safe as implemented above? + tokio::select! { + biased; // always read incomming channel first if it has data + sth = recv_fut => { + match sth { + Some(SendItem::Stream(id, prio, order_tag, data)) => { + trace!("send_loop({}): add stream {} to send", debug_name, id); + sending.push(SendQueueItem { + id, + prio, + order_tag, + data: ByteStreamReader::new(data), + sent: 0, + }) + } + Some(SendItem::Cancel(id)) => { + trace!("send_loop({}): cancelling {}", debug_name, id); + sending.remove(id); + let header_id = RequestID::to_be_bytes(id); + write.write_all(&header_id[..]).await?; + write.write_all(&ChunkLength::to_be_bytes(CANCEL_REQUEST)).await?; + write.flush().await?; + } + None => { + msg_recv = None; + } + }; + } + (id, data) = send_fut => { + trace!( + "send_loop({}): id {}, send {} bytes, header_size {}", + debug_name, + id, + data.data().len(), + hex::encode(data.header()) + ); + + let header_id = RequestID::to_be_bytes(id); + write.write_all(&header_id[..]).await?; + + write.write_all(&data.header()).await?; + write.write_all(data.data()).await?; + write.flush().await?; + } + } + } + + let _ = write.goodbye().await; + Ok(()) + } +} diff --git a/src/net/server.rs b/src/net/server.rs new file mode 100644 index 00000000..55b9e678 --- /dev/null +++ b/src/net/server.rs @@ -0,0 +1,222 @@ +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::{Arc, Mutex}; + +use arc_swap::ArcSwapOption; +use async_trait::async_trait; +use log::*; + +use futures::io::{AsyncReadExt, AsyncWriteExt}; +use kuska_handshake::async_std::{handshake_server, BoxStream}; +use tokio::net::TcpStream; +use tokio::select; +use tokio::sync::{mpsc, watch}; +use tokio_util::compat::*; + +#[cfg(feature = "telemetry")] +use opentelemetry::{ + trace::{FutureExt, Span, SpanKind, TraceContextExt, TraceId, Tracer}, + Context, KeyValue, +}; +#[cfg(feature = "telemetry")] +use opentelemetry_contrib::trace::propagator::binary::*; +#[cfg(feature = "telemetry")] +use rand::{thread_rng, Rng}; + +use crate::error::*; +use crate::message::*; +use crate::netapp::*; +use crate::recv::*; +use crate::send::*; +use crate::stream::*; +use crate::util::*; + +// The client and server connection structs (client.rs and server.rs) +// build upon the chunking mechanism which is exclusively contained +// in proto.rs. +// Here, we just care about sending big messages without size limit. +// The format of these messages is described below. +// Chunking happens independently. + +// Request message format (client -> server): +// - u8 priority +// - u8 path length +// - [u8; path length] path +// - [u8; *] data + +// Response message format (server -> client): +// - u8 response code +// - [u8; *] response + +pub(crate) struct ServerConn { + pub(crate) remote_addr: SocketAddr, + pub(crate) peer_id: NodeID, + + netapp: Arc<NetApp>, + + resp_send: ArcSwapOption<mpsc::UnboundedSender<SendItem>>, + running_handlers: Mutex<HashMap<RequestID, tokio::task::JoinHandle<()>>>, +} + +impl ServerConn { + pub(crate) async fn run( + netapp: Arc<NetApp>, + socket: TcpStream, + must_exit: watch::Receiver<bool>, + ) -> Result<(), Error> { + let remote_addr = socket.peer_addr()?; + let mut socket = socket.compat(); + + // Do handshake to authenticate client + let handshake = handshake_server( + &mut socket, + netapp.netid.clone(), + netapp.id, + netapp.privkey.clone(), + ) + .await?; + let peer_id = handshake.peer_pk; + + debug!( + "Handshake complete (server) with {}@{}", + hex::encode(peer_id), + remote_addr + ); + + // Create BoxStream layer that encodes content + let (read, write) = socket.split(); + let (read, mut write) = + BoxStream::from_handshake(read, write, handshake, 0x8000).split_read_write(); + + // Before doing anything, send version tag, so that client + // can check and disconnect if version is wrong + write.write_all(&netapp.version_tag[..]).await?; + write.flush().await?; + + // Build and launch stuff that handles requests server-side + let (resp_send, resp_recv) = mpsc::unbounded_channel(); + + let conn = Arc::new(ServerConn { + netapp: netapp.clone(), + remote_addr, + peer_id, + resp_send: ArcSwapOption::new(Some(Arc::new(resp_send))), + running_handlers: Mutex::new(HashMap::new()), + }); + + netapp.connected_as_server(peer_id, conn.clone()); + + let debug_name = format!("SRV {}", hex::encode(&peer_id[..8])); + let debug_name_2 = debug_name.clone(); + + let conn2 = conn.clone(); + let recv_future = tokio::spawn(async move { + select! { + r = conn2.recv_loop(read, debug_name_2) => r, + _ = await_exit(must_exit) => Ok(()) + } + }); + let send_future = tokio::spawn(conn.clone().send_loop(resp_recv, write, debug_name)); + + recv_future.await.log_err("ServerConn recv_loop"); + conn.resp_send.store(None); + send_future.await.log_err("ServerConn send_loop"); + + netapp.disconnected_as_server(&peer_id, conn); + + Ok(()) + } + + async fn recv_handler_aux(self: &Arc<Self>, req_enc: ReqEnc) -> Result<RespEnc, Error> { + let path = String::from_utf8(req_enc.path.to_vec())?; + + let handler_opt = { + let endpoints = self.netapp.endpoints.read().unwrap(); + endpoints.get(&path).map(|e| e.clone_endpoint()) + }; + + if let Some(handler) = handler_opt { + cfg_if::cfg_if! { + if #[cfg(feature = "telemetry")] { + let tracer = opentelemetry::global::tracer("netapp"); + + let mut span = if !req_enc.telemetry_id.is_empty() { + let propagator = BinaryPropagator::new(); + let context = propagator.from_bytes(req_enc.telemetry_id.to_vec()); + let context = Context::new().with_remote_span_context(context); + tracer.span_builder(format!(">> RPC {}", path)) + .with_kind(SpanKind::Server) + .start_with_context(&tracer, &context) + } else { + let mut rng = thread_rng(); + let trace_id = TraceId::from_bytes(rng.gen()); + tracer + .span_builder(format!(">> RPC {}", path)) + .with_kind(SpanKind::Server) + .with_trace_id(trace_id) + .start(&tracer) + }; + span.set_attribute(KeyValue::new("path", path.to_string())); + span.set_attribute(KeyValue::new("len_query_msg", req_enc.msg.len() as i64)); + + handler.handle(req_enc, self.peer_id) + .with_context(Context::current_with_span(span)) + .await + } else { + handler.handle(req_enc, self.peer_id).await + } + } + } else { + Err(Error::NoHandler) + } + } +} + +impl SendLoop for ServerConn {} + +#[async_trait] +impl RecvLoop for ServerConn { + fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream) { + let resp_send = match self.resp_send.load_full() { + Some(c) => c, + None => return, + }; + + let mut rh = self.running_handlers.lock().unwrap(); + + let self2 = self.clone(); + let jh = tokio::spawn(async move { + debug!("server: recv_handler got {}", id); + + let (prio, resp_enc_result) = match ReqEnc::decode(stream).await { + Ok(req_enc) => (req_enc.prio, self2.recv_handler_aux(req_enc).await), + Err(e) => (PRIO_HIGH, Err(e)), + }; + + debug!("server: sending response to {}", id); + + let (resp_stream, resp_order) = RespEnc::encode(resp_enc_result); + resp_send + .send(SendItem::Stream(id, prio, resp_order, resp_stream)) + .log_err("ServerConn recv_handler send resp bytes"); + + self2.running_handlers.lock().unwrap().remove(&id); + }); + + rh.insert(id, jh); + } + + fn cancel_handler(self: &Arc<Self>, id: RequestID) { + trace!("received cancel for request {}", id); + + // If the handler is still running, abort it now + if let Some(jh) = self.running_handlers.lock().unwrap().remove(&id) { + jh.abort(); + } + + // Inform the response sender that we don't need to send the response + if let Some(resp_send) = self.resp_send.load_full() { + let _ = resp_send.send(SendItem::Cancel(id)); + } + } +} diff --git a/src/net/stream.rs b/src/net/stream.rs new file mode 100644 index 00000000..88c3fed4 --- /dev/null +++ b/src/net/stream.rs @@ -0,0 +1,202 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; + +use bytes::Bytes; + +use futures::Future; +use futures::{Stream, StreamExt}; +use tokio::io::AsyncRead; + +use crate::bytes_buf::BytesBuf; + +/// A stream of bytes (click to read more). +/// +/// When sent through Netapp, the Vec may be split in smaller chunk in such a way +/// consecutive Vec may get merged, but Vec and error code may not be reordered +/// +/// Items sent in the ByteStream may be errors of type `std::io::Error`. +/// An error indicates the end of the ByteStream: a reader should no longer read +/// after recieving an error, and a writer should stop writing after sending an error. +pub type ByteStream = Pin<Box<dyn Stream<Item = Packet> + Send + Sync>>; + +/// A packet sent in a ByteStream, which may contain either +/// a Bytes object or an error +pub type Packet = Result<Bytes, std::io::Error>; + +// ---- + +/// A helper struct to read defined lengths of data from a BytesStream +pub struct ByteStreamReader { + stream: ByteStream, + buf: BytesBuf, + eos: bool, + err: Option<std::io::Error>, +} + +impl ByteStreamReader { + /// Creates a new `ByteStreamReader` from a `ByteStream` + pub fn new(stream: ByteStream) -> Self { + ByteStreamReader { + stream, + buf: BytesBuf::new(), + eos: false, + err: None, + } + } + + /// Read exactly `read_len` bytes from the underlying stream + /// (returns a future) + pub fn read_exact(&mut self, read_len: usize) -> ByteStreamReadExact<'_> { + ByteStreamReadExact { + reader: self, + read_len, + fail_on_eos: true, + } + } + + /// Read at most `read_len` bytes from the underlying stream, or less + /// if the end of the stream is reached (returns a future) + pub fn read_exact_or_eos(&mut self, read_len: usize) -> ByteStreamReadExact<'_> { + ByteStreamReadExact { + reader: self, + read_len, + fail_on_eos: false, + } + } + + /// Read exactly one byte from the underlying stream and returns it + /// as an u8 + pub async fn read_u8(&mut self) -> Result<u8, ReadExactError> { + Ok(self.read_exact(1).await?[0]) + } + + /// Read exactly two bytes from the underlying stream and returns them as an u16 (using + /// big-endian decoding) + pub async fn read_u16(&mut self) -> Result<u16, ReadExactError> { + let bytes = self.read_exact(2).await?; + let mut b = [0u8; 2]; + b.copy_from_slice(&bytes[..]); + Ok(u16::from_be_bytes(b)) + } + + /// Read exactly four bytes from the underlying stream and returns them as an u32 (using + /// big-endian decoding) + pub async fn read_u32(&mut self) -> Result<u32, ReadExactError> { + let bytes = self.read_exact(4).await?; + let mut b = [0u8; 4]; + b.copy_from_slice(&bytes[..]); + Ok(u32::from_be_bytes(b)) + } + + /// Transforms the stream reader back into the underlying stream (starting + /// after everything that the reader has read) + pub fn into_stream(self) -> ByteStream { + let buf_stream = futures::stream::iter(self.buf.into_slices().into_iter().map(Ok)); + if let Some(err) = self.err { + Box::pin(buf_stream.chain(futures::stream::once(async move { Err(err) }))) + } else if self.eos { + Box::pin(buf_stream) + } else { + Box::pin(buf_stream.chain(self.stream)) + } + } + + /// Tries to fill the internal read buffer from the underlying stream if it is empty. + /// Calling this might be necessary to ensure that `.eos()` returns a correct + /// result, otherwise the reader might not be aware that the underlying + /// stream has nothing left to return. + pub async fn fill_buffer(&mut self) { + if self.buf.is_empty() { + let packet = self.stream.next().await; + self.add_stream_next(packet); + } + } + + /// Clears the internal read buffer and returns its content + pub fn take_buffer(&mut self) -> Bytes { + self.buf.take_all() + } + + /// Returns true if the end of the underlying stream has been reached + pub fn eos(&self) -> bool { + self.buf.is_empty() && self.eos + } + + fn try_get(&mut self, read_len: usize) -> Option<Bytes> { + self.buf.take_exact(read_len) + } + + fn add_stream_next(&mut self, packet: Option<Packet>) { + match packet { + Some(Ok(slice)) => { + self.buf.extend(slice); + } + Some(Err(e)) => { + self.err = Some(e); + self.eos = true; + } + None => { + self.eos = true; + } + } + } +} + +/// The error kind that can be returned by `ByteStreamReader::read_exact` and +/// `ByteStreamReader::read_exact_or_eos` +pub enum ReadExactError { + /// The end of the stream was reached before the requested number of bytes could be read + UnexpectedEos, + /// The underlying data stream returned an IO error when trying to read + Stream(std::io::Error), +} + +/// The future returned by `ByteStreamReader::read_exact` and +/// `ByteStreamReader::read_exact_or_eos` +#[pin_project::pin_project] +pub struct ByteStreamReadExact<'a> { + #[pin] + reader: &'a mut ByteStreamReader, + read_len: usize, + fail_on_eos: bool, +} + +impl<'a> Future for ByteStreamReadExact<'a> { + type Output = Result<Bytes, ReadExactError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<Bytes, ReadExactError>> { + let mut this = self.project(); + + loop { + if let Some(bytes) = this.reader.try_get(*this.read_len) { + return Poll::Ready(Ok(bytes)); + } + if let Some(err) = &this.reader.err { + let err = std::io::Error::new(err.kind(), format!("{}", err)); + return Poll::Ready(Err(ReadExactError::Stream(err))); + } + if this.reader.eos { + if *this.fail_on_eos { + return Poll::Ready(Err(ReadExactError::UnexpectedEos)); + } else { + return Poll::Ready(Ok(this.reader.take_buffer())); + } + } + + let next_packet = futures::ready!(this.reader.stream.as_mut().poll_next(cx)); + this.reader.add_stream_next(next_packet); + } + } +} + +// ---- + +/// Turns a `tokio::io::AsyncRead` asynchronous reader into a `ByteStream` +pub fn asyncread_stream<R: AsyncRead + Send + Sync + 'static>(reader: R) -> ByteStream { + Box::pin(tokio_util::io::ReaderStream::new(reader)) +} + +/// Turns a `ByteStream` into a `tokio::io::AsyncRead` asynchronous reader +pub fn stream_asyncread(stream: ByteStream) -> impl AsyncRead + Send + Sync + 'static { + tokio_util::io::StreamReader::new(stream) +} diff --git a/src/net/test.rs b/src/net/test.rs new file mode 100644 index 00000000..c6259752 --- /dev/null +++ b/src/net/test.rs @@ -0,0 +1,118 @@ +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; + +use tokio::select; +use tokio::sync::watch; + +use sodiumoxide::crypto::auth; +use sodiumoxide::crypto::sign::ed25519; + +use crate::netapp::*; +use crate::peering::*; +use crate::NodeID; + +#[tokio::test(flavor = "current_thread")] +async fn test_with_basic_scheduler() { + pretty_env_logger::init(); + run_test(19980).await +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn test_with_threaded_scheduler() { + run_test(19990).await +} + +async fn run_test(port_base: u16) { + select! { + _ = run_test_inner(port_base) => (), + _ = tokio::time::sleep(Duration::from_secs(20)) => panic!("timeout"), + } +} + +async fn run_test_inner(port_base: u16) { + let netid = auth::gen_key(); + + let (pk1, sk1) = ed25519::gen_keypair(); + let (pk2, sk2) = ed25519::gen_keypair(); + let (pk3, sk3) = ed25519::gen_keypair(); + + let addr1: SocketAddr = SocketAddr::new("127.0.0.1".parse().unwrap(), port_base); + let addr2: SocketAddr = SocketAddr::new("127.0.0.1".parse().unwrap(), port_base + 1); + let addr3: SocketAddr = SocketAddr::new("127.0.0.1".parse().unwrap(), port_base + 2); + + let (stop_tx, stop_rx) = watch::channel(false); + + let (thread1, _netapp1, peering1) = + run_netapp(netid.clone(), pk1, sk1, addr1, vec![], stop_rx.clone()); + tokio::time::sleep(Duration::from_secs(2)).await; + + // Connect second node and check it peers with everyone + let (thread2, _netapp2, peering2) = run_netapp( + netid.clone(), + pk2, + sk2, + addr2, + vec![(pk1, addr1)], + stop_rx.clone(), + ); + tokio::time::sleep(Duration::from_secs(3)).await; + + let pl1 = peering1.get_peer_list(); + println!("A pl1: {:?}", pl1); + assert_eq!(pl1.len(), 2); + + let pl2 = peering2.get_peer_list(); + println!("A pl2: {:?}", pl2); + assert_eq!(pl2.len(), 2); + + // Connect third ndoe and check it peers with everyone + let (thread3, _netapp3, peering3) = + run_netapp(netid, pk3, sk3, addr3, vec![(pk2, addr2)], stop_rx.clone()); + tokio::time::sleep(Duration::from_secs(3)).await; + + let pl1 = peering1.get_peer_list(); + println!("B pl1: {:?}", pl1); + assert_eq!(pl1.len(), 3); + + let pl2 = peering2.get_peer_list(); + println!("B pl2: {:?}", pl2); + assert_eq!(pl2.len(), 3); + + let pl3 = peering3.get_peer_list(); + println!("B pl3: {:?}", pl3); + assert_eq!(pl3.len(), 3); + + // Send stop signal and wait for everyone to finish + stop_tx.send(true).unwrap(); + thread1.await.unwrap(); + thread2.await.unwrap(); + thread3.await.unwrap(); +} + +fn run_netapp( + netid: auth::Key, + _pk: NodeID, + sk: ed25519::SecretKey, + listen_addr: SocketAddr, + bootstrap_peers: Vec<(NodeID, SocketAddr)>, + must_exit: watch::Receiver<bool>, +) -> ( + tokio::task::JoinHandle<()>, + Arc<NetApp>, + Arc<PeeringManager>, +) { + let netapp = NetApp::new(0u64, netid, sk); + let peering = PeeringManager::new(netapp.clone(), bootstrap_peers, None); + + let peering2 = peering.clone(); + let netapp2 = netapp.clone(); + let fut = tokio::spawn(async move { + tokio::join!( + netapp2.listen(listen_addr, None, must_exit.clone()), + peering2.run(must_exit.clone()), + ); + }); + + (fut, netapp, peering) +} diff --git a/src/net/util.rs b/src/net/util.rs new file mode 100644 index 00000000..56230b73 --- /dev/null +++ b/src/net/util.rs @@ -0,0 +1,96 @@ +use std::net::SocketAddr; + +use log::info; +use serde::Serialize; + +use tokio::sync::watch; + +use crate::netapp::*; + +/// Utility function: encodes any serializable value in MessagePack binary format +/// using the RMP library. +/// +/// Field names and variant names are included in the serialization. +/// This is used internally by the netapp communication protocol. +pub fn rmp_to_vec_all_named<T>(val: &T) -> Result<Vec<u8>, rmp_serde::encode::Error> +where + T: Serialize + ?Sized, +{ + let mut wr = Vec::with_capacity(128); + let mut se = rmp_serde::Serializer::new(&mut wr).with_struct_map(); + val.serialize(&mut se)?; + Ok(wr) +} + +/// This async function returns only when a true signal was received +/// from a watcher that tells us when to exit. +/// +/// Usefull in a select statement to interrupt another +/// future: +/// ```ignore +/// select!( +/// _ = a_long_task() => Success, +/// _ = await_exit(must_exit) => Interrupted, +/// ) +/// ``` +pub async fn await_exit(mut must_exit: watch::Receiver<bool>) { + while !*must_exit.borrow_and_update() { + if must_exit.changed().await.is_err() { + break; + } + } +} + +/// Creates a watch that contains `false`, and that changes +/// to `true` when a Ctrl+C signal is received. +pub fn watch_ctrl_c() -> watch::Receiver<bool> { + let (send_cancel, watch_cancel) = watch::channel(false); + tokio::spawn(async move { + tokio::signal::ctrl_c() + .await + .expect("failed to install CTRL+C signal handler"); + info!("Received CTRL+C, shutting down."); + send_cancel.send(true).unwrap(); + }); + watch_cancel +} + +/// Parse a peer's address including public key, written in the format: +/// `<public key hex>@<ip>:<port>` +pub fn parse_peer_addr(peer: &str) -> Option<(NodeID, SocketAddr)> { + let delim = peer.find('@')?; + let (key, ip) = peer.split_at(delim); + let pubkey = NodeID::from_slice(&hex::decode(key).ok()?)?; + let ip = ip[1..].parse::<SocketAddr>().ok()?; + Some((pubkey, ip)) +} + +/// Parse and resolve a peer's address including public key, written in the format: +/// `<public key hex>@<ip or hostname>:<port>` +pub fn parse_and_resolve_peer_addr(peer: &str) -> Option<(NodeID, Vec<SocketAddr>)> { + use std::net::ToSocketAddrs; + + let delim = peer.find('@')?; + let (key, host) = peer.split_at(delim); + let pubkey = NodeID::from_slice(&hex::decode(key).ok()?)?; + let hosts = host[1..].to_socket_addrs().ok()?.collect::<Vec<_>>(); + if hosts.is_empty() { + return None; + } + Some((pubkey, hosts)) +} + +/// async version of parse_and_resolve_peer_addr +pub async fn parse_and_resolve_peer_addr_async(peer: &str) -> Option<(NodeID, Vec<SocketAddr>)> { + let delim = peer.find('@')?; + let (key, host) = peer.split_at(delim); + let pubkey = NodeID::from_slice(&hex::decode(key).ok()?)?; + let hosts = tokio::net::lookup_host(&host[1..]) + .await + .ok()? + .collect::<Vec<_>>(); + if hosts.is_empty() { + return None; + } + Some((pubkey, hosts)) +} |