diff options
Diffstat (limited to 'src/client.rs')
-rw-r--r-- | src/client.rs | 68 |
1 files changed, 66 insertions, 2 deletions
diff --git a/src/client.rs b/src/client.rs index 9726125..d82c91e 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,7 +1,9 @@ use std::collections::HashMap; use std::net::SocketAddr; +use std::pin::Pin; use std::sync::atomic::{self, AtomicU32}; use std::sync::{Arc, Mutex}; +use std::task::Poll; use arc_swap::ArcSwapOption; use async_trait::async_trait; @@ -9,6 +11,7 @@ use bytes::Bytes; use log::{debug, error, trace}; use futures::io::AsyncReadExt; +use futures::Stream; use kuska_handshake::async_std::{handshake_client, BoxStream}; use tokio::net::TcpStream; use tokio::select; @@ -35,7 +38,7 @@ pub(crate) struct ClientConn { pub(crate) remote_addr: SocketAddr, pub(crate) peer_id: NodeID, - query_send: ArcSwapOption<mpsc::UnboundedSender<SendStream>>, + query_send: ArcSwapOption<mpsc::UnboundedSender<SendItem>>, next_query_number: AtomicU32, inflight: Mutex<HashMap<RequestID, oneshot::Sender<ByteStream>>>, @@ -193,7 +196,9 @@ impl ClientConn { #[cfg(feature = "telemetry")] span.set_attribute(KeyValue::new("len_query_msg", req_msg_len as i64)); - query_send.send((id, prio, req_order, req_stream))?; + query_send.send(SendItem::Stream(id, prio, req_order, req_stream))?; + + let canceller = CancelOnDrop::new(id, query_send.as_ref().clone()); cfg_if::cfg_if! { if #[cfg(feature = "telemetry")] { @@ -205,6 +210,8 @@ impl ClientConn { } } + let stream = Box::pin(canceller.for_stream(stream)); + let resp_enc = RespEnc::decode(stream).await?; debug!("client: got response to request {} (path {})", id, path); Resp::from_enc(resp_enc) @@ -223,6 +230,63 @@ impl RecvLoop for ClientConn { if ch.send(stream).is_err() { debug!("Could not send request response, probably because request was interrupted. Dropping response."); } + } else { + debug!("Got unexpected response to request {}, dropping it", id); + } + } +} + +// ---- + +struct CancelOnDrop { + id: RequestID, + query_send: mpsc::UnboundedSender<SendItem>, +} + +impl CancelOnDrop { + fn new(id: RequestID, query_send: mpsc::UnboundedSender<SendItem>) -> Self { + Self { id, query_send } + } + fn for_stream(self, stream: ByteStream) -> CancelOnDropStream { + CancelOnDropStream { + cancel: Some(self), + stream: stream, + } + } +} + +impl Drop for CancelOnDrop { + fn drop(&mut self) { + trace!("cancelling request {}", self.id); + let _ = self.query_send.send(SendItem::Cancel(self.id)); + } +} + +#[pin_project::pin_project] +struct CancelOnDropStream { + cancel: Option<CancelOnDrop>, + #[pin] + stream: ByteStream, +} + +impl Stream for CancelOnDropStream { + type Item = Packet; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<Option<Self::Item>> { + let this = self.project(); + let res = this.stream.poll_next(cx); + if matches!(res, Poll::Ready(None)) { + if let Some(c) = this.cancel.take() { + std::mem::forget(c) + } } + res + } + + fn size_hint(&self) -> (usize, Option<usize>) { + self.stream.size_hint() } } |