aboutsummaryrefslogtreecommitdiff
path: root/src/stream.rs
diff options
context:
space:
mode:
authorAlex Auvolat <alex@adnab.me>2022-07-22 12:45:38 +0200
committerAlex Auvolat <alex@adnab.me>2022-07-22 12:45:38 +0200
commit0b71ca12f910c17eaf2291076438dff3b70dc9cd (patch)
tree28c4239938b1bd585052c9a1b8b6a752b9c3bbe0 /src/stream.rs
parentc358fe3c92da8a8454e461484737efe2a14dfd73 (diff)
downloadnetapp-0b71ca12f910c17eaf2291076438dff3b70dc9cd.tar.gz
netapp-0b71ca12f910c17eaf2291076438dff3b70dc9cd.zip
Clean up framing protocol
Diffstat (limited to 'src/stream.rs')
-rw-r--r--src/stream.rs176
1 files changed, 176 insertions, 0 deletions
diff --git a/src/stream.rs b/src/stream.rs
new file mode 100644
index 0000000..6c23f4a
--- /dev/null
+++ b/src/stream.rs
@@ -0,0 +1,176 @@
+use std::collections::VecDeque;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+
+use bytes::Bytes;
+
+use futures::Future;
+use futures::{Stream, StreamExt};
+
+/// A stream of associated data.
+///
+/// When sent through Netapp, the Vec may be split in smaller chunk in such a way
+/// consecutive Vec may get merged, but Vec and error code may not be reordered
+///
+/// Error code 255 means the stream was cut before its end. Other codes have no predefined
+/// meaning, it's up to your application to define their semantic.
+pub type ByteStream = Pin<Box<dyn Stream<Item = Packet> + Send + Sync>>;
+
+pub type Packet = Result<Bytes, u8>;
+
+pub struct ByteStreamReader {
+ stream: ByteStream,
+ buf: VecDeque<Bytes>,
+ buf_len: usize,
+ eos: bool,
+ err: Option<u8>,
+}
+
+impl ByteStreamReader {
+ pub fn new(stream: ByteStream) -> Self {
+ ByteStreamReader {
+ stream,
+ buf: VecDeque::with_capacity(8),
+ buf_len: 0,
+ eos: false,
+ err: None,
+ }
+ }
+
+ pub fn read_exact(&mut self, read_len: usize) -> ByteStreamReadExact<'_> {
+ ByteStreamReadExact {
+ reader: self,
+ read_len,
+ fail_on_eos: true,
+ }
+ }
+
+ pub fn read_exact_or_eos(&mut self, read_len: usize) -> ByteStreamReadExact<'_> {
+ ByteStreamReadExact {
+ reader: self,
+ read_len,
+ fail_on_eos: false,
+ }
+ }
+
+ pub async fn read_u8(&mut self) -> Result<u8, ReadExactError> {
+ Ok(self.read_exact(1).await?[0])
+ }
+
+ pub async fn read_u16(&mut self) -> Result<u16, ReadExactError> {
+ let bytes = self.read_exact(2).await?;
+ let mut b = [0u8; 2];
+ b.copy_from_slice(&bytes[..]);
+ Ok(u16::from_be_bytes(b))
+ }
+
+ pub async fn read_u32(&mut self) -> Result<u32, ReadExactError> {
+ let bytes = self.read_exact(4).await?;
+ let mut b = [0u8; 4];
+ b.copy_from_slice(&bytes[..]);
+ Ok(u32::from_be_bytes(b))
+ }
+
+ pub fn into_stream(self) -> ByteStream {
+ let buf_stream = futures::stream::iter(self.buf.into_iter().map(Ok));
+ if let Some(err) = self.err {
+ Box::pin(buf_stream.chain(futures::stream::once(async move { Err(err) })))
+ } else if self.eos {
+ Box::pin(buf_stream)
+ } else {
+ Box::pin(buf_stream.chain(self.stream))
+ }
+ }
+
+ fn try_get(&mut self, read_len: usize) -> Option<Bytes> {
+ if self.buf_len >= read_len {
+ let mut slices = Vec::with_capacity(self.buf.len());
+ let mut taken = 0;
+ while taken < read_len {
+ let front = self.buf.pop_front().unwrap();
+ if taken + front.len() <= read_len {
+ taken += front.len();
+ self.buf_len -= front.len();
+ slices.push(front);
+ } else {
+ let front_take = read_len - taken;
+ slices.push(front.slice(..front_take));
+ self.buf.push_front(front.slice(front_take..));
+ self.buf_len -= front_take;
+ break;
+ }
+ }
+ Some(
+ slices
+ .iter()
+ .map(|x| &x[..])
+ .collect::<Vec<_>>()
+ .concat()
+ .into(),
+ )
+ } else {
+ None
+ }
+ }
+}
+
+pub enum ReadExactError {
+ UnexpectedEos,
+ Stream(u8),
+}
+
+#[pin_project::pin_project]
+pub struct ByteStreamReadExact<'a> {
+ #[pin]
+ reader: &'a mut ByteStreamReader,
+ read_len: usize,
+ fail_on_eos: bool,
+}
+
+impl<'a> Future for ByteStreamReadExact<'a> {
+ type Output = Result<Bytes, ReadExactError>;
+
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<Bytes, ReadExactError>> {
+ let mut this = self.project();
+
+ loop {
+ if let Some(bytes) = this.reader.try_get(*this.read_len) {
+ return Poll::Ready(Ok(bytes));
+ }
+ if let Some(err) = this.reader.err {
+ return Poll::Ready(Err(ReadExactError::Stream(err)));
+ }
+ if this.reader.eos {
+ if *this.fail_on_eos {
+ return Poll::Ready(Err(ReadExactError::UnexpectedEos));
+ } else {
+ let bytes = Bytes::from(
+ this.reader
+ .buf
+ .iter()
+ .map(|x| &x[..])
+ .collect::<Vec<_>>()
+ .concat(),
+ );
+ this.reader.buf.clear();
+ this.reader.buf_len = 0;
+ return Poll::Ready(Ok(bytes));
+ }
+ }
+
+ match futures::ready!(this.reader.stream.as_mut().poll_next(cx)) {
+ Some(Ok(slice)) => {
+ this.reader.buf_len += slice.len();
+ this.reader.buf.push_back(slice);
+ }
+ Some(Err(e)) => {
+ this.reader.err = Some(e);
+ this.reader.eos = true;
+ }
+ None => {
+ this.reader.eos = true;
+ }
+ }
+ }
+ }
+}