From 0d109674f758851c0fd5a7261c13ebaa291a2505 Mon Sep 17 00:00:00 2001 From: Tom Date: Sun, 21 Jul 2024 19:56:21 -0700 Subject: [PATCH] make conv2d work on stable --- dfdx-core/src/tensor_ops/conv2d/mod.rs | 61 ++++++++++++++++++++++++++ dfdx-core/src/tensor_ops/mod.rs | 2 - 2 files changed, 61 insertions(+), 2 deletions(-) diff --git a/dfdx-core/src/tensor_ops/conv2d/mod.rs b/dfdx-core/src/tensor_ops/conv2d/mod.rs index c5be9694..ebb0de91 100644 --- a/dfdx-core/src/tensor_ops/conv2d/mod.rs +++ b/dfdx-core/src/tensor_ops/conv2d/mod.rs @@ -116,6 +116,7 @@ pub trait TryConv2D: Sized { ) -> Result; } +#[cfg(feature = "nightly")] impl< const KERNEL: usize, const STRIDE: usize, @@ -140,6 +141,66 @@ where } } +macro_rules! const_try_conv { + ($Dim:expr, $Kernel:expr, $Stride:expr, $Padding:expr, $Dilation:expr, out=$Out_dim:expr) => { + #[cfg(not(feature = "nightly"))] + impl TryConv2D, Const<$Padding>, Const<$Dilation>, Groups> + for (Const<$Dim>, Const<$Kernel>) + { + // ($Dim + 2 * $Padding - $Dilation * ($Kernel - 1) - 1) / $Stride + 1 + // def compute_output_size(dim, kernel_size, stride, padding, dilation): + // output_size = int(int(dim + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1) + // return output_size + type Convolved = Const<$Out_dim>; + + fn try_conv2d( + self, + _: Const<$Stride>, + _: Const<$Padding>, + _: Const<$Dilation>, + _: Groups, + ) -> Result { + Ok(Const) + } + } + }; +} + +const_try_conv!(1, 2, 1, 0, 1, out = 0); +const_try_conv!(2, 2, 1, 0, 1, out = 1); +const_try_conv!(3, 2, 1, 0, 1, out = 2); + +const_try_conv!(1, 2, 1, 2, 1, out = 4); + +const_try_conv!(1, 1, 1, 1, 1, out = 3); +const_try_conv!(1, 2, 1, 1, 1, out = 2); +const_try_conv!(2, 2, 1, 1, 1, out = 3); +const_try_conv!(1, 3, 1, 1, 1, out = 1); +const_try_conv!(2, 3, 1, 1, 1, out = 2); +const_try_conv!(3, 2, 1, 1, 1, out = 4); + +const_try_conv!(5, 3, 1, 0, 1, out = 3); + +const_try_conv!(2, 2, 2, 0, 1, out = 1); +const_try_conv!(3, 2, 2, 0, 1, out = 1); +const_try_conv!(4, 2, 2, 0, 1, out = 2); + +const_try_conv!(4, 2, 1, 0, 2, out = 2); +const_try_conv!(5, 2, 1, 0, 2, out = 3); + +const_try_conv!(2, 3, 3, 4, 1, out = 3); +const_try_conv!(4, 3, 3, 4, 1, out = 4); + +const_try_conv!(6, 2, 4, 3, 1, out = 3); +const_try_conv!(7, 2, 4, 3, 1, out = 3); + +const_try_conv!(14, 3, 1, 0, 1, out = 12); +const_try_conv!(28, 6, 3, 2, 1, out = 9); + +const_try_conv!(3, 3, 1, 0, 1, out = 1); +const_try_conv!(3, 3, 1, 1, 1, out = 3); +const_try_conv!(5, 2, 2, 1, 2, out = 3); + impl TryConv2D for (usize, Kernel) { diff --git a/dfdx-core/src/tensor_ops/mod.rs b/dfdx-core/src/tensor_ops/mod.rs index 453457f4..c009f9db 100644 --- a/dfdx-core/src/tensor_ops/mod.rs +++ b/dfdx-core/src/tensor_ops/mod.rs @@ -284,9 +284,7 @@ pub use var_to::VarTo; mod conv1d; pub use conv1d::TryConv1D; -#[cfg(feature = "nightly")] mod conv2d; -#[cfg(feature = "nightly")] pub use conv2d::TryConv2D; #[cfg(feature = "nightly")]