Skip to content

Commit

Permalink
Polynomial features will support degree from 0 to n (#763)
Browse files Browse the repository at this point in the history
* PolynomialFeatures doesn	 support negtive degree

* polyfeatures for n degree

---------

Co-authored-by: Pusunuru <[email protected]>
  • Loading branch information
giriprasad51 and Pusunuru authored Feb 21, 2024
1 parent 2a56526 commit 3eb8393
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 33 deletions.
53 changes: 24 additions & 29 deletions hummingbird/ml/operator_converters/sklearn/poly_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .._physical_operator import PhysicalOperator
from onnxconverter_common.registration import register_converter
import torch
import itertools


class PolynomialFeatures(PhysicalOperator, torch.nn.Module):
Expand All @@ -19,40 +20,37 @@ class PolynomialFeatures(PhysicalOperator, torch.nn.Module):
# TODO extend this class to support higher orders
"""

def __init__(self, operator, n_features, degree, interaction_only, include_bias, device):
def __init__(self, operator, n_features, degree=2, interaction_only=False, include_bias=True, device=None):
super(PolynomialFeatures, self).__init__(operator)
self.transformer = True

self.degree = degree
self.n_features = n_features
self.interaction_only = interaction_only
self.include_bias = include_bias

indices = [i for j in range(n_features) for i in range(j * n_features + j, (j + 1) * n_features)]
self.n_poly_features = len(indices)
self.n_features = n_features
self.indices = torch.nn.Parameter(torch.LongTensor(indices), requires_grad=False)
def forward(self, x):
if self.degree < 0:
raise ValueError("Degree should be greater than or equal to 0.")

self.bias = torch.nn.Parameter(torch.FloatTensor([1.0]), requires_grad=False)
features = []

def forward(self, x):
x_orig = x
x = x.view(-1, self.n_features, 1) * x.view(-1, 1, self.n_features)
x = x.view(-1, self.n_features ** 2)
x = torch.index_select(x, 1, self.indices)

# TODO: This gives mismatched elements
# if self.interaction_only:
# if self.include_bias:
# bias = self.bias.expand(x_orig.size()[0], 1)
# return torch.cat([bias, x], dim=1)
# else:
# return x
# Move input to GPU if available
device = x.device

# Add bias term if include_bias is True
if self.include_bias:
bias = self.bias.expand(x_orig.size()[0], 1)
return torch.cat([bias, x_orig, x], dim=1)
else:
return torch.cat([x_orig, x], dim=1)
bias = torch.ones(x.size()[0], 1, device=device)
features.append(bias)

# Generate polynomial features
for d in range(1, self.degree + 1):
for combo in itertools.combinations_with_replacement(range(self.n_features), d):
if self.interaction_only and len(set(combo)) != d:
continue
new_feature = torch.prod(torch.stack([x[:, idx] for idx in combo], dim=1), dim=1, keepdim=True)
features.append(new_feature)

return torch.cat(features, dim=1).to(device=device)


def convert_sklearn_poly_features(operator, device, extra_config):
Expand All @@ -71,11 +69,8 @@ def convert_sklearn_poly_features(operator, device, extra_config):
"""
assert operator is not None, "Cannot convert None operator"

if operator.raw_operator.interaction_only:
raise NotImplementedError("Hummingbird does not currently support interaction_only flag for PolynomialFeatures")

if operator.raw_operator.degree != 2:
raise NotImplementedError("Hummingbird currently only supports degree 2 for PolynomialFeatures")
if operator.raw_operator.degree < 0:
raise NotImplementedError("Hummingbird does not supports negtive degree for PolynomialFeatures")
return PolynomialFeatures(
operator,
operator.raw_operator.n_features_in_,
Expand Down
6 changes: 2 additions & 4 deletions tests/test_sklearn_poly_features_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,11 @@ def test_sklearn_poly_feat_with_no_bias(self):
def test_sklearn_poly_featurizer_raises(self):
data = np.array([[1.2, 3.2, 1.3, -5.6], [4.3, -3.2, 5.7, 1.0], [0, 3.2, 4.7, -8.9]], dtype=np.float32)

# TODO: delete when implemented
model = PolynomialFeatures(degree=4, include_bias=True, order="F").fit(data)
self.assertRaises(NotImplementedError, hummingbird.ml.convert, model, "torch")
self._test_sklearn_polynomial_featurizer(data, model)

# TODO: delete when implemented
model = PolynomialFeatures(degree=2, interaction_only=True, order="F").fit(data)
self.assertRaises(NotImplementedError, hummingbird.ml.convert, model, "torch")
self._test_sklearn_polynomial_featurizer(data, model)


if __name__ == "__main__":
Expand Down

0 comments on commit 3eb8393

Please sign in to comment.