Skip to content

Commit

Permalink
Improved datetime encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
MatsMoll committed Mar 25, 2024
1 parent 6eaab9a commit 246ce6d
Show file tree
Hide file tree
Showing 24 changed files with 443 additions and 215 deletions.
19 changes: 10 additions & 9 deletions aligned/feature_view/tests/test_combined_view.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from aligned import FeatureStore, feature_view, Int32, FileSource
from aligned import FeatureStore, feature_view, Int32, Int64, FileSource


@pytest.mark.asyncio
Expand Down Expand Up @@ -44,19 +44,19 @@ async def test_new_combined_solution() -> None:

@feature_view(name='test', source=FileSource.csv_at('test_data/test.csv'))
class Test:
some_id = Int32().as_entity()
some_id = Int64().as_entity()

feature = Int32()
feature = Int64()

derived_feature = feature * 10

@feature_view(name='other', source=FileSource.csv_at('test_data/other.csv'))
class Other:

other_id = Int32().as_entity()
some_id = Int32()
other_id = Int64().as_entity()
some_id = Int64()

other_feature = Int32()
other_feature = Int64()

test_feature = other_feature * 10

Expand All @@ -65,13 +65,14 @@ class Other:

@feature_view(name='combined', source=Test.join(other, on=test.some_id)) # type: ignore
class Combined:
some_id = Int32().as_entity()
some_id = Int64().as_entity()

new_feature = test.derived_feature * other.test_feature

result = await Combined.query().all().to_pandas() # type: ignore
result['new_feature'] = result['new_feature'].astype('int64')
assert result[expected_df.columns].equals(expected_df)

new_df = result.sort_values('some_id', ascending=True)[expected_df.columns].reset_index(drop=True)
assert new_df.equals(expected_df)


@pytest.mark.asyncio
Expand Down
22 changes: 10 additions & 12 deletions aligned/local/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
from aligned.request.retrival_request import AggregatedFeature, AggregateOver, RetrivalRequest
from aligned.retrival_job import RequestResult, RetrivalJob
from aligned.schemas.date_formatter import DateFormatter
from aligned.schemas.feature import Feature, FeatureType
from aligned.schemas.feature import Feature
from aligned.sources.local import DataFileReference
from aligned.schemas.constraints import Optional
import logging

logger = logging.getLogger(__name__)


class LiteralRetrivalJob(RetrivalJob):
Expand Down Expand Up @@ -133,21 +136,23 @@ def decode_timestamps(df: pl.LazyFrame, request: RetrivalRequest, formatter: Dat
and feature.name in df.columns
and not isinstance(dtypes[feature.name], pl.Datetime)
):
columns.add((feature.name, None))
columns.add((feature.name, feature.dtype.datetime_timezone))

if (
request.event_timestamp
and request.event_timestamp.name in df.columns
and not isinstance(dtypes[request.event_timestamp.name], pl.Datetime)
):
columns.add((request.event_timestamp.name, None))
columns.add((request.event_timestamp.name, request.event_timestamp.dtype.datetime_timezone))

if not columns:
return df

exprs = []

for column, time_zone in columns:
logger.info(f'Decoding column {column} with timezone {time_zone}')

if time_zone is None:
exprs.append(formatter.decode_polars(column).alias(column))
else:
Expand Down Expand Up @@ -379,7 +384,7 @@ async def file_transformations(self, df: pl.LazyFrame) -> pl.LazyFrame:
did_rename_event_timestamp = True

row_id_name = 'row_id'
result = result.with_row_count(row_id_name)
result = result.with_row_index(row_id_name)
for request in self.requests:

entity_names = request.entity_names
Expand Down Expand Up @@ -468,11 +473,6 @@ async def file_transformations(self, df: pl.LazyFrame) -> pl.LazyFrame:
field = request.event_timestamp.name
ttl = request.event_timestamp.ttl

if new_result.select(field).dtypes[0] == pl.Utf8():
new_result = new_result.with_columns(
pl.col(field).str.strptime(pl.Datetime, '%+').alias(field)
)

if ttl:
ttl_request = (pl.col(field) <= pl.col(event_timestamp_col)) & (
pl.col(field) >= pl.col(event_timestamp_col) - ttl
Expand All @@ -484,9 +484,7 @@ async def file_transformations(self, df: pl.LazyFrame) -> pl.LazyFrame:
)
new_result = new_result.sort(field, descending=True).select(pl.exclude(field))
elif request.event_timestamp:
new_result = new_result.sort(
[row_id_name, request.event_timestamp.name], descending=True
).select(pl.exclude(request.event_timestamp.name))
new_result = new_result.sort([row_id_name, request.event_timestamp.name], descending=True)

