diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/client.rs | 2 | ||||
-rw-r--r-- | src/recv.rs | 38 | ||||
-rw-r--r-- | src/send.rs | 61 | ||||
-rw-r--r-- | src/stream.rs | 17 |
4 files changed, 64 insertions, 54 deletions
diff --git a/src/client.rs b/src/client.rs index 2fccdb8..0dcbdf1 100644 --- a/src/client.rs +++ b/src/client.rs @@ -175,7 +175,7 @@ impl ClientConn { "Too many inflight requests! RequestID collision. Interrupting previous request." ); let _ = old_ch.send(Box::pin(futures::stream::once(async move { - Err(Error::IdCollision.code()) + Err(std::io::Error::new(std::io::ErrorKind::Other, "RequestID collision, too many inflight requests")) }))); } diff --git a/src/recv.rs b/src/recv.rs index 3bea709..f8d68da 100644 --- a/src/recv.rs +++ b/src/recv.rs @@ -35,7 +35,7 @@ impl Sender { impl Drop for Sender { fn drop(&mut self) { if let Some(inner) = self.inner.take() { - let _ = inner.send(Err(255)); + let _ = inner.send(Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, "Netapp connection dropped before end of stream"))); } } } @@ -76,25 +76,26 @@ pub(crate) trait RecvLoop: Sync + 'static { let has_cont = (size & CHUNK_HAS_CONTINUATION) != 0; let is_error = (size & ERROR_MARKER) != 0; + let size = (size & CHUNK_LENGTH_MASK) as usize; + let mut next_slice = vec![0; size as usize]; + read.read_exact(&mut next_slice[..]).await?; + let packet = if is_error { - trace!( - "recv_loop: got id {}, header_size {:04x}, error {}", - id, - size, - size & !ERROR_MARKER - ); - Err((size & !ERROR_MARKER) as u8) + let msg = String::from_utf8(next_slice).unwrap_or("<invalid utf8 error message>".into()); + debug!("recv_loop: got id {}, error: {}", id, msg); + Some(Err(std::io::Error::new(std::io::ErrorKind::Other, msg))) } else { - let size = size & !CHUNK_HAS_CONTINUATION; - let mut next_slice = vec![0; size as usize]; - read.read_exact(&mut next_slice[..]).await?; trace!( - "recv_loop: got id {}, header_size {:04x}, {} bytes", + "recv_loop: got id {}, size {}, has_cont {}", id, size, - next_slice.len() + has_cont ); - Ok(Bytes::from(next_slice)) + if !next_slice.is_empty() { + Some(Ok(Bytes::from(next_slice))) + } else { + None + } }; let mut sender = if let Some(send) = streams.remove(&(id)) { @@ -109,9 +110,12 @@ pub(crate) trait RecvLoop: Sync + 'static { Sender::new(send) }; - // If we get an error, the receiving end is disconnected. - // We still need to reach eos before dropping this sender - let _ = sender.send(packet); + if let Some(packet) = packet { + // If we cannot put packet in channel, it means that the + // receiving end of the channel is disconnected. + // We still need to reach eos before dropping this sender + let _ = sender.send(packet); + } if has_cont { assert!(!is_error); diff --git a/src/send.rs b/src/send.rs index fd415c6..f362962 100644 --- a/src/send.rs +++ b/src/send.rs @@ -18,9 +18,11 @@ use crate::stream::*; // Messages are sent by chunks // Chunk format: // - u32 BE: request id (same for request and response) -// - u16 BE: chunk length, possibly with CHUNK_HAS_CONTINUATION flag -// when this is not the last chunk of the message -// - [u8; chunk_length] chunk data +// - u16 BE: chunk length + flags: +// CHUNK_HAS_CONTINUATION when this is not the last chunk of the stream +// ERROR_MARKER if this chunk denotes an error +// (these two flags are exclusive, an error denotes the end of the stream) +// - [u8; chunk_length] chunk data / error message pub(crate) type RequestID = u32; pub(crate) type ChunkLength = u16; @@ -28,6 +30,7 @@ pub(crate) type ChunkLength = u16; pub(crate) const MAX_CHUNK_LENGTH: ChunkLength = 0x3FF0; pub(crate) const ERROR_MARKER: ChunkLength = 0x4000; pub(crate) const CHUNK_HAS_CONTINUATION: ChunkLength = 0x8000; +pub(crate) const CHUNK_LENGTH_MASK: ChunkLength = 0x3FFF; struct SendQueue { items: Vec<(u8, VecDeque<SendQueueItem>)>, @@ -92,29 +95,12 @@ impl<'a> futures::Future for SendQueuePollNextReady<'a> { let id = item.id; let eos = item.data.eos(); - let data_frame = match bytes_or_err { - Ok(bytes) => { - trace!( - "send queue poll next ready: id {} eos {:?} bytes {}", - id, - eos, - bytes.len() - ); - DataFrame::Data(bytes, !eos) - } - Err(e) => DataFrame::Error(match e { - ReadExactError::Stream(code) => { - trace!( - "send queue poll next ready: id {} eos {:?} ERROR {}", - id, - eos, - code - ); - code - } - _ => unreachable!(), - }), - }; + let packet = bytes_or_err.map_err(|e| match e { + ReadExactError::Stream(err) => err, + _ => unreachable!(), + }); + + let data_frame = DataFrame::from_packet(packet, !eos); if !eos && !matches!(data_frame, DataFrame::Error(_)) { items_at_prio.push_back(item); @@ -139,15 +125,32 @@ enum DataFrame { /// (albeit sub-optimal) to set it to true if there is nothing coming after Data(Bytes, bool), /// An error code automatically signals the end of the stream - Error(u8), + Error(Bytes), } impl DataFrame { + fn from_packet(p: Packet, has_cont: bool) -> Self { + match p { + Ok(bytes) => { + assert!(bytes.len() <= MAX_CHUNK_LENGTH as usize); + Self::Data(bytes, has_cont) + } + Err(e) => { + let msg = format!("{}", e); + let mut msg = Bytes::from(msg.into_bytes()); + if msg.len() > MAX_CHUNK_LENGTH as usize { + msg = msg.slice(..MAX_CHUNK_LENGTH as usize); + } + Self::Error(msg) + } + } + } + fn header(&self) -> [u8; 2] { let header_u16 = match self { DataFrame::Data(data, false) => data.len() as u16, DataFrame::Data(data, true) => data.len() as u16 | CHUNK_HAS_CONTINUATION, - DataFrame::Error(e) => *e as u16 | ERROR_MARKER, + DataFrame::Error(msg) => msg.len() as u16 | ERROR_MARKER, }; ChunkLength::to_be_bytes(header_u16) } @@ -155,7 +158,7 @@ impl DataFrame { fn data(&self) -> &[u8] { match self { DataFrame::Data(ref data, _) => &data[..], - DataFrame::Error(_) => &[], + DataFrame::Error(ref msg) => &msg[..], } } } diff --git a/src/stream.rs b/src/stream.rs index cc664ce..3518246 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -4,7 +4,7 @@ use std::task::{Context, Poll}; use bytes::Bytes; use futures::Future; -use futures::{Stream, StreamExt, TryStreamExt}; +use futures::{Stream, StreamExt}; use tokio::io::AsyncRead; use crate::bytes_buf::BytesBuf; @@ -18,7 +18,7 @@ use crate::bytes_buf::BytesBuf; /// meaning, it's up to your application to define their semantic. pub type ByteStream = Pin<Box<dyn Stream<Item = Packet> + Send + Sync>>; -pub type Packet = Result<Bytes, u8>; +pub type Packet = Result<Bytes, std::io::Error>; // ---- @@ -26,7 +26,7 @@ pub struct ByteStreamReader { stream: ByteStream, buf: BytesBuf, eos: bool, - err: Option<u8>, + err: Option<std::io::Error>, } impl ByteStreamReader { @@ -99,7 +99,7 @@ impl ByteStreamReader { pub enum ReadExactError { UnexpectedEos, - Stream(u8), + Stream(std::io::Error), } #[pin_project::pin_project] @@ -120,7 +120,8 @@ impl<'a> Future for ByteStreamReadExact<'a> { if let Some(bytes) = this.reader.try_get(*this.read_len) { return Poll::Ready(Ok(bytes)); } - if let Some(err) = this.reader.err { + if let Some(err) = &this.reader.err { + let err = std::io::Error::new(err.kind(), format!("{}", err)); return Poll::Ready(Err(ReadExactError::Stream(err))); } if this.reader.eos { @@ -149,6 +150,7 @@ impl<'a> Future for ByteStreamReadExact<'a> { // ---- +/* fn u8_to_io_error(v: u8) -> std::io::Error { use std::io::{Error, ErrorKind}; let kind = match v { @@ -183,11 +185,12 @@ fn io_error_to_u8(e: std::io::Error) -> u8 { _ => 100, } } +*/ pub fn asyncread_stream<R: AsyncRead + Send + Sync + 'static>(reader: R) -> ByteStream { - Box::pin(tokio_util::io::ReaderStream::new(reader).map_err(io_error_to_u8)) + Box::pin(tokio_util::io::ReaderStream::new(reader)) } pub fn stream_asyncread(stream: ByteStream) -> impl AsyncRead + Send + Sync + 'static { - tokio_util::io::StreamReader::new(stream.map_err(u8_to_io_error)) + tokio_util::io::StreamReader::new(stream) } |