Skip to content

Commit

Permalink
Merge pull request #37 from elftausend/lazy-graph-comb
Browse files Browse the repository at this point in the history
Lazy graph comb
  • Loading branch information
elftausend authored Jan 31, 2024
2 parents efa548a + 42505e6 commit 2934c21
Show file tree
Hide file tree
Showing 35 changed files with 563 additions and 224 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ jobs:
run: cargo test --verbose --no-default-features --features cpu,lazy
- name: Run tests,cached,fork
run: cargo test --verbose --no-default-features --features cpu,cached,fork
- name: Run graph
run: cargo test --verbose --no-default-features --features cpu,graph
- name: Run cached
run: cargo test --verbose --no-default-features --features cpu,cached

test-cached:

Expand Down
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ custos-macro = {git = "https://github.com/elftausend/custos-macro", optional=tru
libm = { version="0.2.6", optional = true }

ash = { version = "0.37", optional = true }
naga = { version = "0.14", features = ["wgsl-in", "spv-out"], optional = true }
naga = { version = "0.19", features = ["wgsl-in", "spv-out"], optional = true }

half = {version = "2.3", default-features = false, optional = true}

Expand All @@ -49,8 +49,8 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true }
# min-cl = { version = "0.3.0", optional=true }

[features]
# default = ["cpu", "autograd", "macro"]
default = ["cpu", "lazy", "static-api", "graph", "autograd", "fork", "serde", "json"]
default = ["cpu", "graph", "cached"]
# default = ["cpu", "lazy", "static-api", "graph", "autograd", "fork", "serde", "json"]

cpu = []
opencl = ["dep:min-cl", "cpu", "cached"]
Expand Down
2 changes: 1 addition & 1 deletion android-nnapi-ci/nnapitestlib/src/using_custos_nnapi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use nnapi::{nnapi_sys::OperationCode, Operand};
use crate::log;

