From 233761df6184e87e1ec5eb512e5c87534a33d141 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Tue, 14 Nov 2023 15:30:20 -0800 Subject: [PATCH] Manual set seed in torch gamma --- keras/backend/torch/random.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/keras/backend/torch/random.py b/keras/backend/torch/random.py index 608b0df6b47..83a21fcee13 100644 --- a/keras/backend/torch/random.py +++ b/keras/backend/torch/random.py @@ -196,6 +196,10 @@ def gamma(shape, alpha, dtype=None, seed=None): dtype = to_torch_dtype(dtype) alpha = torch.ones(shape) * torch.tensor(alpha) beta = torch.ones(shape) - # TODO: seed / generator not supported + prev_rng_state = torch.random.get_rng_state() + first_seed, second_seed = draw_seed(seed) + torch.manual_seed(first_seed + second_seed) gamma_distribution = torch.distributions.gamma.Gamma(alpha, beta) - return gamma_distribution.sample().type(dtype) + sample = gamma_distribution.sample().type(dtype) + torch.random.set_rng_state(prev_rng_state) + return sample