1use crate::prefabs::client::peer_connection::FileTransferHandleRx;
58use crate::prefabs::client::ServerConnectionSettings;
59use crate::prefabs::ClientServerRemote;
60use crate::remote_ext::CitadelClientServerConnection;
61use crate::remote_ext::ProtocolRemoteExt;
62use citadel_io::Mutex;
63use citadel_proto::prelude::*;
64use citadel_types::crypto::PreSharedKey;
65use futures::Future;
66use std::marker::PhantomData;
67use std::net::SocketAddr;
68use uuid::Uuid;
69
70pub struct SingleClientServerConnectionKernel<F, Fut, R: Ratchet> {
74 handler: Mutex<Option<F>>,
75 udp_mode: UdpMode,
76 auth_info: Mutex<Option<ConnectionType>>,
77 session_security_settings: SessionSecuritySettings,
78 unprocessed_signal_filter_tx:
79 Mutex<Option<citadel_io::tokio::sync::mpsc::UnboundedSender<NodeResult<R>>>>,
80 remote: Option<NodeRemote<R>>,
81 server_password: Option<PreSharedKey>,
82 rx_incoming_object_transfer_handle: Mutex<Option<FileTransferHandleRx>>,
83 tx_incoming_object_transfer_handle:
84 citadel_io::tokio::sync::mpsc::UnboundedSender<ObjectTransferHandler>,
85 _pd: PhantomData<fn() -> Fut>,
87}
88
89#[derive(Debug)]
90pub(crate) enum ConnectionType {
91 Register {
92 server_addr: SocketAddr,
93 username: String,
94 password: SecBuffer,
95 full_name: String,
96 },
97 Connect {
98 username: String,
99 password: SecBuffer,
100 },
101 Transient {
102 uuid: Uuid,
103 server_addr: SocketAddr,
104 },
105}
106
107impl<F, Fut, R: Ratchet> SingleClientServerConnectionKernel<F, Fut, R>
108where
109 F: FnOnce(CitadelClientServerConnection<R>) -> Fut + Send,
110 Fut: Future<Output = Result<(), NetworkError>> + Send,
111{
112 fn generate_object_transfer_handle() -> (
113 citadel_io::tokio::sync::mpsc::UnboundedSender<ObjectTransferHandler>,
114 Mutex<Option<FileTransferHandleRx>>,
115 ) {
116 let (tx, rx) = citadel_io::tokio::sync::mpsc::unbounded_channel();
117 let rx = FileTransferHandleRx {
118 inner: rx,
119 conn_type: VirtualTargetType::LocalGroupServer { session_cid: 0 },
120 };
121 (tx, Mutex::new(Some(rx)))
122 }
123
124 pub fn new(settings: ServerConnectionSettings<R>, on_channel_received: F) -> Self {
127 let (udp_mode, session_security_settings) =
128 (settings.udp_mode(), settings.session_security_settings());
129 let server_password = settings.pre_shared_key().cloned();
130 let (tx_incoming_object_transfer_handle, rx_incoming_object_transfer_handle) =
131 Self::generate_object_transfer_handle();
132
133 let connection_type = match settings {
134 ServerConnectionSettings::CredentialedConnect {
135 username, password, ..
136 } => ConnectionType::Connect { username, password },
137
138 ServerConnectionSettings::Transient {
139 server_addr: address,
140 uuid,
141 ..
142 } => ConnectionType::Transient {
143 uuid,
144 server_addr: address,
145 },
146
147 ServerConnectionSettings::CredentialedRegister {
148 alias,
149 username,
150 password,
151 address,
152 ..
153 } => ConnectionType::Register {
154 full_name: alias,
155 server_addr: address,
156 username,
157 password,
158 },
159 };
160
161 Self {
162 handler: Mutex::new(Some(on_channel_received)),
163 udp_mode,
164 auth_info: Mutex::new(Some(connection_type)),
165 session_security_settings,
166 unprocessed_signal_filter_tx: Default::default(),
167 rx_incoming_object_transfer_handle,
168 tx_incoming_object_transfer_handle,
169 server_password,
170 remote: None,
171 _pd: Default::default(),
172 }
173 }
174}
175
176#[async_trait]
177impl<F, Fut, R: Ratchet> NetKernel<R> for SingleClientServerConnectionKernel<F, Fut, R>
178where
179 F: FnOnce(CitadelClientServerConnection<R>) -> Fut + Send,
180 Fut: Future<Output = Result<(), NetworkError>> + Send,
181{
182 fn load_remote(&mut self, server_remote: NodeRemote<R>) -> Result<(), NetworkError> {
183 self.remote = Some(server_remote);
184 Ok(())
185 }
186
187 #[allow(clippy::blocks_in_conditions)]
188 #[cfg_attr(
189 feature = "localhost-testing",
190 tracing::instrument(level = "trace", target = "citadel", skip_all, err(Debug))
191 )]
192 async fn on_start(&self) -> Result<(), NetworkError> {
193 let session_security_settings = self.session_security_settings;
194 let remote = self.remote.clone().unwrap();
195 let (auth_info, handler) = {
196 (
197 self.auth_info.lock().take().unwrap(),
198 self.handler.lock().take().unwrap(),
199 )
200 };
201
202 let auth = match auth_info {
203 ConnectionType::Register {
204 full_name,
205 server_addr,
206 username,
207 password,
208 } => {
209 if !remote
210 .account_manager()
211 .get_persistence_handler()
212 .username_exists(&username)
213 .await?
214 {
215 let _reg_success = remote
216 .register(
217 server_addr,
218 full_name.as_str(),
219 username.as_str(),
220 password.clone(),
221 self.session_security_settings,
222 self.server_password.clone(),
223 )
224 .await?;
225 }
226
227 AuthenticationRequest::credentialed(username, password)
228 }
229
230 ConnectionType::Connect { username, password } => {
231 AuthenticationRequest::credentialed(username, password)
232 }
233
234 ConnectionType::Transient { uuid, server_addr } => {
235 AuthenticationRequest::transient(uuid, server_addr)
236 }
237 };
238
239 let mut connect_success = remote
240 .connect(
241 auth,
242 Default::default(),
243 self.udp_mode,
244 None,
245 self.session_security_settings,
246 self.server_password.clone(),
247 )
248 .await?;
249
250 let conn_type = VirtualTargetType::LocalGroupServer {
251 session_cid: connect_success.cid,
252 };
253
254 let mut handle = {
255 let mut lock = self.rx_incoming_object_transfer_handle.lock();
256 lock.take().expect("Should not have been called before")
257 };
258
259 handle.conn_type.set_session_cid(connect_success.cid);
260
261 let (reroute_tx, reroute_rx) = citadel_io::tokio::sync::mpsc::unbounded_channel();
262 *self.unprocessed_signal_filter_tx.lock() = Some(reroute_tx);
263 connect_success.remote = ClientServerRemote::new(
264 conn_type,
265 remote,
266 session_security_settings,
267 Some(reroute_rx),
268 Some(handle),
269 );
270
271 handler(connect_success).await
272 }
273
274 async fn on_node_event_received(&self, message: NodeResult<R>) -> Result<(), NetworkError> {
275 match message {
276 NodeResult::ObjectTransferHandle(handle) => {
277 if let Err(err) = self.tx_incoming_object_transfer_handle.send(handle.handle) {
278 log::warn!(target: "citadel", "failed to send unprocessed NodeResult: {err:?}")
279 }
280 }
281
282 message => {
283 if let Some(val) = self.unprocessed_signal_filter_tx.lock().as_ref() {
284 log::trace!(target: "citadel", "Will forward message {val:?}");
285 if let Err(err) = val.send(message) {
286 log::warn!(target: "citadel", "failed to send unprocessed NodeResult: {err:?}")
287 }
288 }
289 }
290 }
291
292 Ok(())
293 }
294
295 async fn on_stop(&mut self) -> Result<(), NetworkError> {
296 Ok(())
297 }
298}
299
300#[cfg(all(test, feature = "localhost-testing"))]
301mod tests {
302 use crate::prefabs::client::single_connection::SingleClientServerConnectionKernel;
303 use crate::prefabs::client::DefaultServerConnectionSettingsBuilder;
304 use crate::prelude::*;
305 use crate::test_common::{server_info_reactive, wait_for_peers, TestBarrier};
306 use citadel_io::tokio;
307 use rstest::rstest;
308 use std::sync::atomic::{AtomicBool, Ordering};
309 use uuid::Uuid;
310
311 #[cfg_attr(
312 feature = "localhost-testing",
313 tracing::instrument(level = "trace", target = "citadel", skip_all, err(Debug))
314 )]
315 async fn on_server_received_conn<R: Ratchet>(
316 udp_mode: UdpMode,
317 conn: &mut CitadelClientServerConnection<R>,
318 ) -> Result<(), NetworkError> {
319 crate::test_common::udp_mode_assertions(udp_mode, conn.udp_channel_rx.take()).await;
320 Ok(())
321 }
322
323 #[cfg_attr(
324 feature = "localhost-testing",
325 tracing::instrument(level = "trace", target = "citadel", skip_all, err(Debug))
326 )]
327 async fn default_server_harness<R: Ratchet>(
328 udp_mode: UdpMode,
329 mut connection: CitadelClientServerConnection<R>,
330 server_success: &AtomicBool,
331 ) -> Result<(), NetworkError> {
332 wait_for_peers().await;
333 on_server_received_conn(udp_mode, &mut connection).await?;
334 server_success.store(true, Ordering::SeqCst);
335 log::warn!(target: "citadel", "Server awaiting peer ...");
336 wait_for_peers().await;
337 connection.shutdown_kernel().await
338 }
339
340 #[rstest]
341 #[timeout(std::time::Duration::from_secs(90))]
342 #[citadel_io::tokio::test(flavor = "multi_thread")]
343 async fn test_single_connection_registered(
344 #[values(UdpMode::Enabled, UdpMode::Disabled)] udp_mode: UdpMode,
345 #[values(ServerUnderlyingProtocol::new_quic_self_signed(), ServerUnderlyingProtocol::new_tls_self_signed().unwrap()
346 )]
347 underlying_protocol: ServerUnderlyingProtocol,
348 ) {
349 citadel_logging::setup_log();
350 TestBarrier::setup(2);
351
352 if cfg!(windows) {
355 match &underlying_protocol {
356 ServerUnderlyingProtocol::Tls(..) => {
357 citadel_logging::warn!(target: "citadel", "Skipping TLS test on Windows - self-signed certs may not work");
358 return;
359 }
360 ServerUnderlyingProtocol::Quic(..) => {
361 citadel_logging::warn!(target: "citadel", "Skipping QUIC test on Windows - socket binding may fail with error 10013");
362 return;
363 }
364 _ => {}
365 }
366 }
367
368 let client_success = &AtomicBool::new(false);
369 let server_success = &AtomicBool::new(false);
370
371 let (server, server_addr) = server_info_reactive::<_, _, StackedRatchet>(
372 move |connection| async move {
373 default_server_harness(udp_mode, connection, server_success).await
374 },
375 |builder| {
376 let _ = builder.with_underlying_protocol(underlying_protocol);
377 },
378 );
379
380 let client_settings = DefaultServerConnectionSettingsBuilder::credentialed_registration(
381 server_addr,
382 "nologik",
383 "Some Alias",
384 "password",
385 )
386 .with_udp_mode(udp_mode)
387 .build()
388 .unwrap();
389
390 let client_kernel =
391 SingleClientServerConnectionKernel::new(client_settings, |mut connection| async move {
392 log::trace!(target: "citadel", "***CLIENT TEST SUCCESS***");
393 let chan = connection.udp_channel_rx.take();
394 wait_for_peers().await;
395 crate::test_common::udp_mode_assertions(udp_mode, chan).await;
396 client_success.store(true, Ordering::Relaxed);
397 wait_for_peers().await;
398 connection.shutdown_kernel().await
399 });
400
401 let client = DefaultNodeBuilder::default().build(client_kernel).unwrap();
402
403 let joined = futures::future::try_join(server, client);
404
405 let _ = joined.await.expect("Failed to join server and client - possible port binding issue on Windows (error 10013)");
406
407 assert!(client_success.load(Ordering::Relaxed));
408 assert!(server_success.load(Ordering::Relaxed));
409 }
410
411 #[rstest]
412 #[case(UdpMode::Enabled, None, HeaderObfuscatorSettings::Disabled)]
413 #[case(UdpMode::Enabled, None, HeaderObfuscatorSettings::Enabled)]
414 #[case(
415 UdpMode::Enabled,
416 None,
417 HeaderObfuscatorSettings::EnabledWithKey(12345)
418 )]
419 #[case(
420 UdpMode::Enabled,
421 Some("test-password"),
422 HeaderObfuscatorSettings::Disabled
423 )]
424 #[timeout(std::time::Duration::from_secs(90))]
425 #[citadel_io::tokio::test(flavor = "multi_thread")]
426 async fn test_single_connection_transient(
427 #[case] udp_mode: UdpMode,
428 #[case] server_password: Option<&'static str>,
429 #[case] header_obfuscator_settings: HeaderObfuscatorSettings,
430 ) {
431 citadel_logging::setup_log();
432 TestBarrier::setup(2);
433
434 let client_success = &AtomicBool::new(false);
435 let server_success = &AtomicBool::new(false);
436 let (server, server_addr) = server_info_reactive::<_, _, StackedRatchet>(
437 |connection| async move {
438 default_server_harness(udp_mode, connection, server_success).await
439 },
440 |opts| {
441 if let Some(password) = server_password {
442 let _ = opts
443 .with_server_password(password)
444 .with_server_declared_header_obfuscation(header_obfuscator_settings);
445 }
446 },
447 );
448
449 let uuid = Uuid::new_v4();
450
451 let mut server_connection_settings =
452 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid)
453 .with_udp_mode(udp_mode)
454 .with_session_security_settings(SessionSecuritySettings {
455 security_level: Default::default(),
456 secrecy_mode: Default::default(),
457 crypto_params: Default::default(),
458 header_obfuscator_settings,
459 });
460
461 if let Some(server_password) = server_password {
462 server_connection_settings =
463 server_connection_settings.with_session_password(server_password);
464 }
465
466 let server_connection_settings = server_connection_settings.build().unwrap();
467
468 let client_kernel = SingleClientServerConnectionKernel::new(
469 server_connection_settings,
470 |mut connection| async move {
471 log::trace!(target: "citadel", "***CLIENT TEST SUCCESS***");
472 let chan = connection.udp_channel_rx.take();
473 wait_for_peers().await;
474 crate::test_common::udp_mode_assertions(udp_mode, chan).await;
475 let sessions = connection.remote.sessions().await?;
477 assert!(!sessions.sessions.is_empty());
478 connection.disconnect().await?;
479 client_success.store(true, Ordering::Relaxed);
480 wait_for_peers().await;
481 connection.shutdown_kernel().await
482 },
483 );
484
485 let client = DefaultNodeBuilder::default().build(client_kernel).unwrap();
486
487 let joined = futures::future::try_join(server, client);
488
489 let _ = joined.await.expect("Failed to join server and client - possible port binding issue on Windows (error 10013)");
490
491 assert!(client_success.load(Ordering::Relaxed));
492 assert!(server_success.load(Ordering::Relaxed));
493 }
494
495 #[rstest]
496 #[case(UdpMode::Enabled, Some("test-password"))]
497 #[timeout(std::time::Duration::from_secs(90))]
498 #[tokio::test(flavor = "multi_thread")]
499 async fn test_single_connection_transient_wrong_password(
500 #[case] udp_mode: UdpMode,
501 #[case] server_password: Option<&'static str>,
502 ) {
503 citadel_logging::setup_log();
504 TestBarrier::setup(2);
505
506 let (server, server_addr) = server_info_reactive::<_, _, StackedRatchet>(
507 |_conn| async move { panic!("Server should not have connected") },
508 |opts| {
509 if let Some(password) = server_password {
510 let _ = opts.with_server_password(password);
511 }
512 },
513 );
514
515 let uuid = Uuid::new_v4();
516
517 let server_connection_settings =
518 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid)
519 .with_udp_mode(udp_mode)
520 .with_session_password("wrong-password")
521 .build()
522 .unwrap();
523
524 let client_kernel = SingleClientServerConnectionKernel::new(
525 server_connection_settings,
526 |_connection| async move { panic!("Client should not have connected") },
527 );
528
529 let client = DefaultNodeBuilder::default().build(client_kernel).unwrap();
530
531 tokio::select! {
532 _res0 = server => {
533 panic!("Server should never finish")
534 },
535
536 result = client => {
537 if let Err(error) = result {
538 assert!(error.into_string().contains("EncryptionFailure"));
539 } else {
540 panic!("Client should not have connected")
541 }
542 }
543 }
544 }
545
546 #[rstest]
547 #[case(UdpMode::Disabled)]
548 #[timeout(std::time::Duration::from_secs(90))]
549 #[citadel_io::tokio::test(flavor = "multi_thread")]
550 async fn test_single_connection_transient_deregister(#[case] udp_mode: UdpMode) {
551 citadel_logging::setup_log();
552 TestBarrier::setup(2);
553
554 let client_success = &AtomicBool::new(false);
555 let server_success = &AtomicBool::new(false);
556
557 let (server, server_addr) = server_info_reactive::<_, _, StackedRatchet>(
558 |connection| async move {
559 default_server_harness(udp_mode, connection, server_success).await
560 },
561 |_| (),
562 );
563
564 let uuid = Uuid::new_v4();
565
566 let server_connection_settings =
567 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid)
568 .with_udp_mode(udp_mode)
569 .build()
570 .unwrap();
571
572 let client_kernel = SingleClientServerConnectionKernel::new(
573 server_connection_settings,
574 |mut connection| async move {
575 log::trace!(target: "citadel", "***CLIENT TEST SUCCESS***");
576 let chan = connection.udp_channel_rx.take();
577 wait_for_peers().await;
578 crate::test_common::udp_mode_assertions(udp_mode, chan).await;
579 connection.deregister().await?;
580 client_success.store(true, Ordering::Relaxed);
581 wait_for_peers().await;
582 connection.shutdown_kernel().await
583 },
584 );
585
586 let client = DefaultNodeBuilder::default().build(client_kernel).unwrap();
587
588 let joined = futures::future::try_join(server, client);
589
590 let _ = joined.await.expect("Failed to join server and client - possible port binding issue on Windows (error 10013)");
591
592 assert!(client_success.load(Ordering::Relaxed));
593 assert!(server_success.load(Ordering::Relaxed));
594 }
595
596 #[rstest]
597 #[timeout(std::time::Duration::from_secs(90))]
598 #[citadel_io::tokio::test(flavor = "multi_thread")]
599 async fn test_backend_store_c2s() {
600 citadel_logging::setup_log();
601 TestBarrier::setup(2);
602
603 let udp_mode = UdpMode::Disabled;
604
605 let client_success = &AtomicBool::new(false);
606 let server_success = &AtomicBool::new(false);
607 let (server, server_addr) = server_info_reactive::<_, _, StackedRatchet>(
608 |connection| async move {
609 default_server_harness(udp_mode, connection, server_success).await
610 },
611 |_| (),
612 );
613
614 let uuid = Uuid::new_v4();
615
616 let server_connection_settings =
617 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid)
618 .with_udp_mode(udp_mode)
619 .build()
620 .unwrap();
621
622 let client_kernel = SingleClientServerConnectionKernel::new(
623 server_connection_settings,
624 |connection| async move {
625 log::trace!(target: "citadel", "***CLIENT TEST SUCCESS***");
626 wait_for_peers().await;
627
628 const KEY: &str = "HELLO_WORLD";
629 const KEY2: &str = "HELLO_WORLD2";
630 let value: Vec<u8> = Vec::from("Hello, world!");
631 let value2: Vec<u8> = Vec::from("Hello, world!2");
632
633 assert_eq!(connection.set(KEY, value.clone()).await?.as_deref(), None);
634 assert_eq!(
635 connection.get(KEY).await?.as_deref(),
636 Some(value.as_slice())
637 );
638
639 assert_eq!(connection.set(KEY2, value2.clone()).await?.as_deref(), None);
640 assert_eq!(
641 connection.get(KEY2).await?.as_deref(),
642 Some(value2.as_slice())
643 );
644
645 let map = connection.get_all().await?;
646 assert_eq!(map.get(KEY), Some(&value));
647 assert_eq!(map.get(KEY2), Some(&value2));
648
649 assert_eq!(
650 connection.remove(KEY2).await?.as_deref(),
651 Some(value2.as_slice())
652 );
653
654 assert_eq!(connection.remove(KEY2).await?.as_deref(), None);
655
656 let map = connection.remove_all().await?;
657 assert_eq!(map.get(KEY), Some(&value));
658 assert_eq!(map.get(KEY2), None);
659
660 assert_eq!(connection.get_all().await?.len(), 0);
661 assert_eq!(connection.remove_all().await?.len(), 0);
662
663 client_success.store(true, Ordering::Relaxed);
664 wait_for_peers().await;
665 connection.shutdown_kernel().await
666 },
667 );
668
669 let client = DefaultNodeBuilder::default().build(client_kernel).unwrap();
670
671 let joined = futures::future::try_join(server, client);
672
673 let _ = joined.await.expect("Failed to join server and client - possible port binding issue on Windows (error 10013)");
674
675 assert!(client_success.load(Ordering::Relaxed));
676 assert!(server_success.load(Ordering::Relaxed));
677 }
678
679 #[rstest]
680 #[timeout(std::time::Duration::from_secs(90))]
681 #[citadel_io::tokio::test(flavor = "multi_thread")]
682 async fn test_rekey_c2s() {
683 citadel_logging::setup_log();
684 TestBarrier::setup(2);
685
686 let udp_mode = UdpMode::Disabled;
687
688 let client_success = &AtomicBool::new(false);
689 let server_success = &AtomicBool::new(false);
690 let (server, server_addr) = server_info_reactive::<_, _, StackedRatchet>(
691 |connection| async move {
692 default_server_harness(udp_mode, connection, server_success).await
693 },
694 |_| (),
695 );
696
697 let uuid = Uuid::new_v4();
698
699 let server_connection_settings =
700 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid)
701 .with_udp_mode(udp_mode)
702 .build()
703 .unwrap();
704
705 let client_kernel = SingleClientServerConnectionKernel::new(
706 server_connection_settings,
707 |mut connection| async move {
708 log::trace!(target: "citadel", "***CLIENT LOGIN SUCCESS***");
709 wait_for_peers().await;
710 let chan = connection.udp_channel_rx.take();
711 crate::test_common::udp_mode_assertions(udp_mode, chan).await;
712
713 for x in 1..10 {
714 assert_eq!(connection.remote.rekey().await?, Some(x));
715 }
716
717 client_success.store(true, Ordering::Relaxed);
718 wait_for_peers().await;
719
720 connection.shutdown_kernel().await
721 },
722 );
723
724 let client = DefaultNodeBuilder::default().build(client_kernel).unwrap();
725 let joined = futures::future::try_join(server, client);
726
727 let _ = joined.await.expect("Failed to join server and client - possible port binding issue on Windows (error 10013)");
728
729 assert!(client_success.load(Ordering::Relaxed));
730 assert!(server_success.load(Ordering::Relaxed));
731 }
732}