aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Auvolat <alex@adnab.me>2020-12-07 13:35:24 +0100
committerAlex Auvolat <alex@adnab.me>2020-12-07 13:35:24 +0100
commit5a9ae8615ee616b11460a046deaa6981b10d69ab (patch)
treef625d976531902fa267c20e7359bda43c452d9c4
parent83789a3076e986782af60ba32b0398414c1c82d7 (diff)
downloadnetapp-5a9ae8615ee616b11460a046deaa6981b10d69ab.tar.gz
netapp-5a9ae8615ee616b11460a046deaa6981b10d69ab.zip
Do not close connections immediately on close signal, await for remaining responses
-rw-r--r--examples/basalt.rs41
-rw-r--r--examples/fullmesh.rs17
-rw-r--r--src/conn.rs137
-rw-r--r--src/lib.rs4
-rw-r--r--src/netapp.rs36
-rw-r--r--src/peering/basalt.rs12
-rw-r--r--src/peering/fullmesh.rs10
-rw-r--r--src/proto.rs150
-rw-r--r--src/util.rs20
9 files changed, 192 insertions, 235 deletions
diff --git a/examples/basalt.rs b/examples/basalt.rs
index 4c86cf8..eaf056b 100644
--- a/examples/basalt.rs
+++ b/examples/basalt.rs
@@ -1,20 +1,20 @@
+use std::io::Write;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
-use std::io::Write;
use log::{debug, info, warn};
-use structopt::StructOpt;
use serde::{Deserialize, Serialize};
+use structopt::StructOpt;
use sodiumoxide::crypto::auth;
use sodiumoxide::crypto::sign::ed25519;
-use netapp::NetApp;
-use netapp::peering::basalt::*;
use netapp::message::*;
+use netapp::peering::basalt::*;
use netapp::proto::*;
+use netapp::NetApp;
#[derive(StructOpt, Debug)]
#[structopt(name = "netapp")]
@@ -52,17 +52,17 @@ async fn main() {
env_logger::Builder::new()
.parse_env("RUST_LOG")
.format(|buf, record| {
- writeln!(buf,
- "{} {} {} {}",
- chrono::Local::now().format("%s%.6f"),
- record.module_path().unwrap_or("_"),
- record.level(),
- record.args()
- )
+ writeln!(
+ buf,
+ "{} {} {} {}",
+ chrono::Local::now().format("%s%.6f"),
+ record.module_path().unwrap_or("_"),
+ record.level(),
+ record.args()
+ )
})
.init();
-
let opt = Opt::from_args();
let netid = match &opt.network_key {
@@ -108,10 +108,12 @@ async fn main() {
|_from: ed25519::PublicKey, msg: ExampleMessage| {
debug!("Got example message: {:?}, sending example response", msg);
async {
- ExampleResponse{example_field: false}
+ ExampleResponse {
+ example_field: false,
+ }
}
- }
- );
+ },
+ );
tokio::join!(
sampling_loop(netapp.clone(), peering.clone()),
@@ -120,8 +122,6 @@ async fn main() {
);
}
-
-
async fn sampling_loop(netapp: Arc<NetApp>, basalt: Arc<Basalt>) {
loop {
tokio::time::delay_for(Duration::from_secs(10)).await;
@@ -132,9 +132,10 @@ async fn sampling_loop(netapp: Arc<NetApp>, basalt: Arc<Basalt>) {
let netapp2 = netapp.clone();
tokio::spawn(async move {
- match netapp2.request(&p, ExampleMessage{
- example_field: 42,
- }, PRIO_NORMAL).await {
+ match netapp2
+ .request(&p, ExampleMessage { example_field: 42 }, PRIO_NORMAL)
+ .await
+ {
Ok(resp) => debug!("Got example response: {:?}", resp),
Err(e) => warn!("Error with example request: {}", e),
}
diff --git a/examples/fullmesh.rs b/examples/fullmesh.rs
index dfacb89..5addcea 100644
--- a/examples/fullmesh.rs
+++ b/examples/fullmesh.rs
@@ -1,5 +1,5 @@
-use std::net::SocketAddr;
use std::io::Write;
+use std::net::SocketAddr;
use log::info;
@@ -32,13 +32,14 @@ async fn main() {
env_logger::Builder::new()
.parse_env("RUST_LOG")
.format(|buf, record| {
- writeln!(buf,
- "{} {} {} {}",
- chrono::Local::now().format("%s%.6f"),
- record.module_path().unwrap_or("_"),
- record.level(),
- record.args()
- )
+ writeln!(
+ buf,
+ "{} {} {} {}",
+ chrono::Local::now().format("%s%.6f"),
+ record.module_path().unwrap_or("_"),
+ record.level(),
+ record.args()
+ )
})
.init();
diff --git a/src/conn.rs b/src/conn.rs
index d4362e5..89bf654 100644
--- a/src/conn.rs
+++ b/src/conn.rs
@@ -1,17 +1,18 @@
use std::collections::HashMap;
use std::net::SocketAddr;
-use std::sync::atomic::{self, AtomicU16};
-use std::sync::Arc;
+use std::sync::atomic::{self, AtomicBool, AtomicU16};
+use std::sync::{Arc, Mutex};
-use async_trait::async_trait;
use bytes::Bytes;
-use log::{debug, trace};
+use log::{debug, error, trace};
use sodiumoxide::crypto::sign::ed25519;
use tokio::io::split;
use tokio::net::TcpStream;
use tokio::sync::{mpsc, oneshot, watch};
+use async_trait::async_trait;
+
use kuska_handshake::async_std::{
handshake_client, handshake_server, BoxStream, TokioCompatExt, TokioCompatExtRead,
TokioCompatExtWrite,
@@ -29,7 +30,7 @@ pub(crate) struct ServerConn {
netapp: Arc<NetApp>,
- resp_send: mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>,
+ resp_send: mpsc::UnboundedSender<Option<(RequestID, RequestPriority, Vec<u8>)>>,
close_send: watch::Sender<bool>,
}
@@ -78,9 +79,20 @@ impl ServerConn {
let conn2 = conn.clone();
let conn3 = conn.clone();
+ let close_recv2 = close_recv.clone();
tokio::try_join!(
- conn2.recv_loop(box_stream_read, close_recv.clone()),
- conn3.send_loop(resp_recv, box_stream_write, close_recv.clone()),
+ async move {
+ tokio::select!(
+ r = conn2.recv_loop(box_stream_read) => r,
+ _ = await_exit(close_recv) => Ok(()),
+ )
+ },
+ async move {
+ tokio::select!(
+ r = conn3.send_loop(resp_recv, box_stream_write) => r,
+ _ = await_exit(close_recv2) => Ok(()),
+ )
+ },
)
.map(|_| ())
.log_err("ServerConn recv_loop/send_loop");
@@ -112,7 +124,7 @@ impl RecvLoop for ServerConn {
let net_handler = &handler.net_handler;
let resp = net_handler(self.peer_pk.clone(), bytes.slice(5..)).await;
self.resp_send
- .send((id, prio, resp))
+ .send(Some((id, prio, resp)))
.log_err("ServerConn recv_handler send resp");
}
}
@@ -121,11 +133,12 @@ pub(crate) struct ClientConn {
pub(crate) remote_addr: SocketAddr,
pub(crate) peer_pk: ed25519::PublicKey,
- query_send: mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>,
+ query_send: mpsc::UnboundedSender<Option<(RequestID, RequestPriority, Vec<u8>)>>,
+
next_query_number: AtomicU16,
- resp_send: mpsc::UnboundedSender<(RequestID, Vec<u8>)>,
- resp_notify_send: mpsc::UnboundedSender<(RequestID, oneshot::Sender<Vec<u8>>)>,
- close_send: watch::Sender<bool>,
+ inflight: Mutex<HashMap<RequestID, oneshot::Sender<Vec<u8>>>>,
+ must_exit: AtomicBool,
+ stop_recv_loop: watch::Sender<bool>,
}
impl ClientConn {
@@ -163,19 +176,17 @@ impl ClientConn {
BoxStream::from_handshake(read, write, handshake, 0x8000).split_read_write();
let (query_send, query_recv) = mpsc::unbounded_channel();
- let (resp_send, resp_recv) = mpsc::unbounded_channel();
- let (resp_notify_send, resp_notify_recv) = mpsc::unbounded_channel();
- let (close_send, close_recv) = watch::channel(false);
+ let (stop_recv_loop, stop_recv_loop_recv) = watch::channel(false);
let conn = Arc::new(ClientConn {
remote_addr,
peer_pk: remote_pk.clone(),
next_query_number: AtomicU16::from(0u16),
query_send,
- resp_send,
- resp_notify_send,
- close_send,
+ inflight: Mutex::new(HashMap::new()),
+ must_exit: AtomicBool::new(false),
+ stop_recv_loop,
});
netapp.connected_as_client(remote_pk.clone(), conn.clone());
@@ -183,11 +194,14 @@ impl ClientConn {
tokio::spawn(async move {
let conn2 = conn.clone();
let conn3 = conn.clone();
- let conn4 = conn.clone();
tokio::try_join!(
- conn2.send_loop(query_recv, box_stream_write, close_recv.clone()),
- conn3.recv_loop(box_stream_read, close_recv.clone()),
- conn4.dispatch_resp(resp_recv, resp_notify_recv, close_recv.clone()),
+ conn2.send_loop(query_recv, box_stream_write),
+ async move {
+ tokio::select!(
+ r = conn3.recv_loop(box_stream_read) => r,
+ _ = await_exit(stop_recv_loop_recv) => Ok(()),
+ )
+ }
)
.map(|_| ())
.log_err("ClientConn send_loop/recv_loop/dispatch_loop");
@@ -199,51 +213,15 @@ impl ClientConn {
}
pub fn close(&self) {
- self.close_send.broadcast(true).unwrap();
- }
-
- async fn dispatch_resp(
- self: Arc<Self>,
- mut resp_recv: mpsc::UnboundedReceiver<(RequestID, Vec<u8>)>,
- mut resp_notify_recv: mpsc::UnboundedReceiver<(RequestID, oneshot::Sender<Vec<u8>>)>,
- mut must_exit: watch::Receiver<bool>,
- ) -> Result<(), Error> {
- let mut resps: HashMap<RequestID, Vec<u8>> = HashMap::new();
- let mut resp_notify: HashMap<RequestID, oneshot::Sender<Vec<u8>>> = HashMap::new();
- while !*must_exit.borrow() {
- tokio::select! {
- resp = resp_recv.recv() => {
- if let Some((id, resp)) = resp {
- trace!("dispatch_resp: got resp to {}, {} bytes", id, resp.len());
- if let Some(ch) = resp_notify.remove(&id) {
- if ch.send(resp).is_err() {
- debug!("Could not dispatch reply (channel probably closed, happens if request was canceled)");
- }
- } else {
- resps.insert(id, resp);
- }
- }
- }
- resp_ch = resp_notify_recv.recv() => {
- if let Some((id, resp_ch)) = resp_ch {
- trace!("dispatch_resp: got resp_ch {}", id);
- if let Some(rs) = resps.remove(&id) {
- if resp_ch.send(rs).is_err() {
- debug!("Could not dispatch reply (channel probably closed, happens if request was canceled)");
- }
- } else {
- resp_notify.insert(id, resp_ch);
- }
- }
- }
- exit = must_exit.recv() => {
- if exit == Some(true) {
- break;
- }
- }
- }
+ self.must_exit.store(true, atomic::Ordering::SeqCst);
+ self.query_send
+ .send(None)
+ .log_err("could not write None in query_send");
+ if self.inflight.lock().unwrap().is_empty() {
+ self.stop_recv_loop
+ .broadcast(true)
+ .log_err("could not write true to stop_recv_loop");
}
- Ok(())
}
pub(crate) async fn request<T>(
@@ -262,10 +240,18 @@ impl ClientConn {
bytes.extend_from_slice(&rmp_to_vec_all_named(&rq)?[..]);
let (resp_send, resp_recv) = oneshot::channel();
- self.resp_notify_send.send((id, resp_send))?;
+ let old = self.inflight.lock().unwrap().insert(id, resp_send);
+ if let Some(old_ch) = old {
+ error!(
+ "Too many inflight requests! RequestID collision. Interrupting previous request."
+ );
+ if old_ch.send(vec![]).is_err() {
+ debug!("Could not send empty response to collisionned request, probably because request was interrupted. Dropping response.");
+ }
+ }
trace!("request: query_send {}, {} bytes", id, bytes.len());
- self.query_send.send((id, prio, bytes))?;
+ self.query_send.send(Some((id, prio, bytes)))?;
let resp = resp_recv.await?;
@@ -279,8 +265,17 @@ impl SendLoop for ClientConn {}
#[async_trait]
impl RecvLoop for ClientConn {
async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>) {
- self.resp_send
- .send((id, msg))
- .log_err("ClientConn::recv_handler");
+ let mut inflight = self.inflight.lock().unwrap();
+ if let Some(ch) = inflight.remove(&id) {
+ if ch.send(msg).is_err() {
+ debug!("Could not send request response, probably because request was interrupted. Dropping response.");
+ }
+ }
+
+ if inflight.is_empty() && self.must_exit.load(atomic::Ordering::SeqCst) {
+ self.stop_recv_loop
+ .broadcast(true)
+ .log_err("could not write true to stop_recv_loop");
+ }
}
}
diff --git a/src/lib.rs b/src/lib.rs
index ba365c7..af8fbb8 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,5 +1,5 @@
//! Netapp is a Rust library that takes care of a few common tasks in distributed software:
-//!
+//!
//! - establishing secure connections
//! - managing connection lifetime, reconnecting on failure
//! - checking peer's state
@@ -18,8 +18,8 @@
pub mod error;
pub mod util;
-pub mod proto;
pub mod message;
+pub mod proto;
mod conn;
diff --git a/src/netapp.rs b/src/netapp.rs
index bf9a3f0..967105e 100644
--- a/src/netapp.rs
+++ b/src/netapp.rs
@@ -53,7 +53,7 @@ pub struct NetApp {
server_conns: RwLock<HashMap<ed25519::PublicKey, Arc<ServerConn>>>,
client_conns: RwLock<HashMap<ed25519::PublicKey, Arc<ClientConn>>>,
-
+
pub(crate) msg_handlers: ArcSwap<HashMap<MessageKind, Arc<Handler>>>,
on_connected_handler:
ArcSwapOption<Box<dyn Fn(ed25519::PublicKey, SocketAddr, bool) + Send + Sync>>,
@@ -133,18 +133,22 @@ impl NetApp {
/// been successfully established. Do not set this if using a peering strategy,
/// as the peering strategy will need to set this itself.
pub fn on_connected<F>(&self, handler: F)
- where F: Fn(ed25519::PublicKey, SocketAddr, bool) + Sized + Send + Sync + 'static
- {
- self.on_connected_handler.store(Some(Arc::new(Box::new(handler))));
+ where
+ F: Fn(ed25519::PublicKey, SocketAddr, bool) + Sized + Send + Sync + 'static,
+ {
+ self.on_connected_handler
+ .store(Some(Arc::new(Box::new(handler))));
}
/// Set the handler to be called when an existing connection (incoming or outgoing) has
/// been closed by either party. Do not set this if using a peering strategy,
/// as the peering strategy will need to set this itself.
pub fn on_disconnected<F>(&self, handler: F)
- where F: Fn(ed25519::PublicKey, bool) + Sized + Send + Sync + 'static
- {
- self.on_disconnected_handler.store(Some(Arc::new(Box::new(handler))));
+ where
+ F: Fn(ed25519::PublicKey, bool) + Sized + Send + Sync + 'static,
+ {
+ self.on_disconnected_handler
+ .store(Some(Arc::new(Box::new(handler))));
}
/// Add a handler for a certain message type. Note that only one handler
@@ -240,11 +244,13 @@ impl NetApp {
pub fn disconnect(self: &Arc<Self>, pk: &ed25519::PublicKey) {
// If pk is ourself, we're not supposed to have a connection open
if *pk != self.pubkey {
- let conn = self.client_conns.read().unwrap().remove(pk);
+ let conn = self.client_conns.write().unwrap().remove(pk);
if let Some(c) = conn {
- debug!("Closing connection to {} ({})",
- hex::encode(c.peer_pk),
- c.remote_addr);
+ debug!(
+ "Closing connection to {} ({})",
+ hex::encode(c.peer_pk),
+ c.remote_addr
+ );
c.close();
} else {
return;
@@ -268,9 +274,11 @@ impl NetApp {
pub fn server_disconnect(self: &Arc<Self>, pk: &ed25519::PublicKey) {
let conn = self.server_conns.read().unwrap().get(pk).cloned();
if let Some(c) = conn {
- debug!("Closing incoming connection from {} ({})",
- hex::encode(c.peer_pk),
- c.remote_addr);
+ debug!(
+ "Closing incoming connection from {} ({})",
+ hex::encode(c.peer_pk),
+ c.remote_addr
+ );
c.close();
}
}
diff --git a/src/peering/basalt.rs b/src/peering/basalt.rs
index 615b559..4aa34f6 100644
--- a/src/peering/basalt.rs
+++ b/src/peering/basalt.rs
@@ -3,7 +3,7 @@ use std::net::SocketAddr;
use std::sync::{Arc, RwLock};
use std::time::Duration;
-use log::{trace, debug, info, warn};
+use log::{debug, info, trace, warn};
use lru::LruCache;
use rand::{thread_rng, Rng};
use serde::{Deserialize, Serialize};
@@ -267,15 +267,13 @@ impl Basalt {
netapp.on_connected(
move |pk: ed25519::PublicKey, addr: SocketAddr, is_incoming: bool| {
basalt2.on_connected(pk, addr, is_incoming);
- }
+ },
);
let basalt2 = basalt.clone();
- netapp.on_disconnected(
- move |pk: ed25519::PublicKey, is_incoming: bool| {
- basalt2.on_disconnected(pk, is_incoming);
- },
- );
+ netapp.on_disconnected(move |pk: ed25519::PublicKey, is_incoming: bool| {
+ basalt2.on_disconnected(pk, is_incoming);
+ });
let basalt2 = basalt.clone();
netapp.add_msg_handler::<PullMessage, _, _>(
diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs
index 1b26489..d6ca08a 100644
--- a/src/peering/fullmesh.rs
+++ b/src/peering/fullmesh.rs
@@ -185,12 +185,10 @@ impl FullMeshPeeringStrategy {
);
let strat2 = strat.clone();
- netapp.on_disconnected(
- move |pk: ed25519::PublicKey, is_incoming: bool| {
- let strat2 = strat2.clone();
- tokio::spawn(strat2.on_disconnected(pk, is_incoming));
- },
- );
+ netapp.on_disconnected(move |pk: ed25519::PublicKey, is_incoming: bool| {
+ let strat2 = strat2.clone();
+ tokio::spawn(strat2.on_disconnected(pk, is_incoming));
+ });
strat
}
diff --git a/src/proto.rs b/src/proto.rs
index b044280..d90042f 100644
--- a/src/proto.rs
+++ b/src/proto.rs
@@ -3,14 +3,14 @@ use std::sync::Arc;
use log::trace;
-use async_trait::async_trait;
-
use async_std::io::prelude::WriteExt;
use async_std::io::ReadExt;
use tokio::io::{ReadHalf, WriteHalf};
use tokio::net::TcpStream;
-use tokio::sync::{mpsc, watch};
+use tokio::sync::mpsc;
+
+use async_trait::async_trait;
use crate::error::*;
@@ -85,26 +85,33 @@ impl SendQueue {
}
}
}
+ fn is_empty(&self) -> bool {
+ self.items.iter().all(|(_k, v)| v.is_empty())
+ }
}
#[async_trait]
pub(crate) trait SendLoop: Sync {
async fn send_loop(
self: Arc<Self>,
- mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec<u8>)>,
+ mut msg_recv: mpsc::UnboundedReceiver<Option<(RequestID, RequestPriority, Vec<u8>)>>,
mut write: BoxStreamWrite<TokioCompat<WriteHalf<TcpStream>>>,
- mut must_exit: watch::Receiver<bool>,
) -> Result<(), Error> {
let mut sending = SendQueue::new();
- while !*must_exit.borrow() {
- if let Ok((id, prio, data)) = msg_recv.try_recv() {
- trace!("send_loop: got {}, {} bytes", id, data.len());
- sending.push(SendQueueItem {
- id,
- prio,
- data,
- cursor: 0,
- });
+ let mut should_exit = false;
+ while !should_exit || !sending.is_empty() {
+ if let Ok(sth) = msg_recv.try_recv() {
+ if let Some((id, prio, data)) = sth {
+ trace!("send_loop: got {}, {} bytes", id, data.len());
+ sending.push(SendQueueItem {
+ id,
+ prio,
+ data,
+ cursor: 0,
+ });
+ } else {
+ should_exit = true;
+ }
} else if let Some(mut item) = sending.pop() {
trace!(
"send_loop: sending bytes for {} ({} bytes, {} already sent)",
@@ -113,33 +120,14 @@ pub(crate) trait SendLoop: Sync {
item.cursor
);
let header_id = u16::to_be_bytes(item.id);
- if write_all_or_exit(&header_id[..], &mut write, &mut must_exit)
- .await?
- .is_none()
- {
- break;
- }
+ write.write_all(&header_id[..]).await?;
if item.data.len() - item.cursor > MAX_CHUNK_SIZE {
let header_size = u16::to_be_bytes(MAX_CHUNK_SIZE as u16 | 0x8000);
- if write_all_or_exit(&header_size[..], &mut write, &mut must_exit)
- .await?
- .is_none()
- {
- break;
- }
+ write.write_all(&header_size[..]).await?;
let new_cursor = item.cursor + MAX_CHUNK_SIZE as usize;
- if write_all_or_exit(
- &item.data[item.cursor..new_cursor],
- &mut write,
- &mut must_exit,
- )
- .await?
- .is_none()
- {
- break;
- }
+ write.write_all(&item.data[item.cursor..new_cursor]).await?;
item.cursor = new_cursor;
sending.push(item);
@@ -147,33 +135,27 @@ pub(crate) trait SendLoop: Sync {
let send_len = (item.data.len() - item.cursor) as u16;
let header_size = u16::to_be_bytes(send_len);
- if write_all_or_exit(&header_size[..], &mut write, &mut must_exit)
- .await?
- .is_none()
- {
- break;
- }
+ write.write_all(&header_size[..]).await?;
- if write_all_or_exit(&item.data[item.cursor..], &mut write, &mut must_exit)
- .await?
- .is_none()
- {
- break;
- }
+ write.write_all(&item.data[item.cursor..]).await?;
}
write.flush().await.log_err("Could not flush in send_loop");
} else {
- let (id, prio, data) = msg_recv
+ let sth = msg_recv
.recv()
.await
.ok_or(Error::Message("Connection closed.".into()))?;
- trace!("send_loop: got {}, {} bytes", id, data.len());
- sending.push(SendQueueItem {
- id,
- prio,
- data,
- cursor: 0,
- });
+ if let Some((id, prio, data)) = sth {
+ trace!("send_loop: got {}, {} bytes", id, data.len());
+ sending.push(SendQueueItem {
+ id,
+ prio,
+ data,
+ cursor: 0,
+ });
+ } else {
+ should_exit = true;
+ }
}
}
Ok(())
@@ -182,33 +164,23 @@ pub(crate) trait SendLoop: Sync {
#[async_trait]
pub(crate) trait RecvLoop: Sync + 'static {
+ // Returns true if we should stop receiving after this
async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>);
async fn recv_loop(
self: Arc<Self>,
mut read: BoxStreamRead<TokioCompat<ReadHalf<TcpStream>>>,
- mut must_exit: watch::Receiver<bool>,
) -> Result<(), Error> {
let mut receiving = HashMap::new();
- while !*must_exit.borrow() {
+ loop {
trace!("recv_loop: reading packet");
let mut header_id = [0u8; 2];
- if read_exact_or_exit(&mut header_id[..], &mut read, &mut must_exit)
- .await?
- .is_none()
- {
- break;
- }
+ read.read_exact(&mut header_id[..]).await?;
let id = RequestID::from_be_bytes(header_id);
trace!("recv_loop: got header id: {:04x}", id);
let mut header_size = [0u8; 2];
- if read_exact_or_exit(&mut header_size[..], &mut read, &mut must_exit)
- .await?
- .is_none()
- {
- break;
- }
+ read.read_exact(&mut header_size[..]).await?;
let size = RequestID::from_be_bytes(header_size);
trace!("recv_loop: got header size: {:04x}", id);
@@ -216,12 +188,7 @@ pub(crate) trait RecvLoop: Sync + 'static {
let size = size & !0x8000;
let mut next_slice = vec![0; size as usize];
- if read_exact_or_exit(&mut next_slice[..], &mut read, &mut must_exit)
- .await?
- .is_none()
- {
- break;
- }
+ read.read_exact(&mut next_slice[..]).await?;
trace!("recv_loop: read {} bytes", size);
let mut msg_bytes = receiving.remove(&id).unwrap_or(vec![]);
@@ -233,36 +200,5 @@ pub(crate) trait RecvLoop: Sync + 'static {
tokio::spawn(self.clone().recv_handler(id, msg_bytes));
}
}
- Ok(())
- }
-}
-
-async fn read_exact_or_exit(
- buf: &mut [u8],
- read: &mut BoxStreamRead<TokioCompat<ReadHalf<TcpStream>>>,
- must_exit: &mut watch::Receiver<bool>,
-) -> Result<Option<()>, Error> {
- tokio::select!(
- res = read.read_exact(buf) => Ok(Some(res?)),
- _ = await_exit(must_exit) => Ok(None),
- )
-}
-
-async fn write_all_or_exit(
- buf: &[u8],
- write: &mut BoxStreamWrite<TokioCompat<WriteHalf<TcpStream>>>,
- must_exit: &mut watch::Receiver<bool>,
-) -> Result<Option<()>, Error> {
- tokio::select!(
- res = write.write_all(buf) => Ok(Some(res?)),
- _ = await_exit(must_exit) => Ok(None),
- )
-}
-
-async fn await_exit(must_exit: &mut watch::Receiver<bool>) {
- loop {
- if must_exit.recv().await == Some(true) {
- return;
- }
}
}
diff --git a/src/util.rs b/src/util.rs
index f09a3bc..017ef00 100644
--- a/src/util.rs
+++ b/src/util.rs
@@ -1,5 +1,7 @@
use serde::Serialize;
+use tokio::sync::watch;
+
/// Utility function: encodes any serializable value in MessagePack binary format
/// using the RMP library.
///
@@ -16,3 +18,21 @@ where
val.serialize(&mut se)?;
Ok(wr)
}
+
+/// This async function returns only when a true signal was received
+/// from a watcher that tells us when to exit.
+/// Usefull in a select statement to interrupt another
+/// future:
+/// ```
+/// select!(
+/// _ = a_long_task() => Success,
+/// _ = await_exit(must_exit) => Interrupted,
+/// )
+/// ```
+pub async fn await_exit(mut must_exit: watch::Receiver<bool>) {
+ loop {
+ if must_exit.recv().await == Some(true) {
+ return;
+ }
+ }
+}