Skip to content

Commit

Permalink
Merge branch 'main' into lazy-graph-comb
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Jan 29, 2024
2 parents 620bfd7 + efa548a commit 82b9561
Show file tree
Hide file tree
Showing 26 changed files with 72 additions and 87 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true }

[features]
# default = ["cpu", "autograd", "macro"]
default = ["cpu", "lazy", "static-api", "opencl", "graph", "autograd", "fork", "stack"]
default = ["cpu", "lazy", "static-api", "graph", "autograd", "fork", "serde", "json"]

cpu = []
opencl = ["dep:min-cl", "cpu", "cached"]
Expand Down
2 changes: 0 additions & 2 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ fn has_device_unified_mem() -> bool {
.unified_mem
}

use std::path::{Path, PathBuf};

// https://github.com/coreylowman/cudarc/blob/main/build.rs
#[cfg(feature = "cuda")]
fn link_cuda() {
Expand Down
18 changes: 8 additions & 10 deletions src/buffer/num.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,15 @@ impl Device for () {
}

#[inline(always)]
fn data_as_wrap<'a, T, S: crate::Shape>(
data: &'a Self::Data<T, S>,
) -> &'a Self::Wrap<T, Self::Base<T, S>> {
fn data_as_wrap<T, S: crate::Shape>(
data: &Self::Data<T, S>,
) -> &Self::Wrap<T, Self::Base<T, S>> {
data
}

fn data_as_wrap_mut<'a, T, S: crate::Shape>(
data: &'a mut Self::Data<T, S>,
) -> &'a mut Self::Wrap<T, Self::Base<T, S>> {
fn data_as_wrap_mut<T, S: crate::Shape>(
data: &mut Self::Data<T, S>,
) -> &mut Self::Wrap<T, Self::Base<T, S>> {
data
}
}
Expand Down Expand Up @@ -125,14 +125,12 @@ impl WrappedData for () {
}

#[inline]
fn wrapped_as_base<'a, T, Base: HasId + PtrType>(wrap: &'a Self::Wrap<T, Base>) -> &'a Base {
fn wrapped_as_base<T, Base: HasId + PtrType>(wrap: &Self::Wrap<T, Base>) -> &Base {
wrap
}

#[inline]
fn wrapped_as_base_mut<'a, T, Base: HasId + PtrType>(
wrap: &'a mut Self::Wrap<T, Base>,
) -> &'a mut Base {
fn wrapped_as_base_mut<T, Base: HasId + PtrType>(wrap: &mut Self::Wrap<T, Base>) -> &mut Base {
wrap
}
}
Expand Down
12 changes: 5 additions & 7 deletions src/devices.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,10 @@ pub trait Device: OnDropBuffer + Sized {
// FIXME: probably a better way to realize these
fn base_to_data<T, S: Shape>(&self, base: Self::Base<T, S>) -> Self::Data<T, S>;
fn wrap_to_data<T, S: Shape>(&self, wrap: Self::Wrap<T, Self::Base<T, S>>) -> Self::Data<T, S>;
fn data_as_wrap<'a, T, S: Shape>(
data: &'a Self::Data<T, S>,
) -> &'a Self::Wrap<T, Self::Base<T, S>>;
fn data_as_wrap_mut<'a, T, S: Shape>(
data: &'a mut Self::Data<T, S>,
) -> &'a mut Self::Wrap<T, Self::Base<T, S>>;
fn data_as_wrap<T, S: Shape>(data: &Self::Data<T, S>) -> &Self::Wrap<T, Self::Base<T, S>>;
fn data_as_wrap_mut<T, S: Shape>(
data: &mut Self::Data<T, S>,
) -> &mut Self::Wrap<T, Self::Base<T, S>>;

