diff options
Diffstat (limited to 'src/proto.rs')
-rw-r--r-- | src/proto.rs | 59 |
1 files changed, 47 insertions, 12 deletions
diff --git a/src/proto.rs b/src/proto.rs index d6dc35a..92d8d80 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -5,7 +5,7 @@ use std::task::{Context, Poll}; use log::trace; -use futures::channel::mpsc::{unbounded, UnboundedSender}; +use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; use futures::{AsyncReadExt, AsyncWriteExt}; use futures::{Stream, StreamExt}; use kuska_handshake::async_std::BoxStreamWrite; @@ -15,7 +15,7 @@ use tokio::sync::mpsc; use async_trait::async_trait; use crate::error::*; -use crate::util::AssociatedStream; +use crate::util::{AssociatedStream, Packet}; /// Priority of a request (click to read more about priorities). /// @@ -67,7 +67,7 @@ struct SendQueueItem { struct DataReader { #[pin] reader: AssociatedStream, - packet: Result<Vec<u8>, u8>, + packet: Packet, pos: usize, buf: Vec<u8>, eos: bool, @@ -370,7 +370,7 @@ impl Framing { } } - pub async fn from_stream<S: Stream<Item = Result<Vec<u8>, u8>> + Unpin + Send + 'static>( + pub async fn from_stream<S: Stream<Item = Packet> + Unpin + Send + 'static>( mut stream: S, ) -> Result<Self, Error> { let mut packet = stream @@ -422,6 +422,39 @@ impl Framing { } } +/// Structure to warn when the sender is dropped before end of stream was reached, like when +/// connection to some remote drops while transmitting data +struct Sender { + inner: UnboundedSender<Packet>, + closed: bool, +} + +impl Sender { + fn new(inner: UnboundedSender<Packet>) -> Self { + Sender { + inner, + closed: false, + } + } + + fn send(&self, packet: Packet) { + let _ = self.inner.unbounded_send(packet); + } + + fn end(&mut self) { + self.closed = true; + } +} + +impl Drop for Sender { + fn drop(&mut self) { + if !self.closed { + self.send(Err(255)); + } + self.inner.close_channel(); + } +} + /// The RecvLoop trait, which is implemented both by the client and the server /// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` /// and a prototype of a handler for received messages `.recv_handler()` that @@ -431,13 +464,13 @@ impl Framing { /// the full message is passed to the receive handler. #[async_trait] pub(crate) trait RecvLoop: Sync + 'static { - fn recv_handler(self: &Arc<Self>, id: RequestID, stream: AssociatedStream); + fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>); async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error> where R: AsyncReadExt + Unpin + Send + Sync, { - let mut streams: HashMap<RequestID, UnboundedSender<Result<Vec<u8>, u8>>> = HashMap::new(); + let mut streams: HashMap<RequestID, Sender> = HashMap::new(); loop { trace!("recv_loop: reading packet"); let mut header_id = [0u8; RequestID::BITS as usize / 8]; @@ -466,20 +499,22 @@ pub(crate) trait RecvLoop: Sync + 'static { Ok(next_slice) }; - let sender = if let Some(send) = streams.remove(&(id)) { + let mut sender = if let Some(send) = streams.remove(&(id)) { send } else { let (send, recv) = unbounded(); - self.recv_handler(id, Box::pin(recv)); - send + self.recv_handler(id, recv); + Sender::new(send) }; // if we get an error, the receiving end is disconnected. We still need to // reach eos before dropping this sender - let _ = sender.unbounded_send(packet); + sender.send(packet); if has_cont { streams.insert(id, sender); + } else { + sender.end(); } } Ok(()) @@ -491,9 +526,9 @@ mod test { use super::*; fn empty_data() -> DataReader { - type Item = Result<Vec<u8>, u8>; + type Item = Packet; let stream: Pin<Box<dyn futures::Stream<Item = Item> + Send + 'static>> = - Box::pin(futures::stream::empty::<Result<Vec<u8>, u8>>()); + Box::pin(futures::stream::empty::<Packet>()); stream.into() } |