1use citadel_proto::prelude::*;
39
40use citadel_io::ServerMode;
41use citadel_proto::kernel::KernelExecutorArguments;
42use citadel_proto::macros::{ContextRequirements, LocalContextRequirements};
43use citadel_types::crypto::{HeaderObfuscatorSettings, PreSharedKey};
44use futures::Future;
45use std::fmt::{Debug, Formatter};
46use std::marker::PhantomData;
47use std::pin::Pin;
48use std::task::{Context, Poll};
49
50pub struct NodeBuilder<R: Ratchet = StackedRatchet, T: PlatformOps = DefaultTransport> {
52 hypernode_type: Option<NodeType>,
53 underlying_protocol: Option<ServerMode<T>>,
54 backend_type: Option<BackendType>,
55 server_argon_settings: Option<ArgonDefaultServerSettings>,
56 #[cfg(feature = "google-services")]
57 services: Option<ServicesConfig>,
58 server_misc_settings: Option<ServerMiscSettings>,
59 client_tls_config: Option<T::ClientConfig>,
60 kernel_executor_settings: Option<KernelExecutorSettings>,
61 stun_servers: Option<Vec<String>>,
62 turn_servers: Option<Vec<TurnServerConfig>>,
63 local_only_server_settings: Option<ServerOnlySessionInitSettings>,
64 websocket_listen_addr: Option<std::net::SocketAddr>,
65 #[cfg(target_family = "wasm")]
66 serverless_config: Option<ServerlessConfig>,
67 _ratchet: PhantomData<R>,
68 _transport: PhantomData<T>,
69}
70
71pub type DefaultNodeBuilder = NodeBuilder<StackedRatchet, DefaultTransport>;
73pub type LightweightNodeBuilder = NodeBuilder<MonoRatchet, DefaultTransport>;
75
76impl<R: Ratchet, T: PlatformOps> Default for NodeBuilder<R, T> {
77 fn default() -> Self {
78 Self {
79 hypernode_type: None,
80 underlying_protocol: None,
81 backend_type: None,
82 server_argon_settings: None,
83 #[cfg(feature = "google-services")]
84 services: None,
85 server_misc_settings: None,
86 client_tls_config: None,
87 kernel_executor_settings: None,
88 stun_servers: None,
89 turn_servers: None,
90 local_only_server_settings: None,
91 websocket_listen_addr: None,
92 #[cfg(target_family = "wasm")]
93 serverless_config: None,
94 _ratchet: Default::default(),
95 _transport: Default::default(),
96 }
97 }
98}
99
100pub struct NodeFuture<'a, K> {
102 inner: Pin<Box<dyn FutureContextRequirements<'a, Result<K, NetworkError>>>>,
103 _pd: PhantomData<fn() -> K>,
104}
105
106#[cfg(feature = "multi-threaded")]
107trait FutureContextRequirements<'a, Output>:
108 Future<Output = Output> + Send + LocalContextRequirements<'a>
109{
110}
111#[cfg(feature = "multi-threaded")]
112impl<'a, T: Future<Output = Output> + Send + LocalContextRequirements<'a>, Output>
113 FutureContextRequirements<'a, Output> for T
114{
115}
116
117#[cfg(not(feature = "multi-threaded"))]
118trait FutureContextRequirements<'a, Output>:
119 Future<Output = Output> + LocalContextRequirements<'a>
120{
121}
122#[cfg(not(feature = "multi-threaded"))]
123impl<'a, T: Future<Output = Output> + LocalContextRequirements<'a>, Output>
124 crate::builder::node_builder::FutureContextRequirements<'a, Output> for T
125{
126}
127
128impl<K> Debug for NodeFuture<'_, K> {
129 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
130 write!(f, "NodeFuture")
131 }
132}
133
134impl<K> Future for NodeFuture<'_, K> {
135 type Output = Result<K, NetworkError>;
136
137 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
138 self.inner.as_mut().poll(cx)
139 }
140}
141
142impl<R: Ratchet + ContextRequirements, T: PlatformOps> NodeBuilder<R, T> {
143 pub fn build<'a, 'b: 'a, K: NetKernel<R> + 'b>(
145 &'a mut self,
146 kernel: K,
147 ) -> anyhow::Result<NodeFuture<'b, K>> {
148 self.check()?;
149 let hypernode_type = self.hypernode_type.take().unwrap_or_default();
150 let backend_type = self.backend_type.take().unwrap_or_default();
151 let server_argon_settings = self.server_argon_settings.take();
152 #[cfg(feature = "google-services")]
153 let server_services_cfg = self.services.take();
154 #[cfg(not(feature = "google-services"))]
155 let server_services_cfg = None;
156 let server_misc_settings = self.server_misc_settings.take();
157 let client_config = self.client_tls_config.take();
158 let kernel_executor_settings = self.kernel_executor_settings.take().unwrap_or_default();
159 let stun_servers = self.stun_servers.take();
160 let turn_servers = self.turn_servers.take();
161 let underlying_proto = self.underlying_protocol.take();
162 let server_only_session_init_settings = self.local_only_server_settings.take();
163 let websocket_listen_addr = self.websocket_listen_addr.take();
164 #[cfg(target_family = "wasm")]
165 let serverless_config = self.serverless_config.take();
166
167 Ok(NodeFuture {
168 _pd: Default::default(),
169 inner: Box::pin(async move {
170 let underlying_proto = match underlying_proto {
171 Some(proto) => proto,
172 None => T::default_server_config().await.map_err(|err| {
173 NetworkError::Generic(format!(
174 "Failed to create default server config: {err}"
175 ))
176 })?,
177 };
178
179 T::config_warnings(&underlying_proto);
180
181 #[cfg(target_family = "wasm")]
184 let (pre_built_listener, client_config, hypernode_type) = if let Some(sl_config) =
185 serverless_config
186 {
187 let conn = establish_serverless_connection(
188 sl_config.signaling.as_ref(),
189 &sl_config.room_token,
190 &sl_config.ice_servers,
191 sl_config.poll_interval_ms,
192 sl_config.timeout_ms,
193 )
194 .await
195 .map_err(|e: std::io::Error| NetworkError::Generic(e.to_string()))?;
196
197 T::setup_serverless_transport(conn.stream, conn.is_server_role, client_config)
198 } else {
199 (None, client_config, hypernode_type)
200 };
201
202 #[cfg(not(target_family = "wasm"))]
203 let pre_built_listener = None;
204
205 log::trace!(target: "citadel", "[NodeBuilder] Checking Tokio runtime ...");
206 let rt = citadel_io::try_current_runtime().map_err(NetworkError::Generic)?;
207 log::trace!(target: "citadel", "[NodeBuilder] Creating account manager ...");
208 let account_manager = AccountManager::new(
209 backend_type,
210 server_argon_settings,
211 server_services_cfg,
212 server_misc_settings,
213 )
214 .await?;
215
216 let args: KernelExecutorArguments<_, _, T> = KernelExecutorArguments {
217 rt,
218 hypernode_type,
219 account_manager,
220 kernel,
221 underlying_proto,
222 client_config,
223 kernel_executor_settings,
224 stun_servers,
225 turn_servers,
226 server_only_session_init_settings,
227 websocket_listen_addr,
228 pre_built_listener,
229 };
230
231 log::trace!(target: "citadel", "[NodeBuilder] Creating KernelExecutor ...");
232 let kernel_executor = KernelExecutor::<_, R>::new(args).await?;
233 log::trace!(target: "citadel", "[NodeBuilder] Executing kernel");
234 kernel_executor.execute().await
235 }),
236 })
237 }
238
239 pub fn with_node_type(&mut self, node_type: NodeType) -> &mut Self {
247 self.hypernode_type = Some(node_type);
248 self
249 }
250
251 pub fn with_backend(&mut self, backend_type: BackendType) -> &mut Self {
255 self.backend_type = Some(backend_type);
256 self
257 }
258
259 pub fn with_kernel_executor_settings(
261 &mut self,
262 kernel_executor_settings: KernelExecutorSettings,
263 ) -> &mut Self {
264 self.kernel_executor_settings = Some(kernel_executor_settings);
265 self
266 }
267
268 pub fn with_server_argon_settings(
270 &mut self,
271 settings: ArgonDefaultServerSettings,
272 ) -> &mut Self {
273 self.server_argon_settings = Some(settings);
274 self
275 }
276
277 #[cfg(feature = "google-services")]
279 pub fn with_google_services_json_path<V: Into<String>>(&mut self, path: V) -> &mut Self {
280 let cfg = self.get_or_create_services();
281 cfg.google_services_json_path = Some(path.into());
282 self
283 }
284
285 pub fn with_server_misc_settings(&mut self, misc_settings: ServerMiscSettings) -> &mut Self {
287 self.server_misc_settings = Some(misc_settings);
288 self
289 }
290
291 #[cfg(feature = "google-services")]
294 pub fn with_google_realtime_database_config<V: Into<String>, W: Into<String>>(
295 &mut self,
296 url: V,
297 api_key: W,
298 ) -> &mut Self {
299 let cfg = self.get_or_create_services();
300 cfg.google_rtdb = Some(RtdbConfig {
301 url: url.into(),
302 api_key: api_key.into(),
303 });
304 self
305 }
306
307 pub fn with_underlying_protocol(&mut self, proto: ServerMode<T>) -> &mut Self {
310 self.underlying_protocol = Some(proto);
311 self
312 }
313
314 pub fn with_client_config(&mut self, config: T::ClientConfig) -> &mut Self {
316 self.client_tls_config = Some(config);
317 self
318 }
319
320 #[cfg(feature = "google-services")]
321 fn get_or_create_services(&mut self) -> &mut ServicesConfig {
322 if self.services.is_some() {
323 self.services.as_mut().unwrap()
324 } else {
325 let cfg = ServicesConfig::default();
326 self.services = Some(cfg);
327 self.services.as_mut().unwrap()
328 }
329 }
330
331 pub fn with_stun_servers<V: Into<String>, S: Into<Vec<V>>>(&mut self, servers: S) -> &mut Self {
333 self.stun_servers = Some(servers.into().into_iter().map(|t| t.into()).collect());
334 self
335 }
336
337 pub fn with_turn_servers<S: Into<Vec<TurnServerConfig>>>(&mut self, servers: S) -> &mut Self {
343 self.turn_servers = Some(servers.into());
344 self
345 }
346
347 pub fn with_websocket_listener(&mut self, addr: std::net::SocketAddr) -> &mut Self {
355 self.websocket_listen_addr = Some(addr);
356 self
357 }
358
359 #[cfg(target_family = "wasm")]
366 pub fn with_no_central_server(&mut self, config: ServerlessConfig) -> &mut Self {
367 self.serverless_config = Some(config);
368 self
369 }
370
371 pub fn with_server_password<V: Into<PreSharedKey>>(&mut self, password: V) -> &mut Self {
376 let mut server_only_settings = self.local_only_server_settings.clone().unwrap_or_default();
377 server_only_settings.declared_pre_shared_key = Some(password.into());
378 self.local_only_server_settings = Some(server_only_settings);
379 self
380 }
381
382 pub fn with_server_declared_header_obfuscation<V: Into<HeaderObfuscatorSettings>>(
384 &mut self,
385 header_obfuscator_settings: V,
386 ) -> &mut Self {
387 let mut server_only_settings = self.local_only_server_settings.clone().unwrap_or_default();
388 server_only_settings.declared_header_obfuscation_setting =
389 header_obfuscator_settings.into();
390 self.local_only_server_settings = Some(server_only_settings);
391 self
392 }
393
394 fn check(&self) -> anyhow::Result<()> {
395 #[cfg(feature = "google-services")]
396 if let Some(svc) = self.services.as_ref() {
397 if svc.google_rtdb.is_some() && svc.google_services_json_path.is_none() {
398 return Err(anyhow::Error::msg(
399 "Google realtime database is enabled, yet, a services path is not provided",
400 ));
401 }
402 }
403
404 if let Some(stun_servers) = self.stun_servers.as_ref() {
405 if stun_servers.len() != 3 {
406 return Err(anyhow::Error::msg(
407 "There must be exactly 3 specified STUN servers",
408 ));
409 }
410 }
411
412 Ok(())
413 }
414}
415
416#[cfg(not(target_family = "wasm"))]
418impl<R: Ratchet + ContextRequirements> NodeBuilder<R, NativeIO> {
419 pub async fn with_native_certs(&mut self) -> anyhow::Result<&mut Self> {
423 let certs = citadel_proto::re_imports::load_native_certs_async().await?;
424 self.client_tls_config = Some(std::sync::Arc::new(
425 citadel_proto::re_imports::cert_vec_to_secure_client_config(&certs)?,
426 ));
427 Ok(self)
428 }
429
430 pub fn with_insecure_skip_cert_verification(&mut self) -> &mut Self {
433 self.client_tls_config = Some(std::sync::Arc::new(
434 citadel_proto::re_imports::insecure::rustls_client_config(),
435 ));
436 self
437 }
438
439 pub fn with_custom_certs<V: AsRef<[u8]>>(
442 &mut self,
443 custom_certs: &[V],
444 ) -> anyhow::Result<&mut Self> {
445 let cfg = citadel_proto::re_imports::create_rustls_client_config(custom_certs)?;
446 self.client_tls_config = Some(std::sync::Arc::new(cfg));
447 Ok(self)
448 }
449
450 #[cfg(feature = "std")]
452 pub async fn with_pem_file<P: AsRef<std::path::Path>>(
453 &mut self,
454 path: P,
455 ) -> anyhow::Result<&mut Self> {
456 use citadel_wire::exports::{Certificate, PemObject};
457 let mut der = std::io::Cursor::new(citadel_io::tokio::fs::read(path).await?);
458 let certs: Vec<Certificate<'static>> =
459 Certificate::pem_reader_iter(&mut der).collect::<Result<Vec<_>, _>>()?;
460 self.client_tls_config = Some(std::sync::Arc::new(
461 citadel_proto::re_imports::create_rustls_client_config(&certs)?,
462 ));
463 Ok(self)
464 }
465}
466
467#[cfg(all(test, not(target_family = "wasm")))]
468mod tests {
469 use crate::builder::node_builder::DefaultNodeBuilder;
470 use crate::prefabs::server::empty::EmptyKernel;
471 use crate::prelude::{BackendType, NodeType};
472 use citadel_io::tokio;
473 use citadel_proto::prelude::{
474 KernelExecutorSettings, NativeIO, NativeP2PConfig, NativeSecureConfig, ServerMode,
475 };
476 use rstest::rstest;
477 use std::str::FromStr;
478
479 #[test]
480 #[cfg(feature = "google-services")]
481 fn okay_config() {
482 let _ = DefaultNodeBuilder::default()
483 .with_google_realtime_database_config("123", "456")
484 .with_google_services_json_path("abc")
485 .build(EmptyKernel::default())
486 .unwrap();
487 }
488
489 #[test]
490 #[cfg(feature = "google-services")]
491 fn bad_config() {
492 assert!(DefaultNodeBuilder::default()
493 .with_google_realtime_database_config("123", "456")
494 .build(EmptyKernel::default())
495 .is_err());
496 }
497
498 #[test]
499 fn bad_config2() {
500 assert!(DefaultNodeBuilder::default()
501 .with_stun_servers(["dummy1", "dummy2"])
502 .build(EmptyKernel::default())
503 .is_err());
504 }
505
506 #[rstest]
507 #[tokio::test]
508 #[timeout(std::time::Duration::from_secs(60))]
509 #[allow(clippy::let_underscore_must_use)]
510 async fn test_options(
511 #[values(ServerMode::P2P(NativeP2PConfig::self_signed()), ServerMode::OrderedReliableSecure(NativeSecureConfig::self_signed().unwrap())
512 )]
513 underlying_protocol: ServerMode<NativeIO>,
514 #[values(NodeType::Peer, NodeType::Server(std::net::SocketAddr::from_str("127.0.0.1:9999").unwrap()
515 ))]
516 node_type: NodeType,
517 #[values(KernelExecutorSettings::default(), KernelExecutorSettings::default().with_max_concurrency(2)
518 )]
519 kernel_settings: KernelExecutorSettings,
520 #[values(BackendType::InMemory, BackendType::new("file:/hello_world/path/").unwrap())]
521 backend_type: BackendType,
522 ) {
523 let mut builder = DefaultNodeBuilder::default();
524 let _ = builder
525 .with_underlying_protocol(underlying_protocol.clone())
526 .with_backend(backend_type.clone())
527 .with_node_type(node_type)
528 .with_kernel_executor_settings(kernel_settings.clone())
529 .with_insecure_skip_cert_verification()
530 .with_stun_servers(["dummy1", "dummy1", "dummy3"])
531 .with_native_certs()
532 .await
533 .unwrap();
534
535 assert!(builder.underlying_protocol.is_some());
536 assert_eq!(backend_type, builder.backend_type.clone().unwrap());
537 assert_eq!(node_type, builder.hypernode_type.unwrap());
538 assert_eq!(
539 kernel_settings,
540 builder.kernel_executor_settings.clone().unwrap()
541 );
542
543 drop(builder.build(EmptyKernel::default()).unwrap());
544 }
545}