diff --git a/pruning/architecture/pruning_methods/schedulers.py b/pruning/architecture/pruning_methods/schedulers.py index c97d416..5238e3a 100644 --- a/pruning/architecture/pruning_methods/schedulers.py +++ b/pruning/architecture/pruning_methods/schedulers.py @@ -26,16 +26,23 @@ def __iter__(self) -> Generator[list[float], None, None]: class IterativeStepScheduler(BasePruningStepScheduler): - def __iter__(self) -> Generator[float, None, None]: + def __iter__(self) -> Generator[list[float], None, None]: nonpruned_percent = 1 if self.start != 0: yield [self.start] nonpruned_percent -= round(self.start * nonpruned_percent, 8) - + dummy_one = 1 + # stop if pruned more than target pruning percentage - 0.1% while nonpruned_percent - (1 - self.end) > 0.001: - current_step = round(self.step * nonpruned_percent, 8) + if self.start != 0: + current_step = round(self.step * dummy_one, 8) + dummy_one -= current_step + dummy_one = round(dummy_one, 8) + else: + current_step = round(self.step * nonpruned_percent, 8) + nonpruned_percent -= current_step nonpruned_percent = round(nonpruned_percent, 8)