Skip to content

Commit

Permalink
Merge branch 'main' into webgl
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend authored Jan 31, 2024
2 parents 1d86bc6 + 2934c21 commit 8e036d6
Show file tree
Hide file tree
Showing 39 changed files with 610 additions and 296 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ jobs:
run: cargo test --verbose --no-default-features --features cpu,lazy
- name: Run tests,cached,fork
run: cargo test --verbose --no-default-features --features cpu,cached,fork
- name: Run graph
run: cargo test --verbose --no-default-features --features cpu,graph
- name: Run cached
run: cargo test --verbose --no-default-features --features cpu,cached

test-cached:

Expand Down
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ custos-macro = {git = "https://github.com/elftausend/custos-macro", optional=tru
libm = { version="0.2.6", optional = true }

ash = { version = "0.37", optional = true }
naga = { version = "0.14", features = ["wgsl-in"], optional = true }
naga = { version = "0.19", features = ["wgsl-in"], optional = true }

half = {version = "2.3", default-features = false, optional = true}

Expand Down Expand Up @@ -65,6 +65,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true }
# default = ["cpu", "autograd", "macro"]
default = ["cpu", "lazy", "static-api", "webgl", "vulkan"]


cpu = []
opencl = ["dep:min-cl", "cpu", "cached"]
#network = ["cuwanto-client"]
Expand Down
2 changes: 1 addition & 1 deletion android-nnapi-ci/nnapitestlib/src/using_custos_nnapi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use nnapi::{nnapi_sys::OperationCode, Operand};
use crate::log;

pub fn run_custos_model() -> custos::Result<String> {
let device = NnapiDevice::<i32, Lazy<Base>>::new()?;
let device = NnapiDevice::<i32, Base>::new()?;

let lhs = Buffer::with(&device, [1, 2, 3, 4, 5, 6, 7, 8, 9, 11]);
let rhs = Buffer::with(&device, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
Expand Down
1 change: 1 addition & 0 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ fn has_device_unified_mem() -> bool {
.unified_mem
}

#[cfg(feature = "cuda")]
use std::path::{Path, PathBuf};

// https://github.com/coreylowman/cudarc/blob/main/build.rs
Expand Down
12 changes: 12 additions & 0 deletions src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,18 @@ impl<'a, Mods: OnDropBuffer, T> Buffer<'a, T, crate::CUDA<Mods>> {
}
}

impl<'a, T, D, S> ShallowCopy for Buffer<'a, T, D, S>
where
D: Device,
D::Data<T, S>: ShallowCopy,
S: Shape,
{
#[inline]
unsafe fn shallow(&self) -> Self {
self.shallow()
}
}

impl<'a, T, D, S> Clone for Buffer<'a, T, D, S>
where
T: Clone,
Expand Down
1 change: 1 addition & 0 deletions src/buffer/impl_autograd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ where
pub fn grad(&self) -> &'a Self
where
D: MayTapeActions + Alloc<T>,
// D::Data<T, S>: crate::ShallowCopy,
{
unsafe {
self.device()
Expand Down
27 changes: 16 additions & 11 deletions src/buffer/num.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use core::{

use crate::{
flag::AllocFlag, Alloc, Buffer, CloneBuf, CommonPtrs, Device, HasId, OnDropBuffer, PtrType,
WrappedData,
ShallowCopy, WrappedData,
};

#[derive(Debug, Default)]
Expand Down Expand Up @@ -57,6 +57,13 @@ impl<T> From<T> for Num<T> {
}
}

impl<T> ShallowCopy for Num<T> {
#[inline]
unsafe fn shallow(&self) -> Self {
unimplemented!()
}
}

impl Device for () {
type Data<T, S: crate::Shape> = Self::Base<T, S>;
type Base<T, S: crate::Shape> = Num<T>;
Expand All @@ -81,15 +88,15 @@ impl Device for () {
}

#[inline(always)]
fn data_as_wrap<'a, T, S: crate::Shape>(
data: &'a Self::Data<T, S>,
) -> &'a Self::Wrap<T, Self::Base<T, S>> {
fn data_as_wrap<T, S: crate::Shape>(
data: &Self::Data<T, S>,
) -> &Self::Wrap<T, Self::Base<T, S>> {
data
}

fn data_as_wrap_mut<'a, T, S: crate::Shape>(
data: &'a mut Self::Data<T, S>,
) -> &'a mut Self::Wrap<T, Self::Base<T, S>> {
fn data_as_wrap_mut<T, S: crate::Shape>(
data: &mut Self::Data<T, S>,
) -> &mut Self::Wrap<T, Self::Base<T, S>> {
data
}
}
Expand Down Expand Up @@ -118,14 +125,12 @@ impl WrappedData for () {
}

#[inline]
fn wrapped_as_base<'a, T, Base: HasId + PtrType>(wrap: &'a Self::Wrap<T, Base>) -> &'a Base {
fn wrapped_as_base<T, Base: HasId + PtrType>(wrap: &Self::Wrap<T, Base>) -> &Base {
wrap
}

