aboutsummaryrefslogtreecommitdiff
path: root/src/proto.rs
diff options
context:
space:
mode:
authortrinity-1686a <trinity@deuxfleurs.fr>2022-06-05 15:33:43 +0200
committertrinity-1686a <trinity@deuxfleurs.fr>2022-06-05 15:33:43 +0200
commit368ba908794901bc793c6a087c02241be046bdf2 (patch)
tree389910f1e1476c9531a01d2e53060e1056cca266 /src/proto.rs
parent648e015e3a73b96973343e0a1f861c9ea41cc24d (diff)
downloadnetapp-368ba908794901bc793c6a087c02241be046bdf2.tar.gz
netapp-368ba908794901bc793c6a087c02241be046bdf2.zip
initial work on associated stream
still require testing, and fixing a few kinks: - sending packets > 16k truncate them - send one more packet than it could at eos - probably update documentation /!\ contains breaking changes
Diffstat (limited to 'src/proto.rs')
-rw-r--r--src/proto.rs260
1 files changed, 216 insertions, 44 deletions
diff --git a/src/proto.rs b/src/proto.rs
index e843bff..b45ff13 100644
--- a/src/proto.rs
+++ b/src/proto.rs
@@ -1,9 +1,13 @@
use std::collections::{HashMap, VecDeque};
+use std::pin::Pin;
use std::sync::Arc;
+use std::task::{Context, Poll};
-use log::trace;
+use log::{trace, warn};
-use futures::{AsyncReadExt, AsyncWriteExt};
+use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
+use futures::Stream;
+use futures::{AsyncReadExt, AsyncWriteExt, FutureExt, StreamExt};
use kuska_handshake::async_std::BoxStreamWrite;
use tokio::sync::mpsc;
@@ -11,6 +15,7 @@ use tokio::sync::mpsc;
use async_trait::async_trait;
use crate::error::*;
+use crate::util::AssociatedStream;
/// Priority of a request (click to read more about priorities).
///
@@ -48,14 +53,73 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01;
pub(crate) type RequestID = u32;
type ChunkLength = u16;
-const MAX_CHUNK_LENGTH: ChunkLength = 0x4000;
+pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x4000;
const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000;
struct SendQueueItem {
id: RequestID,
prio: RequestPriority,
- data: Vec<u8>,
- cursor: usize,
+ data: DataReader,
+}
+
+pub(crate) enum Data {
+ Full(Vec<u8>),
+ Streaming(AssociatedStream),
+}
+
+#[pin_project::pin_project(project = DataReaderProj)]
+enum DataReader {
+ Full {
+ #[pin]
+ data: Vec<u8>,
+ pos: usize,
+ },
+ Streaming {
+ #[pin]
+ reader: AssociatedStream,
+ },
+}
+
+impl From<Data> for DataReader {
+ fn from(data: Data) -> DataReader {
+ match data {
+ Data::Full(data) => DataReader::Full { data, pos: 0 },
+ Data::Streaming(reader) => DataReader::Streaming { reader },
+ }
+ }
+}
+
+impl Stream for DataReader {
+ type Item = ([u8; MAX_CHUNK_LENGTH as usize], usize);
+
+ fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ match self.project() {
+ DataReaderProj::Full { data, pos } => {
+ let len = std::cmp::min(MAX_CHUNK_LENGTH as usize, data.len() - *pos);
+ let end = *pos + len;
+
+ if len == 0 {
+ Poll::Ready(None)
+ } else {
+ let mut body = [0; MAX_CHUNK_LENGTH as usize];
+ body[..len].copy_from_slice(&data[*pos..end]);
+ *pos = end;
+ Poll::Ready(Some((body, len)))
+ }
+ }
+ DataReaderProj::Streaming { reader } => {
+ reader.poll_next(cx).map(|opt| {
+ opt.map(|v| {
+ let mut body = [0; MAX_CHUNK_LENGTH as usize];
+ let len = std::cmp::min(MAX_CHUNK_LENGTH as usize, v.len());
+ // TODO this can throw away long vec, they should be splited instead
+ body[..len].copy_from_slice(&v[..len]);
+ (body, len)
+ })
+ })
+ }
+ }
+ }
}
struct SendQueue {
@@ -108,7 +172,7 @@ impl SendQueue {
pub(crate) trait SendLoop: Sync {
async fn send_loop<W>(
self: Arc<Self>,
- mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec<u8>)>,
+ mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Data)>,
mut write: BoxStreamWrite<W>,
) -> Result<(), Error>
where
@@ -118,51 +182,78 @@ pub(crate) trait SendLoop: Sync {
let mut should_exit = false;
while !should_exit || !sending.is_empty() {
if let Ok((id, prio, data)) = msg_recv.try_recv() {
- trace!("send_loop: got {}, {} bytes", id, data.len());
+ match &data {
+ Data::Full(data) => {
+ trace!("send_loop: got {}, {} bytes", id, data.len());
+ }
+ Data::Streaming(_) => {
+ trace!("send_loop: got {}, unknown size", id);
+ }
+ }
sending.push(SendQueueItem {
id,
prio,
- data,
- cursor: 0,
+ data: data.into(),
});
} else if let Some(mut item) = sending.pop() {
trace!(
- "send_loop: sending bytes for {} ({} bytes, {} already sent)",
- item.id,
- item.data.len(),
- item.cursor
+ "send_loop: sending bytes for {}",
+ item.id,
);
+
+ let data = futures::select! {
+ data = item.data.next().fuse() => data,
+ default => {
+ // nothing to send yet; re-schedule and find something else to do
+ sending.push(item);
+ continue;
+
+ // TODO if every SendQueueItem is waiting on data, use select_all to await
+ // something to do
+ // TODO find some way to not require sending empty last chunk
+ }
+ };
+
let header_id = RequestID::to_be_bytes(item.id);
write.write_all(&header_id[..]).await?;
- if item.data.len() - item.cursor > MAX_CHUNK_LENGTH as usize {
+ let data = match data.as_ref() {
+ Some((data, len)) => &data[..*len],
+ None => &[],
+ };
+
+ if !data.is_empty() {
let size_header =
- ChunkLength::to_be_bytes(MAX_CHUNK_LENGTH | CHUNK_HAS_CONTINUATION);
+ ChunkLength::to_be_bytes(data.len() as u16 | CHUNK_HAS_CONTINUATION);
write.write_all(&size_header[..]).await?;
- let new_cursor = item.cursor + MAX_CHUNK_LENGTH as usize;
- write.write_all(&item.data[item.cursor..new_cursor]).await?;
- item.cursor = new_cursor;
+ write.write_all(data).await?;
sending.push(item);
} else {
- let send_len = (item.data.len() - item.cursor) as ChunkLength;
-
- let size_header = ChunkLength::to_be_bytes(send_len);
+ // this is always zero for now, but may be more when above TODO get fixed
+ let size_header = ChunkLength::to_be_bytes(data.len() as u16);
write.write_all(&size_header[..]).await?;
- write.write_all(&item.data[item.cursor..]).await?;
+ write.write_all(data).await?;
}
+
write.flush().await?;
} else {
let sth = msg_recv.recv().await;
if let Some((id, prio, data)) = sth {
- trace!("send_loop: got {}, {} bytes", id, data.len());
+ match &data {
+ Data::Full(data) => {
+ trace!("send_loop: got {}, {} bytes", id, data.len());
+ }
+ Data::Streaming(_) => {
+ trace!("send_loop: got {}, unknown size", id);
+ }
+ }
sending.push(SendQueueItem {
id,
prio,
- data,
- cursor: 0,
+ data: data.into(),
});
} else {
should_exit = true;
@@ -175,6 +266,41 @@ pub(crate) trait SendLoop: Sync {
}
}
+struct ChannelPair {
+ receiver: Option<UnboundedReceiver<Vec<u8>>>,
+ sender: Option<UnboundedSender<Vec<u8>>>,
+}
+
+impl ChannelPair {
+ fn take_receiver(&mut self) -> Option<UnboundedReceiver<Vec<u8>>> {
+ self.receiver.take()
+ }
+
+ fn take_sender(&mut self) -> Option<UnboundedSender<Vec<u8>>> {
+ self.sender.take()
+ }
+
+ fn ref_sender(&mut self) -> Option<&UnboundedSender<Vec<u8>>> {
+ self.sender.as_ref().take()
+ }
+
+ fn insert_into(self, map: &mut HashMap<RequestID, ChannelPair>, index: RequestID) {
+ if self.receiver.is_some() || self.sender.is_some() {
+ map.insert(index, self);
+ }
+ }
+}
+
+impl Default for ChannelPair {
+ fn default() -> Self {
+ let (send, recv) = unbounded();
+ ChannelPair {
+ receiver: Some(recv),
+ sender: Some(send),
+ }
+ }
+}
+
/// The RecvLoop trait, which is implemented both by the client and the server
/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()`
/// and a prototype of a handler for received messages `.recv_handler()` that
@@ -184,13 +310,17 @@ pub(crate) trait SendLoop: Sync {
/// the full message is passed to the receive handler.
#[async_trait]
pub(crate) trait RecvLoop: Sync + 'static {
- fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>);
+ fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>, stream: AssociatedStream);
async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error>
where
R: AsyncReadExt + Unpin + Send + Sync,
{
- let mut receiving = HashMap::new();
+ let mut receiving: HashMap<RequestID, Vec<u8>> = HashMap::new();
+ let mut streams: HashMap<
+ RequestID,
+ ChannelPair,
+ > = HashMap::new();
loop {
trace!("recv_loop: reading packet");
let mut header_id = [0u8; RequestID::BITS as usize / 8];
@@ -214,13 +344,43 @@ pub(crate) trait RecvLoop: Sync + 'static {
read.read_exact(&mut next_slice[..]).await?;
trace!("recv_loop: read {} bytes", next_slice.len());
- let mut msg_bytes: Vec<_> = receiving.remove(&id).unwrap_or_default();
- msg_bytes.extend_from_slice(&next_slice[..]);
+ if id & 1 == 0 {
+ // main stream
+ let mut msg_bytes = receiving.remove(&id).unwrap_or_default();
+ msg_bytes.extend_from_slice(&next_slice[..]);
- if has_cont {
- receiving.insert(id, msg_bytes);
+ if has_cont {
+ receiving.insert(id, msg_bytes);
+ } else {
+ let mut channel_pair = streams.remove(&(id | 1)).unwrap_or_default();
+
+ if let Some(receiver) = channel_pair.take_receiver() {
+ self.recv_handler(id, msg_bytes, Box::pin(receiver));
+ } else {
+ warn!("Couldn't take receiver part of stream")
+ }
+
+ channel_pair.insert_into(&mut streams, id | 1);
+ }
} else {
- self.recv_handler(id, msg_bytes);
+ // associated stream
+ let mut channel_pair = streams.remove(&(id)).unwrap_or_default();
+
+ // if we get an error, the receiving end is disconnected. We still need to
+ // reach eos before dropping this sender
+ if !next_slice.is_empty() {
+ if let Some(sender) = channel_pair.ref_sender() {
+ let _ = sender.unbounded_send(next_slice);
+ } else {
+ warn!("Couldn't take sending part of stream")
+ }
+ }
+
+ if !has_cont {
+ channel_pair.take_sender();
+ }
+
+ channel_pair.insert_into(&mut streams, id);
}
}
Ok(())
@@ -236,38 +396,50 @@ mod test {
let i1 = SendQueueItem {
id: 1,
prio: PRIO_NORMAL,
- data: vec![],
- cursor: 0,
+ data: DataReader::Full {
+ data: vec![],
+ pos: 0,
+ },
};
let i2 = SendQueueItem {
id: 2,
prio: PRIO_HIGH,
- data: vec![],
- cursor: 0,
+ data: DataReader::Full {
+ data: vec![],
+ pos: 0,
+ },
};
let i2bis = SendQueueItem {
id: 20,
prio: PRIO_HIGH,
- data: vec![],
- cursor: 0,
+ data: DataReader::Full {
+ data: vec![],
+ pos: 0,
+ },
};
let i3 = SendQueueItem {
id: 3,
prio: PRIO_HIGH | PRIO_SECONDARY,
- data: vec![],
- cursor: 0,
+ data: DataReader::Full {
+ data: vec![],
+ pos: 0,
+ },
};
let i4 = SendQueueItem {
id: 4,
prio: PRIO_BACKGROUND | PRIO_SECONDARY,
- data: vec![],
- cursor: 0,
+ data: DataReader::Full {
+ data: vec![],
+ pos: 0,
+ },
};
let i5 = SendQueueItem {
id: 5,
prio: PRIO_BACKGROUND | PRIO_PRIMARY,
- data: vec![],
- cursor: 0,
+ data: DataReader::Full {
+ data: vec![],
+ pos: 0,
+ },
};
let mut q = SendQueue::new();