Skip to content

Commit

Permalink
fix: ideal functionality synchronization (#201)
Browse files Browse the repository at this point in the history
  • Loading branch information
sinui0 authored Jan 8, 2025
1 parent 9e264e8 commit d155962
Show file tree
Hide file tree
Showing 17 changed files with 246 additions and 254 deletions.
2 changes: 1 addition & 1 deletion crates/mpz-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ executor = ["cpu", "dep:uid-mux"]
sync = ["tokio/sync"]
future = []
test-utils = ["dep:uid-mux", "uid-mux/test-utils"]
ideal = []
ideal = ["tokio/sync"]
rayon = ["dep:rayon"]
force-st = []

Expand Down
199 changes: 33 additions & 166 deletions crates/mpz-common/src/ideal.rs
Original file line number Diff line number Diff line change
@@ -1,191 +1,58 @@
//! Ideal functionality utilities.
use futures::channel::oneshot;
use std::{
any::Any,
collections::HashMap,
sync::{Arc, Mutex, MutexGuard},
};
use std::sync::Arc;
use tokio::sync::Barrier;

use crate::{Context, ThreadId};

type BoxAny = Box<dyn Any + Send + 'static>;

#[derive(Debug, Default)]
struct Buffer {
alice: HashMap<ThreadId, (BoxAny, oneshot::Sender<BoxAny>)>,
bob: HashMap<ThreadId, (BoxAny, oneshot::Sender<BoxAny>)>,
}

/// The ideal functionality from the perspective of Alice.
#[derive(Debug)]
pub struct Alice<F> {
f: Arc<Mutex<F>>,
buffer: Arc<Mutex<Buffer>>,
}

impl<F> Clone for Alice<F> {
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
buffer: self.buffer.clone(),
}
}
}

impl<F> Alice<F> {
/// Returns a lock to the ideal functionality.
pub fn get_mut(&mut self) -> MutexGuard<'_, F> {
self.f.lock().unwrap()
}

/// Calls the ideal functionality.
pub async fn call<Ctx, C, IA, IB, OA, OB>(&mut self, ctx: &mut Ctx, input: IA, call: C) -> OA
where
Ctx: Context,
C: FnOnce(&mut F, IA, IB) -> (OA, OB),
IA: Send + 'static,
IB: Send + 'static,
OA: Send + 'static,
OB: Send + 'static,
{
let receiver = {
let mut buffer = self.buffer.lock().unwrap();
if let Some((input_bob, ret_bob)) = buffer.bob.remove(ctx.id()) {
let input_bob = *input_bob
.downcast()
.expect("alice received correct input type for bob");

let (output_alice, output_bob) =
call(&mut self.f.lock().unwrap(), input, input_bob);

_ = ret_bob.send(Box::new(output_bob));

return output_alice;
}

let (sender, receiver) = oneshot::channel();
buffer
.alice
.insert(ctx.id().clone(), (Box::new(input), sender));
receiver
};

let output_alice = receiver.await.expect("bob did not drop the channel");
*output_alice
.downcast()
.expect("bob sent correct output type for alice")
}
/// Creates a new call synchronizer between two parties.
pub fn call_sync() -> (CallSync, CallSync) {
let barrier = Arc::new(Barrier::new(2));
(
CallSync {
barrier: Arc::clone(&barrier),
},
CallSync { barrier },
)
}

/// The ideal functionality from the perspective of Bob.
/// Synchronizes function calls between two parties.
#[derive(Debug)]
pub struct Bob<F> {
f: Arc<Mutex<F>>,
buffer: Arc<Mutex<Buffer>>,
}

impl<F> Clone for Bob<F> {
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
buffer: self.buffer.clone(),
}
}
pub struct CallSync {
barrier: Arc<Barrier>,
}

impl<F> Bob<F> {
/// Returns a lock to the ideal functionality.
pub fn get_mut(&mut self) -> MutexGuard<'_, F> {
self.f.lock().unwrap()
}

