1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
use crate::prefabs::ClientServerRemote;
use crate::prelude::*;
use citadel_proto::prelude::async_trait;
use futures::Future;
use std::marker::PhantomData;

/// A kernel that executes a user-provided function each time
/// a client makes a connection
pub struct ClientConnectListenerKernel<F, Fut> {
    on_channel_received: F,
    node_remote: Option<NodeRemote>,
    _pd: PhantomData<Fut>,
}

impl<F, Fut> ClientConnectListenerKernel<F, Fut>
where
    F: Fn(ConnectionSuccess, ClientServerRemote) -> Fut + Send + Sync,
    Fut: Future<Output = Result<(), NetworkError>> + Send + Sync,
{
    pub fn new(on_channel_received: F) -> Self {
        Self {
            on_channel_received,
            node_remote: None,
            _pd: Default::default(),
        }
    }
}

#[async_trait]
impl<F, Fut> NetKernel for ClientConnectListenerKernel<F, Fut>
where
    F: Fn(ConnectionSuccess, ClientServerRemote) -> Fut + Send + Sync,
    Fut: Future<Output = Result<(), NetworkError>> + Send + Sync,
{
    fn load_remote(&mut self, server_remote: NodeRemote) -> Result<(), NetworkError> {
        self.node_remote = Some(server_remote);
        Ok(())
    }

    async fn on_start(&self) -> Result<(), NetworkError> {
        Ok(())
    }

    async fn on_node_event_received(&self, message: NodeResult) -> Result<(), NetworkError> {
        match message {
            NodeResult::ConnectSuccess(ConnectSuccess {
                ticket: _,
                implicated_cid: cid,
                remote_addr: _,
                is_personal: _,
                v_conn_type: conn_type,
                services,
                welcome_message: _,
                channel,
                udp_rx_opt: udp_channel_rx,
                session_security_settings,
            }) => {
                let client_server_remote = ClientServerRemote::new(
                    conn_type,
                    self.node_remote.clone().unwrap(),
                    session_security_settings,
                );
                (self.on_channel_received)(
                    ConnectionSuccess {
                        channel,
                        udp_channel_rx,
                        services,
                        cid,
                        session_security_settings,
                    },
                    client_server_remote,
                )
                .await
            }

            other => {
                log::trace!(target: "citadel", "Unhandled server signal: {:?}", other);
                Ok(())
            }
        }
    }

    async fn on_stop(&mut self) -> Result<(), NetworkError> {
        Ok(())
    }
}