citadel_sdk/prefabs/shared/
internal_service.rs1use crate::prelude::{CitadelClientServerConnection, TargetLockedRemote};
46use bytes::Bytes;
47use citadel_io::tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
48use citadel_proto::prelude::NetworkError;
49use citadel_proto::prelude::*;
50use citadel_proto::re_imports::{StreamReader, UnboundedReceiverStream};
51use citadel_types::crypto::SecBuffer;
52use futures::StreamExt;
53use std::future::Future;
54use std::pin::Pin;
55use std::task::{Context, Poll};
56
57pub async fn internal_service<F, Fut, R: Ratchet>(
58 connection: CitadelClientServerConnection<R>,
59 service: F,
60) -> Result<(), NetworkError>
61where
62 F: Send + Copy + Sync + FnOnce(InternalServerCommunicator) -> Fut,
63 Fut: Send + Sync + Future<Output = Result<(), NetworkError>>,
64{
65 let remote = connection.remote.clone();
66 let (tx_to_service, rx_from_kernel) = citadel_io::tokio::sync::mpsc::unbounded_channel();
67 let (tx_to_kernel, mut rx_from_service) = citadel_io::tokio::sync::mpsc::unbounded_channel();
68
69 let internal_server_communicator = InternalServerCommunicator {
70 tx_to_kernel,
71 rx_from_kernel: StreamReader::new(rx_from_kernel.into()),
72 };
73
74 let internal_server = service(internal_server_communicator);
75
76 let (mut sink, mut stream) = connection.split();
78 let from_proto = async move {
80 while let Some(packet) = stream.next().await {
81 tx_to_service.send(Ok(packet.into_buffer().freeze()))?;
84 }
85
86 Ok(())
87 };
88
89 let from_webserver = async move {
91 while let Some(packet) = rx_from_service.recv().await {
92 sink.send(packet).await?;
93 }
94
95 Ok(())
96 };
97
98 let res = citadel_io::tokio::select! {
99 res0 = from_proto => {
100 res0
101 },
102 res1 = from_webserver => {
103 res1
104 },
105 res2 = internal_server => {
106 res2
107 }
108 };
109
110 citadel_logging::warn!(target: "citadel", "Internal Server Stopped: {res:?}");
111
112 remote.remote().shutdown().await?;
113 res
114}
115
116pub struct InternalServerCommunicator {
117 pub(crate) tx_to_kernel: citadel_io::tokio::sync::mpsc::UnboundedSender<SecBuffer>,
118 pub(crate) rx_from_kernel:
119 StreamReader<UnboundedReceiverStream<Result<Bytes, std::io::Error>>, Bytes>,
120}
121
122impl AsyncWrite for InternalServerCommunicator {
123 fn poll_write(
124 self: Pin<&mut Self>,
125 _cx: &mut Context<'_>,
126 buf: &[u8],
127 ) -> Poll<std::io::Result<usize>> {
128 let len = buf.len();
129 match self.tx_to_kernel.send(buf.into()) {
130 Ok(_) => Poll::Ready(Ok(len)),
131 Err(err) => Poll::Ready(Err(std::io::Error::other(err.to_string()))),
132 }
133 }
134
135 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
136 Poll::Ready(Ok(()))
137 }
138
139 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
140 Poll::Ready(Ok(()))
141 }
142}
143
144impl AsyncRead for InternalServerCommunicator {
145 fn poll_read(
146 mut self: Pin<&mut Self>,
147 cx: &mut Context<'_>,
148 buf: &mut ReadBuf<'_>,
149 ) -> Poll<std::io::Result<()>> {
150 Pin::new(&mut self.rx_from_kernel).poll_read(cx, buf)
151 }
152}
153
154impl Unpin for InternalServerCommunicator {}