aboutsummaryrefslogtreecommitdiff
path: root/src/endpoint.rs
blob: 3f292d96a59dda92848c0e4dfd34a4af34da3987 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
use std::borrow::Borrow;
use std::marker::PhantomData;
use std::sync::Arc;

use arc_swap::ArcSwapOption;
use async_trait::async_trait;

use crate::error::Error;
use crate::message::*;
use crate::netapp::*;
use crate::util::*;

/// This trait should be implemented by an object of your application
/// that can handle a message of type `M`.
///
/// The handler object should be in an Arc, see `Endpoint::set_handler`
#[async_trait]
pub trait EndpointHandler<M>: Send + Sync
where
	M: Message,
{
	async fn handle(self: &Arc<Self>, m: M, from: NodeID) -> M::Response;
}

/// If one simply wants to use an endpoint in a client fashion,
/// without locally serving requests to that endpoint,
/// use the unit type `()` as the handler type:
/// it will panic if it is ever made to handle request.
#[async_trait]
impl<M: Message + 'static> EndpointHandler<M> for () {
	async fn handle(self: &Arc<()>, _m: M, _from: NodeID) -> M::Response {
		panic!("This endpoint should not have a local handler.");
	}
}

/// This struct represents an endpoint for message of type `M`.
///
/// Creating a new endpoint is done by calling `NetApp::endpoint`.
/// An endpoint is identified primarily by its path, which is specified
/// at creation time.
///
/// An `Endpoint` is used both to send requests to remote nodes,
/// and to specify the handler for such requests on the local node.
/// The type `H` represents the type of the handler object for
/// endpoint messages (see `EndpointHandler`).
pub struct Endpoint<M, H>
where
	M: Message,
	H: EndpointHandler<M>,
{
	phantom: PhantomData<M>,
	netapp: Arc<NetApp>,
	path: String,
	handler: ArcSwapOption<H>,
}

impl<M, H> Endpoint<M, H>
where
	M: Message,
	H: EndpointHandler<M>,
{
	pub(crate) fn new(netapp: Arc<NetApp>, path: String) -> Self {
		Self {
			phantom: PhantomData::default(),
			netapp,
			path,
			handler: ArcSwapOption::from(None),
		}
	}

	/// Get the path of this endpoint
	pub fn path(&self) -> &str {
		&self.path
	}

	/// Set the object that is responsible of handling requests to
	/// this endpoint on the local node.
	pub fn set_handler(&self, h: Arc<H>) {
		self.handler.swap(Some(h));
	}

	/// Call this endpoint on a remote node (or on the local node,
	/// for that matter)
	pub async fn call(
		&self,
		target: &NodeID,
		req: M,
		prio: RequestPriority,
	) -> Result<<M as Message>::Response, Error> {
		if *target == self.netapp.id {
			match self.handler.load_full() {
				None => Err(Error::NoHandler),
				Some(h) => Ok(h.handle(req, self.netapp.id).await),
			}
		} else {
			let conn = self
				.netapp
				.client_conns
				.read()
				.unwrap()
				.get(target)
				.cloned();
			match conn {
				None => Err(Error::Message(format!(
					"Not connected: {}",
					hex::encode(&target[..8])
				))),
				Some(c) => c.call(req, self.path.as_str(), prio).await,
			}
		}
	}
}

// ---- Internal stuff ----

pub(crate) type DynEndpoint = Box<dyn GenericEndpoint + Send + Sync>;

#[async_trait]
pub(crate) trait GenericEndpoint {
	async fn handle(
		&self,
		buf: &[u8],
		stream: ByteStream,
		from: NodeID,
	) -> Result<(Vec<u8>, Option<ByteStream>), Error>;
	fn drop_handler(&self);
	fn clone_endpoint(&self) -> DynEndpoint;
}

#[derive(Clone)]
pub(crate) struct EndpointArc<M, H>(pub(crate) Arc<Endpoint<M, H>>)
where
	M: Message,
	H: EndpointHandler<M>;

#[async_trait]
impl<M, H> GenericEndpoint for EndpointArc<M, H>
where
	M: Message + 'static,
	H: EndpointHandler<M> + 'static,
{
	async fn handle(
		&self,
		buf: &[u8],
		stream: ByteStream,
		from: NodeID,
	) -> Result<(Vec<u8>, Option<ByteStream>), Error> {
		match self.0.handler.load_full() {
			None => Err(Error::NoHandler),
			Some(h) => {
				let req = rmp_serde::decode::from_read_ref(buf)?;
				let req = M::from_parts(req, stream);
				let res = h.handle(req, from).await;
				let (res, res_stream) = res.into_parts();
				let res_bytes = rmp_to_vec_all_named(&res)?;
				Ok((res_bytes, res_stream))
			}
		}
	}

	fn drop_handler(&self) {
		self.0.handler.swap(None);
	}

	fn clone_endpoint(&self) -> DynEndpoint {
		Box::new(Self(self.0.clone()))
	}
}