diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/bytes_buf.rs | 173 | ||||
-rw-r--r-- | src/client.rs | 96 | ||||
-rw-r--r-- | src/endpoint.rs | 106 | ||||
-rw-r--r-- | src/error.rs | 69 | ||||
-rw-r--r-- | src/lib.rs | 8 | ||||
-rw-r--r-- | src/message.rs | 470 | ||||
-rw-r--r-- | src/netapp.rs | 19 | ||||
-rw-r--r-- | src/peering/basalt.rs | 6 | ||||
-rw-r--r-- | src/peering/fullmesh.rs | 11 | ||||
-rw-r--r-- | src/proto.rs | 358 | ||||
-rw-r--r-- | src/proto2.rs | 75 | ||||
-rw-r--r-- | src/recv.rs | 137 | ||||
-rw-r--r-- | src/send.rs | 300 | ||||
-rw-r--r-- | src/server.rs | 77 | ||||
-rw-r--r-- | src/stream.rs | 159 | ||||
-rw-r--r-- | src/test.rs | 1 | ||||
-rw-r--r-- | src/util.rs | 10 |
17 files changed, 1455 insertions, 620 deletions
diff --git a/src/bytes_buf.rs b/src/bytes_buf.rs new file mode 100644 index 0000000..857be9d --- /dev/null +++ b/src/bytes_buf.rs @@ -0,0 +1,173 @@ +use std::collections::VecDeque; + +pub use bytes::Bytes; + +/// A circular buffer of bytes, internally represented as a list of Bytes +/// for optimization, but that for all intent and purposes acts just like +/// a big byte slice which can be extended on the right and from which +/// one can take on the left. +pub struct BytesBuf { + buf: VecDeque<Bytes>, + buf_len: usize, +} + +impl BytesBuf { + /// Creates a new empty BytesBuf + pub fn new() -> Self { + Self { + buf: VecDeque::new(), + buf_len: 0, + } + } + + /// Returns the number of bytes stored in the BytesBuf + #[inline] + pub fn len(&self) -> usize { + self.buf_len + } + + /// Returns true iff the BytesBuf contains zero bytes + #[inline] + pub fn is_empty(&self) -> bool { + self.buf_len == 0 + } + + /// Adds some bytes to the right of the buffer + pub fn extend(&mut self, b: Bytes) { + if !b.is_empty() { + self.buf_len += b.len(); + self.buf.push_back(b); + } + } + + /// Takes the whole content of the buffer and returns it as a single Bytes unit + pub fn take_all(&mut self) -> Bytes { + if self.buf.len() == 0 { + Bytes::new() + } else if self.buf.len() == 1 { + self.buf_len = 0; + self.buf.pop_back().unwrap() + } else { + let mut ret = Vec::with_capacity(self.buf_len); + for b in self.buf.iter() { + ret.extend(&b[..]); + } + self.buf.clear(); + self.buf_len = 0; + Bytes::from(ret) + } + } + + /// Takes at most max_len bytes from the left of the buffer + pub fn take_max(&mut self, max_len: usize) -> Bytes { + if self.buf_len <= max_len { + self.take_all() + } else { + self.take_exact_ok(max_len) + } + } + + /// Take exactly len bytes from the left of the buffer, returns None if + /// the BytesBuf doesn't contain enough data + pub fn take_exact(&mut self, len: usize) -> Option<Bytes> { + if self.buf_len < len { + None + } else { + Some(self.take_exact_ok(len)) + } + } + + fn take_exact_ok(&mut self, len: usize) -> Bytes { + assert!(len <= self.buf_len); + let front = self.buf.pop_front().unwrap(); + if front.len() > len { + self.buf.push_front(front.slice(len..)); + self.buf_len -= len; + front.slice(..len) + } else if front.len() == len { + self.buf_len -= len; + front + } else { + let mut ret = Vec::with_capacity(len); + ret.extend(&front[..]); + self.buf_len -= front.len(); + while ret.len() < len { + let front = self.buf.pop_front().unwrap(); + if front.len() > len - ret.len() { + let take = len - ret.len(); + ret.extend(front.slice(..take)); + self.buf.push_front(front.slice(take..)); + self.buf_len -= take; + break; + } else { + ret.extend(&front[..]); + self.buf_len -= front.len(); + } + } + Bytes::from(ret) + } + } + + /// Return the internal sequence of Bytes slices that make up the buffer + pub fn into_slices(self) -> VecDeque<Bytes> { + self.buf + } +} + +impl From<Bytes> for BytesBuf { + fn from(b: Bytes) -> BytesBuf { + let mut ret = BytesBuf::new(); + ret.extend(b); + ret + } +} + +impl From<BytesBuf> for Bytes { + fn from(mut b: BytesBuf) -> Bytes { + b.take_all() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_bytes_buf() { + let mut buf = BytesBuf::new(); + assert!(buf.len() == 0); + assert!(buf.is_empty()); + + buf.extend(Bytes::from(b"Hello, world!".to_vec())); + assert!(buf.len() == 13); + assert!(!buf.is_empty()); + + buf.extend(Bytes::from(b"1234567890".to_vec())); + assert!(buf.len() == 23); + assert!(!buf.is_empty()); + + assert_eq!( + buf.take_all(), + Bytes::from(b"Hello, world!1234567890".to_vec()) + ); + assert!(buf.len() == 0); + assert!(buf.is_empty()); + + buf.extend(Bytes::from(b"1234567890".to_vec())); + buf.extend(Bytes::from(b"Hello, world!".to_vec())); + assert!(buf.len() == 23); + assert!(!buf.is_empty()); + + assert_eq!(buf.take_max(12), Bytes::from(b"1234567890He".to_vec())); + assert!(buf.len() == 11); + + assert_eq!(buf.take_exact(12), None); + assert!(buf.len() == 11); + assert_eq!( + buf.take_exact(11), + Some(Bytes::from(b"llo, world!".to_vec())) + ); + assert!(buf.len() == 0); + assert!(buf.is_empty()); + } +} diff --git a/src/client.rs b/src/client.rs index 5c5a05b..9726125 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,12 +1,15 @@ -use std::borrow::Borrow; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::atomic::{self, AtomicU32}; use std::sync::{Arc, Mutex}; use arc_swap::ArcSwapOption; +use async_trait::async_trait; +use bytes::Bytes; use log::{debug, error, trace}; +use futures::io::AsyncReadExt; +use kuska_handshake::async_std::{handshake_client, BoxStream}; use tokio::net::TcpStream; use tokio::select; use tokio::sync::{mpsc, oneshot, watch}; @@ -20,27 +23,22 @@ use opentelemetry::{ #[cfg(feature = "telemetry")] use opentelemetry_contrib::trace::propagator::binary::*; -use futures::io::AsyncReadExt; - -use async_trait::async_trait; - -use kuska_handshake::async_std::{handshake_client, BoxStream}; - -use crate::endpoint::*; use crate::error::*; +use crate::message::*; use crate::netapp::*; -use crate::proto::*; -use crate::proto2::*; +use crate::recv::*; +use crate::send::*; +use crate::stream::*; use crate::util::*; pub(crate) struct ClientConn { pub(crate) remote_addr: SocketAddr, pub(crate) peer_id: NodeID, - query_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>>, + query_send: ArcSwapOption<mpsc::UnboundedSender<SendStream>>, next_query_number: AtomicU32, - inflight: Mutex<HashMap<RequestID, oneshot::Sender<Vec<u8>>>>, + inflight: Mutex<HashMap<RequestID, oneshot::Sender<ByteStream>>>, } impl ClientConn { @@ -139,15 +137,14 @@ impl ClientConn { self.query_send.store(None); } - pub(crate) async fn call<T, B>( + pub(crate) async fn call<T>( self: Arc<Self>, - rq: B, + req: Req<T>, path: &str, prio: RequestPriority, - ) -> Result<<T as Message>::Response, Error> + ) -> Result<Resp<T>, Error> where T: Message, - B: Borrow<T>, { let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?; @@ -162,24 +159,16 @@ impl ClientConn { .with_kind(SpanKind::Client) .start(&tracer); let propagator = BinaryPropagator::new(); - let telemetry_id = Some(propagator.to_bytes(span.span_context()).to_vec()); + let telemetry_id: Bytes = propagator.to_bytes(span.span_context()).to_vec().into(); } else { - let telemetry_id: Option<Vec<u8>> = None; + let telemetry_id: Bytes = Bytes::new(); } }; // Encode request - let body = rmp_to_vec_all_named(rq.borrow())?; - drop(rq); - - let request = QueryMessage { - prio, - path: path.as_bytes(), - telemetry_id, - body: &body[..], - }; - let bytes = request.encode(); - drop(body); + let req_enc = req.into_enc(prio, path.as_bytes().to_vec().into(), telemetry_id); + let req_msg_len = req_enc.msg.len(); + let (req_stream, req_order) = req_enc.encode(); // Send request through let (resp_send, resp_recv) = oneshot::channel(); @@ -188,46 +177,37 @@ impl ClientConn { error!( "Too many inflight requests! RequestID collision. Interrupting previous request." ); - if old_ch.send(vec![]).is_err() { - debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response."); - } + let _ = old_ch.send(Box::pin(futures::stream::once(async move { + Err(std::io::Error::new( + std::io::ErrorKind::Other, + "RequestID collision, too many inflight requests", + )) + }))); } - trace!("request: query_send {}, {} bytes", id, bytes.len()); + debug!( + "request: query_send {}, path {}, prio {} (serialized message: {} bytes)", + id, path, prio, req_msg_len + ); #[cfg(feature = "telemetry")] - span.set_attribute(KeyValue::new("len_query", bytes.len() as i64)); + span.set_attribute(KeyValue::new("len_query_msg", req_msg_len as i64)); - query_send.send((id, prio, bytes))?; + query_send.send((id, prio, req_order, req_stream))?; cfg_if::cfg_if! { if #[cfg(feature = "telemetry")] { - let resp = resp_recv + let stream = resp_recv .with_context(Context::current_with_span(span)) .await?; } else { - let resp = resp_recv.await?; + let stream = resp_recv.await?; } } - if resp.is_empty() { - return Err(Error::Message( - "Response is 0 bytes, either a collision or a protocol error".into(), - )); - } - - trace!("request response {}: ", id); - - let code = resp[0]; - if code == 0 { - Ok(rmp_serde::decode::from_read_ref::< - _, - <T as Message>::Response, - >(&resp[1..])?) - } else { - let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default(); - Err(Error::Remote(code, msg)) - } + let resp_enc = RespEnc::decode(stream).await?; + debug!("client: got response to request {} (path {})", id, path); + Resp::from_enc(resp_enc) } } @@ -235,12 +215,12 @@ impl SendLoop for ClientConn {} #[async_trait] impl RecvLoop for ClientConn { - fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>) { - trace!("ClientConn recv_handler {} ({} bytes)", id, msg.len()); + fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream) { + trace!("ClientConn recv_handler {}", id); let mut inflight = self.inflight.lock().unwrap(); if let Some(ch) = inflight.remove(&id) { - if ch.send(msg).is_err() { + if ch.send(stream).is_err() { debug!("Could not send request response, probably because request was interrupted. Dropping response."); } } diff --git a/src/endpoint.rs b/src/endpoint.rs index 42e9a98..bb768de 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -1,33 +1,25 @@ -use std::borrow::Borrow; use std::marker::PhantomData; use std::sync::Arc; use arc_swap::ArcSwapOption; use async_trait::async_trait; -use serde::{Deserialize, Serialize}; - use crate::error::Error; +use crate::message::*; use crate::netapp::*; -use crate::proto::*; -use crate::util::*; - -/// This trait should be implemented by all messages your application -/// wants to handle -pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync { - type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync; -} /// This trait should be implemented by an object of your application -/// that can handle a message of type `M`. +/// that can handle a message of type `M`, if it wishes to handle +/// streams attached to the request and/or to send back streams +/// attached to the response.. /// /// The handler object should be in an Arc, see `Endpoint::set_handler` #[async_trait] -pub trait EndpointHandler<M>: Send + Sync +pub trait StreamingEndpointHandler<M>: Send + Sync where M: Message, { - async fn handle(self: &Arc<Self>, m: &M, from: NodeID) -> M::Response; + async fn handle(self: &Arc<Self>, m: Req<M>, from: NodeID) -> Resp<M>; } /// If one simply wants to use an endpoint in a client fashion, @@ -35,12 +27,41 @@ where /// use the unit type `()` as the handler type: /// it will panic if it is ever made to handle request. #[async_trait] -impl<M: Message + 'static> EndpointHandler<M> for () { +impl<M: Message> EndpointHandler<M> for () { async fn handle(self: &Arc<()>, _m: &M, _from: NodeID) -> M::Response { panic!("This endpoint should not have a local handler."); } } +// ---- + +/// This trait should be implemented by an object of your application +/// that can handle a message of type `M`, in the cases where it doesn't +/// care about attached stream in the request nor in the response. +#[async_trait] +pub trait EndpointHandler<M>: Send + Sync +where + M: Message, +{ + async fn handle(self: &Arc<Self>, m: &M, from: NodeID) -> <M as Message>::Response; +} + +#[async_trait] +impl<T, M> StreamingEndpointHandler<M> for T +where + T: EndpointHandler<M>, + M: Message, +{ + async fn handle(self: &Arc<Self>, mut m: Req<M>, from: NodeID) -> Resp<M> { + // Immediately drop stream to avoid backpressure if a stream was sent + // (this will make all data sent to the stream be ignored immediately) + drop(m.take_stream()); + Resp::new(EndpointHandler::handle(self, m.msg(), from).await) + } +} + +// ---- + /// This struct represents an endpoint for message of type `M`. /// /// Creating a new endpoint is done by calling `NetApp::endpoint`. @@ -50,13 +71,13 @@ impl<M: Message + 'static> EndpointHandler<M> for () { /// An `Endpoint` is used both to send requests to remote nodes, /// and to specify the handler for such requests on the local node. /// The type `H` represents the type of the handler object for -/// endpoint messages (see `EndpointHandler`). +/// endpoint messages (see `StreamingEndpointHandler`). pub struct Endpoint<M, H> where M: Message, - H: EndpointHandler<M>, + H: StreamingEndpointHandler<M>, { - phantom: PhantomData<M>, + _phantom: PhantomData<M>, netapp: Arc<NetApp>, path: String, handler: ArcSwapOption<H>, @@ -65,11 +86,11 @@ where impl<M, H> Endpoint<M, H> where M: Message, - H: EndpointHandler<M>, + H: StreamingEndpointHandler<M>, { pub(crate) fn new(netapp: Arc<NetApp>, path: String) -> Self { Self { - phantom: PhantomData::default(), + _phantom: PhantomData::default(), netapp, path, handler: ArcSwapOption::from(None), @@ -88,20 +109,22 @@ where } /// Call this endpoint on a remote node (or on the local node, - /// for that matter) - pub async fn call<B>( + /// for that matter). This function invokes the full version that + /// allows to attach a stream to the request and to + /// receive such a stream attached to the response. + pub async fn call_streaming<T>( &self, target: &NodeID, - req: B, + req: T, prio: RequestPriority, - ) -> Result<<M as Message>::Response, Error> + ) -> Result<Resp<M>, Error> where - B: Borrow<M>, + T: IntoReq<M>, { if *target == self.netapp.id { match self.handler.load_full() { None => Err(Error::NoHandler), - Some(h) => Ok(h.handle(req.borrow(), self.netapp.id).await), + Some(h) => Ok(h.handle(req.into_req_local(), self.netapp.id).await), } } else { let conn = self @@ -116,10 +139,22 @@ where "Not connected: {}", hex::encode(&target[..8]) ))), - Some(c) => c.call(req, self.path.as_str(), prio).await, + Some(c) => c.call(req.into_req()?, self.path.as_str(), prio).await, } } } + + /// Call this endpoint on a remote node. This function is the simplified + /// version that doesn't allow to have streams attached to the request + /// or the response; see `call_streaming` for the full version. + pub async fn call( + &self, + target: &NodeID, + req: M, + prio: RequestPriority, + ) -> Result<<M as Message>::Response, Error> { + Ok(self.call_streaming(target, req, prio).await?.into_msg()) + } } // ---- Internal stuff ---- @@ -128,7 +163,7 @@ pub(crate) type DynEndpoint = Box<dyn GenericEndpoint + Send + Sync>; #[async_trait] pub(crate) trait GenericEndpoint { - async fn handle(&self, buf: &[u8], from: NodeID) -> Result<Vec<u8>, Error>; + async fn handle(&self, req_enc: ReqEnc, from: NodeID) -> Result<RespEnc, Error>; fn drop_handler(&self); fn clone_endpoint(&self) -> DynEndpoint; } @@ -137,22 +172,21 @@ pub(crate) trait GenericEndpoint { pub(crate) struct EndpointArc<M, H>(pub(crate) Arc<Endpoint<M, H>>) where M: Message, - H: EndpointHandler<M>; + H: StreamingEndpointHandler<M>; #[async_trait] impl<M, H> GenericEndpoint for EndpointArc<M, H> where - M: Message + 'static, - H: EndpointHandler<M> + 'static, + M: Message, + H: StreamingEndpointHandler<M> + 'static, { - async fn handle(&self, buf: &[u8], from: NodeID) -> Result<Vec<u8>, Error> { + async fn handle(&self, req_enc: ReqEnc, from: NodeID) -> Result<RespEnc, Error> { match self.0.handler.load_full() { None => Err(Error::NoHandler), Some(h) => { - let req = rmp_serde::decode::from_read_ref::<_, M>(buf)?; - let res = h.handle(&req, from).await; - let res_bytes = rmp_to_vec_all_named(&res)?; - Ok(res_bytes) + let req = Req::from_enc(req_enc)?; + let res = h.handle(req, from).await; + Ok(res.into_enc()?) } } } diff --git a/src/error.rs b/src/error.rs index 99acdd1..c0aeeac 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,6 @@ -use err_derive::Error; use std::io; +use err_derive::Error; use log::error; #[derive(Debug, Error)] @@ -25,6 +25,15 @@ pub enum Error { #[error(display = "UTF8 error: {}", _0)] UTF8(#[error(source)] std::string::FromUtf8Error), + #[error(display = "Framing protocol error")] + Framing, + + #[error(display = "Remote error ({:?}): {}", _0, _1)] + Remote(io::ErrorKind, String), + + #[error(display = "Request ID collision")] + IdCollision, + #[error(display = "{}", _0)] Message(String), @@ -36,28 +45,6 @@ pub enum Error { #[error(display = "Version mismatch: {}", _0)] VersionMismatch(String), - - #[error(display = "Remote error {}: {}", _0, _1)] - Remote(u8, String), -} - -impl Error { - pub fn code(&self) -> u8 { - match self { - Self::Io(_) => 100, - Self::TokioJoin(_) => 110, - Self::OneshotRecv(_) => 111, - Self::RMPEncode(_) => 10, - Self::RMPDecode(_) => 11, - Self::UTF8(_) => 12, - Self::NoHandler => 20, - Self::ConnectionClosed => 21, - Self::Handshake(_) => 30, - Self::VersionMismatch(_) => 31, - Self::Remote(c, _) => *c, - Self::Message(_) => 99, - } - } } impl<T> From<tokio::sync::watch::error::SendError<T>> for Error { @@ -101,3 +88,39 @@ where } } } + +// ---- Helpers for serializing I/O Errors + +pub(crate) fn u8_to_io_errorkind(v: u8) -> std::io::ErrorKind { + use std::io::ErrorKind; + match v { + 101 => ErrorKind::ConnectionAborted, + 102 => ErrorKind::BrokenPipe, + 103 => ErrorKind::WouldBlock, + 104 => ErrorKind::InvalidInput, + 105 => ErrorKind::InvalidData, + 106 => ErrorKind::TimedOut, + 107 => ErrorKind::Interrupted, + 108 => ErrorKind::UnexpectedEof, + 109 => ErrorKind::OutOfMemory, + 110 => ErrorKind::ConnectionReset, + _ => ErrorKind::Other, + } +} + +pub(crate) fn io_errorkind_to_u8(kind: std::io::ErrorKind) -> u8 { + use std::io::ErrorKind; + match kind { + ErrorKind::ConnectionAborted => 101, + ErrorKind::BrokenPipe => 102, + ErrorKind::WouldBlock => 103, + ErrorKind::InvalidInput => 104, + ErrorKind::InvalidData => 105, + ErrorKind::TimedOut => 106, + ErrorKind::Interrupted => 107, + ErrorKind::UnexpectedEof => 108, + ErrorKind::OutOfMemory => 109, + ErrorKind::ConnectionReset => 110, + _ => 100, + } +} @@ -13,21 +13,23 @@ //! about message priorization. //! Also check out the examples to learn how to use this crate. +pub mod bytes_buf; pub mod error; +pub mod stream; pub mod util; pub mod endpoint; -pub mod proto; +pub mod message; mod client; -mod proto2; +mod recv; +mod send; mod server; pub mod netapp; pub mod peering; pub use crate::netapp::*; -pub use util::{NetworkKey, NodeID, NodeKey}; #[cfg(test)] mod test; diff --git a/src/message.rs b/src/message.rs new file mode 100644 index 0000000..1834f28 --- /dev/null +++ b/src/message.rs @@ -0,0 +1,470 @@ +use std::fmt; +use std::marker::PhantomData; +use std::sync::Arc; + +use bytes::{BufMut, Bytes, BytesMut}; +use rand::prelude::*; +use serde::{Deserialize, Serialize}; + +use futures::stream::StreamExt; + +use crate::error::*; +use crate::stream::*; +use crate::util::*; + +/// Priority of a request (click to read more about priorities). +/// +/// This priority value is used to priorize messages +/// in the send queue of the client, and their responses in the send queue of the +/// server. Lower values mean higher priority. +/// +/// This mechanism is usefull for messages bigger than the maximum chunk size +/// (set at `0x4000` bytes), such as large file transfers. +/// In such case, all of the messages in the send queue with the highest priority +/// will take turns to send individual chunks, in a round-robin fashion. +/// Once all highest priority messages are sent successfully, the messages with +/// the next highest priority will begin being sent in the same way. +/// +/// The same priority value is given to a request and to its associated response. +pub type RequestPriority = u8; + +/// Priority class: high +pub const PRIO_HIGH: RequestPriority = 0x20; +/// Priority class: normal +pub const PRIO_NORMAL: RequestPriority = 0x40; +/// Priority class: background +pub const PRIO_BACKGROUND: RequestPriority = 0x80; +/// Priority: primary among given class +pub const PRIO_PRIMARY: RequestPriority = 0x00; +/// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`) +pub const PRIO_SECONDARY: RequestPriority = 0x01; + +// ---- + +#[derive(Clone, Copy)] +pub struct OrderTagStream(u64); +#[derive(Clone, Copy, Serialize, Deserialize, Debug)] +pub struct OrderTag(pub(crate) u64, pub(crate) u64); + +impl OrderTag { + pub fn stream() -> OrderTagStream { + OrderTagStream(thread_rng().gen()) + } +} +impl OrderTagStream { + pub fn order(&self, order: u64) -> OrderTag { + OrderTag(self.0, order) + } +} + +// ---- + +/// This trait should be implemented by all messages your application +/// wants to handle +pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static { + type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static; +} + +// ---- + +/// The Req<M> is a helper object used to create requests and attach them +/// a stream of data. If the stream is a fixed Bytes and not a ByteStream, +/// Req<M> is cheaply clonable to allow the request to be sent to different +/// peers (Clone will panic if the stream is a ByteStream). +pub struct Req<M: Message> { + pub(crate) msg: Arc<M>, + pub(crate) msg_ser: Option<Bytes>, + pub(crate) stream: AttachedStream, + pub(crate) order_tag: Option<OrderTag>, +} + +impl<M: Message> Req<M> { + pub fn new(v: M) -> Result<Self, Error> { + Ok(v.into_req()?) + } + + pub fn with_stream_from_buffer(self, b: Bytes) -> Self { + Self { + stream: AttachedStream::Fixed(b), + ..self + } + } + + pub fn with_stream(self, b: ByteStream) -> Self { + Self { + stream: AttachedStream::Stream(b), + ..self + } + } + + pub fn with_order_tag(self, order_tag: OrderTag) -> Self { + Self { + order_tag: Some(order_tag), + ..self + } + } + + pub fn msg(&self) -> &M { + &self.msg + } + + pub fn take_stream(&mut self) -> Option<ByteStream> { + std::mem::replace(&mut self.stream, AttachedStream::None).into_stream() + } + + pub(crate) fn into_enc( + self, + prio: RequestPriority, + path: Bytes, + telemetry_id: Bytes, + ) -> ReqEnc { + ReqEnc { + prio, + path, + telemetry_id, + msg: self.msg_ser.unwrap(), + stream: self.stream.into_stream(), + order_tag: self.order_tag, + } + } + + pub(crate) fn from_enc(enc: ReqEnc) -> Result<Self, rmp_serde::decode::Error> { + let msg = rmp_serde::decode::from_read_ref(&enc.msg)?; + Ok(Req { + msg: Arc::new(msg), + msg_ser: Some(enc.msg), + stream: enc + .stream + .map(AttachedStream::Stream) + .unwrap_or(AttachedStream::None), + order_tag: enc.order_tag, + }) + } +} + +pub trait IntoReq<M: Message> { + fn into_req(self) -> Result<Req<M>, rmp_serde::encode::Error>; + fn into_req_local(self) -> Req<M>; +} + +impl<M: Message> IntoReq<M> for M { + fn into_req(self) -> Result<Req<M>, rmp_serde::encode::Error> { + let msg_ser = rmp_to_vec_all_named(&self)?; + Ok(Req { + msg: Arc::new(self), + msg_ser: Some(Bytes::from(msg_ser)), + stream: AttachedStream::None, + order_tag: None, + }) + } + fn into_req_local(self) -> Req<M> { + Req { + msg: Arc::new(self), + msg_ser: None, + stream: AttachedStream::None, + order_tag: None, + } + } +} + +impl<M: Message> IntoReq<M> for Req<M> { + fn into_req(self) -> Result<Req<M>, rmp_serde::encode::Error> { + Ok(self) + } + fn into_req_local(self) -> Req<M> { + self + } +} + +impl<M: Message> Clone for Req<M> { + fn clone(&self) -> Self { + let stream = match &self.stream { + AttachedStream::None => AttachedStream::None, + AttachedStream::Fixed(b) => AttachedStream::Fixed(b.clone()), + AttachedStream::Stream(_) => { + panic!("Cannot clone a Req<_> with a non-buffer attached stream") + } + }; + Self { + msg: self.msg.clone(), + msg_ser: self.msg_ser.clone(), + stream, + order_tag: self.order_tag, + } + } +} + +impl<M> fmt::Debug for Req<M> +where + M: Message + fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!(f, "Req[{:?}", self.msg)?; + match &self.stream { + AttachedStream::None => write!(f, "]"), + AttachedStream::Fixed(b) => write!(f, "; stream=buf:{}]", b.len()), + AttachedStream::Stream(_) => write!(f, "; stream]"), + } + } +} + +// ---- + +/// The Resp<M> represents a full response from a RPC that may have +/// an attached stream. +pub struct Resp<M: Message> { + pub(crate) _phantom: PhantomData<M>, + pub(crate) msg: M::Response, + pub(crate) stream: AttachedStream, + pub(crate) order_tag: Option<OrderTag>, +} + +impl<M: Message> Resp<M> { + pub fn new(v: M::Response) -> Self { + Resp { + _phantom: Default::default(), + msg: v, + stream: AttachedStream::None, + order_tag: None, + } + } + + pub fn with_stream_from_buffer(self, b: Bytes) -> Self { + Self { + stream: AttachedStream::Fixed(b), + ..self + } + } + + pub fn with_stream(self, b: ByteStream) -> Self { + Self { + stream: AttachedStream::Stream(b), + ..self + } + } + + pub fn with_order_tag(self, order_tag: OrderTag) -> Self { + Self { + order_tag: Some(order_tag), + ..self + } + } + + pub fn msg(&self) -> &M::Response { + &self.msg + } + + pub fn into_msg(self) -> M::Response { + self.msg + } + + pub fn into_parts(self) -> (M::Response, Option<ByteStream>) { + (self.msg, self.stream.into_stream()) + } + + pub(crate) fn into_enc(self) -> Result<RespEnc, rmp_serde::encode::Error> { + Ok(RespEnc { + msg: rmp_to_vec_all_named(&self.msg)?.into(), + stream: self.stream.into_stream(), + order_tag: self.order_tag, + }) + } + + pub(crate) fn from_enc(enc: RespEnc) -> Result<Self, Error> { + let msg = rmp_serde::decode::from_read_ref(&enc.msg)?; + Ok(Self { + _phantom: Default::default(), + msg, + stream: enc + .stream + .map(AttachedStream::Stream) + .unwrap_or(AttachedStream::None), + order_tag: enc.order_tag, + }) + } +} + +impl<M> fmt::Debug for Resp<M> +where + M: Message, + <M as Message>::Response: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!(f, "Resp[{:?}", self.msg)?; + match &self.stream { + AttachedStream::None => write!(f, "]"), + AttachedStream::Fixed(b) => write!(f, "; stream=buf:{}]", b.len()), + AttachedStream::Stream(_) => write!(f, "; stream]"), + } + } +} + +// ---- + +pub(crate) enum AttachedStream { + None, + Fixed(Bytes), + Stream(ByteStream), +} + +impl AttachedStream { + pub fn into_stream(self) -> Option<ByteStream> { + match self { + AttachedStream::None => None, + AttachedStream::Fixed(b) => Some(Box::pin(futures::stream::once(async move { Ok(b) }))), + AttachedStream::Stream(s) => Some(s), + } + } +} + +// ---- ---- + +/// Encoding for requests into a ByteStream: +/// - priority: u8 +/// - path length: u8 +/// - path: [u8; path length] +/// - telemetry id length: u8 +/// - telemetry id: [u8; telemetry id length] +/// - msg len: u32 +/// - msg [u8; ..] +/// - the attached stream as the rest of the encoded stream +pub(crate) struct ReqEnc { + pub(crate) prio: RequestPriority, + pub(crate) path: Bytes, + pub(crate) telemetry_id: Bytes, + pub(crate) msg: Bytes, + pub(crate) stream: Option<ByteStream>, + pub(crate) order_tag: Option<OrderTag>, +} + +impl ReqEnc { + pub(crate) fn encode(self) -> (ByteStream, Option<OrderTag>) { + let mut buf = BytesMut::with_capacity( + self.path.len() + self.telemetry_id.len() + self.msg.len() + 16, + ); + + buf.put_u8(self.prio); + + buf.put_u8(self.path.len() as u8); + buf.put(self.path); + + buf.put_u8(self.telemetry_id.len() as u8); + buf.put(&self.telemetry_id[..]); + + buf.put_u32(self.msg.len() as u32); + + let header = buf.freeze(); + + let res_stream: ByteStream = if let Some(stream) = self.stream { + Box::pin(futures::stream::iter([Ok(header), Ok(self.msg)]).chain(stream)) + } else { + Box::pin(futures::stream::iter([Ok(header), Ok(self.msg)])) + }; + (res_stream, self.order_tag) + } + + pub(crate) async fn decode(stream: ByteStream) -> Result<Self, Error> { + Self::decode_aux(stream) + .await + .map_err(read_exact_error_to_error) + } + + async fn decode_aux(stream: ByteStream) -> Result<Self, ReadExactError> { + let mut reader = ByteStreamReader::new(stream); + + let prio = reader.read_u8().await?; + + let path_len = reader.read_u8().await?; + let path = reader.read_exact(path_len as usize).await?; + + let telemetry_id_len = reader.read_u8().await?; + let telemetry_id = reader.read_exact(telemetry_id_len as usize).await?; + + let msg_len = reader.read_u32().await?; + let msg = reader.read_exact(msg_len as usize).await?; + + Ok(Self { + prio, + path, + telemetry_id, + msg, + stream: Some(reader.into_stream()), + order_tag: None, + }) + } +} + +/// Encoding for responses into a ByteStream: +/// IF SUCCESS: +/// - 0: u8 +/// - msg len: u32 +/// - msg [u8; ..] +/// - the attached stream as the rest of the encoded stream +/// IF ERROR: +/// - message length + 1: u8 +/// - error code: u8 +/// - message: [u8; message_length] +pub(crate) struct RespEnc { + msg: Bytes, + stream: Option<ByteStream>, + order_tag: Option<OrderTag>, +} + +impl RespEnc { + pub(crate) fn encode(resp: Result<Self, Error>) -> (ByteStream, Option<OrderTag>) { + match resp { + Ok(Self { + msg, + stream, + order_tag, + }) => { + let mut buf = BytesMut::with_capacity(4); + buf.put_u32(msg.len() as u32); + let header = buf.freeze(); + + let res_stream: ByteStream = if let Some(stream) = stream { + Box::pin(futures::stream::iter([Ok(header), Ok(msg)]).chain(stream)) + } else { + Box::pin(futures::stream::iter([Ok(header), Ok(msg)])) + }; + (res_stream, order_tag) + } + Err(err) => { + let err = std::io::Error::new( + std::io::ErrorKind::Other, + format!("netapp error: {}", err), + ); + ( + Box::pin(futures::stream::once(async move { Err(err) })), + None, + ) + } + } + } + + pub(crate) async fn decode(stream: ByteStream) -> Result<Self, Error> { + Self::decode_aux(stream) + .await + .map_err(read_exact_error_to_error) + } + + async fn decode_aux(stream: ByteStream) -> Result<Self, ReadExactError> { + let mut reader = ByteStreamReader::new(stream); + + let msg_len = reader.read_u32().await?; + let msg = reader.read_exact(msg_len as usize).await?; + + Ok(Self { + msg, + stream: Some(reader.into_stream()), + order_tag: None, + }) + } +} + +fn read_exact_error_to_error(e: ReadExactError) -> Error { + match e { + ReadExactError::Stream(err) => Error::Remote(err.kind(), err.to_string()), + ReadExactError::UnexpectedEos => Error::Framing, + } +} diff --git a/src/netapp.rs b/src/netapp.rs index e9efa2e..f1e14ed 100644 --- a/src/netapp.rs +++ b/src/netapp.rs @@ -20,9 +20,15 @@ use tokio::sync::{mpsc, watch}; use crate::client::*; use crate::endpoint::*; use crate::error::*; -use crate::proto::*; +use crate::message::*; use crate::server::*; -use crate::util::*; + +/// A node's identifier, which is also its public cryptographic key +pub type NodeID = sodiumoxide::crypto::sign::ed25519::PublicKey; +/// A node's secret key +pub type NodeKey = sodiumoxide::crypto::sign::ed25519::SecretKey; +/// A network key +pub type NetworkKey = sodiumoxide::crypto::auth::Key; /// Tag which is exchanged between client and server upon connection establishment /// to check that they are running compatible versions of Netapp, @@ -30,9 +36,9 @@ use crate::util::*; pub(crate) type VersionTag = [u8; 16]; /// Value of the Netapp version used in the version tag -pub(crate) const NETAPP_VERSION_TAG: u64 = 0x6e65746170700004; // netapp 0x0004 +pub(crate) const NETAPP_VERSION_TAG: u64 = 0x6e65746170700005; // netapp 0x0005 -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub(crate) struct HelloMessage { pub server_addr: Option<IpAddr>, pub server_port: u16, @@ -152,7 +158,7 @@ impl NetApp { pub fn endpoint<M, H>(self: &Arc<Self>, path: String) -> Arc<Endpoint<M, H>> where M: Message + 'static, - H: EndpointHandler<M> + 'static, + H: StreamingEndpointHandler<M> + 'static, { let endpoint = Arc::new(Endpoint::<M, H>::new(self.clone(), path.clone())); let endpoint_arc = EndpointArc(endpoint.clone()); @@ -397,13 +403,14 @@ impl NetApp { hello_endpoint .call( &conn.peer_id, - &HelloMessage { + HelloMessage { server_addr, server_port, }, PRIO_NORMAL, ) .await + .map(|_| ()) .log_err("Sending hello message"); }); } diff --git a/src/peering/basalt.rs b/src/peering/basalt.rs index 7f77995..310077f 100644 --- a/src/peering/basalt.rs +++ b/src/peering/basalt.rs @@ -14,8 +14,8 @@ use sodiumoxide::crypto::hash; use tokio::sync::watch; use crate::endpoint::*; +use crate::message::*; use crate::netapp::*; -use crate::proto::*; use crate::NodeID; // -- Protocol messages -- @@ -331,7 +331,7 @@ impl Basalt { async fn do_pull(self: Arc<Self>, peer: NodeID) { match self .pull_endpoint - .call(&peer, &PullMessage {}, PRIO_NORMAL) + .call(&peer, PullMessage {}, PRIO_NORMAL) .await { Ok(resp) => { @@ -346,7 +346,7 @@ impl Basalt { async fn do_push(self: Arc<Self>, peer: NodeID) { let push_msg = self.make_push_message(); - match self.push_endpoint.call(&peer, &push_msg, PRIO_NORMAL).await { + match self.push_endpoint.call(&peer, push_msg, PRIO_NORMAL).await { Ok(_) => { trace!("KYEV PEXo {}", hex::encode(peer)); } diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs index 208cfe4..7f1c065 100644 --- a/src/peering/fullmesh.rs +++ b/src/peering/fullmesh.rs @@ -17,7 +17,8 @@ use sodiumoxide::crypto::hash; use crate::endpoint::*; use crate::error::*; use crate::netapp::*; -use crate::proto::*; + +use crate::message::*; use crate::NodeID; const CONN_RETRY_INTERVAL: Duration = Duration::from_secs(30); @@ -29,7 +30,7 @@ const FAILED_PING_THRESHOLD: usize = 4; // -- Protocol messages -- -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] struct PingMessage { pub id: u64, pub peer_list_hash: hash::Digest, @@ -39,7 +40,7 @@ impl Message for PingMessage { type Response = PingMessage; } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] struct PeerListMessage { pub list: Vec<(NodeID, SocketAddr)>, } @@ -382,7 +383,7 @@ impl FullMeshPeeringStrategy { ping_time ); let ping_response = select! { - r = self.ping_endpoint.call(&id, &ping_msg, PRIO_HIGH) => r, + r = self.ping_endpoint.call(&id, ping_msg, PRIO_HIGH) => r, _ = tokio::time::sleep(PING_TIMEOUT) => Err(Error::Message("Ping timeout".into())), }; @@ -434,7 +435,7 @@ impl FullMeshPeeringStrategy { let pex_message = PeerListMessage { list: peer_list }; match self .peer_list_endpoint - .call(id, &pex_message, PRIO_BACKGROUND) + .call(id, pex_message, PRIO_BACKGROUND) .await { Err(e) => warn!("Error doing peer exchange: {}", e), diff --git a/src/proto.rs b/src/proto.rs deleted file mode 100644 index 8f7e70f..0000000 --- a/src/proto.rs +++ /dev/null @@ -1,358 +0,0 @@ -use std::collections::{HashMap, VecDeque}; -use std::fmt::Write; -use std::sync::Arc; - -use log::trace; - -use futures::{AsyncReadExt, AsyncWriteExt}; -use kuska_handshake::async_std::BoxStreamWrite; - -use tokio::sync::mpsc; - -use async_trait::async_trait; - -use crate::error::*; - -/// Priority of a request (click to read more about priorities). -/// -/// This priority value is used to priorize messages -/// in the send queue of the client, and their responses in the send queue of the -/// server. Lower values mean higher priority. -/// -/// This mechanism is usefull for messages bigger than the maximum chunk size -/// (set at `0x4000` bytes), such as large file transfers. -/// In such case, all of the messages in the send queue with the highest priority -/// will take turns to send individual chunks, in a round-robin fashion. -/// Once all highest priority messages are sent successfully, the messages with -/// the next highest priority will begin being sent in the same way. -/// -/// The same priority value is given to a request and to its associated response. -pub type RequestPriority = u8; - -/// Priority class: high -pub const PRIO_HIGH: RequestPriority = 0x20; -/// Priority class: normal -pub const PRIO_NORMAL: RequestPriority = 0x40; -/// Priority class: background -pub const PRIO_BACKGROUND: RequestPriority = 0x80; -/// Priority: primary among given class -pub const PRIO_PRIMARY: RequestPriority = 0x00; -/// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`) -pub const PRIO_SECONDARY: RequestPriority = 0x01; - -// Messages are sent by chunks -// Chunk format: -// - u32 BE: request id (same for request and response) -// - u16 BE: chunk length, possibly with CHUNK_HAS_CONTINUATION flag -// when this is not the last chunk of the message -// - [u8; chunk_length] chunk data - -pub(crate) type RequestID = u32; -type ChunkLength = u16; -const MAX_CHUNK_LENGTH: ChunkLength = 0x4000; -const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; - -struct SendQueueItem { - id: RequestID, - prio: RequestPriority, - data: Vec<u8>, - cursor: usize, -} - -struct SendQueue { - items: VecDeque<(u8, VecDeque<SendQueueItem>)>, -} - -impl SendQueue { - fn new() -> Self { - Self { - items: VecDeque::with_capacity(64), - } - } - fn push(&mut self, item: SendQueueItem) { - let prio = item.prio; - let pos_prio = match self.items.binary_search_by(|(p, _)| p.cmp(&prio)) { - Ok(i) => i, - Err(i) => { - self.items.insert(i, (prio, VecDeque::new())); - i - } - }; - self.items[pos_prio].1.push_back(item); - } - fn pop(&mut self) -> Option<SendQueueItem> { - match self.items.pop_front() { - None => None, - Some((prio, mut items_at_prio)) => { - let ret = items_at_prio.pop_front(); - if !items_at_prio.is_empty() { - self.items.push_front((prio, items_at_prio)); - } - ret.or_else(|| self.pop()) - } - } - } - fn is_empty(&self) -> bool { - self.items.iter().all(|(_k, v)| v.is_empty()) - } - fn dump(&self) -> String { - let mut ret = String::new(); - for (prio, q) in self.items.iter() { - for item in q.iter() { - write!( - &mut ret, - " [{} {} ({})]", - prio, - item.data.len() - item.cursor, - item.id - ) - .unwrap(); - } - } - ret - } -} - -/// The SendLoop trait, which is implemented both by the client and the server -/// connection objects (ServerConna and ClientConn) adds a method `.send_loop()` -/// that takes a channel of messages to send and an asynchronous writer, -/// and sends messages from the channel to the async writer, putting them in a queue -/// before being sent and doing the round-robin sending strategy. -/// -/// The `.send_loop()` exits when the sending end of the channel is closed, -/// or if there is an error at any time writing to the async writer. -#[async_trait] -pub(crate) trait SendLoop: Sync { - async fn send_loop<W>( - self: Arc<Self>, - mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec<u8>)>, - mut write: BoxStreamWrite<W>, - debug_name: String, - ) -> Result<(), Error> - where - W: AsyncWriteExt + Unpin + Send + Sync, - { - let mut sending = SendQueue::new(); - let mut should_exit = false; - while !should_exit || !sending.is_empty() { - trace!("send_loop({}): queue = {}", debug_name, sending.dump()); - if let Ok((id, prio, data)) = msg_recv.try_recv() { - trace!( - "send_loop({}): new message to send, id = {}, prio = {}, {} bytes", - debug_name, - id, - prio, - data.len() - ); - sending.push(SendQueueItem { - id, - prio, - data, - cursor: 0, - }); - } else if let Some(mut item) = sending.pop() { - trace!( - "send_loop({}): sending bytes for {} ({} bytes, {} already sent)", - debug_name, - item.id, - item.data.len(), - item.cursor - ); - let header_id = RequestID::to_be_bytes(item.id); - write.write_all(&header_id[..]).await?; - - if item.data.len() - item.cursor > MAX_CHUNK_LENGTH as usize { - let size_header = - ChunkLength::to_be_bytes(MAX_CHUNK_LENGTH | CHUNK_HAS_CONTINUATION); - write.write_all(&size_header[..]).await?; - - let new_cursor = item.cursor + MAX_CHUNK_LENGTH as usize; - write.write_all(&item.data[item.cursor..new_cursor]).await?; - item.cursor = new_cursor; - - sending.push(item); - } else { - let send_len = (item.data.len() - item.cursor) as ChunkLength; - - let size_header = ChunkLength::to_be_bytes(send_len); - write.write_all(&size_header[..]).await?; - - write.write_all(&item.data[item.cursor..]).await?; - } - write.flush().await?; - } else { - let sth = msg_recv.recv().await; - if let Some((id, prio, data)) = sth { - trace!( - "send_loop({}): new message to send, id = {}, prio = {}, {} bytes", - debug_name, - id, - prio, - data.len() - ); - sending.push(SendQueueItem { - id, - prio, - data, - cursor: 0, - }); - } else { - should_exit = true; - } - } - } - - let _ = write.goodbye().await; - Ok(()) - } -} - -/// The RecvLoop trait, which is implemented both by the client and the server -/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` -/// and a prototype of a handler for received messages `.recv_handler()` that -/// must be filled by implementors. `.recv_loop()` receives messages in a loop -/// according to the protocol defined above: chunks of message in progress of being -/// received are stored in a buffer, and when the last chunk of a message is received, -/// the full message is passed to the receive handler. -#[async_trait] -pub(crate) trait RecvLoop: Sync + 'static { - fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>); - - async fn recv_loop<R>(self: Arc<Self>, mut read: R, debug_name: String) -> Result<(), Error> - where - R: AsyncReadExt + Unpin + Send + Sync, - { - let mut receiving = HashMap::new(); - loop { - let mut header_id = [0u8; RequestID::BITS as usize / 8]; - match read.read_exact(&mut header_id[..]).await { - Ok(_) => (), - Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break, - Err(e) => return Err(e.into()), - }; - let id = RequestID::from_be_bytes(header_id); - - let mut header_size = [0u8; ChunkLength::BITS as usize / 8]; - read.read_exact(&mut header_size[..]).await?; - let size = ChunkLength::from_be_bytes(header_size); - trace!( - "recv_loop({}): got header id = {}, size = 0x{:04x} ({} bytes)", - debug_name, - id, - size, - size & !CHUNK_HAS_CONTINUATION - ); - - let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; - let size = size & !CHUNK_HAS_CONTINUATION; - - let mut next_slice = vec![0; size as usize]; - read.read_exact(&mut next_slice[..]).await?; - trace!("recv_loop({}): read {} bytes", debug_name, next_slice.len()); - - let mut msg_bytes: Vec<_> = receiving.remove(&id).unwrap_or_default(); - msg_bytes.extend_from_slice(&next_slice[..]); - - if has_cont { - receiving.insert(id, msg_bytes); - } else { - self.recv_handler(id, msg_bytes); - } - } - Ok(()) - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_priority_queue() { - let i1 = SendQueueItem { - id: 1, - prio: PRIO_NORMAL, - data: vec![], - cursor: 0, - }; - let i2 = SendQueueItem { - id: 2, - prio: PRIO_HIGH, - data: vec![], - cursor: 0, - }; - let i2bis = SendQueueItem { - id: 20, - prio: PRIO_HIGH, - data: vec![], - cursor: 0, - }; - let i3 = SendQueueItem { - id: 3, - prio: PRIO_HIGH | PRIO_SECONDARY, - data: vec![], - cursor: 0, - }; - let i4 = SendQueueItem { - id: 4, - prio: PRIO_BACKGROUND | PRIO_SECONDARY, - data: vec![], - cursor: 0, - }; - let i5 = SendQueueItem { - id: 5, - prio: PRIO_BACKGROUND | PRIO_PRIMARY, - data: vec![], - cursor: 0, - }; - - let mut q = SendQueue::new(); - - q.push(i1); // 1 - let a = q.pop().unwrap(); // empty -> 1 - assert_eq!(a.id, 1); - assert!(q.pop().is_none()); - - q.push(a); // 1 - q.push(i2); // 2 1 - q.push(i2bis); // [2 20] 1 - let a = q.pop().unwrap(); // 20 1 -> 2 - assert_eq!(a.id, 2); - let b = q.pop().unwrap(); // 1 -> 20 - assert_eq!(b.id, 20); - let c = q.pop().unwrap(); // empty -> 1 - assert_eq!(c.id, 1); - assert!(q.pop().is_none()); - - q.push(a); // 2 - q.push(b); // [2 20] - q.push(c); // [2 20] 1 - q.push(i3); // [2 20] 3 1 - q.push(i4); // [2 20] 3 1 4 - q.push(i5); // [2 20] 3 1 5 4 - - let a = q.pop().unwrap(); // 20 3 1 5 4 -> 2 - assert_eq!(a.id, 2); - q.push(a); // [20 2] 3 1 5 4 - - let a = q.pop().unwrap(); // 2 3 1 5 4 -> 20 - assert_eq!(a.id, 20); - let b = q.pop().unwrap(); // 3 1 5 4 -> 2 - assert_eq!(b.id, 2); - q.push(b); // 2 3 1 5 4 - let b = q.pop().unwrap(); // 3 1 5 4 -> 2 - assert_eq!(b.id, 2); - let c = q.pop().unwrap(); // 1 5 4 -> 3 - assert_eq!(c.id, 3); - q.push(b); // 2 1 5 4 - let b = q.pop().unwrap(); // 1 5 4 -> 2 - assert_eq!(b.id, 2); - let e = q.pop().unwrap(); // 5 4 -> 1 - assert_eq!(e.id, 1); - let f = q.pop().unwrap(); // 4 -> 5 - assert_eq!(f.id, 5); - let g = q.pop().unwrap(); // empty -> 4 - assert_eq!(g.id, 4); - assert!(q.pop().is_none()); - } -} diff --git a/src/proto2.rs b/src/proto2.rs deleted file mode 100644 index 7210781..0000000 --- a/src/proto2.rs +++ /dev/null @@ -1,75 +0,0 @@ -use crate::error::*; -use crate::proto::*; - -pub(crate) struct QueryMessage<'a> { - pub(crate) prio: RequestPriority, - pub(crate) path: &'a [u8], - pub(crate) telemetry_id: Option<Vec<u8>>, - pub(crate) body: &'a [u8], -} - -/// QueryMessage encoding: -/// - priority: u8 -/// - path length: u8 -/// - path: [u8; path length] -/// - telemetry id length: u8 -/// - telemetry id: [u8; telemetry id length] -/// - body [u8; ..] -impl<'a> QueryMessage<'a> { - pub(crate) fn encode(self) -> Vec<u8> { - let tel_len = match &self.telemetry_id { - Some(t) => t.len(), - None => 0, - }; - - let mut ret = Vec::with_capacity(10 + self.path.len() + tel_len + self.body.len()); - - ret.push(self.prio); - - ret.push(self.path.len() as u8); - ret.extend_from_slice(self.path); - - if let Some(t) = self.telemetry_id { - ret.push(t.len() as u8); - ret.extend(t); - } else { - ret.push(0u8); - } - - ret.extend_from_slice(self.body); - - ret - } - - pub(crate) fn decode(bytes: &'a [u8]) -> Result<Self, Error> { - if bytes.len() < 3 { - return Err(Error::Message("Invalid protocol message".into())); - } - - let path_length = bytes[1] as usize; - if bytes.len() < 3 + path_length { - return Err(Error::Message("Invalid protocol message".into())); - } - - let telemetry_id_len = bytes[2 + path_length] as usize; - if bytes.len() < 3 + path_length + telemetry_id_len { - return Err(Error::Message("Invalid protocol message".into())); - } - - let path = &bytes[2..2 + path_length]; - let telemetry_id = if telemetry_id_len > 0 { - Some(bytes[3 + path_length..3 + path_length + telemetry_id_len].to_vec()) - } else { - None - }; - - let body = &bytes[3 + path_length + telemetry_id_len..]; - - Ok(Self { - prio: bytes[0], - path, - telemetry_id, - body, - }) - } -} diff --git a/src/recv.rs b/src/recv.rs new file mode 100644 index 0000000..ac93c4b --- /dev/null +++ b/src/recv.rs @@ -0,0 +1,137 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use async_trait::async_trait; +use bytes::Bytes; +use log::*; + +use futures::AsyncReadExt; +use tokio::sync::mpsc; + +use crate::error::*; +use crate::send::*; +use crate::stream::*; + +/// Structure to warn when the sender is dropped before end of stream was reached, like when +/// connection to some remote drops while transmitting data +struct Sender { + inner: Option<mpsc::UnboundedSender<Packet>>, +} + +impl Sender { + fn new(inner: mpsc::UnboundedSender<Packet>) -> Self { + Sender { inner: Some(inner) } + } + + fn send(&self, packet: Packet) { + let _ = self.inner.as_ref().unwrap().send(packet); + } + + fn end(&mut self) { + self.inner = None; + } +} + +impl Drop for Sender { + fn drop(&mut self) { + if let Some(inner) = self.inner.take() { + let _ = inner.send(Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "Netapp connection dropped before end of stream", + ))); + } + } +} + +/// The RecvLoop trait, which is implemented both by the client and the server +/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` +/// and a prototype of a handler for received messages `.recv_handler()` that +/// must be filled by implementors. `.recv_loop()` receives messages in a loop +/// according to the protocol defined above: chunks of message in progress of being +/// received are stored in a buffer, and when the last chunk of a message is received, +/// the full message is passed to the receive handler. +#[async_trait] +pub(crate) trait RecvLoop: Sync + 'static { + fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream); + + async fn recv_loop<R>(self: Arc<Self>, mut read: R, debug_name: String) -> Result<(), Error> + where + R: AsyncReadExt + Unpin + Send + Sync, + { + let mut streams: HashMap<RequestID, Sender> = HashMap::new(); + loop { + trace!( + "recv_loop({}): in_progress = {:?}", + debug_name, + streams.iter().map(|(id, _)| id).collect::<Vec<_>>() + ); + + let mut header_id = [0u8; RequestID::BITS as usize / 8]; + match read.read_exact(&mut header_id[..]).await { + Ok(_) => (), + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break, + Err(e) => return Err(e.into()), + }; + let id = RequestID::from_be_bytes(header_id); + + let mut header_size = [0u8; ChunkLength::BITS as usize / 8]; + read.read_exact(&mut header_size[..]).await?; + let size = ChunkLength::from_be_bytes(header_size); + + let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; + let is_error = (size & ERROR_MARKER) != 0; + let size = (size & CHUNK_LENGTH_MASK) as usize; + let mut next_slice = vec![0; size as usize]; + read.read_exact(&mut next_slice[..]).await?; + + let packet = if is_error { + let kind = u8_to_io_errorkind(next_slice[0]); + let msg = + std::str::from_utf8(&next_slice[1..]).unwrap_or("<invalid utf8 error message>"); + debug!("recv_loop({}): got id {}, error {:?}: {}", debug_name, id, kind, msg); + Some(Err(std::io::Error::new(kind, msg.to_string()))) + } else { + trace!( + "recv_loop({}): got id {}, size {}, has_cont {}", + debug_name, + id, + size, + has_cont + ); + if !next_slice.is_empty() { + Some(Ok(Bytes::from(next_slice))) + } else { + None + } + }; + + let mut sender = if let Some(send) = streams.remove(&(id)) { + send + } else { + let (send, recv) = mpsc::unbounded_channel(); + trace!("recv_loop({}): id {} is new channel", debug_name, id); + self.recv_handler( + id, + Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(recv)), + ); + Sender::new(send) + }; + + if let Some(packet) = packet { + // If we cannot put packet in channel, it means that the + // receiving end of the channel is disconnected. + // We still need to reach eos before dropping this sender + let _ = sender.send(packet); + } + + if has_cont { + assert!(!is_error); + streams.insert(id, sender); + } else { + trace!("recv_loop({}): close channel id {}", debug_name, id); + sender.end(); + } + } + Ok(()) + } +} diff --git a/src/send.rs b/src/send.rs new file mode 100644 index 0000000..3b01cb5 --- /dev/null +++ b/src/send.rs @@ -0,0 +1,300 @@ +use std::collections::{HashMap, VecDeque}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use async_trait::async_trait; +use bytes::{BufMut, Bytes, BytesMut}; +use log::*; + +use futures::{AsyncWriteExt, Future}; +use kuska_handshake::async_std::BoxStreamWrite; +use tokio::sync::mpsc; + +use crate::error::*; +use crate::message::*; +use crate::stream::*; + +// Messages are sent by chunks +// Chunk format: +// - u32 BE: request id (same for request and response) +// - u16 BE: chunk length + flags: +// CHUNK_HAS_CONTINUATION when this is not the last chunk of the stream +// ERROR_MARKER if this chunk denotes an error +// (these two flags are exclusive, an error denotes the end of the stream) +// - [u8; chunk_length], either +// - if not error: chunk data +// - if error: +// - u8: error kind, encoded using error::io_errorkind_to_u8 +// - rest: error message + +pub(crate) type RequestID = u32; +pub(crate) type ChunkLength = u16; + +pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0; +pub(crate) const ERROR_MARKER: ChunkLength = 0x4000; +pub(crate) const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; +pub(crate) const CHUNK_LENGTH_MASK: ChunkLength = 0x3FFF; + +pub(crate) type SendStream = (RequestID, RequestPriority, Option<OrderTag>, ByteStream); + +struct SendQueue { + items: Vec<(u8, SendQueuePriority)>, +} + +struct SendQueuePriority { + items: VecDeque<SendQueueItem>, + order: HashMap<u64, VecDeque<u64>>, +} + +struct SendQueueItem { + id: RequestID, + prio: RequestPriority, + order_tag: Option<OrderTag>, + data: ByteStreamReader, +} + +impl SendQueue { + fn new() -> Self { + Self { + items: Vec::with_capacity(64), + } + } + fn push(&mut self, item: SendQueueItem) { + let prio = item.prio; + let pos_prio = match self.items.binary_search_by(|(p, _)| p.cmp(&prio)) { + Ok(i) => i, + Err(i) => { + self.items.insert(i, (prio, SendQueuePriority::new())); + i + } + }; + self.items[pos_prio].1.push(item); + } + fn is_empty(&self) -> bool { + self.items.iter().all(|(_k, v)| v.is_empty()) + } + + // this is like an async fn, but hand implemented + fn next_ready(&mut self) -> SendQueuePollNextReady<'_> { + SendQueuePollNextReady { queue: self } + } +} + +impl SendQueuePriority { + fn new() -> Self { + Self { + items: VecDeque::new(), + order: HashMap::new(), + } + } + fn push(&mut self, item: SendQueueItem) { + if let Some(OrderTag(stream, order)) = item.order_tag { + let order_vec = self.order.entry(stream).or_default(); + let i = order_vec.iter().take_while(|o2| **o2 < order).count(); + order_vec.insert(i, order); + } + self.items.push_back(item); + } + fn is_empty(&self) -> bool { + self.items.is_empty() + } + fn poll_next_ready(&mut self, ctx: &mut Context<'_>) -> Poll<(RequestID, DataFrame)> { + for (j, item) in self.items.iter_mut().enumerate() { + if let Some(OrderTag(stream, order)) = item.order_tag { + if order > *self.order.get(&stream).unwrap().front().unwrap() { + continue; + } + } + + let mut item_reader = item.data.read_exact_or_eos(MAX_CHUNK_LENGTH as usize); + if let Poll::Ready(bytes_or_err) = Pin::new(&mut item_reader).poll(ctx) { + let id = item.id; + let eos = item.data.eos(); + + let packet = bytes_or_err.map_err(|e| match e { + ReadExactError::Stream(err) => err, + _ => unreachable!(), + }); + + if eos || packet.is_err() { + if let Some(OrderTag(stream, order)) = item.order_tag { + assert_eq!( + self.order.get_mut(&stream).unwrap().pop_front(), + Some(order) + ) + } + self.items.remove(j); + } + + let data_frame = DataFrame::from_packet(packet, !eos); + + return Poll::Ready((id, data_frame)); + } + } + + Poll::Pending + } + fn dump(&self, prio: u8) -> String { + self.items + .iter() + .map(|i| format!("[{} {} {:?}]", prio, i.id, i.order_tag)) + .collect::<Vec<_>>() + .join(" ") + } +} + +struct SendQueuePollNextReady<'a> { + queue: &'a mut SendQueue, +} + +impl<'a> futures::Future for SendQueuePollNextReady<'a> { + type Output = (RequestID, DataFrame); + + fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> { + for (i, (_prio, items_at_prio)) in self.queue.items.iter_mut().enumerate() { + if let Poll::Ready(res) = items_at_prio.poll_next_ready(ctx) { + if items_at_prio.is_empty() { + self.queue.items.remove(i); + } + return Poll::Ready(res); + } + } + // If the queue is empty, this futures is eternally pending. + // This is ok because we use it in a select with another future + // that can interrupt it. + Poll::Pending + } +} + +enum DataFrame { + /// a fixed size buffer containing some data + a boolean indicating whether + /// there may be more data comming from this stream. Can be used for some + /// optimization. It's an error to set it to false if there is more data, but it is correct + /// (albeit sub-optimal) to set it to true if there is nothing coming after + Data(Bytes, bool), + /// An error code automatically signals the end of the stream + Error(Bytes), +} + +impl DataFrame { + fn from_packet(p: Packet, has_cont: bool) -> Self { + match p { + Ok(bytes) => { + assert!(bytes.len() <= MAX_CHUNK_LENGTH as usize); + Self::Data(bytes, has_cont) + } + Err(e) => { + let mut buf = BytesMut::new(); + buf.put_u8(io_errorkind_to_u8(e.kind())); + + let msg = format!("{}", e).into_bytes(); + if msg.len() > (MAX_CHUNK_LENGTH - 1) as usize { + buf.put(&msg[..(MAX_CHUNK_LENGTH - 1) as usize]); + } else { + buf.put(&msg[..]); + } + + Self::Error(buf.freeze()) + } + } + } + + fn header(&self) -> [u8; 2] { + let header_u16 = match self { + DataFrame::Data(data, false) => data.len() as u16, + DataFrame::Data(data, true) => data.len() as u16 | CHUNK_HAS_CONTINUATION, + DataFrame::Error(msg) => msg.len() as u16 | ERROR_MARKER, + }; + ChunkLength::to_be_bytes(header_u16) + } + + fn data(&self) -> &[u8] { + match self { + DataFrame::Data(ref data, _) => &data[..], + DataFrame::Error(ref msg) => &msg[..], + } + } +} + +/// The SendLoop trait, which is implemented both by the client and the server +/// connection objects (ServerConna and ClientConn) adds a method `.send_loop()` +/// that takes a channel of messages to send and an asynchronous writer, +/// and sends messages from the channel to the async writer, putting them in a queue +/// before being sent and doing the round-robin sending strategy. +/// +/// The `.send_loop()` exits when the sending end of the channel is closed, +/// or if there is an error at any time writing to the async writer. +#[async_trait] +pub(crate) trait SendLoop: Sync { + async fn send_loop<W>( + self: Arc<Self>, + msg_recv: mpsc::UnboundedReceiver<SendStream>, + mut write: BoxStreamWrite<W>, + debug_name: String, + ) -> Result<(), Error> + where + W: AsyncWriteExt + Unpin + Send + Sync, + { + let mut sending = SendQueue::new(); + let mut msg_recv = Some(msg_recv); + while msg_recv.is_some() || !sending.is_empty() { + trace!( + "send_loop({}): queue = {:?}", + debug_name, + sending + .items + .iter() + .map(|(prio, i)| i.dump(*prio)) + .collect::<Vec<_>>() + ); + + let recv_fut = async { + if let Some(chan) = &mut msg_recv { + chan.recv().await + } else { + futures::future::pending().await + } + }; + let send_fut = sending.next_ready(); + + // recv_fut is cancellation-safe according to tokio doc, + // send_fut is cancellation-safe as implemented above? + tokio::select! { + biased; // always read incomming channel first if it has data + sth = recv_fut => { + if let Some((id, prio, order_tag, data)) = sth { + trace!("send_loop({}): add stream {} to send", debug_name, id); + sending.push(SendQueueItem { + id, + prio, + order_tag, + data: ByteStreamReader::new(data), + }); + } else { + msg_recv = None; + }; + } + (id, data) = send_fut => { + trace!( + "send_loop({}): id {}, send {} bytes, header_size {}", + debug_name, + id, + data.data().len(), + hex::encode(data.header()) + ); + + let header_id = RequestID::to_be_bytes(id); + write.write_all(&header_id[..]).await?; + + write.write_all(&data.header()).await?; + write.write_all(data.data()).await?; + write.flush().await?; + } + } + } + + let _ = write.goodbye().await; + Ok(()) + } +} diff --git a/src/server.rs b/src/server.rs index a835959..2c12d9d 100644 --- a/src/server.rs +++ b/src/server.rs @@ -2,8 +2,15 @@ use std::net::SocketAddr; use std::sync::Arc; use arc_swap::ArcSwapOption; -use bytes::Bytes; -use log::{debug, trace}; +use async_trait::async_trait; +use log::*; + +use futures::io::{AsyncReadExt, AsyncWriteExt}; +use kuska_handshake::async_std::{handshake_server, BoxStream}; +use tokio::net::TcpStream; +use tokio::select; +use tokio::sync::{mpsc, watch}; +use tokio_util::compat::*; #[cfg(feature = "telemetry")] use opentelemetry::{ @@ -15,21 +22,12 @@ use opentelemetry_contrib::trace::propagator::binary::*; #[cfg(feature = "telemetry")] use rand::{thread_rng, Rng}; -use tokio::net::TcpStream; -use tokio::select; -use tokio::sync::{mpsc, watch}; -use tokio_util::compat::*; - -use futures::io::{AsyncReadExt, AsyncWriteExt}; - -use async_trait::async_trait; - -use kuska_handshake::async_std::{handshake_server, BoxStream}; - use crate::error::*; +use crate::message::*; use crate::netapp::*; -use crate::proto::*; -use crate::proto2::*; +use crate::recv::*; +use crate::send::*; +use crate::stream::*; use crate::util::*; // The client and server connection structs (client.rs and server.rs) @@ -55,7 +53,7 @@ pub(crate) struct ServerConn { netapp: Arc<NetApp>, - resp_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>>, + resp_send: ArcSwapOption<mpsc::UnboundedSender<SendStream>>, } impl ServerConn { @@ -126,13 +124,12 @@ impl ServerConn { Ok(()) } - async fn recv_handler_aux(self: &Arc<Self>, bytes: &[u8]) -> Result<Vec<u8>, Error> { - let msg = QueryMessage::decode(bytes)?; - let path = String::from_utf8(msg.path.to_vec())?; + async fn recv_handler_aux(self: &Arc<Self>, req_enc: ReqEnc) -> Result<RespEnc, Error> { + let path = String::from_utf8(req_enc.path.to_vec())?; let handler_opt = { let endpoints = self.netapp.endpoints.read().unwrap(); - endpoints.get(&path).map(|e| e.clone_endpoint()) + endpoints.get(&path[..]).map(|e| e.clone_endpoint()) }; if let Some(handler) = handler_opt { @@ -140,9 +137,9 @@ impl ServerConn { if #[cfg(feature = "telemetry")] { let tracer = opentelemetry::global::tracer("netapp"); - let mut span = if let Some(telemetry_id) = msg.telemetry_id { + let mut span = if !req_enc.telemetry_id.is_empty() { let propagator = BinaryPropagator::new(); - let context = propagator.from_bytes(telemetry_id); + let context = propagator.from_bytes(req_enc.telemetry_id.to_vec()); let context = Context::new().with_remote_span_context(context); tracer.span_builder(format!(">> RPC {}", path)) .with_kind(SpanKind::Server) @@ -157,13 +154,13 @@ impl ServerConn { .start(&tracer) }; span.set_attribute(KeyValue::new("path", path.to_string())); - span.set_attribute(KeyValue::new("len_query", msg.body.len() as i64)); + span.set_attribute(KeyValue::new("len_query_msg", req_enc.msg.len() as i64)); - handler.handle(msg.body, self.peer_id) + handler.handle(req_enc, self.peer_id) .with_context(Context::current_with_span(span)) .await } else { - handler.handle(msg.body, self.peer_id).await + handler.handle(req_enc, self.peer_id).await } } } else { @@ -176,35 +173,25 @@ impl SendLoop for ServerConn {} #[async_trait] impl RecvLoop for ServerConn { - fn recv_handler(self: &Arc<Self>, id: RequestID, bytes: Vec<u8>) { + fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream) { let resp_send = self.resp_send.load_full().unwrap(); let self2 = self.clone(); tokio::spawn(async move { - trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len()); - let bytes: Bytes = bytes.into(); - - let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 }; - let resp = self2.recv_handler_aux(&bytes[..]).await; + debug!("server: recv_handler got {}", id); - let resp_bytes = match resp { - Ok(rb) => { - let mut resp_bytes = vec![0u8]; - resp_bytes.extend(rb); - resp_bytes - } - Err(e) => { - let mut resp_bytes = vec![e.code()]; - resp_bytes.extend(e.to_string().into_bytes()); - resp_bytes - } + let (prio, resp_enc_result) = match ReqEnc::decode(stream).await { + Ok(req_enc) => (req_enc.prio, self2.recv_handler_aux(req_enc).await), + Err(e) => (PRIO_HIGH, Err(e)), }; - trace!("ServerConn sending response to {}: ", id); + debug!("server: sending response to {}", id); + let (resp_stream, resp_order) = RespEnc::encode(resp_enc_result); resp_send - .send((id, prio, resp_bytes)) - .log_err("ServerConn recv_handler send resp"); + .send((id, prio, resp_order, resp_stream)) + .log_err("ServerConn recv_handler send resp bytes"); + Ok::<_, Error>(()) }); } } diff --git a/src/stream.rs b/src/stream.rs new file mode 100644 index 0000000..efa0ebc --- /dev/null +++ b/src/stream.rs @@ -0,0 +1,159 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; + +use bytes::Bytes; + +use futures::Future; +use futures::{Stream, StreamExt}; +use tokio::io::AsyncRead; + +use crate::bytes_buf::BytesBuf; + +/// A stream of associated data. +/// +/// When sent through Netapp, the Vec may be split in smaller chunk in such a way +/// consecutive Vec may get merged, but Vec and error code may not be reordered +/// +/// Error code 255 means the stream was cut before its end. Other codes have no predefined +/// meaning, it's up to your application to define their semantic. +pub type ByteStream = Pin<Box<dyn Stream<Item = Packet> + Send + Sync>>; + +pub type Packet = Result<Bytes, std::io::Error>; + +// ---- + +pub struct ByteStreamReader { + stream: ByteStream, + buf: BytesBuf, + eos: bool, + err: Option<std::io::Error>, +} + +impl ByteStreamReader { + pub fn new(stream: ByteStream) -> Self { + ByteStreamReader { + stream, + buf: BytesBuf::new(), + eos: false, + err: None, + } + } + + pub fn read_exact(&mut self, read_len: usize) -> ByteStreamReadExact<'_> { + ByteStreamReadExact { + reader: self, + read_len, + fail_on_eos: true, + } + } + + pub fn read_exact_or_eos(&mut self, read_len: usize) -> ByteStreamReadExact<'_> { + ByteStreamReadExact { + reader: self, + read_len, + fail_on_eos: false, + } + } + + pub async fn read_u8(&mut self) -> Result<u8, ReadExactError> { + Ok(self.read_exact(1).await?[0]) + } + + pub async fn read_u16(&mut self) -> Result<u16, ReadExactError> { + let bytes = self.read_exact(2).await?; + let mut b = [0u8; 2]; + b.copy_from_slice(&bytes[..]); + Ok(u16::from_be_bytes(b)) + } + + pub async fn read_u32(&mut self) -> Result<u32, ReadExactError> { + let bytes = self.read_exact(4).await?; + let mut b = [0u8; 4]; + b.copy_from_slice(&bytes[..]); + Ok(u32::from_be_bytes(b)) + } + + pub fn into_stream(self) -> ByteStream { + let buf_stream = futures::stream::iter(self.buf.into_slices().into_iter().map(Ok)); + if let Some(err) = self.err { + Box::pin(buf_stream.chain(futures::stream::once(async move { Err(err) }))) + } else if self.eos { + Box::pin(buf_stream) + } else { + Box::pin(buf_stream.chain(self.stream)) + } + } + + pub fn take_buffer(&mut self) -> Bytes { + self.buf.take_all() + } + + pub fn eos(&self) -> bool { + self.buf.is_empty() && self.eos + } + + fn try_get(&mut self, read_len: usize) -> Option<Bytes> { + self.buf.take_exact(read_len) + } +} + +pub enum ReadExactError { + UnexpectedEos, + Stream(std::io::Error), +} + +#[pin_project::pin_project] +pub struct ByteStreamReadExact<'a> { + #[pin] + reader: &'a mut ByteStreamReader, + read_len: usize, + fail_on_eos: bool, +} + +impl<'a> Future for ByteStreamReadExact<'a> { + type Output = Result<Bytes, ReadExactError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<Bytes, ReadExactError>> { + let mut this = self.project(); + + loop { + if let Some(bytes) = this.reader.try_get(*this.read_len) { + return Poll::Ready(Ok(bytes)); + } + if let Some(err) = &this.reader.err { + let err = std::io::Error::new(err.kind(), format!("{}", err)); + return Poll::Ready(Err(ReadExactError::Stream(err))); + } + if this.reader.eos { + if *this.fail_on_eos { + return Poll::Ready(Err(ReadExactError::UnexpectedEos)); + } else { + return Poll::Ready(Ok(this.reader.take_buffer())); + } + } + + match futures::ready!(this.reader.stream.as_mut().poll_next(cx)) { + Some(Ok(slice)) => { + this.reader.buf.extend(slice); + } + Some(Err(e)) => { + this.reader.err = Some(e); + this.reader.eos = true; + } + None => { + this.reader.eos = true; + } + } + } + } +} + +// ---- + +pub fn asyncread_stream<R: AsyncRead + Send + Sync + 'static>(reader: R) -> ByteStream { + Box::pin(tokio_util::io::ReaderStream::new(reader)) +} + +pub fn stream_asyncread(stream: ByteStream) -> impl AsyncRead + Send + Sync + 'static { + tokio_util::io::StreamReader::new(stream) +} diff --git a/src/test.rs b/src/test.rs index 82c7ba6..ecd5450 100644 --- a/src/test.rs +++ b/src/test.rs @@ -14,6 +14,7 @@ use crate::NodeID; #[tokio::test(flavor = "current_thread")] async fn test_with_basic_scheduler() { + pretty_env_logger::init(); run_test().await } diff --git a/src/util.rs b/src/util.rs index f4dfac7..425d26f 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,18 +1,12 @@ use std::net::SocketAddr; use std::net::ToSocketAddrs; -use serde::Serialize; - use log::info; +use serde::Serialize; use tokio::sync::watch; -/// A node's identifier, which is also its public cryptographic key -pub type NodeID = sodiumoxide::crypto::sign::ed25519::PublicKey; -/// A node's secret key -pub type NodeKey = sodiumoxide::crypto::sign::ed25519::SecretKey; -/// A network key -pub type NetworkKey = sodiumoxide::crypto::auth::Key; +use crate::netapp::*; /// Utility function: encodes any serializable value in MessagePack binary format /// using the RMP library. |