1use crate::prelude::results::PeerConnectSuccess;
68use crate::prelude::*;
69use crate::test_common::wait_for_peers;
70use citadel_io::tokio::sync::mpsc::{Receiver, UnboundedSender};
71use citadel_io::{tokio, Mutex};
72use citadel_proto::prelude::async_trait;
73use citadel_user::hypernode_account::UserIdentifierExt;
74use futures::stream::FuturesUnordered;
75use futures::TryStreamExt;
76use std::collections::HashMap;
77use std::future::Future;
78use std::marker::PhantomData;
79use std::sync::Arc;
80use uuid::Uuid;
81
82pub struct PeerConnectionKernel<'a, F, Fut, R: Ratchet> {
85 inner_kernel: Box<dyn NetKernel<R> + 'a>,
86 shared: Shared,
87 _pd: PhantomData<fn() -> (F, Fut)>,
89}
90
91#[derive(Clone)]
92#[doc(hidden)]
93pub struct Shared {
94 active_peer_conns: Arc<Mutex<HashMap<PeerConnectionType, PeerContext>>>,
95}
96
97struct PeerContext {
98 #[allow(dead_code)]
99 conn_type: PeerConnectionType,
100 send_file_transfer_tx: UnboundedSender<ObjectTransferHandler>,
101}
102
103#[derive(Debug)]
104pub struct FileTransferHandleRx {
105 pub inner: citadel_io::tokio::sync::mpsc::UnboundedReceiver<ObjectTransferHandler>,
106 pub conn_type: VirtualTargetType,
107}
108
109impl FileTransferHandleRx {
110 pub fn accept_all(mut self) {
112 let task = tokio::task::spawn(async move {
113 let rx = &mut self.inner;
114 while let Some(mut handle) = rx.recv().await {
115 let task = tokio::task::spawn(async move {
116 if let Err(err) = handle.exhaust_stream().await {
117 let orientation = handle.orientation;
118 log::warn!(target: "citadel", "Error background handling of file transfer for {orientation:?}: {err:?}");
119 }
120 });
121
122 drop(task);
123 }
124 });
125
126 drop(task);
127 }
128}
129
130impl std::ops::Deref for FileTransferHandleRx {
131 type Target = citadel_io::tokio::sync::mpsc::UnboundedReceiver<ObjectTransferHandler>;
132
133 fn deref(&self) -> &Self::Target {
134 &self.inner
135 }
136}
137
138impl std::ops::DerefMut for FileTransferHandleRx {
139 fn deref_mut(&mut self) -> &mut Self::Target {
140 &mut self.inner
141 }
142}
143
144impl Drop for FileTransferHandleRx {
145 fn drop(&mut self) {
146 log::trace!(target: "citadel", "Dropping file transfer handle receiver {:?}", self.conn_type);
147 }
148}
149
150#[async_trait]
151impl<F, Fut, R: Ratchet> NetKernel<R> for PeerConnectionKernel<'_, F, Fut, R> {
152 fn load_remote(&mut self, server_remote: NodeRemote<R>) -> Result<(), NetworkError> {
153 self.inner_kernel.load_remote(server_remote)
154 }
155
156 async fn on_start(&self) -> Result<(), NetworkError> {
157 self.inner_kernel.on_start().await
158 }
159
160 #[allow(clippy::collapsible_else_if)]
161 async fn on_node_event_received(&self, message: NodeResult<R>) -> Result<(), NetworkError> {
162 match message {
163 NodeResult::ObjectTransferHandle(ObjectTransferHandle {
164 ticket: _,
165 handle,
166 session_cid,
167 }) => {
168 let is_revfs = matches!(
169 handle.metadata.transfer_type,
170 TransferType::RemoteEncryptedVirtualFilesystem { .. }
171 );
172 let active_peers = self.shared.active_peer_conns.lock();
173 let v_conn = if is_revfs {
174 let peer_cid = if session_cid != handle.source {
175 handle.source
176 } else {
177 handle.receiver
178 };
179 PeerConnectionType::LocalGroupPeer {
180 session_cid,
181 peer_cid,
182 }
183 } else {
184 if matches!(
185 handle.orientation,
186 ObjectTransferOrientation::Receiver { .. }
187 ) {
188 PeerConnectionType::LocalGroupPeer {
189 session_cid,
190 peer_cid: handle.source,
191 }
192 } else {
193 PeerConnectionType::LocalGroupPeer {
194 session_cid,
195 peer_cid: handle.receiver,
196 }
197 }
198 };
199
200 if let Some(peer_ctx) = active_peers.get(&v_conn) {
201 if let Err(err) = peer_ctx.send_file_transfer_tx.send(handle) {
202 log::warn!(target: "citadel", "Error forwarding file transfer handle: {:?}", err.to_string());
203 }
204 } else {
205 log::warn!(target: "citadel", "Unable to find key for inbound file transfer handle: {:?}\n Active Peers: {:?} \n handle_source = {}, handle_receiver = {}", v_conn, active_peers.keys().cloned().collect::<Vec<_>>(), handle.source, handle.receiver);
206 }
207
208 Ok(())
209 }
210
211 unprocessed @ NodeResult::Disconnect(..) | unprocessed => {
215 self.inner_kernel.on_node_event_received(unprocessed).await
217 }
218 }
219 }
220
221 async fn on_stop(&mut self) -> Result<(), NetworkError> {
222 self.inner_kernel.on_stop().await
223 }
224}
225
226#[derive(Debug, Default, Clone)]
229pub struct PeerConnectionSetupAggregator {
230 inner: Vec<PeerConnectionSettings>,
231}
232
233#[derive(Debug, Clone)]
234struct PeerConnectionSettings {
235 id: UserIdentifier,
236 session_security_settings: SessionSecuritySettings,
237 udp_mode: UdpMode,
238 ensure_registered: bool,
239 peer_session_password: Option<PreSharedKey>,
240}
241
242pub struct AddedPeer {
243 list: PeerConnectionSetupAggregator,
244 id: UserIdentifier,
245 session_security_settings: Option<SessionSecuritySettings>,
246 ensure_registered: bool,
247 udp_mode: Option<UdpMode>,
248 peer_session_password: Option<PreSharedKey>,
249}
250
251impl AddedPeer {
252 pub fn add(mut self) -> PeerConnectionSetupAggregator {
254 let new = PeerConnectionSettings {
255 id: self.id,
256 session_security_settings: self.session_security_settings.unwrap_or_default(),
257 udp_mode: self.udp_mode.unwrap_or_default(),
258 ensure_registered: self.ensure_registered,
259 peer_session_password: self.peer_session_password,
260 };
261
262 self.list.inner.push(new);
263 self.list
264 }
265
266 pub fn with_udp_mode(mut self, udp_mode: UdpMode) -> Self {
268 self.udp_mode = Some(udp_mode);
269 self
270 }
271
272 pub fn disable_udp(self) -> Self {
274 self.with_udp_mode(UdpMode::Disabled)
275 }
276
277 pub fn enable_udp(self) -> Self {
279 self.with_udp_mode(UdpMode::Enabled)
280 }
281
282 pub fn with_session_security_settings(
284 mut self,
285 session_security_settings: SessionSecuritySettings,
286 ) -> Self {
287 self.session_security_settings = Some(session_security_settings);
288 self
289 }
290
291 pub fn ensure_registered(mut self) -> Self {
293 self.ensure_registered = true;
294 self
295 }
296
297 pub fn with_session_password<T: Into<PreSharedKey>>(mut self, password: T) -> Self {
300 self.peer_session_password = Some(password.into());
301 self
302 }
303}
304
305impl PeerConnectionSetupAggregator {
306 pub fn with_peer<T: Into<UserIdentifier>>(self, peer: T) -> PeerConnectionSetupAggregator {
315 self.with_peer_custom(peer).add()
316 }
317
318 pub fn with_peer_custom<T: Into<UserIdentifier>>(self, peer: T) -> AddedPeer {
334 AddedPeer {
335 list: self,
336 id: peer.into(),
337 ensure_registered: false,
338 session_security_settings: None,
339 udp_mode: None,
340 peer_session_password: None,
341 }
342 }
343}
344
345impl From<PeerConnectionSetupAggregator> for Vec<PeerConnectionSettings> {
346 fn from(this: PeerConnectionSetupAggregator) -> Self {
347 this.inner
348 }
349}
350
351impl From<Vec<UserIdentifier>> for PeerConnectionSetupAggregator {
352 fn from(ids: Vec<UserIdentifier>) -> Self {
353 let mut this = PeerConnectionSetupAggregator::default();
354 for peer in ids {
355 this = this.with_peer(peer);
356 }
357
358 this
359 }
360}
361
362impl From<UserIdentifier> for PeerConnectionSetupAggregator {
363 fn from(this: UserIdentifier) -> Self {
364 Self::from(vec![this])
365 }
366}
367
368impl From<Uuid> for PeerConnectionSetupAggregator {
369 fn from(user: Uuid) -> Self {
370 let user_identifier: UserIdentifier = user.into();
371 user_identifier.into()
372 }
373}
374
375impl From<String> for PeerConnectionSetupAggregator {
376 fn from(this: String) -> Self {
377 let user_identifier: UserIdentifier = this.into();
378 user_identifier.into()
379 }
380}
381
382impl From<&str> for PeerConnectionSetupAggregator {
383 fn from(this: &str) -> Self {
384 let user_identifier: UserIdentifier = this.into();
385 user_identifier.into()
386 }
387}
388
389impl From<u64> for PeerConnectionSetupAggregator {
390 fn from(this: u64) -> Self {
391 let user_identifier: UserIdentifier = this.into();
392 user_identifier.into()
393 }
394}
395
396#[async_trait]
397impl<'a, F, Fut, T: Into<PeerConnectionSetupAggregator> + Send + 'a, R: Ratchet>
398 PrefabFunctions<'a, T, R> for PeerConnectionKernel<'a, F, Fut, R>
399where
400 F: FnOnce(
401 Receiver<Result<PeerConnectSuccess<R>, NetworkError>>,
402 CitadelClientServerConnection<R>,
403 ) -> Fut
404 + Send
405 + 'a,
406 Fut: Future<Output = Result<(), NetworkError>> + Send + 'a,
407{
408 type UserLevelInputFunction = F;
409 type SharedBundle = Shared;
410
411 fn get_shared_bundle(&self) -> Self::SharedBundle {
412 self.shared.clone()
413 }
414
415 #[allow(clippy::blocks_in_conditions)]
416 #[cfg_attr(
417 feature = "localhost-testing",
418 tracing::instrument(level = "trace", target = "citadel", skip_all, ret, err(Debug))
419 )]
420 async fn on_c2s_channel_received(
421 connect_success: CitadelClientServerConnection<R>,
422 peers_to_connect: T,
423 f: Self::UserLevelInputFunction,
424 shared: Shared,
425 ) -> Result<(), NetworkError> {
426 let shared = &shared;
427 let session_cid = connect_success.cid;
428 let mut peers_already_registered = vec![];
429
430 wait_for_peers().await;
431 let peers_to_connect = peers_to_connect.into().inner;
432
433 for peer in &peers_to_connect {
434 peers_already_registered.push(
436 peer.id
437 .search_peer(session_cid, connect_success.account_manager())
438 .await?,
439 )
440 }
441
442 let remote = connect_success.clone();
443 let (tx, rx) = citadel_io::tokio::sync::mpsc::channel(peers_to_connect.len());
444 let requests = FuturesUnordered::new();
445
446 for (mutually_registered, peer_to_connect) in
447 peers_already_registered.into_iter().zip(peers_to_connect)
448 {
449 let remote = remote.clone();
452 let tx = tx.clone();
453 let PeerConnectionSettings {
454 id,
455 session_security_settings,
456 udp_mode,
457 ensure_registered,
458 peer_session_password,
459 } = peer_to_connect;
460
461 let task = async move {
462 let inner_task = async move {
463 let (file_transfer_tx, file_transfer_rx) =
464 citadel_io::tokio::sync::mpsc::unbounded_channel();
465
466 let peer_cid = if let Some(mutual_peer) = &mutually_registered {
468 mutual_peer.cid
469 } else {
470 id.get_cid()
471 };
472
473 let handle = if let Some(_already_registered) = mutually_registered {
474 remote.find_target(session_cid, id).await?
475 } else {
476 log::info!(target: "citadel", "{session_cid} proposing target {id:?} to central node");
478 let handle = remote.propose_target(session_cid, id.clone()).await?;
479 if ensure_registered {
482 loop {
483 if handle.is_peer_registered().await? {
484 break;
485 }
486 citadel_io::time::sleep(std::time::Duration::from_millis(200))
487 .await;
488 }
489 }
490
491 log::info!(target: "citadel", "{session_cid} registering to peer {id:?}");
492 let _reg_success = handle.register_to_peer().await?;
493 log::info!(target: "citadel", "{session_cid} registered to peer {id:?} registered || success -> now connecting");
494 handle
495 };
496
497 let peer_conn = PeerConnectionType::LocalGroupPeer {
500 session_cid,
501 peer_cid,
502 };
503 let peer_context = PeerContext {
504 conn_type: peer_conn,
505 send_file_transfer_tx: file_transfer_tx.clone(),
506 };
507 log::debug!(target: "citadel", "Early registering peer connection: {peer_conn:?}");
508 let _ = shared
509 .active_peer_conns
510 .lock()
511 .insert(peer_conn, peer_context);
512
513 handle
514 .connect_to_peer_custom(
515 session_security_settings,
516 udp_mode,
517 peer_session_password,
518 )
519 .await
520 .map(|mut success| {
521 let actual_peer_conn = success.channel.get_peer_conn_type().unwrap();
522
523 if actual_peer_conn != peer_conn {
526 log::debug!(target: "citadel", "Updating peer connection registration from {peer_conn:?} to {actual_peer_conn:?}");
527 let mut active_peers = shared.active_peer_conns.lock();
528 if let Some(peer_ctx) = active_peers.remove(&peer_conn) {
529 let _ = active_peers.insert(actual_peer_conn, peer_ctx);
530 }
531 }
532 success.incoming_object_transfer_handles = Some(FileTransferHandleRx {
534 inner: file_transfer_rx,
535 conn_type: actual_peer_conn.as_virtual_connection(),
536 });
537 success
538 })
539 .inspect_err(|_err| {
540 let _ = shared.active_peer_conns.lock().remove(&peer_conn);
542 })
543 };
544
545 tx.send(inner_task.await)
546 .await
547 .map_err(|err| NetworkError::Generic(err.to_string()))
548 };
549
550 requests.push(Box::pin(task))
551 }
552
553 drop(tx);
556
557 let (collection_result, user_result) =
560 citadel_io::tokio::join!(requests.try_collect::<()>(), f(rx, connect_success));
561
562 collection_result?;
564 user_result
565 }
566
567 fn construct(kernel: Box<dyn NetKernel<R> + 'a>) -> Self {
568 Self {
569 inner_kernel: kernel,
570 shared: Shared {
571 active_peer_conns: Arc::new(Mutex::new(Default::default())),
572 },
573 _pd: Default::default(),
574 }
575 }
576}
577
578#[cfg(all(test, feature = "localhost-testing"))]
579mod tests {
580 use crate::prefabs::client::peer_connection::PeerConnectionKernel;
581 use crate::prefabs::client::DefaultServerConnectionSettingsBuilder;
582 use crate::prelude::*;
583 use crate::remote_ext::results::PeerConnectSuccess;
584 use crate::test_common::{server_info, wait_for_peers, TestBarrier};
585 use citadel_io::tokio;
586 use citadel_io::tokio::sync::mpsc::{Receiver, UnboundedSender};
587 use citadel_user::prelude::UserIdentifierExt;
588 use futures::stream::FuturesUnordered;
589 use futures::TryStreamExt;
590 use rstest::rstest;
591 use std::collections::HashMap;
592 use std::future::Future;
593 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
594 use std::time::Duration;
595 use uuid::Uuid;
596
597 lazy_static::lazy_static! {
598 pub static ref PEERS: Vec<(String, String, String)> = {
599 ["alpha", "beta", "charlie", "echo", "delta", "epsilon", "foxtrot"]
600 .iter().map(|base| (format!("{base}.username"), format!("{base}.password"), format!("{base}.full_name")))
601 .collect()
602 };
603 }
604
605 #[rstest]
606 #[case(2, UdpMode::Enabled)]
607 #[case(3, UdpMode::Disabled)]
608 #[timeout(Duration::from_secs(90))]
609 #[tokio::test(flavor = "multi_thread")]
610 async fn peer_to_peer_connect(#[case] peer_count: usize, #[case] udp_mode: UdpMode) {
611 assert!(peer_count > 1);
612 citadel_logging::setup_log();
613 TestBarrier::setup(peer_count);
614
615 let client_success = &AtomicUsize::new(0);
616 let (server, server_addr) = server_info::<StackedRatchet>();
617
618 let client_kernels = FuturesUnordered::new();
619 let total_peers = (0..peer_count)
620 .map(|idx| PEERS.get(idx).unwrap().0.clone())
621 .collect::<Vec<String>>();
622
623 for idx in 0..peer_count {
624 let (username, password, full_name) = PEERS.get(idx).unwrap();
625 let peers = total_peers
626 .clone()
627 .into_iter()
628 .filter(|r| r != username)
629 .map(UserIdentifier::Username)
630 .collect::<Vec<UserIdentifier>>();
631
632 let mut agg = PeerConnectionSetupAggregator::default();
633
634 for peer in peers {
635 agg = agg
636 .with_peer_custom(peer)
637 .ensure_registered()
638 .with_udp_mode(udp_mode)
639 .with_session_security_settings(SessionSecuritySettings::default())
640 .add();
641 }
642
643 let server_connection_settings =
644 DefaultServerConnectionSettingsBuilder::credentialed_registration(
645 server_addr,
646 username,
647 full_name,
648 password.as_str(),
649 )
650 .build()
651 .unwrap();
652
653 let username = username.clone();
654
655 let client_kernel = PeerConnectionKernel::new(
656 server_connection_settings,
657 agg.clone(),
658 move |results, connection| async move {
659 log::info!(target: "citadel", "***PEER {username} CONNECTED ***");
660 let session_cid = connection.conn_type.get_session_cid();
661 let check = move |conn: PeerConnectSuccess<_>| async move {
662 let session_cid = conn.channel.get_session_cid();
663 let _mutual_peers = conn
664 .remote
665 .remote()
666 .get_local_group_mutual_peers(session_cid)
667 .await
668 .unwrap();
669 conn
670 };
671 let p2p_remotes = handle_peer_connect_successes(
672 results,
673 session_cid,
674 peer_count,
675 udp_mode,
676 check,
677 )
678 .await
679 .into_iter()
680 .map(|r| (r.channel.get_peer_cid(), r.remote))
681 .collect::<HashMap<_, _>>();
682
683 let network_peers = connection.get_peers(None).await.unwrap();
687 for user in agg.inner {
688 let peer_cid = user.id.get_cid();
689 assert!(network_peers.iter().any(|r| r.cid == peer_cid))
690 }
691
692 let session_cid = connection.conn_type.get_session_cid();
694 let mutual_peers = connection
695 .get_local_group_mutual_peers(session_cid)
696 .await
697 .unwrap();
698 for (peer_cid, _) in p2p_remotes {
699 assert!(mutual_peers.iter().any(|r| r.cid == peer_cid))
700 }
701
702 log::info!(target: "citadel", "***PEER {username} finished all checks***");
703 let _ = client_success.fetch_add(1, Ordering::Relaxed);
704 wait_for_peers().await;
705 connection.shutdown_kernel().await
706 },
707 );
708
709 let client = DefaultNodeBuilder::default().build(client_kernel).unwrap();
710 client_kernels.push(async move { client.await.map(|_| ()) });
711 }
712
713 let clients = Box::pin(async move { client_kernels.try_collect::<()>().await.map(|_| ()) });
714
715 assert!(futures::future::try_select(server, clients).await.is_ok());
716
717 assert_eq!(client_success.load(Ordering::Relaxed), peer_count);
718 }
719
720 #[rstest]
721 #[case(2, HeaderObfuscatorSettings::default())]
722 #[case(2, HeaderObfuscatorSettings::Enabled)]
723 #[case(2, HeaderObfuscatorSettings::EnabledWithKey(12345))]
724 #[case(3, HeaderObfuscatorSettings::default())]
725 #[timeout(Duration::from_secs(90))]
726 #[tokio::test(flavor = "multi_thread")]
727 async fn peer_to_peer_connect_transient(
728 #[case] peer_count: usize,
729 #[case] header_obfuscator_settings: HeaderObfuscatorSettings,
730 ) -> Result<(), Box<dyn std::error::Error>> {
731 assert!(peer_count > 1);
732 citadel_logging::setup_log();
733 TestBarrier::setup(peer_count);
734 let udp_mode = UdpMode::Enabled;
735
736 let do_deregister = peer_count == 2;
737
738 let client_success = &AtomicUsize::new(0);
739 let (server, server_addr) = server_info::<StackedRatchet>();
740
741 let client_kernels = FuturesUnordered::new();
742 let total_peers = (0..peer_count)
743 .map(|_| Uuid::new_v4())
744 .collect::<Vec<Uuid>>();
745
746 for idx in 0..peer_count {
747 let uuid = total_peers.get(idx).cloned().unwrap();
748 let peers = total_peers
749 .clone()
750 .into_iter()
751 .filter(|r| r != &uuid)
752 .map(UserIdentifier::from)
753 .collect::<Vec<UserIdentifier>>();
754
755 let mut agg = PeerConnectionSetupAggregator::default();
756
757 for peer in peers {
758 let security_settings = SessionSecuritySettings {
759 header_obfuscator_settings,
760 ..Default::default()
761 };
762 agg = agg
763 .with_peer_custom(peer)
764 .with_udp_mode(udp_mode)
765 .ensure_registered()
766 .with_session_security_settings(security_settings)
767 .add();
768 }
769
770 let server_connection_settings =
771 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid)
772 .build()
773 .unwrap();
774
775 let client_kernel = PeerConnectionKernel::new(
776 server_connection_settings,
777 agg,
778 move |results, remote| async move {
779 log::info!(target: "citadel", "***PEER {uuid} CONNECTED***");
780 let session_cid = remote.conn_type.get_session_cid();
781
782 let check = move |conn: PeerConnectSuccess<_>| async move {
783 if do_deregister {
784 conn.remote
785 .deregister()
786 .await
787 .expect("Deregistration failed");
788 assert!(!conn
789 .remote
790 .inner
791 .account_manager()
792 .get_persistence_handler()
793 .hyperlan_peer_exists(session_cid, conn.channel.get_peer_cid())
794 .await
795 .unwrap());
796 }
797 conn
798 };
799
800 let _ = handle_peer_connect_successes(
801 results,
802 session_cid,
803 peer_count,
804 udp_mode,
805 check,
806 )
807 .await;
808
809 log::info!(target: "citadel", "***PEER {uuid} finished all checks***");
810 let _ = client_success.fetch_add(1, Ordering::Relaxed);
811 wait_for_peers().await;
812 remote.shutdown_kernel().await
813 },
814 );
815
816 let client = DefaultNodeBuilder::default().build(client_kernel)?;
817 client_kernels.push(async move { client.await.map(|_| ()) });
818 }
819
820 let clients = Box::pin(async move { client_kernels.try_collect::<()>().await.map(|_| ()) });
821
822 if let Err(err) = futures::future::try_select(server, clients).await {
823 return match err {
824 futures::future::Either::Left(res) => Err(res.0.into_string().into()),
825 futures::future::Either::Right(res) => Err(res.0.into_string().into()),
826 };
827 }
828
829 assert_eq!(client_success.load(Ordering::Relaxed), peer_count);
830 Ok(())
831 }
832
833 #[rstest]
834 #[case(2)]
835 #[case(3)]
836 #[timeout(std::time::Duration::from_secs(180))]
837 #[tokio::test(flavor = "multi_thread")]
838 async fn test_peer_to_peer_file_transfer(
839 #[case] peer_count: usize,
840 ) -> Result<(), Box<dyn std::error::Error>> {
841 assert!(peer_count > 1);
842 citadel_logging::setup_log();
843 TestBarrier::setup(peer_count);
844 let udp_mode = UdpMode::Enabled;
845
846 let sender_success = &AtomicBool::new(false);
847 let receiver_success = &AtomicBool::new(false);
848
849 let (server, server_addr) = server_info::<StackedRatchet>();
850
851 let client_kernels = FuturesUnordered::new();
852 let total_peers = (0..peer_count)
853 .map(|_| Uuid::new_v4())
854 .collect::<Vec<Uuid>>();
855
856 let sender_uuid = total_peers[0];
857
858 for idx in 0..peer_count {
859 let uuid = total_peers.get(idx).cloned().unwrap();
860 let mut peers = total_peers
861 .clone()
862 .into_iter()
863 .filter(|r| r != &uuid)
864 .map(UserIdentifier::from)
865 .collect::<Vec<UserIdentifier>>();
866 if idx != 0 {
872 peers = vec![sender_uuid.into()];
873 }
874
875 let mut agg = PeerConnectionSetupAggregator::default();
876
877 for peer in peers {
878 agg = agg
879 .with_peer_custom(peer)
880 .ensure_registered()
881 .with_udp_mode(udp_mode)
882 .with_session_security_settings(SessionSecuritySettings::default())
883 .add();
884 }
885
886 let server_connection_settings =
887 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid)
888 .build()
889 .unwrap();
890
891 let client_kernel = PeerConnectionKernel::new(
892 server_connection_settings,
893 agg,
894 move |results, remote| async move {
895 log::info!(target: "citadel", "***PEER {uuid} CONNECTED***");
896 wait_for_peers().await;
897 let session_cid = remote.conn_type.get_session_cid();
898 let is_sender = idx == 0; let check = move |mut conn: PeerConnectSuccess<_>| async move {
900 if is_sender {
901 conn.remote
902 .send_file_with_custom_opts(
903 "../resources/TheBridge.pdf",
904 32 * 1024,
905 TransferType::FileTransfer,
906 )
907 .await
908 .expect("Failed to send file");
909 } else {
910 let mut handle = conn
912 .incoming_object_transfer_handles
913 .take()
914 .unwrap()
915 .recv()
916 .await
917 .unwrap();
918 handle.accept().unwrap();
919
920 use citadel_types::proto::ObjectTransferStatus;
921 use futures::StreamExt;
922 let mut path = None;
923 while let Some(status) = handle.next().await {
924 match status {
925 ObjectTransferStatus::ReceptionComplete => {
926 let cmp =
927 include_bytes!("../../../../resources/TheBridge.pdf");
928 let streamed_data =
929 citadel_io::tokio::fs::read(path.clone().unwrap())
930 .await
931 .unwrap();
932 assert_eq!(
933 cmp,
934 streamed_data.as_slice(),
935 "Original data and streamed data does not match"
936 );
937
938 log::info!(target: "citadel", "Peer has finished receiving and verifying the file!");
939 break;
940 }
941
942 ObjectTransferStatus::ReceptionBeginning(file_path, vfm) => {
943 path = Some(file_path);
944 assert_eq!(vfm.name, "TheBridge.pdf")
945 }
946
947 _ => {}
948 }
949 }
950 }
951
952 conn
953 };
954 let peer_count = if idx == 0 { peer_count } else { 2 };
957 let _ = handle_peer_connect_successes(
958 results,
959 session_cid,
960 peer_count,
961 udp_mode,
962 check,
963 )
964 .await;
965
966 if is_sender {
967 sender_success.store(true, Ordering::Relaxed);
968 } else {
969 receiver_success.store(true, Ordering::Relaxed);
970 }
971
972 log::info!(target: "citadel", "***PEER {uuid} (is_sender: {is_sender}) finished all checks***");
973 wait_for_peers().await;
974 log::info!(target: "citadel", "***PEER {uuid} (is_sender: {is_sender}) shutting down***");
975 remote.shutdown_kernel().await
976 },
977 );
978
979 let client = DefaultNodeBuilder::default().build(client_kernel).unwrap();
980 client_kernels.push(async move { client.await.map(|_| ()) });
981 }
982
983 let clients = Box::pin(async move { client_kernels.try_collect::<()>().await.map(|_| ()) });
984
985 if let Err(err) = futures::future::try_select(server, clients).await {
986 return match err {
987 futures::future::Either::Left(res) => Err(res.0.into_string().into()),
988 futures::future::Either::Right(res) => Err(res.0.into_string().into()),
989 };
990 }
991
992 assert!(sender_success.load(Ordering::Relaxed));
993 assert!(receiver_success.load(Ordering::Relaxed));
994 Ok(())
995 }
996
997 #[rstest]
998 #[case(2)]
999 #[timeout(std::time::Duration::from_secs(90))]
1000 #[tokio::test(flavor = "multi_thread")]
1001 async fn test_peer_to_peer_rekey(
1002 #[case] peer_count: usize,
1003 ) -> Result<(), Box<dyn std::error::Error>> {
1004 assert!(peer_count > 1);
1005 citadel_logging::setup_log();
1006 TestBarrier::setup(peer_count);
1007 let udp_mode = UdpMode::Enabled;
1008
1009 let client_success = &AtomicUsize::new(0);
1010 let (server, server_addr) = server_info::<StackedRatchet>();
1011
1012 let client_kernels = FuturesUnordered::new();
1013 let total_peers = (0..peer_count)
1014 .map(|_| Uuid::new_v4())
1015 .collect::<Vec<Uuid>>();
1016
1017 for idx in 0..peer_count {
1018 let uuid = total_peers.get(idx).cloned().unwrap();
1019 let peers = total_peers
1020 .clone()
1021 .into_iter()
1022 .filter(|r| r != &uuid)
1023 .map(UserIdentifier::from)
1024 .collect::<Vec<UserIdentifier>>();
1025
1026 let mut agg = PeerConnectionSetupAggregator::default();
1027
1028 for peer in peers {
1029 agg = agg
1030 .with_peer_custom(peer)
1031 .ensure_registered()
1032 .with_udp_mode(udp_mode)
1033 .with_session_security_settings(SessionSecuritySettings::default())
1034 .add();
1035 }
1036
1037 let server_connection_settings =
1038 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid)
1039 .build()
1040 .unwrap();
1041
1042 let client_kernel = PeerConnectionKernel::new(
1043 server_connection_settings,
1044 agg,
1045 move |results, remote| async move {
1046 log::info!(target: "citadel", "***PEER {uuid} CONNECTED***");
1047 let session_cid = remote.conn_type.get_session_cid();
1048
1049 let check = move |conn: PeerConnectSuccess<_>| async move {
1050 if idx == 0 {
1051 for x in 1..10 {
1052 assert_eq!(
1053 conn.remote.rekey().await.expect("Failed to rekey"),
1054 Some(x)
1055 );
1056 }
1057 }
1058
1059 conn
1060 };
1061
1062 let results = handle_peer_connect_successes(
1063 results,
1064 session_cid,
1065 peer_count,
1066 udp_mode,
1067 check,
1068 )
1069 .await;
1070
1071 log::info!(target: "citadel", "***PEER {uuid} finished all check (count: {})s***", results.len());
1072 let _ = client_success.fetch_add(1, Ordering::Relaxed);
1073 wait_for_peers().await;
1074 remote.shutdown_kernel().await
1075 },
1076 );
1077
1078 let client = DefaultNodeBuilder::default().build(client_kernel)?;
1079 client_kernels.push(async move { client.await.map(|_| ()) });
1080 }
1081
1082 let clients = Box::pin(async move { client_kernels.try_collect::<()>().await.map(|_| ()) });
1083
1084 if let Err(err) = futures::future::try_select(server, clients).await {
1085 return match err {
1086 futures::future::Either::Left(res) => Err(res.0.into_string().into()),
1087 futures::future::Either::Right(res) => Err(res.0.into_string().into()),
1088 };
1089 }
1090
1091 assert_eq!(client_success.load(Ordering::Relaxed), peer_count);
1092 Ok(())
1093 }
1094
1095 #[rstest]
1096 #[case(2)]
1097 #[timeout(std::time::Duration::from_secs(90))]
1098 #[tokio::test(flavor = "multi_thread")]
1099 async fn test_peer_to_peer_disconnect(
1100 #[case] peer_count: usize,
1101 ) -> Result<(), Box<dyn std::error::Error>> {
1102 assert!(peer_count > 1);
1103 citadel_logging::setup_log();
1104 TestBarrier::setup(peer_count);
1105 let udp_mode = UdpMode::Enabled;
1106
1107 let client_success = &AtomicUsize::new(0);
1108 let (server, server_addr) = server_info::<StackedRatchet>();
1109
1110 let client_kernels = FuturesUnordered::new();
1111 let total_peers = (0..peer_count)
1112 .map(|_| Uuid::new_v4())
1113 .collect::<Vec<Uuid>>();
1114
1115 for idx in 0..peer_count {
1116 let uuid = total_peers.get(idx).cloned().unwrap();
1117 let peers = total_peers
1118 .clone()
1119 .into_iter()
1120 .filter(|r| r != &uuid)
1121 .map(UserIdentifier::from)
1122 .collect::<Vec<UserIdentifier>>();
1123
1124 let mut agg = PeerConnectionSetupAggregator::default();
1125
1126 for peer in peers {
1127 agg = agg
1128 .with_peer_custom(peer)
1129 .ensure_registered()
1130 .with_udp_mode(udp_mode)
1131 .with_session_security_settings(SessionSecuritySettings::default())
1132 .add();
1133 }
1134
1135 let server_connection_settings =
1136 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid)
1137 .build()
1138 .unwrap();
1139
1140 let client_kernel = PeerConnectionKernel::new(
1141 server_connection_settings,
1142 agg,
1143 move |results, remote| async move {
1144 log::info!(target: "citadel", "***PEER {uuid} CONNECTED***");
1145 wait_for_peers().await;
1146 let session_cid = remote.conn_type.get_session_cid();
1147
1148 let check = move |conn: PeerConnectSuccess<_>| async move {
1149 conn.remote
1150 .disconnect()
1151 .await
1152 .expect("Failed to p2p disconnect");
1153 conn
1154 };
1155 let _ = handle_peer_connect_successes(
1156 results,
1157 session_cid,
1158 peer_count,
1159 udp_mode,
1160 check,
1161 )
1162 .await;
1163 log::info!(target: "citadel", "***PEER {uuid} finished all checks***");
1164
1165 let _ = client_success.fetch_add(1, Ordering::Relaxed);
1166 wait_for_peers().await;
1167 remote.shutdown_kernel().await
1168 },
1169 );
1170
1171 let client = DefaultNodeBuilder::default().build(client_kernel)?;
1172 client_kernels.push(async move { client.await.map(|_| ()) });
1173 }
1174
1175 let clients = Box::pin(async move { client_kernels.try_collect::<()>().await.map(|_| ()) });
1176
1177 if let Err(err) = futures::future::try_select(server, clients).await {
1178 return match err {
1179 futures::future::Either::Left(res) => Err(res.0.into_string().into()),
1180 futures::future::Either::Right(res) => Err(res.0.into_string().into()),
1181 };
1182 }
1183
1184 assert_eq!(client_success.load(Ordering::Relaxed), peer_count);
1185 Ok(())
1186 }
1187
1188 #[rstest]
1189 #[case(SecrecyMode::BestEffort, Some("test-p2p-password"))]
1190 #[timeout(std::time::Duration::from_secs(240))]
1191 #[citadel_io::tokio::test(flavor = "multi_thread")]
1192 async fn test_p2p_wrong_session_password(
1193 #[case] secrecy_mode: SecrecyMode,
1194 #[case] p2p_password: Option<&'static str>,
1195 #[values(KemAlgorithm::MlKem)] kem: KemAlgorithm,
1196 #[values(EncryptionAlgorithm::AES_GCM_256)] enx: EncryptionAlgorithm,
1197 ) {
1198 citadel_logging::setup_log_no_panic_hook();
1199 TestBarrier::setup(2);
1200 let (server, server_addr) = server_info::<StackedRatchet>();
1201 let peer_0_error_received = &AtomicBool::new(false);
1202 let peer_1_error_received = &AtomicBool::new(false);
1203
1204 let uuid0 = Uuid::new_v4();
1205 let uuid1 = Uuid::new_v4();
1206 let session_security = SessionSecuritySettingsBuilder::default()
1207 .with_secrecy_mode(secrecy_mode)
1208 .with_crypto_params(kem + enx)
1209 .build()
1210 .unwrap();
1211
1212 let mut peer0_agg = PeerConnectionSetupAggregator::default()
1213 .with_peer_custom(uuid1)
1214 .ensure_registered()
1215 .with_session_security_settings(session_security);
1216
1217 if let Some(password) = p2p_password {
1218 peer0_agg = peer0_agg.with_session_password(password);
1219 }
1220
1221 let peer0_connection = peer0_agg.add();
1222
1223 let mut peer1_agg = PeerConnectionSetupAggregator::default()
1224 .with_peer_custom(uuid0)
1225 .ensure_registered()
1226 .with_session_security_settings(session_security);
1227
1228 if let Some(_password) = p2p_password {
1229 peer1_agg = peer1_agg.with_session_password("wrong password");
1230 }
1231
1232 let peer1_connection = peer1_agg.add();
1233
1234 let server_connection_settings0 =
1235 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid0)
1236 .with_udp_mode(UdpMode::Enabled)
1237 .with_session_security_settings(session_security)
1238 .build()
1239 .unwrap();
1240
1241 let server_connection_settings1 =
1242 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid1)
1243 .with_udp_mode(UdpMode::Enabled)
1244 .with_session_security_settings(session_security)
1245 .build()
1246 .unwrap();
1247
1248 let client_kernel0 = PeerConnectionKernel::new(
1249 server_connection_settings0,
1250 peer0_connection,
1251 move |mut connection, remote| async move {
1252 wait_for_peers().await;
1253 let conn = connection.recv().await.unwrap();
1254 log::trace!(target: "citadel", "Peer 0 {} received: {:?}", remote.conn_type.get_session_cid(), conn);
1255 if conn.is_ok() {
1256 peer_0_error_received.store(true, Ordering::SeqCst);
1257 }
1258 wait_for_peers().await;
1259 remote.shutdown_kernel().await
1260 },
1261 );
1262
1263 let client_kernel1 = PeerConnectionKernel::new(
1264 server_connection_settings1,
1265 peer1_connection,
1266 move |mut connection, remote| async move {
1267 wait_for_peers().await;
1268 let conn = connection.recv().await.unwrap();
1269 log::trace!(target: "citadel", "Peer 1 {} received: {:?}", remote.conn_type.get_session_cid(), conn);
1270 if conn.is_ok() {
1271 peer_1_error_received.store(true, Ordering::SeqCst);
1272 }
1273 wait_for_peers().await;
1274 remote.shutdown_kernel().await
1275 },
1276 );
1277
1278 let client0 = DefaultNodeBuilder::default().build(client_kernel0).unwrap();
1279 let client1 = DefaultNodeBuilder::default().build(client_kernel1).unwrap();
1280 let clients = futures::future::try_join(client0, client1);
1281
1282 let task = async move {
1283 tokio::select! {
1284 server_res = server => Err(NetworkError::msg(format!("Server ended prematurely: {:?}", server_res.map(|_| ())))),
1285 client_res = clients => client_res.map(|_| ())
1286 }
1287 };
1288
1289 tokio::time::timeout(Duration::from_secs(120), task)
1290 .await
1291 .unwrap()
1292 .unwrap();
1293
1294 assert!(!peer_0_error_received.load(Ordering::SeqCst));
1295 assert!(!peer_1_error_received.load(Ordering::SeqCst));
1296 }
1297
1298 async fn handle_peer_connect_successes<F, Fut, R: Ratchet>(
1299 mut conn_rx: Receiver<Result<PeerConnectSuccess<R>, NetworkError>>,
1300 session_cid: u64,
1301 peer_count: usize,
1302 udp_mode: UdpMode,
1303 checks: F,
1304 ) -> Vec<PeerConnectSuccess<R>>
1305 where
1306 F: Fn(PeerConnectSuccess<R>) -> Fut + Send + Clone + 'static,
1307 Fut: Future<Output = PeerConnectSuccess<R>> + Send,
1308 {
1309 let (finished_tx, finished_rx) = tokio::sync::oneshot::channel();
1310
1311 let task = async move {
1312 let (done_tx, mut done_rx) = tokio::sync::mpsc::unbounded_channel();
1313 let mut conns = vec![];
1314 while let Some(conn) = conn_rx.recv().await {
1315 conns.push(conn);
1316 if conns.len() == peer_count - 1 {
1317 break;
1318 }
1319 }
1320
1321 log::info!(target: "citadel", "~~~*** Peer {session_cid} has {} connections to other peers ***~~~", conns.len());
1322
1323 for conn in conns {
1324 let conn = conn.expect("Error receiving peer connection");
1325 handle_peer_connect_success(
1326 conn,
1327 done_tx.clone(),
1328 session_cid,
1329 udp_mode,
1330 checks.clone(),
1331 );
1332 }
1333
1334 let mut ret = vec![];
1336 while let Some(done) = done_rx.recv().await {
1337 ret.push(done);
1338 if ret.len() == peer_count - 1 {
1339 break;
1340 }
1341 }
1342
1343 finished_tx
1344 .send(ret)
1345 .expect("Error sending finished signal in handle_peer_connect_successes");
1346 };
1347
1348 drop(tokio::task::spawn(task));
1349 let ret = finished_rx
1350 .await
1351 .expect("Error receiving finished signal in handle_peer_connect_successes");
1352
1353 assert_eq!(ret.len(), peer_count - 1);
1354 ret
1355 }
1356
1357 fn handle_peer_connect_success<F, Fut, R: Ratchet>(
1358 mut conn: PeerConnectSuccess<R>,
1359 done_tx: UnboundedSender<PeerConnectSuccess<R>>,
1360 session_cid: u64,
1361 udp_mode: UdpMode,
1362 checks: F,
1363 ) where
1364 F: Fn(PeerConnectSuccess<R>) -> Fut + Send + Clone + 'static,
1365 Fut: Future<Output = PeerConnectSuccess<R>> + Send,
1366 {
1367 let task = async move {
1368 let chan = conn.udp_channel_rx.take();
1369 crate::test_common::p2p_assertions(session_cid, &conn).await;
1370 crate::test_common::udp_mode_assertions(udp_mode, chan).await;
1371 let conn = checks(conn).await;
1372 done_tx
1373 .send(conn)
1374 .expect("Error sending done signal in handle_peer_connect_success");
1375 };
1376
1377 drop(tokio::task::spawn(task));
1378 }
1379}