aboutsummaryrefslogtreecommitdiff
path: root/src/endpoint.rs
blob: 81ed0362c2d26c312513358aea8f9164c67d6bf6 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
use std::borrow::Borrow;
use std::marker::PhantomData;
use std::sync::Arc;

use arc_swap::ArcSwapOption;
use async_trait::async_trait;

use serde::de::Error as DeError;
use serde::{Deserialize, Deserializer, Serialize, Serializer};

use crate::error::Error;
use crate::netapp::*;
use crate::proto::*;
use crate::util::*;

/// This trait should be implemented by all messages your application
/// wants to handle
pub trait Message: SerializeMessage + Send + Sync {
	type Response: SerializeMessage + Send + Sync;
}

/// A trait for de/serializing messages, with possible associated stream.
#[async_trait]
pub trait SerializeMessage: Sized {
	fn serialize_msg<S: Serializer>(
		&self,
		serializer: S,
	) -> Result<(S::Ok, Option<AssociatedStream>), S::Error>;

	async fn deserialize_msg<'de, D: Deserializer<'de> + Send>(
		deserializer: D,
		stream: AssociatedStream,
	) -> Result<Self, D::Error>;
}

#[async_trait]
impl<T> SerializeMessage for T
where
	T: Serialize + for<'de> Deserialize<'de> + Send + Sync,
{
	fn serialize_msg<S: Serializer>(
		&self,
		serializer: S,
	) -> Result<(S::Ok, Option<AssociatedStream>), S::Error> {
		self.serialize(serializer).map(|r| (r, None))
	}

	async fn deserialize_msg<'de, D: Deserializer<'de> + Send>(
		deserializer: D,
		mut stream: AssociatedStream,
	) -> Result<Self, D::Error> {
		use futures::StreamExt;

		let res = Self::deserialize(deserializer)?;
		if stream.next().await.is_some() {
			return Err(D::Error::custom(
				"failed to deserialize: found associated stream when none expected",
			));
		}
		Ok(res)
	}
}

/// This trait should be implemented by an object of your application
/// that can handle a message of type `M`.
///
/// The handler object should be in an Arc, see `Endpoint::set_handler`
#[async_trait]
pub trait EndpointHandler<M>: Send + Sync
where
	M: Message,
{
	async fn handle(self: &Arc<Self>, m: &M, from: NodeID) -> M::Response;
}

/// If one simply wants to use an endpoint in a client fashion,
/// without locally serving requests to that endpoint,
/// use the unit type `()` as the handler type:
/// it will panic if it is ever made to handle request.
#[async_trait]
impl<M: Message + 'static> EndpointHandler<M> for () {
	async fn handle(self: &Arc<()>, _m: &M, _from: NodeID) -> M::Response {
		panic!("This endpoint should not have a local handler.");
	}
}

/// This struct represents an endpoint for message of type `M`.
///
/// Creating a new endpoint is done by calling `NetApp::endpoint`.
/// An endpoint is identified primarily by its path, which is specified
/// at creation time.
///
/// An `Endpoint` is used both to send requests to remote nodes,
/// and to specify the handler for such requests on the local node.
/// The type `H` represents the type of the handler object for
/// endpoint messages (see `EndpointHandler`).
pub struct Endpoint<M, H>
where
	M: Message,
	H: EndpointHandler<M>,
{
	phantom: PhantomData<M>,
	netapp: Arc<NetApp>,
	path: String,
	handler: ArcSwapOption<H>,
}

impl<M, H> Endpoint<M, H>
where
	M: Message,
	H: EndpointHandler<M>,
{
	pub(crate) fn new(netapp: Arc<NetApp>, path: String) -> Self {
		Self {
			phantom: PhantomData::default(),
			netapp,
			path,
			handler: ArcSwapOption::from(None),
		}
	}

	/// Get the path of this endpoint
	pub fn path(&self) -> &str {
		&self.path
	}

	/// Set the object that is responsible of handling requests to
	/// this endpoint on the local node.
	pub fn set_handler(&self, h: Arc<H>) {
		self.handler.swap(Some(h));
	}

	/// Call this endpoint on a remote node (or on the local node,
	/// for that matter)
	pub async fn call<B>(
		&self,
		target: &NodeID,
		req: B,
		prio: RequestPriority,
	) -> Result<<M as Message>::Response, Error>
	where
		B: Borrow<M>,
	{
		if *target == self.netapp.id {
			match self.handler.load_full() {
				None => Err(Error::NoHandler),
				Some(h) => Ok(h.handle(req.borrow(), self.netapp.id).await),
			}
		} else {
			let conn = self
				.netapp
				.client_conns
				.read()
				.unwrap()
				.get(target)
				.cloned();
			match conn {
				None => Err(Error::Message(format!(
					"Not connected: {}",
					hex::encode(&target[..8])
				))),
				Some(c) => c.call(req, self.path.as_str(), prio).await,
			}
		}
	}
}

// ---- Internal stuff ----

pub(crate) type DynEndpoint = Box<dyn GenericEndpoint + Send + Sync>;

#[async_trait]
pub(crate) trait GenericEndpoint {
	async fn handle(
		&self,
		buf: &[u8],
		stream: AssociatedStream,
		from: NodeID,
	) -> Result<(Vec<u8>, Option<AssociatedStream>), Error>;
	fn drop_handler(&self);
	fn clone_endpoint(&self) -> DynEndpoint;
}

#[derive(Clone)]
pub(crate) struct EndpointArc<M, H>(pub(crate) Arc<Endpoint<M, H>>)
where
	M: Message,
	H: EndpointHandler<M>;

#[async_trait]
impl<M, H> GenericEndpoint for EndpointArc<M, H>
where
	M: Message + 'static,
	H: EndpointHandler<M> + 'static,
{
	async fn handle(
		&self,
		buf: &[u8],
		stream: AssociatedStream,
		from: NodeID,
	) -> Result<(Vec<u8>, Option<AssociatedStream>), Error> {
		match self.0.handler.load_full() {
			None => Err(Error::NoHandler),
			Some(h) => {
				let mut deser = rmp_serde::decode::Deserializer::from_read_ref(buf);
				let req = M::deserialize_msg(&mut deser, stream).await?;
				let res = h.handle(&req, from).await;
				let res_bytes = rmp_to_vec_all_named(&res)?;
				Ok(res_bytes)
			}
		}
	}

	fn drop_handler(&self) {
		self.0.handler.swap(None);
	}

	fn clone_endpoint(&self) -> DynEndpoint {
		Box::new(Self(self.0.clone()))
	}
}