diff options
Diffstat (limited to 'src/proto.rs')
-rw-r--r-- | src/proto.rs | 150 |
1 files changed, 43 insertions, 107 deletions
diff --git a/src/proto.rs b/src/proto.rs index b044280..d90042f 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -3,14 +3,14 @@ use std::sync::Arc; use log::trace; -use async_trait::async_trait; - use async_std::io::prelude::WriteExt; use async_std::io::ReadExt; use tokio::io::{ReadHalf, WriteHalf}; use tokio::net::TcpStream; -use tokio::sync::{mpsc, watch}; +use tokio::sync::mpsc; + +use async_trait::async_trait; use crate::error::*; @@ -85,26 +85,33 @@ impl SendQueue { } } } + fn is_empty(&self) -> bool { + self.items.iter().all(|(_k, v)| v.is_empty()) + } } #[async_trait] pub(crate) trait SendLoop: Sync { async fn send_loop( self: Arc<Self>, - mut msg_recv: mpsc::UnboundedReceiver<(RequestID, RequestPriority, Vec<u8>)>, + mut msg_recv: mpsc::UnboundedReceiver<Option<(RequestID, RequestPriority, Vec<u8>)>>, mut write: BoxStreamWrite<TokioCompat<WriteHalf<TcpStream>>>, - mut must_exit: watch::Receiver<bool>, ) -> Result<(), Error> { let mut sending = SendQueue::new(); - while !*must_exit.borrow() { - if let Ok((id, prio, data)) = msg_recv.try_recv() { - trace!("send_loop: got {}, {} bytes", id, data.len()); - sending.push(SendQueueItem { - id, - prio, - data, - cursor: 0, - }); + let mut should_exit = false; + while !should_exit || !sending.is_empty() { + if let Ok(sth) = msg_recv.try_recv() { + if let Some((id, prio, data)) = sth { + trace!("send_loop: got {}, {} bytes", id, data.len()); + sending.push(SendQueueItem { + id, + prio, + data, + cursor: 0, + }); + } else { + should_exit = true; + } } else if let Some(mut item) = sending.pop() { trace!( "send_loop: sending bytes for {} ({} bytes, {} already sent)", @@ -113,33 +120,14 @@ pub(crate) trait SendLoop: Sync { item.cursor ); let header_id = u16::to_be_bytes(item.id); - if write_all_or_exit(&header_id[..], &mut write, &mut must_exit) - .await? - .is_none() - { - break; - } + write.write_all(&header_id[..]).await?; if item.data.len() - item.cursor > MAX_CHUNK_SIZE { let header_size = u16::to_be_bytes(MAX_CHUNK_SIZE as u16 | 0x8000); - if write_all_or_exit(&header_size[..], &mut write, &mut must_exit) - .await? - .is_none() - { - break; - } + write.write_all(&header_size[..]).await?; let new_cursor = item.cursor + MAX_CHUNK_SIZE as usize; - if write_all_or_exit( - &item.data[item.cursor..new_cursor], - &mut write, - &mut must_exit, - ) - .await? - .is_none() - { - break; - } + write.write_all(&item.data[item.cursor..new_cursor]).await?; item.cursor = new_cursor; sending.push(item); @@ -147,33 +135,27 @@ pub(crate) trait SendLoop: Sync { let send_len = (item.data.len() - item.cursor) as u16; let header_size = u16::to_be_bytes(send_len); - if write_all_or_exit(&header_size[..], &mut write, &mut must_exit) - .await? - .is_none() - { - break; - } + write.write_all(&header_size[..]).await?; - if write_all_or_exit(&item.data[item.cursor..], &mut write, &mut must_exit) - .await? - .is_none() - { - break; - } + write.write_all(&item.data[item.cursor..]).await?; } write.flush().await.log_err("Could not flush in send_loop"); } else { - let (id, prio, data) = msg_recv + let sth = msg_recv .recv() .await .ok_or(Error::Message("Connection closed.".into()))?; - trace!("send_loop: got {}, {} bytes", id, data.len()); - sending.push(SendQueueItem { - id, - prio, - data, - cursor: 0, - }); + if let Some((id, prio, data)) = sth { + trace!("send_loop: got {}, {} bytes", id, data.len()); + sending.push(SendQueueItem { + id, + prio, + data, + cursor: 0, + }); + } else { + should_exit = true; + } } } Ok(()) @@ -182,33 +164,23 @@ pub(crate) trait SendLoop: Sync { #[async_trait] pub(crate) trait RecvLoop: Sync + 'static { + // Returns true if we should stop receiving after this async fn recv_handler(self: Arc<Self>, id: RequestID, msg: Vec<u8>); async fn recv_loop( self: Arc<Self>, mut read: BoxStreamRead<TokioCompat<ReadHalf<TcpStream>>>, - mut must_exit: watch::Receiver<bool>, ) -> Result<(), Error> { let mut receiving = HashMap::new(); - while !*must_exit.borrow() { + loop { trace!("recv_loop: reading packet"); let mut header_id = [0u8; 2]; - if read_exact_or_exit(&mut header_id[..], &mut read, &mut must_exit) - .await? - .is_none() - { - break; - } + read.read_exact(&mut header_id[..]).await?; let id = RequestID::from_be_bytes(header_id); trace!("recv_loop: got header id: {:04x}", id); let mut header_size = [0u8; 2]; - if read_exact_or_exit(&mut header_size[..], &mut read, &mut must_exit) - .await? - .is_none() - { - break; - } + read.read_exact(&mut header_size[..]).await?; let size = RequestID::from_be_bytes(header_size); trace!("recv_loop: got header size: {:04x}", id); @@ -216,12 +188,7 @@ pub(crate) trait RecvLoop: Sync + 'static { let size = size & !0x8000; let mut next_slice = vec![0; size as usize]; - if read_exact_or_exit(&mut next_slice[..], &mut read, &mut must_exit) - .await? - .is_none() - { - break; - } + read.read_exact(&mut next_slice[..]).await?; trace!("recv_loop: read {} bytes", size); let mut msg_bytes = receiving.remove(&id).unwrap_or(vec![]); @@ -233,36 +200,5 @@ pub(crate) trait RecvLoop: Sync + 'static { tokio::spawn(self.clone().recv_handler(id, msg_bytes)); } } - Ok(()) - } -} - -async fn read_exact_or_exit( - buf: &mut [u8], - read: &mut BoxStreamRead<TokioCompat<ReadHalf<TcpStream>>>, - must_exit: &mut watch::Receiver<bool>, -) -> Result<Option<()>, Error> { - tokio::select!( - res = read.read_exact(buf) => Ok(Some(res?)), - _ = await_exit(must_exit) => Ok(None), - ) -} - -async fn write_all_or_exit( - buf: &[u8], - write: &mut BoxStreamWrite<TokioCompat<WriteHalf<TcpStream>>>, - must_exit: &mut watch::Receiver<bool>, -) -> Result<Option<()>, Error> { - tokio::select!( - res = write.write_all(buf) => Ok(Some(res?)), - _ = await_exit(must_exit) => Ok(None), - ) -} - -async fn await_exit(must_exit: &mut watch::Receiver<bool>) { - loop { - if must_exit.recv().await == Some(true) { - return; - } } } |