Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Classification: option to disable input formatting [wip] #1676

Open
wants to merge 46 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
16a6ee9
initial idea
SkafteNicki Mar 31, 2023
0875df2
Merge branch 'master' into bugfix/disable_input_format
Borda Mar 31, 2023
99f2488
Merge branch 'master' into bugfix/disable_input_format
Borda Mar 31, 2023
af183a3
Merge branch 'master' into bugfix/disable_input_format
Borda Apr 3, 2023
f5a883b
Merge branch 'master' into bugfix/disable_input_format
SkafteNicki Apr 13, 2023
424104d
Merge branch 'master' into bugfix/disable_input_format
Borda Apr 17, 2023
d937d90
Merge branch 'master' into bugfix/disable_input_format
Borda Apr 18, 2023
b52a09a
Merge branch 'master' into bugfix/disable_input_format
Borda Apr 26, 2023
1c14524
Merge branch 'master' into bugfix/disable_input_format
Borda Aug 7, 2023
54e52b5
Merge branch 'master' into bugfix/disable_input_format
SkafteNicki Aug 9, 2023
5cd443b
new interface
SkafteNicki Aug 9, 2023
22f06fa
fix
SkafteNicki Aug 9, 2023
f18dc30
Merge branch 'master' into bugfix/disable_input_format
SkafteNicki Aug 17, 2023
5de39d3
base functional implementation
SkafteNicki Aug 17, 2023
9f535db
base module implementation
SkafteNicki Aug 17, 2023
963bc3c
Merge branch 'master' into bugfix/disable_input_format
Borda Sep 19, 2023
2e9673c
Merge branch 'master' into bugfix/disable_input_format
justusschock Oct 20, 2023
a3f9f3f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 20, 2023
67abaa6
merge master
SkafteNicki Dec 21, 2023
25e3c46
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 21, 2023
a0775f6
confmat working and being tested
SkafteNicki Dec 21, 2023
20ed714
confmat working and being tested
SkafteNicki Dec 21, 2023
b8a5b21
accuracy
SkafteNicki Dec 21, 2023
3ce162b
specificity
SkafteNicki Dec 21, 2023
76a5da1
precision and recall
SkafteNicki Dec 21, 2023
c7bc98a
f1 + fbeta
SkafteNicki Dec 21, 2023
cd52af8
precision recall curve
SkafteNicki Dec 21, 2023
be59da4
auroc
SkafteNicki Dec 21, 2023
eb4f9ba
average precision
SkafteNicki Dec 21, 2023
083b1a1
roc
SkafteNicki Dec 21, 2023
d59793d
missing parametrization
SkafteNicki Dec 21, 2023
51d7d30
Merge branch 'master' into bugfix/disable_input_format
Borda Dec 31, 2023
265bd03
Merge branch 'master' into bugfix/disable_input_format
Borda Jan 9, 2024
5654e11
_check_valid_input_format_type(input_format)
Borda Jan 9, 2024
d3a929f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2024
db7ce70
hamming and more
Borda Jan 9, 2024
5434c4d
Merge branch 'bugfix/disable_input_format' of https://github.com/PyTo…
Borda Jan 9, 2024
e716faf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2024
4fc84a2
test hamming
Borda Jan 9, 2024
783e7aa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2024
bd27107
Merge branch 'master' into bugfix/disable_input_format
Borda Jan 12, 2024
7b0f1e7
Merge branch 'master' into bugfix/disable_input_format
Borda Jan 30, 2024
1c39d86
Merge branch 'master' into bugfix/disable_input_format
SkafteNicki Feb 4, 2024
ea4d8bf
Merge branch 'master' into bugfix/disable_input_format
Borda Feb 6, 2024
3495ba2
Merge branch 'master' into bugfix/disable_input_format
Borda Feb 15, 2024
a7d719b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 52 additions & 5 deletions src/torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,21 @@ class BinaryAccuracy(BinaryStatScores):
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
input_format: str specifying the format of the input preds tensor. Can be one of:

- ``'auto'``: automatically detect the format based on the values in the tensor. If all values
are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values.
If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the
tensor to be logits and will apply sigmoid to the tensor and threshold the values.
- ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only
thresholding will be applied to the tensor and values will be checked to be in [0,1] range.
- ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We
will apply sigmoid to the tensor and threshold the values before calculating the metric.
- ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be
applied to preds tensor.
- ``'none'``: will disable all input formatting. This is the fastest option but also the least safe.

kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example (preds is int tensor):
>>> from torch import tensor
Expand Down Expand Up @@ -205,6 +220,21 @@ class MulticlassAccuracy(MulticlassStatScores):
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
input_format: str specifying the format of the input preds tensor. Can be one of:

- ``'auto'``: automatically detect the format based on the values in the tensor. If all values
are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values.
If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the
tensor to be logits and will apply sigmoid to the tensor and threshold the values.
- ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only
thresholding will be applied to the tensor and values will be checked to be in [0,1] range.
- ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We
will apply sigmoid to the tensor and threshold the values before calculating the metric.
- ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be
applied to preds tensor.
- ``'none'``: will disable all input formatting. This is the fastest option but also the least safe.

kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example (preds is int tensor):
>>> from torch import tensor
Expand Down Expand Up @@ -356,6 +386,21 @@ class MultilabelAccuracy(MultilabelStatScores):
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
input_format: str specifying the format of the input preds tensor. Can be one of:

- ``'auto'``: automatically detect the format based on the values in the tensor. If all values
are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values.
If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the
tensor to be logits and will apply sigmoid to the tensor and threshold the values.
- ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only
thresholding will be applied to the tensor and values will be checked to be in [0,1] range.
- ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We
will apply sigmoid to the tensor and threshold the values before calculating the metric.
- ``'labels'``: preds tensor contains integer values and is considered to be labels. No formatting will be
applied to preds tensor.
- ``'none'``: will disable all input formatting. This is the fastest option but also the least safe.

kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example (preds is int tensor):
>>> from torch import tensor
Expand Down Expand Up @@ -497,31 +542,33 @@ def __new__( # type: ignore[misc]
top_k: Optional[int] = 1,
ignore_index: Optional[int] = None,
validate_args: bool = True,
input_format: Literal["auto", "probs", "logits", "labels", "none"] = "auto",
**kwargs: Any,
) -> Metric:
"""Initialize task metric."""
task = ClassificationTask.from_str(task)

