diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/client.rs | 7 | ||||
-rw-r--r-- | src/proto.rs | 59 | ||||
-rw-r--r-- | src/server.rs | 3 | ||||
-rw-r--r-- | src/util.rs | 8 |
4 files changed, 58 insertions, 19 deletions
diff --git a/src/client.rs b/src/client.rs index a630f87..6d49f5c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -7,6 +7,7 @@ use std::sync::{Arc, Mutex}; use arc_swap::ArcSwapOption; use log::{debug, error, trace}; +use futures::channel::mpsc::{unbounded, UnboundedReceiver}; use tokio::net::TcpStream; use tokio::select; use tokio::sync::{mpsc, oneshot, watch}; @@ -41,7 +42,7 @@ pub(crate) struct ClientConn { ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, AssociatedStream)>>, next_query_number: AtomicU32, - inflight: Mutex<HashMap<RequestID, oneshot::Sender<AssociatedStream>>>, + inflight: Mutex<HashMap<RequestID, oneshot::Sender<UnboundedReceiver<Packet>>>>, } impl ClientConn { @@ -186,7 +187,7 @@ impl ClientConn { error!( "Too many inflight requests! RequestID collision. Interrupting previous request." ); - if old_ch.send(Box::pin(futures::stream::empty())).is_err() { + if old_ch.send(unbounded().1).is_err() { debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response."); } } @@ -232,7 +233,7 @@ impl SendLoop for ClientConn {} #[async_trait] impl RecvLoop for ClientConn { - fn recv_handler(self: &Arc<Self>, id: RequestID, stream: AssociatedStream) { + fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>) { trace!("ClientConn recv_handler {}", id); let mut inflight = self.inflight.lock().unwrap(); diff --git a/src/proto.rs b/src/proto.rs index d6dc35a..92d8d80 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -5,7 +5,7 @@ use std::task::{Context, Poll}; use log::trace; -use futures::channel::mpsc::{unbounded, UnboundedSender}; +use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; use futures::{AsyncReadExt, AsyncWriteExt}; use futures::{Stream, StreamExt}; use kuska_handshake::async_std::BoxStreamWrite; @@ -15,7 +15,7 @@ use tokio::sync::mpsc; use async_trait::async_trait; use crate::error::*; -use crate::util::AssociatedStream; +use crate::util::{AssociatedStream, Packet}; /// Priority of a request (click to read more about priorities). /// @@ -67,7 +67,7 @@ struct SendQueueItem { struct DataReader { #[pin] reader: AssociatedStream, - packet: Result<Vec<u8>, u8>, + packet: Packet, pos: usize, buf: Vec<u8>, eos: bool, @@ -370,7 +370,7 @@ impl Framing { } } - pub async fn from_stream<S: Stream<Item = Result<Vec<u8>, u8>> + Unpin + Send + 'static>( + pub async fn from_stream<S: Stream<Item = Packet> + Unpin + Send + 'static>( mut stream: S, ) -> Result<Self, Error> { let mut packet = stream @@ -422,6 +422,39 @@ impl Framing { } } +/// Structure to warn when the sender is dropped before end of stream was reached, like when +/// connection to some remote drops while transmitting data +struct Sender { + inner: UnboundedSender<Packet>, + closed: bool, +} + +impl Sender { + fn new(inner: UnboundedSender<Packet>) -> Self { + Sender { + inner, + closed: false, + } + } + + fn send(&self, packet: Packet) { + let _ = self.inner.unbounded_send(packet); + } + + fn end(&mut self) { + self.closed = true; + } +} + +impl Drop for Sender { + fn drop(&mut self) { + if !self.closed { + self.send(Err(255)); + } + self.inner.close_channel(); + } +} + /// The RecvLoop trait, which is implemented both by the client and the server /// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()` /// and a prototype of a handler for received messages `.recv_handler()` that @@ -431,13 +464,13 @@ impl Framing { /// the full message is passed to the receive handler. #[async_trait] pub(crate) trait RecvLoop: Sync + 'static { - fn recv_handler(self: &Arc<Self>, id: RequestID, stream: AssociatedStream); + fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>); async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error> where R: AsyncReadExt + Unpin + Send + Sync, { - let mut streams: HashMap<RequestID, UnboundedSender<Result<Vec<u8>, u8>>> = HashMap::new(); + let mut streams: HashMap<RequestID, Sender> = HashMap::new(); loop { trace!("recv_loop: reading packet"); let mut header_id = [0u8; RequestID::BITS as usize / 8]; @@ -466,20 +499,22 @@ pub(crate) trait RecvLoop: Sync + 'static { Ok(next_slice) }; - let sender = if let Some(send) = streams.remove(&(id)) { + let mut sender = if let Some(send) = streams.remove(&(id)) { send } else { let (send, recv) = unbounded(); - self.recv_handler(id, Box::pin(recv)); - send + self.recv_handler(id, recv); + Sender::new(send) }; // if we get an error, the receiving end is disconnected. We still need to // reach eos before dropping this sender - let _ = sender.unbounded_send(packet); + sender.send(packet); if has_cont { streams.insert(id, sender); + } else { + sender.end(); } } Ok(()) @@ -491,9 +526,9 @@ mod test { use super::*; fn empty_data() -> DataReader { - type Item = Result<Vec<u8>, u8>; + type Item = Packet; let stream: Pin<Box<dyn futures::Stream<Item = Item> + Send + 'static>> = - Box::pin(futures::stream::empty::<Result<Vec<u8>, u8>>()); + Box::pin(futures::stream::empty::<Packet>()); stream.into() } diff --git a/src/server.rs b/src/server.rs index 86e5156..8075484 100644 --- a/src/server.rs +++ b/src/server.rs @@ -19,6 +19,7 @@ use tokio::select; use tokio::sync::{mpsc, watch}; use tokio_util::compat::*; +use futures::channel::mpsc::UnboundedReceiver; use futures::io::{AsyncReadExt, AsyncWriteExt}; use async_trait::async_trait; @@ -176,7 +177,7 @@ impl SendLoop for ServerConn {} #[async_trait] impl RecvLoop for ServerConn { - fn recv_handler(self: &Arc<Self>, id: RequestID, stream: AssociatedStream) { + fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>) { let resp_send = self.resp_send.load_full().unwrap(); let self2 = self.clone(); diff --git a/src/util.rs b/src/util.rs index 76d7ecf..186678d 100644 --- a/src/util.rs +++ b/src/util.rs @@ -25,9 +25,11 @@ pub type NetworkKey = sodiumoxide::crypto::auth::Key; /// When sent through Netapp, the Vec may be split in smaller chunk in such a way /// consecutive Vec may get merged, but Vec and error code may not be reordered /// -/// The error code have no predefined meaning, it's up to you application to define their -/// semantic. -pub type AssociatedStream = Pin<Box<dyn Stream<Item = Result<Vec<u8>, u8>> + Send>>; +/// Error code 255 means the stream was cut before its end. Other codes have no predefined +/// meaning, it's up to your application to define their semantic. +pub type AssociatedStream = Pin<Box<dyn Stream<Item = Packet> + Send>>; + +pub type Packet = Result<Vec<u8>, u8>; /// Utility function: encodes any serializable value in MessagePack binary format /// using the RMP library. |