diff options
Diffstat (limited to 'src/stream.rs')
-rw-r--r-- | src/stream.rs | 176 |
1 files changed, 176 insertions, 0 deletions
diff --git a/src/stream.rs b/src/stream.rs new file mode 100644 index 0000000..6c23f4a --- /dev/null +++ b/src/stream.rs @@ -0,0 +1,176 @@ +use std::collections::VecDeque; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use bytes::Bytes; + +use futures::Future; +use futures::{Stream, StreamExt}; + +/// A stream of associated data. +/// +/// When sent through Netapp, the Vec may be split in smaller chunk in such a way +/// consecutive Vec may get merged, but Vec and error code may not be reordered +/// +/// Error code 255 means the stream was cut before its end. Other codes have no predefined +/// meaning, it's up to your application to define their semantic. +pub type ByteStream = Pin<Box<dyn Stream<Item = Packet> + Send + Sync>>; + +pub type Packet = Result<Bytes, u8>; + +pub struct ByteStreamReader { + stream: ByteStream, + buf: VecDeque<Bytes>, + buf_len: usize, + eos: bool, + err: Option<u8>, +} + +impl ByteStreamReader { + pub fn new(stream: ByteStream) -> Self { + ByteStreamReader { + stream, + buf: VecDeque::with_capacity(8), + buf_len: 0, + eos: false, + err: None, + } + } + + pub fn read_exact(&mut self, read_len: usize) -> ByteStreamReadExact<'_> { + ByteStreamReadExact { + reader: self, + read_len, + fail_on_eos: true, + } + } + + pub fn read_exact_or_eos(&mut self, read_len: usize) -> ByteStreamReadExact<'_> { + ByteStreamReadExact { + reader: self, + read_len, + fail_on_eos: false, + } + } + + pub async fn read_u8(&mut self) -> Result<u8, ReadExactError> { + Ok(self.read_exact(1).await?[0]) + } + + pub async fn read_u16(&mut self) -> Result<u16, ReadExactError> { + let bytes = self.read_exact(2).await?; + let mut b = [0u8; 2]; + b.copy_from_slice(&bytes[..]); + Ok(u16::from_be_bytes(b)) + } + + pub async fn read_u32(&mut self) -> Result<u32, ReadExactError> { + let bytes = self.read_exact(4).await?; + let mut b = [0u8; 4]; + b.copy_from_slice(&bytes[..]); + Ok(u32::from_be_bytes(b)) + } + + pub fn into_stream(self) -> ByteStream { + let buf_stream = futures::stream::iter(self.buf.into_iter().map(Ok)); + if let Some(err) = self.err { + Box::pin(buf_stream.chain(futures::stream::once(async move { Err(err) }))) + } else if self.eos { + Box::pin(buf_stream) + } else { + Box::pin(buf_stream.chain(self.stream)) + } + } + + fn try_get(&mut self, read_len: usize) -> Option<Bytes> { + if self.buf_len >= read_len { + let mut slices = Vec::with_capacity(self.buf.len()); + let mut taken = 0; + while taken < read_len { + let front = self.buf.pop_front().unwrap(); + if taken + front.len() <= read_len { + taken += front.len(); + self.buf_len -= front.len(); + slices.push(front); + } else { + let front_take = read_len - taken; + slices.push(front.slice(..front_take)); + self.buf.push_front(front.slice(front_take..)); + self.buf_len -= front_take; + break; + } + } + Some( + slices + .iter() + .map(|x| &x[..]) + .collect::<Vec<_>>() + .concat() + .into(), + ) + } else { + None + } + } +} + +pub enum ReadExactError { + UnexpectedEos, + Stream(u8), +} + +#[pin_project::pin_project] +pub struct ByteStreamReadExact<'a> { + #[pin] + reader: &'a mut ByteStreamReader, + read_len: usize, + fail_on_eos: bool, +} + +impl<'a> Future for ByteStreamReadExact<'a> { + type Output = Result<Bytes, ReadExactError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<Bytes, ReadExactError>> { + let mut this = self.project(); + + loop { + if let Some(bytes) = this.reader.try_get(*this.read_len) { + return Poll::Ready(Ok(bytes)); + } + if let Some(err) = this.reader.err { + return Poll::Ready(Err(ReadExactError::Stream(err))); + } + if this.reader.eos { + if *this.fail_on_eos { + return Poll::Ready(Err(ReadExactError::UnexpectedEos)); + } else { + let bytes = Bytes::from( + this.reader + .buf + .iter() + .map(|x| &x[..]) + .collect::<Vec<_>>() + .concat(), + ); + this.reader.buf.clear(); + this.reader.buf_len = 0; + return Poll::Ready(Ok(bytes)); + } + } + + match futures::ready!(this.reader.stream.as_mut().poll_next(cx)) { + Some(Ok(slice)) => { + this.reader.buf_len += slice.len(); + this.reader.buf.push_back(slice); + } + Some(Err(e)) => { + this.reader.err = Some(e); + this.reader.eos = true; + } + None => { + this.reader.eos = true; + } + } + } + } +} |