From cdff8ae1beab44a22d0eb0eb00c624e49971b6ca Mon Sep 17 00:00:00 2001 From: trinity-1686a Date: Mon, 18 Jul 2022 15:21:13 +0200 Subject: add detection of premature eos --- src/client.rs | 7 ++++--- src/proto.rs | 59 +++++++++++++++++++++++++++++++++++++++++++++++------------ src/server.rs | 3 ++- src/util.rs | 8 +++++--- 4 files changed, 58 insertions(+), 19 deletions(-) (limited to 'src') diff --git a/src/client.rs b/src/client.rs index a630f87..6d49f5c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -7,6 +7,7 @@ use std::sync::{Arc, Mutex}; use arc_swap::ArcSwapOption; use log::{debug, error, trace}; +use futures::channel::mpsc::{unbounded, UnboundedReceiver}; use tokio::net::TcpStream; use tokio::select; use tokio::sync::{mpsc, oneshot, watch}; @@ -41,7 +42,7 @@ pub(crate) struct ClientConn { ArcSwapOption>, next_query_number: AtomicU32, - inflight: Mutex>>, + inflight: Mutex>>>, } impl ClientConn { @@ -186,7 +187,7 @@ impl ClientConn { error!( "Too many inflight requests! RequestID collision. Interrupting previous request." ); - if old_ch.send(Box::pin(futures::stream::empty())).is_err() { + if old_ch.send(unbounded().1).is_err() { debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response."); } } @@ -232,7 +233,7 @@ impl SendLoop for ClientConn {} #[async_trait] impl RecvLoop for ClientConn { - fn recv_handler(self: &Arc, id: RequestID, stream: AssociatedStream) { + fn recv_handler(self: &Arc, id: RequestID, stream: UnboundedReceiver) { trace!("ClientConn recv_handler {}", id); let mut inflight = self.inflight.lock().unwrap(); diff --git a/src/proto.rs b/src/proto.rs index d6dc35a..92d8d80 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -5,7 +5,7 @@ use std::task::{Context, Poll}; use log::trace; -use futures::channel::mpsc::{unbounded, UnboundedSender}; +use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; use futures::{AsyncReadExt, AsyncWriteExt}; use futures::{Stream, StreamExt}; use kuska_handshake::async_std::BoxStreamWrite; @@ -15,7 +15,7 @@ use tokio::sync::mpsc; use async_trait::async_trait; use crate::error::*; -use crate::util::AssociatedStream; +use crate::util::{AssociatedStream, Packet}; /// Priority of a request (click to read more about priorities). /// @@ -67,7 +67,7 @@ struct SendQueueItem { struct DataReader { #[pin] reader: AssociatedStream, - packet: Result, u8>, + packet: Packet, pos: usize, buf: Vec, eos: bool, @@ -370,7 +370,7 @@ impl Framing { } } - pub async fn from_stream, u8>> + Unpin + Send + 'static>( + pub async fn from_stream + Unpin + Send + 'static>( mut stream: S, ) -> Result { let mut packet = stream @@ -422,6 +422,39 @@ impl Framing { } } +/// Structure to warn when the sender is dropped before end of stream was reached, like when +/// connection to some remote drops while transmitting data +struct Sender { + inner: UnboundedSender, + closed: bool, +} + +impl Sender { + fn new(inner: UnboundedSender) -> Self { + Sender { + inner, + closed: false, + } + } + + fn send(&self, packet: Packet) { + let _ = self.inner.unbounded_send(packet); + } + + fn end(&mut self) { + self.closed = true; + } +} + +impl Drop for Sender { + fn drop(&mut self) { + if !self.closed { + self.send(Err(255)); + } + self.inner.close_channel(); + } +} + /// The RecvLoop trait, which is implemented both by the client and the server /// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` /// and a prototype of a handler for received messages `.recv_handler()` that @@ -431,13 +464,13 @@ impl Framing { /// the full message is passed to the receive handler. #[async_trait] pub(crate) trait RecvLoop: Sync + 'static { - fn recv_handler(self: &Arc, id: RequestID, stream: AssociatedStream); + fn recv_handler(self: &Arc, id: RequestID, stream: UnboundedReceiver); async fn recv_loop(self: Arc, mut read: R) -> Result<(), Error> where R: AsyncReadExt + Unpin + Send + Sync, { - let mut streams: HashMap, u8>>> = HashMap::new(); + let mut streams: HashMap = HashMap::new(); loop { trace!("recv_loop: reading packet"); let mut header_id = [0u8; RequestID::BITS as usize / 8]; @@ -466,20 +499,22 @@ pub(crate) trait RecvLoop: Sync + 'static { Ok(next_slice) }; - let sender = if let Some(send) = streams.remove(&(id)) { + let mut sender = if let Some(send) = streams.remove(&(id)) { send } else { let (send, recv) = unbounded(); - self.recv_handler(id, Box::pin(recv)); - send + self.recv_handler(id, recv); + Sender::new(send) }; // if we get an error, the receiving end is disconnected. We still need to // reach eos before dropping this sender - let _ = sender.unbounded_send(packet); + sender.send(packet); if has_cont { streams.insert(id, sender); + } else { + sender.end(); } } Ok(()) @@ -491,9 +526,9 @@ mod test { use super::*; fn empty_data() -> DataReader { - type Item = Result, u8>; + type Item = Packet; let stream: Pin + Send + 'static>> = - Box::pin(futures::stream::empty::, u8>>()); + Box::pin(futures::stream::empty::()); stream.into() } diff --git a/src/server.rs b/src/server.rs index 86e5156..8075484 100644 --- a/src/server.rs +++ b/src/server.rs @@ -19,6 +19,7 @@ use tokio::select; use tokio::sync::{mpsc, watch}; use tokio_util::compat::*; +use futures::channel::mpsc::UnboundedReceiver; use futures::io::{AsyncReadExt, AsyncWriteExt}; use async_trait::async_trait; @@ -176,7 +177,7 @@ impl SendLoop for ServerConn {} #[async_trait] impl RecvLoop for ServerConn { - fn recv_handler(self: &Arc, id: RequestID, stream: AssociatedStream) { + fn recv_handler(self: &Arc, id: RequestID, stream: UnboundedReceiver) { let resp_send = self.resp_send.load_full().unwrap(); let self2 = self.clone(); diff --git a/src/util.rs b/src/util.rs index 76d7ecf..186678d 100644 --- a/src/util.rs +++ b/src/util.rs @@ -25,9 +25,11 @@ pub type NetworkKey = sodiumoxide::crypto::auth::Key; /// When sent through Netapp, the Vec may be split in smaller chunk in such a way /// consecutive Vec may get merged, but Vec and error code may not be reordered /// -/// The error code have no predefined meaning, it's up to you application to define their -/// semantic. -pub type AssociatedStream = Pin, u8>> + Send>>; +/// Error code 255 means the stream was cut before its end. Other codes have no predefined +/// meaning, it's up to your application to define their semantic. +pub type AssociatedStream = Pin + Send>>; + +pub type Packet = Result, u8>; /// Utility function: encodes any serializable value in MessagePack binary format /// using the RMP library. -- cgit v1.2.3