Skip to content

Commit

Permalink
make conv2d work on stable
Browse files Browse the repository at this point in the history
  • Loading branch information
Tom committed Jul 22, 2024
1 parent 78f727f commit 0d10967
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 2 deletions.
61 changes: 61 additions & 0 deletions dfdx-core/src/tensor_ops/conv2d/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ pub trait TryConv2D<Stride, Padding, Dilation, Groups>: Sized {
) -> Result<Self::Convolved, Error>;
}

#[cfg(feature = "nightly")]
impl<
const KERNEL: usize,
const STRIDE: usize,
Expand All @@ -140,6 +141,66 @@ where
}
}

macro_rules! const_try_conv {
($Dim:expr, $Kernel:expr, $Stride:expr, $Padding:expr, $Dilation:expr, out=$Out_dim:expr) => {
#[cfg(not(feature = "nightly"))]
impl<Groups: Dim> TryConv2D<Const<$Stride>, Const<$Padding>, Const<$Dilation>, Groups>
for (Const<$Dim>, Const<$Kernel>)
{
// ($Dim + 2 * $Padding - $Dilation * ($Kernel - 1) - 1) / $Stride + 1
// def compute_output_size(dim, kernel_size, stride, padding, dilation):
// output_size = int(int(dim + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1)
// return output_size
type Convolved = Const<$Out_dim>;

fn try_conv2d(
self,
_: Const<$Stride>,
_: Const<$Padding>,
_: Const<$Dilation>,
_: Groups,
) -> Result<Self::Convolved, Error> {
Ok(Const)
}
}
};
}

const_try_conv!(1, 2, 1, 0, 1, out = 0);
const_try_conv!(2, 2, 1, 0, 1, out = 1);
const_try_conv!(3, 2, 1, 0, 1, out = 2);

const_try_conv!(1, 2, 1, 2, 1, out = 4);

const_try_conv!(1, 1, 1, 1, 1, out = 3);
const_try_conv!(1, 2, 1, 1, 1, out = 2);
const_try_conv!(2, 2, 1, 1, 1, out = 3);
const_try_conv!(1, 3, 1, 1, 1, out = 1);
const_try_conv!(2, 3, 1, 1, 1, out = 2);
const_try_conv!(3, 2, 1, 1, 1, out = 4);

const_try_conv!(5, 3, 1, 0, 1, out = 3);

const_try_conv!(2, 2, 2, 0, 1, out = 1);
const_try_conv!(3, 2, 2, 0, 1, out = 1);
const_try_conv!(4, 2, 2, 0, 1, out = 2);

const_try_conv!(4, 2, 1, 0, 2, out = 2);
const_try_conv!(5, 2, 1, 0, 2, out = 3);

const_try_conv!(2, 3, 3, 4, 1, out = 3);
const_try_conv!(4, 3, 3, 4, 1, out = 4);

const_try_conv!(6, 2, 4, 3, 1, out = 3);
const_try_conv!(7, 2, 4, 3, 1, out = 3);

const_try_conv!(14, 3, 1, 0, 1, out = 12);
const_try_conv!(28, 6, 3, 2, 1, out = 9);

const_try_conv!(3, 3, 1, 0, 1, out = 1);
const_try_conv!(3, 3, 1, 1, 1, out = 3);
const_try_conv!(5, 2, 2, 1, 2, out = 3);

impl<Kernel: Dim, Stride: Dim, Padding: Dim, Dilation: Dim, Groups: Dim>
TryConv2D<Stride, Padding, Dilation, Groups> for (usize, Kernel)
{
Expand Down
2 changes: 0 additions & 2 deletions dfdx-core/src/tensor_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,7 @@ pub use var_to::VarTo;
mod conv1d;
pub use conv1d::TryConv1D;

#[cfg(feature = "nightly")]
mod conv2d;
#[cfg(feature = "nightly")]
pub use conv2d::TryConv2D;

#[cfg(feature = "nightly")]
Expand Down

0 comments on commit 0d10967

Please sign in to comment.