diff --git a/steganogan/encoders.py b/steganogan/encoders.py index d1a93c8..3f08a8b 100644 --- a/steganogan/encoders.py +++ b/steganogan/encoders.py @@ -3,6 +3,15 @@ import torch from torch import nn +import torch.onnx +from torchvision.ops.deform_conv import DeformConv2d + +input = torch.rand(4, 3, 10, 10) +kh, kw = 3, 3 +weight = torch.rand(5, 3, kh, kw) +offset = torch.rand(4, 2 * kh * kw, 8, 8) +mask = torch.rand(4, kh * kw, 8, 8) + class BasicEncoder(nn.Module): """ @@ -16,11 +25,11 @@ class BasicEncoder(nn.Module): add_image = False def _conv2d(self, in_channels, out_channels): - return nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - padding=1 + return DeformConv2d( + input=input + offset=offset, + weight=weight, + mask=mask ) def _build_models(self):