Skip to content

Commit

Permalink
Implement an easier way to get Send Serve and Stub
Browse files Browse the repository at this point in the history
  • Loading branch information
stevefan1999-personal committed Oct 9, 2024
1 parent 240f94f commit 4e7ff9e
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 17 deletions.
12 changes: 6 additions & 6 deletions plugins/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -559,15 +559,15 @@ impl<'a> ServiceGenerator<'a> {
)| {
quote! {
#( #attrs )*
async fn #ident(self, context: ::tarpc::context::Context, #( #args ),*) -> #output;
fn #ident(self, context: ::tarpc::context::Context, #( #args ),*) -> impl ::core::future::Future<Output = #output> + ::core::marker::Send;
}
},
);

let stub_doc = format!("The stub trait for service [`{service_ident}`].");
quote! {
#( #attrs )*
#vis trait #service_ident: ::core::marker::Sized {
#vis trait #service_ident: ::core::marker::Sized + ::core::marker::Send {
#( #rpc_fns )*

/// Returns a serving function to use with
Expand All @@ -578,11 +578,11 @@ impl<'a> ServiceGenerator<'a> {
}

#[doc = #stub_doc]
#vis trait #client_stub_ident: ::tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident> {
#vis trait #client_stub_ident: ::tarpc::client::stub::SendStub<Req = #request_ident, Resp = #response_ident> {
}

impl<S> #client_stub_ident for S
where S: ::tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident>
where S: ::tarpc::client::stub::SendStub<Req = #request_ident, Resp = #response_ident>
{
}
}
Expand Down Expand Up @@ -616,7 +616,7 @@ impl<'a> ServiceGenerator<'a> {
} = self;

quote! {
impl<S> ::tarpc::server::Serve for #server_ident<S>
impl<S> ::tarpc::server::SendServe for #server_ident<S>
where S: #service_ident
{
type Req = #request_ident;
Expand Down Expand Up @@ -780,7 +780,7 @@ impl<'a> ServiceGenerator<'a> {

quote! {
impl<Stub> #client_ident<Stub>
where Stub: ::tarpc::client::stub::Stub<
where Stub: ::tarpc::client::stub::SendStub<
Req = #request_ident,
Resp = #response_ident>
{
Expand Down
55 changes: 51 additions & 4 deletions tarpc/src/client/stub.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
//! Provides a Stub trait, implemented by types that can call remote services.
use std::future::Future;

use crate::{
client::{Channel, RpcError},
context,
server::Serve,
server::{SendServe, Serve},
RequestName,
};

Expand All @@ -15,7 +17,6 @@ mod mock;

/// A connection to a remote service.
/// Calls the service with requests of type `Req` and receives responses of type `Resp`.
#[allow(async_fn_in_trait)]
pub trait Stub {
/// The service request type.
type Req: RequestName;
Expand All @@ -24,8 +25,28 @@ pub trait Stub {
type Resp;

/// Calls a remote service.
async fn call(&self, ctx: context::Context, request: Self::Req)
-> Result<Self::Resp, RpcError>;
fn call(
&self,
ctx: context::Context,
request: Self::Req,
) -> impl Future<Output = Result<Self::Resp, RpcError>>;
}

/// A connection to a remote service.
/// Calls the service with requests of type `Req` and receives responses of type `Resp`.
pub trait SendStub: Send {
/// The service request type.
type Req: RequestName;

/// The service response type.
type Resp;

/// Calls a remote service.
fn call(
&self,
ctx: context::Context,
request: Self::Req,
) -> impl Future<Output = Result<Self::Resp, RpcError>> + Send;
}

impl<Req, Resp> Stub for Channel<Req, Resp>
Expand All @@ -40,6 +61,19 @@ where
}
}

impl<Req, Resp> SendStub for Channel<Req, Resp>
where
Req: RequestName + Send,
Resp: Send,
{
type Req = Req;
type Resp = Resp;

async fn call(&self, ctx: context::Context, request: Req) -> Result<Self::Resp, RpcError> {
Self::call(self, ctx, request).await
}
}

impl<S> Stub for S
where
S: Serve + Clone,
Expand All @@ -50,3 +84,16 @@ where
self.clone().serve(ctx, req).await.map_err(RpcError::Server)
}
}

impl<S> SendStub for S
where
S: SendServe + Clone + Sync,
S::Req: Send + Sync,
S::Resp: Send,
{
type Req = S::Req;
type Resp = S::Resp;
async fn call(&self, ctx: context::Context, req: Self::Req) -> Result<Self::Resp, RpcError> {
self.clone().serve(ctx, req).await.map_err(RpcError::Server)
}
}
18 changes: 18 additions & 0 deletions tarpc/src/client/stub/load_balance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,24 @@ mod round_robin {
}
}

impl<Stub> stub::SendStub for RoundRobin<Stub>
where
Stub: stub::SendStub + Send + Sync,
Stub::Req: Send,
{
type Req = Stub::Req;
type Resp = Stub::Resp;

async fn call(
&self,
ctx: context::Context,
request: Self::Req,
) -> Result<Stub::Resp, RpcError> {
let next = self.stubs.next();
next.call(ctx, request).await
}
}

/// A Stub that load-balances across backing stubs by round robin.
#[derive(Clone, Debug)]
pub struct RoundRobin<Stub> {
Expand Down
27 changes: 26 additions & 1 deletion tarpc/src/client/stub/mock.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use crate::{
client::{stub::Stub, RpcError},
client::{
stub::{SendStub, Stub},
RpcError,
},
context, RequestName, ServerError,
};
use std::{collections::HashMap, hash::Hash, io};
Expand Down Expand Up @@ -42,3 +45,25 @@ where
})
}
}

