use std::collections::VecDeque;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::Bytes;
use futures::Future;
use futures::{Stream, StreamExt};
/// A stream of associated data.
///
/// When sent through Netapp, the Vec may be split in smaller chunk in such a way
/// consecutive Vec may get merged, but Vec and error code may not be reordered
///
/// Error code 255 means the stream was cut before its end. Other codes have no predefined
/// meaning, it's up to your application to define their semantic.
pub type ByteStream = Pin<Box<dyn Stream<Item = Packet> + Send + Sync>>;
pub type Packet = Result<Bytes, u8>;
pub struct ByteStreamReader {
stream: ByteStream,
buf: VecDeque<Bytes>,
buf_len: usize,
eos: bool,
err: Option<u8>,
}
impl ByteStreamReader {
pub fn new(stream: ByteStream) -> Self {
ByteStreamReader {
stream,
buf: VecDeque::with_capacity(8),
buf_len: 0,
eos: false,
err: None,
}
}
pub fn read_exact(&mut self, read_len: usize) -> ByteStreamReadExact<'_> {
ByteStreamReadExact {
reader: self,
read_len,
fail_on_eos: true,
}
}
pub fn read_exact_or_eos(&mut self, read_len: usize) -> ByteStreamReadExact<'_> {
ByteStreamReadExact {
reader: self,
read_len,
fail_on_eos: false,
}
}
pub async fn read_u8(&mut self) -> Result<u8, ReadExactError> {
Ok(self.read_exact(1).await?[0])
}
pub async fn read_u16(&mut self) -> Result<u16, ReadExactError> {
let bytes = self.read_exact(2).await?;
let mut b = [0u8; 2];
b.copy_from_slice(&bytes[..]);
Ok(u16::from_be_bytes(b))
}
pub async fn read_u32(&mut self) -> Result<u32, ReadExactError> {
let bytes = self.read_exact(4).await?;
let mut b = [0u8; 4];
b.copy_from_slice(&bytes[..]);
Ok(u32::from_be_bytes(b))
}
pub fn into_stream(self) -> ByteStream {
let buf_stream = futures::stream::iter(self.buf.into_iter().map(Ok));
if let Some(err) = self.err {
Box::pin(buf_stream.chain(futures::stream::once(async move { Err(err) })))
} else if self.eos {
Box::pin(buf_stream)
} else {
Box::pin(buf_stream.chain(self.stream))
}
}
fn try_get(&mut self, read_len: usize) -> Option<Bytes> {
if self.buf_len >= read_len {
let mut slices = Vec::with_capacity(self.buf.len());
let mut taken = 0;
while taken < read_len {
let front = self.buf.pop_front().unwrap();
if taken + front.len() <= read_len {
taken += front.len();
self.buf_len -= front.len();
slices.push(front);
} else {
let front_take = read_len - taken;
slices.push(front.slice(..front_take));
self.buf.push_front(front.slice(front_take..));
self.buf_len -= front_take;
break;
}
}
Some(
slices
.iter()
.map(|x| &x[..])
.collect::<Vec<_>>()
.concat()
.into(),
)
} else {
None
}
}
}
pub enum ReadExactError {
UnexpectedEos,
Stream(u8),
}
#[pin_project::pin_project]
pub struct ByteStreamReadExact<'a> {
#[pin]
reader: &'a mut ByteStreamReader,
read_len: usize,
fail_on_eos: bool,
}
impl<'a> Future for ByteStreamReadExact<'a> {
type Output = Result<Bytes, ReadExactError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<Bytes, ReadExactError>> {
let mut this = self.project();
loop {
if let Some(bytes) = this.reader.try_get(*this.read_len) {
return Poll::Ready(Ok(bytes));
}
if let Some(err) = this.reader.err {
return Poll::Ready(Err(ReadExactError::Stream(err)));
}
if this.reader.eos {
if *this.fail_on_eos {
return Poll::Ready(Err(ReadExactError::UnexpectedEos));
} else {
let bytes = Bytes::from(
this.reader
.buf
.iter()
.map(|x| &x[..])
.collect::<Vec<_>>()
.concat(),
);
this.reader.buf.clear();
this.reader.buf_len = 0;
return Poll::Ready(Ok(bytes));
}
}
match futures::ready!(this.reader.stream.as_mut().poll_next(cx)) {
Some(Ok(slice)) => {
this.reader.buf_len += slice.len();
this.reader.buf.push_back(slice);
}
Some(Err(e)) => {
this.reader.err = Some(e);
this.reader.eos = true;
}
None => {
this.reader.eos = true;
}
}
}
}
}