aboutsummaryrefslogtreecommitdiff
path: root/src/client.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/client.rs')
-rw-r--r--src/client.rs68
1 files changed, 66 insertions, 2 deletions
diff --git a/src/client.rs b/src/client.rs
index 9726125..d82c91e 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -1,7 +1,9 @@
use std::collections::HashMap;
use std::net::SocketAddr;
+use std::pin::Pin;
use std::sync::atomic::{self, AtomicU32};
use std::sync::{Arc, Mutex};
+use std::task::Poll;
use arc_swap::ArcSwapOption;
use async_trait::async_trait;
@@ -9,6 +11,7 @@ use bytes::Bytes;
use log::{debug, error, trace};
use futures::io::AsyncReadExt;
+use futures::Stream;
use kuska_handshake::async_std::{handshake_client, BoxStream};
use tokio::net::TcpStream;
use tokio::select;
@@ -35,7 +38,7 @@ pub(crate) struct ClientConn {
pub(crate) remote_addr: SocketAddr,
pub(crate) peer_id: NodeID,
- query_send: ArcSwapOption<mpsc::UnboundedSender<SendStream>>,
+ query_send: ArcSwapOption<mpsc::UnboundedSender<SendItem>>,
next_query_number: AtomicU32,
inflight: Mutex<HashMap<RequestID, oneshot::Sender<ByteStream>>>,
@@ -193,7 +196,9 @@ impl ClientConn {
#[cfg(feature = "telemetry")]
span.set_attribute(KeyValue::new("len_query_msg", req_msg_len as i64));
- query_send.send((id, prio, req_order, req_stream))?;
+ query_send.send(SendItem::Stream(id, prio, req_order, req_stream))?;
+
+ let canceller = CancelOnDrop::new(id, query_send.as_ref().clone());
cfg_if::cfg_if! {
if #[cfg(feature = "telemetry")] {
@@ -205,6 +210,8 @@ impl ClientConn {
}
}
+ let stream = Box::pin(canceller.for_stream(stream));
+
let resp_enc = RespEnc::decode(stream).await?;
debug!("client: got response to request {} (path {})", id, path);
Resp::from_enc(resp_enc)
@@ -223,6 +230,63 @@ impl RecvLoop for ClientConn {
if ch.send(stream).is_err() {
debug!("Could not send request response, probably because request was interrupted. Dropping response.");
}
+ } else {
+ debug!("Got unexpected response to request {}, dropping it", id);
+ }
+ }
+}
+
+// ----
+
+struct CancelOnDrop {
+ id: RequestID,
+ query_send: mpsc::UnboundedSender<SendItem>,
+}
+
+impl CancelOnDrop {
+ fn new(id: RequestID, query_send: mpsc::UnboundedSender<SendItem>) -> Self {
+ Self { id, query_send }
+ }
+ fn for_stream(self, stream: ByteStream) -> CancelOnDropStream {
+ CancelOnDropStream {
+ cancel: Some(self),
+ stream: stream,
+ }
+ }
+}
+
+impl Drop for CancelOnDrop {
+ fn drop(&mut self) {
+ trace!("cancelling request {}", self.id);
+ let _ = self.query_send.send(SendItem::Cancel(self.id));
+ }
+}
+
+#[pin_project::pin_project]
+struct CancelOnDropStream {
+ cancel: Option<CancelOnDrop>,
+ #[pin]
+ stream: ByteStream,
+}
+
+impl Stream for CancelOnDropStream {
+ type Item = Packet;
+
+ fn poll_next(
+ self: Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
+ let this = self.project();
+ let res = this.stream.poll_next(cx);
+ if matches!(res, Poll::Ready(None)) {
+ if let Some(c) = this.cancel.take() {
+ std::mem::forget(c)
+ }
}
+ res
+ }
+
+ fn size_hint(&self) -> (usize, Option<usize>) {
+ self.stream.size_hint()
}
}