Skip to content

Commit

Permalink
Manual set seed in torch gamma
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Nov 14, 2023
1 parent b9cca1a commit 233761d
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions keras/backend/torch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ def gamma(shape, alpha, dtype=None, seed=None):
dtype = to_torch_dtype(dtype)
alpha = torch.ones(shape) * torch.tensor(alpha)
beta = torch.ones(shape)
# TODO: seed / generator not supported
prev_rng_state = torch.random.get_rng_state()
first_seed, second_seed = draw_seed(seed)
torch.manual_seed(first_seed + second_seed)
gamma_distribution = torch.distributions.gamma.Gamma(alpha, beta)
return gamma_distribution.sample().type(dtype)
sample = gamma_distribution.sample().type(dtype)
torch.random.set_rng_state(prev_rng_state)
return sample

0 comments on commit 233761d

Please sign in to comment.