aboutsummaryrefslogtreecommitdiff
path: root/src/server.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/server.rs')
-rw-r--r--src/server.rs110
1 files changed, 61 insertions, 49 deletions
diff --git a/src/server.rs b/src/server.rs
index a835959..f9eb121 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -1,9 +1,17 @@
+use std::collections::HashMap;
use std::net::SocketAddr;
-use std::sync::Arc;
+use std::sync::{Arc, Mutex};
use arc_swap::ArcSwapOption;
-use bytes::Bytes;
-use log::{debug, trace};
+use async_trait::async_trait;
+use log::*;
+
+use futures::io::{AsyncReadExt, AsyncWriteExt};
+use kuska_handshake::async_std::{handshake_server, BoxStream};
+use tokio::net::TcpStream;
+use tokio::select;
+use tokio::sync::{mpsc, watch};
+use tokio_util::compat::*;
#[cfg(feature = "telemetry")]
use opentelemetry::{
@@ -15,21 +23,12 @@ use opentelemetry_contrib::trace::propagator::binary::*;
#[cfg(feature = "telemetry")]
use rand::{thread_rng, Rng};
-use tokio::net::TcpStream;
-use tokio::select;
-use tokio::sync::{mpsc, watch};
-use tokio_util::compat::*;
-
-use futures::io::{AsyncReadExt, AsyncWriteExt};
-
-use async_trait::async_trait;
-
-use kuska_handshake::async_std::{handshake_server, BoxStream};
-
use crate::error::*;
+use crate::message::*;
use crate::netapp::*;
-use crate::proto::*;
-use crate::proto2::*;
+use crate::recv::*;
+use crate::send::*;
+use crate::stream::*;
use crate::util::*;
// The client and server connection structs (client.rs and server.rs)
@@ -55,7 +54,8 @@ pub(crate) struct ServerConn {
netapp: Arc<NetApp>,
- resp_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>>,
+ resp_send: ArcSwapOption<mpsc::UnboundedSender<SendItem>>,
+ running_handlers: Mutex<HashMap<RequestID, tokio::task::JoinHandle<()>>>,
}
impl ServerConn {
@@ -101,6 +101,7 @@ impl ServerConn {
remote_addr,
peer_id,
resp_send: ArcSwapOption::new(Some(Arc::new(resp_send))),
+ running_handlers: Mutex::new(HashMap::new()),
});
netapp.connected_as_server(peer_id, conn.clone());
@@ -126,13 +127,12 @@ impl ServerConn {
Ok(())
}
- async fn recv_handler_aux(self: &Arc<Self>, bytes: &[u8]) -> Result<Vec<u8>, Error> {
- let msg = QueryMessage::decode(bytes)?;
- let path = String::from_utf8(msg.path.to_vec())?;
+ async fn recv_handler_aux(self: &Arc<Self>, req_enc: ReqEnc) -> Result<RespEnc, Error> {
+ let path = String::from_utf8(req_enc.path.to_vec())?;
let handler_opt = {
let endpoints = self.netapp.endpoints.read().unwrap();
- endpoints.get(&path).map(|e| e.clone_endpoint())
+ endpoints.get(&path[..]).map(|e| e.clone_endpoint())
};
if let Some(handler) = handler_opt {
@@ -140,9 +140,9 @@ impl ServerConn {
if #[cfg(feature = "telemetry")] {
let tracer = opentelemetry::global::tracer("netapp");
- let mut span = if let Some(telemetry_id) = msg.telemetry_id {
+ let mut span = if !req_enc.telemetry_id.is_empty() {
let propagator = BinaryPropagator::new();
- let context = propagator.from_bytes(telemetry_id);
+ let context = propagator.from_bytes(req_enc.telemetry_id.to_vec());
let context = Context::new().with_remote_span_context(context);
tracer.span_builder(format!(">> RPC {}", path))
.with_kind(SpanKind::Server)
@@ -157,13 +157,13 @@ impl ServerConn {
.start(&tracer)
};
span.set_attribute(KeyValue::new("path", path.to_string()));
- span.set_attribute(KeyValue::new("len_query", msg.body.len() as i64));
+ span.set_attribute(KeyValue::new("len_query_msg", req_enc.msg.len() as i64));
- handler.handle(msg.body, self.peer_id)
+ handler.handle(req_enc, self.peer_id)
.with_context(Context::current_with_span(span))
.await
} else {
- handler.handle(msg.body, self.peer_id).await
+ handler.handle(req_enc, self.peer_id).await
}
}
} else {
@@ -176,35 +176,47 @@ impl SendLoop for ServerConn {}
#[async_trait]
impl RecvLoop for ServerConn {
- fn recv_handler(self: &Arc<Self>, id: RequestID, bytes: Vec<u8>) {
- let resp_send = self.resp_send.load_full().unwrap();
+ fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream) {
+ let resp_send = match self.resp_send.load_full() {
+ Some(c) => c,
+ None => return,
+ };
+
+ let mut rh = self.running_handlers.lock().unwrap();
let self2 = self.clone();
- tokio::spawn(async move {
- trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len());
- let bytes: Bytes = bytes.into();
-
- let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 };
- let resp = self2.recv_handler_aux(&bytes[..]).await;
-
- let resp_bytes = match resp {
- Ok(rb) => {
- let mut resp_bytes = vec![0u8];
- resp_bytes.extend(rb);
- resp_bytes
- }
- Err(e) => {
- let mut resp_bytes = vec![e.code()];
- resp_bytes.extend(e.to_string().into_bytes());
- resp_bytes
- }
+ let jh = tokio::spawn(async move {
+ debug!("server: recv_handler got {}", id);
+
+ let (prio, resp_enc_result) = match ReqEnc::decode(stream).await {
+ Ok(req_enc) => (req_enc.prio, self2.recv_handler_aux(req_enc).await),
+ Err(e) => (PRIO_HIGH, Err(e)),
};
- trace!("ServerConn sending response to {}: ", id);
+ debug!("server: sending response to {}", id);
+ let (resp_stream, resp_order) = RespEnc::encode(resp_enc_result);
resp_send
- .send((id, prio, resp_bytes))
- .log_err("ServerConn recv_handler send resp");
+ .send(SendItem::Stream(id, prio, resp_order, resp_stream))
+ .log_err("ServerConn recv_handler send resp bytes");
+
+ self2.running_handlers.lock().unwrap().remove(&id);
});
+
+ rh.insert(id, jh);
+ }
+
+ fn cancel_handler(self: &Arc<Self>, id: RequestID) {
+ trace!("received cancel for request {}", id);
+
+ // If the handler is still running, abort it now
+ if let Some(jh) = self.running_handlers.lock().unwrap().remove(&id) {
+ jh.abort();
+ }
+
+ // Inform the response sender that we don't need to send the response
+ if let Some(resp_send) = self.resp_send.load_full() {
+ let _ = resp_send.send(SendItem::Cancel(id));
+ }
}
}