From da2d0241d905ecbde2aff69e44c90b7c152b3c5e Mon Sep 17 00:00:00 2001 From: Tom Date: Mon, 22 Jul 2024 17:51:20 -0700 Subject: [PATCH] make pool2d work on stable --- dfdx-core/src/tensor_ops/mod.rs | 2 -- dfdx-core/src/tensor_ops/pool2d/mod.rs | 45 ++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/dfdx-core/src/tensor_ops/mod.rs b/dfdx-core/src/tensor_ops/mod.rs index c009f9db..22c6bf59 100644 --- a/dfdx-core/src/tensor_ops/mod.rs +++ b/dfdx-core/src/tensor_ops/mod.rs @@ -292,7 +292,5 @@ mod convtrans2d; #[cfg(feature = "nightly")] pub use convtrans2d::TryConvTrans2D; -#[cfg(feature = "nightly")] mod pool2d; -#[cfg(feature = "nightly")] pub use pool2d::{Pool2DKind, TryPool2D}; diff --git a/dfdx-core/src/tensor_ops/pool2d/mod.rs b/dfdx-core/src/tensor_ops/pool2d/mod.rs index 150525c7..53997e73 100644 --- a/dfdx-core/src/tensor_ops/pool2d/mod.rs +++ b/dfdx-core/src/tensor_ops/pool2d/mod.rs @@ -77,6 +77,7 @@ pub trait TryPool2D: Sized { ) -> Result; } +#[cfg(feature = "nightly")] impl< const KERNEL: usize, const STRIDE: usize, @@ -100,6 +101,50 @@ where } } +macro_rules! const_try_pool { + ($Dim:expr, $Kernel:expr, $Stride:expr, $Padding:expr, $Dilation:expr, out=$Out_dim:expr) => { + #[cfg(not(feature = "nightly"))] + impl TryPool2D, Const<$Stride>, Const<$Padding>, Const<$Dilation>> + for Const<$Dim> + { + // ($Dim + 2 * $Padding - $Dilation * ($Kernel - 1) - 1) / $Stride + 1 + // def compute_output_size(dim, kernel_size, stride, padding, dilation): + // output_size = int(int(dim + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1) + // return output_size + type Pooled = Const<$Out_dim>; + + fn try_pool2d( + self, + _: Pool2DKind, + _: Const<$Kernel>, + _: Const<$Stride>, + _: Const<$Padding>, + _: Const<$Dilation>, + ) -> Result { + Ok(Const) + } + } + }; +} + +const_try_pool!(1, 2, 1, 0, 1, out = 0); +const_try_pool!(2, 2, 1, 0, 1, out = 1); +const_try_pool!(3, 2, 1, 0, 1, out = 2); +const_try_pool!(4, 2, 1, 0, 1, out = 3); + +const_try_pool!(1, 2, 2, 0, 1, out = 0); +const_try_pool!(2, 2, 2, 0, 1, out = 1); +const_try_pool!(3, 2, 2, 0, 1, out = 1); +const_try_pool!(4, 2, 2, 0, 1, out = 2); + +const_try_pool!(1, 1, 2, 0, 1, out = 1); +const_try_pool!(2, 1, 2, 0, 1, out = 1); +const_try_pool!(3, 1, 2, 0, 1, out = 2); +const_try_pool!(4, 1, 2, 0, 1, out = 2); + +const_try_pool!(4, 2, 1, 0, 2, out = 2); +const_try_pool!(5, 2, 1, 0, 2, out = 3); + impl TryPool2D for usize {