From d4de2ffc40fe9d003f12139053ca070eda0b7221 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Wed, 2 Dec 2020 13:30:47 +0100 Subject: First commit --- src/conn.rs | 280 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 280 insertions(+) create mode 100644 src/conn.rs (limited to 'src/conn.rs') diff --git a/src/conn.rs b/src/conn.rs new file mode 100644 index 0000000..9b60d2a --- /dev/null +++ b/src/conn.rs @@ -0,0 +1,280 @@ +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::atomic::{self, AtomicU16}; +use std::sync::Arc; + +use async_trait::async_trait; +use bytes::Bytes; +use log::{debug, trace}; + +use sodiumoxide::crypto::sign::ed25519; +use tokio::io::split; +use tokio::net::TcpStream; +use tokio::sync::{mpsc, oneshot, watch}; + +use kuska_handshake::async_std::{ + handshake_client, handshake_server, BoxStream, TokioCompatExt, TokioCompatExtRead, + TokioCompatExtWrite, +}; + +use crate::error::*; +use crate::message::*; +use crate::netapp::*; +use crate::proto::*; +use crate::util::*; + +pub struct ServerConn { + netapp: Arc, + pub remote_addr: SocketAddr, + pub peer_pk: ed25519::PublicKey, + resp_send: mpsc::UnboundedSender<(RequestID, RequestPriority, Vec)>, + close_send: watch::Sender, +} + +impl ServerConn { + pub(crate) async fn run(netapp: Arc, socket: TcpStream) -> Result<(), Error> { + let mut asyncstd_socket = TokioCompatExt::wrap(socket); + let handshake = handshake_server( + &mut asyncstd_socket, + netapp.netid.clone(), + netapp.pubkey.clone(), + netapp.privkey.clone(), + ) + .await?; + let peer_pk = handshake.peer_pk.clone(); + + let tokio_socket = asyncstd_socket.into_inner(); + let remote_addr = tokio_socket.peer_addr().unwrap(); + + debug!( + "Handshake complete (server) with {}@{}", + hex::encode(&peer_pk), + remote_addr + ); + + let (read, write) = split(tokio_socket); + + let read = TokioCompatExtRead::wrap(read); + let write = TokioCompatExtWrite::wrap(write); + + let (box_stream_read, box_stream_write) = + BoxStream::from_handshake(read, write, handshake, 0x8000).split_read_write(); + + let (resp_send, resp_recv) = mpsc::unbounded_channel(); + + let (close_send, close_recv) = watch::channel(false); + + let conn = Arc::new(ServerConn { + netapp: netapp.clone(), + remote_addr, + peer_pk: peer_pk.clone(), + resp_send, + close_send, + }); + + netapp.connected_as_server(peer_pk.clone(), conn.clone()); + + let conn2 = conn.clone(); + let conn3 = conn.clone(); + tokio::try_join!( + conn2.recv_loop(box_stream_read, close_recv.clone()), + conn3.send_loop(resp_recv, box_stream_write, close_recv.clone()), + ) + .map(|_| ()) + .log_err("ServerConn recv_loop/send_loop"); + + netapp.disconnected_as_server(&peer_pk, conn); + + Ok(()) + } + + pub fn close(&self) { + self.close_send.broadcast(true).unwrap(); + } +} + +impl SendLoop for ServerConn {} + +#[async_trait] +impl RecvLoop for ServerConn { + async fn recv_handler(self: Arc, id: u16, bytes: Vec) { + let bytes: Bytes = bytes.into(); + + let prio = bytes[0]; + + let mut kind_bytes = [0u8; 4]; + kind_bytes.copy_from_slice(&bytes[1..5]); + let kind = u32::from_be_bytes(kind_bytes); + + if let Some(handler) = self.netapp.msg_handlers.load().get(&kind) { + let resp = handler(self.peer_pk.clone(), bytes.slice(5..)).await; + self.resp_send + .send((id, prio, resp)) + .log_err("ServerConn recv_handler send resp"); + } + } +} +pub struct ClientConn { + pub netapp: Arc, + pub remote_addr: SocketAddr, + pub peer_pk: ed25519::PublicKey, + query_send: mpsc::UnboundedSender<(RequestID, RequestPriority, Vec)>, + next_query_number: AtomicU16, + resp_send: mpsc::UnboundedSender<(RequestID, Vec)>, + resp_notify_send: mpsc::UnboundedSender<(RequestID, oneshot::Sender>)>, + close_send: watch::Sender, +} + +impl ClientConn { + pub(crate) async fn init( + netapp: Arc, + socket: TcpStream, + remote_pk: ed25519::PublicKey, + ) -> Result<(), Error> { + let mut asyncstd_socket = TokioCompatExt::wrap(socket); + + let handshake = handshake_client( + &mut asyncstd_socket, + netapp.netid.clone(), + netapp.pubkey.clone(), + netapp.privkey.clone(), + remote_pk.clone(), + ) + .await?; + + let tokio_socket = asyncstd_socket.into_inner(); + let remote_addr = tokio_socket.peer_addr().unwrap(); + + debug!( + "Handshake complete (client) with {}@{}", + hex::encode(&remote_pk), + remote_addr + ); + + let (read, write) = split(tokio_socket); + + let read = TokioCompatExtRead::wrap(read); + let write = TokioCompatExtWrite::wrap(write); + + let (box_stream_read, box_stream_write) = + BoxStream::from_handshake(read, write, handshake, 0x8000).split_read_write(); + + let (query_send, query_recv) = mpsc::unbounded_channel(); + let (resp_send, resp_recv) = mpsc::unbounded_channel(); + let (resp_notify_send, resp_notify_recv) = mpsc::unbounded_channel(); + + let (close_send, close_recv) = watch::channel(false); + + let conn = Arc::new(ClientConn { + netapp: netapp.clone(), + remote_addr, + peer_pk: remote_pk.clone(), + next_query_number: AtomicU16::from(0u16), + query_send, + resp_send, + resp_notify_send, + close_send, + }); + + netapp.connected_as_client(remote_pk.clone(), conn.clone()); + + tokio::spawn(async move { + let conn2 = conn.clone(); + let conn3 = conn.clone(); + let conn4 = conn.clone(); + tokio::try_join!( + conn2.send_loop(query_recv, box_stream_write, close_recv.clone()), + conn3.recv_loop(box_stream_read, close_recv.clone()), + conn4.dispatch_resp(resp_recv, resp_notify_recv, close_recv.clone()), + ) + .map(|_| ()) + .log_err("ClientConn send_loop/recv_loop/dispatch_loop"); + + netapp.disconnected_as_client(&remote_pk, conn); + }); + + Ok(()) + } + + pub fn close(&self) { + self.close_send.broadcast(true).unwrap(); + } + + async fn dispatch_resp( + self: Arc, + mut resp_recv: mpsc::UnboundedReceiver<(RequestID, Vec)>, + mut resp_notify_recv: mpsc::UnboundedReceiver<(RequestID, oneshot::Sender>)>, + mut must_exit: watch::Receiver, + ) -> Result<(), Error> { + let mut resps: HashMap> = HashMap::new(); + let mut resp_notify: HashMap>> = HashMap::new(); + while !*must_exit.borrow() { + tokio::select! { + resp = resp_recv.recv() => { + if let Some((id, resp)) = resp { + trace!("dispatch_resp: got resp to {}, {} bytes", id, resp.len()); + if let Some(ch) = resp_notify.remove(&id) { + ch.send(resp).map_err(|_| Error::Message("Could not dispatch reply".to_string()))?; + } else { + resps.insert(id, resp); + } + } + } + resp_ch = resp_notify_recv.recv() => { + if let Some((id, resp_ch)) = resp_ch { + trace!("dispatch_resp: got resp_ch {}", id); + if let Some(rs) = resps.remove(&id) { + resp_ch.send(rs).map_err(|_| Error::Message("Could not dispatch reply".to_string()))?; + } else { + resp_notify.insert(id, resp_ch); + } + } + } + exit = must_exit.recv() => { + if exit == Some(true) { + break; + } + } + } + } + Ok(()) + } + + pub async fn request( + self: Arc, + rq: T, + prio: RequestPriority, + ) -> Result<::Response, Error> + where + T: Message, + { + let id = self + .next_query_number + .fetch_add(1u16, atomic::Ordering::Relaxed); + let mut bytes = vec![prio]; + bytes.extend_from_slice(&u32::to_be_bytes(T::KIND)[..]); + bytes.extend_from_slice(&rmp_to_vec_all_named(&rq)?[..]); + + let (resp_send, resp_recv) = oneshot::channel(); + self.resp_notify_send.send((id, resp_send))?; + + trace!("request: query_send {}, {} bytes", id, bytes.len()); + self.query_send.send((id, prio, bytes))?; + + let resp = resp_recv.await?; + + rmp_serde::decode::from_read_ref::<_, Result<::Response, String>>(&resp[..])? + .map_err(Error::Remote) + } +} + +impl SendLoop for ClientConn {} + +#[async_trait] +impl RecvLoop for ClientConn { + async fn recv_handler(self: Arc, id: RequestID, msg: Vec) { + self.resp_send + .send((id, msg)) + .log_err("ClientConn::recv_handler"); + } +} -- cgit v1.2.3