/// Creates a new [`Buffer`] using `A`.
///
Expand Down Expand Up @@ -111,7 +109,7 @@ macro_rules! impl_device_traits {
$crate::impl_wrapped_data!($device);

#[cfg(feature = "graph")]
crate::pass_down_optimize_mem_graph!($device);
$crate::pass_down_optimize_mem_graph!($device);

$crate::pass_down_grad_fn!($device);
$crate::pass_down_tape_actions!($device);
Expand Down
10 changes: 4 additions & 6 deletions src/devices/cpu/cpu_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,14 @@ impl<Mods: OnDropBuffer> Device for CPU<Mods> {
}

#[inline(always)]
fn data_as_wrap<'a, T, S: Shape>(
data: &'a Self::Data<T, S>,
) -> &'a Self::Wrap<T, Self::Base<T, S>> {
fn data_as_wrap<T, S: Shape>(data: &Self::Data<T, S>) -> &Self::Wrap<T, Self::Base<T, S>> {
data
}

#[inline(always)]
fn data_as_wrap_mut<'a, T, S: Shape>(
data: &'a mut Self::Data<T, S>,
) -> &'a mut Self::Wrap<T, Self::Base<T, S>> {
fn data_as_wrap_mut<T, S: Shape>(
data: &mut Self::Data<T, S>,
) -> &mut Self::Wrap<T, Self::Base<T, S>> {
data
}

Expand Down
1 change: 0 additions & 1 deletion src/devices/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,3 @@ mod cpu_ptr;
mod ops;

pub use cpu_ptr::*;
pub use ops::*;
8 changes: 8 additions & 0 deletions src/devices/opencl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ impl<T> HasId for CLPtr<T> {
}
}

impl<T> CLPtr<T> {
pub fn len(&self) -> usize {
self.len
}
}

impl<T> ShallowCopy for CLPtr<T> {
#[inline]
unsafe fn shallow(&self) -> Self {
Expand Down Expand Up @@ -116,6 +122,7 @@ impl<T> HostPtr<T> for CLPtr<T> {
}
}

#[cfg(unified_cl)]
impl<T> Deref for CLPtr<T> {
type Target = [T];

Expand All @@ -125,6 +132,7 @@ impl<T> Deref for CLPtr<T> {
}
}

#[cfg(unified_cl)]
impl<T> DerefMut for CLPtr<T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
Expand Down
5 changes: 4 additions & 1 deletion src/devices/opencl/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,10 @@ where
Ok(())
}
#[cfg(not(unified_cl))]
try_cl_apply_fn_mut(dev, buf, out, f);
{
try_cl_apply_fn_mut(dev, buf, out, **f)?;
Ok(())
}
})
.unwrap();