#[inline]
fn wrapped_as_base_mut<'a, T, Base: HasId + PtrType>(
wrap: &'a mut Self::Wrap<T, Base>,
) -> &'a mut Base {
fn wrapped_as_base_mut<T, Base: HasId + PtrType>(wrap: &mut Self::Wrap<T, Base>) -> &mut Base {
wrap
}
}
Expand Down
6 changes: 4 additions & 2 deletions src/cache/borrow_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ impl Display for CachingError {

impl std::error::Error for CachingError {}

#[derive(Debug, Default)]
pub(crate) type AnyBuffers = HashMap<UniqueId, Box<dyn Any>, BuildHasherDefault<NoHasher>>;

#[derive(Default)]
pub struct BorrowCache {
pub(crate) cache: HashMap<UniqueId, Box<dyn Any>, BuildHasherDefault<NoHasher>>,
pub(crate) cache: AnyBuffers,
}

// TODO: make BorrowedCache unuseable without device (=> Static get methods with D: CacheReturn)
Expand Down
12 changes: 7 additions & 5 deletions src/cache/owned_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ impl Cache {
}
}

/// # Safety
/// Lifetime of data must be at least as long as the lifetime of the cache (usually the device).
#[track_caller]
#[inline]
pub fn get<T, S, D>(&mut self, device: &D, len: usize, callback: fn()) -> D::Base<T, S>
pub unsafe fn get<T, S, D>(&mut self, device: &D, len: usize, callback: fn()) -> D::Base<T, S>
where
D: Alloc<T> + 'static,
D::Base<T, S>: ShallowCopy + 'static,
Expand Down Expand Up @@ -80,7 +82,7 @@ mod tests {
assert_eq!(cache.nodes.len(), 1);
assert_eq!(out.len, 10);

let out1 = cache.get::<f32, (), _>(&device, 10, || ());
let out1 = unsafe { cache.get::<f32, (), _>(&device, 10, || ()) };
assert_ne!(out.ptr, out1.ptr);
}

