1use crate::prelude::*;
66use crate::test_common::wait_for_peers;
67use citadel_io::tokio::sync::Mutex;
68use citadel_user::prelude::UserIdentifierExt;
69use futures::{Future, StreamExt};
70use std::marker::PhantomData;
71use std::pin::Pin;
72use std::sync::atomic::{AtomicBool, Ordering};
73use std::sync::Arc;
74use uuid::Uuid;
75
76pub struct BroadcastKernel<'a, F, Fut, R: Ratchet> {
82 inner_kernel: Box<dyn NetKernel<R> + 'a>,
83 shared: Arc<BroadcastShared>,
84 _pd: PhantomData<fn() -> (F, Fut)>,
85}
86
87pub struct BroadcastShared {
88 route_registers: AtomicBool,
89 register_rx:
90 citadel_io::Mutex<Option<citadel_io::tokio::sync::mpsc::UnboundedReceiver<PeerSignal>>>,
91 register_tx: citadel_io::tokio::sync::mpsc::UnboundedSender<PeerSignal>,
92}
93
94pub enum GroupInitRequestType {
102 Create {
107 local_user: UserIdentifier,
108 invite_list: Vec<UserIdentifier>,
109 group_id: Uuid,
110 accept_registrations: bool,
111 },
112 Join {
120 local_user: UserIdentifier,
121 owner: UserIdentifier,
122 group_id: Uuid,
123 do_peer_register: bool,
124 },
125}
126
127#[async_trait]
128impl<'a, F, Fut, R: Ratchet> PrefabFunctions<'a, GroupInitRequestType, R>
129 for BroadcastKernel<'a, F, Fut, R>
130where
131 F: FnOnce(GroupChannel, CitadelClientServerConnection<R>) -> Fut + Send + 'a,
132 Fut: Future<Output = Result<(), NetworkError>> + Send + 'a,
133{
134 type UserLevelInputFunction = F;
135 type SharedBundle = Arc<BroadcastShared>;
136
137 fn get_shared_bundle(&self) -> Self::SharedBundle {
138 self.shared.clone()
139 }
140
141 #[allow(unreachable_code, clippy::blocks_in_conditions)]
142 #[cfg_attr(
143 feature = "localhost-testing",
144 tracing::instrument(level = "trace", target = "citadel", skip_all, ret, err(Debug))
145 )]
146 async fn on_c2s_channel_received(
147 connect_success: CitadelClientServerConnection<R>,
148 arg: GroupInitRequestType,
149 fx: Self::UserLevelInputFunction,
150 shared: Arc<BroadcastShared>,
151 ) -> Result<(), NetworkError> {
152 let session_cid = connect_success.cid;
153 wait_for_peers().await;
154 let mut creator_only_accept_inbound_registers = false;
155
156 let mut is_owner = false;
157 let request = match arg {
158 GroupInitRequestType::Create {
159 local_user,
160 invite_list,
161 group_id,
162 accept_registrations,
163 } => {
164 is_owner = true;
165 let mut peers_registered = vec![];
167
168 for peer in &invite_list {
169 let peer = peer
170 .search_peer(session_cid, connect_success.account_manager())
171 .await?
172 .ok_or_else(|| {
173 NetworkError::msg(format!(
174 "[create] User {:?} is not registered to {:?}",
175 peer, &local_user
176 ))
177 })?;
178
179 peers_registered.push(peer.cid)
180 }
181
182 creator_only_accept_inbound_registers = accept_registrations;
183
184 GroupBroadcast::Create {
185 initial_invitees: peers_registered,
186 options: MessageGroupOptions {
187 group_type: GroupType::Public,
188 id: group_id.as_u128(),
189 },
190 }
191 }
192
193 GroupInitRequestType::Join {
194 local_user,
195 owner,
196 group_id,
197 do_peer_register,
198 } => {
199 let owner_orig = owner;
201 let owner_find = owner_orig
202 .search_peer(session_cid, connect_success.account_manager())
203 .await?;
204
205 let owner = if let Some(owner) = owner_find {
206 Some(owner)
207 } else if do_peer_register {
208 let handle = connect_success
209 .propose_target(local_user.clone(), owner_orig.clone())
210 .await?;
211 let _ = handle.register_to_peer().await?;
212 owner_orig
214 .search_peer(session_cid, connect_success.account_manager())
215 .await?
216 } else {
217 None
218 };
219
220 let owner = owner.ok_or_else(|| {
221 NetworkError::msg(format!(
222 "User {:?} is not registered to {:?}",
223 owner_orig, &local_user
224 ))
225 })?;
226
227 let expected_message_group_key = MessageGroupKey {
228 cid: owner.cid,
229 mgid: group_id.as_u128(),
230 };
231
232 let mut retries = 0;
234 let group_owner_handle = connect_success
235 .propose_target(local_user.clone(), owner.cid)
236 .await?;
237 loop {
238 let owned_groups = group_owner_handle.list_owned_groups().await?;
239 if owned_groups.contains(&expected_message_group_key) {
240 break;
241 } else {
242 citadel_io::tokio::time::sleep(std::time::Duration::from_secs(
243 2u64.pow(retries),
244 ))
245 .await;
246
247 retries += 1;
248 if retries > 4 {
249 return Err(NetworkError::Generic(format!(
250 "Owner {owner:?} has not created group {group_id:?}"
251 )));
252 }
253 }
254 }
255
256 GroupBroadcast::RequestJoin {
257 sender: local_user.get_cid(),
258 key: expected_message_group_key,
259 }
260 }
261 };
262
263 let request = NodeRequest::GroupBroadcastCommand(GroupBroadcastCommand {
264 session_cid,
265 command: request,
266 });
267
268 let subscription = &Mutex::new(Some(
269 connect_success.send_callback_subscription(request).await?,
270 ));
271
272 log::trace!(target: "citadel", "Peer {session_cid} is attempting to join group");
273 let acceptor_task = if creator_only_accept_inbound_registers {
274 shared.route_registers.store(true, Ordering::Relaxed);
275 let mut reg_rx = shared.register_rx.lock().take().unwrap();
276 let remote = connect_success.remote_ref().clone();
277 Box::pin(async move {
278 let mut subscription = subscription.lock().await.take().unwrap();
279 let mut count_registered = 0;
281 loop {
282 let post_register = citadel_io::tokio::select! {
283 reg_request = reg_rx.recv() => {
284 reg_request.ok_or_else(|| NetworkError::InternalError("reg_rx ended unexpectedly"))?
285 },
286
287 reg_request2 = subscription.next() => {
288 let signal = reg_request2.ok_or_else(|| NetworkError::InternalError("subscription ended unexpectedly"))?;
289 if let NodeResult::PeerEvent(PeerEvent { event: sig @ PeerSignal::PostRegister { .. }, .. }) = &signal {
290 sig.clone()
291 } else {
292 continue;
293 }
294 }
295 };
296
297 log::trace!(target: "citadel", "ACCEPTOR {session_cid} RECV reg_request: {post_register:?}");
298 if let PeerSignal::PostRegister {
299 peer_conn_type: peer_conn,
300 inviter_username: _,
301 invitee_username: _,
302 ticket_opt: _,
303 invitee_response: None,
304 } = &post_register
305 {
306 let cid = peer_conn.get_original_target_cid();
307 if cid != session_cid {
308 log::warn!(target: "citadel", "Received the wrong CID. Will not accept request");
309 continue;
310 }
311
312 let _ = responses::peer_register(post_register, true, &remote).await?;
313 if cfg!(feature = "localhost-testing") {
314 count_registered += 1;
315 if count_registered == crate::test_common::num_local_test_peers() - 1 {
316 break;
318 }
319 }
320 }
321 }
322
323 Ok::<_, NetworkError>(())
324 })
325 as Pin<
326 Box<
327 dyn futures::Future<
328 Output = Result<(), citadel_proto::prelude::NetworkError>,
329 > + Send,
330 >,
331 >
332 } else {
333 Box::pin(async move { Ok::<_, NetworkError>(()) })
334 as Pin<
335 Box<
336 dyn futures::Future<
337 Output = Result<(), citadel_proto::prelude::NetworkError>,
338 > + Send,
339 >,
340 >
341 };
342
343 let mut lock = subscription.lock().await;
344 let subscription = lock.as_mut().unwrap();
345 while let Some(event) = subscription.next().await {
346 match map_errors(event)? {
347 NodeResult::PeerEvent(PeerEvent {
348 event: ref ps @ PeerSignal::PostRegister { .. },
349 ticket: _,
350 ..
351 }) => {
352 shared
353 .register_tx
354 .send(ps.clone())
355 .map_err(|err| NetworkError::Generic(err.to_string()))?;
356 }
357 NodeResult::GroupChannelCreated(GroupChannelCreated {
358 ticket: _,
359 channel,
360 session_cid: _,
361 }) => {
362 drop(lock);
365 return if is_owner {
366 citadel_io::tokio::try_join!(fx(channel, connect_success), acceptor_task)
367 .map(|_| ())
368 } else {
369 fx(channel, connect_success).await.map(|_| ())
370 };
371 }
372
373 NodeResult::GroupEvent(GroupEvent {
374 session_cid: _,
375 ticket: _,
376 event: GroupBroadcast::CreateResponse { key: None },
377 }) => {
378 return Err(NetworkError::InternalError(
379 "Unable to create a message group",
380 ))
381 }
382
383 _ => {}
384 }
385 }
386
387 Ok(())
388 }
389
390 fn construct(kernel: Box<dyn NetKernel<R> + 'a>) -> Self {
391 let (tx, rx) = citadel_io::tokio::sync::mpsc::unbounded_channel();
392 Self {
393 shared: Arc::new(BroadcastShared {
394 route_registers: AtomicBool::new(false),
395 register_rx: citadel_io::Mutex::new(Some(rx)),
396 register_tx: tx,
397 }),
398 inner_kernel: kernel,
399 _pd: Default::default(),
400 }
401 }
402}
403
404#[async_trait]
405impl<F, Fut, R: Ratchet> NetKernel<R> for BroadcastKernel<'_, F, Fut, R> {
406 fn load_remote(&mut self, node_remote: NodeRemote<R>) -> Result<(), NetworkError> {
407 self.inner_kernel.load_remote(node_remote)
408 }
409
410 async fn on_start(&self) -> Result<(), NetworkError> {
411 self.inner_kernel.on_start().await
412 }
413
414 async fn on_node_event_received(&self, message: NodeResult<R>) -> Result<(), NetworkError> {
415 if let NodeResult::PeerEvent(PeerEvent {
416 event: ps @ PeerSignal::PostRegister { .. },
417 ticket: _,
418 ..
419 }) = &message
420 {
421 if self.shared.route_registers.load(Ordering::Relaxed) {
422 return self
423 .shared
424 .register_tx
425 .send(ps.clone())
426 .map_err(|err| NetworkError::Generic(err.to_string()));
427 }
428 }
429
430 self.inner_kernel.on_node_event_received(message).await
431 }
432
433 async fn on_stop(&mut self) -> Result<(), NetworkError> {
434 self.inner_kernel.on_stop().await
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use crate::prefabs::client::broadcast::{BroadcastKernel, GroupInitRequestType};
441 use crate::prefabs::client::peer_connection::PeerConnectionKernel;
442 use crate::prefabs::client::DefaultServerConnectionSettingsBuilder;
443 use crate::prelude::*;
444 use crate::test_common::{server_info, wait_for_peers, TestBarrier};
445 use citadel_io::tokio;
446 use futures::prelude::stream::FuturesUnordered;
447 use futures::TryStreamExt;
448 use rstest::rstest;
449 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
450 use uuid::Uuid;
451
452 #[citadel_io::tokio::test(flavor = "multi_thread")]
453 async fn group_connect_list_members() -> Result<(), Box<dyn std::error::Error>> {
454 let peer_count = 3;
455 assert!(peer_count > 1);
456 citadel_logging::setup_log();
457 TestBarrier::setup(peer_count);
458
459 let client_success = &AtomicUsize::new(0);
460 let (server, server_addr) = server_info::<StackedRatchet>();
461
462 let client_kernels = FuturesUnordered::new();
463 let total_peers = (0..peer_count)
464 .map(|_| Uuid::new_v4())
465 .collect::<Vec<Uuid>>();
466 let group_id = Uuid::new_v4();
467
468 for idx in 0..peer_count {
469 let uuid = total_peers.get(idx).cloned().unwrap();
470
471 let request = if idx == 0 {
472 GroupInitRequestType::Create {
474 local_user: UserIdentifier::from(uuid),
475 invite_list: vec![],
476 group_id,
477 accept_registrations: true,
478 }
479 } else {
480 GroupInitRequestType::Join {
481 local_user: UserIdentifier::from(uuid),
482 owner: total_peers.first().cloned().unwrap().into(),
483 group_id,
484 do_peer_register: true,
485 }
486 };
487
488 let server_connection_settings =
489 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid)
490 .build()
491 .unwrap();
492
493 let client_kernel = BroadcastKernel::new(
494 server_connection_settings,
495 request,
496 move |channel, connection| async move {
497 wait_for_peers().await;
498 log::trace!(target: "citadel", "***GROUP PEER {}={}={} CONNECT SUCCESS***", idx, uuid, connection.conn_type.get_session_cid());
499
500 let owned_groups = connection.list_owned_groups().await.unwrap();
501
502 if idx == 0 {
503 assert_eq!(owned_groups.len(), 1);
504 } else {
505 assert_eq!(owned_groups.len(), 0);
506 }
507
508 log::trace!(target: "citadel", "Peer {idx}={} is COMPLETE!", connection.conn_type.get_session_cid());
509
510 let _ = client_success.fetch_add(1, Ordering::Relaxed);
511 wait_for_peers().await;
512 drop(channel);
513 connection.shutdown_kernel().await
514 },
515 );
516
517 let client = DefaultNodeBuilder::default().build(client_kernel).unwrap();
518
519 client_kernels.push(async move { client.await.map(|_| ()) });
520 }
521
522 let clients = Box::pin(async move { client_kernels.try_collect::<()>().await.map(|_| ()) });
523
524 let res = futures::future::try_select(server, clients).await;
525 if let Err(err) = res {
526 return match err {
527 futures::future::Either::Left(left) => Err(left.0.into_string().into()),
528 futures::future::Either::Right(right) => Err(right.0.into_string().into()),
529 };
530 }
531
532 assert_eq!(client_success.load(Ordering::Relaxed), peer_count);
533 Ok(())
534 }
535
536 #[rstest]
537 #[case(2)]
538 #[timeout(std::time::Duration::from_secs(90))]
539 #[citadel_io::tokio::test(flavor = "multi_thread")]
540 async fn test_manual_group_connect(
541 #[case] peer_count: usize,
542 ) -> Result<(), Box<dyn std::error::Error>> {
543 assert!(peer_count > 1);
548 citadel_logging::setup_log();
549 TestBarrier::setup(peer_count);
550
551 let client_success = &AtomicBool::new(false);
552 let receiver_success = &AtomicBool::new(false);
553
554 let (server, server_addr) = server_info::<StackedRatchet>();
555
556 let client_kernels = FuturesUnordered::new();
557 let total_peers = (0..peer_count)
558 .map(|_| Uuid::new_v4())
559 .collect::<Vec<Uuid>>();
560
561 for idx in 0..peer_count {
562 let uuid = total_peers.get(idx).cloned().unwrap();
563 let peers = total_peers
564 .clone()
565 .into_iter()
566 .filter(|r| r != &uuid)
567 .map(UserIdentifier::from)
568 .collect::<Vec<UserIdentifier>>();
569
570 let server_connection_settings =
571 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid)
572 .build()
573 .unwrap();
574
575 let client_kernel = PeerConnectionKernel::new(
576 server_connection_settings,
577 peers,
578 move |mut results, remote| async move {
579 let _sender = remote.conn_type.get_session_cid();
580 let mut signals = remote.get_unprocessed_signals_receiver().unwrap();
581
582 wait_for_peers().await;
583 let conn = results.recv().await.unwrap()?;
584 log::trace!(target: "citadel", "User {uuid} received {conn:?}");
585
586 if idx == 0 {
588 let _channel = remote
589 .create_group(Some(vec![conn.channel.get_peer_cid().into()]))
590 .await?;
591 log::info!(target: "citadel", "The designated node has finished creating a group");
592
593 wait_for_peers().await;
594 client_success.store(true, Ordering::Relaxed);
595 return remote.shutdown_kernel().await;
596 } else {
597 while let Some(evt) = signals.recv().await {
599 log::info!(target: "citadel", "Received unprocessed signal: {evt:?}");
600 match evt {
601 NodeResult::GroupEvent(GroupEvent {
602 session_cid: _,
603 ticket: _,
604 event:
605 GroupBroadcast::Invitation {
606 sender: _,
607 key: _key,
608 },
609 }) => {
610 let _ =
611 crate::responses::group_invite(evt, true, &remote.inner)
612 .await?;
613 }
614
615 NodeResult::GroupChannelCreated(GroupChannelCreated {
616 ticket: _,
617 channel: _chan,
618 session_cid: _,
619 }) => {
620 receiver_success.store(true, Ordering::Relaxed);
621 log::trace!(target: "citadel", "***PEER {uuid} CONNECT***");
622 wait_for_peers().await;
623 return remote.shutdown_kernel().await;
624 }
625
626 val => {
627 log::warn!(target: "citadel", "Unhandled response: {val:?}")
628 }
629 }
630 }
631 }
632
633 Err(NetworkError::InternalError(
634 "signals_recv ended unexpectedly",
635 ))
636 },
637 );
638
639 let client = DefaultNodeBuilder::default().build(client_kernel).unwrap();
640 client_kernels.push(async move { client.await.map(|_| ()) });
641 }
642
643 let clients = Box::pin(async move { client_kernels.try_collect::<()>().await.map(|_| ()) });
644
645 if let Err(err) = futures::future::try_select(server, clients).await {
646 return match err {
647 futures::future::Either::Left(res) => Err(res.0.into_string().into()),
648 futures::future::Either::Right(res) => Err(res.0.into_string().into()),
649 };
650 }
651
652 assert!(client_success.load(Ordering::Relaxed));
653 assert!(receiver_success.load(Ordering::Relaxed));
654 Ok(())
655 }
656}