unique = new_result.unique(subset=row_id_name, keep='first')
column_selects.remove('row_id')
Expand Down
44 changes: 2 additions & 42 deletions aligned/retrival_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import timeit
from abc import ABC, abstractmethod
from collections import defaultdict
from contextlib import suppress
from dataclasses import dataclass, field
from datetime import datetime
from typing import TYPE_CHECKING, Callable, Union, TypeVar, Coroutine, Any
Expand Down Expand Up @@ -1889,47 +1888,8 @@ def retrival_requests(self) -> list[RetrivalRequest]:
return self.requests

async def to_pandas(self) -> pd.DataFrame:
df = await self.job.to_pandas()
for request in self.requests:
features_to_check = request.all_required_features

if request.aggregated_features:
features_to_check = {feature.derived_feature for feature in request.aggregated_features}

for feature in features_to_check:

mask = ~df[feature.name].isnull()

with suppress(AttributeError, TypeError):
df[feature.name] = df[feature.name].mask(
~mask, other=df.loc[mask, feature.name].str.strip('"')
)

if feature.dtype.is_datetime:
df[feature.name] = pd.to_datetime(df[feature.name], infer_datetime_format=True, utc=True)
elif feature.dtype == FeatureType.string():
continue
elif (feature.dtype.is_array) or (feature.dtype == FeatureType.embedding()):
import json

if df[feature.name].dtype == 'object':
df[feature.name] = df[feature.name].apply(
lambda x: json.loads(x) if isinstance(x, str) else x
)
elif (feature.dtype == FeatureType.json()) or feature.dtype.is_datetime:
pass
else:
if feature.dtype.is_numeric:
df[feature.name] = pd.to_numeric(df[feature.name], errors='coerce').astype(
feature.dtype.pandas_type
)
else:
df[feature.name] = df[feature.name].astype(feature.dtype.pandas_type)

if request.event_timestamp and request.event_timestamp.name in df.columns:
feature = request.event_timestamp
df[feature.name] = pd.to_datetime(df[feature.name], infer_datetime_format=True, utc=True)
return df
df = await self.to_polars()
return df.to_pandas()

async def to_lazy_polars(self) -> pl.LazyFrame:
df = await self.job.to_lazy_polars()
Expand Down
33 changes: 26 additions & 7 deletions aligned/schemas/date_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@ class AllDateFormatters:
@classmethod
def shared(cls) -> AllDateFormatters:
if cls._shared is None:
formatters = [
Timestamp,
StringDateFormatter,
]
formatters = [Timestamp, StringDateFormatter, NoopFormatter]
cls._shared = AllDateFormatters({formatter.name(): formatter for formatter in formatters})
return cls._shared

Expand Down Expand Up @@ -58,12 +55,34 @@ def string_format(format: str) -> StringDateFormatter:

@staticmethod
def iso_8601() -> StringDateFormatter:
return StringDateFormatter('yyyy-MM-ddTHH:mm:ssZ')
return StringDateFormatter('%Y-%m-%dT%H:%M:%S%.f+%Z')

@staticmethod
def unix_timestamp(time_unit: TimeUnit = 'us', time_zone: str | None = 'UTC') -> Timestamp:
return Timestamp(time_unit, time_zone)

@staticmethod
def noop() -> DateFormatter:
return NoopFormatter()


@dataclass
class NoopFormatter(DateFormatter):
"""
A formatter that assumes that the underlying format can store timestamps.
Therefore, no decoding or encoding is necessary.
"""

@classmethod
def name(cls) -> str:
return 'noop'

def decode_polars(self, column: str) -> pl.Expr:
return pl.col(column)

def encode_polars(self, column: str) -> pl.Expr:
return pl.col(column)


@dataclass
class Timestamp(DateFormatter):
Expand Down Expand Up @@ -97,8 +116,8 @@ def name(cls) -> str:

def decode_polars(self, column: str) -> pl.Expr:
return pl.col(column).str.to_datetime(
self.date_format, time_unit=self.time_unit, time_zone=self.time_zone
format=self.date_format, time_unit=self.time_unit, time_zone=self.time_zone
)

def encode_polars(self, column: str) -> pl.Expr:
return pl.col(column).dt.strftime(self.date_format)
return pl.col(column).dt.to_string(self.date_format)
Loading

0 comments on commit 246ce6d

Please sign in to comment.