Skip to content

Commit

Permalink
Added better custom aggregation support
Browse files Browse the repository at this point in the history
  • Loading branch information
MatsMoll committed Jan 18, 2024
1 parent b12b722 commit 84203c6
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 6 deletions.
12 changes: 8 additions & 4 deletions aligned/compiler/aggregation_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,11 @@ class PolarsTransformationFactoryAggregation(TransformationFactory, AggregationT
def using_features(self) -> list[FeatureFactory]:
return self._using_features

def aggregate_over(
self, group_by: list[FeatureReferance], time_column: FeatureReferance | None
) -> AggregateOver:
return aggregate_over(group_by, time_column, None, None, None, None)

def compile(self) -> Transformation:
import inspect
import types
Expand All @@ -358,10 +363,9 @@ def compile(self) -> Transformation:
from aligned.schemas.transformation import PolarsFunctionTransformation, PolarsLambdaTransformation

if isinstance(self.method, pl.Expr):
code = str(self.method)
return PolarsLambdaTransformation(
method=dill.dumps(self.method), code=code.strip(), dtype=self.dtype.dtype
)
method = lambda df, alias: self.method # type: ignore
code = ''
return PolarsLambdaTransformation(method=dill.dumps(method), code=code, dtype=self.dtype.dtype)
else:
code = inspect.getsource(self.method)

Expand Down
20 changes: 20 additions & 0 deletions aligned/compiler/feature_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,26 @@ def transform_polars(
dtype.transformation = PolarsTransformationFactory(dtype, expression, using_features or [self])
return dtype # type: ignore [return-value]

def polars_aggregation(self, aggregation: pl.Expr, as_type: T) -> T:
from aligned.compiler.aggregation_factory import PolarsTransformationFactoryAggregation

value = as_type.copy_type() # type: ignore [assignment]
value.transformation = PolarsTransformationFactoryAggregation(as_type, aggregation, [self])

return value

def polars_aggregation_using_features(
self: T,
using_features: list[FeatureFactory],
aggregation: pl.Expr,
) -> T:
from aligned.compiler.aggregation_factory import PolarsTransformationFactoryAggregation

value = self.copy_type() # type: ignore [assignment]
value.transformation = PolarsTransformationFactoryAggregation(self, aggregation, using_features)

return value

def is_required(self: T) -> T:
return self

Expand Down
21 changes: 20 additions & 1 deletion aligned/tests/test_transformations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from aligned.compiler.feature_factory import EventTimestamp, Int32, String
from aligned.compiler.feature_factory import EventTimestamp, Int32, String, Float

from aligned.feature_store import FeatureStore
from aligned.feature_view.feature_view import feature_view
Expand Down Expand Up @@ -54,6 +54,8 @@ class TestAgg:

@pytest.mark.asyncio
async def test_aggregations_on_all_no_window() -> None:
import polars as pl

@feature_view(name='test_agg', source=FileSource.parquet_at('test_data/credit_history.parquet'))
class TestAgg:
dob_ssn = String().as_entity()
Expand All @@ -66,9 +68,26 @@ class TestAgg:
credit_card_due_sum = credit_card_due.aggregate().sum()
student_loan_due_mean = student_loan_due.aggregate().mean()

custom_mean_aggregation = student_loan_due.polars_aggregation(
pl.col('student_loan_due').mean(),
as_type=Float(),
)
custom_mean_aggregation_using_features = Float().polars_aggregation_using_features(
using_features=[student_loan_due],
aggregation=pl.col('student_loan_due').mean(),
)
custom_sum_aggregation = credit_card_due.polars_aggregation(
pl.col('credit_card_due').sum(),
as_type=Float(),
)

df = await TestAgg.query().all().to_pandas() # type: ignore
assert df.shape[0] == 3

assert df['custom_mean_aggregation'].equals(df['student_loan_due_mean'])
assert df['custom_mean_aggregation'].equals(df['custom_mean_aggregation_using_features'])
assert df['custom_sum_aggregation'].equals(df['credit_card_due_sum'])


@pytest.mark.asyncio
async def test_aggregations_on_all_no_window_materialised() -> None:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "aligned"
version = "0.0.62"
version = "0.0.63"
description = "A data managment and lineage tool for ML applications."
authors = ["Mats E. Mollestad <[email protected]>"]
license = "Apache-2.0"
Expand Down

0 comments on commit 84203c6

Please sign in to comment.