aboutsummaryrefslogblamecommitdiff
path: root/src/netapp.rs
blob: 25c3b5a055542e893877e61ffbd99558bc542e81 (plain) (tree)
1
                  





















                                         











                                                                                                         






                                                                               
                                                                             




                                                                                                       




                                   


                                                                  
                                                                   






                                                                              

                                                          
          


                                                    















                                                                           






















                                                                             
                                            









                                                                          
                                                                                     

                                                


                                                                                            
                                                                                       







                                                                                           

                           





                                             
                                                                             
                                               



























                                                                                          











                                                                                    


                                                                        
 





                                                                       











                                                                                       





                                                                                                 

                                                                      













                                                                                                       

                                                                    












                                                                                                 

                                                                       



















                                                                                                     
                                                                  










                                                                               




































                                                                                          
 
use std::any::Any;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::{Arc, RwLock};

use std::future::Future;

use log::{debug, info};

use arc_swap::{ArcSwap, ArcSwapOption};
use bytes::Bytes;

use sodiumoxide::crypto::auth;
use sodiumoxide::crypto::sign::ed25519;
use tokio::net::{TcpListener, TcpStream};

use crate::conn::*;
use crate::error::*;
use crate::message::*;
use crate::proto::*;
use crate::util::*;

type DynMsg = Box<dyn Any + Send + Sync + 'static>;

pub(crate) struct Handler {
	pub(crate) local_handler:
		Box<dyn Fn(DynMsg) -> Pin<Box<dyn Future<Output = DynMsg> + Sync + Send>> + Sync + Send>,
	pub(crate) net_handler: Box<
		dyn Fn(ed25519::PublicKey, Bytes) -> Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>>
			+ Sync
			+ Send,
	>,
}

pub struct NetApp {
	pub listen_addr: SocketAddr,
	pub netid: auth::Key,
	pub pubkey: ed25519::PublicKey,
	pub privkey: ed25519::SecretKey,
	pub server_conns: RwLock<HashMap<ed25519::PublicKey, Arc<ServerConn>>>,
	pub client_conns: RwLock<HashMap<ed25519::PublicKey, Arc<ClientConn>>>,
	pub(crate) msg_handlers: ArcSwap<HashMap<MessageKind, Arc<Handler>>>,
	pub(crate) on_connected:
		ArcSwapOption<Box<dyn Fn(ed25519::PublicKey, SocketAddr, bool) + Send + Sync>>,
	pub(crate) on_disconnected: ArcSwapOption<Box<dyn Fn(ed25519::PublicKey, bool) + Send + Sync>>,
}

async fn net_handler_aux<M, F, R>(
	handler: Arc<F>,
	remote: ed25519::PublicKey,
	bytes: Bytes,
) -> Vec<u8>
where
	M: Message + 'static,
	F: Fn(ed25519::PublicKey, M) -> R + Send + Sync + 'static,
	R: Future<Output = <M as Message>::Response> + Send + Sync,
{
	debug!(
		"Handling message of kind {:08x} from {}",
		M::KIND,
		hex::encode(remote)
	);
	let res = match rmp_serde::decode::from_read_ref::<_, M>(&bytes[..]) {
		Ok(msg) => Ok(handler(remote, msg).await),
		Err(e) => Err(e.to_string()),
	};
	rmp_to_vec_all_named(&res).unwrap_or(vec![])
}

