1use citadel_proto::prelude::*;
39
40use citadel_proto::kernel::KernelExecutorArguments;
41use citadel_proto::macros::{ContextRequirements, LocalContextRequirements};
42use citadel_proto::re_imports::RustlsClientConfig;
43use citadel_types::crypto::{HeaderObfuscatorSettings, PreSharedKey};
44use futures::Future;
45use std::fmt::{Debug, Formatter};
46use std::marker::PhantomData;
47use std::path::Path;
48use std::pin::Pin;
49use std::sync::Arc;
50use std::task::{Context, Poll};
51
52pub struct NodeBuilder<R: Ratchet = StackedRatchet> {
54 hypernode_type: Option<NodeType>,
55 underlying_protocol: Option<ServerUnderlyingProtocol>,
56 backend_type: Option<BackendType>,
57 server_argon_settings: Option<ArgonDefaultServerSettings>,
58 #[cfg(feature = "google-services")]
59 services: Option<ServicesConfig>,
60 server_misc_settings: Option<ServerMiscSettings>,
61 client_tls_config: Option<RustlsClientConfig>,
62 kernel_executor_settings: Option<KernelExecutorSettings>,
63 stun_servers: Option<Vec<String>>,
64 local_only_server_settings: Option<ServerOnlySessionInitSettings>,
65 _ratchet: PhantomData<R>,
66}
67
68pub type DefaultNodeBuilder = NodeBuilder<StackedRatchet>;
70
71pub type LightweightNodeBuilder = NodeBuilder<MonoRatchet>;
72
73impl<R: Ratchet> Default for NodeBuilder<R> {
74 fn default() -> Self {
75 Self {
76 hypernode_type: None,
77 underlying_protocol: None,
78 backend_type: None,
79 server_argon_settings: None,
80 #[cfg(feature = "google-services")]
81 services: None,
82 server_misc_settings: None,
83 client_tls_config: None,
84 kernel_executor_settings: None,
85 stun_servers: None,
86 local_only_server_settings: None,
87 _ratchet: Default::default(),
88 }
89 }
90}
91
92pub struct NodeFuture<'a, K> {
94 inner: Pin<Box<dyn FutureContextRequirements<'a, Result<K, NetworkError>>>>,
95 _pd: PhantomData<fn() -> K>,
96}
97
98#[cfg(feature = "multi-threaded")]
99trait FutureContextRequirements<'a, Output>:
100 Future<Output = Output> + Send + LocalContextRequirements<'a>
101{
102}
103#[cfg(feature = "multi-threaded")]
104impl<'a, T: Future<Output = Output> + Send + LocalContextRequirements<'a>, Output>
105 FutureContextRequirements<'a, Output> for T
106{
107}
108
109#[cfg(not(feature = "multi-threaded"))]
110trait FutureContextRequirements<'a, Output>:
111 Future<Output = Output> + LocalContextRequirements<'a>
112{
113}
114#[cfg(not(feature = "multi-threaded"))]
115impl<'a, T: Future<Output = Output> + LocalContextRequirements<'a>, Output>
116 crate::builder::node_builder::FutureContextRequirements<'a, Output> for T
117{
118}
119
120impl<K> Debug for NodeFuture<'_, K> {
121 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
122 write!(f, "NodeFuture")
123 }
124}
125
126impl<K> Future for NodeFuture<'_, K> {
127 type Output = Result<K, NetworkError>;
128
129 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
130 self.inner.as_mut().poll(cx)
131 }
132}
133
134impl<R: Ratchet + ContextRequirements> NodeBuilder<R> {
135 pub fn build<'a, 'b: 'a, K: NetKernel<R> + 'b>(
137 &'a mut self,
138 kernel: K,
139 ) -> anyhow::Result<NodeFuture<'b, K>> {
140 self.check()?;
141 let hypernode_type = self.hypernode_type.take().unwrap_or_default();
142 let backend_type = self.backend_type.take().unwrap_or_else(|| {
143 if cfg!(feature = "filesystem") {
144 let mut home_dir = dirs2::home_dir().unwrap();
146 home_dir.push(format!(".citadel/{}", uuid::Uuid::new_v4().as_u128()));
147 return BackendType::Filesystem(home_dir.to_str().unwrap().to_string());
148 }
149
150 BackendType::InMemory
151 });
152 let server_argon_settings = self.server_argon_settings.take();
153 #[cfg(feature = "google-services")]
154 let server_services_cfg = self.services.take();
155 #[cfg(not(feature = "google-services"))]
156 let server_services_cfg = None;
157 let server_misc_settings = self.server_misc_settings.take();
158 let client_config = self.client_tls_config.take().map(Arc::new);
159 let kernel_executor_settings = self.kernel_executor_settings.take().unwrap_or_default();
160 let stun_servers = self.stun_servers.take();
161
162 let underlying_proto = if let Some(proto) = self.underlying_protocol.take() {
163 proto
164 } else {
165 ServerUnderlyingProtocol::new_tls_self_signed()
167 .map_err(|err| anyhow::Error::msg(err.into_string()))?
168 };
169
170 if matches!(underlying_proto, ServerUnderlyingProtocol::Tcp(..)) {
171 citadel_logging::warn!(target: "citadel", "⚠️ WARNING ⚠️ TCP is discouraged for production use until The Citadel Protocol has been reviewed. Use TLS automatically by not changing the underlying protocol");
172 }
173
174 let server_only_session_init_settings = self.local_only_server_settings.take();
175
176 Ok(NodeFuture {
177 _pd: Default::default(),
178 inner: Box::pin(async move {
179 log::trace!(target: "citadel", "[NodeBuilder] Checking Tokio runtime ...");
180 let rt = citadel_io::tokio::runtime::Handle::try_current()
181 .map_err(|err| NetworkError::Generic(err.to_string()))?;
182 log::trace!(target: "citadel", "[NodeBuilder] Creating account manager ...");
183 let account_manager = AccountManager::new(
184 backend_type,
185 server_argon_settings,
186 server_services_cfg,
187 server_misc_settings,
188 )
189 .await?;
190
191 let args = KernelExecutorArguments {
192 rt,
193 hypernode_type,
194 account_manager,
195 kernel,
196 underlying_proto,
197 client_config,
198 kernel_executor_settings,
199 stun_servers,
200 server_only_session_init_settings,
201 };
202
203 log::trace!(target: "citadel", "[NodeBuilder] Creating KernelExecutor ...");
204 let kernel_executor = KernelExecutor::<_, R>::new(args).await?;
205 log::trace!(target: "citadel", "[NodeBuilder] Executing kernel");
206 kernel_executor.execute().await
207 }),
208 })
209 }
210
211 pub fn with_node_type(&mut self, node_type: NodeType) -> &mut Self {
219 self.hypernode_type = Some(node_type);
220 self
221 }
222
223 pub fn with_backend(&mut self, backend_type: BackendType) -> &mut Self {
227 self.backend_type = Some(backend_type);
228 self
229 }
230
231 pub fn with_kernel_executor_settings(
233 &mut self,
234 kernel_executor_settings: KernelExecutorSettings,
235 ) -> &mut Self {
236 self.kernel_executor_settings = Some(kernel_executor_settings);
237 self
238 }
239
240 pub fn with_server_argon_settings(
242 &mut self,
243 settings: ArgonDefaultServerSettings,
244 ) -> &mut Self {
245 self.server_argon_settings = Some(settings);
246 self
247 }
248
249 #[cfg(feature = "google-services")]
251 pub fn with_google_services_json_path<T: Into<String>>(&mut self, path: T) -> &mut Self {
252 let cfg = self.get_or_create_services();
253 cfg.google_services_json_path = Some(path.into());
254 self
255 }
256
257 pub fn with_server_misc_settings(&mut self, misc_settings: ServerMiscSettings) -> &mut Self {
259 self.server_misc_settings = Some(misc_settings);
260 self
261 }
262
263 #[cfg(feature = "google-services")]
266 pub fn with_google_realtime_database_config<T: Into<String>, R: Into<String>>(
267 &mut self,
268 url: T,
269 api_key: R,
270 ) -> &mut Self {
271 let cfg = self.get_or_create_services();
272 cfg.google_rtdb = Some(RtdbConfig {
273 url: url.into(),
274 api_key: api_key.into(),
275 });
276 self
277 }
278
279 pub fn with_underlying_protocol(&mut self, proto: ServerUnderlyingProtocol) -> &mut Self {
282 self.underlying_protocol = Some(proto);
283 self
284 }
285
286 #[cfg(feature = "google-services")]
287 fn get_or_create_services(&mut self) -> &mut ServicesConfig {
288 if self.services.is_some() {
289 self.services.as_mut().unwrap()
290 } else {
291 let cfg = ServicesConfig::default();
292 self.services = Some(cfg);
293 self.services.as_mut().unwrap()
294 }
295 }
296
297 pub async fn with_native_certs(&mut self) -> anyhow::Result<&mut Self> {
301 let certs = citadel_proto::re_imports::load_native_certs_async().await?;
302 self.client_tls_config = Some(citadel_proto::re_imports::cert_vec_to_secure_client_config(
303 &certs,
304 )?);
305 Ok(self)
306 }
307
308 pub fn with_insecure_skip_cert_verification(&mut self) -> &mut Self {
311 self.client_tls_config = Some(citadel_proto::re_imports::insecure::rustls_client_config());
312 self
313 }
314
315 pub fn with_custom_certs<T: AsRef<[u8]>>(
318 &mut self,
319 custom_certs: &[T],
320 ) -> anyhow::Result<&mut Self> {
321 let cfg = citadel_proto::re_imports::create_rustls_client_config(custom_certs)?;
322 self.client_tls_config = Some(cfg);
323 Ok(self)
324 }
325
326 #[cfg(feature = "std")]
328 pub async fn with_pem_file<P: AsRef<Path>>(&mut self, path: P) -> anyhow::Result<&mut Self> {
329 use citadel_wire::exports::{Certificate, PemObject};
330 let mut der = std::io::Cursor::new(citadel_io::tokio::fs::read(path).await?);
331 let certs: Vec<Certificate<'static>> =
332 Certificate::pem_reader_iter(&mut der).collect::<Result<Vec<_>, _>>()?;
333 self.client_tls_config = Some(citadel_proto::re_imports::create_rustls_client_config(
334 &certs,
335 )?);
336 Ok(self)
337 }
338
339 pub fn with_stun_servers<T: Into<String>, S: Into<Vec<T>>>(&mut self, servers: S) -> &mut Self {
341 self.stun_servers = Some(servers.into().into_iter().map(|t| t.into()).collect());
342 self
343 }
344
345 pub fn with_server_password<T: Into<PreSharedKey>>(&mut self, password: T) -> &mut Self {
350 let mut server_only_settings = self.local_only_server_settings.clone().unwrap_or_default();
351 server_only_settings.declared_pre_shared_key = Some(password.into());
352 self.local_only_server_settings = Some(server_only_settings);
353 self
354 }
355
356 pub fn with_server_declared_header_obfuscation<T: Into<HeaderObfuscatorSettings>>(
358 &mut self,
359 header_obfuscator_settings: T,
360 ) -> &mut Self {
361 let mut server_only_settings = self.local_only_server_settings.clone().unwrap_or_default();
362 server_only_settings.declared_header_obfuscation_setting =
363 header_obfuscator_settings.into();
364 self.local_only_server_settings = Some(server_only_settings);
365 self
366 }
367
368 fn check(&self) -> anyhow::Result<()> {
369 #[cfg(feature = "google-services")]
370 if let Some(svc) = self.services.as_ref() {
371 if svc.google_rtdb.is_some() && svc.google_services_json_path.is_none() {
372 return Err(anyhow::Error::msg(
373 "Google realtime database is enabled, yet, a services path is not provided",
374 ));
375 }
376 }
377
378 if let Some(stun_servers) = self.stun_servers.as_ref() {
379 if stun_servers.len() != 3 {
380 return Err(anyhow::Error::msg(
381 "There must be exactly 3 specified STUN servers",
382 ));
383 }
384 }
385
386 Ok(())
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use crate::builder::node_builder::DefaultNodeBuilder;
393 use crate::prefabs::server::empty::EmptyKernel;
394 use crate::prelude::{BackendType, NodeType};
395 use citadel_io::tokio;
396 use citadel_proto::prelude::{KernelExecutorSettings, ServerUnderlyingProtocol};
397 use rstest::rstest;
398 use std::str::FromStr;
399
400 #[test]
401 #[cfg(feature = "google-services")]
402 fn okay_config() {
403 let _ = DefaultNodeBuilder::default()
404 .with_google_realtime_database_config("123", "456")
405 .with_google_services_json_path("abc")
406 .build(EmptyKernel::default())
407 .unwrap();
408 }
409
410 #[test]
411 #[cfg(feature = "google-services")]
412 fn bad_config() {
413 assert!(DefaultNodeBuilder::default()
414 .with_google_realtime_database_config("123", "456")
415 .build(EmptyKernel::default())
416 .is_err());
417 }
418
419 #[test]
420 fn bad_config2() {
421 assert!(DefaultNodeBuilder::default()
422 .with_stun_servers(["dummy1", "dummy2"])
423 .build(EmptyKernel::default())
424 .is_err());
425 }
426
427 #[rstest]
428 #[tokio::test]
429 #[timeout(std::time::Duration::from_secs(60))]
430 #[allow(clippy::let_underscore_must_use)]
431 async fn test_options(
432 #[values(ServerUnderlyingProtocol::new_quic_self_signed(), ServerUnderlyingProtocol::new_tls_self_signed().unwrap()
433 )]
434 underlying_protocol: ServerUnderlyingProtocol,
435 #[values(NodeType::Peer, NodeType::Server(std::net::SocketAddr::from_str("127.0.0.1:9999").unwrap()
436 ))]
437 node_type: NodeType,
438 #[values(KernelExecutorSettings::default(), KernelExecutorSettings::default().with_max_concurrency(2)
439 )]
440 kernel_settings: KernelExecutorSettings,
441 #[values(BackendType::InMemory, BackendType::new("file:/hello_world/path/").unwrap())]
442 backend_type: BackendType,
443 ) {
444 let mut builder = DefaultNodeBuilder::default();
445 let _ = builder
446 .with_underlying_protocol(underlying_protocol.clone())
447 .with_backend(backend_type.clone())
448 .with_node_type(node_type)
449 .with_kernel_executor_settings(kernel_settings.clone())
450 .with_insecure_skip_cert_verification()
451 .with_stun_servers(["dummy1", "dummy1", "dummy3"])
452 .with_native_certs()
453 .await
454 .unwrap();
455
456 assert!(builder.underlying_protocol.is_some());
457 assert_eq!(backend_type, builder.backend_type.clone().unwrap());
458 assert_eq!(node_type, builder.hypernode_type.unwrap());
459 assert_eq!(
460 kernel_settings,
461 builder.kernel_executor_settings.clone().unwrap()
462 );
463
464 drop(builder.build(EmptyKernel::default()).unwrap());
465 }
466}