aboutsummaryrefslogtreecommitdiff
path: root/src/netapp.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/netapp.rs')
-rw-r--r--src/netapp.rs214
1 files changed, 214 insertions, 0 deletions
diff --git a/src/netapp.rs b/src/netapp.rs
new file mode 100644
index 0000000..6f174b4
--- /dev/null
+++ b/src/netapp.rs
@@ -0,0 +1,214 @@
+use std::collections::HashMap;
+use std::net::SocketAddr;
+use std::pin::Pin;
+use std::sync::{Arc, RwLock};
+
+use std::future::Future;
+
+use log::{debug, info};
+
+use arc_swap::{ArcSwap, ArcSwapOption};
+use bytes::Bytes;
+
+use sodiumoxide::crypto::auth;
+use sodiumoxide::crypto::sign::ed25519;
+use tokio::net::{TcpListener, TcpStream};
+
+use crate::conn::*;
+use crate::error::*;
+use crate::message::*;
+use crate::proto::*;
+use crate::util::*;
+
+pub struct NetApp {
+ pub listen_addr: SocketAddr,
+ pub netid: auth::Key,
+ pub pubkey: ed25519::PublicKey,
+ pub privkey: ed25519::SecretKey,
+ pub server_conns: RwLock<HashMap<ed25519::PublicKey, Arc<ServerConn>>>,
+ pub client_conns: RwLock<HashMap<ed25519::PublicKey, Arc<ClientConn>>>,
+ pub(crate) msg_handlers: ArcSwap<
+ HashMap<
+ MessageKind,
+ Arc<
+ dyn Fn(
+ ed25519::PublicKey,
+ Bytes,
+ ) -> Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>>
+ + Sync
+ + Send,
+ >,
+ >,
+ >,
+ pub(crate) on_connected:
+ ArcSwapOption<Box<dyn Fn(ed25519::PublicKey, SocketAddr, bool) + Send + Sync>>,
+ pub(crate) on_disconnected: ArcSwapOption<Box<dyn Fn(ed25519::PublicKey, bool) + Send + Sync>>,
+}
+
+async fn handler_aux<M, F, R>(handler: Arc<F>, remote: ed25519::PublicKey, bytes: Bytes) -> Vec<u8>
+where
+ M: Message + 'static,
+ F: Fn(ed25519::PublicKey, M) -> R + Send + Sync + 'static,
+ R: Future<Output = Result<<M as Message>::Response, Error>> + Send + Sync,
+{
+ debug!(
+ "Handling message of kind {:08x} from {}",
+ M::KIND,
+ hex::encode(remote)
+ );
+ let res = match rmp_serde::decode::from_read_ref::<_, M>(&bytes[..]) {
+ Ok(msg) => handler(remote.clone(), msg).await,
+ Err(e) => Err(e.into()),
+ };
+ let res = res.map_err(|e| format!("{}", e));
+ rmp_to_vec_all_named(&res).unwrap_or(vec![])
+}
+
+impl NetApp {
+ pub fn new(
+ listen_addr: SocketAddr,
+ netid: auth::Key,
+ privkey: ed25519::SecretKey,
+ ) -> Arc<Self> {
+ let pubkey = privkey.public_key();
+ let netapp = Arc::new(Self {
+ listen_addr,
+ netid,
+ pubkey,
+ privkey,
+ server_conns: RwLock::new(HashMap::new()),
+ client_conns: RwLock::new(HashMap::new()),
+ msg_handlers: ArcSwap::new(Arc::new(HashMap::new())),
+ on_connected: ArcSwapOption::new(None),
+ on_disconnected: ArcSwapOption::new(None),
+ });
+
+ let netapp2 = netapp.clone();
+ netapp.add_msg_handler::<HelloMessage, _, _>(
+ move |from: ed25519::PublicKey, msg: HelloMessage| {
+ netapp2.handle_hello_message(from, msg);
+ async { Ok(()) }
+ },
+ );
+
+ netapp
+ }
+
+ pub fn add_msg_handler<M, F, R>(&self, handler: F)
+ where
+ M: Message + 'static,
+ F: Fn(ed25519::PublicKey, M) -> R + Send + Sync + 'static,
+ R: Future<Output = Result<<M as Message>::Response, Error>> + Send + Sync + 'static,
+ {
+ let handler = Arc::new(handler);
+ let fun = Arc::new(move |remote: ed25519::PublicKey, bytes: Bytes| {
+ let fun: Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>> =
+ Box::pin(handler_aux(handler.clone(), remote, bytes));
+ fun
+ });
+ let mut handlers = self.msg_handlers.load().as_ref().clone();
+ handlers.insert(M::KIND, fun);
+ self.msg_handlers.store(Arc::new(handlers));
+ }
+
+ pub async fn listen(self: Arc<Self>) {
+ let mut listener = TcpListener::bind(self.listen_addr).await.unwrap();
+ info!("Listening on {}", self.listen_addr);
+
+ loop {
+ // The second item contains the IP and port of the new connection.
+ let (socket, _) = listener.accept().await.unwrap();
+ info!(
+ "Incoming connection from {}, negotiating handshake...",
+ socket.peer_addr().unwrap()
+ );
+ let self2 = self.clone();
+ tokio::spawn(async move {
+ ServerConn::run(self2, socket)
+ .await
+ .log_err("ServerConn::run");
+ });
+ }
+ }
+
+ pub async fn try_connect(
+ self: Arc<Self>,
+ ip: SocketAddr,
+ pk: ed25519::PublicKey,
+ ) -> Result<(), Error> {
+ if self.client_conns.read().unwrap().contains_key(&pk) {
+ return Ok(());
+ }
+ let socket = TcpStream::connect(ip).await?;
+ info!("Connected to {}, negotiating handshake...", ip);
+ ClientConn::init(self, socket, pk.clone()).await?;
+ Ok(())
+ }
+
+ pub fn disconnect(self: Arc<Self>, id: &ed25519::PublicKey) {
+ let conn = self.client_conns.read().unwrap().get(id).cloned();
+ if let Some(c) = conn {
+ c.close();
+ }
+ }
+
+ pub(crate) fn connected_as_server(&self, id: ed25519::PublicKey, conn: Arc<ServerConn>) {
+ let mut conn_list = self.server_conns.write().unwrap();
+ conn_list.insert(id.clone(), conn);
+ }
+
+ fn handle_hello_message(&self, id: ed25519::PublicKey, msg: HelloMessage) {
+ if let Some(h) = self.on_connected.load().as_ref() {
+ if let Some(c) = self.server_conns.read().unwrap().get(&id) {
+ let remote_addr = SocketAddr::new(c.remote_addr.ip(), msg.server_port);
+ h(id, remote_addr, true);
+ }
+ }
+ }
+
+ pub(crate) fn disconnected_as_server(&self, id: &ed25519::PublicKey, conn: Arc<ServerConn>) {
+ let mut conn_list = self.server_conns.write().unwrap();
+ if let Some(c) = conn_list.get(id) {
+ if Arc::ptr_eq(c, &conn) {
+ conn_list.remove(id);
+ }
+
+ if let Some(h) = self.on_disconnected.load().as_ref() {
+ h(conn.peer_pk, true);
+ }
+ }
+ }
+
+ pub(crate) fn connected_as_client(&self, id: ed25519::PublicKey, conn: Arc<ClientConn>) {
+ {
+ let mut conn_list = self.client_conns.write().unwrap();
+ if let Some(old_c) = conn_list.insert(id.clone(), conn.clone()) {
+ tokio::spawn(async move { old_c.close() });
+ }
+ }
+
+ if let Some(h) = self.on_connected.load().as_ref() {
+ h(conn.peer_pk, conn.remote_addr, false);
+ }
+
+ tokio::spawn(async move {
+ let server_port = conn.netapp.listen_addr.port();
+ conn.request(HelloMessage { server_port }, prio::NORMAL)
+ .await
+ .log_err("Sending hello message");
+ });
+ }
+
+ pub(crate) fn disconnected_as_client(&self, id: &ed25519::PublicKey, conn: Arc<ClientConn>) {
+ let mut conn_list = self.client_conns.write().unwrap();
+ if let Some(c) = conn_list.get(id) {
+ if Arc::ptr_eq(c, &conn) {
+ conn_list.remove(id);
+ }
+
+ if let Some(h) = self.on_disconnected.load().as_ref() {
+ h(conn.peer_pk, false);
+ }
+ }
+ }
+}