use std::any::Any;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::{Arc, RwLock};
use std::future::Future;
use log::{debug, info};
use arc_swap::{ArcSwap, ArcSwapOption};
use bytes::Bytes;
use sodiumoxide::crypto::auth;
use sodiumoxide::crypto::sign::ed25519;
use tokio::net::{TcpListener, TcpStream};
use crate::conn::*;
use crate::error::*;
use crate::message::*;
use crate::proto::*;
use crate::util::*;
type DynMsg = Box<dyn Any + Send + Sync + 'static>;
pub(crate) struct Handler {
pub(crate) local_handler:
Box<dyn Fn(DynMsg) -> Pin<Box<dyn Future<Output = DynMsg> + Sync + Send>> + Sync + Send>,
pub(crate) net_handler: Box<
dyn Fn(ed25519::PublicKey, Bytes) -> Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>>
+ Sync
+ Send,
>,
}
pub struct NetApp {
pub listen_addr: SocketAddr,
pub netid: auth::Key,
pub pubkey: ed25519::PublicKey,
pub privkey: ed25519::SecretKey,
pub server_conns: RwLock<HashMap<ed25519::PublicKey, Arc<ServerConn>>>,
pub client_conns: RwLock<HashMap<ed25519::PublicKey, Arc<ClientConn>>>,
pub(crate) msg_handlers: ArcSwap<HashMap<MessageKind, Arc<Handler>>>,
pub(crate) on_connected:
ArcSwapOption<Box<dyn Fn(ed25519::PublicKey, SocketAddr, bool) + Send + Sync>>,
pub(crate) on_disconnected: ArcSwapOption<Box<dyn Fn(ed25519::PublicKey, bool) + Send + Sync>>,
}
async fn net_handler_aux<M, F, R>(
handler: Arc<F>,
remote: ed25519::PublicKey,
bytes: Bytes,
) -> Vec<u8>
where
M: Message + 'static,
F: Fn(ed25519::PublicKey, M) -> R + Send + Sync + 'static,
R: Future<Output = <M as Message>::Response> + Send + Sync,
{
debug!(
"Handling message of kind {:08x} from {}",
M::KIND,
hex::encode(remote)
);
let res = match rmp_serde::decode::from_read_ref::<_, M>(&bytes[..]) {
Ok(msg) => Ok(handler(remote, msg).await),
Err(e) => Err(e.to_string()),
};
rmp_to_vec_all_named(&res).unwrap_or(vec![])
}
async fn local_handler_aux<M, F, R>(
handler: Arc<F>,
remote: ed25519::PublicKey,
msg: DynMsg,
) -> DynMsg
where
M: Message + 'static,
F: Fn(ed25519::PublicKey, M) -> R + Send + Sync + 'static,
R: Future<Output = <M as Message>::Response> + Send + Sync,
{
debug!("Handling message of kind {:08x} from ourself", M::KIND,);
let msg = (msg as Box<dyn Any + 'static>).downcast::<M>().unwrap();
let res = handler(remote, *msg).await;
Box::new(res)
}
impl NetApp {
pub fn new(
listen_addr: SocketAddr,
netid: auth::Key,
privkey: ed25519::SecretKey,
) -> Arc<Self> {
let pubkey = privkey.public_key();
let netapp = Arc::new(Self {
listen_addr,
netid,
pubkey,
privkey,
server_conns: RwLock::new(HashMap::new()),
client_conns: RwLock::new(HashMap::new()),
msg_handlers: ArcSwap::new(Arc::new(HashMap::new())),
on_connected: ArcSwapOption::new(None),
on_disconnected: ArcSwapOption::new(None),
});
let netapp2 = netapp.clone();
netapp.add_msg_handler::<HelloMessage, _, _>(
move |from: ed25519::PublicKey, msg: HelloMessage| {
netapp2.handle_hello_message(from, msg);
async { () }
},
);
netapp
}
pub fn add_msg_handler<M, F, R>(&self, handler: F)
where
M: Message + 'static,
F: Fn(ed25519::PublicKey, M) -> R + Send + Sync + 'static,
R: Future<Output = <M as Message>::Response> + Send + Sync + 'static,
{
let handler = Arc::new(handler);
let handler1 = handler.clone();
let net_handler = Box::new(move |remote: ed25519::PublicKey, bytes: Bytes| {
let fun: Pin<Box<dyn Future<Output = Vec<u8>> + Sync + Send>> =
Box::pin(net_handler_aux(handler1.clone(), remote, bytes));
fun
});
let self_id = self.pubkey.clone();
let local_handler = Box::new(move |msg: DynMsg| {
let fun: Pin<Box<dyn Future<Output = DynMsg> + Sync + Send>> =
Box::pin(local_handler_aux(handler.clone(), self_id, msg));
fun
});
let funs = Arc::new(Handler {
net_handler,
local_handler,
});
let mut handlers = self.msg_handlers.load().as_ref().clone();
handlers.insert(M::KIND, funs);
self.msg_handlers.store(Arc::new(handlers));
}
pub async fn listen(self: Arc<Self>) {
let mut listener = TcpListener::bind(self.listen_addr).await.unwrap();
info!("Listening on {}", self.listen_addr);
loop {
// The second item contains the IP and port of the new connection.
let (socket, _) = listener.accept().await.unwrap();
info!(
"Incoming connection from {}, negotiating handshake...",
socket.peer_addr().unwrap()
);
let self2 = self.clone();
tokio::spawn(async move {
ServerConn::run(self2, socket)
.await
.log_err("ServerConn::run");
});
}
}
pub async fn try_connect(
self: Arc<Self>,
ip: SocketAddr,
pk: ed25519::PublicKey,
) -> Result<(), Error> {
if pk == self.pubkey {
// Don't connect to ourself, we don't care
// but pretend we did
tokio::spawn(async move {
if let Some(h) = self.on_connected.load().as_ref() {
h(pk, ip, false);
}
});
return Ok(());
}
// Don't connect if already connected
if self.client_conns.read().unwrap().contains_key(&pk) {
return Ok(());
}
let socket = TcpStream::connect(ip).await?;
info!("Connected to {}, negotiating handshake...", ip);
ClientConn::init(self, socket, pk.clone()).await?;
Ok(())
}
pub fn disconnect(self: Arc<Self>, pk: &ed25519::PublicKey) {
if *pk == self.pubkey {
let pk = *pk;
tokio::spawn(async move {
if let Some(h) = self.on_disconnected.load().as_ref() {
h(pk, false);
}
});
return;
}
let conn = self.client_conns.read().unwrap().get(pk).cloned();
if let Some(c) = conn {
c.close();
}
}
pub(crate) fn connected_as_server(&self, id: ed25519::PublicKey, conn: Arc<ServerConn>) {
info!("Accepted connection from {}", hex::encode(id));
let mut conn_list = self.server_conns.write().unwrap();
conn_list.insert(id.clone(), conn);
}
fn handle_hello_message(&self, id: ed25519::PublicKey, msg: HelloMessage) {
if let Some(h) = self.on_connected.load().as_ref() {
if let Some(c) = self.server_conns.read().unwrap().get(&id) {
let remote_addr = SocketAddr::new(c.remote_addr.ip(), msg.server_port);
h(id, remote_addr, true);
}
}
}
pub(crate) fn disconnected_as_server(&self, id: &ed25519::PublicKey, conn: Arc<ServerConn>) {
info!("Connection from {} closed", hex::encode(id));
let mut conn_list = self.server_conns.write().unwrap();
if let Some(c) = conn_list.get(id) {
if Arc::ptr_eq(c, &conn) {
conn_list.remove(id);
}
if let Some(h) = self.on_disconnected.load().as_ref() {
h(conn.peer_pk, true);
}
}
}
pub(crate) fn connected_as_client(&self, id: ed25519::PublicKey, conn: Arc<ClientConn>) {
info!("Connection established to {}", hex::encode(id));
{
let mut conn_list = self.client_conns.write().unwrap();
if let Some(old_c) = conn_list.insert(id.clone(), conn.clone()) {
tokio::spawn(async move { old_c.close() });
}
}
if let Some(h) = self.on_connected.load().as_ref() {
h(conn.peer_pk, conn.remote_addr, false);
}
tokio::spawn(async move {
let server_port = conn.netapp.listen_addr.port();
conn.request(HelloMessage { server_port }, prio::NORMAL)
.await
.log_err("Sending hello message");
});
}
pub(crate) fn disconnected_as_client(&self, id: &ed25519::PublicKey, conn: Arc<ClientConn>) {
info!("Connection to {} closed", hex::encode(id));
let mut conn_list = self.client_conns.write().unwrap();
if let Some(c) = conn_list.get(id) {
if Arc::ptr_eq(c, &conn) {
conn_list.remove(id);
}
if let Some(h) = self.on_disconnected.load().as_ref() {
h(conn.peer_pk, false);
}
}
}
pub async fn request<T>(
&self,
target: &ed25519::PublicKey,
rq: T,
prio: RequestPriority,
) -> Result<<T as Message>::Response, Error>
where
T: Message + 'static,
{
if *target == self.pubkey {
let handler = self.msg_handlers.load().get(&T::KIND).cloned();
match handler {
None => Err(Error::Message(format!(
"No handler registered for message kind {:08x}",
T::KIND
))),
Some(h) => {
let local_handler = &h.local_handler;
let res = local_handler(Box::new(rq)).await;
let res_t = (res as Box<dyn Any + 'static>)
.downcast::<<T as Message>::Response>()
.unwrap();
Ok(*res_t)
}
}
} else {
let conn = self.client_conns.read().unwrap().get(target).cloned();
match conn {
None => Err(Error::Message(format!(
"Not connected: {}",
hex::encode(target)
))),
Some(c) => c.request(rq, prio).await,
}
}
}
}