aboutsummaryrefslogtreecommitdiff
path: root/src/proto.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/proto.rs')
-rw-r--r--src/proto.rs59
1 files changed, 47 insertions, 12 deletions
diff --git a/src/proto.rs b/src/proto.rs
index d6dc35a..92d8d80 100644
--- a/src/proto.rs
+++ b/src/proto.rs
@@ -5,7 +5,7 @@ use std::task::{Context, Poll};
use log::trace;
-use futures::channel::mpsc::{unbounded, UnboundedSender};
+use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
use futures::{AsyncReadExt, AsyncWriteExt};
use futures::{Stream, StreamExt};
use kuska_handshake::async_std::BoxStreamWrite;
@@ -15,7 +15,7 @@ use tokio::sync::mpsc;
use async_trait::async_trait;
use crate::error::*;
-use crate::util::AssociatedStream;
+use crate::util::{AssociatedStream, Packet};
/// Priority of a request (click to read more about priorities).
///
@@ -67,7 +67,7 @@ struct SendQueueItem {
struct DataReader {
#[pin]
reader: AssociatedStream,
- packet: Result<Vec<u8>, u8>,
+ packet: Packet,
pos: usize,
buf: Vec<u8>,
eos: bool,
@@ -370,7 +370,7 @@ impl Framing {
}
}
- pub async fn from_stream<S: Stream<Item = Result<Vec<u8>, u8>> + Unpin + Send + 'static>(
+ pub async fn from_stream<S: Stream<Item = Packet> + Unpin + Send + 'static>(
mut stream: S,
) -> Result<Self, Error> {
let mut packet = stream
@@ -422,6 +422,39 @@ impl Framing {
}
}
+/// Structure to warn when the sender is dropped before end of stream was reached, like when
+/// connection to some remote drops while transmitting data
+struct Sender {
+ inner: UnboundedSender<Packet>,
+ closed: bool,
+}
+
+impl Sender {
+ fn new(inner: UnboundedSender<Packet>) -> Self {
+ Sender {
+ inner,
+ closed: false,
+ }
+ }
+
+ fn send(&self, packet: Packet) {
+ let _ = self.inner.unbounded_send(packet);
+ }
+
+ fn end(&mut self) {
+ self.closed = true;
+ }
+}
+
+impl Drop for Sender {
+ fn drop(&mut self) {
+ if !self.closed {
+ self.send(Err(255));
+ }
+ self.inner.close_channel();
+ }
+}
+
/// The RecvLoop trait, which is implemented both by the client and the server
/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()`
/// and a prototype of a handler for received messages `.recv_handler()` that
@@ -431,13 +464,13 @@ impl Framing {
/// the full message is passed to the receive handler.
#[async_trait]
pub(crate) trait RecvLoop: Sync + 'static {
- fn recv_handler(self: &Arc<Self>, id: RequestID, stream: AssociatedStream);
+ fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>);
async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error>
where
R: AsyncReadExt + Unpin + Send + Sync,
{
- let mut streams: HashMap<RequestID, UnboundedSender<Result<Vec<u8>, u8>>> = HashMap::new();
+ let mut streams: HashMap<RequestID, Sender> = HashMap::new();
loop {
trace!("recv_loop: reading packet");
let mut header_id = [0u8; RequestID::BITS as usize / 8];
@@ -466,20 +499,22 @@ pub(crate) trait RecvLoop: Sync + 'static {
Ok(next_slice)
};
- let sender = if let Some(send) = streams.remove(&(id)) {
+ let mut sender = if let Some(send) = streams.remove(&(id)) {
send
} else {
let (send, recv) = unbounded();
- self.recv_handler(id, Box::pin(recv));
- send
+ self.recv_handler(id, recv);
+ Sender::new(send)
};
// if we get an error, the receiving end is disconnected. We still need to
// reach eos before dropping this sender
- let _ = sender.unbounded_send(packet);
+ sender.send(packet);
if has_cont {
streams.insert(id, sender);
+ } else {
+ sender.end();
}
}
Ok(())
@@ -491,9 +526,9 @@ mod test {
use super::*;
fn empty_data() -> DataReader {
- type Item = Result<Vec<u8>, u8>;
+ type Item = Packet;
let stream: Pin<Box<dyn futures::Stream<Item = Item> + Send + 'static>> =
- Box::pin(futures::stream::empty::<Result<Vec<u8>, u8>>());
+ Box::pin(futures::stream::empty::<Packet>());
stream.into()
}