citadel_sdk/prefabs/server/
internal_service.rs

1//! Internal Service Integration
2//!
3//! This module provides a network kernel that enables integration of internal services,
4//! such as HTTP servers, within the Citadel Protocol network. It's particularly useful
5//! for implementing web services that need to communicate over secure Citadel channels.
6//!
7//! # Features
8//! - Internal service integration
9//! - HTTP server support
10//! - Custom service handlers
11//! - Asynchronous processing
12//! - Type-safe communication
13//! - Automatic channel management
14//! - Service lifecycle handling
15//!
16//! # Example
17//! ```rust
18//! use std::convert::Infallible;
19//! use std::net::SocketAddr;
20//! use citadel_sdk::prelude::*;
21//! use hyper::{Response, Body, Server, Request};
22//! use hyper::server::conn::AddrStream;
23//! use hyper::service::{make_service_fn, service_fn};
24//! use citadel_sdk::prefabs::server::internal_service::InternalServiceKernel;
25//!
26//! // Create a kernel with an HTTP server
27//! let kernel = InternalServiceKernel::<_, _, StackedRatchet>::new(|comm| async move {
28//!
29//!     let make_svc = make_service_fn(|socket: &AddrStream| {
30//!         let remote_addr = socket.remote_addr();
31//!         async move {
32//!             Ok::<_, Infallible>(service_fn(move |_: Request<Body>| async move {
33//!                 Ok::<_, Infallible>(
34//!                     Response::new(Body::from(format!("Hello, {}!", remote_addr)))
35//!                 )
36//!             }))
37//!         }
38//!     });
39//!
40//!     // Start the HTTP server
41//!     let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
42//!     let server = Server::bind(&addr).serve(make_svc);
43//!     // Run this server indefinitely
44//!     if let Err(e) = server.await {
45//!         eprintln!("server error: {}", e);
46//!     }
47//!
48//!    Ok(())
49//! });
50//! ```
51//!
52//! # Important Notes
53//! - Services run in isolated contexts
54//! - Communication is bidirectional
55//! - Supports HTTP/1.1 and HTTP/2
56//! - Automatic error handling
57//! - Resource cleanup on shutdown
58//!
59//! # Related Components
60//! - [`NetKernel`]: Base trait for network kernels
61//! - [`InternalServerCommunicator`]: Service communication
62//! - [`ClientConnectListenerKernel`]: Connection handling
63//! - [`NodeResult`]: Network event handling
64//!
65//! [`NetKernel`]: crate::prelude::NetKernel
66//! [`InternalServerCommunicator`]: crate::prefabs::shared::internal_service::InternalServerCommunicator
67//! [`ClientConnectListenerKernel`]: crate::prefabs::server::client_connect_listener::ClientConnectListenerKernel
68//! [`NodeResult`]: crate::prelude::NodeResult
69
70use crate::prefabs::shared::internal_service::InternalServerCommunicator;
71use crate::prelude::*;
72use std::future::Future;
73use std::marker::PhantomData;
74
75pub struct InternalServiceKernel<'a, F, Fut, R: Ratchet = StackedRatchet> {
76    inner_kernel: Box<dyn NetKernel<R> + 'a>,
77    _pd: PhantomData<fn() -> (&'a F, Fut)>,
78}
79
80impl<F, Fut, R: Ratchet> InternalServiceKernel<'_, F, Fut, R>
81where
82    F: Send + Copy + Sync + FnOnce(InternalServerCommunicator) -> Fut,
83    Fut: Send + Sync + Future<Output = Result<(), NetworkError>>,
84{
85    pub fn new(on_create_webserver: F) -> Self {
86        Self {
87            _pd: Default::default(),
88            inner_kernel: Box::new(
89                super::client_connect_listener::ClientConnectListenerKernel::new(
90                    move |connect_success| async move {
91                        crate::prefabs::shared::internal_service::internal_service(
92                            connect_success,
93                            on_create_webserver,
94                        )
95                        .await
96                    },
97                ),
98            ),
99        }
100    }
101}
102
103#[async_trait]
104impl<F, Fut, R: Ratchet> NetKernel<R> for InternalServiceKernel<'_, F, Fut, R> {
105    fn load_remote(&mut self, node_remote: NodeRemote<R>) -> Result<(), NetworkError> {
106        self.inner_kernel.load_remote(node_remote)
107    }
108
109    async fn on_start(&self) -> Result<(), NetworkError> {
110        self.inner_kernel.on_start().await
111    }
112
113    async fn on_node_event_received(&self, message: NodeResult<R>) -> Result<(), NetworkError> {
114        self.inner_kernel.on_node_event_received(message).await
115    }
116
117    async fn on_stop(&mut self) -> Result<(), NetworkError> {
118        self.inner_kernel.on_stop().await
119    }
120}
121
122#[cfg(test)]
123mod test {
124    use crate::prefabs::client::single_connection::SingleClientServerConnectionKernel;
125    use crate::prefabs::client::DefaultServerConnectionSettingsBuilder;
126    use crate::prefabs::server::internal_service::InternalServiceKernel;
127    use crate::prefabs::shared::internal_service::InternalServerCommunicator;
128    use crate::prelude::*;
129    use crate::test_common::TestBarrier;
130    use citadel_io::tokio;
131    use citadel_logging::setup_log;
132    use hyper::client::conn::Builder;
133    use hyper::server::conn::Http;
134    use hyper::service::service_fn;
135    use hyper::{Body, Error, Request, Response, StatusCode};
136    use rstest::rstest;
137    use std::convert::Infallible;
138    use std::sync::atomic::{AtomicUsize, Ordering};
139    use std::time::Duration;
140
141    #[derive(serde::Serialize, serde::Deserialize)]
142    struct TestPacket {
143        packet: Vec<u8>,
144    }
145
146    fn from_hyper_error(e: Error) -> NetworkError {
147        NetworkError::msg(format!("Hyper error: {e}"))
148    }
149
150    async fn test_write_and_read_one_packet(
151        barrier: &TestBarrier,
152        internal_server_communicator: &mut InternalServerCommunicator,
153        message: &Vec<u8>,
154        success_count: &AtomicUsize,
155    ) -> Result<(), NetworkError> {
156        barrier.wait().await;
157        let packet = TestPacket {
158            packet: message.clone(),
159        }
160        .serialize_to_vector()
161        .unwrap();
162        let internal_server_communicator =
163            write_one_packet(internal_server_communicator, packet).await?;
164        let (_, response) =
165            read_one_packet_as_framed::<_, TestPacket>(internal_server_communicator).await?;
166        barrier.wait().await;
167
168        if &response.packet != message {
169            return Err(NetworkError::msg("Response did not match request"));
170        }
171
172        let _ = success_count.fetch_add(1, Ordering::SeqCst);
173        barrier.wait().await;
174
175        Ok(())
176    }
177
178    #[rstest]
179    #[timeout(Duration::from_secs(60))]
180    #[citadel_io::tokio::test]
181    async fn test_internal_service_basic_bytes() {
182        setup_log();
183        let barrier = &TestBarrier::new(2);
184        let success_count = &AtomicUsize::new(0);
185        let message = &(0..4096usize)
186            .map(|r| (r % u8::MAX as usize) as u8)
187            .collect::<Vec<u8>>();
188        let server_listener = citadel_wire::socket_helpers::get_tcp_listener("0.0.0.0:0")
189            .expect("Failed to get TCP listener");
190        let server_bind_addr = server_listener.local_addr().unwrap();
191        let server_kernel =
192            InternalServiceKernel::new(|mut internal_server_communicator| async move {
193                test_write_and_read_one_packet(
194                    barrier,
195                    &mut internal_server_communicator,
196                    message,
197                    success_count,
198                )
199                .await
200            });
201
202        let server_connection_settings =
203            DefaultServerConnectionSettingsBuilder::transient(server_bind_addr)
204                .build()
205                .unwrap();
206
207        let client_kernel = SingleClientServerConnectionKernel::new(
208            server_connection_settings,
209            |connection| async move {
210                crate::prefabs::shared::internal_service::internal_service(
211                    connection,
212                    |mut internal_server_communicator| async move {
213                        test_write_and_read_one_packet(
214                            barrier,
215                            &mut internal_server_communicator,
216                            message,
217                            success_count,
218                        )
219                        .await
220                    },
221                )
222                .await
223            },
224        );
225
226        let client = DefaultNodeBuilder::default()
227            .with_node_type(NodeType::Peer)
228            .build(client_kernel)
229            .unwrap();
230
231        let server = DefaultNodeBuilder::default()
232            .with_node_type(NodeType::Server(server_bind_addr))
233            .with_underlying_protocol(
234                ServerUnderlyingProtocol::from_tokio_tcp_listener(server_listener).unwrap(),
235            )
236            .build(server_kernel)
237            .unwrap();
238
239        let res = citadel_io::tokio::select! {
240            res0 = server => {
241                citadel_logging::info!(target: "citadel", "Server exited");
242                res0.map(|_|())
243            },
244
245            res1 = client => {
246                citadel_logging::info!(target: "citadel", "Client exited");
247                res1.map(|_|())
248            }
249        };
250
251        res.unwrap();
252
253        assert_eq!(success_count.load(Ordering::SeqCst), 2);
254    }
255
256    #[rstest]
257    #[timeout(Duration::from_secs(60))]
258    #[citadel_io::tokio::test]
259    async fn test_internal_service_http() {
260        setup_log();
261        let barrier = &TestBarrier::new(2);
262        let success_count = &AtomicUsize::new(0);
263        let server_listener = citadel_wire::socket_helpers::get_tcp_listener("0.0.0.0:0")
264            .expect("Failed to get TCP listener");
265        let server_bind_addr = server_listener.local_addr().unwrap();
266
267        let server_kernel = InternalServiceKernel::new(|internal_server_communicator| async move {
268            barrier.wait().await;
269
270            async fn hello(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
271                Ok(Response::new(Body::from("Hello World!")))
272            }
273
274            Http::new()
275                .serve_connection(internal_server_communicator, service_fn(hello))
276                .await
277                .map_err(from_hyper_error)?;
278
279            Ok(())
280        });
281
282        let server_connection_settings =
283            DefaultServerConnectionSettingsBuilder::transient(server_bind_addr)
284                .build()
285                .unwrap();
286
287        let client_kernel = SingleClientServerConnectionKernel::new(
288            server_connection_settings,
289            |connection| async move {
290                crate::prefabs::shared::internal_service::internal_service(
291                    connection,
292                    |internal_server_communicator| async move {
293                        barrier.wait().await;
294                        // wait for the server
295                        citadel_io::tokio::time::sleep(Duration::from_millis(500)).await;
296                        let (mut request_sender, connection) = Builder::new()
297                            .handshake(internal_server_communicator)
298                            .await
299                            .map_err(from_hyper_error)?;
300
301                        // spawn a task to poll the connection and drive the HTTP state
302                        drop(citadel_io::tokio::spawn(async move {
303                            if let Err(e) = connection.await {
304                                citadel_logging::error!(target: "citadel", "Error in connection: {e}");
305                                std::process::exit(-1);
306                            }
307                        }));
308
309                        // give time for task to spawn
310                        citadel_io::tokio::time::sleep(Duration::from_millis(100)).await;
311                        let request = Request::builder()
312                            // We need to manually add the host header because SendRequest does not
313                            .header("Host", "example.com")
314                            .method("GET")
315                            .body(Body::from(""))
316                            .map_err(|err| NetworkError::msg(format!("hyper error: {err}")))?;
317                        let response = request_sender.send_request(request).await.map_err(from_hyper_error)?;
318                        assert_eq!(response.status(), StatusCode::OK);
319
320                        let body_bytes = hyper::body::to_bytes(response.into_body()).await.map_err(from_hyper_error)?;
321                        assert_eq!(&body_bytes, b"Hello World!" as &[u8]);
322                        let _ = success_count.fetch_add(1, Ordering::SeqCst);
323
324                        // To send via the same connection again, it may not work as it may not be ready,
325                        // so we have to wait until the request_sender becomes ready. (requires tower)
326                        // request_sender.ready().await.map_err(from_hyper_error)?;
327                        Ok(())
328                    },
329                )
330                    .await
331            },
332        );
333
334        let client = DefaultNodeBuilder::default()
335            .with_node_type(NodeType::Peer)
336            .build(client_kernel)
337            .unwrap();
338
339        let server = DefaultNodeBuilder::default()
340            .with_node_type(NodeType::Server(server_bind_addr))
341            .with_underlying_protocol(
342                ServerUnderlyingProtocol::from_tokio_tcp_listener(server_listener).unwrap(),
343            )
344            .build(server_kernel)
345            .unwrap();
346
347        let res = citadel_io::tokio::select! {
348            res0 = server => {
349                citadel_logging::info!(target: "citadel", "Server exited");
350                res0.map(|_|())
351            },
352
353            res1 = client => {
354                citadel_logging::info!(target: "citadel", "Client exited");
355                res1.map(|_|())
356            }
357        };
358
359        res.unwrap();
360
361        assert_eq!(success_count.load(Ordering::SeqCst), 1);
362    }
363}