diff options
Diffstat (limited to 'src/server.rs')
-rw-r--r-- | src/server.rs | 108 |
1 files changed, 60 insertions, 48 deletions
diff --git a/src/server.rs b/src/server.rs index a835959..cd367c4 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,9 +1,17 @@ +use std::collections::HashMap; use std::net::SocketAddr; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use arc_swap::ArcSwapOption; -use bytes::Bytes; -use log::{debug, trace}; +use async_trait::async_trait; +use log::*; + +use futures::io::{AsyncReadExt, AsyncWriteExt}; +use kuska_handshake::async_std::{handshake_server, BoxStream}; +use tokio::net::TcpStream; +use tokio::select; +use tokio::sync::{mpsc, watch}; +use tokio_util::compat::*; #[cfg(feature = "telemetry")] use opentelemetry::{ @@ -15,21 +23,12 @@ use opentelemetry_contrib::trace::propagator::binary::*; #[cfg(feature = "telemetry")] use rand::{thread_rng, Rng}; -use tokio::net::TcpStream; -use tokio::select; -use tokio::sync::{mpsc, watch}; -use tokio_util::compat::*; - -use futures::io::{AsyncReadExt, AsyncWriteExt}; - -use async_trait::async_trait; - -use kuska_handshake::async_std::{handshake_server, BoxStream}; - use crate::error::*; +use crate::message::*; use crate::netapp::*; -use crate::proto::*; -use crate::proto2::*; +use crate::recv::*; +use crate::send::*; +use crate::stream::*; use crate::util::*; // The client and server connection structs (client.rs and server.rs) @@ -55,7 +54,8 @@ pub(crate) struct ServerConn { netapp: Arc<NetApp>, - resp_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>>, + resp_send: ArcSwapOption<mpsc::UnboundedSender<SendItem>>, + running_handlers: Mutex<HashMap<RequestID, tokio::task::JoinHandle<()>>>, } impl ServerConn { @@ -101,6 +101,7 @@ impl ServerConn { remote_addr, peer_id, resp_send: ArcSwapOption::new(Some(Arc::new(resp_send))), + running_handlers: Mutex::new(HashMap::new()), }); netapp.connected_as_server(peer_id, conn.clone()); @@ -126,9 +127,8 @@ impl ServerConn { Ok(()) } - async fn recv_handler_aux(self: &Arc<Self>, bytes: &[u8]) -> Result<Vec<u8>, Error> { - let msg = QueryMessage::decode(bytes)?; - let path = String::from_utf8(msg.path.to_vec())?; + async fn recv_handler_aux(self: &Arc<Self>, req_enc: ReqEnc) -> Result<RespEnc, Error> { + let path = String::from_utf8(req_enc.path.to_vec())?; let handler_opt = { let endpoints = self.netapp.endpoints.read().unwrap(); @@ -140,9 +140,9 @@ impl ServerConn { if #[cfg(feature = "telemetry")] { let tracer = opentelemetry::global::tracer("netapp"); - let mut span = if let Some(telemetry_id) = msg.telemetry_id { + let mut span = if !req_enc.telemetry_id.is_empty() { let propagator = BinaryPropagator::new(); - let context = propagator.from_bytes(telemetry_id); + let context = propagator.from_bytes(req_enc.telemetry_id.to_vec()); let context = Context::new().with_remote_span_context(context); tracer.span_builder(format!(">> RPC {}", path)) .with_kind(SpanKind::Server) @@ -157,13 +157,13 @@ impl ServerConn { .start(&tracer) }; span.set_attribute(KeyValue::new("path", path.to_string())); - span.set_attribute(KeyValue::new("len_query", msg.body.len() as i64)); + span.set_attribute(KeyValue::new("len_query_msg", req_enc.msg.len() as i64)); - handler.handle(msg.body, self.peer_id) + handler.handle(req_enc, self.peer_id) .with_context(Context::current_with_span(span)) .await } else { - handler.handle(msg.body, self.peer_id).await + handler.handle(req_enc, self.peer_id).await } } } else { @@ -176,35 +176,47 @@ impl SendLoop for ServerConn {} #[async_trait] impl RecvLoop for ServerConn { - fn recv_handler(self: &Arc<Self>, id: RequestID, bytes: Vec<u8>) { - let resp_send = self.resp_send.load_full().unwrap(); + fn recv_handler(self: &Arc<Self>, id: RequestID, stream: ByteStream) { + let resp_send = match self.resp_send.load_full() { + Some(c) => c, + None => return, + }; + + let mut rh = self.running_handlers.lock().unwrap(); let self2 = self.clone(); - tokio::spawn(async move { - trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len()); - let bytes: Bytes = bytes.into(); - - let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 }; - let resp = self2.recv_handler_aux(&bytes[..]).await; - - let resp_bytes = match resp { - Ok(rb) => { - let mut resp_bytes = vec![0u8]; - resp_bytes.extend(rb); - resp_bytes - } - Err(e) => { - let mut resp_bytes = vec![e.code()]; - resp_bytes.extend(e.to_string().into_bytes()); - resp_bytes - } + let jh = tokio::spawn(async move { + debug!("server: recv_handler got {}", id); + + let (prio, resp_enc_result) = match ReqEnc::decode(stream).await { + Ok(req_enc) => (req_enc.prio, self2.recv_handler_aux(req_enc).await), + Err(e) => (PRIO_HIGH, Err(e)), }; - trace!("ServerConn sending response to {}: ", id); + debug!("server: sending response to {}", id); + let (resp_stream, resp_order) = RespEnc::encode(resp_enc_result); resp_send - .send((id, prio, resp_bytes)) - .log_err("ServerConn recv_handler send resp"); + .send(SendItem::Stream(id, prio, resp_order, resp_stream)) + .log_err("ServerConn recv_handler send resp bytes"); + + self2.running_handlers.lock().unwrap().remove(&id); }); + + rh.insert(id, jh); + } + + fn cancel_handler(self: &Arc<Self>, id: RequestID) { + trace!("received cancel for request {}", id); + + // If the handler is still running, abort it now + if let Some(jh) = self.running_handlers.lock().unwrap().remove(&id) { + jh.abort(); + } + + // Inform the response sender that we don't need to send the response + if let Some(resp_send) = self.resp_send.load_full() { + let _ = resp_send.send(SendItem::Cancel(id)); + } } } |