From 0fbd004ebafce8d7cff033dea9741b5f5c80aef5 Mon Sep 17 00:00:00 2001 From: Marek Mracna Date: Mon, 29 Jul 2024 10:00:49 +0200 Subject: [PATCH 1/8] feat: distinguish broadcast and unicast messages --- proto/meesign.proto | 8 ++++- src/protocol/elgamal.rs | 34 +++++++++--------- src/protocol/frost.rs | 45 +++++++++++------------ src/protocol/gg18.rs | 79 ++++++++++++++++++++--------------------- src/protocol/mod.rs | 78 +++++++++++++++++++++++++--------------- 5 files changed, 133 insertions(+), 111 deletions(-) diff --git a/proto/meesign.proto b/proto/meesign.proto index a08365a..cf7a824 100644 --- a/proto/meesign.proto +++ b/proto/meesign.proto @@ -21,7 +21,13 @@ message ProtocolInit { bytes data = 4; } +enum MessageType { + UNICAST = 0; + BROADCAST = 1; +} + message ProtocolMessage { ProtocolType protocol_type = 1; - repeated bytes message = 2; + MessageType message_type = 2; + repeated bytes messages = 3; } diff --git a/src/protocol/elgamal.rs b/src/protocol/elgamal.rs index 5a42b19..cbd8c7b 100644 --- a/src/protocol/elgamal.rs +++ b/src/protocol/elgamal.rs @@ -49,16 +49,16 @@ impl KeygenContext { let dkg = ParticipantCollectingCommitments::::new(params, index.into(), &mut OsRng); let c = dkg.commitment(); - let ser = serialize_bcast(&c, msg.parties as usize - 1)?; + let msg = serialize_bcast(&c, ProtocolType::Elgamal)?; self.round = KeygenRound::R1(dkg, index); - Ok(pack(ser, ProtocolType::Elgamal)) + Ok(msg) } fn update(&mut self, data: &[u8]) -> Result> { - let msgs = unpack(data)?; + let msgs = crate::protocol::decode(data)?; let n = msgs.len(); - let (c, ser) = match &self.round { + let (c, msg) = match &self.round { KeygenRound::R0 => return Err("protocol not initialized".into()), KeygenRound::R1(dkg, idx) => { let mut dkg = dkg.clone(); @@ -74,9 +74,9 @@ impl KeygenContext { } let dkg = dkg.finish_commitment_phase(); let public_info = dkg.public_info(); - let ser = serialize_bcast(&public_info, n)?; + let msg = serialize_bcast(&public_info, ProtocolType::Elgamal)?; - (KeygenRound::R2(dkg, *idx), ser) + (KeygenRound::R2(dkg, *idx), msg) } KeygenRound::R2(dkg, idx) => { let mut dkg = dkg.clone(); @@ -100,9 +100,9 @@ impl KeygenContext { let secret_share = dkg.secret_share_for_participant(i); shares.push(secret_share); } - let ser = serialize_uni(shares)?; + let msg = serialize_uni(shares, ProtocolType::Elgamal)?; - (KeygenRound::R3(dkg, *idx), ser) + (KeygenRound::R3(dkg, *idx), msg) } KeygenRound::R3(dkg, idx) => { let mut dkg = dkg.clone(); @@ -117,15 +117,15 @@ impl KeygenContext { return Err("not enough shares".into()); } let dkg = dkg.complete()?; - let ser = inflate(dkg.key_set().shared_key().as_bytes().to_vec(), n); - (KeygenRound::Done(dkg), ser) + let msg = encode_raw_bcast(dkg.key_set().shared_key().as_bytes().to_vec(), ProtocolType::Elgamal); + (KeygenRound::Done(dkg), msg) } KeygenRound::Done(_) => return Err("protocol already finished".into()), }; self.round = c; - Ok(pack(ser, ProtocolType::Elgamal)) + Ok(msg) } } @@ -179,15 +179,15 @@ impl DecryptContext { let (share, proof) = self.ctx.decrypt_share(self.encrypted_key, &mut OsRng); - let ser = serialize_bcast( + let msg = serialize_bcast( &serde_json::to_string(&(share, proof))?.as_bytes(), - self.indices.len() - 1, + ProtocolType::Elgamal, )?; let share = (self.ctx.index(), share); self.shares.push(share); - Ok(pack(ser, ProtocolType::Elgamal)) + Ok(msg) } fn update(&mut self, data: &[u8]) -> Result> { @@ -198,7 +198,7 @@ impl DecryptContext { return Err("protocol already finished".into()); } - let msgs = unpack(data)?; + let msgs = crate::protocol::decode(data)?; let data: Vec> = deserialize_vec(&msgs)?; let local_index = self @@ -254,8 +254,8 @@ impl DecryptContext { self.result = Some(msg.clone()); - let ser = inflate(msg, self.indices.len() - 1); - Ok(pack(ser, ProtocolType::Elgamal)) + let msg = encode_raw_bcast(msg, ProtocolType::Elgamal); + Ok(msg) } } diff --git a/src/protocol/frost.rs b/src/protocol/frost.rs index 68928fd..ce47f6e 100644 --- a/src/protocol/frost.rs +++ b/src/protocol/frost.rs @@ -87,16 +87,16 @@ impl KeygenContext { OsRng, )?; - let msgs = serialize_bcast(&public_package, (setup.parties - 1) as usize)?; + let msg = serialize_bcast(&public_package, ProtocolType::Frost)?; self.round = KeygenRound::R1(setup, secret_package); - Ok((pack(msgs, ProtocolType::Frost), Recipient::Server)) + Ok((msg, Recipient::Server)) } fn update(&mut self, data: &[u8]) -> Result<(Vec, Recipient)> { let (c, data, rec) = match &self.round { KeygenRound::R0 => return Err("protocol not initialized".into()), KeygenRound::R1(setup, secret) => { - let data: Vec = deserialize_vec(&unpack(data)?)?; + let data: Vec = deserialize_vec(&decode(data)?)?; let round1 = map_share_vec(data, &Vec::from_iter(1..=setup.parties), setup.index)?; let (secret, round2) = dkg::part2(secret.clone(), &round1)?; let mut round2: Vec<_> = round2.into_iter().collect(); @@ -105,20 +105,20 @@ impl KeygenContext { ( KeygenRound::R2(*setup, secret, round1), - pack(serialize_uni(round2)?, ProtocolType::Frost), + serialize_uni(round2, ProtocolType::Frost)?, Recipient::Server, ) } KeygenRound::R2(setup, secret, round1) => { - let data: Vec = deserialize_vec(&unpack(data)?)?; + let data: Vec = deserialize_vec(&decode(data)?)?; let round2 = map_share_vec(data, &Vec::from_iter(1..=setup.parties), setup.index)?; let (key, pubkey) = frost::keys::dkg::part3(secret, round1, &round2)?; if !self.with_card { - let msgs = inflate(serde_json::to_vec(&pubkey.verifying_key())?, round2.len()); + let msgs = serialize_bcast(&pubkey.verifying_key(), ProtocolType::Frost)?; ( KeygenRound::Done(*setup, Some(key), pubkey), - pack(msgs, ProtocolType::Frost), + msgs, Recipient::Server, ) } else { @@ -138,13 +138,10 @@ impl KeygenContext { } KeygenRound::R21AwaitSetupResp(setup, pubkey) => { jc::response::setup(data)?; - let msgs = inflate( - serde_json::to_vec(&pubkey.verifying_key())?, - (setup.parties - 1) as usize, - ); + let msg = serialize_bcast(&pubkey.verifying_key(), ProtocolType::Frost)?; ( KeygenRound::Done(*setup, None, pubkey.clone()), - pack(msgs, ProtocolType::Frost), + msg, Recipient::Server, ) } @@ -222,9 +219,9 @@ impl SignContext { if let Some(key) = &self.key { let (nonces, commitments) = frost::round1::commit(key.signing_share(), &mut OsRng); - let msgs = serialize_bcast(&commitments, self.participants() - 1)?; + let msg = serialize_bcast(&commitments, ProtocolType::Frost)?; self.round = SignRound::R1(Some(nonces), commitments); - Ok((pack(msgs, ProtocolType::Frost), Recipient::Server)) + Ok((msg, Recipient::Server)) } else { self.round = SignRound::R01AwaitCommitResp; Ok((jc::command::commit(), Recipient::Card)) @@ -236,12 +233,12 @@ impl SignContext { SignRound::R0 => Err("protocol not initialized".into()), SignRound::R01AwaitCommitResp => { let commitments = jc::response::commit(data)?; - let msgs = serialize_bcast(&commitments, self.participants() - 1)?; + let msg = serialize_bcast(&commitments, ProtocolType::Frost)?; self.round = SignRound::R1(None, commitments); - Ok((pack(msgs, ProtocolType::Frost), Recipient::Server)) + Ok((msg, Recipient::Server)) } SignRound::R1(nonces, commitments) => { - let data: Vec = deserialize_vec(&unpack(data)?)?; + let data: Vec = deserialize_vec(&decode(data)?)?; let mut commitments_map: BTreeMap = map_share_vec(data, self.indices.as_deref().unwrap(), self.setup.index)?; @@ -254,9 +251,9 @@ impl SignContext { if let Some(key) = &self.key { let share = frost::round2::sign(&signing_package, nonces.as_ref().unwrap(), key)?; - let msgs = serialize_bcast(&share, self.participants() - 1)?; + let msg = serialize_bcast(&share, ProtocolType::Frost)?; self.round = SignRound::R2(signing_package, share); - Ok((pack(msgs, ProtocolType::Frost), Recipient::Server)) + Ok((msg, Recipient::Server)) } else { let index = self.indices.as_deref().unwrap()[0]; let command = jc::command::commitment( @@ -294,12 +291,12 @@ impl SignContext { } SignRound::R12AwaitSignResp(signing_package) => { let share = jc::response::sign(data)?; - let msgs = serialize_bcast(&share, self.participants() - 1)?; + let msg = serialize_bcast(&share, ProtocolType::Frost)?; self.round = SignRound::R2(signing_package.clone(), share); - Ok((pack(msgs, ProtocolType::Frost), Recipient::Server)) + Ok((msg, Recipient::Server)) } SignRound::R2(signing_package, share) => { - let data: Vec = deserialize_vec(&unpack(data)?)?; + let data: Vec = deserialize_vec(&decode(data)?)?; let mut shares: BTreeMap = map_share_vec(data, self.indices.as_deref().unwrap(), self.setup.index)?; @@ -308,9 +305,9 @@ impl SignContext { let signature = frost::aggregate(signing_package, &shares, &self.pubkey)?; - let msgs = serialize_bcast(&signature, self.participants() - 1)?; + let msg = serialize_bcast(&signature, ProtocolType::Frost)?; self.round = SignRound::Done(signature); - Ok((pack(msgs, ProtocolType::Frost), Recipient::Server)) + Ok((msg, Recipient::Server)) } SignRound::Done(_) => Err("protocol already finished".into()), } diff --git a/src/protocol/gg18.rs b/src/protocol/gg18.rs index 7d8287e..0c570ed 100644 --- a/src/protocol/gg18.rs +++ b/src/protocol/gg18.rs @@ -29,47 +29,46 @@ impl KeygenContext { (msg.parties as u16, msg.threshold as u16, msg.index as u16); let (out, c1) = gg18_key_gen_1(parties, threshold, index)?; - let ser = serialize_bcast(&out, msg.parties as usize - 1)?; + let msg = serialize_bcast(&out, ProtocolType::Gg18)?; self.round = KeygenRound::R1(c1); - Ok(pack(ser, ProtocolType::Gg18)) + Ok(msg) } fn update(&mut self, data: &[u8]) -> Result> { - let msgs = unpack(data)?; - let n = msgs.len(); + let msgs = decode(data)?; - let (c, ser) = match &self.round { + let (c, msg) = match &self.round { KeygenRound::R0 => unreachable!(), KeygenRound::R1(c1) => { let (out, c2) = gg18_key_gen_2(deserialize_vec(&msgs)?, c1.clone())?; - let ser = serialize_bcast(&out, n)?; - (KeygenRound::R2(c2), ser) + let msg = serialize_bcast(&out, ProtocolType::Gg18)?; + (KeygenRound::R2(c2), msg) } KeygenRound::R2(c2) => { let (outs, c3) = gg18_key_gen_3(deserialize_vec(&msgs)?, c2.clone())?; - let ser = serialize_uni(outs)?; - (KeygenRound::R3(c3), ser) + let msg = serialize_uni(outs, ProtocolType::Gg18)?; + (KeygenRound::R3(c3), msg) } KeygenRound::R3(c3) => { let (out, c4) = gg18_key_gen_4(deserialize_vec(&msgs)?, c3.clone())?; - let ser = serialize_bcast(&out, n)?; - (KeygenRound::R4(c4), ser) + let msg = serialize_bcast(&out, ProtocolType::Gg18)?; + (KeygenRound::R4(c4), msg) } KeygenRound::R4(c4) => { let (out, c5) = gg18_key_gen_5(deserialize_vec(&msgs)?, c4.clone())?; - let ser = serialize_bcast(&out, n)?; - (KeygenRound::R5(c5), ser) + let msg = serialize_bcast(&out, ProtocolType::Gg18)?; + (KeygenRound::R5(c5), msg) } KeygenRound::R5(c5) => { let c = gg18_key_gen_6(deserialize_vec(&msgs)?, c5.clone())?; - let ser = inflate(c.pk.to_bytes(false).to_vec(), n); - (KeygenRound::Done(c), ser) + let msg = encode_raw_bcast(c.pk.to_bytes(false).to_vec(), ProtocolType::Gg18); + (KeygenRound::Done(c), msg) } KeygenRound::Done(_) => todo!(), }; self.round = c; - Ok(pack(ser, ProtocolType::Gg18)) + Ok(msg) } } @@ -124,7 +123,6 @@ impl SignContext { let msg = ProtocolInit::decode(data)?; let indices: Vec = msg.indices.clone().into_iter().map(|i| i as u16).collect(); - let parties = indices.len(); let local_index = indices.iter().position(|&i| i == msg.index as u16).unwrap(); let c0 = match &self.round { @@ -133,67 +131,66 @@ impl SignContext { }; let (out, c1) = gg18_sign1(c0, indices, local_index, msg.data)?; - let ser = serialize_bcast(&out, parties - 1)?; + let msg = serialize_bcast(&out, ProtocolType::Gg18)?; self.round = SignRound::R1(c1); - Ok(pack(ser, ProtocolType::Gg18)) + Ok(msg) } fn update(&mut self, data: &[u8]) -> Result> { - let msgs = unpack(data)?; - let n = msgs.len(); + let msgs = decode(data)?; - let (c, ser) = match &self.round { + let (c, msg) = match &self.round { SignRound::R0(_) => unreachable!(), SignRound::R1(c1) => { let (outs, c2) = gg18_sign2(deserialize_vec(&msgs)?, c1.clone())?; - let ser = serialize_uni(outs)?; - (SignRound::R2(c2), ser) + let msg = serialize_uni(outs, ProtocolType::Gg18)?; + (SignRound::R2(c2), msg) } SignRound::R2(c2) => { let (out, c3) = gg18_sign3(deserialize_vec(&msgs)?, c2.clone())?; - let ser = serialize_bcast(&out, n)?; - (SignRound::R3(c3), ser) + let msg = serialize_bcast(&out, ProtocolType::Gg18)?; + (SignRound::R3(c3), msg) } SignRound::R3(c3) => { let (out, c4) = gg18_sign4(deserialize_vec(&msgs)?, c3.clone())?; - let ser = serialize_bcast(&out, n)?; - (SignRound::R4(c4), ser) + let msg = serialize_bcast(&out, ProtocolType::Gg18)?; + (SignRound::R4(c4), msg) } SignRound::R4(c4) => { let (out, c5) = gg18_sign5(deserialize_vec(&msgs)?, c4.clone())?; - let ser = serialize_bcast(&out, n)?; - (SignRound::R5(c5), ser) + let msg = serialize_bcast(&out, ProtocolType::Gg18)?; + (SignRound::R5(c5), msg) } SignRound::R5(c5) => { let (out, c6) = gg18_sign6(deserialize_vec(&msgs)?, c5.clone())?; - let ser = serialize_bcast(&out, n)?; - (SignRound::R6(c6), ser) + let msg = serialize_bcast(&out, ProtocolType::Gg18)?; + (SignRound::R6(c6), msg) } SignRound::R6(c6) => { let (out, c7) = gg18_sign7(deserialize_vec(&msgs)?, c6.clone())?; - let ser = serialize_bcast(&out, n)?; - (SignRound::R7(c7), ser) + let msg = serialize_bcast(&out, ProtocolType::Gg18)?; + (SignRound::R7(c7), msg) } SignRound::R7(c7) => { let (out, c8) = gg18_sign8(deserialize_vec(&msgs)?, c7.clone())?; - let ser = serialize_bcast(&out, n)?; - (SignRound::R8(c8), ser) + let msg = serialize_bcast(&out, ProtocolType::Gg18)?; + (SignRound::R8(c8), msg) } SignRound::R8(c8) => { let (out, c9) = gg18_sign9(deserialize_vec(&msgs)?, c8.clone())?; - let ser = serialize_bcast(&out, n)?; - (SignRound::R9(c9), ser) + let msg = serialize_bcast(&out, ProtocolType::Gg18)?; + (SignRound::R9(c9), msg) } SignRound::R9(c9) => { let sig = gg18_sign10(deserialize_vec(&msgs)?, c9.clone())?; - let ser = inflate(sig.clone(), n); - (SignRound::Done(sig), ser) + let msg = encode_raw_bcast(sig.clone(), ProtocolType::Gg18); + (SignRound::Done(sig), msg) } SignRound::Done(_) => todo!(), }; self.round = c; - Ok(pack(ser, ProtocolType::Gg18)) + Ok(msg) } } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 29e25d4..f80f30c 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -9,7 +9,7 @@ mod apdu; pub type Result = std::result::Result>; -use crate::proto::{ProtocolMessage, ProtocolType}; +use crate::proto::{MessageType, ProtocolMessage, ProtocolType}; use prost::Message; use serde::{Deserialize, Serialize}; @@ -42,35 +42,39 @@ fn deserialize_vec<'de, T: Deserialize<'de>>(vec: &'de [Vec]) -> serde_json: .collect() } -fn inflate(value: T, n: usize) -> Vec { - std::iter::repeat(value).take(n).collect() +/// Encode a broadcast message +fn encode_raw_bcast(message: Vec, protocol_type: ProtocolType) -> Vec { + ProtocolMessage { + protocol_type: protocol_type.into(), + message_type: MessageType::Broadcast.into(), + messages: vec![message], + }.encode_to_vec() } -/// Serialize value and repeat the result n times, -/// as the current server always expects one message for each party -fn serialize_bcast(value: &T, n: usize) -> serde_json::Result>> { - let ser = serde_json::to_vec(value)?; - Ok(inflate(ser, n)) +/// Serialize and encode a broadcast message +fn serialize_bcast(value: &T, protocol_type: ProtocolType) -> serde_json::Result> { + let message = serde_json::to_vec(value)?; + Ok(encode_raw_bcast(message, protocol_type)) } -/// Serialize vector of unicast messages -fn serialize_uni(vec: Vec) -> serde_json::Result>> { - vec.iter().map(|item| serde_json::to_vec(item)).collect() +/// Encode a Vec of unicast messages +fn encode_raw_uni(messages: Vec>, protocol_type: ProtocolType) -> Vec { + ProtocolMessage { + protocol_type: protocol_type.into(), + message_type: MessageType::Unicast.into(), + messages, + }.encode_to_vec() } -/// Decode a protobuf message from the server -fn unpack(data: &[u8]) -> std::result::Result>, prost::DecodeError> { - let msgs = ProtocolMessage::decode(data)?.message; - Ok(msgs) +/// Serialize and encode a Vec of unicast messages +fn serialize_uni(vec: Vec, protocol_type: ProtocolType) -> serde_json::Result> { + let messages: serde_json::Result> = vec.iter().map(serde_json::to_vec).collect(); + Ok(encode_raw_uni(messages?, protocol_type)) } -/// Encode msgs as a protobuf message for the server -fn pack(msgs: Vec>, protocol_type: ProtocolType) -> Vec { - ProtocolMessage { - protocol_type: protocol_type.into(), - message: msgs, - } - .encode_to_vec() +/// Decode a protobuf message from the server +fn decode(data: &[u8]) -> std::result::Result>, prost::DecodeError> { + Ok(ProtocolMessage::decode(data)?.messages) } #[cfg(test)] @@ -84,6 +88,22 @@ mod tests { protocol::{KeygenProtocol, ThresholdProtocol}, }; + /// Translate a message from a client to a Vec of messages for every other client + fn distribute_client_message(message: ProtocolMessage, parties: u32) -> Vec> { + match message.message_type() { + MessageType::Broadcast => { + let messages = message.messages; + assert_eq!(messages.len(), 1); + std::iter::repeat(messages[0].clone()).take(parties as usize).collect() + }, + MessageType::Unicast => { + let messages = message.messages; + assert_eq!(messages.len(), parties as usize); + messages + }, + } + } + pub(super) trait KeygenProtocolTest: KeygenProtocol + Sized { // Cannot be added in Protocol (yet) due to typetag Trait limitations const PROTOCOL_TYPE: ProtocolType; @@ -114,8 +134,8 @@ mod tests { .into(), ) .unwrap() - .message }) + .map(|msg| distribute_client_message(msg, parties - 1)) .collect(); // protocol rounds @@ -144,7 +164,8 @@ mod tests { ctx.advance( &(ProtocolMessage { protocol_type: Self::PROTOCOL_TYPE as i32, - message: relay, + message_type: MessageType::Unicast.into(), + messages: relay, }) .encode_to_vec(), ) @@ -153,8 +174,8 @@ mod tests { .into(), ) .unwrap() - .message }) + .map(|msg| distribute_client_message(msg, parties - 1)) .collect(); } @@ -205,8 +226,8 @@ mod tests { .into(), ) .unwrap() - .message }) + .map(|msg| distribute_client_message(msg, indices.len() as u32 - 1)) .collect(); // protocol rounds @@ -235,7 +256,8 @@ mod tests { ctx.advance( &(ProtocolMessage { protocol_type: Self::PROTOCOL_TYPE as i32, - message: relay, + message_type: MessageType::Unicast.into(), + messages: relay, }) .encode_to_vec(), ) @@ -244,8 +266,8 @@ mod tests { .into(), ) .unwrap() - .message }) + .map(|msg| distribute_client_message(msg, indices.len() as u32 - 1)) .collect(); } From c7e64fae955c8c0d5ad4717616d4e29f439c1fb8 Mon Sep 17 00:00:00 2001 From: Marek Mracna Date: Mon, 29 Jul 2024 10:12:06 +0200 Subject: [PATCH 2/8] chore: bump version to 0.3.0 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index b009fd5..00cbc01 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "meesign-crypto" -version = "0.2.0" +version = "0.3.0" edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html From 9d4d0ee77f987620aeb4cc9edfb03de08819f66e Mon Sep 17 00:00:00 2001 From: Marek Mracna Date: Sat, 17 Aug 2024 09:17:23 +0200 Subject: [PATCH 3/8] refactor: sensible error handling --- src/protocol/mod.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index f80f30c..5c5f935 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -68,8 +68,10 @@ fn encode_raw_uni(messages: Vec>, protocol_type: ProtocolType) -> Vec(vec: Vec, protocol_type: ProtocolType) -> serde_json::Result> { - let messages: serde_json::Result> = vec.iter().map(serde_json::to_vec).collect(); - Ok(encode_raw_uni(messages?, protocol_type)) + let messages = vec.iter() + .map(serde_json::to_vec) + .collect::>>()?; + Ok(encode_raw_uni(messages, protocol_type)) } /// Decode a protobuf message from the server From 6f84a6cb4c081aac88d959b161235827a5c86346 Mon Sep 17 00:00:00 2001 From: Marek Mracna Date: Tue, 27 Aug 2024 00:04:26 +0200 Subject: [PATCH 4/8] feat: hashmaps in protocol messages --- proto/meesign.proto | 13 +-- src/protocol/elgamal.rs | 93 ++++++------------ src/protocol/frost.rs | 59 ++++++------ src/protocol/gg18.rs | 76 +++++++++++---- src/protocol/mod.rs | 209 +++++++++++++++++++--------------------- 5 files changed, 224 insertions(+), 226 deletions(-) diff --git a/proto/meesign.proto b/proto/meesign.proto index cf7a824..7314a17 100644 --- a/proto/meesign.proto +++ b/proto/meesign.proto @@ -21,13 +21,14 @@ message ProtocolInit { bytes data = 4; } -enum MessageType { - UNICAST = 0; - BROADCAST = 1; +message ClientMessage { + ProtocolType protocol_type = 1; + map unicasts = 2; + optional bytes broadcast = 3; } -message ProtocolMessage { +message ServerMessage { ProtocolType protocol_type = 1; - MessageType message_type = 2; - repeated bytes messages = 3; + map unicasts = 2; + map broadcasts = 3; } diff --git a/src/protocol/elgamal.rs b/src/protocol/elgamal.rs index cbd8c7b..a34a7ec 100644 --- a/src/protocol/elgamal.rs +++ b/src/protocol/elgamal.rs @@ -1,4 +1,4 @@ -use crate::proto::{ProtocolGroupInit, ProtocolInit, ProtocolType}; +use crate::proto::{ProtocolGroupInit, ProtocolInit, ProtocolType, ServerMessage}; use crate::protocol::*; use curve25519_dalek::{ ristretto::{CompressedRistretto, RistrettoPoint}, @@ -19,6 +19,8 @@ use aes_gcm::{ use prost::Message; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + #[derive(Serialize, Deserialize)] pub(crate) struct KeygenContext { round: KeygenRound, @@ -27,9 +29,9 @@ pub(crate) struct KeygenContext { #[derive(Serialize, Deserialize)] enum KeygenRound { R0, - R1(ParticipantCollectingCommitments, u16), - R2(ParticipantCollectingPolynomials, u16), - R3(ParticipantExchangingSecrets, u16), + R1(ParticipantCollectingCommitments), + R2(ParticipantCollectingPolynomials), + R3(ParticipantExchangingSecrets), Done(ActiveParticipant), } @@ -50,24 +52,20 @@ impl KeygenContext { ParticipantCollectingCommitments::::new(params, index.into(), &mut OsRng); let c = dkg.commitment(); let msg = serialize_bcast(&c, ProtocolType::Elgamal)?; - self.round = KeygenRound::R1(dkg, index); + self.round = KeygenRound::R1(dkg); Ok(msg) } fn update(&mut self, data: &[u8]) -> Result> { - let msgs = crate::protocol::decode(data)?; - let n = msgs.len(); + let msgs = ServerMessage::decode(data)?; let (c, msg) = match &self.round { KeygenRound::R0 => return Err("protocol not initialized".into()), - KeygenRound::R1(dkg, idx) => { + KeygenRound::R1(dkg) => { let mut dkg = dkg.clone(); - let data = deserialize_vec(&msgs)?; - for (mut i, msg) in data.into_iter().enumerate() { - if i >= *idx as usize { - i += 1; - } - dkg.insert_commitment(i, msg); + let data = deserialize_map(&msgs.broadcasts)?; + for (i, msg) in data { + dkg.insert_commitment(i as usize, msg); } if dkg.missing_commitments().next().is_some() { return Err("not enough commitments".into()); @@ -76,42 +74,33 @@ impl KeygenContext { let public_info = dkg.public_info(); let msg = serialize_bcast(&public_info, ProtocolType::Elgamal)?; - (KeygenRound::R2(dkg, *idx), msg) + (KeygenRound::R2(dkg), msg) } - KeygenRound::R2(dkg, idx) => { + KeygenRound::R2(dkg) => { let mut dkg = dkg.clone(); - let data = deserialize_vec(&msgs)?; - for (mut i, msg) in data.into_iter().enumerate() { - if i >= *idx as usize { - i += 1; - } - dkg.insert_public_polynomial(i, msg)? + let data = deserialize_map(&msgs.broadcasts)?; + for (i, msg) in data { + dkg.insert_public_polynomial(i as usize, msg)? } if dkg.missing_public_polynomials().next().is_some() { return Err("not enough polynomials".into()); } let dkg = dkg.finish_polynomials_phase(); - let mut shares = Vec::new(); - for mut i in 0..n { - if i >= *idx as usize { - i += 1; - } - let secret_share = dkg.secret_share_for_participant(i); - shares.push(secret_share); - } + let shares = msgs + .broadcasts + .into_keys() + .map(|i| (i, dkg.secret_share_for_participant(i as usize))); + let msg = serialize_uni(shares, ProtocolType::Elgamal)?; - (KeygenRound::R3(dkg, *idx), msg) + (KeygenRound::R3(dkg), msg) } - KeygenRound::R3(dkg, idx) => { + KeygenRound::R3(dkg) => { let mut dkg = dkg.clone(); - let data = deserialize_vec(&msgs)?; - for (mut i, msg) in data.into_iter().enumerate() { - if i >= *idx as usize { - i += 1; - } - dkg.insert_secret_share(i, msg)?; + let data = deserialize_map(&msgs.unicasts)?; + for (i, msg) in data { + dkg.insert_secret_share(i as usize, msg)?; } if dkg.missing_shares().next().is_some() { return Err("not enough shares".into()); @@ -160,7 +149,6 @@ pub(crate) struct DecryptContext { ctx: ActiveParticipant, encrypted_key: Ciphertext, data: (Vec, Vec, Vec), - indices: Vec, shares: Vec<(usize, VerifiableDecryption)>, result: Option>, } @@ -173,7 +161,6 @@ impl DecryptContext { return Err("wrong protocol type".into()); } - self.indices = msg.indices.clone().into_iter().map(|i| i as u16).collect(); self.data = serde_json::from_slice(&msg.data)?; self.encrypted_key = serde_json::from_slice(&self.data.0)?; @@ -198,32 +185,17 @@ impl DecryptContext { return Err("protocol already finished".into()); } - let msgs = crate::protocol::decode(data)?; + let msgs = ServerMessage::decode(data)?; - let data: Vec> = deserialize_vec(&msgs)?; - let local_index = self - .indices - .iter() - .position(|x| *x as usize == self.ctx.index()) - .ok_or("participant index not included")?; - assert_eq!(self.ctx.index(), self.indices[local_index] as usize); - - for (mut i, msg) in data.into_iter().enumerate() { - if i >= local_index { - i += 1; - } + let data: HashMap> = deserialize_map(&msgs.broadcasts)?; + for (i, msg) in data { let msg: (VerifiableDecryption, LogEqualityProof) = serde_json::from_slice(&msg)?; self.ctx .key_set() - .verify_share( - msg.0.into(), - self.encrypted_key, - self.indices[i].into(), - &msg.1, - ) + .verify_share(msg.0.into(), self.encrypted_key, i as usize, &msg.1) .unwrap(); - self.shares.push((self.indices[i].into(), msg.0)); + self.shares.push((i as usize, msg.0)); } let mut key = [0u8; 16]; @@ -284,7 +256,6 @@ impl ThresholdProtocol for DecryptContext { ctx: serde_json::from_slice(group).expect("could not deserialize group context"), encrypted_key: Ciphertext::zero(), data: (Vec::new(), Vec::new(), Vec::new()), - indices: Vec::new(), shares: Vec::new(), result: None, } diff --git a/src/protocol/frost.rs b/src/protocol/frost.rs index ce47f6e..658f6b6 100644 --- a/src/protocol/frost.rs +++ b/src/protocol/frost.rs @@ -1,4 +1,4 @@ -use crate::proto::{ProtocolGroupInit, ProtocolInit, ProtocolType}; +use crate::proto::{ProtocolGroupInit, ProtocolInit, ProtocolType, ServerMessage}; use crate::protocol::*; use frost::keys::dkg::{self, round1, round2}; @@ -10,7 +10,6 @@ use prost::Message; use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; use std::convert::{TryFrom, TryInto}; -use std::iter::FromIterator; use frost_secp256k1 as frost; use rand::rngs::OsRng; @@ -22,23 +21,14 @@ struct Setup { index: u16, } -fn map_share_vec(vec: Vec, indices: &[u16], index: u16) -> Result -where - C: FromIterator<(Identifier, T)>, -{ - let pos = indices - .iter() - .position(|&x| x == index) - .ok_or("missing index")?; - let collection = vec - .into_iter() - .enumerate() - .map(move |(i, item)| { - let index = indices[if i < pos { i } else { i + 1 }]; - (Identifier::try_from(index).unwrap(), item) - }) - .collect(); - Ok(collection) +fn map_index_identifier( + kvs: impl Iterator, +) -> impl Iterator { + kvs.map(|(i, x)| { + assert!(i > 0); + assert!(i <= u16::MAX as u32); + (Identifier::try_from(i as u16).unwrap(), x) + }) } #[derive(Serialize, Deserialize)] @@ -96,12 +86,16 @@ impl KeygenContext { let (c, data, rec) = match &self.round { KeygenRound::R0 => return Err("protocol not initialized".into()), KeygenRound::R1(setup, secret) => { - let data: Vec = deserialize_vec(&decode(data)?)?; - let round1 = map_share_vec(data, &Vec::from_iter(1..=setup.parties), setup.index)?; + let data = ServerMessage::decode(data)?.broadcasts; + let round1 = deserialize_map(&data)?; + let indices: Vec<_> = round1.keys().cloned().collect(); + let round1 = map_index_identifier(round1.into_iter()).collect(); let (secret, round2) = dkg::part2(secret.clone(), &round1)?; - let mut round2: Vec<_> = round2.into_iter().collect(); - round2.sort_by_key(|(i, _)| *i); - let round2: Vec<_> = round2.into_iter().map(|(_, p)| p).collect(); + + let round2 = indices.into_iter().map(|i| { + let id = Identifier::try_from(i as u16).unwrap(); + (i, round2.get(&id)) + }); ( KeygenRound::R2(*setup, secret, round1), @@ -110,8 +104,9 @@ impl KeygenContext { ) } KeygenRound::R2(setup, secret, round1) => { - let data: Vec = deserialize_vec(&decode(data)?)?; - let round2 = map_share_vec(data, &Vec::from_iter(1..=setup.parties), setup.index)?; + let data = ServerMessage::decode(data)?.unicasts; + let round2 = deserialize_map(&data)?; + let round2 = map_index_identifier(round2.into_iter()).collect(); let (key, pubkey) = frost::keys::dkg::part3(secret, round1, &round2)?; if !self.with_card { @@ -238,10 +233,10 @@ impl SignContext { Ok((msg, Recipient::Server)) } SignRound::R1(nonces, commitments) => { - let data: Vec = deserialize_vec(&decode(data)?)?; - + let data = ServerMessage::decode(data)?.broadcasts; + let commitments_map = deserialize_map(&data)?; let mut commitments_map: BTreeMap = - map_share_vec(data, self.indices.as_deref().unwrap(), self.setup.index)?; + map_index_identifier(commitments_map.into_iter()).collect(); let identifier = Identifier::try_from(self.setup.index).unwrap(); commitments_map.insert(identifier, *commitments); @@ -296,10 +291,10 @@ impl SignContext { Ok((msg, Recipient::Server)) } SignRound::R2(signing_package, share) => { - let data: Vec = deserialize_vec(&decode(data)?)?; - + let data = ServerMessage::decode(data)?.broadcasts; + let shares = deserialize_map(&data)?; let mut shares: BTreeMap = - map_share_vec(data, self.indices.as_deref().unwrap(), self.setup.index)?; + map_index_identifier(shares.into_iter()).collect(); let identifier = Identifier::try_from(self.setup.index).unwrap(); shares.insert(identifier, *share); diff --git a/src/protocol/gg18.rs b/src/protocol/gg18.rs index 0c570ed..630aa89 100644 --- a/src/protocol/gg18.rs +++ b/src/protocol/gg18.rs @@ -1,4 +1,4 @@ -use crate::proto::{ProtocolGroupInit, ProtocolInit, ProtocolType}; +use crate::proto::{ProtocolGroupInit, ProtocolInit, ProtocolType, ServerMessage}; use crate::protocol::*; use mpecdsa::{gg18_key_gen::*, gg18_sign::*}; use prost::Message; @@ -21,6 +21,12 @@ enum KeygenRound { Done(GG18SignContext), } +fn map_to_sorted_vec(map: HashMap) -> Vec { + let mut vec: Vec<_> = map.into_iter().collect(); + vec.sort_by_key(|(i, _)| *i); + vec.into_iter().map(|(_, x)| x).collect() +} + impl KeygenContext { fn init(&mut self, data: &[u8]) -> Result> { let msg = ProtocolGroupInit::decode(data)?; @@ -36,32 +42,46 @@ impl KeygenContext { } fn update(&mut self, data: &[u8]) -> Result> { - let msgs = decode(data)?; + let data = ServerMessage::decode(data)?; let (c, msg) = match &self.round { KeygenRound::R0 => unreachable!(), KeygenRound::R1(c1) => { - let (out, c2) = gg18_key_gen_2(deserialize_vec(&msgs)?, c1.clone())?; + let msgs = deserialize_map(&data.broadcasts)?; + let msgs = map_to_sorted_vec(msgs); + let (out, c2) = gg18_key_gen_2(msgs, c1.clone())?; let msg = serialize_bcast(&out, ProtocolType::Gg18)?; (KeygenRound::R2(c2), msg) } KeygenRound::R2(c2) => { - let (outs, c3) = gg18_key_gen_3(deserialize_vec(&msgs)?, c2.clone())?; + let msgs = deserialize_map(&data.broadcasts)?; + let msgs = map_to_sorted_vec(msgs); + let (outs, c3) = gg18_key_gen_3(msgs, c2.clone())?; + + let mut indices: Vec<_> = data.broadcasts.into_keys().collect(); + indices.sort(); + let outs = indices.into_iter().zip(outs.into_iter()); let msg = serialize_uni(outs, ProtocolType::Gg18)?; (KeygenRound::R3(c3), msg) } KeygenRound::R3(c3) => { - let (out, c4) = gg18_key_gen_4(deserialize_vec(&msgs)?, c3.clone())?; + let msgs = deserialize_map(&data.unicasts)?; + let msgs = map_to_sorted_vec(msgs); + let (out, c4) = gg18_key_gen_4(msgs, c3.clone())?; let msg = serialize_bcast(&out, ProtocolType::Gg18)?; (KeygenRound::R4(c4), msg) } KeygenRound::R4(c4) => { - let (out, c5) = gg18_key_gen_5(deserialize_vec(&msgs)?, c4.clone())?; + let msgs = deserialize_map(&data.broadcasts)?; + let msgs = map_to_sorted_vec(msgs); + let (out, c5) = gg18_key_gen_5(msgs, c4.clone())?; let msg = serialize_bcast(&out, ProtocolType::Gg18)?; (KeygenRound::R5(c5), msg) } KeygenRound::R5(c5) => { - let c = gg18_key_gen_6(deserialize_vec(&msgs)?, c5.clone())?; + let msgs = deserialize_map(&data.broadcasts)?; + let msgs = map_to_sorted_vec(msgs); + let c = gg18_key_gen_6(msgs, c5.clone())?; let msg = encode_raw_bcast(c.pk.to_bytes(false).to_vec(), ProtocolType::Gg18); (KeygenRound::Done(c), msg) } @@ -137,52 +157,74 @@ impl SignContext { } fn update(&mut self, data: &[u8]) -> Result> { - let msgs = decode(data)?; + let data = ServerMessage::decode(data)?; let (c, msg) = match &self.round { SignRound::R0(_) => unreachable!(), SignRound::R1(c1) => { - let (outs, c2) = gg18_sign2(deserialize_vec(&msgs)?, c1.clone())?; + let msgs = deserialize_map(&data.broadcasts)?; + let msgs = map_to_sorted_vec(msgs); + let (outs, c2) = gg18_sign2(msgs, c1.clone())?; + + let mut indices: Vec<_> = data.broadcasts.into_keys().collect(); + indices.sort(); + let outs = indices.into_iter().zip(outs.into_iter()); let msg = serialize_uni(outs, ProtocolType::Gg18)?; (SignRound::R2(c2), msg) } SignRound::R2(c2) => { - let (out, c3) = gg18_sign3(deserialize_vec(&msgs)?, c2.clone())?; + let msgs = deserialize_map(&data.unicasts)?; + let msgs = map_to_sorted_vec(msgs); + let (out, c3) = gg18_sign3(msgs, c2.clone())?; let msg = serialize_bcast(&out, ProtocolType::Gg18)?; (SignRound::R3(c3), msg) } SignRound::R3(c3) => { - let (out, c4) = gg18_sign4(deserialize_vec(&msgs)?, c3.clone())?; + let msgs = deserialize_map(&data.broadcasts)?; + let msgs = map_to_sorted_vec(msgs); + let (out, c4) = gg18_sign4(msgs, c3.clone())?; let msg = serialize_bcast(&out, ProtocolType::Gg18)?; (SignRound::R4(c4), msg) } SignRound::R4(c4) => { - let (out, c5) = gg18_sign5(deserialize_vec(&msgs)?, c4.clone())?; + let msgs = deserialize_map(&data.broadcasts)?; + let msgs = map_to_sorted_vec(msgs); + let (out, c5) = gg18_sign5(msgs, c4.clone())?; let msg = serialize_bcast(&out, ProtocolType::Gg18)?; (SignRound::R5(c5), msg) } SignRound::R5(c5) => { - let (out, c6) = gg18_sign6(deserialize_vec(&msgs)?, c5.clone())?; + let msgs = deserialize_map(&data.broadcasts)?; + let msgs = map_to_sorted_vec(msgs); + let (out, c6) = gg18_sign6(msgs, c5.clone())?; let msg = serialize_bcast(&out, ProtocolType::Gg18)?; (SignRound::R6(c6), msg) } SignRound::R6(c6) => { - let (out, c7) = gg18_sign7(deserialize_vec(&msgs)?, c6.clone())?; + let msgs = deserialize_map(&data.broadcasts)?; + let msgs = map_to_sorted_vec(msgs); + let (out, c7) = gg18_sign7(msgs, c6.clone())?; let msg = serialize_bcast(&out, ProtocolType::Gg18)?; (SignRound::R7(c7), msg) } SignRound::R7(c7) => { - let (out, c8) = gg18_sign8(deserialize_vec(&msgs)?, c7.clone())?; + let msgs = deserialize_map(&data.broadcasts)?; + let msgs = map_to_sorted_vec(msgs); + let (out, c8) = gg18_sign8(msgs, c7.clone())?; let msg = serialize_bcast(&out, ProtocolType::Gg18)?; (SignRound::R8(c8), msg) } SignRound::R8(c8) => { - let (out, c9) = gg18_sign9(deserialize_vec(&msgs)?, c8.clone())?; + let msgs = deserialize_map(&data.broadcasts)?; + let msgs = map_to_sorted_vec(msgs); + let (out, c9) = gg18_sign9(msgs, c8.clone())?; let msg = serialize_bcast(&out, ProtocolType::Gg18)?; (SignRound::R9(c9), msg) } SignRound::R9(c9) => { - let sig = gg18_sign10(deserialize_vec(&msgs)?, c9.clone())?; + let msgs = deserialize_map(&data.broadcasts)?; + let msgs = map_to_sorted_vec(msgs); + let sig = gg18_sign10(msgs, c9.clone())?; let msg = encode_raw_bcast(sig.clone(), ProtocolType::Gg18); (SignRound::Done(sig), msg) } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 5c5f935..10f4503 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -9,9 +9,10 @@ mod apdu; pub type Result = std::result::Result>; -use crate::proto::{MessageType, ProtocolMessage, ProtocolType}; +use crate::proto::{ClientMessage, ProtocolType}; use prost::Message; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; pub enum Recipient { Card, @@ -36,19 +37,22 @@ pub trait ThresholdProtocol: Protocol { Self: Sized; } -fn deserialize_vec<'de, T: Deserialize<'de>>(vec: &'de [Vec]) -> serde_json::Result> { - vec.iter() - .map(|item| serde_json::from_slice::(item)) +fn deserialize_map<'de, T: Deserialize<'de>>( + map: &'de HashMap>, +) -> serde_json::Result> { + map.iter() + .map(|(k, v)| Ok((*k, serde_json::from_slice::(v.as_slice())?))) .collect() } /// Encode a broadcast message fn encode_raw_bcast(message: Vec, protocol_type: ProtocolType) -> Vec { - ProtocolMessage { + ClientMessage { protocol_type: protocol_type.into(), - message_type: MessageType::Broadcast.into(), - messages: vec![message], - }.encode_to_vec() + unicasts: HashMap::new(), + broadcast: Some(message), + } + .encode_to_vec() } /// Serialize and encode a broadcast message @@ -58,27 +62,27 @@ fn serialize_bcast(value: &T, protocol_type: ProtocolType) -> serd } /// Encode a Vec of unicast messages -fn encode_raw_uni(messages: Vec>, protocol_type: ProtocolType) -> Vec { - ProtocolMessage { +fn encode_raw_uni(messages: HashMap>, protocol_type: ProtocolType) -> Vec { + ClientMessage { protocol_type: protocol_type.into(), - message_type: MessageType::Unicast.into(), - messages, - }.encode_to_vec() + unicasts: messages, + broadcast: None, + } + .encode_to_vec() } -/// Serialize and encode a Vec of unicast messages -fn serialize_uni(vec: Vec, protocol_type: ProtocolType) -> serde_json::Result> { - let messages = vec.iter() - .map(serde_json::to_vec) - .collect::>>()?; +/// Serialize and encode a map of unicast messages +fn serialize_uni(kvs: I, protocol_type: ProtocolType) -> serde_json::Result> +where + I: Iterator, + T: Serialize, +{ + let messages = kvs + .map(|(k, v)| Ok((k, serde_json::to_vec(&v)?))) + .collect::>()?; Ok(encode_raw_uni(messages, protocol_type)) } -/// Decode a protobuf message from the server -fn decode(data: &[u8]) -> std::result::Result>, prost::DecodeError> { - Ok(ProtocolMessage::decode(data)?.messages) -} - #[cfg(test)] mod tests { use super::*; @@ -86,26 +90,10 @@ mod tests { use prost::bytes::Bytes; use crate::{ - proto::{ProtocolGroupInit, ProtocolInit}, + proto::{ProtocolGroupInit, ProtocolInit, ServerMessage}, protocol::{KeygenProtocol, ThresholdProtocol}, }; - /// Translate a message from a client to a Vec of messages for every other client - fn distribute_client_message(message: ProtocolMessage, parties: u32) -> Vec> { - match message.message_type() { - MessageType::Broadcast => { - let messages = message.messages; - assert_eq!(messages.len(), 1); - std::iter::repeat(messages[0].clone()).take(parties as usize).collect() - }, - MessageType::Unicast => { - let messages = message.messages; - assert_eq!(messages.len(), parties as usize); - messages - }, - } - } - pub(super) trait KeygenProtocolTest: KeygenProtocol + Sized { // Cannot be added in Protocol (yet) due to typetag Trait limitations const PROTOCOL_TYPE: ProtocolType; @@ -116,16 +104,18 @@ mod tests { assert!(threshold <= parties); // initialize - let mut ctxs: Vec = (0..parties).map(|_| Self::new()).collect(); - let mut messages: Vec<_> = ctxs + let mut ctxs: HashMap = (0..parties) + .map(|i| (i as u32 + Self::INDEX_OFFSET, Self::new())) + .collect(); + + let mut messages: HashMap = ctxs .iter_mut() - .enumerate() - .map(|(idx, ctx)| { - ProtocolMessage::decode::( + .map(|(&index, ctx)| { + let msg = ClientMessage::decode::( ctx.advance( &(ProtocolGroupInit { protocol_type: Self::PROTOCOL_TYPE as i32, - index: idx as u32 + Self::INDEX_OFFSET, + index, parties, threshold, }) @@ -135,39 +125,37 @@ mod tests { .0 .into(), ) - .unwrap() + .unwrap(); + (index, msg) }) - .map(|msg| distribute_client_message(msg, parties - 1)) .collect(); // protocol rounds for _ in 0..(Self::ROUNDS - 1) { messages = ctxs .iter_mut() - .enumerate() - .map(|(idx, ctx)| { - let relay = messages - .iter() - .enumerate() - .map(|(sender, msg)| { - if sender < idx { - Some(msg[idx - 1].clone()) - } else if sender > idx { - Some(msg[idx].clone()) - } else { - None - } - }) - .filter(Option::is_some) - .map(Option::unwrap) - .collect(); + .map(|(&idx, ctx)| { + let mut unicasts = HashMap::new(); + let mut broadcasts = HashMap::new(); - ProtocolMessage::decode::( + for (&sender, msg) in &messages { + if sender == idx { + continue; + } + if let Some(broadcast) = &msg.broadcast { + broadcasts.insert(sender, broadcast.clone()); + } + if let Some(unicast) = msg.unicasts.get(&idx) { + unicasts.insert(sender, unicast.clone()); + } + } + + let msg = ClientMessage::decode::( ctx.advance( - &(ProtocolMessage { + &(ServerMessage { protocol_type: Self::PROTOCOL_TYPE as i32, - message_type: MessageType::Unicast.into(), - messages: relay, + unicasts, + broadcasts, }) .encode_to_vec(), ) @@ -175,18 +163,23 @@ mod tests { .0 .into(), ) - .unwrap() + .unwrap(); + (idx, msg) }) - .map(|msg| distribute_client_message(msg, parties - 1)) .collect(); } - let pks: Vec<_> = messages.iter().map(|x| x[0].clone()).collect(); + let pks: Vec<_> = messages + .iter() + .map(|(_, msgs)| msgs.broadcast.as_ref().unwrap().clone()) + .collect(); - let results = ctxs + let mut results: Vec<_> = ctxs .into_iter() - .map(|ctx| Box::new(ctx).finish().unwrap()) + .map(|(i, ctx)| (i, Box::new(ctx).finish().unwrap())) .collect(); + results.sort_by_key(|(i, _)| *i); + let results = results.into_iter().map(|(_, ctx)| ctx).collect(); (pks, results) } @@ -200,17 +193,15 @@ mod tests { fn run(ctxs: Vec>, indices: Vec, data: Vec) -> Vec> { // initialize - let mut ctxs: Vec = ctxs + let mut ctxs: HashMap = indices .iter() - .enumerate() - .filter(|(idx, _)| indices.contains(&(*idx as u16))) - .map(|(_, ctx)| Self::new(&ctx)) + .map(|&i| (i as u32 + Self::INDEX_OFFSET, Self::new(&ctxs[i as usize]))) .collect(); - let mut messages: Vec<_> = indices - .iter() - .zip(ctxs.iter_mut()) - .map(|(idx, ctx)| { - ProtocolMessage::decode::( + + let mut messages: HashMap = ctxs + .iter_mut() + .map(|(&index, ctx)| { + let msg = ClientMessage::decode::( ctx.advance( &(ProtocolInit { protocol_type: Self::PROTOCOL_TYPE as i32, @@ -218,7 +209,7 @@ mod tests { .iter() .map(|x| *x as u32 + Self::INDEX_OFFSET) .collect(), - index: *idx as u32 + Self::INDEX_OFFSET, + index, data: data.clone(), }) .encode_to_vec(), @@ -227,39 +218,37 @@ mod tests { .0 .into(), ) - .unwrap() + .unwrap(); + (index, msg) }) - .map(|msg| distribute_client_message(msg, indices.len() as u32 - 1)) .collect(); // protocol rounds for _ in 0..(Self::ROUNDS - 1) { messages = ctxs .iter_mut() - .enumerate() - .map(|(idx, ctx)| { - let relay = messages - .iter() - .enumerate() - .map(|(sender, msg)| { - if sender < idx { - Some(msg[idx - 1].clone()) - } else if sender > idx { - Some(msg[idx].clone()) - } else { - None - } - }) - .filter(Option::is_some) - .map(Option::unwrap) - .collect(); + .map(|(&idx, ctx)| { + let mut unicasts = HashMap::new(); + let mut broadcasts = HashMap::new(); - ProtocolMessage::decode::( + for (&sender, msg) in &messages { + if sender == idx { + continue; + } + if let Some(broadcast) = &msg.broadcast { + broadcasts.insert(sender, broadcast.clone()); + } + if let Some(unicast) = msg.unicasts.get(&idx) { + unicasts.insert(sender, unicast.clone()); + } + } + + let msg = ClientMessage::decode::( ctx.advance( - &(ProtocolMessage { + &(ServerMessage { protocol_type: Self::PROTOCOL_TYPE as i32, - message_type: MessageType::Unicast.into(), - messages: relay, + unicasts, + broadcasts, }) .encode_to_vec(), ) @@ -267,14 +256,14 @@ mod tests { .0 .into(), ) - .unwrap() + .unwrap(); + (idx, msg) }) - .map(|msg| distribute_client_message(msg, indices.len() as u32 - 1)) .collect(); } ctxs.into_iter() - .map(|ctx| Box::new(ctx).finish().unwrap()) + .map(|(_, ctx)| Box::new(ctx).finish().unwrap()) .collect() } } From 4ba2f79ebf2b199ce91c7c52e97344f51d9f7e5d Mon Sep 17 00:00:00 2001 From: Marek Mracna Date: Tue, 27 Aug 2024 00:08:39 +0200 Subject: [PATCH 5/8] refactor: follow rustfmt --- src/protocol/elgamal.rs | 5 ++++- src/protocol/mod.rs | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/protocol/elgamal.rs b/src/protocol/elgamal.rs index a34a7ec..7439841 100644 --- a/src/protocol/elgamal.rs +++ b/src/protocol/elgamal.rs @@ -107,7 +107,10 @@ impl KeygenContext { } let dkg = dkg.complete()?; - let msg = encode_raw_bcast(dkg.key_set().shared_key().as_bytes().to_vec(), ProtocolType::Elgamal); + let msg = encode_raw_bcast( + dkg.key_set().shared_key().as_bytes().to_vec(), + ProtocolType::Elgamal, + ); (KeygenRound::Done(dkg), msg) } KeygenRound::Done(_) => return Err("protocol already finished".into()), diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 10f4503..6d0bbc9 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -56,7 +56,10 @@ fn encode_raw_bcast(message: Vec, protocol_type: ProtocolType) -> Vec { } /// Serialize and encode a broadcast message -fn serialize_bcast(value: &T, protocol_type: ProtocolType) -> serde_json::Result> { +fn serialize_bcast( + value: &T, + protocol_type: ProtocolType, +) -> serde_json::Result> { let message = serde_json::to_vec(value)?; Ok(encode_raw_bcast(message, protocol_type)) } From 20ab5f1c0ed706a499cf9956b6fd29e24122fbe2 Mon Sep 17 00:00:00 2001 From: Marek Mracna Date: Tue, 27 Aug 2024 01:06:58 +0200 Subject: [PATCH 6/8] test: replace indices vector with hashmap --- src/protocol/elgamal.rs | 12 +++++++++--- src/protocol/frost.rs | 14 +++++++++----- src/protocol/gg18.rs | 12 +++++++++--- src/protocol/mod.rs | 28 +++++++++++++--------------- 4 files changed, 40 insertions(+), 26 deletions(-) diff --git a/src/protocol/elgamal.rs b/src/protocol/elgamal.rs index 7439841..0b91638 100644 --- a/src/protocol/elgamal.rs +++ b/src/protocol/elgamal.rs @@ -348,6 +348,8 @@ mod tests { let (pks, _) = ::run(threshold as u32, parties as u32); + let pks: Vec<_> = pks.into_values().collect(); + for i in 1..parties { assert_eq!(pks[0], pks[i]) } @@ -361,13 +363,17 @@ mod tests { for parties in threshold..6 { let (pks, ctxs) = ::run(threshold as u32, parties as u32); + let pks: Vec<_> = pks.into_values().collect(); let msg = b"hello"; let ct = encrypt(msg, &pks[0]).unwrap(); - let mut indices = (0..parties as u16).choose_multiple(&mut OsRng, threshold); - indices.sort(); + let ctxs = ctxs + .into_iter() + .choose_multiple(&mut OsRng, threshold) + .into_iter() + .collect(); let results = - ::run(ctxs, indices, ct.to_vec()); + ::run(ctxs, ct.to_vec()); for result in results { assert_eq!(&msg.to_vec(), &result); diff --git a/src/protocol/frost.rs b/src/protocol/frost.rs index 658f6b6..ebb9961 100644 --- a/src/protocol/frost.rs +++ b/src/protocol/frost.rs @@ -369,7 +369,7 @@ mod tests { let pks: Vec = pks .iter() - .map(|x| serde_json::from_slice(&x).unwrap()) + .map(|(_, x)| serde_json::from_slice(&x).unwrap()) .collect(); for i in 1..parties { @@ -386,12 +386,16 @@ mod tests { let (pks, ctxs) = ::run(threshold as u32, parties as u32); let msg = b"hello"; - let pk: VerifyingKey = serde_json::from_slice(&pks[0]).unwrap(); + let (_, pk) = pks.iter().take(1).collect::>()[0]; + let pk: VerifyingKey = serde_json::from_slice(&pk).unwrap(); - let mut indices = (0..parties as u16).choose_multiple(&mut OsRng, threshold); - indices.sort(); + let ctxs = ctxs + .into_iter() + .choose_multiple(&mut OsRng, threshold) + .into_iter() + .collect(); let results = - ::run(ctxs, indices, msg.to_vec()); + ::run(ctxs, msg.to_vec()); let signature: Signature = serde_json::from_slice(&results[0]).unwrap(); diff --git a/src/protocol/gg18.rs b/src/protocol/gg18.rs index 630aa89..30706bf 100644 --- a/src/protocol/gg18.rs +++ b/src/protocol/gg18.rs @@ -288,6 +288,8 @@ mod tests { let (pks, _) = ::run(threshold as u32, parties as u32); + let pks: Vec<_> = pks.into_values().collect(); + for i in 1..parties { assert_eq!(pks[0], pks[i]) } @@ -304,12 +306,16 @@ mod tests { let msg = b"hello"; let dgst = sha2::Sha256::digest(msg); + let pks: Vec<_> = pks.into_values().collect(); let pk = VerifyingKey::from_sec1_bytes(&pks[0]).unwrap(); - let mut indices = (0..parties as u16).choose_multiple(&mut OsRng, threshold); - indices.sort(); + let ctxs = ctxs + .into_iter() + .choose_multiple(&mut OsRng, threshold) + .into_iter() + .collect(); let results = - ::run(ctxs, indices, dgst.to_vec()); + ::run(ctxs, dgst.to_vec()); let signature = results[0].clone(); diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 6d0bbc9..8faa861 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -103,7 +103,7 @@ mod tests { const ROUNDS: usize; const INDEX_OFFSET: u32 = 0; - fn run(threshold: u32, parties: u32) -> (Vec>, Vec>) { + fn run(threshold: u32, parties: u32) -> (HashMap>, HashMap>) { assert!(threshold <= parties); // initialize @@ -172,17 +172,15 @@ mod tests { .collect(); } - let pks: Vec<_> = messages - .iter() - .map(|(_, msgs)| msgs.broadcast.as_ref().unwrap().clone()) + let pks = messages + .into_iter() + .map(|(i, msgs)| (i, msgs.broadcast.unwrap())) .collect(); - let mut results: Vec<_> = ctxs + let results = ctxs .into_iter() .map(|(i, ctx)| (i, Box::new(ctx).finish().unwrap())) .collect(); - results.sort_by_key(|(i, _)| *i); - let results = results.into_iter().map(|(_, ctx)| ctx).collect(); (pks, results) } @@ -194,13 +192,16 @@ mod tests { const ROUNDS: usize; const INDEX_OFFSET: u32 = 0; - fn run(ctxs: Vec>, indices: Vec, data: Vec) -> Vec> { + fn run(ctxs: HashMap>, data: Vec) -> Vec> { // initialize - let mut ctxs: HashMap = indices - .iter() - .map(|&i| (i as u32 + Self::INDEX_OFFSET, Self::new(&ctxs[i as usize]))) + let mut ctxs: HashMap = ctxs + .into_iter() + .map(|(i, ctx)| (i, Self::new(&ctx))) .collect(); + let mut indices: Vec<_> = ctxs.keys().cloned().collect(); + indices.sort(); + let mut messages: HashMap = ctxs .iter_mut() .map(|(&index, ctx)| { @@ -208,10 +209,7 @@ mod tests { ctx.advance( &(ProtocolInit { protocol_type: Self::PROTOCOL_TYPE as i32, - indices: indices - .iter() - .map(|x| *x as u32 + Self::INDEX_OFFSET) - .collect(), + indices: indices.clone(), index, data: data.clone(), }) From 404cf6afa6d7c4bbd6d4085cec0cc22b6667d184 Mon Sep 17 00:00:00 2001 From: Marek Mracna Date: Thu, 5 Sep 2024 00:22:32 +0200 Subject: [PATCH 7/8] refactor: more flexible iterator usage --- src/protocol/frost.rs | 24 +++++++++++------------- src/protocol/mod.rs | 3 ++- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/protocol/frost.rs b/src/protocol/frost.rs index ebb9961..fdcae92 100644 --- a/src/protocol/frost.rs +++ b/src/protocol/frost.rs @@ -21,14 +21,10 @@ struct Setup { index: u16, } -fn map_index_identifier( - kvs: impl Iterator, -) -> impl Iterator { - kvs.map(|(i, x)| { - assert!(i > 0); - assert!(i <= u16::MAX as u32); - (Identifier::try_from(i as u16).unwrap(), x) - }) +fn index_to_identifier((i, x): (u32, T)) -> (Identifier, T) { + assert!(i > 0); + assert!(i <= u16::MAX as u32); + (Identifier::try_from(i as u16).unwrap(), x) } #[derive(Serialize, Deserialize)] @@ -89,7 +85,7 @@ impl KeygenContext { let data = ServerMessage::decode(data)?.broadcasts; let round1 = deserialize_map(&data)?; let indices: Vec<_> = round1.keys().cloned().collect(); - let round1 = map_index_identifier(round1.into_iter()).collect(); + let round1 = round1.into_iter().map(index_to_identifier).collect(); let (secret, round2) = dkg::part2(secret.clone(), &round1)?; let round2 = indices.into_iter().map(|i| { @@ -106,7 +102,7 @@ impl KeygenContext { KeygenRound::R2(setup, secret, round1) => { let data = ServerMessage::decode(data)?.unicasts; let round2 = deserialize_map(&data)?; - let round2 = map_index_identifier(round2.into_iter()).collect(); + let round2 = round2.into_iter().map(index_to_identifier).collect(); let (key, pubkey) = frost::keys::dkg::part3(secret, round1, &round2)?; if !self.with_card { @@ -235,8 +231,10 @@ impl SignContext { SignRound::R1(nonces, commitments) => { let data = ServerMessage::decode(data)?.broadcasts; let commitments_map = deserialize_map(&data)?; - let mut commitments_map: BTreeMap = - map_index_identifier(commitments_map.into_iter()).collect(); + let mut commitments_map: BTreeMap = commitments_map + .into_iter() + .map(index_to_identifier) + .collect(); let identifier = Identifier::try_from(self.setup.index).unwrap(); commitments_map.insert(identifier, *commitments); @@ -294,7 +292,7 @@ impl SignContext { let data = ServerMessage::decode(data)?.broadcasts; let shares = deserialize_map(&data)?; let mut shares: BTreeMap = - map_index_identifier(shares.into_iter()).collect(); + shares.into_iter().map(index_to_identifier).collect(); let identifier = Identifier::try_from(self.setup.index).unwrap(); shares.insert(identifier, *share); diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 8faa861..2b4f799 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -77,10 +77,11 @@ fn encode_raw_uni(messages: HashMap>, protocol_type: ProtocolType) /// Serialize and encode a map of unicast messages fn serialize_uni(kvs: I, protocol_type: ProtocolType) -> serde_json::Result> where - I: Iterator, + I: IntoIterator, T: Serialize, { let messages = kvs + .into_iter() .map(|(k, v)| Ok((k, serde_json::to_vec(&v)?))) .collect::>()?; Ok(encode_raw_uni(messages, protocol_type)) From 8e3e61d53ef49aaaa3148a14e2b680e330b983b7 Mon Sep 17 00:00:00 2001 From: Marek Mracna Date: Thu, 5 Sep 2024 00:24:40 +0200 Subject: [PATCH 8/8] docs: reword and annotate a few doccomments --- src/protocol/frost.rs | 1 + src/protocol/gg18.rs | 1 + src/protocol/mod.rs | 13 +++++++++---- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/protocol/frost.rs b/src/protocol/frost.rs index fdcae92..ff9f98f 100644 --- a/src/protocol/frost.rs +++ b/src/protocol/frost.rs @@ -21,6 +21,7 @@ struct Setup { index: u16, } +/// Helper intended for use in `iterator.map` fn index_to_identifier((i, x): (u32, T)) -> (Identifier, T) { assert!(i > 0); assert!(i <= u16::MAX as u32); diff --git a/src/protocol/gg18.rs b/src/protocol/gg18.rs index 30706bf..6638d8a 100644 --- a/src/protocol/gg18.rs +++ b/src/protocol/gg18.rs @@ -21,6 +21,7 @@ enum KeygenRound { Done(GG18SignContext), } +/// Collects a hashmap's values sorted by their respective keys fn map_to_sorted_vec(map: HashMap) -> Vec { let mut vec: Vec<_> = map.into_iter().collect(); vec.sort_by_key(|(i, _)| *i); diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 2b4f799..bf553d9 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -37,6 +37,7 @@ pub trait ThresholdProtocol: Protocol { Self: Sized; } +/// Deserializes values in a `HashMap` fn deserialize_map<'de, T: Deserialize<'de>>( map: &'de HashMap>, ) -> serde_json::Result> { @@ -45,7 +46,7 @@ fn deserialize_map<'de, T: Deserialize<'de>>( .collect() } -/// Encode a broadcast message +/// Encode a broadcast message to protobuf format fn encode_raw_bcast(message: Vec, protocol_type: ProtocolType) -> Vec { ClientMessage { protocol_type: protocol_type.into(), @@ -55,7 +56,7 @@ fn encode_raw_bcast(message: Vec, protocol_type: ProtocolType) -> Vec { .encode_to_vec() } -/// Serialize and encode a broadcast message +/// Serialize and encode a broadcast message to protobuf format fn serialize_bcast( value: &T, protocol_type: ProtocolType, @@ -64,7 +65,9 @@ fn serialize_bcast( Ok(encode_raw_bcast(message, protocol_type)) } -/// Encode a Vec of unicast messages +/// Encode unicast messages to protobuf format +/// +/// Each message is associated with an index as used by a respective protocol fn encode_raw_uni(messages: HashMap>, protocol_type: ProtocolType) -> Vec { ClientMessage { protocol_type: protocol_type.into(), @@ -74,7 +77,9 @@ fn encode_raw_uni(messages: HashMap>, protocol_type: ProtocolType) .encode_to_vec() } -/// Serialize and encode a map of unicast messages +/// Serialize and encode unicast messages to protobuf format +/// +/// Each message is associated with an index as used by a respective protocol fn serialize_uni(kvs: I, protocol_type: ProtocolType) -> serde_json::Result> where I: IntoIterator,