Skip to content

Commit

Permalink
Make Protocol and ProtocolError generic over Id
Browse files Browse the repository at this point in the history
  • Loading branch information
fjarri committed Dec 25, 2024
1 parent fc9beee commit 578c4b3
Show file tree
Hide file tree
Showing 16 changed files with 159 additions and 190 deletions.
5 changes: 3 additions & 2 deletions examples/src/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub enum SimpleProtocolError {
Round2InvalidPosition,
}

impl ProtocolError for SimpleProtocolError {
impl<Id> ProtocolError<Id> for SimpleProtocolError {
fn description(&self) -> String {
format!("{:?}", self)
}
Expand Down Expand Up @@ -53,6 +53,7 @@ impl ProtocolError for SimpleProtocolError {
fn verify_messages_constitute_error(
&self,
deserializer: &Deserializer,
_guilty_party: &Id,
_shared_randomness: &[u8],
_echo_broadcast: &EchoBroadcast,
_normal_broadcast: &NormalBroadcast,
Expand Down Expand Up @@ -87,7 +88,7 @@ impl ProtocolError for SimpleProtocolError {
}
}

impl Protocol for SimpleProtocol {
impl<Id> Protocol<Id> for SimpleProtocol {
type Result = u8;
type ProtocolError = SimpleProtocolError;

Expand Down
17 changes: 7 additions & 10 deletions examples/src/simple_chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@ use alloc::collections::BTreeSet;
use core::fmt::Debug;

use manul::{
combinators::{
chain::{Chain, ChainedJoin, ChainedProtocol, ChainedSplit},
CombinatorEntryPoint,
},
combinators::chain::{ChainedJoin, ChainedMarker, ChainedProtocol, ChainedSplit},
protocol::{PartyId, Protocol},
};

Expand All @@ -16,7 +13,9 @@ use super::simple::{SimpleProtocol, SimpleProtocolEntryPoint};
#[derive(Debug)]
pub struct DoubleSimpleProtocol;

impl ChainedProtocol for DoubleSimpleProtocol {
impl ChainedMarker for DoubleSimpleProtocol {}

impl<Id> ChainedProtocol<Id> for DoubleSimpleProtocol {
type Protocol1 = SimpleProtocol;
type Protocol2 = SimpleProtocol;
}
Expand All @@ -25,16 +24,14 @@ pub struct DoubleSimpleEntryPoint<Id> {
all_ids: BTreeSet<Id>,
}

impl<Id> ChainedMarker for DoubleSimpleEntryPoint<Id> {}

impl<Id: PartyId> DoubleSimpleEntryPoint<Id> {
pub fn new(all_ids: BTreeSet<Id>) -> Self {
Self { all_ids }
}
}

impl<Id> CombinatorEntryPoint for DoubleSimpleEntryPoint<Id> {
type Combinator = Chain;
}

impl<Id> ChainedSplit<Id> for DoubleSimpleEntryPoint<Id>
where
Id: PartyId,
Expand All @@ -60,7 +57,7 @@ where
{
type Protocol = DoubleSimpleProtocol;
type EntryPoint = SimpleProtocolEntryPoint<Id>;
fn make_entry_point2(self, _result: <SimpleProtocol as Protocol>::Result) -> Self::EntryPoint {
fn make_entry_point2(self, _result: <SimpleProtocol as Protocol<Id>>::Result) -> Self::EntryPoint {
SimpleProtocolEntryPoint::new(self.all_ids)
}
}
Expand Down
4 changes: 2 additions & 2 deletions examples/tests/async_runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async fn run_session<P, SP>(
session: Session<P, SP>,
) -> Result<SessionReport<P, SP>, LocalError>
where
P: Protocol,
P: Protocol<SP::Verifier>,
SP: SessionParameters,
{
let rng = &mut OsRng;
Expand Down Expand Up @@ -194,7 +194,7 @@ async fn message_dispatcher<SP>(

async fn run_nodes<P, SP>(sessions: Vec<Session<P, SP>>) -> Vec<SessionReport<P, SP>>
where
P: Protocol + Send,
P: Protocol<SP::Verifier> + Send,
SP: SessionParameters,
P::Result: Send,
SP::Signer: Send,
Expand Down
2 changes: 1 addition & 1 deletion manul/benches/empty_rounds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use serde::{Deserialize, Serialize};
#[derive(Debug)]
pub struct EmptyProtocol;

impl Protocol for EmptyProtocol {
impl<Id> Protocol<Id> for EmptyProtocol {
type Result = ();
type ProtocolError = ();
}
Expand Down
3 changes: 0 additions & 3 deletions manul/src/combinators.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
//! Combinators operating on protocols.
pub mod chain;
mod markers;
pub mod misbehave;

pub use markers::{Combinator, CombinatorEntryPoint};
82 changes: 43 additions & 39 deletions manul/src/combinators/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,19 @@ Usage:
1. Implement [`ChainedProtocol`] for a type of your choice. Usually it will be a ZST.
You will have to specify the two protocol types you want to chain.
This type will then automatically implement [`Protocol`](`crate::protocol::Protocol`).
2. Define an entry point type for the new joined protocol.
2. Implement the marker trait [`ChainedMarker`] for this type. This will activate the blanket implementation
of [`Protocol`](`crate::protocol::Protocol`) for it.
The marker trait is needed to disambiguate different generic blanket implementations.
3. Define an entry point type for the new joined protocol.
Most likely it will contain a union between the required data for the entry point
of the first and the second protocol.
3. Implement [`ChainedSplit`] and [`ChainedJoin`] for the new entry point.
4. Implement [`ChainedSplit`] and [`ChainedJoin`] for the new entry point.
4. Mark the new entry point with the [`CombinatorEntryPoint`] trait using [`Chain`] for `Type`.
5. Implement the marker trait [`ChainedMarker`] for this type.
Same as with the protocol, this is needed to disambiguate different generic blanket implementations.
*/

use alloc::{
Expand All @@ -54,65 +58,61 @@ use core::fmt::Debug;
use rand_core::CryptoRngCore;
use serde::{Deserialize, Serialize};

use super::markers::{Combinator, CombinatorEntryPoint};
use crate::protocol::{
Artifact, BoxedRng, BoxedRound, Deserializer, DirectMessage, EchoBroadcast, EchoRoundParticipation, EntryPoint,
FinalizeOutcome, LocalError, NormalBroadcast, ObjectSafeRound, PartyId, Payload, Protocol, ProtocolError,
ProtocolValidationError, ReceiveError, RoundId, Serializer,
};

/// A marker for the `chain` combinator.
#[derive(Debug, Clone, Copy)]
pub struct Chain;

impl Combinator for Chain {}
/// A marker trait that is used to disambiguate blanket trait implementations for [`Protocol`] and [`EntryPoint`].
pub trait ChainedMarker {}

/// A trait defining two protocols executed sequentially.
pub trait ChainedProtocol: 'static + Debug {
pub trait ChainedProtocol<Id>: 'static + Debug {
/// The protcol that is executed first.
type Protocol1: Protocol;
type Protocol1: Protocol<Id>;

/// The protcol that is executed second.
type Protocol2: Protocol;
type Protocol2: Protocol<Id>;
}

/// The protocol error type for the chained protocol.
#[derive_where::derive_where(Debug, Clone)]
#[derive(Serialize, Deserialize)]
#[serde(bound(serialize = "
<C::Protocol1 as Protocol>::ProtocolError: Serialize,
<C::Protocol2 as Protocol>::ProtocolError: Serialize,
<C::Protocol1 as Protocol<Id>>::ProtocolError: Serialize,
<C::Protocol2 as Protocol<Id>>::ProtocolError: Serialize,
"))]
#[serde(bound(deserialize = "
<C::Protocol1 as Protocol>::ProtocolError: for<'x> Deserialize<'x>,
<C::Protocol2 as Protocol>::ProtocolError: for<'x> Deserialize<'x>,
<C::Protocol1 as Protocol<Id>>::ProtocolError: for<'x> Deserialize<'x>,
<C::Protocol2 as Protocol<Id>>::ProtocolError: for<'x> Deserialize<'x>,
"))]
pub enum ChainedProtocolError<C>
pub enum ChainedProtocolError<Id, C>
where
C: ChainedProtocol,
C: ChainedProtocol<Id>,
{
/// A protocol error from the first protocol.
Protocol1(<C::Protocol1 as Protocol>::ProtocolError),
Protocol1(<C::Protocol1 as Protocol<Id>>::ProtocolError),
/// A protocol error from the second protocol.
Protocol2(<C::Protocol2 as Protocol>::ProtocolError),
Protocol2(<C::Protocol2 as Protocol<Id>>::ProtocolError),
}

impl<C> ChainedProtocolError<C>
impl<Id, C> ChainedProtocolError<Id, C>
where
C: ChainedProtocol,
C: ChainedProtocol<Id>,
{
fn from_protocol1(err: <C::Protocol1 as Protocol>::ProtocolError) -> Self {
fn from_protocol1(err: <C::Protocol1 as Protocol<Id>>::ProtocolError) -> Self {
Self::Protocol1(err)
}

fn from_protocol2(err: <C::Protocol2 as Protocol>::ProtocolError) -> Self {
fn from_protocol2(err: <C::Protocol2 as Protocol<Id>>::ProtocolError) -> Self {
Self::Protocol2(err)
}
}

impl<C> ProtocolError for ChainedProtocolError<C>
impl<Id, C> ProtocolError<Id> for ChainedProtocolError<Id, C>
where
C: ChainedProtocol,
C: ChainedProtocol<Id>,
{
fn description(&self) -> String {
match self {
Expand Down Expand Up @@ -169,6 +169,7 @@ where
fn verify_messages_constitute_error(
&self,
deserializer: &Deserializer,
guilty_party: &Id,
shared_randomness: &[u8],
echo_broadcast: &EchoBroadcast,
normal_broadcast: &NormalBroadcast,
Expand Down Expand Up @@ -204,6 +205,7 @@ where
match self {
Self::Protocol1(err) => err.verify_messages_constitute_error(
deserializer,
guilty_party,
shared_randomness,
echo_broadcast,
normal_broadcast,
Expand All @@ -215,6 +217,7 @@ where
),
Self::Protocol2(err) => err.verify_messages_constitute_error(
deserializer,
guilty_party,
shared_randomness,
echo_broadcast,
normal_broadcast,
Expand All @@ -228,23 +231,24 @@ where
}
}

impl<C> Protocol for C
impl<Id, C> Protocol<Id> for C
where
C: ChainedProtocol,
Id: 'static,
C: ChainedProtocol<Id> + ChainedMarker,
{
type Result = <C::Protocol2 as Protocol>::Result;
type ProtocolError = ChainedProtocolError<C>;
type Result = <C::Protocol2 as Protocol<Id>>::Result;
type ProtocolError = ChainedProtocolError<Id, C>;
}

/// A trait defining how the entry point for the whole chained protocol
/// will be split into the entry point for the first protocol, and a piece of data
/// that, along with the first protocol's result, will be used to create the entry point for the second protocol.
pub trait ChainedSplit<Id: PartyId> {
/// The chained protocol this trait belongs to.
type Protocol: ChainedProtocol;
type Protocol: ChainedProtocol<Id> + ChainedMarker;

/// The first protocol's entry point.
type EntryPoint: EntryPoint<Id, Protocol = <Self::Protocol as ChainedProtocol>::Protocol1>;
type EntryPoint: EntryPoint<Id, Protocol = <Self::Protocol as ChainedProtocol<Id>>::Protocol1>;

/// Creates the first protocol's entry point and the data for creating the second entry point.
fn make_entry_point1(self) -> (Self::EntryPoint, impl ChainedJoin<Id, Protocol = Self::Protocol>);
Expand All @@ -254,22 +258,22 @@ pub trait ChainedSplit<Id: PartyId> {
/// will be joined with the result of the first protocol to create an entry point for the second protocol.
pub trait ChainedJoin<Id: PartyId>: 'static + Debug + Send + Sync {
/// The chained protocol this trait belongs to.
type Protocol: ChainedProtocol;
type Protocol: ChainedProtocol<Id> + ChainedMarker;

/// The second protocol's entry point.
type EntryPoint: EntryPoint<Id, Protocol = <Self::Protocol as ChainedProtocol>::Protocol2>;
type EntryPoint: EntryPoint<Id, Protocol = <Self::Protocol as ChainedProtocol<Id>>::Protocol2>;

/// Creates the second protocol's entry point using the first protocol's result.
fn make_entry_point2(
self,
result: <<Self::Protocol as ChainedProtocol>::Protocol1 as Protocol>::Result,
result: <<Self::Protocol as ChainedProtocol<Id>>::Protocol1 as Protocol<Id>>::Result,
) -> Self::EntryPoint;
}

impl<Id, T> EntryPoint<Id> for T
where
Id: PartyId,
T: ChainedSplit<Id> + CombinatorEntryPoint<Combinator = Chain>,
T: ChainedSplit<Id> + ChainedMarker,
{
type Protocol = T::Protocol;

Expand Down Expand Up @@ -314,11 +318,11 @@ where
{
Protocol1 {
id: Id,
round: BoxedRound<Id, <T::Protocol as ChainedProtocol>::Protocol1>,
round: BoxedRound<Id, <T::Protocol as ChainedProtocol<Id>>::Protocol1>,
shared_randomness: Box<[u8]>,
transition: T,
},
Protocol2(BoxedRound<Id, <T::Protocol as ChainedProtocol>::Protocol2>),
Protocol2(BoxedRound<Id, <T::Protocol as ChainedProtocol<Id>>::Protocol2>),
}

impl<Id, T> ObjectSafeRound<Id> for ChainedRound<Id, T>
Expand Down
9 changes: 0 additions & 9 deletions manul/src/combinators/markers.rs

This file was deleted.

8 changes: 4 additions & 4 deletions manul/src/dev/run_sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
},
};

enum State<P: Protocol, SP: SessionParameters> {
enum State<P: Protocol<SP::Verifier>, SP: SessionParameters> {
InProgress {
session: Session<P, SP>,
accum: RoundAccumulator<P, SP>,
Expand All @@ -34,7 +34,7 @@ fn propagate<P, SP>(
accum: RoundAccumulator<P, SP>,
) -> Result<(State<P, SP>, Vec<RoundMessage<SP>>), LocalError>
where
P: Protocol,
P: Protocol<SP::Verifier>,
SP: SessionParameters,
{
let mut messages = Vec::new();
Expand Down Expand Up @@ -168,12 +168,12 @@ where

/// The result of a protocol execution on a set of nodes.
#[derive(Debug)]
pub struct ExecutionResult<P: Protocol, SP: SessionParameters> {
pub struct ExecutionResult<P: Protocol<SP::Verifier>, SP: SessionParameters> {
pub reports: BTreeMap<SP::Verifier, SessionReport<P, SP>>,
}
impl<P, SP> ExecutionResult<P, SP>
where
P: Protocol,
P: Protocol<SP::Verifier>,
SP: SessionParameters,
{
pub fn results(self) -> Result<BTreeMap<SP::Verifier, P::Result>, String> {
Expand Down
Loading

0 comments on commit 578c4b3

Please sign in to comment.