Skip to content

Commit

Permalink
Add assert eq for check for mismatching buf types in lazy replace buf…
Browse files Browse the repository at this point in the history
…, add ignoring of mismatching buffer types in cache + graph optimization step
  • Loading branch information
elftausend committed Feb 9, 2024
1 parent 3583d48 commit 9f84d8b
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 6 deletions.
12 changes: 12 additions & 0 deletions src/modules/cached.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use core::{
any::Any,
cell::{Cell, RefCell},
marker::PhantomData,
ops::Deref,
};

use crate::{
Expand Down Expand Up @@ -265,6 +267,16 @@ impl<Mods: OptimizeMemGraph, SD: Device> OptimizeMemGraph for CachedModule<Mods,
.clone();

for to_replace in &cache_trace.use_cache_idxs {
if cache
.nodes
.get(&(*to_replace as UniqueId))
.unwrap()
.deref()
.type_id()
!= used_to_replace.deref().type_id()
{
continue;
}
cache
.nodes
.insert(*to_replace as UniqueId, used_to_replace.clone());
Expand Down
84 changes: 79 additions & 5 deletions src/modules/graph/opt_graph/optimize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,15 @@ impl OptGraph {
/// let mut graph = OptGraph::default();
/// let a = graph.add_leaf(10);
/// let b = graph.add_leaf(10);
///
///
/// let c = graph.add_node(10, vec![a, b]);
///
///
/// let d = graph.add_node(10, vec![c, c]);
///
///
/// let _u = graph.add_node(10, vec![d, a]);
///
///
/// let _e = graph.add_node(10, vec![d, b]);
///
///
/// assert!(graph.is_path_optimizable(graph.node(c)));
/// assert!(!graph.is_path_optimizable(graph.node(d)));
/// ```
Expand Down Expand Up @@ -597,6 +597,80 @@ mod tests {
}
}

#[cfg(feature = "cpu")]
#[test]
fn test_mismatched_optimized_types_cached() {
use crate::{
Base, Buffer, Cached, Cursor, Device, Graph, HasId, OptimizeMemGraph, Retriever, CPU,
};

let device = CPU::<Graph<Cached<Base>>>::new();

// idx: 0, deps: []
let x: Buffer<f32, _> = device.buffer([1.; 1000]);
// idx: 1, deps: []
let b: Buffer<f32, _> = device.buffer([1.1; 1000]);

for i in device.range(2) {
// idx: 2, deps: [0, 0]
let squared: Buffer<f32, _> = device.retrieve::<2>(1000, (&x, &x));
// idx: 3, deps: [1, 0]
let add: Buffer<f32, _> = device.retrieve::<2>(1000, (&b, &x));
// idx: 4, deps: [3, 1]
let mul_b: Buffer<u8, _> = device.retrieve::<2>(1000, (&add, &b));
// idx: 5, deps: [2, 0]
let mul: Buffer<f32, _> = device.retrieve::<2>(1000, (&squared, &x));
// idx: 6, deps: [5, 4]
let out: Buffer<f32, _> = device.retrieve::<2>(1000, (&mul, &mul_b));

if i == 0 {
assert_ne!(squared.id(), mul.id());
}

if i == 1 {
assert_eq!(squared.id(), mul.id());
assert_eq!(squared.id(), out.id());

break;
}
device.optimize_mem_graph(&device, None).unwrap();
}
}

#[cfg(feature = "cpu")]
#[should_panic]
#[test]
fn test_mismatched_optimized_types_lazy() {
use crate::{
Base, Buffer, Cursor, Device, Graph, HasId, Lazy, OptimizeMemGraph, Retriever, Run, CPU,
};

let device = CPU::<Graph<Lazy<Base>>>::new();

// idx: 0, deps: []
let x: Buffer<f32, _> = device.buffer([1.; 1000]);
// idx: 1, deps: []
let b: Buffer<f32, _> = device.buffer([1.1; 1000]);
// idx: 2, deps: [0, 0]
let squared: Buffer<f32, _> = device.retrieve::<2>(1000, (&x, &x));
// idx: 3, deps: [1, 0]
let add: Buffer<f32, _> = device.retrieve::<2>(1000, (&b, &x));
// idx: 4, deps: [3, 1]
let mul_b: Buffer<u8, _> = device.retrieve::<2>(1000, (&add, &b));
// idx: 5, deps: [2, 0]
let mul: Buffer<f32, _> = device.retrieve::<2>(1000, (&squared, &x));
// idx: 6, deps: [5, 4]
let out: Buffer<f32, _> = device.retrieve::<2>(1000, (&mul, &mul_b));

device.optimize_mem_graph(&device, None).unwrap();
let _err = unsafe { device.run() };

assert_eq!(squared.replace().id(), mul.replace().id());
assert_eq!(squared.replace().id(), out.replace().id());

assert_eq!(add.replace().id(), mul_b.replace().id());
}

#[test]
fn test_no_cache_trace_in_graph() {
let mut graph = OptGraph::default();
Expand Down
9 changes: 8 additions & 1 deletion src/modules/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::{
use crate::DeviceError;

use core::{
any::Any,
any::{Any, TypeId},
cell::{Cell, RefCell},
fmt::Debug,
};
Expand Down Expand Up @@ -156,6 +156,8 @@ impl<Mods> Lazy<Mods> {
alloc_fn(&mut self.buffers.borrow_mut(), *id, device);
let buf = self.buffers.borrow().get(&id.id).unwrap().shallow_copy();

// TODO: add type check - lower assert_eq to debug in lazy replace buf

for use_id_as_well in &cache_trace.use_cache_idxs {
let use_id_as_well_id = graph_trans
.idx_to_buf_id
Expand Down Expand Up @@ -329,6 +331,11 @@ impl<T: 'static, D: Device + 'static, S: Shape, Mods: OnDropBuffer> ReplaceBuf<T
match self.buffers.borrow().get(&buffer.id()) {
Some(buf) => {
let buf = &**buf;
assert_eq!(
buf.as_any().type_id(),
TypeId::of::<Buffer<T, D, S>>(),
"Type data does not match! e.g. optimized graph with different types"
);
unsafe { &*(buf as *const _ as *const Buffer<T, D, S>) }
}
None => buffer,
Expand Down
6 changes: 6 additions & 0 deletions src/modules/lazy/generic_support.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,19 @@ use crate::{AsAny, ShallowCopy};

pub trait BoxedShallowCopy: 'static {
fn shallow_copy(&self) -> Box<dyn BoxedShallowCopy>;
fn as_any(&self) -> &dyn Any;
}

impl<T: ShallowCopy + 'static> BoxedShallowCopy for T {
#[inline]
fn shallow_copy(&self) -> Box<dyn BoxedShallowCopy> {
Box::new(unsafe { self.shallow() })
}

#[inline]
fn as_any(&self) -> &dyn Any {
self
}
}

impl AsAny for Box<dyn BoxedShallowCopy> {
Expand Down

0 comments on commit 9f84d8b

Please sign in to comment.