diff --git a/src/garage/torch/distributions/tanh_normal.py b/src/garage/torch/distributions/tanh_normal.py index 6d1ca10d34..51a835031c 100644 --- a/src/garage/torch/distributions/tanh_normal.py +++ b/src/garage/torch/distributions/tanh_normal.py @@ -19,7 +19,7 @@ class TanhNormal(torch.distributions.Distribution): """ # noqa: 501 def __init__(self, loc, scale): - self._normal = Independent(Normal(loc, scale), 1) + self._normal = Independent(Normal(loc, scale), 1, validate_args=False) super().__init__() def log_prob(self, value, pre_tanh_value=None, epsilon=1e-6):