/// Calls the ideal functionality.
pub async fn call<Ctx, C, IA, IB, OA, OB>(&mut self, ctx: &mut Ctx, input: IB, call: C) -> OB
where
Ctx: Context,
C: FnOnce(&mut F, IA, IB) -> (OA, OB),
IA: Send + 'static,
IB: Send + 'static,
OA: Send + 'static,
OB: Send + 'static,
{
let receiver = {
let mut buffer = self.buffer.lock().unwrap();
if let Some((input_alice, ret_alice)) = buffer.alice.remove(ctx.id()) {
let input_alice = *input_alice
.downcast()
.expect("bob received correct input type for alice");
impl CallSync {
/// Synchronizes a call.
pub async fn call<F: FnMut() -> R, R>(&mut self, mut f: F) -> Option<R> {
// Wait for both parties to call.
let is_leader = self.barrier.wait().await.is_leader();

let (output_alice, output_bob) =
call(&mut self.f.lock().unwrap(), input_alice, input);
let ret = if is_leader { Some(f()) } else { None };

_ = ret_alice.send(Box::new(output_alice));
// Wait for the call to return.
self.barrier.wait().await;

return output_bob;
}

let (sender, receiver) = oneshot::channel();
buffer
.bob
.insert(ctx.id().clone(), (Box::new(input), sender));
receiver
};

let output_bob = receiver.await.expect("alice did not drop the channel");
*output_bob
.downcast()
.expect("alice sent correct output type for bob")
ret
}
}

/// Creates an ideal functionality, returning the perspectives of Alice and Bob.
pub fn ideal_f2p<F>(f: F) -> (Alice<F>, Bob<F>) {
let f = Arc::new(Mutex::new(f));
let buffer = Arc::new(Mutex::new(Buffer::default()));

(
Alice {
f: f.clone(),
buffer: buffer.clone(),
},
Bob { f, buffer },
)
}

#[cfg(test)]
mod test {
use crate::executor::test_st_executor;
use std::sync::Mutex;

use super::*;

#[test]
fn test_ideal() {
let (mut alice, mut bob) = ideal_f2p(());
let (mut ctx_a, mut ctx_b) = test_st_executor(8);
#[tokio::test]
async fn test_call_sync() {
let x = Arc::new(Mutex::new(0));

let (output_a, output_b) = futures::executor::block_on(async {
futures::join!(
alice.call(&mut ctx_a, 1u8, |&mut (), a: u8, b: u8| (a + b, a + b)),
bob.call(&mut ctx_b, 2u8, |&mut (), a: u8, b: u8| (a + b, a + b)),
)
});
let (mut sync_0, mut sync_1) = call_sync();

assert_eq!(output_a, 3);
assert_eq!(output_b, 3);
}
let add_one = || {
*x.lock().unwrap() += 1;
};

#[test]
#[should_panic]
fn test_ideal_wrong_input_type() {
let (mut alice, mut bob) = ideal_f2p(());
let (mut ctx_a, mut ctx_b) = test_st_executor(8);
futures::join!(sync_0.call(add_one.clone()), sync_1.call(add_one));

futures::executor::block_on(async {
futures::join!(
alice.call(&mut ctx_a, 1u16, |&mut (), a: u16, b: u16| (a + b, a + b)),
bob.call(&mut ctx_b, 2u8, |&mut (), a: u8, b: u8| (a + b, a + b)),
)
});
assert_eq!(*x.lock().unwrap(), 1);
}
}
2 changes: 1 addition & 1 deletion crates/mpz-ole-core/src/ideal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ where
let sender_count = this.sender_state.alloc;
let receiver_count = this.receiver_state.alloc;

sender_count > 0 || receiver_count > 0 && sender_count == receiver_count
sender_count > 0 || receiver_count > 0
}

/// Flushes the functionality.
Expand Down
85 changes: 61 additions & 24 deletions crates/mpz-ole/src/ideal.rs
Original file line number Diff line number Diff line change
@@ -1,37 +1,40 @@
//! Ideal ROLE.
use async_trait::async_trait;
use mpz_common::{Context, Flush};
use mpz_core::Block;
use rand::{rngs::StdRng, Rng, SeedableRng};

use mpz_common::{
ideal::{call_sync, CallSync},
Context, Flush,
};
use mpz_fields::Field;
use mpz_ole_core::{
ideal::{IdealROLE as Core, IdealROLEError},
ROLEReceiver, ROLESender, ROLESenderOutput,
};

/// Ideal ROLE.
#[derive(Debug, Clone)]
pub struct IdealROLE<F> {
core: Core<F>,
/// Returns a new ideal ROLE sender and receiver.
pub fn ideal_role<F: Field>() -> (IdealROLESender<F>, IdealROLEReceiver<F>) {
let mut rng = StdRng::seed_from_u64(0);
let core = Core::new(rng.gen());
let (sync_0, sync_1) = call_sync();
(
IdealROLESender {
core: core.clone(),
sync: sync_0,
},
IdealROLEReceiver { core, sync: sync_1 },
)
}

impl<F> IdealROLE<F>
where
F: Field,
{
/// Create a new ideal ROLE.
///
/// # Arguments
///
/// * `seed` - PRG seed.
pub fn new(seed: Block) -> Self {
Self {
core: Core::new(seed),
}
}
/// Ideal ROLE sender.
#[derive(Debug)]
pub struct IdealROLESender<F> {
core: Core<F>,
sync: CallSync,
}

impl<F> ROLESender<F> for IdealROLE<F>
impl<F> ROLESender<F> for IdealROLESender<F>
where
F: Field,
{
Expand All @@ -55,7 +58,38 @@ where
}
}

impl<F> ROLEReceiver<F> for IdealROLE<F>
#[async_trait]
impl<Ctx, F> Flush<Ctx> for IdealROLESender<F>
where
Ctx: Context,
F: Field,
{
type Error = IdealROLEError;

fn wants_flush(&self) -> bool {
self.core.wants_flush()
}

async fn flush(&mut self, _ctx: &mut Ctx) -> Result<(), Self::Error> {
if self.core.wants_flush() {
self.sync
.call(|| self.core.flush().map_err(IdealROLEError::from))
.await
.transpose()?;
}

Ok(())
}
}

/// Ideal ROLE Receiver.
#[derive(Debug)]
pub struct IdealROLEReceiver<F> {
core: Core<F>,
sync: CallSync,
}

impl<F> ROLEReceiver<F> for IdealROLEReceiver<F>
where
F: Field,
{
Expand Down Expand Up @@ -83,7 +117,7 @@ where
}

#[async_trait]
impl<Ctx, F> Flush<Ctx> for IdealROLE<F>
impl<Ctx, F> Flush<Ctx> for IdealROLEReceiver<F>
where
Ctx: Context,
F: Field,
Expand All @@ -96,7 +130,10 @@ where

async fn flush(&mut self, _ctx: &mut Ctx) -> Result<(), Self::Error> {
if self.core.wants_flush() {
self.core.flush()?;
self.sync
.call(|| self.core.flush().map_err(IdealROLEError::from))
.await
.transpose()?;
}

Ok(())
Expand Down
2 changes: 1 addition & 1 deletion crates/mpz-ot-core/src/ideal/cot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ impl IdealCOT {
let sender_queue = this.sender_state.queue.len();
let receiver_queue = this.receiver_state.queue.len();

sender_queue > 0 && receiver_queue > 0 && sender_queue == receiver_queue
sender_queue > 0 || receiver_queue > 0
}

/// Flushes the functionality.
Expand Down
2 changes: 1 addition & 1 deletion crates/mpz-ot-core/src/ideal/ot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ impl IdealOT {
let sender_queue = this.sender_state.queue.len();
let receiver_queue = this.receiver_state.queue.len();

sender_queue > 0 && receiver_queue > 0 && sender_queue == receiver_queue
sender_queue > 0 || receiver_queue > 0
}

/// Flushes the functionality.
Expand Down
2 changes: 1 addition & 1 deletion crates/mpz-ot-core/src/ideal/rcot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ impl IdealRCOT {
let sender_count = this.sender_state.alloc;
let receiver_count = this.receiver_state.alloc;

sender_count > 0 && receiver_count > 0 && sender_count == receiver_count
sender_count > 0 || receiver_count > 0
}

/// Flushes pending operations.
Expand Down
Loading

0 comments on commit d155962

Please sign in to comment.