impl<Req, Resp> SendStub for Mock<Req, Resp>
where
Req: Eq + Hash + RequestName + Send + Sync,
Resp: Clone + Send + Sync,
{
type Req = Req;
type Resp = Resp;

async fn call(&self, _: context::Context, request: Self::Req) -> Result<Resp, RpcError> {
self.responses
.get(&request)
.cloned()
.map(Ok)
.unwrap_or_else(|| {
Err(RpcError::Server(ServerError {
kind: io::ErrorKind::NotFound,
detail: "mock (request, response) entry not found".into(),
}))
})
}
}
27 changes: 27 additions & 0 deletions tarpc/src/client/stub/retry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,33 @@ where
}
}

impl<Stub, Req, F> stub::SendStub for Retry<F, Stub>
where
Req: RequestName + Send + Sync,
Stub: stub::SendStub<Req = Arc<Req>> + Send + Sync,
F: Fn(&Result<Stub::Resp, RpcError>, u32) -> bool + Send + Sync,
{
type Req = Req;
type Resp = Stub::Resp;

async fn call(
&self,
ctx: context::Context,
request: Self::Req,
) -> Result<Stub::Resp, RpcError> {
let request = Arc::new(request);
for i in 1.. {
let result = self.stub.call(ctx, Arc::clone(&request)).await;
if (self.should_retry)(&result, i) {
tracing::trace!("Retrying on attempt {i}");
continue;
}
return result;
}
unreachable!("Wow, that was a lot of attempts!");
}
}

/// A Stub that retries requests based on response contents.
/// Note: to use this stub with Serde serialization, the "rc" feature of Serde needs to be enabled.
#[derive(Clone, Debug)]
Expand Down
37 changes: 31 additions & 6 deletions tarpc/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ impl Config {
}

/// Equivalent to a `FnOnce(Req) -> impl Future<Output = Resp>`.
#[allow(async_fn_in_trait)]
pub trait Serve {
/// Type of request.
type Req: RequestName;
Expand All @@ -76,7 +75,33 @@ pub trait Serve {
type Resp;

/// Responds to a single request.
async fn serve(self, ctx: context::Context, req: Self::Req) -> Result<Self::Resp, ServerError>;
fn serve(
self,
ctx: context::Context,
req: Self::Req,
) -> impl Future<Output = Result<Self::Resp, ServerError>>;
}

/// Equivalent to a `FnOnce(Req) -> impl Future<Output = Resp>`.
pub trait SendServe: Send {
/// Type of request.
type Req: RequestName;
/// Type of response.
type Resp;
/// Responds to a single request.
fn serve(
self,
ctx: context::Context,
req: Self::Req,
) -> impl Future<Output = Result<Self::Resp, ServerError>> + Send;
}

impl<S: SendServe> Serve for S {
type Req = <Self as SendServe>::Req;
type Resp = <Self as SendServe>::Resp;
async fn serve(self, ctx: context::Context, req: Self::Req) -> Result<Self::Resp, ServerError> {
<Self as SendServe>::serve(self, ctx, req).await
}
}

/// A Serve wrapper around a Fn.
Expand Down Expand Up @@ -113,11 +138,11 @@ where
}
}

impl<Req, Resp, Fut, F> Serve for ServeFn<Req, Resp, F>
impl<Req, Resp, Fut, F> SendServe for ServeFn<Req, Resp, F>
where
Req: RequestName,
F: FnOnce(context::Context, Req) -> Fut,
Fut: Future<Output = Result<Resp, ServerError>>,
Req: RequestName + Send,
F: FnOnce(context::Context, Req) -> Fut + Send,
Fut: Future<Output = Result<Resp, ServerError>> + Send,
{
type Req = Req;
type Resp = Resp;
Expand Down

0 comments on commit 4e7ff9e

Please sign in to comment.