1use crate::prefabs::client::peer_connection::FileTransferHandleRx;
58use crate::prefabs::client::ServerConnectionSettings;
59use crate::prefabs::ClientServerRemote;
60use crate::remote_ext::CitadelClientServerConnection;
61use crate::remote_ext::ProtocolRemoteExt;
62use citadel_io::Mutex;
63use citadel_proto::prelude::*;
64use citadel_types::crypto::PreSharedKey;
65use futures::Future;
66use std::marker::PhantomData;
67use std::net::SocketAddr;
68use uuid::Uuid;
69
70pub struct SingleClientServerConnectionKernel<F, Fut, R: Ratchet> {
74 handler: Mutex<Option<F>>,
75 udp_mode: UdpMode,
76 auth_info: Mutex<Option<ConnectionType>>,
77 session_security_settings: SessionSecuritySettings,
78 unprocessed_signal_filter_tx:
79 Mutex<Option<citadel_io::tokio::sync::mpsc::UnboundedSender<NodeResult<R>>>>,
80 remote: Option<NodeRemote<R>>,
81 server_password: Option<PreSharedKey>,
82 rx_incoming_object_transfer_handle: Mutex<Option<FileTransferHandleRx>>,
83 tx_incoming_object_transfer_handle:
84 citadel_io::tokio::sync::mpsc::UnboundedSender<ObjectTransferHandler>,
85 _pd: PhantomData<fn() -> Fut>,
87}
88
89#[derive(Debug)]
90pub(crate) enum ConnectionType {
91 Register {
92 server_addr: SocketAddr,
93 username: String,
94 password: SecBuffer,
95 full_name: String,
96 },
97 Connect {
98 username: String,
99 password: SecBuffer,
100 },
101 Transient {
102 uuid: Uuid,
103 server_addr: SocketAddr,
104 },
105}
106
107impl<F, Fut, R: Ratchet> SingleClientServerConnectionKernel<F, Fut, R>
108where
109 F: FnOnce(CitadelClientServerConnection<R>) -> Fut + Send,
110 Fut: Future<Output = Result<(), NetworkError>> + Send,
111{
112 fn generate_object_transfer_handle() -> (
113 citadel_io::tokio::sync::mpsc::UnboundedSender<ObjectTransferHandler>,
114 Mutex<Option<FileTransferHandleRx>>,
115 ) {
116 let (tx, rx) = citadel_io::tokio::sync::mpsc::unbounded_channel();
117 let rx = FileTransferHandleRx {
118 inner: rx,
119 conn_type: VirtualTargetType::LocalGroupServer { session_cid: 0 },
120 };
121 (tx, Mutex::new(Some(rx)))
122 }
123
124 pub fn new(settings: ServerConnectionSettings<R>, on_channel_received: F) -> Self {
127 let (udp_mode, session_security_settings) =
128 (settings.udp_mode(), settings.session_security_settings());
129 let server_password = settings.pre_shared_key().cloned();
130 let (tx_incoming_object_transfer_handle, rx_incoming_object_transfer_handle) =
131 Self::generate_object_transfer_handle();
132
133 let connection_type = match settings {
134 ServerConnectionSettings::CredentialedConnect {
135 username, password, ..
136 } => ConnectionType::Connect { username, password },
137
138 ServerConnectionSettings::Transient {
139 server_addr: address,
140 uuid,
141 ..
142 } => ConnectionType::Transient {
143 uuid,
144 server_addr: address,
145 },
146
147 ServerConnectionSettings::CredentialedRegister {
148 alias,
149 username,
150 password,
151 address,
152 ..
153 } => ConnectionType::Register {
154 full_name: alias,
155 server_addr: address,
156 username,
157 password,
158 },
159 };
160
161 Self {
162 handler: Mutex::new(Some(on_channel_received)),
163 udp_mode,
164 auth_info: Mutex::new(Some(connection_type)),
165 session_security_settings,
166 unprocessed_signal_filter_tx: Default::default(),
167 rx_incoming_object_transfer_handle,
168 tx_incoming_object_transfer_handle,
169 server_password,
170 remote: None,
171 _pd: Default::default(),
172 }
173 }
174}
175
176#[async_trait]
177impl<F, Fut, R: Ratchet> NetKernel<R> for SingleClientServerConnectionKernel<F, Fut, R>
178where
179 F: FnOnce(CitadelClientServerConnection<R>) -> Fut + Send,
180 Fut: Future<Output = Result<(), NetworkError>> + Send,
181{
182 fn load_remote(&mut self, server_remote: NodeRemote<R>) -> Result<(), NetworkError> {
183 self.remote = Some(server_remote);
184 Ok(())
185 }
186
187 #[allow(clippy::blocks_in_conditions)]
188 #[cfg_attr(
189 feature = "localhost-testing",
190 tracing::instrument(level = "trace", target = "citadel", skip_all, err(Debug))
191 )]
192 async fn on_start(&self) -> Result<(), NetworkError> {
193 let session_security_settings = self.session_security_settings;
194 let remote = self.remote.clone().unwrap();
195 let (auth_info, handler) = {
196 (
197 self.auth_info.lock().take().unwrap(),
198 self.handler.lock().take().unwrap(),
199 )
200 };
201
202 let auth = match auth_info {
203 ConnectionType::Register {
204 full_name,
205 server_addr,
206 username,
207 password,
208 } => {
209 if !remote
210 .account_manager()
211 .get_persistence_handler()
212 .username_exists(&username)
213 .await?
214 {
215 let _reg_success = remote
216 .register(
217 server_addr,
218 full_name.as_str(),
219 username.as_str(),
220 password.clone(),
221 self.session_security_settings,
222 self.server_password.clone(),
223 )
224 .await?;
225 }
226
227 AuthenticationRequest::credentialed(username, password)
228 }
229
230 ConnectionType::Connect { username, password } => {
231 AuthenticationRequest::credentialed(username, password)
232 }
233
234 ConnectionType::Transient { uuid, server_addr } => {
235 AuthenticationRequest::transient(uuid, server_addr)
236 }
237 };
238
239 let mut connect_success = remote
240 .connect(
241 auth,
242 Default::default(),
243 self.udp_mode,
244 None,
245 self.session_security_settings,
246 self.server_password.clone(),
247 )
248 .await?;
249
250 let conn_type = VirtualTargetType::LocalGroupServer {
251 session_cid: connect_success.cid,
252 };
253
254 let mut handle = {
255 let mut lock = self.rx_incoming_object_transfer_handle.lock();
256 lock.take().expect("Should not have been called before")
257 };
258
259 handle.conn_type.set_session_cid(connect_success.cid);
260
261 let (reroute_tx, reroute_rx) = citadel_io::tokio::sync::mpsc::unbounded_channel();
262 *self.unprocessed_signal_filter_tx.lock() = Some(reroute_tx);
263 connect_success.remote = ClientServerRemote::new(
264 conn_type,
265 remote,
266 session_security_settings,
267 Some(reroute_rx),
268 Some(handle),
269 );
270
271 handler(connect_success).await
272 }
273
274 async fn on_node_event_received(&self, message: NodeResult<R>) -> Result<(), NetworkError> {
275 match message {
276 NodeResult::ObjectTransferHandle(handle) => {
277 if let Err(err) = self.tx_incoming_object_transfer_handle.send(handle.handle) {
278 log::warn!(target: "citadel", "failed to send unprocessed NodeResult: {err:?}")
279 }
280 }
281
282 message => {
283 if let Some(val) = self.unprocessed_signal_filter_tx.lock().as_ref() {
284 log::trace!(target: "citadel", "Will forward message {val:?}");
285 if let Err(err) = val.send(message) {
286 log::warn!(target: "citadel", "failed to send unprocessed NodeResult: {err:?}")
287 }
288 }
289 }
290 }
291
292 Ok(())
293 }
294
295 async fn on_stop(&mut self) -> Result<(), NetworkError> {
296 Ok(())
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use crate::prefabs::client::single_connection::SingleClientServerConnectionKernel;
303 use crate::prefabs::client::DefaultServerConnectionSettingsBuilder;
304 use crate::prelude::*;
305 use crate::test_common::{server_info_reactive, wait_for_peers, TestBarrier};
306 use citadel_io::tokio;
307 use rstest::rstest;
308 use std::sync::atomic::{AtomicBool, Ordering};
309 use uuid::Uuid;
310
311 #[cfg_attr(
312 feature = "localhost-testing",
313 tracing::instrument(level = "trace", target = "citadel", skip_all, err(Debug))
314 )]
315 async fn on_server_received_conn<R: Ratchet>(
316 udp_mode: UdpMode,
317 conn: &mut CitadelClientServerConnection<R>,
318 ) -> Result<(), NetworkError> {
319 crate::test_common::udp_mode_assertions(udp_mode, conn.udp_channel_rx.take()).await;
320 Ok(())
321 }
322
323 #[cfg_attr(
324 feature = "localhost-testing",
325 tracing::instrument(level = "trace", target = "citadel", skip_all, err(Debug))
326 )]
327 async fn default_server_harness<R: Ratchet>(
328 udp_mode: UdpMode,
329 mut connection: CitadelClientServerConnection<R>,
330 server_success: &AtomicBool,
331 ) -> Result<(), NetworkError> {
332 wait_for_peers().await;
333 on_server_received_conn(udp_mode, &mut connection).await?;
334 server_success.store(true, Ordering::SeqCst);
335 log::warn!(target: "citadel", "Server awaiting peer ...");
336 wait_for_peers().await;
337 connection.shutdown_kernel().await
338 }
339
340 #[rstest]
341 #[timeout(std::time::Duration::from_secs(90))]
342 #[citadel_io::tokio::test(flavor = "multi_thread")]
343 async fn test_single_connection_registered(
344 #[values(UdpMode::Enabled, UdpMode::Disabled)] udp_mode: UdpMode,
345 #[values(ServerUnderlyingProtocol::new_quic_self_signed(), ServerUnderlyingProtocol::new_tls_self_signed().unwrap()
346 )]
347 underlying_protocol: ServerUnderlyingProtocol,
348 ) {
349 citadel_logging::setup_log();
350 TestBarrier::setup(2);
351
352 if matches!(underlying_protocol, ServerUnderlyingProtocol::Tls(..)) && cfg!(windows) {
354 citadel_logging::warn!(target: "citadel", "Will skip test since self-signed certs may not necessarily work on windows runner");
355 return;
356 }
357
358 let client_success = &AtomicBool::new(false);
359 let server_success = &AtomicBool::new(false);
360
361 let (server, server_addr) = server_info_reactive::<_, _, StackedRatchet>(
362 move |connection| async move {
363 default_server_harness(udp_mode, connection, server_success).await
364 },
365 |builder| {
366 let _ = builder.with_underlying_protocol(underlying_protocol);
367 },
368 );
369
370 let client_settings = DefaultServerConnectionSettingsBuilder::credentialed_registration(
371 server_addr,
372 "nologik",
373 "Some Alias",
374 "password",
375 )
376 .with_udp_mode(udp_mode)
377 .build()
378 .unwrap();
379
380 let client_kernel =
381 SingleClientServerConnectionKernel::new(client_settings, |mut connection| async move {
382 log::trace!(target: "citadel", "***CLIENT TEST SUCCESS***");
383 let chan = connection.udp_channel_rx.take();
384 wait_for_peers().await;
385 crate::test_common::udp_mode_assertions(udp_mode, chan).await;
386 client_success.store(true, Ordering::Relaxed);
387 wait_for_peers().await;
388 connection.shutdown_kernel().await
389 });
390
391 let client = DefaultNodeBuilder::default().build(client_kernel).unwrap();
392
393 let joined = futures::future::try_join(server, client);
394
395 let _ = joined.await.unwrap();
396
397 assert!(client_success.load(Ordering::Relaxed));
398 assert!(server_success.load(Ordering::Relaxed));
399 }
400
401 #[rstest]
402 #[case(UdpMode::Enabled, None, HeaderObfuscatorSettings::Disabled)]
403 #[case(UdpMode::Enabled, None, HeaderObfuscatorSettings::Enabled)]
404 #[case(
405 UdpMode::Enabled,
406 None,
407 HeaderObfuscatorSettings::EnabledWithKey(12345)
408 )]
409 #[case(
410 UdpMode::Enabled,
411 Some("test-password"),
412 HeaderObfuscatorSettings::Disabled
413 )]
414 #[timeout(std::time::Duration::from_secs(90))]
415 #[citadel_io::tokio::test(flavor = "multi_thread")]
416 async fn test_single_connection_transient(
417 #[case] udp_mode: UdpMode,
418 #[case] server_password: Option<&'static str>,
419 #[case] header_obfuscator_settings: HeaderObfuscatorSettings,
420 ) {
421 citadel_logging::setup_log();
422 TestBarrier::setup(2);
423
424 let client_success = &AtomicBool::new(false);
425 let server_success = &AtomicBool::new(false);
426 let (server, server_addr) = server_info_reactive::<_, _, StackedRatchet>(
427 |connection| async move {
428 default_server_harness(udp_mode, connection, server_success).await
429 },
430 |opts| {
431 if let Some(password) = server_password {
432 let _ = opts
433 .with_server_password(password)
434 .with_server_declared_header_obfuscation(header_obfuscator_settings);
435 }
436 },
437 );
438
439 let uuid = Uuid::new_v4();
440
441 let mut server_connection_settings =
442 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid)
443 .with_udp_mode(udp_mode)
444 .with_session_security_settings(SessionSecuritySettings {
445 security_level: Default::default(),
446 secrecy_mode: Default::default(),
447 crypto_params: Default::default(),
448 header_obfuscator_settings,
449 });
450
451 if let Some(server_password) = server_password {
452 server_connection_settings =
453 server_connection_settings.with_session_password(server_password);
454 }
455
456 let server_connection_settings = server_connection_settings.build().unwrap();
457
458 let client_kernel = SingleClientServerConnectionKernel::new(
459 server_connection_settings,
460 |mut connection| async move {
461 log::trace!(target: "citadel", "***CLIENT TEST SUCCESS***");
462 let chan = connection.udp_channel_rx.take();
463 wait_for_peers().await;
464 crate::test_common::udp_mode_assertions(udp_mode, chan).await;
465 connection.disconnect().await?;
466 client_success.store(true, Ordering::Relaxed);
467 wait_for_peers().await;
468 connection.shutdown_kernel().await
469 },
470 );
471
472 let client = DefaultNodeBuilder::default().build(client_kernel).unwrap();
473
474 let joined = futures::future::try_join(server, client);
475
476 let _ = joined.await.unwrap();
477
478 assert!(client_success.load(Ordering::Relaxed));
479 assert!(server_success.load(Ordering::Relaxed));
480 }
481
482 #[rstest]
483 #[case(UdpMode::Enabled, Some("test-password"))]
484 #[timeout(std::time::Duration::from_secs(90))]
485 #[tokio::test(flavor = "multi_thread")]
486 async fn test_single_connection_transient_wrong_password(
487 #[case] udp_mode: UdpMode,
488 #[case] server_password: Option<&'static str>,
489 ) {
490 citadel_logging::setup_log();
491 TestBarrier::setup(2);
492
493 let (server, server_addr) = server_info_reactive::<_, _, StackedRatchet>(
494 |_conn| async move { panic!("Server should not have connected") },
495 |opts| {
496 if let Some(password) = server_password {
497 let _ = opts.with_server_password(password);
498 }
499 },
500 );
501
502 let uuid = Uuid::new_v4();
503
504 let server_connection_settings =
505 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid)
506 .with_udp_mode(udp_mode)
507 .with_session_password("wrong-password")
508 .build()
509 .unwrap();
510
511 let client_kernel = SingleClientServerConnectionKernel::new(
512 server_connection_settings,
513 |_connection| async move { panic!("Client should not have connected") },
514 );
515
516 let client = DefaultNodeBuilder::default().build(client_kernel).unwrap();
517
518 tokio::select! {
519 _res0 = server => {
520 panic!("Server should never finish")
521 },
522
523 result = client => {
524 if let Err(error) = result {
525 assert!(error.into_string().contains("EncryptionFailure"));
526 } else {
527 panic!("Client should not have connected")
528 }
529 }
530 }
531 }
532
533 #[rstest]
534 #[case(UdpMode::Disabled)]
535 #[timeout(std::time::Duration::from_secs(90))]
536 #[citadel_io::tokio::test(flavor = "multi_thread")]
537 async fn test_single_connection_transient_deregister(#[case] udp_mode: UdpMode) {
538 citadel_logging::setup_log();
539 TestBarrier::setup(2);
540
541 let client_success = &AtomicBool::new(false);
542 let server_success = &AtomicBool::new(false);
543
544 let (server, server_addr) = server_info_reactive::<_, _, StackedRatchet>(
545 |connection| async move {
546 default_server_harness(udp_mode, connection, server_success).await
547 },
548 |_| (),
549 );
550
551 let uuid = Uuid::new_v4();
552
553 let server_connection_settings =
554 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid)
555 .with_udp_mode(udp_mode)
556 .build()
557 .unwrap();
558
559 let client_kernel = SingleClientServerConnectionKernel::new(
560 server_connection_settings,
561 |mut connection| async move {
562 log::trace!(target: "citadel", "***CLIENT TEST SUCCESS***");
563 let chan = connection.udp_channel_rx.take();
564 wait_for_peers().await;
565 crate::test_common::udp_mode_assertions(udp_mode, chan).await;
566 connection.deregister().await?;
567 client_success.store(true, Ordering::Relaxed);
568 wait_for_peers().await;
569 connection.shutdown_kernel().await
570 },
571 );
572
573 let client = DefaultNodeBuilder::default().build(client_kernel).unwrap();
574
575 let joined = futures::future::try_join(server, client);
576
577 let _ = joined.await.unwrap();
578
579 assert!(client_success.load(Ordering::Relaxed));
580 assert!(server_success.load(Ordering::Relaxed));
581 }
582
583 #[rstest]
584 #[timeout(std::time::Duration::from_secs(90))]
585 #[citadel_io::tokio::test(flavor = "multi_thread")]
586 async fn test_backend_store_c2s() {
587 citadel_logging::setup_log();
588 TestBarrier::setup(2);
589
590 let udp_mode = UdpMode::Disabled;
591
592 let client_success = &AtomicBool::new(false);
593 let server_success = &AtomicBool::new(false);
594 let (server, server_addr) = server_info_reactive::<_, _, StackedRatchet>(
595 |connection| async move {
596 default_server_harness(udp_mode, connection, server_success).await
597 },
598 |_| (),
599 );
600
601 let uuid = Uuid::new_v4();
602
603 let server_connection_settings =
604 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid)
605 .with_udp_mode(udp_mode)
606 .build()
607 .unwrap();
608
609 let client_kernel = SingleClientServerConnectionKernel::new(
610 server_connection_settings,
611 |connection| async move {
612 log::trace!(target: "citadel", "***CLIENT TEST SUCCESS***");
613 wait_for_peers().await;
614
615 const KEY: &str = "HELLO_WORLD";
616 const KEY2: &str = "HELLO_WORLD2";
617 let value: Vec<u8> = Vec::from("Hello, world!");
618 let value2: Vec<u8> = Vec::from("Hello, world!2");
619
620 assert_eq!(connection.set(KEY, value.clone()).await?.as_deref(), None);
621 assert_eq!(
622 connection.get(KEY).await?.as_deref(),
623 Some(value.as_slice())
624 );
625
626 assert_eq!(connection.set(KEY2, value2.clone()).await?.as_deref(), None);
627 assert_eq!(
628 connection.get(KEY2).await?.as_deref(),
629 Some(value2.as_slice())
630 );
631
632 let map = connection.get_all().await?;
633 assert_eq!(map.get(KEY), Some(&value));
634 assert_eq!(map.get(KEY2), Some(&value2));
635
636 assert_eq!(
637 connection.remove(KEY2).await?.as_deref(),
638 Some(value2.as_slice())
639 );
640
641 assert_eq!(connection.remove(KEY2).await?.as_deref(), None);
642
643 let map = connection.remove_all().await?;
644 assert_eq!(map.get(KEY), Some(&value));
645 assert_eq!(map.get(KEY2), None);
646
647 assert_eq!(connection.get_all().await?.len(), 0);
648 assert_eq!(connection.remove_all().await?.len(), 0);
649
650 client_success.store(true, Ordering::Relaxed);
651 wait_for_peers().await;
652 connection.shutdown_kernel().await
653 },
654 );
655
656 let client = DefaultNodeBuilder::default().build(client_kernel).unwrap();
657
658 let joined = futures::future::try_join(server, client);
659
660 let _ = joined.await.unwrap();
661
662 assert!(client_success.load(Ordering::Relaxed));
663 assert!(server_success.load(Ordering::Relaxed));
664 }
665
666 #[rstest]
667 #[timeout(std::time::Duration::from_secs(90))]
668 #[citadel_io::tokio::test(flavor = "multi_thread")]
669 async fn test_rekey_c2s() {
670 citadel_logging::setup_log();
671 TestBarrier::setup(2);
672
673 let udp_mode = UdpMode::Disabled;
674
675 let client_success = &AtomicBool::new(false);
676 let server_success = &AtomicBool::new(false);
677 let (server, server_addr) = server_info_reactive::<_, _, StackedRatchet>(
678 |connection| async move {
679 default_server_harness(udp_mode, connection, server_success).await
680 },
681 |_| (),
682 );
683
684 let uuid = Uuid::new_v4();
685
686 let server_connection_settings =
687 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid)
688 .with_udp_mode(udp_mode)
689 .build()
690 .unwrap();
691
692 let client_kernel = SingleClientServerConnectionKernel::new(
693 server_connection_settings,
694 |mut connection| async move {
695 log::trace!(target: "citadel", "***CLIENT LOGIN SUCCESS***");
696 wait_for_peers().await;
697 let chan = connection.udp_channel_rx.take();
698 crate::test_common::udp_mode_assertions(udp_mode, chan).await;
699
700 for x in 1..10 {
701 assert_eq!(connection.remote.rekey().await?, Some(x));
702 }
703
704 client_success.store(true, Ordering::Relaxed);
705 wait_for_peers().await;
706
707 connection.shutdown_kernel().await
708 },
709 );
710
711 let client = DefaultNodeBuilder::default().build(client_kernel).unwrap();
712 let joined = futures::future::try_join(server, client);
713
714 let _ = joined.await.unwrap();
715
716 assert!(client_success.load(Ordering::Relaxed));
717 assert!(server_success.load(Ordering::Relaxed));
718 }
719}