diff --git a/keras_cv/models/segmentation/segment_anything/sam_layers.py b/keras_cv/models/segmentation/segment_anything/sam_layers.py index 127db266c4..577031c63c 100644 --- a/keras_cv/models/segmentation/segment_anything/sam_layers.py +++ b/keras_cv/models/segmentation/segment_anything/sam_layers.py @@ -275,15 +275,12 @@ def __init__(self, num_positional_features, scale, **kwargs): super().__init__(**kwargs) self.num_positional_features = num_positional_features self.scale = scale - init_func = lambda *a, **kw: self.scale * ops.random.normal( - shape=(2, self.num_positional_features), dtype=self.dtype - ) self.positional_encoding_gaussian_matrix = self.add_weight( name="positional_encoding_gaussian_matrix", shape=(2, self.num_positional_features), dtype=self.dtype, trainable=False, - initializer=init_func, + initializer=keras.initializers.get("normal"), ) def build(self, input_shape=None):