diff options
Diffstat (limited to 'src/recv.rs')
-rw-r--r-- | src/recv.rs | 114 |
1 files changed, 114 insertions, 0 deletions
diff --git a/src/recv.rs b/src/recv.rs new file mode 100644 index 0000000..628612b --- /dev/null +++ b/src/recv.rs @@ -0,0 +1,114 @@ +use std::collections::HashMap; + +use std::sync::Arc; + +use log::trace; + +use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; +use futures::AsyncReadExt; + +use async_trait::async_trait; + +use crate::error::*; + +use crate::send::*; +use crate::util::Packet; + +/// Structure to warn when the sender is dropped before end of stream was reached, like when +/// connection to some remote drops while transmitting data +struct Sender { + inner: UnboundedSender<Packet>, + closed: bool, +} + +impl Sender { + fn new(inner: UnboundedSender<Packet>) -> Self { + Sender { + inner, + closed: false, + } + } + + fn send(&self, packet: Packet) { + let _ = self.inner.unbounded_send(packet); + } + + fn end(&mut self) { + self.closed = true; + } +} + +impl Drop for Sender { + fn drop(&mut self) { + if !self.closed { + self.send(Err(255)); + } + self.inner.close_channel(); + } +} + +/// The RecvLoop trait, which is implemented both by the client and the server +/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` +/// and a prototype of a handler for received messages `.recv_handler()` that +/// must be filled by implementors. `.recv_loop()` receives messages in a loop +/// according to the protocol defined above: chunks of message in progress of being +/// received are stored in a buffer, and when the last chunk of a message is received, +/// the full message is passed to the receive handler. +#[async_trait] +pub(crate) trait RecvLoop: Sync + 'static { + fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>); + + async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error> + where + R: AsyncReadExt + Unpin + Send + Sync, + { + let mut streams: HashMap<RequestID, Sender> = HashMap::new(); + loop { + trace!("recv_loop: reading packet"); + let mut header_id = [0u8; RequestID::BITS as usize / 8]; + match read.read_exact(&mut header_id[..]).await { + Ok(_) => (), + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break, + Err(e) => return Err(e.into()), + }; + let id = RequestID::from_be_bytes(header_id); + trace!("recv_loop: got header id: {:04x}", id); + + let mut header_size = [0u8; ChunkLength::BITS as usize / 8]; + read.read_exact(&mut header_size[..]).await?; + let size = ChunkLength::from_be_bytes(header_size); + trace!("recv_loop: got header size: {:04x}", size); + + let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; + let is_error = (size & ERROR_MARKER) != 0; + let packet = if is_error { + Err(size as u8) + } else { + let size = size & !CHUNK_HAS_CONTINUATION; + let mut next_slice = vec![0; size as usize]; + read.read_exact(&mut next_slice[..]).await?; + trace!("recv_loop: read {} bytes", next_slice.len()); + Ok(next_slice) + }; + + let mut sender = if let Some(send) = streams.remove(&(id)) { + send + } else { + let (send, recv) = unbounded(); + self.recv_handler(id, recv); + Sender::new(send) + }; + + // if we get an error, the receiving end is disconnected. We still need to + // reach eos before dropping this sender + sender.send(packet); + + if has_cont { + streams.insert(id, sender); + } else { + sender.end(); + } + } + Ok(()) + } +} |