diff --git a/backbones/hrnet.py b/backbones/hrnet.py index 305f3e3..c8f5c5c 100644 --- a/backbones/hrnet.py +++ b/backbones/hrnet.py @@ -41,7 +41,7 @@ def call(self, inputs, training=None): residual = tf.identity(inputs, name="residual") if self.downsample is not None: - residual = self.downsample(residual) + residual = self.downsample(residual, training=training) x = self.conv1(inputs) x = self.bn1(x, training=training) @@ -85,7 +85,7 @@ def call(self, inputs, training=None): residual = tf.identity(inputs, name="residual") if self.downsample is not None: - residual = self.downsample(residual) + residual = self.downsample(residual, training=training) x = self.conv1(inputs) x = self.bn1(x, training=training)