From 4c05e0cd2dfd385fb9db204885ddf51635b9e9b7 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Sat, 14 Dec 2024 09:17:31 -0800 Subject: [PATCH] Fix issues with randomgrayscale layer --- .../preprocessing/image_preprocessing/random_grayscale.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py index 418f2982777..804e9323a0f 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py @@ -45,7 +45,7 @@ class RandomGrayscale(BaseImagePreprocessingLayer): will have the same value. """ - def __init__(self, factor=0.5, data_format=None, **kwargs): + def __init__(self, factor=0.5, data_format=None, seed=None, **kwargs): super().__init__(**kwargs) if factor < 0 or factor > 1: raise ValueError( @@ -54,7 +54,8 @@ def __init__(self, factor=0.5, data_format=None, **kwargs): ) self.factor = factor self.data_format = backend.standardize_data_format(data_format) - self.generator = self.backend.random.SeedGenerator() + self.seed = seed + self.generator = self.backend.random.SeedGenerator(seed) def get_random_transformation(self, images, training=True, seed=None): if seed is None: