Skip to content

Commit

Permalink
Bus Interaction Enum (#2465)
Browse files Browse the repository at this point in the history
Ready for review. Solves #2460. Depends on #2469. Prior bug fixed (due
to an error in the `bus.asm` code update when unwrapping the new
`BusInteraction` enum type.
  • Loading branch information
qwang98 authored Feb 14, 2025
1 parent a1df755 commit 7671334
Show file tree
Hide file tree
Showing 13 changed files with 103 additions and 91 deletions.
20 changes: 15 additions & 5 deletions linker/src/bus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,20 @@ impl LinkerBackend for BusLinker {
}
.into();

self.bus_multi_send_args.items.push(Expression::Tuple(
SourceRef::unknown(),
vec![(interaction_id as u32).into(), tuple, selector],
));
self.bus_multi_send_args
.items
.push(Expression::FunctionCall(
SourceRef::unknown(),
FunctionCall {
function: Box::new(Expression::Reference(
SourceRef::unknown(),
SymbolPath::from_str("std::protocols::bus::BusInteraction::Send")
.unwrap()
.into(),
)),
arguments: vec![(interaction_id as u32).into(), tuple, selector],
},
));
}

fn process_object(&mut self, location: &Location, objects: &BTreeMap<Location, Object>) {
Expand Down Expand Up @@ -345,7 +355,7 @@ mod test {
pc' = (1 - first_step') * pc_update;
pol commit call_selectors[0];
std::array::map(call_selectors, std::utils::force_bool);
std::protocols::bus::bus_multi_send([(12064, [0, pc, instr__jump_to_operation, instr__reset, instr__loop, instr_return], 1)]);
std::protocols::bus::bus_multi_send([std::protocols::bus::BusInteraction::Send(12064, [0, pc, instr__jump_to_operation, instr__reset, instr__loop, instr_return], 1)]);
namespace main__rom(4);
pol constant p_line = [0, 1, 2] + [2]*;
pol constant p_instr__jump_to_operation = [0, 1, 0] + [0]*;
Expand Down
95 changes: 42 additions & 53 deletions std/protocols/bus.asm
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ use std::field::known_field;
use std::field::KnownField;
use std::check::panic;

enum BusInteraction {
// id, payload, multiplicity
// For bus sends, the multiplicity always equals the latch
Send(expr, expr[], expr),
// id, payload, multiplicity, latch
Receive(expr, expr[], expr, expr)
}

/// Helper function.
/// Materialized as a witness column for two reasons:
/// - It makes sure the constraint degree is independent of the input payload.
Expand Down Expand Up @@ -228,31 +236,21 @@ let compute_next_z: expr, expr, expr[], expr, Ext<expr>, Ext<expr>, Ext<expr> ->
/// Transpose user interface friendly bus send input format `(expr, expr[], expr)[]`
/// to constraint building friendly bus send input format `expr[], expr[][], expr[]`, i.e. id, payload, multiplicity.
/// This is because Rust-style tuple indexing, e.g. tuple.0, isn't supported yet.
let transpose_bus_send_inputs: (expr, expr[], expr)[] -> (expr[], expr[][], expr[]) = |bus_inputs| {
let ids: expr[] = array::map(bus_inputs,
|bus_input| {
let (id, _, _) = bus_input;
id
}
);
let payloads: expr[][] = array::map(bus_inputs,
|bus_input| {
let (_, payload, _) = bus_input;
payload
let transpose_bus_send_inputs: BusInteraction[] -> (expr[], expr[][], expr[]) = |bus_inputs| {
array::fold(
bus_inputs, ([], [], []),
|(ids, payloads, multiplicities), bus_input| {
match bus_input {
BusInteraction::Send(id, payload, multiplicity) => (ids + [id], payloads + [payload], multiplicities + [multiplicity]),
_ => std::check::panic("Requires BusInteraction::Send")
}
}
);
let multiplicities: expr[] = array::map(bus_inputs,
|bus_input| {
let (_, _, multiplicity) = bus_input;
multiplicity
}
);
(ids, payloads, multiplicities)
)
};

/// Convenience function for batching multiple bus sends.
/// Transposes user inputs and then calls the key logic for batch building bus interactions.
let bus_multi_send: (expr, expr[], expr)[] -> () = constr |bus_inputs| {
let bus_multi_send: BusInteraction[] -> () = constr |bus_inputs| {
let (ids, payloads, multiplicities) = transpose_bus_send_inputs(bus_inputs);
// For bus sends, the multiplicity always equals the latch
bus_multi_interaction(ids, payloads, multiplicities, multiplicities);
Expand All @@ -262,39 +260,23 @@ let bus_multi_send: (expr, expr[], expr)[] -> () = constr |bus_inputs| {
/// Transpose user interface friendly bus send input format `(expr, expr[], expr, expr)[]`
/// to constraint building friendly bus send input format `expr[], expr[][], expr[], expr[]`, i.e. id, payload, multiplicity, latch.
/// This is because Rust-style tuple indexing, e.g. tuple.0, isn't supported yet.
let transpose_bus_receive_inputs: (expr, expr[], expr, expr)[] -> (expr[], expr[][], expr[], expr[]) = |bus_inputs| {
let ids: expr[] = array::map(bus_inputs,
|bus_input| {
let (id, _, _, _) = bus_input;
id
let transpose_bus_receive_inputs: BusInteraction[] -> (expr[], expr[][], expr[], expr[]) = |bus_inputs| {
array::fold(
bus_inputs, ([], [], [], []),
|(ids, payloads, multiplicities, latches), bus_input| {
match bus_input {
BusInteraction::Receive(id, payload, multiplicity, latch) => (ids + [id], payloads + [payload], multiplicities + [-multiplicity], latches + [latch]),
_ => std::check::panic("Requires BusInteraction::Receive")
}
}
);
let payloads: expr[][] = array::map(bus_inputs,
|bus_input| {
let (_, payload, _, _) = bus_input;
payload
}
);
let negated_multiplicities: expr[] = array::map(bus_inputs,
|bus_input| {
let (_, _, multiplicity, _) = bus_input;
-multiplicity
}
);
let latches: expr[] = array::map(bus_inputs,
|bus_input| {
let (_, _, _, latch) = bus_input;
latch
}
);
(ids, payloads, negated_multiplicities, latches)
)
};

/// Convenience function for batching multiple bus receives.
/// Transposes user inputs and then calls the key logic for batch building bus interactions.
/// In practice, can also batch bus send and bus receive, but requires knowledge of this function and careful configuration of input parameters.
/// E.g. sending negative multiplicity and multiplicity for "multiplicity" and "latch" parameters for bus sends.
let bus_multi_receive: (expr, expr[], expr, expr)[] -> () = constr |bus_inputs| {
let bus_multi_receive: BusInteraction[] -> () = constr |bus_inputs| {
let (ids, payloads, negated_multiplicities, latches) = transpose_bus_receive_inputs(bus_inputs);
bus_multi_interaction(ids, payloads, negated_multiplicities, latches);
};
Expand All @@ -304,7 +286,7 @@ let bus_multi_receive: (expr, expr[], expr, expr)[] -> () = constr |bus_inputs|
let bus_multi_receive_batch_lookup_permutation: (expr, expr, expr[], int)[] -> () = constr |inputs| {
// Lookup requires adding a multiplicity column and constraining it to zero if selector is zero.
// Permutation passes the selector to both multiplicity and latch fields as well.
let inputs_inner: (expr, expr[], expr, expr)[] = array::fold(inputs, [], constr |acc, input| {
let inputs_inner = array::fold(inputs, [], constr |acc, input| {
// Converted to input format for the inner function `bus_multi_receive`:
// For lookup, format is id, payload, multiplicity, selector
// For permutation, format is id, payload, selector, selector
Expand All @@ -316,7 +298,7 @@ let bus_multi_receive_batch_lookup_permutation: (expr, expr, expr[], int)[] -> (
(1 - selector) * multiplicity = 0;
multiplicity
};
acc + [(id, payload, multiplicity, selector)]
acc + [BusInteraction::Receive(id, payload, multiplicity, selector)]
});
bus_multi_receive(inputs_inner);
};
Expand All @@ -328,12 +310,19 @@ let bus_interaction: expr, expr[], expr, expr -> () = constr |id, payload, multi
};

/// Convenience function for single bus interaction to send columns
let bus_send: expr, expr[], expr -> () = constr |id, payload, multiplicity| {
// For bus sends, the multiplicity always equals the latch
bus_interaction(id, payload, multiplicity, multiplicity);
let bus_send: BusInteraction -> () = constr |bus_input| {
match bus_input {
// For bus sends, the multiplicity always equals the latch
BusInteraction::Send(id, payload, multiplicity) => bus_interaction(id, payload, multiplicity, multiplicity),
_ => std::check::panic("Requires BusInteraction::Send.")
}
};

/// Convenience function for single bus interaction to receive columns
let bus_receive: expr, expr[], expr, expr -> () = constr |id, payload, multiplicity, latch| {
bus_interaction(id, payload, -multiplicity, latch);
let bus_receive: BusInteraction -> () = constr |bus_input| {
match bus_input {
BusInteraction::Receive(id, payload, multiplicity, latch) => bus_interaction(id, payload, -multiplicity, latch),
_ => std::check::panic("Requires BusInteraction::Receive.")
}
;
};
8 changes: 5 additions & 3 deletions std/protocols/lookup_via_bus.asm
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::protocols::bus::bus_receive;
use std::protocols::bus::BusInteraction;
use std::protocols::bus::bus_multi_receive;
use std::array;

Expand All @@ -8,17 +9,18 @@ let lookup_receive: expr, expr, expr[] -> () = constr |id, selector, tuple| {
let multiplicities;
(1 - selector) * multiplicities = 0;

bus_receive(id, tuple, multiplicities, selector);
bus_receive(BusInteraction::Receive(id, tuple, multiplicities, selector));
};

/// Batched version of `lookup_receive` that uses the more column-saving `bus_multi_receive`.
/// Ideally, should use `bus_multi_receive` to batch both lookup and permutation receives.
/// Note that we cannot input BusInteraction::Receive, which is defined differently.
let lookup_multi_receive: (expr, expr, expr[])[] -> () = constr |inputs| {
let inputs_inner: (expr, expr[], expr, expr)[] = array::fold(inputs, [], constr |acc, input| {
let inputs_inner = array::fold(inputs, [], constr |acc, input| {
let (id, selector, tuple) = input;
let multiplicity;
(1 - selector) * multiplicity = 0;
acc + [(id, tuple, multiplicity, selector)]
acc + [BusInteraction::Receive(id, tuple, multiplicity, selector)]
});
bus_multi_receive(inputs_inner);
};
8 changes: 5 additions & 3 deletions std/protocols/permutation_via_bus.asm
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
use std::protocols::bus::bus_receive;
use std::protocols::bus::bus_multi_receive;
use std::protocols::bus::BusInteraction;
use std::array;

/// Given an ID, selector, and tuple, receives (ID, ...tuple) tuple from the bus
/// with multiplicity 1 if the selector is 1.
let permutation_receive: expr, expr, expr[] -> () = constr |id, selector, tuple| {
bus_receive(id, tuple, selector, selector);
bus_receive(BusInteraction::Receive(id, tuple, selector, selector));
};

/// Batched version of `permutation_receive` that uses the more column-saving `bus_multi_receive`.
/// Ideally, should use `bus_multi_receive` to batch both lookup and permutation receives.
/// Note that we cannot input BusInteraction::Receive, which is defined differently.
let permutation_multi_receive: (expr, expr, expr[])[] -> () = constr |inputs| {
let inputs_inner: (expr, expr[], expr, expr)[] = array::fold(inputs, [], |acc, input| {
let inputs_inner = array::fold(inputs, [], |acc, input| {
let (id, selector, tuple) = input;
acc + [(id, tuple, selector, selector)]
acc + [BusInteraction::Receive(id, tuple, selector, selector)]
});
bus_multi_receive(inputs_inner);
};
5 changes: 3 additions & 2 deletions test_data/asm/block_to_block_with_bus.asm
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::protocols::bus::bus_receive;
use std::protocols::bus::bus_send;
use std::protocols::bus::BusInteraction;

// Like block_to_block.asm, but also adds a bus to both machines.
// This is still flawed currently, because:
Expand All @@ -22,7 +23,7 @@ machine Arith with

col witness bus_selector;
std::utils::force_bool(bus_selector);
bus_receive(ARITH_INTERACTION_ID, [0, x, y, z], latch * bus_selector, latch * bus_selector);
bus_receive(BusInteraction::Receive(ARITH_INTERACTION_ID, [0, x, y, z], latch * bus_selector, latch * bus_selector));

// TODO: Expose final value of acc as public.

Expand Down Expand Up @@ -51,7 +52,7 @@ machine Main with
// Need a constraint so that it's not optimized away
dummy = dummy';

bus_send(ARITH_INTERACTION_ID, [0, x, y, z], instr_add);
bus_send(BusInteraction::Send(ARITH_INTERACTION_ID, [0, x, y, z], instr_add));

// TODO: Expose final value of acc as public.

Expand Down
5 changes: 3 additions & 2 deletions test_data/asm/block_to_block_with_bus_different_sizes.asm
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::protocols::bus::bus_receive;
use std::protocols::bus::bus_send;
use std::protocols::bus::BusInteraction;
use std::prelude::Query;
use std::prover::challenge;

Expand All @@ -20,7 +21,7 @@ machine Arith with

col witness bus_selector;
std::utils::force_bool(bus_selector);
bus_receive(ARITH_INTERACTION_ID, [0, x, y, z], latch * bus_selector, latch * bus_selector);
bus_receive(BusInteraction::Receive(ARITH_INTERACTION_ID, [0, x, y, z], latch * bus_selector, latch * bus_selector));

// TODO: Expose final value of acc as public.

Expand Down Expand Up @@ -52,7 +53,7 @@ machine Main with
// Need a constraint so that it's not optimized away
dummy = dummy';

bus_send(ARITH_INTERACTION_ID, [0, x, y, z], instr_add);
bus_send(BusInteraction::Send(ARITH_INTERACTION_ID, [0, x, y, z], instr_add));

// TODO: Expose final value of acc as public.

Expand Down
7 changes: 4 additions & 3 deletions test_data/asm/dynamic_bus.asm
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::protocols::bus::bus_receive;
use std::protocols::bus::bus_send;
use std::protocols::bus::BusInteraction;

let ADD_BUS_ID = 123;
let MUL_BUS_ID = 456;
Expand All @@ -17,13 +18,13 @@ machine Main with
col witness add_a, add_b, add_c, add_sel;
std::utils::force_bool(add_sel);
add_c = add_a + add_b;
bus_receive(ADD_BUS_ID, [add_a, add_b, add_c], add_sel, add_sel);
bus_receive(BusInteraction::Receive(ADD_BUS_ID, [add_a, add_b, add_c], add_sel, add_sel));

// Mul block machine
col witness mul_a, mul_b, mul_c, mul_sel;
std::utils::force_bool(mul_sel);
mul_c = mul_a * mul_b;
bus_receive(MUL_BUS_ID, [mul_a, mul_b, mul_c], mul_sel, mul_sel);
bus_receive(BusInteraction::Receive(MUL_BUS_ID, [mul_a, mul_b, mul_c], mul_sel, mul_sel));

// Main machine
col fixed is_mul = [0, 1]*;
Expand All @@ -33,5 +34,5 @@ machine Main with

// Because we're doing exactly one of the two operations at any given time,
// we only need to do one send, choosing the bus to send to at runtime.
bus_send(is_mul * MUL_BUS_ID + (1 - is_mul) * ADD_BUS_ID, [x, y, z], 1);
bus_send(BusInteraction::Send(is_mul * MUL_BUS_ID + (1 - is_mul) * ADD_BUS_ID, [x, y, z], 1));
}
9 changes: 5 additions & 4 deletions test_data/asm/static_bus.asm
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::protocols::bus::bus_receive;
use std::protocols::bus::bus_send;
use std::protocols::bus::BusInteraction;

let ADD_BUS_ID = 123;
let MUL_BUS_ID = 456;
Expand All @@ -17,13 +18,13 @@ machine Main with
col witness add_a, add_b, add_c, add_sel;
std::utils::force_bool(add_sel);
add_c = add_a + add_b;
bus_receive(ADD_BUS_ID, [add_a, add_b, add_c], add_sel, add_sel);
bus_receive(BusInteraction::Receive(ADD_BUS_ID, [add_a, add_b, add_c], add_sel, add_sel));

// Mul block machine
col witness mul_a, mul_b, mul_c, mul_sel;
std::utils::force_bool(mul_sel);
mul_c = mul_a * mul_b;
bus_receive(MUL_BUS_ID, [mul_a, mul_b, mul_c], mul_sel, mul_sel);
bus_receive(BusInteraction::Receive(MUL_BUS_ID, [mul_a, mul_b, mul_c], mul_sel, mul_sel));

// Main machine
col fixed is_mul = [0, 1]*;
Expand All @@ -34,6 +35,6 @@ machine Main with
// Because the bus ID needs to be known at compile time, we have to do
// a bus send for each receiver, even though at most one send will be
// active in each row.
bus_send(MUL_BUS_ID, [x, y, z], is_mul);
bus_send(ADD_BUS_ID, [x, y, z], 1 - is_mul);
bus_send(BusInteraction::Send(MUL_BUS_ID, [x, y, z], is_mul));
bus_send(BusInteraction::Send(ADD_BUS_ID, [x, y, z], 1 - is_mul));
}
21 changes: 11 additions & 10 deletions test_data/asm/static_bus_multi.asm
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::protocols::bus::bus_multi_receive;
use std::protocols::bus::bus_multi_send;
use std::protocols::bus::BusInteraction;

let ADD_BUS_ID = 123;
let MUL_BUS_ID = 456;
Expand Down Expand Up @@ -44,11 +45,11 @@ machine Main with
// Multi bus receive
bus_multi_receive(
[
(ADD_BUS_ID, [add_a, add_b, add_c], add_sel, add_sel),
(MUL_BUS_ID, [mul_a, mul_b, mul_c], mul_sel, mul_sel),
(SUB_BUS_ID, [sub_a, sub_b, sub_c], sub_sel, sub_sel),
(DOUBLE_BUS_ID, [double_a, double_b, double_c], double_sel, double_sel),
(TRIPLE_BUS_ID, [triple_a, triple_b, triple_c], triple_sel, triple_sel)
BusInteraction::Receive(ADD_BUS_ID, [add_a, add_b, add_c], add_sel, add_sel),
BusInteraction::Receive(MUL_BUS_ID, [mul_a, mul_b, mul_c], mul_sel, mul_sel),
BusInteraction::Receive(SUB_BUS_ID, [sub_a, sub_b, sub_c], sub_sel, sub_sel),
BusInteraction::Receive(DOUBLE_BUS_ID, [double_a, double_b, double_c], double_sel, double_sel),
BusInteraction::Receive(TRIPLE_BUS_ID, [triple_a, triple_b, triple_c], triple_sel, triple_sel)
]
);

Expand All @@ -67,11 +68,11 @@ machine Main with
// active in each row.
bus_multi_send(
[
(MUL_BUS_ID, [x, y, z], is_mul),
(ADD_BUS_ID, [x, y, z], is_add),
(DOUBLE_BUS_ID, [x, y, z], is_double),
(SUB_BUS_ID, [x, y, z], is_sub),
(TRIPLE_BUS_ID, [x, y, z], is_triple)
BusInteraction::Send(MUL_BUS_ID, [x, y, z], is_mul),
BusInteraction::Send(ADD_BUS_ID, [x, y, z], is_add),
BusInteraction::Send(DOUBLE_BUS_ID, [x, y, z], is_double),
BusInteraction::Send(SUB_BUS_ID, [x, y, z], is_sub),
BusInteraction::Send(TRIPLE_BUS_ID, [x, y, z], is_triple)
]
);
}
Loading

0 comments on commit 7671334

Please sign in to comment.