aboutsummaryrefslogtreecommitdiff
path: root/src/conn.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/conn.rs')
-rw-r--r--src/conn.rs137
1 files changed, 66 insertions, 71 deletions
diff --git a/src/conn.rs b/src/conn.rs
index d4362e5..89bf654 100644
--- a/src/conn.rs
+++ b/src/conn.rs
@@ -1,17 +1,18 @@
use std::collections::HashMap;
use std::net::SocketAddr;
-use std::sync::atomic::{self, AtomicU16};
-use std::sync::Arc;
+use std::sync::atomic::{self, AtomicBool, AtomicU16};
+use std::sync::{Arc, Mutex};
-use async_trait::async_trait;
use bytes::Bytes;
-use log::{debug, trace};
+use log::{debug, error, trace};
use sodiumoxide::crypto::sign::ed25519;
use tokio::io::split;
use tokio::net::TcpStream;
use tokio::sync::{mpsc, oneshot, watch};
+use async_trait::async_trait;
+
use kuska_handshake::async_std::{
handshake_client, handshake_server, BoxStream, TokioCompatExt, TokioCompatExtRead,
TokioCompatExtWrite,
@@ -29,7 +30,7 @@ pub(crate) struct ServerConn {
netapp: Arc<NetApp>,
- resp_send: mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>,
+ resp_send: mpsc::UnboundedSender<Option<(RequestID, RequestPriority, Vec<u8>)>>,
close_send: watch::Sender<bool>,
}
@@ -78,9 +79,20 @@ impl ServerConn {
let conn2 = conn.clone();
let conn3 = conn.clone();
+ let close_recv2 = close_recv.clone();
tokio::try_join!(
- conn2.recv_loop(box_stream_read, close_recv.clone()),
- conn3.send_loop(resp_recv, box_stream_write, close_recv.clone()),
+ async move {
+ tokio::select!(
+ r = conn2.recv_loop(box_stream_read) => r,
+ _ = await_exit(close_recv) => Ok(()),
+ )
+ },
+ async move {
+ tokio::select!(
+ r = conn3.send_loop(resp_recv, box_stream_write) => r,
+ _ = await_exit(close_recv2) => Ok(()),
+ )
+ },
)
.map(|_| ())
.log_err("ServerConn recv_loop/send_loop");
@@ -112,7 +124,7 @@ impl RecvLoop for ServerConn {
let net_handler = &handler.net_handler;
let resp = net_handler(self.peer_pk.clone(), bytes.slice(5..)).await;
self.resp_send
- .send((id, prio, resp))
+ .send(Some((id, prio, resp)))
.log_err("ServerConn recv_handler send resp");
}
}
@@ -121,11 +133,12 @@ pub(crate) struct ClientConn {
pub(crate) remote_addr: SocketAddr,
pub(crate) peer_pk: ed25519::PublicKey,
- query_send: mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>,
+ query_send: mpsc::UnboundedSender<Option<(RequestID, RequestPriority, Vec<u8>)>>,
+
next_query_number: AtomicU16,
- resp_send: mpsc::UnboundedSender<(RequestID, Vec<u8>)>,
- resp_notify_send: mpsc::UnboundedSender<(RequestID, oneshot::Sender<Vec<u8>>)>,
- close_send: watch::Sender<bool>,
+ inflight: Mutex<HashMap<RequestID, oneshot::Sender<Vec<u8>>>>,
+ must_exit: AtomicBool,
+ stop_recv_loop: watch::Sender<bool>,
}
impl ClientConn {
@@ -163,19 +176,17 @@ impl ClientConn {
BoxStream::from_handshake(read, write, handshake, 0x8000).split_read_write();
let (query_send, query_recv) = mpsc::unbounded_channel();
- let (resp_send, resp_recv) = mpsc::unbounded_channel();
- let (resp_notify_send, resp_notify_recv) = mpsc::unbounded_channel();
- let (close_send, close_recv) = watch::channel(false);
+ let (stop_recv_loop, stop_recv_loop_recv) = watch::channel(false);
let conn = Arc::new(ClientConn {
remote_addr,
peer_pk: remote_pk.clone(),
next_query_number: AtomicU16::from(0u16),
query_send,
- resp_send,
- resp_notify_send,
- close_send,
+ inflight: Mutex::new(HashMap::new()),
+ must_exit: AtomicBool::new(false),
+ stop_recv_loop,
});
netapp.connected_as_client(remote_pk.clone(), conn.clone());
@@ -183,11 +194,14 @@ impl ClientConn {
tokio::spawn(async move {
let conn2 = conn.clone();
let conn3 = conn.clone();
- let conn4 = conn.clone();
tokio::try_join!(
- conn2.send_loop(query_recv, box_stream_write, close_recv.clone()),
- conn3.recv_loop(box_stream_read, close_recv.clone()),
- conn4.dispatch_resp(resp_recv, resp_notify_recv, close_recv.clone()),
+ conn2.send_loop(query_recv, box_stream_write),
+ async move {
+ tokio::select!(
+ r = conn3.recv_loop(box_stream_read) => r,
+ _ = await_exit(stop_recv_loop_recv) => Ok(()),
+ )
+ }
)
.map(|_| ())
.log_err("ClientConn send_loop/recv_loop/dispatch_loop");
@@ -199,51 +213,15 @@ impl ClientConn {
}
pub fn close(&self) {
- self.close_send.broadcast(true).unwrap();
- }
-
- async fn dispatch_resp(
- self: Arc<Self>,
- mut resp_recv: mpsc::UnboundedReceiver<(RequestID, Vec<u8>)>,
- mut resp_notify_recv: mpsc::UnboundedReceiver<(RequestID, oneshot::Sender<Vec<u8>>)>,
- mut must_exit: watch::Receiver<bool>,
- ) -> Result<(), Error> {
- let mut resps: HashMap<RequestID, Vec<u8>> = HashMap::new();
- let mut resp_notify: HashMap<RequestID, oneshot::Sender<Vec<u8>>> = HashMap::new();
- while !*must_exit.borrow() {
- tokio::select! {
- resp = resp_recv.recv() => {
- if let Some((id, resp)) = resp {
- trace!("dispatch_resp: got resp to {}, {} bytes", id, resp.len());
- if let Some(ch) = resp_notify.remove(&id) {
- if ch.send(resp).is_err() {
- debug!("Could not dispatch reply (channel probably closed, happens if request was canceled)");
- }
- } else {
- resps.insert(id, resp);
- }
- }
- }
- resp_ch = resp_notify_recv.recv() => {
- if let Some((id, resp_ch)) = resp_ch {
- trace!("dispatch_resp: got resp_ch {}", id);
- if let Some(rs) = resps.remove(&id) {
- if resp_ch.send(rs).is_err() {
- debug!("Could not dispatch reply (channel probably closed, happens if request was canceled)");
- }
- } else {
- resp_notify.insert(id, resp_ch);
- }
- }
- }
- exit = must_exit.recv() => {
- if exit == Some(true) {
- break;
- }
- }
- }
+ self.must_exit.store(true, atomic::Ordering::SeqCst);
+ self.query_send
+ .send(None)
+ .log_err("could not write None in query_send");
+ if self.inflight.lock().unwrap().is_empty() {
+ self.stop_recv_loop
+ .broadcast(true)
+ .log_err("could not write true to stop_recv_loop");
}
- Ok(())
}
pub(crate) async fn request<T>(
@@ -262,10 +240,18 @@ impl ClientConn {
bytes.extend_from_slice(&rmp_to_vec_all_named(&rq)?[..]);
let (resp_send, resp_recv) = oneshot::channel();
- self.resp_notify_send.send((id, resp_send))?;
+ let old = self.inflight.lock().unwrap().insert(id, resp_send);
+ if let Some(old_ch) = old {
+ error!(
+ "Too many inflight requests! RequestID collision. Interrupting previous request."
+ );
+ if old_ch.send(vec![]).is_err() {
+ debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response.");
+ }
+ }
trace!("request: query_send {}, {} bytes", id, bytes.len());
- self.query_send.send((id, prio, bytes))?;
+ self.query_send.send(Some((id, prio, bytes)))?;
let resp = resp_recv.await?;
@@ -279,8 +265,17 @@ impl SendLoop for ClientConn {}
#[async_trait]
impl RecvLoop for ClientConn {
async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>) {
- self.resp_send
- .send((id, msg))
- .log_err("ClientConn::recv_handler");
+ let mut inflight = self.inflight.lock().unwrap();
+ if let Some(ch) = inflight.remove(&id) {
+ if ch.send(msg).is_err() {
+ debug!("Could not send request response, probably because request was interrupted. Dropping response.");
+ }
+ }
+
+ if inflight.is_empty() && self.must_exit.load(atomic::Ordering::SeqCst) {
+ self.stop_recv_loop
+ .broadcast(true)
+ .log_err("could not write true to stop_recv_loop");
+ }
}
}