diff options
Diffstat (limited to 'src/client.rs')
-rw-r--r-- | src/client.rs | 158 |
1 files changed, 101 insertions, 57 deletions
diff --git a/src/client.rs b/src/client.rs index 5c5a05b..d82c91e 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,12 +1,18 @@ -use std::borrow::Borrow; use std::collections::HashMap; use std::net::SocketAddr; +use std::pin::Pin; use std::sync::atomic::{self, AtomicU32}; use std::sync::{Arc, Mutex}; +use std::task::Poll; use arc_swap::ArcSwapOption; +use async_trait::async_trait; +use bytes::Bytes; use log::{debug, error, trace}; +use futures::io::AsyncReadExt; +use futures::Stream; +use kuska_handshake::async_std::{handshake_client, BoxStream}; use tokio::net::TcpStream; use tokio::select; use tokio::sync::{mpsc, oneshot, watch}; @@ -20,27 +26,22 @@ use opentelemetry::{ #[cfg(feature = "telemetry")] use opentelemetry_contrib::trace::propagator::binary::*; -use futures::io::AsyncReadExt; - -use async_trait::async_trait; - -use kuska_handshake::async_std::{handshake_client, BoxStream}; - -use crate::endpoint::*; use crate::error::*; +use crate::message::*; use crate::netapp::*; -use crate::proto::*; -use crate::proto2::*; +use crate::recv::*; +use crate::send::*; +use crate::stream::*; use crate::util::*; pub(crate) struct ClientConn { pub(crate) remote_addr: SocketAddr, pub(crate) peer_id: NodeID, - query_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>>, + query_send: ArcSwapOption<mpsc::UnboundedSender<SendItem>>, next_query_number: AtomicU32, - inflight: Mutex<HashMap<RequestID, oneshot::Sender<Vec<u8>>>>, + inflight: Mutex<HashMap<RequestID, oneshot::Sender<ByteStream>>>, } impl ClientConn { @@ -139,15 +140,14 @@ impl ClientConn { self.query_send.store(None); } - pub(crate) async fn call<T, B>( + pub(crate) async fn call<T>( self: Arc<Self>, - rq: B, + req: Req<T>, path: &str, prio: RequestPriority, - ) -> Result<<T as Message>::Response, Error> + ) -> Result<Resp<T>, Error> where T: Message, - B: Borrow<T>, { let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?; @@ -162,24 +162,16 @@ impl ClientConn { .with_kind(SpanKind::Client) .start(&tracer); let propagator = BinaryPropagator::new(); - let telemetry_id = Some(propagator.to_bytes(span.span_context()).to_vec()); + let telemetry_id: Bytes = propagator.to_bytes(span.span_context()).to_vec().into(); } else { - let telemetry_id: Option<Vec<u8>> = None; + let telemetry_id: Bytes = Bytes::new(); } }; // Encode request - let body = rmp_to_vec_all_named(rq.borrow())?; - drop(rq); - - let request = QueryMessage { - prio, - path: path.as_bytes(), - telemetry_id, - body: &body[..], - }; - let bytes = request.encode(); - drop(body); + let req_enc = req.into_enc(prio, path.as_bytes().to_vec().into(), telemetry_id); + let req_msg_len = req_enc.msg.len(); + let (req_stream, req_order) = req_enc.encode(); // Send request through let (resp_send, resp_recv) = oneshot::channel(); @@ -188,46 +180,41 @@ impl ClientConn { error!( "Too many inflight requests! RequestID collision. Interrupting previous request." ); - if old_ch.send(vec![]).is_err() { - debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response."); - } + let _ = old_ch.send(Box::pin(futures::stream::once(async move { + Err(std::io::Error::new( + std::io::ErrorKind::Other, + "RequestID collision, too many inflight requests", + )) + }))); } - trace!("request: query_send {}, {} bytes", id, bytes.len()); + debug!( + "request: query_send {}, path {}, prio {} (serialized message: {} bytes)", + id, path, prio, req_msg_len + ); #[cfg(feature = "telemetry")] - span.set_attribute(KeyValue::new("len_query", bytes.len() as i64)); + span.set_attribute(KeyValue::new("len_query_msg", req_msg_len as i64)); - query_send.send((id, prio, bytes))?; + query_send.send(SendItem::Stream(id, prio, req_order, req_stream))?; + + let canceller = CancelOnDrop::new(id, query_send.as_ref().clone()); cfg_if::cfg_if! { if #[cfg(feature = "telemetry")] { - let resp = resp_recv + let stream = resp_recv .with_context(Context::current_with_span(span)) .await?; } else { - let resp = resp_recv.await?; + let stream = resp_recv.await?; } } - if resp.is_empty() { - return Err(Error::Message( - "Response is 0 bytes, either a collision or a protocol error".into(), - )); - } - - trace!("request response {}: ", id); + let stream = Box::pin(canceller.for_stream(stream)); - let code = resp[0]; - if code == 0 { - Ok(rmp_serde::decode::from_read_ref::< - _, - <T as Message>::Response, - >(&resp[1..])?) - } else { - let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default(); - Err(Error::Remote(code, msg)) - } + let resp_enc = RespEnc::decode(stream).await?; + debug!("client: got response to request {} (path {})", id, path); + Resp::from_enc(resp_enc) } } @@ -235,14 +222,71 @@ impl SendLoop for ClientConn {} #[async_trait] impl RecvLoop for ClientConn { - fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>) { - trace!("ClientConn recv_handler {} ({} bytes)", id, msg.len()); + fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream) { + trace!("ClientConn recv_handler {}", id); let mut inflight = self.inflight.lock().unwrap(); if let Some(ch) = inflight.remove(&id) { - if ch.send(msg).is_err() { + if ch.send(stream).is_err() { debug!("Could not send request response, probably because request was interrupted. Dropping response."); } + } else { + debug!("Got unexpected response to request {}, dropping it", id); } } } + +// ---- + +struct CancelOnDrop { + id: RequestID, + query_send: mpsc::UnboundedSender<SendItem>, +} + +impl CancelOnDrop { + fn new(id: RequestID, query_send: mpsc::UnboundedSender<SendItem>) -> Self { + Self { id, query_send } + } + fn for_stream(self, stream: ByteStream) -> CancelOnDropStream { + CancelOnDropStream { + cancel: Some(self), + stream: stream, + } + } +} + +impl Drop for CancelOnDrop { + fn drop(&mut self) { + trace!("cancelling request {}", self.id); + let _ = self.query_send.send(SendItem::Cancel(self.id)); + } +} + +#[pin_project::pin_project] +struct CancelOnDropStream { + cancel: Option<CancelOnDrop>, + #[pin] + stream: ByteStream, +} + +impl Stream for CancelOnDropStream { + type Item = Packet; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<Option<Self::Item>> { + let this = self.project(); + let res = this.stream.poll_next(cx); + if matches!(res, Poll::Ready(None)) { + if let Some(c) = this.cancel.take() { + std::mem::forget(c) + } + } + res + } + + fn size_hint(&self) -> (usize, Option<usize>) { + self.stream.size_hint() + } +} |