aboutsummaryrefslogblamecommitdiff
path: root/src/server.rs
blob: eb700573bfd7a5202cf9fc5acf06fdc25a0710ba (plain) (tree)
1
2
3
4
5
6
7
8
9
                         
                   
 
                            


                        
                             
                    
                                                                             




                                                        

                            
                          
                  











                                                              
                     
























                                                                     
                                                                                               


                 




                                                 
























                                                                                                     



                                                
                                                                                 




                                                                  










                                                                                         





                                                              
                                                                                             

                                                                 






                                                                              

                                                                  

                                                                                             
                                                                                                     
                                                                                         
                                                                                                  



                                                                                                               

                                                                           





                                                                                                 
                                          
                                                                                                    
                                                                                                              
 
                                                                              


                                                                                               
                                                                                    

                                 









                                             










                                                                                         
                                                     
                                           


                                                                       

                                           
                                                      
                                 
                          
 
                                                                          
 



                                                                              

         
use std::net::SocketAddr;
use std::sync::Arc;

use arc_swap::ArcSwapOption;
use bytes::Bytes;
use log::{debug, trace};

#[cfg(feature = "telemetry")]
use opentelemetry::{
	trace::{FutureExt, Span, SpanKind, TraceContextExt, TraceId, Tracer},
	Context, KeyValue,
};
#[cfg(feature = "telemetry")]
use opentelemetry_contrib::trace::propagator::binary::*;
#[cfg(feature = "telemetry")]
use rand::{thread_rng, Rng};

use tokio::net::TcpStream;
use tokio::select;
use tokio::sync::{mpsc, watch};
use tokio_util::compat::*;

use futures::io::AsyncReadExt;

use async_trait::async_trait;

use kuska_handshake::async_std::{handshake_server, BoxStream};

use crate::error::*;
use crate::netapp::*;
use crate::proto::*;
use crate::proto2::*;
use crate::util::*;

// The client and server connection structs (client.rs and server.rs)
// build upon the chunking mechanism which is exclusively contained
// in proto.rs.
// Here, we just care about sending big messages without size limit.
// The format of these messages is described below.
// Chunking happens independently.

// Request message format (client -> server):
// - u8 priority
// - u8 path length
// - [u8; path length] path
// - [u8; *] data

// Response message format (server -> client):
// - u8 response code
// - [u8; *] response

pub(crate) struct ServerConn {
	pub(crate) remote_addr: SocketAddr,
	pub(crate) peer_id: NodeID,

	netapp: Arc<NetApp>,

	resp_send: ArcSwapOption<mpsc::UnboundedSender<(RequestID, RequestPriority, Vec<u8>)>>,
}

impl ServerConn {
	pub(crate) async fn run(
		netapp: Arc<NetApp>,
		socket: TcpStream,
		must_exit: watch::Receiver<bool>,
	) -> Result<(), Error> {
		let remote_addr = socket.peer_addr()?;
		let mut socket = socket.compat();

		let handshake = handshake_server(
			&mut socket,
			netapp.netid.clone(),
			netapp.id,
			netapp.privkey.clone(),
		)
		.await?;
		let peer_id = handshake.peer_pk;

		debug!(
			"Handshake complete (server) with {}@{}",
			hex::encode(&peer_id),
			remote_addr
		);

		let (read, write) = socket.split();

		let (read, write) =
			BoxStream::from_handshake(read, write, handshake, 0x8000).split_read_write();

		let (resp_send, resp_recv) = mpsc::unbounded_channel();

		let conn = Arc::new(ServerConn {
			netapp: netapp.clone(),
			remote_addr,
			peer_id,
			resp_send: ArcSwapOption::new(Some(Arc::new(resp_send))),
		});

		netapp.connected_as_server(peer_id, conn.clone());

		let conn2 = conn.clone();
		let recv_future = tokio::spawn(async move {
			select! {
				r = conn2.recv_loop(read) => r,
				_ = await_exit(must_exit) => Ok(())
			}
		});
		let send_future = tokio::spawn(conn.clone().send_loop(resp_recv, write));

		recv_future.await.log_err("ServerConn recv_loop");
		conn.resp_send.store(None);
		send_future.await.log_err("ServerConn send_loop");

		netapp.disconnected_as_server(&peer_id, conn);

		Ok(())
	}

	async fn recv_handler_aux(self: &Arc<Self>, bytes: &[u8]) -> Result<Vec<u8>, Error> {
		let msg = QueryMessage::decode(bytes)?;
		let path = String::from_utf8(msg.path.to_vec())?;

		let handler_opt = {
			let endpoints = self.netapp.endpoints.read().unwrap();
			endpoints.get(&path).map(|e| e.clone_endpoint())
		};

		if let Some(handler) = handler_opt {
			cfg_if::cfg_if! {
				if #[cfg(feature = "telemetry")] {
					let tracer = opentelemetry::global::tracer("netapp");

					let mut span = if let Some(telemetry_id) = msg.telemetry_id {
						let propagator = BinaryPropagator::new();
						let context = propagator.from_bytes(telemetry_id);
						let context = Context::new().with_remote_span_context(context);
						tracer.span_builder(format!(">> RPC {}", path))
							.with_kind(SpanKind::Server)
							.start_with_context(&tracer, &context)
					} else {
						let mut rng = thread_rng();
						let trace_id = TraceId::from_bytes(rng.gen());
						tracer
							.span_builder(format!(">> RPC {}", path))
							.with_kind(SpanKind::Server)
							.with_trace_id(trace_id)
							.start(&tracer)
					};
					span.set_attribute(KeyValue::new("path", path.to_string()));
					span.set_attribute(KeyValue::new("len_query", msg.body.len() as i64));

					handler.handle(msg.body, self.peer_id)
						.with_context(Context::current_with_span(span))
						.await
				} else {
					handler.handle(msg.body, self.peer_id).await
				}
			}
		} else {
			Err(Error::NoHandler)
		}
	}
}

impl SendLoop for ServerConn {}

#[async_trait]
impl RecvLoop for ServerConn {
	fn recv_handler(self: &Arc<Self>, id: RequestID, bytes: Vec<u8>) {
		let resp_send = self.resp_send.load_full().unwrap();

		let self2 = self.clone();
		tokio::spawn(async move {
			trace!("ServerConn recv_handler {} ({} bytes)", id, bytes.len());
			let bytes: Bytes = bytes.into();

			let prio = if !bytes.is_empty() { bytes[0] } else { 0u8 };
			let resp = self2.recv_handler_aux(&bytes[..]).await;

			let resp_bytes = match resp {
				Ok(rb) => {
					let mut resp_bytes = vec![0u8];
					resp_bytes.extend(rb);
					resp_bytes
				}
				Err(e) => {
					vec![e.code()]
				}
			};

			trace!("ServerConn sending response to {}: ", id);

			resp_send
				.send((id, prio, resp_bytes))
				.log_err("ServerConn recv_handler send resp");
		});
	}
}