diff --git a/aligned/data_source/batch_data_source.py b/aligned/data_source/batch_data_source.py index b48090e4..e47fe4b4 100644 --- a/aligned/data_source/batch_data_source.py +++ b/aligned/data_source/batch_data_source.py @@ -1,6 +1,5 @@ from __future__ import annotations -from abc import ABC, abstractmethod from typing import TYPE_CHECKING, TypeVar, Any from dataclasses import dataclass @@ -66,7 +65,7 @@ def wrap_job(self, job: RetrivalJob) -> RetrivalJob: raise NotImplementedError() -class BatchDataSource(ABC, Codable, SerializableType): +class BatchDataSource(Codable, SerializableType): """ A definition to where a specific pice of data can be found. E.g: A database table, a file, a web service, etc. @@ -76,12 +75,11 @@ class BatchDataSource(ABC, Codable, SerializableType): type_name: str - @abstractmethod def job_group_key(self) -> str: """ A key defining which sources can be grouped together in one request. """ - pass + raise NotImplementedError(type(self)) def source_id(self) -> str: """ @@ -305,6 +303,24 @@ def multi_source_features_for( return source.source.features_for(facts, request).filter(source.condition) + async def freshness(self, event_timestamp: EventTimestamp) -> datetime | None: + return await self.source.freshness(event_timestamp) + + def all_between_dates( + self, request: RetrivalRequest, start_date: datetime, end_date: datetime + ) -> RetrivalJob: + + if isinstance(self.condition, Feature): + request.features.add(self.condition) + else: + request.derived_features.add(self.condition) + + return ( + self.source.all_between_dates(request, start_date, end_date) + .filter(self.condition) + .derive_features([request]) + ) + def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob: if isinstance(self.condition, Feature): @@ -479,6 +495,39 @@ def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob: .derive_features([request]) ) + def all_between_dates( + self, request: RetrivalRequest, start_date: datetime, end_date: datetime + ) -> RetrivalJob: + + right_job = self.right_source.all_data(self.right_request, limit=None).derive_features( + [self.right_request] + ) + + return ( + self.source.all_between_dates(self.left_request, start_date, end_date) + .derive_features([self.left_request]) + .join_asof( + right_job, + left_event_timestamp=self.left_event_timestamp, + right_event_timestamp=self.right_event_timestamp, + left_on=self.left_on, + right_on=self.right_on, + ) + .derive_features([request]) + ) + + async def freshness(self, event_timestamp: EventTimestamp) -> datetime | None: + left_freshness = await self.source.freshness(event_timestamp) + right_frehsness = await self.right_source.freshness(event_timestamp) + + if left_freshness is None: + return None + + if right_frehsness is None: + return None + + return min(left_freshness, right_frehsness) + def join( self, view: Any, @@ -504,7 +553,7 @@ def join_asof( ) def depends_on(self) -> set[FeatureLocation]: - return self.source.depends_on().intersection(self.right_source.depends_on()) + return self.source.depends_on().union(self.right_source.depends_on()) @dataclass @@ -547,6 +596,39 @@ def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob: .derive_features([request]) ) + def all_between_dates( + self, request: RetrivalRequest, start_date: datetime, end_date: datetime + ) -> RetrivalJob: + + right_job = self.right_source.all_data(self.right_request, limit=None).derive_features( + [self.right_request] + ) + + return ( + self.source.all_between_dates(self.left_request, start_date, end_date) + .derive_features([self.left_request]) + .join_asof( + right_job, + left_event_timestamp=self.left_event_timestamp, + right_event_timestamp=self.right_event_timestamp, + left_on=self.left_on, + right_on=self.right_on, + ) + .derive_features([request]) + ) + + async def freshness(self, event_timestamp: EventTimestamp) -> datetime | None: + left_freshness = await self.source.freshness(event_timestamp) + right_frehsness = await self.right_source.freshness(event_timestamp) + + if left_freshness is None: + return None + + if right_frehsness is None: + return None + + return min(left_freshness, right_frehsness) + def join( self, view: Any, @@ -563,7 +645,7 @@ def join( return join_source(self, view, on_left, on_right, how) def depends_on(self) -> set[FeatureLocation]: - return self.source.depends_on().intersection(self.right_source.depends_on()) + return self.source.depends_on().union(self.right_source.depends_on()) class ColumnFeatureMappable: @@ -571,7 +653,7 @@ class ColumnFeatureMappable: def with_renames(self: T, mapping_keys: dict[str, str]) -> T: self.mapping_keys = mapping_keys # type: ignore - return self + return selfFileSource.parquet_at('source_data/transactions.parquet') def columns_for(self, features: list[Feature]) -> list[str]: return [self.mapping_keys.get(feature.name, feature.name) for feature in features] diff --git a/aligned/feature_store.py b/aligned/feature_store.py index 4f686716..d43b1f83 100644 --- a/aligned/feature_store.py +++ b/aligned/feature_store.py @@ -894,6 +894,29 @@ def predictions_for( entities, features=[f'{location_id}:*'], event_timestamp_column=event_timestamp_column ) + def predictions_between(self, start_date: datetime, end_date: datetime) -> RetrivalJob: + + selected_source = self.store.feature_source + + if not isinstance(selected_source, BatchFeatureSource): + raise ValueError( + f'Unable to load all predictions for selected feature source {type(selected_source)}' + ) + + location = FeatureLocation.model(self.model.name) + if location.identifier not in selected_source.sources: + raise ValueError( + f'Unable to find source for {location.identifier}. Either set through a `prediction_source`' + 'in the model contract, or use the `using_source` method on the store object.' + ) + + source = selected_source.sources[location.identifier] + request = self.model.predictions_view.request(self.model.name) + + return source.all_between_dates(request, start_date, end_date).select_columns( + set(request.all_returned_columns) + ) + def all_predictions(self, limit: int | None = None) -> RetrivalJob: selected_source = self.store.feature_source @@ -913,7 +936,7 @@ def all_predictions(self, limit: int | None = None) -> RetrivalJob: source = selected_source.sources[location.identifier] request = self.model.predictions_view.request(self.model.name) - return source.all_data(request, limit=limit) + return source.all_data(request, limit=limit).select_columns(set(request.all_returned_columns)) def using_source(self, source: FeatureSourceable | BatchDataSource) -> ModelFeatureStore: @@ -1034,7 +1057,10 @@ class SupervisedModelFeatureStore: store: FeatureStore def features_for( - self, entities: ConvertableToRetrivalJob | RetrivalJob, event_timestamp_column: str | None = None + self, + entities: ConvertableToRetrivalJob | RetrivalJob, + event_timestamp_column: str | None = None, + target_event_timestamp_column: str | None = None, ) -> SupervisedJob: """Loads the features and labels for a model @@ -1079,12 +1105,23 @@ def features_for( else: raise ValueError('Found no targets in the model') + if event_timestamp_column == target_event_timestamp_column: + request = self.store.requests_for( + RawStringFeatureRequest(features.union(target_features)), + event_timestamp_column=event_timestamp_column, + ) + job = self.store.features_for_request(request, entities, request.features_to_include) + return SupervisedJob( + job.select_columns(request.features_to_include), + target_columns=targets, + ) + request = self.store.requests_for( RawStringFeatureRequest(features), event_timestamp_column=event_timestamp_column ) target_request = self.store.requests_for( - RawStringFeatureRequest(target_features), event_timestamp_column=event_timestamp_column - ).without_event_timestamp(name_sufix='target') + RawStringFeatureRequest(target_features), event_timestamp_column=target_event_timestamp_column + ).with_sufix('target') total_request = FeatureRequest( FeatureLocation.model(self.model.name), @@ -1098,7 +1135,10 @@ def features_for( ) def predictions_for( - self, entities: ConvertableToRetrivalJob | RetrivalJob, event_timestamp_column: str | None = None + self, + entities: ConvertableToRetrivalJob | RetrivalJob, + event_timestamp_column: str | None = None, + target_event_timestamp_column: str | None = None, ) -> RetrivalJob: """Loads the predictions and labels / ground truths for a model @@ -1148,8 +1188,9 @@ def predictions_for( RawStringFeatureRequest(pred_features), event_timestamp_column=event_timestamp_column ) target_request = self.store.requests_for( - RawStringFeatureRequest(target_features) - ).without_event_timestamp(name_sufix='target') + RawStringFeatureRequest(target_features), + event_timestamp_column=target_event_timestamp_column, + ).with_sufix('target') total_request = FeatureRequest( FeatureLocation.model(self.model.name), @@ -1268,24 +1309,16 @@ def previous(self, days: int = 0, minutes: int = 0, seconds: int = 0) -> Retriva def features_for( self, entities: ConvertableToRetrivalJob | RetrivalJob, event_timestamp_column: str | None = None ) -> RetrivalJob: - - request = self.view.request_all + features = {'*'} if self.feature_filter: - request = self.view.request_for(self.feature_filter) - - if not event_timestamp_column: - request = request.without_event_timestamp() - - if isinstance(entities, RetrivalJob): - entity_job = entities - else: - entity_job = RetrivalJob.from_convertable(entities, request.needed_requests) + features = self.feature_filter - job = self.source.features_for(entity_job, request) - if self.feature_filter: - return job.select_columns(self.feature_filter) - else: - return job + feature_refs = [f'{self.view.name}:{feature}' for feature in features] + return self.store.features_for( + entities, + feature_refs, + event_timestamp_column=event_timestamp_column, + ) def select(self, features: set[str]) -> FeatureViewStore: logger.info(f'Selecting features {features}') diff --git a/aligned/feature_view/feature_view.py b/aligned/feature_view/feature_view.py index 9590329d..837aa5aa 100644 --- a/aligned/feature_view/feature_view.py +++ b/aligned/feature_view/feature_view.py @@ -457,6 +457,7 @@ def compile_with_metadata(feature_view: Any, metadata: FeatureViewMetadata) -> C name=metadata.name, description=metadata.description, tags=metadata.tags, + contacts=metadata.contacts, source=metadata.source, entities=set(), features=set(), diff --git a/aligned/request/retrival_request.py b/aligned/request/retrival_request.py index 8d3e24fd..9626654c 100644 --- a/aligned/request/retrival_request.py +++ b/aligned/request/retrival_request.py @@ -187,6 +187,18 @@ def aggregate_over(self) -> dict[AggregateOver, set[AggregatedFeature]]: features[feature.aggregate_over].add(feature) return features + def with_sufix(self, sufix: str) -> 'RetrivalRequest': + + return RetrivalRequest( + name=f'{self.name}{sufix}', + location=self.location, + entities=self.entities, + features=self.features, + derived_features=self.derived_features, + aggregated_features=self.aggregated_features, + event_timestamp_request=self.event_timestamp_request, + ) + def without_event_timestamp(self, name_sufix: str | None = None) -> 'RetrivalRequest': request = None @@ -250,6 +262,18 @@ def combine(requests: list['RetrivalRequest']) -> list['RetrivalRequest']: return list(grouped_requests.values()) + def rename_entities(self, mapping: dict[str, str]) -> 'RetrivalRequest': + + return RetrivalRequest( + name=self.name, + location=self.location, + entities={entity.renamed(mapping.get(entity.name, entity.name)) for entity in self.entities}, + features=self.features, + derived_features=self.derived_features, + aggregated_features=self.aggregated_features, + event_timestamp_request=self.event_timestamp_request, + ) + @staticmethod def unsafe_combine(requests: list['RetrivalRequest']) -> 'RetrivalRequest': @@ -420,6 +444,13 @@ def without_event_timestamp(self, name_sufix: str | None = None) -> 'FeatureRequ needed_requests=[request.without_event_timestamp(name_sufix) for request in self.needed_requests], ) + def with_sufix(self, sufix: str) -> 'FeatureRequest': + return FeatureRequest( + location=self.location, + features_to_include=self.features_to_include, + needed_requests=[request.with_sufix(sufix) for request in self.needed_requests], + ) + def rename_entities(self, mappings: dict[str, str]) -> 'FeatureRequest': return FeatureRequest( location=self.location, diff --git a/aligned/schemas/feature.py b/aligned/schemas/feature.py index 59c68f67..e223ab04 100644 --- a/aligned/schemas/feature.py +++ b/aligned/schemas/feature.py @@ -269,6 +269,15 @@ def __pre_serialize__(self) -> Feature: return self + def renamed(self, new_name: str) -> Feature: + return Feature( + name=new_name, + dtype=self.dtype, + description=self.description, + tags=self.tags, + constraints=self.constraints, + ) + def __hash__(self) -> int: return hash(self.name) diff --git a/aligned/schemas/feature_view.py b/aligned/schemas/feature_view.py index 0c9d8d64..876bea60 100644 --- a/aligned/schemas/feature_view.py +++ b/aligned/schemas/feature_view.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING +from datetime import datetime from dataclasses import dataclass, field @@ -388,5 +389,22 @@ def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob: else: return core_job.derive_features().derive_features([request]) + def all_between_dates( + self, request: RetrivalRequest, start_date: datetime, end_date: datetime + ) -> RetrivalJob: + sub_source = self.view.materialized_source or self.view.source + + sub_req = self.sub_request(request) + + core_job = sub_source.all_between_dates(sub_req, start_date, end_date) + if request.aggregated_features: + return core_job.aggregate(request).derive_features([request]) + else: + return core_job.derive_features().derive_features([request]) + def depends_on(self) -> set[FeatureLocation]: return {FeatureLocation.feature_view(self.view.name)} + + async def freshness(self, event_timestamp: EventTimestamp) -> datetime | None: + source = self.view.materialized_source or self.view.source + return await source.freshness(event_timestamp) diff --git a/aligned/sources/local.py b/aligned/sources/local.py index fac11a0f..3c600447 100644 --- a/aligned/sources/local.py +++ b/aligned/sources/local.py @@ -62,7 +62,12 @@ async def as_repo_definition(self) -> RepoDefinition: async def data_file_freshness(reference: DataFileReference, column_name: str) -> datetime | None: try: file = await reference.to_polars() - return file.select(column_name).max().collect()[0, column_name] + if isinstance(reference, ColumnFeatureMappable): + source_column = reference.feature_identifier_for([column_name])[0] + else: + source_column = column_name + + return file.select(source_column).max().collect()[0, source_column] except UnableToFindFileException: return None @@ -314,11 +319,6 @@ async def feature_view_code(self, view_name: str) -> str: schema, data_source_code, view_name, 'from aligned import FileSource' ) - async def freshness(self, event_timestamp: EventTimestamp) -> datetime | None: - df = await self.to_polars() - et_name = event_timestamp.name - return df.select(et_name).max().collect()[0, et_name] - @dataclass class DeltaFileConfig(Codable):