Skip to content

Commit

Permalink
Add continue_deallocation to AllocFlag
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Dec 14, 2023
1 parent 4427112 commit 3be19d5
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/devices/cpu/cpu_ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ impl<T> Default for CPUPtr<T> {

impl<T> Drop for CPUPtr<T> {
fn drop(&mut self) {
if !matches!(self.flag, AllocFlag::None | AllocFlag::BorrowedCache) {
if !self.flag.continue_deallocation() {
return;
}

Expand Down
2 changes: 1 addition & 1 deletion src/devices/cuda/cuda_ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl<T> Default for CUDAPtr<T> {

impl<T> Drop for CUDAPtr<T> {
fn drop(&mut self) {
if !matches!(self.flag, AllocFlag::None | AllocFlag::BorrowedCache) {
if !self.flag.continue_deallocation() {
return;
}

Expand Down
2 changes: 1 addition & 1 deletion src/devices/opencl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ impl<T> DerefMut for CLPtr<T> {

impl<T> Drop for CLPtr<T> {
fn drop(&mut self) {
if !matches!(self.flag, AllocFlag::None | AllocFlag::BorrowedCache) {
if !self.flag.continue_deallocation() {
return;
}

Expand Down
2 changes: 1 addition & 1 deletion src/devices/vulkan/vk_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ impl<T> VkArray<T> {
impl<T> Drop for VkArray<T> {
#[inline]
fn drop(&mut self) {
if self.flag != AllocFlag::None {
if !self.flag.continue_deallocation() {
return;
}
unsafe {
Expand Down
7 changes: 7 additions & 0 deletions src/flag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,10 @@ impl PartialEq for AllocFlag {
core::mem::discriminant(self) == core::mem::discriminant(other)
}
}

impl AllocFlag {
#[inline]
pub fn continue_deallocation(&self) -> bool {
matches!(self, AllocFlag::None | AllocFlag::BorrowedCache | AllocFlag::Lazy)
}
}
16 changes: 8 additions & 8 deletions src/modules/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,8 @@ mod tests {

{
let buf = Buffer::<i32, _>::new(&device, 10);
let out = device.apply_fn(&buf, |x| x.add(3));
assert_eq!(out.read(), &[0; 10]);
let _out = device.apply_fn(&buf, |x| x.add(3));
// assert_eq!(out.replace().read(), &[0; 10]);
}

if DeviceError::InvalidLazyBuf
Expand All @@ -351,9 +351,9 @@ mod tests {
let buf = Buffer::<i32, _>::new(&device, 10);
let out = device.apply_fn(&buf, |x| x.add(3));

assert_eq!(out.read(), &[0; 10]);
// assert_eq!(out.read(), &[0; 10]);
unsafe { device.run().unwrap() };
assert_eq!(out.read(), &[3; 10]);
assert_eq!(out.replace().read(), &[3; 10]);
}

#[test]
Expand All @@ -380,18 +380,18 @@ mod tests {
let buf = Buffer::<i32, _>::new(&device, 10);
let lhs = device.apply_fn(&buf, |x| x.add(3));

assert_eq!(lhs.read(), &[0; 10]);
// assert_eq!(lhs.read(), &[0; 10]);
let rhs = Buffer::<_, _>::from((&device, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]));

assert_eq!(rhs.read(), &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);

let out = device.add(&lhs, &rhs);
assert_eq!(out.read(), &[0; 10]);
// assert_eq!(out.read(), &[0; 10]);

unsafe { device.run().unwrap() };
assert_eq!(lhs.read(), &[3; 10]);
assert_eq!(lhs.replace().read(), &[3; 10]);

assert_eq!(out.read(), [4, 5, 6, 7, 8, 9, 10, 11, 12, 13])
assert_eq!(out.replace().read(), [4, 5, 6, 7, 8, 9, 10, 11, 12, 13])
}

#[test]
Expand Down
6 changes: 4 additions & 2 deletions src/modules/lazy/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,21 @@ impl<Data: PtrType, T> PtrType for LazyWrapper<Data, T> {
}
}

const MISSING_DATA: &'static str = "This lazy buffer does not contain any data. Try with a buffer.replace() call.";

impl<Data: Deref<Target = [T]>, T> Deref for LazyWrapper<Data, T> {
type Target = [T];

#[inline]
fn deref(&self) -> &Self::Target {
self.data.as_ref().unwrap()
self.data.as_ref().expect(MISSING_DATA)
}
}

impl<Data: DerefMut<Target = [T]>, T> DerefMut for LazyWrapper<Data, T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
self.data.as_mut().unwrap()
self.data.as_mut().expect(MISSING_DATA)
}
}

Expand Down

0 comments on commit 3be19d5

Please sign in to comment.