diff --git a/Cargo.toml b/Cargo.toml index b009fd5..00cbc01 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "meesign-crypto" -version = "0.2.0" +version = "0.3.0" edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/proto/meesign.proto b/proto/meesign.proto index a08365a..7314a17 100644 --- a/proto/meesign.proto +++ b/proto/meesign.proto @@ -21,7 +21,14 @@ message ProtocolInit { bytes data = 4; } -message ProtocolMessage { +message ClientMessage { ProtocolType protocol_type = 1; - repeated bytes message = 2; + map unicasts = 2; + optional bytes broadcast = 3; +} + +message ServerMessage { + ProtocolType protocol_type = 1; + map unicasts = 2; + map broadcasts = 3; } diff --git a/src/protocol/elgamal.rs b/src/protocol/elgamal.rs index 5a42b19..0b91638 100644 --- a/src/protocol/elgamal.rs +++ b/src/protocol/elgamal.rs @@ -1,4 +1,4 @@ -use crate::proto::{ProtocolGroupInit, ProtocolInit, ProtocolType}; +use crate::proto::{ProtocolGroupInit, ProtocolInit, ProtocolType, ServerMessage}; use crate::protocol::*; use curve25519_dalek::{ ristretto::{CompressedRistretto, RistrettoPoint}, @@ -19,6 +19,8 @@ use aes_gcm::{ use prost::Message; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + #[derive(Serialize, Deserialize)] pub(crate) struct KeygenContext { round: KeygenRound, @@ -27,9 +29,9 @@ pub(crate) struct KeygenContext { #[derive(Serialize, Deserialize)] enum KeygenRound { R0, - R1(ParticipantCollectingCommitments, u16), - R2(ParticipantCollectingPolynomials, u16), - R3(ParticipantExchangingSecrets, u16), + R1(ParticipantCollectingCommitments), + R2(ParticipantCollectingPolynomials), + R3(ParticipantExchangingSecrets), Done(ActiveParticipant), } @@ -49,83 +51,73 @@ impl KeygenContext { let dkg = ParticipantCollectingCommitments::::new(params, index.into(), &mut OsRng); let c = dkg.commitment(); - let ser = serialize_bcast(&c, msg.parties as usize - 1)?; - self.round = KeygenRound::R1(dkg, index); - Ok(pack(ser, ProtocolType::Elgamal)) + let msg = serialize_bcast(&c, ProtocolType::Elgamal)?; + self.round = KeygenRound::R1(dkg); + Ok(msg) } fn update(&mut self, data: &[u8]) -> Result> { - let msgs = unpack(data)?; - let n = msgs.len(); + let msgs = ServerMessage::decode(data)?; - let (c, ser) = match &self.round { + let (c, msg) = match &self.round { KeygenRound::R0 => return Err("protocol not initialized".into()), - KeygenRound::R1(dkg, idx) => { + KeygenRound::R1(dkg) => { let mut dkg = dkg.clone(); - let data = deserialize_vec(&msgs)?; - for (mut i, msg) in data.into_iter().enumerate() { - if i >= *idx as usize { - i += 1; - } - dkg.insert_commitment(i, msg); + let data = deserialize_map(&msgs.broadcasts)?; + for (i, msg) in data { + dkg.insert_commitment(i as usize, msg); } if dkg.missing_commitments().next().is_some() { return Err("not enough commitments".into()); } let dkg = dkg.finish_commitment_phase(); let public_info = dkg.public_info(); - let ser = serialize_bcast(&public_info, n)?; + let msg = serialize_bcast(&public_info, ProtocolType::Elgamal)?; - (KeygenRound::R2(dkg, *idx), ser) + (KeygenRound::R2(dkg), msg) } - KeygenRound::R2(dkg, idx) => { + KeygenRound::R2(dkg) => { let mut dkg = dkg.clone(); - let data = deserialize_vec(&msgs)?; - for (mut i, msg) in data.into_iter().enumerate() { - if i >= *idx as usize { - i += 1; - } - dkg.insert_public_polynomial(i, msg)? + let data = deserialize_map(&msgs.broadcasts)?; + for (i, msg) in data { + dkg.insert_public_polynomial(i as usize, msg)? } if dkg.missing_public_polynomials().next().is_some() { return Err("not enough polynomials".into()); } let dkg = dkg.finish_polynomials_phase(); - let mut shares = Vec::new(); - for mut i in 0..n { - if i >= *idx as usize { - i += 1; - } - let secret_share = dkg.secret_share_for_participant(i); - shares.push(secret_share); - } - let ser = serialize_uni(shares)?; + let shares = msgs + .broadcasts + .into_keys() + .map(|i| (i, dkg.secret_share_for_participant(i as usize))); + + let msg = serialize_uni(shares, ProtocolType::Elgamal)?; - (KeygenRound::R3(dkg, *idx), ser) + (KeygenRound::R3(dkg), msg) } - KeygenRound::R3(dkg, idx) => { + KeygenRound::R3(dkg) => { let mut dkg = dkg.clone(); - let data = deserialize_vec(&msgs)?; - for (mut i, msg) in data.into_iter().enumerate() { - if i >= *idx as usize { - i += 1; - } - dkg.insert_secret_share(i, msg)?; + let data = deserialize_map(&msgs.unicasts)?; + for (i, msg) in data { + dkg.insert_secret_share(i as usize, msg)?; } if dkg.missing_shares().next().is_some() { return Err("not enough shares".into()); } let dkg = dkg.complete()?; - let ser = inflate(dkg.key_set().shared_key().as_bytes().to_vec(), n); - (KeygenRound::Done(dkg), ser) + let msg = encode_raw_bcast( + dkg.key_set().shared_key().as_bytes().to_vec(), + ProtocolType::Elgamal, + ); + (KeygenRound::Done(dkg), msg) } KeygenRound::Done(_) => return Err("protocol already finished".into()), }; self.round = c; - Ok(pack(ser, ProtocolType::Elgamal)) + Ok(msg) } } @@ -160,7 +152,6 @@ pub(crate) struct DecryptContext { ctx: ActiveParticipant, encrypted_key: Ciphertext, data: (Vec, Vec, Vec), - indices: Vec, shares: Vec<(usize, VerifiableDecryption)>, result: Option>, } @@ -173,21 +164,20 @@ impl DecryptContext { return Err("wrong protocol type".into()); } - self.indices = msg.indices.clone().into_iter().map(|i| i as u16).collect(); self.data = serde_json::from_slice(&msg.data)?; self.encrypted_key = serde_json::from_slice(&self.data.0)?; let (share, proof) = self.ctx.decrypt_share(self.encrypted_key, &mut OsRng); - let ser = serialize_bcast( + let msg = serialize_bcast( &serde_json::to_string(&(share, proof))?.as_bytes(), - self.indices.len() - 1, + ProtocolType::Elgamal, )?; let share = (self.ctx.index(), share); self.shares.push(share); - Ok(pack(ser, ProtocolType::Elgamal)) + Ok(msg) } fn update(&mut self, data: &[u8]) -> Result> { @@ -198,32 +188,17 @@ impl DecryptContext { return Err("protocol already finished".into()); } - let msgs = unpack(data)?; + let msgs = ServerMessage::decode(data)?; - let data: Vec> = deserialize_vec(&msgs)?; - let local_index = self - .indices - .iter() - .position(|x| *x as usize == self.ctx.index()) - .ok_or("participant index not included")?; - assert_eq!(self.ctx.index(), self.indices[local_index] as usize); - - for (mut i, msg) in data.into_iter().enumerate() { - if i >= local_index { - i += 1; - } + let data: HashMap> = deserialize_map(&msgs.broadcasts)?; + for (i, msg) in data { let msg: (VerifiableDecryption, LogEqualityProof) = serde_json::from_slice(&msg)?; self.ctx .key_set() - .verify_share( - msg.0.into(), - self.encrypted_key, - self.indices[i].into(), - &msg.1, - ) + .verify_share(msg.0.into(), self.encrypted_key, i as usize, &msg.1) .unwrap(); - self.shares.push((self.indices[i].into(), msg.0)); + self.shares.push((i as usize, msg.0)); } let mut key = [0u8; 16]; @@ -254,8 +229,8 @@ impl DecryptContext { self.result = Some(msg.clone()); - let ser = inflate(msg, self.indices.len() - 1); - Ok(pack(ser, ProtocolType::Elgamal)) + let msg = encode_raw_bcast(msg, ProtocolType::Elgamal); + Ok(msg) } } @@ -284,7 +259,6 @@ impl ThresholdProtocol for DecryptContext { ctx: serde_json::from_slice(group).expect("could not deserialize group context"), encrypted_key: Ciphertext::zero(), data: (Vec::new(), Vec::new(), Vec::new()), - indices: Vec::new(), shares: Vec::new(), result: None, } @@ -374,6 +348,8 @@ mod tests { let (pks, _) = ::run(threshold as u32, parties as u32); + let pks: Vec<_> = pks.into_values().collect(); + for i in 1..parties { assert_eq!(pks[0], pks[i]) } @@ -387,13 +363,17 @@ mod tests { for parties in threshold..6 { let (pks, ctxs) = ::run(threshold as u32, parties as u32); + let pks: Vec<_> = pks.into_values().collect(); let msg = b"hello"; let ct = encrypt(msg, &pks[0]).unwrap(); - let mut indices = (0..parties as u16).choose_multiple(&mut OsRng, threshold); - indices.sort(); + let ctxs = ctxs + .into_iter() + .choose_multiple(&mut OsRng, threshold) + .into_iter() + .collect(); let results = - ::run(ctxs, indices, ct.to_vec()); + ::run(ctxs, ct.to_vec()); for result in results { assert_eq!(&msg.to_vec(), &result); diff --git a/src/protocol/frost.rs b/src/protocol/frost.rs index 68928fd..ff9f98f 100644 --- a/src/protocol/frost.rs +++ b/src/protocol/frost.rs @@ -1,4 +1,4 @@ -use crate::proto::{ProtocolGroupInit, ProtocolInit, ProtocolType}; +use crate::proto::{ProtocolGroupInit, ProtocolInit, ProtocolType, ServerMessage}; use crate::protocol::*; use frost::keys::dkg::{self, round1, round2}; @@ -10,7 +10,6 @@ use prost::Message; use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; use std::convert::{TryFrom, TryInto}; -use std::iter::FromIterator; use frost_secp256k1 as frost; use rand::rngs::OsRng; @@ -22,23 +21,11 @@ struct Setup { index: u16, } -fn map_share_vec(vec: Vec, indices: &[u16], index: u16) -> Result -where - C: FromIterator<(Identifier, T)>, -{ - let pos = indices - .iter() - .position(|&x| x == index) - .ok_or("missing index")?; - let collection = vec - .into_iter() - .enumerate() - .map(move |(i, item)| { - let index = indices[if i < pos { i } else { i + 1 }]; - (Identifier::try_from(index).unwrap(), item) - }) - .collect(); - Ok(collection) +/// Helper intended for use in `iterator.map` +fn index_to_identifier((i, x): (u32, T)) -> (Identifier, T) { + assert!(i > 0); + assert!(i <= u16::MAX as u32); + (Identifier::try_from(i as u16).unwrap(), x) } #[derive(Serialize, Deserialize)] @@ -87,38 +74,43 @@ impl KeygenContext { OsRng, )?; - let msgs = serialize_bcast(&public_package, (setup.parties - 1) as usize)?; + let msg = serialize_bcast(&public_package, ProtocolType::Frost)?; self.round = KeygenRound::R1(setup, secret_package); - Ok((pack(msgs, ProtocolType::Frost), Recipient::Server)) + Ok((msg, Recipient::Server)) } fn update(&mut self, data: &[u8]) -> Result<(Vec, Recipient)> { let (c, data, rec) = match &self.round { KeygenRound::R0 => return Err("protocol not initialized".into()), KeygenRound::R1(setup, secret) => { - let data: Vec = deserialize_vec(&unpack(data)?)?; - let round1 = map_share_vec(data, &Vec::from_iter(1..=setup.parties), setup.index)?; + let data = ServerMessage::decode(data)?.broadcasts; + let round1 = deserialize_map(&data)?; + let indices: Vec<_> = round1.keys().cloned().collect(); + let round1 = round1.into_iter().map(index_to_identifier).collect(); let (secret, round2) = dkg::part2(secret.clone(), &round1)?; - let mut round2: Vec<_> = round2.into_iter().collect(); - round2.sort_by_key(|(i, _)| *i); - let round2: Vec<_> = round2.into_iter().map(|(_, p)| p).collect(); + + let round2 = indices.into_iter().map(|i| { + let id = Identifier::try_from(i as u16).unwrap(); + (i, round2.get(&id)) + }); ( KeygenRound::R2(*setup, secret, round1), - pack(serialize_uni(round2)?, ProtocolType::Frost), + serialize_uni(round2, ProtocolType::Frost)?, Recipient::Server, ) } KeygenRound::R2(setup, secret, round1) => { - let data: Vec = deserialize_vec(&unpack(data)?)?; - let round2 = map_share_vec(data, &Vec::from_iter(1..=setup.parties), setup.index)?; + let data = ServerMessage::decode(data)?.unicasts; + let round2 = deserialize_map(&data)?; + let round2 = round2.into_iter().map(index_to_identifier).collect(); let (key, pubkey) = frost::keys::dkg::part3(secret, round1, &round2)?; if !self.with_card { - let msgs = inflate(serde_json::to_vec(&pubkey.verifying_key())?, round2.len()); + let msgs = serialize_bcast(&pubkey.verifying_key(), ProtocolType::Frost)?; ( KeygenRound::Done(*setup, Some(key), pubkey), - pack(msgs, ProtocolType::Frost), + msgs, Recipient::Server, ) } else { @@ -138,13 +130,10 @@ impl KeygenContext { } KeygenRound::R21AwaitSetupResp(setup, pubkey) => { jc::response::setup(data)?; - let msgs = inflate( - serde_json::to_vec(&pubkey.verifying_key())?, - (setup.parties - 1) as usize, - ); + let msg = serialize_bcast(&pubkey.verifying_key(), ProtocolType::Frost)?; ( KeygenRound::Done(*setup, None, pubkey.clone()), - pack(msgs, ProtocolType::Frost), + msg, Recipient::Server, ) } @@ -222,9 +211,9 @@ impl SignContext { if let Some(key) = &self.key { let (nonces, commitments) = frost::round1::commit(key.signing_share(), &mut OsRng); - let msgs = serialize_bcast(&commitments, self.participants() - 1)?; + let msg = serialize_bcast(&commitments, ProtocolType::Frost)?; self.round = SignRound::R1(Some(nonces), commitments); - Ok((pack(msgs, ProtocolType::Frost), Recipient::Server)) + Ok((msg, Recipient::Server)) } else { self.round = SignRound::R01AwaitCommitResp; Ok((jc::command::commit(), Recipient::Card)) @@ -236,15 +225,17 @@ impl SignContext { SignRound::R0 => Err("protocol not initialized".into()), SignRound::R01AwaitCommitResp => { let commitments = jc::response::commit(data)?; - let msgs = serialize_bcast(&commitments, self.participants() - 1)?; + let msg = serialize_bcast(&commitments, ProtocolType::Frost)?; self.round = SignRound::R1(None, commitments); - Ok((pack(msgs, ProtocolType::Frost), Recipient::Server)) + Ok((msg, Recipient::Server)) } SignRound::R1(nonces, commitments) => { - let data: Vec = deserialize_vec(&unpack(data)?)?; - - let mut commitments_map: BTreeMap = - map_share_vec(data, self.indices.as_deref().unwrap(), self.setup.index)?; + let data = ServerMessage::decode(data)?.broadcasts; + let commitments_map = deserialize_map(&data)?; + let mut commitments_map: BTreeMap = commitments_map + .into_iter() + .map(index_to_identifier) + .collect(); let identifier = Identifier::try_from(self.setup.index).unwrap(); commitments_map.insert(identifier, *commitments); @@ -254,9 +245,9 @@ impl SignContext { if let Some(key) = &self.key { let share = frost::round2::sign(&signing_package, nonces.as_ref().unwrap(), key)?; - let msgs = serialize_bcast(&share, self.participants() - 1)?; + let msg = serialize_bcast(&share, ProtocolType::Frost)?; self.round = SignRound::R2(signing_package, share); - Ok((pack(msgs, ProtocolType::Frost), Recipient::Server)) + Ok((msg, Recipient::Server)) } else { let index = self.indices.as_deref().unwrap()[0]; let command = jc::command::commitment( @@ -294,23 +285,23 @@ impl SignContext { } SignRound::R12AwaitSignResp(signing_package) => { let share = jc::response::sign(data)?; - let msgs = serialize_bcast(&share, self.participants() - 1)?; + let msg = serialize_bcast(&share, ProtocolType::Frost)?; self.round = SignRound::R2(signing_package.clone(), share); - Ok((pack(msgs, ProtocolType::Frost), Recipient::Server)) + Ok((msg, Recipient::Server)) } SignRound::R2(signing_package, share) => { - let data: Vec = deserialize_vec(&unpack(data)?)?; - + let data = ServerMessage::decode(data)?.broadcasts; + let shares = deserialize_map(&data)?; let mut shares: BTreeMap = - map_share_vec(data, self.indices.as_deref().unwrap(), self.setup.index)?; + shares.into_iter().map(index_to_identifier).collect(); let identifier = Identifier::try_from(self.setup.index).unwrap(); shares.insert(identifier, *share); let signature = frost::aggregate(signing_package, &shares, &self.pubkey)?; - let msgs = serialize_bcast(&signature, self.participants() - 1)?; + let msg = serialize_bcast(&signature, ProtocolType::Frost)?; self.round = SignRound::Done(signature); - Ok((pack(msgs, ProtocolType::Frost), Recipient::Server)) + Ok((msg, Recipient::Server)) } SignRound::Done(_) => Err("protocol already finished".into()), } @@ -377,7 +368,7 @@ mod tests { let pks: Vec = pks .iter() - .map(|x| serde_json::from_slice(&x).unwrap()) + .map(|(_, x)| serde_json::from_slice(&x).unwrap()) .collect(); for i in 1..parties { @@ -394,12 +385,16 @@ mod tests { let (pks, ctxs) = ::run(threshold as u32, parties as u32); let msg = b"hello"; - let pk: VerifyingKey = serde_json::from_slice(&pks[0]).unwrap(); + let (_, pk) = pks.iter().take(1).collect::>()[0]; + let pk: VerifyingKey = serde_json::from_slice(&pk).unwrap(); - let mut indices = (0..parties as u16).choose_multiple(&mut OsRng, threshold); - indices.sort(); + let ctxs = ctxs + .into_iter() + .choose_multiple(&mut OsRng, threshold) + .into_iter() + .collect(); let results = - ::run(ctxs, indices, msg.to_vec()); + ::run(ctxs, msg.to_vec()); let signature: Signature = serde_json::from_slice(&results[0]).unwrap(); diff --git a/src/protocol/gg18.rs b/src/protocol/gg18.rs index 7d8287e..6638d8a 100644 --- a/src/protocol/gg18.rs +++ b/src/protocol/gg18.rs @@ -1,4 +1,4 @@ -use crate::proto::{ProtocolGroupInit, ProtocolInit, ProtocolType}; +use crate::proto::{ProtocolGroupInit, ProtocolInit, ProtocolType, ServerMessage}; use crate::protocol::*; use mpecdsa::{gg18_key_gen::*, gg18_sign::*}; use prost::Message; @@ -21,6 +21,13 @@ enum KeygenRound { Done(GG18SignContext), } +/// Collects a hashmap's values sorted by their respective keys +fn map_to_sorted_vec(map: HashMap) -> Vec { + let mut vec: Vec<_> = map.into_iter().collect(); + vec.sort_by_key(|(i, _)| *i); + vec.into_iter().map(|(_, x)| x).collect() +} + impl KeygenContext { fn init(&mut self, data: &[u8]) -> Result> { let msg = ProtocolGroupInit::decode(data)?; @@ -29,47 +36,60 @@ impl KeygenContext { (msg.parties as u16, msg.threshold as u16, msg.index as u16); let (out, c1) = gg18_key_gen_1(parties, threshold, index)?; - let ser = serialize_bcast(&out, msg.parties as usize - 1)?; + let msg = serialize_bcast(&out, ProtocolType::Gg18)?; self.round = KeygenRound::R1(c1); - Ok(pack(ser, ProtocolType::Gg18)) + Ok(msg) } fn update(&mut self, data: &[u8]) -> Result> { - let msgs = unpack(data)?; - let n = msgs.len(); + let data = ServerMessage::decode(data)?; - let (c, ser) = match &self.round { + let (c, msg) = match &self.round { KeygenRound::R0 => unreachable!(), KeygenRound::R1(c1) => { - let (out, c2) = gg18_key_gen_2(deserialize_vec(&msgs)?, c1.clone())?; - let ser = serialize_bcast(&out, n)?; - (KeygenRound::R2(c2), ser) + let msgs = deserialize_map(&data.broadcasts)?; + let msgs = map_to_sorted_vec(msgs); + let (out, c2) = gg18_key_gen_2(msgs, c1.clone())?; + let msg = serialize_bcast(&out, ProtocolType::Gg18)?; + (KeygenRound::R2(c2), msg) } KeygenRound::R2(c2) => { - let (outs, c3) = gg18_key_gen_3(deserialize_vec(&msgs)?, c2.clone())?; - let ser = serialize_uni(outs)?; - (KeygenRound::R3(c3), ser) + let msgs = deserialize_map(&data.broadcasts)?; + let msgs = map_to_sorted_vec(msgs); + let (outs, c3) = gg18_key_gen_3(msgs, c2.clone())?; + + let mut indices: Vec<_> = data.broadcasts.into_keys().collect(); + indices.sort(); + let outs = indices.into_iter().zip(outs.into_iter()); + let msg = serialize_uni(outs, ProtocolType::Gg18)?; + (KeygenRound::R3(c3), msg) } KeygenRound::R3(c3) => { - let (out, c4) = gg18_key_gen_4(deserialize_vec(&msgs)?, c3.clone())?; - let ser = serialize_bcast(&out, n)?; - (KeygenRound::R4(c4), ser) + let msgs = deserialize_map(&data.unicasts)?; + let msgs = map_to_sorted_vec(msgs); + let (out, c4) = gg18_key_gen_4(msgs, c3.clone())?; + let msg = serialize_bcast(&out, ProtocolType::Gg18)?; + (KeygenRound::R4(c4), msg) } KeygenRound::R4(c4) => { - let (out, c5) = gg18_key_gen_5(deserialize_vec(&msgs)?, c4.clone())?; - let ser = serialize_bcast(&out, n)?; - (KeygenRound::R5(c5), ser) + let msgs = deserialize_map(&data.broadcasts)?; + let msgs = map_to_sorted_vec(msgs); + let (out, c5) = gg18_key_gen_5(msgs, c4.clone())?; + let msg = serialize_bcast(&out, ProtocolType::Gg18)?; + (KeygenRound::R5(c5), msg) } KeygenRound::R5(c5) => { - let c = gg18_key_gen_6(deserialize_vec(&msgs)?, c5.clone())?; - let ser = inflate(c.pk.to_bytes(false).to_vec(), n); - (KeygenRound::Done(c), ser) + let msgs = deserialize_map(&data.broadcasts)?; + let msgs = map_to_sorted_vec(msgs); + let c = gg18_key_gen_6(msgs, c5.clone())?; + let msg = encode_raw_bcast(c.pk.to_bytes(false).to_vec(), ProtocolType::Gg18); + (KeygenRound::Done(c), msg) } KeygenRound::Done(_) => todo!(), }; self.round = c; - Ok(pack(ser, ProtocolType::Gg18)) + Ok(msg) } } @@ -124,7 +144,6 @@ impl SignContext { let msg = ProtocolInit::decode(data)?; let indices: Vec = msg.indices.clone().into_iter().map(|i| i as u16).collect(); - let parties = indices.len(); let local_index = indices.iter().position(|&i| i == msg.index as u16).unwrap(); let c0 = match &self.round { @@ -133,67 +152,88 @@ impl SignContext { }; let (out, c1) = gg18_sign1(c0, indices, local_index, msg.data)?; - let ser = serialize_bcast(&out, parties - 1)?; + let msg = serialize_bcast(&out, ProtocolType::Gg18)?; self.round = SignRound::R1(c1); - Ok(pack(ser, ProtocolType::Gg18)) + Ok(msg) } fn update(&mut self, data: &[u8]) -> Result> { - let msgs = unpack(data)?; - let n = msgs.len(); + let data = ServerMessage::decode(data)?; - let (c, ser) = match &self.round { + let (c, msg) = match &self.round { SignRound::R0(_) => unreachable!(), SignRound::R1(c1) => { - let (outs, c2) = gg18_sign2(deserialize_vec(&msgs)?, c1.clone())?; - let ser = serialize_uni(outs)?; - (SignRound::R2(c2), ser) + let msgs = deserialize_map(&data.broadcasts)?; + let msgs = map_to_sorted_vec(msgs); + let (outs, c2) = gg18_sign2(msgs, c1.clone())?; + + let mut indices: Vec<_> = data.broadcasts.into_keys().collect(); + indices.sort(); + let outs = indices.into_iter().zip(outs.into_iter()); + let msg = serialize_uni(outs, ProtocolType::Gg18)?; + (SignRound::R2(c2), msg) } SignRound::R2(c2) => { - let (out, c3) = gg18_sign3(deserialize_vec(&msgs)?, c2.clone())?; - let ser = serialize_bcast(&out, n)?; - (SignRound::R3(c3), ser) + let msgs = deserialize_map(&data.unicasts)?; + let msgs = map_to_sorted_vec(msgs); + let (out, c3) = gg18_sign3(msgs, c2.clone())?; + let msg = serialize_bcast(&out, ProtocolType::Gg18)?; + (SignRound::R3(c3), msg) } SignRound::R3(c3) => { - let (out, c4) = gg18_sign4(deserialize_vec(&msgs)?, c3.clone())?; - let ser = serialize_bcast(&out, n)?; - (SignRound::R4(c4), ser) + let msgs = deserialize_map(&data.broadcasts)?; + let msgs = map_to_sorted_vec(msgs); + let (out, c4) = gg18_sign4(msgs, c3.clone())?; + let msg = serialize_bcast(&out, ProtocolType::Gg18)?; + (SignRound::R4(c4), msg) } SignRound::R4(c4) => { - let (out, c5) = gg18_sign5(deserialize_vec(&msgs)?, c4.clone())?; - let ser = serialize_bcast(&out, n)?; - (SignRound::R5(c5), ser) + let msgs = deserialize_map(&data.broadcasts)?; + let msgs = map_to_sorted_vec(msgs); + let (out, c5) = gg18_sign5(msgs, c4.clone())?; + let msg = serialize_bcast(&out, ProtocolType::Gg18)?; + (SignRound::R5(c5), msg) } SignRound::R5(c5) => { - let (out, c6) = gg18_sign6(deserialize_vec(&msgs)?, c5.clone())?; - let ser = serialize_bcast(&out, n)?; - (SignRound::R6(c6), ser) + let msgs = deserialize_map(&data.broadcasts)?; + let msgs = map_to_sorted_vec(msgs); + let (out, c6) = gg18_sign6(msgs, c5.clone())?; + let msg = serialize_bcast(&out, ProtocolType::Gg18)?; + (SignRound::R6(c6), msg) } SignRound::R6(c6) => { - let (out, c7) = gg18_sign7(deserialize_vec(&msgs)?, c6.clone())?; - let ser = serialize_bcast(&out, n)?; - (SignRound::R7(c7), ser) + let msgs = deserialize_map(&data.broadcasts)?; + let msgs = map_to_sorted_vec(msgs); + let (out, c7) = gg18_sign7(msgs, c6.clone())?; + let msg = serialize_bcast(&out, ProtocolType::Gg18)?; + (SignRound::R7(c7), msg) } SignRound::R7(c7) => { - let (out, c8) = gg18_sign8(deserialize_vec(&msgs)?, c7.clone())?; - let ser = serialize_bcast(&out, n)?; - (SignRound::R8(c8), ser) + let msgs = deserialize_map(&data.broadcasts)?; + let msgs = map_to_sorted_vec(msgs); + let (out, c8) = gg18_sign8(msgs, c7.clone())?; + let msg = serialize_bcast(&out, ProtocolType::Gg18)?; + (SignRound::R8(c8), msg) } SignRound::R8(c8) => { - let (out, c9) = gg18_sign9(deserialize_vec(&msgs)?, c8.clone())?; - let ser = serialize_bcast(&out, n)?; - (SignRound::R9(c9), ser) + let msgs = deserialize_map(&data.broadcasts)?; + let msgs = map_to_sorted_vec(msgs); + let (out, c9) = gg18_sign9(msgs, c8.clone())?; + let msg = serialize_bcast(&out, ProtocolType::Gg18)?; + (SignRound::R9(c9), msg) } SignRound::R9(c9) => { - let sig = gg18_sign10(deserialize_vec(&msgs)?, c9.clone())?; - let ser = inflate(sig.clone(), n); - (SignRound::Done(sig), ser) + let msgs = deserialize_map(&data.broadcasts)?; + let msgs = map_to_sorted_vec(msgs); + let sig = gg18_sign10(msgs, c9.clone())?; + let msg = encode_raw_bcast(sig.clone(), ProtocolType::Gg18); + (SignRound::Done(sig), msg) } SignRound::Done(_) => todo!(), }; self.round = c; - Ok(pack(ser, ProtocolType::Gg18)) + Ok(msg) } } @@ -249,6 +289,8 @@ mod tests { let (pks, _) = ::run(threshold as u32, parties as u32); + let pks: Vec<_> = pks.into_values().collect(); + for i in 1..parties { assert_eq!(pks[0], pks[i]) } @@ -265,12 +307,16 @@ mod tests { let msg = b"hello"; let dgst = sha2::Sha256::digest(msg); + let pks: Vec<_> = pks.into_values().collect(); let pk = VerifyingKey::from_sec1_bytes(&pks[0]).unwrap(); - let mut indices = (0..parties as u16).choose_multiple(&mut OsRng, threshold); - indices.sort(); + let ctxs = ctxs + .into_iter() + .choose_multiple(&mut OsRng, threshold) + .into_iter() + .collect(); let results = - ::run(ctxs, indices, dgst.to_vec()); + ::run(ctxs, dgst.to_vec()); let signature = results[0].clone(); diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 29e25d4..bf553d9 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -9,9 +9,10 @@ mod apdu; pub type Result = std::result::Result>; -use crate::proto::{ProtocolMessage, ProtocolType}; +use crate::proto::{ClientMessage, ProtocolType}; use prost::Message; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; pub enum Recipient { Card, @@ -36,43 +37,61 @@ pub trait ThresholdProtocol: Protocol { Self: Sized; } -fn deserialize_vec<'de, T: Deserialize<'de>>(vec: &'de [Vec]) -> serde_json::Result> { - vec.iter() - .map(|item| serde_json::from_slice::(item)) +/// Deserializes values in a `HashMap` +fn deserialize_map<'de, T: Deserialize<'de>>( + map: &'de HashMap>, +) -> serde_json::Result> { + map.iter() + .map(|(k, v)| Ok((*k, serde_json::from_slice::(v.as_slice())?))) .collect() } -fn inflate(value: T, n: usize) -> Vec { - std::iter::repeat(value).take(n).collect() -} - -/// Serialize value and repeat the result n times, -/// as the current server always expects one message for each party -fn serialize_bcast(value: &T, n: usize) -> serde_json::Result>> { - let ser = serde_json::to_vec(value)?; - Ok(inflate(ser, n)) -} - -/// Serialize vector of unicast messages -fn serialize_uni(vec: Vec) -> serde_json::Result>> { - vec.iter().map(|item| serde_json::to_vec(item)).collect() +/// Encode a broadcast message to protobuf format +fn encode_raw_bcast(message: Vec, protocol_type: ProtocolType) -> Vec { + ClientMessage { + protocol_type: protocol_type.into(), + unicasts: HashMap::new(), + broadcast: Some(message), + } + .encode_to_vec() } -/// Decode a protobuf message from the server -fn unpack(data: &[u8]) -> std::result::Result>, prost::DecodeError> { - let msgs = ProtocolMessage::decode(data)?.message; - Ok(msgs) +/// Serialize and encode a broadcast message to protobuf format +fn serialize_bcast( + value: &T, + protocol_type: ProtocolType, +) -> serde_json::Result> { + let message = serde_json::to_vec(value)?; + Ok(encode_raw_bcast(message, protocol_type)) } -/// Encode msgs as a protobuf message for the server -fn pack(msgs: Vec>, protocol_type: ProtocolType) -> Vec { - ProtocolMessage { +/// Encode unicast messages to protobuf format +/// +/// Each message is associated with an index as used by a respective protocol +fn encode_raw_uni(messages: HashMap>, protocol_type: ProtocolType) -> Vec { + ClientMessage { protocol_type: protocol_type.into(), - message: msgs, + unicasts: messages, + broadcast: None, } .encode_to_vec() } +/// Serialize and encode unicast messages to protobuf format +/// +/// Each message is associated with an index as used by a respective protocol +fn serialize_uni(kvs: I, protocol_type: ProtocolType) -> serde_json::Result> +where + I: IntoIterator, + T: Serialize, +{ + let messages = kvs + .into_iter() + .map(|(k, v)| Ok((k, serde_json::to_vec(&v)?))) + .collect::>()?; + Ok(encode_raw_uni(messages, protocol_type)) +} + #[cfg(test)] mod tests { use super::*; @@ -80,7 +99,7 @@ mod tests { use prost::bytes::Bytes; use crate::{ - proto::{ProtocolGroupInit, ProtocolInit}, + proto::{ProtocolGroupInit, ProtocolInit, ServerMessage}, protocol::{KeygenProtocol, ThresholdProtocol}, }; @@ -90,20 +109,22 @@ mod tests { const ROUNDS: usize; const INDEX_OFFSET: u32 = 0; - fn run(threshold: u32, parties: u32) -> (Vec>, Vec>) { + fn run(threshold: u32, parties: u32) -> (HashMap>, HashMap>) { assert!(threshold <= parties); // initialize - let mut ctxs: Vec = (0..parties).map(|_| Self::new()).collect(); - let mut messages: Vec<_> = ctxs + let mut ctxs: HashMap = (0..parties) + .map(|i| (i as u32 + Self::INDEX_OFFSET, Self::new())) + .collect(); + + let mut messages: HashMap = ctxs .iter_mut() - .enumerate() - .map(|(idx, ctx)| { - ProtocolMessage::decode::( + .map(|(&index, ctx)| { + let msg = ClientMessage::decode::( ctx.advance( &(ProtocolGroupInit { protocol_type: Self::PROTOCOL_TYPE as i32, - index: idx as u32 + Self::INDEX_OFFSET, + index, parties, threshold, }) @@ -113,8 +134,8 @@ mod tests { .0 .into(), ) - .unwrap() - .message + .unwrap(); + (index, msg) }) .collect(); @@ -122,29 +143,28 @@ mod tests { for _ in 0..(Self::ROUNDS - 1) { messages = ctxs .iter_mut() - .enumerate() - .map(|(idx, ctx)| { - let relay = messages - .iter() - .enumerate() - .map(|(sender, msg)| { - if sender < idx { - Some(msg[idx - 1].clone()) - } else if sender > idx { - Some(msg[idx].clone()) - } else { - None - } - }) - .filter(Option::is_some) - .map(Option::unwrap) - .collect(); + .map(|(&idx, ctx)| { + let mut unicasts = HashMap::new(); + let mut broadcasts = HashMap::new(); - ProtocolMessage::decode::( + for (&sender, msg) in &messages { + if sender == idx { + continue; + } + if let Some(broadcast) = &msg.broadcast { + broadcasts.insert(sender, broadcast.clone()); + } + if let Some(unicast) = msg.unicasts.get(&idx) { + unicasts.insert(sender, unicast.clone()); + } + } + + let msg = ClientMessage::decode::( ctx.advance( - &(ProtocolMessage { + &(ServerMessage { protocol_type: Self::PROTOCOL_TYPE as i32, - message: relay, + unicasts, + broadcasts, }) .encode_to_vec(), ) @@ -152,17 +172,20 @@ mod tests { .0 .into(), ) - .unwrap() - .message + .unwrap(); + (idx, msg) }) .collect(); } - let pks: Vec<_> = messages.iter().map(|x| x[0].clone()).collect(); + let pks = messages + .into_iter() + .map(|(i, msgs)| (i, msgs.broadcast.unwrap())) + .collect(); let results = ctxs .into_iter() - .map(|ctx| Box::new(ctx).finish().unwrap()) + .map(|(i, ctx)| (i, Box::new(ctx).finish().unwrap())) .collect(); (pks, results) @@ -175,27 +198,25 @@ mod tests { const ROUNDS: usize; const INDEX_OFFSET: u32 = 0; - fn run(ctxs: Vec>, indices: Vec, data: Vec) -> Vec> { + fn run(ctxs: HashMap>, data: Vec) -> Vec> { // initialize - let mut ctxs: Vec = ctxs - .iter() - .enumerate() - .filter(|(idx, _)| indices.contains(&(*idx as u16))) - .map(|(_, ctx)| Self::new(&ctx)) + let mut ctxs: HashMap = ctxs + .into_iter() + .map(|(i, ctx)| (i, Self::new(&ctx))) .collect(); - let mut messages: Vec<_> = indices - .iter() - .zip(ctxs.iter_mut()) - .map(|(idx, ctx)| { - ProtocolMessage::decode::( + + let mut indices: Vec<_> = ctxs.keys().cloned().collect(); + indices.sort(); + + let mut messages: HashMap = ctxs + .iter_mut() + .map(|(&index, ctx)| { + let msg = ClientMessage::decode::( ctx.advance( &(ProtocolInit { protocol_type: Self::PROTOCOL_TYPE as i32, - indices: indices - .iter() - .map(|x| *x as u32 + Self::INDEX_OFFSET) - .collect(), - index: *idx as u32 + Self::INDEX_OFFSET, + indices: indices.clone(), + index, data: data.clone(), }) .encode_to_vec(), @@ -204,8 +225,8 @@ mod tests { .0 .into(), ) - .unwrap() - .message + .unwrap(); + (index, msg) }) .collect(); @@ -213,29 +234,28 @@ mod tests { for _ in 0..(Self::ROUNDS - 1) { messages = ctxs .iter_mut() - .enumerate() - .map(|(idx, ctx)| { - let relay = messages - .iter() - .enumerate() - .map(|(sender, msg)| { - if sender < idx { - Some(msg[idx - 1].clone()) - } else if sender > idx { - Some(msg[idx].clone()) - } else { - None - } - }) - .filter(Option::is_some) - .map(Option::unwrap) - .collect(); + .map(|(&idx, ctx)| { + let mut unicasts = HashMap::new(); + let mut broadcasts = HashMap::new(); - ProtocolMessage::decode::( + for (&sender, msg) in &messages { + if sender == idx { + continue; + } + if let Some(broadcast) = &msg.broadcast { + broadcasts.insert(sender, broadcast.clone()); + } + if let Some(unicast) = msg.unicasts.get(&idx) { + unicasts.insert(sender, unicast.clone()); + } + } + + let msg = ClientMessage::decode::( ctx.advance( - &(ProtocolMessage { + &(ServerMessage { protocol_type: Self::PROTOCOL_TYPE as i32, - message: relay, + unicasts, + broadcasts, }) .encode_to_vec(), ) @@ -243,14 +263,14 @@ mod tests { .0 .into(), ) - .unwrap() - .message + .unwrap(); + (idx, msg) }) .collect(); } ctxs.into_iter() - .map(|ctx| Box::new(ctx).finish().unwrap()) + .map(|(_, ctx)| Box::new(ctx).finish().unwrap()) .collect() } }