diff options
author | Alex Auvolat <alex@adnab.me> | 2020-12-02 13:30:47 +0100 |
---|---|---|
committer | Alex Auvolat <alex@adnab.me> | 2020-12-02 13:30:47 +0100 |
commit | d4de2ffc40fe9d003f12139053ca070eda0b7221 (patch) | |
tree | e95476f0b7a6d1c75cc462b3ea7eee74c4faf09f /src/netapp.rs | |
download | netapp-d4de2ffc40fe9d003f12139053ca070eda0b7221.tar.gz netapp-d4de2ffc40fe9d003f12139053ca070eda0b7221.zip |
First commit
Diffstat (limited to 'src/netapp.rs')
-rw-r--r-- | src/netapp.rs | 214 |
1 files changed, 214 insertions, 0 deletions
diff --git a/src/netapp.rs b/src/netapp.rs new file mode 100644 index 0000000..6f174b4 --- /dev/null +++ b/src/netapp.rs @@ -0,0 +1,214 @@ +use std::collections::HashMap; +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::{Arc, RwLock}; + +use std::future::Future; + +use log::{debug, info}; + +use arc_swap::{ArcSwap, ArcSwapOption}; +use bytes::Bytes; + +use sodiumoxide::crypto::auth; +use sodiumoxide::crypto::sign::ed25519; +use tokio::net::{TcpListener, TcpStream}; + +use crate::conn::*; +use crate::error::*; +use crate::message::*; +use crate::proto::*; +use crate::util::*; + +pub struct NetApp { + pub listen_addr: SocketAddr, + pub netid: auth::Key, + pub pubkey: ed25519::PublicKey, + pub privkey: ed25519::SecretKey, + pub server_conns: RwLock<HashMap<ed25519::PublicKey, Arc<ServerConn>>>, + pub client_conns: RwLock<HashMap<ed25519::PublicKey, Arc<ClientConn>>>, + pub(crate) msg_handlers: ArcSwap< + HashMap< + MessageKind, + Arc< + dyn Fn( + ed25519::PublicKey, + Bytes, + ) -> Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>> + + Sync + + Send, + >, + >, + >, + pub(crate) on_connected: + ArcSwapOption<Box<dyn Fn(ed25519::PublicKey, SocketAddr, bool) + Send + Sync>>, + pub(crate) on_disconnected: ArcSwapOption<Box<dyn Fn(ed25519::PublicKey, bool) + Send + Sync>>, +} + +async fn handler_aux<M, F, R>(handler: Arc<F>, remote: ed25519::PublicKey, bytes: Bytes) -> Vec<u8> +where + M: Message + 'static, + F: Fn(ed25519::PublicKey, M) -> R + Send + Sync + 'static, + R: Future<Output = Result<<M as Message>::Response, Error>> + Send + Sync, +{ + debug!( + "Handling message of kind {:08x} from {}", + M::KIND, + hex::encode(remote) + ); + let res = match rmp_serde::decode::from_read_ref::<_, M>(&bytes[..]) { + Ok(msg) => handler(remote.clone(), msg).await, + Err(e) => Err(e.into()), + }; + let res = res.map_err(|e| format!("{}", e)); + rmp_to_vec_all_named(&res).unwrap_or(vec![]) +} + +impl NetApp { + pub fn new( + listen_addr: SocketAddr, + netid: auth::Key, + privkey: ed25519::SecretKey, + ) -> Arc<Self> { + let pubkey = privkey.public_key(); + let netapp = Arc::new(Self { + listen_addr, + netid, + pubkey, + privkey, + server_conns: RwLock::new(HashMap::new()), + client_conns: RwLock::new(HashMap::new()), + msg_handlers: ArcSwap::new(Arc::new(HashMap::new())), + on_connected: ArcSwapOption::new(None), + on_disconnected: ArcSwapOption::new(None), + }); + + let netapp2 = netapp.clone(); + netapp.add_msg_handler::<HelloMessage, _, _>( + move |from: ed25519::PublicKey, msg: HelloMessage| { + netapp2.handle_hello_message(from, msg); + async { Ok(()) } + }, + ); + + netapp + } + + pub fn add_msg_handler<M, F, R>(&self, handler: F) + where + M: Message + 'static, + F: Fn(ed25519::PublicKey, M) -> R + Send + Sync + 'static, + R: Future<Output = Result<<M as Message>::Response, Error>> + Send + Sync + 'static, + { + let handler = Arc::new(handler); + let fun = Arc::new(move |remote: ed25519::PublicKey, bytes: Bytes| { + let fun: Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>> = + Box::pin(handler_aux(handler.clone(), remote, bytes)); + fun + }); + let mut handlers = self.msg_handlers.load().as_ref().clone(); + handlers.insert(M::KIND, fun); + self.msg_handlers.store(Arc::new(handlers)); + } + + pub async fn listen(self: Arc<Self>) { + let mut listener = TcpListener::bind(self.listen_addr).await.unwrap(); + info!("Listening on {}", self.listen_addr); + + loop { + // The second item contains the IP and port of the new connection. + let (socket, _) = listener.accept().await.unwrap(); + info!( + "Incoming connection from {}, negotiating handshake...", + socket.peer_addr().unwrap() + ); + let self2 = self.clone(); + tokio::spawn(async move { + ServerConn::run(self2, socket) + .await + .log_err("ServerConn::run"); + }); + } + } + + pub async fn try_connect( + self: Arc<Self>, + ip: SocketAddr, + pk: ed25519::PublicKey, + ) -> Result<(), Error> { + if self.client_conns.read().unwrap().contains_key(&pk) { + return Ok(()); + } + let socket = TcpStream::connect(ip).await?; + info!("Connected to {}, negotiating handshake...", ip); + ClientConn::init(self, socket, pk.clone()).await?; + Ok(()) + } + + pub fn disconnect(self: Arc<Self>, id: &ed25519::PublicKey) { + let conn = self.client_conns.read().unwrap().get(id).cloned(); + if let Some(c) = conn { + c.close(); + } + } + + pub(crate) fn connected_as_server(&self, id: ed25519::PublicKey, conn: Arc<ServerConn>) { + let mut conn_list = self.server_conns.write().unwrap(); + conn_list.insert(id.clone(), conn); + } + + fn handle_hello_message(&self, id: ed25519::PublicKey, msg: HelloMessage) { + if let Some(h) = self.on_connected.load().as_ref() { + if let Some(c) = self.server_conns.read().unwrap().get(&id) { + let remote_addr = SocketAddr::new(c.remote_addr.ip(), msg.server_port); + h(id, remote_addr, true); + } + } + } + + pub(crate) fn disconnected_as_server(&self, id: &ed25519::PublicKey, conn: Arc<ServerConn>) { + let mut conn_list = self.server_conns.write().unwrap(); + if let Some(c) = conn_list.get(id) { + if Arc::ptr_eq(c, &conn) { + conn_list.remove(id); + } + + if let Some(h) = self.on_disconnected.load().as_ref() { + h(conn.peer_pk, true); + } + } + } + + pub(crate) fn connected_as_client(&self, id: ed25519::PublicKey, conn: Arc<ClientConn>) { + { + let mut conn_list = self.client_conns.write().unwrap(); + if let Some(old_c) = conn_list.insert(id.clone(), conn.clone()) { + tokio::spawn(async move { old_c.close() }); + } + } + + if let Some(h) = self.on_connected.load().as_ref() { + h(conn.peer_pk, conn.remote_addr, false); + } + + tokio::spawn(async move { + let server_port = conn.netapp.listen_addr.port(); + conn.request(HelloMessage { server_port }, prio::NORMAL) + .await + .log_err("Sending hello message"); + }); + } + + pub(crate) fn disconnected_as_client(&self, id: &ed25519::PublicKey, conn: Arc<ClientConn>) { + let mut conn_list = self.client_conns.write().unwrap(); + if let Some(c) = conn_list.get(id) { + if Arc::ptr_eq(c, &conn) { + conn_list.remove(id); + } + + if let Some(h) = self.on_disconnected.load().as_ref() { + h(conn.peer_pk, false); + } + } + } +} |