diff options
author | Alex Auvolat <alex@adnab.me> | 2021-10-13 17:12:13 +0200 |
---|---|---|
committer | Alex Auvolat <alex@adnab.me> | 2021-10-13 17:12:13 +0200 |
commit | 70839d70d86354232f168e63ce4062219acb85c7 (patch) | |
tree | 9c956af0339aa048f487c3a4e54c320be8d13647 | |
parent | 8dede69dee20b812ad1dcab5b374c60232409f4f (diff) | |
download | netapp-70839d70d86354232f168e63ce4062219acb85c7.tar.gz netapp-70839d70d86354232f168e63ce4062219acb85c7.zip |
Try to handle termination and closing of stuff properly
-rw-r--r-- | Cargo.lock | 23 | ||||
-rw-r--r-- | Cargo.toml | 3 | ||||
-rw-r--r-- | Makefile | 2 | ||||
-rw-r--r-- | examples/basalt.rs | 13 | ||||
-rw-r--r-- | examples/fullmesh.rs | 7 | ||||
-rw-r--r-- | src/client.rs | 78 | ||||
-rw-r--r-- | src/endpoint.rs | 1 | ||||
-rw-r--r-- | src/error.rs | 17 | ||||
-rw-r--r-- | src/lib.rs | 4 | ||||
-rw-r--r-- | src/netapp.rs | 107 | ||||
-rw-r--r-- | src/peering/basalt.rs | 2 | ||||
-rw-r--r-- | src/peering/fullmesh.rs | 6 | ||||
-rw-r--r-- | src/proto.rs | 56 | ||||
-rw-r--r-- | src/server.rs | 101 | ||||
-rw-r--r-- | src/util.rs | 14 |
15 files changed, 266 insertions, 168 deletions
@@ -446,6 +446,7 @@ dependencies = [ "serde", "structopt", "tokio", + "tokio-stream", "tokio-util", ] @@ -667,6 +668,15 @@ dependencies = [ ] [[package]] +name = "signal-hook-registry" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0" +dependencies = [ + "libc", +] + +[[package]] name = "slab" version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -779,7 +789,9 @@ dependencies = [ "memchr", "mio", "num_cpus", + "once_cell", "pin-project-lite", + "signal-hook-registry", "tokio-macros", "winapi", ] @@ -796,6 +808,17 @@ dependencies = [ ] [[package]] +name = "tokio-stream" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b2f3f698253f03119ac0102beaa64f67a67e08074d03a22d18784104543727f" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + +[[package]] name = "tokio-util" version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -20,8 +20,9 @@ basalt = ["lru", "rand"] [dependencies] futures = "0.3.17" -tokio = { version = "1.0", default-features = false, features = ["net", "rt", "rt-multi-thread", "sync", "time", "macros", "io-util"] } +tokio = { version = "1.0", default-features = false, features = ["net", "rt", "rt-multi-thread", "sync", "time", "macros", "io-util", "signal"] } tokio-util = { version = "0.6.8", default-features = false, features = ["compat"] } +tokio-stream = "0.1.7" serde = { version = "1.0", default-features = false, features = ["derive"] } rmp-serde = "0.14.3" @@ -2,6 +2,6 @@ all: cargo build --all-features cargo build --example fullmesh cargo build --all-features --example basalt - RUST_LOG=netapp=debug,fullmesh=debug cargo run --example fullmesh -- -n 3242ce79e05e8b6a0e43441fbd140a906e13f335f298ae3a52f29784abbab500 -p 6c304114a0e1018bbe60502a34d33f4f439f370856c3333dda2726da01eb93a4894b7ef7249a71f11d342b69702f1beb7c93ec95fbcf122ad1eca583bb0629e7 + RUST_LOG=netapp=trace,fullmesh=trace cargo run --example fullmesh -- -n 3242ce79e05e8b6a0e43441fbd140a906e13f335f298ae3a52f29784abbab500 -p 6c304114a0e1018bbe60502a34d33f4f439f370856c3333dda2726da01eb93a4894b7ef7249a71f11d342b69702f1beb7c93ec95fbcf122ad1eca583bb0629e7 #RUST_LOG=netapp=debug,fullmesh=debug cargo run --example fullmesh diff --git a/examples/basalt.rs b/examples/basalt.rs index 7093e05..63b4b4c 100644 --- a/examples/basalt.rs +++ b/examples/basalt.rs @@ -5,9 +5,9 @@ use std::time::Duration; use log::{debug, info, warn}; +use async_trait::async_trait; use serde::{Deserialize, Serialize}; use structopt::StructOpt; -use async_trait::async_trait; use sodiumoxide::crypto::auth; use sodiumoxide::crypto::sign::ed25519; @@ -122,9 +122,15 @@ async fn main() { let listen_addr = opt.listen_addr.parse().unwrap(); let public_addr = opt.public_addr.map(|x| x.parse().unwrap()); + + let watch_cancel = netapp::util::watch_ctrl_c(); + tokio::join!( example.clone().sampling_loop(), - example.netapp.clone().listen(listen_addr, public_addr), + example + .netapp + .clone() + .listen(listen_addr, public_addr, watch_cancel), example.basalt.clone().run(), ); } @@ -141,7 +147,8 @@ impl Example { let self2 = self.clone(); tokio::spawn(async move { match self2 - .example_endpoint.call(&p, ExampleMessage { example_field: 42 }, PRIO_NORMAL) + .example_endpoint + .call(&p, ExampleMessage { example_field: 42 }, PRIO_NORMAL) .await { Ok(resp) => debug!("Got example response: {:?}", resp), diff --git a/examples/fullmesh.rs b/examples/fullmesh.rs index f40591a..67861a7 100644 --- a/examples/fullmesh.rs +++ b/examples/fullmesh.rs @@ -87,6 +87,11 @@ async fn main() { hex::encode(&privkey.public_key()), listen_addr); + let watch_cancel = netapp::util::watch_ctrl_c(); + let public_addr = opt.public_addr.map(|x| x.parse().unwrap()); - tokio::join!(netapp.listen(listen_addr, public_addr), peering.run(),); + tokio::join!( + netapp.listen(listen_addr, public_addr, watch_cancel.clone()), + peering.run(watch_cancel), + ); } diff --git a/src/client.rs b/src/client.rs index 127ff46..773fa9d 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,11 +1,13 @@ use std::collections::HashMap; use std::net::SocketAddr; -use std::sync::atomic::{self, AtomicBool, AtomicU32}; +use std::sync::atomic::{self, AtomicU32}; use std::sync::{Arc, Mutex}; +use arc_swap::ArcSwapOption; use log::{debug, error, trace}; use tokio::net::TcpStream; +use tokio::select; use tokio::sync::{mpsc, oneshot, watch}; use tokio_util::compat::*; @@ -21,17 +23,14 @@ use crate::netapp::*; use crate::proto::*; use crate::util::*; - pub(crate) struct ClientConn { pub(crate) remote_addr: SocketAddr, pub(crate) peer_id: NodeID, - query_send: mpsc::UnboundedSender<Option<(RequestID, RequestPriority, Vec<u8>)>>, + query_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>>, next_query_number: AtomicU32, inflight: Mutex<HashMap<RequestID, oneshot::Sender<Vec<u8>>>>, - must_exit: AtomicBool, - stop_recv_loop: watch::Sender<bool>, } impl ClientConn { @@ -71,25 +70,35 @@ impl ClientConn { remote_addr, peer_id, next_query_number: AtomicU32::from(RequestID::default()), - query_send, + query_send: ArcSwapOption::new(Some(Arc::new(query_send))), inflight: Mutex::new(HashMap::new()), - must_exit: AtomicBool::new(false), - stop_recv_loop, }); netapp.connected_as_client(peer_id, conn.clone()); tokio::spawn(async move { + let send_future = tokio::spawn(conn.clone().send_loop(query_recv, write)); + let conn2 = conn.clone(); - let conn3 = conn.clone(); - tokio::try_join!(conn2.send_loop(query_recv, write), async move { - tokio::select!( - r = conn3.recv_loop(read) => r, - _ = await_exit(stop_recv_loop_recv) => Ok(()), - ) - }) - .map(|_| ()) - .log_err("ClientConn send_loop/recv_loop/dispatch_loop"); + let recv_future = tokio::spawn(async move { + select! { + r = conn2.recv_loop(read) => r, + _ = await_exit(stop_recv_loop_recv) => Ok(()) + } + }); + + send_future.await.log_err("ClientConn send_loop"); + + // TODO here: wait for inflight requests to all have their response + stop_recv_loop + .send(true) + .log_err("ClientConn send true to stop_recv_loop"); + + recv_future.await.log_err("ClientConn recv_loop"); + + // Make sure we don't wait on any more requests that won't + // have a response + conn.inflight.lock().unwrap().clear(); netapp.disconnected_as_client(&peer_id, conn); }); @@ -98,15 +107,7 @@ impl ClientConn { } pub fn close(&self) { - self.must_exit.store(true, atomic::Ordering::SeqCst); - self.query_send - .send(None) - .log_err("could not write None in query_send"); - if self.inflight.lock().unwrap().is_empty() { - self.stop_recv_loop - .send(true) - .log_err("could not write true to stop_recv_loop"); - } + self.query_send.store(None); } pub(crate) async fn call<T>( @@ -118,6 +119,8 @@ impl ClientConn { where T: Message, { + let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?; + let id = self .next_query_number .fetch_add(1, atomic::Ordering::Relaxed); @@ -138,20 +141,23 @@ impl ClientConn { } trace!("request: query_send {}, {} bytes", id, bytes.len()); - self.query_send.send(Some((id, prio, bytes)))?; + query_send.send((id, prio, bytes))?; let resp = resp_recv.await?; - if resp.len() == 0 { - return Err(Error::Message("Response is 0 bytes, either a collision or a protocol error".into())); + if resp.is_empty() { + return Err(Error::Message( + "Response is 0 bytes, either a collision or a protocol error".into(), + )); } trace!("request response {}: ", id); let code = resp[0]; if code == 0 { - Ok(rmp_serde::decode::from_read_ref::<_, <T as Message>::Response>( - &resp[1..], - )?) + Ok(rmp_serde::decode::from_read_ref::< + _, + <T as Message>::Response, + >(&resp[1..])?) } else { Err(Error::Remote(format!("Remote error code {}", code))) } @@ -162,7 +168,7 @@ impl SendLoop for ClientConn {} #[async_trait] impl RecvLoop for ClientConn { - async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>) { + fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>) { trace!("ClientConn recv_handler {} ({} bytes)", id, msg.len()); let mut inflight = self.inflight.lock().unwrap(); @@ -171,11 +177,5 @@ impl RecvLoop for ClientConn { debug!("Could not send request response, probably because request was interrupted. Dropping response."); } } - - if inflight.is_empty() && self.must_exit.load(atomic::Ordering::SeqCst) { - self.stop_recv_loop - .send(true) - .log_err("could not write true to stop_recv_loop"); - } } } diff --git a/src/endpoint.rs b/src/endpoint.rs index 83957e2..0e1f5c8 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -123,4 +123,3 @@ where Box::new(Self(self.0.clone())) } } - diff --git a/src/error.rs b/src/error.rs index 14c6187..0ed30a5 100644 --- a/src/error.rs +++ b/src/error.rs @@ -31,6 +31,9 @@ pub enum Error { #[error(display = "No handler / shutting down")] NoHandler, + #[error(display = "Connection closed")] + ConnectionClosed, + #[error(display = "Remote error: {}", _0)] Remote(String), } @@ -45,6 +48,7 @@ impl Error { Self::RMPDecode(_) => 11, Self::UTF8(_) => 12, Self::NoHandler => 20, + Self::ConnectionClosed => 21, Self::Handshake(_) => 30, Self::Remote(_) => 40, Self::Message(_) => 99, @@ -80,3 +84,16 @@ where }; } } + +impl<E, T> LogError for Result<T, E> +where + T: LogError, + E: Into<Error>, +{ + fn log_err(self, msg: &'static str) { + match self { + Err(e) => error!("Error: {}: {}", msg, Into::<Error>::into(e)), + Ok(x) => x.log_err(msg), + } + } +} @@ -13,16 +13,14 @@ //! about message priorization. //! Also check out the examples to learn how to use this crate. -#![feature(map_first_last)] - pub mod error; pub mod util; pub mod endpoint; pub mod proto; -mod server; mod client; +mod server; pub mod netapp; pub mod peering; diff --git a/src/netapp.rs b/src/netapp.rs index b6994ea..bffa0e1 100644 --- a/src/netapp.rs +++ b/src/netapp.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use std::net::{IpAddr, SocketAddr}; use std::sync::{Arc, RwLock}; -use log::{debug, info, error}; +use log::{debug, error, info, trace, warn}; use arc_swap::ArcSwapOption; use async_trait::async_trait; @@ -10,13 +10,18 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use sodiumoxide::crypto::auth; use sodiumoxide::crypto::sign::ed25519; + +use futures::stream::futures_unordered::FuturesUnordered; +use futures::stream::StreamExt; use tokio::net::{TcpListener, TcpStream}; +use tokio::select; +use tokio::sync::{mpsc, watch}; use crate::client::*; -use crate::server::*; use crate::endpoint::*; use crate::error::*; use crate::proto::*; +use crate::server::*; use crate::util::*; #[derive(Serialize, Deserialize)] @@ -142,35 +147,91 @@ impl NetApp { /// Main listening process for our app. This future runs during the whole /// run time of our application. /// If this is not called, the NetApp instance remains a passive client. - pub async fn listen(self: Arc<Self>, listen_addr: SocketAddr, public_addr: Option<IpAddr>) { + pub async fn listen( + self: Arc<Self>, + listen_addr: SocketAddr, + public_addr: Option<IpAddr>, + mut must_exit: watch::Receiver<bool>, + ) { let listen_params = ListenParams { listen_addr, public_addr, }; - if self.listen_params.swap(Some(Arc::new(listen_params))).is_some() { + if self + .listen_params + .swap(Some(Arc::new(listen_params))) + .is_some() + { error!("Trying to listen on NetApp but we're already listening!"); } let listener = TcpListener::bind(listen_addr).await.unwrap(); info!("Listening on {}", listen_addr); - loop { - // The second item contains the IP and port of the new connection. - let (socket, _) = listener.accept().await.unwrap(); + let (conn_in, mut conn_out) = mpsc::unbounded_channel(); + let connection_collector = tokio::spawn(async move { + let mut collection = FuturesUnordered::new(); + loop { + if collection.is_empty() { + match conn_out.recv().await { + Some(f) => collection.push(f), + None => break, + } + } else { + select! { + new_fut = conn_out.recv() => { + match new_fut { + Some(f) => collection.push(f), + None => break, + } + } + result = collection.next() => { + trace!("Collected connection: {:?}", result); + } + } + } + } + debug!("Collecting last open server connections."); + while let Some(conn_res) = collection.next().await { + trace!("Collected connection: {:?}", conn_res); + } + debug!("No more server connections to collect"); + }); + + while !*must_exit.borrow_and_update() { + let (socket, peer_addr) = select! { + sockres = listener.accept() => { + match sockres { + Ok(x) => x, + Err(e) => { + warn!("Error in listener.accept: {}", e); + continue; + } + } + }, + _ = must_exit.changed() => continue, + }; + info!( "Incoming connection from {}, negotiating handshake...", - match socket.peer_addr() { - Ok(x) => format!("{}", x), - Err(e) => format!("<invalid addr: {}>", e), - } + peer_addr ); let self2 = self.clone(); - tokio::spawn(async move { - ServerConn::run(self2, socket) - .await - .log_err("ServerConn::run"); - }); + let must_exit2 = must_exit.clone(); + conn_in + .send(tokio::spawn(async move { + ServerConn::run(self2, socket, must_exit2) + .await + .log_err("ServerConn::run"); + })) + .log_err("Failed to send connection to connection collector"); } + + drop(conn_in); + + connection_collector + .await + .log_err("Failed to await for connection collector"); } /// Attempt to connect to a peer, given by its ip:port and its public key. @@ -231,20 +292,6 @@ impl NetApp { }); } - /// Close the incoming connection from a certain client to us, - /// if such a connection is currently open. - pub fn server_disconnect(self: &Arc<Self>, id: &NodeID) { - let conn = self.server_conns.read().unwrap().get(id).cloned(); - if let Some(c) = conn { - debug!( - "Closing incoming connection from {} ({})", - hex::encode(c.peer_id), - c.remote_addr - ); - c.close(); - } - } - // Called from conn.rs when an incoming connection is successfully established // Registers the connection in our list of connections // Do not yet call the on_connected handler, because we don't know if the remote diff --git a/src/peering/basalt.rs b/src/peering/basalt.rs index e0c8301..efbf6e6 100644 --- a/src/peering/basalt.rs +++ b/src/peering/basalt.rs @@ -3,11 +3,11 @@ use std::net::SocketAddr; use std::sync::{Arc, RwLock}; use std::time::Duration; +use async_trait::async_trait; use log::{debug, info, trace, warn}; use lru::LruCache; use rand::{thread_rng, Rng}; use serde::{Deserialize, Serialize}; -use async_trait::async_trait; use sodiumoxide::crypto::hash; diff --git a/src/peering/fullmesh.rs b/src/peering/fullmesh.rs index b579654..793eeb2 100644 --- a/src/peering/fullmesh.rs +++ b/src/peering/fullmesh.rs @@ -8,6 +8,8 @@ use async_trait::async_trait; use log::{debug, info, trace, warn}; use serde::{Deserialize, Serialize}; +use tokio::sync::watch; + use sodiumoxide::crypto::hash; use crate::endpoint::*; @@ -171,8 +173,8 @@ impl FullMeshPeeringStrategy { strat } - pub async fn run(self: Arc<Self>) { - loop { + pub async fn run(self: Arc<Self>, must_exit: watch::Receiver<bool>) { + while !*must_exit.borrow() { // 1. Read current state: get list of connected peers (ping them) let (to_ping, to_retry) = { let known_hosts = self.known_hosts.read().unwrap(); diff --git a/src/proto.rs b/src/proto.rs index 3811e3f..f91ffc7 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -1,4 +1,4 @@ -use std::collections::{BTreeMap, HashMap, VecDeque}; +use std::collections::{HashMap, VecDeque}; use std::sync::Arc; use log::trace; @@ -50,7 +50,6 @@ type ChunkLength = u16; const MAX_CHUNK_LENGTH: ChunkLength = 0x4000; const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; - struct SendQueueItem { id: RequestID, prio: RequestPriority, @@ -59,31 +58,33 @@ struct SendQueueItem { } struct SendQueue { - items: BTreeMap<u8, VecDeque<SendQueueItem>>, + items: VecDeque<(u8, VecDeque<SendQueueItem>)>, } impl SendQueue { fn new() -> Self { Self { - items: BTreeMap::new(), + items: VecDeque::with_capacity(64), } } fn push(&mut self, item: SendQueueItem) { let prio = item.prio; - let mut items_at_prio = self - .items - .remove(&prio) - .unwrap_or_else(|| VecDeque::with_capacity(4)); - items_at_prio.push_back(item); - self.items.insert(prio, items_at_prio); + let pos_prio = match self.items.binary_search_by(|(p, _)| p.cmp(&prio)) { + Ok(i) => i, + Err(i) => { + self.items.insert(i, (prio, VecDeque::new())); + i + } + }; + self.items[pos_prio].1.push_back(item); } fn pop(&mut self) -> Option<SendQueueItem> { - match self.items.pop_first() { + match self.items.pop_front() { None => None, Some((prio, mut items_at_prio)) => { let ret = items_at_prio.pop_front(); if !items_at_prio.is_empty() { - self.items.insert(prio, items_at_prio); + self.items.push_front((prio, items_at_prio)); } ret.or_else(|| self.pop()) } @@ -98,7 +99,7 @@ impl SendQueue { pub(crate) trait SendLoop: Sync { async fn send_loop<W>( self: Arc<Self>, - mut msg_recv: mpsc::UnboundedReceiver<Option<(RequestID, RequestPriority, Vec<u8>)>>, + mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec<u8>)>, mut write: W, ) -> Result<(), Error> where @@ -107,18 +108,14 @@ pub(crate) trait SendLoop: Sync { let mut sending = SendQueue::new(); let mut should_exit = false; while !should_exit || !sending.is_empty() { - if let Ok(sth) = msg_recv.try_recv() { - if let Some((id, prio, data)) = sth { - trace!("send_loop: got {}, {} bytes", id, data.len()); - sending.push(SendQueueItem { - id, - prio, - data, - cursor: 0, - }); - } else { - should_exit = true; - } + if let Ok((id, prio, data)) = msg_recv.try_recv() { + trace!("send_loop: got {}, {} bytes", id, data.len()); + sending.push(SendQueueItem { + id, + prio, + data, + cursor: 0, + }); } else if let Some(mut item) = sending.pop() { trace!( "send_loop: sending bytes for {} ({} bytes, {} already sent)", @@ -149,10 +146,7 @@ pub(crate) trait SendLoop: Sync { } write.flush().await?; } else { - let sth = msg_recv - .recv() - .await - .ok_or_else(|| Error::Message("Connection closed.".into()))?; + let sth = msg_recv.recv().await; if let Some((id, prio, data)) = sth { trace!("send_loop: got {}, {} bytes", id, data.len()); sending.push(SendQueueItem { @@ -173,7 +167,7 @@ pub(crate) trait SendLoop: Sync { #[async_trait] pub(crate) trait RecvLoop: Sync + 'static { // Returns true if we should stop receiving after this - async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>); + fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>); async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error> where @@ -205,7 +199,7 @@ pub(crate) trait RecvLoop: Sync + 'static { if has_cont { receiving.insert(id, msg_bytes); } else { - tokio::spawn(self.clone().recv_handler(id, msg_bytes)); + self.recv_handler(id, msg_bytes); } } } diff --git a/src/server.rs b/src/server.rs index c7d99b5..f23b810 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,10 +1,12 @@ use std::net::SocketAddr; -use std::sync::{Arc}; +use std::sync::Arc; +use arc_swap::ArcSwapOption; use bytes::Bytes; use log::{debug, trace}; use tokio::net::TcpStream; +use tokio::select; use tokio::sync::{mpsc, watch}; use tokio_util::compat::*; @@ -42,12 +44,15 @@ pub(crate) struct ServerConn { netapp: Arc<NetApp>, - resp_send: mpsc::UnboundedSender<Option<(RequestID, RequestPriority, Vec<u8>)>>, - close_send: watch::Sender<bool>, + resp_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>>, } impl ServerConn { - pub(crate) async fn run(netapp: Arc<NetApp>, socket: TcpStream) -> Result<(), Error> { + pub(crate) async fn run( + netapp: Arc<NetApp>, + socket: TcpStream, + must_exit: watch::Receiver<bool>, + ) -> Result<(), Error> { let remote_addr = socket.peer_addr()?; let mut socket = socket.compat(); @@ -73,47 +78,33 @@ impl ServerConn { let (resp_send, resp_recv) = mpsc::unbounded_channel(); - let (close_send, close_recv) = watch::channel(false); - let conn = Arc::new(ServerConn { netapp: netapp.clone(), remote_addr, peer_id, - resp_send, - close_send, + resp_send: ArcSwapOption::new(Some(Arc::new(resp_send))), }); netapp.connected_as_server(peer_id, conn.clone()); let conn2 = conn.clone(); - let conn3 = conn.clone(); - let close_recv2 = close_recv.clone(); - tokio::try_join!( - async move { - tokio::select!( - r = conn2.recv_loop(read) => r, - _ = await_exit(close_recv) => Ok(()), - ) - }, - async move { - tokio::select!( - r = conn3.send_loop(resp_recv, write) => r, - _ = await_exit(close_recv2) => Ok(()), - ) - }, - ) - .map(|_| ()) - .log_err("ServerConn recv_loop/send_loop"); + let recv_future = tokio::spawn(async move { + select! { + r = conn2.recv_loop(read) => r, + _ = await_exit(must_exit) => Ok(()) + } + }); + let send_future = tokio::spawn(conn.clone().send_loop(resp_recv, write)); + + recv_future.await.log_err("ServerConn recv_loop"); + conn.resp_send.store(None); + send_future.await.log_err("ServerConn send_loop"); netapp.disconnected_as_server(&peer_id, conn); Ok(()) } - pub fn close(&self) { - self.close_send.send(true).unwrap(); - } - async fn recv_handler_aux(self: &Arc<Self>, bytes: &[u8]) -> Result<Vec<u8>, Error> { if bytes.len() < 2 { return Err(Error::Message("Invalid protocol message".into())); @@ -146,33 +137,33 @@ impl SendLoop for ServerConn {} #[async_trait] impl RecvLoop for ServerConn { - async fn recv_handler(self: Arc<Self>, id: RequestID, bytes: Vec<u8>) { - trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len()); - let bytes: Bytes = bytes.into(); - - let prio = if !bytes.is_empty() { - bytes[0] - } else { - 0u8 - }; - let resp = self.recv_handler_aux(&bytes[..]).await; - - let mut resp_bytes = vec![]; - match resp { - Ok(rb) => { - resp_bytes.push(0u8); - resp_bytes.extend(&rb[..]); - } - Err(e) => { - resp_bytes.push(e.code()); + fn recv_handler(self: &Arc<Self>, id: RequestID, bytes: Vec<u8>) { + let resp_send = self.resp_send.load_full().unwrap(); + + let self2 = self.clone(); + tokio::spawn(async move { + trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len()); + let bytes: Bytes = bytes.into(); + + let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 }; + let resp = self2.recv_handler_aux(&bytes[..]).await; + + let mut resp_bytes = vec![]; + match resp { + Ok(rb) => { + resp_bytes.push(0u8); + resp_bytes.extend(&rb[..]); + } + Err(e) => { + resp_bytes.push(e.code()); + } } - } - trace!("ServerConn sending response to {}: ", id); + trace!("ServerConn sending response to {}: ", id); - self.resp_send - .send(Some((id, prio, resp_bytes))) - .log_err("ServerConn recv_handler send resp"); + resp_send + .send((id, prio, resp_bytes)) + .log_err("ServerConn recv_handler send resp"); + }); } } - diff --git a/src/util.rs b/src/util.rs index ba485bf..e5b57ec 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,5 +1,7 @@ use serde::Serialize; +use log::info; + use tokio::sync::watch; pub type NodeID = sodiumoxide::crypto::sign::ed25519::PublicKey; @@ -38,3 +40,15 @@ pub async fn await_exit(mut must_exit: watch::Receiver<bool>) { } } } + +pub fn watch_ctrl_c() -> watch::Receiver<bool> { + let (send_cancel, watch_cancel) = watch::channel(false); + tokio::spawn(async move { + tokio::signal::ctrl_c() + .await + .expect("failed to install CTRL+C signal handler"); + info!("Received CTRL+C, shutting down."); + send_cancel.send(true).unwrap(); + }); + watch_cancel +} |