citadel_sdk/prefabs/server/
internal_service.rs1use crate::prefabs::shared::internal_service::InternalServerCommunicator;
71use crate::prelude::*;
72use std::future::Future;
73use std::marker::PhantomData;
74
75pub struct InternalServiceKernel<'a, F, Fut, R: Ratchet = StackedRatchet> {
76 inner_kernel: Box<dyn NetKernel<R> + 'a>,
77 _pd: PhantomData<fn() -> (&'a F, Fut)>,
78}
79
80impl<F, Fut, R: Ratchet> InternalServiceKernel<'_, F, Fut, R>
81where
82 F: Send + Copy + Sync + FnOnce(InternalServerCommunicator) -> Fut,
83 Fut: Send + Sync + Future<Output = Result<(), NetworkError>>,
84{
85 pub fn new(on_create_webserver: F) -> Self {
86 Self {
87 _pd: Default::default(),
88 inner_kernel: Box::new(
89 super::client_connect_listener::ClientConnectListenerKernel::new(
90 move |connect_success| async move {
91 crate::prefabs::shared::internal_service::internal_service(
92 connect_success,
93 on_create_webserver,
94 )
95 .await
96 },
97 ),
98 ),
99 }
100 }
101}
102
103#[async_trait]
104impl<F, Fut, R: Ratchet> NetKernel<R> for InternalServiceKernel<'_, F, Fut, R> {
105 fn load_remote(&mut self, node_remote: NodeRemote<R>) -> Result<(), NetworkError> {
106 self.inner_kernel.load_remote(node_remote)
107 }
108
109 async fn on_start(&self) -> Result<(), NetworkError> {
110 self.inner_kernel.on_start().await
111 }
112
113 async fn on_node_event_received(&self, message: NodeResult<R>) -> Result<(), NetworkError> {
114 self.inner_kernel.on_node_event_received(message).await
115 }
116
117 async fn on_stop(&mut self) -> Result<(), NetworkError> {
118 self.inner_kernel.on_stop().await
119 }
120}
121
122#[cfg(test)]
123mod test {
124 use crate::prefabs::client::single_connection::SingleClientServerConnectionKernel;
125 use crate::prefabs::client::DefaultServerConnectionSettingsBuilder;
126 use crate::prefabs::server::internal_service::InternalServiceKernel;
127 use crate::prefabs::shared::internal_service::InternalServerCommunicator;
128 use crate::prelude::*;
129 use crate::test_common::TestBarrier;
130 use citadel_io::tokio;
131 use citadel_logging::setup_log;
132 use hyper::client::conn::Builder;
133 use hyper::server::conn::Http;
134 use hyper::service::service_fn;
135 use hyper::{Body, Error, Request, Response, StatusCode};
136 use rstest::rstest;
137 use std::convert::Infallible;
138 use std::sync::atomic::{AtomicUsize, Ordering};
139 use std::time::Duration;
140
141 #[derive(serde::Serialize, serde::Deserialize)]
142 struct TestPacket {
143 packet: Vec<u8>,
144 }
145
146 fn from_hyper_error(e: Error) -> NetworkError {
147 NetworkError::msg(format!("Hyper error: {e}"))
148 }
149
150 async fn test_write_and_read_one_packet(
151 barrier: &TestBarrier,
152 internal_server_communicator: &mut InternalServerCommunicator,
153 message: &Vec<u8>,
154 success_count: &AtomicUsize,
155 ) -> Result<(), NetworkError> {
156 barrier.wait().await;
157 let packet = TestPacket {
158 packet: message.clone(),
159 }
160 .serialize_to_vector()
161 .unwrap();
162 let internal_server_communicator =
163 write_one_packet(internal_server_communicator, packet).await?;
164 let (_, response) =
165 read_one_packet_as_framed::<_, TestPacket>(internal_server_communicator).await?;
166 barrier.wait().await;
167
168 if &response.packet != message {
169 return Err(NetworkError::msg("Response did not match request"));
170 }
171
172 let _ = success_count.fetch_add(1, Ordering::SeqCst);
173 barrier.wait().await;
174
175 Ok(())
176 }
177
178 #[rstest]
179 #[timeout(Duration::from_secs(60))]
180 #[citadel_io::tokio::test]
181 async fn test_internal_service_basic_bytes() {
182 setup_log();
183 let barrier = &TestBarrier::new(2);
184 let success_count = &AtomicUsize::new(0);
185 let message = &(0..4096usize)
186 .map(|r| (r % u8::MAX as usize) as u8)
187 .collect::<Vec<u8>>();
188 let server_listener = citadel_wire::socket_helpers::get_tcp_listener("0.0.0.0:0")
189 .expect("Failed to get TCP listener");
190 let server_bind_addr = server_listener.local_addr().unwrap();
191 let server_kernel =
192 InternalServiceKernel::new(|mut internal_server_communicator| async move {
193 test_write_and_read_one_packet(
194 barrier,
195 &mut internal_server_communicator,
196 message,
197 success_count,
198 )
199 .await
200 });
201
202 let server_connection_settings =
203 DefaultServerConnectionSettingsBuilder::transient(server_bind_addr)
204 .build()
205 .unwrap();
206
207 let client_kernel = SingleClientServerConnectionKernel::new(
208 server_connection_settings,
209 |connection| async move {
210 crate::prefabs::shared::internal_service::internal_service(
211 connection,
212 |mut internal_server_communicator| async move {
213 test_write_and_read_one_packet(
214 barrier,
215 &mut internal_server_communicator,
216 message,
217 success_count,
218 )
219 .await
220 },
221 )
222 .await
223 },
224 );
225
226 let client = DefaultNodeBuilder::default()
227 .with_node_type(NodeType::Peer)
228 .build(client_kernel)
229 .unwrap();
230
231 let server = DefaultNodeBuilder::default()
232 .with_node_type(NodeType::Server(server_bind_addr))
233 .with_underlying_protocol(
234 ServerUnderlyingProtocol::from_tokio_tcp_listener(server_listener).unwrap(),
235 )
236 .build(server_kernel)
237 .unwrap();
238
239 let res = citadel_io::tokio::select! {
240 res0 = server => {
241 citadel_logging::info!(target: "citadel", "Server exited");
242 res0.map(|_|())
243 },
244
245 res1 = client => {
246 citadel_logging::info!(target: "citadel", "Client exited");
247 res1.map(|_|())
248 }
249 };
250
251 res.unwrap();
252
253 assert_eq!(success_count.load(Ordering::SeqCst), 2);
254 }
255
256 #[rstest]
257 #[timeout(Duration::from_secs(60))]
258 #[citadel_io::tokio::test]
259 async fn test_internal_service_http() {
260 setup_log();
261 let barrier = &TestBarrier::new(2);
262 let success_count = &AtomicUsize::new(0);
263 let server_listener = citadel_wire::socket_helpers::get_tcp_listener("0.0.0.0:0")
264 .expect("Failed to get TCP listener");
265 let server_bind_addr = server_listener.local_addr().unwrap();
266
267 let server_kernel = InternalServiceKernel::new(|internal_server_communicator| async move {
268 barrier.wait().await;
269
270 async fn hello(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
271 Ok(Response::new(Body::from("Hello World!")))
272 }
273
274 Http::new()
275 .serve_connection(internal_server_communicator, service_fn(hello))
276 .await
277 .map_err(from_hyper_error)?;
278
279 Ok(())
280 });
281
282 let server_connection_settings =
283 DefaultServerConnectionSettingsBuilder::transient(server_bind_addr)
284 .build()
285 .unwrap();
286
287 let client_kernel = SingleClientServerConnectionKernel::new(
288 server_connection_settings,
289 |connection| async move {
290 crate::prefabs::shared::internal_service::internal_service(
291 connection,
292 |internal_server_communicator| async move {
293 barrier.wait().await;
294 citadel_io::tokio::time::sleep(Duration::from_millis(500)).await;
296 let (mut request_sender, connection) = Builder::new()
297 .handshake(internal_server_communicator)
298 .await
299 .map_err(from_hyper_error)?;
300
301 drop(citadel_io::tokio::spawn(async move {
303 if let Err(e) = connection.await {
304 citadel_logging::error!(target: "citadel", "Error in connection: {e}");
305 std::process::exit(-1);
306 }
307 }));
308
309 citadel_io::tokio::time::sleep(Duration::from_millis(100)).await;
311 let request = Request::builder()
312 .header("Host", "example.com")
314 .method("GET")
315 .body(Body::from(""))
316 .map_err(|err| NetworkError::msg(format!("hyper error: {err}")))?;
317 let response = request_sender.send_request(request).await.map_err(from_hyper_error)?;
318 assert_eq!(response.status(), StatusCode::OK);
319
320 let body_bytes = hyper::body::to_bytes(response.into_body()).await.map_err(from_hyper_error)?;
321 assert_eq!(&body_bytes, b"Hello World!" as &[u8]);
322 let _ = success_count.fetch_add(1, Ordering::SeqCst);
323
324 Ok(())
328 },
329 )
330 .await
331 },
332 );
333
334 let client = DefaultNodeBuilder::default()
335 .with_node_type(NodeType::Peer)
336 .build(client_kernel)
337 .unwrap();
338
339 let server = DefaultNodeBuilder::default()
340 .with_node_type(NodeType::Server(server_bind_addr))
341 .with_underlying_protocol(
342 ServerUnderlyingProtocol::from_tokio_tcp_listener(server_listener).unwrap(),
343 )
344 .build(server_kernel)
345 .unwrap();
346
347 let res = citadel_io::tokio::select! {
348 res0 = server => {
349 citadel_logging::info!(target: "citadel", "Server exited");
350 res0.map(|_|())
351 },
352
353 res1 = client => {
354 citadel_logging::info!(target: "citadel", "Client exited");
355 res1.map(|_|())
356 }
357 };
358
359 res.unwrap();
360
361 assert_eq!(success_count.load(Ordering::SeqCst), 1);
362 }
363}