diff --git a/hummingbird/ml/operator_converters/sklearn/poly_features.py b/hummingbird/ml/operator_converters/sklearn/poly_features.py index 1e87eb3e..7855010c 100644 --- a/hummingbird/ml/operator_converters/sklearn/poly_features.py +++ b/hummingbird/ml/operator_converters/sklearn/poly_features.py @@ -10,6 +10,7 @@ from .._physical_operator import PhysicalOperator from onnxconverter_common.registration import register_converter import torch +import itertools class PolynomialFeatures(PhysicalOperator, torch.nn.Module): @@ -19,40 +20,37 @@ class PolynomialFeatures(PhysicalOperator, torch.nn.Module): # TODO extend this class to support higher orders """ - def __init__(self, operator, n_features, degree, interaction_only, include_bias, device): + def __init__(self, operator, n_features, degree=2, interaction_only=False, include_bias=True, device=None): super(PolynomialFeatures, self).__init__(operator) self.transformer = True - + self.degree = degree self.n_features = n_features self.interaction_only = interaction_only self.include_bias = include_bias - indices = [i for j in range(n_features) for i in range(j * n_features + j, (j + 1) * n_features)] - self.n_poly_features = len(indices) - self.n_features = n_features - self.indices = torch.nn.Parameter(torch.LongTensor(indices), requires_grad=False) + def forward(self, x): + if self.degree < 0: + raise ValueError("Degree should be greater than or equal to 0.") - self.bias = torch.nn.Parameter(torch.FloatTensor([1.0]), requires_grad=False) + features = [] - def forward(self, x): - x_orig = x - x = x.view(-1, self.n_features, 1) * x.view(-1, 1, self.n_features) - x = x.view(-1, self.n_features ** 2) - x = torch.index_select(x, 1, self.indices) - - # TODO: This gives mismatched elements - # if self.interaction_only: - # if self.include_bias: - # bias = self.bias.expand(x_orig.size()[0], 1) - # return torch.cat([bias, x], dim=1) - # else: - # return x + # Move input to GPU if available + device = x.device + # Add bias term if include_bias is True if self.include_bias: - bias = self.bias.expand(x_orig.size()[0], 1) - return torch.cat([bias, x_orig, x], dim=1) - else: - return torch.cat([x_orig, x], dim=1) + bias = torch.ones(x.size()[0], 1, device=device) + features.append(bias) + + # Generate polynomial features + for d in range(1, self.degree + 1): + for combo in itertools.combinations_with_replacement(range(self.n_features), d): + if self.interaction_only and len(set(combo)) != d: + continue + new_feature = torch.prod(torch.stack([x[:, idx] for idx in combo], dim=1), dim=1, keepdim=True) + features.append(new_feature) + + return torch.cat(features, dim=1).to(device=device) def convert_sklearn_poly_features(operator, device, extra_config): @@ -71,11 +69,8 @@ def convert_sklearn_poly_features(operator, device, extra_config): """ assert operator is not None, "Cannot convert None operator" - if operator.raw_operator.interaction_only: - raise NotImplementedError("Hummingbird does not currently support interaction_only flag for PolynomialFeatures") - - if operator.raw_operator.degree != 2: - raise NotImplementedError("Hummingbird currently only supports degree 2 for PolynomialFeatures") + if operator.raw_operator.degree < 0: + raise NotImplementedError("Hummingbird does not supports negtive degree for PolynomialFeatures") return PolynomialFeatures( operator, operator.raw_operator.n_features_in_, diff --git a/tests/test_sklearn_poly_features_converter.py b/tests/test_sklearn_poly_features_converter.py index 3d8a60c3..85721556 100644 --- a/tests/test_sklearn_poly_features_converter.py +++ b/tests/test_sklearn_poly_features_converter.py @@ -46,13 +46,11 @@ def test_sklearn_poly_feat_with_no_bias(self): def test_sklearn_poly_featurizer_raises(self): data = np.array([[1.2, 3.2, 1.3, -5.6], [4.3, -3.2, 5.7, 1.0], [0, 3.2, 4.7, -8.9]], dtype=np.float32) - # TODO: delete when implemented model = PolynomialFeatures(degree=4, include_bias=True, order="F").fit(data) - self.assertRaises(NotImplementedError, hummingbird.ml.convert, model, "torch") + self._test_sklearn_polynomial_featurizer(data, model) - # TODO: delete when implemented model = PolynomialFeatures(degree=2, interaction_only=True, order="F").fit(data) - self.assertRaises(NotImplementedError, hummingbird.ml.convert, model, "torch") + self._test_sklearn_polynomial_featurizer(data, model) if __name__ == "__main__":