1use crate::prelude::*;
66use crate::test_common::wait_for_peers;
67use citadel_io::tokio::sync::Mutex;
68use citadel_user::prelude::UserIdentifierExt;
69use futures::{Future, StreamExt};
70use std::marker::PhantomData;
71use std::pin::Pin;
72use std::sync::atomic::{AtomicBool, Ordering};
73use std::sync::Arc;
74use uuid::Uuid;
75
76pub struct BroadcastKernel<'a, F, Fut, R: Ratchet> {
82 inner_kernel: Box<dyn NetKernel<R> + 'a>,
83 shared: Arc<BroadcastShared>,
84 _pd: PhantomData<fn() -> (F, Fut)>,
85}
86
87pub struct BroadcastShared {
88 route_registers: AtomicBool,
89 register_rx:
90 citadel_io::Mutex<Option<citadel_io::tokio::sync::mpsc::UnboundedReceiver<PeerSignal>>>,
91 register_tx: citadel_io::tokio::sync::mpsc::UnboundedSender<PeerSignal>,
92}
93
94pub enum GroupInitRequestType {
102 Create {
107 local_user: UserIdentifier,
108 invite_list: Vec<UserIdentifier>,
109 group_id: Uuid,
110 accept_registrations: bool,
111 },
112 Join {
120 local_user: UserIdentifier,
121 owner: UserIdentifier,
122 group_id: Uuid,
123 do_peer_register: bool,
124 },
125}
126
127#[async_trait]
128impl<'a, F, Fut, R: Ratchet> PrefabFunctions<'a, GroupInitRequestType, R>
129 for BroadcastKernel<'a, F, Fut, R>
130where
131 F: FnOnce(GroupChannel, CitadelClientServerConnection<R>) -> Fut + Send + 'a,
132 Fut: Future<Output = Result<(), NetworkError>> + Send + 'a,
133{
134 type UserLevelInputFunction = F;
135 type SharedBundle = Arc<BroadcastShared>;
136
137 fn get_shared_bundle(&self) -> Self::SharedBundle {
138 self.shared.clone()
139 }
140
141 #[allow(unreachable_code, clippy::blocks_in_conditions)]
142 #[cfg_attr(
143 feature = "localhost-testing",
144 tracing::instrument(level = "trace", target = "citadel", skip_all, ret, err(Debug))
145 )]
146 async fn on_c2s_channel_received(
147 connect_success: CitadelClientServerConnection<R>,
148 arg: GroupInitRequestType,
149 fx: Self::UserLevelInputFunction,
150 shared: Arc<BroadcastShared>,
151 ) -> Result<(), NetworkError> {
152 let session_cid = connect_success.cid;
153 wait_for_peers().await;
154 let mut creator_only_accept_inbound_registers = false;
155
156 let mut is_owner = false;
157 let request = match arg {
158 GroupInitRequestType::Create {
159 local_user,
160 invite_list,
161 group_id,
162 accept_registrations,
163 } => {
164 is_owner = true;
165 let mut peers_registered = vec![];
167
168 for peer in &invite_list {
169 let peer = peer
170 .search_peer(session_cid, connect_success.account_manager())
171 .await?
172 .ok_or_else(|| {
173 citadel_io::error!(
174 citadel_io::ErrorCode::BroadcastCreateUserNotRegistered,
175 format!("{peer:?}"),
176 format!("{local_user:?}")
177 )
178 })?;
179
180 peers_registered.push(peer.cid)
181 }
182
183 creator_only_accept_inbound_registers = accept_registrations;
184
185 GroupBroadcast::Create {
186 initial_invitees: peers_registered,
187 options: MessageGroupOptions {
188 group_type: GroupType::Public,
189 id: group_id.as_u128(),
190 ..Default::default()
191 },
192 }
193 }
194
195 GroupInitRequestType::Join {
196 local_user,
197 owner,
198 group_id,
199 do_peer_register,
200 } => {
201 let owner_orig = owner;
203 let owner_find = owner_orig
204 .search_peer(session_cid, connect_success.account_manager())
205 .await?;
206
207 let owner = if let Some(owner) = owner_find {
208 Some(owner)
209 } else if do_peer_register {
210 let handle = connect_success
211 .propose_target(local_user.clone(), owner_orig.clone())
212 .await?;
213 let _ = handle.register_to_peer().await?;
214 owner_orig
216 .search_peer(session_cid, connect_success.account_manager())
217 .await?
218 } else {
219 None
220 };
221
222 let owner = owner.ok_or_else(|| {
223 citadel_io::error!(
224 citadel_io::ErrorCode::BroadcastJoinUserNotRegistered,
225 format!("{owner_orig:?}"),
226 format!("{local_user:?}")
227 )
228 })?;
229
230 let expected_message_group_key = MessageGroupKey {
231 cid: owner.cid,
232 mgid: group_id.as_u128(),
233 };
234
235 let mut retries = 0;
237 let group_owner_handle = connect_success
238 .propose_target(local_user.clone(), owner.cid)
239 .await?;
240 loop {
241 let owned_groups = group_owner_handle.list_owned_groups().await?;
242 if owned_groups.contains(&expected_message_group_key) {
243 break;
244 } else {
245 citadel_io::time::sleep(std::time::Duration::from_secs(2u64.pow(retries)))
246 .await;
247
248 retries += 1;
249 if retries > 4 {
250 return Err(citadel_io::error!(
251 citadel_io::ErrorCode::BroadcastOwnerGroupMissing,
252 citadel_io::Dbg(owner),
253 citadel_io::Dbg(group_id)
254 ));
255 }
256 }
257 }
258
259 GroupBroadcast::RequestJoin {
260 sender: local_user.get_cid(),
261 key: expected_message_group_key,
262 }
263 }
264 };
265
266 let request = NodeRequest::GroupBroadcastCommand(GroupBroadcastCommand {
267 session_cid,
268 command: request,
269 });
270
271 let subscription = &Mutex::new(Some(
272 connect_success.send_callback_subscription(request).await?,
273 ));
274
275 log::trace!(target: "citadel", "Peer {session_cid} is attempting to join group");
276 let acceptor_task = if creator_only_accept_inbound_registers {
277 shared.route_registers.store(true, Ordering::Relaxed);
278 let mut reg_rx = shared.register_rx.lock().take().unwrap();
279 let remote = connect_success.remote_ref().clone();
280 Box::pin(async move {
281 let mut subscription = subscription.lock().await.take().unwrap();
282 let mut count_registered = 0;
284 loop {
285 let post_register = citadel_io::tokio::select! {
286 reg_request = reg_rx.recv() => {
287 reg_request.ok_or_else(|| citadel_io::error!(citadel_io::ErrorCode::BroadcastStreamEndedUnexpectedly, "reg_rx"))?
288 },
289
290 reg_request2 = subscription.next() => {
291 let signal = reg_request2.ok_or_else(|| citadel_io::error!(citadel_io::ErrorCode::BroadcastStreamEndedUnexpectedly, "subscription"))?;
292 if let NodeResult::PeerEvent(PeerEvent { event: sig @ PeerSignal::PostRegister { .. }, .. }) = &signal {
293 sig.clone()
294 } else {
295 continue;
296 }
297 }
298 };
299
300 log::trace!(target: "citadel", "ACCEPTOR {session_cid} RECV reg_request: {post_register:?}");
301 if let PeerSignal::PostRegister {
302 peer_conn_type: peer_conn,
303 inviter_username: _,
304 invitee_username: _,
305 ticket_opt: _,
306 invitee_response: None,
307 } = &post_register
308 {
309 let cid = peer_conn.get_original_target_cid();
310 if cid != session_cid {
311 log::warn!(target: "citadel", "Received the wrong CID. Will not accept request");
312 continue;
313 }
314
315 let _ = responses::peer_register(post_register, true, &remote).await?;
316 if cfg!(feature = "localhost-testing") {
317 count_registered += 1;
318 if count_registered == crate::test_common::num_local_test_peers() - 1 {
319 break;
321 }
322 }
323 }
324 }
325
326 Ok::<_, NetworkError>(())
327 })
328 as Pin<
329 Box<
330 dyn futures::Future<
331 Output = Result<(), citadel_proto::prelude::NetworkError>,
332 > + Send,
333 >,
334 >
335 } else {
336 Box::pin(async move { Ok::<_, NetworkError>(()) })
337 as Pin<
338 Box<
339 dyn futures::Future<
340 Output = Result<(), citadel_proto::prelude::NetworkError>,
341 > + Send,
342 >,
343 >
344 };
345
346 let mut lock = subscription.lock().await;
347 let subscription = lock.as_mut().unwrap();
348 while let Some(event) = subscription.next().await {
349 match event.into_result()? {
350 NodeResult::PeerEvent(PeerEvent {
351 event: ref ps @ PeerSignal::PostRegister { .. },
352 ticket: _,
353 ..
354 }) => {
355 shared
356 .register_tx
357 .send(ps.clone())
358 .map_err(|err| NetworkError::generic(err.to_string()))?;
359 }
360 NodeResult::GroupChannelCreated(GroupChannelCreated {
361 ticket: _,
362 channel,
363 session_cid: _,
364 }) => {
365 drop(lock);
368 return if is_owner {
369 citadel_io::tokio::try_join!(fx(channel, connect_success), acceptor_task)
370 .map(|_| ())
371 } else {
372 fx(channel, connect_success).await.map(|_| ())
373 };
374 }
375
376 NodeResult::GroupEvent(GroupEvent {
377 session_cid: _,
378 ticket: _,
379 event: GroupBroadcast::CreateResponse { key: None },
380 }) => {
381 return Err(citadel_io::error!(
382 citadel_io::ErrorCode::BroadcastCreateGroupFailed
383 ))
384 }
385
386 _ => {}
387 }
388 }
389
390 Ok(())
391 }
392
393 fn construct(kernel: Box<dyn NetKernel<R> + 'a>) -> Self {
394 let (tx, rx) = citadel_io::tokio::sync::mpsc::unbounded_channel();
395 Self {
396 shared: Arc::new(BroadcastShared {
397 route_registers: AtomicBool::new(false),
398 register_rx: citadel_io::Mutex::new(Some(rx)),
399 register_tx: tx,
400 }),
401 inner_kernel: kernel,
402 _pd: Default::default(),
403 }
404 }
405}
406
407#[async_trait]
408impl<F, Fut, R: Ratchet> NetKernel<R> for BroadcastKernel<'_, F, Fut, R> {
409 fn load_remote(&mut self, node_remote: NodeRemote<R>) -> Result<(), NetworkError> {
410 self.inner_kernel.load_remote(node_remote)
411 }
412
413 async fn on_start(&self) -> Result<(), NetworkError> {
414 self.inner_kernel.on_start().await
415 }
416
417 async fn on_node_event_received(&self, message: NodeResult<R>) -> Result<(), NetworkError> {
418 if let NodeResult::PeerEvent(PeerEvent {
419 event: ps @ PeerSignal::PostRegister { .. },
420 ticket: _,
421 ..
422 }) = &message
423 {
424 if self.shared.route_registers.load(Ordering::Relaxed) {
425 return self
426 .shared
427 .register_tx
428 .send(ps.clone())
429 .map_err(|err| NetworkError::generic(err.to_string()));
430 }
431 }
432
433 self.inner_kernel.on_node_event_received(message).await
434 }
435
436 async fn on_stop(&mut self) -> Result<(), NetworkError> {
437 self.inner_kernel.on_stop().await
438 }
439}
440
441#[cfg(all(test, feature = "localhost-testing"))]
442mod tests {
443 use crate::prefabs::client::broadcast::{BroadcastKernel, GroupInitRequestType};
444 use crate::prefabs::client::peer_connection::PeerConnectionKernel;
445 use crate::prefabs::client::DefaultServerConnectionSettingsBuilder;
446 use crate::prelude::*;
447 use crate::test_common::{server_info, wait_for_peers, TestBarrier};
448 use citadel_io::tokio;
449 use futures::prelude::stream::FuturesUnordered;
450 use futures::TryStreamExt;
451 use rstest::rstest;
452 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
453 use uuid::Uuid;
454
455 #[citadel_io::tokio::test(flavor = "multi_thread")]
456 async fn group_connect_list_members() -> Result<(), Box<dyn std::error::Error>> {
457 let peer_count = 3;
458 assert!(peer_count > 1);
459 citadel_logging::setup_log();
460 TestBarrier::setup(peer_count);
461
462 let client_success = &AtomicUsize::new(0);
463 let (server, server_addr) = server_info::<StackedRatchet>();
464
465 let client_kernels = FuturesUnordered::new();
466 let total_peers = (0..peer_count)
467 .map(|_| Uuid::new_v4())
468 .collect::<Vec<Uuid>>();
469 let group_id = Uuid::new_v4();
470
471 for idx in 0..peer_count {
472 let uuid = total_peers.get(idx).cloned().unwrap();
473
474 let request = if idx == 0 {
475 GroupInitRequestType::Create {
477 local_user: UserIdentifier::from(uuid),
478 invite_list: vec![],
479 group_id,
480 accept_registrations: true,
481 }
482 } else {
483 GroupInitRequestType::Join {
484 local_user: UserIdentifier::from(uuid),
485 owner: total_peers.first().cloned().unwrap().into(),
486 group_id,
487 do_peer_register: true,
488 }
489 };
490
491 let server_connection_settings =
492 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid)
493 .build()
494 .unwrap();
495
496 let client_kernel = BroadcastKernel::new(
497 server_connection_settings,
498 request,
499 move |channel, connection| async move {
500 wait_for_peers().await;
501 log::trace!(target: "citadel", "***GROUP PEER {}={}={} CONNECT SUCCESS***", idx, uuid, connection.conn_type.get_session_cid());
502
503 let owned_groups = connection.list_owned_groups().await.unwrap();
504
505 if idx == 0 {
506 assert_eq!(owned_groups.len(), 1);
507 } else {
508 assert_eq!(owned_groups.len(), 0);
509 }
510
511 log::trace!(target: "citadel", "Peer {idx}={} is COMPLETE!", connection.conn_type.get_session_cid());
512
513 let _ = client_success.fetch_add(1, Ordering::Relaxed);
514 wait_for_peers().await;
515 drop(channel);
516 connection.shutdown_kernel().await
517 },
518 );
519
520 let client = DefaultNodeBuilder::default().build(client_kernel).unwrap();
521
522 client_kernels.push(async move { client.await.map(|_| ()) });
523 }
524
525 let clients = Box::pin(async move { client_kernels.try_collect::<()>().await.map(|_| ()) });
526
527 let res = futures::future::try_select(server, clients).await;
528 if let Err(err) = res {
529 return match err {
530 futures::future::Either::Left(left) => Err(left.0.into_string().into()),
531 futures::future::Either::Right(right) => Err(right.0.into_string().into()),
532 };
533 }
534
535 assert_eq!(client_success.load(Ordering::Relaxed), peer_count);
536 Ok(())
537 }
538
539 #[rstest]
540 #[case(2)]
541 #[timeout(std::time::Duration::from_secs(90))]
542 #[citadel_io::tokio::test(flavor = "multi_thread")]
543 async fn test_manual_group_connect(
544 #[case] peer_count: usize,
545 ) -> Result<(), Box<dyn std::error::Error>> {
546 assert!(peer_count > 1);
551 citadel_logging::setup_log();
552 TestBarrier::setup(peer_count);
553
554 let client_success = &AtomicBool::new(false);
555 let receiver_success = &AtomicBool::new(false);
556
557 let (server, server_addr) = server_info::<StackedRatchet>();
558
559 let client_kernels = FuturesUnordered::new();
560 let total_peers = (0..peer_count)
561 .map(|_| Uuid::new_v4())
562 .collect::<Vec<Uuid>>();
563
564 for idx in 0..peer_count {
565 let uuid = total_peers.get(idx).cloned().unwrap();
566 let peers = total_peers
567 .clone()
568 .into_iter()
569 .filter(|r| r != &uuid)
570 .map(UserIdentifier::from)
571 .collect::<Vec<UserIdentifier>>();
572
573 let server_connection_settings =
574 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid)
575 .build()
576 .unwrap();
577
578 let client_kernel = PeerConnectionKernel::new(
579 server_connection_settings,
580 peers,
581 move |mut results, remote| async move {
582 let _sender = remote.conn_type.get_session_cid();
583 let mut signals = remote.get_unprocessed_signals_receiver().unwrap();
584
585 wait_for_peers().await;
586 let conn = results.recv().await.unwrap()?;
587 log::trace!(target: "citadel", "User {uuid} received {conn:?}");
588
589 if idx == 0 {
591 let _channel = remote
592 .create_group(Some(vec![conn.channel.get_peer_cid().into()]))
593 .await?;
594 log::info!(target: "citadel", "The designated node has finished creating a group");
595
596 wait_for_peers().await;
597 client_success.store(true, Ordering::Relaxed);
598 return remote.shutdown_kernel().await;
599 } else {
600 while let Some(evt) = signals.recv().await {
602 log::info!(target: "citadel", "Received unprocessed signal: {evt:?}");
603 match evt {
604 NodeResult::GroupEvent(GroupEvent {
605 session_cid: _,
606 ticket: _,
607 event:
608 GroupBroadcast::Invitation {
609 sender: _,
610 key: _key,
611 },
612 }) => {
613 let _ =
614 crate::responses::group_invite(evt, true, &remote.inner)
615 .await?;
616 }
617
618 NodeResult::GroupChannelCreated(GroupChannelCreated {
619 ticket: _,
620 channel: _chan,
621 session_cid: _,
622 }) => {
623 receiver_success.store(true, Ordering::Relaxed);
624 log::trace!(target: "citadel", "***PEER {uuid} CONNECT***");
625 wait_for_peers().await;
626 return remote.shutdown_kernel().await;
627 }
628
629 val => {
630 log::warn!(target: "citadel", "Unhandled response: {val:?}")
631 }
632 }
633 }
634 }
635
636 Err(citadel_io::error!(
637 citadel_io::ErrorCode::BroadcastStreamEndedUnexpectedly,
638 "signals_recv"
639 ))
640 },
641 );
642
643 let client = DefaultNodeBuilder::default().build(client_kernel).unwrap();
644 client_kernels.push(async move { client.await.map(|_| ()) });
645 }
646
647 let clients = Box::pin(async move { client_kernels.try_collect::<()>().await.map(|_| ()) });
648
649 if let Err(err) = futures::future::try_select(server, clients).await {
650 return match err {
651 futures::future::Either::Left(res) => Err(res.0.into_string().into()),
652 futures::future::Either::Right(res) => Err(res.0.into_string().into()),
653 };
654 }
655
656 assert!(client_success.load(Ordering::Relaxed));
657 assert!(receiver_success.load(Ordering::Relaxed));
658 Ok(())
659 }
660
661 #[citadel_io::tokio::test(flavor = "multi_thread")]
666 async fn group_command_hierarchy_superior_reads_subordinate(
667 ) -> Result<(), Box<dyn std::error::Error>> {
668 use crate::prelude::GroupBroadcastPayload;
669 use citadel_types::crypto::SecBuffer;
670 use citadel_types::proto::{
671 CommandPath, GroupHierarchyMode, MessageGroupOptions, ReadPolicy,
672 };
673 use std::collections::HashMap;
674
675 let peer_count = 2;
676 citadel_logging::setup_log();
677 TestBarrier::setup(peer_count);
678
679 let owner_read = &AtomicBool::new(false);
680 let (server, server_addr) = server_info::<StackedRatchet>();
681 let client_kernels = FuturesUnordered::new();
682 let total_peers = (0..peer_count)
683 .map(|_| Uuid::new_v4())
684 .collect::<Vec<Uuid>>();
685
686 for idx in 0..peer_count {
687 let uuid = total_peers.get(idx).cloned().unwrap();
688 let peers = total_peers
689 .clone()
690 .into_iter()
691 .filter(|r| r != &uuid)
692 .map(UserIdentifier::from)
693 .collect::<Vec<UserIdentifier>>();
694 let server_connection_settings =
695 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid)
696 .build()
697 .unwrap();
698
699 let client_kernel = PeerConnectionKernel::new(
700 server_connection_settings,
701 peers,
702 move |mut results, remote| async move {
703 let mut signals = remote.get_unprocessed_signals_receiver().unwrap();
704 wait_for_peers().await;
705 let conn = results.recv().await.unwrap()?;
706
707 if idx == 0 {
708 let sub_cid: u64 = conn.channel.get_peer_cid();
710 let mut ranks = HashMap::new();
711 let _ = ranks.insert(sub_cid, CommandPath::parse("/alpha"));
712 let options = MessageGroupOptions {
713 hierarchy: GroupHierarchyMode::CommandHierarchy {
714 read_policy: ReadPolicy::SuperiorOnly,
715 ranks,
716 },
717 ..Default::default()
718 };
719 let mut channel = remote
720 .create_group_with_options(Some(vec![sub_cid.into()]), options)
721 .await?;
722
723 loop {
725 match channel.recv().await {
726 Some(GroupBroadcastPayload::Message { payload, sender: _ }) => {
727 assert_eq!(payload.as_ref(), b"sitrep from subordinate");
728 owner_read.store(true, Ordering::Relaxed);
729 break;
730 }
731 Some(_) => continue,
732 None => break,
733 }
734 }
735 wait_for_peers().await;
736 return remote.shutdown_kernel().await;
737 }
738
739 while let Some(evt) = signals.recv().await {
741 match evt {
742 NodeResult::GroupEvent(GroupEvent {
743 event: GroupBroadcast::Invitation { .. },
744 ..
745 }) => {
746 let _ = crate::responses::group_invite(evt, true, &remote.inner)
747 .await?;
748 }
749 NodeResult::GroupChannelCreated(GroupChannelCreated {
750 channel,
751 ..
752 }) => {
753 channel
754 .send_message(SecBuffer::from(
755 b"sitrep from subordinate".to_vec(),
756 ))
757 .await?;
758 wait_for_peers().await;
759 return remote.shutdown_kernel().await;
760 }
761 _ => {}
762 }
763 }
764
765 Err(citadel_io::error!(
766 citadel_io::ErrorCode::BroadcastStreamEndedUnexpectedly,
767 "signals"
768 ))
769 },
770 );
771 let client = DefaultNodeBuilder::default().build(client_kernel).unwrap();
772 client_kernels.push(async move { client.await.map(|_| ()) });
773 }
774
775 let clients = Box::pin(async move { client_kernels.try_collect::<()>().await.map(|_| ()) });
776 if let Err(err) = futures::future::try_select(server, clients).await {
777 return match err {
778 futures::future::Either::Left(res) => Err(res.0.into_string().into()),
779 futures::future::Either::Right(res) => Err(res.0.into_string().into()),
780 };
781 }
782
783 assert!(
784 owner_read.load(Ordering::Relaxed),
785 "owner (hierarchy root) must read the subordinate's DHE message"
786 );
787 Ok(())
788 }
789}