1use crate::prelude::results::PeerConnectSuccess;
68use crate::prelude::*;
69use crate::test_common::wait_for_peers;
70use citadel_io::tokio::sync::mpsc::{Receiver, UnboundedSender};
71use citadel_io::{tokio, Mutex};
72use citadel_proto::re_imports::async_trait;
73use citadel_user::hypernode_account::UserIdentifierExt;
74use futures::stream::FuturesUnordered;
75use futures::TryStreamExt;
76use std::collections::HashMap;
77use std::future::Future;
78use std::marker::PhantomData;
79use std::sync::Arc;
80use uuid::Uuid;
81
82pub struct PeerConnectionKernel<'a, F, Fut, R: Ratchet> {
85 inner_kernel: Box<dyn NetKernel<R> + 'a>,
86 shared: Shared,
87 _pd: PhantomData<fn() -> (F, Fut)>,
89}
90
91#[derive(Clone)]
92#[doc(hidden)]
93pub struct Shared {
94 active_peer_conns: Arc<Mutex<HashMap<PeerConnectionType, PeerContext>>>,
95}
96
97struct PeerContext {
98 #[allow(dead_code)]
99 conn_type: PeerConnectionType,
100 send_file_transfer_tx: UnboundedSender<ObjectTransferHandler>,
101}
102
103#[derive(Debug)]
104pub struct FileTransferHandleRx {
105 pub inner: citadel_io::tokio::sync::mpsc::UnboundedReceiver<ObjectTransferHandler>,
106 pub conn_type: VirtualTargetType,
107}
108
109impl FileTransferHandleRx {
110 pub fn accept_all(mut self) {
112 let task = tokio::task::spawn(async move {
113 let rx = &mut self.inner;
114 while let Some(mut handle) = rx.recv().await {
115 let task = tokio::task::spawn(async move {
116 if let Err(err) = handle.exhaust_stream().await {
117 let orientation = handle.orientation;
118 log::warn!(target: "citadel", "Error background handling of file transfer for {orientation:?}: {err:?}");
119 }
120 });
121
122 drop(task);
123 }
124 });
125
126 drop(task);
127 }
128}
129
130impl std::ops::Deref for FileTransferHandleRx {
131 type Target = citadel_io::tokio::sync::mpsc::UnboundedReceiver<ObjectTransferHandler>;
132
133 fn deref(&self) -> &Self::Target {
134 &self.inner
135 }
136}
137
138impl std::ops::DerefMut for FileTransferHandleRx {
139 fn deref_mut(&mut self) -> &mut Self::Target {
140 &mut self.inner
141 }
142}
143
144impl Drop for FileTransferHandleRx {
145 fn drop(&mut self) {
146 log::trace!(target: "citadel", "Dropping file transfer handle receiver {:?}", self.conn_type);
147 }
148}
149
150#[async_trait]
151impl<F, Fut, R: Ratchet> NetKernel<R> for PeerConnectionKernel<'_, F, Fut, R> {
152 fn load_remote(&mut self, server_remote: NodeRemote<R>) -> Result<(), NetworkError> {
153 self.inner_kernel.load_remote(server_remote)
154 }
155
156 async fn on_start(&self) -> Result<(), NetworkError> {
157 self.inner_kernel.on_start().await
158 }
159
160 #[allow(clippy::collapsible_else_if)]
161 async fn on_node_event_received(&self, message: NodeResult<R>) -> Result<(), NetworkError> {
162 match message {
163 NodeResult::ObjectTransferHandle(ObjectTransferHandle {
164 ticket: _,
165 handle,
166 session_cid,
167 }) => {
168 let is_revfs = matches!(
169 handle.metadata.transfer_type,
170 TransferType::RemoteEncryptedVirtualFilesystem { .. }
171 );
172 let active_peers = self.shared.active_peer_conns.lock();
173 let v_conn = if is_revfs {
174 let peer_cid = if session_cid != handle.source {
175 handle.source
176 } else {
177 handle.receiver
178 };
179 PeerConnectionType::LocalGroupPeer {
180 session_cid,
181 peer_cid,
182 }
183 } else {
184 if matches!(
185 handle.orientation,
186 ObjectTransferOrientation::Receiver { .. }
187 ) {
188 PeerConnectionType::LocalGroupPeer {
189 session_cid,
190 peer_cid: handle.source,
191 }
192 } else {
193 PeerConnectionType::LocalGroupPeer {
194 session_cid,
195 peer_cid: handle.receiver,
196 }
197 }
198 };
199
200 if let Some(peer_ctx) = active_peers.get(&v_conn) {
201 if let Err(err) = peer_ctx.send_file_transfer_tx.send(handle) {
202 log::warn!(target: "citadel", "Error forwarding file transfer handle: {:?}", err.to_string());
203 }
204 } else {
205 log::warn!(target: "citadel", "Unable to find key for inbound file transfer handle: {:?}\n Active Peers: {:?} \n handle_source = {}, handle_receiver = {}", v_conn, active_peers.keys().cloned().collect::<Vec<_>>(), handle.source, handle.receiver);
206 }
207
208 Ok(())
209 }
210
211 unprocessed @ NodeResult::Disconnect(..) | unprocessed => {
215 self.inner_kernel.on_node_event_received(unprocessed).await
217 }
218 }
219 }
220
221 async fn on_stop(&mut self) -> Result<(), NetworkError> {
222 self.inner_kernel.on_stop().await
223 }
224}
225
226#[derive(Debug, Default, Clone)]
229pub struct PeerConnectionSetupAggregator {
230 inner: Vec<PeerConnectionSettings>,
231}
232
233#[derive(Debug, Clone)]
234struct PeerConnectionSettings {
235 id: UserIdentifier,
236 session_security_settings: SessionSecuritySettings,
237 udp_mode: UdpMode,
238 ensure_registered: bool,
239 peer_session_password: Option<PreSharedKey>,
240}
241
242pub struct AddedPeer {
243 list: PeerConnectionSetupAggregator,
244 id: UserIdentifier,
245 session_security_settings: Option<SessionSecuritySettings>,
246 ensure_registered: bool,
247 udp_mode: Option<UdpMode>,
248 peer_session_password: Option<PreSharedKey>,
249}
250
251impl AddedPeer {
252 pub fn add(mut self) -> PeerConnectionSetupAggregator {
254 let new = PeerConnectionSettings {
255 id: self.id,
256 session_security_settings: self.session_security_settings.unwrap_or_default(),
257 udp_mode: self.udp_mode.unwrap_or_default(),
258 ensure_registered: self.ensure_registered,
259 peer_session_password: self.peer_session_password,
260 };
261
262 self.list.inner.push(new);
263 self.list
264 }
265
266 pub fn with_udp_mode(mut self, udp_mode: UdpMode) -> Self {
268 self.udp_mode = Some(udp_mode);
269 self
270 }
271
272 pub fn disable_udp(self) -> Self {
274 self.with_udp_mode(UdpMode::Disabled)
275 }
276
277 pub fn enable_udp(self) -> Self {
279 self.with_udp_mode(UdpMode::Enabled)
280 }
281
282 pub fn with_session_security_settings(
284 mut self,
285 session_security_settings: SessionSecuritySettings,
286 ) -> Self {
287 self.session_security_settings = Some(session_security_settings);
288 self
289 }
290
291 pub fn ensure_registered(mut self) -> Self {
293 self.ensure_registered = true;
294 self
295 }
296
297 pub fn with_session_password<T: Into<PreSharedKey>>(mut self, password: T) -> Self {
300 self.peer_session_password = Some(password.into());
301 self
302 }
303}
304
305impl PeerConnectionSetupAggregator {
306 pub fn with_peer<T: Into<UserIdentifier>>(self, peer: T) -> PeerConnectionSetupAggregator {
315 self.with_peer_custom(peer).add()
316 }
317
318 pub fn with_peer_custom<T: Into<UserIdentifier>>(self, peer: T) -> AddedPeer {
334 AddedPeer {
335 list: self,
336 id: peer.into(),
337 ensure_registered: false,
338 session_security_settings: None,
339 udp_mode: None,
340 peer_session_password: None,
341 }
342 }
343}
344
345impl From<PeerConnectionSetupAggregator> for Vec<PeerConnectionSettings> {
346 fn from(this: PeerConnectionSetupAggregator) -> Self {
347 this.inner
348 }
349}
350
351impl From<Vec<UserIdentifier>> for PeerConnectionSetupAggregator {
352 fn from(ids: Vec<UserIdentifier>) -> Self {
353 let mut this = PeerConnectionSetupAggregator::default();
354 for peer in ids {
355 this = this.with_peer(peer);
356 }
357
358 this
359 }
360}
361
362impl From<UserIdentifier> for PeerConnectionSetupAggregator {
363 fn from(this: UserIdentifier) -> Self {
364 Self::from(vec![this])
365 }
366}
367
368impl From<Uuid> for PeerConnectionSetupAggregator {
369 fn from(user: Uuid) -> Self {
370 let user_identifier: UserIdentifier = user.into();
371 user_identifier.into()
372 }
373}
374
375impl From<String> for PeerConnectionSetupAggregator {
376 fn from(this: String) -> Self {
377 let user_identifier: UserIdentifier = this.into();
378 user_identifier.into()
379 }
380}
381
382impl From<&str> for PeerConnectionSetupAggregator {
383 fn from(this: &str) -> Self {
384 let user_identifier: UserIdentifier = this.into();
385 user_identifier.into()
386 }
387}
388
389impl From<u64> for PeerConnectionSetupAggregator {
390 fn from(this: u64) -> Self {
391 let user_identifier: UserIdentifier = this.into();
392 user_identifier.into()
393 }
394}
395
396#[async_trait]
397impl<'a, F, Fut, T: Into<PeerConnectionSetupAggregator> + Send + 'a, R: Ratchet>
398 PrefabFunctions<'a, T, R> for PeerConnectionKernel<'a, F, Fut, R>
399where
400 F: FnOnce(
401 Receiver<Result<PeerConnectSuccess<R>, NetworkError>>,
402 CitadelClientServerConnection<R>,
403 ) -> Fut
404 + Send
405 + 'a,
406 Fut: Future<Output = Result<(), NetworkError>> + Send + 'a,
407{
408 type UserLevelInputFunction = F;
409 type SharedBundle = Shared;
410
411 fn get_shared_bundle(&self) -> Self::SharedBundle {
412 self.shared.clone()
413 }
414
415 #[allow(clippy::blocks_in_conditions)]
416 #[cfg_attr(
417 feature = "localhost-testing",
418 tracing::instrument(level = "trace", target = "citadel", skip_all, ret, err(Debug))
419 )]
420 async fn on_c2s_channel_received(
421 connect_success: CitadelClientServerConnection<R>,
422 peers_to_connect: T,
423 f: Self::UserLevelInputFunction,
424 shared: Shared,
425 ) -> Result<(), NetworkError> {
426 let shared = &shared;
427 let session_cid = connect_success.cid;
428 let mut peers_already_registered = vec![];
429
430 wait_for_peers().await;
431 let peers_to_connect = peers_to_connect.into().inner;
432
433 for peer in &peers_to_connect {
434 peers_already_registered.push(
436 peer.id
437 .search_peer(session_cid, connect_success.account_manager())
438 .await?,
439 )
440 }
441
442 let remote = connect_success.clone();
443 let (tx, rx) = citadel_io::tokio::sync::mpsc::channel(peers_to_connect.len());
444 let requests = FuturesUnordered::new();
445
446 for (mutually_registered, peer_to_connect) in
447 peers_already_registered.into_iter().zip(peers_to_connect)
448 {
449 let remote = remote.clone();
452 let tx = tx.clone();
453 let PeerConnectionSettings {
454 id,
455 session_security_settings,
456 udp_mode,
457 ensure_registered,
458 peer_session_password,
459 } = peer_to_connect;
460
461 let task = async move {
462 let inner_task = async move {
463 let (file_transfer_tx, file_transfer_rx) =
464 citadel_io::tokio::sync::mpsc::unbounded_channel();
465
466 let peer_cid = if let Some(mutual_peer) = &mutually_registered {
468 mutual_peer.cid
469 } else {
470 id.get_cid()
471 };
472
473 let handle = if let Some(_already_registered) = mutually_registered {
474 remote.find_target(session_cid, id).await?
475 } else {
476 log::info!(target: "citadel", "{session_cid} proposing target {id:?} to central node");
478 let handle = remote.propose_target(session_cid, id.clone()).await?;
479 if ensure_registered {
482 loop {
483 if handle.is_peer_registered().await? {
484 break;
485 }
486 citadel_io::tokio::time::sleep(std::time::Duration::from_millis(
487 200,
488 ))
489 .await;
490 }
491 }
492
493 log::info!(target: "citadel", "{session_cid} registering to peer {id:?}");
494 let _reg_success = handle.register_to_peer().await?;
495 log::info!(target: "citadel", "{session_cid} registered to peer {id:?} registered || success -> now connecting");
496 handle
497 };
498
499 let peer_conn = PeerConnectionType::LocalGroupPeer {
502 session_cid,
503 peer_cid,
504 };
505 let peer_context = PeerContext {
506 conn_type: peer_conn,
507 send_file_transfer_tx: file_transfer_tx.clone(),
508 };
509 log::debug!(target: "citadel", "Early registering peer connection: {peer_conn:?}");
510 let _ = shared
511 .active_peer_conns
512 .lock()
513 .insert(peer_conn, peer_context);
514
515 handle
516 .connect_to_peer_custom(
517 session_security_settings,
518 udp_mode,
519 peer_session_password,
520 )
521 .await
522 .map(|mut success| {
523 let actual_peer_conn = success.channel.get_peer_conn_type().unwrap();
524
525 if actual_peer_conn != peer_conn {
528 log::debug!(target: "citadel", "Updating peer connection registration from {peer_conn:?} to {actual_peer_conn:?}");
529 let mut active_peers = shared.active_peer_conns.lock();
530 if let Some(peer_ctx) = active_peers.remove(&peer_conn) {
531 let _ = active_peers.insert(actual_peer_conn, peer_ctx);
532 }
533 }
534 success.incoming_object_transfer_handles = Some(FileTransferHandleRx {
536 inner: file_transfer_rx,
537 conn_type: actual_peer_conn.as_virtual_connection(),
538 });
539 success
540 })
541 .inspect_err(|_err| {
542 let _ = shared.active_peer_conns.lock().remove(&peer_conn);
544 })
545 };
546
547 tx.send(inner_task.await)
548 .await
549 .map_err(|err| NetworkError::Generic(err.to_string()))
550 };
551
552 requests.push(Box::pin(task))
553 }
554
555 drop(tx);
558
559 let (collection_result, user_result) =
562 citadel_io::tokio::join!(requests.try_collect::<()>(), f(rx, connect_success));
563
564 collection_result?;
566 user_result
567 }
568
569 fn construct(kernel: Box<dyn NetKernel<R> + 'a>) -> Self {
570 Self {
571 inner_kernel: kernel,
572 shared: Shared {
573 active_peer_conns: Arc::new(Mutex::new(Default::default())),
574 },
575 _pd: Default::default(),
576 }
577 }
578}
579
580#[cfg(all(test, feature = "localhost-testing"))]
581mod tests {
582 use crate::prefabs::client::peer_connection::PeerConnectionKernel;
583 use crate::prefabs::client::DefaultServerConnectionSettingsBuilder;
584 use crate::prelude::*;
585 use crate::remote_ext::results::PeerConnectSuccess;
586 use crate::test_common::{server_info, wait_for_peers, TestBarrier};
587 use citadel_io::tokio;
588 use citadel_io::tokio::sync::mpsc::{Receiver, UnboundedSender};
589 use citadel_user::prelude::UserIdentifierExt;
590 use futures::stream::FuturesUnordered;
591 use futures::TryStreamExt;
592 use rstest::rstest;
593 use std::collections::HashMap;
594 use std::future::Future;
595 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
596 use std::time::Duration;
597 use uuid::Uuid;
598
599 lazy_static::lazy_static! {
600 pub static ref PEERS: Vec<(String, String, String)> = {
601 ["alpha", "beta", "charlie", "echo", "delta", "epsilon", "foxtrot"]
602 .iter().map(|base| (format!("{base}.username"), format!("{base}.password"), format!("{base}.full_name")))
603 .collect()
604 };
605 }
606
607 #[rstest]
608 #[case(2, UdpMode::Enabled)]
609 #[case(3, UdpMode::Disabled)]
610 #[timeout(Duration::from_secs(90))]
611 #[tokio::test(flavor = "multi_thread")]
612 async fn peer_to_peer_connect(#[case] peer_count: usize, #[case] udp_mode: UdpMode) {
613 assert!(peer_count > 1);
614 citadel_logging::setup_log();
615 TestBarrier::setup(peer_count);
616
617 let client_success = &AtomicUsize::new(0);
618 let (server, server_addr) = server_info::<StackedRatchet>();
619
620 let client_kernels = FuturesUnordered::new();
621 let total_peers = (0..peer_count)
622 .map(|idx| PEERS.get(idx).unwrap().0.clone())
623 .collect::<Vec<String>>();
624
625 for idx in 0..peer_count {
626 let (username, password, full_name) = PEERS.get(idx).unwrap();
627 let peers = total_peers
628 .clone()
629 .into_iter()
630 .filter(|r| r != username)
631 .map(UserIdentifier::Username)
632 .collect::<Vec<UserIdentifier>>();
633
634 let mut agg = PeerConnectionSetupAggregator::default();
635
636 for peer in peers {
637 agg = agg
638 .with_peer_custom(peer)
639 .ensure_registered()
640 .with_udp_mode(udp_mode)
641 .with_session_security_settings(SessionSecuritySettings::default())
642 .add();
643 }
644
645 let server_connection_settings =
646 DefaultServerConnectionSettingsBuilder::credentialed_registration(
647 server_addr,
648 username,
649 full_name,
650 password.as_str(),
651 )
652 .build()
653 .unwrap();
654
655 let username = username.clone();
656
657 let client_kernel = PeerConnectionKernel::new(
658 server_connection_settings,
659 agg.clone(),
660 move |results, connection| async move {
661 log::info!(target: "citadel", "***PEER {username} CONNECTED ***");
662 let session_cid = connection.conn_type.get_session_cid();
663 let check = move |conn: PeerConnectSuccess<_>| async move {
664 let session_cid = conn.channel.get_session_cid();
665 let _mutual_peers = conn
666 .remote
667 .remote()
668 .get_local_group_mutual_peers(session_cid)
669 .await
670 .unwrap();
671 conn
672 };
673 let p2p_remotes = handle_peer_connect_successes(
674 results,
675 session_cid,
676 peer_count,
677 udp_mode,
678 check,
679 )
680 .await
681 .into_iter()
682 .map(|r| (r.channel.get_peer_cid(), r.remote))
683 .collect::<HashMap<_, _>>();
684
685 let network_peers = connection.get_peers(None).await.unwrap();
689 for user in agg.inner {
690 let peer_cid = user.id.get_cid();
691 assert!(network_peers.iter().any(|r| r.cid == peer_cid))
692 }
693
694 let session_cid = connection.conn_type.get_session_cid();
696 let mutual_peers = connection
697 .get_local_group_mutual_peers(session_cid)
698 .await
699 .unwrap();
700 for (peer_cid, _) in p2p_remotes {
701 assert!(mutual_peers.iter().any(|r| r.cid == peer_cid))
702 }
703
704 log::info!(target: "citadel", "***PEER {username} finished all checks***");
705 let _ = client_success.fetch_add(1, Ordering::Relaxed);
706 wait_for_peers().await;
707 connection.shutdown_kernel().await
708 },
709 );
710
711 let client = DefaultNodeBuilder::default().build(client_kernel).unwrap();
712 client_kernels.push(async move { client.await.map(|_| ()) });
713 }
714
715 let clients = Box::pin(async move { client_kernels.try_collect::<()>().await.map(|_| ()) });
716
717 assert!(futures::future::try_select(server, clients).await.is_ok());
718
719 assert_eq!(client_success.load(Ordering::Relaxed), peer_count);
720 }
721
722 #[rstest]
723 #[case(2, HeaderObfuscatorSettings::default())]
724 #[case(2, HeaderObfuscatorSettings::Enabled)]
725 #[case(2, HeaderObfuscatorSettings::EnabledWithKey(12345))]
726 #[case(3, HeaderObfuscatorSettings::default())]
727 #[timeout(Duration::from_secs(90))]
728 #[tokio::test(flavor = "multi_thread")]
729 async fn peer_to_peer_connect_transient(
730 #[case] peer_count: usize,
731 #[case] header_obfuscator_settings: HeaderObfuscatorSettings,
732 ) -> Result<(), Box<dyn std::error::Error>> {
733 assert!(peer_count > 1);
734 citadel_logging::setup_log();
735 TestBarrier::setup(peer_count);
736 let udp_mode = UdpMode::Enabled;
737
738 let do_deregister = peer_count == 2;
739
740 let client_success = &AtomicUsize::new(0);
741 let (server, server_addr) = server_info::<StackedRatchet>();
742
743 let client_kernels = FuturesUnordered::new();
744 let total_peers = (0..peer_count)
745 .map(|_| Uuid::new_v4())
746 .collect::<Vec<Uuid>>();
747
748 for idx in 0..peer_count {
749 let uuid = total_peers.get(idx).cloned().unwrap();
750 let peers = total_peers
751 .clone()
752 .into_iter()
753 .filter(|r| r != &uuid)
754 .map(UserIdentifier::from)
755 .collect::<Vec<UserIdentifier>>();
756
757 let mut agg = PeerConnectionSetupAggregator::default();
758
759 for peer in peers {
760 let security_settings = SessionSecuritySettings {
761 header_obfuscator_settings,
762 ..Default::default()
763 };
764 agg = agg
765 .with_peer_custom(peer)
766 .with_udp_mode(udp_mode)
767 .ensure_registered()
768 .with_session_security_settings(security_settings)
769 .add();
770 }
771
772 let server_connection_settings =
773 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid)
774 .build()
775 .unwrap();
776
777 let client_kernel = PeerConnectionKernel::new(
778 server_connection_settings,
779 agg,
780 move |results, remote| async move {
781 log::info!(target: "citadel", "***PEER {uuid} CONNECTED***");
782 let session_cid = remote.conn_type.get_session_cid();
783
784 let check = move |conn: PeerConnectSuccess<_>| async move {
785 if do_deregister {
786 conn.remote
787 .deregister()
788 .await
789 .expect("Deregistration failed");
790 assert!(!conn
791 .remote
792 .inner
793 .account_manager()
794 .get_persistence_handler()
795 .hyperlan_peer_exists(session_cid, conn.channel.get_peer_cid())
796 .await
797 .unwrap());
798 }
799 conn
800 };
801
802 let _ = handle_peer_connect_successes(
803 results,
804 session_cid,
805 peer_count,
806 udp_mode,
807 check,
808 )
809 .await;
810
811 log::info!(target: "citadel", "***PEER {uuid} finished all checks***");
812 let _ = client_success.fetch_add(1, Ordering::Relaxed);
813 wait_for_peers().await;
814 remote.shutdown_kernel().await
815 },
816 );
817
818 let client = DefaultNodeBuilder::default().build(client_kernel)?;
819 client_kernels.push(async move { client.await.map(|_| ()) });
820 }
821
822 let clients = Box::pin(async move { client_kernels.try_collect::<()>().await.map(|_| ()) });
823
824 if let Err(err) = futures::future::try_select(server, clients).await {
825 return match err {
826 futures::future::Either::Left(res) => Err(res.0.into_string().into()),
827 futures::future::Either::Right(res) => Err(res.0.into_string().into()),
828 };
829 }
830
831 assert_eq!(client_success.load(Ordering::Relaxed), peer_count);
832 Ok(())
833 }
834
835 #[rstest]
836 #[case(2)]
837 #[case(3)]
838 #[timeout(std::time::Duration::from_secs(180))]
839 #[tokio::test(flavor = "multi_thread")]
840 async fn test_peer_to_peer_file_transfer(
841 #[case] peer_count: usize,
842 ) -> Result<(), Box<dyn std::error::Error>> {
843 assert!(peer_count > 1);
844 citadel_logging::setup_log();
845 TestBarrier::setup(peer_count);
846 let udp_mode = UdpMode::Enabled;
847
848 let sender_success = &AtomicBool::new(false);
849 let receiver_success = &AtomicBool::new(false);
850
851 let (server, server_addr) = server_info::<StackedRatchet>();
852
853 let client_kernels = FuturesUnordered::new();
854 let total_peers = (0..peer_count)
855 .map(|_| Uuid::new_v4())
856 .collect::<Vec<Uuid>>();
857
858 let sender_uuid = total_peers[0];
859
860 for idx in 0..peer_count {
861 let uuid = total_peers.get(idx).cloned().unwrap();
862 let mut peers = total_peers
863 .clone()
864 .into_iter()
865 .filter(|r| r != &uuid)
866 .map(UserIdentifier::from)
867 .collect::<Vec<UserIdentifier>>();
868 if idx != 0 {
874 peers = vec![sender_uuid.into()];
875 }
876
877 let mut agg = PeerConnectionSetupAggregator::default();
878
879 for peer in peers {
880 agg = agg
881 .with_peer_custom(peer)
882 .ensure_registered()
883 .with_udp_mode(udp_mode)
884 .with_session_security_settings(SessionSecuritySettings::default())
885 .add();
886 }
887
888 let server_connection_settings =
889 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid)
890 .build()
891 .unwrap();
892
893 let client_kernel = PeerConnectionKernel::new(
894 server_connection_settings,
895 agg,
896 move |results, remote| async move {
897 log::info!(target: "citadel", "***PEER {uuid} CONNECTED***");
898 wait_for_peers().await;
899 let session_cid = remote.conn_type.get_session_cid();
900 let is_sender = idx == 0; let check = move |mut conn: PeerConnectSuccess<_>| async move {
902 if is_sender {
903 conn.remote
904 .send_file_with_custom_opts(
905 "../resources/TheBridge.pdf",
906 32 * 1024,
907 TransferType::FileTransfer,
908 )
909 .await
910 .expect("Failed to send file");
911 } else {
912 let mut handle = conn
914 .incoming_object_transfer_handles
915 .take()
916 .unwrap()
917 .recv()
918 .await
919 .unwrap();
920 handle.accept().unwrap();
921
922 use citadel_types::proto::ObjectTransferStatus;
923 use futures::StreamExt;
924 let mut path = None;
925 while let Some(status) = handle.next().await {
926 match status {
927 ObjectTransferStatus::ReceptionComplete => {
928 let cmp =
929 include_bytes!("../../../../resources/TheBridge.pdf");
930 let streamed_data =
931 citadel_io::tokio::fs::read(path.clone().unwrap())
932 .await
933 .unwrap();
934 assert_eq!(
935 cmp,
936 streamed_data.as_slice(),
937 "Original data and streamed data does not match"
938 );
939
940 log::info!(target: "citadel", "Peer has finished receiving and verifying the file!");
941 break;
942 }
943
944 ObjectTransferStatus::ReceptionBeginning(file_path, vfm) => {
945 path = Some(file_path);
946 assert_eq!(vfm.name, "TheBridge.pdf")
947 }
948
949 _ => {}
950 }
951 }
952 }
953
954 conn
955 };
956 let peer_count = if idx == 0 { peer_count } else { 2 };
959 let _ = handle_peer_connect_successes(
960 results,
961 session_cid,
962 peer_count,
963 udp_mode,
964 check,
965 )
966 .await;
967
968 if is_sender {
969 sender_success.store(true, Ordering::Relaxed);
970 } else {
971 receiver_success.store(true, Ordering::Relaxed);
972 }
973
974 log::info!(target: "citadel", "***PEER {uuid} (is_sender: {is_sender}) finished all checks***");
975 wait_for_peers().await;
976 log::info!(target: "citadel", "***PEER {uuid} (is_sender: {is_sender}) shutting down***");
977 remote.shutdown_kernel().await
978 },
979 );
980
981 let client = DefaultNodeBuilder::default().build(client_kernel).unwrap();
982 client_kernels.push(async move { client.await.map(|_| ()) });
983 }
984
985 let clients = Box::pin(async move { client_kernels.try_collect::<()>().await.map(|_| ()) });
986
987 if let Err(err) = futures::future::try_select(server, clients).await {
988 return match err {
989 futures::future::Either::Left(res) => Err(res.0.into_string().into()),
990 futures::future::Either::Right(res) => Err(res.0.into_string().into()),
991 };
992 }
993
994 assert!(sender_success.load(Ordering::Relaxed));
995 assert!(receiver_success.load(Ordering::Relaxed));
996 Ok(())
997 }
998
999 #[rstest]
1000 #[case(2)]
1001 #[timeout(std::time::Duration::from_secs(90))]
1002 #[tokio::test(flavor = "multi_thread")]
1003 async fn test_peer_to_peer_rekey(
1004 #[case] peer_count: usize,
1005 ) -> Result<(), Box<dyn std::error::Error>> {
1006 assert!(peer_count > 1);
1007 citadel_logging::setup_log();
1008 TestBarrier::setup(peer_count);
1009 let udp_mode = UdpMode::Enabled;
1010
1011 let client_success = &AtomicUsize::new(0);
1012 let (server, server_addr) = server_info::<StackedRatchet>();
1013
1014 let client_kernels = FuturesUnordered::new();
1015 let total_peers = (0..peer_count)
1016 .map(|_| Uuid::new_v4())
1017 .collect::<Vec<Uuid>>();
1018
1019 for idx in 0..peer_count {
1020 let uuid = total_peers.get(idx).cloned().unwrap();
1021 let peers = total_peers
1022 .clone()
1023 .into_iter()
1024 .filter(|r| r != &uuid)
1025 .map(UserIdentifier::from)
1026 .collect::<Vec<UserIdentifier>>();
1027
1028 let mut agg = PeerConnectionSetupAggregator::default();
1029
1030 for peer in peers {
1031 agg = agg
1032 .with_peer_custom(peer)
1033 .ensure_registered()
1034 .with_udp_mode(udp_mode)
1035 .with_session_security_settings(SessionSecuritySettings::default())
1036 .add();
1037 }
1038
1039 let server_connection_settings =
1040 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid)
1041 .build()
1042 .unwrap();
1043
1044 let client_kernel = PeerConnectionKernel::new(
1045 server_connection_settings,
1046 agg,
1047 move |results, remote| async move {
1048 log::info!(target: "citadel", "***PEER {uuid} CONNECTED***");
1049 let session_cid = remote.conn_type.get_session_cid();
1050
1051 let check = move |conn: PeerConnectSuccess<_>| async move {
1052 if idx == 0 {
1053 for x in 1..10 {
1054 assert_eq!(
1055 conn.remote.rekey().await.expect("Failed to rekey"),
1056 Some(x)
1057 );
1058 }
1059 }
1060
1061 conn
1062 };
1063
1064 let results = handle_peer_connect_successes(
1065 results,
1066 session_cid,
1067 peer_count,
1068 udp_mode,
1069 check,
1070 )
1071 .await;
1072
1073 log::info!(target: "citadel", "***PEER {uuid} finished all check (count: {})s***", results.len());
1074 let _ = client_success.fetch_add(1, Ordering::Relaxed);
1075 wait_for_peers().await;
1076 remote.shutdown_kernel().await
1077 },
1078 );
1079
1080 let client = DefaultNodeBuilder::default().build(client_kernel)?;
1081 client_kernels.push(async move { client.await.map(|_| ()) });
1082 }
1083
1084 let clients = Box::pin(async move { client_kernels.try_collect::<()>().await.map(|_| ()) });
1085
1086 if let Err(err) = futures::future::try_select(server, clients).await {
1087 return match err {
1088 futures::future::Either::Left(res) => Err(res.0.into_string().into()),
1089 futures::future::Either::Right(res) => Err(res.0.into_string().into()),
1090 };
1091 }
1092
1093 assert_eq!(client_success.load(Ordering::Relaxed), peer_count);
1094 Ok(())
1095 }
1096
1097 #[rstest]
1098 #[case(2)]
1099 #[timeout(std::time::Duration::from_secs(90))]
1100 #[tokio::test(flavor = "multi_thread")]
1101 async fn test_peer_to_peer_disconnect(
1102 #[case] peer_count: usize,
1103 ) -> Result<(), Box<dyn std::error::Error>> {
1104 assert!(peer_count > 1);
1105 citadel_logging::setup_log();
1106 TestBarrier::setup(peer_count);
1107 let udp_mode = UdpMode::Enabled;
1108
1109 let client_success = &AtomicUsize::new(0);
1110 let (server, server_addr) = server_info::<StackedRatchet>();
1111
1112 let client_kernels = FuturesUnordered::new();
1113 let total_peers = (0..peer_count)
1114 .map(|_| Uuid::new_v4())
1115 .collect::<Vec<Uuid>>();
1116
1117 for idx in 0..peer_count {
1118 let uuid = total_peers.get(idx).cloned().unwrap();
1119 let peers = total_peers
1120 .clone()
1121 .into_iter()
1122 .filter(|r| r != &uuid)
1123 .map(UserIdentifier::from)
1124 .collect::<Vec<UserIdentifier>>();
1125
1126 let mut agg = PeerConnectionSetupAggregator::default();
1127
1128 for peer in peers {
1129 agg = agg
1130 .with_peer_custom(peer)
1131 .ensure_registered()
1132 .with_udp_mode(udp_mode)
1133 .with_session_security_settings(SessionSecuritySettings::default())
1134 .add();
1135 }
1136
1137 let server_connection_settings =
1138 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid)
1139 .build()
1140 .unwrap();
1141
1142 let client_kernel = PeerConnectionKernel::new(
1143 server_connection_settings,
1144 agg,
1145 move |results, remote| async move {
1146 log::info!(target: "citadel", "***PEER {uuid} CONNECTED***");
1147 wait_for_peers().await;
1148 let session_cid = remote.conn_type.get_session_cid();
1149
1150 let check = move |conn: PeerConnectSuccess<_>| async move {
1151 conn.remote
1152 .disconnect()
1153 .await
1154 .expect("Failed to p2p disconnect");
1155 conn
1156 };
1157 let _ = handle_peer_connect_successes(
1158 results,
1159 session_cid,
1160 peer_count,
1161 udp_mode,
1162 check,
1163 )
1164 .await;
1165 log::info!(target: "citadel", "***PEER {uuid} finished all checks***");
1166
1167 let _ = client_success.fetch_add(1, Ordering::Relaxed);
1168 wait_for_peers().await;
1169 remote.shutdown_kernel().await
1170 },
1171 );
1172
1173 let client = DefaultNodeBuilder::default().build(client_kernel)?;
1174 client_kernels.push(async move { client.await.map(|_| ()) });
1175 }
1176
1177 let clients = Box::pin(async move { client_kernels.try_collect::<()>().await.map(|_| ()) });
1178
1179 if let Err(err) = futures::future::try_select(server, clients).await {
1180 return match err {
1181 futures::future::Either::Left(res) => Err(res.0.into_string().into()),
1182 futures::future::Either::Right(res) => Err(res.0.into_string().into()),
1183 };
1184 }
1185
1186 assert_eq!(client_success.load(Ordering::Relaxed), peer_count);
1187 Ok(())
1188 }
1189
1190 #[rstest]
1191 #[case(SecrecyMode::BestEffort, Some("test-p2p-password"))]
1192 #[timeout(std::time::Duration::from_secs(240))]
1193 #[citadel_io::tokio::test(flavor = "multi_thread")]
1194 async fn test_p2p_wrong_session_password(
1195 #[case] secrecy_mode: SecrecyMode,
1196 #[case] p2p_password: Option<&'static str>,
1197 #[values(KemAlgorithm::Kyber)] kem: KemAlgorithm,
1198 #[values(EncryptionAlgorithm::AES_GCM_256)] enx: EncryptionAlgorithm,
1199 ) {
1200 citadel_logging::setup_log_no_panic_hook();
1201 TestBarrier::setup(2);
1202 let (server, server_addr) = server_info::<StackedRatchet>();
1203 let peer_0_error_received = &AtomicBool::new(false);
1204 let peer_1_error_received = &AtomicBool::new(false);
1205
1206 let uuid0 = Uuid::new_v4();
1207 let uuid1 = Uuid::new_v4();
1208 let session_security = SessionSecuritySettingsBuilder::default()
1209 .with_secrecy_mode(secrecy_mode)
1210 .with_crypto_params(kem + enx)
1211 .build()
1212 .unwrap();
1213
1214 let mut peer0_agg = PeerConnectionSetupAggregator::default()
1215 .with_peer_custom(uuid1)
1216 .ensure_registered()
1217 .with_session_security_settings(session_security);
1218
1219 if let Some(password) = p2p_password {
1220 peer0_agg = peer0_agg.with_session_password(password);
1221 }
1222
1223 let peer0_connection = peer0_agg.add();
1224
1225 let mut peer1_agg = PeerConnectionSetupAggregator::default()
1226 .with_peer_custom(uuid0)
1227 .ensure_registered()
1228 .with_session_security_settings(session_security);
1229
1230 if let Some(_password) = p2p_password {
1231 peer1_agg = peer1_agg.with_session_password("wrong password");
1232 }
1233
1234 let peer1_connection = peer1_agg.add();
1235
1236 let server_connection_settings0 =
1237 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid0)
1238 .with_udp_mode(UdpMode::Enabled)
1239 .with_session_security_settings(session_security)
1240 .build()
1241 .unwrap();
1242
1243 let server_connection_settings1 =
1244 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid1)
1245 .with_udp_mode(UdpMode::Enabled)
1246 .with_session_security_settings(session_security)
1247 .build()
1248 .unwrap();
1249
1250 let client_kernel0 = PeerConnectionKernel::new(
1251 server_connection_settings0,
1252 peer0_connection,
1253 move |mut connection, remote| async move {
1254 wait_for_peers().await;
1255 let conn = connection.recv().await.unwrap();
1256 log::trace!(target: "citadel", "Peer 0 {} received: {:?}", remote.conn_type.get_session_cid(), conn);
1257 if conn.is_ok() {
1258 peer_0_error_received.store(true, Ordering::SeqCst);
1259 }
1260 wait_for_peers().await;
1261 remote.shutdown_kernel().await
1262 },
1263 );
1264
1265 let client_kernel1 = PeerConnectionKernel::new(
1266 server_connection_settings1,
1267 peer1_connection,
1268 move |mut connection, remote| async move {
1269 wait_for_peers().await;
1270 let conn = connection.recv().await.unwrap();
1271 log::trace!(target: "citadel", "Peer 1 {} received: {:?}", remote.conn_type.get_session_cid(), conn);
1272 if conn.is_ok() {
1273 peer_1_error_received.store(true, Ordering::SeqCst);
1274 }
1275 wait_for_peers().await;
1276 remote.shutdown_kernel().await
1277 },
1278 );
1279
1280 let client0 = DefaultNodeBuilder::default().build(client_kernel0).unwrap();
1281 let client1 = DefaultNodeBuilder::default().build(client_kernel1).unwrap();
1282 let clients = futures::future::try_join(client0, client1);
1283
1284 let task = async move {
1285 tokio::select! {
1286 server_res = server => Err(NetworkError::msg(format!("Server ended prematurely: {:?}", server_res.map(|_| ())))),
1287 client_res = clients => client_res.map(|_| ())
1288 }
1289 };
1290
1291 tokio::time::timeout(Duration::from_secs(120), task)
1292 .await
1293 .unwrap()
1294 .unwrap();
1295
1296 assert!(!peer_0_error_received.load(Ordering::SeqCst));
1297 assert!(!peer_1_error_received.load(Ordering::SeqCst));
1298 }
1299
1300 async fn handle_peer_connect_successes<F, Fut, R: Ratchet>(
1301 mut conn_rx: Receiver<Result<PeerConnectSuccess<R>, NetworkError>>,
1302 session_cid: u64,
1303 peer_count: usize,
1304 udp_mode: UdpMode,
1305 checks: F,
1306 ) -> Vec<PeerConnectSuccess<R>>
1307 where
1308 F: Fn(PeerConnectSuccess<R>) -> Fut + Send + Clone + 'static,
1309 Fut: Future<Output = PeerConnectSuccess<R>> + Send,
1310 {
1311 let (finished_tx, finished_rx) = tokio::sync::oneshot::channel();
1312
1313 let task = async move {
1314 let (done_tx, mut done_rx) = tokio::sync::mpsc::unbounded_channel();
1315 let mut conns = vec![];
1316 while let Some(conn) = conn_rx.recv().await {
1317 conns.push(conn);
1318 if conns.len() == peer_count - 1 {
1319 break;
1320 }
1321 }
1322
1323 log::info!(target: "citadel", "~~~*** Peer {session_cid} has {} connections to other peers ***~~~", conns.len());
1324
1325 for conn in conns {
1326 let conn = conn.expect("Error receiving peer connection");
1327 handle_peer_connect_success(
1328 conn,
1329 done_tx.clone(),
1330 session_cid,
1331 udp_mode,
1332 checks.clone(),
1333 );
1334 }
1335
1336 let mut ret = vec![];
1338 while let Some(done) = done_rx.recv().await {
1339 ret.push(done);
1340 if ret.len() == peer_count - 1 {
1341 break;
1342 }
1343 }
1344
1345 finished_tx
1346 .send(ret)
1347 .expect("Error sending finished signal in handle_peer_connect_successes");
1348 };
1349
1350 drop(tokio::task::spawn(task));
1351 let ret = finished_rx
1352 .await
1353 .expect("Error receiving finished signal in handle_peer_connect_successes");
1354
1355 assert_eq!(ret.len(), peer_count - 1);
1356 ret
1357 }
1358
1359 fn handle_peer_connect_success<F, Fut, R: Ratchet>(
1360 mut conn: PeerConnectSuccess<R>,
1361 done_tx: UnboundedSender<PeerConnectSuccess<R>>,
1362 session_cid: u64,
1363 udp_mode: UdpMode,
1364 checks: F,
1365 ) where
1366 F: Fn(PeerConnectSuccess<R>) -> Fut + Send + Clone + 'static,
1367 Fut: Future<Output = PeerConnectSuccess<R>> + Send,
1368 {
1369 let task = async move {
1370 let chan = conn.udp_channel_rx.take();
1371 crate::test_common::p2p_assertions(session_cid, &conn).await;
1372 crate::test_common::udp_mode_assertions(udp_mode, chan).await;
1373 let conn = checks(conn).await;
1374 done_tx
1375 .send(conn)
1376 .expect("Error sending done signal in handle_peer_connect_success");
1377 };
1378
1379 drop(tokio::task::spawn(task));
1380 }
1381}