Skip to main content

citadel_sdk/prefabs/server/
internal_service.rs

1//! Internal Service Integration
2//!
3//! This module provides a network kernel that enables integration of internal services,
4//! such as HTTP servers, within the Citadel Protocol network. It's particularly useful
5//! for implementing web services that need to communicate over secure Citadel channels.
6//!
7//! # Features
8//! - Internal service integration
9//! - HTTP server support
10//! - Custom service handlers
11//! - Asynchronous processing
12//! - Type-safe communication
13//! - Automatic channel management
14//! - Service lifecycle handling
15//!
16//! # Example
17//! ```rust
18//! use std::convert::Infallible;
19//! use std::net::SocketAddr;
20//! use citadel_sdk::prelude::*;
21//! use hyper::{Response, Body, Server, Request};
22//! use hyper::server::conn::AddrStream;
23//! use hyper::service::{make_service_fn, service_fn};
24//! use citadel_sdk::prefabs::server::internal_service::InternalServiceKernel;
25//!
26//! // Create a kernel with an HTTP server
27//! let kernel = InternalServiceKernel::<_, _, StackedRatchet>::new(|comm| async move {
28//!
29//!     let make_svc = make_service_fn(|socket: &AddrStream| {
30//!         let remote_addr = socket.remote_addr();
31//!         async move {
32//!             Ok::<_, Infallible>(service_fn(move |_: Request<Body>| async move {
33//!                 Ok::<_, Infallible>(
34//!                     Response::new(Body::from(format!("Hello, {}!", remote_addr)))
35//!                 )
36//!             }))
37//!         }
38//!     });
39//!
40//!     // Start the HTTP server
41//!     let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
42//!     let server = Server::bind(&addr).serve(make_svc);
43//!     // Run this server indefinitely
44//!     if let Err(e) = server.await {
45//!         eprintln!("server error: {}", e);
46//!     }
47//!
48//!    Ok(())
49//! });
50//! ```
51//!
52//! # Important Notes
53//! - Services run in isolated contexts
54//! - Communication is bidirectional
55//! - Supports HTTP/1.1 and HTTP/2
56//! - Automatic error handling
57//! - Resource cleanup on shutdown
58//!
59//! # Related Components
60//! - [`NetKernel`]: Base trait for network kernels
61//! - [`InternalServerCommunicator`]: Service communication
62//! - [`ClientConnectListenerKernel`]: Connection handling
63//! - [`NodeResult`]: Network event handling
64//!
65//! [`NetKernel`]: crate::prelude::NetKernel
66//! [`InternalServerCommunicator`]: crate::prefabs::shared::internal_service::InternalServerCommunicator
67//! [`ClientConnectListenerKernel`]: crate::prefabs::server::client_connect_listener::ClientConnectListenerKernel
68//! [`NodeResult`]: crate::prelude::NodeResult
69
70use crate::prefabs::shared::internal_service::InternalServerCommunicator;
71use crate::prelude::*;
72use std::future::Future;
73use std::marker::PhantomData;
74
75pub struct InternalServiceKernel<'a, F, Fut, R: Ratchet = StackedRatchet> {
76    inner_kernel: Box<dyn NetKernel<R> + 'a>,
77    _pd: PhantomData<fn() -> (&'a F, Fut)>,
78}
79
80impl<F, Fut, R: Ratchet> InternalServiceKernel<'_, F, Fut, R>
81where
82    F: Send + Copy + Sync + FnOnce(InternalServerCommunicator) -> Fut,
83    Fut: Send + Sync + Future<Output = Result<(), NetworkError>>,
84{
85    pub fn new(on_create_webserver: F) -> Self {
86        Self {
87            _pd: Default::default(),
88            inner_kernel: Box::new(
89                super::client_connect_listener::ClientConnectListenerKernel::new(
90                    move |connect_success| async move {
91                        crate::prefabs::shared::internal_service::internal_service(
92                            connect_success,
93                            on_create_webserver,
94                        )
95                        .await
96                    },
97                ),
98            ),
99        }
100    }
101}
102
103#[async_trait]
104impl<F, Fut, R: Ratchet> NetKernel<R> for InternalServiceKernel<'_, F, Fut, R> {
105    fn load_remote(&mut self, node_remote: NodeRemote<R>) -> Result<(), NetworkError> {
106        self.inner_kernel.load_remote(node_remote)
107    }
108
109    async fn on_start(&self) -> Result<(), NetworkError> {
110        self.inner_kernel.on_start().await
111    }
112
113    async fn on_node_event_received(&self, message: NodeResult<R>) -> Result<(), NetworkError> {
114        self.inner_kernel.on_node_event_received(message).await
115    }
116
117    async fn on_stop(&mut self) -> Result<(), NetworkError> {
118        self.inner_kernel.on_stop().await
119    }
120}
121
122#[cfg(all(test, feature = "localhost-testing"))]
123mod test {
124    use crate::prefabs::client::single_connection::SingleClientServerConnectionKernel;
125    use crate::prefabs::client::DefaultServerConnectionSettingsBuilder;
126    use crate::prefabs::server::internal_service::InternalServiceKernel;
127    use crate::prefabs::shared::internal_service::InternalServerCommunicator;
128    use crate::prelude::*;
129    use crate::test_common::TestBarrier;
130    use citadel_io::tokio;
131    use citadel_logging::setup_log;
132    use hyper::client::conn::Builder;
133    use hyper::server::conn::Http;
134    use hyper::service::service_fn;
135    use hyper::{Body, Error, Request, Response, StatusCode};
136    use rstest::rstest;
137    use std::convert::Infallible;
138    use std::sync::atomic::{AtomicUsize, Ordering};
139    use std::time::Duration;
140
141    #[derive(serde::Serialize, serde::Deserialize)]
142    struct TestPacket {
143        packet: Vec<u8>,
144    }
145
146    fn from_hyper_error(e: Error) -> NetworkError {
147        citadel_io::error!(
148            citadel_io::ErrorCode::InternalServiceHyperError,
149            e.to_string()
150        )
151    }
152
153    async fn test_write_and_read_one_packet(
154        barrier: &TestBarrier,
155        internal_server_communicator: &mut InternalServerCommunicator,
156        message: &Vec<u8>,
157        success_count: &AtomicUsize,
158    ) -> Result<(), NetworkError> {
159        barrier.wait().await;
160        let packet = TestPacket {
161            packet: message.clone(),
162        }
163        .serialize_to_vector()
164        .unwrap();
165        let internal_server_communicator =
166            write_one_packet(internal_server_communicator, packet).await?;
167        let (_, response) =
168            read_one_packet_as_framed::<_, TestPacket>(internal_server_communicator).await?;
169        barrier.wait().await;
170
171        if &response.packet != message {
172            return Err(citadel_io::error!(
173                citadel_io::ErrorCode::InternalServiceResponseMismatch
174            ));
175        }
176
177        let _ = success_count.fetch_add(1, Ordering::SeqCst);
178        barrier.wait().await;
179
180        Ok(())
181    }
182
183    #[rstest]
184    #[timeout(Duration::from_secs(60))]
185    #[citadel_io::tokio::test]
186    async fn test_internal_service_basic_bytes() {
187        setup_log();
188        let barrier = &TestBarrier::new(2);
189        let success_count = &AtomicUsize::new(0);
190        let message = &(0..4096usize)
191            .map(|r| (r % u8::MAX as usize) as u8)
192            .collect::<Vec<u8>>();
193        let server_listener = citadel_wire::socket_helpers::get_tcp_listener("0.0.0.0:0")
194            .expect("Failed to get TCP listener");
195        let server_bind_addr = server_listener.local_addr().unwrap();
196        let server_kernel =
197            InternalServiceKernel::new(|mut internal_server_communicator| async move {
198                test_write_and_read_one_packet(
199                    barrier,
200                    &mut internal_server_communicator,
201                    message,
202                    success_count,
203                )
204                .await
205            });
206
207        let server_connection_settings =
208            DefaultServerConnectionSettingsBuilder::transient(server_bind_addr)
209                .build()
210                .unwrap();
211
212        let client_kernel = SingleClientServerConnectionKernel::new(
213            server_connection_settings,
214            |connection| async move {
215                crate::prefabs::shared::internal_service::internal_service(
216                    connection,
217                    |mut internal_server_communicator| async move {
218                        test_write_and_read_one_packet(
219                            barrier,
220                            &mut internal_server_communicator,
221                            message,
222                            success_count,
223                        )
224                        .await
225                    },
226                )
227                .await
228            },
229        );
230
231        let client = DefaultNodeBuilder::default()
232            .with_node_type(NodeType::Peer)
233            .build(client_kernel)
234            .unwrap();
235
236        let server = DefaultNodeBuilder::default()
237            .with_node_type(NodeType::Server(server_bind_addr))
238            .with_underlying_protocol(ServerMode::OrderedReliable(
239                NativeOrderedReliableConfig::from_tokio_listener(server_listener).unwrap(),
240            ))
241            .build(server_kernel)
242            .unwrap();
243
244        let res = citadel_io::tokio::select! {
245            res0 = server => {
246                citadel_logging::info!(target: "citadel", "Server exited");
247                res0.map(|_|())
248            },
249
250            res1 = client => {
251                citadel_logging::info!(target: "citadel", "Client exited");
252                res1.map(|_|())
253            }
254        };
255
256        res.unwrap();
257
258        assert_eq!(success_count.load(Ordering::SeqCst), 2);
259    }
260
261    #[rstest]
262    #[timeout(Duration::from_secs(60))]
263    #[citadel_io::tokio::test]
264    async fn test_internal_service_http() {
265        setup_log();
266        let barrier = &TestBarrier::new(2);
267        let success_count = &AtomicUsize::new(0);
268        let server_listener = citadel_wire::socket_helpers::get_tcp_listener("0.0.0.0:0")
269            .expect("Failed to get TCP listener");
270        let server_bind_addr = server_listener.local_addr().unwrap();
271
272        let server_kernel = InternalServiceKernel::new(|internal_server_communicator| async move {
273            barrier.wait().await;
274
275            async fn hello(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
276                Ok(Response::new(Body::from("Hello World!")))
277            }
278
279            Http::new()
280                .serve_connection(internal_server_communicator, service_fn(hello))
281                .await
282                .map_err(from_hyper_error)?;
283
284            Ok(())
285        });
286
287        let server_connection_settings =
288            DefaultServerConnectionSettingsBuilder::transient(server_bind_addr)
289                .build()
290                .unwrap();
291
292        let client_kernel = SingleClientServerConnectionKernel::new(
293            server_connection_settings,
294            |connection| async move {
295                crate::prefabs::shared::internal_service::internal_service(
296                    connection,
297                    |internal_server_communicator| async move {
298                        barrier.wait().await;
299                        // wait for the server
300                        citadel_io::time::sleep(Duration::from_millis(500)).await;
301                        let (mut request_sender, connection) = Builder::new()
302                            .handshake(internal_server_communicator)
303                            .await
304                            .map_err(from_hyper_error)?;
305
306                        // spawn a task to poll the connection and drive the HTTP state
307                        drop(citadel_io::tokio::spawn(async move {
308                            if let Err(e) = connection.await {
309                                citadel_logging::error!(target: "citadel", "Error in connection: {e}");
310                                std::process::exit(-1);
311                            }
312                        }));
313
314                        // give time for task to spawn
315                        citadel_io::time::sleep(Duration::from_millis(100)).await;
316                        let request = Request::builder()
317                            // We need to manually add the host header because SendRequest does not
318                            .header("Host", "example.com")
319                            .method("GET")
320                            .body(Body::from(""))
321                            .map_err(|err| citadel_io::error!(citadel_io::ErrorCode::InternalServiceHyperError, err.to_string()))?;
322                        let response = request_sender.send_request(request).await.map_err(from_hyper_error)?;
323                        assert_eq!(response.status(), StatusCode::OK);
324
325                        let body_bytes = hyper::body::to_bytes(response.into_body()).await.map_err(from_hyper_error)?;
326                        assert_eq!(&body_bytes, b"Hello World!" as &[u8]);
327                        let _ = success_count.fetch_add(1, Ordering::SeqCst);
328
329                        // To send via the same connection again, it may not work as it may not be ready,
330                        // so we have to wait until the request_sender becomes ready. (requires tower)
331                        // request_sender.ready().await.map_err(from_hyper_error)?;
332                        Ok(())
333                    },
334                )
335                    .await
336            },
337        );
338
339        let client = DefaultNodeBuilder::default()
340            .with_node_type(NodeType::Peer)
341            .build(client_kernel)
342            .unwrap();
343
344        let server = DefaultNodeBuilder::default()
345            .with_node_type(NodeType::Server(server_bind_addr))
346            .with_underlying_protocol(ServerMode::OrderedReliable(
347                NativeOrderedReliableConfig::from_tokio_listener(server_listener).unwrap(),
348            ))
349            .build(server_kernel)
350            .unwrap();
351
352        let res = citadel_io::tokio::select! {
353            res0 = server => {
354                citadel_logging::info!(target: "citadel", "Server exited");
355                res0.map(|_|())
356            },
357
358            res1 = client => {
359                citadel_logging::info!(target: "citadel", "Client exited");
360                res1.map(|_|())
361            }
362        };
363
364        res.unwrap();
365
366        assert_eq!(success_count.load(Ordering::SeqCst), 1);
367    }
368}