Skip to content

Commit

Permalink
Fix nnapidevice (compilation)
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Dec 17, 2023
1 parent c509389 commit 8824482
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 40 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ min-cl = { path="../min-cl", optional=true }

[features]
# default = ["cpu", "autograd", "macro"]
default = ["cpu", "static-api", "lazy", "autograd", "graph", "fork", "opencl", "vulkan", "stack"]
default = ["cpu", "static-api", "lazy", "autograd", "graph", "fork", "opencl", "nnapi", "stack"]

cpu = []
opencl = ["dep:min-cl", "cpu", "cached"]
Expand Down
98 changes: 67 additions & 31 deletions src/devices/nnapi/nnapi_device.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
cpu::CPUPtr, Alloc, AsOperandCode, Base, Buffer, Device, Lazy, LazyRun, LazySetup, Module,
OnDropBuffer, PtrConv, Retrieve, Retriever, Setup, Shape, CPU,
cpu::CPUPtr, Alloc, AsOperandCode, Base, Buffer, ConvPtr, Device, Lazy,
LazyRun, LazySetup, Module, OnDropBuffer, Retrieve, Retriever, Setup, Shape, WrappedData, HasId, PtrType, IsShapeIndep,
};

use super::NnapiPtr;
Expand All @@ -23,17 +23,43 @@ pub struct NnapiDevice<T, Mods = Base> {
out: Cell<Vec<T>>,
}

impl<T, Mods: OnDropBuffer> Device for NnapiDevice<T, Mods> {
type Data<U, S: crate::Shape> = NnapiPtr;
impl<U, Mods: OnDropBuffer> Device for NnapiDevice<U, Mods> {
type Data<T, S: crate::Shape> = Mods::Wrap<T, NnapiPtr>;
type Base<T, S: Shape> = NnapiPtr;
type Error = crate::Error;

#[inline]
fn new() -> crate::Result<Self> {
// NnapiDevice::new()
todo!()
}

fn base_to_data<T, S: Shape>(&self, base: Self::Base<T, S>) -> Self::Data<T, S> {
self.wrap_in_base(base)
}

#[inline]
fn wrap_to_data<T, S: Shape>(&self, wrap: Self::Wrap<T, Self::Base<T, S>>) -> Self::Data<T, S> {
wrap
}

#[inline]
fn data_as_wrap<'a, T, S: Shape>(
data: &'a Self::Data<T, S>,
) -> &'a Self::Wrap<T, Self::Base<T, S>> {
data
}

#[inline]
fn data_as_wrap_mut<'a, T, S: Shape>(
data: &'a mut Self::Data<T, S>,
) -> &'a mut Self::Wrap<T, Self::Base<T, S>> {
data
}
}

unsafe impl<U, Mods: OnDropBuffer> IsShapeIndep for NnapiDevice<U, Mods> {}

impl<U, T, D: Device, S: Shape, Mods: crate::OnNewBuffer<T, D, S>> crate::OnNewBuffer<T, D, S>
for NnapiDevice<U, Mods>
{
Expand All @@ -42,23 +68,47 @@ impl<U, T, D: Device, S: Shape, Mods: crate::OnNewBuffer<T, D, S>> crate::OnNewB
self.modules.on_new_buffer(device, new_buf)
}
}
impl<U, Mods: WrappedData> WrappedData for NnapiDevice<U, Mods> {
type Wrap<T, Base: HasId + PtrType> = Mods::Wrap<T, Base>;

#[inline]
fn wrap_in_base<T, Base: HasId + PtrType>(
&self,
base: Base,
) -> Self::Wrap<T, Base> {
self.modules.wrap_in_base(base)
}

#[inline]
fn wrapped_as_base<'a, T, Base: HasId + PtrType>(
wrap: &'a Self::Wrap<T, Base>,
) -> &'a Base {
Mods::wrapped_as_base(wrap)
}

#[inline]
fn wrapped_as_base_mut<'a, T, Base: HasId + PtrType>(
wrap: &'a mut Self::Wrap<T, Base>,
) -> &'a mut Base {
Mods::wrapped_as_base_mut(wrap)
}
}

