Skip to content

Commit

Permalink
Passing the store to custom transformations
Browse files Browse the repository at this point in the history
  • Loading branch information
MatsMoll committed Oct 13, 2024
1 parent 9bc1f29 commit 903e729
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 21 deletions.
10 changes: 10 additions & 0 deletions aligned/compiler/feature_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ class FeatureFactory(FeatureReferencable):
_location: FeatureLocation | None = None
_description: str | None = None
_default_value: LiteralValue | None = None
_loads_feature: FeatureReference | None = None

tags: set[str] | None = None
transformation: TransformationFactory | None = None
Expand Down Expand Up @@ -477,6 +478,7 @@ def compile(self) -> DerivedFeature:
description=self._description,
tags=list(self.tags) if self.tags else None,
constraints=self.constraints,
loads_feature=self._loads_feature,
)

def depth(self) -> int:
Expand Down Expand Up @@ -635,6 +637,14 @@ def referencing(self: T, entity: FeatureFactory) -> T:
self._add_constraint(ReferencingColumn(entity.feature_reference()))
return self

def for_entities(self: T, entities: dict[str, FeatureFactory]) -> T:
from aligned.compiler.transformation_factory import LoadFeature

new = self.copy_type()
new.transformation = LoadFeature(entities, self.feature_reference())
new._loads_feature = self.feature_reference()
return new


class CouldBeModelVersion:
def as_model_version(self) -> ModelVersion:
Expand Down
27 changes: 26 additions & 1 deletion aligned/compiler/transformation_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from aligned import AwsS3Config
from aligned.lazy_imports import pandas as pd
from aligned.compiler.feature_factory import FeatureFactory, Transformation, TransformationFactory
from aligned.schemas.feature import FeatureReference
from aligned.schemas.transformation import FillNaValuesColumns, LiteralValue, EmbeddingModel

if TYPE_CHECKING:
Expand Down Expand Up @@ -759,7 +760,7 @@ def compile(self) -> Transformation:

if isinstance(self.method, pl.Expr):

def method(df: pl.DataFrame, alias: str) -> pl.Expr:
def method(df: pl.DataFrame, alias: str, store: ContractStore) -> pl.Expr:
return self.method # type: ignore

return PolarsLambdaTransformation(method=dill.dumps(method), code='', dtype=self.dtype.dtype)
Expand Down Expand Up @@ -1005,3 +1006,27 @@ def compile(self) -> Transformation:
return MultiplyValue(self.first.name, self.behind)
else:
return Multiply(self.first.name, self.behind.name)


@dataclass
class LoadFeature(TransformationFactory):

entities: dict[str, FeatureFactory]
feature: FeatureReference

@property
def using_features(self) -> list[FeatureFactory]:
return list(self.entities.values())

def compile(self) -> Transformation:
from aligned.compiler.feature_factory import List
from aligned.schemas.transformation import LoadFeature

explode_key: str | None = None
for feature in self.entities.values():
if isinstance(feature, List):
explode_key = feature.name

