Skip to content

Commit

Permalink
fix: join source bug
Browse files Browse the repository at this point in the history
  • Loading branch information
MatsMoll committed Dec 21, 2023
1 parent 9b5f943 commit a0bbbce
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 36 deletions.
96 changes: 89 additions & 7 deletions aligned/data_source/batch_data_source.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, TypeVar, Any
from dataclasses import dataclass

Expand Down Expand Up @@ -66,7 +65,7 @@ def wrap_job(self, job: RetrivalJob) -> RetrivalJob:
raise NotImplementedError()


class BatchDataSource(ABC, Codable, SerializableType):
class BatchDataSource(Codable, SerializableType):
"""
A definition to where a specific pice of data can be found.
E.g: A database table, a file, a web service, etc.
Expand All @@ -76,12 +75,11 @@ class BatchDataSource(ABC, Codable, SerializableType):

type_name: str

@abstractmethod
def job_group_key(self) -> str:
"""
A key defining which sources can be grouped together in one request.
"""
pass
raise NotImplementedError(type(self))

def source_id(self) -> str:
"""
Expand Down Expand Up @@ -305,6 +303,24 @@ def multi_source_features_for(

return source.source.features_for(facts, request).filter(source.condition)

async def freshness(self, event_timestamp: EventTimestamp) -> datetime | None:
return await self.source.freshness(event_timestamp)

def all_between_dates(
self, request: RetrivalRequest, start_date: datetime, end_date: datetime
) -> RetrivalJob:

if isinstance(self.condition, Feature):
request.features.add(self.condition)
else:
request.derived_features.add(self.condition)

return (
self.source.all_between_dates(request, start_date, end_date)
.filter(self.condition)
.derive_features([request])
)

def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob:

if isinstance(self.condition, Feature):
Expand Down Expand Up @@ -479,6 +495,39 @@ def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob:
.derive_features([request])
)

def all_between_dates(
self, request: RetrivalRequest, start_date: datetime, end_date: datetime
) -> RetrivalJob:

right_job = self.right_source.all_data(self.right_request, limit=None).derive_features(
[self.right_request]
)

return (
self.source.all_between_dates(self.left_request, start_date, end_date)
.derive_features([self.left_request])
.join_asof(
right_job,
left_event_timestamp=self.left_event_timestamp,
right_event_timestamp=self.right_event_timestamp,
left_on=self.left_on,
right_on=self.right_on,
)
.derive_features([request])
)

async def freshness(self, event_timestamp: EventTimestamp) -> datetime | None:
left_freshness = await self.source.freshness(event_timestamp)
right_frehsness = await self.right_source.freshness(event_timestamp)

if left_freshness is None:
return None

if right_frehsness is None:
return None

return min(left_freshness, right_frehsness)

def join(
self,
view: Any,
Expand All @@ -504,7 +553,7 @@ def join_asof(
)

def depends_on(self) -> set[FeatureLocation]:
return self.source.depends_on().intersection(self.right_source.depends_on())
return self.source.depends_on().union(self.right_source.depends_on())


@dataclass
Expand Down Expand Up @@ -547,6 +596,39 @@ def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob:
.derive_features([request])
)

def all_between_dates(
self, request: RetrivalRequest, start_date: datetime, end_date: datetime
) -> RetrivalJob:

right_job = self.right_source.all_data(self.right_request, limit=None).derive_features(
[self.right_request]
)

return (
self.source.all_between_dates(self.left_request, start_date, end_date)
.derive_features([self.left_request])
.join_asof(
right_job,
left_event_timestamp=self.left_event_timestamp,
right_event_timestamp=self.right_event_timestamp,
left_on=self.left_on,
right_on=self.right_on,
)
.derive_features([request])
)

async def freshness(self, event_timestamp: EventTimestamp) -> datetime | None:
left_freshness = await self.source.freshness(event_timestamp)
right_frehsness = await self.right_source.freshness(event_timestamp)

if left_freshness is None:
return None

if right_frehsness is None:
return None

return min(left_freshness, right_frehsness)

