diff options
Diffstat (limited to 'src/server.rs')
-rw-r--r-- | src/server.rs | 101 |
1 files changed, 46 insertions, 55 deletions
diff --git a/src/server.rs b/src/server.rs index c7d99b5..f23b810 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,10 +1,12 @@ use std::net::SocketAddr; -use std::sync::{Arc}; +use std::sync::Arc; +use arc_swap::ArcSwapOption; use bytes::Bytes; use log::{debug, trace}; use tokio::net::TcpStream; +use tokio::select; use tokio::sync::{mpsc, watch}; use tokio_util::compat::*; @@ -42,12 +44,15 @@ pub(crate) struct ServerConn { netapp: Arc<NetApp>, - resp_send: mpsc::UnboundedSender<Option<(RequestID, RequestPriority, Vec<u8>)>>, - close_send: watch::Sender<bool>, + resp_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>>, } impl ServerConn { - pub(crate) async fn run(netapp: Arc<NetApp>, socket: TcpStream) -> Result<(), Error> { + pub(crate) async fn run( + netapp: Arc<NetApp>, + socket: TcpStream, + must_exit: watch::Receiver<bool>, + ) -> Result<(), Error> { let remote_addr = socket.peer_addr()?; let mut socket = socket.compat(); @@ -73,47 +78,33 @@ impl ServerConn { let (resp_send, resp_recv) = mpsc::unbounded_channel(); - let (close_send, close_recv) = watch::channel(false); - let conn = Arc::new(ServerConn { netapp: netapp.clone(), remote_addr, peer_id, - resp_send, - close_send, + resp_send: ArcSwapOption::new(Some(Arc::new(resp_send))), }); netapp.connected_as_server(peer_id, conn.clone()); let conn2 = conn.clone(); - let conn3 = conn.clone(); - let close_recv2 = close_recv.clone(); - tokio::try_join!( - async move { - tokio::select!( - r = conn2.recv_loop(read) => r, - _ = await_exit(close_recv) => Ok(()), - ) - }, - async move { - tokio::select!( - r = conn3.send_loop(resp_recv, write) => r, - _ = await_exit(close_recv2) => Ok(()), - ) - }, - ) - .map(|_| ()) - .log_err("ServerConn recv_loop/send_loop"); + let recv_future = tokio::spawn(async move { + select! { + r = conn2.recv_loop(read) => r, + _ = await_exit(must_exit) => Ok(()) + } + }); + let send_future = tokio::spawn(conn.clone().send_loop(resp_recv, write)); + + recv_future.await.log_err("ServerConn recv_loop"); + conn.resp_send.store(None); + send_future.await.log_err("ServerConn send_loop"); netapp.disconnected_as_server(&peer_id, conn); Ok(()) } - pub fn close(&self) { - self.close_send.send(true).unwrap(); - } - async fn recv_handler_aux(self: &Arc<Self>, bytes: &[u8]) -> Result<Vec<u8>, Error> { if bytes.len() < 2 { return Err(Error::Message("Invalid protocol message".into())); @@ -146,33 +137,33 @@ impl SendLoop for ServerConn {} #[async_trait] impl RecvLoop for ServerConn { - async fn recv_handler(self: Arc<Self>, id: RequestID, bytes: Vec<u8>) { - trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len()); - let bytes: Bytes = bytes.into(); - - let prio = if !bytes.is_empty() { - bytes[0] - } else { - 0u8 - }; - let resp = self.recv_handler_aux(&bytes[..]).await; - - let mut resp_bytes = vec![]; - match resp { - Ok(rb) => { - resp_bytes.push(0u8); - resp_bytes.extend(&rb[..]); - } - Err(e) => { - resp_bytes.push(e.code()); + fn recv_handler(self: &Arc<Self>, id: RequestID, bytes: Vec<u8>) { + let resp_send = self.resp_send.load_full().unwrap(); + + let self2 = self.clone(); + tokio::spawn(async move { + trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len()); + let bytes: Bytes = bytes.into(); + + let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 }; + let resp = self2.recv_handler_aux(&bytes[..]).await; + + let mut resp_bytes = vec![]; + match resp { + Ok(rb) => { + resp_bytes.push(0u8); + resp_bytes.extend(&rb[..]); + } + Err(e) => { + resp_bytes.push(e.code()); + } } - } - trace!("ServerConn sending response to {}: ", id); + trace!("ServerConn sending response to {}: ", id); - self.resp_send - .send(Some((id, prio, resp_bytes))) - .log_err("ServerConn recv_handler send resp"); + resp_send + .send((id, prio, resp_bytes)) + .log_err("ServerConn recv_handler send resp"); + }); } } - |