diff --git a/aligned/compiler/feature_factory.py b/aligned/compiler/feature_factory.py index dc83aff6..43658ce1 100644 --- a/aligned/compiler/feature_factory.py +++ b/aligned/compiler/feature_factory.py @@ -547,7 +547,7 @@ def transform_pandas(self, transformation: Callable[[pd.DataFrame], pd.Series], def transformed_using_features_polars( self: T, using_features: list[FeatureFactory], - transformation: Callable[[pl.LazyFrame, str], pl.LazyFrame], + transformation: Callable[[pl.LazyFrame, str], pl.LazyFrame] | pl.Expr, ) -> T: from aligned.compiler.transformation_factory import PolarsTransformationFactory @@ -1841,3 +1841,40 @@ def percentile(self, percentile: float) -> Float: offset_interval=self.offset_interval, ) return feat + + +def transform_polars( + using_features: list[FeatureFactory], return_type: T +) -> Callable[[Callable[[Any, pl.LazyFrame, str], pl.LazyFrame]], T]: + def wrapper(method: Callable[[Any, pl.LazyFrame, str], pl.LazyFrame]) -> T: + return return_type.transformed_using_features_polars( + using_features=using_features, transformation=method # type: ignore + ) + + return wrapper + + +def transform_pandas( + using_features: list[FeatureFactory], return_type: T +) -> Callable[[Callable[[Any, pd.DataFrame], pd.Series]], T]: + def wrapper(method: Callable[[Any, pd.DataFrame], pd.Series]) -> T: + return return_type.transformed_using_features_pandas( + using_features=using_features, transformation=method # type: ignore + ) + + return wrapper + + +def transform_row( + using_features: list[FeatureFactory], return_type: T +) -> Callable[[Callable[[Any, dict[str, Any]], Any]], T]: + def wrapper(method: Callable[[Any, dict[str, Any]], Any]) -> T: + from aligned.compiler.transformation_factory import MapRowTransformation + + new_value = return_type.copy_type() + new_value.transformation = MapRowTransformation( + dtype=new_value, method=method, _using_features=using_features # type: ignore + ) + return new_value + + return wrapper diff --git a/aligned/compiler/model.py b/aligned/compiler/model.py index 3360373a..1eca053f 100644 --- a/aligned/compiler/model.py +++ b/aligned/compiler/model.py @@ -584,10 +584,9 @@ def sort_key(x: tuple[int, FeatureFactory]) -> int: if not probability_features and inference_view.classification_targets: inference_view.features.update({target.feature for target in inference_view.classification_targets}) - schema_hash = inference_view.schema_hash() - + view = inference_view.as_view(metadata.name) if inference_view.source: - inference_view.source = inference_view.source.with_schema_version(schema_hash) + inference_view.source = inference_view.source.with_view(view) return ModelSchema( name=metadata.name, diff --git a/aligned/compiler/transformation_factory.py b/aligned/compiler/transformation_factory.py index 6cba0aed..2d8b890c 100644 --- a/aligned/compiler/transformation_factory.py +++ b/aligned/compiler/transformation_factory.py @@ -629,6 +629,56 @@ def compile(self) -> Transformation: return Absolute(self.feature.name) +@dataclass +class MapRowTransformation(TransformationFactory): + + dtype: FeatureFactory + method: Callable[[dict], Any] + _using_features: list[FeatureFactory] + + @property + def using_features(self) -> list[FeatureFactory]: + return self._using_features + + def compile(self) -> Transformation: + import inspect + import types + import dill + from aligned.schemas.transformation import PolarsMapRowTransformation + + if isinstance(self.method, types.LambdaType) and self.method.__name__ == '': + raise NotImplementedError(type(self)) + + function_name = dill.source.getname(self.method) + assert isinstance(function_name, str), 'Need a function name' + raw_code = inspect.getsource(self.method) + + code = '' + + indents: int | None = None + start_signature = f"def {function_name}" + + for line in raw_code.splitlines(keepends=True): + + if indents: + if len(line) > indents: + code += line[:indents].lstrip() + line[indents:] + else: + code += line + + if start_signature in line: + stripped = line.lstrip() + indents = len(line) - len(stripped) + stripped = stripped.replace(f"{start_signature}(self,", f"{start_signature}(") + code += stripped + + return PolarsMapRowTransformation( + code=code, + function_name=function_name, + dtype=self.dtype.dtype, + ) + + @dataclass class PandasTransformationFactory(TransformationFactory): @@ -655,10 +705,31 @@ def compile(self) -> Transformation: dtype=self.dtype.dtype, ) else: - function_name = (dill.source.getname(self.method),) - assert isinstance(function_name, str), f"Expected string got {type(function_name)}" + function_name = dill.source.getname(self.method) + assert isinstance(function_name, str), 'Need a function name' + raw_code = inspect.getsource(self.method) + + code = '' + + indents: int | None = None + start_signature = f"def {function_name}" + + for line in raw_code.splitlines(keepends=True): + + if indents: + if len(line) > indents: + code += line[:indents].lstrip() + line[indents:] + else: + code += line + + if start_signature in line: + stripped = line.lstrip() + indents = len(line) - len(stripped) + stripped = stripped.replace(f"{start_signature}(self,", f"{start_signature}(") + code += stripped + return PandasFunctionTransformation( - code=inspect.getsource(self.method), + code=code, function_name=function_name, dtype=self.dtype.dtype, ) @@ -690,14 +761,35 @@ def method(df: pl.DataFrame, alias: str) -> pl.Expr: return PolarsLambdaTransformation(method=dill.dumps(method), code='', dtype=self.dtype.dtype) else: - code = inspect.getsource(self.method) + function_name = dill.source.getname(self.method) + assert isinstance(function_name, str), 'Need a function name' + raw_code = inspect.getsource(self.method) + + code = '' + + indents: int | None = None + start_signature = f"def {function_name}" + + for line in raw_code.splitlines(keepends=True): + + if indents: + if len(line) > indents: + code += line[:indents].lstrip() + line[indents:] + else: + code += line + + if start_signature in line: + stripped = line.lstrip() + indents = len(line) - len(stripped) + stripped = stripped.replace(f"{start_signature}(self,", f"{start_signature}(") + code += stripped if isinstance(self.method, types.LambdaType) and self.method.__name__ == '': return PolarsLambdaTransformation( method=dill.dumps(self.method), code=code.strip(), dtype=self.dtype.dtype ) else: - function_name = (dill.source.getname(self.method),) + function_name = dill.source.getname(self.method) assert isinstance(function_name, str), f"Expected string got {type(function_name)}" return PolarsFunctionTransformation( code=code, diff --git a/aligned/data_source/batch_data_source.py b/aligned/data_source/batch_data_source.py index 33595c80..26de2452 100644 --- a/aligned/data_source/batch_data_source.py +++ b/aligned/data_source/batch_data_source.py @@ -18,10 +18,12 @@ import logging + logger = logging.getLogger(__name__) if TYPE_CHECKING: from aligned.retrival_job import RetrivalJob + from aligned.schemas.feature_view import CompiledFeatureView T = TypeVar('T') @@ -113,7 +115,7 @@ def source_id(self) -> str: """ return self.job_group_key() - def with_schema_version(self: T, schema_hash: bytes) -> T: + def with_view(self: T, view: CompiledFeatureView) -> T: return self def __hash__(self) -> int: @@ -1051,9 +1053,11 @@ def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob: ) def all_between_dates( - self, request: RetrivalRequest, start_date: datetime, end_date: datetime + self, + request: RetrivalRequest, + start_date: datetime, + end_date: datetime, ) -> RetrivalJob: - right_job = self.right_source.all_data(self.right_request, limit=None).derive_features( [self.right_request] ) diff --git a/aligned/data_source/model_predictor.py b/aligned/data_source/model_predictor.py index 8eec8afe..6e876417 100644 --- a/aligned/data_source/model_predictor.py +++ b/aligned/data_source/model_predictor.py @@ -39,13 +39,18 @@ def all_data(self, request: RetrivalRequest, limit: int | None = None) -> Retriv f'Type: {type(self)} have not implemented how to load fact data with multiple sources.' ) - location = reqs.needed_requests[0].location + req = reqs.needed_requests[0] + location = req.location if location.location_type != 'feature_view': raise NotImplementedError( f'Type: {type(self)} have not implemented how to load fact data with multiple sources.' ) - entities = self.store.store.feature_view(location.name).all_columns(limit=limit) + entities = ( + self.store.store.feature_view(location.name) + .select(req.features_to_include) + .all_columns(limit=limit) + ) return self.store.predict_over(entities).with_request([request]) def all_between_dates( @@ -57,13 +62,18 @@ def all_between_dates( f'Type: {type(self)} have not implemented how to load fact data with multiple sources.' ) - location = reqs.needed_requests[0].location + req = reqs.needed_requests[0] + location = req.location if location.location_type != 'feature_view': raise NotImplementedError( f'Type: {type(self)} have not implemented how to load fact data with multiple sources.' ) - entities = self.store.store.feature_view(location.name).between_dates(start_date, end_date) + entities = ( + self.store.store.feature_view(location.name) + .select(req.features_to_include) + .between_dates(start_date, end_date) + ) return self.store.predict_over(entities).with_request([request]) def features_for(self, facts: RetrivalJob, request: RetrivalRequest) -> RetrivalJob: diff --git a/aligned/exposed_model/interface.py b/aligned/exposed_model/interface.py index 49e8f3c4..034866bb 100644 --- a/aligned/exposed_model/interface.py +++ b/aligned/exposed_model/interface.py @@ -373,7 +373,7 @@ def ab_test_model(models: list[tuple[ExposedModel, float]]) -> ABTestModel: @dataclass -class DillFunction(ExposedModel): +class DillFunction(ExposedModel, VersionedModel): function: bytes @@ -387,6 +387,11 @@ def exposed_at_url(self) -> str | None: def as_markdown(self) -> str: return 'A function stored in a dill file.' + async def model_version(self) -> str: + from hashlib import md5 + + return md5(self.function, usedforsecurity=False).hexdigest() + async def needed_features(self, store: ModelFeatureStore) -> list[FeatureReference]: default = store.model.features.default_version return store.feature_references_for(store.selected_version or default) diff --git a/aligned/exposed_model/tests/test_model.py b/aligned/exposed_model/tests/test_model.py index f559b5a9..03c1a8cb 100644 --- a/aligned/exposed_model/tests/test_model.py +++ b/aligned/exposed_model/tests/test_model.py @@ -1,6 +1,8 @@ from contextlib import suppress import pytest from aligned import ExposedModel, model_contract, String, Int32, EventTimestamp, feature_view, FileSource +from aligned.feature_store import ContractStore +from aligned.sources.in_mem_source import InMemorySource from aligned.sources.random_source import RandomDataSource from aligned.exposed_model.interface import python_function @@ -136,3 +138,91 @@ class MyModelContract2: ).to_polars() assert 'other_pred' in preds + + +@pytest.mark.asyncio +async def test_pipeline_model() -> None: + @feature_view( + name='input', + source=InMemorySource.from_values( + {'entity_id': ['a', 'b', 'c'], 'x': [1, 2, 3], 'other': [9, 8, 7]} # type: ignore + ), + ) + class InputFeatureView: + entity_id = String().as_entity() + x = Int32() + other = Int32() + + input = InputFeatureView() + + @model_contract( + input_features=[InputFeatureView().x], + exposed_model=python_function(lambda df: df['x'] * 2), + output_source=InMemorySource.from_values( + {'entity_id': ['a', 'b'], 'prediction': [2, 4]} # type: ignore + ), + ) + class MyModelContract: + entity_id = String().as_entity() + + prediction = input.x.as_regression_target() + + model_version = String().as_model_version() + + @model_contract( + input_features=[InputFeatureView().x, MyModelContract().prediction], + exposed_model=python_function(lambda df: df['prediction'] * 3 + df['x']), + ) + class MyModelContract2: + entity_id = String().as_entity() + + other_pred = input.other.as_regression_target() + + model_version = String().as_model_version() + + store = ContractStore.empty() + store.add_view(InputFeatureView) + store.add_model(MyModelContract) + store.add_model(MyModelContract2) + + without_cache = store.without_model_cache() + + first_preds = await store.model(MyModelContract).predict_over({'entity_id': ['a', 'c']}).to_polars() + assert first_preds['prediction'].null_count() == 0 + + preds = ( + await store.model(MyModelContract2) + .predict_over( + { + 'entity_id': ['a', 'c'], + } + ) + .to_polars() + ) + assert preds['other_pred'].null_count() == 1 + assert not first_preds['model_version'].series_equal(preds['model_version']) + + preds = ( + await store.model(MyModelContract2) + .predict_over( + { + 'entity_id': ['a', 'c'], + 'prediction': [2, 6], + } + ) + .to_polars() + ) + assert preds['other_pred'].null_count() == 0 + assert not first_preds['model_version'].series_equal(preds['model_version']) + + preds = ( + await without_cache.model(MyModelContract2) + .predict_over( + { + 'entity_id': ['a', 'c'], + } + ) + .to_polars() + ) + assert preds['other_pred'].null_count() == 0 + assert not first_preds['model_version'].series_equal(preds['model_version']) diff --git a/aligned/feature_source.py b/aligned/feature_source.py index b81c4608..978a05c5 100644 --- a/aligned/feature_source.py +++ b/aligned/feature_source.py @@ -76,31 +76,57 @@ def features_for(self, facts: RetrivalJob, request: FeatureRequest) -> RetrivalJ if request.location.identifier in self.sources } + loaded_columns = set(facts.loaded_columns) + + def needs_to_load_source(requests: list[RetrivalRequest]) -> bool: + for req in requests: + if set(req.feature_names) - loaded_columns: + return True + + for feat in req.derived_features: + if ( + set( + depends_on.name + for depends_on in feat.depending_on + if depends_on.location != req.location + ) + - loaded_columns + ): + return True + + for feat in req.aggregated_features: + if set(feat.depending_on_names) - loaded_columns: + return True + return False + # The combined views basicly, as they have no direct combined_requests = [ request for request in request.needed_requests if request.location.identifier not in self.sources ] jobs = [] for source_group in source_groupes: - requests = [ + requests_with_source = [ (source, req) for source, req in core_requests if source.job_group_key() == source_group ] - has_derived_features = any(req.derived_features for _, req in requests) - job = ( - self.source_types[source_group] - .multi_source_features_for(facts=facts, requests=requests) - .ensure_types([req for _, req in requests]) - ) - if has_derived_features: - job = job.derive_features() + requests = [req for _, req in requests_with_source] + + if needs_to_load_source(requests): + job = ( + self.source_types[source_group] + .multi_source_features_for(facts=facts, requests=requests_with_source) + .ensure_types(requests) + .derive_features() + ) + else: + job = facts.derive_features(requests) - if len(requests) == 1 and requests[0][1].aggregated_features: - req = requests[0][1] + if len(requests) == 1 and requests_with_source[0][1].aggregated_features: + req = requests_with_source[0][1] job = job.aggregate(req) jobs.append(job) - fact_features = set(facts.loaded_columns) - set(request.request_result.entity_columns) + fact_features = loaded_columns - set(request.request_result.entity_columns) if fact_features: jobs.append(facts) diff --git a/aligned/feature_store.py b/aligned/feature_store.py index 6cfdcc61..78b56c45 100644 --- a/aligned/feature_store.py +++ b/aligned/feature_store.py @@ -1,4 +1,5 @@ from __future__ import annotations +from copy import copy import polars as pl from aligned.lazy_imports import pandas as pd @@ -199,6 +200,21 @@ def dummy_store(self) -> ContractStore: self.feature_views, self.models, BatchFeatureSource(sources), self.vector_indexes ) + def without_model_cache(self) -> ContractStore: + from aligned.data_source.model_predictor import PredictModelSource + + new_store = self + + for model_name, model in self.models.items(): + if not model.exposed_model: + continue + + new_store = new_store.update_source_for( + FeatureLocation.model(model_name), PredictModelSource(new_store.model(model_name)) + ) + + return new_store + def repo_definition(self) -> RepoDefinition: return RepoDefinition( metadata=RepoMetadata(datetime.utcnow(), name='feature_store_location.py'), @@ -757,7 +773,7 @@ def update_source_for(self, location: FeatureLocation | str, source: BatchDataSo if isinstance(location, str): location = FeatureLocation.from_string(location) - new_source = self.feature_source + new_source = BatchFeatureSource(copy(self.feature_source.sources)) assert isinstance(new_source.sources, dict) new_source.sources[location.identifier] = source diff --git a/aligned/feature_view/feature_view.py b/aligned/feature_view/feature_view.py index 192f8664..5e33b8da 100644 --- a/aligned/feature_view/feature_view.py +++ b/aligned/feature_view/feature_view.py @@ -722,9 +722,6 @@ def sort_key(x: tuple[int, FeatureFactory]) -> int: else: view.features.add(compiled_feature) - if not view.entities: - raise ValueError(f'FeatureView {metadata.name} must contain at least one Entity') - loc = FeatureLocation.feature_view(view.name) aggregation_group_by = [FeatureReference(entity.name, loc, entity.dtype) for entity in view.entities] event_timestamp_ref = ( @@ -746,11 +743,10 @@ def sort_key(x: tuple[int, FeatureFactory]) -> int: ) view.aggregated_features.add(feat) - schema_hash = view.schema_hash() - view.source = view.source.with_schema_version(schema_hash) + view.source = view.source.with_view(view) if view.materialized_source: - view.materialized_source = view.materialized_source.with_schema_version(schema_hash) + view.materialized_source = view.materialized_source.with_view(view) return view diff --git a/aligned/jobs/tests/test_combined_job.py b/aligned/jobs/tests/test_combined_job.py index 3d70b371..ba194223 100644 --- a/aligned/jobs/tests/test_combined_job.py +++ b/aligned/jobs/tests/test_combined_job.py @@ -1,6 +1,41 @@ import pytest +from aligned import feature_view, String, Bool +from aligned.sources.in_mem_source import InMemorySource from aligned.retrival_job import CombineFactualJob, RetrivalJob, RetrivalRequest +from aligned.compiler.feature_factory import transform_polars, transform_pandas, transform_row + +import polars as pl +from aligned.lazy_imports import pandas as pd + + +@feature_view(source=InMemorySource.empty()) +class CombinedData: + query = String() + contains_mr = query.contains('mr') + + @transform_polars(using_features=[query], return_type=Bool()) + def contains_something(self, df: pl.LazyFrame, return_value: str) -> pl.LazyFrame: + return df.with_columns((pl.col('query').str.len_chars() > 5).alias(return_value)) + + @transform_pandas(using_features=[query], return_type=String()) + def append_someting(self, df: pd.DataFrame) -> pd.Series: + return df['query'] + ' something' + + @transform_row(using_features=[query], return_type=String()) + def using_row(self, row: dict) -> str: + return row['query'] + ' something' + + not_contains = contains_something.not_equals(True) + + +@pytest.mark.asyncio +async def test_feature_view_without_entity(): + + job = CombinedData.query().features_for({'query': ['Hello', 'Hello mr']}) + df = await job.to_polars() + + assert df['contains_mr'].sum() == 1 @pytest.mark.asyncio diff --git a/aligned/retrival_job.py b/aligned/retrival_job.py index 4eaeb473..86b2794d 100644 --- a/aligned/retrival_job.py +++ b/aligned/retrival_job.py @@ -762,7 +762,7 @@ def derive_features(self, requests: list[RetrivalRequest] | None = None) -> Retr requests = requests or self.retrival_requests for request in requests: - if len(request.derived_features) > 0: + if request.derived_features: return DerivedFeatureJob(job=self, requests=requests) return self diff --git a/aligned/schemas/model.py b/aligned/schemas/model.py index fd4249dc..1ad69f74 100644 --- a/aligned/schemas/model.py +++ b/aligned/schemas/model.py @@ -93,13 +93,13 @@ def logged_features(self) -> Feature | None: return feature return None - def as_view(self, name: str) -> CompiledFeatureView | None: - if not self.source: - return None + def as_view(self, name: str) -> CompiledFeatureView: + from aligned.sources.in_mem_source import InMemorySource + import polars as pl return CompiledFeatureView( name=name, - source=self.source, + source=self.source or InMemorySource(pl.DataFrame()), entities=self.entities, features=self.features, derived_features=self.derived_features, diff --git a/aligned/schemas/transformation.py b/aligned/schemas/transformation.py index 8de9e839..afdac191 100644 --- a/aligned/schemas/transformation.py +++ b/aligned/schemas/transformation.py @@ -235,6 +235,7 @@ def __init__(self) -> None: ArrayContains, ArrayAtIndex, OllamaEmbedding, + PolarsMapRowTransformation, ]: self.add(tran_type) @@ -249,6 +250,43 @@ def shared(cls) -> SupportedTransformations: return cls._shared +@dataclass +class PolarsMapRowTransformation(Transformation): + """ + This will encode a custom method, that is not a lambda function + Threfore, we will stort the actuall code, and dynamically load it on runtime. + + This is unsafe, but will remove the ModuleImportError for custom methods + """ + + code: str + function_name: str + dtype: FeatureType + name: str = 'pol_map_row' + + async def transform_pandas(self, df: pd.DataFrame) -> pd.Series: + return (await self.transform_polars(pl.from_pandas(df).lazy(), 'value')).collect()[ + 'value' + ] # type: ignore + + async def transform_polars(self, df: pl.LazyFrame, alias: str) -> pl.LazyFrame | pl.Expr: + if self.function_name not in locals(): + exec(self.code) + + loaded = locals()[self.function_name] + + polars_df = df.collect() + columns = polars_df.columns + new_cols = polars_df.columns + new_cols.append(alias) + + return ( + polars_df.map_rows(lambda values: (*values, loaded(dict(zip(columns, values))))) + .rename(lambda col: new_cols[int(col.split('_')[1])]) + .lazy() + ) + + @dataclass class PandasFunctionTransformation(Transformation): """ diff --git a/aligned/sources/azure_blob_storage.py b/aligned/sources/azure_blob_storage.py index a33cbc29..d4c70b59 100644 --- a/aligned/sources/azure_blob_storage.py +++ b/aligned/sources/azure_blob_storage.py @@ -15,6 +15,7 @@ from aligned.retrival_job import RetrivalJob, RetrivalRequest from aligned.schemas.date_formatter import DateFormatter from aligned.schemas.feature import FeatureType, EventTimestamp +from aligned.schemas.feature_view import CompiledFeatureView from aligned.sources.local import ( CsvConfig, DataFileReference, @@ -352,7 +353,8 @@ def job_group_key(self) -> str: def storage(self) -> Storage: return self.config.storage - def with_schema_version(self, schema_hash: bytes) -> AzureBlobCsvDataSource: + def with_view(self, view: CompiledFeatureView) -> AzureBlobCsvDataSource: + schema_hash = view.schema_hash() return AzureBlobCsvDataSource( config=self.config, path=self.path.replace(AzureBlobDirectory.schema_placeholder(), schema_hash.hex()), @@ -483,7 +485,8 @@ def job_group_key(self) -> str: def __hash__(self) -> int: return hash(self.job_group_key()) - def with_schema_version(self, schema_hash: bytes) -> AzureBlobPartitionedParquetDataSource: + def with_view(self, view: CompiledFeatureView) -> AzureBlobPartitionedParquetDataSource: + schema_hash = view.schema_hash() return AzureBlobPartitionedParquetDataSource( config=self.config, directory=self.directory.replace(AzureBlobDirectory.schema_placeholder(), schema_hash.hex()), @@ -688,10 +691,10 @@ def job_group_key(self) -> str: def __hash__(self) -> int: return hash(self.job_group_key()) - def with_schema_version(self, schema_hash: bytes) -> AzureBlobParquetDataSource: + def with_view(self, view: CompiledFeatureView) -> AzureBlobParquetDataSource: return AzureBlobParquetDataSource( config=self.config, - path=self.path.replace(AzureBlobDirectory.schema_placeholder(), schema_hash.hex()), + path=self.path.replace(AzureBlobDirectory.schema_placeholder(), view.schema_hash().hex()), mapping_keys=self.mapping_keys, parquet_config=self.parquet_config, date_formatter=self.date_formatter, diff --git a/aligned/sources/in_mem_source.py b/aligned/sources/in_mem_source.py index da139dad..039333fd 100644 --- a/aligned/sources/in_mem_source.py +++ b/aligned/sources/in_mem_source.py @@ -1,13 +1,21 @@ +from typing import TYPE_CHECKING + import uuid import polars as pl from aligned.data_file import DataFileReference, upsert_on_column -from aligned.data_source.batch_data_source import BatchDataSource +from aligned.data_source.batch_data_source import BatchDataSource, CodableBatchDataSource from aligned.feature_source import WritableFeatureSource from aligned.retrival_job import RetrivalJob, RetrivalRequest +if TYPE_CHECKING: + from aligned.schemas.feature_view import CompiledFeatureView + + +class InMemorySource(CodableBatchDataSource, DataFileReference, WritableFeatureSource): + + type_name = 'in_mem_source' -class InMemorySource(BatchDataSource, DataFileReference, WritableFeatureSource): def __init__(self, data: pl.DataFrame) -> None: self.data = data self.job_key = str(uuid.uuid4()) @@ -38,6 +46,11 @@ async def overwrite(self, job: RetrivalJob, request: RetrivalRequest) -> None: async def write_polars(self, df: pl.LazyFrame) -> None: self.data = df.collect() + def with_view(self, view: 'CompiledFeatureView') -> 'InMemorySource': + if self.data.is_empty(): + return InMemorySource.from_values({feat.name: [] for feat in view.features}) + return self + @classmethod def multi_source_features_for( # type: ignore cls: type['InMemorySource'], diff --git a/aligned/sources/local.py b/aligned/sources/local.py index 25311cdc..dc6a86c9 100644 --- a/aligned/sources/local.py +++ b/aligned/sources/local.py @@ -27,6 +27,7 @@ if TYPE_CHECKING: from datetime import datetime from aligned.schemas.repo_definition import RepoDefinition + from aligned.schemas.feature_view import CompiledFeatureView from aligned.feature_store import ContractStore @@ -184,7 +185,8 @@ def job_group_key(self) -> str: def __hash__(self) -> int: return hash(self.job_group_key()) - def with_schema_version(self, schema_hash: bytes) -> CsvFileSource: + def with_view(self, view: CompiledFeatureView) -> CsvFileSource: + schema_hash = view.schema_hash() return CsvFileSource( path=self.path.replace(FileDirectory.schema_placeholder(), schema_hash.hex()), mapping_keys=self.mapping_keys, @@ -435,7 +437,8 @@ def job_group_key(self) -> str: def __hash__(self) -> int: return hash(self.job_group_key()) - def with_schema_version(self, schema_hash: bytes) -> PartitionedParquetFileSource: + def with_view(self, view: CompiledFeatureView) -> PartitionedParquetFileSource: + schema_hash = view.schema_hash() return PartitionedParquetFileSource( directory=self.directory.replace(FileDirectory.schema_placeholder(), schema_hash.hex()), partition_keys=self.partition_keys, @@ -588,7 +591,8 @@ def to_markdown(self) -> str: [Go to file]({self.path})''' # noqa - def with_schema_version(self, schema_hash: bytes) -> ParquetFileSource: + def with_view(self, view: CompiledFeatureView) -> ParquetFileSource: + schema_hash = view.schema_hash() return ParquetFileSource( path=self.path.replace(FileDirectory.schema_placeholder(), schema_hash.hex()), mapping_keys=self.mapping_keys, diff --git a/pyproject.toml b/pyproject.toml index 7d990f75..4a3e456c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "aligned" -version = "0.0.105" +version = "0.0.106" description = "A data managment and lineage tool for ML applications." authors = ["Mats E. Mollestad "] license = "Apache-2.0"