def join(
self,
view: Any,
Expand All @@ -563,15 +645,15 @@ def join(
return join_source(self, view, on_left, on_right, how)

def depends_on(self) -> set[FeatureLocation]:
return self.source.depends_on().intersection(self.right_source.depends_on())
return self.source.depends_on().union(self.right_source.depends_on())


class ColumnFeatureMappable:
mapping_keys: dict[str, str]

def with_renames(self: T, mapping_keys: dict[str, str]) -> T:
self.mapping_keys = mapping_keys # type: ignore
return self
return selfFileSource.parquet_at('source_data/transactions.parquet')

def columns_for(self, features: list[Feature]) -> list[str]:
return [self.mapping_keys.get(feature.name, feature.name) for feature in features]
Expand Down
79 changes: 56 additions & 23 deletions aligned/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,29 @@ def predictions_for(
entities, features=[f'{location_id}:*'], event_timestamp_column=event_timestamp_column
)

def predictions_between(self, start_date: datetime, end_date: datetime) -> RetrivalJob:

selected_source = self.store.feature_source

if not isinstance(selected_source, BatchFeatureSource):
raise ValueError(
f'Unable to load all predictions for selected feature source {type(selected_source)}'
)

location = FeatureLocation.model(self.model.name)
if location.identifier not in selected_source.sources:
raise ValueError(
f'Unable to find source for {location.identifier}. Either set through a `prediction_source`'
'in the model contract, or use the `using_source` method on the store object.'
)

source = selected_source.sources[location.identifier]
request = self.model.predictions_view.request(self.model.name)

return source.all_between_dates(request, start_date, end_date).select_columns(
set(request.all_returned_columns)
)

def all_predictions(self, limit: int | None = None) -> RetrivalJob:

selected_source = self.store.feature_source
Expand All @@ -913,7 +936,7 @@ def all_predictions(self, limit: int | None = None) -> RetrivalJob:
source = selected_source.sources[location.identifier]
request = self.model.predictions_view.request(self.model.name)

return source.all_data(request, limit=limit)
return source.all_data(request, limit=limit).select_columns(set(request.all_returned_columns))

def using_source(self, source: FeatureSourceable | BatchDataSource) -> ModelFeatureStore:

Expand Down Expand Up @@ -1034,7 +1057,10 @@ class SupervisedModelFeatureStore:
store: FeatureStore

def features_for(
self, entities: ConvertableToRetrivalJob | RetrivalJob, event_timestamp_column: str | None = None
self,
entities: ConvertableToRetrivalJob | RetrivalJob,
event_timestamp_column: str | None = None,
target_event_timestamp_column: str | None = None,
) -> SupervisedJob:
"""Loads the features and labels for a model
Expand Down Expand Up @@ -1079,12 +1105,23 @@ def features_for(
else:
raise ValueError('Found no targets in the model')

if event_timestamp_column == target_event_timestamp_column:
request = self.store.requests_for(
RawStringFeatureRequest(features.union(target_features)),
event_timestamp_column=event_timestamp_column,
)
job = self.store.features_for_request(request, entities, request.features_to_include)
return SupervisedJob(
job.select_columns(request.features_to_include),
target_columns=targets,
)

request = self.store.requests_for(
RawStringFeatureRequest(features), event_timestamp_column=event_timestamp_column
)
target_request = self.store.requests_for(
RawStringFeatureRequest(target_features), event_timestamp_column=event_timestamp_column
).without_event_timestamp(name_sufix='target')
RawStringFeatureRequest(target_features), event_timestamp_column=target_event_timestamp_column
).with_sufix('target')

total_request = FeatureRequest(
FeatureLocation.model(self.model.name),
Expand All @@ -1098,7 +1135,10 @@ def features_for(
)

def predictions_for(
self, entities: ConvertableToRetrivalJob | RetrivalJob, event_timestamp_column: str | None = None
self,
entities: ConvertableToRetrivalJob | RetrivalJob,
event_timestamp_column: str | None = None,
target_event_timestamp_column: str | None = None,
) -> RetrivalJob:
"""Loads the predictions and labels / ground truths for a model
Expand Down Expand Up @@ -1148,8 +1188,9 @@ def predictions_for(
RawStringFeatureRequest(pred_features), event_timestamp_column=event_timestamp_column
)
target_request = self.store.requests_for(
RawStringFeatureRequest(target_features)
).without_event_timestamp(name_sufix='target')
RawStringFeatureRequest(target_features),
event_timestamp_column=target_event_timestamp_column,
).with_sufix('target')

total_request = FeatureRequest(
FeatureLocation.model(self.model.name),
Expand Down Expand Up @@ -1268,24 +1309,16 @@ def previous(self, days: int = 0, minutes: int = 0, seconds: int = 0) -> Retriva
def features_for(
self, entities: ConvertableToRetrivalJob | RetrivalJob, event_timestamp_column: str | None = None
) -> RetrivalJob:

request = self.view.request_all
features = {'*'}
if self.feature_filter:
request = self.view.request_for(self.feature_filter)

if not event_timestamp_column:
request = request.without_event_timestamp()

if isinstance(entities, RetrivalJob):
entity_job = entities
else:
entity_job = RetrivalJob.from_convertable(entities, request.needed_requests)
features = self.feature_filter

job = self.source.features_for(entity_job, request)
if self.feature_filter:
return job.select_columns(self.feature_filter)
else:
return job
feature_refs = [f'{self.view.name}:{feature}' for feature in features]
return self.store.features_for(
entities,
feature_refs,
event_timestamp_column=event_timestamp_column,
)

def select(self, features: set[str]) -> FeatureViewStore:
logger.info(f'Selecting features {features}')
Expand Down
1 change: 1 addition & 0 deletions aligned/feature_view/feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ def compile_with_metadata(feature_view: Any, metadata: FeatureViewMetadata) -> C
name=metadata.name,
description=metadata.description,
tags=metadata.tags,
contacts=metadata.contacts,
source=metadata.source,
entities=set(),
features=set(),
Expand Down
31 changes: 31 additions & 0 deletions aligned/request/retrival_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,18 @@ def aggregate_over(self) -> dict[AggregateOver, set[AggregatedFeature]]:
features[feature.aggregate_over].add(feature)
return features

def with_sufix(self, sufix: str) -> 'RetrivalRequest':

return RetrivalRequest(
name=f'{self.name}{sufix}',
location=self.location,
entities=self.entities,
features=self.features,
derived_features=self.derived_features,
aggregated_features=self.aggregated_features,
event_timestamp_request=self.event_timestamp_request,
)

def without_event_timestamp(self, name_sufix: str | None = None) -> 'RetrivalRequest':

request = None
Expand Down Expand Up @@ -250,6 +262,18 @@ def combine(requests: list['RetrivalRequest']) -> list['RetrivalRequest']:

return list(grouped_requests.values())

def rename_entities(self, mapping: dict[str, str]) -> 'RetrivalRequest':

return RetrivalRequest(
name=self.name,
location=self.location,
entities={entity.renamed(mapping.get(entity.name, entity.name)) for entity in self.entities},
features=self.features,
derived_features=self.derived_features,
aggregated_features=self.aggregated_features,
event_timestamp_request=self.event_timestamp_request,
)

@staticmethod
def unsafe_combine(requests: list['RetrivalRequest']) -> 'RetrivalRequest':

Expand Down Expand Up @@ -420,6 +444,13 @@ def without_event_timestamp(self, name_sufix: str | None = None) -> 'FeatureRequ
needed_requests=[request.without_event_timestamp(name_sufix) for request in self.needed_requests],
)

def with_sufix(self, sufix: str) -> 'FeatureRequest':
return FeatureRequest(
location=self.location,
features_to_include=self.features_to_include,
needed_requests=[request.with_sufix(sufix) for request in self.needed_requests],
)

def rename_entities(self, mappings: dict[str, str]) -> 'FeatureRequest':
return FeatureRequest(
location=self.location,
Expand Down
9 changes: 9 additions & 0 deletions aligned/schemas/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,15 @@ def __pre_serialize__(self) -> Feature:

return self

def renamed(self, new_name: str) -> Feature:
return Feature(
name=new_name,
dtype=self.dtype,
description=self.description,
tags=self.tags,
constraints=self.constraints,
)

def __hash__(self) -> int:
return hash(self.name)

Expand Down
Loading

0 comments on commit a0bbbce

Please sign in to comment.