1use citadel_proto::prelude::*;
39
40use citadel_proto::kernel::KernelExecutorArguments;
41use citadel_proto::macros::{ContextRequirements, LocalContextRequirements};
42use citadel_proto::re_imports::RustlsClientConfig;
43use citadel_types::crypto::{HeaderObfuscatorSettings, PreSharedKey};
44use futures::Future;
45use std::fmt::{Debug, Formatter};
46use std::marker::PhantomData;
47use std::path::Path;
48use std::pin::Pin;
49use std::sync::Arc;
50use std::task::{Context, Poll};
51
52pub struct NodeBuilder<R: Ratchet = StackedRatchet> {
54 hypernode_type: Option<NodeType>,
55 underlying_protocol: Option<ServerUnderlyingProtocol>,
56 backend_type: Option<BackendType>,
57 server_argon_settings: Option<ArgonDefaultServerSettings>,
58 #[cfg(feature = "google-services")]
59 services: Option<ServicesConfig>,
60 server_misc_settings: Option<ServerMiscSettings>,
61 client_tls_config: Option<RustlsClientConfig>,
62 kernel_executor_settings: Option<KernelExecutorSettings>,
63 stun_servers: Option<Vec<String>>,
64 local_only_server_settings: Option<ServerOnlySessionInitSettings>,
65 _ratchet: PhantomData<R>,
66}
67
68pub type DefaultNodeBuilder = NodeBuilder<StackedRatchet>;
70
71pub type LightweightNodeBuilder = NodeBuilder<MonoRatchet>;
72
73impl<R: Ratchet> Default for NodeBuilder<R> {
74 fn default() -> Self {
75 Self {
76 hypernode_type: None,
77 underlying_protocol: None,
78 backend_type: None,
79 server_argon_settings: None,
80 #[cfg(feature = "google-services")]
81 services: None,
82 server_misc_settings: None,
83 client_tls_config: None,
84 kernel_executor_settings: None,
85 stun_servers: None,
86 local_only_server_settings: None,
87 _ratchet: Default::default(),
88 }
89 }
90}
91
92pub struct NodeFuture<'a, K> {
94 inner: Pin<Box<dyn FutureContextRequirements<'a, Result<K, NetworkError>>>>,
95 _pd: PhantomData<fn() -> K>,
96}
97
98#[cfg(feature = "multi-threaded")]
99trait FutureContextRequirements<'a, Output>:
100 Future<Output = Output> + Send + LocalContextRequirements<'a>
101{
102}
103#[cfg(feature = "multi-threaded")]
104impl<'a, T: Future<Output = Output> + Send + LocalContextRequirements<'a>, Output>
105 FutureContextRequirements<'a, Output> for T
106{
107}
108
109#[cfg(not(feature = "multi-threaded"))]
110trait FutureContextRequirements<'a, Output>:
111 Future<Output = Output> + LocalContextRequirements<'a>
112{
113}
114#[cfg(not(feature = "multi-threaded"))]
115impl<'a, T: Future<Output = Output> + LocalContextRequirements<'a>, Output>
116 crate::builder::node_builder::FutureContextRequirements<'a, Output> for T
117{
118}
119
120impl<K> Debug for NodeFuture<'_, K> {
121 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
122 write!(f, "NodeFuture")
123 }
124}
125
126impl<K> Future for NodeFuture<'_, K> {
127 type Output = Result<K, NetworkError>;
128
129 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
130 self.inner.as_mut().poll(cx)
131 }
132}
133
134impl<R: Ratchet + ContextRequirements> NodeBuilder<R> {
135 pub fn build<'a, 'b: 'a, K: NetKernel<R> + 'b>(
137 &'a mut self,
138 kernel: K,
139 ) -> anyhow::Result<NodeFuture<'b, K>> {
140 self.check()?;
141 let hypernode_type = self.hypernode_type.take().unwrap_or_default();
142 let backend_type = self.backend_type.take().unwrap_or_else(|| {
143 if cfg!(feature = "filesystem") {
144 let mut home_dir = dirs2::home_dir().unwrap();
146 home_dir.push(format!(".citadel/{}", uuid::Uuid::new_v4().as_u128()));
147 return BackendType::Filesystem(home_dir.to_str().unwrap().to_string());
148 }
149
150 BackendType::InMemory
151 });
152 let server_argon_settings = self.server_argon_settings.take();
153 #[cfg(feature = "google-services")]
154 let server_services_cfg = self.services.take();
155 #[cfg(not(feature = "google-services"))]
156 let server_services_cfg = None;
157 let server_misc_settings = self.server_misc_settings.take();
158 let client_config = self.client_tls_config.take().map(Arc::new);
159 let kernel_executor_settings = self.kernel_executor_settings.take().unwrap_or_default();
160 let stun_servers = self.stun_servers.take();
161
162 let underlying_proto = if let Some(proto) = self.underlying_protocol.take() {
163 proto
164 } else {
165 ServerUnderlyingProtocol::new_tls_self_signed()
167 .map_err(|err| anyhow::Error::msg(err.into_string()))?
168 };
169
170 if matches!(underlying_proto, ServerUnderlyingProtocol::Tcp(..)) {
171 citadel_logging::warn!(target: "citadel", "⚠️ WARNING ⚠️ TCP is discouraged for production use until The Citadel Protocol has been reviewed. Use TLS automatically by not changing the underlying protocol");
172 }
173
174 let server_only_session_init_settings = self.local_only_server_settings.take();
175
176 Ok(NodeFuture {
177 _pd: Default::default(),
178 inner: Box::pin(async move {
179 log::trace!(target: "citadel", "[NodeBuilder] Checking Tokio runtime ...");
180 let rt = citadel_io::tokio::runtime::Handle::try_current()
181 .map_err(|err| NetworkError::Generic(err.to_string()))?;
182 log::trace!(target: "citadel", "[NodeBuilder] Creating account manager ...");
183 let account_manager = AccountManager::new(
184 backend_type,
185 server_argon_settings,
186 server_services_cfg,
187 server_misc_settings,
188 )
189 .await?;
190
191 let args = KernelExecutorArguments {
192 rt,
193 hypernode_type,
194 account_manager,
195 kernel,
196 underlying_proto,
197 client_config,
198 kernel_executor_settings,
199 stun_servers,
200 server_only_session_init_settings,
201 };
202
203 log::trace!(target: "citadel", "[NodeBuilder] Creating KernelExecutor ...");
204 let kernel_executor = KernelExecutor::<_, R>::new(args).await?;
205 log::trace!(target: "citadel", "[NodeBuilder] Executing kernel");
206 kernel_executor.execute().await
207 }),
208 })
209 }
210
211 pub fn with_node_type(&mut self, node_type: NodeType) -> &mut Self {
219 self.hypernode_type = Some(node_type);
220 self
221 }
222
223 pub fn with_backend(&mut self, backend_type: BackendType) -> &mut Self {
227 self.backend_type = Some(backend_type);
228 self
229 }
230
231 pub fn with_kernel_executor_settings(
233 &mut self,
234 kernel_executor_settings: KernelExecutorSettings,
235 ) -> &mut Self {
236 self.kernel_executor_settings = Some(kernel_executor_settings);
237 self
238 }
239
240 pub fn with_server_argon_settings(
242 &mut self,
243 settings: ArgonDefaultServerSettings,
244 ) -> &mut Self {
245 self.server_argon_settings = Some(settings);
246 self
247 }
248
249 #[cfg(feature = "google-services")]
251 pub fn with_google_services_json_path<T: Into<String>>(&mut self, path: T) -> &mut Self {
252 let cfg = self.get_or_create_services();
253 cfg.google_services_json_path = Some(path.into());
254 self
255 }
256
257 pub fn with_server_misc_settings(&mut self, misc_settings: ServerMiscSettings) -> &mut Self {
259 self.server_misc_settings = Some(misc_settings);
260 self
261 }
262
263 #[cfg(feature = "google-services")]
266 pub fn with_google_realtime_database_config<T: Into<String>, R: Into<String>>(
267 &mut self,
268 url: T,
269 api_key: R,
270 ) -> &mut Self {
271 let cfg = self.get_or_create_services();
272 cfg.google_rtdb = Some(RtdbConfig {
273 url: url.into(),
274 api_key: api_key.into(),
275 });
276 self
277 }
278
279 pub fn with_underlying_protocol(&mut self, proto: ServerUnderlyingProtocol) -> &mut Self {
282 self.underlying_protocol = Some(proto);
283 self
284 }
285
286 #[cfg(feature = "google-services")]
287 fn get_or_create_services(&mut self) -> &mut ServicesConfig {
288 if self.services.is_some() {
289 self.services.as_mut().unwrap()
290 } else {
291 let cfg = ServicesConfig::default();
292 self.services = Some(cfg);
293 self.services.as_mut().unwrap()
294 }
295 }
296
297 pub async fn with_native_certs(&mut self) -> anyhow::Result<&mut Self> {
301 let certs = citadel_proto::re_imports::load_native_certs_async().await?;
302 self.client_tls_config = Some(citadel_proto::re_imports::cert_vec_to_secure_client_config(
303 &certs,
304 )?);
305 Ok(self)
306 }
307
308 pub fn with_insecure_skip_cert_verification(&mut self) -> &mut Self {
311 self.client_tls_config = Some(citadel_proto::re_imports::insecure::rustls_client_config());
312 self
313 }
314
315 pub fn with_custom_certs<T: AsRef<[u8]>>(
318 &mut self,
319 custom_certs: &[T],
320 ) -> anyhow::Result<&mut Self> {
321 let cfg = citadel_proto::re_imports::create_rustls_client_config(custom_certs)?;
322 self.client_tls_config = Some(cfg);
323 Ok(self)
324 }
325
326 #[cfg(feature = "std")]
328 pub async fn with_pem_file<P: AsRef<Path>>(&mut self, path: P) -> anyhow::Result<&mut Self> {
329 let mut der = std::io::Cursor::new(citadel_io::tokio::fs::read(path).await?);
330 let certs = citadel_proto::re_imports::rustls_pemfile::certs(&mut der).collect::<Vec<_>>();
331 let mut filtered_certs = Vec::new();
333 for cert in certs {
334 filtered_certs.push(cert?);
335 }
336 self.client_tls_config = Some(citadel_proto::re_imports::create_rustls_client_config(
337 &filtered_certs,
338 )?);
339 Ok(self)
340 }
341
342 pub fn with_stun_servers<T: Into<String>, S: Into<Vec<T>>>(&mut self, servers: S) -> &mut Self {
344 self.stun_servers = Some(servers.into().into_iter().map(|t| t.into()).collect());
345 self
346 }
347
348 pub fn with_server_password<T: Into<PreSharedKey>>(&mut self, password: T) -> &mut Self {
353 let mut server_only_settings = self.local_only_server_settings.clone().unwrap_or_default();
354 server_only_settings.declared_pre_shared_key = Some(password.into());
355 self.local_only_server_settings = Some(server_only_settings);
356 self
357 }
358
359 pub fn with_server_declared_header_obfuscation<T: Into<HeaderObfuscatorSettings>>(
361 &mut self,
362 header_obfuscator_settings: T,
363 ) -> &mut Self {
364 let mut server_only_settings = self.local_only_server_settings.clone().unwrap_or_default();
365 server_only_settings.declared_header_obfuscation_setting =
366 header_obfuscator_settings.into();
367 self.local_only_server_settings = Some(server_only_settings);
368 self
369 }
370
371 fn check(&self) -> anyhow::Result<()> {
372 #[cfg(feature = "google-services")]
373 if let Some(svc) = self.services.as_ref() {
374 if svc.google_rtdb.is_some() && svc.google_services_json_path.is_none() {
375 return Err(anyhow::Error::msg(
376 "Google realtime database is enabled, yet, a services path is not provided",
377 ));
378 }
379 }
380
381 if let Some(stun_servers) = self.stun_servers.as_ref() {
382 if stun_servers.len() != 3 {
383 return Err(anyhow::Error::msg(
384 "There must be exactly 3 specified STUN servers",
385 ));
386 }
387 }
388
389 Ok(())
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use crate::builder::node_builder::DefaultNodeBuilder;
396 use crate::prefabs::server::empty::EmptyKernel;
397 use crate::prelude::{BackendType, NodeType};
398 use citadel_io::tokio;
399 use citadel_proto::prelude::{KernelExecutorSettings, ServerUnderlyingProtocol};
400 use rstest::rstest;
401 use std::str::FromStr;
402
403 #[test]
404 #[cfg(feature = "google-services")]
405 fn okay_config() {
406 let _ = DefaultNodeBuilder::default()
407 .with_google_realtime_database_config("123", "456")
408 .with_google_services_json_path("abc")
409 .build(EmptyKernel::default())
410 .unwrap();
411 }
412
413 #[test]
414 #[cfg(feature = "google-services")]
415 fn bad_config() {
416 assert!(DefaultNodeBuilder::default()
417 .with_google_realtime_database_config("123", "456")
418 .build(EmptyKernel::default())
419 .is_err());
420 }
421
422 #[test]
423 fn bad_config2() {
424 assert!(DefaultNodeBuilder::default()
425 .with_stun_servers(["dummy1", "dummy2"])
426 .build(EmptyKernel::default())
427 .is_err());
428 }
429
430 #[rstest]
431 #[tokio::test]
432 #[timeout(std::time::Duration::from_secs(60))]
433 #[allow(clippy::let_underscore_must_use)]
434 async fn test_options(
435 #[values(ServerUnderlyingProtocol::new_quic_self_signed(), ServerUnderlyingProtocol::new_tls_self_signed().unwrap()
436 )]
437 underlying_protocol: ServerUnderlyingProtocol,
438 #[values(NodeType::Peer, NodeType::Server(std::net::SocketAddr::from_str("127.0.0.1:9999").unwrap()
439 ))]
440 node_type: NodeType,
441 #[values(KernelExecutorSettings::default(), KernelExecutorSettings::default().with_max_concurrency(2)
442 )]
443 kernel_settings: KernelExecutorSettings,
444 #[values(BackendType::InMemory, BackendType::new("file:/hello_world/path/").unwrap())]
445 backend_type: BackendType,
446 ) {
447 let mut builder = DefaultNodeBuilder::default();
448 let _ = builder
449 .with_underlying_protocol(underlying_protocol.clone())
450 .with_backend(backend_type.clone())
451 .with_node_type(node_type)
452 .with_kernel_executor_settings(kernel_settings.clone())
453 .with_insecure_skip_cert_verification()
454 .with_stun_servers(["dummy1", "dummy1", "dummy3"])
455 .with_native_certs()
456 .await
457 .unwrap();
458
459 assert!(builder.underlying_protocol.is_some());
460 assert_eq!(backend_type, builder.backend_type.clone().unwrap());
461 assert_eq!(node_type, builder.hypernode_type.unwrap());
462 assert_eq!(
463 kernel_settings,
464 builder.kernel_executor_settings.clone().unwrap()
465 );
466
467 drop(builder.build(EmptyKernel::default()).unwrap());
468 }
469}