return LoadFeature(
{key: value.name for key, value in self.entities.items()}, self.feature, explode_key
)
40 changes: 28 additions & 12 deletions aligned/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,7 +1096,7 @@ def features_for(
else:
features = self.raw_string_features(set())

job = self.store.features_for(entities, list(features), event_timestamp_column=event_timestamp_column)
job = None

if isinstance(entities, (dict, pl.DataFrame, pd.DataFrame)):

Expand All @@ -1117,7 +1117,16 @@ def features_for(
break

if not needs_core_features:
job = RetrivalJob.from_convertable(entities, request).derive_features(request.needed_requests)
job = (
RetrivalJob.from_convertable(entities, request)
.derive_features(request.needed_requests)
.inject_store(self.store)
)

if job is None:
job = self.store.features_for(
entities, list(features), event_timestamp_column=event_timestamp_column
)

return job

Expand All @@ -1129,23 +1138,30 @@ async def input_freshness(self) -> dict[FeatureLocation, datetime | None]:

locs: dict[FeatureLocation, EventTimestamp] = {}

other_locs: set[FeatureLocation] = set()

for req in self.request().needed_requests:
if req.event_timestamp:
locs[req.location] = req.event_timestamp

for feature in req.derived_features:
if feature.loads_feature:
other_locs.add(feature.loads_feature.location)

if self.model.exposed_model:
additional_model_deps = await self.model.exposed_model.depends_on()
for loc in additional_model_deps:
if loc in locs:
continue
other_locs.update(await self.model.exposed_model.depends_on())

if loc.location_type == 'model':
event_timestamp = self.store.model(loc.name).prediction_request().event_timestamp
else:
event_timestamp = self.store.feature_view(loc.name).request.event_timestamp
for loc in other_locs:
if loc in locs:
continue

if loc.location_type == 'model':
event_timestamp = self.store.model(loc.name).prediction_request().event_timestamp
else:
event_timestamp = self.store.feature_view(loc.name).request.event_timestamp

if event_timestamp:
locs[loc] = event_timestamp
if event_timestamp:
locs[loc] = event_timestamp

return await self.store.feature_source.freshness_for(locs)

Expand Down
11 changes: 7 additions & 4 deletions aligned/schemas/derivied_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class DerivedFeature(Feature):

depending_on: set[FeatureReference]
transformation: Transformation
loads_feature: FeatureReference | None
depth: int = 1

def __init__(
Expand All @@ -24,15 +25,17 @@ def __init__(
description: str | None = None,
tags: list[str] | None = None,
constraints: set[Constraint] | None = None,
loads_feature: FeatureReference | None = None,
):
self.name = name
self.dtype = dtype
self.depending_on = depending_on
self.transformation = transformation
self.tags = tags
self.depth = depth
self.dtype = dtype
self.description = description
self.tags = tags
self.constraints = constraints
self.depending_on = depending_on
self.transformation = transformation
self.loads_feature = loads_feature
self.default_value = None

def __pre_serialize__(self) -> DerivedFeature:
Expand Down
48 changes: 47 additions & 1 deletion aligned/schemas/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from aligned.lazy_imports import pandas as pd
from aligned.schemas.codable import Codable
from aligned.schemas.feature import FeatureType
from aligned.schemas.feature import FeatureReference, FeatureType
from aligned.schemas.literal_value import LiteralValue
from aligned.schemas.text_vectoriser import EmbeddingModel

Expand Down Expand Up @@ -2543,3 +2543,49 @@ async def transform_polars(
self, df: pl.LazyFrame, alias: str, store: ContractStore
) -> pl.LazyFrame | pl.Expr | pl.Expr:
return pl.col(self.key).str.split(self.separator)


@dataclass
class LoadFeature(Transformation):

entities: dict[str, str]
feature: FeatureReference
explode_key: str | None

async def transform_pandas(self, df: pd.DataFrame, store: ContractStore) -> pd.Series:

entities = {}
for key, df_key in self.entities.items():
entities[key] = df[df_key]

values = await store.features_for(entities, features=[self.feature.identifier]).to_pandas()
return values[self.feature.name] # type: ignore

async def transform_polars(
self, df: pl.LazyFrame, alias: str, store: ContractStore
) -> pl.LazyFrame | pl.Expr:

group_keys = []

if self.explode_key:
group_keys = ['row_nr']
entity_df = df.with_row_count('row_nr').explode(self.explode_key)
else:
entity_df = df

entities = entity_df.rename({df_key: key for key, df_key in self.entities.items()})

values = (
await store.features_for(entities.collect(), features=[self.feature.identifier])
.with_subfeatures()
.to_polars()
)

if group_keys:
values = values.group_by(group_keys).agg(
[pl.col(col) for col in values.columns if col not in group_keys]
)

values = values.select(pl.col(self.feature.name).alias(alias))

return pl.concat([df, values.lazy()], how='horizontal')
2 changes: 1 addition & 1 deletion aligned/sources/in_mem_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def with_view(self, view: 'CompiledFeatureView') -> 'InMemorySource':
self._vector_index_name = view.name

if self.data.is_empty():
return InMemorySource.from_values({feat.name: [] for feat in view.features})
return InMemorySource.from_values({feat.name: [] for feat in view.entities.union(view.features)})
return self

@classmethod
Expand Down
29 changes: 28 additions & 1 deletion aligned/tests/test_transformations.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import pytest
from aligned.compiler.feature_factory import EventTimestamp, Int32, String, Float
from aligned.compiler.feature_factory import EventTimestamp, Int32, String, Float, List

from aligned.feature_store import ContractStore
from aligned.feature_view.feature_view import feature_view
from aligned.schemas.transformation import SupportedTransformations
from aligned.sources.in_mem_source import InMemorySource
from aligned.sources.local import FileSource, CsvFileSource


Expand Down Expand Up @@ -160,3 +161,29 @@ class TestFill:

assert df['some_new_column'].is_null().sum() == 0
assert df['some_string'].is_null().sum() == 0


@pytest.mark.asyncio
async def test_load_features() -> None:
import polars as pl

@feature_view(source=InMemorySource.from_values({'passenger_id': [1, 2, 3], 'age': [24, 20, 30]}))
class Test:
passenger_id = Int32().as_entity()
age = Int32()

@feature_view(source=InMemorySource.empty())
class Other:
some_value = Int32()

lookup_id = some_value.transform_polars(pl.lit([2, 1]), as_dtype=List(Int32()))
age_value = Test().age.for_entities({'passenger_id': lookup_id})

store = ContractStore.empty()
store.add_feature_view(Test)
store.add_feature_view(Other)

df = await store.feature_view(Other).features_for({'some_value': [1, 1.5, 0.5]}).to_polars()

assert Other().age_value._loads_feature is not None
assert df['age_value'].null_count() == 0
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "aligned"
version = "0.0.107"
version = "0.0.108"
description = "A data managment and lineage tool for ML applications."
authors = ["Mats E. Mollestad <[email protected]>"]
license = "Apache-2.0"
Expand Down

0 comments on commit 903e729

Please sign in to comment.