impl<U, Mods: crate::OnDropBuffer> crate::OnDropBuffer for NnapiDevice<U, Mods> {
impl<U, Mods: OnDropBuffer> OnDropBuffer for NnapiDevice<U, Mods> {
#[inline]
fn on_drop_buffer<T, D: Device, S: Shape>(&self, device: &D, buf: &Buffer<T, D, S>) {
self.modules.on_drop_buffer(device, buf)
}
}
impl<U, Mods: Retrieve<Self, T>, T: AsOperandCode> Retriever<T> for NnapiDevice<U, Mods> {
fn retrieve<S, const NUM_PARENTS: usize>(

impl<U, Mods: Retrieve<Self, T, S>, T: AsOperandCode, S: Shape> Retriever<T, S> for NnapiDevice<U, Mods> {
fn retrieve<const NUM_PARENTS: usize>(
&self,
len: usize,
parents: impl crate::Parents<NUM_PARENTS>,
) -> Buffer<T, Self, S>
where
S: Shape,
{
let data = self.modules.retrieve::<S, NUM_PARENTS>(self, len, parents);
let data = self.modules.retrieve::<NUM_PARENTS>(self, len, parents);
let buf = Buffer {
data,
device: Some(self),
Expand Down Expand Up @@ -87,7 +137,7 @@ impl<U, T: AsOperandCode, Mods: OnDropBuffer> Alloc<T> for NnapiDevice<U, Mods>
&self,
_len: usize,
flag: crate::flag::AllocFlag,
) -> <Self as Device>::Data<T, S> {
) -> Self::Base<T, S> {
let dtype = dtype_from_shape::<T, S>();
let idx = self.add_operand(&dtype).unwrap();
let nnapi_ptr = NnapiPtr { dtype, idx, flag };
Expand All @@ -97,7 +147,7 @@ impl<U, T: AsOperandCode, Mods: OnDropBuffer> Alloc<T> for NnapiDevice<U, Mods>
nnapi_ptr
}

fn alloc_from_slice<S: Shape>(&self, data: &[T]) -> <Self as Device>::Data<T, S>
fn alloc_from_slice<S: Shape>(&self, data: &[T]) ->Self::Base<T, S>
where
T: Clone,
{
Expand All @@ -106,8 +156,7 @@ impl<U, T: AsOperandCode, Mods: OnDropBuffer> Alloc<T> for NnapiDevice<U, Mods>
let mut ptr = unsafe { CPUPtr::<T>::new(data.len(), crate::flag::AllocFlag::Wrapper) };
ptr.clone_from_slice(data);

let ptr =
unsafe { <CPU<Base> as PtrConv>::convert::<T, S, u8, S>(&ptr, crate::AllocFlag::None) };
let ptr = unsafe { ConvPtr::<_, ()>::convert(&ptr, crate::flag::AllocFlag::None) };

self.input_ptrs.borrow_mut().push((nnapi_ptr.idx, ptr));
nnapi_ptr
Expand Down Expand Up @@ -236,19 +285,6 @@ where
}
}

impl<U, Mods: OnDropBuffer> PtrConv for NnapiDevice<U, Mods> {
unsafe fn convert<T, IS: Shape, Conv, OS: Shape>(
ptr: &Self::Data<T, IS>,
flag: crate::flag::AllocFlag,
) -> Self::Data<Conv, OS> {
NnapiPtr {
dtype: ptr.dtype.clone(),
idx: ptr.idx,
flag,
}
}
}

#[cfg(test)]
mod tests {
use nnapi::{nnapi_sys::OperationCode, Operand};
Expand All @@ -273,8 +309,8 @@ mod tests {
.unwrap();
model.add_operation(
OperationCode::ANEURALNETWORKS_ADD,
&[lhs.data.idx, rhs.data.idx, activation_idx],
&[out.data.idx],
&[lhs.base().idx, rhs.base().idx, activation_idx],
&[out.base().idx],
)?;

let out2 = Buffer::<f32, _, Dim1<10>>::new(&device, 0);
Expand All @@ -284,8 +320,8 @@ mod tests {
.unwrap();
model.add_operation(
OperationCode::ANEURALNETWORKS_MUL,
&[lhs.data.idx, out.data.idx, activation_idx],
&[out2.data.idx],
&[lhs.base().idx, out.base().idx, activation_idx],
&[out2.base().idx],
)?;

device.run()?;
Expand Down
17 changes: 16 additions & 1 deletion src/devices/nnapi/nnapi_ptr.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{flag::AllocFlag, HasId, PtrType};
use crate::{flag::AllocFlag, HasId, PtrType, ShallowCopy};
use nnapi::Operand;

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -41,4 +41,19 @@ impl PtrType for NnapiPtr {
fn flag(&self) -> crate::flag::AllocFlag {
self.flag
}

#[inline]
unsafe fn set_flag(&mut self, flag: AllocFlag) {
self.flag = flag;
}
}

impl ShallowCopy for NnapiPtr {
unsafe fn shallow(&self) -> Self {
NnapiPtr {
dtype: self.dtype.clone(),
idx: self.idx,
flag: crate::flag::AllocFlag::Wrapper,
}
}
}
10 changes: 4 additions & 6 deletions src/devices/stack/stack_device.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use core::convert::Infallible;

use crate::{
flag::AllocFlag, impl_buffer_hook_traits, impl_retriever, shape::Shape, Alloc, Base, Buffer,
CloneBuf, Device, DevicelessAble, OnDropBuffer, Read, StackArray, WriteBuf, impl_wrapped_data, WrappedData,
flag::AllocFlag, impl_buffer_hook_traits, impl_retriever, impl_wrapped_data, shape::Shape,
Alloc, Base, Buffer, CloneBuf, Device, DevicelessAble, OnDropBuffer, Read, StackArray,
WrappedData, WriteBuf,
};

/// A device that allocates memory on the stack.
Expand Down Expand Up @@ -73,10 +74,7 @@ impl<Mods: OnDropBuffer, T: Copy + Default> Alloc<T> for Stack<Mods> {
}

#[inline]
fn alloc_from_array<S: Shape>(
&self,
array: <S as Shape>::ARR<T>,
) -> Self::Base<T, S>
fn alloc_from_array<S: Shape>(&self, array: <S as Shape>::ARR<T>) -> Self::Base<T, S>
where
T: Clone,
{
Expand Down
1 change: 0 additions & 1 deletion src/devices/vulkan/vulkan_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ impl<Mods: OnDropBuffer> Device for Vulkan<Mods> {

type Error = ();

#[inline]
fn base_to_data<T, S: Shape>(&self, base: Self::Base<T, S>) -> Self::Data<T, S> {
self.wrap_in_base(base)
}
Expand Down

0 comments on commit 8824482

Please sign in to comment.