Expand All @@ -91,10 +93,10 @@ mod tests {

assert_eq!(cache.nodes.len(), 0);

let out1 = cache.get::<f32, (), _>(&device, 10, || ());
let out1 = unsafe { cache.get::<f32, (), _>(&device, 10, || ()) };
assert_eq!(cache.nodes.len(), 1);

let out2 = cache.get::<f32, (), _>(&device, 10, || ());
let out2 = unsafe { cache.get::<f32, (), _>(&device, 10, || ()) };

assert_ne!(out1.ptr, out2.ptr);
assert_eq!(cache.nodes.len(), 2);
Expand All @@ -108,7 +110,7 @@ mod tests {

let mut prev = None;
for _ in 0..1000 {
let out3 = cache.get::<f32, (), _>(&device, 10, || ());
let out3 = unsafe { cache.get::<f32, (), _>(&device, 10, || ()) };
if prev.is_none() {
prev = Some(out3.ptr);
}
Expand Down
18 changes: 8 additions & 10 deletions src/devices.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,10 @@ pub trait Device: OnDropBuffer + Sized {
// FIXME: probably a better way to realize these
fn base_to_data<T, S: Shape>(&self, base: Self::Base<T, S>) -> Self::Data<T, S>;
fn wrap_to_data<T, S: Shape>(&self, wrap: Self::Wrap<T, Self::Base<T, S>>) -> Self::Data<T, S>;
fn data_as_wrap<'a, T, S: Shape>(
data: &'a Self::Data<T, S>,
) -> &'a Self::Wrap<T, Self::Base<T, S>>;
fn data_as_wrap_mut<'a, T, S: Shape>(
data: &'a mut Self::Data<T, S>,
) -> &'a mut Self::Wrap<T, Self::Base<T, S>>;
fn data_as_wrap<T, S: Shape>(data: &Self::Data<T, S>) -> &Self::Wrap<T, Self::Base<T, S>>;
fn data_as_wrap_mut<T, S: Shape>(
data: &mut Self::Data<T, S>,
) -> &mut Self::Wrap<T, Self::Base<T, S>>;

/// Creates a new [`Buffer`] using `A`.
///
Expand Down Expand Up @@ -114,12 +112,12 @@ macro_rules! impl_device_traits {
$crate::impl_wrapped_data!($device);

#[cfg(feature = "graph")]
crate::pass_down_optimize_mem_graph!($device);
$crate::pass_down_optimize_mem_graph!($device);

$crate::pass_down_grad_fn!($device);
$crate::pass_down_tape_actions!($device);

$crate::pass_down_replace_buf!($device);
$crate::pass_down_replace_buf_dev!($device);
};
}

Expand All @@ -133,9 +131,9 @@ macro_rules! impl_retriever {
len: usize,
parents: impl $crate::Parents<NUM_PARENTS>,
) -> Buffer<T, Self, S> {
let data = self
let data = unsafe { self
.modules
.retrieve::<NUM_PARENTS>(self, len, parents);
.retrieve::<NUM_PARENTS>(self, len, parents) };
let buf = Buffer {
data,
device: Some(self),
Expand Down
14 changes: 6 additions & 8 deletions src/devices/cpu/cpu_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,14 @@ impl<Mods: OnDropBuffer> Device for CPU<Mods> {
}

#[inline(always)]
fn data_as_wrap<'a, T, S: Shape>(
data: &'a Self::Data<T, S>,
) -> &'a Self::Wrap<T, Self::Base<T, S>> {
fn data_as_wrap<T, S: Shape>(data: &Self::Data<T, S>) -> &Self::Wrap<T, Self::Base<T, S>> {
data
}

#[inline(always)]
fn data_as_wrap_mut<'a, T, S: Shape>(
data: &'a mut Self::Data<T, S>,
) -> &'a mut Self::Wrap<T, Self::Base<T, S>> {
fn data_as_wrap_mut<T, S: Shape>(
data: &mut Self::Data<T, S>,
) -> &mut Self::Wrap<T, Self::Base<T, S>> {
data
}

Expand Down Expand Up @@ -190,12 +188,12 @@ unsafe impl<Mods: OnDropBuffer> IsShapeIndep for CPU<Mods> {}

#[cfg(test)]
mod tests {
use crate::{Base, CPU};

#[cfg(feature = "fork")]
#[cfg(feature = "cached")]
#[test]
fn test_add_layer_cpu() {
use crate::{Base, CPU};

let cpu = CPU::<Base>::new();
let cpu = cpu.add_layer::<crate::Cached<()>>();
let cpu = cpu.add_layer::<crate::Fork<()>>();
Expand Down
1 change: 0 additions & 1 deletion src/devices/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,3 @@ mod cpu_ptr;
mod ops;

pub use cpu_ptr::*;
pub use ops::*;
9 changes: 5 additions & 4 deletions src/devices/nnapi/nnapi_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ impl<U, Mods: Retrieve<Self, T, S>, T: AsOperandCode, S: Shape> Retriever<T, S>
len: usize,
parents: impl crate::Parents<NUM_PARENTS>,
) -> Buffer<T, Self, S> {
let data = self.modules.retrieve::<NUM_PARENTS>(self, len, parents);
let data = unsafe { self.modules.retrieve::<NUM_PARENTS>(self, len, parents) };
let buf = Buffer {
data,
device: Some(self),
Expand Down Expand Up @@ -158,10 +158,11 @@ impl<U, T: AsOperandCode, Mods: OnDropBuffer> Alloc<T> for NnapiDevice<U, Mods>

impl<T, SimpleMods> NnapiDevice<T, SimpleMods> {
/// Creates a new [`NnapiDevice`].
pub fn new<NewMods>() -> crate::Result<NnapiDevice<T, Lazy<NewMods>>>
pub fn new<NewMods>() -> crate::Result<NnapiDevice<T, NewMods>>
// TODO keep in mind that lazy module requirement would make sense here
where
SimpleMods: Module<NnapiDevice<T>, Module = Lazy<NewMods>>,
Lazy<NewMods>: Setup<NnapiDevice<T, Lazy<NewMods>>>,
SimpleMods: Module<NnapiDevice<T>, Module = NewMods>,
NewMods: Setup<NnapiDevice<T, NewMods>>,
{
let mut device = NnapiDevice {
modules: SimpleMods::new(),
Expand Down
40 changes: 28 additions & 12 deletions src/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@ use core::{fmt::Debug, ops::RangeBounds};

use crate::{HasId, Parents, Shape, UniqueId, UpdateArgs, CPU};

#[cfg(feature = "graph")]
use crate::TranslatedCacheTrace;

#[cfg(feature = "cached")]
use crate::{Base, CachedModule};

Expand All @@ -22,7 +19,7 @@ pub trait Feature: OnDropBuffer {}
pub trait Retrieve<D, T, S: Shape = ()>: OnDropBuffer {
// "generator"
#[track_caller]
fn retrieve<const NUM_PARENTS: usize>(
unsafe fn retrieve<const NUM_PARENTS: usize>(
&self,
device: &D,
len: usize,
Expand Down Expand Up @@ -102,7 +99,7 @@ macro_rules! pass_down_grad_fn {
fn add_grad_fn<Args: $crate::Parents<N> + $crate::UpdateArgs, const N: usize>(
&self,
args: Args,
op: fn(&mut Args) -> crate::Result<()>,
op: fn(&mut Args) -> $crate::Result<()>,
) {
self.modules.add_grad_fn(args, op)
}
Expand Down Expand Up @@ -188,7 +185,7 @@ pub trait ReplaceBuf<T, D: Device, S: Shape>: OnDropBuffer {
}

#[macro_export]
macro_rules! pass_down_replace_buf {
macro_rules! pass_down_replace_buf_dev {
($device:ident) => {
impl<T, S: Shape, Mods: $crate::ReplaceBuf<T, Self, S>> $crate::ReplaceBuf<T, Self, S>
for $device<Mods>
Expand All @@ -204,6 +201,23 @@ macro_rules! pass_down_replace_buf {
};
}

#[macro_export]
macro_rules! pass_down_replace_buf_module {
($module:ident) => {
impl<T, S: Shape, Mods: $crate::ReplaceBuf<T, D, S>, D: $crate::Device>
$crate::ReplaceBuf<T, D, S> for $module<Mods>
{
#[inline]
fn replace_buf<'a, 'c>(
&'c self,
buffer: &'c Buffer<'a, T, D, S>,
) -> &'c Buffer<'a, T, D, S> {
self.modules.replace_buf(buffer)
}
}
};
}

pub trait AddOperation {
#[track_caller]
fn add_op<Args: Parents<N> + UpdateArgs, const N: usize>(
Expand Down Expand Up @@ -236,7 +250,7 @@ macro_rules! pass_down_add_operation {
fn add_op<Args: $crate::Parents<N> + $crate::UpdateArgs, const N: usize>(
&self,
args: Args,
operation: fn(&mut Args) -> crate::Result<()>,
operation: fn(&mut Args) -> $crate::Result<()>,
) -> $crate::Result<()> {
self.modules.add_op(args, operation)
}
Expand Down Expand Up @@ -375,21 +389,23 @@ pub trait UseGpuOrCpu {

#[cfg(feature = "graph")]
pub trait OptimizeMemGraph {
fn optimize_mem_graph(
fn optimize_mem_graph<D: 'static>(
&self,
cache_traces: Option<&[TranslatedCacheTrace]>,
device: &D,
graph_translator: Option<&crate::modules::GraphTranslator>,
) -> crate::Result<()>;
}

#[macro_export]
macro_rules! pass_down_optimize_mem_graph {
($to_impl:ident) => {
impl<Mods: $crate::OptimizeMemGraph> $crate::OptimizeMemGraph for $to_impl<Mods> {
fn optimize_mem_graph(
fn optimize_mem_graph<D: 'static>(
&self,
cache_traces: Option<&[$crate::TranslatedCacheTrace]>,
device: &D,
graph_translator: Option<&$crate::modules::GraphTranslator>,
) -> crate::Result<()> {
self.modules.optimize_mem_graph(cache_traces)
self.modules.optimize_mem_graph(device, graph_translator)
}
}
};
Expand Down
Loading

0 comments on commit 8e036d6

Please sign in to comment.