Skip to content

Commit

Permalink
feat: added custom source
Browse files Browse the repository at this point in the history
  • Loading branch information
MatsMoll committed Feb 23, 2024
1 parent 6b264b4 commit 388983f
Show file tree
Hide file tree
Showing 16 changed files with 369 additions and 151 deletions.
40 changes: 0 additions & 40 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,46 +199,6 @@ if freshness < datetime.now() - timedelta(days=2):
raise ValueError("To old data to create an ML model")
```


## Data Enrichers

In many cases will extra data be needed in order to generate some features.
We therefore need some way of enriching the data.
This can easily be done with Alinged's `DataEnricher`s.

```python
my_db = PostgreSQLConfig.localhost()
redis = RedisConfig.localhost()

user_location = my_db.data_enricher( # Fetch all user locations
sql="SELECT * FROM user_location"
).cache( # Cache them for one day
ttl=timedelta(days=1),
cache_key="user_location_cache"
).lock( # Make sure only one processer fetches the data at a time
lock_name="user_location_lock",
redis_config=redis
)


async def distance_to_users(df: DataFrame) -> Series:
user_location_df = await user_location.load()
...
return distances

@feature_view(...)
class SomeFeatures:

latitude = Float()
longitude = Float()

distance_to_users = Float().transformed_using_features_pandas(
[latitude, longitude],
distance_to_users
)
```


## Access Data

You can easily create a feature store that contains all your feature definitions.
Expand Down
6 changes: 6 additions & 0 deletions aligned/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
String,
Timestamp,
CustomAggregation,
List,
Embedding,
)
from aligned.compiler.model import model_contract, FeatureInputVersions
from aligned.data_source.stream_data_source import HttpStreamSource
from aligned.data_source.batch_data_source import CustomMethodDataSource
from aligned.feature_store import FeatureStore
from aligned.feature_view import (
feature_view,
Expand All @@ -35,6 +38,7 @@
'FileSource',
'AwsS3Config',
'RedshiftSQLConfig',
'CustomMethodDataSource',
# Stream Data Source
'HttpStreamSource',
# Online Source
Expand All @@ -52,6 +56,8 @@
'Float',
'EventTimestamp',
'Timestamp',
'List',
'Embedding',
'Json',
'EmbeddingModel',
'feature_view',
Expand Down
60 changes: 59 additions & 1 deletion aligned/compiler/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from aligned.data_source.batch_data_source import BatchDataSource
from aligned.data_source.stream_data_source import StreamDataSource
from aligned.feature_view.feature_view import FeatureView
from aligned.feature_view.feature_view import FeatureView, FeatureViewWrapper
from aligned.schemas.derivied_feature import DerivedFeature
from aligned.schemas.feature import Feature, FeatureLocation, FeatureReferance, FeatureType
from aligned.schemas.feature_view import CompiledFeatureView
Expand Down Expand Up @@ -139,6 +139,64 @@ def filter(

return ModelContractWrapper(metadata=meta, contract=self.contract)

def as_source(self) -> BatchDataSource:
from aligned.schemas.model import ModelSource

compiled_model = self.compile()
compiled_view = self.as_view()

if compiled_view is None:
raise ValueError(f"Model {compiled_model.name} is not compiled as a view")

return ModelSource(compiled_model, compiled_view)

def join(
self,
view: FeatureViewWrapper,
on_left: str | FeatureFactory | list[str] | list[FeatureFactory],
on_right: str | FeatureFactory | list[str] | list[FeatureFactory],
how: str = 'inner',
) -> BatchDataSource:
from aligned.data_source.batch_data_source import join_source
from aligned.schemas.model import ModelSource

compiled_model = self.compile()
compiled_view = self.as_view()

if compiled_view is None:
raise ValueError(f"Model {compiled_model.name} is not compiled as a view")

source = ModelSource(compiled_model, compiled_view)

return join_source(
source,
view=view,
on_left=on_left,
on_right=on_right,
left_request=compiled_view.request_all.needed_requests[0],
how=how,
)

def join_asof(self, view: FeatureViewWrapper, on_left: list[str], on_right: list[str]) -> BatchDataSource:
from aligned.data_source.batch_data_source import join_asof_source
from aligned.schemas.model import ModelSource

compiled_model = self.compile()
compiled_view = self.as_view()

if compiled_view is None:
raise ValueError(f"Model {compiled_model.name} is not compiled as a view")

source = ModelSource(compiled_model, compiled_view)

return join_asof_source(
source,
view=view,
left_on=on_left,
right_on=on_right,
left_request=compiled_view.request_all.needed_requests[0],
)


def resolve_dataset_store(dataset_store: DatasetStore | StorageFileReference) -> DatasetStore:

Expand Down
16 changes: 16 additions & 0 deletions aligned/compiler/transformation_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,22 @@ def compile(self) -> Transformation:
return LTETransformation(self.in_feature.name, self.value)


@dataclass
class Split(TransformationFactory):

pattern: str
from_feature: FeatureFactory

@property
def using_features(self) -> list[FeatureFactory]:
return [self.from_feature]

def compile(self) -> Transformation:
from aligned.schemas.transformation import Split as SplitTransformation

return SplitTransformation(self.from_feature.name, self.pattern)


# @dataclass
# class Split(TransformationFactory):

Expand Down
77 changes: 76 additions & 1 deletion aligned/data_source/batch_data_source.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, TypeVar, Any
from typing import TYPE_CHECKING, TypeVar, Any, Callable, Coroutine
from dataclasses import dataclass

from mashumaro.types import SerializableType
Expand All @@ -16,6 +16,7 @@
if TYPE_CHECKING:
from aligned.retrival_job import RetrivalJob
from datetime import datetime
import polars as pl


class BatchDataSourceFactory:
Expand All @@ -30,6 +31,7 @@ def __init__(self) -> None:
from aligned.sources.redshift import RedshiftSQLDataSource
from aligned.sources.s3 import AwsS3CsvDataSource, AwsS3ParquetDataSource
from aligned.schemas.feature_view import FeatureViewReferenceSource
from aligned.schemas.model import ModelSource

source_types = [
PostgreSQLDataSource,
Expand All @@ -43,6 +45,8 @@ def __init__(self) -> None:
JoinAsofDataSource,
FilteredDataSource,
FeatureViewReferenceSource,
CustomMethodDataSource,
ModelSource,
]

self.supported_data_sources = {source.type_name: source for source in source_types}
Expand Down Expand Up @@ -278,6 +282,77 @@ def depends_on(self) -> set[FeatureLocation]:
return set()


@dataclass
class CustomMethodDataSource(BatchDataSource):

all_data_method: bytes
all_between_dates_method: bytes
features_for_method: bytes

type_name: str = 'custom_method'

def job_group_key(self) -> str:
return 'custom_method'

def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob:
from aligned.retrival_job import CustomLazyPolarsJob
import dill

return CustomLazyPolarsJob(
request=request, method=lambda: dill.loads(self.all_data_method)(request, limit)
)

def all_between_dates(
self, request: RetrivalRequest, start_date: datetime, end_date: datetime
) -> RetrivalJob:
from aligned.retrival_job import CustomLazyPolarsJob
import dill

return CustomLazyPolarsJob(
request=request,
method=lambda: dill.loads(self.all_between_dates_method)(request, start_date, end_date),
)

def features_for(self, facts: RetrivalJob, request: RetrivalRequest) -> RetrivalJob:
from aligned.retrival_job import CustomLazyPolarsJob
import dill

return CustomLazyPolarsJob(
request=request, method=lambda: dill.loads(self.features_for_method)(facts, request)
)

@staticmethod
def from_methods(
all_data: Callable[[RetrivalRequest, int | None], Coroutine[None, None, pl.LazyFrame]] | None = None,
all_between_dates: Callable[
[RetrivalRequest, datetime, datetime], Coroutine[None, None, pl.LazyFrame]
]
| None = None,
features_for: Callable[[RetrivalJob, RetrivalRequest], Coroutine[None, None, pl.LazyFrame]]
| None = None,
) -> 'CustomMethodDataSource':
import dill

if not all_data:
all_data = CustomMethodDataSource.default_throw # type: ignore

if not all_between_dates:
all_between_dates = CustomMethodDataSource.default_throw # type: ignore

if not features_for:
features_for = CustomMethodDataSource.default_throw # type: ignore

return CustomMethodDataSource(
all_data_method=dill.dumps(all_data),
all_between_dates_method=dill.dumps(all_between_dates),
features_for_method=dill.dumps(features_for),
)

@staticmethod
def default_throw(**kwargs: Any) -> pl.LazyFrame:
raise NotImplementedError('No method is defined for this data source.')


@dataclass
class FilteredDataSource(BatchDataSource):

Expand Down
25 changes: 25 additions & 0 deletions aligned/feature_view/tests/test_custom_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest
from aligned import feature_view, Int32, CustomMethodDataSource
from aligned.request.retrival_request import RetrivalRequest
import polars as pl


async def all_data(request: RetrivalRequest, limit: int | None = None) -> pl.LazyFrame:

return pl.DataFrame({'some_id': [1, 2, 3], 'feature': [2, 3, 4]}).lazy()


@feature_view(name='right', source=CustomMethodDataSource.from_methods(all_data=all_data))
class CustomSourceData:

some_id = Int32().as_entity()

feature = Int32()


@pytest.mark.asyncio
async def test_custom_source() -> None:

result = await CustomSourceData.query().all().to_polars()

assert result.equals(pl.DataFrame({'some_id': [1, 2, 3], 'feature': [2, 3, 4]}).select(result.columns))
21 changes: 21 additions & 0 deletions aligned/retrival_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2051,3 +2051,24 @@ async def to_lazy_polars(self) -> pl.LazyFrame:

def remove_derived_features(self) -> RetrivalJob:
return self.job.remove_derived_features()


@dataclass
class CustomLazyPolarsJob(RetrivalJob):

request: RetrivalRequest
method: Callable[[], Coroutine[None, None, pl.LazyFrame]]

@property
def retrival_requests(self) -> list[RetrivalRequest]:
return [self.request]

@property
def request_result(self) -> RequestResult:
return RequestResult.from_request(self.request)

async def to_lazy_polars(self) -> pl.LazyFrame:
return await self.method()

async def to_pandas(self) -> pd.DataFrame:
return (await self.to_polars()).to_pandas()
Loading

0 comments on commit 388983f

Please sign in to comment.