pub fn run_custos_model() -> custos::Result<String> {
let device = NnapiDevice::<i32, Lazy<Base>>::new()?;
let device = NnapiDevice::<i32, Base>::new()?;

let lhs = Buffer::with(&device, [1, 2, 3, 4, 5, 6, 7, 8, 9, 11]);
let rhs = Buffer::with(&device, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
Expand Down
3 changes: 3 additions & 0 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ fn has_device_unified_mem() -> bool {
.unified_mem
}

#[cfg(feature = "cuda")]
use std::path::{Path, PathBuf};

// https://github.com/coreylowman/cudarc/blob/main/build.rs
#[cfg(feature = "cuda")]
fn link_cuda() {
Expand Down
12 changes: 12 additions & 0 deletions src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,18 @@ impl<'a, Mods: OnDropBuffer, T> Buffer<'a, T, crate::CUDA<Mods>> {
}
}

impl<'a, T, D, S> ShallowCopy for Buffer<'a, T, D, S>
where
D: Device,
D::Data<T, S>: ShallowCopy,
S: Shape,
{
#[inline]
unsafe fn shallow(&self) -> Self {
self.shallow()
}
}

impl<'a, T, D, S> Clone for Buffer<'a, T, D, S>
where
T: Clone,
Expand Down
1 change: 1 addition & 0 deletions src/buffer/impl_autograd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ where
pub fn grad(&self) -> &'a Self
where
D: MayTapeActions + Alloc<T>,
// D::Data<T, S>: crate::ShallowCopy,
{
unsafe {
self.device()
Expand Down
9 changes: 8 additions & 1 deletion src/buffer/num.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use core::{

use crate::{
flag::AllocFlag, Alloc, Buffer, CloneBuf, CommonPtrs, Device, HasId, OnDropBuffer, PtrType,
WrappedData,
ShallowCopy, WrappedData,
};

#[derive(Debug, Default)]
Expand Down Expand Up @@ -57,6 +57,13 @@ impl<T> From<T> for Num<T> {
}
}

impl<T> ShallowCopy for Num<T> {
#[inline]
unsafe fn shallow(&self) -> Self {
unimplemented!()
}
}

impl Device for () {
type Data<T, S: crate::Shape> = Self::Base<T, S>;
type Base<T, S: crate::Shape> = Num<T>;
Expand Down
6 changes: 4 additions & 2 deletions src/cache/borrow_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ impl Display for CachingError {

impl std::error::Error for CachingError {}

#[derive(Debug, Default)]
pub(crate) type AnyBuffers = HashMap<UniqueId, Box<dyn Any>, BuildHasherDefault<NoHasher>>;

#[derive(Default)]
pub struct BorrowCache {
pub(crate) cache: HashMap<UniqueId, Box<dyn Any>, BuildHasherDefault<NoHasher>>,
pub(crate) cache: AnyBuffers,
}

// TODO: make BorrowedCache unuseable without device (=> Static get methods with D: CacheReturn)
Expand Down
12 changes: 7 additions & 5 deletions src/cache/owned_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ impl Cache {
}
}

/// # Safety
/// Lifetime of data must be at least as long as the lifetime of the cache (usually the device).
#[track_caller]
#[inline]
pub fn get<T, S, D>(&mut self, device: &D, len: usize, callback: fn()) -> D::Base<T, S>
pub unsafe fn get<T, S, D>(&mut self, device: &D, len: usize, callback: fn()) -> D::Base<T, S>
where
D: Alloc<T> + 'static,
D::Base<T, S>: ShallowCopy + 'static,
Expand Down Expand Up @@ -80,7 +82,7 @@ mod tests {
assert_eq!(cache.nodes.len(), 1);
assert_eq!(out.len, 10);

let out1 = cache.get::<f32, (), _>(&device, 10, || ());
let out1 = unsafe { cache.get::<f32, (), _>(&device, 10, || ()) };
assert_ne!(out.ptr, out1.ptr);
}

Expand All @@ -91,10 +93,10 @@ mod tests {

assert_eq!(cache.nodes.len(), 0);

let out1 = cache.get::<f32, (), _>(&device, 10, || ());
let out1 = unsafe { cache.get::<f32, (), _>(&device, 10, || ()) };
assert_eq!(cache.nodes.len(), 1);

let out2 = cache.get::<f32, (), _>(&device, 10, || ());
let out2 = unsafe { cache.get::<f32, (), _>(&device, 10, || ()) };

assert_ne!(out1.ptr, out2.ptr);
assert_eq!(cache.nodes.len(), 2);
Expand All @@ -108,7 +110,7 @@ mod tests {

let mut prev = None;
for _ in 0..1000 {
let out3 = cache.get::<f32, (), _>(&device, 10, || ());
let out3 = unsafe { cache.get::<f32, (), _>(&device, 10, || ()) };
if prev.is_none() {
prev = Some(out3.ptr);
}
Expand Down
6 changes: 3 additions & 3 deletions src/devices.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ macro_rules! impl_device_traits {
$crate::pass_down_grad_fn!($device);
$crate::pass_down_tape_actions!($device);

$crate::pass_down_replace_buf!($device);
$crate::pass_down_replace_buf_dev!($device);
};
}

Expand All @@ -128,9 +128,9 @@ macro_rules! impl_retriever {
len: usize,
parents: impl $crate::Parents<NUM_PARENTS>,
) -> Buffer<T, Self, S> {
let data = self
let data = unsafe { self
.modules
.retrieve::<NUM_PARENTS>(self, len, parents);
.retrieve::<NUM_PARENTS>(self, len, parents) };
let buf = Buffer {
data,
device: Some(self),
Expand Down
4 changes: 2 additions & 2 deletions src/devices/cpu/cpu_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,12 @@ unsafe impl<Mods: OnDropBuffer> IsShapeIndep for CPU<Mods> {}

#[cfg(test)]
mod tests {
use crate::{Base, CPU};

#[cfg(feature = "fork")]
#[cfg(feature = "cached")]
#[test]
fn test_add_layer_cpu() {
use crate::{Base, CPU};

let cpu = CPU::<Base>::new();
let cpu = cpu.add_layer::<crate::Cached<()>>();
let cpu = cpu.add_layer::<crate::Fork<()>>();
Expand Down
9 changes: 5 additions & 4 deletions src/devices/nnapi/nnapi_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ impl<U, Mods: Retrieve<Self, T, S>, T: AsOperandCode, S: Shape> Retriever<T, S>
len: usize,
parents: impl crate::Parents<NUM_PARENTS>,
) -> Buffer<T, Self, S> {
let data = self.modules.retrieve::<NUM_PARENTS>(self, len, parents);
let data = unsafe { self.modules.retrieve::<NUM_PARENTS>(self, len, parents) };
let buf = Buffer {
data,
device: Some(self),
Expand Down Expand Up @@ -158,10 +158,11 @@ impl<U, T: AsOperandCode, Mods: OnDropBuffer> Alloc<T> for NnapiDevice<U, Mods>

impl<T, SimpleMods> NnapiDevice<T, SimpleMods> {
/// Creates a new [`NnapiDevice`].
pub fn new<NewMods>() -> crate::Result<NnapiDevice<T, Lazy<NewMods>>>
pub fn new<NewMods>() -> crate::Result<NnapiDevice<T, NewMods>>
// TODO keep in mind that lazy module requirement would make sense here
where
SimpleMods: Module<NnapiDevice<T>, Module = Lazy<NewMods>>,
Lazy<NewMods>: Setup<NnapiDevice<T, Lazy<NewMods>>>,
SimpleMods: Module<NnapiDevice<T>, Module = NewMods>,
NewMods: Setup<NnapiDevice<T, NewMods>>,
{
let mut device = NnapiDevice {
modules: SimpleMods::new(),
Expand Down
38 changes: 27 additions & 11 deletions src/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@ use core::{fmt::Debug, ops::RangeBounds};

use crate::{HasId, Parents, Shape, UniqueId, UpdateArgs, CPU};

#[cfg(feature = "graph")]
use crate::TranslatedCacheTrace;

#[cfg(feature = "cached")]
use crate::{Base, CachedModule};

Expand All @@ -22,7 +19,7 @@ pub trait Feature: OnDropBuffer {}
pub trait Retrieve<D, T, S: Shape = ()>: OnDropBuffer {
// "generator"
#[track_caller]
fn retrieve<const NUM_PARENTS: usize>(
unsafe fn retrieve<const NUM_PARENTS: usize>(
&self,
device: &D,
len: usize,
Expand Down Expand Up @@ -188,7 +185,7 @@ pub trait ReplaceBuf<T, D: Device, S: Shape>: OnDropBuffer {
}

#[macro_export]
macro_rules! pass_down_replace_buf {
macro_rules! pass_down_replace_buf_dev {
($device:ident) => {
impl<T, S: Shape, Mods: $crate::ReplaceBuf<T, Self, S>> $crate::ReplaceBuf<T, Self, S>
for $device<Mods>
Expand All @@ -204,6 +201,23 @@ macro_rules! pass_down_replace_buf {
};
}

#[macro_export]
macro_rules! pass_down_replace_buf_module {
($module:ident) => {
impl<T, S: Shape, Mods: $crate::ReplaceBuf<T, D, S>, D: $crate::Device>
$crate::ReplaceBuf<T, D, S> for $module<Mods>
{
#[inline]
fn replace_buf<'a, 'c>(
&'c self,
buffer: &'c Buffer<'a, T, D, S>,
) -> &'c Buffer<'a, T, D, S> {
self.modules.replace_buf(buffer)
}
}
};
}

pub trait AddOperation {
#[track_caller]
fn add_op<Args: Parents<N> + UpdateArgs, const N: usize>(
Expand Down Expand Up @@ -375,21 +389,23 @@ pub trait UseGpuOrCpu {

#[cfg(feature = "graph")]
pub trait OptimizeMemGraph {
fn optimize_mem_graph(
fn optimize_mem_graph<D: 'static>(
&self,
cache_traces: Option<&[TranslatedCacheTrace]>,
device: &D,
graph_translator: Option<&crate::modules::GraphTranslator>,
) -> crate::Result<()>;
}

#[macro_export]
macro_rules! pass_down_optimize_mem_graph {
($to_impl:ident) => {
impl<Mods: $crate::OptimizeMemGraph> $crate::OptimizeMemGraph for $to_impl<Mods> {
fn optimize_mem_graph(
fn optimize_mem_graph<D: 'static>(
&self,
cache_traces: Option<&[$crate::TranslatedCacheTrace]>,
) -> $crate::Result<()> {
self.modules.optimize_mem_graph(cache_traces)
device: &D,
graph_translator: Option<&$crate::modules::GraphTranslator>,
) -> crate::Result<()> {
self.modules.optimize_mem_graph(device, graph_translator)
}
}
};
Expand Down
38 changes: 14 additions & 24 deletions src/id.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
use core::{
any::Any,
ops::{Deref, DerefMut},
};
use core::ops::{Deref, DerefMut};

use crate::{Buffer, Device, DeviceError, Shape, UniqueId, UpdateArg};

Expand Down Expand Up @@ -101,34 +98,31 @@ impl<T: Into<NoId<T>>> AsNoId for T {
impl<T> UpdateArg for NoId<T> {
#[inline]
#[cfg(not(feature = "no-std"))]
fn update_arg(
fn update_arg<B>(
_to_update: &mut Self,
_id: Option<UniqueId>,
_buffers: &mut std::collections::HashMap<
crate::UniqueId,
Box<dyn core::any::Any>,
core::hash::BuildHasherDefault<crate::NoHasher>,
>,
_buffers: &mut crate::Buffers<B>,
) -> crate::Result<()> {
Ok(())
}
}

impl<'a, T: 'static, D: Device + 'static, S: Shape + 'static> UpdateArg for &Buffer<'a, T, D, S> {
#[cfg(not(feature = "no-std"))]
fn update_arg(
fn update_arg<B: crate::AsAny>(
to_update: &mut Self,
id: Option<UniqueId>,
buffers: &mut std::collections::HashMap<
crate::UniqueId,
Box<dyn core::any::Any>,
core::hash::BuildHasherDefault<crate::NoHasher>,
>,
buffers: &mut crate::Buffers<B>,
) -> crate::Result<()> {
// use crate::ShallowCopyable;

let buf = buffers
.get(&id.unwrap())
.ok_or(DeviceError::InvalidLazyBuf)?;
*to_update = unsafe { &*(&**buf as *const dyn Any as *const Buffer<T, D, S>) };
// let any = buf.as_any();
// let _to_update = buf.as_any().downcast_ref::<Buffer<T, D, S>>().unwrap();
// todo!();
*to_update = unsafe { &*(buf.as_any() as *const Buffer<T, D, S>) };
// *self = buffers.get(&self.id()).unwrap().downcast_ref().unwrap();
Ok(())
}
Expand All @@ -138,19 +132,15 @@ impl<'a, T: 'static, D: Device + 'static, S: Shape + 'static> UpdateArg
for &mut Buffer<'a, T, D, S>
{
#[cfg(not(feature = "no-std"))]
fn update_arg(
fn update_arg<B: crate::AsAny>(
to_update: &mut Self,
id: Option<UniqueId>,
buffers: &mut std::collections::HashMap<
crate::UniqueId,
Box<dyn core::any::Any>,
core::hash::BuildHasherDefault<crate::NoHasher>,
>,
buffers: &mut crate::Buffers<B>,
) -> crate::Result<()> {
let buf = buffers
.get_mut(&id.unwrap())
.ok_or(DeviceError::InvalidLazyBuf)?;
*to_update = unsafe { &mut *(&mut **buf as *mut dyn Any as *mut Buffer<T, D, S>) };
*to_update = unsafe { &mut *(buf.as_any_mut() as *mut Buffer<T, D, S>) };
Ok(())
// *self = buffers.get(&self.id()).unwrap().downcast_ref().unwrap();
}
Expand Down
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,10 @@ impl_buffer_hook_traits!(CPU);
#[cfg(not(feature = "cpu"))]
crate::impl_wrapped_data!(CPU);

#[cfg(not(feature = "no-std"))]
pub(crate) type Buffers<B> =
std::collections::HashMap<UniqueId, B, std::hash::BuildHasherDefault<NoHasher>>;

pub mod prelude {
//! Typical imports for using custos.
Expand Down
Loading

0 comments on commit 2934c21

Please sign in to comment.