async fn local_handler_aux<M, F, R>(
	handler: Arc<F>,
	remote: ed25519::PublicKey,
	msg: DynMsg,
) -> DynMsg
where
	M: Message + 'static,
	F: Fn(ed25519::PublicKey, M) -> R + Send + Sync + 'static,
	R: Future<Output = <M as Message>::Response> + Send + Sync,
{
	debug!("Handling message of kind {:08x} from ourself", M::KIND,);
	let msg = (msg as Box<dyn Any + 'static>).downcast::<M>().unwrap();
	let res = handler(remote, *msg).await;
	Box::new(res)
}

impl NetApp {
	pub fn new(
		listen_addr: SocketAddr,
		netid: auth::Key,
		privkey: ed25519::SecretKey,
	) -> Arc<Self> {
		let pubkey = privkey.public_key();
		let netapp = Arc::new(Self {
			listen_addr,
			netid,
			pubkey,
			privkey,
			server_conns: RwLock::new(HashMap::new()),
			client_conns: RwLock::new(HashMap::new()),
			msg_handlers: ArcSwap::new(Arc::new(HashMap::new())),
			on_connected: ArcSwapOption::new(None),
			on_disconnected: ArcSwapOption::new(None),
		});

		let netapp2 = netapp.clone();
		netapp.add_msg_handler::<HelloMessage, _, _>(
			move |from: ed25519::PublicKey, msg: HelloMessage| {
				netapp2.handle_hello_message(from, msg);
				async { () }
			},
		);

		netapp
	}

	pub fn add_msg_handler<M, F, R>(&self, handler: F)
	where
		M: Message + 'static,
		F: Fn(ed25519::PublicKey, M) -> R + Send + Sync + 'static,
		R: Future<Output = <M as Message>::Response> + Send + Sync + 'static,
	{
		let handler = Arc::new(handler);

		let handler1 = handler.clone();
		let net_handler = Box::new(move |remote: ed25519::PublicKey, bytes: Bytes| {
			let fun: Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>> =
				Box::pin(net_handler_aux(handler1.clone(), remote, bytes));
			fun
		});

		let self_id = self.pubkey.clone();
		let local_handler = Box::new(move |msg: DynMsg| {
			let fun: Pin<Box<dyn Future<Output = DynMsg> + Sync + Send>> =
				Box::pin(local_handler_aux(handler.clone(), self_id, msg));
			fun
		});

		let funs = Arc::new(Handler {
			net_handler,
			local_handler,
		});

		let mut handlers = self.msg_handlers.load().as_ref().clone();
		handlers.insert(M::KIND, funs);
		self.msg_handlers.store(Arc::new(handlers));
	}

	pub async fn listen(self: Arc<Self>) {
		let mut listener = TcpListener::bind(self.listen_addr).await.unwrap();
		info!("Listening on {}", self.listen_addr);

		loop {
			// The second item contains the IP and port of the new connection.
			let (socket, _) = listener.accept().await.unwrap();
			info!(
				"Incoming connection from {}, negotiating handshake...",
				socket.peer_addr().unwrap()
			);
			let self2 = self.clone();
			tokio::spawn(async move {
				ServerConn::run(self2, socket)
					.await
					.log_err("ServerConn::run");
			});
		}
	}

	pub async fn try_connect(
		self: Arc<Self>,
		ip: SocketAddr,
		pk: ed25519::PublicKey,
	) -> Result<(), Error> {
		if pk == self.pubkey {
			// Don't connect to ourself, we don't care
			// but pretend we did
			tokio::spawn(async move {
				if let Some(h) = self.on_connected.load().as_ref() {
					h(pk, ip, false);
				}
			});
			return Ok(());
		}

		// Don't connect if already connected
		if self.client_conns.read().unwrap().contains_key(&pk) {
			return Ok(());
		}

		let socket = TcpStream::connect(ip).await?;
		info!("Connected to {}, negotiating handshake...", ip);
		ClientConn::init(self, socket, pk.clone()).await?;
		Ok(())
	}

	pub fn disconnect(self: Arc<Self>, pk: &ed25519::PublicKey) {
		if *pk == self.pubkey {
			let pk = *pk;
			tokio::spawn(async move {
				if let Some(h) = self.on_disconnected.load().as_ref() {
					h(pk, false);
				}
			});
			return;
		}

		let conn = self.client_conns.read().unwrap().get(pk).cloned();
		if let Some(c) = conn {
			c.close();
		}
	}

	pub(crate) fn connected_as_server(&self, id: ed25519::PublicKey, conn: Arc<ServerConn>) {
		info!("Accepted connection from {}", hex::encode(id));

		let mut conn_list = self.server_conns.write().unwrap();
		conn_list.insert(id.clone(), conn);
	}

	fn handle_hello_message(&self, id: ed25519::PublicKey, msg: HelloMessage) {
		if let Some(h) = self.on_connected.load().as_ref() {
			if let Some(c) = self.server_conns.read().unwrap().get(&id) {
				let remote_addr = SocketAddr::new(c.remote_addr.ip(), msg.server_port);
				h(id, remote_addr, true);
			}
		}
	}

	pub(crate) fn disconnected_as_server(&self, id: &ed25519::PublicKey, conn: Arc<ServerConn>) {
		info!("Connection from {} closed", hex::encode(id));

		let mut conn_list = self.server_conns.write().unwrap();
		if let Some(c) = conn_list.get(id) {
			if Arc::ptr_eq(c, &conn) {
				conn_list.remove(id);
			}

			if let Some(h) = self.on_disconnected.load().as_ref() {
				h(conn.peer_pk, true);
			}
		}
	}

	pub(crate) fn connected_as_client(&self, id: ed25519::PublicKey, conn: Arc<ClientConn>) {
		info!("Connection established to {}", hex::encode(id));

		{
			let mut conn_list = self.client_conns.write().unwrap();
			if let Some(old_c) = conn_list.insert(id.clone(), conn.clone()) {
				tokio::spawn(async move { old_c.close() });
			}
		}

		if let Some(h) = self.on_connected.load().as_ref() {
			h(conn.peer_pk, conn.remote_addr, false);
		}

		tokio::spawn(async move {
			let server_port = conn.netapp.listen_addr.port();
			conn.request(HelloMessage { server_port }, prio::NORMAL)
				.await
				.log_err("Sending hello message");
		});
	}

	pub(crate) fn disconnected_as_client(&self, id: &ed25519::PublicKey, conn: Arc<ClientConn>) {
		info!("Connection to {} closed", hex::encode(id));
		let mut conn_list = self.client_conns.write().unwrap();
		if let Some(c) = conn_list.get(id) {
			if Arc::ptr_eq(c, &conn) {
				conn_list.remove(id);
			}

			if let Some(h) = self.on_disconnected.load().as_ref() {
				h(conn.peer_pk, false);
			}
		}
	}

	pub async fn request<T>(
		&self,
		target: &ed25519::PublicKey,
		rq: T,
		prio: RequestPriority,
	) -> Result<<T as Message>::Response, Error>
	where
		T: Message + 'static,
	{
		if *target == self.pubkey {
			let handler = self.msg_handlers.load().get(&T::KIND).cloned();
			match handler {
				None => Err(Error::Message(format!(
					"No handler registered for message kind {:08x}",
					T::KIND
				))),
				Some(h) => {
					let local_handler = &h.local_handler;
					let res = local_handler(Box::new(rq)).await;
					let res_t = (res as Box<dyn Any + 'static>)
						.downcast::<<T as Message>::Response>()
						.unwrap();
					Ok(*res_t)
				}
			}
		} else {
			let conn = self.client_conns.read().unwrap().get(target).cloned();
			match conn {
				None => Err(Error::Message(format!(
					"Not connected: {}",
					hex::encode(target)
				))),
				Some(c) => c.request(rq, prio).await,
			}
		}
	}
}