Skip to content

Commit

Permalink
enforce ordered barriers
Browse files Browse the repository at this point in the history
  • Loading branch information
rdfriese committed Jul 31, 2024
1 parent 28cf0f9 commit acd2bd3
Show file tree
Hide file tree
Showing 12 changed files with 187 additions and 539 deletions.
2 changes: 2 additions & 0 deletions examples/kernels/dft_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,9 @@ fn main() {
.for_each(|elem| *elem = 0.0)
.block();
}
println!("here 0");
full_spectrum_array.wait_all();
println!("here 1");
full_spectrum_array.barrier();
times[ti].push(dft_lamellar_array_opt_test(
full_signal_array.clone(),
Expand Down
4 changes: 2 additions & 2 deletions examples/kernels/safe_parallel_blocked_array_gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ fn main() {
let a_init = a
.dist_iter_mut()
.enumerate()
.for_each(|(i, x)| *x = i as f32);
.for_each(move |(i, x)| *x = i as f32);
let b_init = b.dist_iter_mut().enumerate().for_each(move |(i, x)| {
//identity matrix
let row = i / dim;
Expand All @@ -47,7 +47,7 @@ fn main() {
*x = 0 as f32;
}
});
let c_init = c.dist_iter_mut().for_each(|x| *x = 0.0);
let c_init = c.dist_iter_mut().for_each(move |x| *x = 0.0);
world.block_on_all([a_init, b_init, c_init]);
let a = a.into_read_only();
let b = b.into_read_only();
Expand Down
4 changes: 2 additions & 2 deletions src/active_messaging/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl LamellarRequestAddResult for AmHandleInner {
waker.wake();
}
}
fn update_counters(&self) {
fn update_counters(&self, _sub_id: usize) {
let _team_reqs = self.team_outstanding_reqs.fetch_sub(1, Ordering::SeqCst);
let _world_req = self.world_outstanding_reqs.fetch_sub(1, Ordering::SeqCst);
if let Some(tg_outstanding_reqs) = self.tg_outstanding_reqs.clone() {
Expand Down Expand Up @@ -344,7 +344,7 @@ impl LamellarRequestAddResult for MultiAmHandleInner {
}
}
}
fn update_counters(&self) {
fn update_counters(&self, _sub_id: usize) {
let _team_reqs = self.team_outstanding_reqs.fetch_sub(1, Ordering::SeqCst);
let _world_req = self.world_outstanding_reqs.fetch_sub(1, Ordering::SeqCst);
if let Some(tg_outstanding_reqs) = self.tg_outstanding_reqs.clone() {
Expand Down
1 change: 1 addition & 0 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ crate::inventory::collect!(ReduceKey);
// impl Dist for bool {}
// lamellar_impl::generate_reductions_for_type_rt!(true, u8, usize);
// lamellar_impl::generate_ops_for_type_rt!(true, true, true, u8, usize);

// lamellar_impl::generate_reductions_for_type_rt!(false, f32);
// lamellar_impl::generate_ops_for_type_rt!(false, false, false, f32);
// lamellar_impl::generate_reductions_for_type_rt!(false, u128);
Expand Down
2 changes: 1 addition & 1 deletion src/array/global_lock_atomic/iteration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ impl<T: Dist> LamellarArrayMutIterators<T> for GlobalLockArray<T> {
self.array
.block_on(async move { lock.collective_write().await }),
);
self.barrier();
// self.barrier();
// println!("dist_iter thread {:?} got lock",std::thread::current().id());
GlobalLockDistIterMut {
data: self.clone(),
Expand Down
51 changes: 41 additions & 10 deletions src/array/iterator/distributed_iterator/consumer/for_each.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,14 @@ impl DistIterForEachHandle {
/// This function returns a handle that can be used to wait for the operation to complete
#[must_use = "this function returns a future used to poll for completion and retrieve the result. Call '.await' on the future otherwise, if it is ignored (via ' let _ = *.spawn()') or dropped the only way to ensure completion is calling 'wait_all()' on the world or array. Alternatively it may be acceptable to call '.block()' instead of 'spawn()'"]
pub fn spawn(self) -> LamellarTask<()> {
// match self.state {
// State::Barrier(ref barrier, _) => {
// println!("spawning task barrier id {:?}", barrier.barrier_id);
// }
// State::Reqs(_, barrier_id) => {
// println!("spawning task not sure I can be here {:?}", barrier_id);
// }
// }
self.team.clone().scheduler.spawn_task(self)
}
}
Expand All @@ -236,7 +244,7 @@ enum State {
#[pin] BarrierHandle,
Pin<Box<dyn Future<Output = InnerDistIterForEachHandle> + Send>>,
),
Reqs(#[pin] InnerDistIterForEachHandle),
Reqs(#[pin] InnerDistIterForEachHandle, usize),
}

impl Future for DistIterForEachHandle {
Expand All @@ -245,19 +253,42 @@ impl Future for DistIterForEachHandle {
let mut this = self.project();
match this.state.as_mut().project() {
StateProj::Barrier(barrier, inner) => {
let barrier_id = barrier.barrier_id;
// println!("in task barrier {:?}", barrier_id);
ready!(barrier.poll(cx));
let mut inner = ready!(Future::poll(inner.as_mut(), cx));
// println!("past barrier {:?}", barrier_id);
let mut inner: InnerDistIterForEachHandle =
ready!(Future::poll(inner.as_mut(), cx));

match Pin::new(&mut inner).poll(cx) {
Poll::Ready(()) => Poll::Ready(()),
Poll::Ready(()) => {
// println!("past reqs barrier_id {:?}", barrier_id);
Poll::Ready(())
}
Poll::Pending => {
*this.state = State::Reqs(inner);
// println!(
// "reqs remaining {:?} barrier_id {:?}",
// inner.reqs.len(),
// barrier_id
// );
*this.state = State::Reqs(inner, barrier_id);
Poll::Pending
}
}
}
StateProj::Reqs(inner) => {
ready!(inner.poll(cx));
Poll::Ready(())
StateProj::Reqs(inner, barrier_id) => {
// println!(
// "reqs remaining {:?} barrier_id {:?}",
// inner.reqs.len(),
// barrier_id
// );
match inner.poll(cx) {
Poll::Ready(()) => {
// println!("past reqs barrier_id {:?}", barrier_id);
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
}
}
}
}
Expand All @@ -271,7 +302,7 @@ impl LamellarRequest for DistIterForEachHandle {
barrier.blocking_wait();
self.team.block_on(reqs).blocking_wait();
}
State::Reqs(inner) => {
State::Reqs(inner, _) => {
inner.blocking_wait();
}
}
Expand All @@ -285,15 +316,15 @@ impl LamellarRequest for DistIterForEachHandle {
waker.wake_by_ref();
false
}
State::Reqs(inner) => inner.ready_or_set_waker(waker),
State::Reqs(inner, _) => inner.ready_or_set_waker(waker),
}
}
fn val(&self) -> Self::Output {
match &self.state {
State::Barrier(_barrier, _reqs) => {
unreachable!("should never be in barrier state when val is called");
}
State::Reqs(inner) => inner.val(),
State::Reqs(inner, _) => inner.val(),
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/array/local_lock_atomic/iteration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ impl<T: Dist> LamellarArrayMutIterators<T> for LocalLockArray<T> {
fn dist_iter_mut(&self) -> Self::DistIter {
let lock: LocalRwDarc<()> = self.lock.clone();
let lock = Arc::new(self.array.block_on(async move { lock.write().await }));
self.barrier();
// self.barrier();
// println!("dist_iter thread {:?} got lock",std::thread::current().id());
LocalLockDistIterMut {
data: self.clone(),
Expand Down
6 changes: 6 additions & 0 deletions src/array/unsafe/iteration/distributed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,14 @@ macro_rules! consumer_impl {
self.data.team.world_counters.add_send_req(1);
self.data.task_group.counters.add_send_req(1);

// self.data.team.scheduler.print_status();
let barrier = self.barrier_handle();
// let barrier_id = barrier.barrier_id;
// println!("barrier_id {:?} creating dist iter handle",barrier_id);
let inner = self.clone();
let reqs_future = Box::pin(async move{

// println!("barrier id {:?} entering dist iter sched {:?} {:?} {:?}",barrier_id, inner.data.team.team_counters.outstanding_reqs.load(Ordering::SeqCst), inner.data.team.world_counters.outstanding_reqs.load(Ordering::SeqCst), inner.data.task_group.counters.outstanding_reqs.load(Ordering::SeqCst));
let reqs = match sched {
Schedule::Static => inner.sched_static(am),
Schedule::Dynamic => inner.sched_dynamic(am),
Expand All @@ -68,6 +73,7 @@ macro_rules! consumer_impl {
inner.data.team.team_counters.outstanding_reqs.fetch_sub(1,Ordering::SeqCst);
inner.data.team.world_counters.outstanding_reqs.fetch_sub(1,Ordering::SeqCst);
inner.data.task_group.counters.outstanding_reqs.fetch_sub(1,Ordering::SeqCst);
// println!("barrier id {:?} done with dist iter sched {:?} {:?} {:?}",barrier_id,inner.data.team.team_counters.outstanding_reqs.load(Ordering::SeqCst), inner.data.team.world_counters.outstanding_reqs.load(Ordering::SeqCst), inner.data.task_group.counters.outstanding_reqs.load(Ordering::SeqCst));
reqs
});
$return_type::new(barrier,reqs_future,self)
Expand Down
Loading

0 comments on commit acd2bd3

Please sign in to comment.