use std::collections::HashMap; use std::net::SocketAddr; use std::pin::Pin; use std::sync::{Arc, RwLock}; use std::future::Future; use log::{debug, info}; use arc_swap::{ArcSwap, ArcSwapOption}; use bytes::Bytes; use sodiumoxide::crypto::auth; use sodiumoxide::crypto::sign::ed25519; use tokio::net::{TcpListener, TcpStream}; use crate::conn::*; use crate::error::*; use crate::message::*; use crate::proto::*; use crate::util::*; pub struct NetApp { pub listen_addr: SocketAddr, pub netid: auth::Key, pub pubkey: ed25519::PublicKey, pub privkey: ed25519::SecretKey, pub server_conns: RwLock>>, pub client_conns: RwLock>>, pub(crate) msg_handlers: ArcSwap< HashMap< MessageKind, Arc< dyn Fn( ed25519::PublicKey, Bytes, ) -> Pin> + Sync + Send>> + Sync + Send, >, >, >, pub(crate) on_connected: ArcSwapOption>, pub(crate) on_disconnected: ArcSwapOption>, } async fn handler_aux(handler: Arc, remote: ed25519::PublicKey, bytes: Bytes) -> Vec where M: Message + 'static, F: Fn(ed25519::PublicKey, M) -> R + Send + Sync + 'static, R: Future::Response, Error>> + Send + Sync, { debug!( "Handling message of kind {:08x} from {}", M::KIND, hex::encode(remote) ); let res = match rmp_serde::decode::from_read_ref::<_, M>(&bytes[..]) { Ok(msg) => handler(remote.clone(), msg).await, Err(e) => Err(e.into()), }; let res = res.map_err(|e| format!("{}", e)); rmp_to_vec_all_named(&res).unwrap_or(vec![]) } impl NetApp { pub fn new( listen_addr: SocketAddr, netid: auth::Key, privkey: ed25519::SecretKey, ) -> Arc { let pubkey = privkey.public_key(); let netapp = Arc::new(Self { listen_addr, netid, pubkey, privkey, server_conns: RwLock::new(HashMap::new()), client_conns: RwLock::new(HashMap::new()), msg_handlers: ArcSwap::new(Arc::new(HashMap::new())), on_connected: ArcSwapOption::new(None), on_disconnected: ArcSwapOption::new(None), }); let netapp2 = netapp.clone(); netapp.add_msg_handler::( move |from: ed25519::PublicKey, msg: HelloMessage| { netapp2.handle_hello_message(from, msg); async { Ok(()) } }, ); netapp } pub fn add_msg_handler(&self, handler: F) where M: Message + 'static, F: Fn(ed25519::PublicKey, M) -> R + Send + Sync + 'static, R: Future::Response, Error>> + Send + Sync + 'static, { let handler = Arc::new(handler); let fun = Arc::new(move |remote: ed25519::PublicKey, bytes: Bytes| { let fun: Pin> + Sync + Send>> = Box::pin(handler_aux(handler.clone(), remote, bytes)); fun }); let mut handlers = self.msg_handlers.load().as_ref().clone(); handlers.insert(M::KIND, fun); self.msg_handlers.store(Arc::new(handlers)); } pub async fn listen(self: Arc) { let mut listener = TcpListener::bind(self.listen_addr).await.unwrap(); info!("Listening on {}", self.listen_addr); loop { // The second item contains the IP and port of the new connection. let (socket, _) = listener.accept().await.unwrap(); info!( "Incoming connection from {}, negotiating handshake...", socket.peer_addr().unwrap() ); let self2 = self.clone(); tokio::spawn(async move { ServerConn::run(self2, socket) .await .log_err("ServerConn::run"); }); } } pub async fn try_connect( self: Arc, ip: SocketAddr, pk: ed25519::PublicKey, ) -> Result<(), Error> { if self.client_conns.read().unwrap().contains_key(&pk) { return Ok(()); } let socket = TcpStream::connect(ip).await?; info!("Connected to {}, negotiating handshake...", ip); ClientConn::init(self, socket, pk.clone()).await?; Ok(()) } pub fn disconnect(self: Arc, id: &ed25519::PublicKey) { let conn = self.client_conns.read().unwrap().get(id).cloned(); if let Some(c) = conn { c.close(); } } pub(crate) fn connected_as_server(&self, id: ed25519::PublicKey, conn: Arc) { let mut conn_list = self.server_conns.write().unwrap(); conn_list.insert(id.clone(), conn); } fn handle_hello_message(&self, id: ed25519::PublicKey, msg: HelloMessage) { if let Some(h) = self.on_connected.load().as_ref() { if let Some(c) = self.server_conns.read().unwrap().get(&id) { let remote_addr = SocketAddr::new(c.remote_addr.ip(), msg.server_port); h(id, remote_addr, true); } } } pub(crate) fn disconnected_as_server(&self, id: &ed25519::PublicKey, conn: Arc) { let mut conn_list = self.server_conns.write().unwrap(); if let Some(c) = conn_list.get(id) { if Arc::ptr_eq(c, &conn) { conn_list.remove(id); } if let Some(h) = self.on_disconnected.load().as_ref() { h(conn.peer_pk, true); } } } pub(crate) fn connected_as_client(&self, id: ed25519::PublicKey, conn: Arc) { { let mut conn_list = self.client_conns.write().unwrap(); if let Some(old_c) = conn_list.insert(id.clone(), conn.clone()) { tokio::spawn(async move { old_c.close() }); } } if let Some(h) = self.on_connected.load().as_ref() { h(conn.peer_pk, conn.remote_addr, false); } tokio::spawn(async move { let server_port = conn.netapp.listen_addr.port(); conn.request(HelloMessage { server_port }, prio::NORMAL) .await .log_err("Sending hello message"); }); } pub(crate) fn disconnected_as_client(&self, id: &ed25519::PublicKey, conn: Arc) { let mut conn_list = self.client_conns.write().unwrap(); if let Some(c) = conn_list.get(id) { if Arc::ptr_eq(c, &conn) { conn_list.remove(id); } if let Some(h) = self.on_disconnected.load().as_ref() { h(conn.peer_pk, false); } } } }