Skip to content

Commit

Permalink
feat: add Access::try_get_many_mut
Browse files Browse the repository at this point in the history
  • Loading branch information
SOF3 committed Dec 10, 2023
1 parent eceaed3 commit b2d0aee
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 3 deletions.
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@
#![feature(never_type)]
#![feature(sync_unsafe_cell)]
#![feature(slice_take)]
#![feature(get_many_mut)]
#![feature(array_try_from_fn, array_try_map)]

/// Internal re-exports used in macros.
#[doc(hidden)]
Expand Down
15 changes: 15 additions & 0 deletions src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ pub trait Partition<'t>: Access + Send + Sync + Sized + 't {

/// Same as [`get_mut`](Access::get_mut), but returns a reference with lifetime `'t`.
fn into_mut(self, entity: Self::RawEntity) -> Option<&'t mut Self::Comp>;

/// Same as [`get_many_mut`](Access::get_many_mut), but returns a reference with lifetime `'t`.
fn into_many_mut<const N: usize>(
self,
entities: [Self::RawEntity; N],
) -> Option<[&'t mut Self::Comp; N]>;
}

/// Mutable access functions for a storage, generalizing [`Storage`] and [`Partition`].
Expand All @@ -118,6 +124,15 @@ pub trait Access {
/// Gets a mutable reference to the component for a specific entity if it is present.
fn get_mut(&mut self, entity: Self::RawEntity) -> Option<&mut Self::Comp>;

/// Gets mutable references to the components for specific entities if they are present.
///
/// Returns `None` if any entity is uninitialized
/// or if any entity appeared in `entities` more than once.
fn get_many_mut<const N: usize>(
&mut self,
entities: [Self::RawEntity; N],
) -> Option<[&mut Self::Comp; N]>;

/// Return value of [`iter_mut`](Self::iter_mut).
type IterMut<'u>: Iterator<Item = (Self::RawEntity, &'u mut Self::Comp)> + 'u
where
Expand Down
72 changes: 70 additions & 2 deletions src/storage/tree.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::cell::SyncUnsafeCell;
use std::collections::BTreeMap;
use std::slice;
use std::ptr::NonNull;
use std::{array, slice};

use super::{Access, ChunkMut, ChunkRef, Partition, Storage};
use crate::entity;
use crate::{entity, util};

/// A storage based on [`BTreeMap`].
pub struct Tree<RawT: entity::Raw, C> {
Expand All @@ -26,6 +27,36 @@ impl<RawT: entity::Raw, C: Send + Sync + 'static> Access for Tree<RawT, C> {
self.data.get_mut(&id).map(|cell| cell.get_mut())
}

fn get_many_mut<const N: usize>(
&mut self,
entities: [Self::RawEntity; N],
) -> Option<[&mut Self::Comp; N]> {
let ptrs = entities.map(|entity| {
let datum = self.data.get(&entity)?;
let ptr = datum.get();
NonNull::new(ptr)
});

if !util::is_all_distinct_quadtime(&ptrs) {
return None;
}

if ptrs.iter().any(|ptr| ptr.is_none()) {
return None;
}

Some(ptrs.map(|ptr| {
let mut ptr = ptr.expect("checked all are not none");

unsafe {
// All pointers originated from a `&mut self`, so all possible aliases are in locals.
// We have checked that all `ptrs` are distinct,
// and since they come from UnsafeCell, they cannot overlap.
ptr.as_mut()
}
}))
}

type IterMut<'t> = impl Iterator<Item = (Self::RawEntity, &'t mut Self::Comp)> + 't;
fn iter_mut(&mut self) -> Self::IterMut<'_> {
Box::new(self.data.iter_mut().map(|(&entity, cell)| (entity, cell.get_mut())))
Expand Down Expand Up @@ -111,6 +142,13 @@ impl<'t, RawT: entity::Raw, C: Send + Sync + 'static> Access for StoragePartitio
}
}

fn get_many_mut<const N: usize>(
&mut self,
entities: [RawT; N],
) -> Option<[&mut Self::Comp; N]> {
self.by_ref().into_many_mut(entities)
}

type IterMut<'u> = impl Iterator<Item = (Self::RawEntity, &'u mut Self::Comp)> + 'u where Self: 'u;
fn iter_mut(&mut self) -> Self::IterMut<'_> { self.by_ref().into_iter_mut() }
}
Expand Down Expand Up @@ -158,6 +196,36 @@ impl<'t, RawT: entity::Raw, C: Send + Sync + 'static> Partition<'t>
}
}

