diff options
Diffstat (limited to 'src/proto.rs')
-rw-r--r-- | src/proto.rs | 364 |
1 files changed, 167 insertions, 197 deletions
diff --git a/src/proto.rs b/src/proto.rs index e3f9be8..d6dc35a 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -3,11 +3,11 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use log::{trace, warn}; +use log::trace; -use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; -use futures::Stream; +use futures::channel::mpsc::{unbounded, UnboundedSender}; use futures::{AsyncReadExt, AsyncWriteExt}; +use futures::{Stream, StreamExt}; use kuska_handshake::async_std::BoxStreamWrite; use tokio::sync::mpsc; @@ -63,39 +63,24 @@ struct SendQueueItem { data: DataReader, } -pub(crate) enum Data { - Full(Vec<u8>), - Streaming(AssociatedStream), +#[pin_project::pin_project] +struct DataReader { + #[pin] + reader: AssociatedStream, + packet: Result<Vec<u8>, u8>, + pos: usize, + buf: Vec<u8>, + eos: bool, } -#[pin_project::pin_project(project = DataReaderProj)] -enum DataReader { - Full { - #[pin] - data: Vec<u8>, - pos: usize, - }, - Streaming { - #[pin] - reader: AssociatedStream, - packet: Result<Vec<u8>, u8>, - pos: usize, - buf: Vec<u8>, - eos: bool, - }, -} - -impl From<Data> for DataReader { - fn from(data: Data) -> DataReader { - match data { - Data::Full(data) => DataReader::Full { data, pos: 0 }, - Data::Streaming(reader) => DataReader::Streaming { - reader, - packet: Ok(Vec::new()), - pos: 0, - buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize), - eos: false, - }, +impl From<AssociatedStream> for DataReader { + fn from(data: AssociatedStream) -> DataReader { + DataReader { + reader: data, + packet: Ok(Vec::new()), + pos: 0, + buf: Vec::with_capacity(MAX_CHUNK_LENGTH as usize), + eos: false, } } } @@ -155,82 +140,60 @@ impl Stream for DataReader { type Item = DataReaderItem; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { - match self.project() { - DataReaderProj::Full { data, pos } => { - let len = std::cmp::min(MAX_CHUNK_LENGTH as usize, data.len() - *pos); - let end = *pos + len; - - if len == 0 { - Poll::Ready(None) - } else { - let mut body = [0; MAX_CHUNK_LENGTH as usize]; - body[..len].copy_from_slice(&data[*pos..end]); - *pos = end; - Poll::Ready(Some(DataReaderItem { - data: DataFrame::Data { data: body, len }, - may_have_more: end < data.len(), - })) + let mut this = self.project(); + + if *this.eos { + // eos was reached at previous call to poll_next, where a partial packet + // was returned. Now return None + return Poll::Ready(None); + } + + loop { + let packet = match this.packet { + Ok(v) => v, + Err(e) => { + let e = *e; + *this.packet = Ok(Vec::new()); + return Poll::Ready(Some(DataReaderItem { + data: DataFrame::Error(e), + may_have_more: true, + })); } + }; + let packet_left = packet.len() - *this.pos; + let buf_left = MAX_CHUNK_LENGTH as usize - this.buf.len(); + let to_read = std::cmp::min(buf_left, packet_left); + this.buf + .extend_from_slice(&packet[*this.pos..*this.pos + to_read]); + *this.pos += to_read; + if this.buf.len() == MAX_CHUNK_LENGTH as usize { + // we have a full buf, ready to send + break; } - DataReaderProj::Streaming { - mut reader, - packet: res_packet, - pos, - buf, - eos, - } => { - if *eos { - // eos was reached at previous call to poll_next, where a partial packet - // was returned. Now return None - return Poll::Ready(None); - } - loop { - let packet = match res_packet { - Ok(v) => v, - Err(e) => { - let e = *e; - *res_packet = Ok(Vec::new()); - return Poll::Ready(Some(DataReaderItem { - data: DataFrame::Error(e), - may_have_more: true, - })); - } - }; - let packet_left = packet.len() - *pos; - let buf_left = MAX_CHUNK_LENGTH as usize - buf.len(); - let to_read = std::cmp::min(buf_left, packet_left); - buf.extend_from_slice(&packet[*pos..*pos + to_read]); - *pos += to_read; - if buf.len() == MAX_CHUNK_LENGTH as usize { - // we have a full buf, ready to send - break; - } - // we don't have a full buf, packet is empty; try receive more - if let Some(p) = futures::ready!(reader.as_mut().poll_next(cx)) { - *res_packet = p; - *pos = 0; - // if buf is empty, we will loop and return the error directly. If buf - // isn't empty, send it before by breaking. - if res_packet.is_err() && !buf.is_empty() { - break; - } - } else { - *eos = true; - break; - } + // we don't have a full buf, packet is empty; try receive more + if let Some(p) = futures::ready!(this.reader.as_mut().poll_next(cx)) { + *this.packet = p; + *this.pos = 0; + // if buf is empty, we will loop and return the error directly. If buf + // isn't empty, send it before by breaking. + if this.packet.is_err() && !this.buf.is_empty() { + break; } - - let mut body = [0; MAX_CHUNK_LENGTH as usize]; - let len = buf.len(); - body[..len].copy_from_slice(buf); - buf.clear(); - Poll::Ready(Some(DataReaderItem { - data: DataFrame::Data { data: body, len }, - may_have_more: !*eos, - })) + } else { + *this.eos = true; + break; } } + + let mut body = [0; MAX_CHUNK_LENGTH as usize]; + let len = this.buf.len(); + body[..len].copy_from_slice(this.buf); + this.buf.clear(); + Poll::Ready(Some(DataReaderItem { + data: DataFrame::Data { data: body, len }, + may_have_more: !*this.eos, + })) } } @@ -334,7 +297,7 @@ impl<'a> futures::Future for SendQueuePollNextReady<'a> { pub(crate) trait SendLoop: Sync { async fn send_loop<W>( self: Arc<Self>, - mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Data)>, + mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, AssociatedStream)>, mut write: BoxStreamWrite<W>, ) -> Result<(), Error> where @@ -380,38 +343,82 @@ pub(crate) trait SendLoop: Sync { } } -struct ChannelPair { - receiver: Option<UnboundedReceiver<Vec<u8>>>, - sender: Option<UnboundedSender<Vec<u8>>>, +pub(crate) struct Framing { + direct: Vec<u8>, + stream: Option<AssociatedStream>, } -impl ChannelPair { - fn take_receiver(&mut self) -> Option<UnboundedReceiver<Vec<u8>>> { - self.receiver.take() +impl Framing { + pub fn new(direct: Vec<u8>, stream: Option<AssociatedStream>) -> Self { + assert!(direct.len() <= u32::MAX as usize); + Framing { direct, stream } } - fn take_sender(&mut self) -> Option<UnboundedSender<Vec<u8>>> { - self.sender.take() - } + pub fn into_stream(self) -> AssociatedStream { + use futures::stream; + let len = self.direct.len() as u32; + // required because otherwise the borrow-checker complains + let Framing { direct, stream } = self; - fn ref_sender(&mut self) -> Option<&UnboundedSender<Vec<u8>>> { - self.sender.as_ref().take() - } + let res = stream::once(async move { Ok(u32::to_be_bytes(len).to_vec()) }) + .chain(stream::once(async move { Ok(direct) })); - fn insert_into(self, map: &mut HashMap<RequestID, ChannelPair>, index: RequestID) { - if self.receiver.is_some() || self.sender.is_some() { - map.insert(index, self); + if let Some(stream) = stream { + Box::pin(res.chain(stream)) + } else { + Box::pin(res) } } -} -impl Default for ChannelPair { - fn default() -> Self { - let (send, recv) = unbounded(); - ChannelPair { - receiver: Some(recv), - sender: Some(send), + pub async fn from_stream<S: Stream<Item = Result<Vec<u8>, u8>> + Unpin + Send + 'static>( + mut stream: S, + ) -> Result<Self, Error> { + let mut packet = stream + .next() + .await + .ok_or(Error::Framing)? + .map_err(|_| Error::Framing)?; + if packet.len() < 4 { + return Err(Error::Framing); + } + + let mut len = [0; 4]; + len.copy_from_slice(&packet[..4]); + let len = u32::from_be_bytes(len); + packet.drain(..4); + + let mut buffer = Vec::new(); + let len = len as usize; + loop { + let max_cp = std::cmp::min(len - buffer.len(), packet.len()); + + buffer.extend_from_slice(&packet[..max_cp]); + if buffer.len() == len { + packet.drain(..max_cp); + break; + } + packet = stream + .next() + .await + .ok_or(Error::Framing)? + .map_err(|_| Error::Framing)?; } + + let stream: AssociatedStream = if packet.is_empty() { + Box::pin(stream) + } else { + Box::pin(futures::stream::once(async move { Ok(packet) }).chain(stream)) + }; + + Ok(Framing { + direct: buffer, + stream: Some(stream), + }) + } + + pub fn into_parts(self) -> (Vec<u8>, AssociatedStream) { + let Framing { direct, stream } = self; + (direct, stream.unwrap_or(Box::pin(futures::stream::empty()))) } } @@ -424,14 +431,13 @@ impl Default for ChannelPair { /// the full message is passed to the receive handler. #[async_trait] pub(crate) trait RecvLoop: Sync + 'static { - fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>, stream: AssociatedStream); + fn recv_handler(self: &Arc<Self>, id: RequestID, stream: AssociatedStream); async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error> where R: AsyncReadExt + Unpin + Send + Sync, { - let mut receiving: HashMap<RequestID, Vec<u8>> = HashMap::new(); - let mut streams: HashMap<RequestID, ChannelPair> = HashMap::new(); + let mut streams: HashMap<RequestID, UnboundedSender<Result<Vec<u8>, u8>>> = HashMap::new(); loop { trace!("recv_loop: reading packet"); let mut header_id = [0u8; RequestID::BITS as usize / 8]; @@ -450,55 +456,30 @@ pub(crate) trait RecvLoop: Sync + 'static { let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; let is_error = (size & ERROR_MARKER) != 0; - let size = if !is_error { - size & !CHUNK_HAS_CONTINUATION + let packet = if is_error { + Err(size as u8) } else { - 0 + let size = size & !CHUNK_HAS_CONTINUATION; + let mut next_slice = vec![0; size as usize]; + read.read_exact(&mut next_slice[..]).await?; + trace!("recv_loop: read {} bytes", next_slice.len()); + Ok(next_slice) }; - // TODO propagate errors - - let mut next_slice = vec![0; size as usize]; - read.read_exact(&mut next_slice[..]).await?; - trace!("recv_loop: read {} bytes", next_slice.len()); - - if id & 1 == 0 { - // main stream - let mut msg_bytes = receiving.remove(&id).unwrap_or_default(); - msg_bytes.extend_from_slice(&next_slice[..]); - - if has_cont { - receiving.insert(id, msg_bytes); - } else { - let mut channel_pair = streams.remove(&(id | 1)).unwrap_or_default(); - - if let Some(receiver) = channel_pair.take_receiver() { - use futures::StreamExt; - self.recv_handler(id, msg_bytes, Box::pin(receiver.map(|v| Ok(v)))); - } else { - warn!("Couldn't take receiver part of stream") - } - channel_pair.insert_into(&mut streams, id | 1); - } + let sender = if let Some(send) = streams.remove(&(id)) { + send } else { - // associated stream - let mut channel_pair = streams.remove(&(id)).unwrap_or_default(); - - // if we get an error, the receiving end is disconnected. We still need to - // reach eos before dropping this sender - if !next_slice.is_empty() { - if let Some(sender) = channel_pair.ref_sender() { - let _ = sender.unbounded_send(next_slice); - } else { - warn!("Couldn't take sending part of stream") - } - } + let (send, recv) = unbounded(); + self.recv_handler(id, Box::pin(recv)); + send + }; - if !has_cont { - channel_pair.take_sender(); - } + // if we get an error, the receiving end is disconnected. We still need to + // reach eos before dropping this sender + let _ = sender.unbounded_send(packet); - channel_pair.insert_into(&mut streams, id); + if has_cont { + streams.insert(id, sender); } } Ok(()) @@ -509,55 +490,44 @@ pub(crate) trait RecvLoop: Sync + 'static { mod test { use super::*; + fn empty_data() -> DataReader { + type Item = Result<Vec<u8>, u8>; + let stream: Pin<Box<dyn futures::Stream<Item = Item> + Send + 'static>> = + Box::pin(futures::stream::empty::<Result<Vec<u8>, u8>>()); + stream.into() + } + #[test] fn test_priority_queue() { let i1 = SendQueueItem { id: 1, prio: PRIO_NORMAL, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let i2 = SendQueueItem { id: 2, prio: PRIO_HIGH, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let i2bis = SendQueueItem { id: 20, prio: PRIO_HIGH, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let i3 = SendQueueItem { id: 3, prio: PRIO_HIGH | PRIO_SECONDARY, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let i4 = SendQueueItem { id: 4, prio: PRIO_BACKGROUND | PRIO_SECONDARY, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let i5 = SendQueueItem { id: 5, prio: PRIO_BACKGROUND | PRIO_PRIMARY, - data: DataReader::Full { - data: vec![], - pos: 0, - }, + data: empty_data(), }; let mut q = SendQueue::new(); |