Skip to content

Commit

Permalink
fixed paths of selfclean
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianGroeger96 committed Mar 19, 2024
1 parent 69b6e86 commit d6221b3
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 19 deletions.
2 changes: 1 addition & 1 deletion src/cleaner/auto_cleaning_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import scipy
import scipy.stats

from src.utils.plotting import (
from ..utils.plotting import (
plot_frac_cut,
plot_sensitivity,
subplot_frac_cut,
Expand Down
5 changes: 3 additions & 2 deletions src/cleaner/irrelevants/lad_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import numpy as np
from scipy.cluster.hierarchy import single

from src.cleaner.irrelevants.base_irrelevant_mixin import BaseIrrelevantMixin
from src.scoring.lad_scoring import LAD
from ssl_library.src.utils.logging import plot_dist

from ...cleaner.irrelevants.base_irrelevant_mixin import BaseIrrelevantMixin
from ...scoring.lad_scoring import LAD


class LADIrrelevantMixin(BaseIrrelevantMixin):
def __init__(self, global_leaves: bool = False, **kwargs):
Expand Down
5 changes: 3 additions & 2 deletions src/cleaner/irrelevants/quantile_irrelevant_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import numpy as np

from src.cleaner.irrelevants.base_irrelevant_mixin import BaseIrrelevantMixin
from src.utils.plotting import plot_irrelevant_samples
from ssl_library.src.utils.logging import plot_dist

from ...cleaner.irrelevants.base_irrelevant_mixin import BaseIrrelevantMixin
from ...utils.plotting import plot_irrelevant_samples


class QuantileIrrelevantMixin(BaseIrrelevantMixin):
def __init__(self, quantile: float = 0.01, **kwargs):
Expand Down
5 changes: 3 additions & 2 deletions src/cleaner/label_errors/intra_extra_distance_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

import numpy as np

from src.cleaner.label_errors.base_label_error_mixin import BaseLabelErrorMixin
from src.utils.utils import has_same_label
from ssl_library.src.utils.logging import plot_dist

from ...cleaner.label_errors.base_label_error_mixin import BaseLabelErrorMixin
from ...utils.utils import has_same_label


class IntraExtraDistanceLabelErrorMixin(BaseLabelErrorMixin):
def labels_calc_scores(self) -> np.ndarray:
Expand Down
5 changes: 3 additions & 2 deletions src/cleaner/near_duplicates/embedding_distance_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import numpy as np
from tqdm import tqdm

from src.cleaner.near_duplicates.base_near_duplicate_mixin import BaseNearDuplicateMixin
from src.utils.utils import condensed_to_square
from ssl_library.src.utils.logging import plot_dist

from ...cleaner.near_duplicates.base_near_duplicate_mixin import BaseNearDuplicateMixin
from ...utils.utils import condensed_to_square


class EmbeddingDistanceMixin(BaseNearDuplicateMixin):
def get_near_duplicate_ranking(self) -> List[Tuple[float, int]]:
Expand Down
5 changes: 3 additions & 2 deletions src/cleaner/selfclean.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
from torchvision.datasets import ImageFolder
from torchvision.transforms import InterpolationMode

from src.cleaner.selfclean_cleaner import SelfCleanCleaner
from src.utils.utils import set_dataset_transformation
from ssl_library.src.augmentations.ibot import iBOTDataAugmentation
from ssl_library.src.pkg import Embedder, embed_dataset
from ssl_library.src.trainers.dino_trainer import DINOTrainer
from ssl_library.src.utils.utils import cleanup, init_distributed_mode

from ..cleaner.selfclean_cleaner import SelfCleanCleaner
from ..utils.utils import set_dataset_transformation

DINO_STANDARD_HYPERPARAMETERS = {
"optim": "adamw",
"lr": 0.0005,
Expand Down
14 changes: 6 additions & 8 deletions src/cleaner/selfclean_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,14 @@
import sklearn # noqa: F401
from tqdm import tqdm

import src.distances # noqa: F401
import src.distances.projective_distance # noqa: F401
from src.cleaner.auto_cleaning_mixin import AutoCleaningMixin
from src.cleaner.base_cleaner import BaseCleaner
from src.cleaner.irrelevants.lad_mixin import LADIrrelevantMixin
from src.cleaner.label_errors.intra_extra_distance_mixin import (
from ..cleaner.auto_cleaning_mixin import AutoCleaningMixin
from ..cleaner.base_cleaner import BaseCleaner
from ..cleaner.irrelevants.lad_mixin import LADIrrelevantMixin
from ..cleaner.label_errors.intra_extra_distance_mixin import (
IntraExtraDistanceLabelErrorMixin,
)
from src.cleaner.near_duplicates.embedding_distance_mixin import EmbeddingDistanceMixin
from src.utils.plotting import plot_inspection_result
from ..cleaner.near_duplicates.embedding_distance_mixin import EmbeddingDistanceMixin
from ..utils.plotting import plot_inspection_result


class SelfCleanCleaner(
Expand Down

0 comments on commit d6221b3

Please sign in to comment.