Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Server trait usability fixes #484

Draft
wants to merge 6 commits into
base: initial-nsec3-generation
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/query-routing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async fn main() {
.ok();

// Start building the query router plus upstreams.
let mut qr: QnameRouter<Vec<u8>, Vec<u8>, ReplyMessage> =
let mut qr: QnameRouter<Vec<u8>, Vec<u8>, (), ReplyMessage> =
QnameRouter::new();

// Queries to the root go to 2606:4700:4700::1111 and 1.1.1.1.
Expand All @@ -57,8 +57,8 @@ async fn main() {
let conn_service = ClientTransportToSingleService::new(redun);
qr.add(Name::<Vec<u8>>::from_str("nl").unwrap(), conn_service);

let srv = SingleServiceToService::new(qr);
let srv = MandatoryMiddlewareSvc::<Vec<u8>, _, _>::new(srv);
let srv = SingleServiceToService::<_, _, _, _>::new(qr);
let srv = MandatoryMiddlewareSvc::new(srv);
let my_svc = Arc::new(srv);

let udpsocket = UdpSocket::bind("[::1]:8053").await.unwrap();
Expand Down
27 changes: 17 additions & 10 deletions examples/serve-zone.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,32 +118,35 @@ async fn main() {
let svc = service_fn(my_service, zones.clone());

#[cfg(feature = "siphasher")]
let svc = CookiesMiddlewareSvc::<Vec<u8>, _, _>::with_random_secret(svc);
let svc = EdnsMiddlewareSvc::<Vec<u8>, _, _>::new(svc);
let svc = XfrMiddlewareSvc::<Vec<u8>, _, _, _>::new(
svc,
zones_and_diffs.clone(),
1,
);
let svc = NotifyMiddlewareSvc::new(svc, DemoNotifyTarget);
let svc = TsigMiddlewareSvc::<_, _, _, ()>::new(svc, key_store);
let svc = CookiesMiddlewareSvc::<Vec<u8>, _, _>::with_random_secret(svc);
let svc = EdnsMiddlewareSvc::<Vec<u8>, _, _>::new(svc);
let svc = MandatoryMiddlewareSvc::<Vec<u8>, _, _>::new(svc);
let svc = TsigMiddlewareSvc::new(svc, key_store);
let svc = Arc::new(svc);

let sock = UdpSocket::bind(&addr).await.unwrap();
let sock = Arc::new(sock);
let mut udp_metrics = vec![];
let num_cores = std::thread::available_parallelism().unwrap().get();
for _i in 0..num_cores {
let udp_srv =
DgramServer::new(sock.clone(), VecBufSource, svc.clone());
let udp_srv = DgramServer::<_, _, _>::new(
sock.clone(),
VecBufSource,
svc.clone(),
);
let metrics = udp_srv.metrics();
udp_metrics.push(metrics);
tokio::spawn(async move { udp_srv.run().await });
}

let sock = TcpListener::bind(addr).await.unwrap();
let tcp_srv = StreamServer::new(sock, VecBufSource, svc);
let tcp_srv = StreamServer::<_, _, _>::new(sock, VecBufSource, svc);
let tcp_metrics = tcp_srv.metrics();

tokio::spawn(async move { tcp_srv.run().await });
Expand Down Expand Up @@ -240,8 +243,8 @@ async fn main() {
}

#[allow(clippy::type_complexity)]
fn my_service(
request: Request<Vec<u8>>,
fn my_service<RequestMeta>(
request: Request<Vec<u8>, RequestMeta>,
zones: Arc<ZoneTree>,
) -> ServiceResult<Vec<u8>> {
let question = request.message().sole_question().unwrap();
Expand Down Expand Up @@ -317,12 +320,12 @@ impl ZoneTreeWithDiffs {
}
}

impl<RequestMeta> XfrDataProvider<RequestMeta> for ZoneTreeWithDiffs {
impl XfrDataProvider<Option<Key>> for ZoneTreeWithDiffs {
type Diff = InMemoryZoneDiff;

fn request<Octs>(
&self,
req: &Request<Octs, RequestMeta>,
req: &Request<Octs, Option<Key>>,
diff_from: Option<Serial>,
) -> Pin<
Box<
Expand All @@ -338,6 +341,10 @@ impl<RequestMeta> XfrDataProvider<RequestMeta> for ZoneTreeWithDiffs {
where
Octs: Octets + Send + Sync,
{
if req.metadata().is_none() {
eprintln!("Rejecting");
return Box::pin(ready(Err(XfrDataProviderError::Refused)));
}
let res = req
.message()
.sole_question()
Expand Down
65 changes: 34 additions & 31 deletions examples/server-transports.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use core::time::Duration;
use std::fs::File;
use std::io;
use std::io::BufReader;
use std::marker::Unpin;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
Expand Down Expand Up @@ -50,7 +51,7 @@ use domain::rdata::{Soa, A};

// Helper fn to create a dummy response to send back to the client
fn mk_answer<Target>(
msg: &Request<Vec<u8>>,
msg: &Request<Vec<u8>, ()>,
builder: MessageBuilder<StreamTarget<Target>>,
) -> Result<AdditionalBuilder<StreamTarget<Target>>, PushError>
where
Expand All @@ -69,7 +70,7 @@ where
}

fn mk_soa_answer<Target>(
msg: &Request<Vec<u8>>,
msg: &Request<Vec<u8>, ()>,
builder: MessageBuilder<StreamTarget<Target>>,
) -> Result<AdditionalBuilder<StreamTarget<Target>>, PushError>
where
Expand Down Expand Up @@ -100,6 +101,7 @@ where

//--- MySingleResultService

#[derive(Clone)]
struct MySingleResultService;

/// This example shows how to implement the [`Service`] trait directly.
Expand All @@ -116,12 +118,12 @@ struct MySingleResultService;
///
/// See [`query`] and [`name_to_ip`] for ways of implementing the [`Service`]
/// trait for a function instead of a struct.
impl Service<Vec<u8>> for MySingleResultService {
impl Service<Vec<u8>, ()> for MySingleResultService {
type Target = Vec<u8>;
type Stream = Once<Ready<ServiceResult<Self::Target>>>;
type Future = Ready<Self::Stream>;

fn call(&self, request: Request<Vec<u8>>) -> Self::Future {
fn call(&self, request: Request<Vec<u8>, ()>) -> Self::Future {
let builder = mk_builder_for_target();
let additional = mk_answer(&request, builder).unwrap();
let item = Ok(CallResult::new(additional));
Expand All @@ -131,6 +133,7 @@ impl Service<Vec<u8>> for MySingleResultService {

//--- MyAsyncStreamingService

#[derive(Clone)]
struct MyAsyncStreamingService;

/// This example also shows how to implement the [`Service`] trait directly.
Expand All @@ -147,13 +150,13 @@ struct MyAsyncStreamingService;
/// and/or Stream implementations that actually wait and/or stream, e.g.
/// making the Stream type be UnboundedReceiver instead of Pin<Box<dyn
/// Stream...>>.
impl Service<Vec<u8>> for MyAsyncStreamingService {
impl Service<Vec<u8>, ()> for MyAsyncStreamingService {
type Target = Vec<u8>;
type Stream =
Pin<Box<dyn Stream<Item = ServiceResult<Self::Target>> + Send>>;
type Future = Pin<Box<dyn Future<Output = Self::Stream> + Send>>;

fn call(&self, request: Request<Vec<u8>>) -> Self::Future {
fn call(&self, request: Request<Vec<u8>, ()>) -> Self::Future {
Box::pin(async move {
if !matches!(
request
Expand Down Expand Up @@ -209,7 +212,10 @@ impl Service<Vec<u8>> for MyAsyncStreamingService {
/// The function signature is slightly more complex than when using
/// [`service_fn`] (see the [`query`] example below).
#[allow(clippy::type_complexity)]
fn name_to_ip(request: Request<Vec<u8>>, _: ()) -> ServiceResult<Vec<u8>> {
fn name_to_ip(
request: Request<Vec<u8>, ()>,
_: (),
) -> ServiceResult<Vec<u8>> {
let mut out_answer = None;
if let Ok(question) = request.message().sole_question() {
let qname = question.qname();
Expand Down Expand Up @@ -257,7 +263,7 @@ fn name_to_ip(request: Request<Vec<u8>>, _: ()) -> ServiceResult<Vec<u8>> {
/// [`service_fn`] and supports passing in meta data without any extra
/// boilerplate.
fn query(
request: Request<Vec<u8>>,
request: Request<Vec<u8>, ()>,
count: Arc<AtomicU8>,
) -> ServiceResult<Vec<u8>> {
let cnt = count
Expand Down Expand Up @@ -455,6 +461,7 @@ impl std::fmt::Display for Stats {
}
}

#[derive(Clone)]
pub struct StatsMiddlewareSvc<Svc> {
svc: Svc,
stats: Arc<RwLock<Stats>>,
Expand All @@ -467,7 +474,7 @@ impl<Svc> StatsMiddlewareSvc<Svc> {
Self { svc, stats }
}

fn preprocess<RequestOctets>(&self, request: &Request<RequestOctets>)
fn preprocess<RequestOctets>(&self, request: &Request<RequestOctets, ()>)
where
RequestOctets: Octets + Send + Sync + Unpin,
{
Expand All @@ -488,12 +495,12 @@ impl<Svc> StatsMiddlewareSvc<Svc> {
}

fn postprocess<RequestOctets>(
request: &Request<RequestOctets>,
request: &Request<RequestOctets, ()>,
response: &AdditionalBuilder<StreamTarget<Svc::Target>>,
stats: &RwLock<Stats>,
) where
RequestOctets: Octets + Send + Sync + Unpin,
Svc: Service<RequestOctets>,
Svc: Service<RequestOctets, ()>,
Svc::Target: AsRef<[u8]>,
{
let duration = Instant::now().duration_since(request.received_at());
Expand All @@ -510,13 +517,13 @@ impl<Svc> StatsMiddlewareSvc<Svc> {
}

fn map_stream_item<RequestOctets>(
request: Request<RequestOctets>,
request: Request<RequestOctets, ()>,
stream_item: ServiceResult<Svc::Target>,
stats: &mut Arc<RwLock<Stats>>,
) -> ServiceResult<Svc::Target>
where
RequestOctets: Octets + Send + Sync + Unpin,
Svc: Service<RequestOctets>,
Svc: Service<RequestOctets, ()>,
Svc::Target: AsRef<[u8]>,
{
if let Ok(cr) = &stream_item {
Expand All @@ -528,10 +535,11 @@ impl<Svc> StatsMiddlewareSvc<Svc> {
}
}

impl<RequestOctets, Svc> Service<RequestOctets> for StatsMiddlewareSvc<Svc>
impl<RequestOctets, Svc> Service<RequestOctets, ()>
for StatsMiddlewareSvc<Svc>
where
RequestOctets: Octets + Send + Sync + 'static + Unpin,
Svc: Service<RequestOctets>,
Svc: Service<RequestOctets, ()>,
Svc::Target: AsRef<[u8]>,
Svc::Future: Unpin,
{
Expand All @@ -551,7 +559,7 @@ where
>;
type Future = Ready<Self::Stream>;

fn call(&self, request: Request<RequestOctets>) -> Self::Future {
fn call(&self, request: Request<RequestOctets, ()>) -> Self::Future {
self.preprocess(&request);
let svc_call_fut = self.svc.call(request.clone());
let map = PostprocessingStream::new(
Expand All @@ -567,24 +575,19 @@ where
//------------ build_middleware_chain() --------------------------------------

#[allow(clippy::type_complexity)]
fn build_middleware_chain<Svc>(
fn build_middleware_chain<Svc, Octs>(
svc: Svc,
stats: Arc<RwLock<Stats>>,
) -> StatsMiddlewareSvc<
MandatoryMiddlewareSvc<
Vec<u8>,
EdnsMiddlewareSvc<
Vec<u8>,
CookiesMiddlewareSvc<Vec<u8>, Svc, ()>,
(),
>,
(),
>,
> {
) -> impl Service<Octs, ()>
where
Octs: Octets + Send + Sync + Clone + Unpin + 'static,
Svc: Service<Octs, ()>,
<Svc as Service<Octs, ()>>::Future: Unpin,
{
#[cfg(feature = "siphasher")]
let svc = CookiesMiddlewareSvc::<Vec<u8>, _, _>::with_random_secret(svc);
let svc = EdnsMiddlewareSvc::<Vec<u8>, _, _>::new(svc);
let svc = MandatoryMiddlewareSvc::<Vec<u8>, _, _>::new(svc);
let svc = CookiesMiddlewareSvc::<Octs, _, ()>::with_random_secret(svc);
let svc = EdnsMiddlewareSvc::new(svc);
let svc = MandatoryMiddlewareSvc::new(svc);
StatsMiddlewareSvc::new(svc, stats.clone())
}

Expand Down
43 changes: 32 additions & 11 deletions src/net/server/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,30 @@ use std::string::ToString;
use std::vec::Vec;

/// Provide a [Service] trait for an object that implements [SingleService].
pub struct SingleServiceToService<RequestOcts, SVC, CR> {
pub struct SingleServiceToService<RequestOcts, SVC, CR, RequestMeta>
where
RequestMeta: Clone + Default,
RequestOcts: Octets + Send + Sync,
SVC: SingleService<RequestOcts, RequestMeta, CR>,
CR: ComposeReply + 'static,
Self: Send + Sync + 'static,
{
/// Service that is wrapped by this object.
service: SVC,

/// Phantom field for RequestOcts and CR.
_phantom: PhantomData<(RequestOcts, CR)>,
_phantom: PhantomData<(RequestOcts, CR, RequestMeta)>,
}

impl<RequestOcts, SVC, CR> SingleServiceToService<RequestOcts, SVC, CR> {
impl<RequestOcts, SVC, CR, RequestMeta>
SingleServiceToService<RequestOcts, SVC, CR, RequestMeta>
where
RequestMeta: Clone + Default,
RequestOcts: Octets + Send + Sync,
SVC: SingleService<RequestOcts, RequestMeta, CR>,
CR: ComposeReply + 'static,
Self: Send + Sync + 'static,
{
/// Create a new [SingleServiceToService] object.
pub fn new(service: SVC) -> Self {
Self {
Expand All @@ -49,18 +64,23 @@ impl<RequestOcts, SVC, CR> SingleServiceToService<RequestOcts, SVC, CR> {
}
}

impl<RequestOcts, SVC, CR> Service<RequestOcts>
for SingleServiceToService<RequestOcts, SVC, CR>
impl<RequestOcts, SVC, CR, RequestMeta> Service<RequestOcts, RequestMeta>
for SingleServiceToService<RequestOcts, SVC, CR, RequestMeta>
where
RequestMeta: Clone + Default,
RequestOcts: Octets + Send + Sync,
SVC: SingleService<RequestOcts, CR>,
SVC: SingleService<RequestOcts, RequestMeta, CR>,
CR: ComposeReply + 'static,
Self: Send + Sync + 'static,
{
type Target = Vec<u8>;
type Stream = Once<Ready<ServiceResult<Self::Target>>>;
type Future = Pin<Box<dyn Future<Output = Self::Stream> + Send>>;

fn call(&self, request: Request<RequestOcts>) -> Self::Future {
fn call(
&self,
request: Request<RequestOcts, RequestMeta>,
) -> Self::Future {
let fut = self.service.call(request);
let fut = async move {
let reply = match fut.await {
Expand Down Expand Up @@ -114,7 +134,8 @@ where
}
}

impl<SR, RequestOcts, CR> SingleService<RequestOcts, CR>
impl<SR, RequestOcts, RequestMeta, CR>
SingleService<RequestOcts, RequestMeta, CR>
for ClientTransportToSingleService<SR, RequestOcts>
where
RequestOcts: AsRef<[u8]> + Clone + Debug + Octets + Send + Sync,
Expand All @@ -123,7 +144,7 @@ where
{
fn call(
&self,
request: Request<RequestOcts>,
request: Request<RequestOcts, RequestMeta>,
) -> Pin<Box<dyn Future<Output = Result<CR, ServiceError>> + Send + Sync>>
where
RequestOcts: AsRef<[u8]>,
Expand Down Expand Up @@ -194,15 +215,15 @@ where
}
}

impl<RequestOcts, CR> SingleService<RequestOcts, CR>
impl<RequestOcts, RequestMeta, CR> SingleService<RequestOcts, RequestMeta, CR>
for BoxClientTransportToSingleService<RequestOcts>
where
RequestOcts: AsRef<[u8]> + Clone + Debug + Octets + Send + Sync,
CR: ComposeReply + Send + Sync + 'static,
{
fn call(
&self,
request: Request<RequestOcts>,
request: Request<RequestOcts, RequestMeta>,
) -> Pin<Box<dyn Future<Output = Result<CR, ServiceError>> + Send + Sync>>
where
RequestOcts: AsRef<[u8]>,
Expand Down
Loading
Loading