Skip to content

Commit

Permalink
Remove UpdateArgs
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Jul 24, 2024
1 parent 0d81aa7 commit c1d3191
Show file tree
Hide file tree
Showing 11 changed files with 22 additions and 259 deletions.
29 changes: 1 addition & 28 deletions src/boxed_shallow_copy.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::{AnyBuffer, AsAny, Downcast, ShallowCopy};
use core::any::Any;
use crate::{AnyBuffer, Downcast, ShallowCopy};

pub trait BoxedShallowCopy: AnyBuffer {
fn shallow_copy(&self) -> Box<dyn BoxedShallowCopy>;
Expand Down Expand Up @@ -77,29 +76,3 @@ impl<I: Downcast + ?Sized> Downcast for Box<I> {
(**self).is::<T>()
}
}

impl AsAny for Box<dyn BoxedShallowCopy> {
#[inline]
fn as_any(&self) -> *const () {
let data = &**self;
data as *const _ as *const ()
}

#[inline]
fn as_any_mut(&mut self) -> *mut () {
let data = &mut **self;
data as *mut _ as *mut ()
}
}

impl AsAny for Box<dyn Any> {
#[inline]
fn as_any(&self) -> *const () {
(&**self) as *const _ as *const ()
}

#[inline]
fn as_any_mut(&mut self) -> *mut () {
(&mut **self) as *mut _ as *mut ()
}
}
2 changes: 1 addition & 1 deletion src/devices/opencl/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use min_cl::{

use crate::{
bounds_to_range, cpu_stack_ops::clear_slice, location, op_hint::unary, pass_down_add_operation,
pass_down_exec_now, prelude::Number, AddOperation, ApplyFunction, AsNoId, BufAsNoId, Buffer,
pass_down_exec_now, prelude::Number, AddOperation, ApplyFunction, Buffer,
CDatatype, ClearBuf, CopySlice, OnDropBuffer, OpenCL, Read, Resolve, Retrieve, Retriever,
SetOpHint, Shape, ToCLSource, ToMarker, TwoWay, UnaryGrad, Unit, UseGpuOrCpu, WriteBuf,
ZeroGrad,
Expand Down
9 changes: 4 additions & 5 deletions src/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use core::{cell::RefMut, fmt::Debug, ops::RangeBounds};
use crate::{
op_hint::OpHint,
range::{AsRange, CursorRange},
AnyOp, HasId, Parents, Shape, UniqueId, Unit, UpdateArgs, ZeroGrad, CPU,
AnyOp, HasId, Parents, Shape, UniqueId, Unit, ZeroGrad, CPU,
};

#[cfg(feature = "cached")]
Expand Down Expand Up @@ -157,16 +157,15 @@ pub trait AddGradFn {
op: impl for<'b> Fn(Args::Replicated<'b>) -> crate::Result<()> + 'static,
);

fn add_grad_and_forward_fn<Args: Parents<N> + UpdateArgs + AnyOp + Clone, const N: usize>(
fn add_grad_and_forward_fn<Args: Parents<N> + AnyOp + Clone, const N: usize>(
&self,
args: Args,
forward_fn: fn(&mut Args) -> crate::Result<()>,
forward_fn: impl for<'b> Fn(Args::Replicated<'b>) -> crate::Result<()> + 'static,
grad_fn: impl for<'b> Fn(Args::Replicated<'b>) -> crate::Result<()> + 'static,
) where
Self: AddOperation,
{
todo!();
// self.add_op(args.clone(), forward_fn).unwrap();
self.add_op(args.clone(), forward_fn).unwrap();
self.add_grad_fn(args, grad_fn)
}

Expand Down
55 changes: 1 addition & 54 deletions src/id.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use core::ops::{Deref, DerefMut};

use crate::{Buffer, Device, DeviceError, Shape, UniqueId, Unit, UpdateArg};
use crate::{Buffer, Device, Shape, Unit};

