Skip to content

Commit

Permalink
Add unified mem check
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Feb 7, 2024
1 parent ad1b561 commit c183bbd
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true }
# min-cl = { version = "0.3.0", optional=true }

[features]
default = ["cpu", "graph", "cached", "autograd", "lazy"]
default = ["cpu", "graph", "cached", "autograd", "lazy", "opencl"]
# default = ["cpu", "lazy", "static-api", "graph", "autograd", "fork", "serde", "json"]

std = []
Expand Down
4 changes: 2 additions & 2 deletions src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ impl<'a, T, D: Device, S: Shape> Buffer<'a, T, D, S> {
}

#[inline]
pub fn require_grad(self) -> Buffer<'a, T, D, S>
pub fn require_grad(self) -> Buffer<'a, T, D, S>
where
D: OnNewBuffer<T, D, S>,
D: OnNewBuffer<T, D, S>,
{
if let Some(device) = self.device {
device.on_drop_buffer(device, &self);
Expand Down
8 changes: 4 additions & 4 deletions src/buffer/impl_autograd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ where
D: ZeroGrad<T> + MayTapeActions + Alloc<T>,
// D::Data<T, S>: crate::ShallowCopy,
{
// TODO: consider activating this check ->
// e.g. binary grad ops are computed in a single function where differentiating between
// TODO: consider activating this check ->
// e.g. binary grad ops are computed in a single function where differentiating between
// req grad and no req grad is not possible/ difficult
// assert!(self.requires_grad(), "Buffer does not require gradient.");
unsafe {
Expand Down Expand Up @@ -101,8 +101,8 @@ where
where
D: MayTapeActions + Alloc<T> + ZeroGrad<T>,
{
// TODO: consider activating this check ->
// e.g. binary grad ops are computed in a single function where differentiating between
// TODO: consider activating this check ->
// e.g. binary grad ops are computed in a single function where differentiating between
// req grad and no req grad is not possible/ difficult
// assert!(self.requires_grad(), "Buffer does not require gradient.");
unsafe {
Expand Down
15 changes: 15 additions & 0 deletions src/devices/opencl/cl_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ impl<SimpleMods> OpenCL<SimpleMods> {
device: CLDevice::new(device_idx)?,
cpu: CPU::<Cached<Base>>::new(),
};
opencl.unified_mem_check();
NewMods::setup(&mut opencl)?;
Ok(opencl)
}
Expand All @@ -101,12 +102,26 @@ impl<SimpleMods> OpenCL<SimpleMods> {
device: CLDevice::fastest()?,
cpu: CPU::<Cached<Base>>::new(),
};
opencl.unified_mem_check();

NewMods::setup(&mut opencl)?;
Ok(opencl)
}

}

impl<Mods> OpenCL<Mods> {
pub fn unified_mem_check(&self) {
#[cfg(unified_cl)]
if !self.unified_mem() {
panic!("
Your selected compute device does not support unified memory!
You are probably using a laptop.
Launch with environment variable `CUSTOS_USE_UNIFIED=false` or change `CUSTOS_CL_DEVICE_IDX=<idx:default=0>`
")
}
}

/// Sets the values of the attributes cache, kernel cache, graph and CPU to their default.
/// This cleans up any accumulated allocations.
pub fn reset(&'static mut self) {
Expand Down
5 changes: 4 additions & 1 deletion src/devices/stack/stack_device.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use core::convert::Infallible;

use crate::{
flag::AllocFlag, impl_buffer_hook_traits, impl_retriever, impl_wrapped_data, pass_down_cursor, pass_down_grad_fn, pass_down_optimize_mem_graph, pass_down_tape_actions, pass_down_use_gpu_or_cpu, shape::Shape, Alloc, Base, Buffer, CloneBuf, Device, DevicelessAble, OnDropBuffer, Read, StackArray, WrappedData, WriteBuf
flag::AllocFlag, impl_buffer_hook_traits, impl_retriever, impl_wrapped_data, pass_down_cursor,
pass_down_grad_fn, pass_down_optimize_mem_graph, pass_down_tape_actions,
pass_down_use_gpu_or_cpu, shape::Shape, Alloc, Base, Buffer, CloneBuf, Device, DevicelessAble,
OnDropBuffer, Read, StackArray, WrappedData, WriteBuf,
};

/// A device that allocates memory on the stack.
Expand Down
11 changes: 5 additions & 6 deletions src/modules/autograd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ where
where
D: Alloc<T>,
{
let requires_grad = parents.requires_grads().iter().any(|&x| x);
let requires_grad = parents.requires_grads().iter().any(|&x| x);
let data = self.modules.retrieve(device, len, parents);

ReqGradWrapper {
Expand Down Expand Up @@ -484,22 +484,21 @@ mod tests {

let lhs = device.buffer([1i32, 2, 3, 4]).require_grad();
assert!(lhs.requires_grad());

let no_grad = device.buffer([1i32, 2, 3, 4]);
let rhs = device.buffer([1i32, 2, 3, 4]);
assert!(!rhs.requires_grad());

let out: Buffer<i32, _> = device.retrieve(rhs.len(), (&lhs, &rhs));
assert!(out.requires_grad());

let out: Buffer<i32, _> = device.retrieve(rhs.len(), &lhs);
assert!(out.requires_grad());

let out: Buffer<i32, _> = device.retrieve(rhs.len(), &rhs);
assert!(!out.requires_grad());

let out: Buffer<i32, _> = device.retrieve(rhs.len(), (&no_grad, &rhs));
assert!(!out.requires_grad());

}
}
5 changes: 4 additions & 1 deletion src/modules/autograd/gradients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ impl Gradients {
// self.grads_pool.cache.clear();
}

pub fn add_zero_grad_cb<T: 'static, D: Device + ZeroGrad<T> + 'static, S: Shape>(&mut self, id: &Id) {
pub fn add_zero_grad_cb<T: 'static, D: Device + ZeroGrad<T> + 'static, S: Shape>(
&mut self,
id: &Id,
) {
self.zero_grad_cbs.push((*id, |grad_buf, buf| {
let grad_buf = grad_buf.downcast_mut::<Buffer<T, D, S>>().unwrap();
let buf = buf.downcast_ref::<Buffer<T, D, S>>().unwrap();
Expand Down
2 changes: 1 addition & 1 deletion src/parents.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl<T: HasId> Parents<1> for T {

#[inline]
fn requires_grads(&self) -> [bool; 1] {
[self.requires_grad()]
[self.requires_grad()]
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/unary.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::{
AddGradFn, Alloc, AsNoId, Buffer, Device, Eval, HasId, MayTapeActions, MayToCLSource, Resolve, Shape, ZeroGrad
AddGradFn, Alloc, AsNoId, Buffer, Device, Eval, HasId, MayTapeActions, MayToCLSource, Resolve,
Shape, ZeroGrad,
};

/// Applies a function to a buffer and returns a new buffer.
Expand Down

0 comments on commit c183bbd

Please sign in to comment.