diff --git a/src/lib.rs b/src/lib.rs index a458ffc..f7d36e5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1093,7 +1093,7 @@ pub mod rpc { use std::{fmt::Debug, future::Future, io, marker::PhantomData, pin::Pin, sync::Arc}; use n0_future::{future::Boxed as BoxFuture, task::JoinSet}; - use quinn::ConnectionError; + use quinn::{ConnectionError, Endpoint}; use serde::{de::DeserializeOwned, Serialize}; use smallvec::SmallVec; use tracing::{trace, trace_span, warn, Instrument}; @@ -1470,6 +1470,89 @@ pub mod rpc { request_id += 1; } } + + type MultiHandler = Arc< + dyn Fn( + &[u8], + quinn::RecvStream, + quinn::SendStream, + ) -> std::result::Result< + BoxFuture>, + (quinn::RecvStream, quinn::SendStream), + > + Send + + Sync + + 'static, + >; + + pub struct Listener { + handlers: Vec, + } + + impl Listener { + pub fn add_handler(mut self, handler: Handler) -> Self { + self.handlers.push(Arc::new( + move |buf, recv, send| match postcard::from_bytes::(buf) { + Err(_) => Err((recv, send)), + Ok(msg) => Ok(handler(msg, recv, send)), + }, + )); + self + } + + pub async fn listen(self, endpoint: Endpoint) { + let mut request_id = 0u64; + let mut tasks = JoinSet::new(); + while let Some(incoming) = endpoint.accept().await { + let handlers = self.handlers.clone(); + let fut = async move { + let connection = match incoming.await { + Ok(connection) => connection, + Err(cause) => { + warn!("failed to accept connection {cause:?}"); + return io::Result::Ok(()); + } + }; + loop { + let (mut send, mut recv) = match connection.accept_bi().await { + Ok((s, r)) => (s, r), + Err(ConnectionError::ApplicationClosed(cause)) + if cause.error_code.into_inner() == 0 => + { + trace!("remote side closed connection {cause:?}"); + return Ok(()); + } + Err(cause) => { + warn!("failed to accept bi stream {cause:?}"); + return Err(cause.into()); + } + }; + let size = recv.read_varint_u64().await?.ok_or_else(|| { + io::Error::new(io::ErrorKind::UnexpectedEof, "failed to read size") + })?; + let mut buf = vec![0; size as usize]; + recv.read_exact(&mut buf) + .await + .map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?; + for handler in &handlers { + match handler(&buf, recv, send) { + Ok(fut) => { + fut.await?; + break; + } + Err((recv_ret, send_ret)) => { + recv = recv_ret; + send = send_ret; + } + } + } + } + }; + let span = trace_span!("rpc", id = request_id); + tasks.spawn(fut.instrument(span)); + request_id += 1; + } + } + } } /// A request to a service. This can be either local or remote.