aboutsummaryrefslogtreecommitdiff
path: root/src/client.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/client.rs')
-rw-r--r--src/client.rs158
1 files changed, 101 insertions, 57 deletions
diff --git a/src/client.rs b/src/client.rs
index 5c5a05b..d82c91e 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -1,12 +1,18 @@
-use std::borrow::Borrow;
use std::collections::HashMap;
use std::net::SocketAddr;
+use std::pin::Pin;
use std::sync::atomic::{self, AtomicU32};
use std::sync::{Arc, Mutex};
+use std::task::Poll;
use arc_swap::ArcSwapOption;
+use async_trait::async_trait;
+use bytes::Bytes;
use log::{debug, error, trace};
+use futures::io::AsyncReadExt;
+use futures::Stream;
+use kuska_handshake::async_std::{handshake_client, BoxStream};
use tokio::net::TcpStream;
use tokio::select;
use tokio::sync::{mpsc, oneshot, watch};
@@ -20,27 +26,22 @@ use opentelemetry::{
#[cfg(feature = "telemetry")]
use opentelemetry_contrib::trace::propagator::binary::*;
-use futures::io::AsyncReadExt;
-
-use async_trait::async_trait;
-
-use kuska_handshake::async_std::{handshake_client, BoxStream};
-
-use crate::endpoint::*;
use crate::error::*;
+use crate::message::*;
use crate::netapp::*;
-use crate::proto::*;
-use crate::proto2::*;
+use crate::recv::*;
+use crate::send::*;
+use crate::stream::*;
use crate::util::*;
pub(crate) struct ClientConn {
pub(crate) remote_addr: SocketAddr,
pub(crate) peer_id: NodeID,
- query_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>>,
+ query_send: ArcSwapOption<mpsc::UnboundedSender<SendItem>>,
next_query_number: AtomicU32,
- inflight: Mutex<HashMap<RequestID, oneshot::Sender<Vec<u8>>>>,
+ inflight: Mutex<HashMap<RequestID, oneshot::Sender<ByteStream>>>,
}
impl ClientConn {
@@ -139,15 +140,14 @@ impl ClientConn {
self.query_send.store(None);
}
- pub(crate) async fn call<T, B>(
+ pub(crate) async fn call<T>(
self: Arc<Self>,
- rq: B,
+ req: Req<T>,
path: &str,
prio: RequestPriority,
- ) -> Result<<T as Message>::Response, Error>
+ ) -> Result<Resp<T>, Error>
where
T: Message,
- B: Borrow<T>,
{
let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?;
@@ -162,24 +162,16 @@ impl ClientConn {
.with_kind(SpanKind::Client)
.start(&tracer);
let propagator = BinaryPropagator::new();
- let telemetry_id = Some(propagator.to_bytes(span.span_context()).to_vec());
+ let telemetry_id: Bytes = propagator.to_bytes(span.span_context()).to_vec().into();
} else {
- let telemetry_id: Option<Vec<u8>> = None;
+ let telemetry_id: Bytes = Bytes::new();
}
};
// Encode request
- let body = rmp_to_vec_all_named(rq.borrow())?;
- drop(rq);
-
- let request = QueryMessage {
- prio,
- path: path.as_bytes(),
- telemetry_id,
- body: &body[..],
- };
- let bytes = request.encode();
- drop(body);
+ let req_enc = req.into_enc(prio, path.as_bytes().to_vec().into(), telemetry_id);
+ let req_msg_len = req_enc.msg.len();
+ let (req_stream, req_order) = req_enc.encode();
// Send request through
let (resp_send, resp_recv) = oneshot::channel();
@@ -188,46 +180,41 @@ impl ClientConn {
error!(
"Too many inflight requests! RequestID collision. Interrupting previous request."
);
- if old_ch.send(vec![]).is_err() {
- debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response.");
- }
+ let _ = old_ch.send(Box::pin(futures::stream::once(async move {
+ Err(std::io::Error::new(
+ std::io::ErrorKind::Other,
+ "RequestID collision, too many inflight requests",
+ ))
+ })));
}
- trace!("request: query_send {}, {} bytes", id, bytes.len());
+ debug!(
+ "request: query_send {}, path {}, prio {} (serialized message: {} bytes)",
+ id, path, prio, req_msg_len
+ );
#[cfg(feature = "telemetry")]
- span.set_attribute(KeyValue::new("len_query", bytes.len() as i64));
+ span.set_attribute(KeyValue::new("len_query_msg", req_msg_len as i64));
- query_send.send((id, prio, bytes))?;
+ query_send.send(SendItem::Stream(id, prio, req_order, req_stream))?;
+
+ let canceller = CancelOnDrop::new(id, query_send.as_ref().clone());
cfg_if::cfg_if! {
if #[cfg(feature = "telemetry")] {
- let resp = resp_recv
+ let stream = resp_recv
.with_context(Context::current_with_span(span))
.await?;
} else {
- let resp = resp_recv.await?;
+ let stream = resp_recv.await?;
}
}
- if resp.is_empty() {
- return Err(Error::Message(
- "Response is 0 bytes, either a collision or a protocol error".into(),
- ));
- }
-
- trace!("request response {}: ", id);
+ let stream = Box::pin(canceller.for_stream(stream));
- let code = resp[0];
- if code == 0 {
- Ok(rmp_serde::decode::from_read_ref::<
- _,
- <T as Message>::Response,
- >(&resp[1..])?)
- } else {
- let msg = String::from_utf8(resp[1..].to_vec()).unwrap_or_default();
- Err(Error::Remote(code, msg))
- }
+ let resp_enc = RespEnc::decode(stream).await?;
+ debug!("client: got response to request {} (path {})", id, path);
+ Resp::from_enc(resp_enc)
}
}
@@ -235,14 +222,71 @@ impl SendLoop for ClientConn {}
#[async_trait]
impl RecvLoop for ClientConn {
- fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>) {
- trace!("ClientConn recv_handler {} ({} bytes)", id, msg.len());
+ fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream) {
+ trace!("ClientConn recv_handler {}", id);
let mut inflight = self.inflight.lock().unwrap();
if let Some(ch) = inflight.remove(&id) {
- if ch.send(msg).is_err() {
+ if ch.send(stream).is_err() {
debug!("Could not send request response, probably because request was interrupted. Dropping response.");
}
+ } else {
+ debug!("Got unexpected response to request {}, dropping it", id);
}
}
}
+
+// ----
+
+struct CancelOnDrop {
+ id: RequestID,
+ query_send: mpsc::UnboundedSender<SendItem>,
+}
+
+impl CancelOnDrop {
+ fn new(id: RequestID, query_send: mpsc::UnboundedSender<SendItem>) -> Self {
+ Self { id, query_send }
+ }
+ fn for_stream(self, stream: ByteStream) -> CancelOnDropStream {
+ CancelOnDropStream {
+ cancel: Some(self),
+ stream: stream,
+ }
+ }
+}
+
+impl Drop for CancelOnDrop {
+ fn drop(&mut self) {
+ trace!("cancelling request {}", self.id);
+ let _ = self.query_send.send(SendItem::Cancel(self.id));
+ }
+}
+
+#[pin_project::pin_project]
+struct CancelOnDropStream {
+ cancel: Option<CancelOnDrop>,
+ #[pin]
+ stream: ByteStream,
+}
+
+impl Stream for CancelOnDropStream {
+ type Item = Packet;
+
+ fn poll_next(
+ self: Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
+ let this = self.project();
+ let res = this.stream.poll_next(cx);
+ if matches!(res, Poll::Ready(None)) {
+ if let Some(c) = this.cancel.take() {
+ std::mem::forget(c)
+ }
+ }
+ res
+ }
+
+ fn size_hint(&self) -> (usize, Option<usize>) {
+ self.stream.size_hint()
+ }
+}