Skip to content

Commit

Permalink
Transform decorator (#29)
Browse files Browse the repository at this point in the history
* Updated for udfs

* Updated poetry version

* Removed print in test
  • Loading branch information
MatsMoll authored Oct 12, 2024
1 parent c30258d commit b441458
Show file tree
Hide file tree
Showing 18 changed files with 419 additions and 51 deletions.
39 changes: 38 additions & 1 deletion aligned/compiler/feature_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def transform_pandas(self, transformation: Callable[[pd.DataFrame], pd.Series],
def transformed_using_features_polars(
self: T,
using_features: list[FeatureFactory],
transformation: Callable[[pl.LazyFrame, str], pl.LazyFrame],
transformation: Callable[[pl.LazyFrame, str], pl.LazyFrame] | pl.Expr,
) -> T:
from aligned.compiler.transformation_factory import PolarsTransformationFactory

Expand Down Expand Up @@ -1841,3 +1841,40 @@ def percentile(self, percentile: float) -> Float:
offset_interval=self.offset_interval,
)
return feat


def transform_polars(
using_features: list[FeatureFactory], return_type: T
) -> Callable[[Callable[[Any, pl.LazyFrame, str], pl.LazyFrame]], T]:
def wrapper(method: Callable[[Any, pl.LazyFrame, str], pl.LazyFrame]) -> T:
return return_type.transformed_using_features_polars(
using_features=using_features, transformation=method # type: ignore
)

return wrapper


def transform_pandas(
using_features: list[FeatureFactory], return_type: T
) -> Callable[[Callable[[Any, pd.DataFrame], pd.Series]], T]:
def wrapper(method: Callable[[Any, pd.DataFrame], pd.Series]) -> T:
return return_type.transformed_using_features_pandas(
using_features=using_features, transformation=method # type: ignore
)

return wrapper


def transform_row(
using_features: list[FeatureFactory], return_type: T
) -> Callable[[Callable[[Any, dict[str, Any]], Any]], T]:
def wrapper(method: Callable[[Any, dict[str, Any]], Any]) -> T:
from aligned.compiler.transformation_factory import MapRowTransformation

new_value = return_type.copy_type()
new_value.transformation = MapRowTransformation(
dtype=new_value, method=method, _using_features=using_features # type: ignore
)
return new_value

return wrapper
5 changes: 2 additions & 3 deletions aligned/compiler/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,10 +584,9 @@ def sort_key(x: tuple[int, FeatureFactory]) -> int:
if not probability_features and inference_view.classification_targets:
inference_view.features.update({target.feature for target in inference_view.classification_targets})

schema_hash = inference_view.schema_hash()

view = inference_view.as_view(metadata.name)
if inference_view.source:
inference_view.source = inference_view.source.with_schema_version(schema_hash)
inference_view.source = inference_view.source.with_view(view)

return ModelSchema(
name=metadata.name,
Expand Down
102 changes: 97 additions & 5 deletions aligned/compiler/transformation_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,56 @@ def compile(self) -> Transformation:
return Absolute(self.feature.name)


@dataclass
class MapRowTransformation(TransformationFactory):

dtype: FeatureFactory
method: Callable[[dict], Any]
_using_features: list[FeatureFactory]

@property
def using_features(self) -> list[FeatureFactory]:
return self._using_features

def compile(self) -> Transformation:
import inspect
import types
import dill
from aligned.schemas.transformation import PolarsMapRowTransformation

if isinstance(self.method, types.LambdaType) and self.method.__name__ == '<lambda>':
raise NotImplementedError(type(self))

function_name = dill.source.getname(self.method)
assert isinstance(function_name, str), 'Need a function name'
raw_code = inspect.getsource(self.method)

code = ''

indents: int | None = None
start_signature = f"def {function_name}"

for line in raw_code.splitlines(keepends=True):

if indents:
if len(line) > indents:
code += line[:indents].lstrip() + line[indents:]
else:
code += line

if start_signature in line:
stripped = line.lstrip()
indents = len(line) - len(stripped)
stripped = stripped.replace(f"{start_signature}(self,", f"{start_signature}(")
code += stripped

return PolarsMapRowTransformation(
code=code,
function_name=function_name,
dtype=self.dtype.dtype,
)


@dataclass
class PandasTransformationFactory(TransformationFactory):

Expand All @@ -655,10 +705,31 @@ def compile(self) -> Transformation:
dtype=self.dtype.dtype,
)
else:
function_name = (dill.source.getname(self.method),)
assert isinstance(function_name, str), f"Expected string got {type(function_name)}"
function_name = dill.source.getname(self.method)
assert isinstance(function_name, str), 'Need a function name'
raw_code = inspect.getsource(self.method)

code = ''

indents: int | None = None
start_signature = f"def {function_name}"

for line in raw_code.splitlines(keepends=True):

if indents:
if len(line) > indents:
code += line[:indents].lstrip() + line[indents:]
else:
code += line

if start_signature in line:
stripped = line.lstrip()
indents = len(line) - len(stripped)
stripped = stripped.replace(f"{start_signature}(self,", f"{start_signature}(")
code += stripped

return PandasFunctionTransformation(
code=inspect.getsource(self.method),
code=code,
function_name=function_name,
dtype=self.dtype.dtype,
)
Expand Down Expand Up @@ -690,14 +761,35 @@ def method(df: pl.DataFrame, alias: str) -> pl.Expr:

return PolarsLambdaTransformation(method=dill.dumps(method), code='', dtype=self.dtype.dtype)
else:
code = inspect.getsource(self.method)
function_name = dill.source.getname(self.method)
assert isinstance(function_name, str), 'Need a function name'
raw_code = inspect.getsource(self.method)

code = ''

indents: int | None = None
start_signature = f"def {function_name}"

for line in raw_code.splitlines(keepends=True):

if indents:
if len(line) > indents:
code += line[:indents].lstrip() + line[indents:]
else:
code += line

if start_signature in line:
stripped = line.lstrip()
indents = len(line) - len(stripped)
stripped = stripped.replace(f"{start_signature}(self,", f"{start_signature}(")
code += stripped

if isinstance(self.method, types.LambdaType) and self.method.__name__ == '<lambda>':
return PolarsLambdaTransformation(
method=dill.dumps(self.method), code=code.strip(), dtype=self.dtype.dtype
)
else:
function_name = (dill.source.getname(self.method),)
function_name = dill.source.getname(self.method)
assert isinstance(function_name, str), f"Expected string got {type(function_name)}"
return PolarsFunctionTransformation(
code=code,
Expand Down
10 changes: 7 additions & 3 deletions aligned/data_source/batch_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@

import logging


logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from aligned.retrival_job import RetrivalJob
from aligned.schemas.feature_view import CompiledFeatureView

T = TypeVar('T')

Expand Down Expand Up @@ -113,7 +115,7 @@ def source_id(self) -> str:
"""
return self.job_group_key()

def with_schema_version(self: T, schema_hash: bytes) -> T:
def with_view(self: T, view: CompiledFeatureView) -> T:
return self

def __hash__(self) -> int:
Expand Down Expand Up @@ -1051,9 +1053,11 @@ def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob:
)

def all_between_dates(
self, request: RetrivalRequest, start_date: datetime, end_date: datetime
self,
request: RetrivalRequest,
start_date: datetime,
end_date: datetime,
) -> RetrivalJob:

right_job = self.right_source.all_data(self.right_request, limit=None).derive_features(
[self.right_request]
)
Expand Down
18 changes: 14 additions & 4 deletions aligned/data_source/model_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,18 @@ def all_data(self, request: RetrivalRequest, limit: int | None = None) -> Retriv
f'Type: {type(self)} have not implemented how to load fact data with multiple sources.'
)

location = reqs.needed_requests[0].location
req = reqs.needed_requests[0]
location = req.location
if location.location_type != 'feature_view':
raise NotImplementedError(
f'Type: {type(self)} have not implemented how to load fact data with multiple sources.'
)

entities = self.store.store.feature_view(location.name).all_columns(limit=limit)
entities = (
self.store.store.feature_view(location.name)
.select(req.features_to_include)
.all_columns(limit=limit)
)
return self.store.predict_over(entities).with_request([request])

def all_between_dates(
Expand All @@ -57,13 +62,18 @@ def all_between_dates(
f'Type: {type(self)} have not implemented how to load fact data with multiple sources.'
)

location = reqs.needed_requests[0].location
req = reqs.needed_requests[0]
location = req.location
if location.location_type != 'feature_view':
raise NotImplementedError(
f'Type: {type(self)} have not implemented how to load fact data with multiple sources.'
)

entities = self.store.store.feature_view(location.name).between_dates(start_date, end_date)
entities = (
self.store.store.feature_view(location.name)
.select(req.features_to_include)
.between_dates(start_date, end_date)
)
return self.store.predict_over(entities).with_request([request])

def features_for(self, facts: RetrivalJob, request: RetrivalRequest) -> RetrivalJob:
Expand Down
7 changes: 6 additions & 1 deletion aligned/exposed_model/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def ab_test_model(models: list[tuple[ExposedModel, float]]) -> ABTestModel:


@dataclass
class DillFunction(ExposedModel):
class DillFunction(ExposedModel, VersionedModel):

function: bytes

Expand All @@ -387,6 +387,11 @@ def exposed_at_url(self) -> str | None:
def as_markdown(self) -> str:
return 'A function stored in a dill file.'

async def model_version(self) -> str:
from hashlib import md5

return md5(self.function, usedforsecurity=False).hexdigest()

async def needed_features(self, store: ModelFeatureStore) -> list[FeatureReference]:
default = store.model.features.default_version
return store.feature_references_for(store.selected_version or default)
Expand Down
90 changes: 90 additions & 0 deletions aligned/exposed_model/tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from contextlib import suppress
import pytest
from aligned import ExposedModel, model_contract, String, Int32, EventTimestamp, feature_view, FileSource
from aligned.feature_store import ContractStore
from aligned.sources.in_mem_source import InMemorySource
from aligned.sources.random_source import RandomDataSource
from aligned.exposed_model.interface import python_function

Expand Down Expand Up @@ -136,3 +138,91 @@ class MyModelContract2:
).to_polars()

assert 'other_pred' in preds


@pytest.mark.asyncio
async def test_pipeline_model() -> None:
@feature_view(
name='input',
source=InMemorySource.from_values(
{'entity_id': ['a', 'b', 'c'], 'x': [1, 2, 3], 'other': [9, 8, 7]} # type: ignore
),
)
class InputFeatureView:
entity_id = String().as_entity()
x = Int32()
other = Int32()

input = InputFeatureView()

@model_contract(
input_features=[InputFeatureView().x],
exposed_model=python_function(lambda df: df['x'] * 2),
output_source=InMemorySource.from_values(
{'entity_id': ['a', 'b'], 'prediction': [2, 4]} # type: ignore
),
)
class MyModelContract:
entity_id = String().as_entity()

prediction = input.x.as_regression_target()

model_version = String().as_model_version()

@model_contract(
input_features=[InputFeatureView().x, MyModelContract().prediction],
exposed_model=python_function(lambda df: df['prediction'] * 3 + df['x']),
)
class MyModelContract2:
entity_id = String().as_entity()

other_pred = input.other.as_regression_target()

model_version = String().as_model_version()

store = ContractStore.empty()
store.add_view(InputFeatureView)
store.add_model(MyModelContract)
store.add_model(MyModelContract2)

without_cache = store.without_model_cache()

first_preds = await store.model(MyModelContract).predict_over({'entity_id': ['a', 'c']}).to_polars()
assert first_preds['prediction'].null_count() == 0

preds = (
await store.model(MyModelContract2)
.predict_over(
{
'entity_id': ['a', 'c'],
}
)
.to_polars()
)
assert preds['other_pred'].null_count() == 1
assert not first_preds['model_version'].series_equal(preds['model_version'])

preds = (
await store.model(MyModelContract2)
.predict_over(
{
'entity_id': ['a', 'c'],
'prediction': [2, 6],
}
)
.to_polars()
)
assert preds['other_pred'].null_count() == 0
assert not first_preds['model_version'].series_equal(preds['model_version'])

preds = (
await without_cache.model(MyModelContract2)
.predict_over(
{
'entity_id': ['a', 'c'],
}
)
.to_polars()
)
assert preds['other_pred'].null_count() == 0
assert not first_preds['model_version'].series_equal(preds['model_version'])
Loading

0 comments on commit b441458

Please sign in to comment.