Skip to content

Commit

Permalink
A lot of minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
MatsMoll committed Jan 16, 2024
1 parent 7ffc51c commit e106a8d
Show file tree
Hide file tree
Showing 12 changed files with 128 additions and 63 deletions.
4 changes: 2 additions & 2 deletions aligned/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
feature_view,
combined_feature_view,
)
from aligned.schemas.text_vectoriser import TextVectoriserModel
from aligned.schemas.text_vectoriser import EmbeddingModel
from aligned.sources.kafka import KafkaConfig
from aligned.sources.local import FileSource
from aligned.sources.psql import PostgreSQLConfig
Expand Down Expand Up @@ -53,7 +53,7 @@
'EventTimestamp',
'Timestamp',
'Json',
'TextVectoriserModel',
'EmbeddingModel',
'feature_view',
'combined_feature_view',
'model_contract',
Expand Down
6 changes: 3 additions & 3 deletions aligned/compiler/feature_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from aligned.schemas.target import ClassificationTarget as ClassificationTargetSchemas
from aligned.schemas.target import ClassTargetProbability
from aligned.schemas.target import RegressionTarget as RegressionTargetSchemas
from aligned.schemas.transformation import TextVectoriserModel, Transformation
from aligned.schemas.transformation import EmbeddingModel, Transformation
from aligned.schemas.vector_storage import VectorStorage

if TYPE_CHECKING:
Expand Down Expand Up @@ -969,15 +969,15 @@ def contains(self, value: str) -> Bool:
feature.transformation = ContainsFactory(value, self)
return feature

def sentence_vector(self, model: TextVectoriserModel) -> Embedding:
def sentence_vector(self, model: EmbeddingModel) -> Embedding:
from aligned.compiler.transformation_factory import WordVectoriserFactory

feature = Embedding()
feature.transformation = WordVectoriserFactory(self, model)
feature.embedding_size = model.embedding_size
return feature

def embedding(self, model: TextVectoriserModel) -> Embedding:
def embedding(self, model: EmbeddingModel) -> Embedding:
return self.sentence_vector(model)

def append(self, other: FeatureFactory | str) -> String:
Expand Down
4 changes: 2 additions & 2 deletions aligned/compiler/transformation_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from aligned import AwsS3Config
from aligned.compiler.feature_factory import FeatureFactory, Transformation, TransformationFactory
from aligned.schemas.transformation import FillNaValuesColumns, LiteralValue, TextVectoriserModel
from aligned.schemas.transformation import FillNaValuesColumns, LiteralValue, EmbeddingModel

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -670,7 +670,7 @@ def copy(self) -> 'MeanTransfomrationFactory':
class WordVectoriserFactory(TransformationFactory):

feature: FeatureFactory
model: TextVectoriserModel
model: EmbeddingModel

@property
def using_features(self) -> list[FeatureFactory]:
Expand Down
12 changes: 11 additions & 1 deletion aligned/data_source/batch_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def all_between_dates(
return (
self.source.all_between_dates(request, start_date, end_date)
.filter(self.condition)
.aggregate(request)
.derive_features([request])
)

Expand All @@ -334,7 +335,12 @@ def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob:
else:
request.derived_features.add(self.condition)

return self.source.all_data(request, limit).filter(self.condition).derive_features([request])
return (
self.source.all_data(request, limit)
.filter(self.condition)
.aggregate(request)
.derive_features([request])
)

def depends_on(self) -> set[FeatureLocation]:
return self.source.depends_on()
Expand Down Expand Up @@ -498,6 +504,7 @@ def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob:
left_on=self.left_on,
right_on=self.right_on,
)
.aggregate(request)
.derive_features([request])
)

Expand All @@ -519,6 +526,7 @@ def all_between_dates(
left_on=self.left_on,
right_on=self.right_on,
)
.aggregate(request)
.derive_features([request])
)

Expand Down Expand Up @@ -599,6 +607,7 @@ def all_data(self, request: RetrivalRequest, limit: int | None) -> RetrivalJob:
self.source.all_data(self.left_request, limit=limit)
.derive_features([self.left_request])
.join(right_job, method=self.method, left_on=self.left_on, right_on=self.right_on)
.aggregate(request)
.derive_features([request])
)

Expand All @@ -620,6 +629,7 @@ def all_between_dates(
left_on=self.left_on,
right_on=self.right_on,
)
.aggregate(request)
.derive_features([request])
)

Expand Down
26 changes: 16 additions & 10 deletions aligned/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@
StreamAggregationJob,
SupervisedJob,
ConvertableToRetrivalJob,
SupervisedTrainJob,
)
from aligned.schemas.feature import FeatureLocation, Feature
from aligned.schemas.feature import FeatureLocation, Feature, FeatureReferance
from aligned.schemas.feature_view import CompiledFeatureView
from aligned.schemas.model import EventTrigger
from aligned.schemas.model import Model as ModelSchema
Expand Down Expand Up @@ -602,7 +601,11 @@ def model_features_for(self, view_name: str) -> set[str]:
all_model_features: set[str] = set()
for model in self.models.values():
all_model_features.update(
{feature.name for feature in model.features if feature.location.name == view_name}
{
feature.name
for feature in model.features.default_features
if feature.location.name == view_name
}
)
return all_model_features

