diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index 9f7a85c0c93..2992023a2c3 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -65,6 +65,21 @@ class BinaryAccuracy(BinaryStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): >>> from torch import tensor @@ -205,6 +220,21 @@ class MulticlassAccuracy(MulticlassStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): >>> from torch import tensor @@ -356,6 +386,21 @@ class MultilabelAccuracy(MultilabelStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): >>> from torch import tensor @@ -497,19 +542,21 @@ def __new__( # type: ignore[misc] top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", **kwargs: Any, ) -> Metric: """Initialize task metric.""" task = ClassificationTask.from_str(task) - - kwargs.update({ + kwargs_extra = kwargs.copy() + kwargs_extra.update({ "multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args, + "input_format": input_format, }) if task == ClassificationTask.BINARY: - return BinaryAccuracy(threshold, **kwargs) + return BinaryAccuracy(threshold, **kwargs_extra) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError( @@ -517,11 +564,11 @@ def __new__( # type: ignore[misc] ) if not isinstance(top_k, int): raise ValueError(f"Optional arg `top_k` must be type `int` when task is {task}. Got {type(top_k)}") - return MulticlassAccuracy(num_classes, top_k, average, **kwargs) + return MulticlassAccuracy(num_classes, top_k, average, **kwargs_extra) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError( f"Optional arg `num_labels` must be type `int` when task is {task}. Got {type(num_labels)}" ) - return MultilabelAccuracy(num_labels, threshold, average, **kwargs) + return MultilabelAccuracy(num_labels, threshold, average, **kwargs_extra) raise ValueError(f"Not handled value: {task}") diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index 65e9493b14c..a171a7d9431 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -83,6 +83,18 @@ class BinaryAUROC(BinaryPrecisionRecallCurve): validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: @@ -111,11 +123,14 @@ def __init__( thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", **kwargs: Any, ) -> None: - super().__init__(thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs) + super().__init__( + thresholds=thresholds, ignore_index=ignore_index, validate_args=False, input_format=input_format, **kwargs + ) if validate_args: - _binary_auroc_arg_validation(max_fpr, thresholds, ignore_index) + _binary_auroc_arg_validation(max_fpr, thresholds, ignore_index, input_format) self.max_fpr = max_fpr def compute(self) -> Tensor: # type: ignore[override] @@ -221,6 +236,18 @@ class MulticlassAUROC(MulticlassPrecisionRecallCurve): validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: @@ -260,13 +287,19 @@ def __init__( thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", **kwargs: Any, ) -> None: super().__init__( - num_classes=num_classes, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs + num_classes=num_classes, + thresholds=thresholds, + ignore_index=ignore_index, + validate_args=False, + input_format=input_format, + **kwargs, ) if validate_args: - _multiclass_auroc_arg_validation(num_classes, average, thresholds, ignore_index) + _multiclass_auroc_arg_validation(num_classes, average, thresholds, ignore_index, input_format) self.average = average # type: ignore[assignment] self.validate_args = validate_args @@ -373,6 +406,18 @@ class MultilabelAUROC(MultilabelPrecisionRecallCurve): validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: @@ -415,13 +460,19 @@ def __init__( thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", **kwargs: Any, ) -> None: super().__init__( - num_labels=num_labels, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs + num_labels=num_labels, + thresholds=thresholds, + ignore_index=ignore_index, + validate_args=False, + input_format=input_format, + **kwargs, ) if validate_args: - _multilabel_auroc_arg_validation(num_labels, average, thresholds, ignore_index) + _multilabel_auroc_arg_validation(num_labels, average, thresholds, ignore_index, input_format) self.average = average self.validate_args = validate_args @@ -516,31 +567,26 @@ def __new__( # type: ignore[misc] max_fpr: Optional[float] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", **kwargs: Any, ) -> Metric: """Initialize task metric.""" task = ClassificationTask.from_str(task) - kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args}) + kwargs_extra = kwargs.copy() + kwargs_extra.update({ + "thresholds": thresholds, + "ignore_index": ignore_index, + "validate_args": validate_args, + "input_format": input_format, + }) if task == ClassificationTask.BINARY: - return BinaryAUROC(max_fpr, **kwargs) + return BinaryAUROC(max_fpr, **kwargs_extra) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") - return MulticlassAUROC(num_classes, average, **kwargs) + return MulticlassAUROC(num_classes, average, **kwargs_extra) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") - return MultilabelAUROC(num_labels, average, **kwargs) + return MultilabelAUROC(num_labels, average, **kwargs_extra) raise ValueError(f"Task {task} not supported!") - - def update(self, *args: Any, **kwargs: Any) -> None: - """Update metric state.""" - raise NotImplementedError( - f"{self.__class__.__name__} metric does not have a global `update` method. Use the task specific metric." - ) - - def compute(self) -> None: - """Compute metric.""" - raise NotImplementedError( - f"{self.__class__.__name__} metric does not have a global `compute` method. Use the task specific metric." - ) diff --git a/src/torchmetrics/classification/average_precision.py b/src/torchmetrics/classification/average_precision.py index 9d36774938c..6f26ff404fd 100644 --- a/src/torchmetrics/classification/average_precision.py +++ b/src/torchmetrics/classification/average_precision.py @@ -90,6 +90,18 @@ class BinaryAveragePrecision(BinaryPrecisionRecallCurve): validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: @@ -219,6 +231,18 @@ class MulticlassAveragePrecision(MulticlassPrecisionRecallCurve): validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: @@ -258,13 +282,19 @@ def __init__( thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", **kwargs: Any, ) -> None: super().__init__( - num_classes=num_classes, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs + num_classes=num_classes, + thresholds=thresholds, + ignore_index=ignore_index, + validate_args=validate_args, + input_format=input_format, + **kwargs, ) if validate_args: - _multiclass_average_precision_arg_validation(num_classes, average, thresholds, ignore_index) + _multiclass_average_precision_arg_validation(num_classes, average, thresholds, ignore_index, input_format) self.average = average # type: ignore[assignment] self.validate_args = validate_args @@ -376,6 +406,18 @@ class MultilabelAveragePrecision(MultilabelPrecisionRecallCurve): validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: @@ -418,13 +460,19 @@ def __init__( thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", **kwargs: Any, ) -> None: super().__init__( - num_labels=num_labels, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs + num_labels=num_labels, + thresholds=thresholds, + ignore_index=ignore_index, + validate_args=validate_args, + input_format=input_format, + **kwargs, ) if validate_args: - _multilabel_average_precision_arg_validation(num_labels, average, thresholds, ignore_index) + _multilabel_average_precision_arg_validation(num_labels, average, thresholds, ignore_index, input_format) self.average = average self.validate_args = validate_args @@ -525,19 +573,26 @@ def __new__( # type: ignore[misc] average: Optional[Literal["macro", "weighted", "none"]] = "macro", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", **kwargs: Any, ) -> Metric: """Initialize task metric.""" task = ClassificationTask.from_str(task) - kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args}) + kwargs_extra = kwargs.copy() + kwargs_extra.update({ + "thresholds": thresholds, + "ignore_index": ignore_index, + "validate_args": validate_args, + "input_format": input_format, + }) if task == ClassificationTask.BINARY: - return BinaryAveragePrecision(**kwargs) + return BinaryAveragePrecision(**kwargs_extra) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") - return MulticlassAveragePrecision(num_classes, average, **kwargs) + return MulticlassAveragePrecision(num_classes, average, **kwargs_extra) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") - return MultilabelAveragePrecision(num_labels, average, **kwargs) + return MultilabelAveragePrecision(num_labels, average, **kwargs_extra) raise ValueError(f"Task {task} not supported!") diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index d3f65faea9e..2f03982d443 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -87,6 +87,20 @@ class BinaryConfusionMatrix(Metric): - ``'all'``: normalization over the whole matrix validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): @@ -121,23 +135,27 @@ def __init__( ignore_index: Optional[int] = None, normalize: Optional[Literal["true", "pred", "all", "none"]] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", **kwargs: Any, ) -> None: super().__init__(**kwargs) if validate_args: - _binary_confusion_matrix_arg_validation(threshold, ignore_index, normalize) + _binary_confusion_matrix_arg_validation(threshold, ignore_index, normalize, input_format) self.threshold = threshold self.ignore_index = ignore_index self.normalize = normalize self.validate_args = validate_args + self.input_format = input_format self.add_state("confmat", torch.zeros(2, 2, dtype=torch.long), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: """Update state with predictions and targets.""" if self.validate_args: - _binary_confusion_matrix_tensor_validation(preds, target, self.ignore_index) - preds, target = _binary_confusion_matrix_format(preds, target, self.threshold, self.ignore_index) + _binary_confusion_matrix_tensor_validation(preds, target, self.ignore_index, self.input_format) + preds, target = _binary_confusion_matrix_format( + preds, target, self.threshold, self.ignore_index, self.input_format + ) confmat = _binary_confusion_matrix_update(preds, target) self.confmat += confmat @@ -222,6 +240,20 @@ class MulticlassConfusionMatrix(Metric): - ``'all'``: normalization over the whole matrix validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (pred is integer tensor): @@ -262,23 +294,29 @@ def __init__( ignore_index: Optional[int] = None, normalize: Optional[Literal["none", "true", "pred", "all"]] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", **kwargs: Any, ) -> None: super().__init__(**kwargs) if validate_args: - _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index, normalize) + _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index, normalize, input_format) self.num_classes = num_classes self.ignore_index = ignore_index self.normalize = normalize self.validate_args = validate_args + self.input_format = input_format self.add_state("confmat", torch.zeros(num_classes, num_classes, dtype=torch.long), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: """Update state with predictions and targets.""" if self.validate_args: - _multiclass_confusion_matrix_tensor_validation(preds, target, self.num_classes, self.ignore_index) - preds, target = _multiclass_confusion_matrix_format(preds, target, self.ignore_index) + _multiclass_confusion_matrix_tensor_validation( + preds, target, self.num_classes, self.ignore_index, self.input_format + ) + preds, target = _multiclass_confusion_matrix_format( + preds, target, self.ignore_index, input_format=self.input_format + ) confmat = _multiclass_confusion_matrix_update(preds, target, self.num_classes) self.confmat += confmat @@ -365,6 +403,20 @@ class MultilabelConfusionMatrix(Metric): - ``'all'``: normalization over the whole matrix validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str or bool specifying the format of the input preds tensor. Can be one of: + + - ``'auto'`` or ``True``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``False``: will disable all input formatting. This is the fastest option but also the least safe. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): @@ -403,25 +455,29 @@ def __init__( ignore_index: Optional[int] = None, normalize: Optional[Literal["none", "true", "pred", "all"]] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", **kwargs: Any, ) -> None: super().__init__(**kwargs) if validate_args: - _multilabel_confusion_matrix_arg_validation(num_labels, threshold, ignore_index, normalize) + _multilabel_confusion_matrix_arg_validation(num_labels, threshold, ignore_index, normalize, input_format) self.num_labels = num_labels self.threshold = threshold self.ignore_index = ignore_index self.normalize = normalize self.validate_args = validate_args + self.input_format = input_format self.add_state("confmat", torch.zeros(num_labels, 2, 2, dtype=torch.long), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: """Update state with predictions and targets.""" if self.validate_args: - _multilabel_confusion_matrix_tensor_validation(preds, target, self.num_labels, self.ignore_index) + _multilabel_confusion_matrix_tensor_validation( + preds, target, self.num_labels, self.ignore_index, self.input_format + ) preds, target = _multilabel_confusion_matrix_format( - preds, target, self.num_labels, self.threshold, self.ignore_index + preds, target, self.num_labels, self.threshold, self.ignore_index, self.input_format ) confmat = _multilabel_confusion_matrix_update(preds, target, self.num_labels) self.confmat += confmat @@ -516,19 +572,26 @@ def __new__( # type: ignore[misc] normalize: Optional[Literal["true", "pred", "all", "none"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", **kwargs: Any, ) -> Metric: """Initialize task metric.""" task = ClassificationTask.from_str(task) - kwargs.update({"normalize": normalize, "ignore_index": ignore_index, "validate_args": validate_args}) + kwargs_extra = kwargs.copy() + kwargs_extra.update({ + "normalize": normalize, + "ignore_index": ignore_index, + "validate_args": validate_args, + "input_format": input_format, + }) if task == ClassificationTask.BINARY: - return BinaryConfusionMatrix(threshold, **kwargs) + return BinaryConfusionMatrix(threshold, **kwargs_extra) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") - return MulticlassConfusionMatrix(num_classes, **kwargs) + return MulticlassConfusionMatrix(num_classes, **kwargs_extra) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") - return MultilabelConfusionMatrix(num_labels, threshold, **kwargs) + return MultilabelConfusionMatrix(num_labels, threshold, **kwargs_extra) raise ValueError(f"Task {task} not supported!") diff --git a/src/torchmetrics/classification/f_beta.py b/src/torchmetrics/classification/f_beta.py index eec4c33bd8b..c8f500f2743 100644 --- a/src/torchmetrics/classification/f_beta.py +++ b/src/torchmetrics/classification/f_beta.py @@ -83,6 +83,21 @@ class BinaryFBetaScore(BinaryStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): >>> from torch import tensor @@ -125,6 +140,7 @@ def __init__( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", **kwargs: Any, ) -> None: super().__init__( @@ -132,10 +148,11 @@ def __init__( multidim_average=multidim_average, ignore_index=ignore_index, validate_args=False, + input_format=input_format, **kwargs, ) if validate_args: - _binary_fbeta_score_arg_validation(beta, threshold, multidim_average, ignore_index) + _binary_fbeta_score_arg_validation(beta, threshold, multidim_average, ignore_index, input_format) self.validate_args = validate_args self.beta = beta @@ -249,6 +266,22 @@ class MulticlassFBetaScore(MulticlassStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + Example (preds is int tensor): >>> from torch import tensor @@ -306,6 +339,7 @@ def __init__( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", **kwargs: Any, ) -> None: super().__init__( @@ -315,10 +349,13 @@ def __init__( multidim_average=multidim_average, ignore_index=ignore_index, validate_args=False, + input_format=input_format, **kwargs, ) if validate_args: - _multiclass_fbeta_score_arg_validation(beta, num_classes, top_k, average, multidim_average, ignore_index) + _multiclass_fbeta_score_arg_validation( + beta, num_classes, top_k, average, multidim_average, ignore_index, input_format + ) self.validate_args = validate_args self.beta = beta @@ -430,6 +467,21 @@ class MultilabelFBetaScore(MultilabelStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): >>> from torch import tensor @@ -485,6 +537,7 @@ def __init__( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", **kwargs: Any, ) -> None: super().__init__( @@ -494,10 +547,13 @@ def __init__( multidim_average=multidim_average, ignore_index=ignore_index, validate_args=False, + input_format=input_format, **kwargs, ) if validate_args: - _multilabel_fbeta_score_arg_validation(beta, num_labels, threshold, average, multidim_average, ignore_index) + _multilabel_fbeta_score_arg_validation( + beta, num_labels, threshold, average, multidim_average, ignore_index, input_format + ) self.validate_args = validate_args self.beta = beta @@ -592,6 +648,21 @@ class BinaryF1Score(BinaryFBetaScore): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): >>> from torch import tensor @@ -633,6 +704,7 @@ def __init__( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", **kwargs: Any, ) -> None: super().__init__( @@ -641,6 +713,7 @@ def __init__( multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args, + input_format=input_format, **kwargs, ) @@ -748,6 +821,21 @@ class MulticlassF1Score(MulticlassFBetaScore): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): >>> from torch import tensor @@ -804,6 +892,7 @@ def __init__( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", **kwargs: Any, ) -> None: super().__init__( @@ -814,6 +903,7 @@ def __init__( multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args, + input_format=input_format, **kwargs, ) @@ -919,6 +1009,21 @@ class MultilabelF1Score(MultilabelFBetaScore): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): >>> from torch import tensor @@ -973,6 +1078,7 @@ def __init__( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", **kwargs: Any, ) -> None: super().__init__( @@ -983,6 +1089,7 @@ def __init__( multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args, + input_format=input_format, **kwargs, ) @@ -1070,28 +1177,31 @@ def __new__( top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", **kwargs: Any, ) -> Metric: """Initialize task metric.""" task = ClassificationTask.from_str(task) assert multidim_average is not None # noqa: S101 # needed for mypy - kwargs.update({ + kwargs_extra = kwargs.copy() + kwargs_extra.update({ "multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args, + "input_format": input_format, }) if task == ClassificationTask.BINARY: - return BinaryFBetaScore(beta, threshold, **kwargs) + return BinaryFBetaScore(beta, threshold, **kwargs_extra) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") if not isinstance(top_k, int): raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`") - return MulticlassFBetaScore(beta, num_classes, top_k, average, **kwargs) + return MulticlassFBetaScore(beta, num_classes, top_k, average, **kwargs_extra) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") - return MultilabelFBetaScore(beta, num_labels, threshold, average, **kwargs) + return MultilabelFBetaScore(beta, num_labels, threshold, average, **kwargs_extra) raise ValueError(f"Task {task} not supported!") @@ -1133,26 +1243,29 @@ def __new__( top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", **kwargs: Any, ) -> Metric: """Initialize task metric.""" task = ClassificationTask.from_str(task) assert multidim_average is not None # noqa: S101 # needed for mypy - kwargs.update({ + kwargs_extra = kwargs.copy() + kwargs_extra.update({ "multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args, + "input_format": input_format, }) if task == ClassificationTask.BINARY: - return BinaryF1Score(threshold, **kwargs) + return BinaryF1Score(threshold, **kwargs_extra) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") if not isinstance(top_k, int): raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`") - return MulticlassF1Score(num_classes, top_k, average, **kwargs) + return MulticlassF1Score(num_classes, top_k, average, **kwargs_extra) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") - return MultilabelF1Score(num_labels, threshold, average, **kwargs) + return MultilabelF1Score(num_labels, threshold, average, **kwargs_extra) raise ValueError(f"Task {task} not supported!") diff --git a/src/torchmetrics/classification/hamming.py b/src/torchmetrics/classification/hamming.py index bd0bfa733c6..216adb22db9 100644 --- a/src/torchmetrics/classification/hamming.py +++ b/src/torchmetrics/classification/hamming.py @@ -74,6 +74,21 @@ class BinaryHammingDistance(BinaryStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): >>> from torch import tensor @@ -215,6 +230,21 @@ class MulticlassHammingDistance(MulticlassStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): >>> from torch import tensor @@ -369,6 +399,21 @@ class MultilabelHammingDistance(MultilabelStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): >>> from torch import tensor @@ -503,26 +548,29 @@ def __new__( # type: ignore[misc] top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", **kwargs: Any, ) -> Metric: """Initialize task metric.""" task = ClassificationTask.from_str(task) assert multidim_average is not None # noqa: S101 # needed for mypy - kwargs.update({ + kwargs_extra = kwargs.copy() + kwargs_extra.update({ "multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args, + "input_format": input_format, }) if task == ClassificationTask.BINARY: - return BinaryHammingDistance(threshold, **kwargs) + return BinaryHammingDistance(threshold, **kwargs_extra) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") if not isinstance(top_k, int): raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`") - return MulticlassHammingDistance(num_classes, top_k, average, **kwargs) + return MulticlassHammingDistance(num_classes, top_k, average, **kwargs_extra) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") - return MultilabelHammingDistance(num_labels, threshold, average, **kwargs) + return MultilabelHammingDistance(num_labels, threshold, average, **kwargs_extra) raise ValueError(f"Task {task} not supported!") diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index 64aeffc59df..3c2f7620656 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -73,6 +73,21 @@ class BinaryPrecision(BinaryStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): >>> from torch import tensor @@ -217,6 +232,21 @@ class MulticlassPrecision(MulticlassStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): >>> from torch import tensor @@ -373,6 +403,21 @@ class MultilabelPrecision(MultilabelStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): >>> from torch import tensor @@ -507,6 +552,20 @@ class BinaryRecall(BinaryStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): >>> from torch import tensor @@ -650,6 +709,21 @@ class MulticlassRecall(MulticlassStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): >>> from torch import tensor @@ -805,6 +879,21 @@ class MultilabelRecall(MultilabelStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): >>> from torch import tensor @@ -941,28 +1030,31 @@ def __new__( top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", **kwargs: Any, ) -> Metric: """Initialize task metric.""" assert multidim_average is not None # noqa: S101 # needed for mypy - kwargs.update({ + kwargs_extra = kwargs.copy() + kwargs_extra.update({ "multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args, + "input_format": input_format, }) task = ClassificationTask.from_str(task) if task == ClassificationTask.BINARY: - return BinaryPrecision(threshold, **kwargs) + return BinaryPrecision(threshold, **kwargs_extra) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") if not isinstance(top_k, int): raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`") - return MulticlassPrecision(num_classes, top_k, average, **kwargs) + return MulticlassPrecision(num_classes, top_k, average, **kwargs_extra) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") - return MultilabelPrecision(num_labels, threshold, average, **kwargs) + return MultilabelPrecision(num_labels, threshold, average, **kwargs_extra) raise ValueError(f"Task {task} not supported!") @@ -1006,26 +1098,29 @@ def __new__( top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", **kwargs: Any, ) -> Metric: """Initialize task metric.""" task = ClassificationTask.from_str(task) assert multidim_average is not None # noqa: S101 # needed for mypy - kwargs.update({ + kwargs_extra = kwargs.copy() + kwargs_extra.update({ "multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args, + "input_format": input_format, }) if task == ClassificationTask.BINARY: - return BinaryRecall(threshold, **kwargs) + return BinaryRecall(threshold, **kwargs_extra) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") if not isinstance(top_k, int): raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`") - return MulticlassRecall(num_classes, top_k, average, **kwargs) + return MulticlassRecall(num_classes, top_k, average, **kwargs_extra) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") - return MultilabelRecall(num_labels, threshold, average, **kwargs) + return MultilabelRecall(num_labels, threshold, average, **kwargs_extra) return None diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index 46b874a74b2..1d87c8c5c22 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -107,6 +107,19 @@ class BinaryPrecisionRecallCurve(Metric): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: @@ -139,14 +152,15 @@ def __init__( thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", **kwargs: Any, ) -> None: super().__init__(**kwargs) if validate_args: - _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) - + _binary_precision_recall_curve_arg_validation(thresholds, ignore_index, input_format) self.ignore_index = ignore_index self.validate_args = validate_args + self.input_format = input_format thresholds = _adjust_threshold_arg(thresholds) if thresholds is None: @@ -162,8 +176,10 @@ def __init__( def update(self, preds: Tensor, target: Tensor) -> None: """Update metric states.""" if self.validate_args: - _binary_precision_recall_curve_tensor_validation(preds, target, self.ignore_index) - preds, target, _ = _binary_precision_recall_curve_format(preds, target, self.thresholds, self.ignore_index) + _binary_precision_recall_curve_tensor_validation(preds, target, self.ignore_index, self.input_format) + preds, target, _ = _binary_precision_recall_curve_format( + preds, target, self.thresholds, self.ignore_index, self.input_format + ) state = _binary_precision_recall_curve_update(preds, target, self.thresholds) if isinstance(state, Tensor): self.confmat += state @@ -283,6 +299,19 @@ class MulticlassPrecisionRecallCurve(Metric): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: @@ -333,16 +362,19 @@ def __init__( average: Optional[Literal["micro", "macro"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", **kwargs: Any, ) -> None: super().__init__(**kwargs) if validate_args: - _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index, average) - + _multiclass_precision_recall_curve_arg_validation( + num_classes, thresholds, ignore_index, average, input_format + ) self.num_classes = num_classes self.average = average self.ignore_index = ignore_index self.validate_args = validate_args + self.input_format = input_format thresholds = _adjust_threshold_arg(thresholds) if thresholds is None: @@ -360,9 +392,11 @@ def __init__( def update(self, preds: Tensor, target: Tensor) -> None: """Update metric states.""" if self.validate_args: - _multiclass_precision_recall_curve_tensor_validation(preds, target, self.num_classes, self.ignore_index) + _multiclass_precision_recall_curve_tensor_validation( + preds, target, self.num_classes, self.ignore_index, self.input_format + ) preds, target, _ = _multiclass_precision_recall_curve_format( - preds, target, self.num_classes, self.thresholds, self.ignore_index, self.average + preds, target, self.num_classes, self.thresholds, self.ignore_index, self.average, self.input_format ) state = _multiclass_precision_recall_curve_update( preds, target, self.num_classes, self.thresholds, self.average @@ -482,6 +516,17 @@ class MultilabelPrecisionRecallCurve(Metric): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str or bool specifying the format of the input preds tensor. Can be one of: + + - ``'auto'`` or ``True``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and do nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor before calculating the metric. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. No + transformation will be applied to the tensor, but values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor before calculating the metric. + - ``False``: will disable all input formatting. This is the fastest option but also the least safe. + Example: >>> from torchmetrics.classification import MultilabelPrecisionRecallCurve @@ -529,15 +574,16 @@ def __init__( thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", **kwargs: Any, ) -> None: super().__init__(**kwargs) if validate_args: - _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) - + _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index, input_format) self.num_labels = num_labels self.ignore_index = ignore_index self.validate_args = validate_args + self.input_format = input_format thresholds = _adjust_threshold_arg(thresholds) if thresholds is None: @@ -555,9 +601,11 @@ def __init__( def update(self, preds: Tensor, target: Tensor) -> None: """Update metric states.""" if self.validate_args: - _multilabel_precision_recall_curve_tensor_validation(preds, target, self.num_labels, self.ignore_index) + _multilabel_precision_recall_curve_tensor_validation( + preds, target, self.num_labels, self.ignore_index, self.input_format + ) preds, target, _ = _multilabel_precision_recall_curve_format( - preds, target, self.num_labels, self.thresholds, self.ignore_index + preds, target, self.num_labels, self.thresholds, self.ignore_index, self.input_format ) state = _multilabel_precision_recall_curve_update(preds, target, self.num_labels, self.thresholds) if isinstance(state, Tensor): @@ -667,19 +715,26 @@ def __new__( # type: ignore[misc] num_labels: Optional[int] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", **kwargs: Any, ) -> Metric: """Initialize task metric.""" task = ClassificationTask.from_str(task) - kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args}) + kwargs_extra = kwargs.copy() + kwargs_extra.update({ + "thresholds": thresholds, + "ignore_index": ignore_index, + "validate_args": validate_args, + "input_format": input_format, + }) if task == ClassificationTask.BINARY: - return BinaryPrecisionRecallCurve(**kwargs) + return BinaryPrecisionRecallCurve(**kwargs_extra) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") - return MulticlassPrecisionRecallCurve(num_classes, **kwargs) + return MulticlassPrecisionRecallCurve(num_classes, **kwargs_extra) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") - return MultilabelPrecisionRecallCurve(num_labels, **kwargs) + return MultilabelPrecisionRecallCurve(num_labels, **kwargs_extra) raise ValueError(f"Task {task} not supported!") diff --git a/src/torchmetrics/classification/roc.py b/src/torchmetrics/classification/roc.py index 7f1479a1ae6..72fccfada5c 100644 --- a/src/torchmetrics/classification/roc.py +++ b/src/torchmetrics/classification/roc.py @@ -91,6 +91,18 @@ class BinaryROC(BinaryPrecisionRecallCurve): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: @@ -243,6 +255,18 @@ class MulticlassROC(MulticlassPrecisionRecallCurve): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: @@ -398,6 +422,18 @@ class MultilabelROC(MultilabelPrecisionRecallCurve): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: @@ -569,19 +605,26 @@ def __new__( num_labels: Optional[int] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", **kwargs: Any, ) -> Metric: """Initialize task metric.""" task = ClassificationTask.from_str(task) - kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args}) + kwargs_extra = kwargs.copy() + kwargs_extra.update({ + "thresholds": thresholds, + "ignore_index": ignore_index, + "validate_args": validate_args, + "input_format": input_format, + }) if task == ClassificationTask.BINARY: - return BinaryROC(**kwargs) + return BinaryROC(**kwargs_extra) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") - return MulticlassROC(num_classes, **kwargs) + return MulticlassROC(num_classes, **kwargs_extra) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") - return MultilabelROC(num_labels, **kwargs) + return MultilabelROC(num_labels, **kwargs_extra) raise ValueError(f"Task {task} not supported!") diff --git a/src/torchmetrics/classification/specificity.py b/src/torchmetrics/classification/specificity.py index a297cfe9ca5..2923f7d3da8 100644 --- a/src/torchmetrics/classification/specificity.py +++ b/src/torchmetrics/classification/specificity.py @@ -66,6 +66,21 @@ class BinarySpecificity(BinaryStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): >>> from torch import tensor @@ -205,6 +220,21 @@ class MulticlassSpecificity(MulticlassStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): >>> from torch import tensor @@ -354,6 +384,21 @@ class MultilabelSpecificity(MultilabelStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): >>> from torch import tensor @@ -487,26 +532,29 @@ def __new__( # type: ignore[misc] top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", **kwargs: Any, ) -> Metric: """Initialize task metric.""" task = ClassificationTask.from_str(task) assert multidim_average is not None # noqa: S101 # needed for mypy - kwargs.update({ + kwargs_extra = kwargs.copy() + kwargs_extra.update({ "multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args, + "input_format": input_format, }) if task == ClassificationTask.BINARY: - return BinarySpecificity(threshold, **kwargs) + return BinarySpecificity(threshold, **kwargs_extra) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") if not isinstance(top_k, int): raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`") - return MulticlassSpecificity(num_classes, top_k, average, **kwargs) + return MulticlassSpecificity(num_classes, top_k, average, **kwargs_extra) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") - return MultilabelSpecificity(num_labels, threshold, average, **kwargs) + return MultilabelSpecificity(num_labels, threshold, average, **kwargs_extra) raise ValueError(f"Task {task} not supported!") diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index b70ee9ddeac..d6878fc732f 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -126,6 +126,20 @@ class BinaryStatScores(_AbstractStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): @@ -167,23 +181,27 @@ def __init__( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", **kwargs: Any, ) -> None: super(_AbstractStatScores, self).__init__(**kwargs) if validate_args: - _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) + _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, input_format) self.threshold = threshold self.multidim_average = multidim_average self.ignore_index = ignore_index self.validate_args = validate_args + self.input_format = input_format self._create_state(size=1, multidim_average=multidim_average) def update(self, preds: Tensor, target: Tensor) -> None: """Update state with predictions and targets.""" if self.validate_args: - _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index) - preds, target = _binary_stat_scores_format(preds, target, self.threshold, self.ignore_index) + _binary_stat_scores_tensor_validation( + preds, target, self.multidim_average, self.ignore_index, self.input_format + ) + preds, target = _binary_stat_scores_format(preds, target, self.threshold, self.ignore_index, self.input_format) tp, fp, tn, fn = _binary_stat_scores_update(preds, target, self.multidim_average) self._update_state(tp, fp, tn, fn) @@ -205,7 +223,6 @@ class MulticlassStatScores(_AbstractStatScores): probabilities/logits into an int tensor. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` - As output to ``forward`` and ``compute`` the metric returns the following output: - ``mcss`` (:class:`~torch.Tensor`): A tensor of shape ``(..., 5)``, where the last dimension corresponds @@ -248,6 +265,20 @@ class MulticlassStatScores(_AbstractStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): @@ -311,17 +342,21 @@ def __init__( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", **kwargs: Any, ) -> None: super(_AbstractStatScores, self).__init__(**kwargs) if validate_args: - _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) + _multiclass_stat_scores_arg_validation( + num_classes, top_k, average, multidim_average, ignore_index, input_format + ) self.num_classes = num_classes self.top_k = top_k self.average = average self.multidim_average = multidim_average self.ignore_index = ignore_index self.validate_args = validate_args + self.input_format = input_format self._create_state( size=1 if (average == "micro" and top_k == 1) else num_classes, multidim_average=multidim_average @@ -331,9 +366,9 @@ def update(self, preds: Tensor, target: Tensor) -> None: """Update state with predictions and targets.""" if self.validate_args: _multiclass_stat_scores_tensor_validation( - preds, target, self.num_classes, self.multidim_average, self.ignore_index + preds, target, self.num_classes, self.multidim_average, self.ignore_index, self.input_format ) - preds, target = _multiclass_stat_scores_format(preds, target, self.top_k) + preds, target = _multiclass_stat_scores_format(preds, target, self.top_k, self.input_format) tp, fp, tn, fn = _multiclass_stat_scores_update( preds, target, self.num_classes, self.top_k, self.average, self.multidim_average, self.ignore_index ) @@ -398,6 +433,20 @@ class MultilabelStatScores(_AbstractStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): @@ -459,17 +508,21 @@ def __init__( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", **kwargs: Any, ) -> None: super(_AbstractStatScores, self).__init__(**kwargs) if validate_args: - _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) + _multilabel_stat_scores_arg_validation( + num_labels, threshold, average, multidim_average, ignore_index, input_format + ) self.num_labels = num_labels self.threshold = threshold self.average = average self.multidim_average = multidim_average self.ignore_index = ignore_index self.validate_args = validate_args + self.input_format = input_format self._create_state(size=num_labels, multidim_average=multidim_average) @@ -477,10 +530,10 @@ def update(self, preds: Tensor, target: Tensor) -> None: """Update state with predictions and targets.""" if self.validate_args: _multilabel_stat_scores_tensor_validation( - preds, target, self.num_labels, self.multidim_average, self.ignore_index + preds, target, self.num_labels, self.multidim_average, self.ignore_index, self.input_format ) preds, target = _multilabel_stat_scores_format( - preds, target, self.num_labels, self.threshold, self.ignore_index + preds, target, self.num_labels, self.threshold, self.ignore_index, self.input_format ) tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, self.multidim_average) self._update_state(tp, fp, tn, fn) @@ -526,26 +579,29 @@ def __new__( top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", **kwargs: Any, ) -> Metric: """Initialize task metric.""" task = ClassificationTask.from_str(task) assert multidim_average is not None # noqa: S101 # needed for mypy - kwargs.update({ + kwargs_extra = kwargs.copy() + kwargs_extra.update({ "multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args, + "input_format": input_format, }) if task == ClassificationTask.BINARY: - return BinaryStatScores(threshold, **kwargs) + return BinaryStatScores(threshold, **kwargs_extra) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") if not isinstance(top_k, int): raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`") - return MulticlassStatScores(num_classes, top_k, average, **kwargs) + return MulticlassStatScores(num_classes, top_k, average, **kwargs_extra) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") - return MultilabelStatScores(num_labels, threshold, average, **kwargs) + return MultilabelStatScores(num_labels, threshold, average, **kwargs_extra) raise ValueError(f"Task {task} not supported!") diff --git a/src/torchmetrics/functional/classification/accuracy.py b/src/torchmetrics/functional/classification/accuracy.py index 5413604b7c4..f0cf5548b5e 100644 --- a/src/torchmetrics/functional/classification/accuracy.py +++ b/src/torchmetrics/functional/classification/accuracy.py @@ -93,6 +93,7 @@ def binary_accuracy( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute `Accuracy`_ for binary tasks. @@ -124,6 +125,19 @@ def binary_accuracy( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` @@ -154,9 +168,9 @@ def binary_accuracy( """ if validate_args: - _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) - _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index) - preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index) + _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, input_format=input_format) + _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index, input_format=input_format) + preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index, input_format=input_format) tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average) return _accuracy_reduce(tp, fp, tn, fn, average="binary", multidim_average=multidim_average) @@ -170,6 +184,7 @@ def multiclass_accuracy( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute `Accuracy`_ for multiclass tasks. @@ -212,6 +227,19 @@ def multiclass_accuracy( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: The returned shape depends on the ``average`` and ``multidim_average`` arguments: @@ -260,9 +288,13 @@ def multiclass_accuracy( """ if validate_args: - _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) - _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) - preds, target = _multiclass_stat_scores_format(preds, target, top_k) + _multiclass_stat_scores_arg_validation( + num_classes, top_k, average, multidim_average, ignore_index, input_format=input_format + ) + _multiclass_stat_scores_tensor_validation( + preds, target, num_classes, multidim_average, ignore_index, input_format=input_format + ) + preds, target = _multiclass_stat_scores_format(preds, target, top_k, input_format=input_format) tp, fp, tn, fn = _multiclass_stat_scores_update( preds, target, num_classes, top_k, average, multidim_average, ignore_index ) @@ -278,6 +310,7 @@ def multilabel_accuracy( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute `Accuracy`_ for multilabel tasks. @@ -318,6 +351,19 @@ def multilabel_accuracy( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: The returned shape depends on the ``average`` and ``multidim_average`` arguments: @@ -364,9 +410,15 @@ def multilabel_accuracy( """ if validate_args: - _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) - _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) - preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) + _multilabel_stat_scores_arg_validation( + num_labels, threshold, average, multidim_average, ignore_index, input_format=input_format + ) + _multilabel_stat_scores_tensor_validation( + preds, target, num_labels, multidim_average, ignore_index, input_format=input_format + ) + preds, target = _multilabel_stat_scores_format( + preds, target, num_labels, threshold, ignore_index, input_format=input_format + ) tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) return _accuracy_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average, multilabel=True) @@ -383,6 +435,7 @@ def accuracy( top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute `Accuracy`_. @@ -414,7 +467,9 @@ def accuracy( task = ClassificationTask.from_str(task) if task == ClassificationTask.BINARY: - return binary_accuracy(preds, target, threshold, multidim_average, ignore_index, validate_args) + return binary_accuracy( + preds, target, threshold, multidim_average, ignore_index, validate_args, input_format=input_format + ) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError( @@ -423,7 +478,15 @@ def accuracy( if not isinstance(top_k, int): raise ValueError(f"Optional arg `top_k` must be type `int` when task is {task}. Got {type(top_k)}") return multiclass_accuracy( - preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args + preds, + target, + num_classes, + average, + top_k, + multidim_average, + ignore_index, + validate_args, + input_format=input_format, ) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): @@ -431,6 +494,14 @@ def accuracy( f"Optional arg `num_labels` must be type `int` when task is {task}. Got {type(num_labels)}" ) return multilabel_accuracy( - preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args + preds, + target, + num_labels, + threshold, + average, + multidim_average, + ignore_index, + validate_args, + input_format=input_format, ) raise ValueError(f"Not handled value: {task}") diff --git a/src/torchmetrics/functional/classification/auroc.py b/src/torchmetrics/functional/classification/auroc.py index acd94f4050e..e21baced5ca 100644 --- a/src/torchmetrics/functional/classification/auroc.py +++ b/src/torchmetrics/functional/classification/auroc.py @@ -73,8 +73,9 @@ def _binary_auroc_arg_validation( max_fpr: Optional[float] = None, thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", ) -> None: - _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) + _binary_precision_recall_curve_arg_validation(thresholds, ignore_index, input_format=input_format) if max_fpr is not None and not isinstance(max_fpr, float) and 0 < max_fpr <= 1: raise ValueError(f"Arguments `max_fpr` should be a float in range (0, 1], but got: {max_fpr}") @@ -113,6 +114,7 @@ def binary_auroc( thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", ) -> Tensor: r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for binary tasks. @@ -155,6 +157,17 @@ def binary_auroc( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: A single scalar with the auroc score @@ -170,9 +183,11 @@ def binary_auroc( """ if validate_args: - _binary_auroc_arg_validation(max_fpr, thresholds, ignore_index) - _binary_precision_recall_curve_tensor_validation(preds, target, ignore_index) - preds, target, thresholds = _binary_precision_recall_curve_format(preds, target, thresholds, ignore_index) + _binary_auroc_arg_validation(max_fpr, thresholds, ignore_index, input_format=input_format) + _binary_precision_recall_curve_tensor_validation(preds, target, ignore_index, input_format=input_format) + preds, target, thresholds = _binary_precision_recall_curve_format( + preds, target, thresholds, ignore_index, input_format=input_format + ) state = _binary_precision_recall_curve_update(preds, target, thresholds) return _binary_auroc_compute(state, thresholds, max_fpr) @@ -182,8 +197,9 @@ def _multiclass_auroc_arg_validation( average: Optional[Literal["macro", "weighted", "none"]] = "macro", thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", ) -> None: - _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) + _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index, input_format=input_format) allowed_average = ("macro", "weighted", "none", None) if average not in allowed_average: raise ValueError(f"Expected argument `average` to be one of {allowed_average} but got {average}") @@ -212,6 +228,7 @@ def multiclass_auroc( thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", ) -> Tensor: r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for multiclass tasks. @@ -260,6 +277,17 @@ def multiclass_auroc( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will be returned with auroc score per class. @@ -283,10 +311,12 @@ def multiclass_auroc( """ if validate_args: - _multiclass_auroc_arg_validation(num_classes, average, thresholds, ignore_index) - _multiclass_precision_recall_curve_tensor_validation(preds, target, num_classes, ignore_index) + _multiclass_auroc_arg_validation(num_classes, average, thresholds, ignore_index, input_format=input_format) + _multiclass_precision_recall_curve_tensor_validation( + preds, target, num_classes, ignore_index, input_format=input_format + ) preds, target, thresholds = _multiclass_precision_recall_curve_format( - preds, target, num_classes, thresholds, ignore_index + preds, target, num_classes, thresholds, ignore_index, input_format=input_format ) state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds) return _multiclass_auroc_compute(state, num_classes, average, thresholds) @@ -297,8 +327,9 @@ def _multilabel_auroc_arg_validation( average: Optional[Literal["micro", "macro", "weighted", "none"]], thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", ) -> None: - _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) + _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index, input_format=input_format) allowed_average = ("micro", "macro", "weighted", "none", None) if average not in allowed_average: raise ValueError(f"Expected argument `average` to be one of {allowed_average} but got {average}") @@ -340,6 +371,7 @@ def multilabel_auroc( thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", ) -> Tensor: r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for multilabel tasks. @@ -389,6 +421,17 @@ def multilabel_auroc( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will be returned with auroc score per class. @@ -415,10 +458,17 @@ def multilabel_auroc( """ if validate_args: - _multilabel_auroc_arg_validation(num_labels, average, thresholds, ignore_index) - _multilabel_precision_recall_curve_tensor_validation(preds, target, num_labels, ignore_index) + _multilabel_auroc_arg_validation(num_labels, average, thresholds, ignore_index, input_format=input_format) + _multilabel_precision_recall_curve_tensor_validation( + preds, target, num_labels, ignore_index, input_format=input_format + ) preds, target, thresholds = _multilabel_precision_recall_curve_format( - preds, target, num_labels, thresholds, ignore_index + preds, + target, + num_labels, + thresholds, + ignore_index, + input_format=input_format, ) state = _multilabel_precision_recall_curve_update(preds, target, num_labels, thresholds) return _multilabel_auroc_compute(state, num_labels, average, thresholds, ignore_index) @@ -435,6 +485,7 @@ def auroc( max_fpr: Optional[float] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", ) -> Optional[Tensor]: r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_). @@ -467,13 +518,17 @@ def auroc( """ task = ClassificationTask.from_str(task) if task == ClassificationTask.BINARY: - return binary_auroc(preds, target, max_fpr, thresholds, ignore_index, validate_args) + return binary_auroc(preds, target, max_fpr, thresholds, ignore_index, validate_args, input_format=input_format) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") - return multiclass_auroc(preds, target, num_classes, average, thresholds, ignore_index, validate_args) + return multiclass_auroc( + preds, target, num_classes, average, thresholds, ignore_index, validate_args, input_format=input_format + ) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") - return multilabel_auroc(preds, target, num_labels, average, thresholds, ignore_index, validate_args) + return multilabel_auroc( + preds, target, num_labels, average, thresholds, ignore_index, validate_args, input_format=input_format + ) return None diff --git a/src/torchmetrics/functional/classification/average_precision.py b/src/torchmetrics/functional/classification/average_precision.py index 93002bb6d2b..a8c2af5768f 100644 --- a/src/torchmetrics/functional/classification/average_precision.py +++ b/src/torchmetrics/functional/classification/average_precision.py @@ -81,6 +81,7 @@ def binary_average_precision( thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", ) -> Tensor: r"""Compute the average precision (AP) score for binary tasks. @@ -127,6 +128,17 @@ def binary_average_precision( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: A single scalar with the average precision score @@ -142,9 +154,11 @@ def binary_average_precision( """ if validate_args: - _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) - _binary_precision_recall_curve_tensor_validation(preds, target, ignore_index) - preds, target, thresholds = _binary_precision_recall_curve_format(preds, target, thresholds, ignore_index) + _binary_precision_recall_curve_arg_validation(thresholds, ignore_index, input_format=input_format) + _binary_precision_recall_curve_tensor_validation(preds, target, ignore_index, input_format=input_format) + preds, target, thresholds = _binary_precision_recall_curve_format( + preds, target, thresholds, ignore_index, input_format=input_format + ) state = _binary_precision_recall_curve_update(preds, target, thresholds) return _binary_average_precision_compute(state, thresholds) @@ -154,8 +168,9 @@ def _multiclass_average_precision_arg_validation( average: Optional[Literal["macro", "weighted", "none"]] = "macro", thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", ) -> None: - _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) + _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index, input_format=input_format) allowed_average = ("macro", "weighted", "none", None) if average not in allowed_average: raise ValueError(f"Expected argument `average` to be one of {allowed_average} but got {average}") @@ -184,6 +199,7 @@ def multiclass_average_precision( thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", ) -> Tensor: r"""Compute the average precision (AP) score for multiclass tasks. @@ -237,6 +253,17 @@ def multiclass_average_precision( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will be returned with AP score per class. @@ -260,10 +287,14 @@ def multiclass_average_precision( """ if validate_args: - _multiclass_average_precision_arg_validation(num_classes, average, thresholds, ignore_index) - _multiclass_precision_recall_curve_tensor_validation(preds, target, num_classes, ignore_index) + _multiclass_average_precision_arg_validation( + num_classes, average, thresholds, ignore_index, input_format=input_format + ) + _multiclass_precision_recall_curve_tensor_validation( + preds, target, num_classes, ignore_index, input_format=input_format + ) preds, target, thresholds = _multiclass_precision_recall_curve_format( - preds, target, num_classes, thresholds, ignore_index + preds, target, num_classes, thresholds, ignore_index, input_format=input_format ) state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds) return _multiclass_average_precision_compute(state, num_classes, average, thresholds) @@ -274,8 +305,9 @@ def _multilabel_average_precision_arg_validation( average: Optional[Literal["micro", "macro", "weighted", "none"]], thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", ) -> None: - _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) + _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index, input_format=input_format) allowed_average = ("micro", "macro", "weighted", "none", None) if average not in allowed_average: raise ValueError(f"Expected argument `average` to be one of {allowed_average} but got {average}") @@ -317,6 +349,7 @@ def multilabel_average_precision( thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", ) -> Tensor: r"""Compute the average precision (AP) score for multilabel tasks. @@ -371,6 +404,17 @@ def multilabel_average_precision( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will be returned with AP score per class. @@ -397,10 +441,14 @@ def multilabel_average_precision( """ if validate_args: - _multilabel_average_precision_arg_validation(num_labels, average, thresholds, ignore_index) - _multilabel_precision_recall_curve_tensor_validation(preds, target, num_labels, ignore_index) + _multilabel_average_precision_arg_validation( + num_labels, average, thresholds, ignore_index, input_format=input_format + ) + _multilabel_precision_recall_curve_tensor_validation( + preds, target, num_labels, ignore_index, input_format=input_format + ) preds, target, thresholds = _multilabel_precision_recall_curve_format( - preds, target, num_labels, thresholds, ignore_index + preds, target, num_labels, thresholds, ignore_index, input_format=input_format ) state = _multilabel_precision_recall_curve_update(preds, target, num_labels, thresholds) return _multilabel_average_precision_compute(state, num_labels, average, thresholds, ignore_index) diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index c51770ae7d6..e37ba083067 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -17,7 +17,7 @@ from torch import Tensor from typing_extensions import Literal -from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.utilities.checks import _check_same_shape, _check_valid_input_format_type from torchmetrics.utilities.data import _bincount from torchmetrics.utilities.enums import ClassificationTask from torchmetrics.utilities.prints import rank_zero_warn @@ -63,6 +63,7 @@ def _binary_confusion_matrix_arg_validation( threshold: float = 0.5, ignore_index: Optional[int] = None, normalize: Optional[Literal["true", "pred", "all", "none"]] = None, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> None: """Validate non tensor input. @@ -78,10 +79,14 @@ def _binary_confusion_matrix_arg_validation( allowed_normalize = ("true", "pred", "all", "none", None) if normalize not in allowed_normalize: raise ValueError(f"Expected argument `normalize` to be one of {allowed_normalize}, but got {normalize}.") + _check_valid_input_format_type(input_format) def _binary_confusion_matrix_tensor_validation( - preds: Tensor, target: Tensor, ignore_index: Optional[int] = None + preds: Tensor, + target: Tensor, + ignore_index: Optional[int] = None, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> None: """Validate tensor input. @@ -106,13 +111,18 @@ def _binary_confusion_matrix_tensor_validation( ) # If preds is label tensor, also check that it only contains {0,1} values - if not preds.is_floating_point(): + if not preds.is_floating_point() or input_format == "labels": unique_values = torch.unique(preds) if torch.any((unique_values != 0) & (unique_values != 1)): raise RuntimeError( f"Detected the following values in `preds`: {unique_values} but expected only" " the following values [0,1] since preds is a label tensor." ) + if input_format == "probs" and not torch.all((preds >= 0) * (preds <= 1)): + raise ValueError( + "Expected argument `preds` to be a tensor with values in the [0,1] range," + f" but got tensor with values {preds}" + ) def _binary_confusion_matrix_format( @@ -121,6 +131,7 @@ def _binary_confusion_matrix_format( threshold: float = 0.5, ignore_index: Optional[int] = None, convert_to_labels: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tuple[Tensor, Tensor]: """Convert all input to label format. @@ -136,12 +147,12 @@ def _binary_confusion_matrix_format( preds = preds[idx] target = target[idx] - if preds.is_floating_point(): - if not torch.all((preds >= 0) * (preds <= 1)): - # preds is logits, convert with sigmoid - preds = preds.sigmoid() - if convert_to_labels: - preds = preds > threshold + if input_format == "logits": + preds = preds.sigmoid() + if preds.is_floating_point() and input_format == "auto" and not torch.all((preds >= 0) * (preds <= 1)): + preds = preds.sigmoid() + if convert_to_labels and input_format not in ("labels", "none"): + preds = preds > threshold return preds, target @@ -171,6 +182,7 @@ def binary_confusion_matrix( normalize: Optional[Literal["true", "pred", "all", "none"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute the `confusion matrix`_ for binary tasks. @@ -197,6 +209,19 @@ def binary_confusion_matrix( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: A ``[2, 2]`` tensor @@ -220,9 +245,9 @@ def binary_confusion_matrix( """ if validate_args: - _binary_confusion_matrix_arg_validation(threshold, ignore_index, normalize) - _binary_confusion_matrix_tensor_validation(preds, target, ignore_index) - preds, target = _binary_confusion_matrix_format(preds, target, threshold, ignore_index) + _binary_confusion_matrix_arg_validation(threshold, ignore_index, normalize, input_format=input_format) + _binary_confusion_matrix_tensor_validation(preds, target, ignore_index, input_format=input_format) + preds, target = _binary_confusion_matrix_format(preds, target, threshold, ignore_index, input_format=input_format) confmat = _binary_confusion_matrix_update(preds, target) return _binary_confusion_matrix_compute(confmat, normalize) @@ -231,6 +256,7 @@ def _multiclass_confusion_matrix_arg_validation( num_classes: int, ignore_index: Optional[int] = None, normalize: Optional[Literal["true", "pred", "all", "none"]] = None, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> None: """Validate non tensor input. @@ -246,10 +272,15 @@ def _multiclass_confusion_matrix_arg_validation( allowed_normalize = ("true", "pred", "all", "none", None) if normalize not in allowed_normalize: raise ValueError(f"Expected argument `normalize` to be one of {allowed_normalize}, but got {normalize}.") + _check_valid_input_format_type(input_format) def _multiclass_confusion_matrix_tensor_validation( - preds: Tensor, target: Tensor, num_classes: int, ignore_index: Optional[int] = None + preds: Tensor, + target: Tensor, + num_classes: int, + ignore_index: Optional[int] = None, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> None: """Validate tensor input. @@ -294,7 +325,7 @@ def _multiclass_confusion_matrix_tensor_validation( f"{num_unique_values} in `target`." ) - if not preds.is_floating_point(): + if not preds.is_floating_point() or input_format == "labels": num_unique_values = len(torch.unique(preds)) if num_unique_values > num_classes: raise RuntimeError( @@ -302,12 +333,19 @@ def _multiclass_confusion_matrix_tensor_validation( f"{num_classes} but found {num_unique_values} in `preds`." ) + if input_format == "probs" and not torch.all((preds >= 0) * (preds <= 1)): + raise ValueError( + "Expected argument `preds` to be a tensor with values in the [0,1] range," + f" but got tensor with values {preds}" + ) + def _multiclass_confusion_matrix_format( preds: Tensor, target: Tensor, ignore_index: Optional[int] = None, convert_to_labels: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tuple[Tensor, Tensor]: """Convert all input to label format. @@ -316,7 +354,12 @@ def _multiclass_confusion_matrix_format( """ # Apply argmax if we have one more dimension - if preds.ndim == target.ndim + 1 and convert_to_labels: + if ( + input_format == "logits" + or input_format == "probs" + or (input_format == "auto" and preds.ndim == target.ndim + 1) + and convert_to_labels + ): preds = preds.argmax(dim=1) preds = preds.flatten() if convert_to_labels else torch.movedim(preds, 1, -1).reshape(-1, preds.shape[1]) @@ -355,6 +398,7 @@ def multiclass_confusion_matrix( normalize: Optional[Literal["true", "pred", "all", "none"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute the `confusion matrix`_ for multiclass tasks. @@ -381,6 +425,19 @@ def multiclass_confusion_matrix( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: A ``[num_classes, num_classes]`` tensor @@ -409,9 +466,11 @@ def multiclass_confusion_matrix( """ if validate_args: - _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index, normalize) - _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index) - preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index) + _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index, normalize, input_format=input_format) + _multiclass_confusion_matrix_tensor_validation( + preds, target, num_classes, ignore_index, input_format=input_format + ) + preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index, input_format=input_format) confmat = _multiclass_confusion_matrix_update(preds, target, num_classes) return _multiclass_confusion_matrix_compute(confmat, normalize) @@ -421,6 +480,7 @@ def _multilabel_confusion_matrix_arg_validation( threshold: float = 0.5, ignore_index: Optional[int] = None, normalize: Optional[Literal["true", "pred", "all", "none"]] = None, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> None: """Validate non tensor input. @@ -439,10 +499,15 @@ def _multilabel_confusion_matrix_arg_validation( allowed_normalize = ("true", "pred", "all", "none", None) if normalize not in allowed_normalize: raise ValueError(f"Expected argument `normalize` to be one of {allowed_normalize}, but got {normalize}.") + _check_valid_input_format_type(input_format) def _multilabel_confusion_matrix_tensor_validation( - preds: Tensor, target: Tensor, num_labels: int, ignore_index: Optional[int] = None + preds: Tensor, + target: Tensor, + num_labels: int, + ignore_index: Optional[int] = None, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> None: """Validate tensor input. @@ -474,7 +539,7 @@ def _multilabel_confusion_matrix_tensor_validation( ) # If preds is label tensor, also check that it only contains [0,1] values - if not preds.is_floating_point(): + if not preds.is_floating_point() or input_format == "labels": unique_values = torch.unique(preds) if torch.any((unique_values != 0) & (unique_values != 1)): raise RuntimeError( @@ -482,6 +547,12 @@ def _multilabel_confusion_matrix_tensor_validation( " the following values [0,1] since preds is a label tensor." ) + if input_format == "probs" and not torch.all((preds >= 0) * (preds <= 1)): + raise ValueError( + "Expected argument `preds` to be a tensor with values in the [0,1] range," + f" but got tensor with values {preds}" + ) + def _multilabel_confusion_matrix_format( preds: Tensor, @@ -490,6 +561,7 @@ def _multilabel_confusion_matrix_format( threshold: float = 0.5, ignore_index: Optional[int] = None, should_threshold: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tuple[Tensor, Tensor]: """Convert all input to label format. @@ -498,11 +570,12 @@ def _multilabel_confusion_matrix_format( - Mask all elements that should be ignored with negative numbers for later filtration """ - if preds.is_floating_point(): - if not torch.all((preds >= 0) * (preds <= 1)): - preds = preds.sigmoid() - if should_threshold: - preds = preds > threshold + if input_format == "logits": + preds = preds.sigmoid() + if preds.is_floating_point() and input_format == "auto" and not torch.all((preds >= 0) * (preds <= 1)): + preds = preds.sigmoid() + if should_threshold and input_format != "labels": + preds = preds > threshold preds = torch.movedim(preds, 1, -1).reshape(-1, num_labels) target = torch.movedim(target, 1, -1).reshape(-1, num_labels) @@ -545,6 +618,7 @@ def multilabel_confusion_matrix( normalize: Optional[Literal["true", "pred", "all", "none"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute the `confusion matrix`_ for multilabel tasks. @@ -572,6 +646,19 @@ def multilabel_confusion_matrix( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: A ``[num_labels, 2, 2]`` tensor @@ -597,9 +684,15 @@ def multilabel_confusion_matrix( """ if validate_args: - _multilabel_confusion_matrix_arg_validation(num_labels, threshold, ignore_index, normalize) - _multilabel_confusion_matrix_tensor_validation(preds, target, num_labels, ignore_index) - preds, target = _multilabel_confusion_matrix_format(preds, target, num_labels, threshold, ignore_index) + _multilabel_confusion_matrix_arg_validation( + num_labels, threshold, ignore_index, normalize, input_format=input_format + ) + _multilabel_confusion_matrix_tensor_validation( + preds, target, num_labels, ignore_index, input_format=input_format + ) + preds, target = _multilabel_confusion_matrix_format( + preds, target, num_labels, threshold, ignore_index, input_format=input_format + ) confmat = _multilabel_confusion_matrix_update(preds, target, num_labels) return _multilabel_confusion_matrix_compute(confmat, normalize) @@ -614,6 +707,7 @@ def confusion_matrix( normalize: Optional[Literal["true", "pred", "all", "none"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute the `confusion matrix`_. @@ -653,13 +747,19 @@ def confusion_matrix( """ task = ClassificationTask.from_str(task) if task == ClassificationTask.BINARY: - return binary_confusion_matrix(preds, target, threshold, normalize, ignore_index, validate_args) + return binary_confusion_matrix( + preds, target, threshold, normalize, ignore_index, validate_args, input_format=input_format + ) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") - return multiclass_confusion_matrix(preds, target, num_classes, normalize, ignore_index, validate_args) + return multiclass_confusion_matrix( + preds, target, num_classes, normalize, ignore_index, validate_args, input_format=input_format + ) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") - return multilabel_confusion_matrix(preds, target, num_labels, threshold, normalize, ignore_index, validate_args) + return multilabel_confusion_matrix( + preds, target, num_labels, threshold, normalize, ignore_index, validate_args, input_format=input_format + ) raise ValueError(f"Task {task} not supported.") diff --git a/src/torchmetrics/functional/classification/f_beta.py b/src/torchmetrics/functional/classification/f_beta.py index 0f0e883266c..39d20c1efa6 100644 --- a/src/torchmetrics/functional/classification/f_beta.py +++ b/src/torchmetrics/functional/classification/f_beta.py @@ -62,10 +62,11 @@ def _binary_fbeta_score_arg_validation( threshold: float = 0.5, multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> None: if not (isinstance(beta, float) and beta > 0): raise ValueError(f"Expected argument `beta` to be a float larger than 0, but got {beta}.") - _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) + _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, input_format=input_format) def binary_fbeta_score( @@ -76,6 +77,7 @@ def binary_fbeta_score( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute `F-score`_ metric for binary tasks. @@ -106,6 +108,19 @@ def binary_fbeta_score( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` @@ -136,9 +151,9 @@ def binary_fbeta_score( """ if validate_args: - _binary_fbeta_score_arg_validation(beta, threshold, multidim_average, ignore_index) - _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index) - preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index) + _binary_fbeta_score_arg_validation(beta, threshold, multidim_average, ignore_index, input_format=input_format) + _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index, input_format=input_format) + preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index, input_format=input_format) tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average) return _fbeta_reduce(tp, fp, tn, fn, beta, average="binary", multidim_average=multidim_average) @@ -150,10 +165,13 @@ def _multiclass_fbeta_score_arg_validation( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> None: if not (isinstance(beta, float) and beta > 0): raise ValueError(f"Expected argument `beta` to be a float larger than 0, but got {beta}.") - _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) + _multiclass_stat_scores_arg_validation( + num_classes, top_k, average, multidim_average, ignore_index, input_format=input_format + ) def multiclass_fbeta_score( @@ -166,6 +184,7 @@ def multiclass_fbeta_score( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute `F-score`_ metric for multiclass tasks. @@ -206,6 +225,19 @@ def multiclass_fbeta_score( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: The returned shape depends on the ``average`` and ``multidim_average`` arguments: @@ -254,9 +286,13 @@ def multiclass_fbeta_score( """ if validate_args: - _multiclass_fbeta_score_arg_validation(beta, num_classes, top_k, average, multidim_average, ignore_index) - _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) - preds, target = _multiclass_stat_scores_format(preds, target, top_k) + _multiclass_fbeta_score_arg_validation( + beta, num_classes, top_k, average, multidim_average, ignore_index, input_format=input_format + ) + _multiclass_stat_scores_tensor_validation( + preds, target, num_classes, multidim_average, ignore_index, input_format=input_format + ) + preds, target = _multiclass_stat_scores_format(preds, target, top_k, input_format=input_format) tp, fp, tn, fn = _multiclass_stat_scores_update( preds, target, num_classes, top_k, average, multidim_average, ignore_index ) @@ -270,10 +306,13 @@ def _multilabel_fbeta_score_arg_validation( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> None: if not (isinstance(beta, float) and beta > 0): raise ValueError(f"Expected argument `beta` to be a float larger than 0, but got {beta}.") - _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) + _multilabel_stat_scores_arg_validation( + num_labels, threshold, average, multidim_average, ignore_index, input_format=input_format + ) def multilabel_fbeta_score( @@ -286,6 +325,7 @@ def multilabel_fbeta_score( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute `F-score`_ metric for multilabel tasks. @@ -325,6 +365,19 @@ def multilabel_fbeta_score( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: The returned shape depends on the ``average`` and ``multidim_average`` arguments: @@ -371,9 +424,15 @@ def multilabel_fbeta_score( """ if validate_args: - _multilabel_fbeta_score_arg_validation(beta, num_labels, threshold, average, multidim_average, ignore_index) - _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) - preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) + _multilabel_fbeta_score_arg_validation( + beta, num_labels, threshold, average, multidim_average, ignore_index, input_format=input_format + ) + _multilabel_stat_scores_tensor_validation( + preds, target, num_labels, multidim_average, ignore_index, input_format=input_format + ) + preds, target = _multilabel_stat_scores_format( + preds, target, num_labels, threshold, ignore_index, input_format=input_format + ) tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) return _fbeta_reduce(tp, fp, tn, fn, beta, average=average, multidim_average=multidim_average, multilabel=True) @@ -385,6 +444,7 @@ def binary_f1_score( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute F-1 score for binary tasks. @@ -413,6 +473,19 @@ def binary_f1_score( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` @@ -450,6 +523,7 @@ def binary_f1_score( multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args, + input_format=input_format, ) @@ -462,6 +536,7 @@ def multiclass_f1_score( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute F-1 score for multiclass tasks. @@ -500,6 +575,19 @@ def multiclass_f1_score( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: The returned shape depends on the ``average`` and ``multidim_average`` arguments: @@ -557,6 +645,7 @@ def multiclass_f1_score( multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args, + input_format=input_format, ) @@ -569,6 +658,7 @@ def multilabel_f1_score( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute F-1 score for multilabel tasks. @@ -606,6 +696,19 @@ def multilabel_f1_score( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: The returned shape depends on the ``average`` and ``multidim_average`` arguments: @@ -661,6 +764,7 @@ def multilabel_f1_score( multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args, + input_format=input_format, ) @@ -677,6 +781,7 @@ def fbeta_score( top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute `F-score`_ metric. @@ -702,20 +807,40 @@ def fbeta_score( task = ClassificationTask.from_str(task) assert multidim_average is not None # noqa: S101 # needed for mypy if task == ClassificationTask.BINARY: - return binary_fbeta_score(preds, target, beta, threshold, multidim_average, ignore_index, validate_args) + return binary_fbeta_score( + preds, target, beta, threshold, multidim_average, ignore_index, validate_args, input_format=input_format + ) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") if not isinstance(top_k, int): raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`") return multiclass_fbeta_score( - preds, target, beta, num_classes, average, top_k, multidim_average, ignore_index, validate_args + preds, + target, + beta, + num_classes, + average, + top_k, + multidim_average, + ignore_index, + validate_args, + input_format=input_format, ) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") return multilabel_fbeta_score( - preds, target, beta, num_labels, threshold, average, multidim_average, ignore_index, validate_args + preds, + target, + beta, + num_labels, + threshold, + average, + multidim_average, + ignore_index, + validate_args, + input_format=input_format, ) raise ValueError(f"Unsupported task `{task}` passed.") @@ -732,6 +857,7 @@ def f1_score( top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute F-1 score. @@ -756,19 +882,37 @@ def f1_score( task = ClassificationTask.from_str(task) assert multidim_average is not None # noqa: S101 # needed for mypy if task == ClassificationTask.BINARY: - return binary_f1_score(preds, target, threshold, multidim_average, ignore_index, validate_args) + return binary_f1_score( + preds, target, threshold, multidim_average, ignore_index, validate_args, input_format=input_format + ) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") if not isinstance(top_k, int): raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`") return multiclass_f1_score( - preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args + preds, + target, + num_classes, + average, + top_k, + multidim_average, + ignore_index, + validate_args, + input_format=input_format, ) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") return multilabel_f1_score( - preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args + preds, + target, + num_labels, + threshold, + average, + multidim_average, + ignore_index, + validate_args, + input_format=input_format, ) raise ValueError(f"Unsupported task `{task}` passed.") diff --git a/src/torchmetrics/functional/classification/hamming.py b/src/torchmetrics/functional/classification/hamming.py index ed47ce56982..2f3a01fdba4 100644 --- a/src/torchmetrics/functional/classification/hamming.py +++ b/src/torchmetrics/functional/classification/hamming.py @@ -90,6 +90,7 @@ def binary_hamming_distance( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute the average `Hamming distance`_ (also known as Hamming loss) for binary tasks. @@ -122,6 +123,19 @@ def binary_hamming_distance( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` @@ -152,9 +166,9 @@ def binary_hamming_distance( """ if validate_args: - _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) - _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index) - preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index) + _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, input_format=input_format) + _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index, input_format=input_format) + preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index, input_format=input_format) tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average) return _hamming_distance_reduce(tp, fp, tn, fn, average="binary", multidim_average=multidim_average) @@ -168,6 +182,7 @@ def multiclass_hamming_distance( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute the average `Hamming distance`_ (also known as Hamming loss) for multiclass tasks. @@ -211,6 +226,19 @@ def multiclass_hamming_distance( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: The returned shape depends on the ``average`` and ``multidim_average`` arguments: @@ -259,9 +287,13 @@ def multiclass_hamming_distance( """ if validate_args: - _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) - _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) - preds, target = _multiclass_stat_scores_format(preds, target, top_k) + _multiclass_stat_scores_arg_validation( + num_classes, top_k, average, multidim_average, ignore_index, input_format=input_format + ) + _multiclass_stat_scores_tensor_validation( + preds, target, num_classes, multidim_average, ignore_index, input_format=input_format + ) + preds, target = _multiclass_stat_scores_format(preds, target, top_k, input_format=input_format) tp, fp, tn, fn = _multiclass_stat_scores_update( preds, target, num_classes, top_k, average, multidim_average, ignore_index ) @@ -277,6 +309,7 @@ def multilabel_hamming_distance( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute the average `Hamming distance`_ (also known as Hamming loss) for multilabel tasks. @@ -318,6 +351,19 @@ def multilabel_hamming_distance( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: The returned shape depends on the ``average`` and ``multidim_average`` arguments: @@ -364,9 +410,15 @@ def multilabel_hamming_distance( """ if validate_args: - _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) - _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) - preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) + _multilabel_stat_scores_arg_validation( + num_labels, threshold, average, multidim_average, ignore_index, input_format=input_format + ) + _multilabel_stat_scores_tensor_validation( + preds, target, num_labels, multidim_average, ignore_index, input_format=input_format + ) + preds, target = _multilabel_stat_scores_format( + preds, target, num_labels, threshold, ignore_index, input_format=input_format + ) tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) return _hamming_distance_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average, multilabel=True) @@ -383,6 +435,7 @@ def hamming_distance( top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute the average `Hamming distance`_ (also known as Hamming loss). @@ -411,19 +464,37 @@ def hamming_distance( task = ClassificationTask.from_str(task) assert multidim_average is not None # noqa: S101 # needed for mypy if task == ClassificationTask.BINARY: - return binary_hamming_distance(preds, target, threshold, multidim_average, ignore_index, validate_args) + return binary_hamming_distance( + preds, target, threshold, multidim_average, ignore_index, validate_args, input_format=input_format + ) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") if not isinstance(top_k, int): raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`") return multiclass_hamming_distance( - preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args + preds, + target, + num_classes, + average, + top_k, + multidim_average, + ignore_index, + validate_args, + input_format=input_format, ) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") return multilabel_hamming_distance( - preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args + preds, + target, + num_labels, + threshold, + average, + multidim_average, + ignore_index, + validate_args, + input_format=input_format, ) raise ValueError(f"Not handled value: {task}") diff --git a/src/torchmetrics/functional/classification/precision_recall.py b/src/torchmetrics/functional/classification/precision_recall.py index ac94ce35365..d5ef1cce14d 100644 --- a/src/torchmetrics/functional/classification/precision_recall.py +++ b/src/torchmetrics/functional/classification/precision_recall.py @@ -64,6 +64,7 @@ def binary_precision( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute `Precision`_ for binary tasks. @@ -94,6 +95,19 @@ def binary_precision( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` @@ -124,9 +138,9 @@ def binary_precision( """ if validate_args: - _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) - _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index) - preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index) + _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, input_format=input_format) + _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index, input_format=input_format) + preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index, input_format=input_format) tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average) return _precision_recall_reduce("precision", tp, fp, tn, fn, average="binary", multidim_average=multidim_average) @@ -140,6 +154,7 @@ def multiclass_precision( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute `Precision`_ for multiclass tasks. @@ -181,6 +196,19 @@ def multiclass_precision( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: The returned shape depends on the ``average`` and ``multidim_average`` arguments: @@ -229,9 +257,13 @@ def multiclass_precision( """ if validate_args: - _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) - _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) - preds, target = _multiclass_stat_scores_format(preds, target, top_k) + _multiclass_stat_scores_arg_validation( + num_classes, top_k, average, multidim_average, ignore_index, input_format=input_format + ) + _multiclass_stat_scores_tensor_validation( + preds, target, num_classes, multidim_average, ignore_index, input_format=input_format + ) + preds, target = _multiclass_stat_scores_format(preds, target, top_k, input_format=input_format) tp, fp, tn, fn = _multiclass_stat_scores_update( preds, target, num_classes, top_k, average, multidim_average, ignore_index ) @@ -247,6 +279,7 @@ def multilabel_precision( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute `Precision`_ for multilabel tasks. @@ -286,6 +319,19 @@ def multilabel_precision( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: The returned shape depends on the ``average`` and ``multidim_average`` arguments: @@ -332,9 +378,15 @@ def multilabel_precision( """ if validate_args: - _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) - _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) - preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) + _multilabel_stat_scores_arg_validation( + num_labels, threshold, average, multidim_average, ignore_index, input_format=input_format + ) + _multilabel_stat_scores_tensor_validation( + preds, target, num_labels, multidim_average, ignore_index, input_format=input_format + ) + preds, target = _multilabel_stat_scores_format( + preds, target, num_labels, threshold, ignore_index, input_format=input_format + ) tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) return _precision_recall_reduce( "precision", tp, fp, tn, fn, average=average, multidim_average=multidim_average, multilabel=True @@ -348,6 +400,7 @@ def binary_recall( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute `Recall`_ for binary tasks. @@ -378,6 +431,19 @@ def binary_recall( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` @@ -408,9 +474,9 @@ def binary_recall( """ if validate_args: - _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) - _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index) - preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index) + _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, input_format=input_format) + _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index, input_format=input_format) + preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index, input_format=input_format) tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average) return _precision_recall_reduce("recall", tp, fp, tn, fn, average="binary", multidim_average=multidim_average) @@ -424,6 +490,7 @@ def multiclass_recall( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute `Recall`_ for multiclass tasks. @@ -465,6 +532,19 @@ def multiclass_recall( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: The returned shape depends on the ``average`` and ``multidim_average`` arguments: @@ -513,9 +593,13 @@ def multiclass_recall( """ if validate_args: - _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) - _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) - preds, target = _multiclass_stat_scores_format(preds, target, top_k) + _multiclass_stat_scores_arg_validation( + num_classes, top_k, average, multidim_average, ignore_index, input_format=input_format + ) + _multiclass_stat_scores_tensor_validation( + preds, target, num_classes, multidim_average, ignore_index, input_format=input_format + ) + preds, target = _multiclass_stat_scores_format(preds, target, top_k, input_format=input_format) tp, fp, tn, fn = _multiclass_stat_scores_update( preds, target, num_classes, top_k, average, multidim_average, ignore_index ) @@ -531,6 +615,7 @@ def multilabel_recall( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute `Recall`_ for multilabel tasks. @@ -570,6 +655,19 @@ def multilabel_recall( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: The returned shape depends on the ``average`` and ``multidim_average`` arguments: @@ -616,9 +714,15 @@ def multilabel_recall( """ if validate_args: - _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) - _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) - preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) + _multilabel_stat_scores_arg_validation( + num_labels, threshold, average, multidim_average, ignore_index, input_format=input_format + ) + _multilabel_stat_scores_tensor_validation( + preds, target, num_labels, multidim_average, ignore_index, input_format=input_format + ) + preds, target = _multilabel_stat_scores_format( + preds, target, num_labels, threshold, ignore_index, input_format=input_format + ) tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) return _precision_recall_reduce( "recall", tp, fp, tn, fn, average=average, multidim_average=multidim_average, multilabel=True @@ -637,6 +741,7 @@ def precision( top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute `Precision`_. @@ -664,20 +769,38 @@ def precision( """ assert multidim_average is not None # noqa: S101 # needed for mypy if task == ClassificationTask.BINARY: - return binary_precision(preds, target, threshold, multidim_average, ignore_index, validate_args) + return binary_precision( + preds, target, threshold, multidim_average, ignore_index, validate_args, input_format=input_format + ) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") if not isinstance(top_k, int): raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`") return multiclass_precision( - preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args + preds, + target, + num_classes, + average, + top_k, + multidim_average, + ignore_index, + validate_args, + input_format=input_format, ) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") return multilabel_precision( - preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args + preds, + target, + num_labels, + threshold, + average, + multidim_average, + ignore_index, + validate_args, + input_format=input_format, ) raise ValueError( f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" @@ -696,6 +819,7 @@ def recall( top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute `Recall`_. @@ -724,19 +848,37 @@ def recall( task = ClassificationTask.from_str(task) assert multidim_average is not None # noqa: S101 # needed for mypy if task == ClassificationTask.BINARY: - return binary_recall(preds, target, threshold, multidim_average, ignore_index, validate_args) + return binary_recall( + preds, target, threshold, multidim_average, ignore_index, validate_args, input_format=input_format + ) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") if not isinstance(top_k, int): raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`") return multiclass_recall( - preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args + preds, + target, + num_classes, + average, + top_k, + multidim_average, + ignore_index, + validate_args, + input_format=input_format, ) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") return multilabel_recall( - preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args + preds, + target, + num_labels, + threshold, + average, + multidim_average, + ignore_index, + validate_args, + input_format=input_format, ) raise ValueError(f"Not handled value: {task}") diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index 64958267737..4df95c2b70b 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import List, Optional, Sequence, Tuple, Union import torch @@ -19,7 +18,7 @@ from torch.nn import functional as F # noqa: N812 from typing_extensions import Literal -from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.utilities.checks import _check_same_shape, _check_valid_input_format_type from torchmetrics.utilities.compute import _safe_divide, interp from torchmetrics.utilities.data import _bincount, _cumsum from torchmetrics.utilities.enums import ClassificationTask @@ -94,6 +93,7 @@ def _adjust_threshold_arg( def _binary_precision_recall_curve_arg_validation( thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", ) -> None: """Validate non tensor input. @@ -121,9 +121,14 @@ def _binary_precision_recall_curve_arg_validation( if ignore_index is not None and not isinstance(ignore_index, int): raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + _check_valid_input_format_type(input_format, options=("auto", "probs", "logits", "none")) + def _binary_precision_recall_curve_tensor_validation( - preds: Tensor, target: Tensor, ignore_index: Optional[int] = None + preds: Tensor, + target: Tensor, + ignore_index: Optional[int] = None, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", ) -> None: """Validate tensor input. @@ -146,6 +151,12 @@ def _binary_precision_recall_curve_tensor_validation( f" but got tensor with dtype {preds.dtype}" ) + if input_format == "probs" and not torch.all((preds >= 0) * (preds <= 1)): + raise ValueError( + "Expected argument `preds` to be a tensor with values in the [0,1] range," + f" but got tensor with values {preds}" + ) + # Check that target only contains {0,1} values or value in ignore_index unique_values = torch.unique(target) if ignore_index is None: @@ -164,6 +175,7 @@ def _binary_precision_recall_curve_format( target: Tensor, thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: """Convert all input to the right format. @@ -180,8 +192,9 @@ def _binary_precision_recall_curve_format( preds = preds[idx] target = target[idx] - if not torch.all((preds >= 0) * (preds <= 1)): + if input_format == "logits" or (input_format == "auto" and not torch.all((preds >= 0) * (preds <= 1))): preds = preds.sigmoid() + target = target.long() thresholds = _adjust_threshold_arg(thresholds, preds.device) return preds, target, thresholds @@ -289,6 +302,7 @@ def binary_precision_recall_curve( thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", ) -> Tuple[Tensor, Tensor, Tensor]: r"""Compute the precision-recall curve for binary tasks. @@ -329,6 +343,17 @@ def binary_precision_recall_curve( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: (tuple): a tuple of 3 tensors containing: @@ -352,9 +377,11 @@ def binary_precision_recall_curve( """ if validate_args: - _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) - _binary_precision_recall_curve_tensor_validation(preds, target, ignore_index) - preds, target, thresholds = _binary_precision_recall_curve_format(preds, target, thresholds, ignore_index) + _binary_precision_recall_curve_arg_validation(thresholds, ignore_index, input_format=input_format) + _binary_precision_recall_curve_tensor_validation(preds, target, ignore_index, input_format=input_format) + preds, target, thresholds = _binary_precision_recall_curve_format( + preds, target, thresholds, ignore_index, input_format=input_format + ) state = _binary_precision_recall_curve_update(preds, target, thresholds) return _binary_precision_recall_curve_compute(state, thresholds) @@ -364,6 +391,7 @@ def _multiclass_precision_recall_curve_arg_validation( thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, average: Optional[Literal["micro", "macro"]] = None, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", ) -> None: """Validate non tensor input. @@ -376,11 +404,15 @@ def _multiclass_precision_recall_curve_arg_validation( raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}") if average not in (None, "micro", "macro"): raise ValueError(f"Expected argument `average` to be one of None, 'micro' or 'macro', but got {average}") - _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) + _binary_precision_recall_curve_arg_validation(thresholds, ignore_index, input_format=input_format) def _multiclass_precision_recall_curve_tensor_validation( - preds: Tensor, target: Tensor, num_classes: int, ignore_index: Optional[int] = None + preds: Tensor, + target: Tensor, + num_classes: int, + ignore_index: Optional[int] = None, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", ) -> None: """Validate tensor input. @@ -399,6 +431,13 @@ def _multiclass_precision_recall_curve_tensor_validation( ) if not preds.is_floating_point(): raise ValueError(f"Expected `preds` to be a float tensor, but got {preds.dtype}") + + if input_format == "probs" and not torch.all((preds >= 0) * (preds <= 1)): + raise ValueError( + "Expected argument `preds` to be a tensor with values in the [0,1] range," + f" but got tensor with values {preds}" + ) + if preds.shape[1] != num_classes: raise ValueError( "Expected `preds.shape[1]` to be equal to the number of classes but" @@ -427,6 +466,7 @@ def _multiclass_precision_recall_curve_format( thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, average: Optional[Literal["micro", "macro"]] = None, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: """Convert all input to the right format. @@ -444,7 +484,7 @@ def _multiclass_precision_recall_curve_format( preds = preds[idx] target = target[idx] - if not torch.all((preds >= 0) * (preds <= 1)): + if input_format == "logits" or (input_format == "auto" and not torch.all((preds >= 0) * (preds <= 1))): preds = preds.softmax(1) if average == "micro": @@ -591,6 +631,7 @@ def multiclass_precision_recall_curve( average: Optional[Literal["micro", "macro"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: r"""Compute the precision-recall curve for multiclass tasks. @@ -639,6 +680,18 @@ def multiclass_precision_recall_curve( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + Returns: (tuple): a tuple of either 3 tensors or 3 lists containing @@ -688,8 +741,12 @@ def multiclass_precision_recall_curve( """ if validate_args: - _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index, average) - _multiclass_precision_recall_curve_tensor_validation(preds, target, num_classes, ignore_index) + _multiclass_precision_recall_curve_arg_validation( + num_classes, thresholds, ignore_index, average, input_format=input_format + ) + _multiclass_precision_recall_curve_tensor_validation( + preds, target, num_classes, ignore_index, input_format=input_format + ) preds, target, thresholds = _multiclass_precision_recall_curve_format( preds, target, @@ -697,6 +754,7 @@ def multiclass_precision_recall_curve( thresholds, ignore_index, average, + input_format, ) state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds, average) return _multiclass_precision_recall_curve_compute(state, num_classes, thresholds, average) @@ -706,6 +764,7 @@ def _multilabel_precision_recall_curve_arg_validation( num_labels: int, thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", ) -> None: """Validate non tensor input. @@ -714,11 +773,15 @@ def _multilabel_precision_recall_curve_arg_validation( - ``ignore_index`` has to be None or int """ - _multiclass_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) + _multiclass_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index, input_format=input_format) def _multilabel_precision_recall_curve_tensor_validation( - preds: Tensor, target: Tensor, num_labels: int, ignore_index: Optional[int] = None + preds: Tensor, + target: Tensor, + num_labels: int, + ignore_index: Optional[int] = None, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", ) -> None: """Validate tensor input. @@ -728,7 +791,7 @@ def _multilabel_precision_recall_curve_tensor_validation( - that the pred tensor is floating point """ - _binary_precision_recall_curve_tensor_validation(preds, target, ignore_index) + _binary_precision_recall_curve_tensor_validation(preds, target, ignore_index, input_format=input_format) if preds.shape[1] != num_labels: raise ValueError( "Expected both `target.shape[1]` and `preds.shape[1]` to be equal to the number of labels" @@ -742,6 +805,7 @@ def _multilabel_precision_recall_curve_format( num_labels: int, thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: """Convert all input to the right format. @@ -753,8 +817,10 @@ def _multilabel_precision_recall_curve_format( """ preds = preds.transpose(0, 1).reshape(num_labels, -1).T target = target.transpose(0, 1).reshape(num_labels, -1).T - if not torch.all((preds >= 0) * (preds <= 1)): + + if input_format == "logits" or (input_format == "auto" and not torch.all((preds >= 0) * (preds <= 1))): preds = preds.sigmoid() + target = target.long() thresholds = _adjust_threshold_arg(thresholds, preds.device) if ignore_index is not None and thresholds is not None: @@ -837,6 +903,7 @@ def multilabel_precision_recall_curve( thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: r"""Compute the precision-recall curve for multilabel tasks. @@ -878,6 +945,16 @@ def multilabel_precision_recall_curve( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str or bool specifying the format of the input preds tensor. Can be one of: + + - ``'auto'`` or ``True``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and do nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor before calculating the metric. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. No + transformation will be applied to the tensor, but values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor before calculating the metric. + - ``False``: will disable all input formatting. This is the fastest option but also the least safe. Returns: (tuple): a tuple of either 3 tensors or 3 lists containing @@ -926,10 +1003,14 @@ def multilabel_precision_recall_curve( """ if validate_args: - _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) - _multilabel_precision_recall_curve_tensor_validation(preds, target, num_labels, ignore_index) + _multilabel_precision_recall_curve_arg_validation( + num_labels, thresholds, ignore_index, input_format=input_format + ) + _multilabel_precision_recall_curve_tensor_validation( + preds, target, num_labels, ignore_index, input_format=input_format + ) preds, target, thresholds = _multilabel_precision_recall_curve_format( - preds, target, num_labels, thresholds, ignore_index + preds, target, num_labels, thresholds, ignore_index, input_format=input_format ) state = _multilabel_precision_recall_curve_update(preds, target, num_labels, thresholds) return _multilabel_precision_recall_curve_compute(state, num_labels, thresholds, ignore_index) @@ -945,6 +1026,7 @@ def precision_recall_curve( average: Optional[Literal["micro", "macro"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: r"""Compute the precision-recall curve. @@ -987,15 +1069,19 @@ def precision_recall_curve( """ task = ClassificationTask.from_str(task) if task == ClassificationTask.BINARY: - return binary_precision_recall_curve(preds, target, thresholds, ignore_index, validate_args) + return binary_precision_recall_curve( + preds, target, thresholds, ignore_index, validate_args, input_format=input_format + ) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") return multiclass_precision_recall_curve( - preds, target, num_classes, thresholds, average, ignore_index, validate_args + preds, target, num_classes, thresholds, average, ignore_index, validate_args, input_format=input_format ) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") - return multilabel_precision_recall_curve(preds, target, num_labels, thresholds, ignore_index, validate_args) + return multilabel_precision_recall_curve( + preds, target, num_labels, thresholds, ignore_index, validate_args, input_format=input_format + ) raise ValueError(f"Task {task} not supported.") diff --git a/src/torchmetrics/functional/classification/roc.py b/src/torchmetrics/functional/classification/roc.py index d61b920aa9b..917cafcdf1d 100644 --- a/src/torchmetrics/functional/classification/roc.py +++ b/src/torchmetrics/functional/classification/roc.py @@ -86,6 +86,7 @@ def binary_roc( thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", ) -> Tuple[Tensor, Tensor, Tensor]: r"""Compute the Receiver Operating Characteristic (ROC) for binary tasks. @@ -129,6 +130,17 @@ def binary_roc( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: (tuple): a tuple of 3 tensors containing: @@ -152,9 +164,11 @@ def binary_roc( """ if validate_args: - _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) - _binary_precision_recall_curve_tensor_validation(preds, target, ignore_index) - preds, target, thresholds = _binary_precision_recall_curve_format(preds, target, thresholds, ignore_index) + _binary_precision_recall_curve_arg_validation(thresholds, ignore_index, input_format=input_format) + _binary_precision_recall_curve_tensor_validation(preds, target, ignore_index, input_format=input_format) + preds, target, thresholds = _binary_precision_recall_curve_format( + preds, target, thresholds, ignore_index, input_format=input_format + ) state = _binary_precision_recall_curve_update(preds, target, thresholds) return _binary_roc_compute(state, thresholds) @@ -212,6 +226,7 @@ def multiclass_roc( average: Optional[Literal["micro", "macro"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: r"""Compute the Receiver Operating Characteristic (ROC) for multiclass tasks. @@ -263,6 +278,17 @@ def multiclass_roc( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: (tuple): a tuple of either 3 tensors or 3 lists containing @@ -312,15 +338,14 @@ def multiclass_roc( """ if validate_args: - _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index, average) - _multiclass_precision_recall_curve_tensor_validation(preds, target, num_classes, ignore_index) + _multiclass_precision_recall_curve_arg_validation( + num_classes, thresholds, ignore_index, average, input_format=input_format + ) + _multiclass_precision_recall_curve_tensor_validation( + preds, target, num_classes, ignore_index, input_format=input_format + ) preds, target, thresholds = _multiclass_precision_recall_curve_format( - preds, - target, - num_classes, - thresholds, - ignore_index, - average, + preds, target, num_classes, thresholds, ignore_index, average, input_format=input_format ) state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds, average) return _multiclass_roc_compute(state, num_classes, thresholds, average) @@ -363,6 +388,7 @@ def multilabel_roc( thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: r"""Compute the Receiver Operating Characteristic (ROC) for multilabel tasks. @@ -407,6 +433,17 @@ def multilabel_roc( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: (tuple): a tuple of either 3 tensors or 3 lists containing @@ -459,10 +496,14 @@ def multilabel_roc( """ if validate_args: - _multilabel_precision_recall_curve_arg_validation(num_labels, thresholds, ignore_index) - _multilabel_precision_recall_curve_tensor_validation(preds, target, num_labels, ignore_index) + _multilabel_precision_recall_curve_arg_validation( + num_labels, thresholds, ignore_index, input_format=input_format + ) + _multilabel_precision_recall_curve_tensor_validation( + preds, target, num_labels, ignore_index, input_format=input_format + ) preds, target, thresholds = _multilabel_precision_recall_curve_format( - preds, target, num_labels, thresholds, ignore_index + preds, target, num_labels, thresholds, ignore_index, input_format=input_format ) state = _multilabel_precision_recall_curve_update(preds, target, num_labels, thresholds) return _multilabel_roc_compute(state, num_labels, thresholds, ignore_index) @@ -478,6 +519,7 @@ def roc( average: Optional[Literal["micro", "macro"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "none"] = "auto", ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: r"""Compute the Receiver Operating Characteristic (ROC). @@ -538,13 +580,17 @@ def roc( """ task = ClassificationTask.from_str(task) if task == ClassificationTask.BINARY: - return binary_roc(preds, target, thresholds, ignore_index, validate_args) + return binary_roc(preds, target, thresholds, ignore_index, validate_args, input_format=input_format) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") - return multiclass_roc(preds, target, num_classes, thresholds, average, ignore_index, validate_args) + return multiclass_roc( + preds, target, num_classes, thresholds, average, ignore_index, validate_args, input_format=input_format + ) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") - return multilabel_roc(preds, target, num_labels, thresholds, ignore_index, validate_args) + return multilabel_roc( + preds, target, num_labels, thresholds, ignore_index, validate_args, input_format=input_format + ) raise ValueError(f"Task {task} not supported, expected one of {ClassificationTask}.") diff --git a/src/torchmetrics/functional/classification/specificity.py b/src/torchmetrics/functional/classification/specificity.py index 112a7b96204..df632390711 100644 --- a/src/torchmetrics/functional/classification/specificity.py +++ b/src/torchmetrics/functional/classification/specificity.py @@ -61,6 +61,7 @@ def binary_specificity( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute `Specificity`_ for binary tasks. @@ -91,6 +92,19 @@ def binary_specificity( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` @@ -121,9 +135,9 @@ def binary_specificity( """ if validate_args: - _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) - _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index) - preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index) + _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, input_format=input_format) + _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index, input_format=input_format) + preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index, input_format=input_format) tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average) return _specificity_reduce(tp, fp, tn, fn, average="binary", multidim_average=multidim_average) @@ -137,6 +151,7 @@ def multiclass_specificity( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute `Specificity`_ for multiclass tasks. @@ -178,6 +193,19 @@ def multiclass_specificity( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: The returned shape depends on the ``average`` and ``multidim_average`` arguments: @@ -226,9 +254,13 @@ def multiclass_specificity( """ if validate_args: - _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) - _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) - preds, target = _multiclass_stat_scores_format(preds, target, top_k) + _multiclass_stat_scores_arg_validation( + num_classes, top_k, average, multidim_average, ignore_index, input_format=input_format + ) + _multiclass_stat_scores_tensor_validation( + preds, target, num_classes, multidim_average, ignore_index, input_format=input_format + ) + preds, target = _multiclass_stat_scores_format(preds, target, top_k, input_format=input_format) tp, fp, tn, fn = _multiclass_stat_scores_update( preds, target, num_classes, top_k, average, multidim_average, ignore_index ) @@ -244,6 +276,7 @@ def multilabel_specificity( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute `Specificity`_ for multilabel tasks. @@ -283,6 +316,19 @@ def multilabel_specificity( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: The returned shape depends on the ``average`` and ``multidim_average`` arguments: @@ -329,9 +375,15 @@ def multilabel_specificity( """ if validate_args: - _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) - _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) - preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) + _multilabel_stat_scores_arg_validation( + num_labels, threshold, average, multidim_average, ignore_index, input_format=input_format + ) + _multilabel_stat_scores_tensor_validation( + preds, target, num_labels, multidim_average, ignore_index, input_format=input_format + ) + preds, target = _multilabel_stat_scores_format( + preds, target, num_labels, threshold, ignore_index, input_format=input_format + ) tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) return _specificity_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average, multilabel=True) @@ -348,6 +400,7 @@ def specificity( top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute `Specificity`_. @@ -376,19 +429,37 @@ def specificity( task = ClassificationTask.from_str(task) assert multidim_average is not None # noqa: S101 # needed for mypy if task == ClassificationTask.BINARY: - return binary_specificity(preds, target, threshold, multidim_average, ignore_index, validate_args) + return binary_specificity( + preds, target, threshold, multidim_average, ignore_index, validate_args, input_format=input_format + ) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") if not isinstance(top_k, int): raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`") return multiclass_specificity( - preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args + preds, + target, + num_classes, + average, + top_k, + multidim_average, + ignore_index, + validate_args, + input_format=input_format, ) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") return multilabel_specificity( - preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args + preds, + target, + num_labels, + threshold, + average, + multidim_average, + ignore_index, + validate_args, + input_format=input_format, ) raise ValueError(f"Not handled value: {task}") diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index 5153554253b..775e3b4dbf1 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -17,7 +17,11 @@ from torch import Tensor, tensor from typing_extensions import Literal -from torchmetrics.utilities.checks import _check_same_shape, _input_format_classification +from torchmetrics.utilities.checks import ( + _check_same_shape, + _check_valid_input_format_type, + _input_format_classification, +) from torchmetrics.utilities.data import _bincount, select_topk from torchmetrics.utilities.enums import AverageMethod, ClassificationTask, DataType, MDMCAverageMethod @@ -26,6 +30,7 @@ def _binary_stat_scores_arg_validation( threshold: float = 0.5, multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> None: """Validate non tensor input. @@ -44,12 +49,15 @@ def _binary_stat_scores_arg_validation( if ignore_index is not None and not isinstance(ignore_index, int): raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + _check_valid_input_format_type(input_format) + def _binary_stat_scores_tensor_validation( preds: Tensor, target: Tensor, multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> None: """Validate tensor input. @@ -74,8 +82,11 @@ def _binary_stat_scores_tensor_validation( f" the following values {[0, 1] if ignore_index is None else [ignore_index]}." ) + if multidim_average != "global" and preds.ndim < 2: + raise ValueError("Expected input to be at least 2D when multidim_average is set to `samplewise`") + # If preds is label tensor, also check that it only contains [0,1] values - if not preds.is_floating_point(): + if not preds.is_floating_point() or input_format == "labels": unique_values = torch.unique(preds) if torch.any((unique_values != 0) & (unique_values != 1)): raise RuntimeError( @@ -83,8 +94,11 @@ def _binary_stat_scores_tensor_validation( " the following values [0,1] since `preds` is a label tensor." ) - if multidim_average != "global" and preds.ndim < 2: - raise ValueError("Expected input to be at least 2D when multidim_average is set to `samplewise`") + if input_format == "probs" and not torch.all((preds >= 0) * (preds <= 1)): + raise ValueError( + "Expected argument `preds` to be a tensor with values in the [0,1] range," + f" but got tensor with values {preds}" + ) def _binary_stat_scores_format( @@ -92,6 +106,7 @@ def _binary_stat_scores_format( target: Tensor, threshold: float = 0.5, ignore_index: Optional[int] = None, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tuple[Tensor, Tensor]: """Convert all input to label format. @@ -100,10 +115,11 @@ def _binary_stat_scores_format( - Mask all datapoints that should be ignored with negative values """ - if preds.is_floating_point(): - if not torch.all((preds >= 0) * (preds <= 1)): - # preds is logits, convert with sigmoid - preds = preds.sigmoid() + if input_format == "logits": + preds = preds.sigmoid() + if preds.is_floating_point() and input_format == "auto" and not torch.all((preds >= 0) * (preds <= 1)): + preds = preds.sigmoid() + if input_format not in ("labels", "none"): preds = preds > threshold preds = preds.reshape(preds.shape[0], -1) @@ -145,6 +161,7 @@ def binary_stat_scores( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute the true positives, false positives, true negatives, false negatives, support for binary tasks. @@ -172,6 +189,19 @@ def binary_stat_scores( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds @@ -207,9 +237,9 @@ def binary_stat_scores( """ if validate_args: - _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) - _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index) - preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index) + _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, input_format=input_format) + _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index, input_format=input_format) + preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index, input_format=input_format) tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average) return _binary_stat_scores_compute(tp, fp, tn, fn, multidim_average) @@ -220,6 +250,7 @@ def _multiclass_stat_scores_arg_validation( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> None: """Validate non tensor input. @@ -248,6 +279,7 @@ def _multiclass_stat_scores_arg_validation( ) if ignore_index is not None and not isinstance(ignore_index, int): raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + _check_valid_input_format_type(input_format) def _multiclass_stat_scores_tensor_validation( @@ -256,6 +288,7 @@ def _multiclass_stat_scores_tensor_validation( num_classes: int, multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> None: """Validate tensor input. @@ -313,19 +346,26 @@ def _multiclass_stat_scores_tensor_validation( f" {num_unique_values} in `target`." ) - if not preds.is_floating_point(): - unique_values = torch.unique(preds) - if len(unique_values) > num_classes: + if not preds.is_floating_point() or input_format == "labels": + num_unique_values = len(torch.unique(preds)) + if num_unique_values > num_classes: raise RuntimeError( - "Detected more unique values in `preds` than `num_classes`. Expected only" - f" {num_classes} but found {len(unique_values)} in `preds`." + "Detected more unique values in `preds` than `num_classes`. Expected only " + f"{num_classes} but found {num_unique_values} in `preds`." ) + if input_format == "probs" and not torch.all((preds >= 0) * (preds <= 1)): + raise ValueError( + "Expected argument `preds` to be a tensor with values in the [0,1] range," + f" but got tensor with values {preds}" + ) + def _multiclass_stat_scores_format( preds: Tensor, target: Tensor, top_k: int = 1, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tuple[Tensor, Tensor]: """Convert all input to label format except if ``top_k`` is not 1. @@ -334,7 +374,12 @@ def _multiclass_stat_scores_format( """ # Apply argmax if we have one more dimension - if preds.ndim == target.ndim + 1 and top_k == 1: + if ( + input_format == "logits" + or input_format == "probs" + or (input_format == "auto" and preds.ndim == target.ndim + 1) + and top_k == 1 + ): preds = preds.argmax(dim=1) preds = preds.reshape(*preds.shape[:2], -1) if top_k != 1 else preds.reshape(preds.shape[0], -1) target = target.reshape(target.shape[0], -1) @@ -457,6 +502,7 @@ def multiclass_stat_scores( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute the true positives, false positives, true negatives, false negatives and support for multiclass tasks. @@ -494,6 +540,19 @@ def multiclass_stat_scores( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. Returns: The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds @@ -553,9 +612,13 @@ def multiclass_stat_scores( """ if validate_args: - _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) - _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) - preds, target = _multiclass_stat_scores_format(preds, target, top_k) + _multiclass_stat_scores_arg_validation( + num_classes, top_k, average, multidim_average, ignore_index, input_format=input_format + ) + _multiclass_stat_scores_tensor_validation( + preds, target, num_classes, multidim_average, ignore_index, input_format=input_format + ) + preds, target = _multiclass_stat_scores_format(preds, target, top_k, input_format=input_format) tp, fp, tn, fn = _multiclass_stat_scores_update( preds, target, num_classes, top_k, average, multidim_average, ignore_index ) @@ -568,6 +631,7 @@ def _multilabel_stat_scores_arg_validation( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> None: """Validate non tensor input. @@ -592,6 +656,7 @@ def _multilabel_stat_scores_arg_validation( ) if ignore_index is not None and not isinstance(ignore_index, int): raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + _check_valid_input_format_type(input_format) def _multilabel_stat_scores_tensor_validation( @@ -600,6 +665,7 @@ def _multilabel_stat_scores_tensor_validation( num_labels: int, multidim_average: str, ignore_index: Optional[int] = None, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> None: """Validate tensor input. @@ -631,8 +697,10 @@ def _multilabel_stat_scores_tensor_validation( f" the following values {[0, 1] if ignore_index is None else [ignore_index]}." ) - # If preds is label tensor, also check that it only contains [0,1] values - if not preds.is_floating_point(): + if multidim_average != "global" and preds.ndim < 3: + raise ValueError("Expected input to be at least 3D when multidim_average is set to `samplewise`") + + if not preds.is_floating_point() or input_format == "labels": unique_values = torch.unique(preds) if torch.any((unique_values != 0) & (unique_values != 1)): raise RuntimeError( @@ -640,12 +708,20 @@ def _multilabel_stat_scores_tensor_validation( " the following values [0,1] since preds is a label tensor." ) - if multidim_average != "global" and preds.ndim < 3: - raise ValueError("Expected input to be at least 3D when multidim_average is set to `samplewise`") + if input_format == "probs" and not torch.all((preds >= 0) * (preds <= 1)): + raise ValueError( + "Expected argument `preds` to be a tensor with values in the [0,1] range," + f" but got tensor with values {preds}" + ) def _multilabel_stat_scores_format( - preds: Tensor, target: Tensor, num_labels: int, threshold: float = 0.5, ignore_index: Optional[int] = None + preds: Tensor, + target: Tensor, + num_labels: int, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tuple[Tensor, Tensor]: """Convert all input to label format. @@ -654,9 +730,11 @@ def _multilabel_stat_scores_format( - Mask all elements that should be ignored with negative numbers for later filtration """ - if preds.is_floating_point(): - if not torch.all((preds >= 0) * (preds <= 1)): - preds = preds.sigmoid() + if input_format == "logits": + preds = preds.sigmoid() + if preds.is_floating_point() and input_format == "auto" and not torch.all((preds >= 0) * (preds <= 1)): + preds = preds.sigmoid() + if input_format not in ("labels", "none"): preds = preds > threshold preds = preds.reshape(*preds.shape[:2], -1) target = target.reshape(*target.shape[:2], -1) @@ -717,6 +795,7 @@ def multilabel_stat_scores( multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute the true positives, false positives, true negatives, false negatives and support for multilabel tasks. @@ -753,6 +832,20 @@ def multilabel_stat_scores( Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + input_format: str specifying the format of the input preds tensor. Can be one of: + + - ``'auto'``: automatically detect the format based on the values in the tensor. If all values + are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values. + If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the + tensor to be logits and will apply sigmoid to the tensor and threshold the values. + - ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only + thresholding will be applied to the tensor and values will be checked to be in [0,1] range. + - ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We + will apply sigmoid to the tensor and threshold the values before calculating the metric. + - ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be + applied to preds tensor. + - ``'none'``: will disable all input formatting. This is the fastest option but also the least safe. + Returns: The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds @@ -810,9 +903,15 @@ def multilabel_stat_scores( """ if validate_args: - _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) - _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) - preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) + _multilabel_stat_scores_arg_validation( + num_labels, threshold, average, multidim_average, ignore_index, input_format=input_format + ) + _multilabel_stat_scores_tensor_validation( + preds, target, num_labels, multidim_average, ignore_index, input_format=input_format + ) + preds, target = _multilabel_stat_scores_format( + preds, target, num_labels, threshold, ignore_index, input_format=input_format + ) tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) return _multilabel_stat_scores_compute(tp, fp, tn, fn, average, multidim_average) @@ -1086,6 +1185,7 @@ def stat_scores( top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, + input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto", ) -> Tensor: r"""Compute the number of true positives, false positives, true negatives, false negatives and the support. @@ -1111,19 +1211,37 @@ def stat_scores( task = ClassificationTask.from_str(task) assert multidim_average is not None # noqa: S101 # needed for mypy if task == ClassificationTask.BINARY: - return binary_stat_scores(preds, target, threshold, multidim_average, ignore_index, validate_args) + return binary_stat_scores( + preds, target, threshold, multidim_average, ignore_index, validate_args, input_format=input_format + ) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") if not isinstance(top_k, int): raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`") return multiclass_stat_scores( - preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args + preds, + target, + num_classes, + average, + top_k, + multidim_average, + ignore_index, + validate_args, + input_format=input_format, ) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") return multilabel_stat_scores( - preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args + preds, + target, + num_labels, + threshold, + average, + multidim_average, + ignore_index, + validate_args, + input_format=input_format, ) raise ValueError(f"Unsupported task `{task}`") diff --git a/src/torchmetrics/utilities/checks.py b/src/torchmetrics/utilities/checks.py index 7d7d67784f5..8843abb9590 100644 --- a/src/torchmetrics/utilities/checks.py +++ b/src/torchmetrics/utilities/checks.py @@ -44,6 +44,15 @@ def _check_same_shape(preds: Tensor, target: Tensor) -> None: ) +def _check_valid_input_format_type( + input_format: Optional[str], options: Tuple[str] = ("auto", "probs", "logits", "labels", "none") +) -> None: + if not input_format: + return + if input_format not in options: + raise ValueError(f"The `input_format` should be one of {options}, got {input_format}.") + + def _basic_input_validation( preds: Tensor, target: Tensor, threshold: float, multiclass: Optional[bool], ignore_index: Optional[int] ) -> None: diff --git a/tests/unittests/classification/inputs.py b/tests/unittests/classification/inputs.py index c660e625214..dba0b91fa7c 100644 --- a/tests/unittests/classification/inputs.py +++ b/tests/unittests/classification/inputs.py @@ -227,6 +227,20 @@ def _multiclass_with_missing_class(*shape: Any, num_classes=NUM_CLASSES): ) +def check_input_format_matches_data(input_format, request): + """Check that the input format matches the data type, else we skip the test.""" + test_id = request.node.callspec.id + test_id = "-".join(test_id.split("-")[1:]) # remove the first part of the id which is the input_format + if input_format == "labels" and "labels" not in test_id: + pytest.skip("input format labels only works with labels data") + if input_format == "logits" and "logits" not in test_id: + pytest.skip("input format logits only works with logits data") + if input_format == "probs" and "probs" not in test_id: + pytest.skip("input format probs only works with probs data") + if input_format == "none" and "labels" not in test_id: + pytest.skip("input format none only works with labels data") + + _group_cases = ( pytest.param( _GroupInput( diff --git a/tests/unittests/classification/test_accuracy.py b/tests/unittests/classification/test_accuracy.py index f2954f8b620..df49314a5ba 100644 --- a/tests/unittests/classification/test_accuracy.py +++ b/tests/unittests/classification/test_accuracy.py @@ -29,7 +29,13 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification.inputs import _binary_cases, _input_binary, _multiclass_cases, _multilabel_cases +from unittests.classification.inputs import ( + _binary_cases, + _input_binary, + _multiclass_cases, + _multilabel_cases, + check_input_format_matches_data, +) from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index @@ -92,8 +98,10 @@ class TestBinaryAccuracy(MetricTester): @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_binary_accuracy(self, ddp, inputs, ignore_index, multidim_average): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_binary_accuracy(self, ddp, inputs, ignore_index, multidim_average, input_format, request): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -110,13 +118,20 @@ def test_binary_accuracy(self, ddp, inputs, ignore_index, multidim_average): reference_metric=partial( _sklearn_accuracy_binary, ignore_index=ignore_index, multidim_average=multidim_average ), - metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, + metric_args={ + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "input_format": input_format, + }, ) @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) - def test_binary_accuracy_functional(self, inputs, ignore_index, multidim_average): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_binary_accuracy_functional(self, inputs, ignore_index, multidim_average, input_format, request): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -134,6 +149,7 @@ def test_binary_accuracy_functional(self, inputs, ignore_index, multidim_average "threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average, + "input_format": input_format, }, ) @@ -236,8 +252,10 @@ class TestMulticlassAccuracy(MetricTester): @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_multiclass_accuracy(self, ddp, inputs, ignore_index, multidim_average, average): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_multiclass_accuracy(self, ddp, inputs, ignore_index, multidim_average, average, input_format, request): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -262,14 +280,19 @@ def test_multiclass_accuracy(self, ddp, inputs, ignore_index, multidim_average, "multidim_average": multidim_average, "average": average, "num_classes": NUM_CLASSES, + "input_format": input_format, }, ) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) - def test_multiclass_accuracy_functional(self, inputs, ignore_index, multidim_average, average): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_multiclass_accuracy_functional( + self, inputs, ignore_index, multidim_average, average, input_format, request + ): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -291,6 +314,7 @@ def test_multiclass_accuracy_functional(self, inputs, ignore_index, multidim_ave "multidim_average": multidim_average, "average": average, "num_classes": NUM_CLASSES, + "input_format": input_format, }, ) @@ -433,8 +457,10 @@ class TestMultilabelAccuracy(MetricTester): @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) - def test_multilabel_accuracy(self, ddp, inputs, ignore_index, multidim_average, average): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_multilabel_accuracy(self, ddp, inputs, ignore_index, multidim_average, average, input_format, request): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -460,14 +486,19 @@ def test_multilabel_accuracy(self, ddp, inputs, ignore_index, multidim_average, "ignore_index": ignore_index, "multidim_average": multidim_average, "average": average, + "input_format": input_format, }, ) @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) - def test_multilabel_accuracy_functional(self, inputs, ignore_index, multidim_average, average): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_multilabel_accuracy_functional( + self, inputs, ignore_index, multidim_average, average, input_format, request + ): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -490,6 +521,7 @@ def test_multilabel_accuracy_functional(self, inputs, ignore_index, multidim_ave "ignore_index": ignore_index, "multidim_average": multidim_average, "average": average, + "input_format": input_format, }, ) diff --git a/tests/unittests/classification/test_auroc.py b/tests/unittests/classification/test_auroc.py index a6c30271388..be36f730714 100644 --- a/tests/unittests/classification/test_auroc.py +++ b/tests/unittests/classification/test_auroc.py @@ -25,7 +25,12 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES -from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.classification.inputs import ( + _binary_cases, + _multiclass_cases, + _multilabel_cases, + check_input_format_matches_data, +) from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index @@ -48,8 +53,10 @@ class TestBinaryAUROC(MetricTester): @pytest.mark.parametrize("max_fpr", [None, 0.8, 0.5]) @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_binary_auroc(self, inputs, ddp, max_fpr, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits"]) + def test_binary_auroc(self, inputs, ddp, max_fpr, ignore_index, input_format, request): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -63,13 +70,16 @@ def test_binary_auroc(self, inputs, ddp, max_fpr, ignore_index): "max_fpr": max_fpr, "thresholds": None, "ignore_index": ignore_index, + "input_format": input_format, }, ) @pytest.mark.parametrize("max_fpr", [None, 0.8, 0.5]) @pytest.mark.parametrize("ignore_index", [None, -1]) - def test_binary_auroc_functional(self, inputs, max_fpr, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits"]) + def test_binary_auroc_functional(self, inputs, max_fpr, ignore_index, input_format, request): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -82,6 +92,7 @@ def test_binary_auroc_functional(self, inputs, max_fpr, ignore_index): "max_fpr": max_fpr, "thresholds": None, "ignore_index": ignore_index, + "input_format": input_format, }, ) @@ -156,8 +167,10 @@ class TestMulticlassAUROC(MetricTester): @pytest.mark.parametrize("average", ["macro", "weighted"]) @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_multiclass_auroc(self, inputs, average, ddp, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits"]) + def test_multiclass_auroc(self, inputs, average, ddp, ignore_index, input_format, request): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -172,13 +185,16 @@ def test_multiclass_auroc(self, inputs, average, ddp, ignore_index): "num_classes": NUM_CLASSES, "average": average, "ignore_index": ignore_index, + "input_format": input_format, }, ) @pytest.mark.parametrize("average", ["macro", "weighted"]) @pytest.mark.parametrize("ignore_index", [None, -1]) - def test_multiclass_auroc_functional(self, inputs, average, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits"]) + def test_multiclass_auroc_functional(self, inputs, average, ignore_index, input_format, request): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -192,6 +208,7 @@ def test_multiclass_auroc_functional(self, inputs, average, ignore_index): "num_classes": NUM_CLASSES, "average": average, "ignore_index": ignore_index, + "input_format": input_format, }, ) @@ -285,8 +302,10 @@ class TestMultilabelAUROC(MetricTester): @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_multilabel_auroc(self, inputs, ddp, average, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits"]) + def test_multilabel_auroc(self, inputs, ddp, average, ignore_index, input_format, request): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -301,13 +320,16 @@ def test_multilabel_auroc(self, inputs, ddp, average, ignore_index): "num_labels": NUM_CLASSES, "average": average, "ignore_index": ignore_index, + "input_format": input_format, }, ) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) @pytest.mark.parametrize("ignore_index", [None, -1]) - def test_multilabel_auroc_functional(self, inputs, average, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits"]) + def test_multilabel_auroc_functional(self, inputs, average, ignore_index, input_format, request): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -321,6 +343,7 @@ def test_multilabel_auroc_functional(self, inputs, average, ignore_index): "num_labels": NUM_CLASSES, "average": average, "ignore_index": ignore_index, + "input_format": input_format, }, ) diff --git a/tests/unittests/classification/test_average_precision.py b/tests/unittests/classification/test_average_precision.py index cdb76ffce31..d950ee8f5b9 100644 --- a/tests/unittests/classification/test_average_precision.py +++ b/tests/unittests/classification/test_average_precision.py @@ -34,7 +34,12 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES -from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.classification.inputs import ( + _binary_cases, + _multiclass_cases, + _multilabel_cases, + check_input_format_matches_data, +) from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index @@ -56,8 +61,10 @@ class TestBinaryAveragePrecision(MetricTester): @pytest.mark.parametrize("ignore_index", [None, -1, 0]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_binary_average_precision(self, inputs, ddp, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits"]) + def test_binary_average_precision(self, inputs, ddp, ignore_index, input_format, request): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -70,12 +77,15 @@ def test_binary_average_precision(self, inputs, ddp, ignore_index): metric_args={ "thresholds": None, "ignore_index": ignore_index, + "input_format": input_format, }, ) @pytest.mark.parametrize("ignore_index", [None, -1, 0]) - def test_binary_average_precision_functional(self, inputs, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits"]) + def test_binary_average_precision_functional(self, inputs, ignore_index, input_format, request): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -87,6 +97,7 @@ def test_binary_average_precision_functional(self, inputs, ignore_index): metric_args={ "thresholds": None, "ignore_index": ignore_index, + "input_format": input_format, }, ) @@ -172,8 +183,10 @@ class TestMulticlassAveragePrecision(MetricTester): @pytest.mark.parametrize("average", ["macro", "weighted", None]) @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_multiclass_average_precision(self, inputs, average, ddp, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits"]) + def test_multiclass_average_precision(self, inputs, average, ddp, ignore_index, input_format, request): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -188,13 +201,16 @@ def test_multiclass_average_precision(self, inputs, average, ddp, ignore_index): "num_classes": NUM_CLASSES, "average": average, "ignore_index": ignore_index, + "input_format": input_format, }, ) @pytest.mark.parametrize("average", ["macro", "weighted", None]) @pytest.mark.parametrize("ignore_index", [None, -1]) - def test_multiclass_average_precision_functional(self, inputs, average, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits"]) + def test_multiclass_average_precision_functional(self, inputs, average, ignore_index, input_format, request): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -208,6 +224,7 @@ def test_multiclass_average_precision_functional(self, inputs, average, ignore_i "num_classes": NUM_CLASSES, "average": average, "ignore_index": ignore_index, + "input_format": input_format, }, ) @@ -288,8 +305,10 @@ class TestMultilabelAveragePrecision(MetricTester): @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_multilabel_average_precision(self, inputs, ddp, average, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits"]) + def test_multilabel_average_precision(self, inputs, ddp, average, ignore_index, input_format, request): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -304,13 +323,16 @@ def test_multilabel_average_precision(self, inputs, ddp, average, ignore_index): "num_labels": NUM_CLASSES, "average": average, "ignore_index": ignore_index, + "input_format": input_format, }, ) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) @pytest.mark.parametrize("ignore_index", [None, -1]) - def test_multilabel_average_precision_functional(self, inputs, average, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits"]) + def test_multilabel_average_precision_functional(self, inputs, average, ignore_index, input_format, request): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -324,6 +346,7 @@ def test_multilabel_average_precision_functional(self, inputs, average, ignore_i "num_labels": NUM_CLASSES, "average": average, "ignore_index": ignore_index, + "input_format": input_format, }, ) diff --git a/tests/unittests/classification/test_confusion_matrix.py b/tests/unittests/classification/test_confusion_matrix.py index 5265f9a64eb..89c9662fa0a 100644 --- a/tests/unittests/classification/test_confusion_matrix.py +++ b/tests/unittests/classification/test_confusion_matrix.py @@ -32,7 +32,12 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.classification.inputs import ( + _binary_cases, + _multiclass_cases, + _multilabel_cases, + check_input_format_matches_data, +) from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index @@ -57,8 +62,10 @@ class TestBinaryConfusionMatrix(MetricTester): @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) @pytest.mark.parametrize("ignore_index", [None, -1, 0]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_binary_confusion_matrix(self, inputs, ddp, normalize, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_binary_confusion_matrix(self, inputs, ddp, normalize, ignore_index, input_format, request): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -72,13 +79,16 @@ def test_binary_confusion_matrix(self, inputs, ddp, normalize, ignore_index): "threshold": THRESHOLD, "normalize": normalize, "ignore_index": ignore_index, + "input_format": input_format, }, ) @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) @pytest.mark.parametrize("ignore_index", [None, -1, 0]) - def test_binary_confusion_matrix_functional(self, inputs, normalize, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_binary_confusion_matrix_functional(self, inputs, normalize, ignore_index, input_format, request): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -91,6 +101,7 @@ def test_binary_confusion_matrix_functional(self, inputs, normalize, ignore_inde "threshold": THRESHOLD, "normalize": normalize, "ignore_index": ignore_index, + "input_format": input_format, }, ) @@ -154,8 +165,10 @@ class TestMulticlassConfusionMatrix(MetricTester): @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) @pytest.mark.parametrize("ignore_index", [None, -1, 0]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_multiclass_confusion_matrix(self, inputs, ddp, normalize, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_multiclass_confusion_matrix(self, inputs, ddp, normalize, ignore_index, input_format, request): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -171,13 +184,16 @@ def test_multiclass_confusion_matrix(self, inputs, ddp, normalize, ignore_index) "num_classes": NUM_CLASSES, "normalize": normalize, "ignore_index": ignore_index, + "input_format": input_format, }, ) @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) @pytest.mark.parametrize("ignore_index", [None, -1, 0]) - def test_multiclass_confusion_matrix_functional(self, inputs, normalize, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_multiclass_confusion_matrix_functional(self, inputs, normalize, ignore_index, input_format, request): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -192,6 +208,7 @@ def test_multiclass_confusion_matrix_functional(self, inputs, normalize, ignore_ "num_classes": NUM_CLASSES, "normalize": normalize, "ignore_index": ignore_index, + "input_format": input_format, }, ) @@ -271,8 +288,10 @@ class TestMultilabelConfusionMatrix(MetricTester): @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) @pytest.mark.parametrize("ignore_index", [None, -1, 0]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_multilabel_confusion_matrix(self, inputs, ddp, normalize, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_multilabel_confusion_matrix(self, inputs, ddp, normalize, ignore_index, input_format, request): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -288,13 +307,16 @@ def test_multilabel_confusion_matrix(self, inputs, ddp, normalize, ignore_index) "num_labels": NUM_CLASSES, "normalize": normalize, "ignore_index": ignore_index, + "input_format": input_format, }, ) @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) @pytest.mark.parametrize("ignore_index", [None, -1, 0]) - def test_multilabel_confusion_matrix_functional(self, inputs, normalize, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_multilabel_confusion_matrix_functional(self, inputs, normalize, ignore_index, input_format, request): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -309,6 +331,7 @@ def test_multilabel_confusion_matrix_functional(self, inputs, normalize, ignore_ "num_labels": NUM_CLASSES, "normalize": normalize, "ignore_index": ignore_index, + "input_format": input_format, }, ) diff --git a/tests/unittests/classification/test_f_beta.py b/tests/unittests/classification/test_f_beta.py index 4534a1b9259..571a7040527 100644 --- a/tests/unittests/classification/test_f_beta.py +++ b/tests/unittests/classification/test_f_beta.py @@ -42,7 +42,12 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.classification.inputs import ( + _binary_cases, + _multiclass_cases, + _multilabel_cases, + check_input_format_matches_data, +) from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index @@ -90,8 +95,12 @@ class TestBinaryFBetaScore(MetricTester): @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_binary_fbeta_score(self, ddp, inputs, module, functional, compare, ignore_index, multidim_average): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_binary_fbeta_score( + self, ddp, inputs, module, functional, compare, ignore_index, multidim_average, input_format, request + ): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -108,13 +117,22 @@ def test_binary_fbeta_score(self, ddp, inputs, module, functional, compare, igno reference_metric=partial( _sklearn_fbeta_score_binary, sk_fn=compare, ignore_index=ignore_index, multidim_average=multidim_average ), - metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, + metric_args={ + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "input_format": input_format, + }, ) @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) - def test_binary_fbeta_score_functional(self, inputs, module, functional, compare, ignore_index, multidim_average): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_binary_fbeta_score_functional( + self, inputs, module, functional, compare, ignore_index, multidim_average, input_format, request + ): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -132,6 +150,7 @@ def test_binary_fbeta_score_functional(self, inputs, module, functional, compare "threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average, + "input_format": input_format, }, ) @@ -218,10 +237,12 @@ class TestMulticlassFBetaScore(MetricTester): @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) def test_multiclass_fbeta_score( - self, ddp, inputs, module, functional, compare, ignore_index, multidim_average, average + self, ddp, inputs, module, functional, compare, ignore_index, multidim_average, average, input_format, request ): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -247,16 +268,19 @@ def test_multiclass_fbeta_score( "multidim_average": multidim_average, "average": average, "num_classes": NUM_CLASSES, + "input_format": input_format, }, ) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) def test_multiclass_fbeta_score_functional( - self, inputs, module, functional, compare, ignore_index, multidim_average, average + self, inputs, module, functional, compare, ignore_index, multidim_average, average, input_format, request ): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -279,6 +303,7 @@ def test_multiclass_fbeta_score_functional( "multidim_average": multidim_average, "average": average, "num_classes": NUM_CLASSES, + "input_format": input_format, }, ) @@ -464,10 +489,12 @@ class TestMultilabelFBetaScore(MetricTester): @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) def test_multilabel_fbeta_score( - self, ddp, inputs, module, functional, compare, ignore_index, multidim_average, average + self, ddp, inputs, module, functional, compare, ignore_index, multidim_average, average, input_format, request ): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -494,16 +521,19 @@ def test_multilabel_fbeta_score( "ignore_index": ignore_index, "multidim_average": multidim_average, "average": average, + "input_format": input_format, }, ) @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) def test_multilabel_fbeta_score_functional( - self, inputs, module, functional, compare, ignore_index, multidim_average, average + self, inputs, module, functional, compare, ignore_index, multidim_average, average, input_format, request ): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -527,6 +557,7 @@ def test_multilabel_fbeta_score_functional( "ignore_index": ignore_index, "multidim_average": multidim_average, "average": average, + "input_format": input_format, }, ) diff --git a/tests/unittests/classification/test_hamming_distance.py b/tests/unittests/classification/test_hamming.py similarity index 90% rename from tests/unittests/classification/test_hamming_distance.py rename to tests/unittests/classification/test_hamming.py index ad6c2e199b4..2eecfa7a81b 100644 --- a/tests/unittests/classification/test_hamming_distance.py +++ b/tests/unittests/classification/test_hamming.py @@ -33,7 +33,12 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.classification.inputs import ( + _binary_cases, + _multiclass_cases, + _multilabel_cases, + check_input_format_matches_data, +) from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index @@ -78,8 +83,10 @@ class TestBinaryHammingDistance(MetricTester): @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_binary_hamming_distance(self, ddp, inputs, ignore_index, multidim_average): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_binary_hamming_distance(self, ddp, inputs, ignore_index, multidim_average, input_format, request): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -96,13 +103,20 @@ def test_binary_hamming_distance(self, ddp, inputs, ignore_index, multidim_avera reference_metric=partial( _sklearn_hamming_distance_binary, ignore_index=ignore_index, multidim_average=multidim_average ), - metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, + metric_args={ + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "input_format": input_format, + }, ) @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) - def test_binary_hamming_distance_functional(self, inputs, ignore_index, multidim_average): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_binary_hamming_distance_functional(self, inputs, ignore_index, multidim_average, input_format, request): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -120,6 +134,7 @@ def test_binary_hamming_distance_functional(self, inputs, ignore_index, multidim "threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average, + "input_format": input_format, }, ) @@ -228,8 +243,12 @@ class TestMulticlassHammingDistance(MetricTester): @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_multiclass_hamming_distance(self, ddp, inputs, ignore_index, multidim_average, average): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_multiclass_hamming_distance( + self, ddp, inputs, ignore_index, multidim_average, average, input_format, request + ): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -254,14 +273,19 @@ def test_multiclass_hamming_distance(self, ddp, inputs, ignore_index, multidim_a "multidim_average": multidim_average, "average": average, "num_classes": NUM_CLASSES, + "input_format": input_format, }, ) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) - def test_multiclass_hamming_distance_functional(self, inputs, ignore_index, multidim_average, average): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_multiclass_hamming_distance_functional( + self, inputs, ignore_index, multidim_average, average, input_format, request + ): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -283,6 +307,7 @@ def test_multiclass_hamming_distance_functional(self, inputs, ignore_index, mult "multidim_average": multidim_average, "average": average, "num_classes": NUM_CLASSES, + "input_format": input_format, }, ) @@ -410,8 +435,12 @@ class TestMultilabelHammingDistance(MetricTester): @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", None]) - def test_multilabel_hamming_distance(self, ddp, inputs, ignore_index, multidim_average, average): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_multilabel_hamming_distance( + self, ddp, inputs, ignore_index, multidim_average, average, input_format, request + ): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -437,14 +466,19 @@ def test_multilabel_hamming_distance(self, ddp, inputs, ignore_index, multidim_a "ignore_index": ignore_index, "multidim_average": multidim_average, "average": average, + "input_format": input_format, }, ) @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", None]) - def test_multilabel_hamming_distance_functional(self, inputs, ignore_index, multidim_average, average): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_multilabel_hamming_distance_functional( + self, inputs, ignore_index, multidim_average, average, input_format, request + ): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -467,6 +501,7 @@ def test_multilabel_hamming_distance_functional(self, inputs, ignore_index, mult "ignore_index": ignore_index, "multidim_average": multidim_average, "average": average, + "input_format": input_format, }, ) diff --git a/tests/unittests/classification/test_precision_recall.py b/tests/unittests/classification/test_precision_recall.py index f438dd50e21..e375cba2bda 100644 --- a/tests/unittests/classification/test_precision_recall.py +++ b/tests/unittests/classification/test_precision_recall.py @@ -42,7 +42,12 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.classification.inputs import ( + _binary_cases, + _multiclass_cases, + _multilabel_cases, + check_input_format_matches_data, +) from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index @@ -90,8 +95,12 @@ class TestBinaryPrecisionRecall(MetricTester): @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_binary_precision_recall(self, ddp, inputs, module, functional, compare, ignore_index, multidim_average): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_binary_precision_recall( + self, ddp, inputs, module, functional, compare, ignore_index, multidim_average, input_format, request + ): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -111,15 +120,22 @@ def test_binary_precision_recall(self, ddp, inputs, module, functional, compare, ignore_index=ignore_index, multidim_average=multidim_average, ), - metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, + metric_args={ + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "input_format": input_format, + }, ) @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) def test_binary_precision_recall_functional( - self, inputs, module, functional, compare, ignore_index, multidim_average + self, inputs, module, functional, compare, ignore_index, multidim_average, input_format, request ): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -140,6 +156,7 @@ def test_binary_precision_recall_functional( "threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average, + "input_format": input_format, }, ) @@ -223,10 +240,12 @@ class TestMulticlassPrecisionRecall(MetricTester): @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) def test_multiclass_precision_recall( - self, ddp, inputs, module, functional, compare, ignore_index, multidim_average, average + self, ddp, inputs, module, functional, compare, ignore_index, multidim_average, average, input_format, request ): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -252,16 +271,19 @@ def test_multiclass_precision_recall( "multidim_average": multidim_average, "average": average, "num_classes": NUM_CLASSES, + "input_format": input_format, }, ) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) def test_multiclass_precision_recall_functional( - self, inputs, module, functional, compare, ignore_index, multidim_average, average + self, inputs, module, functional, compare, ignore_index, multidim_average, average, input_format, request ): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -284,6 +306,7 @@ def test_multiclass_precision_recall_functional( "multidim_average": multidim_average, "average": average, "num_classes": NUM_CLASSES, + "input_format": input_format, }, ) @@ -460,10 +483,12 @@ class TestMultilabelPrecisionRecall(MetricTester): @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) def test_multilabel_precision_recall( - self, ddp, inputs, module, functional, compare, ignore_index, multidim_average, average + self, ddp, inputs, module, functional, compare, ignore_index, multidim_average, average, input_format, request ): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -490,16 +515,19 @@ def test_multilabel_precision_recall( "ignore_index": ignore_index, "multidim_average": multidim_average, "average": average, + "input_format": input_format, }, ) @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) def test_multilabel_precision_recall_functional( - self, inputs, module, functional, compare, ignore_index, multidim_average, average + self, inputs, module, functional, compare, ignore_index, multidim_average, average, input_format, request ): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -523,6 +551,7 @@ def test_multilabel_precision_recall_functional( "ignore_index": ignore_index, "multidim_average": multidim_average, "average": average, + "input_format": input_format, }, ) diff --git a/tests/unittests/classification/test_precision_recall_curve.py b/tests/unittests/classification/test_precision_recall_curve.py index 9c1d4263b99..26ea9308665 100644 --- a/tests/unittests/classification/test_precision_recall_curve.py +++ b/tests/unittests/classification/test_precision_recall_curve.py @@ -33,7 +33,12 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES -from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.classification.inputs import ( + _binary_cases, + _multiclass_cases, + _multilabel_cases, + check_input_format_matches_data, +) from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index @@ -55,8 +60,10 @@ class TestBinaryPrecisionRecallCurve(MetricTester): @pytest.mark.parametrize("ignore_index", [None, -1, 0]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_binary_precision_recall_curve(self, inputs, ddp, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits"]) + def test_binary_precision_recall_curve(self, inputs, ddp, ignore_index, input_format, request): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -69,12 +76,15 @@ def test_binary_precision_recall_curve(self, inputs, ddp, ignore_index): metric_args={ "thresholds": None, "ignore_index": ignore_index, + "input_format": input_format, }, ) @pytest.mark.parametrize("ignore_index", [None, -1, 0]) - def test_binary_precision_recall_curve_functional(self, inputs, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits"]) + def test_binary_precision_recall_curve_functional(self, inputs, ignore_index, input_format, request): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -86,6 +96,7 @@ def test_binary_precision_recall_curve_functional(self, inputs, ignore_index): metric_args={ "thresholds": None, "ignore_index": ignore_index, + "input_format": input_format, }, ) @@ -179,8 +190,10 @@ class TestMulticlassPrecisionRecallCurve(MetricTester): @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_multiclass_precision_recall_curve(self, inputs, ddp, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits"]) + def test_multiclass_precision_recall_curve(self, inputs, ddp, ignore_index, input_format, request): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -194,12 +207,15 @@ def test_multiclass_precision_recall_curve(self, inputs, ddp, ignore_index): "thresholds": None, "num_classes": NUM_CLASSES, "ignore_index": ignore_index, + "input_format": input_format, }, ) @pytest.mark.parametrize("ignore_index", [None, -1]) - def test_multiclass_precision_recall_curve_functional(self, inputs, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits"]) + def test_multiclass_precision_recall_curve_functional(self, inputs, ignore_index, input_format, request): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -212,6 +228,7 @@ def test_multiclass_precision_recall_curve_functional(self, inputs, ignore_index "thresholds": None, "num_classes": NUM_CLASSES, "ignore_index": ignore_index, + "input_format": input_format, }, ) @@ -318,8 +335,10 @@ class TestMultilabelPrecisionRecallCurve(MetricTester): @pytest.mark.parametrize("ignore_index", [None, -1, 0]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_multilabel_precision_recall_curve(self, inputs, ddp, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits"]) + def test_multilabel_precision_recall_curve(self, inputs, ddp, ignore_index, input_format, request): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -333,12 +352,15 @@ def test_multilabel_precision_recall_curve(self, inputs, ddp, ignore_index): "thresholds": None, "num_labels": NUM_CLASSES, "ignore_index": ignore_index, + "input_format": input_format, }, ) @pytest.mark.parametrize("ignore_index", [None, -1, 0]) - def test_multilabel_precision_recall_curve_functional(self, inputs, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits"]) + def test_multilabel_precision_recall_curve_functional(self, inputs, ignore_index, input_format, request): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -351,6 +373,7 @@ def test_multilabel_precision_recall_curve_functional(self, inputs, ignore_index "thresholds": None, "num_labels": NUM_CLASSES, "ignore_index": ignore_index, + "input_format": input_format, }, ) diff --git a/tests/unittests/classification/test_roc.py b/tests/unittests/classification/test_roc.py index b69dfc0c74b..0ff7203c4ce 100644 --- a/tests/unittests/classification/test_roc.py +++ b/tests/unittests/classification/test_roc.py @@ -24,7 +24,12 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES -from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.classification.inputs import ( + _binary_cases, + _multiclass_cases, + _multilabel_cases, + check_input_format_matches_data, +) from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index @@ -48,8 +53,10 @@ class TestBinaryROC(MetricTester): @pytest.mark.parametrize("ignore_index", [None, -1, 0]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_binary_roc(self, inputs, ddp, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits"]) + def test_binary_roc(self, inputs, ddp, ignore_index, input_format, request): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -62,12 +69,15 @@ def test_binary_roc(self, inputs, ddp, ignore_index): metric_args={ "thresholds": None, "ignore_index": ignore_index, + "input_format": input_format, }, ) @pytest.mark.parametrize("ignore_index", [None, -1, 0]) - def test_binary_roc_functional(self, inputs, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits"]) + def test_binary_roc_functional(self, inputs, ignore_index, input_format, request): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -79,6 +89,7 @@ def test_binary_roc_functional(self, inputs, ignore_index): metric_args={ "thresholds": None, "ignore_index": ignore_index, + "input_format": input_format, }, ) @@ -162,8 +173,10 @@ class TestMulticlassROC(MetricTester): @pytest.mark.parametrize("ignore_index", [None, -1, 0]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_multiclass_roc(self, inputs, ddp, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits"]) + def test_multiclass_roc(self, inputs, ddp, ignore_index, input_format, request): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -177,12 +190,15 @@ def test_multiclass_roc(self, inputs, ddp, ignore_index): "thresholds": None, "num_classes": NUM_CLASSES, "ignore_index": ignore_index, + "input_format": input_format, }, ) @pytest.mark.parametrize("ignore_index", [None, -1, 0]) - def test_multiclass_roc_functional(self, inputs, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits"]) + def test_multiclass_roc_functional(self, inputs, ignore_index, input_format, request): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -195,6 +211,7 @@ def test_multiclass_roc_functional(self, inputs, ignore_index): "thresholds": None, "num_classes": NUM_CLASSES, "ignore_index": ignore_index, + "input_format": input_format, }, ) @@ -285,8 +302,10 @@ class TestMultilabelROC(MetricTester): @pytest.mark.parametrize("ignore_index", [None, -1, 0]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_multilabel_roc(self, inputs, ddp, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits"]) + def test_multilabel_roc(self, inputs, ddp, ignore_index, input_format, request): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -300,12 +319,15 @@ def test_multilabel_roc(self, inputs, ddp, ignore_index): "thresholds": None, "num_labels": NUM_CLASSES, "ignore_index": ignore_index, + "input_format": input_format, }, ) @pytest.mark.parametrize("ignore_index", [None, -1, 0]) - def test_multilabel_roc_functional(self, inputs, ignore_index): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits"]) + def test_multilabel_roc_functional(self, inputs, ignore_index, input_format, request): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index is not None: target = inject_ignore_index(target, ignore_index) @@ -318,6 +340,7 @@ def test_multilabel_roc_functional(self, inputs, ignore_index): "thresholds": None, "num_labels": NUM_CLASSES, "ignore_index": ignore_index, + "input_format": input_format, }, ) diff --git a/tests/unittests/classification/test_specificity.py b/tests/unittests/classification/test_specificity.py index 824e8667e92..89f86a289c3 100644 --- a/tests/unittests/classification/test_specificity.py +++ b/tests/unittests/classification/test_specificity.py @@ -33,7 +33,12 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.classification.inputs import ( + _binary_cases, + _multiclass_cases, + _multilabel_cases, + check_input_format_matches_data, +) from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index @@ -91,8 +96,10 @@ class TestBinarySpecificity(MetricTester): @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_binary_specificity(self, ddp, inputs, ignore_index, multidim_average): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_binary_specificity(self, ddp, inputs, ignore_index, multidim_average, input_format, request): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -109,13 +116,20 @@ def test_binary_specificity(self, ddp, inputs, ignore_index, multidim_average): reference_metric=partial( _baseline_specificity_binary, ignore_index=ignore_index, multidim_average=multidim_average ), - metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, + metric_args={ + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "input_format": input_format, + }, ) @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) - def test_binary_specificity_functional(self, inputs, ignore_index, multidim_average): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_binary_specificity_functional(self, inputs, ignore_index, multidim_average, input_format, request): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -133,6 +147,7 @@ def test_binary_specificity_functional(self, inputs, ignore_index, multidim_aver "threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average, + "input_format": input_format, }, ) @@ -255,8 +270,10 @@ class TestMulticlassSpecificity(MetricTester): @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", None]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_multiclass_specificity(self, ddp, inputs, ignore_index, multidim_average, average): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_multiclass_specificity(self, ddp, inputs, ignore_index, multidim_average, average, input_format, request): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -281,14 +298,19 @@ def test_multiclass_specificity(self, ddp, inputs, ignore_index, multidim_averag "multidim_average": multidim_average, "average": average, "num_classes": NUM_CLASSES, + "input_format": input_format, }, ) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", None]) - def test_multiclass_specificity_functional(self, inputs, ignore_index, multidim_average, average): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_multiclass_specificity_functional( + self, inputs, ignore_index, multidim_average, average, input_format, request + ): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -310,6 +332,7 @@ def test_multiclass_specificity_functional(self, inputs, ignore_index, multidim_ "multidim_average": multidim_average, "average": average, "num_classes": NUM_CLASSES, + "input_format": input_format, }, ) @@ -457,8 +480,10 @@ class TestMultilabelSpecificity(MetricTester): @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", None]) - def test_multilabel_specificity(self, ddp, inputs, ignore_index, multidim_average, average): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_multilabel_specificity(self, ddp, inputs, ignore_index, multidim_average, average, input_format, request): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -484,14 +509,19 @@ def test_multilabel_specificity(self, ddp, inputs, ignore_index, multidim_averag "ignore_index": ignore_index, "multidim_average": multidim_average, "average": average, + "input_format": input_format, }, ) @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", None]) - def test_multilabel_specificity_functional(self, inputs, ignore_index, multidim_average, average): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_multilabel_specificity_functional( + self, inputs, ignore_index, multidim_average, average, input_format, request + ): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -514,6 +544,7 @@ def test_multilabel_specificity_functional(self, inputs, ignore_index, multidim_ "ignore_index": ignore_index, "multidim_average": multidim_average, "average": average, + "input_format": input_format, }, ) diff --git a/tests/unittests/classification/test_stat_scores.py b/tests/unittests/classification/test_stat_scores.py index ef6f25bd7bf..6fc09bcc4ed 100644 --- a/tests/unittests/classification/test_stat_scores.py +++ b/tests/unittests/classification/test_stat_scores.py @@ -32,7 +32,12 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.classification.inputs import ( + _binary_cases, + _multiclass_cases, + _multilabel_cases, + check_input_format_matches_data, +) from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index @@ -74,8 +79,10 @@ class TestBinaryStatScores(MetricTester): @pytest.mark.parametrize("ignore_index", [None, 0, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_binary_stat_scores(self, ddp, inputs, ignore_index, multidim_average): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_binary_stat_scores(self, ddp, inputs, ignore_index, multidim_average, input_format, request): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -92,13 +99,20 @@ def test_binary_stat_scores(self, ddp, inputs, ignore_index, multidim_average): reference_metric=partial( _sklearn_stat_scores_binary, ignore_index=ignore_index, multidim_average=multidim_average ), - metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, + metric_args={ + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "input_format": input_format, + }, ) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) - def test_binary_stat_scores_functional(self, inputs, ignore_index, multidim_average): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_binary_stat_scores_functional(self, inputs, ignore_index, multidim_average, input_format, request): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -116,6 +130,7 @@ def test_binary_stat_scores_functional(self, inputs, ignore_index, multidim_aver "threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average, + "input_format": input_format, }, ) @@ -226,8 +241,10 @@ class TestMulticlassStatScores(MetricTester): @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", None]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_multiclass_stat_scores(self, ddp, inputs, ignore_index, multidim_average, average): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_multiclass_stat_scores(self, ddp, inputs, ignore_index, multidim_average, average, input_format, request): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -252,14 +269,19 @@ def test_multiclass_stat_scores(self, ddp, inputs, ignore_index, multidim_averag "multidim_average": multidim_average, "average": average, "num_classes": NUM_CLASSES, + "input_format": input_format, }, ) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", None]) - def test_multiclass_stat_scores_functional(self, inputs, ignore_index, multidim_average, average): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_multiclass_stat_scores_functional( + self, inputs, ignore_index, multidim_average, average, input_format, request + ): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -281,6 +303,7 @@ def test_multiclass_stat_scores_functional(self, inputs, ignore_index, multidim_ "multidim_average": multidim_average, "average": average, "num_classes": NUM_CLASSES, + "input_format": input_format, }, ) @@ -441,8 +464,10 @@ class TestMultilabelStatScores(MetricTester): @pytest.mark.parametrize("ignore_index", [None, 0, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", None]) - def test_multilabel_stat_scores(self, ddp, inputs, ignore_index, multidim_average, average): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_multilabel_stat_scores(self, ddp, inputs, ignore_index, multidim_average, average, input_format, request): """Test class implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -468,14 +493,19 @@ def test_multilabel_stat_scores(self, ddp, inputs, ignore_index, multidim_averag "ignore_index": ignore_index, "multidim_average": multidim_average, "average": average, + "input_format": input_format, }, ) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", None]) - def test_multilabel_stat_scores_functional(self, inputs, ignore_index, multidim_average, average): + @pytest.mark.parametrize("input_format", ["auto", "probs", "logits", "labels"]) + def test_multilabel_stat_scores_functional( + self, inputs, ignore_index, multidim_average, average, input_format, request + ): """Test functional implementation of metric.""" + check_input_format_matches_data(input_format, request) preds, target = inputs if ignore_index == -1: target = inject_ignore_index(target, ignore_index) @@ -498,6 +528,7 @@ def test_multilabel_stat_scores_functional(self, inputs, ignore_index, multidim_ "ignore_index": ignore_index, "multidim_average": multidim_average, "average": average, + "input_format": input_format, }, )