Expand Down
4 changes: 2 additions & 2 deletions src/exec_on_cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ macro_rules! to_cpu {
macro_rules! to_raw_host {
($cpu:expr, $($t:ident),*) => {
$(
let $t = &unsafe { $crate::Buffer::<_, _, ()>::from_raw_host_device($cpu, $t.data.host_ptr, $t.len()) };
let $t = &unsafe { $crate::Buffer::<_, _, ()>::from_raw_host_device($cpu, $t.base().host_ptr, $t.len()) };
)*
};
}
Expand All @@ -233,7 +233,7 @@ macro_rules! to_raw_host_mut {
($cpu:expr, $($t:ident, $cpu_name:ident),*) => {
$(
#[allow(unused_mut)]
let mut $cpu_name = &mut unsafe { $crate::Buffer::<_, _, ()>::from_raw_host_device($cpu, $t.data.host_ptr, $t.len()) };
let mut $cpu_name = &mut unsafe { $crate::Buffer::<_, _, ()>::from_raw_host_device($cpu, $t.base().host_ptr, $t.len()) };
)*
};
}
Expand Down
2 changes: 1 addition & 1 deletion src/exec_on_cpu/cl_may_unified.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ macro_rules! cl_cpu_exec_unified {
let cpu = CPU::<Base>::new();
if $device.unified_mem() {

$crate::to_raw_host!($crate::CPU::<$crate::CachedModule<$crate::Base, $crate::CPU>>, $($t),*);
$crate::to_raw_host!(&$device.cpu, $($t),*);

#[cfg(not(feature = "realloc"))]
{
Expand Down
4 changes: 2 additions & 2 deletions src/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ macro_rules! pass_down_grad_fn {
fn add_grad_fn<Args: $crate::Parents<N> + $crate::UpdateArgs, const N: usize>(
&self,
args: Args,
op: fn(&mut Args) -> crate::Result<()>,
op: fn(&mut Args) -> $crate::Result<()>,
) {
self.modules.add_grad_fn(args, op)
}
Expand Down Expand Up @@ -253,7 +253,7 @@ macro_rules! pass_down_add_operation {
fn add_op<Args: $crate::Parents<N> + $crate::UpdateArgs, const N: usize>(
&self,
args: Args,
operation: fn(&mut Args) -> crate::Result<()>,
operation: fn(&mut Args) -> $crate::Result<()>,
) -> $crate::Result<()> {
self.modules.add_op(args, operation)
}
Expand Down
2 changes: 1 addition & 1 deletion src/hooks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub trait OnDropBuffer: WrappedData {
fn on_drop_buffer<T, D: Device, S: Shape>(&self, _device: &D, _buf: &Buffer<T, D, S>) {}
}

pub trait OnNewBuffer<T, D: Device, S: Shape> {
pub trait OnNewBuffer<T, D: Device, S: Shape = ()> {
#[track_caller]
fn on_new_buffer(&self, _device: &D, _new_buf: &Buffer<T, D, S>) {}
}
12 changes: 6 additions & 6 deletions src/id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ impl<T> UpdateArg for NoId<T> {
#[inline]
#[cfg(not(feature = "no-std"))]
fn update_arg<B>(
&mut self,
_to_update: &mut Self,
_id: Option<UniqueId>,
_buffers: &mut crate::Buffers<B>,
) -> crate::Result<()> {
Expand All @@ -113,7 +113,7 @@ impl<T> UpdateArg for NoId<T> {
impl<'a, T: 'static, D: Device + 'static, S: Shape + 'static> UpdateArg for &Buffer<'a, T, D, S> {
#[cfg(not(feature = "no-std"))]
fn update_arg<B: crate::AsAny>(
&mut self,
to_update: &mut Self,
id: Option<UniqueId>,
buffers: &mut crate::Buffers<B>,
) -> crate::Result<()> {
Expand All @@ -123,7 +123,7 @@ impl<'a, T: 'static, D: Device + 'static, S: Shape + 'static> UpdateArg for &Buf
.get(&id.unwrap())
.ok_or(DeviceError::InvalidLazyBuf)?;
// let any = buf.as_any();
*self = unsafe { &*(buf.as_any() as *const dyn Any as *const Buffer<T, D, S>) };
*to_update = unsafe { &*(buf.as_any() as *const dyn Any as *const Buffer<T, D, S>) };
// *self = buffers.get(&self.id()).unwrap().downcast_ref().unwrap();
Ok(())
}
Expand All @@ -133,8 +133,8 @@ impl<'a, T: 'static, D: Device + 'static, S: Shape + 'static> UpdateArg
for &mut Buffer<'a, T, D, S>
{
#[cfg(not(feature = "no-std"))]
fn update_arg<B>(
&mut self,
fn update_arg<B: crate::AsAny>(
to_update: &mut Self,
id: Option<UniqueId>,
buffers: &mut crate::Buffers<B>,
) -> crate::Result<()> {
Expand All @@ -143,7 +143,7 @@ impl<'a, T: 'static, D: Device + 'static, S: Shape + 'static> UpdateArg
let buf = buffers
.get_mut(&id.unwrap())
.ok_or(DeviceError::InvalidLazyBuf)?;
// *self = unsafe { &mut *(&mut **buf as *mut dyn ShallowCopyable as *mut Buffer<T, D, S>) };
// *to_update = unsafe { &mut *(&mut **buf as *mut dyn Any as *mut Buffer<T, D, S>) };
Ok(())
// *self = buffers.get(&self.id()).unwrap().downcast_ref().unwrap();
}
Expand Down
6 changes: 2 additions & 4 deletions src/modules/autograd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,12 @@ impl<Mods: WrappedData> WrappedData for Autograd<Mods> {
}

#[inline]
fn wrapped_as_base<'a, T, Base: HasId + PtrType>(wrap: &'a Self::Wrap<T, Base>) -> &'a Base {
fn wrapped_as_base<T, Base: HasId + PtrType>(wrap: &Self::Wrap<T, Base>) -> &Base {
Mods::wrapped_as_base(wrap)
}

#[inline]
fn wrapped_as_base_mut<'a, T, Base: HasId + PtrType>(
wrap: &'a mut Self::Wrap<T, Base>,
) -> &'a mut Base {
fn wrapped_as_base_mut<T, Base: HasId + PtrType>(wrap: &mut Self::Wrap<T, Base>) -> &mut Base {
Mods::wrapped_as_base_mut(wrap)
}
}
Expand Down
6 changes: 2 additions & 4 deletions src/modules/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,12 @@ impl WrappedData for Base {
}

#[inline]
fn wrapped_as_base<'a, T, Base: HasId + PtrType>(wrap: &'a Self::Wrap<T, Base>) -> &'a Base {
fn wrapped_as_base<T, Base: HasId + PtrType>(wrap: &Self::Wrap<T, Base>) -> &Base {
wrap
}

#[inline]
fn wrapped_as_base_mut<'a, T, Base: HasId + PtrType>(
wrap: &'a mut Self::Wrap<T, Base>,
) -> &'a mut Base {
fn wrapped_as_base_mut<T, Base: HasId + PtrType>(wrap: &mut Self::Wrap<T, Base>) -> &mut Base {
wrap
}
}
Expand Down
6 changes: 2 additions & 4 deletions src/modules/cached.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,12 @@ impl<Mods: WrappedData, SD: Device> WrappedData for CachedModule<Mods, SD> {
}

#[inline]
fn wrapped_as_base<'a, T, Base: HasId + PtrType>(wrap: &'a Self::Wrap<T, Base>) -> &'a Base {
fn wrapped_as_base<T, Base: HasId + PtrType>(wrap: &Self::Wrap<T, Base>) -> &Base {
Mods::wrapped_as_base(wrap)
}

#[inline]
fn wrapped_as_base_mut<'a, T, Base: HasId + PtrType>(
wrap: &'a mut Self::Wrap<T, Base>,
) -> &'a mut Base {
fn wrapped_as_base_mut<T, Base: HasId + PtrType>(wrap: &mut Self::Wrap<T, Base>) -> &mut Base {
Mods::wrapped_as_base_mut(wrap)
}
}
Expand Down
6 changes: 2 additions & 4 deletions src/modules/fork.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,12 @@ impl<Mods: WrappedData> WrappedData for Fork<Mods> {
}

#[inline]
fn wrapped_as_base<'a, T, Base: HasId + PtrType>(wrap: &'a Self::Wrap<T, Base>) -> &'a Base {
fn wrapped_as_base<T, Base: HasId + PtrType>(wrap: &Self::Wrap<T, Base>) -> &Base {
Mods::wrapped_as_base(wrap)
}

#[inline]
fn wrapped_as_base_mut<'a, T, Base: HasId + PtrType>(
wrap: &'a mut Self::Wrap<T, Base>,
) -> &'a mut Base {
fn wrapped_as_base_mut<T, Base: HasId + PtrType>(wrap: &mut Self::Wrap<T, Base>) -> &mut Base {
Mods::wrapped_as_base_mut(wrap)
}
}
Expand Down
6 changes: 2 additions & 4 deletions src/modules/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,12 @@ impl<Mods: WrappedData> WrappedData for Graph<Mods> {
}

