Skip to content

Commit

Permalink
Merge pull request #23 from encryptogroup/optimize-base-iter
Browse files Browse the repository at this point in the history
Refactored base_circuit iter
  • Loading branch information
robinhundt authored May 11, 2024
2 parents f5bfb91 + ae07bf2 commit 9ade732
Showing 1 changed file with 47 additions and 59 deletions.
106 changes: 47 additions & 59 deletions crates/seec/src/circuit/base_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ use bytemuck::{Pod, Zeroable};
use petgraph::dot::{Config, Dot};
use petgraph::graph::NodeIndex;
use petgraph::visit::IntoNodeIdentifiers;
use petgraph::visit::{VisitMap, Visitable};
use petgraph::{Directed, Direction, Graph};
use serde::{Deserialize, Serialize};
use tracing::{debug, info, instrument, trace};
Expand Down Expand Up @@ -508,10 +507,9 @@ impl<T: Share, D: Dimension> Gate for BaseGate<T, D> {
pub struct BaseLayerIter<'a, G, Idx: GateIdx, W> {
circuit: &'a BaseCircuit<G, Idx, W>,
inputs_needed_cnt: Vec<u32>,
prev_interactive: VecDeque<NodeIndex<Idx>>,
to_visit: VecDeque<NodeIndex<Idx>>,
next_layer: VecDeque<NodeIndex<Idx>>,
visited: <CircuitGraph<G, Idx, W> as Visitable>::Map,
added_to_next: <CircuitGraph<G, Idx, W> as Visitable>::Map,
// only used for SIMD circuits
// TODO remove entries from hashmap when count recheas 0
inputs_left_to_provide: HashMap<NodeIndex<Idx>, u32>,
Expand Down Expand Up @@ -549,15 +547,12 @@ impl<'a, Idx: GateIdx, G: Gate, W: Wire> BaseLayerIter<'a, G, Idx, W> {
.collect();
let to_visit = VecDeque::new();
let next_layer = circuit.constant_gates.iter().map(|&g| g.into()).collect();
let visited = circuit.graph.visit_map();
let added_to_next = circuit.graph.visit_map();
Self {
circuit,
inputs_needed_cnt,
prev_interactive: VecDeque::new(),
to_visit,
next_layer,
visited,
added_to_next,
inputs_left_to_provide: Default::default(),
last_layer_size: (0, 0),
gates_produced: 0,
Expand All @@ -581,10 +576,7 @@ impl<'a, Idx: GateIdx, G: Gate, W: Wire> BaseLayerIter<'a, G, Idx, W> {

/// Adds idx to the next layer if it has not been visited
pub fn add_to_next_layer(&mut self, idx: NodeIndex<Idx>) {
if !self.added_to_next.is_visited(&idx) {
self.next_layer.push_back(idx);
self.added_to_next.visit(idx);
}
self.next_layer.push_back(idx);
}

pub fn is_exhausted(&self) -> bool {
Expand Down Expand Up @@ -699,71 +691,67 @@ impl<'a, G: Gate, Idx: GateIdx, W: Wire> Iterator for BaseLayerIter<'a, G, Idx,

#[tracing::instrument(level = "trace", skip(self), ret)]
fn next(&mut self) -> Option<Self::Item> {
// TODO this current implementation is confusing -> Refactor
let graph = self.circuit.as_graph();
let mut layer = CircuitLayer::with_capacity(self.last_layer_size);
std::mem::swap(&mut self.to_visit, &mut self.next_layer);

while let Some(node_idx) = self.to_visit.pop_front() {
// This case handles the interactive gates at the front of to_visit that
// are here because they were `add_to_next_layer` but whose neighbours have not
// had their counts decreased
if self.visited.is_visited(&node_idx) {
let mut neigh_cnt = 0;
for neigh in graph.neighbors(node_idx) {
neigh_cnt += 1;
{
let count = self.inputs_needed_cnt[neigh.index()];
trace!("Node: {node_idx:?} -> Neigh {neigh:?}: count {count}")
}
self.inputs_needed_cnt[neigh.index()] -= 1;
let inputs_needed = self.inputs_needed_cnt[neigh.index()];
if inputs_needed == 0 {
self.add_to_visit(neigh);
}
// Unfortunately this is hard to factor into a function on self due to borrowing issues :/
let update_queue = |node,
inputs_needed: &mut [u32],
inputs_left_to_provide: &mut HashMap<_, _>,
queue: &mut VecDeque<_>| {
let mut neigh_cnt = 0;
for neigh in graph.neighbors(node) {
neigh_cnt += 1;
let inputs_needed = &mut inputs_needed[neigh.index()];
*inputs_needed -= 1;
if *inputs_needed == 0 {
queue.push_back(neigh);
}
if self.circuit.is_simd() {
self.inputs_left_to_provide
.entry(node_idx)
.or_insert(neigh_cnt);
}
continue;
}
self.visited.visit(node_idx);

if self.circuit.is_simd() {
for neigh in graph.neighbors_directed(node_idx, Direction::Incoming) {
let cnt = self
.inputs_left_to_provide
.get_mut(&neigh)
.expect("inputs_left_to_provide must be initialize");
*cnt -= 1;
if *cnt == 0 {
layer.freeable_gates.push(neigh.into());
}
}
inputs_left_to_provide.entry(node).or_insert(neigh_cnt);
}

neigh_cnt
};

while let Some(node_idx) = self.prev_interactive.pop_front() {
update_queue(
node_idx,
&mut self.inputs_needed_cnt,
&mut self.inputs_left_to_provide,
&mut self.to_visit,
);
}

while let Some(node_idx) = self.to_visit.pop_front() {
let gate = graph[node_idx].clone();
if gate.is_interactive() {
self.add_to_next_layer(node_idx);
layer.push_interactive((gate.clone(), node_idx.into()));
self.prev_interactive.push_back(node_idx);
} else {
layer.push_non_interactive((gate.clone(), node_idx.into()));
let mut neigh_cnt = 0;
for neigh in graph.neighbors(node_idx) {
neigh_cnt += 1;
self.inputs_needed_cnt[neigh.index()] -= 1;
let inputs_needed = self.inputs_needed_cnt[neigh.index()];
if inputs_needed == 0 {
self.add_to_visit(neigh)
update_queue(
node_idx,
&mut self.inputs_needed_cnt,
&mut self.inputs_left_to_provide,
&mut self.to_visit,
);
}

if self.circuit.is_simd() {
for neigh in graph.neighbors_directed(node_idx, Direction::Incoming) {
let cnt = self
.inputs_left_to_provide
.get_mut(&neigh)
.expect("inputs_left_to_provide is initialized because of topo order");
*cnt -= 1;
if *cnt == 0 {
layer.freeable_gates.push(neigh.into());
}
}
if self.circuit.is_simd() {
self.inputs_left_to_provide
.entry(node_idx)
.or_insert(neigh_cnt);
}
}
}
if layer.is_empty() {
Expand Down

0 comments on commit 9ade732

Please sign in to comment.