Expand Down Expand Up @@ -880,7 +883,8 @@ def cached_at(self, location: DataFileReference) -> RetrivalJob:
"""
from aligned.local.job import FileFullJob

features = {f'{feature.location.identifier}:{feature.name}' for feature in self.model.features}
references = self.model.feature_references(self.selected_version)
features = {f'{feature.location.identifier}:{feature.name}' for feature in references}
request = self.store.requests_for(RawStringFeatureRequest(features))

return FileFullJob(location, RetrivalRequest.unsafe_combine(request.needed_requests)).select_columns(
Expand Down Expand Up @@ -1065,15 +1069,14 @@ class TaxiEta:
"""
await self.store.insert_into(FeatureLocation.model(self.model.name), predictions)

async def store_train_test_dataset(self, job: SupervisedTrainJob) -> SupervisedTrainJob:
pass


@dataclass
class SupervisedModelFeatureStore:

model: ModelSchema
store: FeatureStore
labels_estimates_refs: set[FeatureReferance]

selected_version: str | None = None

def features_for(
Expand Down Expand Up @@ -1114,7 +1117,7 @@ def features_for(
features = {f'{feature.location.identifier}:{feature.name}' for feature in feature_refs}
pred_view = self.model.predictions_view

target_feature_refs = pred_view.labels_estimates_refs()
target_feature_refs = self.labels_estimates_refs
target_features = {feature.identifier for feature in target_feature_refs}

targets = set()
Expand Down Expand Up @@ -1201,8 +1204,9 @@ def predictions_for(
request = pred_view.request(self.model.name)

target_features = pred_view.labels_estimates_refs()
labels = pred_view.labels()
target_features = {feature.identifier for feature in target_features}

labels = pred_view.labels()
pred_features = {f'model:{self.model.name}:{feature.name}' for feature in labels}
request = self.store.requests_for(
RawStringFeatureRequest(pred_features), event_timestamp_column=event_timestamp_column
Expand Down Expand Up @@ -1411,7 +1415,9 @@ def process_input(self, values: ConvertableToRetrivalJob) -> RetrivalJob:

job = RetrivalJob.from_convertable(values, request)

return job.fill_missing_columns().ensure_types([request]).derive_features([request])
return (
job.fill_missing_columns().ensure_types([request]).aggregate(request).derive_features([request])
)

async def batch_write(self, values: ConvertableToRetrivalJob | RetrivalJob) -> None:
"""Takes a set of features, computes the derived features, and store them in the source
Expand Down
7 changes: 5 additions & 2 deletions aligned/feature_view/feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,10 @@ class SomeView(FeatureView):
store.add_compiled_view(self.compile())
return store.feature_view(self.metadata.name)

async def process(self, data: dict[str, list[Any]]) -> list[dict]:
def process_input(self, data: ConvertableToRetrivalJob) -> RetrivalJob:
return self.query().process_input(data)

async def process(self, data: ConvertableToRetrivalJob) -> list[dict]:
df = await self.query().process_input(data).to_polars()
return df.collect().to_dicts()

Expand Down Expand Up @@ -673,7 +676,7 @@ class MyView:

return f"""
from aligned import feature_view, {all_types}
{imports or ""}
{imports or ''}
@feature_view(
name="{view_name}",
Expand Down
43 changes: 41 additions & 2 deletions aligned/retrival_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,14 @@ class SupervisedJob:

job: RetrivalJob
target_columns: set[str]
should_filter_out_null_targets: bool = True

async def to_pandas(self) -> SupervisedDataSet[pd.DataFrame]:
data = await self.job.to_pandas()

if self.should_filter_out_null_targets:
data = data.dropna(subset=list(self.target_columns))

features = {
feature.name
for feature in self.job.request_result.features
Expand All @@ -278,6 +283,9 @@ async def to_pandas(self) -> SupervisedDataSet[pd.DataFrame]:

async def to_polars(self) -> SupervisedDataSet[pl.LazyFrame]:
data = await self.job.to_polars()
if self.should_filter_out_null_targets:
data = data.drop_nulls([column for column in self.target_columns])

features = [
feature.name
for feature in self.job.request_result.features
Expand All @@ -288,6 +296,10 @@ async def to_polars(self) -> SupervisedDataSet[pl.LazyFrame]:
data, set(entities), set(features), self.target_columns, self.job.request_result.event_timestamp
)

def should_filter_null_targets(self, should_filter: bool) -> SupervisedJob:
self.should_filter_out_null_targets = should_filter
return self

@property
def request_result(self) -> RequestResult:
return self.job.request_result
Expand Down Expand Up @@ -608,6 +620,8 @@ def select_columns(self, include_features: set[str]) -> RetrivalJob:
return SelectColumnsJob(include_features, self)

def aggregate(self, request: RetrivalRequest) -> RetrivalJob:
if not request.aggregated_features:
return self
return AggregateJob(self, request)

def with_request(self, requests: list[RetrivalRequest]) -> RetrivalJob:
Expand Down Expand Up @@ -650,6 +664,9 @@ def ignore_event_timestamp(self) -> RetrivalJob:
return self.copy_with(self.job.ignore_event_timestamp())
raise NotImplementedError('Not implemented ignore_event_timestamp')

def polars_method(self, polars_method: Callable[[pl.LazyFrame], pl.LazyFrame]) -> RetrivalJob:
return CustomPolarsJob(self, polars_method)

@staticmethod
def from_dict(data: dict[str, list], request: list[RetrivalRequest] | RetrivalRequest) -> RetrivalJob:
if isinstance(request, RetrivalRequest):
Expand Down Expand Up @@ -724,6 +741,21 @@ def copy_with(self: JobType, job: RetrivalJob) -> JobType:
return self


@dataclass
class CustomPolarsJob(RetrivalJob, ModificationJob):

job: RetrivalJob
polars_method: Callable[[pl.LazyFrame], pl.LazyFrame]

async def to_polars(self) -> pl.LazyFrame:
df = await self.job.to_polars()
return self.polars_method(df)

async def to_pandas(self) -> pd.DataFrame:
df = await self.job.to_polars()
return df.collect().to_pandas()


@dataclass
class SubsetJob(RetrivalJob, ModificationJob):

Expand Down Expand Up @@ -1213,6 +1245,13 @@ def retrival_requests(self) -> list[RetrivalRequest]:
async def compute_derived_features_polars(self, df: pl.LazyFrame) -> pl.LazyFrame:

for request in self.requests:

missing_features = request.features_to_include - set(df.columns)

if len(missing_features) == 0:
logger.debug('Skipping to compute derived features as they are already computed')
continue

for feature_round in request.derived_features_order():

round_expressions: list[pl.Expr] = []
Expand Down Expand Up @@ -1687,7 +1726,7 @@ async def to_pandas(self) -> pd.DataFrame:
df[feature.name] = pd.to_datetime(df[feature.name], infer_datetime_format=True, utc=True)
elif feature.dtype == FeatureType.datetime() or feature.dtype == FeatureType.string():
continue
elif feature.dtype == FeatureType.array():
elif (feature.dtype == FeatureType.array()) or (feature.dtype == FeatureType.embedding()):
import json

if df[feature.name].dtype == 'object':
Expand Down Expand Up @@ -1738,7 +1777,7 @@ async def to_polars(self) -> pl.LazyFrame:
.cast(pl.Datetime(time_zone='UTC'))
.alias(feature.name)
)
elif feature.dtype == FeatureType.array():
elif (feature.dtype == FeatureType.array()) or (feature.dtype == FeatureType.embedding()):
dtype = df.select(feature.name).dtypes[0]
if dtype == pl.Utf8:
df = df.with_columns(pl.col(feature.name).str.json_extract(pl.List(pl.Utf8)))
Expand Down
13 changes: 5 additions & 8 deletions aligned/schemas/feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ class CompiledFeatureView(Codable):

def __pre_serialize__(self) -> CompiledFeatureView:
assert isinstance(self.name, str)
assert isinstance(self.description, str)
assert isinstance(self.tags, dict)
assert isinstance(self.source, BatchDataSource)

for entity in self.entities:
assert isinstance(entity, Feature)
for feature in self.features:
Expand All @@ -54,6 +54,9 @@ def __pre_serialize__(self) -> CompiledFeatureView:
assert isinstance(derived_feature, DerivedFeature)
for aggregated_feature in self.aggregated_features:
assert isinstance(aggregated_feature, AggregatedFeature)

if self.description is not None:
assert isinstance(self.description, str)
if self.event_timestamp is not None:
assert isinstance(self.event_timestamp, EventTimestamp)
if self.stream_data_source is not None:
Expand Down Expand Up @@ -164,12 +167,6 @@ def dependent_features_for(
derived_features.update(intermediate)
aggregated_features.update(aggregated)

all_features = features.union(derived_features).union(
{feature.derived_feature for feature in aggregated_features}
)

exclude_names = {feature.name for feature in all_features} - feature_names

return FeatureRequest(
FeatureLocation.feature_view(self.name),
feature_names,
Expand All @@ -182,7 +179,7 @@ def dependent_features_for(
derived_features=derived_features,
aggregated_features=aggregated_features,
event_timestamp=self.event_timestamp,
features_to_include=exclude_names,
features_to_include=feature_names,
)
],
)
Expand Down
Loading

0 comments on commit e106a8d

Please sign in to comment.