#[inline]
fn wrapped_as_base<'a, T, Base: HasId + PtrType>(wrap: &'a Self::Wrap<T, Base>) -> &'a Base {
fn wrapped_as_base<T, Base: HasId + PtrType>(wrap: &Self::Wrap<T, Base>) -> &Base {
Mods::wrapped_as_base(wrap)
}

#[inline]
fn wrapped_as_base_mut<'a, T, Base: HasId + PtrType>(
wrap: &'a mut Self::Wrap<T, Base>,
) -> &'a mut Base {
fn wrapped_as_base_mut<T, Base: HasId + PtrType>(wrap: &mut Self::Wrap<T, Base>) -> &mut Base {
Mods::wrapped_as_base_mut(wrap)
}
}
Expand Down
7 changes: 4 additions & 3 deletions src/modules/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,10 @@ impl<Mods: AddOperation> AddOperation for Lazy<Mods> {
args: Args,
operation: fn(&mut Args) -> crate::Result<()>,
) -> crate::Result<()> {
Ok(self.graph.try_borrow_mut()
.expect("already borrowed: BorrowMutError - is the inner operation trying to add an operation as well?")
.add_operation(args, operation))
self.graph.try_borrow_mut()
.expect("already borrowed: BorrowMutError - is the inner operation trying to add an operation as well?")
.add_operation(args, operation);
Ok(())
}
}

