aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Auvolat <alex@adnab.me>2021-10-13 17:12:13 +0200
committerAlex Auvolat <alex@adnab.me>2021-10-13 17:12:13 +0200
commit70839d70d86354232f168e63ce4062219acb85c7 (patch)
tree9c956af0339aa048f487c3a4e54c320be8d13647
parent8dede69dee20b812ad1dcab5b374c60232409f4f (diff)
downloadnetapp-70839d70d86354232f168e63ce4062219acb85c7.tar.gz
netapp-70839d70d86354232f168e63ce4062219acb85c7.zip
Try to handle termination and closing of stuff properly
-rw-r--r--Cargo.lock23
-rw-r--r--Cargo.toml3
-rw-r--r--Makefile2
-rw-r--r--examples/basalt.rs13
-rw-r--r--examples/fullmesh.rs7
-rw-r--r--src/client.rs78
-rw-r--r--src/endpoint.rs1
-rw-r--r--src/error.rs17
-rw-r--r--src/lib.rs4
-rw-r--r--src/netapp.rs107
-rw-r--r--src/peering/basalt.rs2
-rw-r--r--src/peering/fullmesh.rs6
-rw-r--r--src/proto.rs56
-rw-r--r--src/server.rs101
-rw-r--r--src/util.rs14
15 files changed, 266 insertions, 168 deletions
diff --git a/Cargo.lock b/Cargo.lock
index bce6ea2..fd56286 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -446,6 +446,7 @@ dependencies = [
"serde",
"structopt",
"tokio",
+ "tokio-stream",
"tokio-util",
]
@@ -667,6 +668,15 @@ dependencies = [
]
[[package]]
+name = "signal-hook-registry"
+version = "1.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0"
+dependencies = [
+ "libc",
+]
+
+[[package]]
name = "slab"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -779,7 +789,9 @@ dependencies = [
"memchr",
"mio",
"num_cpus",
+ "once_cell",
"pin-project-lite",
+ "signal-hook-registry",
"tokio-macros",
"winapi",
]
@@ -796,6 +808,17 @@ dependencies = [
]
[[package]]
+name = "tokio-stream"
+version = "0.1.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7b2f3f698253f03119ac0102beaa64f67a67e08074d03a22d18784104543727f"
+dependencies = [
+ "futures-core",
+ "pin-project-lite",
+ "tokio",
+]
+
+[[package]]
name = "tokio-util"
version = "0.6.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
diff --git a/Cargo.toml b/Cargo.toml
index 64e3401..10614e2 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -20,8 +20,9 @@ basalt = ["lru", "rand"]
[dependencies]
futures = "0.3.17"
-tokio = { version = "1.0", default-features = false, features = ["net", "rt", "rt-multi-thread", "sync", "time", "macros", "io-util"] }
+tokio = { version = "1.0", default-features = false, features = ["net", "rt", "rt-multi-thread", "sync", "time", "macros", "io-util", "signal"] }
tokio-util = { version = "0.6.8", default-features = false, features = ["compat"] }
+tokio-stream = "0.1.7"
serde = { version = "1.0", default-features = false, features = ["derive"] }
rmp-serde = "0.14.3"
diff --git a/Makefile b/Makefile
index 0f680f3..468f591 100644
--- a/Makefile
+++ b/Makefile
@@ -2,6 +2,6 @@ all:
cargo build --all-features
cargo build --example fullmesh
cargo build --all-features --example basalt
- RUST_LOG=netapp=debug,fullmesh=debug cargo run --example fullmesh -- -n 3242ce79e05e8b6a0e43441fbd140a906e13f335f298ae3a52f29784abbab500 -p 6c304114a0e1018bbe60502a34d33f4f439f370856c3333dda2726da01eb93a4894b7ef7249a71f11d342b69702f1beb7c93ec95fbcf122ad1eca583bb0629e7
+ RUST_LOG=netapp=trace,fullmesh=trace cargo run --example fullmesh -- -n 3242ce79e05e8b6a0e43441fbd140a906e13f335f298ae3a52f29784abbab500 -p 6c304114a0e1018bbe60502a34d33f4f439f370856c3333dda2726da01eb93a4894b7ef7249a71f11d342b69702f1beb7c93ec95fbcf122ad1eca583bb0629e7
#RUST_LOG=netapp=debug,fullmesh=debug cargo run --example fullmesh
diff --git a/examples/basalt.rs b/examples/basalt.rs
index 7093e05..63b4b4c 100644
--- a/examples/basalt.rs
+++ b/examples/basalt.rs
@@ -5,9 +5,9 @@ use std::time::Duration;
use log::{debug, info, warn};
+use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use structopt::StructOpt;
-use async_trait::async_trait;
use sodiumoxide::crypto::auth;
use sodiumoxide::crypto::sign::ed25519;
@@ -122,9 +122,15 @@ async fn main() {
let listen_addr = opt.listen_addr.parse().unwrap();
let public_addr = opt.public_addr.map(|x| x.parse().unwrap());
+
+ let watch_cancel = netapp::util::watch_ctrl_c();
+
tokio::join!(
example.clone().sampling_loop(),
- example.netapp.clone().listen(listen_addr, public_addr),
+ example
+ .netapp
+ .clone()
+ .listen(listen_addr, public_addr, watch_cancel),
example.basalt.clone().run(),
);
}
@@ -141,7 +147,8 @@ impl Example {
let self2 = self.clone();
tokio::spawn(async move {
match self2
- .example_endpoint.call(&p, ExampleMessage { example_field: 42 }, PRIO_NORMAL)
+ .example_endpoint
+ .call(&p, ExampleMessage { example_field: 42 }, PRIO_NORMAL)
.await
{
Ok(resp) => debug!("Got example response: {:?}", resp),
diff --git a/examples/fullmesh.rs b/examples/fullmesh.rs
index f40591a..67861a7 100644
--- a/examples/fullmesh.rs
+++ b/examples/fullmesh.rs
@@ -87,6 +87,11 @@ async fn main() {
hex::encode(&privkey.public_key()),
listen_addr);
+ let watch_cancel = netapp::util::watch_ctrl_c();
+
let public_addr = opt.public_addr.map(|x| x.parse().unwrap());
- tokio::join!(netapp.listen(listen_addr, public_addr), peering.run(),);
+ tokio::join!(
+ netapp.listen(listen_addr, public_addr, watch_cancel.clone()),
+ peering.run(watch_cancel),
+ );
}
diff --git a/src/client.rs b/src/client.rs
index 127ff46..773fa9d 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -1,11 +1,13 @@
use std::collections::HashMap;
use std::net::SocketAddr;
-use std::sync::atomic::{self, AtomicBool, AtomicU32};
+use std::sync::atomic::{self, AtomicU32};
use std::sync::{Arc, Mutex};
+use arc_swap::ArcSwapOption;
use log::{debug, error, trace};
use tokio::net::TcpStream;
+use tokio::select;
use tokio::sync::{mpsc, oneshot, watch};
use tokio_util::compat::*;
@@ -21,17 +23,14 @@ use crate::netapp::*;
use crate::proto::*;
use crate::util::*;
-
pub(crate) struct ClientConn {
pub(crate) remote_addr: SocketAddr,
pub(crate) peer_id: NodeID,
- query_send: mpsc::UnboundedSender<Option<(RequestID, RequestPriority, Vec<u8>)>>,
+ query_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>>,
next_query_number: AtomicU32,
inflight: Mutex<HashMap<RequestID, oneshot::Sender<Vec<u8>>>>,
- must_exit: AtomicBool,
- stop_recv_loop: watch::Sender<bool>,
}
impl ClientConn {
@@ -71,25 +70,35 @@ impl ClientConn {
remote_addr,
peer_id,
next_query_number: AtomicU32::from(RequestID::default()),
- query_send,
+ query_send: ArcSwapOption::new(Some(Arc::new(query_send))),
inflight: Mutex::new(HashMap::new()),
- must_exit: AtomicBool::new(false),
- stop_recv_loop,
});
netapp.connected_as_client(peer_id, conn.clone());
tokio::spawn(async move {
+ let send_future = tokio::spawn(conn.clone().send_loop(query_recv, write));
+
let conn2 = conn.clone();
- let conn3 = conn.clone();
- tokio::try_join!(conn2.send_loop(query_recv, write), async move {
- tokio::select!(
- r = conn3.recv_loop(read) => r,
- _ = await_exit(stop_recv_loop_recv) => Ok(()),
- )
- })
- .map(|_| ())
- .log_err("ClientConn send_loop/recv_loop/dispatch_loop");
+ let recv_future = tokio::spawn(async move {
+ select! {
+ r = conn2.recv_loop(read) => r,
+ _ = await_exit(stop_recv_loop_recv) => Ok(())
+ }
+ });
+
+ send_future.await.log_err("ClientConn send_loop");
+
+ // TODO here: wait for inflight requests to all have their response
+ stop_recv_loop
+ .send(true)
+ .log_err("ClientConn send true to stop_recv_loop");
+
+ recv_future.await.log_err("ClientConn recv_loop");
+
+ // Make sure we don't wait on any more requests that won't
+ // have a response
+ conn.inflight.lock().unwrap().clear();
netapp.disconnected_as_client(&peer_id, conn);
});
@@ -98,15 +107,7 @@ impl ClientConn {
}
pub fn close(&self) {
- self.must_exit.store(true, atomic::Ordering::SeqCst);
- self.query_send
- .send(None)
- .log_err("could not write None in query_send");
- if self.inflight.lock().unwrap().is_empty() {
- self.stop_recv_loop
- .send(true)
- .log_err("could not write true to stop_recv_loop");
- }
+ self.query_send.store(None);
}
pub(crate) async fn call<T>(
@@ -118,6 +119,8 @@ impl ClientConn {
where
T: Message,
{
+ let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?;
+
let id = self
.next_query_number
.fetch_add(1, atomic::Ordering::Relaxed);
@@ -138,20 +141,23 @@ impl ClientConn {
}
trace!("request: query_send {}, {} bytes", id, bytes.len());
- self.query_send.send(Some((id, prio, bytes)))?;
+ query_send.send((id, prio, bytes))?;
let resp = resp_recv.await?;
- if resp.len() == 0 {
- return Err(Error::Message("Response is 0 bytes, either a collision or a protocol error".into()));
+ if resp.is_empty() {
+ return Err(Error::Message(
+ "Response is 0 bytes, either a collision or a protocol error".into(),
+ ));
}
trace!("request response {}: ", id);
let code = resp[0];
if code == 0 {
- Ok(rmp_serde::decode::from_read_ref::<_, <T as Message>::Response>(
- &resp[1..],
- )?)
+ Ok(rmp_serde::decode::from_read_ref::<
+ _,
+ <T as Message>::Response,
+ >(&resp[1..])?)
} else {
Err(Error::Remote(format!("Remote error code {}", code)))
}
@@ -162,7 +168,7 @@ impl SendLoop for ClientConn {}
#[async_trait]
impl RecvLoop for ClientConn {
- async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>) {
+ fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>) {
trace!("ClientConn recv_handler {} ({} bytes)", id, msg.len());
let mut inflight = self.inflight.lock().unwrap();
@@ -171,11 +177,5 @@ impl RecvLoop for ClientConn {
debug!("Could not send request response, probably because request was interrupted. Dropping response.");
}
}
-
- if inflight.is_empty() && self.must_exit.load(atomic::Ordering::SeqCst) {
- self.stop_recv_loop
- .send(true)
- .log_err("could not write true to stop_recv_loop");
- }
}
}
diff --git a/src/endpoint.rs b/src/endpoint.rs
index 83957e2..0e1f5c8 100644
--- a/src/endpoint.rs
+++ b/src/endpoint.rs
@@ -123,4 +123,3 @@ where
Box::new(Self(self.0.clone()))
}
}
-
diff --git a/src/error.rs b/src/error.rs
index 14c6187..0ed30a5 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -31,6 +31,9 @@ pub enum Error {
#[error(display = "No handler / shutting down")]
NoHandler,
+ #[error(display = "Connection closed")]
+ ConnectionClosed,
+
#[error(display = "Remote error: {}", _0)]
Remote(String),
}
@@ -45,6 +48,7 @@ impl Error {
Self::RMPDecode(_) => 11,
Self::UTF8(_) => 12,
Self::NoHandler => 20,
+ Self::ConnectionClosed => 21,
Self::Handshake(_) => 30,
Self::Remote(_) => 40,
Self::Message(_) => 99,
@@ -80,3 +84,16 @@ where
};
}
}
+
+impl<E, T> LogError for Result<T, E>
+where
+ T: LogError,
+ E: Into<Error>,
+{
+ fn log_err(self, msg: &'static str) {
+ match self {
+ Err(e) => error!("Error: {}: {}", msg, Into::<Error>::into(e)),
+ Ok(x) => x.log_err(msg),
+ }
+ }
+}
diff --git a/src/lib.rs b/src/lib.rs
index f24b7ac..2ce5a51 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -13,16 +13,14 @@
//! about message priorization.
//! Also check out the examples to learn how to use this crate.
-#![feature(map_first_last)]
-
pub mod error;
pub mod util;
pub mod endpoint;
pub mod proto;
-mod server;
mod client;
+mod server;
pub mod netapp;
pub mod peering;
diff --git a/src/netapp.rs b/src/netapp.rs
index b6994ea..bffa0e1 100644
--- a/src/netapp.rs
+++ b/src/netapp.rs
@@ -2,7 +2,7 @@ use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::sync::{Arc, RwLock};
-use log::{debug, info, error};
+use log::{debug, error, info, trace, warn};
use arc_swap::ArcSwapOption;
use async_trait::async_trait;
@@ -10,13 +10,18 @@ use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use sodiumoxide::crypto::auth;
use sodiumoxide::crypto::sign::ed25519;
+
+use futures::stream::futures_unordered::FuturesUnordered;
+use futures::stream::StreamExt;
use tokio::net::{TcpListener, TcpStream};
+use tokio::select;
+use tokio::sync::{mpsc, watch};
use crate::client::*;
-use crate::server::*;
use crate::endpoint::*;
use crate::error::*;
use crate::proto::*;
+use crate::server::*;
use crate::util::*;
#[derive(Serialize, Deserialize)]
@@ -142,35 +147,91 @@ impl NetApp {
/// Main listening process for our app. This future runs during the whole
/// run time of our application.
/// If this is not called, the NetApp instance remains a passive client.
- pub async fn listen(self: Arc<Self>, listen_addr: SocketAddr, public_addr: Option<IpAddr>) {
+ pub async fn listen(
+ self: Arc<Self>,
+ listen_addr: SocketAddr,
+ public_addr: Option<IpAddr>,
+ mut must_exit: watch::Receiver<bool>,
+ ) {
let listen_params = ListenParams {
listen_addr,
public_addr,
};
- if self.listen_params.swap(Some(Arc::new(listen_params))).is_some() {
+ if self
+ .listen_params
+ .swap(Some(Arc::new(listen_params)))
+ .is_some()
+ {
error!("Trying to listen on NetApp but we're already listening!");
}
let listener = TcpListener::bind(listen_addr).await.unwrap();
info!("Listening on {}", listen_addr);
- loop {
- // The second item contains the IP and port of the new connection.
- let (socket, _) = listener.accept().await.unwrap();
+ let (conn_in, mut conn_out) = mpsc::unbounded_channel();
+ let connection_collector = tokio::spawn(async move {
+ let mut collection = FuturesUnordered::new();
+ loop {
+ if collection.is_empty() {
+ match conn_out.recv().await {
+ Some(f) => collection.push(f),
+ None => break,
+ }
+ } else {
+ select! {
+ new_fut = conn_out.recv() => {
+ match new_fut {
+ Some(f) => collection.push(f),
+ None => break,
+ }
+ }
+ result = collection.next() => {
+ trace!("Collected connection: {:?}", result);
+ }
+ }
+ }
+ }
+ debug!("Collecting last open server connections.");
+ while let Some(conn_res) = collection.next().await {
+ trace!("Collected connection: {:?}", conn_res);
+ }
+ debug!("No more server connections to collect");
+ });
+
+ while !*must_exit.borrow_and_update() {
+ let (socket, peer_addr) = select! {
+ sockres = listener.accept() => {
+ match sockres {
+ Ok(x) => x,
+ Err(e) => {
+ warn!("Error in listener.accept: {}", e);
+ continue;
+ }
+ }
+ },
+ _ = must_exit.changed() => continue,
+ };
+
info!(
"Incoming connection from {}, negotiating handshake...",
- match socket.peer_addr() {
- Ok(x) => format!("{}", x),
- Err(e) => format!("<invalid addr: {}>", e),
- }
+ peer_addr
);
let self2 = self.clone();
- tokio::spawn(async move {
- ServerConn::run(self2, socket)
- .await
- .log_err("ServerConn::run");
- });
+ let must_exit2 = must_exit.clone();
+ conn_in
+ .send(tokio::spawn(async move {
+ ServerConn::run(self2, socket, must_exit2)
+ .await
+ .log_err("ServerConn::run");
+ }))
+ .log_err("Failed to send connection to connection collector");
}
+
+ drop(conn_in);
+
+ connection_collector
+ .await
+ .log_err("Failed to await for connection collector");
}
/// Attempt to connect to a peer, given by its ip:port and its public key.
@@ -231,20 +292,6 @@ impl NetApp {
});
}
- /// Close the incoming connection from a certain client to us,
- /// if such a connection is currently open.
- pub fn server_disconnect(self: &Arc<Self>, id: &NodeID) {
- let conn = self.server_conns.read().unwrap().get(id).cloned();
- if let Some(c) = conn {
- debug!(
- "Closing incoming connection from {} ({})",
- hex::encode(c.peer_id),
- c.remote_addr
- );
- c.close();
- }
- }
-
// Called from conn.rs when an incoming connection is successfully established
// Registers the connection in our list of connections
// Do not yet call the on_connected handler, because we don't know if the remote
diff --git a/src/peering/basalt.rs b/src/peering/basalt.rs
index e0c8301..efbf6e6 100644
--- a/src/peering/basalt.rs
+++ b/src/peering/basalt.rs
@@ -3,11 +3,11 @@ use std::net::SocketAddr;
use std::sync::{Arc, RwLock};
use std::time::Duration;
+use async_trait::async_trait;
use log::{debug, info, trace, warn};
use lru::LruCache;
use rand::{thread_rng, Rng};
use serde::{Deserialize, Serialize};
-use async_trait::async_trait;
use sodiumoxide::crypto::hash;
diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs
index b579654..793eeb2 100644
--- a/src/peering/fullmesh.rs
+++ b/src/peering/fullmesh.rs
@@ -8,6 +8,8 @@ use async_trait::async_trait;
use log::{debug, info, trace, warn};
use serde::{Deserialize, Serialize};
+use tokio::sync::watch;
+
use sodiumoxide::crypto::hash;
use crate::endpoint::*;
@@ -171,8 +173,8 @@ impl FullMeshPeeringStrategy {
strat
}
- pub async fn run(self: Arc<Self>) {
- loop {
+ pub async fn run(self: Arc<Self>, must_exit: watch::Receiver<bool>) {
+ while !*must_exit.borrow() {
// 1. Read current state: get list of connected peers (ping them)
let (to_ping, to_retry) = {
let known_hosts = self.known_hosts.read().unwrap();
diff --git a/src/proto.rs b/src/proto.rs
index 3811e3f..f91ffc7 100644
--- a/src/proto.rs
+++ b/src/proto.rs
@@ -1,4 +1,4 @@
-use std::collections::{BTreeMap, HashMap, VecDeque};
+use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use log::trace;
@@ -50,7 +50,6 @@ type ChunkLength = u16;
const MAX_CHUNK_LENGTH: ChunkLength = 0x4000;
const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000;
-
struct SendQueueItem {
id: RequestID,
prio: RequestPriority,
@@ -59,31 +58,33 @@ struct SendQueueItem {
}
struct SendQueue {
- items: BTreeMap<u8, VecDeque<SendQueueItem>>,
+ items: VecDeque<(u8, VecDeque<SendQueueItem>)>,
}
impl SendQueue {
fn new() -> Self {
Self {
- items: BTreeMap::new(),
+ items: VecDeque::with_capacity(64),
}
}
fn push(&mut self, item: SendQueueItem) {
let prio = item.prio;
- let mut items_at_prio = self
- .items
- .remove(&prio)
- .unwrap_or_else(|| VecDeque::with_capacity(4));
- items_at_prio.push_back(item);
- self.items.insert(prio, items_at_prio);
+ let pos_prio = match self.items.binary_search_by(|(p, _)| p.cmp(&prio)) {
+ Ok(i) => i,
+ Err(i) => {
+ self.items.insert(i, (prio, VecDeque::new()));
+ i
+ }
+ };
+ self.items[pos_prio].1.push_back(item);
}
fn pop(&mut self) -> Option<SendQueueItem> {
- match self.items.pop_first() {
+ match self.items.pop_front() {
None => None,
Some((prio, mut items_at_prio)) => {
let ret = items_at_prio.pop_front();
if !items_at_prio.is_empty() {
- self.items.insert(prio, items_at_prio);
+ self.items.push_front((prio, items_at_prio));
}
ret.or_else(|| self.pop())
}
@@ -98,7 +99,7 @@ impl SendQueue {
pub(crate) trait SendLoop: Sync {
async fn send_loop<W>(
self: Arc<Self>,
- mut msg_recv: mpsc::UnboundedReceiver<Option<(RequestID, RequestPriority, Vec<u8>)>>,
+ mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec<u8>)>,
mut write: W,
) -> Result<(), Error>
where
@@ -107,18 +108,14 @@ pub(crate) trait SendLoop: Sync {
let mut sending = SendQueue::new();
let mut should_exit = false;
while !should_exit || !sending.is_empty() {
- if let Ok(sth) = msg_recv.try_recv() {
- if let Some((id, prio, data)) = sth {
- trace!("send_loop: got {}, {} bytes", id, data.len());
- sending.push(SendQueueItem {
- id,
- prio,
- data,
- cursor: 0,
- });
- } else {
- should_exit = true;
- }
+ if let Ok((id, prio, data)) = msg_recv.try_recv() {
+ trace!("send_loop: got {}, {} bytes", id, data.len());
+ sending.push(SendQueueItem {
+ id,
+ prio,
+ data,
+ cursor: 0,
+ });
} else if let Some(mut item) = sending.pop() {
trace!(
"send_loop: sending bytes for {} ({} bytes, {} already sent)",
@@ -149,10 +146,7 @@ pub(crate) trait SendLoop: Sync {
}
write.flush().await?;
} else {
- let sth = msg_recv
- .recv()
- .await
- .ok_or_else(|| Error::Message("Connection closed.".into()))?;
+ let sth = msg_recv.recv().await;
if let Some((id, prio, data)) = sth {
trace!("send_loop: got {}, {} bytes", id, data.len());
sending.push(SendQueueItem {
@@ -173,7 +167,7 @@ pub(crate) trait SendLoop: Sync {
#[async_trait]
pub(crate) trait RecvLoop: Sync + 'static {
// Returns true if we should stop receiving after this
- async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>);
+ fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>);
async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error>
where
@@ -205,7 +199,7 @@ pub(crate) trait RecvLoop: Sync + 'static {
if has_cont {
receiving.insert(id, msg_bytes);
} else {
- tokio::spawn(self.clone().recv_handler(id, msg_bytes));
+ self.recv_handler(id, msg_bytes);
}
}
}
diff --git a/src/server.rs b/src/server.rs
index c7d99b5..f23b810 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -1,10 +1,12 @@
use std::net::SocketAddr;
-use std::sync::{Arc};
+use std::sync::Arc;
+use arc_swap::ArcSwapOption;
use bytes::Bytes;
use log::{debug, trace};
use tokio::net::TcpStream;
+use tokio::select;
use tokio::sync::{mpsc, watch};
use tokio_util::compat::*;
@@ -42,12 +44,15 @@ pub(crate) struct ServerConn {
netapp: Arc<NetApp>,
- resp_send: mpsc::UnboundedSender<Option<(RequestID, RequestPriority, Vec<u8>)>>,
- close_send: watch::Sender<bool>,
+ resp_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>>,
}
impl ServerConn {
- pub(crate) async fn run(netapp: Arc<NetApp>, socket: TcpStream) -> Result<(), Error> {
+ pub(crate) async fn run(
+ netapp: Arc<NetApp>,
+ socket: TcpStream,
+ must_exit: watch::Receiver<bool>,
+ ) -> Result<(), Error> {
let remote_addr = socket.peer_addr()?;
let mut socket = socket.compat();
@@ -73,47 +78,33 @@ impl ServerConn {
let (resp_send, resp_recv) = mpsc::unbounded_channel();
- let (close_send, close_recv) = watch::channel(false);
-
let conn = Arc::new(ServerConn {
netapp: netapp.clone(),
remote_addr,
peer_id,
- resp_send,
- close_send,
+ resp_send: ArcSwapOption::new(Some(Arc::new(resp_send))),
});
netapp.connected_as_server(peer_id, conn.clone());
let conn2 = conn.clone();
- let conn3 = conn.clone();
- let close_recv2 = close_recv.clone();
- tokio::try_join!(
- async move {
- tokio::select!(
- r = conn2.recv_loop(read) => r,
- _ = await_exit(close_recv) => Ok(()),
- )
- },
- async move {
- tokio::select!(
- r = conn3.send_loop(resp_recv, write) => r,
- _ = await_exit(close_recv2) => Ok(()),
- )
- },
- )
- .map(|_| ())
- .log_err("ServerConn recv_loop/send_loop");
+ let recv_future = tokio::spawn(async move {
+ select! {
+ r = conn2.recv_loop(read) => r,
+ _ = await_exit(must_exit) => Ok(())
+ }
+ });
+ let send_future = tokio::spawn(conn.clone().send_loop(resp_recv, write));
+
+ recv_future.await.log_err("ServerConn recv_loop");
+ conn.resp_send.store(None);
+ send_future.await.log_err("ServerConn send_loop");
netapp.disconnected_as_server(&peer_id, conn);
Ok(())
}
- pub fn close(&self) {
- self.close_send.send(true).unwrap();
- }
-
async fn recv_handler_aux(self: &Arc<Self>, bytes: &[u8]) -> Result<Vec<u8>, Error> {
if bytes.len() < 2 {
return Err(Error::Message("Invalid protocol message".into()));
@@ -146,33 +137,33 @@ impl SendLoop for ServerConn {}
#[async_trait]
impl RecvLoop for ServerConn {
- async fn recv_handler(self: Arc<Self>, id: RequestID, bytes: Vec<u8>) {
- trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len());
- let bytes: Bytes = bytes.into();
-
- let prio = if !bytes.is_empty() {
- bytes[0]
- } else {
- 0u8
- };
- let resp = self.recv_handler_aux(&bytes[..]).await;
-
- let mut resp_bytes = vec![];
- match resp {
- Ok(rb) => {
- resp_bytes.push(0u8);
- resp_bytes.extend(&rb[..]);
- }
- Err(e) => {
- resp_bytes.push(e.code());
+ fn recv_handler(self: &Arc<Self>, id: RequestID, bytes: Vec<u8>) {
+ let resp_send = self.resp_send.load_full().unwrap();
+
+ let self2 = self.clone();
+ tokio::spawn(async move {
+ trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len());
+ let bytes: Bytes = bytes.into();
+
+ let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 };
+ let resp = self2.recv_handler_aux(&bytes[..]).await;
+
+ let mut resp_bytes = vec![];
+ match resp {
+ Ok(rb) => {
+ resp_bytes.push(0u8);
+ resp_bytes.extend(&rb[..]);
+ }
+ Err(e) => {
+ resp_bytes.push(e.code());
+ }
}
- }
- trace!("ServerConn sending response to {}: ", id);
+ trace!("ServerConn sending response to {}: ", id);
- self.resp_send
- .send(Some((id, prio, resp_bytes)))
- .log_err("ServerConn recv_handler send resp");
+ resp_send
+ .send((id, prio, resp_bytes))
+ .log_err("ServerConn recv_handler send resp");
+ });
}
}
-
diff --git a/src/util.rs b/src/util.rs
index ba485bf..e5b57ec 100644
--- a/src/util.rs
+++ b/src/util.rs
@@ -1,5 +1,7 @@
use serde::Serialize;
+use log::info;
+
use tokio::sync::watch;
pub type NodeID = sodiumoxide::crypto::sign::ed25519::PublicKey;
@@ -38,3 +40,15 @@ pub async fn await_exit(mut must_exit: watch::Receiver<bool>) {
}
}
}
+
+pub fn watch_ctrl_c() -> watch::Receiver<bool> {
+ let (send_cancel, watch_cancel) = watch::channel(false);
+ tokio::spawn(async move {
+ tokio::signal::ctrl_c()
+ .await
+ .expect("failed to install CTRL+C signal handler");
+ info!("Received CTRL+C, shutting down.");
+ send_cancel.send(true).unwrap();
+ });
+ watch_cancel
+}