diff --git a/keras/src/layers/normalization/spectral_normalization.py b/keras/src/layers/normalization/spectral_normalization.py index 727d6bb58db..fc11844fc92 100644 --- a/keras/src/layers/normalization/spectral_normalization.py +++ b/keras/src/layers/normalization/spectral_normalization.py @@ -105,8 +105,8 @@ def normalized_weights(self): ops.matmul(vector_u, ops.transpose(weights)), axis=None ) vector_u = normalize(ops.matmul(vector_v, weights), axis=None) - # vector_u = tf.stop_gradient(vector_u) - # vector_v = tf.stop_gradient(vector_v) + vector_u = ops.stop_gradient(vector_u) + vector_v = ops.stop_gradient(vector_v) sigma = ops.matmul( ops.matmul(vector_v, weights), ops.transpose(vector_u) )