1use crate::prelude::*;
66use crate::test_common::wait_for_peers;
67use citadel_io::tokio::sync::Mutex;
68use citadel_user::prelude::UserIdentifierExt;
69use futures::{Future, StreamExt};
70use std::marker::PhantomData;
71use std::pin::Pin;
72use std::sync::atomic::{AtomicBool, Ordering};
73use std::sync::Arc;
74use uuid::Uuid;
75
76pub struct BroadcastKernel<'a, F, Fut, R: Ratchet> {
82 inner_kernel: Box<dyn NetKernel<R> + 'a>,
83 shared: Arc<BroadcastShared>,
84 _pd: PhantomData<fn() -> (F, Fut)>,
85}
86
87pub struct BroadcastShared {
88 route_registers: AtomicBool,
89 register_rx:
90 citadel_io::Mutex<Option<citadel_io::tokio::sync::mpsc::UnboundedReceiver<PeerSignal>>>,
91 register_tx: citadel_io::tokio::sync::mpsc::UnboundedSender<PeerSignal>,
92}
93
94pub enum GroupInitRequestType {
102 Create {
107 local_user: UserIdentifier,
108 invite_list: Vec<UserIdentifier>,
109 group_id: Uuid,
110 accept_registrations: bool,
111 },
112 Join {
120 local_user: UserIdentifier,
121 owner: UserIdentifier,
122 group_id: Uuid,
123 do_peer_register: bool,
124 },
125}
126
127#[async_trait]
128impl<'a, F, Fut, R: Ratchet> PrefabFunctions<'a, GroupInitRequestType, R>
129 for BroadcastKernel<'a, F, Fut, R>
130where
131 F: FnOnce(GroupChannel, CitadelClientServerConnection<R>) -> Fut + Send + 'a,
132 Fut: Future<Output = Result<(), NetworkError>> + Send + 'a,
133{
134 type UserLevelInputFunction = F;
135 type SharedBundle = Arc<BroadcastShared>;
136
137 fn get_shared_bundle(&self) -> Self::SharedBundle {
138 self.shared.clone()
139 }
140
141 #[allow(unreachable_code, clippy::blocks_in_conditions)]
142 #[cfg_attr(
143 feature = "localhost-testing",
144 tracing::instrument(level = "trace", target = "citadel", skip_all, ret, err(Debug))
145 )]
146 async fn on_c2s_channel_received(
147 connect_success: CitadelClientServerConnection<R>,
148 arg: GroupInitRequestType,
149 fx: Self::UserLevelInputFunction,
150 shared: Arc<BroadcastShared>,
151 ) -> Result<(), NetworkError> {
152 let session_cid = connect_success.cid;
153 wait_for_peers().await;
154 let mut creator_only_accept_inbound_registers = false;
155
156 let mut is_owner = false;
157 let request = match arg {
158 GroupInitRequestType::Create {
159 local_user,
160 invite_list,
161 group_id,
162 accept_registrations,
163 } => {
164 is_owner = true;
165 let mut peers_registered = vec![];
167
168 for peer in &invite_list {
169 let peer = peer
170 .search_peer(session_cid, connect_success.account_manager())
171 .await?
172 .ok_or_else(|| {
173 NetworkError::msg(format!(
174 "[create] User {:?} is not registered to {:?}",
175 peer, &local_user
176 ))
177 })?;
178
179 peers_registered.push(peer.cid)
180 }
181
182 creator_only_accept_inbound_registers = accept_registrations;
183
184 GroupBroadcast::Create {
185 initial_invitees: peers_registered,
186 options: MessageGroupOptions {
187 group_type: GroupType::Public,
188 id: group_id.as_u128(),
189 },
190 }
191 }
192
193 GroupInitRequestType::Join {
194 local_user,
195 owner,
196 group_id,
197 do_peer_register,
198 } => {
199 let owner_orig = owner;
201 let owner_find = owner_orig
202 .search_peer(session_cid, connect_success.account_manager())
203 .await?;
204
205 let owner = if let Some(owner) = owner_find {
206 Some(owner)
207 } else if do_peer_register {
208 let handle = connect_success
209 .propose_target(local_user.clone(), owner_orig.clone())
210 .await?;
211 let _ = handle.register_to_peer().await?;
212 owner_orig
214 .search_peer(session_cid, connect_success.account_manager())
215 .await?
216 } else {
217 None
218 };
219
220 let owner = owner.ok_or_else(|| {
221 NetworkError::msg(format!(
222 "User {:?} is not registered to {:?}",
223 owner_orig, &local_user
224 ))
225 })?;
226
227 let expected_message_group_key = MessageGroupKey {
228 cid: owner.cid,
229 mgid: group_id.as_u128(),
230 };
231
232 let mut retries = 0;
234 let group_owner_handle = connect_success
235 .propose_target(local_user.clone(), owner.cid)
236 .await?;
237 loop {
238 let owned_groups = group_owner_handle.list_owned_groups().await?;
239 if owned_groups.contains(&expected_message_group_key) {
240 break;
241 } else {
242 citadel_io::time::sleep(std::time::Duration::from_secs(2u64.pow(retries)))
243 .await;
244
245 retries += 1;
246 if retries > 4 {
247 return Err(NetworkError::Generic(format!(
248 "Owner {owner:?} has not created group {group_id:?}"
249 )));
250 }
251 }
252 }
253
254 GroupBroadcast::RequestJoin {
255 sender: local_user.get_cid(),
256 key: expected_message_group_key,
257 }
258 }
259 };
260
261 let request = NodeRequest::GroupBroadcastCommand(GroupBroadcastCommand {
262 session_cid,
263 command: request,
264 });
265
266 let subscription = &Mutex::new(Some(
267 connect_success.send_callback_subscription(request).await?,
268 ));
269
270 log::trace!(target: "citadel", "Peer {session_cid} is attempting to join group");
271 let acceptor_task = if creator_only_accept_inbound_registers {
272 shared.route_registers.store(true, Ordering::Relaxed);
273 let mut reg_rx = shared.register_rx.lock().take().unwrap();
274 let remote = connect_success.remote_ref().clone();
275 Box::pin(async move {
276 let mut subscription = subscription.lock().await.take().unwrap();
277 let mut count_registered = 0;
279 loop {
280 let post_register = citadel_io::tokio::select! {
281 reg_request = reg_rx.recv() => {
282 reg_request.ok_or_else(|| NetworkError::InternalError("reg_rx ended unexpectedly"))?
283 },
284
285 reg_request2 = subscription.next() => {
286 let signal = reg_request2.ok_or_else(|| NetworkError::InternalError("subscription ended unexpectedly"))?;
287 if let NodeResult::PeerEvent(PeerEvent { event: sig @ PeerSignal::PostRegister { .. }, .. }) = &signal {
288 sig.clone()
289 } else {
290 continue;
291 }
292 }
293 };
294
295 log::trace!(target: "citadel", "ACCEPTOR {session_cid} RECV reg_request: {post_register:?}");
296 if let PeerSignal::PostRegister {
297 peer_conn_type: peer_conn,
298 inviter_username: _,
299 invitee_username: _,
300 ticket_opt: _,
301 invitee_response: None,
302 } = &post_register
303 {
304 let cid = peer_conn.get_original_target_cid();
305 if cid != session_cid {
306 log::warn!(target: "citadel", "Received the wrong CID. Will not accept request");
307 continue;
308 }
309
310 let _ = responses::peer_register(post_register, true, &remote).await?;
311 if cfg!(feature = "localhost-testing") {
312 count_registered += 1;
313 if count_registered == crate::test_common::num_local_test_peers() - 1 {
314 break;
316 }
317 }
318 }
319 }
320
321 Ok::<_, NetworkError>(())
322 })
323 as Pin<
324 Box<
325 dyn futures::Future<
326 Output = Result<(), citadel_proto::prelude::NetworkError>,
327 > + Send,
328 >,
329 >
330 } else {
331 Box::pin(async move { Ok::<_, NetworkError>(()) })
332 as Pin<
333 Box<
334 dyn futures::Future<
335 Output = Result<(), citadel_proto::prelude::NetworkError>,
336 > + Send,
337 >,
338 >
339 };
340
341 let mut lock = subscription.lock().await;
342 let subscription = lock.as_mut().unwrap();
343 while let Some(event) = subscription.next().await {
344 match event.into_result()? {
345 NodeResult::PeerEvent(PeerEvent {
346 event: ref ps @ PeerSignal::PostRegister { .. },
347 ticket: _,
348 ..
349 }) => {
350 shared
351 .register_tx
352 .send(ps.clone())
353 .map_err(|err| NetworkError::Generic(err.to_string()))?;
354 }
355 NodeResult::GroupChannelCreated(GroupChannelCreated {
356 ticket: _,
357 channel,
358 session_cid: _,
359 }) => {
360 drop(lock);
363 return if is_owner {
364 citadel_io::tokio::try_join!(fx(channel, connect_success), acceptor_task)
365 .map(|_| ())
366 } else {
367 fx(channel, connect_success).await.map(|_| ())
368 };
369 }
370
371 NodeResult::GroupEvent(GroupEvent {
372 session_cid: _,
373 ticket: _,
374 event: GroupBroadcast::CreateResponse { key: None },
375 }) => {
376 return Err(NetworkError::InternalError(
377 "Unable to create a message group",
378 ))
379 }
380
381 _ => {}
382 }
383 }
384
385 Ok(())
386 }
387
388 fn construct(kernel: Box<dyn NetKernel<R> + 'a>) -> Self {
389 let (tx, rx) = citadel_io::tokio::sync::mpsc::unbounded_channel();
390 Self {
391 shared: Arc::new(BroadcastShared {
392 route_registers: AtomicBool::new(false),
393 register_rx: citadel_io::Mutex::new(Some(rx)),
394 register_tx: tx,
395 }),
396 inner_kernel: kernel,
397 _pd: Default::default(),
398 }
399 }
400}
401
402#[async_trait]
403impl<F, Fut, R: Ratchet> NetKernel<R> for BroadcastKernel<'_, F, Fut, R> {
404 fn load_remote(&mut self, node_remote: NodeRemote<R>) -> Result<(), NetworkError> {
405 self.inner_kernel.load_remote(node_remote)
406 }
407
408 async fn on_start(&self) -> Result<(), NetworkError> {
409 self.inner_kernel.on_start().await
410 }
411
412 async fn on_node_event_received(&self, message: NodeResult<R>) -> Result<(), NetworkError> {
413 if let NodeResult::PeerEvent(PeerEvent {
414 event: ps @ PeerSignal::PostRegister { .. },
415 ticket: _,
416 ..
417 }) = &message
418 {
419 if self.shared.route_registers.load(Ordering::Relaxed) {
420 return self
421 .shared
422 .register_tx
423 .send(ps.clone())
424 .map_err(|err| NetworkError::Generic(err.to_string()));
425 }
426 }
427
428 self.inner_kernel.on_node_event_received(message).await
429 }
430
431 async fn on_stop(&mut self) -> Result<(), NetworkError> {
432 self.inner_kernel.on_stop().await
433 }
434}
435
436#[cfg(all(test, feature = "localhost-testing"))]
437mod tests {
438 use crate::prefabs::client::broadcast::{BroadcastKernel, GroupInitRequestType};
439 use crate::prefabs::client::peer_connection::PeerConnectionKernel;
440 use crate::prefabs::client::DefaultServerConnectionSettingsBuilder;
441 use crate::prelude::*;
442 use crate::test_common::{server_info, wait_for_peers, TestBarrier};
443 use citadel_io::tokio;
444 use futures::prelude::stream::FuturesUnordered;
445 use futures::TryStreamExt;
446 use rstest::rstest;
447 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
448 use uuid::Uuid;
449
450 #[citadel_io::tokio::test(flavor = "multi_thread")]
451 async fn group_connect_list_members() -> Result<(), Box<dyn std::error::Error>> {
452 let peer_count = 3;
453 assert!(peer_count > 1);
454 citadel_logging::setup_log();
455 TestBarrier::setup(peer_count);
456
457 let client_success = &AtomicUsize::new(0);
458 let (server, server_addr) = server_info::<StackedRatchet>();
459
460 let client_kernels = FuturesUnordered::new();
461 let total_peers = (0..peer_count)
462 .map(|_| Uuid::new_v4())
463 .collect::<Vec<Uuid>>();
464 let group_id = Uuid::new_v4();
465
466 for idx in 0..peer_count {
467 let uuid = total_peers.get(idx).cloned().unwrap();
468
469 let request = if idx == 0 {
470 GroupInitRequestType::Create {
472 local_user: UserIdentifier::from(uuid),
473 invite_list: vec![],
474 group_id,
475 accept_registrations: true,
476 }
477 } else {
478 GroupInitRequestType::Join {
479 local_user: UserIdentifier::from(uuid),
480 owner: total_peers.first().cloned().unwrap().into(),
481 group_id,
482 do_peer_register: true,
483 }
484 };
485
486 let server_connection_settings =
487 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid)
488 .build()
489 .unwrap();
490
491 let client_kernel = BroadcastKernel::new(
492 server_connection_settings,
493 request,
494 move |channel, connection| async move {
495 wait_for_peers().await;
496 log::trace!(target: "citadel", "***GROUP PEER {}={}={} CONNECT SUCCESS***", idx, uuid, connection.conn_type.get_session_cid());
497
498 let owned_groups = connection.list_owned_groups().await.unwrap();
499
500 if idx == 0 {
501 assert_eq!(owned_groups.len(), 1);
502 } else {
503 assert_eq!(owned_groups.len(), 0);
504 }
505
506 log::trace!(target: "citadel", "Peer {idx}={} is COMPLETE!", connection.conn_type.get_session_cid());
507
508 let _ = client_success.fetch_add(1, Ordering::Relaxed);
509 wait_for_peers().await;
510 drop(channel);
511 connection.shutdown_kernel().await
512 },
513 );
514
515 let client = DefaultNodeBuilder::default().build(client_kernel).unwrap();
516
517 client_kernels.push(async move { client.await.map(|_| ()) });
518 }
519
520 let clients = Box::pin(async move { client_kernels.try_collect::<()>().await.map(|_| ()) });
521
522 let res = futures::future::try_select(server, clients).await;
523 if let Err(err) = res {
524 return match err {
525 futures::future::Either::Left(left) => Err(left.0.into_string().into()),
526 futures::future::Either::Right(right) => Err(right.0.into_string().into()),
527 };
528 }
529
530 assert_eq!(client_success.load(Ordering::Relaxed), peer_count);
531 Ok(())
532 }
533
534 #[rstest]
535 #[case(2)]
536 #[timeout(std::time::Duration::from_secs(90))]
537 #[citadel_io::tokio::test(flavor = "multi_thread")]
538 async fn test_manual_group_connect(
539 #[case] peer_count: usize,
540 ) -> Result<(), Box<dyn std::error::Error>> {
541 assert!(peer_count > 1);
546 citadel_logging::setup_log();
547 TestBarrier::setup(peer_count);
548
549 let client_success = &AtomicBool::new(false);
550 let receiver_success = &AtomicBool::new(false);
551
552 let (server, server_addr) = server_info::<StackedRatchet>();
553
554 let client_kernels = FuturesUnordered::new();
555 let total_peers = (0..peer_count)
556 .map(|_| Uuid::new_v4())
557 .collect::<Vec<Uuid>>();
558
559 for idx in 0..peer_count {
560 let uuid = total_peers.get(idx).cloned().unwrap();
561 let peers = total_peers
562 .clone()
563 .into_iter()
564 .filter(|r| r != &uuid)
565 .map(UserIdentifier::from)
566 .collect::<Vec<UserIdentifier>>();
567
568 let server_connection_settings =
569 DefaultServerConnectionSettingsBuilder::transient_with_id(server_addr, uuid)
570 .build()
571 .unwrap();
572
573 let client_kernel = PeerConnectionKernel::new(
574 server_connection_settings,
575 peers,
576 move |mut results, remote| async move {
577 let _sender = remote.conn_type.get_session_cid();
578 let mut signals = remote.get_unprocessed_signals_receiver().unwrap();
579
580 wait_for_peers().await;
581 let conn = results.recv().await.unwrap()?;
582 log::trace!(target: "citadel", "User {uuid} received {conn:?}");
583
584 if idx == 0 {
586 let _channel = remote
587 .create_group(Some(vec![conn.channel.get_peer_cid().into()]))
588 .await?;
589 log::info!(target: "citadel", "The designated node has finished creating a group");
590
591 wait_for_peers().await;
592 client_success.store(true, Ordering::Relaxed);
593 return remote.shutdown_kernel().await;
594 } else {
595 while let Some(evt) = signals.recv().await {
597 log::info!(target: "citadel", "Received unprocessed signal: {evt:?}");
598 match evt {
599 NodeResult::GroupEvent(GroupEvent {
600 session_cid: _,
601 ticket: _,
602 event:
603 GroupBroadcast::Invitation {
604 sender: _,
605 key: _key,
606 },
607 }) => {
608 let _ =
609 crate::responses::group_invite(evt, true, &remote.inner)
610 .await?;
611 }
612
613 NodeResult::GroupChannelCreated(GroupChannelCreated {
614 ticket: _,
615 channel: _chan,
616 session_cid: _,
617 }) => {
618 receiver_success.store(true, Ordering::Relaxed);
619 log::trace!(target: "citadel", "***PEER {uuid} CONNECT***");
620 wait_for_peers().await;
621 return remote.shutdown_kernel().await;
622 }
623
624 val => {
625 log::warn!(target: "citadel", "Unhandled response: {val:?}")
626 }
627 }
628 }
629 }
630
631 Err(NetworkError::InternalError(
632 "signals_recv ended unexpectedly",
633 ))
634 },
635 );
636
637 let client = DefaultNodeBuilder::default().build(client_kernel).unwrap();
638 client_kernels.push(async move { client.await.map(|_| ()) });
639 }
640
641 let clients = Box::pin(async move { client_kernels.try_collect::<()>().await.map(|_| ()) });
642
643 if let Err(err) = futures::future::try_select(server, clients).await {
644 return match err {
645 futures::future::Either::Left(res) => Err(res.0.into_string().into()),
646 futures::future::Either::Right(res) => Err(res.0.into_string().into()),
647 };
648 }
649
650 assert!(client_success.load(Ordering::Relaxed));
651 assert!(receiver_success.load(Ordering::Relaxed));
652 Ok(())
653 }
654}