Skip to content

Commit

Permalink
make pool2d work on stable
Browse files Browse the repository at this point in the history
  • Loading branch information
Tom committed Jul 23, 2024
1 parent 0d10967 commit da2d024
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 2 deletions.
2 changes: 0 additions & 2 deletions dfdx-core/src/tensor_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,5 @@ mod convtrans2d;
#[cfg(feature = "nightly")]
pub use convtrans2d::TryConvTrans2D;

#[cfg(feature = "nightly")]
mod pool2d;
#[cfg(feature = "nightly")]
pub use pool2d::{Pool2DKind, TryPool2D};
45 changes: 45 additions & 0 deletions dfdx-core/src/tensor_ops/pool2d/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ pub trait TryPool2D<Kernel, Stride, Padding, Dilation>: Sized {
) -> Result<Self::Pooled, Error>;
}

#[cfg(feature = "nightly")]
impl<
const KERNEL: usize,
const STRIDE: usize,
Expand All @@ -100,6 +101,50 @@ where
}
}

macro_rules! const_try_pool {
($Dim:expr, $Kernel:expr, $Stride:expr, $Padding:expr, $Dilation:expr, out=$Out_dim:expr) => {
#[cfg(not(feature = "nightly"))]
impl TryPool2D<Const<$Kernel>, Const<$Stride>, Const<$Padding>, Const<$Dilation>>
for Const<$Dim>
{
// ($Dim + 2 * $Padding - $Dilation * ($Kernel - 1) - 1) / $Stride + 1
// def compute_output_size(dim, kernel_size, stride, padding, dilation):
// output_size = int(int(dim + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1)
// return output_size
type Pooled = Const<$Out_dim>;

fn try_pool2d(
self,
_: Pool2DKind,
_: Const<$Kernel>,
_: Const<$Stride>,
_: Const<$Padding>,
_: Const<$Dilation>,
) -> Result<Self::Pooled, Error> {
Ok(Const)
}
}
};
}

const_try_pool!(1, 2, 1, 0, 1, out = 0);
const_try_pool!(2, 2, 1, 0, 1, out = 1);
const_try_pool!(3, 2, 1, 0, 1, out = 2);
const_try_pool!(4, 2, 1, 0, 1, out = 3);

const_try_pool!(1, 2, 2, 0, 1, out = 0);
const_try_pool!(2, 2, 2, 0, 1, out = 1);
const_try_pool!(3, 2, 2, 0, 1, out = 1);
const_try_pool!(4, 2, 2, 0, 1, out = 2);

const_try_pool!(1, 1, 2, 0, 1, out = 1);
const_try_pool!(2, 1, 2, 0, 1, out = 1);
const_try_pool!(3, 1, 2, 0, 1, out = 2);
const_try_pool!(4, 1, 2, 0, 1, out = 2);

const_try_pool!(4, 2, 1, 0, 2, out = 2);
const_try_pool!(5, 2, 1, 0, 2, out = 3);

impl<Kernel: Dim, Stride: Dim, Padding: Dim, Dilation: Dim>
TryPool2D<Kernel, Stride, Padding, Dilation> for usize
{
Expand Down

0 comments on commit da2d024

Please sign in to comment.