aboutsummaryrefslogtreecommitdiff
path: root/src/recv.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/recv.rs')
-rw-r--r--src/recv.rs114
1 files changed, 114 insertions, 0 deletions
diff --git a/src/recv.rs b/src/recv.rs
new file mode 100644
index 0000000..628612b
--- /dev/null
+++ b/src/recv.rs
@@ -0,0 +1,114 @@
+use std::collections::HashMap;
+
+use std::sync::Arc;
+
+use log::trace;
+
+use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
+use futures::AsyncReadExt;
+
+use async_trait::async_trait;
+
+use crate::error::*;
+
+use crate::send::*;
+use crate::util::Packet;
+
+/// Structure to warn when the sender is dropped before end of stream was reached, like when
+/// connection to some remote drops while transmitting data
+struct Sender {
+ inner: UnboundedSender<Packet>,
+ closed: bool,
+}
+
+impl Sender {
+ fn new(inner: UnboundedSender<Packet>) -> Self {
+ Sender {
+ inner,
+ closed: false,
+ }
+ }
+
+ fn send(&self, packet: Packet) {
+ let _ = self.inner.unbounded_send(packet);
+ }
+
+ fn end(&mut self) {
+ self.closed = true;
+ }
+}
+
+impl Drop for Sender {
+ fn drop(&mut self) {
+ if !self.closed {
+ self.send(Err(255));
+ }
+ self.inner.close_channel();
+ }
+}
+
+/// The RecvLoop trait, which is implemented both by the client and the server
+/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()`
+/// and a prototype of a handler for received messages `.recv_handler()` that
+/// must be filled by implementors. `.recv_loop()` receives messages in a loop
+/// according to the protocol defined above: chunks of message in progress of being
+/// received are stored in a buffer, and when the last chunk of a message is received,
+/// the full message is passed to the receive handler.
+#[async_trait]
+pub(crate) trait RecvLoop: Sync + 'static {
+ fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>);
+
+ async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error>
+ where
+ R: AsyncReadExt + Unpin + Send + Sync,
+ {
+ let mut streams: HashMap<RequestID, Sender> = HashMap::new();
+ loop {
+ trace!("recv_loop: reading packet");
+ let mut header_id = [0u8; RequestID::BITS as usize / 8];
+ match read.read_exact(&mut header_id[..]).await {
+ Ok(_) => (),
+ Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
+ Err(e) => return Err(e.into()),
+ };
+ let id = RequestID::from_be_bytes(header_id);
+ trace!("recv_loop: got header id: {:04x}", id);
+
+ let mut header_size = [0u8; ChunkLength::BITS as usize / 8];
+ read.read_exact(&mut header_size[..]).await?;
+ let size = ChunkLength::from_be_bytes(header_size);
+ trace!("recv_loop: got header size: {:04x}", size);
+
+ let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0;
+ let is_error = (size & ERROR_MARKER) != 0;
+ let packet = if is_error {
+ Err(size as u8)
+ } else {
+ let size = size & !CHUNK_HAS_CONTINUATION;
+ let mut next_slice = vec![0; size as usize];
+ read.read_exact(&mut next_slice[..]).await?;
+ trace!("recv_loop: read {} bytes", next_slice.len());
+ Ok(next_slice)
+ };
+
+ let mut sender = if let Some(send) = streams.remove(&(id)) {
+ send
+ } else {
+ let (send, recv) = unbounded();
+ self.recv_handler(id, recv);
+ Sender::new(send)
+ };
+
+ // if we get an error, the receiving end is disconnected. We still need to
+ // reach eos before dropping this sender
+ sender.send(packet);
+
+ if has_cont {
+ streams.insert(id, sender);
+ } else {
+ sender.end();
+ }
+ }
+ Ok(())
+ }
+}