pub trait HasId {
const HAS_NO_ID: bool = false;
Expand Down Expand Up @@ -105,59 +105,6 @@ impl<T: Into<NoId<T>>> AsNoId for T {
}
}

impl<T> UpdateArg for NoId<T> {
#[inline]
#[cfg(feature = "std")]
fn update_arg<B>(
_to_update: &mut Self,
_id: Option<UniqueId>,
_buffers: &mut crate::Buffers<B>,
) -> crate::Result<()> {
Ok(())
}
}

impl<'a, T: Unit + 'static, D: Device + 'static, S: Shape + 'static> UpdateArg
for &Buffer<'a, T, D, S>
{
#[cfg(feature = "std")]
fn update_arg<B: crate::AsAny>(
to_update: &mut Self,
id: Option<UniqueId>,
buffers: &mut crate::Buffers<B>,
) -> crate::Result<()> {
// use crate::ShallowCopyable;

let buf = buffers
.get(&id.unwrap())
.ok_or(DeviceError::InvalidLazyBuf)?;
// let any = buf.as_any();
// let _to_update = buf.as_any().downcast_ref::<Buffer<T, D, S>>().unwrap();
// todo!();
*to_update = unsafe { &*(buf.as_any() as *const Buffer<T, D, S>) };
// *self = buffers.get(&self.id()).unwrap().downcast_ref().unwrap();
Ok(())
}
}

impl<'a, T: Unit + 'static, D: Device + 'static, S: Shape + 'static> UpdateArg
for &mut Buffer<'a, T, D, S>
{
#[cfg(feature = "std")]
fn update_arg<B: crate::AsAny>(
to_update: &mut Self,
id: Option<UniqueId>,
buffers: &mut crate::Buffers<B>,
) -> crate::Result<()> {
let buf = buffers
.get_mut(&id.unwrap())
.ok_or(DeviceError::InvalidLazyBuf)?;
*to_update = unsafe { &mut *(buf.as_any_mut() as *mut Buffer<T, D, S>) };
Ok(())
// *self = buffers.get(&self.id()).unwrap().downcast_ref().unwrap();
}
}

pub trait BufAsNoId: Sized {
fn buf_no_id(self) -> NoId<Self>;
}
Expand Down
2 changes: 0 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ mod range;
mod shape;
mod two_way_ops;
mod unary;
mod update_args;
mod wrapper;

pub use any_op::*;
Expand All @@ -130,7 +129,6 @@ pub use number::*;
pub use parents::*;
pub use ptr_conv::*;
pub use range::*;
pub use update_args::*;
pub use wrapper::*;

#[cfg(not(feature = "cpu"))]
Expand Down
4 changes: 2 additions & 2 deletions src/modules/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
op_hint::OpHint, register_buf_copyable, unregister_buf_copyable, AddLayer, AddOperation, Alloc,
AnyOp, BoxedShallowCopy, Buffer, CachedBuffers, Cursor, Device, ExecNow, HasId, HasModules, Id,
IsShapeIndep, Module, NoHasher, OnDropBuffer, OnNewBuffer, Parents, ReplaceBuf, Retrieve,
RunModule, SetOpHint, Setup, ShallowCopy, Shape, UniqueId, Unit, UpdateArgs, UseGpuOrCpu,
RunModule, SetOpHint, Setup, ShallowCopy, Shape, UniqueId, Unit, UseGpuOrCpu,
};

