aboutsummaryrefslogblamecommitdiff
path: root/src/server.rs
blob: 937d65a0d1ce8fdb8d9bb3d5ed7a062a9398ec7e (plain) (tree)
1
2
3
4
5
6
7
8
9
10
                         
                   
 
                            


                        


                             
                          
                  




































                                                                     
                                                                                               


                 




                                                 
























                                                                                                     



                                                
                                                                                 




                                                                  










                                                                                         





                                                              












                                                                                             



                                                                    






                                                                              































                                                                                                          









                                             



















                                                                                         
                         
 
                                                                          
 



                                                                              

         
use std::net::SocketAddr;
use std::sync::Arc;

use arc_swap::ArcSwapOption;
use bytes::Bytes;
use log::{debug, trace};

#[cfg(feature = "telemetry")]
use rand::{thread_rng, Rng};

use tokio::net::TcpStream;
use tokio::select;
use tokio::sync::{mpsc, watch};
use tokio_util::compat::*;

use futures::io::AsyncReadExt;

use async_trait::async_trait;

use kuska_handshake::async_std::{handshake_server, BoxStream};

use crate::error::*;
use crate::netapp::*;
use crate::proto::*;
use crate::util::*;

// The client and server connection structs (client.rs and server.rs)
// build upon the chunking mechanism which is exclusively contained
// in proto.rs.
// Here, we just care about sending big messages without size limit.
// The format of these messages is described below.
// Chunking happens independently.

// Request message format (client -> server):
// - u8 priority
// - u8 path length
// - [u8; path length] path
// - [u8; *] data

// Response message format (server -> client):
// - u8 response code
// - [u8; *] response

pub(crate) struct ServerConn {
	pub(crate) remote_addr: SocketAddr,
	pub(crate) peer_id: NodeID,

	netapp: Arc<NetApp>,

	resp_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>>,
}

impl ServerConn {
	pub(crate) async fn run(
		netapp: Arc<NetApp>,
		socket: TcpStream,
		must_exit: watch::Receiver<bool>,
	) -> Result<(), Error> {
		let remote_addr = socket.peer_addr()?;
		let mut socket = socket.compat();

		let handshake = handshake_server(
			&mut socket,
			netapp.netid.clone(),
			netapp.id,
			netapp.privkey.clone(),
		)
		.await?;
		let peer_id = handshake.peer_pk;

		debug!(
			"Handshake complete (server) with {}@{}",
			hex::encode(&peer_id),
			remote_addr
		);

		let (read, write) = socket.split();

		let (read, write) =
			BoxStream::from_handshake(read, write, handshake, 0x8000).split_read_write();

		let (resp_send, resp_recv) = mpsc::unbounded_channel();

		let conn = Arc::new(ServerConn {
			netapp: netapp.clone(),
			remote_addr,
			peer_id,
			resp_send: ArcSwapOption::new(Some(Arc::new(resp_send))),
		});

		netapp.connected_as_server(peer_id, conn.clone());

		let conn2 = conn.clone();
		let recv_future = tokio::spawn(async move {
			select! {
				r = conn2.recv_loop(read) => r,
				_ = await_exit(must_exit) => Ok(())
			}
		});
		let send_future = tokio::spawn(conn.clone().send_loop(resp_recv, write));

		recv_future.await.log_err("ServerConn recv_loop");
		conn.resp_send.store(None);
		send_future.await.log_err("ServerConn send_loop");

		netapp.disconnected_as_server(&peer_id, conn);

		Ok(())
	}

	async fn recv_handler_aux(self: &Arc<Self>, bytes: &[u8]) -> Result<Vec<u8>, Error> {
		if bytes.len() < 2 {
			return Err(Error::Message("Invalid protocol message".into()));
		}

		// byte 0 is the request priority, we don't care here
		let path_length = bytes[1] as usize;
		if bytes.len() < 2 + path_length {
			return Err(Error::Message("Invalid protocol message".into()));
		}

		let path = &bytes[2..2 + path_length];
		let path = String::from_utf8(path.to_vec())?;

		let trace_id_len = bytes[2 + path_length] as usize;

		let data = &bytes[3 + path_length + trace_id_len..];

		let handler_opt = {
			let endpoints = self.netapp.endpoints.read().unwrap();
			endpoints.get(&path).map(|e| e.clone_endpoint())
		};

		if let Some(handler) = handler_opt {
			cfg_if::cfg_if! {
				if #[cfg(feature = "telemetry")] {
					use opentelemetry::{
						KeyValue,
						trace::{FutureExt, TraceContextExt, Tracer},
						Context, trace::TraceId
					};
					let trace_id = if trace_id_len == 16 {
						let mut by = [0u8; 16];
						by.copy_from_slice(&bytes[3+path_length..19+path_length]);
						TraceId::from_bytes(by)
					} else {
						let mut rng = thread_rng();
						TraceId::from_bytes(rng.gen())
					};

					let tracer = opentelemetry::global::tracer("garage");
					let span = tracer
						.span_builder(format!("RPC handler {}", path))
						.with_trace_id(trace_id)
						.with_attributes(vec![
							KeyValue::new("path", path),
						])
						.start(&tracer);

					handler.handle(data, self.peer_id)
						.with_context(Context::current_with_span(span))
						.await
				} else {
					handler.handle(data, self.peer_id).await
				}
			}
		} else {
			Err(Error::NoHandler)
		}
	}
}

impl SendLoop for ServerConn {}

#[async_trait]
impl RecvLoop for ServerConn {
	fn recv_handler(self: &Arc<Self>, id: RequestID, bytes: Vec<u8>) {
		let resp_send = self.resp_send.load_full().unwrap();

		let self2 = self.clone();
		tokio::spawn(async move {
			trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len());
			let bytes: Bytes = bytes.into();

			let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 };
			let resp = self2.recv_handler_aux(&bytes[..]).await;

			let mut resp_bytes = vec![];
			match resp {
				Ok(rb) => {
					resp_bytes.push(0u8);
					resp_bytes.extend(&rb[..]);
				}
				Err(e) => {
					resp_bytes.push(e.code());
				}
			}

			trace!("ServerConn sending response to {}: ", id);

			resp_send
				.send((id, prio, resp_bytes))
				.log_err("ServerConn recv_handler send resp");
		});
	}
}