Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Round dyn-safe #94

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Removed `Send` bound on `ProtocolError`. ([#92])
- Merged `Round::id()`, `possible_next_rounds()` and `may_produce_result()` into `transition_info()`. ([#93])
- Merged `Round::message_destinations()`, `expecting_messages_from()` and `echo_round_participation()` into `communication_info()`. ([#93])
- Renamed `Payload::try_to_typed()` and `Artifact::try_to_typed()` to `downcast()`. ([#94])
- `Round` methods take `dyn CryptoRngCore` instead of `impl CryptoRngCore`. ([#94])
- `Round` is now dyn-safe. `serializer` and `deserializer` arguments to `Round` and related traits are merged into `format`. `Round::finalize()` takes `self: Box<Self>`. ([#94])


### Added
Expand Down Expand Up @@ -64,6 +67,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#91]: https://github.com/entropyxyz/manul/pull/91
[#92]: https://github.com/entropyxyz/manul/pull/92
[#93]: https://github.com/entropyxyz/manul/pull/93
[#94]: https://github.com/entropyxyz/manul/pull/94


## [0.1.0] - 2024-11-19
Expand Down
8 changes: 4 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

76 changes: 38 additions & 38 deletions examples/src/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ use alloc::collections::{BTreeMap, BTreeSet};
use core::fmt::Debug;

use manul::protocol::{
Artifact, BoxedRound, CommunicationInfo, Deserializer, DirectMessage, EchoBroadcast, EntryPoint, FinalizeOutcome,
Artifact, BoxedFormat, BoxedRound, CommunicationInfo, DirectMessage, EchoBroadcast, EntryPoint, FinalizeOutcome,
LocalError, MessageValidationError, NormalBroadcast, PartyId, Payload, Protocol, ProtocolError, ProtocolMessage,
ProtocolMessagePart, ProtocolValidationError, ReceiveError, RequiredMessageParts, RequiredMessages, Round, RoundId,
Serializer, TransitionInfo,
TransitionInfo,
};
use rand_core::CryptoRngCore;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -39,7 +39,7 @@ impl<Id> ProtocolError<Id> for SimpleProtocolError {

fn verify_messages_constitute_error(
&self,
deserializer: &Deserializer,
format: &BoxedFormat,
_guilty_party: &Id,
_shared_randomness: &[u8],
_associated_data: &Self::AssociatedData,
Expand All @@ -49,20 +49,20 @@ impl<Id> ProtocolError<Id> for SimpleProtocolError {
) -> Result<(), ProtocolValidationError> {
match self {
SimpleProtocolError::Round1InvalidPosition => {
let _message = message.direct_message.deserialize::<Round1Message>(deserializer)?;
let _message = message.direct_message.deserialize::<Round1Message>(format)?;
// Message contents would be checked here
Ok(())
}
SimpleProtocolError::Round2InvalidPosition => {
let _r1_message = message.direct_message.deserialize::<Round1Message>(deserializer)?;
let _r1_message = message.direct_message.deserialize::<Round1Message>(format)?;
let r1_echos_serialized = combined_echos
.get(&1.into())
.ok_or_else(|| LocalError::new("Could not find combined echos for Round 1"))?;

// Deserialize the echos
let _r1_echos = r1_echos_serialized
.iter()
.map(|(_id, echo)| echo.deserialize::<Round1Echo>(deserializer))
.map(|(_id, echo)| echo.deserialize::<Round1Echo>(format))
.collect::<Result<Vec<_>, _>>()?;

// Message contents would be checked here
Expand All @@ -77,36 +77,36 @@ impl<Id> Protocol<Id> for SimpleProtocol {
type ProtocolError = SimpleProtocolError;

fn verify_direct_message_is_invalid(
deserializer: &Deserializer,
format: &BoxedFormat,
round_id: &RoundId,
message: &DirectMessage,
) -> Result<(), MessageValidationError> {
match round_id {
r if r == &1 => message.verify_is_not::<Round1Message>(deserializer),
r if r == &2 => message.verify_is_not::<Round2Message>(deserializer),
r if r == &1 => message.verify_is_not::<Round1Message>(format),
r if r == &2 => message.verify_is_not::<Round2Message>(format),
_ => Err(MessageValidationError::InvalidEvidence("Invalid round number".into())),
}
}

fn verify_echo_broadcast_is_invalid(
deserializer: &Deserializer,
format: &BoxedFormat,
round_id: &RoundId,
message: &EchoBroadcast,
) -> Result<(), MessageValidationError> {
match round_id {
r if r == &1 => message.verify_is_not::<Round1Echo>(deserializer),
r if r == &1 => message.verify_is_not::<Round1Echo>(format),
r if r == &2 => message.verify_is_some(),
_ => Err(MessageValidationError::InvalidEvidence("Invalid round number".into())),
}
}

fn verify_normal_broadcast_is_invalid(
deserializer: &Deserializer,
format: &BoxedFormat,
round_id: &RoundId,
message: &NormalBroadcast,
) -> Result<(), MessageValidationError> {
match round_id {
r if r == &1 => message.verify_is_not::<Round1Broadcast>(deserializer),
r if r == &1 => message.verify_is_not::<Round1Broadcast>(format),
r if r == &2 => message.verify_is_some(),
_ => Err(MessageValidationError::InvalidEvidence("Invalid round number".into())),
}
Expand Down Expand Up @@ -166,7 +166,7 @@ impl<Id: PartyId> EntryPoint<Id> for SimpleProtocolEntryPoint<Id> {

fn make_round(
self,
_rng: &mut impl CryptoRngCore,
_rng: &mut dyn CryptoRngCore,
_shared_randomness: &[u8],
id: &Id,
) -> Result<BoxedRound<Id, Self::Protocol>, LocalError> {
Expand Down Expand Up @@ -205,8 +205,8 @@ impl<Id: PartyId> Round<Id> for Round1<Id> {

fn make_normal_broadcast(
&self,
_rng: &mut impl CryptoRngCore,
serializer: &Serializer,
_rng: &mut dyn CryptoRngCore,
format: &BoxedFormat,
) -> Result<NormalBroadcast, LocalError> {
debug!("{:?}: making normal broadcast", self.context.id);

Expand All @@ -215,27 +215,27 @@ impl<Id: PartyId> Round<Id> for Round1<Id> {
my_position: self.context.ids_to_positions[&self.context.id],
};

NormalBroadcast::new(serializer, message)
NormalBroadcast::new(format, message)
}

fn make_echo_broadcast(
&self,
_rng: &mut impl CryptoRngCore,
serializer: &Serializer,
_rng: &mut dyn CryptoRngCore,
format: &BoxedFormat,
) -> Result<EchoBroadcast, LocalError> {
debug!("{:?}: making echo broadcast", self.context.id);

let message = Round1Echo {
my_position: self.context.ids_to_positions[&self.context.id],
};

EchoBroadcast::new(serializer, message)
EchoBroadcast::new(format, message)
}

fn make_direct_message(
&self,
_rng: &mut impl CryptoRngCore,
serializer: &Serializer,
_rng: &mut dyn CryptoRngCore,
format: &BoxedFormat,
destination: &Id,
) -> Result<(DirectMessage, Option<Artifact>), LocalError> {
debug!("{:?}: making direct message for {:?}", self.context.id, destination);
Expand All @@ -244,21 +244,21 @@ impl<Id: PartyId> Round<Id> for Round1<Id> {
my_position: self.context.ids_to_positions[&self.context.id],
your_position: self.context.ids_to_positions[destination],
};
let dm = DirectMessage::new(serializer, message)?;
let dm = DirectMessage::new(format, message)?;
Ok((dm, None))
}

fn receive_message(
&self,
deserializer: &Deserializer,
format: &BoxedFormat,
from: &Id,
message: ProtocolMessage,
) -> Result<Payload, ReceiveError<Id, Self::Protocol>> {
debug!("{:?}: receiving message from {:?}", self.context.id, from);

let _echo = message.echo_broadcast.deserialize::<Round1Echo>(deserializer)?;
let _normal = message.normal_broadcast.deserialize::<Round1Broadcast>(deserializer)?;
let message = message.direct_message.deserialize::<Round1Message>(deserializer)?;
let _echo = message.echo_broadcast.deserialize::<Round1Echo>(format)?;
let _normal = message.normal_broadcast.deserialize::<Round1Broadcast>(format)?;
let message = message.direct_message.deserialize::<Round1Message>(format)?;

debug!("{:?}: received message: {:?}", self.context.id, message);

Expand All @@ -270,8 +270,8 @@ impl<Id: PartyId> Round<Id> for Round1<Id> {
}

fn finalize(
self,
_rng: &mut impl CryptoRngCore,
self: Box<Self>,
_rng: &mut dyn CryptoRngCore,
payloads: BTreeMap<Id, Payload>,
_artifacts: BTreeMap<Id, Artifact>,
) -> Result<FinalizeOutcome<Id, Self::Protocol>, LocalError> {
Expand All @@ -283,7 +283,7 @@ impl<Id: PartyId> Round<Id> for Round1<Id> {

let typed_payloads = payloads
.into_values()
.map(|payload| payload.try_to_typed::<Round1Payload>())
.map(|payload| payload.downcast::<Round1Payload>())
.collect::<Result<Vec<_>, _>>()?;
let sum = self.context.ids_to_positions[&self.context.id]
+ typed_payloads.iter().map(|payload| payload.x).sum::<u8>();
Expand Down Expand Up @@ -321,8 +321,8 @@ impl<Id: PartyId> Round<Id> for Round2<Id> {

fn make_direct_message(
&self,
_rng: &mut impl CryptoRngCore,
serializer: &Serializer,
_rng: &mut dyn CryptoRngCore,
format: &BoxedFormat,
destination: &Id,
) -> Result<(DirectMessage, Option<Artifact>), LocalError> {
debug!("{:?}: making direct message for {:?}", self.context.id, destination);
Expand All @@ -331,13 +331,13 @@ impl<Id: PartyId> Round<Id> for Round2<Id> {
my_position: self.context.ids_to_positions[&self.context.id],
your_position: self.context.ids_to_positions[destination],
};
let dm = DirectMessage::new(serializer, message)?;
let dm = DirectMessage::new(format, message)?;
Ok((dm, None))
}

fn receive_message(
&self,
deserializer: &Deserializer,
format: &BoxedFormat,
from: &Id,
message: ProtocolMessage,
) -> Result<Payload, ReceiveError<Id, Self::Protocol>> {
Expand All @@ -346,7 +346,7 @@ impl<Id: PartyId> Round<Id> for Round2<Id> {
message.echo_broadcast.assert_is_none()?;
message.normal_broadcast.assert_is_none()?;

let message = message.direct_message.deserialize::<Round1Message>(deserializer)?;
let message = message.direct_message.deserialize::<Round1Message>(format)?;

debug!("{:?}: received message: {:?}", self.context.id, message);

Expand All @@ -358,8 +358,8 @@ impl<Id: PartyId> Round<Id> for Round2<Id> {
}

fn finalize(
self,
_rng: &mut impl CryptoRngCore,
self: Box<Self>,
_rng: &mut dyn CryptoRngCore,
payloads: BTreeMap<Id, Payload>,
_artifacts: BTreeMap<Id, Artifact>,
) -> Result<FinalizeOutcome<Id, Self::Protocol>, LocalError> {
Expand All @@ -371,7 +371,7 @@ impl<Id: PartyId> Round<Id> for Round2<Id> {

let typed_payloads = payloads
.into_values()
.map(|payload| payload.try_to_typed::<Round1Payload>())
.map(|payload| payload.downcast::<Round1Payload>())
.collect::<Result<Vec<_>, _>>()?;
let sum = self.context.ids_to_positions[&self.context.id]
+ typed_payloads.iter().map(|payload| payload.x).sum::<u8>();
Expand Down
14 changes: 6 additions & 8 deletions examples/src/simple_malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ use manul::{
combinators::misbehave::{Misbehaving, MisbehavingEntryPoint},
dev::{run_sync, BinaryFormat, TestSessionParams, TestSigner},
protocol::{
Artifact, BoxedRound, Deserializer, DirectMessage, EntryPoint, LocalError, PartyId, ProtocolMessagePart,
Serializer,
Artifact, BoxedFormat, BoxedRound, DirectMessage, EntryPoint, LocalError, PartyId, ProtocolMessagePart,
},
signature::Keypair,
};
Expand All @@ -28,25 +27,24 @@ impl<Id: PartyId> Misbehaving<Id, Behavior> for MaliciousLogic {
type EntryPoint = SimpleProtocolEntryPoint<Id>;

fn modify_direct_message(
_rng: &mut impl CryptoRngCore,
_rng: &mut dyn CryptoRngCore,
round: &BoxedRound<Id, <Self::EntryPoint as EntryPoint<Id>>::Protocol>,
behavior: &Behavior,
serializer: &Serializer,
_deserializer: &Deserializer,
format: &BoxedFormat,
_destination: &Id,
direct_message: DirectMessage,
artifact: Option<Artifact>,
) -> Result<(DirectMessage, Option<Artifact>), LocalError> {
let dm = if round.id() == 1 {
match behavior {
Behavior::SerializedGarbage => DirectMessage::new(serializer, [99u8])?,
Behavior::SerializedGarbage => DirectMessage::new(format, [99u8])?,
Behavior::AttributableFailure => {
let round1 = round.downcast_ref::<Round1<Id>>()?;
let message = Round1Message {
my_position: round1.context.ids_to_positions[&round1.context.id],
your_position: round1.context.ids_to_positions[&round1.context.id],
};
DirectMessage::new(serializer, message)?
DirectMessage::new(format, message)?
}
_ => direct_message,
}
Expand All @@ -58,7 +56,7 @@ impl<Id: PartyId> Misbehaving<Id, Behavior> for MaliciousLogic {
my_position: round2.context.ids_to_positions[&round2.context.id],
your_position: round2.context.ids_to_positions[&round2.context.id],
};
DirectMessage::new(serializer, message)?
DirectMessage::new(format, message)?
}
_ => direct_message,
}
Expand Down
2 changes: 1 addition & 1 deletion manul/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ categories = ["cryptography", "no-std"]
[dependencies]
serde = { version = "1", default-features = false, features = ["alloc", "serde_derive"] }
erased-serde = { version = "0.4", default-features = false, features = ["alloc"] }
serde-encoded-bytes = { version = "0.1", default-features = false, features = ["hex", "base64"] }
serde-encoded-bytes = { version = "0.2", default-features = false, features = ["hex", "base64"] }
digest = { version = "0.10", default-features = false }
signature = { version = "2", default-features = false, features = ["digest", "rand_core"] }
rand_core = { version = "0.6.4", default-features = false }
Expand Down
Loading
Loading