Expand Down
8 changes: 3 additions & 5 deletions src/modules/lazy/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,12 @@ impl<Mods: WrappedData> WrappedData for Lazy<Mods> {
}

#[inline]
fn wrapped_as_base<'a, T, Base: HasId + PtrType>(wrap: &'a Self::Wrap<T, Base>) -> &'a Base {
fn wrapped_as_base<T, Base: HasId + PtrType>(wrap: &Self::Wrap<T, Base>) -> &Base {
Mods::wrapped_as_base(wrap.data.as_ref().expect(MISSING_DATA))
}

#[inline]
fn wrapped_as_base_mut<'a, T, Base: HasId + PtrType>(
wrap: &'a mut Self::Wrap<T, Base>,
) -> &'a mut Base {
fn wrapped_as_base_mut<T, Base: HasId + PtrType>(wrap: &mut Self::Wrap<T, Base>) -> &mut Base {
Mods::wrapped_as_base_mut(wrap.data.as_mut().expect(MISSING_DATA))
}
}
Expand Down Expand Up @@ -71,7 +69,7 @@ impl<Data: PtrType, T> PtrType for LazyWrapper<Data, T> {
}
}

const MISSING_DATA: &'static str =
const MISSING_DATA: &str =
"This lazy buffer does not contain any data. Try with a buffer.replace() call.";

impl<Data: Deref<Target = [T]>, T> Deref for LazyWrapper<Data, T> {
Expand Down
4 changes: 2 additions & 2 deletions src/parents.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl AllParents for () {}
impl UpdateArg for () {
#[cfg(not(feature = "no-std"))]
fn update_arg<B>(
&mut self,
_to_update: &mut Self,
_id: Option<crate::UniqueId>,
_buffers: &mut crate::Buffers<B>,
) -> crate::Result<()> {
Expand Down Expand Up @@ -73,7 +73,7 @@ macro_rules! impl_parents {
let mut ids = ids.iter();
#[allow(non_snake_case)]
let ($($to_impl,)+) = self;
$($to_impl.update_arg(*ids.next().unwrap(), buffers)?;)*
$($to_impl::update_arg($to_impl, *ids.next().unwrap(), buffers)?;)*
Ok(())
}
}
Expand Down
Loading

0 comments on commit 82b9561

Please sign in to comment.