diff --git a/openfisca_core/taxscales.py b/openfisca_core/taxscales.py index eadbd5d535..15b57ad3af 100644 --- a/openfisca_core/taxscales.py +++ b/openfisca_core/taxscales.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import copy import itertools import logging @@ -11,7 +9,6 @@ from numpy import ( around, array, - asarray, digitize, dot, finfo, @@ -44,7 +41,7 @@ def __init__( class_name: str, method_name: str, arg_name: str, - arg_value: Union[List, ndarray] + arg_value: ndarray ) -> None: message = [ f"'{class_name}.{method_name}' can't be run with an empty '{arg_name}':\n", @@ -109,7 +106,7 @@ def __repr__(self) -> Any: f"{self.__class__.__name__}", ) - def calc(self, _tax_base: Union[ndarray[int], ndarray[float]], _right: bool) -> Any: + def calc(self, _tax_base: ndarray, _right: bool) -> Any: raise NotImplementedError( "Method 'calc' is not implemented for " f"{self.__class__.__name__}", @@ -129,7 +126,7 @@ def multiply_thresholds( def bracket_indices( self, - tax_base: Union[ndarray[int], ndarray[float]], + tax_base: ndarray, factor: float = 1.0, round_base_decimals: Optional[int] = None, ) -> Any: @@ -247,20 +244,20 @@ def multiply_thresholds( def bracket_indices( self, - tax_base: Union[ndarray[int], ndarray[float]], + tax_base: ndarray, factor: float = 1.0, round_decimals: Optional[int] = None, - ) -> ndarray[int]: + ) -> ndarray: """ Compute the relevant bracket indices for the given tax bases. - :param ndarray tax_base: Array of the tax bases. - :param float factor: Factor to apply to the thresholds of the tax scales. - :param int round_decimals: Decimals to keep when rounding thresholds. + :param tax_base: Array of the tax bases. + :param factor: Factor to apply to the thresholds of the tax scales. + :param round_decimals: Decimals to keep when rounding thresholds. :returns: Int array with relevant bracket indices for the given tax bases. - For instance: + :example: >>> tax_scale = AbstractRateTaxScale() >>> tax_scale.add_bracket(0, 0) @@ -278,7 +275,7 @@ def bracket_indices( self.thresholds, ) - if not size(asarray(tax_base)): + if not size(array(tax_base)): raise EmptyArgumentError( self.__class__.__name__, "bracket_indices", @@ -342,11 +339,7 @@ def add_bracket(self, threshold: int, amount: Union[int, float]) -> None: self.thresholds.insert(i, threshold) self.amounts.insert(i, amount) - def calc( - self, - tax_base: Union[ndarray[int], ndarray[float]], - right: bool = False, - ) -> ndarray[float]: + def calc(self, tax_base: ndarray, right: bool = False) -> ndarray: guarded_thresholds = array([-inf] + self.thresholds + [inf]) bracket_indices = digitize(tax_base, guarded_thresholds, right = right) guarded_amounts = array([0] + self.amounts + [0]) @@ -366,11 +359,7 @@ class MarginalAmountTaxScale(SingleAmountTaxScale): containing the input. """ - def calc( - self, - tax_base: Union[ndarray[int], ndarray[float]], - _right: bool = False, - ) -> ndarray[float]: + def calc(self, tax_base: ndarray, _right: bool = False) -> ndarray: base1 = tile(tax_base, (len(self.thresholds), 1)).T thresholds1 = tile(hstack((self.thresholds, inf)), (len(tax_base), 1)) a = max_(min_(base1, thresholds1[:, 1:]) - thresholds1[:, :-1], 0) @@ -378,11 +367,7 @@ def calc( class LinearAverageRateTaxScale(AbstractRateTaxScale): - def calc( - self, - tax_base: Union[ndarray[int], ndarray[float]], - _right: bool = False, - ) -> ndarray[float]: + def calc(self, tax_base: ndarray, _right: bool = False) -> ndarray: if len(self.rates) == 1: return tax_base * self.rates[0] @@ -460,20 +445,20 @@ def add_tax_scale(self, tax_scale: AbstractRateTaxScale) -> None: def calc( self, - tax_base: Union[ndarray[int], ndarray[float]], + tax_base: ndarray, factor: float = 1.0, round_base_decimals: Optional[int] = None, - ) -> ndarray[float]: + ) -> ndarray: """ Compute the tax amount for the given tax bases by applying the taxscale. - :param ndarray tax_base: Array of the tax bases. - :param float factor: Factor to apply to the thresholds of the tax scale. - :param int round_base_decimals: Decimals to keep when rounding thresholds. + :param tax_base: Array of the tax bases. + :param factor: Factor to apply to the thresholds of the tax scale. + :param round_base_decimals: Decimals to keep when rounding thresholds. :returns: Float array with tax amount for the given tax bases. - For instance: + :example: >>> tax_scale = MarginalRateTaxScale() >>> tax_scale.add_bracket(0, 0) @@ -529,20 +514,20 @@ def combine_bracket( def marginal_rates( self, - tax_base: Union[ndarray[int], ndarray[float]], + tax_base: ndarray, factor: float = 1.0, round_base_decimals: Optional[int] = None, - ) -> ndarray[float]: + ) -> ndarray: """ Compute the marginal tax rates relevant for the given tax bases. - :param ndarray tax_base: Array of the tax bases. - :param float factor: Factor to apply to the thresholds of the tax scale. - :param int round_base_decimals: Decimals to keep when rounding thresholds. + :param tax_base: Array of the tax bases. + :param factor: Factor to apply to the thresholds of the tax scale. + :param round_base_decimals: Decimals to keep when rounding thresholds. :returns: Float array with relevant marginal tax rate for the given tax bases. - For instance: + :example: >>> tax_scale = MarginalRateTaxScale() >>> tax_scale.add_bracket(0, 0)