diff options
author | trinity-1686a <trinity@deuxfleurs.fr> | 2022-06-05 15:33:43 +0200 |
---|---|---|
committer | trinity-1686a <trinity@deuxfleurs.fr> | 2022-06-05 15:33:43 +0200 |
commit | 368ba908794901bc793c6a087c02241be046bdf2 (patch) | |
tree | 389910f1e1476c9531a01d2e53060e1056cca266 /src/proto.rs | |
parent | 648e015e3a73b96973343e0a1f861c9ea41cc24d (diff) | |
download | netapp-368ba908794901bc793c6a087c02241be046bdf2.tar.gz netapp-368ba908794901bc793c6a087c02241be046bdf2.zip |
initial work on associated stream
still require testing, and fixing a few kinks:
- sending packets > 16k truncate them
- send one more packet than it could at eos
- probably update documentation
/!\ contains breaking changes
Diffstat (limited to 'src/proto.rs')
-rw-r--r-- | src/proto.rs | 260 |
1 files changed, 216 insertions, 44 deletions
diff --git a/src/proto.rs b/src/proto.rs index e843bff..b45ff13 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -1,9 +1,13 @@ use std::collections::{HashMap, VecDeque}; +use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; -use log::trace; +use log::{trace, warn}; -use futures::{AsyncReadExt, AsyncWriteExt}; +use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; +use futures::Stream; +use futures::{AsyncReadExt, AsyncWriteExt, FutureExt, StreamExt}; use kuska_handshake::async_std::BoxStreamWrite; use tokio::sync::mpsc; @@ -11,6 +15,7 @@ use tokio::sync::mpsc; use async_trait::async_trait; use crate::error::*; +use crate::util::AssociatedStream; /// Priority of a request (click to read more about priorities). /// @@ -48,14 +53,73 @@ pub const PRIO_SECONDARY: RequestPriority = 0x01; pub(crate) type RequestID = u32; type ChunkLength = u16; -const MAX_CHUNK_LENGTH: ChunkLength = 0x4000; +pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x4000; const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; struct SendQueueItem { id: RequestID, prio: RequestPriority, - data: Vec<u8>, - cursor: usize, + data: DataReader, +} + +pub(crate) enum Data { + Full(Vec<u8>), + Streaming(AssociatedStream), +} + +#[pin_project::pin_project(project = DataReaderProj)] +enum DataReader { + Full { + #[pin] + data: Vec<u8>, + pos: usize, + }, + Streaming { + #[pin] + reader: AssociatedStream, + }, +} + +impl From<Data> for DataReader { + fn from(data: Data) -> DataReader { + match data { + Data::Full(data) => DataReader::Full { data, pos: 0 }, + Data::Streaming(reader) => DataReader::Streaming { reader }, + } + } +} + +impl Stream for DataReader { + type Item = ([u8; MAX_CHUNK_LENGTH as usize], usize); + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + match self.project() { + DataReaderProj::Full { data, pos } => { + let len = std::cmp::min(MAX_CHUNK_LENGTH as usize, data.len() - *pos); + let end = *pos + len; + + if len == 0 { + Poll::Ready(None) + } else { + let mut body = [0; MAX_CHUNK_LENGTH as usize]; + body[..len].copy_from_slice(&data[*pos..end]); + *pos = end; + Poll::Ready(Some((body, len))) + } + } + DataReaderProj::Streaming { reader } => { + reader.poll_next(cx).map(|opt| { + opt.map(|v| { + let mut body = [0; MAX_CHUNK_LENGTH as usize]; + let len = std::cmp::min(MAX_CHUNK_LENGTH as usize, v.len()); + // TODO this can throw away long vec, they should be splited instead + body[..len].copy_from_slice(&v[..len]); + (body, len) + }) + }) + } + } + } } struct SendQueue { @@ -108,7 +172,7 @@ impl SendQueue { pub(crate) trait SendLoop: Sync { async fn send_loop<W>( self: Arc<Self>, - mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec<u8>)>, + mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Data)>, mut write: BoxStreamWrite<W>, ) -> Result<(), Error> where @@ -118,51 +182,78 @@ pub(crate) trait SendLoop: Sync { let mut should_exit = false; while !should_exit || !sending.is_empty() { if let Ok((id, prio, data)) = msg_recv.try_recv() { - trace!("send_loop: got {}, {} bytes", id, data.len()); + match &data { + Data::Full(data) => { + trace!("send_loop: got {}, {} bytes", id, data.len()); + } + Data::Streaming(_) => { + trace!("send_loop: got {}, unknown size", id); + } + } sending.push(SendQueueItem { id, prio, - data, - cursor: 0, + data: data.into(), }); } else if let Some(mut item) = sending.pop() { trace!( - "send_loop: sending bytes for {} ({} bytes, {} already sent)", - item.id, - item.data.len(), - item.cursor + "send_loop: sending bytes for {}", + item.id, ); + + let data = futures::select! { + data = item.data.next().fuse() => data, + default => { + // nothing to send yet; re-schedule and find something else to do + sending.push(item); + continue; + + // TODO if every SendQueueItem is waiting on data, use select_all to await + // something to do + // TODO find some way to not require sending empty last chunk + } + }; + let header_id = RequestID::to_be_bytes(item.id); write.write_all(&header_id[..]).await?; - if item.data.len() - item.cursor > MAX_CHUNK_LENGTH as usize { + let data = match data.as_ref() { + Some((data, len)) => &data[..*len], + None => &[], + }; + + if !data.is_empty() { let size_header = - ChunkLength::to_be_bytes(MAX_CHUNK_LENGTH | CHUNK_HAS_CONTINUATION); + ChunkLength::to_be_bytes(data.len() as u16 | CHUNK_HAS_CONTINUATION); write.write_all(&size_header[..]).await?; - let new_cursor = item.cursor + MAX_CHUNK_LENGTH as usize; - write.write_all(&item.data[item.cursor..new_cursor]).await?; - item.cursor = new_cursor; + write.write_all(data).await?; sending.push(item); } else { - let send_len = (item.data.len() - item.cursor) as ChunkLength; - - let size_header = ChunkLength::to_be_bytes(send_len); + // this is always zero for now, but may be more when above TODO get fixed + let size_header = ChunkLength::to_be_bytes(data.len() as u16); write.write_all(&size_header[..]).await?; - write.write_all(&item.data[item.cursor..]).await?; + write.write_all(data).await?; } + write.flush().await?; } else { let sth = msg_recv.recv().await; if let Some((id, prio, data)) = sth { - trace!("send_loop: got {}, {} bytes", id, data.len()); + match &data { + Data::Full(data) => { + trace!("send_loop: got {}, {} bytes", id, data.len()); + } + Data::Streaming(_) => { + trace!("send_loop: got {}, unknown size", id); + } + } sending.push(SendQueueItem { id, prio, - data, - cursor: 0, + data: data.into(), }); } else { should_exit = true; @@ -175,6 +266,41 @@ pub(crate) trait SendLoop: Sync { } } +struct ChannelPair { + receiver: Option<UnboundedReceiver<Vec<u8>>>, + sender: Option<UnboundedSender<Vec<u8>>>, +} + +impl ChannelPair { + fn take_receiver(&mut self) -> Option<UnboundedReceiver<Vec<u8>>> { + self.receiver.take() + } + + fn take_sender(&mut self) -> Option<UnboundedSender<Vec<u8>>> { + self.sender.take() + } + + fn ref_sender(&mut self) -> Option<&UnboundedSender<Vec<u8>>> { + self.sender.as_ref().take() + } + + fn insert_into(self, map: &mut HashMap<RequestID, ChannelPair>, index: RequestID) { + if self.receiver.is_some() || self.sender.is_some() { + map.insert(index, self); + } + } +} + +impl Default for ChannelPair { + fn default() -> Self { + let (send, recv) = unbounded(); + ChannelPair { + receiver: Some(recv), + sender: Some(send), + } + } +} + /// The RecvLoop trait, which is implemented both by the client and the server /// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` /// and a prototype of a handler for received messages `.recv_handler()` that @@ -184,13 +310,17 @@ pub(crate) trait SendLoop: Sync { /// the full message is passed to the receive handler. #[async_trait] pub(crate) trait RecvLoop: Sync + 'static { - fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>); + fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>, stream: AssociatedStream); async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error> where R: AsyncReadExt + Unpin + Send + Sync, { - let mut receiving = HashMap::new(); + let mut receiving: HashMap<RequestID, Vec<u8>> = HashMap::new(); + let mut streams: HashMap< + RequestID, + ChannelPair, + > = HashMap::new(); loop { trace!("recv_loop: reading packet"); let mut header_id = [0u8; RequestID::BITS as usize / 8]; @@ -214,13 +344,43 @@ pub(crate) trait RecvLoop: Sync + 'static { read.read_exact(&mut next_slice[..]).await?; trace!("recv_loop: read {} bytes", next_slice.len()); - let mut msg_bytes: Vec<_> = receiving.remove(&id).unwrap_or_default(); - msg_bytes.extend_from_slice(&next_slice[..]); + if id & 1 == 0 { + // main stream + let mut msg_bytes = receiving.remove(&id).unwrap_or_default(); + msg_bytes.extend_from_slice(&next_slice[..]); - if has_cont { - receiving.insert(id, msg_bytes); + if has_cont { + receiving.insert(id, msg_bytes); + } else { + let mut channel_pair = streams.remove(&(id | 1)).unwrap_or_default(); + + if let Some(receiver) = channel_pair.take_receiver() { + self.recv_handler(id, msg_bytes, Box::pin(receiver)); + } else { + warn!("Couldn't take receiver part of stream") + } + + channel_pair.insert_into(&mut streams, id | 1); + } } else { - self.recv_handler(id, msg_bytes); + // associated stream + let mut channel_pair = streams.remove(&(id)).unwrap_or_default(); + + // if we get an error, the receiving end is disconnected. We still need to + // reach eos before dropping this sender + if !next_slice.is_empty() { + if let Some(sender) = channel_pair.ref_sender() { + let _ = sender.unbounded_send(next_slice); + } else { + warn!("Couldn't take sending part of stream") + } + } + + if !has_cont { + channel_pair.take_sender(); + } + + channel_pair.insert_into(&mut streams, id); } } Ok(()) @@ -236,38 +396,50 @@ mod test { let i1 = SendQueueItem { id: 1, prio: PRIO_NORMAL, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let i2 = SendQueueItem { id: 2, prio: PRIO_HIGH, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let i2bis = SendQueueItem { id: 20, prio: PRIO_HIGH, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let i3 = SendQueueItem { id: 3, prio: PRIO_HIGH | PRIO_SECONDARY, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let i4 = SendQueueItem { id: 4, prio: PRIO_BACKGROUND | PRIO_SECONDARY, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let i5 = SendQueueItem { id: 5, prio: PRIO_BACKGROUND | PRIO_PRIMARY, - data: vec![], - cursor: 0, + data: DataReader::Full { + data: vec![], + pos: 0, + }, }; let mut q = SendQueue::new(); |