fn into_many_mut<const N: usize>(
self,
entities: [Self::RawEntity; N],
) -> Option<[&'t mut Self::Comp; N]> {
for entity in entities {
self.assert_bounds(entity);
}

let ptrs = entities.map(|entity| {
let datum = self.data.get(&entity)?;
let ptr = datum.get();
NonNull::new(ptr)
});

if !util::is_all_distinct_quadtime(&ptrs) {
return None;
}

array::try_from_fn(|i| {
let mut ptr = ptrs[i]?;

unsafe {
// All pointers originated from a `&mut self`, so all possible aliases are in locals.
// We have checked that all `ptrs` are distinct,
// and since they come from UnsafeCell, they cannot overlap.
Some(ptr.as_mut())
}
})
}

fn split_out(&mut self, entity: RawT) -> Self {
self.assert_bounds(entity);

Expand Down
45 changes: 45 additions & 0 deletions src/storage/vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,24 @@ impl<RawT: entity::Raw, C: Send + Sync + 'static> Access for VecStorage<RawT, C>
}
}

fn get_many_mut<const N: usize>(
&mut self,
entities: [RawT; N],
) -> Option<[&mut Self::Comp; N]> {
let indices = entities.map(|id| id.to_primitive());

if !indices.iter().all(|&index| self.bit(index)) {
return None;
}

let values = self.data.get_many_mut(indices).ok()?;

Some(values.map(|value| {
// Safety: values correspond to indices checked above.
unsafe { value.assume_init_mut() }
}))
}

type IterMut<'t> = impl Iterator<Item = (RawT, &'t mut C)> + 't;
fn iter_mut(&mut self) -> Self::IterMut<'_> { iter_mut(0, &self.bits, &mut self.data) }
}
Expand Down Expand Up @@ -186,6 +204,13 @@ impl<'t, RawT: entity::Raw, C: Send + Sync + 'static> Access for StoragePartitio

fn get_mut(&mut self, entity: RawT) -> Option<&mut C> { self.by_ref().into_mut(entity) }

fn get_many_mut<const N: usize>(
&mut self,
entities: [RawT; N],
) -> Option<[&mut Self::Comp; N]> {
self.by_ref().into_many_mut(entities)
}

type IterMut<'u> = impl Iterator<Item = (RawT, &'u mut C)> + 'u where Self: 'u;
fn iter_mut(&mut self) -> Self::IterMut<'_> { self.by_ref().into_iter_mut() }
}
Expand Down Expand Up @@ -220,6 +245,26 @@ impl<'t, RawT: entity::Raw, C: Send + Sync + 'static> Partition<'t>
}
}

fn into_many_mut<const N: usize>(
self,
entities: [Self::RawEntity; N],
) -> Option<[&'t mut Self::Comp; N]> {
let indices: [usize; N] =
entities.try_map(|entity| match entity.to_primitive().checked_sub(self.offset) {
Some(index) => match self.bits.get(index) {
Some(bit) if *bit => Some(index),
_ => None,
},
None => panic!("Entity {entity:?} is not in the partition {:?}..", self.offset),
})?;
let values = self.data.get_many_mut(indices).ok()?;
Some(values.map(move |value| {
// Safety: all indices have been checked to be initialized
// before getting mapped into `indices`
unsafe { value.assume_init_mut() }
}))
}

fn split_out(&mut self, entity: RawT) -> Self {
let index =
entity.to_primitive().checked_sub(self.offset).expect("parameter out of bounds");
Expand Down
11 changes: 11 additions & 0 deletions src/system/access/single.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,17 @@ where
self.storage.get_mut(entity.id())
}

/// Returns mutable references to the components for the specified entities.
///
/// Returns `None` if any component is not present in the entity
/// or if the same entity is passed multiple times.
pub fn try_get_many_mut<const N: usize>(
&mut self,
entities: [impl entity::Ref<Archetype = A>; N],
) -> Option<[&mut C; N]> {
self.storage.get_many_mut(entities.map(|entity| entity.id()))
}

/// Iterates over mutable references to all initialized components in this storage.
pub fn iter_mut<'t>(
&'t mut self,
Expand Down
15 changes: 14 additions & 1 deletion src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,22 @@ unsafe impl UnsafeEqOrd for usize {}
/// Transforms a value behind a mutable reference with a function that moves it.
///
/// The placeholder value will be left at the position of `ref_` if the transform function panics.
pub fn transform_mut<T, R>(ref_: &mut T, placeholder: T, transform: impl FnOnce(T) -> (T, R)) -> R {
pub(crate) fn transform_mut<T, R>(
ref_: &mut T,
placeholder: T,
transform: impl FnOnce(T) -> (T, R),
) -> R {
let old = mem::replace(ref_, placeholder);
let (new, ret) = transform(old);
*ref_ = new;
ret
}

pub(crate) fn is_all_distinct_quadtime<T: PartialEq>(slice: &[T]) -> bool {
for (i, item) in slice.iter().enumerate() {
if !slice[(i + 1)..].iter().all(|other| item == other) {
return false;
}
}
true
}

0 comments on commit b2d0aee

Please sign in to comment.