aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/client.rs11
-rw-r--r--src/error.rs4
-rw-r--r--src/recv.rs38
-rw-r--r--src/server.rs16
4 files changed, 37 insertions, 32 deletions
diff --git a/src/client.rs b/src/client.rs
index 42eeaa3..d51236b 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -8,7 +8,6 @@ use async_trait::async_trait;
use bytes::Bytes;
use log::{debug, error, trace};
-use futures::channel::mpsc::{unbounded, UnboundedReceiver};
use futures::io::AsyncReadExt;
use kuska_handshake::async_std::{handshake_client, BoxStream};
use tokio::net::TcpStream;
@@ -39,7 +38,7 @@ pub(crate) struct ClientConn {
query_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, ByteStream)>>,
next_query_number: AtomicU32,
- inflight: Mutex<HashMap<RequestID, oneshot::Sender<UnboundedReceiver<Packet>>>>,
+ inflight: Mutex<HashMap<RequestID, oneshot::Sender<ByteStream>>>,
}
impl ClientConn {
@@ -175,7 +174,9 @@ impl ClientConn {
error!(
"Too many inflight requests! RequestID collision. Interrupting previous request."
);
- let _ = old_ch.send(unbounded().1);
+ let _ = old_ch.send(Box::pin(futures::stream::once(async move {
+ Err(Error::IdCollision.code())
+ })));
}
trace!(
@@ -199,7 +200,7 @@ impl ClientConn {
}
}
- let resp_enc = RespEnc::decode(Box::pin(stream)).await?;
+ let resp_enc = RespEnc::decode(stream).await?;
trace!("request response {}", id);
Resp::from_enc(resp_enc)
}
@@ -209,7 +210,7 @@ impl SendLoop for ClientConn {}
#[async_trait]
impl RecvLoop for ClientConn {
- fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>) {
+ fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream) {
trace!("ClientConn recv_handler {}", id);
let mut inflight = self.inflight.lock().unwrap();
diff --git a/src/error.rs b/src/error.rs
index 665647c..f374341 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -28,6 +28,9 @@ pub enum Error {
#[error(display = "Framing protocol error")]
Framing,
+ #[error(display = "Request ID collision")]
+ IdCollision,
+
#[error(display = "{}", _0)]
Message(String),
@@ -56,6 +59,7 @@ impl Error {
Self::Framing => 13,
Self::NoHandler => 20,
Self::ConnectionClosed => 21,
+ Self::IdCollision => 22,
Self::Handshake(_) => 30,
Self::VersionMismatch(_) => 31,
Self::Remote(c, _) => *c,
diff --git a/src/recv.rs b/src/recv.rs
index 19288f2..b2f5530 100644
--- a/src/recv.rs
+++ b/src/recv.rs
@@ -5,8 +5,8 @@ use async_trait::async_trait;
use bytes::Bytes;
use log::trace;
-use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
use futures::AsyncReadExt;
+use tokio::sync::mpsc;
use crate::error::*;
use crate::send::*;
@@ -15,33 +15,28 @@ use crate::stream::*;
/// Structure to warn when the sender is dropped before end of stream was reached, like when
/// connection to some remote drops while transmitting data
struct Sender {
- inner: UnboundedSender<Packet>,
- closed: bool,
+ inner: Option<mpsc::Sender<Packet>>,
}
impl Sender {
- fn new(inner: UnboundedSender<Packet>) -> Self {
- Sender {
- inner,
- closed: false,
- }
+ fn new(inner: mpsc::Sender<Packet>) -> Self {
+ Sender { inner: Some(inner) }
}
- fn send(&self, packet: Packet) {
- let _ = self.inner.unbounded_send(packet);
+ async fn send(&self, packet: Packet) {
+ let _ = self.inner.as_ref().unwrap().send(packet).await;
}
fn end(&mut self) {
- self.closed = true;
+ self.inner = None;
}
}
impl Drop for Sender {
fn drop(&mut self) {
- if !self.closed {
- self.send(Err(255));
+ if let Some(inner) = self.inner.take() {
+ let _ = inner.blocking_send(Err(255));
}
- self.inner.close_channel();
}
}
@@ -54,7 +49,7 @@ impl Drop for Sender {
/// the full message is passed to the receive handler.
#[async_trait]
pub(crate) trait RecvLoop: Sync + 'static {
- fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>);
+ fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream);
async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error>
where
@@ -92,14 +87,17 @@ pub(crate) trait RecvLoop: Sync + 'static {
let mut sender = if let Some(send) = streams.remove(&(id)) {
send
} else {
- let (send, recv) = unbounded();
- self.recv_handler(id, recv);
+ let (send, recv) = mpsc::channel(4);
+ self.recv_handler(
+ id,
+ Box::pin(tokio_stream::wrappers::ReceiverStream::new(recv)),
+ );
Sender::new(send)
};
- // if we get an error, the receiving end is disconnected. We still need to
- // reach eos before dropping this sender
- sender.send(packet);
+ // If we get an error, the receiving end is disconnected.
+ // We still need to reach eos before dropping this sender
+ let _ = sender.send(packet).await;
if has_cont {
streams.insert(id, sender);
diff --git a/src/server.rs b/src/server.rs
index ae1196c..4b232af 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -5,7 +5,6 @@ use arc_swap::ArcSwapOption;
use async_trait::async_trait;
use log::{debug, trace};
-use futures::channel::mpsc::UnboundedReceiver;
use futures::io::{AsyncReadExt, AsyncWriteExt};
use kuska_handshake::async_std::{handshake_server, BoxStream};
use tokio::net::TcpStream;
@@ -171,21 +170,24 @@ impl SendLoop for ServerConn {}
#[async_trait]
impl RecvLoop for ServerConn {
- fn recv_handler(self: &Arc<Self>, id: RequestID, stream: UnboundedReceiver<Packet>) {
+ fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream) {
let resp_send = self.resp_send.load_full().unwrap();
let self2 = self.clone();
tokio::spawn(async move {
trace!("ServerConn recv_handler {}", id);
- let (prio, resp_enc) = match ReqEnc::decode(Box::pin(stream)).await {
+ let (prio, resp_enc) = match ReqEnc::decode(stream).await {
Ok(req_enc) => {
let prio = req_enc.prio;
let resp = self2.recv_handler_aux(req_enc).await;
- (prio, match resp {
- Ok(resp_enc) => resp_enc,
- Err(e) => RespEnc::from_err(e),
- })
+ (
+ prio,
+ match resp {
+ Ok(resp_enc) => resp_enc,
+ Err(e) => RespEnc::from_err(e),
+ },
+ )
}
Err(e) => (PRIO_NORMAL, RespEnc::from_err(e)),
};