diff --git a/aligned/compiler/aggregation_factory.py b/aligned/compiler/aggregation_factory.py index 72e5d822..d701bfff 100644 --- a/aligned/compiler/aggregation_factory.py +++ b/aligned/compiler/aggregation_factory.py @@ -349,6 +349,11 @@ class PolarsTransformationFactoryAggregation(TransformationFactory, AggregationT def using_features(self) -> list[FeatureFactory]: return self._using_features + def aggregate_over( + self, group_by: list[FeatureReferance], time_column: FeatureReferance | None + ) -> AggregateOver: + return aggregate_over(group_by, time_column, None, None, None, None) + def compile(self) -> Transformation: import inspect import types @@ -358,10 +363,9 @@ def compile(self) -> Transformation: from aligned.schemas.transformation import PolarsFunctionTransformation, PolarsLambdaTransformation if isinstance(self.method, pl.Expr): - code = str(self.method) - return PolarsLambdaTransformation( - method=dill.dumps(self.method), code=code.strip(), dtype=self.dtype.dtype - ) + method = lambda df, alias: self.method # type: ignore + code = '' + return PolarsLambdaTransformation(method=dill.dumps(method), code=code, dtype=self.dtype.dtype) else: code = inspect.getsource(self.method) diff --git a/aligned/compiler/feature_factory.py b/aligned/compiler/feature_factory.py index bdcf0ae2..710e2950 100644 --- a/aligned/compiler/feature_factory.py +++ b/aligned/compiler/feature_factory.py @@ -517,6 +517,26 @@ def transform_polars( dtype.transformation = PolarsTransformationFactory(dtype, expression, using_features or [self]) return dtype # type: ignore [return-value] + def polars_aggregation(self, aggregation: pl.Expr, as_type: T) -> T: + from aligned.compiler.aggregation_factory import PolarsTransformationFactoryAggregation + + value = as_type.copy_type() # type: ignore [assignment] + value.transformation = PolarsTransformationFactoryAggregation(as_type, aggregation, [self]) + + return value + + def polars_aggregation_using_features( + self: T, + using_features: list[FeatureFactory], + aggregation: pl.Expr, + ) -> T: + from aligned.compiler.aggregation_factory import PolarsTransformationFactoryAggregation + + value = self.copy_type() # type: ignore [assignment] + value.transformation = PolarsTransformationFactoryAggregation(self, aggregation, using_features) + + return value + def is_required(self: T) -> T: return self diff --git a/aligned/tests/test_transformations.py b/aligned/tests/test_transformations.py index ea009bab..d2894fbb 100644 --- a/aligned/tests/test_transformations.py +++ b/aligned/tests/test_transformations.py @@ -1,5 +1,5 @@ import pytest -from aligned.compiler.feature_factory import EventTimestamp, Int32, String +from aligned.compiler.feature_factory import EventTimestamp, Int32, String, Float from aligned.feature_store import FeatureStore from aligned.feature_view.feature_view import feature_view @@ -54,6 +54,8 @@ class TestAgg: @pytest.mark.asyncio async def test_aggregations_on_all_no_window() -> None: + import polars as pl + @feature_view(name='test_agg', source=FileSource.parquet_at('test_data/credit_history.parquet')) class TestAgg: dob_ssn = String().as_entity() @@ -66,9 +68,26 @@ class TestAgg: credit_card_due_sum = credit_card_due.aggregate().sum() student_loan_due_mean = student_loan_due.aggregate().mean() + custom_mean_aggregation = student_loan_due.polars_aggregation( + pl.col('student_loan_due').mean(), + as_type=Float(), + ) + custom_mean_aggregation_using_features = Float().polars_aggregation_using_features( + using_features=[student_loan_due], + aggregation=pl.col('student_loan_due').mean(), + ) + custom_sum_aggregation = credit_card_due.polars_aggregation( + pl.col('credit_card_due').sum(), + as_type=Float(), + ) + df = await TestAgg.query().all().to_pandas() # type: ignore assert df.shape[0] == 3 + assert df['custom_mean_aggregation'].equals(df['student_loan_due_mean']) + assert df['custom_mean_aggregation'].equals(df['custom_mean_aggregation_using_features']) + assert df['custom_sum_aggregation'].equals(df['credit_card_due_sum']) + @pytest.mark.asyncio async def test_aggregations_on_all_no_window_materialised() -> None: diff --git a/pyproject.toml b/pyproject.toml index 47dddf10..f1783979 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "aligned" -version = "0.0.62" +version = "0.0.63" description = "A data managment and lineage tool for ML applications." authors = ["Mats E. Mollestad "] license = "Apache-2.0"