#[cfg(feature = "graph")]
Expand Down Expand Up @@ -757,7 +757,7 @@ mod tests {
// #[ignore = "causes UB"]
#[test]
fn test_lazy_exec_ub_testing() {
use crate::{AsNoId, Run};
use crate::Run;

let device = CPU::<Lazy<Base>>::new();

Expand Down
49 changes: 5 additions & 44 deletions src/modules/lazy/exec_iter.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
use crate::{Buffers, Operation2, UniqueId, UpdateArgsDynable};
use crate::{Buffers, Operation2};

use super::lazy_graph::Operation;

pub struct ExecIter2<'a, 'b, B, T> {
pub struct ExecIter<'a, 'b, B, T> {
pub(super) operations: std::slice::Iter<'b, Operation2<'a, B, T>>,
pub(super) buffers: &'b mut Buffers<B>,
}

impl<'a, 'b, B, T> Iterator for ExecIter2<'a, 'b, B, T> {
impl<'a, 'b, B, T> Iterator for ExecIter<'a, 'b, B, T> {
type Item = crate::Result<()>;

fn next(&mut self) -> Option<Self::Item> {
Expand All @@ -16,51 +14,14 @@ impl<'a, 'b, B, T> Iterator for ExecIter2<'a, 'b, B, T> {
}
}

impl<'a, 'b, B, T> DoubleEndedIterator for ExecIter2<'a, 'b, B, T> {
impl<'a, 'b, B, T> DoubleEndedIterator for ExecIter<'a, 'b, B, T> {
fn next_back(&mut self) -> Option<Self::Item> {
let op = self.operations.next_back()?;
Some((op.op)(self.buffers))
}
}

impl<'a, 'b, B, T> ExactSizeIterator for ExecIter2<'a, 'b, B, T> {
fn len(&self) -> usize {
self.operations.len()
}
}
pub struct ExecIter<'a, B, T> {
pub(super) operations: std::slice::IterMut<'a, Operation<B, T>>,
pub(super) buffers: &'a mut Buffers<B>,
}

pub fn exec_op<B>(
args: &mut Box<dyn UpdateArgsDynable<B>>,
op: &fn(*mut ()) -> crate::Result<()>,
ids_to_check: &[Option<UniqueId>],
buffers: &mut Buffers<B>,
) -> crate::Result<()> {
args.update_args_dynable(ids_to_check, buffers)?;
let args = core::ptr::addr_of_mut!(**args) as *mut ();
op(args)
}

impl<'a, B, T> Iterator for ExecIter<'a, B, T> {
type Item = crate::Result<()>;

fn next(&mut self) -> Option<Self::Item> {
let op = self.operations.next()?;
Some(exec_op(&mut op.args, &op.op, &op.arg_ids, self.buffers))
}
}

impl<'a, B, T> DoubleEndedIterator for ExecIter<'a, B, T> {
fn next_back(&mut self) -> Option<Self::Item> {
let op = self.operations.next_back()?;
Some(exec_op(&mut op.args, &op.op, &op.arg_ids, self.buffers))
}
}

impl<'a, B, T> ExactSizeIterator for ExecIter<'a, B, T> {
impl<'a, 'b, B, T> ExactSizeIterator for ExecIter<'a, 'b, B, T> {
fn len(&self) -> usize {
self.operations.len()
}
Expand Down
33 changes: 6 additions & 27 deletions src/modules/lazy/lazy_graph.rs
Original file line number Diff line number Diff line change
@@ -1,36 +1,15 @@
use crate::{
bounds_to_range, modules::lazy::exec_iter::ExecIter2, op_hint::OpHint, AnyOp, AsAny,
BoxedShallowCopy, Buffers, Device, Downcast, Parents, UniqueId, UpdateArgs, UpdateArgsDynable,
bounds_to_range, modules::lazy::exec_iter::ExecIter, op_hint::OpHint, AnyOp,
BoxedShallowCopy, Buffers, Device, Downcast, Parents,
};
use core::{mem::transmute, ops::RangeBounds};
use core::ops::RangeBounds;
use std::collections::HashSet;

use super::exec_iter::{exec_op, ExecIter};

pub struct Operation2<'a, B, T> {
pub op: Box<dyn Fn(&mut Buffers<B>) -> crate::Result<()> + 'a>,
pub op_hint: OpHint<T>,
}

pub struct Operation<B, T> {
pub op_hint: OpHint<T>,
pub arg_ids: Vec<Option<UniqueId>>,
pub op: fn(*mut ()) -> crate::Result<()>,
pub args: Box<dyn UpdateArgsDynable<B>>,
}

impl<B: AsAny, T> Operation<B, T> {
pub fn no_op() -> Self {
Self {
op_hint: OpHint::None,
arg_ids: vec![None],
op: |_: *mut ()| Ok(()),
args: Box::new(()),
}
}
}


pub struct LazyGraph<'a, B = Box<dyn BoxedShallowCopy>, T = ()> {
pub(crate) operations: Vec<Operation2<'a, B, T>>,
}
Expand All @@ -50,8 +29,8 @@ impl<'a, B: Downcast, T> LazyGraph<'a, B, T> {
&'b mut self,
// device: &'a D,
buffers: &'b mut Buffers<B>,
) -> ExecIter2<'a, 'b, B, T> {
ExecIter2 {
) -> ExecIter<'a, 'b, B, T> {
ExecIter {
operations: self.operations.iter(),
buffers,
}
Expand Down Expand Up @@ -196,7 +175,7 @@ impl<'a, B: Downcast, T> LazyGraph<'a, B, T> {
#[cfg(test)]
mod tests {
use crate::{
register_buf_any, register_buf_copyable, AnyBuffer, AsNoId, Base, BoxedShallowCopy, Buffer,
register_buf_any, register_buf_copyable, AnyBuffer, Base, Buffer,
CloneBuf, Device, HasId, LazyGraph, Retriever, Shape, UniqueId, CPU,
};
use core::cell::Cell;
Expand Down
28 changes: 1 addition & 27 deletions src/parents.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{HasId, Id, UpdateArg};
use crate::{HasId, Id};

pub trait Parents<const N: usize>: AllParents {
fn ids(&self) -> [Id; N];
Expand All @@ -22,17 +22,6 @@ impl Parents<0> for () {

impl AllParents for () {}

impl UpdateArg for () {
#[cfg(feature = "std")]
fn update_arg<B>(
_to_update: &mut Self,
_id: Option<crate::UniqueId>,
_buffers: &mut crate::Buffers<B>,
) -> crate::Result<()> {
Ok(())
}
}

impl<T: HasId> Parents<1> for T {
#[inline]
fn ids(&self) -> [Id; 1] {
Expand Down Expand Up @@ -78,21 +67,6 @@ macro_rules! impl_parents {
}
impl<$($to_impl: $crate::HasId, )+> AllParents for ($($to_impl,)+) {}

impl<$($to_impl: $crate::UpdateArg + $crate::HasId, )+> $crate::UpdateArgs for ($($to_impl,)+) {
#[cfg(feature = "std")]
fn update_args<B: $crate::AsAny>(&mut self,
ids: &[Option<$crate::UniqueId>],
buffers: &mut $crate::Buffers<B>)
-> crate::Result<()>
{
let mut ids = ids.iter();
#[allow(non_snake_case)]
let ($($to_impl,)+) = self;
$($to_impl::update_arg($to_impl, *ids.next().unwrap(), buffers)?;)*
Ok(())
}
}

impl<'own, 'dev, $($to_impl: $crate::Replicate2<'own, 'dev> + $crate::HasId, )+> $crate::AnyOp2<'own, 'dev> for ($($to_impl,)+) {
type Replicated<'a, 'b> = ($($to_impl::Replication<'a, 'b>,)+) where 'b: 'a;

Expand Down
2 changes: 1 addition & 1 deletion src/unary.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
AddGradFn, AddOperation, Alloc, AsNoId, Buffer, Device, Eval, HasId, MayGradActions,
AddGradFn, AddOperation, Alloc, Buffer, Device, Eval, HasId, MayGradActions,
MayToCLSource, Resolve, Shape, TwoWay, Unit, ZeroGrad,
};

Expand Down
Loading

0 comments on commit c1d3191

Please sign in to comment.