Skip to content

Commit

Permalink
Merge branch 'main' into lazy-autograd
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Feb 8, 2024
2 parents e1f1024 + 3583d48 commit 49fa3cb
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 5 deletions.
20 changes: 18 additions & 2 deletions src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,34 @@ impl<'a, T, D: Device, S: Shape> Buffer<'a, T, D, S> {
}

#[inline]
pub fn require_grad(self) -> Buffer<'a, T, D, S>
pub fn set_require_grad(self, require_grad: bool) -> Buffer<'a, T, D, S>
where
D: OnNewBuffer<T, D, S>,
{
if let Some(device) = self.device {
device.on_drop_buffer(device, &self);
}
let mut buf = self;
buf.set_requires_grad(true);
buf.set_requires_grad(require_grad);
buf.device().on_new_buffer(buf.device(), &buf);
buf
}

#[inline]
pub fn require_grad(self) -> Buffer<'a, T, D, S>
where
D: OnNewBuffer<T, D, S>,
{
self.set_require_grad(true)
}

#[inline]
pub fn no_grad(self) -> Buffer<'a, T, D, S>
where
D: OnNewBuffer<T, D, S>,
{
self.set_require_grad(false)
}
}

// DO NOT implement!
Expand Down
4 changes: 2 additions & 2 deletions src/modules/autograd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,8 @@ mod tests {
let lhs = device.buffer([1i32, 2, 3, 4]).require_grad();
assert!(lhs.requires_grad());

let no_grad = device.buffer([1i32, 2, 3, 4]);
let rhs = device.buffer([1i32, 2, 3, 4]);
let no_grad = device.buffer([1i32, 2, 3, 4]).no_grad();
let rhs = device.buffer([1i32, 2, 3, 4]).no_grad();
assert!(!rhs.requires_grad());

let out: Buffer<i32, _> = device.retrieve(rhs.len(), (&lhs, &rhs));
Expand Down
3 changes: 2 additions & 1 deletion src/modules/autograd/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ impl<Mods: WrappedData> WrappedData for Autograd<Mods> {
base: Base,
) -> Self::Wrap<T, Base> {
ReqGradWrapper {
requires_grad: false,
// by default: true -> if lazy layer is (accidentally) put before autograd, all gradients will be computed instead of none.. subject to change
requires_grad: true,
data: self.modules.wrap_in_base(base),
_pd: PhantomData,
}
Expand Down

0 comments on commit 49fa3cb

Please sign in to comment.