diff --git a/keras/src/applications/convnext.py b/keras/src/applications/convnext.py index c99280ff9b08..721a3e742019 100644 --- a/keras/src/applications/convnext.py +++ b/keras/src/applications/convnext.py @@ -522,6 +522,30 @@ def ConvNeXt( model = Functional(inputs=inputs, outputs=x, name=name) + # Validate weights before requesting them from the API + if weights == "imagenet": + expected_config = MODEL_CONFIGS[weights_name.split("convnext_")[-1]] + if ( + depths != expected_config["depths"] + or projection_dims != expected_config["projection_dims"] + ): + raise ValueError( + f"Architecture configuration does not match {weights_name} " + f"variant. When using pre-trained weights, the model " + f"architecture must match the pre-trained configuration " + f"exactly. Expected depths: {expected_config['depths']}, " + f"got: {depths}. Expected projection_dims: " + f"{expected_config['projection_dims']}, got: {projection_dims}." + ) + + if weights_name not in name: + raise ValueError( + f'Model name "{name}" does not match weights variant ' + f'"{weights_name}". When using imagenet weights, model name ' + f'must contain the weights variant (e.g., "convnext_' + f'{weights_name.split("convnext_")[-1]}").' + ) + # Load weights. if weights == "imagenet": if include_top: diff --git a/keras/src/applications/efficientnet_v2.py b/keras/src/applications/efficientnet_v2.py index e0e4c0b9be83..b0b59470b349 100644 --- a/keras/src/applications/efficientnet_v2.py +++ b/keras/src/applications/efficientnet_v2.py @@ -935,9 +935,17 @@ def EfficientNetV2( num_channels = input_shape[bn_axis - 1] if name.split("-")[-1].startswith("b") and num_channels == 3: x = layers.Rescaling(scale=1.0 / 255)(x) + if backend.image_data_format() == "channels_first": + mean = [[[[0.485]], [[0.456]], [[0.406]]]] # shape [1,3,1,1] + variance = [ + [[[0.229**2]], [[0.224**2]], [[0.225**2]]] + ] # shape [1,3,1,1] + else: + mean = [0.485, 0.456, 0.406] + variance = [0.229**2, 0.224**2, 0.225**2] x = layers.Normalization( - mean=[0.485, 0.456, 0.406], - variance=[0.229**2, 0.224**2, 0.225**2], + mean=mean, + variance=variance, axis=bn_axis, )(x) else: