1use crate::prelude::results::PeerConnectSuccess;
68use crate::prelude::*;
69use crate::test_common::wait_for_peers;
70use citadel_io::tokio::sync::mpsc::{Receiver, UnboundedSender};
71use citadel_io::{tokio, Mutex};
72use citadel_proto::re_imports::async_trait;
73use citadel_user::hypernode_account::UserIdentifierExt;
74use futures::stream::FuturesUnordered;
75use futures::TryStreamExt;
76use std::collections::HashMap;
77use std::future::Future;
78use std::marker::PhantomData;
79use std::sync::Arc;
80use uuid::Uuid;
81
82pub struct PeerConnectionKernel<'a, F, Fut, R: Ratchet> {
85 inner_kernel: Box<dyn NetKernel<R> + 'a>,
86 shared: Shared,
87 _pd: PhantomData<fn() -> (F, Fut)>,
89}
90
91#[derive(Clone)]
92#[doc(hidden)]
93pub struct Shared {
94 active_peer_conns: Arc<Mutex<HashMap<PeerConnectionType, PeerContext>>>,
95}
96
97struct PeerContext {
98 #[allow(dead_code)]
99 conn_type: PeerConnectionType,
100 send_file_transfer_tx: UnboundedSender<ObjectTransferHandler>,
101}
102
103#[derive(Debug)]
104pub struct FileTransferHandleRx {
105 pub inner: citadel_io::tokio::sync::mpsc::UnboundedReceiver<ObjectTransferHandler>,
106 pub conn_type: VirtualTargetType,
107}
108
109impl FileTransferHandleRx {
110 pub fn accept_all(mut self) {
112 let task = tokio::task::spawn(async move {
113 let rx = &mut self.inner;
114 while let Some(mut handle) = rx.recv().await {
115 let task = tokio::task::spawn(async move {
116 if let Err(err) = handle.exhaust_stream().await {
117 let orientation = handle.orientation;
118 log::warn!(target: "citadel", "Error background handling of file transfer for {orientation:?}: {err:?}");
119 }
120 });
121
122 drop(task);
123 }
124 });
125
126 drop(task);
127 }
128}
129
130impl std::ops::Deref for FileTransferHandleRx {
131 type Target = citadel_io::tokio::sync::mpsc::UnboundedReceiver<ObjectTransferHandler>;
132
133 fn deref(&self) -> &Self::Target {
134 &self.inner
135 }
136}
137
138impl std::ops::DerefMut for FileTransferHandleRx {
139 fn deref_mut(&mut self) -> &mut Self::Target {
140 &mut self.inner
141 }
142}
143
144impl Drop for FileTransferHandleRx {
145 fn drop(&mut self) {
146 log::trace!(target: "citadel", "Dropping file transfer handle receiver {:?}", self.conn_type);
147 }
148}
149
150#[async_trait]
151impl<F, Fut, R: Ratchet> NetKernel<R> for PeerConnectionKernel<'_, F, Fut, R> {
152 fn load_remote(&mut self, server_remote: NodeRemote<R>) -> Result<(), NetworkError> {
153 self.inner_kernel.load_remote(server_remote)
154 }
155
156 async fn on_start(&self) -> Result<(), NetworkError> {
157 self.inner_kernel.on_start().await
158 }
159
160 #[allow(clippy::collapsible_else_if)]
161 async fn on_node_event_received(&self, message: NodeResult<R>) -> Result<(), NetworkError> {
162 match message {
163 NodeResult::ObjectTransferHandle(ObjectTransferHandle {
164 ticket: _,
165 handle,
166 session_cid,
167 }) => {
168 let is_revfs = matches!(
169 handle.metadata.transfer_type,
170 TransferType::RemoteEncryptedVirtualFilesystem { .. }
171 );
172 let active_peers = self.shared.active_peer_conns.lock();
173 let v_conn = if is_revfs {
174 let peer_cid = if session_cid != handle.source {
175 handle.source
176 } else {
177 handle.receiver
178 };
179 PeerConnectionType::LocalGroupPeer {
180 session_cid,
181 peer_cid,
182 }
183 } else {
184 if matches!(
185 handle.orientation,
186 ObjectTransferOrientation::Receiver { .. }
187 ) {
188 PeerConnectionType::LocalGroupPeer {
189 session_cid,
190 peer_cid: handle.source,
191 }
192 } else {
193 PeerConnectionType::LocalGroupPeer {
194 session_cid,
195 peer_cid: handle.receiver,
196 }
197 }
198 };
199
200 if let Some(peer_ctx) = active_peers.get(&v_conn) {
201 if let Err(err) = peer_ctx.send_file_transfer_tx.send(handle) {
202 log::warn!(target: "citadel", "Error forwarding file transfer handle: {:?}", err.to_string());
203 }
204 } else {
205 log::warn!(target: "citadel", "Unable to find key for inbound file transfer handle: {:?}\n Active Peers: {:?} \n handle_source = {}, handle_receiver = {}", v_conn, active_peers.keys().cloned().collect::<Vec<_>>(), handle.source, handle.receiver);
206 }
207
208 Ok(())
209 }
210
211 NodeResult::Disconnect(Disconnect {
212 ticket: _,
213 cid_opt: _,
214 success: _,
215 v_conn_type: Some(v_conn),
216 ..
217 }) => {
218 if let Some(v_conn) = v_conn.try_as_peer_connection() {
219 let mut active_peers = self.shared.active_peer_conns.lock();
220 let _ = active_peers.remove(&v_conn);
221 }
222
223 Ok(())
224 }
225
226 unprocessed => {
227 self.inner_kernel.on_node_event_received(unprocessed).await
229 }
230 }
231 }
232
233 async fn on_stop(&mut self) -> Result<(), NetworkError> {
234 self.inner_kernel.on_stop().await
235 }
236}
237
238#[derive(Debug, Default, Clone)]
241pub struct PeerConnectionSetupAggregator {
242 inner: Vec<PeerConnectionSettings>,
243}
244
245#[derive(Debug, Clone)]
246struct PeerConnectionSettings {
247 id: UserIdentifier,
248 session_security_settings: SessionSecuritySettings,
249 udp_mode: UdpMode,
250 ensure_registered: bool,
251 peer_session_password: Option<PreSharedKey>,
252}
253
254pub struct AddedPeer {
255 list: PeerConnectionSetupAggregator,
256 id: UserIdentifier,
257 session_security_settings: Option<SessionSecuritySettings>,
258 ensure_registered: bool,
259 udp_mode: Option<UdpMode>,
260 peer_session_password: Option<PreSharedKey>,
261}
262
263impl AddedPeer {
264 pub fn add(mut self) -> PeerConnectionSetupAggregator {
266 let new = PeerConnectionSettings {
267 id: self.id,
268 session_security_settings: self.session_security_settings.unwrap_or_default(),
269 udp_mode: self.udp_mode.unwrap_or_default(),
270 ensure_registered: self.ensure_registered,
271 peer_session_password: self.peer_session_password,
272 };
273
274 self.list.inner.push(new);
275 self.list
276 }
277
278 pub fn with_udp_mode(mut self, udp_mode: UdpMode) -> Self {
280 self.udp_mode = Some(udp_mode);
281 self
282 }
283
284 pub fn disable_udp(self) -> Self {
286 self.with_udp_mode(UdpMode::Disabled)
287 }
288
289 pub fn enable_udp(self) -> Self {
291 self.with_udp_mode(UdpMode::Enabled)
292 }
293
294 pub fn with_session_security_settings(
296 mut self,
297 session_security_settings: SessionSecuritySettings,
298 ) -> Self {
299 self.session_security_settings = Some(session_security_settings);
300 self
301 }
302
303 pub fn ensure_registered(mut self) -> Self {
305 self.ensure_registered = true;
306 self
307 }
308
309 pub fn with_session_password<T: Into<PreSharedKey>>(mut self, password: T) -> Self {
312 self.peer_session_password = Some(password.into());
313 self
314 }
315}
316
317impl PeerConnectionSetupAggregator {
318 pub fn with_peer<T: Into<UserIdentifier>>(self, peer: T) -> PeerConnectionSetupAggregator {
327 self.with_peer_custom(peer).add()
328 }
329
330 pub fn with_peer_custom<T: Into<UserIdentifier>>(self, peer: T) -> AddedPeer {
346 AddedPeer {
347 list: self,
348 id: peer.into(),
349 ensure_registered: false,
350 session_security_settings: None,
351 udp_mode: None,
352 peer_session_password: None,
353 }
354 }
355}
356
357impl From<PeerConnectionSetupAggregator> for Vec<PeerConnectionSettings> {
358 fn from(this: PeerConnectionSetupAggregator) -> Self {
359 this.inner
360 }
361}
362
363impl From<Vec<UserIdentifier>> for PeerConnectionSetupAggregator {
364 fn from(ids: Vec<UserIdentifier>) -> Self {
365 let mut this = PeerConnectionSetupAggregator::default();
366 for peer in ids {
367 this = this.with_peer(peer);
368 }
369
370 this
371 }
372}
373
374impl From<UserIdentifier> for PeerConnectionSetupAggregator {
375 fn from(this: UserIdentifier) -> Self {
376 Self::from(vec![this])
377 }
378}
379
380impl From<Uuid> for PeerConnectionSetupAggregator {
381 fn from(user: Uuid) -> Self {
382 let user_identifier: UserIdentifier = user.into();
383 user_identifier.into()
384 }
385}
386
387impl From<String> for PeerConnectionSetupAggregator {
388 fn from(this: String) -> Self {
389 let user_identifier: UserIdentifier = this.into();
390 user_identifier.into()
391 }
392}
393
394impl From<&str> for PeerConnectionSetupAggregator {
395 fn from(this: &str) -> Self {
396 let user_identifier: UserIdentifier = this.into();
397 user_identifier.into()
398 }
399}
400
401impl From<u64> for PeerConnectionSetupAggregator {
402 fn from(this: u64) -> Self {
403 let user_identifier: UserIdentifier = this.into();
404 user_identifier.into()
405 }
406}
407
408#[async_trait]
409impl<'a, F, Fut, T: Into<PeerConnectionSetupAggregator> + Send + 'a, R: Ratchet>
410 PrefabFunctions<'a, T, R> for PeerConnectionKernel<'a, F, Fut, R>
411where
412 F: FnOnce(
413 Receiver<Result<PeerConnectSuccess<R>, NetworkError>>,
414 CitadelClientServerConnection<R>,
415 ) -> Fut
416 + Send
417 + 'a,
418 Fut: Future<Output = Result<(), NetworkError>> + Send + 'a,
419{
420 type UserLevelInputFunction = F;
421 type SharedBundle = Shared;
422
423 fn get_shared_bundle(&self) -> Self::SharedBundle {
424 self.shared.clone()
425 }
426
427 #[allow(clippy::blocks_in_conditions)]
428 #[cfg_attr(
429 feature = "localhost-testing",
430 tracing::instrument(level = "trace", target = "citadel", skip_all, ret, err(Debug))
431 )]
432 async fn on_c2s_channel_received(
433 connect_success: CitadelClientServerConnection<R>,
434 peers_to_connect: T,
435 f: Self::UserLevelInputFunction,
436 shared: Shared,
437 ) -> Result<(), NetworkError> {
438 let shared = &shared;
439 let session_cid = connect_success.cid;
440 let mut peers_already_registered = vec![];
441
442 wait_for_peers().await;
443 let peers_to_connect = peers_to_connect.into().inner;
444
445 for peer in &peers_to_connect {
446 peers_already_registered.push(
448 peer.id
449 .search_peer(session_cid, connect_success.account_manager())
450 .await?,
451 )
452 }
453
454 let remote = connect_success.clone();
455 let (ref tx, rx) = citadel_io::tokio::sync::mpsc::channel(peers_to_connect.len());
456 let requests = FuturesUnordered::new();
457
458 for (mutually_registered, peer_to_connect) in
459 peers_already_registered.into_iter().zip(peers_to_connect)
460 {
461 let remote = remote.clone();
464 let PeerConnectionSettings {
465 id,
466 session_security_settings,
467 udp_mode,
468 ensure_registered,
469 peer_session_password,
470 } = peer_to_connect;
471
472 let task = async move {
473 let inner_task = async move {
474 let (file_transfer_tx, file_transfer_rx) =
475 citadel_io::tokio::sync::mpsc::unbounded_channel();
476
477 let peer_cid = if let Some(mutual_peer) = &mutually_registered {
479 mutual_peer.cid
480 } else {
481 id.get_cid()
482 };
483
484 let handle = if let Some(_already_registered) = mutually_registered {
485 remote.find_target(session_cid, id).await?
486 } else {
487 log::info!(target: "citadel", "{session_cid} proposing target {id:?} to central node");
489 let handle = remote.propose_target(session_cid, id.clone()).await?;
490 if ensure_registered {
493 loop {
494 if handle.is_peer_registered().await? {
495 break;
496 }
497 citadel_io::tokio::time::sleep(std::time::Duration::from_millis(
498 200,
499 ))
500 .await;
501 }
502 }
503
504 log::info!(target: "citadel", "{session_cid} registering to peer {id:?}");
505 let _reg_success = handle.register_to_peer().await?;
506 log::info!(target: "citadel", "{session_cid} registered to peer {id:?} registered || success -> now connecting");
507 handle
508 };
509
510 let peer_conn = PeerConnectionType::LocalGroupPeer {
513 session_cid,
514 peer_cid,
515 };
516 let peer_context = PeerContext {
517 conn_type: peer_conn,
518 send_file_transfer_tx: file_transfer_tx.clone(),
519 };
520 log::debug!(target: "citadel", "Early registering peer connection: {peer_conn:?}");
521 let _ = shared
522 .active_peer_conns
523 .lock()
524 .insert(peer_conn, peer_context);
525
526 handle
527 .connect_to_peer_custom(
528 session_security_settings,
529 udp_mode,
530 peer_session_password,
531 )
532 .await
533 .map(|mut success| {
534 let actual_peer_conn = success.channel.get_peer_conn_type().unwrap();
535
536 if actual_peer_conn != peer_conn {
539 log::debug!(target: "citadel", "Updating peer connection registration from {peer_conn:?} to {actual_peer_conn:?}");
540 let mut active_peers = shared.active_peer_conns.lock();
541 if let Some(peer_ctx) = active_peers.remove(&peer_conn) {
542 let _ = active_peers.insert(actual_peer_conn, peer_ctx);
543 }
544 }
545 success.incoming_object_transfer_handles = Some(FileTransferHandleRx {
547 inner: file_transfer_rx,
548 conn_type: actual_peer_conn.as_virtual_connection(),
549 });
550 success
551 })
552 .inspect_err(|_err| {
553 let _ = shared.active_peer_conns.lock().remove(&peer_conn);
555 })
556 };
557
558 tx.send(inner_task.await)
559 .await
560 .map_err(|err| NetworkError::Generic(err.to_string()))
561 };
562
563 requests.push(Box::pin(task))
564 }
565
566 let collection_task = async move { requests.try_collect::<()>().await };
568
569 citadel_io::tokio::try_join!(collection_task, f(rx, connect_success)).map(|_| ())
570 }
571
572 fn construct(kernel: Box<dyn NetKernel<R> + 'a>) -> Self {
573 Self {
574 inner_kernel: kernel,
575 shared: Shared {
576 active_peer_conns: Arc::new(Mutex::new(Default::default())),
577 },
578 _pd: Default::default(),
579 }
580 }
581}
582
583#[cfg(test)]
584mod tests {
585 use crate::prefabs::client::peer_connection::PeerConnectionKernel;
586 use crate::prefabs::client::DefaultServerConnectionSettingsBuilder;
587 use crate::prelude::*;
588 use crate::remote_ext::results::PeerConnectSuccess;
589 use crate::test_common::{server_info, wait_for_peers, TestBarrier};
590 use citadel_io::tokio;
591 use citadel_io::tokio::sync::mpsc::{Receiver, UnboundedSender};
592 use citadel_user::prelude::UserIdentifierExt;
593 use futures::stream::FuturesUnordered;
594 use futures::TryStreamExt;
595 use rstest::rstest;
596 use std::collections::HashMap;
597 use std::future::Future;
598 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
599 use std::time::Duration;
600 use uuid::Uuid;
601
602 lazy_static::lazy_static! {
603 pub static ref PEERS: Vec<(String, String, String)> = {
604 ["alpha", "beta", "charlie", "echo", "delta", "epsilon", "foxtrot"]
605 .iter().map(|base| (format!("{base}.username"), format!("{base}.password"), format!("{base}.full_name")))
606 .collect()
607 };
608 }
609
610 #[rstest]
611 #[case(2, UdpMode::Enabled)]
612 #[case(3, UdpMode::Disabled)]
613 #[timeout(Duration::from_secs(90))]
614 #[tokio::test(flavor = "multi_thread")]
615 async fn peer_to_peer_connect(#[case] peer_count: usize, #[case] udp_mode: UdpMode) {
616 assert!(peer_count > 1);
617 citadel_logging::setup_log();
618 TestBarrier::setup(peer_count);
619
620 let client_success = &AtomicUsize::new(0);
621 let (server, server_addr) = server_info::<StackedRatchet>();
622
623 let client_kernels = FuturesUnordered::new();
624 let total_peers = (0..peer_count)
625 .map(|idx| PEERS.get(idx).unwrap().0.clone())
626 .collect::<Vec<String>>();
627
628 for idx in 0..peer_count {
629 let (username, password, full_name) = PEERS.get(idx).unwrap();
630 let peers = total_peers
631 .clone()
632 .into_iter()
633 .filter(|r| r != username)
634 .map(UserIdentifier::Username)
635 .collect::<Vec<UserIdentifier>>();
636
637 let mut agg = PeerConnectionSetupAggregator::default();
638
639 for peer in peers {
640 agg = agg
641 .with_peer_custom(peer)
642 .ensure_registered()
643 .with_udp_mode(udp_mode)
644 .with_session_security_settings(SessionSecuritySettings::default())
645 .add();
646 }
647
648 let server_connection_settings =
649 DefaultServerConnectionSettingsBuilder::credentialed_registration(
650 server_addr,
651 username,
652 full_name,
653 password.as_str(),
654 )
655 .build()
656 .unwrap();
657
658 let username = username.clone();
659
660 let client_kernel = PeerConnectionKernel::new(
661 server_connection_settings,
662 agg.clone(),
663 move |results, connection| async move {
664 log::info!(target: "citadel", "***PEER {username} CONNECTED ***");
665 let session_cid = connection.conn_type.get_session_cid();
666 let check = move |conn: PeerConnectSuccess<_>| async move {
667 let session_cid = conn.channel.get_session_cid();
668 let _mutual_peers = conn
669 .remote
670 .remote()
671 .get_local_group_mutual_peers(session_cid)
672 .await
673 .unwrap();
674 conn
675 };
676 let p2p_remotes = handle_peer_connect_successes(
677 results,
678 session_cid,
679 peer_count,
680 udp_mode,
681 check,
682 )
683 .await
684 .into_iter()
685 .map(|r| (r.channel.get_peer_cid(), r.remote))
686 .collect::<HashMap<_, _>>();
687
688 let network_peers = connection.get_peers(None).await.unwrap();
692 for user in agg.inner {
693 let peer_cid = user.id.get_cid();
694 assert!(network_peers.iter().any(|r| r.cid == peer_cid))
695 }
696
697 let session_cid = connection.conn_type.get_session_cid();
699 let mutual_peers = connection
700 .get_local_group_mutual_peers(session_cid)
701 .await
702 .unwrap();
703 for (peer_cid, _) in p2p_remotes {
704 assert!(mutual_peers.iter().any(|r| r.cid == peer_cid))
705 }
706
707 log::info!(target: "citadel", "***PEER {username} finished all checks***");
708 let _ = client_success.fetch_add(1, Ordering::Relaxed);
709 wait_for_peers().await;
710 connection.shutdown_kernel().await
711 },
712 );
713
714 let client = DefaultNodeBuilder::default().build(client_kernel).unwrap();
715 client_kernels.push(async move { client.await.map(|_| ()) });
716 }
717
718 let clients = Box::pin(async move { client_kernels.try_collect::<()>().await.map(|_| ()) });
719
720 assert!(futures::future::try_select(server, clients).await.is_ok());
721
722 assert_eq!(client_success.load(Ordering::Relaxed), peer_count);
723 }
724
725 #[rstest]
726 #[case(2, HeaderObfuscatorSettings::default())]
727 #[case(2, HeaderObfuscatorSettings::Enabled)]
728 #[case(2, HeaderObfuscatorSettings::EnabledWithKey(12345))]
729 #[case(3, HeaderObfuscatorSettings::default())]
730 #[timeout(Duration::from_secs(90))]
731 #[tokio::test(flavor = "multi_thread")]
732 async fn peer_to_peer_connect_transient(
733 #[case] peer_count: usize,
734 #[case] header_obfuscator_settings: HeaderObfuscatorSettings,
735 ) -> Result<(), Box<dyn std::error::Error>> {
736 assert!(peer_count > 1);
737 citadel_logging::setup_log();
738 TestBarrier::setup(peer_count);
739 let udp_mode = UdpMode::Enabled;
740
741 let do_deregister = peer_count == 2;
742
743 let client_success = &AtomicUsize::new(0);
744 let (server, server_addr) = server_info::<StackedRatchet>();
745
746 let client_kernels = FuturesUnordered::new();
747 let total_peers = (0..peer_count)
748 .map(|_| Uuid::new_v4())
749 .collect::<Vec<Uuid>>();
750
751 for idx in 0..peer_count {
752 let uuid = total_peers.get(idx).cloned().unwrap();
753 let peers = total_peers
754 .clone()
755 .into_iter()
756 .filter(|r| r != &uuid)
757 .map(UserIdentifier::from)
758 .collect::<Vec<UserIdentifier>>();
759
760 let mut agg = PeerConnectionSetupAggregator::default();
761
762 for peer in peers {
763 let security_settings = SessionSecuritySettings {
764 header_obfuscator_settings,
765 ..Default::default()
766 };
767 agg = agg
768 .with_peer_custom(peer)
769 .with_udp_mode(udp_mode)
770 .ensure_registered()
771 .with_session_security_settings(security_settings)
772 .add();
773 }
774
775 let server_connection_settings =
776 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid)
777 .build()
778 .unwrap();
779
780 let client_kernel = PeerConnectionKernel::new(
781 server_connection_settings,
782 agg,
783 move |results, remote| async move {
784 log::info!(target: "citadel", "***PEER {uuid} CONNECTED***");
785 let session_cid = remote.conn_type.get_session_cid();
786
787 let check = move |conn: PeerConnectSuccess<_>| async move {
788 if do_deregister {
789 conn.remote
790 .deregister()
791 .await
792 .expect("Deregistration failed");
793 assert!(!conn
794 .remote
795 .inner
796 .account_manager()
797 .get_persistence_handler()
798 .hyperlan_peer_exists(session_cid, conn.channel.get_peer_cid())
799 .await
800 .unwrap());
801 }
802 conn
803 };
804
805 let _ = handle_peer_connect_successes(
806 results,
807 session_cid,
808 peer_count,
809 udp_mode,
810 check,
811 )
812 .await;
813
814 log::info!(target: "citadel", "***PEER {uuid} finished all checks***");
815 let _ = client_success.fetch_add(1, Ordering::Relaxed);
816 wait_for_peers().await;
817 remote.shutdown_kernel().await
818 },
819 );
820
821 let client = DefaultNodeBuilder::default().build(client_kernel)?;
822 client_kernels.push(async move { client.await.map(|_| ()) });
823 }
824
825 let clients = Box::pin(async move { client_kernels.try_collect::<()>().await.map(|_| ()) });
826
827 if let Err(err) = futures::future::try_select(server, clients).await {
828 return match err {
829 futures::future::Either::Left(res) => Err(res.0.into_string().into()),
830 futures::future::Either::Right(res) => Err(res.0.into_string().into()),
831 };
832 }
833
834 assert_eq!(client_success.load(Ordering::Relaxed), peer_count);
835 Ok(())
836 }
837
838 #[rstest]
839 #[case(2)]
840 #[case(3)]
841 #[timeout(std::time::Duration::from_secs(90))]
842 #[tokio::test(flavor = "multi_thread")]
843 async fn test_peer_to_peer_file_transfer(
844 #[case] peer_count: usize,
845 ) -> Result<(), Box<dyn std::error::Error>> {
846 assert!(peer_count > 1);
847 citadel_logging::setup_log();
848 TestBarrier::setup(peer_count);
849 let udp_mode = UdpMode::Enabled;
850
851 let sender_success = &AtomicBool::new(false);
852 let receiver_success = &AtomicBool::new(false);
853
854 let (server, server_addr) = server_info::<StackedRatchet>();
855
856 let client_kernels = FuturesUnordered::new();
857 let total_peers = (0..peer_count)
858 .map(|_| Uuid::new_v4())
859 .collect::<Vec<Uuid>>();
860
861 let sender_uuid = total_peers[0];
862
863 for idx in 0..peer_count {
864 let uuid = total_peers.get(idx).cloned().unwrap();
865 let mut peers = total_peers
866 .clone()
867 .into_iter()
868 .filter(|r| r != &uuid)
869 .map(UserIdentifier::from)
870 .collect::<Vec<UserIdentifier>>();
871 if idx != 0 {
877 peers = vec![sender_uuid.into()];
878 }
879
880 let mut agg = PeerConnectionSetupAggregator::default();
881
882 for peer in peers {
883 agg = agg
884 .with_peer_custom(peer)
885 .ensure_registered()
886 .with_udp_mode(udp_mode)
887 .with_session_security_settings(SessionSecuritySettings::default())
888 .add();
889 }
890
891 let server_connection_settings =
892 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid)
893 .build()
894 .unwrap();
895
896 let client_kernel = PeerConnectionKernel::new(
897 server_connection_settings,
898 agg,
899 move |results, remote| async move {
900 log::info!(target: "citadel", "***PEER {uuid} CONNECTED***");
901 wait_for_peers().await;
902 let session_cid = remote.conn_type.get_session_cid();
903 let is_sender = idx == 0; let check = move |mut conn: PeerConnectSuccess<_>| async move {
905 if is_sender {
906 conn.remote
907 .send_file_with_custom_opts(
908 "../resources/TheBridge.pdf",
909 32 * 1024,
910 TransferType::FileTransfer,
911 )
912 .await
913 .expect("Failed to send file");
914 } else {
915 let mut handle = conn
917 .incoming_object_transfer_handles
918 .take()
919 .unwrap()
920 .recv()
921 .await
922 .unwrap();
923 handle.accept().unwrap();
924
925 use citadel_types::proto::ObjectTransferStatus;
926 use futures::StreamExt;
927 let mut path = None;
928 while let Some(status) = handle.next().await {
929 match status {
930 ObjectTransferStatus::ReceptionComplete => {
931 let cmp =
932 include_bytes!("../../../../resources/TheBridge.pdf");
933 let streamed_data =
934 citadel_io::tokio::fs::read(path.clone().unwrap())
935 .await
936 .unwrap();
937 assert_eq!(
938 cmp,
939 streamed_data.as_slice(),
940 "Original data and streamed data does not match"
941 );
942
943 log::info!(target: "citadel", "Peer has finished receiving and verifying the file!");
944 break;
945 }
946
947 ObjectTransferStatus::ReceptionBeginning(file_path, vfm) => {
948 path = Some(file_path);
949 assert_eq!(vfm.name, "TheBridge.pdf")
950 }
951
952 _ => {}
953 }
954 }
955 }
956
957 conn
958 };
959 let peer_count = if idx == 0 { peer_count } else { 2 };
962 let _ = handle_peer_connect_successes(
963 results,
964 session_cid,
965 peer_count,
966 udp_mode,
967 check,
968 )
969 .await;
970
971 if is_sender {
972 sender_success.store(true, Ordering::Relaxed);
973 } else {
974 receiver_success.store(true, Ordering::Relaxed);
975 }
976
977 log::info!(target: "citadel", "***PEER {uuid} (is_sender: {is_sender}) finished all checks***");
978 wait_for_peers().await;
979 log::info!(target: "citadel", "***PEER {uuid} (is_sender: {is_sender}) shutting down***");
980 remote.shutdown_kernel().await
981 },
982 );
983
984 let client = DefaultNodeBuilder::default().build(client_kernel).unwrap();
985 client_kernels.push(async move { client.await.map(|_| ()) });
986 }
987
988 let clients = Box::pin(async move { client_kernels.try_collect::<()>().await.map(|_| ()) });
989
990 if let Err(err) = futures::future::try_select(server, clients).await {
991 return match err {
992 futures::future::Either::Left(res) => Err(res.0.into_string().into()),
993 futures::future::Either::Right(res) => Err(res.0.into_string().into()),
994 };
995 }
996
997 assert!(sender_success.load(Ordering::Relaxed));
998 assert!(receiver_success.load(Ordering::Relaxed));
999 Ok(())
1000 }
1001
1002 #[rstest]
1003 #[case(2)]
1004 #[timeout(std::time::Duration::from_secs(90))]
1005 #[tokio::test(flavor = "multi_thread")]
1006 async fn test_peer_to_peer_rekey(
1007 #[case] peer_count: usize,
1008 ) -> Result<(), Box<dyn std::error::Error>> {
1009 assert!(peer_count > 1);
1010 citadel_logging::setup_log();
1011 TestBarrier::setup(peer_count);
1012 let udp_mode = UdpMode::Enabled;
1013
1014 let client_success = &AtomicUsize::new(0);
1015 let (server, server_addr) = server_info::<StackedRatchet>();
1016
1017 let client_kernels = FuturesUnordered::new();
1018 let total_peers = (0..peer_count)
1019 .map(|_| Uuid::new_v4())
1020 .collect::<Vec<Uuid>>();
1021
1022 for idx in 0..peer_count {
1023 let uuid = total_peers.get(idx).cloned().unwrap();
1024 let peers = total_peers
1025 .clone()
1026 .into_iter()
1027 .filter(|r| r != &uuid)
1028 .map(UserIdentifier::from)
1029 .collect::<Vec<UserIdentifier>>();
1030
1031 let mut agg = PeerConnectionSetupAggregator::default();
1032
1033 for peer in peers {
1034 agg = agg
1035 .with_peer_custom(peer)
1036 .ensure_registered()
1037 .with_udp_mode(udp_mode)
1038 .with_session_security_settings(SessionSecuritySettings::default())
1039 .add();
1040 }
1041
1042 let server_connection_settings =
1043 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid)
1044 .build()
1045 .unwrap();
1046
1047 let client_kernel = PeerConnectionKernel::new(
1048 server_connection_settings,
1049 agg,
1050 move |results, remote| async move {
1051 log::info!(target: "citadel", "***PEER {uuid} CONNECTED***");
1052 let session_cid = remote.conn_type.get_session_cid();
1053
1054 let check = move |conn: PeerConnectSuccess<_>| async move {
1055 if idx == 0 {
1056 for x in 1..10 {
1057 assert_eq!(
1058 conn.remote.rekey().await.expect("Failed to rekey"),
1059 Some(x)
1060 );
1061 }
1062 }
1063
1064 conn
1065 };
1066
1067 let results = handle_peer_connect_successes(
1068 results,
1069 session_cid,
1070 peer_count,
1071 udp_mode,
1072 check,
1073 )
1074 .await;
1075
1076 log::info!(target: "citadel", "***PEER {uuid} finished all check (count: {})s***", results.len());
1077 let _ = client_success.fetch_add(1, Ordering::Relaxed);
1078 wait_for_peers().await;
1079 remote.shutdown_kernel().await
1080 },
1081 );
1082
1083 let client = DefaultNodeBuilder::default().build(client_kernel)?;
1084 client_kernels.push(async move { client.await.map(|_| ()) });
1085 }
1086
1087 let clients = Box::pin(async move { client_kernels.try_collect::<()>().await.map(|_| ()) });
1088
1089 if let Err(err) = futures::future::try_select(server, clients).await {
1090 return match err {
1091 futures::future::Either::Left(res) => Err(res.0.into_string().into()),
1092 futures::future::Either::Right(res) => Err(res.0.into_string().into()),
1093 };
1094 }
1095
1096 assert_eq!(client_success.load(Ordering::Relaxed), peer_count);
1097 Ok(())
1098 }
1099
1100 #[rstest]
1101 #[case(2)]
1102 #[timeout(std::time::Duration::from_secs(90))]
1103 #[tokio::test(flavor = "multi_thread")]
1104 async fn test_peer_to_peer_disconnect(
1105 #[case] peer_count: usize,
1106 ) -> Result<(), Box<dyn std::error::Error>> {
1107 assert!(peer_count > 1);
1108 citadel_logging::setup_log();
1109 TestBarrier::setup(peer_count);
1110 let udp_mode = UdpMode::Enabled;
1111
1112 let client_success = &AtomicUsize::new(0);
1113 let (server, server_addr) = server_info::<StackedRatchet>();
1114
1115 let client_kernels = FuturesUnordered::new();
1116 let total_peers = (0..peer_count)
1117 .map(|_| Uuid::new_v4())
1118 .collect::<Vec<Uuid>>();
1119
1120 for idx in 0..peer_count {
1121 let uuid = total_peers.get(idx).cloned().unwrap();
1122 let peers = total_peers
1123 .clone()
1124 .into_iter()
1125 .filter(|r| r != &uuid)
1126 .map(UserIdentifier::from)
1127 .collect::<Vec<UserIdentifier>>();
1128
1129 let mut agg = PeerConnectionSetupAggregator::default();
1130
1131 for peer in peers {
1132 agg = agg
1133 .with_peer_custom(peer)
1134 .ensure_registered()
1135 .with_udp_mode(udp_mode)
1136 .with_session_security_settings(SessionSecuritySettings::default())
1137 .add();
1138 }
1139
1140 let server_connection_settings =
1141 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid)
1142 .build()
1143 .unwrap();
1144
1145 let client_kernel = PeerConnectionKernel::new(
1146 server_connection_settings,
1147 agg,
1148 move |results, remote| async move {
1149 log::info!(target: "citadel", "***PEER {uuid} CONNECTED***");
1150 wait_for_peers().await;
1151 let session_cid = remote.conn_type.get_session_cid();
1152
1153 let check = move |conn: PeerConnectSuccess<_>| async move {
1154 conn.remote
1155 .disconnect()
1156 .await
1157 .expect("Failed to p2p disconnect");
1158 conn
1159 };
1160 let _ = handle_peer_connect_successes(
1161 results,
1162 session_cid,
1163 peer_count,
1164 udp_mode,
1165 check,
1166 )
1167 .await;
1168 log::info!(target: "citadel", "***PEER {uuid} finished all checks***");
1169
1170 let _ = client_success.fetch_add(1, Ordering::Relaxed);
1171 wait_for_peers().await;
1172 remote.shutdown_kernel().await
1173 },
1174 );
1175
1176 let client = DefaultNodeBuilder::default().build(client_kernel)?;
1177 client_kernels.push(async move { client.await.map(|_| ()) });
1178 }
1179
1180 let clients = Box::pin(async move { client_kernels.try_collect::<()>().await.map(|_| ()) });
1181
1182 if let Err(err) = futures::future::try_select(server, clients).await {
1183 return match err {
1184 futures::future::Either::Left(res) => Err(res.0.into_string().into()),
1185 futures::future::Either::Right(res) => Err(res.0.into_string().into()),
1186 };
1187 }
1188
1189 assert_eq!(client_success.load(Ordering::Relaxed), peer_count);
1190 Ok(())
1191 }
1192
1193 #[rstest]
1194 #[case(SecrecyMode::BestEffort, Some("test-p2p-password"))]
1195 #[timeout(std::time::Duration::from_secs(240))]
1196 #[citadel_io::tokio::test(flavor = "multi_thread")]
1197 async fn test_p2p_wrong_session_password(
1198 #[case] secrecy_mode: SecrecyMode,
1199 #[case] p2p_password: Option<&'static str>,
1200 #[values(KemAlgorithm::Kyber)] kem: KemAlgorithm,
1201 #[values(EncryptionAlgorithm::AES_GCM_256)] enx: EncryptionAlgorithm,
1202 ) {
1203 citadel_logging::setup_log_no_panic_hook();
1204 TestBarrier::setup(2);
1205 let (server, server_addr) = server_info::<StackedRatchet>();
1206 let peer_0_error_received = &AtomicBool::new(false);
1207 let peer_1_error_received = &AtomicBool::new(false);
1208
1209 let uuid0 = Uuid::new_v4();
1210 let uuid1 = Uuid::new_v4();
1211 let session_security = SessionSecuritySettingsBuilder::default()
1212 .with_secrecy_mode(secrecy_mode)
1213 .with_crypto_params(kem + enx)
1214 .build()
1215 .unwrap();
1216
1217 let mut peer0_agg = PeerConnectionSetupAggregator::default()
1218 .with_peer_custom(uuid1)
1219 .ensure_registered()
1220 .with_session_security_settings(session_security);
1221
1222 if let Some(password) = p2p_password {
1223 peer0_agg = peer0_agg.with_session_password(password);
1224 }
1225
1226 let peer0_connection = peer0_agg.add();
1227
1228 let mut peer1_agg = PeerConnectionSetupAggregator::default()
1229 .with_peer_custom(uuid0)
1230 .ensure_registered()
1231 .with_session_security_settings(session_security);
1232
1233 if let Some(_password) = p2p_password {
1234 peer1_agg = peer1_agg.with_session_password("wrong password");
1235 }
1236
1237 let peer1_connection = peer1_agg.add();
1238
1239 let server_connection_settings0 =
1240 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid0)
1241 .with_udp_mode(UdpMode::Enabled)
1242 .with_session_security_settings(session_security)
1243 .build()
1244 .unwrap();
1245
1246 let server_connection_settings1 =
1247 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid1)
1248 .with_udp_mode(UdpMode::Enabled)
1249 .with_session_security_settings(session_security)
1250 .build()
1251 .unwrap();
1252
1253 let client_kernel0 = PeerConnectionKernel::new(
1254 server_connection_settings0,
1255 peer0_connection,
1256 move |mut connection, remote| async move {
1257 wait_for_peers().await;
1258 let conn = connection.recv().await.unwrap();
1259 log::trace!(target: "citadel", "Peer 0 {} received: {:?}", remote.conn_type.get_session_cid(), conn);
1260 if conn.is_ok() {
1261 peer_0_error_received.store(true, Ordering::SeqCst);
1262 }
1263 wait_for_peers().await;
1264 remote.shutdown_kernel().await
1265 },
1266 );
1267
1268 let client_kernel1 = PeerConnectionKernel::new(
1269 server_connection_settings1,
1270 peer1_connection,
1271 move |mut connection, remote| async move {
1272 wait_for_peers().await;
1273 let conn = connection.recv().await.unwrap();
1274 log::trace!(target: "citadel", "Peer 1 {} received: {:?}", remote.conn_type.get_session_cid(), conn);
1275 if conn.is_ok() {
1276 peer_1_error_received.store(true, Ordering::SeqCst);
1277 }
1278 wait_for_peers().await;
1279 remote.shutdown_kernel().await
1280 },
1281 );
1282
1283 let client0 = DefaultNodeBuilder::default().build(client_kernel0).unwrap();
1284 let client1 = DefaultNodeBuilder::default().build(client_kernel1).unwrap();
1285 let clients = futures::future::try_join(client0, client1);
1286
1287 let task = async move {
1288 tokio::select! {
1289 server_res = server => Err(NetworkError::msg(format!("Server ended prematurely: {:?}", server_res.map(|_| ())))),
1290 client_res = clients => client_res.map(|_| ())
1291 }
1292 };
1293
1294 tokio::time::timeout(Duration::from_secs(120), task)
1295 .await
1296 .unwrap()
1297 .unwrap();
1298
1299 assert!(!peer_0_error_received.load(Ordering::SeqCst));
1300 assert!(!peer_1_error_received.load(Ordering::SeqCst));
1301 }
1302
1303 async fn handle_peer_connect_successes<F, Fut, R: Ratchet>(
1304 mut conn_rx: Receiver<Result<PeerConnectSuccess<R>, NetworkError>>,
1305 session_cid: u64,
1306 peer_count: usize,
1307 udp_mode: UdpMode,
1308 checks: F,
1309 ) -> Vec<PeerConnectSuccess<R>>
1310 where
1311 F: Fn(PeerConnectSuccess<R>) -> Fut + Send + Clone + 'static,
1312 Fut: Future<Output = PeerConnectSuccess<R>> + Send,
1313 {
1314 let (finished_tx, finished_rx) = tokio::sync::oneshot::channel();
1315
1316 let task = async move {
1317 let (done_tx, mut done_rx) = tokio::sync::mpsc::unbounded_channel();
1318 let mut conns = vec![];
1319 while let Some(conn) = conn_rx.recv().await {
1320 conns.push(conn);
1321 if conns.len() == peer_count - 1 {
1322 break;
1323 }
1324 }
1325
1326 log::info!(target: "citadel", "~~~*** Peer {session_cid} has {} connections to other peers ***~~~", conns.len());
1327
1328 for conn in conns {
1329 let conn = conn.expect("Error receiving peer connection");
1330 handle_peer_connect_success(
1331 conn,
1332 done_tx.clone(),
1333 session_cid,
1334 udp_mode,
1335 checks.clone(),
1336 );
1337 }
1338
1339 let mut ret = vec![];
1341 while let Some(done) = done_rx.recv().await {
1342 ret.push(done);
1343 if ret.len() == peer_count - 1 {
1344 break;
1345 }
1346 }
1347
1348 finished_tx
1349 .send(ret)
1350 .expect("Error sending finished signal in handle_peer_connect_successes");
1351 };
1352
1353 drop(tokio::task::spawn(task));
1354 let ret = finished_rx
1355 .await
1356 .expect("Error receiving finished signal in handle_peer_connect_successes");
1357
1358 assert_eq!(ret.len(), peer_count - 1);
1359 ret
1360 }
1361
1362 fn handle_peer_connect_success<F, Fut, R: Ratchet>(
1363 mut conn: PeerConnectSuccess<R>,
1364 done_tx: UnboundedSender<PeerConnectSuccess<R>>,
1365 session_cid: u64,
1366 udp_mode: UdpMode,
1367 checks: F,
1368 ) where
1369 F: Fn(PeerConnectSuccess<R>) -> Fut + Send + Clone + 'static,
1370 Fut: Future<Output = PeerConnectSuccess<R>> + Send,
1371 {
1372 let task = async move {
1373 let chan = conn.udp_channel_rx.take();
1374 crate::test_common::p2p_assertions(session_cid, &conn).await;
1375 crate::test_common::udp_mode_assertions(udp_mode, chan).await;
1376 let conn = checks(conn).await;
1377 done_tx
1378 .send(conn)
1379 .expect("Error sending done signal in handle_peer_connect_success");
1380 };
1381
1382 drop(tokio::task::spawn(task));
1383 }
1384}