kwargs.update({
kwargs_extra = kwargs.copy()
kwargs_extra.update({
"multidim_average": multidim_average,
"ignore_index": ignore_index,
"validate_args": validate_args,
"input_format": input_format,
})

if task == ClassificationTask.BINARY:
return BinaryAccuracy(threshold, **kwargs)
return BinaryAccuracy(threshold, **kwargs_extra)
if task == ClassificationTask.MULTICLASS:
if not isinstance(num_classes, int):
raise ValueError(
f"Optional arg `num_classes` must be type `int` when task is {task}. Got {type(num_classes)}"
)
if not isinstance(top_k, int):
raise ValueError(f"Optional arg `top_k` must be type `int` when task is {task}. Got {type(top_k)}")
return MulticlassAccuracy(num_classes, top_k, average, **kwargs)
return MulticlassAccuracy(num_classes, top_k, average, **kwargs_extra)
if task == ClassificationTask.MULTILABEL:
if not isinstance(num_labels, int):
raise ValueError(
f"Optional arg `num_labels` must be type `int` when task is {task}. Got {type(num_labels)}"
)
return MultilabelAccuracy(num_labels, threshold, average, **kwargs)
return MultilabelAccuracy(num_labels, threshold, average, **kwargs_extra)
raise ValueError(f"Not handled value: {task}")
90 changes: 68 additions & 22 deletions src/torchmetrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,18 @@ class BinaryAUROC(BinaryPrecisionRecallCurve):

validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
input_format: str specifying the format of the input preds tensor. Can be one of:

- ``'auto'``: automatically detect the format based on the values in the tensor. If all values
are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values.
If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the
tensor to be logits and will apply sigmoid to the tensor and threshold the values.
- ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only
thresholding will be applied to the tensor and values will be checked to be in [0,1] range.
- ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We
will apply sigmoid to the tensor and threshold the values before calculating the metric.
- ``'none'``: will disable all input formatting. This is the fastest option but also the least safe.

kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example:
Expand Down Expand Up @@ -111,11 +123,14 @@ def __init__(
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
input_format: Literal["auto", "probs", "logits", "none"] = "auto",
**kwargs: Any,
) -> None:
super().__init__(thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs)
super().__init__(
thresholds=thresholds, ignore_index=ignore_index, validate_args=False, input_format=input_format, **kwargs
)
if validate_args:
_binary_auroc_arg_validation(max_fpr, thresholds, ignore_index)
_binary_auroc_arg_validation(max_fpr, thresholds, ignore_index, input_format)
self.max_fpr = max_fpr

def compute(self) -> Tensor: # type: ignore[override]
Expand Down Expand Up @@ -221,6 +236,18 @@ class MulticlassAUROC(MulticlassPrecisionRecallCurve):

validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
input_format: str specifying the format of the input preds tensor. Can be one of:

- ``'auto'``: automatically detect the format based on the values in the tensor. If all values
are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values.
If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the
tensor to be logits and will apply sigmoid to the tensor and threshold the values.
- ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only
thresholding will be applied to the tensor and values will be checked to be in [0,1] range.
- ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We
will apply sigmoid to the tensor and threshold the values before calculating the metric.
- ``'none'``: will disable all input formatting. This is the fastest option but also the least safe.

kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example:
Expand Down Expand Up @@ -260,13 +287,19 @@ def __init__(
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
input_format: Literal["auto", "probs", "logits", "none"] = "auto",
**kwargs: Any,
) -> None:
super().__init__(
num_classes=num_classes, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs
num_classes=num_classes,
thresholds=thresholds,
ignore_index=ignore_index,
validate_args=False,
input_format=input_format,
**kwargs,
)
if validate_args:
_multiclass_auroc_arg_validation(num_classes, average, thresholds, ignore_index)
_multiclass_auroc_arg_validation(num_classes, average, thresholds, ignore_index, input_format)
self.average = average # type: ignore[assignment]
self.validate_args = validate_args

Expand Down Expand Up @@ -373,6 +406,18 @@ class MultilabelAUROC(MultilabelPrecisionRecallCurve):

validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
input_format: str specifying the format of the input preds tensor. Can be one of:

- ``'auto'``: automatically detect the format based on the values in the tensor. If all values
are in the [0,1] range, we consider the tensor to be probabilities and only thresholds the values.
If all values are non-float we consider the tensor to be labels and does nothing. Else we consider the
tensor to be logits and will apply sigmoid to the tensor and threshold the values.
- ``'probs'``: preds tensor contains values in the [0,1] range and is considered to be probabilities. Only
thresholding will be applied to the tensor and values will be checked to be in [0,1] range.
- ``'logits'``: preds tensor contains values outside the [0,1] range and is considered to be logits. We
will apply sigmoid to the tensor and threshold the values before calculating the metric.
- ``'none'``: will disable all input formatting. This is the fastest option but also the least safe.

kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example:
Expand Down Expand Up @@ -415,13 +460,19 @@ def __init__(
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
input_format: Literal["auto", "probs", "logits", "none"] = "auto",
**kwargs: Any,
) -> None:
super().__init__(
num_labels=num_labels, thresholds=thresholds, ignore_index=ignore_index, validate_args=False, **kwargs
num_labels=num_labels,
thresholds=thresholds,
ignore_index=ignore_index,
validate_args=False,
input_format=input_format,
**kwargs,
)
if validate_args:
_multilabel_auroc_arg_validation(num_labels, average, thresholds, ignore_index)
_multilabel_auroc_arg_validation(num_labels, average, thresholds, ignore_index, input_format)
self.average = average
self.validate_args = validate_args

Expand Down Expand Up @@ -516,31 +567,26 @@ def __new__( # type: ignore[misc]
max_fpr: Optional[float] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
input_format: Literal["auto", "probs", "logits", "none"] = "auto",
**kwargs: Any,
) -> Metric:
"""Initialize task metric."""
task = ClassificationTask.from_str(task)
kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args})
kwargs_extra = kwargs.copy()
kwargs_extra.update({
"thresholds": thresholds,
"ignore_index": ignore_index,
"validate_args": validate_args,
"input_format": input_format,
})
if task == ClassificationTask.BINARY:
return BinaryAUROC(max_fpr, **kwargs)
return BinaryAUROC(max_fpr, **kwargs_extra)
if task == ClassificationTask.MULTICLASS:
if not isinstance(num_classes, int):
raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
return MulticlassAUROC(num_classes, average, **kwargs)
return MulticlassAUROC(num_classes, average, **kwargs_extra)
if task == ClassificationTask.MULTILABEL:
if not isinstance(num_labels, int):
raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
return MultilabelAUROC(num_labels, average, **kwargs)
return MultilabelAUROC(num_labels, average, **kwargs_extra)
raise ValueError(f"Task {task} not supported!")

def update(self, *args: Any, **kwargs: Any) -> None:
"""Update metric state."""
raise NotImplementedError(
f"{self.__class__.__name__} metric does not have a global `update` method. Use the task specific metric."
)

def compute(self) -> None:
"""Compute metric."""
raise NotImplementedError(
f"{self.__class__.__name__} metric does not have a global `compute` method. Use the task specific metric."
)
Loading
Loading