aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/client.rs18
-rw-r--r--src/endpoint.rs78
-rw-r--r--src/error.rs2
-rw-r--r--src/lib.rs5
-rw-r--r--src/message.rs255
-rw-r--r--src/netapp.rs2
-rw-r--r--src/peering/basalt.rs3
-rw-r--r--src/peering/fullmesh.rs3
-rw-r--r--src/proto2.rs75
-rw-r--r--src/recv.rs114
-rw-r--r--src/send.rs (renamed from src/proto.rs)235
-rw-r--r--src/server.rs32
-rw-r--r--src/util.rs12
13 files changed, 424 insertions, 410 deletions
diff --git a/src/client.rs b/src/client.rs
index 6d49f5c..663a3e4 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -5,9 +5,12 @@ use std::sync::atomic::{self, AtomicU32};
use std::sync::{Arc, Mutex};
use arc_swap::ArcSwapOption;
+use async_trait::async_trait;
use log::{debug, error, trace};
use futures::channel::mpsc::{unbounded, UnboundedReceiver};
+use futures::io::AsyncReadExt;
+use kuska_handshake::async_std::{handshake_client, BoxStream};
use tokio::net::TcpStream;
use tokio::select;
use tokio::sync::{mpsc, oneshot, watch};
@@ -21,25 +24,18 @@ use opentelemetry::{
#[cfg(feature = "telemetry")]
use opentelemetry_contrib::trace::propagator::binary::*;
-use futures::io::AsyncReadExt;
-
-use async_trait::async_trait;
-
-use kuska_handshake::async_std::{handshake_client, BoxStream};
-
-use crate::endpoint::*;
use crate::error::*;
+use crate::message::*;
use crate::netapp::*;
-use crate::proto::*;
-use crate::proto2::*;
+use crate::recv::*;
+use crate::send::*;
use crate::util::*;
pub(crate) struct ClientConn {
pub(crate) remote_addr: SocketAddr,
pub(crate) peer_id: NodeID,
- query_send:
- ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, AssociatedStream)>>,
+ query_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, ByteStream)>>,
next_query_number: AtomicU32,
inflight: Mutex<HashMap<RequestID, oneshot::Sender<UnboundedReceiver<Packet>>>>,
diff --git a/src/endpoint.rs b/src/endpoint.rs
index f31141d..e6b2236 100644
--- a/src/endpoint.rs
+++ b/src/endpoint.rs
@@ -5,79 +5,11 @@ use std::sync::Arc;
use arc_swap::ArcSwapOption;
use async_trait::async_trait;
-use serde::{Deserialize, Serialize};
-
use crate::error::Error;
+use crate::message::*;
use crate::netapp::*;
-use crate::proto::*;
use crate::util::*;
-/// This trait should be implemented by all messages your application
-/// wants to handle
-pub trait Message: SerializeMessage + Send + Sync {
- type Response: SerializeMessage + Send + Sync;
-}
-
-/// A trait for de/serializing messages, with possible associated stream.
-#[async_trait]
-pub trait SerializeMessage: Sized {
- type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send;
-
- fn serialize_msg(&self) -> (Self::SerializableSelf, Option<AssociatedStream>);
-
- // TODO should return Result
- async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self;
-}
-
-pub trait AutoSerialize: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {}
-
-#[async_trait]
-impl<T> SerializeMessage for T
-where
- T: AutoSerialize,
-{
- type SerializableSelf = Self;
- fn serialize_msg(&self) -> (Self::SerializableSelf, Option<AssociatedStream>) {
- (self.clone(), None)
- }
-
- async fn deserialize_msg(ser_self: Self::SerializableSelf, _stream: AssociatedStream) -> Self {
- // TODO verify no stream
- ser_self
- }
-}
-
-impl AutoSerialize for () {}
-
-#[async_trait]
-impl<T, E> SerializeMessage for Result<T, E>
-where
- T: SerializeMessage + Send,
- E: SerializeMessage + Send,
-{
- type SerializableSelf = Result<T::SerializableSelf, E::SerializableSelf>;
-
- fn serialize_msg(&self) -> (Self::SerializableSelf, Option<AssociatedStream>) {
- match self {
- Ok(ok) => {
- let (msg, stream) = ok.serialize_msg();
- (Ok(msg), stream)
- }
- Err(err) => {
- let (msg, stream) = err.serialize_msg();
- (Err(msg), stream)
- }
- }
- }
-
- async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: AssociatedStream) -> Self {
- match ser_self {
- Ok(ok) => Ok(T::deserialize_msg(ok, stream).await),
- Err(err) => Err(E::deserialize_msg(err, stream).await),
- }
- }
-}
-
/// This trait should be implemented by an object of your application
/// that can handle a message of type `M`.
///
@@ -191,9 +123,9 @@ pub(crate) trait GenericEndpoint {
async fn handle(
&self,
buf: &[u8],
- stream: AssociatedStream,
+ stream: ByteStream,
from: NodeID,
- ) -> Result<(Vec<u8>, Option<AssociatedStream>), Error>;
+ ) -> Result<(Vec<u8>, Option<ByteStream>), Error>;
fn drop_handler(&self);
fn clone_endpoint(&self) -> DynEndpoint;
}
@@ -213,9 +145,9 @@ where
async fn handle(
&self,
buf: &[u8],
- stream: AssociatedStream,
+ stream: ByteStream,
from: NodeID,
- ) -> Result<(Vec<u8>, Option<AssociatedStream>), Error> {
+ ) -> Result<(Vec<u8>, Option<ByteStream>), Error> {
match self.0.handler.load_full() {
None => Err(Error::NoHandler),
Some(h) => {
diff --git a/src/error.rs b/src/error.rs
index 7911c29..665647c 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -1,6 +1,6 @@
-use err_derive::Error;
use std::io;
+use err_derive::Error;
use log::error;
#[derive(Debug, Error)]
diff --git a/src/lib.rs b/src/lib.rs
index cb24337..1edb919 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -17,10 +17,11 @@ pub mod error;
pub mod util;
pub mod endpoint;
-pub mod proto;
+pub mod message;
mod client;
-mod proto2;
+mod recv;
+mod send;
mod server;
pub mod netapp;
diff --git a/src/message.rs b/src/message.rs
new file mode 100644
index 0000000..dbcc857
--- /dev/null
+++ b/src/message.rs
@@ -0,0 +1,255 @@
+use async_trait::async_trait;
+use futures::stream::{Stream, StreamExt};
+use serde::{Deserialize, Serialize};
+
+use crate::error::*;
+use crate::util::*;
+
+/// Priority of a request (click to read more about priorities).
+///
+/// This priority value is used to priorize messages
+/// in the send queue of the client, and their responses in the send queue of the
+/// server. Lower values mean higher priority.
+///
+/// This mechanism is usefull for messages bigger than the maximum chunk size
+/// (set at `0x4000` bytes), such as large file transfers.
+/// In such case, all of the messages in the send queue with the highest priority
+/// will take turns to send individual chunks, in a round-robin fashion.
+/// Once all highest priority messages are sent successfully, the messages with
+/// the next highest priority will begin being sent in the same way.
+///
+/// The same priority value is given to a request and to its associated response.
+pub type RequestPriority = u8;
+
+/// Priority class: high
+pub const PRIO_HIGH: RequestPriority = 0x20;
+/// Priority class: normal
+pub const PRIO_NORMAL: RequestPriority = 0x40;
+/// Priority class: background
+pub const PRIO_BACKGROUND: RequestPriority = 0x80;
+/// Priority: primary among given class
+pub const PRIO_PRIMARY: RequestPriority = 0x00;
+/// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`)
+pub const PRIO_SECONDARY: RequestPriority = 0x01;
+
+// ----
+
+/// This trait should be implemented by all messages your application
+/// wants to handle
+pub trait Message: SerializeMessage + Send + Sync {
+ type Response: SerializeMessage + Send + Sync;
+}
+
+/// A trait for de/serializing messages, with possible associated stream.
+#[async_trait]
+pub trait SerializeMessage: Sized {
+ type SerializableSelf: Serialize + for<'de> Deserialize<'de> + Send;
+
+ fn serialize_msg(&self) -> (Self::SerializableSelf, Option<ByteStream>);
+
+ // TODO should return Result
+ async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self;
+}
+
+pub trait AutoSerialize: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync {}
+
+#[async_trait]
+impl<T> SerializeMessage for T
+where
+ T: AutoSerialize,
+{
+ type SerializableSelf = Self;
+ fn serialize_msg(&self) -> (Self::SerializableSelf, Option<ByteStream>) {
+ (self.clone(), None)
+ }
+
+ async fn deserialize_msg(ser_self: Self::SerializableSelf, _stream: ByteStream) -> Self {
+ // TODO verify no stream
+ ser_self
+ }
+}
+
+impl AutoSerialize for () {}
+
+#[async_trait]
+impl<T, E> SerializeMessage for Result<T, E>
+where
+ T: SerializeMessage + Send,
+ E: SerializeMessage + Send,
+{
+ type SerializableSelf = Result<T::SerializableSelf, E::SerializableSelf>;
+
+ fn serialize_msg(&self) -> (Self::SerializableSelf, Option<ByteStream>) {
+ match self {
+ Ok(ok) => {
+ let (msg, stream) = ok.serialize_msg();
+ (Ok(msg), stream)
+ }
+ Err(err) => {
+ let (msg, stream) = err.serialize_msg();
+ (Err(msg), stream)
+ }
+ }
+ }
+
+ async fn deserialize_msg(ser_self: Self::SerializableSelf, stream: ByteStream) -> Self {
+ match ser_self {
+ Ok(ok) => Ok(T::deserialize_msg(ok, stream).await),
+ Err(err) => Err(E::deserialize_msg(err, stream).await),
+ }
+ }
+}
+
+// ----
+
+pub(crate) struct QueryMessage<'a> {
+ pub(crate) prio: RequestPriority,
+ pub(crate) path: &'a [u8],
+ pub(crate) telemetry_id: Option<Vec<u8>>,
+ pub(crate) body: &'a [u8],
+}
+
+/// QueryMessage encoding:
+/// - priority: u8
+/// - path length: u8
+/// - path: [u8; path length]
+/// - telemetry id length: u8
+/// - telemetry id: [u8; telemetry id length]
+/// - body [u8; ..]
+impl<'a> QueryMessage<'a> {
+ pub(crate) fn encode(self) -> Vec<u8> {
+ let tel_len = match &self.telemetry_id {
+ Some(t) => t.len(),
+ None => 0,
+ };
+
+ let mut ret = Vec::with_capacity(10 + self.path.len() + tel_len + self.body.len());
+
+ ret.push(self.prio);
+
+ ret.push(self.path.len() as u8);
+ ret.extend_from_slice(self.path);
+
+ if let Some(t) = self.telemetry_id {
+ ret.push(t.len() as u8);
+ ret.extend(t);
+ } else {
+ ret.push(0u8);
+ }
+
+ ret.extend_from_slice(self.body);
+
+ ret
+ }
+
+ pub(crate) fn decode(bytes: &'a [u8]) -> Result<Self, Error> {
+ if bytes.len() < 3 {
+ return Err(Error::Message("Invalid protocol message".into()));
+ }
+
+ let path_length = bytes[1] as usize;
+ if bytes.len() < 3 + path_length {
+ return Err(Error::Message("Invalid protocol message".into()));
+ }
+
+ let telemetry_id_len = bytes[2 + path_length] as usize;
+ if bytes.len() < 3 + path_length + telemetry_id_len {
+ return Err(Error::Message("Invalid protocol message".into()));
+ }
+
+ let path = &bytes[2..2 + path_length];
+ let telemetry_id = if telemetry_id_len > 0 {
+ Some(bytes[3 + path_length..3 + path_length + telemetry_id_len].to_vec())
+ } else {
+ None
+ };
+
+ let body = &bytes[3 + path_length + telemetry_id_len..];
+
+ Ok(Self {
+ prio: bytes[0],
+ path,
+ telemetry_id,
+ body,
+ })
+ }
+}
+
+pub(crate) struct Framing {
+ direct: Vec<u8>,
+ stream: Option<ByteStream>,
+}
+
+impl Framing {
+ pub fn new(direct: Vec<u8>, stream: Option<ByteStream>) -> Self {
+ assert!(direct.len() <= u32::MAX as usize);
+ Framing { direct, stream }
+ }
+
+ pub fn into_stream(self) -> ByteStream {
+ use futures::stream;
+ let len = self.direct.len() as u32;
+ // required because otherwise the borrow-checker complains
+ let Framing { direct, stream } = self;
+
+ let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec()) })
+ .chain(stream::once(async move { Ok(direct) }));
+
+ if let Some(stream) = stream {
+ Box::pin(res.chain(stream))
+ } else {
+ Box::pin(res)
+ }
+ }
+
+ pub async fn from_stream<S: Stream<Item = Packet> + Unpin + Send + 'static>(
+ mut stream: S,
+ ) -> Result<Self, Error> {
+ let mut packet = stream
+ .next()
+ .await
+ .ok_or(Error::Framing)?
+ .map_err(|_| Error::Framing)?;
+ if packet.len() < 4 {
+ return Err(Error::Framing);
+ }
+
+ let mut len = [0; 4];
+ len.copy_from_slice(&packet[..4]);
+ let len = u32::from_be_bytes(len);
+ packet.drain(..4);
+
+ let mut buffer = Vec::new();
+ let len = len as usize;
+ loop {
+ let max_cp = std::cmp::min(len - buffer.len(), packet.len());
+
+ buffer.extend_from_slice(&packet[..max_cp]);
+ if buffer.len() == len {
+ packet.drain(..max_cp);
+ break;
+ }
+ packet = stream
+ .next()
+ .await
+ .ok_or(Error::Framing)?
+ .map_err(|_| Error::Framing)?;
+ }
+
+ let stream: ByteStream = if packet.is_empty() {
+ Box::pin(stream)
+ } else {
+ Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream))
+ };
+
+ Ok(Framing {
+ direct: buffer,
+ stream: Some(stream),
+ })
+ }
+
+ pub fn into_parts(self) -> (Vec<u8>, ByteStream) {
+ let Framing { direct, stream } = self;
+ (direct, stream.unwrap_or(Box::pin(futures::stream::empty())))
+ }
+}
diff --git a/src/netapp.rs b/src/netapp.rs
index 27f17e6..dd22d90 100644
--- a/src/netapp.rs
+++ b/src/netapp.rs
@@ -20,7 +20,7 @@ use tokio::sync::{mpsc, watch};
use crate::client::*;
use crate::endpoint::*;
use crate::error::*;
-use crate::proto::*;
+use crate::message::*;
use crate::server::*;
use crate::util::*;
diff --git a/src/peering/basalt.rs b/src/peering/basalt.rs
index 7f77995..98977a3 100644
--- a/src/peering/basalt.rs
+++ b/src/peering/basalt.rs
@@ -14,8 +14,9 @@ use sodiumoxide::crypto::hash;
use tokio::sync::watch;
use crate::endpoint::*;
+use crate::message::*;
use crate::netapp::*;
-use crate::proto::*;
+use crate::send::*;
use crate::NodeID;
// -- Protocol messages --
diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs
index 7dfc5c4..5b489ae 100644
--- a/src/peering/fullmesh.rs
+++ b/src/peering/fullmesh.rs
@@ -17,7 +17,8 @@ use sodiumoxide::crypto::hash;
use crate::endpoint::*;
use crate::error::*;
use crate::netapp::*;
-use crate::proto::*;
+
+use crate::message::*;
use crate::NodeID;
const CONN_RETRY_INTERVAL: Duration = Duration::from_secs(30);
diff --git a/src/proto2.rs b/src/proto2.rs
deleted file mode 100644
index 7210781..0000000
--- a/src/proto2.rs
+++ /dev/null
@@ -1,75 +0,0 @@
-use crate::error::*;
-use crate::proto::*;
-
-pub(crate) struct QueryMessage<'a> {
- pub(crate) prio: RequestPriority,
- pub(crate) path: &'a [u8],
- pub(crate) telemetry_id: Option<Vec<u8>>,
- pub(crate) body: &'a [u8],
-}
-
-/// QueryMessage encoding:
-/// - priority: u8
-/// - path length: u8
-/// - path: [u8; path length]
-/// - telemetry id length: u8
-/// - telemetry id: [u8; telemetry id length]
-/// - body [u8; ..]
-impl<'a> QueryMessage<'a> {
- pub(crate) fn encode(self) -> Vec<u8> {
- let tel_len = match &self.telemetry_id {
- Some(t) => t.len(),
- None => 0,
- };
-
- let mut ret = Vec::with_capacity(10 + self.path.len() + tel_len + self.body.len());
-
- ret.push(self.prio);
-
- ret.push(self.path.len() as u8);
- ret.extend_from_slice(self.path);
-
- if let Some(t) = self.telemetry_id {
- ret.push(t.len() as u8);
- ret.extend(t);
- } else {
- ret.push(0u8);
- }
-
- ret.extend_from_slice(self.body);
-
- ret
- }
-
- pub(crate) fn decode(bytes: &'a [u8]) -> Result<Self, Error> {
- if bytes.len() < 3 {
- return Err(Error::Message("Invalid protocol message".into()));
- }
-
- let path_length = bytes[1] as usize;
- if bytes.len() < 3 + path_length {
- return Err(Error::Message("Invalid protocol message".into()));
- }
-
- let telemetry_id_len = bytes[2 + path_length] as usize;
- if bytes.len() < 3 + path_length + telemetry_id_len {
- return Err(Error::Message("Invalid protocol message".into()));
- }
-
- let path = &bytes[2..2 + path_length];
- let telemetry_id = if telemetry_id_len > 0 {
- Some(bytes[3 + path_length..3 + path_length + telemetry_id_len].to_vec())
- } else {
- None
- };
-
- let body = &bytes[3 + path_length + telemetry_id_len..];
-
- Ok(Self {
- prio: bytes[0],
- path,
- telemetry_id,
- body,
- })
- }
-}
diff --git a/src/recv.rs b/src/recv.rs
new file mode 100644
index 0000000..628612b
--- /dev/null
+++ b/src/recv.rs
@@ -0,0 +1,114 @@
+use std::collections::HashMap;
+
+use std::sync::Arc;
+
+use log::trace;
+
+use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
+use futures::AsyncReadExt;
+
+use async_trait::async_trait;
+
+use crate::error::*;
+
+use crate::send::*;
+use crate::util::Packet;
+
+/// Structure to warn when the sender is dropped before end of stream was reached, like when
+/// connection to some remote drops while transmitting data
+struct Sender {
+ inner: UnboundedSender<Packet>,
+ closed: bool,
+}
+
+impl Sender {
+ fn new(inner: UnboundedSender<Packet>) -> Self {
+ Sender {
+ inner,
+ closed: false,
+ }
+ }
+
+ fn send(&self, packet: Packet) {
+ let _ = self.inner.unbounded_send(packet);
+ }
+
+ fn end(&mut self) {
+ self.closed = true;
+ }
+}
+
+impl Drop for Sender {
+ fn drop(&mut self) {
+ if !self.closed {
+ self.send(Err(255));
+ }
+ self.inner.close_channel();
+ }
+}
+
+/// The RecvLoop trait, which is implemented both by the client and the server
+/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()`
+/// and a prototype of a handler for received messages `.recv_handler()` that
+/// must be filled by implementors. `.recv_loop()` receives messages in a loop
+/// according to the protocol defined above: chunks of message in progress of being
+/// received are stored in a buffer, and when the last chunk of a message is received,
+/// the full message is passed to the receive handler.
+#[async_trait]
+pub(crate) trait RecvLoop: Sync + 'static {
+ fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>);
+
+ async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error>
+ where
+ R: AsyncReadExt + Unpin + Send + Sync,
+ {
+ let mut streams: HashMap<RequestID, Sender> = HashMap::new();
+ loop {
+ trace!("recv_loop: reading packet");
+ let mut header_id = [0u8; RequestID::BITS as usize / 8];
+ match read.read_exact(&mut header_id[..]).await {
+ Ok(_) => (),
+ Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
+ Err(e) => return Err(e.into()),
+ };
+ let id = RequestID::from_be_bytes(header_id);
+ trace!("recv_loop: got header id: {:04x}", id);
+
+ let mut header_size = [0u8; ChunkLength::BITS as usize / 8];
+ read.read_exact(&mut header_size[..]).await?;
+ let size = ChunkLength::from_be_bytes(header_size);
+ trace!("recv_loop: got header size: {:04x}", size);
+
+ let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0;
+ let is_error = (size & ERROR_MARKER) != 0;
+ let packet = if is_error {
+ Err(size as u8)
+ } else {
+ let size = size & !CHUNK_HAS_CONTINUATION;
+ let mut next_slice = vec![0; size as usize];
+ read.read_exact(&mut next_slice[..]).await?;
+ trace!("recv_loop: read {} bytes", next_slice.len());
+ Ok(next_slice)
+ };
+
+ let mut sender = if let Some(send) = streams.remove(&(id)) {
+ send
+ } else {
+ let (send, recv) = unbounded();
+ self.recv_handler(id, recv);
+ Sender::new(send)
+ };
+
+ // if we get an error, the receiving end is disconnected. We still need to
+ // reach eos before dropping this sender
+ sender.send(packet);
+
+ if has_cont {
+ streams.insert(id, sender);
+ } else {
+ sender.end();
+ }
+ }
+ Ok(())
+ }
+}
diff --git a/src/proto.rs b/src/send.rs
index 92d8d80..330d41d 100644
--- a/src/proto.rs
+++ b/src/send.rs
@@ -1,48 +1,19 @@
-use std::collections::{HashMap, VecDeque};
+use std::collections::VecDeque;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
+use async_trait::async_trait;
use log::trace;
-use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
-use futures::{AsyncReadExt, AsyncWriteExt};
-use futures::{Stream, StreamExt};
+use futures::AsyncWriteExt;
+use futures::Stream;
use kuska_handshake::async_std::BoxStreamWrite;
-
use tokio::sync::mpsc;
-use async_trait::async_trait;
-
use crate::error::*;
-use crate::util::{AssociatedStream, Packet};
-
-/// Priority of a request (click to read more about priorities).
-///
-/// This priority value is used to priorize messages
-/// in the send queue of the client, and their responses in the send queue of the
-/// server. Lower values mean higher priority.
-///
-/// This mechanism is usefull for messages bigger than the maximum chunk size
-/// (set at `0x4000` bytes), such as large file transfers.
-/// In such case, all of the messages in the send queue with the highest priority
-/// will take turns to send individual chunks, in a round-robin fashion.
-/// Once all highest priority messages are sent successfully, the messages with
-/// the next highest priority will begin being sent in the same way.
-///
-/// The same priority value is given to a request and to its associated response.
-pub type RequestPriority = u8;
-
-/// Priority class: high
-pub const PRIO_HIGH: RequestPriority = 0x20;
-/// Priority class: normal
-pub const PRIO_NORMAL: RequestPriority = 0x40;
-/// Priority class: background
-pub const PRIO_BACKGROUND: RequestPriority = 0x80;
-/// Priority: primary among given class
-pub const PRIO_PRIMARY: RequestPriority = 0x00;
-/// Priority: secondary among given class (ex: `PRIO_HIGH | PRIO_SECONDARY`)
-pub const PRIO_SECONDARY: RequestPriority = 0x01;
+use crate::message::*;
+use crate::util::{ByteStream, Packet};
// Messages are sent by chunks
// Chunk format:
@@ -52,10 +23,10 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01;
// - [u8; chunk_length] chunk data
pub(crate) type RequestID = u32;
-type ChunkLength = u16;
-const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0;
-const ERROR_MARKER: ChunkLength = 0x4000;
-const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000;
+pub(crate) type ChunkLength = u16;
+pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0;
+pub(crate) const ERROR_MARKER: ChunkLength = 0x4000;
+pub(crate) const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000;
struct SendQueueItem {
id: RequestID,
@@ -66,15 +37,15 @@ struct SendQueueItem {
#[pin_project::pin_project]
struct DataReader {
#[pin]
- reader: AssociatedStream,
+ reader: ByteStream,
packet: Packet,
pos: usize,
buf: Vec<u8>,
eos: bool,
}
-impl From<AssociatedStream> for DataReader {
- fn from(data: AssociatedStream) -> DataReader {
+impl From<ByteStream> for DataReader {
+ fn from(data: ByteStream) -> DataReader {
DataReader {
reader: data,
packet: Ok(Vec::new()),
@@ -297,7 +268,7 @@ impl<'a> futures::Future for SendQueuePollNextReady<'a> {
pub(crate) trait SendLoop: Sync {
async fn send_loop<W>(
self: Arc<Self>,
- mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, AssociatedStream)>,
+ mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, ByteStream)>,
mut write: BoxStreamWrite<W>,
) -> Result<(), Error>
where
@@ -343,184 +314,6 @@ pub(crate) trait SendLoop: Sync {
}
}
-pub(crate) struct Framing {
- direct: Vec<u8>,
- stream: Option<AssociatedStream>,
-}
-
-impl Framing {
- pub fn new(direct: Vec<u8>, stream: Option<AssociatedStream>) -> Self {
- assert!(direct.len() <= u32::MAX as usize);
- Framing { direct, stream }
- }
-
- pub fn into_stream(self) -> AssociatedStream {
- use futures::stream;
- let len = self.direct.len() as u32;
- // required because otherwise the borrow-checker complains
- let Framing { direct, stream } = self;
-
- let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec()) })
- .chain(stream::once(async move { Ok(direct) }));
-
- if let Some(stream) = stream {
- Box::pin(res.chain(stream))
- } else {
- Box::pin(res)
- }
- }
-
- pub async fn from_stream<S: Stream<Item = Packet> + Unpin + Send + 'static>(
- mut stream: S,
- ) -> Result<Self, Error> {
- let mut packet = stream
- .next()
- .await
- .ok_or(Error::Framing)?
- .map_err(|_| Error::Framing)?;
- if packet.len() < 4 {
- return Err(Error::Framing);
- }
-
- let mut len = [0; 4];
- len.copy_from_slice(&packet[..4]);
- let len = u32::from_be_bytes(len);
- packet.drain(..4);
-
- let mut buffer = Vec::new();
- let len = len as usize;
- loop {
- let max_cp = std::cmp::min(len - buffer.len(), packet.len());
-
- buffer.extend_from_slice(&packet[..max_cp]);
- if buffer.len() == len {
- packet.drain(..max_cp);
- break;
- }
- packet = stream
- .next()
- .await
- .ok_or(Error::Framing)?
- .map_err(|_| Error::Framing)?;
- }
-
- let stream: AssociatedStream = if packet.is_empty() {
- Box::pin(stream)
- } else {
- Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream))
- };
-
- Ok(Framing {
- direct: buffer,
- stream: Some(stream),
- })
- }
-
- pub fn into_parts(self) -> (Vec<u8>, AssociatedStream) {
- let Framing { direct, stream } = self;
- (direct, stream.unwrap_or(Box::pin(futures::stream::empty())))
- }
-}
-
-/// Structure to warn when the sender is dropped before end of stream was reached, like when
-/// connection to some remote drops while transmitting data
-struct Sender {
- inner: UnboundedSender<Packet>,
- closed: bool,
-}
-
-impl Sender {
- fn new(inner: UnboundedSender<Packet>) -> Self {
- Sender {
- inner,
- closed: false,
- }
- }
-
- fn send(&self, packet: Packet) {
- let _ = self.inner.unbounded_send(packet);
- }
-
- fn end(&mut self) {
- self.closed = true;
- }
-}
-
-impl Drop for Sender {
- fn drop(&mut self) {
- if !self.closed {
- self.send(Err(255));
- }
- self.inner.close_channel();
- }
-}
-
-/// The RecvLoop trait, which is implemented both by the client and the server
-/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()`
-/// and a prototype of a handler for received messages `.recv_handler()` that
-/// must be filled by implementors. `.recv_loop()` receives messages in a loop
-/// according to the protocol defined above: chunks of message in progress of being
-/// received are stored in a buffer, and when the last chunk of a message is received,
-/// the full message is passed to the receive handler.
-#[async_trait]
-pub(crate) trait RecvLoop: Sync + 'static {
- fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>);
-
- async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error>
- where
- R: AsyncReadExt + Unpin + Send + Sync,
- {
- let mut streams: HashMap<RequestID, Sender> = HashMap::new();
- loop {
- trace!("recv_loop: reading packet");
- let mut header_id = [0u8; RequestID::BITS as usize / 8];
- match read.read_exact(&mut header_id[..]).await {
- Ok(_) => (),
- Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
- Err(e) => return Err(e.into()),
- };
- let id = RequestID::from_be_bytes(header_id);
- trace!("recv_loop: got header id: {:04x}", id);
-
- let mut header_size = [0u8; ChunkLength::BITS as usize / 8];
- read.read_exact(&mut header_size[..]).await?;
- let size = ChunkLength::from_be_bytes(header_size);
- trace!("recv_loop: got header size: {:04x}", size);
-
- let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0;
- let is_error = (size & ERROR_MARKER) != 0;
- let packet = if is_error {
- Err(size as u8)
- } else {
- let size = size & !CHUNK_HAS_CONTINUATION;
- let mut next_slice = vec![0; size as usize];
- read.read_exact(&mut next_slice[..]).await?;
- trace!("recv_loop: read {} bytes", next_slice.len());
- Ok(next_slice)
- };
-
- let mut sender = if let Some(send) = streams.remove(&(id)) {
- send
- } else {
- let (send, recv) = unbounded();
- self.recv_handler(id, recv);
- Sender::new(send)
- };
-
- // if we get an error, the receiving end is disconnected. We still need to
- // reach eos before dropping this sender
- sender.send(packet);
-
- if has_cont {
- streams.insert(id, sender);
- } else {
- sender.end();
- }
- }
- Ok(())
- }
-}
-
#[cfg(test)]
mod test {
use super::*;
diff --git a/src/server.rs b/src/server.rs
index 8075484..1f1c22a 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -2,8 +2,17 @@ use std::net::SocketAddr;
use std::sync::Arc;
use arc_swap::ArcSwapOption;
+use async_trait::async_trait;
use log::{debug, trace};
+use futures::channel::mpsc::UnboundedReceiver;
+use futures::io::{AsyncReadExt, AsyncWriteExt};
+use kuska_handshake::async_std::{handshake_server, BoxStream};
+use tokio::net::TcpStream;
+use tokio::select;
+use tokio::sync::{mpsc, watch};
+use tokio_util::compat::*;
+
#[cfg(feature = "telemetry")]
use opentelemetry::{
trace::{FutureExt, Span, SpanKind, TraceContextExt, TraceId, Tracer},
@@ -14,22 +23,11 @@ use opentelemetry_contrib::trace::propagator::binary::*;
#[cfg(feature = "telemetry")]
use rand::{thread_rng, Rng};
-use tokio::net::TcpStream;
-use tokio::select;
-use tokio::sync::{mpsc, watch};
-use tokio_util::compat::*;
-
-use futures::channel::mpsc::UnboundedReceiver;
-use futures::io::{AsyncReadExt, AsyncWriteExt};
-
-use async_trait::async_trait;
-
-use kuska_handshake::async_std::{handshake_server, BoxStream};
-
use crate::error::*;
+use crate::message::*;
use crate::netapp::*;
-use crate::proto::*;
-use crate::proto2::*;
+use crate::recv::*;
+use crate::send::*;
use crate::util::*;
// The client and server connection structs (client.rs and server.rs)
@@ -55,7 +53,7 @@ pub(crate) struct ServerConn {
netapp: Arc<NetApp>,
- resp_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, AssociatedStream)>>,
+ resp_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, ByteStream)>>,
}
impl ServerConn {
@@ -126,8 +124,8 @@ impl ServerConn {
async fn recv_handler_aux(
self: &Arc<Self>,
bytes: &[u8],
- stream: AssociatedStream,
- ) -> Result<(Vec<u8>, Option<AssociatedStream>), Error> {
+ stream: ByteStream,
+ ) -> Result<(Vec<u8>, Option<ByteStream>), Error> {
let msg = QueryMessage::decode(bytes)?;
let path = String::from_utf8(msg.path.to_vec())?;
diff --git a/src/util.rs b/src/util.rs
index 186678d..6fbafe6 100644
--- a/src/util.rs
+++ b/src/util.rs
@@ -1,17 +1,15 @@
-use crate::endpoint::SerializeMessage;
-
use std::net::SocketAddr;
use std::net::ToSocketAddrs;
use std::pin::Pin;
-use futures::Stream;
-
use log::info;
-
use serde::Serialize;
+use futures::Stream;
use tokio::sync::watch;
+use crate::message::SerializeMessage;
+
/// A node's identifier, which is also its public cryptographic key
pub type NodeID = sodiumoxide::crypto::sign::ed25519::PublicKey;
/// A node's secret key
@@ -27,7 +25,7 @@ pub type NetworkKey = sodiumoxide::crypto::auth::Key;
///
/// Error code 255 means the stream was cut before its end. Other codes have no predefined
/// meaning, it's up to your application to define their semantic.
-pub type AssociatedStream = Pin<Box<dyn Stream<Item = Packet> + Send>>;
+pub type ByteStream = Pin<Box<dyn Stream<Item = Packet> + Send>>;
pub type Packet = Result<Vec<u8>, u8>;
@@ -38,7 +36,7 @@ pub type Packet = Result<Vec<u8>, u8>;
/// This is used internally by the netapp communication protocol.
pub fn rmp_to_vec_all_named<T>(
val: &T,
-) -> Result<(Vec<u8>, Option<AssociatedStream>), rmp_serde::encode::Error>
+) -> Result<(Vec<u8>, Option<ByteStream>), rmp_serde::encode::Error>
where
T: SerializeMessage + ?Sized,
{