aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/client.rs33
-rw-r--r--src/endpoint.rs10
-rw-r--r--src/lib.rs1
-rw-r--r--src/proto.rs24
-rw-r--r--src/proto2.rs75
-rw-r--r--src/server.rs42
6 files changed, 132 insertions, 53 deletions
diff --git a/src/client.rs b/src/client.rs
index d6caf68..e2d5d84 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -2,6 +2,7 @@ use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::atomic::{self, AtomicU32};
use std::sync::{Arc, Mutex};
+use std::borrow::Borrow;
use arc_swap::ArcSwapOption;
use log::{debug, error, trace};
@@ -29,6 +30,7 @@ use crate::endpoint::*;
use crate::error::*;
use crate::netapp::*;
use crate::proto::*;
+use crate::proto2::*;
use crate::util::*;
pub(crate) struct ClientConn {
@@ -118,14 +120,15 @@ impl ClientConn {
self.query_send.store(None);
}
- pub(crate) async fn call<T>(
+ pub(crate) async fn call<'a, T, B>(
self: Arc<Self>,
- rq: &T,
- path: &str,
+ rq: B,
+ path: &'a str,
prio: RequestPriority,
) -> Result<<T as Message>::Response, Error>
where
T: Message,
+ B: Borrow<T>,
{
let query_send = self.query_send.load_full().ok_or(Error::ConnectionClosed)?;
@@ -147,19 +150,17 @@ impl ClientConn {
};
// Encode request
- let mut bytes = vec![];
-
- bytes.extend_from_slice(&[prio, path.as_bytes().len() as u8]);
- bytes.extend_from_slice(path.as_bytes());
-
- if let Some(by) = telemetry_id {
- bytes.push(by.len() as u8);
- bytes.extend(by);
- } else {
- bytes.push(0);
- }
-
- bytes.extend_from_slice(&rmp_to_vec_all_named(rq)?[..]);
+ let body = rmp_to_vec_all_named(rq.borrow())?;
+ drop(rq);
+
+ let request = QueryMessage {
+ prio,
+ path: path.as_bytes(),
+ telemetry_id,
+ body: &body[..],
+ };
+ let bytes = request.encode();
+ drop(body);
// Send request through
let (resp_send, resp_recv) = oneshot::channel();
diff --git a/src/endpoint.rs b/src/endpoint.rs
index 760bf32..b408241 100644
--- a/src/endpoint.rs
+++ b/src/endpoint.rs
@@ -1,5 +1,6 @@
use std::marker::PhantomData;
use std::sync::Arc;
+use std::borrow::Borrow;
use arc_swap::ArcSwapOption;
use async_trait::async_trait;
@@ -88,16 +89,17 @@ where
/// Call this endpoint on a remote node (or on the local node,
/// for that matter)
- pub async fn call(
+ pub async fn call<B>(
&self,
target: &NodeID,
- req: &M,
+ req: B,
prio: RequestPriority,
- ) -> Result<<M as Message>::Response, Error> {
+ ) -> Result<<M as Message>::Response, Error>
+ where B: Borrow<M> {
if *target == self.netapp.id {
match self.handler.load_full() {
None => Err(Error::NoHandler),
- Some(h) => Ok(h.handle(req, self.netapp.id).await),
+ Some(h) => Ok(h.handle(req.borrow(), self.netapp.id).await),
}
} else {
let conn = self
diff --git a/src/lib.rs b/src/lib.rs
index 3162c42..89b4f32 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -19,6 +19,7 @@ pub mod util;
pub mod endpoint;
pub mod proto;
+mod proto2;
mod client;
mod server;
diff --git a/src/proto.rs b/src/proto.rs
index 18e7c44..2db3f83 100644
--- a/src/proto.rs
+++ b/src/proto.rs
@@ -96,6 +96,14 @@ impl SendQueue {
}
}
+/// The SendLoop trait, which is implemented both by the client and the server
+/// connection objects (ServerConna and ClientConn) adds a method `.send_loop()`
+/// that takes a channel of messages to send and an asynchronous writer,
+/// and sends messages from the channel to the async writer, putting them in a queue
+/// before being sent and doing the round-robin sending strategy.
+///
+/// The `.send_loop()` exits when the sending end of the channel is closed,
+/// or if there is an error at any time writing to the async writer.
#[async_trait]
pub(crate) trait SendLoop: Sync {
async fn send_loop<W>(
@@ -128,9 +136,9 @@ pub(crate) trait SendLoop: Sync {
write.write_all(&header_id[..]).await?;
if item.data.len() - item.cursor > MAX_CHUNK_LENGTH as usize {
- let header_size =
+ let size_header =
ChunkLength::to_be_bytes(MAX_CHUNK_LENGTH | CHUNK_HAS_CONTINUATION);
- write.write_all(&header_size[..]).await?;
+ write.write_all(&size_header[..]).await?;
let new_cursor = item.cursor + MAX_CHUNK_LENGTH as usize;
write.write_all(&item.data[item.cursor..new_cursor]).await?;
@@ -140,8 +148,8 @@ pub(crate) trait SendLoop: Sync {
} else {
let send_len = (item.data.len() - item.cursor) as ChunkLength;
- let header_size = ChunkLength::to_be_bytes(send_len);
- write.write_all(&header_size[..]).await?;
+ let size_header = ChunkLength::to_be_bytes(send_len);
+ write.write_all(&size_header[..]).await?;
write.write_all(&item.data[item.cursor..]).await?;
}
@@ -166,9 +174,15 @@ pub(crate) trait SendLoop: Sync {
}
}
+/// The RecvLoop trait, which is implemented both by the client and the server
+/// connection objects (ServerConn and ClientConn) adds a method `.recv_loop()`
+/// and a prototype of a handler for received messages `.recv_handler()` that
+/// must be filled by implementors. `.recv_loop()` receives messages in a loop
+/// according to the protocol defined above: chunks of message in progress of being
+/// received are stored in a buffer, and when the last chunk of a message is received,
+/// the full message is passed to the receive handler.
#[async_trait]
pub(crate) trait RecvLoop: Sync + 'static {
- // Returns true if we should stop receiving after this
fn recv_handler(self: &Arc<Self>, id: RequestID, msg: Vec<u8>);
async fn recv_loop<R>(self: Arc<Self>, mut read: R) -> Result<(), Error>
diff --git a/src/proto2.rs b/src/proto2.rs
new file mode 100644
index 0000000..4e126d3
--- /dev/null
+++ b/src/proto2.rs
@@ -0,0 +1,75 @@
+use crate::proto::*;
+use crate::error::*;
+
+pub(crate) struct QueryMessage<'a> {
+ pub(crate) prio: RequestPriority,
+ pub(crate) path: &'a [u8],
+ pub(crate) telemetry_id: Option<Vec<u8>>,
+ pub(crate) body: &'a [u8],
+}
+
+/// QueryMessage encoding:
+/// - priority: u8
+/// - path length: u8
+/// - path: [u8; path length]
+/// - telemetry id length: u8
+/// - telemetry id: [u8; telemetry id length]
+/// - body [u8; ..]
+impl<'a> QueryMessage<'a> {
+ pub(crate) fn encode(self) -> Vec<u8> {
+ let tel_len = match &self.telemetry_id {
+ Some(t) => t.len(),
+ None => 0,
+ };
+
+ let mut ret = Vec::with_capacity(10 + self.path.len() + tel_len + self.body.len());
+
+ ret.push(self.prio);
+
+ ret.push(self.path.len() as u8);
+ ret.extend_from_slice(self.path);
+
+ if let Some(t) = self.telemetry_id {
+ ret.push(t.len() as u8);
+ ret.extend(t);
+ } else {
+ ret.push(0u8);
+ }
+
+ ret.extend_from_slice(self.body);
+
+ ret
+ }
+
+ pub(crate) fn decode(bytes: &'a [u8]) -> Result<Self, Error> {
+ if bytes.len() < 3 {
+ return Err(Error::Message("Invalid protocol message".into()));
+ }
+
+ let path_length = bytes[1] as usize;
+ if bytes.len() < 3 + path_length {
+ return Err(Error::Message("Invalid protocol message".into()));
+ }
+
+ let telemetry_id_len = bytes[2 + path_length] as usize;
+ if bytes.len() < 3 + path_length + telemetry_id_len {
+ return Err(Error::Message("Invalid protocol message".into()));
+ }
+
+ let path = &bytes[2..2 + path_length];
+ let telemetry_id = if telemetry_id_len > 0 {
+ Some(bytes[3 + path_length .. 3 + path_length + telemetry_id_len].to_vec())
+ } else {
+ None
+ };
+
+ let body = &bytes[3 + path_length + telemetry_id_len..];
+
+ Ok(Self {
+ prio: bytes[0],
+ path,
+ telemetry_id,
+ body,
+ })
+ }
+}
diff --git a/src/server.rs b/src/server.rs
index 7bf17df..eb70057 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -29,6 +29,7 @@ use kuska_handshake::async_std::{handshake_server, BoxStream};
use crate::error::*;
use crate::netapp::*;
use crate::proto::*;
+use crate::proto2::*;
use crate::util::*;
// The client and server connection structs (client.rs and server.rs)
@@ -116,22 +117,8 @@ impl ServerConn {
}
async fn recv_handler_aux(self: &Arc<Self>, bytes: &[u8]) -> Result<Vec<u8>, Error> {
- if bytes.len() < 2 {
- return Err(Error::Message("Invalid protocol message".into()));
- }
-
- // byte 0 is the request priority, we don't care here
- let path_length = bytes[1] as usize;
- if bytes.len() < 2 + path_length {
- return Err(Error::Message("Invalid protocol message".into()));
- }
-
- let path = &bytes[2..2 + path_length];
- let path = String::from_utf8(path.to_vec())?;
-
- let telemetry_id_len = bytes[2 + path_length] as usize;
-
- let data = &bytes[3 + path_length + telemetry_id_len..];
+ let msg = QueryMessage::decode(bytes)?;
+ let path = String::from_utf8(msg.path.to_vec())?;
let handler_opt = {
let endpoints = self.netapp.endpoints.read().unwrap();
@@ -143,10 +130,9 @@ impl ServerConn {
if #[cfg(feature = "telemetry")] {
let tracer = opentelemetry::global::tracer("netapp");
- let mut span = if telemetry_id_len > 0 {
- let by = bytes[3+path_length..3+path_length+telemetry_id_len].to_vec();
+ let mut span = if let Some(telemetry_id) = msg.telemetry_id {
let propagator = BinaryPropagator::new();
- let context = propagator.from_bytes(by);
+ let context = propagator.from_bytes(telemetry_id);
let context = Context::new().with_remote_span_context(context);
tracer.span_builder(format!(">> RPC {}", path))
.with_kind(SpanKind::Server)
@@ -161,13 +147,13 @@ impl ServerConn {
.start(&tracer)
};
span.set_attribute(KeyValue::new("path", path.to_string()));
- span.set_attribute(KeyValue::new("len_query", data.len() as i64));
+ span.set_attribute(KeyValue::new("len_query", msg.body.len() as i64));
- handler.handle(data, self.peer_id)
+ handler.handle(msg.body, self.peer_id)
.with_context(Context::current_with_span(span))
.await
} else {
- handler.handle(data, self.peer_id).await
+ handler.handle(msg.body, self.peer_id).await
}
}
} else {
@@ -191,16 +177,16 @@ impl RecvLoop for ServerConn {
let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 };
let resp = self2.recv_handler_aux(&bytes[..]).await;
- let mut resp_bytes = vec![];
- match resp {
+ let resp_bytes = match resp {
Ok(rb) => {
- resp_bytes.push(0u8);
- resp_bytes.extend(&rb[..]);
+ let mut resp_bytes = vec![0u8];
+ resp_bytes.extend(rb);
+ resp_bytes
}
Err(e) => {
- resp_bytes.push(e.code());
+ vec![e.code()]
}
- }
+ };
trace!("ServerConn sending response to {}: ", id);