aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/bytes_buf.rs173
-rw-r--r--src/client.rs158
-rw-r--r--src/endpoint.rs106
-rw-r--r--src/error.rs69
-rw-r--r--src/lib.rs8
-rw-r--r--src/message.rs472
-rw-r--r--src/netapp.rs19
-rw-r--r--src/peering/basalt.rs6
-rw-r--r--src/peering/fullmesh.rs11
-rw-r--r--src/proto.rs358
-rw-r--r--src/proto2.rs75
-rw-r--r--src/recv.rs153
-rw-r--r--src/send.rs334
-rw-r--r--src/server.rs110
-rw-r--r--src/stream.rs169
-rw-r--r--src/test.rs1
-rw-r--r--src/util.rs10
17 files changed, 1609 insertions, 623 deletions
diff --git a/src/bytes_buf.rs b/src/bytes_buf.rs
new file mode 100644
index 0000000..857be9d
--- /dev/null
+++ b/src/bytes_buf.rs
@@ -0,0 +1,173 @@
+use std::collections::VecDeque;
+
+pub use bytes::Bytes;
+
+/// A circular buffer of bytes, internally represented as a list of Bytes
+/// for optimization, but that for all intent and purposes acts just like
+/// a big byte slice which can be extended on the right and from which
+/// one can take on the left.
+pub struct BytesBuf {
+ buf: VecDeque<Bytes>,
+ buf_len: usize,
+}
+
+impl BytesBuf {
+ /// Creates a new empty BytesBuf
+ pub fn new() -> Self {
+ Self {
+ buf: VecDeque::new(),
+ buf_len: 0,
+ }
+ }
+
+ /// Returns the number of bytes stored in the BytesBuf
+ #[inline]
+ pub fn len(&self) -> usize {
+ self.buf_len
+ }
+
+ /// Returns true iff the BytesBuf contains zero bytes
+ #[inline]
+ pub fn is_empty(&self) -> bool {
+ self.buf_len == 0
+ }
+
+ /// Adds some bytes to the right of the buffer
+ pub fn extend(&mut self, b: Bytes) {
+ if !b.is_empty() {
+ self.buf_len += b.len();
+ self.buf.push_back(b);
+ }
+ }
+
+ /// Takes the whole content of the buffer and returns it as a single Bytes unit
+ pub fn take_all(&mut self) -> Bytes {
+ if self.buf.len() == 0 {
+ Bytes::new()
+ } else if self.buf.len() == 1 {
+ self.buf_len = 0;
+ self.buf.pop_back().unwrap()
+ } else {
+ let mut ret = Vec::with_capacity(self.buf_len);
+ for b in self.buf.iter() {
+ ret.extend(&b[..]);
+ }
+ self.buf.clear();
+ self.buf_len = 0;
+ Bytes::from(ret)
+ }
+ }
+
+ /// Takes at most max_len bytes from the left of the buffer
+ pub fn take_max(&mut self, max_len: usize) -> Bytes {
+ if self.buf_len <= max_len {
+ self.take_all()
+ } else {
+ self.take_exact_ok(max_len)
+ }
+ }
+
+ /// Take exactly len bytes from the left of the buffer, returns None if
+ /// the BytesBuf doesn't contain enough data
+ pub fn take_exact(&mut self, len: usize) -> Option<Bytes> {
+ if self.buf_len < len {
+ None
+ } else {
+ Some(self.take_exact_ok(len))
+ }
+ }
+
+ fn take_exact_ok(&mut self, len: usize) -> Bytes {
+ assert!(len <= self.buf_len);
+ let front = self.buf.pop_front().unwrap();
+ if front.len() > len {
+ self.buf.push_front(front.slice(len..));
+ self.buf_len -= len;
+ front.slice(..len)
+ } else if front.len() == len {
+ self.buf_len -= len;
+ front
+ } else {
+ let mut ret = Vec::with_capacity(len);
+ ret.extend(&front[..]);
+ self.buf_len -= front.len();
+ while ret.len() < len {
+ let front = self.buf.pop_front().unwrap();
+ if front.len() > len - ret.len() {
+ let take = len - ret.len();
+ ret.extend(front.slice(..take));
+ self.buf.push_front(front.slice(take..));
+ self.buf_len -= take;
+ break;
+ } else {
+ ret.extend(&front[..]);
+ self.buf_len -= front.len();
+ }
+ }
+ Bytes::from(ret)
+ }
+ }
+
+ /// Return the internal sequence of Bytes slices that make up the buffer
+ pub fn into_slices(self) -> VecDeque<Bytes> {
+ self.buf
+ }
+}
+
+impl From<Bytes> for BytesBuf {
+ fn from(b: Bytes) -> BytesBuf {
+ let mut ret = BytesBuf::new();
+ ret.extend(b);
+ ret
+ }
+}
+
+impl From<BytesBuf> for Bytes {
+ fn from(mut b: BytesBuf) -> Bytes {
+ b.take_all()
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ #[test]
+ fn test_bytes_buf() {
+ let mut buf = BytesBuf::new();
+ assert!(buf.len() == 0);
+ assert!(buf.is_empty());
+
+ buf.extend(Bytes::from(b"Hello, world!".to_vec()));
+ assert!(buf.len() == 13);
+ assert!(!buf.is_empty());
+
+ buf.extend(Bytes::from(b"1234567890".to_vec()));
+ assert!(buf.len() == 23);
+ assert!(!buf.is_empty());
+
+ assert_eq!(
+ buf.take_all(),
+ Bytes::from(b"Hello, world!1234567890".to_vec())
+ );
+ assert!(buf.len() == 0);
+ assert!(buf.is_empty());
+
+ buf.extend(Bytes::from(b"1234567890".to_vec()));
+ buf.extend(Bytes::from(b"Hello, world!".to_vec()));
+ assert!(buf.len() == 23);
+ assert!(!buf.is_empty());
+
+ assert_eq!(buf.take_max(12), Bytes::from(b"1234567890He".to_vec()));
+ assert!(buf.len() == 11);
+
+ assert_eq!(buf.take_exact(12), None);
+ assert!(buf.len() == 11);
+ assert_eq!(
+ buf.take_exact(11),
+ Some(Bytes::from(b"llo, world!".to_vec()))
+ );
+ assert!(buf.len() == 0);
+ assert!(buf.is_empty());
+ }
+}
diff --git a/src/client.rs b/src/client.rs
index 5c5a05b..d82c91e 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -1,12 +1,18 @@
-use std::borrow::Borrow;
use std::collections::HashMap;
use std::net::SocketAddr;
+use std::pin::Pin;
use std::sync::atomic::{self, AtomicU32};
use std::sync::{Arc, Mutex};
+use std::task::Poll;
use arc_swap::ArcSwapOption;
+use async_trait::async_trait;
+use bytes::Bytes;
use log::{debug, error, trace};
+use futures::io::AsyncReadExt;
+use futures::Stream;
+use kuska_handshake::async_std::{handshake_client, BoxStream};
use tokio::net::TcpStream;
use tokio::select;
use tokio::sync::{mpsc, oneshot, watch};
@@ -20,27 +26,22 @@ use opentelemetry::{
#[cfg(feature = "telemetry")]
use opentelemetry_contrib::trace::propagator::binary::*;
-use futures::io::AsyncReadExt;
-
-use async_trait::async_trait;
-
-use kuska_handshake::async_std::{handshake_client, BoxStream};
-
-use crate::endpoint::*;
use crate::error::*;
+use crate::message::*;
use crate::netapp::*;
-use crate::proto::*;
-use crate::proto2::*;
+use crate::recv::*;
+use crate::send::*;
+use crate::stream::*;
use crate::util::*;
pub(crate) struct ClientConn {
pub(crate) remote_addr: SocketAddr,
pub(crate) peer_id: NodeID,
- query_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>>,
+ query_send: ArcSwapOption<mpsc::UnboundedSender<SendItem>>,
next_query_number: AtomicU32,
- inflight: Mutex<HashMap<RequestID, oneshot::Sender<Vec<u8>>>>,
+ inflight: Mutex<HashMap<RequestID, oneshot::Sender<ByteStream>>>,
}
impl ClientConn {
@@ -139,15 +140,14 @@ impl ClientConn {
self.query_send.store(None);
}
- pub(crate) async fn call<T, B>(
+ pub(crate) async fn call<T>(
self: Arc<Self>,
- rq: B,
+ req: Req<T>,
path: &str,
prio: RequestPriority,
- ) -> Result<<T as Message>::Response, Error>
+ ) -> Result<Resp<T>, Error>
where
T: Message,
- B: Borrow<T>,
{
let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?;
@@ -162,24 +162,16 @@ impl ClientConn {
.with_kind(SpanKind::Client)
.start(&tracer);
let propagator = BinaryPropagator::new();
- let telemetry_id = Some(propagator.to_bytes(span.span_context()).to_vec());
+ let telemetry_id: Bytes = propagator.to_bytes(span.span_context()).to_vec().into();
} else {
- let telemetry_id: Option<Vec<u8>> = None;
+ let telemetry_id: Bytes = Bytes::new();
}
};
// Encode request
- let body = rmp_to_vec_all_named(rq.borrow())?;
- drop(rq);
-
- let request = QueryMessage {
- prio,
- path: path.as_bytes(),
- telemetry_id,
- body: &body[..],
- };
- let bytes = request.encode();
- drop(body);
+ let req_enc = req.into_enc(prio, path.as_bytes().to_vec().into(), telemetry_id);
+ let req_msg_len = req_enc.msg.len();
+ let (req_stream, req_order) = req_enc.encode();
// Send request through
let (resp_send, resp_recv) = oneshot::channel();
@@ -188,46 +180,41 @@ impl ClientConn {
error!(
"Too many inflight requests! RequestID collision. Interrupting previous request."
);
- if old_ch.send(vec![]).is_err() {
- debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response.");
- }
+ let _ = old_ch.send(Box::pin(futures::stream::once(async move {
+ Err(std::io::Error::new(
+ std::io::ErrorKind::Other,
+ "RequestID collision, too many inflight requests",
+ ))
+ })));
}
- trace!("request: query_send {}, {} bytes", id, bytes.len());
+ debug!(
+ "request: query_send {}, path {}, prio {} (serialized message: {} bytes)",
+ id, path, prio, req_msg_len
+ );
#[cfg(feature = "telemetry")]
- span.set_attribute(KeyValue::new("len_query", bytes.len() as i64));
+ span.set_attribute(KeyValue::new("len_query_msg", req_msg_len as i64));
- query_send.send((id, prio, bytes))?;
+ query_send.send(SendItem::Stream(id, prio, req_order, req_stream))?;
+
+ let canceller = CancelOnDrop::new(id, query_send.as_ref().clone());
cfg_if::cfg_if! {
if #[cfg(feature = "telemetry")] {
- let resp = resp_recv
+ let stream = resp_recv
.with_context(Context::current_with_span(span))
.await?;
} else {
- let resp = resp_recv.await?;
+ let stream = resp_recv.await?;
}
}
- if resp.is_empty() {
- return Err(Error::Message(
- "Response is 0 bytes, either a collision or a protocol error".into(),
- ));
- }
-
- trace!("request response {}: ", id);
+ let stream = Box::pin(canceller.for_stream(stream));
- let code = resp[0];
- if code == 0 {
- Ok(rmp_serde::decode::from_read_ref::<
- _,
- <T as Message>::Response,
- >(&resp[1..])?)
- } else {
- let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default();
- Err(Error::Remote(code, msg))
- }
+ let resp_enc = RespEnc::decode(stream).await?;
+ debug!("client: got response to request {} (path {})", id, path);
+ Resp::from_enc(resp_enc)
}
}
@@ -235,14 +222,71 @@ impl SendLoop for ClientConn {}
#[async_trait]
impl RecvLoop for ClientConn {
- fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>) {
- trace!("ClientConn recv_handler {} ({} bytes)", id, msg.len());
+ fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream) {
+ trace!("ClientConn recv_handler {}", id);
let mut inflight = self.inflight.lock().unwrap();
if let Some(ch) = inflight.remove(&id) {
- if ch.send(msg).is_err() {
+ if ch.send(stream).is_err() {
debug!("Could not send request response, probably because request was interrupted. Dropping response.");
}
+ } else {
+ debug!("Got unexpected response to request {}, dropping it", id);
}
}
}
+
+// ----
+
+struct CancelOnDrop {
+ id: RequestID,
+ query_send: mpsc::UnboundedSender<SendItem>,
+}
+
+impl CancelOnDrop {
+ fn new(id: RequestID, query_send: mpsc::UnboundedSender<SendItem>) -> Self {
+ Self { id, query_send }
+ }
+ fn for_stream(self, stream: ByteStream) -> CancelOnDropStream {
+ CancelOnDropStream {
+ cancel: Some(self),
+ stream: stream,
+ }
+ }
+}
+
+impl Drop for CancelOnDrop {
+ fn drop(&mut self) {
+ trace!("cancelling request {}", self.id);
+ let _ = self.query_send.send(SendItem::Cancel(self.id));
+ }
+}
+
+#[pin_project::pin_project]
+struct CancelOnDropStream {
+ cancel: Option<CancelOnDrop>,
+ #[pin]
+ stream: ByteStream,
+}
+
+impl Stream for CancelOnDropStream {
+ type Item = Packet;
+
+ fn poll_next(
+ self: Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
+ let this = self.project();
+ let res = this.stream.poll_next(cx);
+ if matches!(res, Poll::Ready(None)) {
+ if let Some(c) = this.cancel.take() {
+ std::mem::forget(c)
+ }
+ }
+ res
+ }
+
+ fn size_hint(&self) -> (usize, Option<usize>) {
+ self.stream.size_hint()
+ }
+}
diff --git a/src/endpoint.rs b/src/endpoint.rs
index 42e9a98..bb768de 100644
--- a/src/endpoint.rs
+++ b/src/endpoint.rs
@@ -1,33 +1,25 @@
-use std::borrow::Borrow;
use std::marker::PhantomData;
use std::sync::Arc;
use arc_swap::ArcSwapOption;
use async_trait::async_trait;
-use serde::{Deserialize, Serialize};
-
use crate::error::Error;
+use crate::message::*;
use crate::netapp::*;
-use crate::proto::*;
-use crate::util::*;
-
-/// This trait should be implemented by all messages your application
-/// wants to handle
-pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync {
- type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync;
-}
/// This trait should be implemented by an object of your application
-/// that can handle a message of type `M`.
+/// that can handle a message of type `M`, if it wishes to handle
+/// streams attached to the request and/or to send back streams
+/// attached to the response..
///
/// The handler object should be in an Arc, see `Endpoint::set_handler`
#[async_trait]
-pub trait EndpointHandler<M>: Send + Sync
+pub trait StreamingEndpointHandler<M>: Send + Sync
where
M: Message,
{
- async fn handle(self: &Arc<Self>, m: &M, from: NodeID) -> M::Response;
+ async fn handle(self: &Arc<Self>, m: Req<M>, from: NodeID) -> Resp<M>;
}
/// If one simply wants to use an endpoint in a client fashion,
@@ -35,12 +27,41 @@ where
/// use the unit type `()` as the handler type:
/// it will panic if it is ever made to handle request.
#[async_trait]
-impl<M: Message + 'static> EndpointHandler<M> for () {
+impl<M: Message> EndpointHandler<M> for () {
async fn handle(self: &Arc<()>, _m: &M, _from: NodeID) -> M::Response {
panic!("This endpoint should not have a local handler.");
}
}
+// ----
+
+/// This trait should be implemented by an object of your application
+/// that can handle a message of type `M`, in the cases where it doesn't
+/// care about attached stream in the request nor in the response.
+#[async_trait]
+pub trait EndpointHandler<M>: Send + Sync
+where
+ M: Message,
+{
+ async fn handle(self: &Arc<Self>, m: &M, from: NodeID) -> <M as Message>::Response;
+}
+
+#[async_trait]
+impl<T, M> StreamingEndpointHandler<M> for T
+where
+ T: EndpointHandler<M>,
+ M: Message,
+{
+ async fn handle(self: &Arc<Self>, mut m: Req<M>, from: NodeID) -> Resp<M> {
+ // Immediately drop stream to avoid backpressure if a stream was sent
+ // (this will make all data sent to the stream be ignored immediately)
+ drop(m.take_stream());
+ Resp::new(EndpointHandler::handle(self, m.msg(), from).await)
+ }
+}
+
+// ----
+
/// This struct represents an endpoint for message of type `M`.
///
/// Creating a new endpoint is done by calling `NetApp::endpoint`.
@@ -50,13 +71,13 @@ impl<M: Message + 'static> EndpointHandler<M> for () {
/// An `Endpoint` is used both to send requests to remote nodes,
/// and to specify the handler for such requests on the local node.
/// The type `H` represents the type of the handler object for
-/// endpoint messages (see `EndpointHandler`).
+/// endpoint messages (see `StreamingEndpointHandler`).
pub struct Endpoint<M, H>
where
M: Message,
- H: EndpointHandler<M>,
+ H: StreamingEndpointHandler<M>,
{
- phantom: PhantomData<M>,
+ _phantom: PhantomData<M>,
netapp: Arc<NetApp>,
path: String,
handler: ArcSwapOption<H>,
@@ -65,11 +86,11 @@ where
impl<M, H> Endpoint<M, H>
where
M: Message,
- H: EndpointHandler<M>,
+ H: StreamingEndpointHandler<M>,
{
pub(crate) fn new(netapp: Arc<NetApp>, path: String) -> Self {
Self {
- phantom: PhantomData::default(),
+ _phantom: PhantomData::default(),
netapp,
path,
handler: ArcSwapOption::from(None),
@@ -88,20 +109,22 @@ where
}
/// Call this endpoint on a remote node (or on the local node,
- /// for that matter)
- pub async fn call<B>(
+ /// for that matter). This function invokes the full version that
+ /// allows to attach a stream to the request and to
+ /// receive such a stream attached to the response.
+ pub async fn call_streaming<T>(
&self,
target: &NodeID,
- req: B,
+ req: T,
prio: RequestPriority,
- ) -> Result<<M as Message>::Response, Error>
+ ) -> Result<Resp<M>, Error>
where
- B: Borrow<M>,
+ T: IntoReq<M>,
{
if *target == self.netapp.id {
match self.handler.load_full() {
None => Err(Error::NoHandler),
- Some(h) => Ok(h.handle(req.borrow(), self.netapp.id).await),
+ Some(h) => Ok(h.handle(req.into_req_local(), self.netapp.id).await),
}
} else {
let conn = self
@@ -116,10 +139,22 @@ where
"Not connected: {}",
hex::encode(&target[..8])
))),
- Some(c) => c.call(req, self.path.as_str(), prio).await,
+ Some(c) => c.call(req.into_req()?, self.path.as_str(), prio).await,
}
}
}
+
+ /// Call this endpoint on a remote node. This function is the simplified
+ /// version that doesn't allow to have streams attached to the request
+ /// or the response; see `call_streaming` for the full version.
+ pub async fn call(
+ &self,
+ target: &NodeID,
+ req: M,
+ prio: RequestPriority,
+ ) -> Result<<M as Message>::Response, Error> {
+ Ok(self.call_streaming(target, req, prio).await?.into_msg())
+ }
}
// ---- Internal stuff ----
@@ -128,7 +163,7 @@ pub(crate) type DynEndpoint = Box<dyn GenericEndpoint + Send + Sync>;
#[async_trait]
pub(crate) trait GenericEndpoint {
- async fn handle(&self, buf: &[u8], from: NodeID) -> Result<Vec<u8>, Error>;
+ async fn handle(&self, req_enc: ReqEnc, from: NodeID) -> Result<RespEnc, Error>;
fn drop_handler(&self);
fn clone_endpoint(&self) -> DynEndpoint;
}
@@ -137,22 +172,21 @@ pub(crate) trait GenericEndpoint {
pub(crate) struct EndpointArc<M, H>(pub(crate) Arc<Endpoint<M, H>>)
where
M: Message,
- H: EndpointHandler<M>;
+ H: StreamingEndpointHandler<M>;
#[async_trait]
impl<M, H> GenericEndpoint for EndpointArc<M, H>
where
- M: Message + 'static,
- H: EndpointHandler<M> + 'static,
+ M: Message,
+ H: StreamingEndpointHandler<M> + 'static,
{
- async fn handle(&self, buf: &[u8], from: NodeID) -> Result<Vec<u8>, Error> {
+ async fn handle(&self, req_enc: ReqEnc, from: NodeID) -> Result<RespEnc, Error> {
match self.0.handler.load_full() {
None => Err(Error::NoHandler),
Some(h) => {
- let req = rmp_serde::decode::from_read_ref::<_, M>(buf)?;
- let res = h.handle(&req, from).await;
- let res_bytes = rmp_to_vec_all_named(&res)?;
- Ok(res_bytes)
+ let req = Req::from_enc(req_enc)?;
+ let res = h.handle(req, from).await;
+ Ok(res.into_enc()?)
}
}
}
diff --git a/src/error.rs b/src/error.rs
index 99acdd1..c0aeeac 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -1,6 +1,6 @@
-use err_derive::Error;
use std::io;
+use err_derive::Error;
use log::error;
#[derive(Debug, Error)]
@@ -25,6 +25,15 @@ pub enum Error {
#[error(display = "UTF8 error: {}", _0)]
UTF8(#[error(source)] std::string::FromUtf8Error),
+ #[error(display = "Framing protocol error")]
+ Framing,
+
+ #[error(display = "Remote error ({:?}): {}", _0, _1)]
+ Remote(io::ErrorKind, String),
+
+ #[error(display = "Request ID collision")]
+ IdCollision,
+
#[error(display = "{}", _0)]
Message(String),
@@ -36,28 +45,6 @@ pub enum Error {
#[error(display = "Version mismatch: {}", _0)]
VersionMismatch(String),
-
- #[error(display = "Remote error {}: {}", _0, _1)]
- Remote(u8, String),
-}
-
-impl Error {
- pub fn code(&self) -> u8 {
- match self {
- Self::Io(_) => 100,
- Self::TokioJoin(_) => 110,
- Self::OneshotRecv(_) => 111,
- Self::RMPEncode(_) => 10,
- Self::RMPDecode(_) => 11,
- Self::UTF8(_) => 12,
- Self::NoHandler => 20,
- Self::ConnectionClosed => 21,
- Self::Handshake(_) => 30,
- Self::VersionMismatch(_) => 31,
- Self::Remote(c, _) => *c,
- Self::Message(_) => 99,
- }
- }
}
impl<T> From<tokio::sync::watch::error::SendError<T>> for Error {
@@ -101,3 +88,39 @@ where
}
}
}
+
+// ---- Helpers for serializing I/O Errors
+
+pub(crate) fn u8_to_io_errorkind(v: u8) -> std::io::ErrorKind {
+ use std::io::ErrorKind;
+ match v {
+ 101 => ErrorKind::ConnectionAborted,
+ 102 => ErrorKind::BrokenPipe,
+ 103 => ErrorKind::WouldBlock,
+ 104 => ErrorKind::InvalidInput,
+ 105 => ErrorKind::InvalidData,
+ 106 => ErrorKind::TimedOut,
+ 107 => ErrorKind::Interrupted,
+ 108 => ErrorKind::UnexpectedEof,
+ 109 => ErrorKind::OutOfMemory,
+ 110 => ErrorKind::ConnectionReset,
+ _ => ErrorKind::Other,
+ }
+}
+
+pub(crate) fn io_errorkind_to_u8(kind: std::io::ErrorKind) -> u8 {
+ use std::io::ErrorKind;
+ match kind {
+ ErrorKind::ConnectionAborted => 101,
+ ErrorKind::BrokenPipe => 102,
+ ErrorKind::WouldBlock => 103,
+ ErrorKind::InvalidInput => 104,
+ ErrorKind::InvalidData => 105,
+ ErrorKind::TimedOut => 106,
+ ErrorKind::Interrupted => 107,
+ ErrorKind::UnexpectedEof => 108,
+ ErrorKind::OutOfMemory => 109,
+ ErrorKind::ConnectionReset => 110,
+ _ => 100,
+ }
+}
diff --git a/src/lib.rs b/src/lib.rs
index cb24337..8e30e40 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -13,21 +13,23 @@
//! about message priorization.
//! Also check out the examples to learn how to use this crate.
+pub mod bytes_buf;
pub mod error;
+pub mod stream;
pub mod util;
pub mod endpoint;
-pub mod proto;
+pub mod message;
mod client;
-mod proto2;
+mod recv;
+mod send;
mod server;
pub mod netapp;
pub mod peering;
pub use crate::netapp::*;
-pub use util::{NetworkKey, NodeID, NodeKey};
#[cfg(test)]
mod test;
diff --git a/src/message.rs b/src/message.rs
new file mode 100644
index 0000000..ec9433a
--- /dev/null
+++ b/src/message.rs
@@ -0,0 +1,472 @@
+use std::fmt;
+use std::marker::PhantomData;
+use std::sync::Arc;
+
+use bytes::{BufMut, Bytes, BytesMut};
+use rand::prelude::*;
+use serde::{Deserialize, Serialize};
+
+use futures::stream::StreamExt;
+
+use crate::error::*;
+use crate::stream::*;
+use crate::util::*;
+
+/// Priority of a request (click to read more about priorities).
+///
+/// This priority value is used to priorize messages
+/// in the send queue of the client, and their responses in the send queue of the
+/// server. Lower values mean higher priority.
+///
+/// This mechanism is usefull for messages bigger than the maximum chunk size
+/// (set at `0x4000` bytes), such as large file transfers.
+/// In such case, all of the messages in the send queue with the highest priority
+/// will take turns to send individual chunks, in a round-robin fashion.
+/// Once all highest priority messages are sent successfully, the messages with
+/// the next highest priority will begin being sent in the same way.
+///
+/// The same priority value is given to a request and to its associated response.
+pub type RequestPriority = u8;
+
+/// Priority class: high
+pub const PRIO_HIGH: RequestPriority = 0x20;
+/// Priority class: normal
+pub const PRIO_NORMAL: RequestPriority = 0x40;
+/// Priority class: background
+pub const PRIO_BACKGROUND: RequestPriority = 0x80;
+/// Priority: primary among given class
+pub const PRIO_PRIMARY: RequestPriority = 0x00;
+/// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`)
+pub const PRIO_SECONDARY: RequestPriority = 0x01;
+
+// ----
+
+#[derive(Clone, Copy)]
+pub struct OrderTagStream(u64);
+#[derive(Clone, Copy, Serialize, Deserialize, Debug)]
+pub struct OrderTag(pub(crate) u64, pub(crate) u64);
+
+impl OrderTag {
+ pub fn stream() -> OrderTagStream {
+ OrderTagStream(thread_rng().gen())
+ }
+}
+impl OrderTagStream {
+ pub fn order(&self, order: u64) -> OrderTag {
+ OrderTag(self.0, order)
+ }
+}
+
+// ----
+
+/// This trait should be implemented by all messages your application
+/// wants to handle
+pub trait Message: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static {
+ type Response: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static;
+}
+
+// ----
+
+/// The Req<M> is a helper object used to create requests and attach them
+/// a stream of data. If the stream is a fixed Bytes and not a ByteStream,
+/// Req<M> is cheaply clonable to allow the request to be sent to different
+/// peers (Clone will panic if the stream is a ByteStream).
+pub struct Req<M: Message> {
+ pub(crate) msg: Arc<M>,
+ pub(crate) msg_ser: Option<Bytes>,
+ pub(crate) stream: AttachedStream,
+ pub(crate) order_tag: Option<OrderTag>,
+}
+
+impl<M: Message> Req<M> {
+ pub fn new(v: M) -> Result<Self, Error> {
+ Ok(v.into_req()?)
+ }
+
+ pub fn with_stream_from_buffer(self, b: Bytes) -> Self {
+ Self {
+ stream: AttachedStream::Fixed(b),
+ ..self
+ }
+ }
+
+ pub fn with_stream(self, b: ByteStream) -> Self {
+ Self {
+ stream: AttachedStream::Stream(b),
+ ..self
+ }
+ }
+
+ pub fn with_order_tag(self, order_tag: OrderTag) -> Self {
+ Self {
+ order_tag: Some(order_tag),
+ ..self
+ }
+ }
+
+ pub fn msg(&self) -> &M {
+ &self.msg
+ }
+
+ pub fn take_stream(&mut self) -> Option<ByteStream> {
+ std::mem::replace(&mut self.stream, AttachedStream::None).into_stream()
+ }
+
+ pub(crate) fn into_enc(
+ self,
+ prio: RequestPriority,
+ path: Bytes,
+ telemetry_id: Bytes,
+ ) -> ReqEnc {
+ ReqEnc {
+ prio,
+ path,
+ telemetry_id,
+ msg: self.msg_ser.unwrap(),
+ stream: self.stream.into_stream(),
+ order_tag: self.order_tag,
+ }
+ }
+
+ pub(crate) fn from_enc(enc: ReqEnc) -> Result<Self, rmp_serde::decode::Error> {
+ let msg = rmp_serde::decode::from_read_ref(&enc.msg)?;
+ Ok(Req {
+ msg: Arc::new(msg),
+ msg_ser: Some(enc.msg),
+ stream: enc
+ .stream
+ .map(AttachedStream::Stream)
+ .unwrap_or(AttachedStream::None),
+ order_tag: enc.order_tag,
+ })
+ }
+}
+
+pub trait IntoReq<M: Message> {
+ fn into_req(self) -> Result<Req<M>, rmp_serde::encode::Error>;
+ fn into_req_local(self) -> Req<M>;
+}
+
+impl<M: Message> IntoReq<M> for M {
+ fn into_req(self) -> Result<Req<M>, rmp_serde::encode::Error> {
+ let msg_ser = rmp_to_vec_all_named(&self)?;
+ Ok(Req {
+ msg: Arc::new(self),
+ msg_ser: Some(Bytes::from(msg_ser)),
+ stream: AttachedStream::None,
+ order_tag: None,
+ })
+ }
+ fn into_req_local(self) -> Req<M> {
+ Req {
+ msg: Arc::new(self),
+ msg_ser: None,
+ stream: AttachedStream::None,
+ order_tag: None,
+ }
+ }
+}
+
+impl<M: Message> IntoReq<M> for Req<M> {
+ fn into_req(self) -> Result<Req<M>, rmp_serde::encode::Error> {
+ Ok(self)
+ }
+ fn into_req_local(self) -> Req<M> {
+ self
+ }
+}
+
+impl<M: Message> Clone for Req<M> {
+ fn clone(&self) -> Self {
+ let stream = match &self.stream {
+ AttachedStream::None => AttachedStream::None,
+ AttachedStream::Fixed(b) => AttachedStream::Fixed(b.clone()),
+ AttachedStream::Stream(_) => {
+ panic!("Cannot clone a Req<_> with a non-buffer attached stream")
+ }
+ };
+ Self {
+ msg: self.msg.clone(),
+ msg_ser: self.msg_ser.clone(),
+ stream,
+ order_tag: self.order_tag,
+ }
+ }
+}
+
+impl<M> fmt::Debug for Req<M>
+where
+ M: Message + fmt::Debug,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
+ write!(f, "Req[{:?}", self.msg)?;
+ match &self.stream {
+ AttachedStream::None => write!(f, "]"),
+ AttachedStream::Fixed(b) => write!(f, "; stream=buf:{}]", b.len()),
+ AttachedStream::Stream(_) => write!(f, "; stream]"),
+ }
+ }
+}
+
+// ----
+
+/// The Resp<M> represents a full response from a RPC that may have
+/// an attached stream.
+pub struct Resp<M: Message> {
+ pub(crate) _phantom: PhantomData<M>,
+ pub(crate) msg: M::Response,
+ pub(crate) stream: AttachedStream,
+ pub(crate) order_tag: Option<OrderTag>,
+}
+
+impl<M: Message> Resp<M> {
+ pub fn new(v: M::Response) -> Self {
+ Resp {
+ _phantom: Default::default(),
+ msg: v,
+ stream: AttachedStream::None,
+ order_tag: None,
+ }
+ }
+
+ pub fn with_stream_from_buffer(self, b: Bytes) -> Self {
+ Self {
+ stream: AttachedStream::Fixed(b),
+ ..self
+ }
+ }
+
+ pub fn with_stream(self, b: ByteStream) -> Self {
+ Self {
+ stream: AttachedStream::Stream(b),
+ ..self
+ }
+ }
+
+ pub fn with_order_tag(self, order_tag: OrderTag) -> Self {
+ Self {
+ order_tag: Some(order_tag),
+ ..self
+ }
+ }
+
+ pub fn msg(&self) -> &M::Response {
+ &self.msg
+ }
+
+ pub fn into_msg(self) -> M::Response {
+ self.msg
+ }
+
+ pub fn into_parts(self) -> (M::Response, Option<ByteStream>) {
+ (self.msg, self.stream.into_stream())
+ }
+
+ pub(crate) fn into_enc(self) -> Result<RespEnc, rmp_serde::encode::Error> {
+ Ok(RespEnc {
+ msg: rmp_to_vec_all_named(&self.msg)?.into(),
+ stream: self.stream.into_stream(),
+ order_tag: self.order_tag,
+ })
+ }
+
+ pub(crate) fn from_enc(enc: RespEnc) -> Result<Self, Error> {
+ let msg = rmp_serde::decode::from_read_ref(&enc.msg)?;
+ Ok(Self {
+ _phantom: Default::default(),
+ msg,
+ stream: enc
+ .stream
+ .map(AttachedStream::Stream)
+ .unwrap_or(AttachedStream::None),
+ order_tag: enc.order_tag,
+ })
+ }
+}
+
+impl<M> fmt::Debug for Resp<M>
+where
+ M: Message,
+ <M as Message>::Response: fmt::Debug,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
+ write!(f, "Resp[{:?}", self.msg)?;
+ match &self.stream {
+ AttachedStream::None => write!(f, "]"),
+ AttachedStream::Fixed(b) => write!(f, "; stream=buf:{}]", b.len()),
+ AttachedStream::Stream(_) => write!(f, "; stream]"),
+ }
+ }
+}
+
+// ----
+
+pub(crate) enum AttachedStream {
+ None,
+ Fixed(Bytes),
+ Stream(ByteStream),
+}
+
+impl AttachedStream {
+ pub fn into_stream(self) -> Option<ByteStream> {
+ match self {
+ AttachedStream::None => None,
+ AttachedStream::Fixed(b) => Some(Box::pin(futures::stream::once(async move { Ok(b) }))),
+ AttachedStream::Stream(s) => Some(s),
+ }
+ }
+}
+
+// ---- ----
+
+/// Encoding for requests into a ByteStream:
+/// - priority: u8
+/// - path length: u8
+/// - path: [u8; path length]
+/// - telemetry id length: u8
+/// - telemetry id: [u8; telemetry id length]
+/// - msg len: u32
+/// - msg [u8; ..]
+/// - the attached stream as the rest of the encoded stream
+pub(crate) struct ReqEnc {
+ pub(crate) prio: RequestPriority,
+ pub(crate) path: Bytes,
+ pub(crate) telemetry_id: Bytes,
+ pub(crate) msg: Bytes,
+ pub(crate) stream: Option<ByteStream>,
+ pub(crate) order_tag: Option<OrderTag>,
+}
+
+impl ReqEnc {
+ pub(crate) fn encode(self) -> (ByteStream, Option<OrderTag>) {
+ let mut buf = BytesMut::with_capacity(
+ self.path.len() + self.telemetry_id.len() + self.msg.len() + 16,
+ );
+
+ buf.put_u8(self.prio);
+
+ buf.put_u8(self.path.len() as u8);
+ buf.put(self.path);
+
+ buf.put_u8(self.telemetry_id.len() as u8);
+ buf.put(&self.telemetry_id[..]);
+
+ buf.put_u32(self.msg.len() as u32);
+
+ let header = buf.freeze();
+
+ let res_stream: ByteStream = if let Some(stream) = self.stream {
+ Box::pin(futures::stream::iter([Ok(header), Ok(self.msg)]).chain(stream))
+ } else {
+ Box::pin(futures::stream::iter([Ok(header), Ok(self.msg)]))
+ };
+ (res_stream, self.order_tag)
+ }
+
+ pub(crate) async fn decode(stream: ByteStream) -> Result<Self, Error> {
+ Self::decode_aux(stream)
+ .await
+ .map_err(read_exact_error_to_error)
+ }
+
+ async fn decode_aux(stream: ByteStream) -> Result<Self, ReadExactError> {
+ let mut reader = ByteStreamReader::new(stream);
+
+ let prio = reader.read_u8().await?;
+
+ let path_len = reader.read_u8().await?;
+ let path = reader.read_exact(path_len as usize).await?;
+
+ let telemetry_id_len = reader.read_u8().await?;
+ let telemetry_id = reader.read_exact(telemetry_id_len as usize).await?;
+
+ let msg_len = reader.read_u32().await?;
+ let msg = reader.read_exact(msg_len as usize).await?;
+
+ Ok(Self {
+ prio,
+ path,
+ telemetry_id,
+ msg,
+ stream: Some(reader.into_stream()),
+ order_tag: None,
+ })
+ }
+}
+
+/// Encoding for responses into a ByteStream:
+/// IF SUCCESS:
+/// - 0: u8
+/// - msg len: u32
+/// - msg [u8; ..]
+/// - the attached stream as the rest of the encoded stream
+/// IF ERROR:
+/// - message length + 1: u8
+/// - error code: u8
+/// - message: [u8; message_length]
+pub(crate) struct RespEnc {
+ msg: Bytes,
+ stream: Option<ByteStream>,
+ order_tag: Option<OrderTag>,
+}
+
+impl RespEnc {
+ pub(crate) fn encode(resp: Result<Self, Error>) -> (ByteStream, Option<OrderTag>) {
+ match resp {
+ Ok(Self {
+ msg,
+ stream,
+ order_tag,
+ }) => {
+ let mut buf = BytesMut::with_capacity(4);
+ buf.put_u32(msg.len() as u32);
+ let header = buf.freeze();
+
+ let res_stream: ByteStream = if let Some(stream) = stream {
+ Box::pin(futures::stream::iter([Ok(header), Ok(msg)]).chain(stream))
+ } else {
+ Box::pin(futures::stream::iter([Ok(header), Ok(msg)]))
+ };
+ (res_stream, order_tag)
+ }
+ Err(err) => {
+ let err = std::io::Error::new(
+ std::io::ErrorKind::Other,
+ format!("netapp error: {}", err),
+ );
+ (
+ Box::pin(futures::stream::once(async move { Err(err) })),
+ None,
+ )
+ }
+ }
+ }
+
+ pub(crate) async fn decode(stream: ByteStream) -> Result<Self, Error> {
+ Self::decode_aux(stream)
+ .await
+ .map_err(read_exact_error_to_error)
+ }
+
+ async fn decode_aux(stream: ByteStream) -> Result<Self, ReadExactError> {
+ let mut reader = ByteStreamReader::new(stream);
+
+ let msg_len = reader.read_u32().await?;
+ let msg = reader.read_exact(msg_len as usize).await?;
+
+ reader.fill_buffer().await;
+
+ Ok(Self {
+ msg,
+ stream: Some(reader.into_stream()),
+ order_tag: None,
+ })
+ }
+}
+
+fn read_exact_error_to_error(e: ReadExactError) -> Error {
+ match e {
+ ReadExactError::Stream(err) => Error::Remote(err.kind(), err.to_string()),
+ ReadExactError::UnexpectedEos => Error::Framing,
+ }
+}
diff --git a/src/netapp.rs b/src/netapp.rs
index e9efa2e..f1e14ed 100644
--- a/src/netapp.rs
+++ b/src/netapp.rs
@@ -20,9 +20,15 @@ use tokio::sync::{mpsc, watch};
use crate::client::*;
use crate::endpoint::*;
use crate::error::*;
-use crate::proto::*;
+use crate::message::*;
use crate::server::*;
-use crate::util::*;
+
+/// A node's identifier, which is also its public cryptographic key
+pub type NodeID = sodiumoxide::crypto::sign::ed25519::PublicKey;
+/// A node's secret key
+pub type NodeKey = sodiumoxide::crypto::sign::ed25519::SecretKey;
+/// A network key
+pub type NetworkKey = sodiumoxide::crypto::auth::Key;
/// Tag which is exchanged between client and server upon connection establishment
/// to check that they are running compatible versions of Netapp,
@@ -30,9 +36,9 @@ use crate::util::*;
pub(crate) type VersionTag = [u8; 16];
/// Value of the Netapp version used in the version tag
-pub(crate) const NETAPP_VERSION_TAG: u64 = 0x6e65746170700004; // netapp 0x0004
+pub(crate) const NETAPP_VERSION_TAG: u64 = 0x6e65746170700005; // netapp 0x0005
-#[derive(Serialize, Deserialize, Debug)]
+#[derive(Serialize, Deserialize, Debug, Clone)]
pub(crate) struct HelloMessage {
pub server_addr: Option<IpAddr>,
pub server_port: u16,
@@ -152,7 +158,7 @@ impl NetApp {
pub fn endpoint<M, H>(self: &Arc<Self>, path: String) -> Arc<Endpoint<M, H>>
where
M: Message + 'static,
- H: EndpointHandler<M> + 'static,
+ H: StreamingEndpointHandler<M> + 'static,
{
let endpoint = Arc::new(Endpoint::<M, H>::new(self.clone(), path.clone()));
let endpoint_arc = EndpointArc(endpoint.clone());
@@ -397,13 +403,14 @@ impl NetApp {
hello_endpoint
.call(
&conn.peer_id,
- &HelloMessage {
+ HelloMessage {
server_addr,
server_port,
},
PRIO_NORMAL,
)
.await
+ .map(|_| ())
.log_err("Sending hello message");
});
}
diff --git a/src/peering/basalt.rs b/src/peering/basalt.rs
index cd70c2b..2c27dba 100644
--- a/src/peering/basalt.rs
+++ b/src/peering/basalt.rs
@@ -14,8 +14,8 @@ use sodiumoxide::crypto::hash;
use tokio::sync::watch;
use crate::endpoint::*;
+use crate::message::*;
use crate::netapp::*;
-use crate::proto::*;
use crate::NodeID;
// -- Protocol messages --
@@ -331,7 +331,7 @@ impl Basalt {
async fn do_pull(self: Arc<Self>, peer: NodeID) {
match self
.pull_endpoint
- .call(&peer, &PullMessage {}, PRIO_NORMAL)
+ .call(&peer, PullMessage {}, PRIO_NORMAL)
.await
{
Ok(resp) => {
@@ -346,7 +346,7 @@ impl Basalt {
async fn do_push(self: Arc<Self>, peer: NodeID) {
let push_msg = self.make_push_message();
- match self.push_endpoint.call(&peer, &push_msg, PRIO_NORMAL).await {
+ match self.push_endpoint.call(&peer, push_msg, PRIO_NORMAL).await {
Ok(_) => {
trace!("KYEV PEXo {}", hex::encode(peer));
}
diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs
index 208cfe4..7f1c065 100644
--- a/src/peering/fullmesh.rs
+++ b/src/peering/fullmesh.rs
@@ -17,7 +17,8 @@ use sodiumoxide::crypto::hash;
use crate::endpoint::*;
use crate::error::*;
use crate::netapp::*;
-use crate::proto::*;
+
+use crate::message::*;
use crate::NodeID;
const CONN_RETRY_INTERVAL: Duration = Duration::from_secs(30);
@@ -29,7 +30,7 @@ const FAILED_PING_THRESHOLD: usize = 4;
// -- Protocol messages --
-#[derive(Serialize, Deserialize)]
+#[derive(Serialize, Deserialize, Clone)]
struct PingMessage {
pub id: u64,
pub peer_list_hash: hash::Digest,
@@ -39,7 +40,7 @@ impl Message for PingMessage {
type Response = PingMessage;
}
-#[derive(Serialize, Deserialize)]
+#[derive(Serialize, Deserialize, Clone)]
struct PeerListMessage {
pub list: Vec<(NodeID, SocketAddr)>,
}
@@ -382,7 +383,7 @@ impl FullMeshPeeringStrategy {
ping_time
);
let ping_response = select! {
- r = self.ping_endpoint.call(&id, &ping_msg, PRIO_HIGH) => r,
+ r = self.ping_endpoint.call(&id, ping_msg, PRIO_HIGH) => r,
_ = tokio::time::sleep(PING_TIMEOUT) => Err(Error::Message("Ping timeout".into())),
};
@@ -434,7 +435,7 @@ impl FullMeshPeeringStrategy {
let pex_message = PeerListMessage { list: peer_list };
match self
.peer_list_endpoint
- .call(id, &pex_message, PRIO_BACKGROUND)
+ .call(id, pex_message, PRIO_BACKGROUND)
.await
{
Err(e) => warn!("Error doing peer exchange: {}", e),
diff --git a/src/proto.rs b/src/proto.rs
deleted file mode 100644
index 8f7e70f..0000000
--- a/src/proto.rs
+++ /dev/null
@@ -1,358 +0,0 @@
-use std::collections::{HashMap, VecDeque};
-use std::fmt::Write;
-use std::sync::Arc;
-
-use log::trace;
-
-use futures::{AsyncReadExt, AsyncWriteExt};
-use kuska_handshake::async_std::BoxStreamWrite;
-
-use tokio::sync::mpsc;
-
-use async_trait::async_trait;
-
-use crate::error::*;
-
-/// Priority of a request (click to read more about priorities).
-///
-/// This priority value is used to priorize messages
-/// in the send queue of the client, and their responses in the send queue of the
-/// server. Lower values mean higher priority.
-///
-/// This mechanism is usefull for messages bigger than the maximum chunk size
-/// (set at `0x4000` bytes), such as large file transfers.
-/// In such case, all of the messages in the send queue with the highest priority
-/// will take turns to send individual chunks, in a round-robin fashion.
-/// Once all highest priority messages are sent successfully, the messages with
-/// the next highest priority will begin being sent in the same way.
-///
-/// The same priority value is given to a request and to its associated response.
-pub type RequestPriority = u8;
-
-/// Priority class: high
-pub const PRIO_HIGH: RequestPriority = 0x20;
-/// Priority class: normal
-pub const PRIO_NORMAL: RequestPriority = 0x40;
-/// Priority class: background
-pub const PRIO_BACKGROUND: RequestPriority = 0x80;
-/// Priority: primary among given class
-pub const PRIO_PRIMARY: RequestPriority = 0x00;
-/// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`)
-pub const PRIO_SECONDARY: RequestPriority = 0x01;
-
-// Messages are sent by chunks
-// Chunk format:
-// - u32 BE: request id (same for request and response)
-// - u16 BE: chunk length, possibly with CHUNK_HAS_CONTINUATION flag
-// when this is not the last chunk of the message
-// - [u8; chunk_length] chunk data
-
-pub(crate) type RequestID = u32;
-type ChunkLength = u16;
-const MAX_CHUNK_LENGTH: ChunkLength = 0x4000;
-const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000;
-
-struct SendQueueItem {
- id: RequestID,
- prio: RequestPriority,
- data: Vec<u8>,
- cursor: usize,
-}
-
-struct SendQueue {
- items: VecDeque<(u8, VecDeque<SendQueueItem>)>,
-}
-
-impl SendQueue {
- fn new() -> Self {
- Self {
- items: VecDeque::with_capacity(64),
- }
- }
- fn push(&mut self, item: SendQueueItem) {
- let prio = item.prio;
- let pos_prio = match self.items.binary_search_by(|(p, _)| p.cmp(&prio)) {
- Ok(i) => i,
- Err(i) => {
- self.items.insert(i, (prio, VecDeque::new()));
- i
- }
- };
- self.items[pos_prio].1.push_back(item);
- }
- fn pop(&mut self) -> Option<SendQueueItem> {
- match self.items.pop_front() {
- None => None,
- Some((prio, mut items_at_prio)) => {
- let ret = items_at_prio.pop_front();
- if !items_at_prio.is_empty() {
- self.items.push_front((prio, items_at_prio));
- }
- ret.or_else(|| self.pop())
- }
- }
- }
- fn is_empty(&self) -> bool {
- self.items.iter().all(|(_k, v)| v.is_empty())
- }
- fn dump(&self) -> String {
- let mut ret = String::new();
- for (prio, q) in self.items.iter() {
- for item in q.iter() {
- write!(
- &mut ret,
- " [{} {} ({})]",
- prio,
- item.data.len() - item.cursor,
- item.id
- )
- .unwrap();
- }
- }
- ret
- }
-}
-
-/// The SendLoop trait, which is implemented both by the client and the server
-/// connection objects (ServerConna and ClientConn) adds a method `.send_loop()`
-/// that takes a channel of messages to send and an asynchronous writer,
-/// and sends messages from the channel to the async writer, putting them in a queue
-/// before being sent and doing the round-robin sending strategy.
-///
-/// The `.send_loop()` exits when the sending end of the channel is closed,
-/// or if there is an error at any time writing to the async writer.
-#[async_trait]
-pub(crate) trait SendLoop: Sync {
- async fn send_loop<W>(
- self: Arc<Self>,
- mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec<u8>)>,
- mut write: BoxStreamWrite<W>,
- debug_name: String,
- ) -> Result<(), Error>
- where
- W: AsyncWriteExt + Unpin + Send + Sync,
- {
- let mut sending = SendQueue::new();
- let mut should_exit = false;
- while !should_exit || !sending.is_empty() {
- trace!("send_loop({}): queue = {}", debug_name, sending.dump());
- if let Ok((id, prio, data)) = msg_recv.try_recv() {
- trace!(
- "send_loop({}): new message to send, id = {}, prio = {}, {} bytes",
- debug_name,
- id,
- prio,
- data.len()
- );
- sending.push(SendQueueItem {
- id,
- prio,
- data,
- cursor: 0,
- });
- } else if let Some(mut item) = sending.pop() {
- trace!(
- "send_loop({}): sending bytes for {} ({} bytes, {} already sent)",
- debug_name,
- item.id,
- item.data.len(),
- item.cursor
- );
- let header_id = RequestID::to_be_bytes(item.id);
- write.write_all(&header_id[..]).await?;
-
- if item.data.len() - item.cursor > MAX_CHUNK_LENGTH as usize {
- let size_header =
- ChunkLength::to_be_bytes(MAX_CHUNK_LENGTH | CHUNK_HAS_CONTINUATION);
- write.write_all(&size_header[..]).await?;
-
- let new_cursor = item.cursor + MAX_CHUNK_LENGTH as usize;
- write.write_all(&item.data[item.cursor..new_cursor]).await?;
- item.cursor = new_cursor;
-
- sending.push(item);
- } else {
- let send_len = (item.data.len() - item.cursor) as ChunkLength;
-
- let size_header = ChunkLength::to_be_bytes(send_len);
- write.write_all(&size_header[..]).await?;
-
- write.write_all(&item.data[item.cursor..]).await?;
- }
- write.flush().await?;
- } else {
- let sth = msg_recv.recv().await;
- if let Some((id, prio, data)) = sth {
- trace!(
- "send_loop({}): new message to send, id = {}, prio = {}, {} bytes",
- debug_name,
- id,
- prio,
- data.len()
- );
- sending.push(SendQueueItem {
- id,
- prio,
- data,
- cursor: 0,
- });
- } else {
- should_exit = true;
- }
- }
- }
-
- let _ = write.goodbye().await;
- Ok(())
- }
-}
-
-/// The RecvLoop trait, which is implemented both by the client and the server
-/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()`
-/// and a prototype of a handler for received messages `.recv_handler()` that
-/// must be filled by implementors. `.recv_loop()` receives messages in a loop
-/// according to the protocol defined above: chunks of message in progress of being
-/// received are stored in a buffer, and when the last chunk of a message is received,
-/// the full message is passed to the receive handler.
-#[async_trait]
-pub(crate) trait RecvLoop: Sync + 'static {
- fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>);
-
- async fn recv_loop<R>(self: Arc<Self>, mut read: R, debug_name: String) -> Result<(), Error>
- where
- R: AsyncReadExt + Unpin + Send + Sync,
- {
- let mut receiving = HashMap::new();
- loop {
- let mut header_id = [0u8; RequestID::BITS as usize / 8];
- match read.read_exact(&mut header_id[..]).await {
- Ok(_) => (),
- Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
- Err(e) => return Err(e.into()),
- };
- let id = RequestID::from_be_bytes(header_id);
-
- let mut header_size = [0u8; ChunkLength::BITS as usize / 8];
- read.read_exact(&mut header_size[..]).await?;
- let size = ChunkLength::from_be_bytes(header_size);
- trace!(
- "recv_loop({}): got header id = {}, size = 0x{:04x} ({} bytes)",
- debug_name,
- id,
- size,
- size & !CHUNK_HAS_CONTINUATION
- );
-
- let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0;
- let size = size & !CHUNK_HAS_CONTINUATION;
-
- let mut next_slice = vec![0; size as usize];
- read.read_exact(&mut next_slice[..]).await?;
- trace!("recv_loop({}): read {} bytes", debug_name, next_slice.len());
-
- let mut msg_bytes: Vec<_> = receiving.remove(&id).unwrap_or_default();
- msg_bytes.extend_from_slice(&next_slice[..]);
-
- if has_cont {
- receiving.insert(id, msg_bytes);
- } else {
- self.recv_handler(id, msg_bytes);
- }
- }
- Ok(())
- }
-}
-
-#[cfg(test)]
-mod test {
- use super::*;
-
- #[test]
- fn test_priority_queue() {
- let i1 = SendQueueItem {
- id: 1,
- prio: PRIO_NORMAL,
- data: vec![],
- cursor: 0,
- };
- let i2 = SendQueueItem {
- id: 2,
- prio: PRIO_HIGH,
- data: vec![],
- cursor: 0,
- };
- let i2bis = SendQueueItem {
- id: 20,
- prio: PRIO_HIGH,
- data: vec![],
- cursor: 0,
- };
- let i3 = SendQueueItem {
- id: 3,
- prio: PRIO_HIGH | PRIO_SECONDARY,
- data: vec![],
- cursor: 0,
- };
- let i4 = SendQueueItem {
- id: 4,
- prio: PRIO_BACKGROUND | PRIO_SECONDARY,
- data: vec![],
- cursor: 0,
- };
- let i5 = SendQueueItem {
- id: 5,
- prio: PRIO_BACKGROUND | PRIO_PRIMARY,
- data: vec![],
- cursor: 0,
- };
-
- let mut q = SendQueue::new();
-
- q.push(i1); // 1
- let a = q.pop().unwrap(); // empty -> 1
- assert_eq!(a.id, 1);
- assert!(q.pop().is_none());
-
- q.push(a); // 1
- q.push(i2); // 2 1
- q.push(i2bis); // [2 20] 1
- let a = q.pop().unwrap(); // 20 1 -> 2
- assert_eq!(a.id, 2);
- let b = q.pop().unwrap(); // 1 -> 20
- assert_eq!(b.id, 20);
- let c = q.pop().unwrap(); // empty -> 1
- assert_eq!(c.id, 1);
- assert!(q.pop().is_none());
-
- q.push(a); // 2
- q.push(b); // [2 20]
- q.push(c); // [2 20] 1
- q.push(i3); // [2 20] 3 1
- q.push(i4); // [2 20] 3 1 4
- q.push(i5); // [2 20] 3 1 5 4
-
- let a = q.pop().unwrap(); // 20 3 1 5 4 -> 2
- assert_eq!(a.id, 2);
- q.push(a); // [20 2] 3 1 5 4
-
- let a = q.pop().unwrap(); // 2 3 1 5 4 -> 20
- assert_eq!(a.id, 20);
- let b = q.pop().unwrap(); // 3 1 5 4 -> 2
- assert_eq!(b.id, 2);
- q.push(b); // 2 3 1 5 4
- let b = q.pop().unwrap(); // 3 1 5 4 -> 2
- assert_eq!(b.id, 2);
- let c = q.pop().unwrap(); // 1 5 4 -> 3
- assert_eq!(c.id, 3);
- q.push(b); // 2 1 5 4
- let b = q.pop().unwrap(); // 1 5 4 -> 2
- assert_eq!(b.id, 2);
- let e = q.pop().unwrap(); // 5 4 -> 1
- assert_eq!(e.id, 1);
- let f = q.pop().unwrap(); // 4 -> 5
- assert_eq!(f.id, 5);
- let g = q.pop().unwrap(); // empty -> 4
- assert_eq!(g.id, 4);
- assert!(q.pop().is_none());
- }
-}
diff --git a/src/proto2.rs b/src/proto2.rs
deleted file mode 100644
index 7210781..0000000
--- a/src/proto2.rs
+++ /dev/null
@@ -1,75 +0,0 @@
-use crate::error::*;
-use crate::proto::*;
-
-pub(crate) struct QueryMessage<'a> {
- pub(crate) prio: RequestPriority,
- pub(crate) path: &'a [u8],
- pub(crate) telemetry_id: Option<Vec<u8>>,
- pub(crate) body: &'a [u8],
-}
-
-/// QueryMessage encoding:
-/// - priority: u8
-/// - path length: u8
-/// - path: [u8; path length]
-/// - telemetry id length: u8
-/// - telemetry id: [u8; telemetry id length]
-/// - body [u8; ..]
-impl<'a> QueryMessage<'a> {
- pub(crate) fn encode(self) -> Vec<u8> {
- let tel_len = match &self.telemetry_id {
- Some(t) => t.len(),
- None => 0,
- };
-
- let mut ret = Vec::with_capacity(10 + self.path.len() + tel_len + self.body.len());
-
- ret.push(self.prio);
-
- ret.push(self.path.len() as u8);
- ret.extend_from_slice(self.path);
-
- if let Some(t) = self.telemetry_id {
- ret.push(t.len() as u8);
- ret.extend(t);
- } else {
- ret.push(0u8);
- }
-
- ret.extend_from_slice(self.body);
-
- ret
- }
-
- pub(crate) fn decode(bytes: &'a [u8]) -> Result<Self, Error> {
- if bytes.len() < 3 {
- return Err(Error::Message("Invalid protocol message".into()));
- }
-
- let path_length = bytes[1] as usize;
- if bytes.len() < 3 + path_length {
- return Err(Error::Message("Invalid protocol message".into()));
- }
-
- let telemetry_id_len = bytes[2 + path_length] as usize;
- if bytes.len() < 3 + path_length + telemetry_id_len {
- return Err(Error::Message("Invalid protocol message".into()));
- }
-
- let path = &bytes[2..2 + path_length];
- let telemetry_id = if telemetry_id_len > 0 {
- Some(bytes[3 + path_length..3 + path_length + telemetry_id_len].to_vec())
- } else {
- None
- };
-
- let body = &bytes[3 + path_length + telemetry_id_len..];
-
- Ok(Self {
- prio: bytes[0],
- path,
- telemetry_id,
- body,
- })
- }
-}
diff --git a/src/recv.rs b/src/recv.rs
new file mode 100644
index 0000000..8909190
--- /dev/null
+++ b/src/recv.rs
@@ -0,0 +1,153 @@
+use std::collections::HashMap;
+use std::sync::Arc;
+
+use async_trait::async_trait;
+use bytes::Bytes;
+use log::*;
+
+use futures::AsyncReadExt;
+use tokio::sync::mpsc;
+
+use crate::error::*;
+use crate::send::*;
+use crate::stream::*;
+
+/// Structure to warn when the sender is dropped before end of stream was reached, like when
+/// connection to some remote drops while transmitting data
+struct Sender {
+ inner: Option<mpsc::UnboundedSender<Packet>>,
+}
+
+impl Sender {
+ fn new(inner: mpsc::UnboundedSender<Packet>) -> Self {
+ Sender { inner: Some(inner) }
+ }
+
+ fn send(&self, packet: Packet) {
+ let _ = self.inner.as_ref().unwrap().send(packet);
+ }
+
+ fn end(&mut self) {
+ self.inner = None;
+ }
+}
+
+impl Drop for Sender {
+ fn drop(&mut self) {
+ if let Some(inner) = self.inner.take() {
+ let _ = inner.send(Err(std::io::Error::new(
+ std::io::ErrorKind::BrokenPipe,
+ "Netapp connection dropped before end of stream",
+ )));
+ }
+ }
+}
+
+/// The RecvLoop trait, which is implemented both by the client and the server
+/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()`
+/// and a prototype of a handler for received messages `.recv_handler()` that
+/// must be filled by implementors. `.recv_loop()` receives messages in a loop
+/// according to the protocol defined above: chunks of message in progress of being
+/// received are stored in a buffer, and when the last chunk of a message is received,
+/// the full message is passed to the receive handler.
+#[async_trait]
+pub(crate) trait RecvLoop: Sync + 'static {
+ fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream);
+ fn cancel_handler(self: &Arc<Self>, _id: RequestID) {}
+
+ async fn recv_loop<R>(self: Arc<Self>, mut read: R, debug_name: String) -> Result<(), Error>
+ where
+ R: AsyncReadExt + Unpin + Send + Sync,
+ {
+ let mut streams: HashMap<RequestID, Sender> = HashMap::new();
+ loop {
+ trace!(
+ "recv_loop({}): in_progress = {:?}",
+ debug_name,
+ streams.iter().map(|(id, _)| id).collect::<Vec<_>>()
+ );
+
+ let mut header_id = [0u8; RequestID::BITS as usize / 8];
+ match read.read_exact(&mut header_id[..]).await {
+ Ok(_) => (),
+ Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
+ Err(e) => return Err(e.into()),
+ };
+ let id = RequestID::from_be_bytes(header_id);
+
+ let mut header_size = [0u8; ChunkLength::BITS as usize / 8];
+ read.read_exact(&mut header_size[..]).await?;
+ let size = ChunkLength::from_be_bytes(header_size);
+
+ if size == CANCEL_REQUEST {
+ if let Some(mut stream) = streams.remove(&id) {
+ let _ = stream.send(Err(std::io::Error::new(
+ std::io::ErrorKind::Other,
+ "netapp: cancel requested",
+ )));
+ stream.end();
+ }
+ self.cancel_handler(id);
+ continue;
+ }
+
+ let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0;
+ let is_error = (size & ERROR_MARKER) != 0;
+ let size = (size & CHUNK_LENGTH_MASK) as usize;
+ let mut next_slice = vec![0; size as usize];
+ read.read_exact(&mut next_slice[..]).await?;
+
+ let packet = if is_error {
+ let kind = u8_to_io_errorkind(next_slice[0]);
+ let msg =
+ std::str::from_utf8(&next_slice[1..]).unwrap_or("<invalid utf8 error message>");
+ debug!(
+ "recv_loop({}): got id {}, error {:?}: {}",
+ debug_name, id, kind, msg
+ );
+ Some(Err(std::io::Error::new(kind, msg.to_string())))
+ } else {
+ trace!(
+ "recv_loop({}): got id {}, size {}, has_cont {}",
+ debug_name,
+ id,
+ size,
+ has_cont
+ );
+ if !next_slice.is_empty() {
+ Some(Ok(Bytes::from(next_slice)))
+ } else {
+ None
+ }
+ };
+
+ let mut sender = if let Some(send) = streams.remove(&(id)) {
+ send
+ } else {
+ let (send, recv) = mpsc::unbounded_channel();
+ trace!("recv_loop({}): id {} is new channel", debug_name, id);
+ self.recv_handler(
+ id,
+ Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(recv)),
+ );
+ Sender::new(send)
+ };
+
+ if let Some(packet) = packet {
+ // If we cannot put packet in channel, it means that the
+ // receiving end of the channel is disconnected.
+ // We still need to reach eos before dropping this sender
+ let _ = sender.send(packet);
+ }
+
+ if has_cont {
+ assert!(!is_error);
+ streams.insert(id, sender);
+ } else {
+ trace!("recv_loop({}): close channel id {}", debug_name, id);
+ sender.end();
+ }
+ }
+ Ok(())
+ }
+}
diff --git a/src/send.rs b/src/send.rs
new file mode 100644
index 0000000..780bbcf
--- /dev/null
+++ b/src/send.rs
@@ -0,0 +1,334 @@
+use std::collections::{HashMap, VecDeque};
+use std::pin::Pin;
+use std::sync::Arc;
+use std::task::{Context, Poll};
+
+use async_trait::async_trait;
+use bytes::{BufMut, Bytes, BytesMut};
+use log::*;
+
+use futures::{AsyncWriteExt, Future};
+use kuska_handshake::async_std::BoxStreamWrite;
+use tokio::sync::mpsc;
+
+use crate::error::*;
+use crate::message::*;
+use crate::stream::*;
+
+// Messages are sent by chunks
+// Chunk format:
+// - u32 BE: request id (same for request and response)
+// - u16 BE: chunk length + flags:
+// CHUNK_HAS_CONTINUATION when this is not the last chunk of the stream
+// ERROR_MARKER if this chunk denotes an error
+// (these two flags are exclusive, an error denotes the end of the stream)
+// **special value** 0xFFFF indicates a CANCEL message
+// - [u8; chunk_length], either
+// - if not error: chunk data
+// - if error:
+// - u8: error kind, encoded using error::io_errorkind_to_u8
+// - rest: error message
+
+pub(crate) type RequestID = u32;
+pub(crate) type ChunkLength = u16;
+
+pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0;
+pub(crate) const ERROR_MARKER: ChunkLength = 0x4000;
+pub(crate) const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000;
+pub(crate) const CHUNK_LENGTH_MASK: ChunkLength = 0x3FFF;
+pub(crate) const CANCEL_REQUEST: ChunkLength = 0xFFFF;
+
+pub(crate) enum SendItem {
+ Stream(RequestID, RequestPriority, Option<OrderTag>, ByteStream),
+ Cancel(RequestID),
+}
+
+// ----
+
+struct SendQueue {
+ items: Vec<(u8, SendQueuePriority)>,
+}
+
+struct SendQueuePriority {
+ items: VecDeque<SendQueueItem>,
+ order: HashMap<u64, VecDeque<u64>>,
+}
+
+struct SendQueueItem {
+ id: RequestID,
+ prio: RequestPriority,
+ order_tag: Option<OrderTag>,
+ data: ByteStreamReader,
+}
+
+impl SendQueue {
+ fn new() -> Self {
+ Self {
+ items: Vec::with_capacity(64),
+ }
+ }
+ fn push(&mut self, item: SendQueueItem) {
+ let prio = item.prio;
+ let pos_prio = match self.items.binary_search_by(|(p, _)| p.cmp(&prio)) {
+ Ok(i) => i,
+ Err(i) => {
+ self.items.insert(i, (prio, SendQueuePriority::new()));
+ i
+ }
+ };
+ self.items[pos_prio].1.push(item);
+ }
+ fn remove(&mut self, id: RequestID) {
+ for (_, prioq) in self.items.iter_mut() {
+ prioq.remove(id);
+ }
+ }
+ fn is_empty(&self) -> bool {
+ self.items.iter().all(|(_k, v)| v.is_empty())
+ }
+
+ // this is like an async fn, but hand implemented
+ fn next_ready(&mut self) -> SendQueuePollNextReady<'_> {
+ SendQueuePollNextReady { queue: self }
+ }
+}
+
+impl SendQueuePriority {
+ fn new() -> Self {
+ Self {
+ items: VecDeque::new(),
+ order: HashMap::new(),
+ }
+ }
+ fn push(&mut self, item: SendQueueItem) {
+ if let Some(OrderTag(stream, order)) = item.order_tag {
+ let order_vec = self.order.entry(stream).or_default();
+ let i = order_vec.iter().take_while(|o2| **o2 < order).count();
+ order_vec.insert(i, order);
+ }
+ self.items.push_back(item);
+ }
+ fn remove(&mut self, id: RequestID) {
+ if let Some(i) = self.items.iter().position(|x| x.id == id) {
+ let item = self.items.remove(i).unwrap();
+ if let Some(OrderTag(stream, order)) = item.order_tag {
+ let order_vec = self.order.get_mut(&stream).unwrap();
+ let j = order_vec.iter().position(|x| *x == order).unwrap();
+ order_vec.remove(j).unwrap();
+ }
+ }
+ }
+ fn is_empty(&self) -> bool {
+ self.items.is_empty()
+ }
+ fn poll_next_ready(&mut self, ctx: &mut Context<'_>) -> Poll<(RequestID, DataFrame)> {
+ for (j, item) in self.items.iter_mut().enumerate() {
+ if let Some(OrderTag(stream, order)) = item.order_tag {
+ if order > *self.order.get(&stream).unwrap().front().unwrap() {
+ continue;
+ }
+ }
+
+ let mut item_reader = item.data.read_exact_or_eos(MAX_CHUNK_LENGTH as usize);
+ if let Poll::Ready(bytes_or_err) = Pin::new(&mut item_reader).poll(ctx) {
+ let id = item.id;
+ let eos = item.data.eos();
+
+ let packet = bytes_or_err.map_err(|e| match e {
+ ReadExactError::Stream(err) => err,
+ _ => unreachable!(),
+ });
+
+ if eos || packet.is_err() {
+ if let Some(OrderTag(stream, order)) = item.order_tag {
+ assert_eq!(
+ self.order.get_mut(&stream).unwrap().pop_front(),
+ Some(order)
+ )
+ }
+ self.items.remove(j);
+ }
+
+ let data_frame = DataFrame::from_packet(packet, !eos);
+
+ return Poll::Ready((id, data_frame));
+ }
+ }
+
+ Poll::Pending
+ }
+ fn dump(&self, prio: u8) -> String {
+ self.items
+ .iter()
+ .map(|i| format!("[{} {} {:?}]", prio, i.id, i.order_tag))
+ .collect::<Vec<_>>()
+ .join(" ")
+ }
+}
+
+struct SendQueuePollNextReady<'a> {
+ queue: &'a mut SendQueue,
+}
+
+impl<'a> futures::Future for SendQueuePollNextReady<'a> {
+ type Output = (RequestID, DataFrame);
+
+ fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
+ for (i, (_prio, items_at_prio)) in self.queue.items.iter_mut().enumerate() {
+ if let Poll::Ready(res) = items_at_prio.poll_next_ready(ctx) {
+ if items_at_prio.is_empty() {
+ self.queue.items.remove(i);
+ }
+ return Poll::Ready(res);
+ }
+ }
+ // If the queue is empty, this futures is eternally pending.
+ // This is ok because we use it in a select with another future
+ // that can interrupt it.
+ Poll::Pending
+ }
+}
+
+enum DataFrame {
+ /// a fixed size buffer containing some data + a boolean indicating whether
+ /// there may be more data comming from this stream. Can be used for some
+ /// optimization. It's an error to set it to false if there is more data, but it is correct
+ /// (albeit sub-optimal) to set it to true if there is nothing coming after
+ Data(Bytes, bool),
+ /// An error code automatically signals the end of the stream
+ Error(Bytes),
+}
+
+impl DataFrame {
+ fn from_packet(p: Packet, has_cont: bool) -> Self {
+ match p {
+ Ok(bytes) => {
+ assert!(bytes.len() <= MAX_CHUNK_LENGTH as usize);
+ Self::Data(bytes, has_cont)
+ }
+ Err(e) => {
+ let mut buf = BytesMut::new();
+ buf.put_u8(io_errorkind_to_u8(e.kind()));
+
+ let msg = format!("{}", e).into_bytes();
+ if msg.len() > (MAX_CHUNK_LENGTH - 1) as usize {
+ buf.put(&msg[..(MAX_CHUNK_LENGTH - 1) as usize]);
+ } else {
+ buf.put(&msg[..]);
+ }
+
+ Self::Error(buf.freeze())
+ }
+ }
+ }
+
+ fn header(&self) -> [u8; 2] {
+ let header_u16 = match self {
+ DataFrame::Data(data, false) => data.len() as u16,
+ DataFrame::Data(data, true) => data.len() as u16 | CHUNK_HAS_CONTINUATION,
+ DataFrame::Error(msg) => msg.len() as u16 | ERROR_MARKER,
+ };
+ ChunkLength::to_be_bytes(header_u16)
+ }
+
+ fn data(&self) -> &[u8] {
+ match self {
+ DataFrame::Data(ref data, _) => &data[..],
+ DataFrame::Error(ref msg) => &msg[..],
+ }
+ }
+}
+
+/// The SendLoop trait, which is implemented both by the client and the server
+/// connection objects (ServerConna and ClientConn) adds a method `.send_loop()`
+/// that takes a channel of messages to send and an asynchronous writer,
+/// and sends messages from the channel to the async writer, putting them in a queue
+/// before being sent and doing the round-robin sending strategy.
+///
+/// The `.send_loop()` exits when the sending end of the channel is closed,
+/// or if there is an error at any time writing to the async writer.
+#[async_trait]
+pub(crate) trait SendLoop: Sync {
+ async fn send_loop<W>(
+ self: Arc<Self>,
+ msg_recv: mpsc::UnboundedReceiver<SendItem>,
+ mut write: BoxStreamWrite<W>,
+ debug_name: String,
+ ) -> Result<(), Error>
+ where
+ W: AsyncWriteExt + Unpin + Send + Sync,
+ {
+ let mut sending = SendQueue::new();
+ let mut msg_recv = Some(msg_recv);
+ while msg_recv.is_some() || !sending.is_empty() {
+ trace!(
+ "send_loop({}): queue = {:?}",
+ debug_name,
+ sending
+ .items
+ .iter()
+ .map(|(prio, i)| i.dump(*prio))
+ .collect::<Vec<_>>()
+ .join(" ; ")
+ );
+
+ let recv_fut = async {
+ if let Some(chan) = &mut msg_recv {
+ chan.recv().await
+ } else {
+ futures::future::pending().await
+ }
+ };
+ let send_fut = sending.next_ready();
+
+ // recv_fut is cancellation-safe according to tokio doc,
+ // send_fut is cancellation-safe as implemented above?
+ tokio::select! {
+ biased; // always read incomming channel first if it has data
+ sth = recv_fut => {
+ match sth {
+ Some(SendItem::Stream(id, prio, order_tag, data)) => {
+ trace!("send_loop({}): add stream {} to send", debug_name, id);
+ sending.push(SendQueueItem {
+ id,
+ prio,
+ order_tag,
+ data: ByteStreamReader::new(data),
+ })
+ }
+ Some(SendItem::Cancel(id)) => {
+ trace!("send_loop({}): cancelling {}", debug_name, id);
+ sending.remove(id);
+ let header_id = RequestID::to_be_bytes(id);
+ write.write_all(&header_id[..]).await?;
+ write.write_all(&ChunkLength::to_be_bytes(CANCEL_REQUEST)).await?;
+ write.flush().await?;
+ }
+ None => {
+ msg_recv = None;
+ }
+ };
+ }
+ (id, data) = send_fut => {
+ trace!(
+ "send_loop({}): id {}, send {} bytes, header_size {}",
+ debug_name,
+ id,
+ data.data().len(),
+ hex::encode(data.header())
+ );
+
+ let header_id = RequestID::to_be_bytes(id);
+ write.write_all(&header_id[..]).await?;
+
+ write.write_all(&data.header()).await?;
+ write.write_all(data.data()).await?;
+ write.flush().await?;
+ }
+ }
+ }
+
+ let _ = write.goodbye().await;
+ Ok(())
+ }
+}
diff --git a/src/server.rs b/src/server.rs
index a835959..f9eb121 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -1,9 +1,17 @@
+use std::collections::HashMap;
use std::net::SocketAddr;
-use std::sync::Arc;
+use std::sync::{Arc, Mutex};
use arc_swap::ArcSwapOption;
-use bytes::Bytes;
-use log::{debug, trace};
+use async_trait::async_trait;
+use log::*;
+
+use futures::io::{AsyncReadExt, AsyncWriteExt};
+use kuska_handshake::async_std::{handshake_server, BoxStream};
+use tokio::net::TcpStream;
+use tokio::select;
+use tokio::sync::{mpsc, watch};
+use tokio_util::compat::*;
#[cfg(feature = "telemetry")]
use opentelemetry::{
@@ -15,21 +23,12 @@ use opentelemetry_contrib::trace::propagator::binary::*;
#[cfg(feature = "telemetry")]
use rand::{thread_rng, Rng};
-use tokio::net::TcpStream;
-use tokio::select;
-use tokio::sync::{mpsc, watch};
-use tokio_util::compat::*;
-
-use futures::io::{AsyncReadExt, AsyncWriteExt};
-
-use async_trait::async_trait;
-
-use kuska_handshake::async_std::{handshake_server, BoxStream};
-
use crate::error::*;
+use crate::message::*;
use crate::netapp::*;
-use crate::proto::*;
-use crate::proto2::*;
+use crate::recv::*;
+use crate::send::*;
+use crate::stream::*;
use crate::util::*;
// The client and server connection structs (client.rs and server.rs)
@@ -55,7 +54,8 @@ pub(crate) struct ServerConn {
netapp: Arc<NetApp>,
- resp_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>>,
+ resp_send: ArcSwapOption<mpsc::UnboundedSender<SendItem>>,
+ running_handlers: Mutex<HashMap<RequestID, tokio::task::JoinHandle<()>>>,
}
impl ServerConn {
@@ -101,6 +101,7 @@ impl ServerConn {
remote_addr,
peer_id,
resp_send: ArcSwapOption::new(Some(Arc::new(resp_send))),
+ running_handlers: Mutex::new(HashMap::new()),
});
netapp.connected_as_server(peer_id, conn.clone());
@@ -126,13 +127,12 @@ impl ServerConn {
Ok(())
}
- async fn recv_handler_aux(self: &Arc<Self>, bytes: &[u8]) -> Result<Vec<u8>, Error> {
- let msg = QueryMessage::decode(bytes)?;
- let path = String::from_utf8(msg.path.to_vec())?;
+ async fn recv_handler_aux(self: &Arc<Self>, req_enc: ReqEnc) -> Result<RespEnc, Error> {
+ let path = String::from_utf8(req_enc.path.to_vec())?;
let handler_opt = {
let endpoints = self.netapp.endpoints.read().unwrap();
- endpoints.get(&path).map(|e| e.clone_endpoint())
+ endpoints.get(&path[..]).map(|e| e.clone_endpoint())
};
if let Some(handler) = handler_opt {
@@ -140,9 +140,9 @@ impl ServerConn {
if #[cfg(feature = "telemetry")] {
let tracer = opentelemetry::global::tracer("netapp");
- let mut span = if let Some(telemetry_id) = msg.telemetry_id {
+ let mut span = if !req_enc.telemetry_id.is_empty() {
let propagator = BinaryPropagator::new();
- let context = propagator.from_bytes(telemetry_id);
+ let context = propagator.from_bytes(req_enc.telemetry_id.to_vec());
let context = Context::new().with_remote_span_context(context);
tracer.span_builder(format!(">> RPC {}", path))
.with_kind(SpanKind::Server)
@@ -157,13 +157,13 @@ impl ServerConn {
.start(&tracer)
};
span.set_attribute(KeyValue::new("path", path.to_string()));
- span.set_attribute(KeyValue::new("len_query", msg.body.len() as i64));
+ span.set_attribute(KeyValue::new("len_query_msg", req_enc.msg.len() as i64));
- handler.handle(msg.body, self.peer_id)
+ handler.handle(req_enc, self.peer_id)
.with_context(Context::current_with_span(span))
.await
} else {
- handler.handle(msg.body, self.peer_id).await
+ handler.handle(req_enc, self.peer_id).await
}
}
} else {
@@ -176,35 +176,47 @@ impl SendLoop for ServerConn {}
#[async_trait]
impl RecvLoop for ServerConn {
- fn recv_handler(self: &Arc<Self>, id: RequestID, bytes: Vec<u8>) {
- let resp_send = self.resp_send.load_full().unwrap();
+ fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream) {
+ let resp_send = match self.resp_send.load_full() {
+ Some(c) => c,
+ None => return,
+ };
+
+ let mut rh = self.running_handlers.lock().unwrap();
let self2 = self.clone();
- tokio::spawn(async move {
- trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len());
- let bytes: Bytes = bytes.into();
-
- let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 };
- let resp = self2.recv_handler_aux(&bytes[..]).await;
-
- let resp_bytes = match resp {
- Ok(rb) => {
- let mut resp_bytes = vec![0u8];
- resp_bytes.extend(rb);
- resp_bytes
- }
- Err(e) => {
- let mut resp_bytes = vec![e.code()];
- resp_bytes.extend(e.to_string().into_bytes());
- resp_bytes
- }
+ let jh = tokio::spawn(async move {
+ debug!("server: recv_handler got {}", id);
+
+ let (prio, resp_enc_result) = match ReqEnc::decode(stream).await {
+ Ok(req_enc) => (req_enc.prio, self2.recv_handler_aux(req_enc).await),
+ Err(e) => (PRIO_HIGH, Err(e)),
};
- trace!("ServerConn sending response to {}: ", id);
+ debug!("server: sending response to {}", id);
+ let (resp_stream, resp_order) = RespEnc::encode(resp_enc_result);
resp_send
- .send((id, prio, resp_bytes))
- .log_err("ServerConn recv_handler send resp");
+ .send(SendItem::Stream(id, prio, resp_order, resp_stream))
+ .log_err("ServerConn recv_handler send resp bytes");
+
+ self2.running_handlers.lock().unwrap().remove(&id);
});
+
+ rh.insert(id, jh);
+ }
+
+ fn cancel_handler(self: &Arc<Self>, id: RequestID) {
+ trace!("received cancel for request {}", id);
+
+ // If the handler is still running, abort it now
+ if let Some(jh) = self.running_handlers.lock().unwrap().remove(&id) {
+ jh.abort();
+ }
+
+ // Inform the response sender that we don't need to send the response
+ if let Some(resp_send) = self.resp_send.load_full() {
+ let _ = resp_send.send(SendItem::Cancel(id));
+ }
}
}
diff --git a/src/stream.rs b/src/stream.rs
new file mode 100644
index 0000000..05ee051
--- /dev/null
+++ b/src/stream.rs
@@ -0,0 +1,169 @@
+use std::pin::Pin;
+use std::task::{Context, Poll};
+
+use bytes::Bytes;
+
+use futures::Future;
+use futures::{Stream, StreamExt};
+use tokio::io::AsyncRead;
+
+use crate::bytes_buf::BytesBuf;
+
+/// A stream of associated data.
+///
+/// When sent through Netapp, the Vec may be split in smaller chunk in such a way
+/// consecutive Vec may get merged, but Vec and error code may not be reordered
+///
+/// Error code 255 means the stream was cut before its end. Other codes have no predefined
+/// meaning, it's up to your application to define their semantic.
+pub type ByteStream = Pin<Box<dyn Stream<Item = Packet> + Send + Sync>>;
+
+pub type Packet = Result<Bytes, std::io::Error>;
+
+// ----
+
+pub struct ByteStreamReader {
+ stream: ByteStream,
+ buf: BytesBuf,
+ eos: bool,
+ err: Option<std::io::Error>,
+}
+
+impl ByteStreamReader {
+ pub fn new(stream: ByteStream) -> Self {
+ ByteStreamReader {
+ stream,
+ buf: BytesBuf::new(),
+ eos: false,
+ err: None,
+ }
+ }
+
+ pub fn read_exact(&mut self, read_len: usize) -> ByteStreamReadExact<'_> {
+ ByteStreamReadExact {
+ reader: self,
+ read_len,
+ fail_on_eos: true,
+ }
+ }
+
+ pub fn read_exact_or_eos(&mut self, read_len: usize) -> ByteStreamReadExact<'_> {
+ ByteStreamReadExact {
+ reader: self,
+ read_len,
+ fail_on_eos: false,
+ }
+ }
+
+ pub async fn read_u8(&mut self) -> Result<u8, ReadExactError> {
+ Ok(self.read_exact(1).await?[0])
+ }
+
+ pub async fn read_u16(&mut self) -> Result<u16, ReadExactError> {
+ let bytes = self.read_exact(2).await?;
+ let mut b = [0u8; 2];
+ b.copy_from_slice(&bytes[..]);
+ Ok(u16::from_be_bytes(b))
+ }
+
+ pub async fn read_u32(&mut self) -> Result<u32, ReadExactError> {
+ let bytes = self.read_exact(4).await?;
+ let mut b = [0u8; 4];
+ b.copy_from_slice(&bytes[..]);
+ Ok(u32::from_be_bytes(b))
+ }
+
+ pub fn into_stream(self) -> ByteStream {
+ let buf_stream = futures::stream::iter(self.buf.into_slices().into_iter().map(Ok));
+ if let Some(err) = self.err {
+ Box::pin(buf_stream.chain(futures::stream::once(async move { Err(err) })))
+ } else if self.eos {
+ Box::pin(buf_stream)
+ } else {
+ Box::pin(buf_stream.chain(self.stream))
+ }
+ }
+
+ pub fn take_buffer(&mut self) -> Bytes {
+ self.buf.take_all()
+ }
+
+ pub fn eos(&self) -> bool {
+ self.buf.is_empty() && self.eos
+ }
+
+ fn try_get(&mut self, read_len: usize) -> Option<Bytes> {
+ self.buf.take_exact(read_len)
+ }
+
+ fn add_stream_next(&mut self, packet: Option<Packet>) {
+ match packet {
+ Some(Ok(slice)) => {
+ self.buf.extend(slice);
+ }
+ Some(Err(e)) => {
+ self.err = Some(e);
+ self.eos = true;
+ }
+ None => {
+ self.eos = true;
+ }
+ }
+ }
+
+ pub async fn fill_buffer(&mut self) {
+ let packet = self.stream.next().await;
+ self.add_stream_next(packet);
+ }
+}
+
+pub enum ReadExactError {
+ UnexpectedEos,
+ Stream(std::io::Error),
+}
+
+#[pin_project::pin_project]
+pub struct ByteStreamReadExact<'a> {
+ #[pin]
+ reader: &'a mut ByteStreamReader,
+ read_len: usize,
+ fail_on_eos: bool,
+}
+
+impl<'a> Future for ByteStreamReadExact<'a> {
+ type Output = Result<Bytes, ReadExactError>;
+
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<Bytes, ReadExactError>> {
+ let mut this = self.project();
+
+ loop {
+ if let Some(bytes) = this.reader.try_get(*this.read_len) {
+ return Poll::Ready(Ok(bytes));
+ }
+ if let Some(err) = &this.reader.err {
+ let err = std::io::Error::new(err.kind(), format!("{}", err));
+ return Poll::Ready(Err(ReadExactError::Stream(err)));
+ }
+ if this.reader.eos {
+ if *this.fail_on_eos {
+ return Poll::Ready(Err(ReadExactError::UnexpectedEos));
+ } else {
+ return Poll::Ready(Ok(this.reader.take_buffer()));
+ }
+ }
+
+ let next_packet = futures::ready!(this.reader.stream.as_mut().poll_next(cx));
+ this.reader.add_stream_next(next_packet);
+ }
+ }
+}
+
+// ----
+
+pub fn asyncread_stream<R: AsyncRead + Send + Sync + 'static>(reader: R) -> ByteStream {
+ Box::pin(tokio_util::io::ReaderStream::new(reader))
+}
+
+pub fn stream_asyncread(stream: ByteStream) -> impl AsyncRead + Send + Sync + 'static {
+ tokio_util::io::StreamReader::new(stream)
+}
diff --git a/src/test.rs b/src/test.rs
index 82c7ba6..19759cc 100644
--- a/src/test.rs
+++ b/src/test.rs
@@ -14,6 +14,7 @@ use crate::NodeID;
#[tokio::test(flavor = "current_thread")]
async fn test_with_basic_scheduler() {
+ env_logger::init();
run_test().await
}
diff --git a/src/util.rs b/src/util.rs
index f4dfac7..425d26f 100644
--- a/src/util.rs
+++ b/src/util.rs
@@ -1,18 +1,12 @@
use std::net::SocketAddr;
use std::net::ToSocketAddrs;
-use serde::Serialize;
-
use log::info;
+use serde::Serialize;
use tokio::sync::watch;
-/// A node's identifier, which is also its public cryptographic key
-pub type NodeID = sodiumoxide::crypto::sign::ed25519::PublicKey;
-/// A node's secret key
-pub type NodeKey = sodiumoxide::crypto::sign::ed25519::SecretKey;
-/// A network key
-pub type NetworkKey = sodiumoxide::crypto::auth::Key;
+use crate::netapp::*;
/// Utility function: encodes any serializable value in MessagePack binary format
/// using the RMP library.