Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metadata cleanup #272

Closed
wants to merge 19 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/dlt_tempo.py
Original file line number Diff line number Diff line change
@@ -26,7 +26,7 @@ def ts_bronze():
@dlt.expect_or_drop("User_check","User in ('a','c','i')")
def ts_ft():
phone_accel_df = dlt.read("ts_bronze")
phone_accel_tsdf = TSDF(phone_accel_df, ts_col="event_ts", partition_cols = ["User"])
phone_accel_tsdf = TSDF(phone_accel_df, ts_col="event_ts")
ts_ft_df = phone_accel_tsdf.fourier_transform(timestep=1, valueCol="x").df
return ts_ft_df

6 changes: 3 additions & 3 deletions examples/financial_services_quickstart.py
Original file line number Diff line number Diff line change
@@ -92,8 +92,8 @@
# DBTITLE 1,Define TSDF Time Series Data Structure
from tempo import *

trades_tsdf = TSDF(trades_df, partition_cols = ['date', 'symbol'], ts_col = 'event_ts')
quotes_tsdf = TSDF(quotes_df, partition_cols = ['date', 'symbol'], ts_col = 'event_ts')
trades_tsdf = TSDF(trades_df, ts_col='event_ts')
quotes_tsdf = TSDF(quotes_df, ts_col='event_ts')

# COMMAND ----------

@@ -178,7 +178,7 @@
from tempo import *
from pyspark.sql.functions import *

minute_bars = TSDF(spark.table("time_test"), partition_cols=['ticker'], ts_col="ts").calc_bars(freq = '1 minute', func= 'ceil')
minute_bars = TSDF(spark.table("time_test"), ts_col="ts").calc_bars(freq ='1 minute', func='ceil')

display(minute_bars)

45 changes: 22 additions & 23 deletions python/README.md
Original file line number Diff line number Diff line change
@@ -51,7 +51,7 @@ phone_accel_df = spark.read.format("csv").option("header", "true").load("dbfs:/h

from tempo import *

phone_accel_tsdf = TSDF(phone_accel_df, ts_col="event_ts", partition_cols = ["User"])
phone_accel_tsdf = TSDF(phone_accel_df, ts_col="event_ts", series_ids = ["User"])

display(phone_accel_tsdf)
```
@@ -65,7 +65,7 @@ Note: You can upsample any missing values by using an option in the resample int

```python
# ts_col = timestamp column on which to sort fact and source table
# partition_cols - columns to use for partitioning the TSDF into more granular time series for windowing and sorting
# series_ids - columns to use for partitioning the TSDF into more granular time series for windowing and sorting

resampled_sdf = phone_accel_tsdf.resample(freq='min', func='floor')
resampled_pdf = resampled_sdf.df.filter(col('event_ts').cast("date") == "2015-02-23").toPandas()
@@ -97,7 +97,7 @@ from pyspark.sql.functions import *

watch_accel_df = spark.read.format("csv").option("header", "true").load("dbfs:/home/tempo/Watch_accelerometer").withColumn("event_ts", (col("Arrival_Time").cast("double")/1000).cast("timestamp")).withColumn("x", col("x").cast("double")).withColumn("y", col("y").cast("double")).withColumn("z", col("z").cast("double")).withColumn("event_ts_dbl", col("event_ts").cast("double"))

watch_accel_tsdf = TSDF(watch_accel_df, ts_col="event_ts", partition_cols = ["User"])
watch_accel_tsdf = TSDF(watch_accel_df, ts_col="event_ts", series_ids = ["User"])

# Applying AS OF join to TSDF datasets
joined_df = watch_accel_tsdf.asofJoin(phone_accel_tsdf, right_prefix="phone_accel")
@@ -107,12 +107,12 @@ display(joined_df)

#### 3. Skew Join Optimized AS OF Join

The purpose of the skew optimized as of join is to bucket each set of `partition_cols` to get the latest source record merged onto the fact table
The purpose of the skew optimized as of join is to bucket each set of `series_ids` to get the latest source record merged onto the fact table

Parameters:

ts_col = timestamp column for sorting
partition_cols = partition columns for defining granular time series for windowing and sorting
series_ids = partition columns for defining granular time series for windowing and sorting
tsPartitionVal = value to break up each partition into time brackets
fraction = overlap fraction
right_prefix = prefix used for source columns when merged into fact table
@@ -185,11 +185,10 @@ Valid columns data types for interpolation are: `["int", "bigint", "float", "dou
```python
# Create instance of the TSDF class
input_tsdf = TSDF(
input_df,
partition_cols=["partition_a", "partition_b"],
ts_col="event_ts",
)

input_df,
series_ids=["partition_a", "partition_b"],
ts_col="event_ts",
)

# What the following chain of operation does is:
# 1. Aggregate all valid numeric columns using mean into 30 second intervals
@@ -205,32 +204,32 @@ interpolated_tsdf = input_tsdf.resample(freq="30 seconds", func="mean").interpol
interpolated_tsdf = input_tsdf.interpolate(
freq="30 seconds",
func="mean",
target_cols= ["columnA","columnB"],
target_cols=["columnA", "columnB"],
method="linear"

)

# Alternatively it's also possible to override default TSDF parameters.
# e.g. partition_cols, ts_col a
# e.g. series_ids, ts_col a
interpolated_tsdf = input_tsdf.interpolate(
partition_cols=["partition_c"],
series_ids=["partition_c"],
ts_col="other_event_ts"
freq="30 seconds",
func="mean",
target_cols= ["columnA","columnB"],
method="linear"
freq = "30 seconds",
func = "mean",
target_cols = ["columnA", "columnB"],
method = "linear"
)

# The show_interpolated flag can be set to `True` to show additional boolean columns
# for a given row that shows if a column has been interpolated.
interpolated_tsdf = input_tsdf.interpolate(
partition_cols=["partition_c"],
series_ids=["partition_c"],
ts_col="other_event_ts"
freq="30 seconds",
func="mean",
method="linear",
target_cols= ["columnA","columnB"],
show_interpolated=True,
freq = "30 seconds",
func = "mean",
method = "linear",
target_cols = ["columnA", "columnB"],
show_interpolated = True,
)

```
130 changes: 47 additions & 83 deletions python/tempo/interpol.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
from __future__ import annotations

from typing import List, Optional, Union, Callable
from typing import List

from pyspark.sql.dataframe import DataFrame
from pyspark.sql.functions import col, expr, last, lead, lit, when
from pyspark.sql.window import Window

import tempo.utils as t_utils
import tempo.resample as t_resample
import tempo.tsdf as t_tsdf
from tempo.utils import calculate_time_horizon
from tempo.resample import checkAllowableFreq, freq_dict

# Interpolation fill options
method_options = ["zero", "null", "bfill", "ffill", "linear"]
@@ -19,7 +15,7 @@ class Interpolation:
def __init__(self, is_resampled: bool):
self.is_resampled = is_resampled

def __validate_fill(self, method: str) -> None:
def __validate_fill(self, method: str):
"""
Validate if the fill provided is within the allowed list of values.
@@ -33,11 +29,10 @@ def __validate_fill(self, method: str) -> None:
def __validate_col(
self,
df: DataFrame,
partition_cols: Optional[List[str]],
partition_cols: List[str],
target_cols: List[str],
ts_col: str,
ts_col_dtype: Optional[str] = None, # NB: added for testing purposes only
) -> None:
):
"""
Validate if target column exists and is of numeric type, and validates if partition column exists.
@@ -47,12 +42,11 @@ def __validate_col(
:param ts_col: Timestamp column to be validated
"""

if partition_cols is not None:
for column in partition_cols:
if column not in str(df.columns):
raise ValueError(
f"Partition Column: '{column}' does not exist in DataFrame."
)
for column in partition_cols:
if column not in str(df.columns):
raise ValueError(
f"Partition Column: '{column}' does not exist in DataFrame."
)
for column in target_cols:
if column not in str(df.columns):
raise ValueError(
@@ -68,14 +62,10 @@ def __validate_col(
f"Timestamp Column: '{ts_col}' does not exist in DataFrame."
)

if ts_col_dtype is None:
ts_col_dtype = df.select(ts_col).dtypes[0][1]
if ts_col_dtype != "timestamp":
if df.select(ts_col).dtypes[0][1] != "timestamp":
raise ValueError("Timestamp Column needs to be of timestamp type.")

def __calc_linear_spark(
self, df: DataFrame, ts_col: str, target_col: str
) -> DataFrame:
def __calc_linear_spark(self, df: DataFrame, ts_col: str, target_col: str):
"""
Native Spark function for calculating linear interpolation on a DataFrame.
@@ -194,7 +184,7 @@ def __interpolate_column(
return output_df

def __generate_time_series_fill(
self, df: DataFrame, partition_cols: Optional[List[str]], ts_col: str
self, df: DataFrame, partition_cols: List[str], ts_col: str
) -> DataFrame:
"""
Create additional timeseries columns for previous and next timestamps
@@ -203,20 +193,13 @@ def __generate_time_series_fill(
:param partition_cols: partition column names
:param ts_col: timestamp column name
"""
return df.withColumn(
"previous_timestamp",
col(ts_col),
).withColumn(
return df.withColumn("previous_timestamp", col(ts_col),).withColumn(
"next_timestamp",
lead(df[ts_col]).over(Window.partitionBy(*partition_cols).orderBy(ts_col)),
)

def __generate_column_time_fill(
self,
df: DataFrame,
partition_cols: Optional[List[str]],
ts_col: str,
target_col: str,
self, df: DataFrame, partition_cols: List[str], ts_col: str, target_col: str
) -> DataFrame:
"""
Create timeseries columns for previous and next timestamps for a specific target column
@@ -226,30 +209,24 @@ def __generate_column_time_fill(
:param ts_col: timestamp column name
:param target_col: target column name
"""
window = Window
if partition_cols is not None:
window = Window.partitionBy(*partition_cols)

return df.withColumn(
f"previous_timestamp_{target_col}",
last(col(f"{ts_col}_{target_col}"), ignorenulls=True).over(
window.orderBy(ts_col).rowsBetween(Window.unboundedPreceding, 0)
Window.partitionBy(*partition_cols)
.orderBy(ts_col)
.rowsBetween(Window.unboundedPreceding, 0)
),
).withColumn(
f"next_timestamp_{target_col}",
last(col(f"{ts_col}_{target_col}"), ignorenulls=True).over(
window.orderBy(col(ts_col).desc()).rowsBetween(
Window.unboundedPreceding, 0
)
Window.partitionBy(*partition_cols)
.orderBy(col(ts_col).desc())
.rowsBetween(Window.unboundedPreceding, 0)
),
)

def __generate_target_fill(
self,
df: DataFrame,
partition_cols: Optional[List[str]],
ts_col: str,
target_col: str,
self, df: DataFrame, partition_cols: List[str], ts_col: str, target_col: str
) -> DataFrame:
"""
Create columns for previous and next value for a specific target column
@@ -259,39 +236,39 @@ def __generate_target_fill(
:param ts_col: timestamp column name
:param target_col: target column name
"""
window = Window

if partition_cols is not None:
window = Window.partitionBy(*partition_cols)
return (
df.withColumn(
f"previous_{target_col}",
last(df[target_col], ignorenulls=True).over(
window.orderBy(ts_col).rowsBetween(Window.unboundedPreceding, 0)
Window.partitionBy(*partition_cols)
.orderBy(ts_col)
.rowsBetween(Window.unboundedPreceding, 0)
),
)
# Handle if subsequent value is null
.withColumn(
f"next_null_{target_col}",
last(df[target_col], ignorenulls=True).over(
window.orderBy(col(ts_col).desc()).rowsBetween(
Window.unboundedPreceding, 0
)
Window.partitionBy(*partition_cols)
.orderBy(col(ts_col).desc())
.rowsBetween(Window.unboundedPreceding, 0)
),
).withColumn(
f"next_{target_col}",
lead(df[target_col]).over(window.orderBy(ts_col)),
lead(df[target_col]).over(
Window.partitionBy(*partition_cols).orderBy(ts_col)
),
)
)

def interpolate(
self,
tsdf: t_tsdf.TSDF,
tsdf,
ts_col: str,
partition_cols: Optional[List[str]],
series_ids: List[str],
target_cols: List[str],
freq: Optional[str],
func: Optional[Union[Callable | str]],
freq: str,
func: str,
method: str,
show_interpolated: bool,
perform_checks: bool = True,
@@ -302,7 +279,7 @@ def interpolate(
:param tsdf: input TSDF
:param ts_col: timestamp column name
:param target_cols: numeric columns to interpolate
:param partition_cols: partition columns names
:param series_ids: partition columns names
:param freq: frequency at which to sample
:param func: aggregate function used for sampling to the specified interval
:param method: interpolation function usded to fill missing values
@@ -312,42 +289,29 @@ def interpolate(
"""
# Validate input parameters
self.__validate_fill(method)
self.__validate_col(tsdf.df, partition_cols, target_cols, ts_col)

if freq is None:
raise ValueError("freq cannot be None")

if func is None:
raise ValueError("func cannot be None")

if callable(func):
raise ValueError("func must be a string")
self.__validate_col(tsdf.df, series_ids, target_cols, ts_col)

# Convert Frequency using resample dictionary
parsed_freq = t_resample.checkAllowableFreq(freq)
period, unit = parsed_freq[0], parsed_freq[1]
freq = f"{period} {t_resample.freq_dict[unit]}" # type: ignore[literal-required]
parsed_freq = checkAllowableFreq(freq)
freq = f"{parsed_freq[0]} {freq_dict[parsed_freq[1]]}"

# Throw warning for user to validate that the expected number of output rows is valid.
if perform_checks:
t_utils.calculate_time_horizon(tsdf.df, ts_col, freq, partition_cols)
calculate_time_horizon(tsdf.df, ts_col, freq, series_ids)

# Only select required columns for interpolation
input_cols: List[str] = [ts_col, *target_cols]
if partition_cols is not None:
input_cols += [*partition_cols]

input_cols: List[str] = [*series_ids, ts_col, *target_cols]
sampled_input: DataFrame = tsdf.df.select(*input_cols)

if self.is_resampled is False:
# Resample and Normalize Input
sampled_input = tsdf.resample(
sampled_input: DataFrame = tsdf.resample(
freq=freq, func=func, metricCols=target_cols
).df

# Fill timeseries for nearest values
time_series_filled = self.__generate_time_series_fill(
sampled_input, partition_cols, ts_col
sampled_input, series_ids, ts_col
)

# Generate surrogate timestamps for each target column
@@ -359,7 +323,7 @@ def interpolate(
when(col(column).isNull(), None).otherwise(col(ts_col)),
)
add_column_time = self.__generate_column_time_fill(
add_column_time, partition_cols, ts_col, column
add_column_time, series_ids, ts_col, column
)

# Handle edge case if last value (latest) is null
@@ -374,7 +338,7 @@ def interpolate(
target_column_filled = edge_filled
for column in target_cols:
target_column_filled = self.__generate_target_fill(
target_column_filled, partition_cols, ts_col, column
target_column_filled, series_ids, ts_col, column
)

# Generate missing timeseries values
@@ -400,7 +364,7 @@ def interpolate(
interpolated_result: DataFrame = flagged_series
for target_col in target_cols:
# Interpolate target columns
interpolated_result = self.__interpolate_column(
interpolated_result: DataFrame = self.__interpolate_column(
interpolated_result, ts_col, target_col, method
)

57 changes: 28 additions & 29 deletions python/tempo/intervals.py
Original file line number Diff line number Diff line change
@@ -3,10 +3,9 @@
from typing import Optional
from functools import cached_property

import pyspark.sql
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import NumericType, BooleanType, StructField
import pyspark.sql.functions as f
import pyspark.sql.functions as Fn
from pyspark.sql.window import Window


@@ -32,11 +31,7 @@ class IntervalsDF:
"""

def __init__(
self,
df: DataFrame,
start_ts: str,
end_ts: str,
series_ids: Optional[list[str]] = None,
self, df: DataFrame, start_ts: str, end_ts: str, series_ids: list[str] = None
) -> None:
"""
Constructor for :class:`IntervalsDF`.
@@ -105,7 +100,7 @@ def metric_columns(self) -> list[str]:
return [col.name for col in self.df.schema.fields if is_metric_col(col)]

@cached_property
def window(self) -> pyspark.sql.window:
def window(self):
return Window.partitionBy(*self.series_ids).orderBy(*self.interval_boundaries)

@classmethod
@@ -186,7 +181,7 @@ def fromStackedMetrics(

df = (
df.groupBy(start_ts, end_ts, *series)
.pivot(metrics_name_col, values=metric_names)
.pivot(metrics_name_col, values=metric_names) # type: ignore
.max(metrics_value_col)
)

@@ -210,10 +205,10 @@ def __get_adjacent_rows(self, df: DataFrame) -> DataFrame:
for c in self.interval_boundaries + self.metric_columns:
df = df.withColumn(
f"_lead_1_{c}",
f.lead(c, 1).over(self.window),
Fn.lead(c, 1).over(self.window),
).withColumn(
f"_lag_1_{c}",
f.lag(c, 1).over(self.window),
Fn.lag(c, 1).over(self.window),
)

return df
@@ -236,8 +231,8 @@ def __identify_subset_intervals(self, df: DataFrame) -> tuple[DataFrame, str]:

df = df.withColumn(
subset_indicator,
(f.col(f"_lag_1_{self.start_ts}") <= f.col(self.start_ts))
& (f.col(f"_lag_1_{self.end_ts}") >= f.col(self.end_ts)),
(Fn.col(f"_lag_1_{self.start_ts}") <= Fn.col(self.start_ts))
& (Fn.col(f"_lag_1_{self.end_ts}") >= Fn.col(self.end_ts)),
)

# NB: the first record cannot be a subset of the previous and
@@ -271,12 +266,12 @@ def __identify_overlaps(self, df: DataFrame) -> tuple[DataFrame, list[str]]:
for ts in self.interval_boundaries:
df = df.withColumn(
f"_lead_1_{ts}_overlaps",
(f.col(f"_lead_1_{ts}") > f.col(self.start_ts))
& (f.col(f"_lead_1_{ts}") < f.col(self.end_ts)),
(Fn.col(f"_lead_1_{ts}") > Fn.col(self.start_ts))
& (Fn.col(f"_lead_1_{ts}") < Fn.col(self.end_ts)),
).withColumn(
f"_lag_1_{ts}_overlaps",
(f.col(f"_lag_1_{ts}") > f.col(self.start_ts))
& (f.col(f"_lag_1_{ts}") < f.col(self.end_ts)),
(Fn.col(f"_lag_1_{ts}") > Fn.col(self.start_ts))
& (Fn.col(f"_lag_1_{ts}") < Fn.col(self.end_ts)),
)

overlap_indicators.extend(
@@ -321,9 +316,9 @@ def __merge_adjacent_subset_and_superset(
for c in self.metric_columns:
df = df.withColumn(
c,
f.when(
f.col(subset_indicator), f.coalesce(f.col(c), f"_lag_1_{c}")
).otherwise(f.col(c)),
Fn.when(
Fn.col(subset_indicator), Fn.coalesce(Fn.col(c), f"_lag_1_{c}")
).otherwise(Fn.col(c)),
)

return df
@@ -351,12 +346,14 @@ def __merge_adjacent_overlaps(
"""

if how == "left":

# new boundary for interval end will become the start of the next
# interval
new_boundary_col = self.end_ts
new_boundary_val = f"_lead_1_{self.start_ts}"

else:

# new boundary for interval start will become the end of the
# previous interval
new_boundary_col = self.start_ts
@@ -385,22 +382,23 @@ def __merge_adjacent_overlaps(

df = df.withColumn(
new_boundary_col,
f.expr(new_interval_boundaries),
Fn.expr(new_interval_boundaries),
)

if how == "left":

for c in self.metric_columns:
df = df.withColumn(
c,
# needed when intervals have same start but different ends
# in this case, merge metrics since they overlap
f.when(
f.col(f"_lag_1_{self.end_ts}_overlaps"),
f.coalesce(f.col(c), f.col(f"_lag_1_{c}")),
Fn.when(
Fn.col(f"_lag_1_{self.end_ts}_overlaps"),
Fn.coalesce(Fn.col(c), Fn.col(f"_lag_1_{c}")),
)
# general case when constructing left disjoint interval
# just want new boundary without merging metrics
.otherwise(f.col(c)),
.otherwise(Fn.col(c)),
)

return df
@@ -423,7 +421,7 @@ def __merge_equal_intervals(self, df: DataFrame) -> DataFrame:
"""

merge_expr = tuple(f.max(c).alias(c) for c in self.metric_columns)
merge_expr = tuple(Fn.max(c).alias(c) for c in self.metric_columns)

return df.groupBy(*self.interval_boundaries, *self.series_ids).agg(*merge_expr)

@@ -469,7 +467,7 @@ def disjoint(self) -> "IntervalsDF":

(df, subset_indicator) = self.__identify_subset_intervals(df)

subset_df = df.filter(f.col(subset_indicator))
subset_df = df.filter(Fn.col(subset_indicator))

subset_df = self.__merge_adjacent_subset_and_superset(
subset_df, subset_indicator
@@ -479,7 +477,7 @@ def disjoint(self) -> "IntervalsDF":
*self.interval_boundaries, *self.series_ids, *self.metric_columns
)

non_subset_df = df.filter(~f.col(subset_indicator))
non_subset_df = df.filter(~Fn.col(subset_indicator))

(non_subset_df, overlap_indicators) = self.__identify_overlaps(non_subset_df)

@@ -601,6 +599,7 @@ def toDF(self, stack: bool = False) -> DataFrame:
"""

if stack:

n_cols = len(self.metric_columns)
metric_cols_expr = ",".join(
tuple(f"'{col}', {col}" for col in self.metric_columns)
@@ -611,7 +610,7 @@ def toDF(self, stack: bool = False) -> DataFrame:
)

return self.df.select(
*self.interval_boundaries, *self.series_ids, f.expr(stack_expr)
*self.interval_boundaries, *self.series_ids, Fn.expr(stack_expr)
).dropna(subset="metric_value")

else:
20 changes: 9 additions & 11 deletions python/tempo/io.py
Original file line number Diff line number Diff line change
@@ -3,22 +3,20 @@
import os
import logging
from collections import deque
from typing import Optional

import tempo.tsdf as t_tsdf
import pyspark.sql.functions as f
import tempo
import pyspark.sql.functions as Fn
from pyspark.sql import SparkSession
from pyspark.sql.utils import ParseException

logger = logging.getLogger(__name__)


def write(
tsdf: t_tsdf.TSDF,
tsdf: tempo.TSDF,
spark: SparkSession,
tabName: str,
optimizationCols: Optional[list[str]] = None,
) -> None:
optimizationCols: list[str] = None,
):
"""
param: tsdf: input TSDF object to write
param: tabName Delta output table name
@@ -29,17 +27,17 @@ def write(

df = tsdf.df
ts_col = tsdf.ts_col
partitionCols = tsdf.partitionCols
series_ids = tsdf.series_ids
if optimizationCols:
optimizationCols = optimizationCols + ["event_time"]
else:
optimizationCols = ["event_time"]

useDeltaOpt = os.getenv("DATABRICKS_RUNTIME_VERSION") is not None

view_df = df.withColumn("event_dt", f.to_date(f.col(ts_col))).withColumn(
view_df = df.withColumn("event_dt", Fn.to_date(Fn.col(ts_col))).withColumn(
"event_time",
f.translate(f.split(f.col(ts_col).cast("string"), " ")[1], ":", "").cast(
Fn.translate(Fn.split(Fn.col(ts_col).cast("string"), " ")[1], ":", "").cast(
"double"
),
)
@@ -55,7 +53,7 @@ def write(
try:
spark.sql(
"optimize {} zorder by {}".format(
tabName, "(" + ",".join(partitionCols + optimizationCols) + ")"
tabName, "(" + ",".join(series_ids + optimizationCols) + ")"
)
)
except ParseException as e:
212 changes: 67 additions & 145 deletions python/tempo/resample.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,13 @@
from __future__ import annotations

from typing import (
Union,
Optional,
Tuple,
Any,
TypedDict,
List,
Callable,
get_type_hints,
)
from typing import Union, Optional

import pyspark.sql.functions as f
import tempo

import pyspark.sql.functions as Fn
from pyspark.sql.window import Window
from pyspark.sql import DataFrame

import tempo.tsdf as t_tsdf

# define global frequency options
MUSEC = "microsec"
MS = "ms"
@@ -32,92 +23,48 @@
average = "mean"
ceiling = "ceil"


class FreqDict(TypedDict):
musec: str
microsec: str
microsecond: str
microseconds: str
ms: str
millisecond: str
milliseconds: str
sec: str
second: str
seconds: str
min: str
minute: str
minutes: str
hr: str
hour: str
hours: str
day: str
days: str


freq_dict: FreqDict = {
"musec": "microseconds",
freq_dict = {
"microsec": "microseconds",
"microsecond": "microseconds",
"microseconds": "microseconds",
"ms": "milliseconds",
"millisecond": "milliseconds",
"milliseconds": "milliseconds",
"sec": "seconds",
"second": "seconds",
"seconds": "seconds",
"min": "minutes",
"minute": "minutes",
"minutes": "minutes",
"hr": "hours",
"hour": "hours",
"hours": "hours",
"day": "days",
"days": "days",
"hour": "hours",
}

ALLOWED_FREQ_KEYS: List[str] = list(get_type_hints(FreqDict).keys())


def is_valid_allowed_freq_keys(val: str, literal_constant: List[str]) -> bool:
return val in literal_constant


allowableFreqs = [MUSEC, MS, SEC, MIN, HR, DAY]
allowableFuncs = [floor, min, max, average, ceiling]


def _appendAggKey(
tsdf: t_tsdf.TSDF, freq: Optional[str] = None
) -> Tuple[t_tsdf.TSDF, int | str, Any]:
def _appendAggKey(tsdf: tempo.TSDF, freq: str = None):
"""
:param tsdf: TSDF object as input
:param freq: frequency at which to upsample
:return: triple - 1) return a TSDF with a new aggregate key (called agg_key) 2) return the period for use in interpolation, 3) return the time increment (also necessary for interpolation)
"""
df = tsdf.df
parsed_freq = checkAllowableFreq(freq)
period, unit = parsed_freq[0], parsed_freq[1]

agg_window = f.window(
f.col(tsdf.ts_col), "{} {}".format(period, freq_dict[unit]) # type: ignore[literal-required]
agg_window = Fn.window(
Fn.col(tsdf.ts_col), "{} {}".format(parsed_freq[0], freq_dict[parsed_freq[1]])
)

df = df.withColumn("agg_key", agg_window)

return (
t_tsdf.TSDF(df, tsdf.ts_col, partition_cols=tsdf.partitionCols),
period,
freq_dict[unit], # type: ignore[literal-required]
tempo.TSDF(df, ts_col=tsdf.ts_col, series_ids=tsdf.series_ids),
parsed_freq[0],
freq_dict[parsed_freq[1]],
)


def aggregate(
tsdf: t_tsdf.TSDF,
tsdf: tempo.TSDF,
freq: str,
func: Union[Callable, str],
metricCols: Optional[List[str]] = None,
prefix: Optional[str] = None,
fill: Optional[bool] = None,
func: str,
metricCols: list[str] = None,
prefix: str = None,
fill: bool = None,
) -> DataFrame:
"""
aggregate a data frame by a coarser timestamp than the initial TSDF ts_col
@@ -132,8 +79,7 @@ def aggregate(

df = tsdf.df

groupingCols = tsdf.partitionCols + ["agg_key"]

groupingCols = tsdf.series_ids + ["agg_key"]
if metricCols is None:
metricCols = list(set(df.columns).difference(set(groupingCols + [tsdf.ts_col])))

@@ -142,93 +88,87 @@ def aggregate(
else:
prefix = prefix + "_"

groupingCols = [f.col(column) for column in groupingCols]
groupingCols = [Fn.col(column) for column in groupingCols]

if func == floor:
metricCol = f.struct([tsdf.ts_col] + metricCols)
metricCol = Fn.struct([tsdf.ts_col] + metricCols)
res = df.withColumn("struct_cols", metricCol).groupBy(groupingCols)
res = res.agg(f.min("struct_cols").alias("closest_data")).select(
*groupingCols, f.col("closest_data.*")
res = res.agg(Fn.min("struct_cols").alias("closest_data")).select(
*groupingCols, Fn.col("closest_data.*")
)
new_cols = [f.col(tsdf.ts_col)] + [
f.col(c).alias("{}".format(prefix) + c) for c in metricCols
new_cols = [Fn.col(tsdf.ts_col)] + [
Fn.col(c).alias("{}".format(prefix) + c) for c in metricCols
]
res = res.select(*groupingCols, *new_cols)
elif func == average:
exprs = {x: "avg" for x in metricCols}
res = df.groupBy(groupingCols).agg(exprs)
agg_metric_cls = list(
set(res.columns).difference(
set(tsdf.partitionCols + [tsdf.ts_col, "agg_key"])
)
set(res.columns).difference(set(tsdf.series_ids + [tsdf.ts_col, "agg_key"]))
)
new_cols = [
f.col(c).alias("{}".format(prefix) + (c.split("avg(")[1]).replace(")", ""))
Fn.col(c).alias("{}".format(prefix) + (c.split("avg(")[1]).replace(")", ""))
for c in agg_metric_cls
]
res = res.select(*groupingCols, *new_cols)
elif func == min:
exprs = {x: "min" for x in metricCols}
res = df.groupBy(groupingCols).agg(exprs)
agg_metric_cls = list(
set(res.columns).difference(
set(tsdf.partitionCols + [tsdf.ts_col, "agg_key"])
)
set(res.columns).difference(set(tsdf.series_ids + [tsdf.ts_col, "agg_key"]))
)
new_cols = [
f.col(c).alias("{}".format(prefix) + (c.split("min(")[1]).replace(")", ""))
Fn.col(c).alias("{}".format(prefix) + (c.split("min(")[1]).replace(")", ""))
for c in agg_metric_cls
]
res = res.select(*groupingCols, *new_cols)
elif func == max:
exprs = {x: "max" for x in metricCols}
res = df.groupBy(groupingCols).agg(exprs)
agg_metric_cls = list(
set(res.columns).difference(
set(tsdf.partitionCols + [tsdf.ts_col, "agg_key"])
)
set(res.columns).difference(set(tsdf.series_ids + [tsdf.ts_col, "agg_key"]))
)
new_cols = [
f.col(c).alias("{}".format(prefix) + (c.split("max(")[1]).replace(")", ""))
Fn.col(c).alias("{}".format(prefix) + (c.split("max(")[1]).replace(")", ""))
for c in agg_metric_cls
]
res = res.select(*groupingCols, *new_cols)
elif func == ceiling:
metricCol = f.struct([tsdf.ts_col] + metricCols)
metricCol = Fn.struct([tsdf.ts_col] + metricCols)
res = df.withColumn("struct_cols", metricCol).groupBy(groupingCols)
res = res.agg(f.max("struct_cols").alias("ceil_data")).select(
*groupingCols, f.col("ceil_data.*")
res = res.agg(Fn.max("struct_cols").alias("ceil_data")).select(
*groupingCols, Fn.col("ceil_data.*")
)
new_cols = [f.col(tsdf.ts_col)] + [
f.col(c).alias("{}".format(prefix) + c) for c in metricCols
new_cols = [Fn.col(tsdf.ts_col)] + [
Fn.col(c).alias("{}".format(prefix) + c) for c in metricCols
]
res = res.select(*groupingCols, *new_cols)

# aggregate by the window and drop the end time (use start time as new ts_col)
res = (
res.drop(tsdf.ts_col)
.withColumnRenamed("agg_key", tsdf.ts_col)
.withColumn(tsdf.ts_col, f.col(tsdf.ts_col).start)
.withColumn(tsdf.ts_col, Fn.col(tsdf.ts_col).start)
)

# sort columns so they are consistent
non_part_cols = set(set(res.columns) - set(tsdf.partitionCols)) - set([tsdf.ts_col])
sel_and_sort = tsdf.partitionCols + [tsdf.ts_col] + sorted(non_part_cols)
non_part_cols = set(set(res.columns) - set(tsdf.series_ids)) - {tsdf.ts_col}
sel_and_sort = tsdf.series_ids + [tsdf.ts_col] + sorted(non_part_cols)
res = res.select(sel_and_sort)

fillW = Window.partitionBy(tsdf.partitionCols)
fillW = Window.partitionBy(tsdf.series_ids)

imputes = (
res.select(
*tsdf.partitionCols,
f.min(tsdf.ts_col).over(fillW).alias("from"),
f.max(tsdf.ts_col).over(fillW).alias("until"),
*tsdf.series_ids,
Fn.min(tsdf.ts_col).over(fillW).alias("from"),
Fn.max(tsdf.ts_col).over(fillW).alias("until"),
)
.distinct()
.withColumn(
tsdf.ts_col,
f.explode(
f.expr("sequence(from, until, interval {} {})".format(period, unit))
Fn.explode(
Fn.expr("sequence(from, until, interval {} {})".format(period, unit))
),
)
.drop("from", "until")
@@ -240,66 +180,48 @@ def aggregate(
metrics.append(col[0])

if fill:
res = imputes.join(
res, tsdf.partitionCols + [tsdf.ts_col], "leftouter"
).na.fill(0, metrics)
res = imputes.join(res, tsdf.series_ids + [tsdf.ts_col], "leftouter").na.fill(
0, metrics
)

return res


def checkAllowableFreq(freq: Optional[str]) -> Tuple[Union[int | str], str]:
def checkAllowableFreq(freq: Optional[str]) -> tuple[Union[int | str], Optional[str]]:
"""
Parses frequency and checks against allowable frequencies
:param freq: frequncy at which to upsample/downsample, declared in resample function
:return: list of parsed frequency value and time suffix
"""
if not isinstance(freq, str):
raise TypeError(f"Invalid type for `freq` argument: {freq}.")

# TODO - return either int OR str for first argument
allowable_freq: Tuple[Union[int | str], str] = (
0,
"will_always_fail_if_not_overwritten",
)

if is_valid_allowed_freq_keys(
freq.lower(),
ALLOWED_FREQ_KEYS,
):
allowable_freq = 1, freq
return allowable_freq

try:
periods = freq.lower().split(" ")[0].strip()
units = freq.lower().split(" ")[1].strip()
except IndexError:
raise ValueError(
"Allowable grouping frequencies are microsecond (musec), millisecond (ms), sec (second), min (minute), hr (hour), day. Reformat your frequency as <integer> <day/hour/minute/second>"
)

if is_valid_allowed_freq_keys(
units.lower(),
ALLOWED_FREQ_KEYS,
):
elif freq in allowableFreqs:
return 1, freq
else:
try:
periods = freq.lower().split(" ")[0].strip()
units = freq.lower().split(" ")[1].strip()
except IndexError:
raise ValueError(
"Allowable grouping frequencies are microsecond (musec), millisecond (ms), sec (second), min (minute), hr (hour), day. Reformat your frequency as <integer> <day/hour/minute/second>"
)
if units.startswith(MUSEC):
allowable_freq = periods, MUSEC
return periods, MUSEC
elif units.startswith(MS) | units.startswith("millis"):
allowable_freq = periods, MS
return periods, MS
elif units.startswith(SEC):
allowable_freq = periods, SEC
return periods, SEC
elif units.startswith(MIN):
allowable_freq = periods, MIN
return periods, MIN
elif units.startswith("hour") | units.startswith(HR):
allowable_freq = periods, "hour"
return periods, "hour"
elif units.startswith(DAY):
allowable_freq = periods, DAY
else:
raise ValueError(f"Invalid value for `freq` argument: {freq}.")

return allowable_freq
return periods, DAY
else:
raise ValueError(f"Invalid value for `freq` argument: {freq}.")


def validateFuncExists(func: Union[Callable | str]) -> None:
def validateFuncExists(func: str):
if func is None:
raise TypeError(
"Aggregate function missing. Provide one of the allowable functions: "
1,290 changes: 604 additions & 686 deletions python/tempo/tsdf.py

Large diffs are not rendered by default.

511 changes: 511 additions & 0 deletions python/tempo/tsschema.py

Large diffs are not rendered by default.

91 changes: 17 additions & 74 deletions python/tempo/utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
from __future__ import annotations

from typing import List, Union, Optional, overload
import logging
import os
import warnings
from typing import List

from IPython import get_ipython
from IPython.core.display import HTML
from IPython.display import display as ipydisplay
from pandas.core.frame import DataFrame as pandasDataFrame
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.functions import expr, max, min, sum, percentile_approx

import tempo.tsdf as t_tsdf
import tempo.resample as t_resample
from tempo.resample import checkAllowableFreq, freq_dict

logger = logging.getLogger(__name__)
IS_DATABRICKS = "DB_HOME" in os.environ.keys()
@@ -33,7 +31,7 @@ class ResampleWarning(Warning):
pass


def _is_capable_of_html_rendering() -> bool:
def _is_capable_of_html_rendering():
"""
This method returns a boolean value signifying whether the environment is a notebook environment
capable of rendering HTML or not.
@@ -51,24 +49,11 @@ def _is_capable_of_html_rendering() -> bool:


def calculate_time_horizon(
df: DataFrame,
ts_col: str,
freq: str,
partition_cols: Optional[List[str]],
local_freq_dict: Optional[t_resample.FreqDict] = None,
) -> None:
df: DataFrame, ts_col: str, freq: str, partition_cols: List[str]
):
# Convert Frequency using resample dictionary
if local_freq_dict is None:
local_freq_dict = t_resample.freq_dict
parsed_freq = t_resample.checkAllowableFreq(freq)
period, unit = parsed_freq[0], parsed_freq[1]
if t_resample.is_valid_allowed_freq_keys(
unit,
t_resample.ALLOWED_FREQ_KEYS,
):
freq = f"{period} {local_freq_dict[unit]}" # type: ignore[literal-required]
else:
raise ValueError(f"Frequency {unit} not supported")
parsed_freq = checkAllowableFreq(freq)
freq = f"{parsed_freq[0]} {freq_dict[parsed_freq[1]]}"

# Get max and min timestamp per partition
partitioned_df: DataFrame = df.groupBy(*partition_cols).agg(
@@ -134,17 +119,7 @@ def calculate_time_horizon(
)


@overload
def display_html(df: pandasDataFrame) -> None:
...


@overload
def display_html(df: DataFrame) -> None:
...


def display_html(df: Union[pandasDataFrame, DataFrame]) -> None:
def display_html(df):
"""
Display method capable of displaying the dataframe in a formatted HTML structured output
"""
@@ -157,7 +132,7 @@ def display_html(df: Union[pandasDataFrame, DataFrame]) -> None:
logger.error("'display' method not available for this object")


def display_unavailable() -> None:
def display_unavailable(df):
"""
This method is called when display method is not available in the environment.
"""
@@ -166,13 +141,8 @@ def display_unavailable() -> None:
)


def get_display_df(tsdf: t_tsdf.TSDF, k: int) -> DataFrame:
# let's show the n most recent records per series, in order:
orderCols = tsdf.partitionCols.copy()
orderCols.append(tsdf.ts_col)
if tsdf.sequence_col:
orderCols.append(tsdf.sequence_col)
return tsdf.latest(k).df.orderBy(orderCols)
def get_display_df(tsdf, k):
return tsdf.latest(k).withNaturalOrdering().df


ENV_CAN_RENDER_HTML = _is_capable_of_html_rendering()
@@ -183,24 +153,11 @@ def get_display_df(tsdf: t_tsdf.TSDF, k: int) -> DataFrame:
and ("display" in get_ipython().user_ns.keys())
):
method = get_ipython().user_ns["display"]

# Under 'display' key in user_ns the original databricks display method is present
# to know more refer: /databricks/python_shell/scripts/db_ipykernel_launcher.py

@overload
def display_improvised(obj: t_tsdf.TSDF) -> None:
...

@overload
def display_improvised(obj: pandasDataFrame) -> None:
...

@overload
def display_improvised(obj: DataFrame) -> None:
...

def display_improvised(obj: Union[t_tsdf.TSDF, pandasDataFrame, DataFrame]) -> None:
if isinstance(obj, t_tsdf.TSDF):
def display_improvised(obj):
if type(obj).__name__ == "TSDF":
method(get_display_df(obj, k=5))
else:
method(obj)
@@ -209,30 +166,16 @@ def display_improvised(obj: Union[t_tsdf.TSDF, pandasDataFrame, DataFrame]) -> N

elif ENV_CAN_RENDER_HTML:

@overload
def display_html_improvised(obj: Optional[t_tsdf.TSDF]) -> None:
...

@overload
def display_html_improvised(obj: Optional[pandasDataFrame]) -> None:
...

@overload
def display_html_improvised(obj: Optional[DataFrame]) -> None:
...

def display_html_improvised(
obj: Union[t_tsdf.TSDF, pandasDataFrame, DataFrame]
) -> None:
if isinstance(obj, t_tsdf.TSDF):
def display_html_improvised(obj):
if type(obj).__name__ == "TSDF":
display_html(get_display_df(obj, k=5))
else:
display_html(obj)

display = display_html_improvised

else:
display = display_unavailable # type: ignore
display = display_unavailable

"""
display method's equivalent for TSDF object
93 changes: 21 additions & 72 deletions python/tests/as_of_join_tests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import unittest
from unittest.mock import patch

from tests.base import SparkTest

@@ -74,93 +73,43 @@ def test_sequence_number_sort(self):

def test_partitioned_asof_join(self):
"""AS-OF Join with a time-partition"""
with self.assertLogs(level="WARNING") as warning_captured:
# fetch test data
tsdf_left = self.get_data_as_tsdf("left")
tsdf_right = self.get_data_as_tsdf("right")
dfExpected = self.get_data_as_sdf("expected")

joined_df = tsdf_left.asofJoin(
tsdf_right,
left_prefix="left",
right_prefix="right",
tsPartitionVal=10,
fraction=0.1,
).df

self.assertDataFrameEquality(joined_df, dfExpected)
self.assertEqual(
warning_captured.output,
[
"WARNING:tempo.tsdf:You are using the skew version of the AS OF join. This "
"may result in null values if there are any values outside of the maximum "
"lookback. For maximum efficiency, choose smaller values of maximum lookback, "
"trading off performance and potential blank AS OF values for sparse keys"
],
)

def test_asof_join_nanos(self):
"""As of join with nanosecond timestamps"""

# fetch test data
tsdf_left = self.get_data_as_tsdf("left")
tsdf_right = self.get_data_as_tsdf("right")
dfExpected = self.get_data_as_sdf("expected")

# perform join
joined_df = tsdf_left.asofJoin(
tsdf_right, left_prefix="left", right_prefix="right"
tsdf_right,
left_prefix="left",
right_prefix="right",
tsPartitionVal=10,
fraction=0.1,
).df

# compare
self.assertDataFrameEquality(joined_df, dfExpected)

def test_asof_join_tolerance(self):
"""As of join with tolerance band"""
def test_asof_join_nanos(self):
"""As of join with nanosecond timestamps"""

# fetch test data
tsdf_left = self.get_data_as_tsdf("left")
tsdf_right = self.get_data_as_tsdf("right")
dfExpected = self.get_data_as_sdf("expected")

tolerance_test_values = [None, 0, 5.5, 7, 10]
for tolerance in tolerance_test_values:
# perform join
joined_df = tsdf_left.asofJoin(
tsdf_right,
left_prefix="left",
right_prefix="right",
tolerance=tolerance,
).df

# compare
expected_tolerance = self.get_data_as_sdf(f"expected_tolerance_{tolerance}")
self.assertDataFrameEquality(joined_df, expected_tolerance)

def test_asof_join_sql_join_opt_and_bytes_threshold(self):
"""AS-OF Join with out a time-partition test"""
with patch("tempo.tsdf.TSDF._TSDF__getBytesFromPlan", return_value=1000):
# Construct dataframes
tsdf_left = self.get_data_as_tsdf("left")
tsdf_right = self.get_data_as_tsdf("right")
dfExpected = self.get_data_as_sdf("expected")
noRightPrefixdfExpected = self.get_data_as_sdf("expected_no_right_prefix")

# perform the join
joined_df = tsdf_left.asofJoin(
tsdf_right, left_prefix="left", right_prefix="right", sql_join_opt=True
).df
non_prefix_joined_df = tsdf_left.asofJoin(
tsdf_right, left_prefix="left", right_prefix="", sql_join_opt=True
).df

# joined dataframe should equal the expected dataframe
self.assertDataFrameEquality(joined_df, dfExpected)
self.assertDataFrameEquality(non_prefix_joined_df, noRightPrefixdfExpected)

spark_sql_joined_df = tsdf_left.asofJoin(
tsdf_right, left_prefix="left", right_prefix="right"
).df
self.assertDataFrameEquality(spark_sql_joined_df, dfExpected)
# perform join
joined_df = tsdf_left.asofJoin(
tsdf_right, left_prefix="left", right_prefix="right"
).df

print("joined_df:")
joined_df.printSchema()

print("defExpected:")
dfExpected.printSchema()

# compare
self.assertDataFrameEquality(joined_df, dfExpected)


# MAIN
52 changes: 26 additions & 26 deletions python/tests/base.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import re
import os
import unittest
import warnings
from typing import Union

import jsonref

import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from tempo.tsdf import TSDF
from tempo.intervals import IntervalsDF
import pyspark.sql.functions as Fn
from chispa import assert_df_equality
from pyspark.sql import SparkSession
from pyspark.sql.dataframe import DataFrame

from tempo.intervals import IntervalsDF
from tempo.tsdf import TSDF


class SparkTest(unittest.TestCase):
#
@@ -77,12 +76,12 @@ def get_data_as_sdf(self, name: str, convert_ts_col=True):
def get_data_as_tsdf(self, name: str, convert_ts_col=True):
df = self.get_data_as_sdf(name, convert_ts_col)
td = self.test_data[name]
tsdf = TSDF(
df,
ts_col=td["ts_col"],
partition_cols=td.get("partition_cols", None),
sequence_col=td.get("sequence_col", None),
)
if "sequence_col" in td:
tsdf = TSDF.fromSubsequenceCol(
df, td["ts_col"], td["sequence_col"], td.get("series_ids", None)
)
else:
tsdf = TSDF(df, ts_col=td["ts_col"], series_ids=td.get("series_ids", None))
return tsdf

def get_data_as_idf(self, name: str, convert_ts_col=True):
@@ -156,17 +155,18 @@ def buildTestDF(self, schema, data, ts_cols=["event_ts"]):
# build dataframe
df = self.spark.createDataFrame(data, schema)

# check if ts_col follows standard timestamp format, then check if timestamp has micro/nanoseconds
# convert timstamp fields to timestamp type
for tsc in ts_cols:
ts_value = str(df.select(ts_cols).limit(1).collect()[0][0])
ts_pattern = r"^\d{4}-\d{2}-\d{2}| \d{2}:\d{2}:\d{2}\.\d*$"
decimal_pattern = r"[.]\d+"
if re.match(ts_pattern, str(ts_value)) is not None:
if (
re.search(decimal_pattern, ts_value) is None
or len(re.search(decimal_pattern, ts_value)[0]) <= 4
):
df = df.withColumn(tsc, F.to_timestamp(F.col(tsc)))
# check if the column is nested in a struct or not
if "." in tsc:
# we're changing a field nested in a struct
(struct, field) = tsc.split(".")
df = df.withColumn(
struct, Fn.col(struct).withField(field, Fn.to_timestamp(tsc))
)
else:
# standard column
df = df.withColumn(tsc, Fn.to_timestamp(Fn.col(tsc)))
return df

#
@@ -178,8 +178,8 @@ def assertFieldsEqual(self, fieldA, fieldB):
Test that two fields are equivalent
"""
self.assertEqual(
fieldA.name.lower(),
fieldB.name.lower(),
fieldA.colname.lower(),
fieldB.colname.lower(),
msg=f"Field {fieldA} has different name from {fieldB}",
)
self.assertEqual(
@@ -195,9 +195,9 @@ def assertSchemaContainsField(self, schema, field):
"""
# the schema must contain a field with the right name
lc_fieldNames = [fc.lower() for fc in schema.fieldNames()]
self.assertTrue(field.name.lower() in lc_fieldNames)
self.assertTrue(field.colname.lower() in lc_fieldNames)
# the attributes of the fields must be equal
self.assertFieldsEqual(field, schema[field.name])
self.assertFieldsEqual(field, schema[field.colname])

@staticmethod
def assertDataFrameEquality(
240 changes: 76 additions & 164 deletions python/tests/interpol_tests.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import unittest

from pyspark.sql.dataframe import DataFrame

from tempo.interpol import Interpolation
from tempo.tsdf import TSDF

from tempo.utils import *
from tests.tsdf_tests import SparkTest


@@ -73,18 +71,21 @@ def test_fill_validation(self):
input_tsdf: TSDF = self.get_data_as_tsdf("input_data")

# interpolate
self.assertRaises(
ValueError,
self.interpolate_helper.interpolate,
input_tsdf,
["partition_a", "partition_b"],
["value_a", "value_b"],
"30 seconds",
"event_ts",
"mean",
"fill_wrong",
True,
)
try:
self.interpolate_helper.interpolate(
tsdf=input_tsdf,
series_ids=["partition_a", "partition_b"],
target_cols=["value_a", "value_b"],
freq="30 seconds",
ts_col="event_ts",
func="mean",
method="abcd",
show_interpolated=True,
)
except ValueError as e:
self.assertEqual(type(e), ValueError)
else:
self.fail("ValueError not raised")

def test_target_column_validation(self):
"""Test target columns exist in schema, and are of the right type (numeric)."""
@@ -93,18 +94,21 @@ def test_target_column_validation(self):
input_tsdf: TSDF = self.get_data_as_tsdf("input_data")

# interpolate
self.assertRaises(
ValueError,
self.interpolate_helper.interpolate,
input_tsdf,
["partition_a", "partition_b"],
["target_column_wrong", "value_b"],
"30 seconds",
"event_ts",
"mean",
"zero",
True,
)
try:
self.interpolate_helper.interpolate(
tsdf=input_tsdf,
series_ids=["partition_a", "partition_b"],
target_cols=["partition_a", "value_b"],
freq="30 seconds",
ts_col="event_ts",
func="mean",
method="zero",
show_interpolated=True,
)
except TypeError as e:
self.assertEqual(type(e), TypeError)
else:
self.fail("ValueError not raised")

def test_partition_column_validation(self):
"""Test partition columns exist in schema."""
@@ -113,18 +117,21 @@ def test_partition_column_validation(self):
input_tsdf: TSDF = self.get_data_as_tsdf("input_data")

# interpolate
self.assertRaises(
ValueError,
self.interpolate_helper.interpolate,
input_tsdf,
["partition_c", "partition_column_wrong"],
["value_a", "value_b"],
"30 seconds",
"event_ts",
"mean",
"zero",
True,
)
try:
self.interpolate_helper.interpolate(
tsdf=input_tsdf,
series_ids=["partition_c", "partition_b"],
target_cols=["value_a", "value_b"],
freq="30 seconds",
ts_col="event_ts",
func="mean",
method="zero",
show_interpolated=True,
)
except ValueError as e:
self.assertEqual(type(e), ValueError)
else:
self.fail("ValueError not raised")

def test_ts_column_validation(self):
"""Test time series column exist in schema."""
@@ -133,18 +140,21 @@ def test_ts_column_validation(self):
input_tsdf: TSDF = self.get_data_as_tsdf("input_data")

# interpolate
self.assertRaises(
ValueError,
self.interpolate_helper.interpolate,
input_tsdf,
["partition_a", "partition_b"],
["value_a", "value_b"],
"30 seconds",
"event_ts_wrong",
"mean",
"zero",
True,
)
try:
self.interpolate_helper.interpolate(
tsdf=input_tsdf,
series_ids=["partition_a", "partition_b"],
target_cols=["value_a", "value_b"],
freq="30 seconds",
ts_col="value_a",
func="mean",
method="zero",
show_interpolated=True,
)
except ValueError as e:
self.assertEqual(type(e), ValueError)
else:
self.fail("ValueError not raised")

def test_zero_fill_interpolation(self):
"""Test zero fill interpolation.
@@ -161,7 +171,7 @@ def test_zero_fill_interpolation(self):
# interpolate
actual_df: DataFrame = self.interpolate_helper.interpolate(
tsdf=simple_input_tsdf,
partition_cols=["partition_a", "partition_b"],
series_ids=["partition_a", "partition_b"],
target_cols=["value_a", "value_b"],
freq="30 seconds",
ts_col="event_ts",
@@ -187,7 +197,7 @@ def test_zero_fill_interpolation_no_perform_checks(self):
# interpolate
actual_df: DataFrame = self.interpolate_helper.interpolate(
tsdf=simple_input_tsdf,
partition_cols=["partition_a", "partition_b"],
series_ids=["partition_a", "partition_b"],
target_cols=["value_a", "value_b"],
freq="30 seconds",
ts_col="event_ts",
@@ -214,7 +224,7 @@ def test_null_fill_interpolation(self):
# interpolate
actual_df: DataFrame = self.interpolate_helper.interpolate(
tsdf=simple_input_tsdf,
partition_cols=["partition_a", "partition_b"],
series_ids=["partition_a", "partition_b"],
target_cols=["value_a", "value_b"],
freq="30 seconds",
ts_col="event_ts",
@@ -241,7 +251,7 @@ def test_back_fill_interpolation(self):
# interpolate
actual_df: DataFrame = self.interpolate_helper.interpolate(
tsdf=simple_input_tsdf,
partition_cols=["partition_a", "partition_b"],
series_ids=["partition_a", "partition_b"],
target_cols=["value_a", "value_b"],
freq="30 seconds",
ts_col="event_ts",
@@ -268,7 +278,7 @@ def test_forward_fill_interpolation(self):
# interpolate
actual_df: DataFrame = self.interpolate_helper.interpolate(
tsdf=simple_input_tsdf,
partition_cols=["partition_a", "partition_b"],
series_ids=["partition_a", "partition_b"],
target_cols=["value_a", "value_b"],
freq="30 seconds",
ts_col="event_ts",
@@ -295,7 +305,7 @@ def test_linear_fill_interpolation(self):
# interpolate
actual_df: DataFrame = self.interpolate_helper.interpolate(
tsdf=simple_input_tsdf,
partition_cols=["partition_a", "partition_b"],
series_ids=["partition_a", "partition_b"],
target_cols=["value_a", "value_b"],
freq="30 seconds",
ts_col="event_ts",
@@ -320,7 +330,7 @@ def test_different_freq_abbreviations(self):
# interpolate
actual_df: DataFrame = self.interpolate_helper.interpolate(
tsdf=simple_input_tsdf,
partition_cols=["partition_a", "partition_b"],
series_ids=["partition_a", "partition_b"],
target_cols=["value_a", "value_b"],
freq="30 sec",
ts_col="event_ts",
@@ -347,7 +357,7 @@ def test_show_interpolated(self):
# interpolate
actual_df: DataFrame = self.interpolate_helper.interpolate(
tsdf=simple_input_tsdf,
partition_cols=["partition_a", "partition_b"],
series_ids=["partition_a", "partition_b"],
target_cols=["value_a", "value_b"],
freq="30 seconds",
ts_col="event_ts",
@@ -358,99 +368,6 @@ def test_show_interpolated(self):

self.assertDataFrameEquality(expected_df, actual_df, ignore_nullable=True)

def test_validate_ts_col_data_type_is_not_timestamp(self):
input_df: DataFrame = self.get_data_as_sdf("input_data")

self.assertRaises(
ValueError,
self.interpolate_helper._Interpolation__validate_col,
input_df,
["partition_a", "partition_b"],
["value_a", "value_b"],
"event_ts",
"not_timestamp",
)

def test_interpolation_freq_is_none(self):
"""Test a ValueError is raised when freq is None."""

# load test data
simple_input_tsdf: TSDF = self.get_data_as_tsdf("input_data")

# interpolate
self.assertRaises(
ValueError,
self.interpolate_helper.interpolate,
simple_input_tsdf,
"event_ts",
["partition_a", "partition_b"],
["value_a", "value_b"],
None,
"mean",
"zero",
True,
)

def test_interpolation_func_is_none(self):
"""Test a ValueError is raised when func is None."""

# load test data
simple_input_tsdf: TSDF = self.get_data_as_tsdf("input_data")

# interpolate
self.assertRaises(
ValueError,
self.interpolate_helper.interpolate,
simple_input_tsdf,
"event_ts",
["partition_a", "partition_b"],
["value_a", "value_b"],
"30 seconds",
None,
"zero",
True,
)

def test_interpolation_func_is_callable(self):
"""Test ValueError is raised when func is callable."""

# load test data
simple_input_tsdf: TSDF = self.get_data_as_tsdf("input_data")

# interpolate
self.assertRaises(
ValueError,
self.interpolate_helper.interpolate,
simple_input_tsdf,
"event_ts",
["partition_a", "partition_b"],
["value_a", "value_b"],
"30 seconds",
sum,
"zero",
True,
)

def test_interpolation_freq_is_not_supported_type(self):
"""Test ValueError is raised when func is callable."""

# load test data
simple_input_tsdf: TSDF = self.get_data_as_tsdf("input_data")

# interpolate
self.assertRaises(
ValueError,
self.interpolate_helper.interpolate,
simple_input_tsdf,
"event_ts",
["partition_a", "partition_b"],
["value_a", "value_b"],
"30 not_supported_type",
"mean",
"zero",
True,
)


class InterpolationIntegrationTest(SparkTest):
def test_interpolation_using_default_tsdf_params(self):
@@ -472,23 +389,19 @@ def test_interpolation_using_default_tsdf_params(self):
self.assertDataFrameEquality(expected_df, actual_df, ignore_nullable=True)

def test_interpolation_using_custom_params(self):
"""Verify that by specifying optional paramters it will change the result of the interpolation based on those
modified params."""
"""Verify that by specifying optional paramters it will change the result of the interpolation based on those modified params."""

# Modify input DataFrame using different ts_col
simple_input_tsdf: TSDF = self.get_data_as_tsdf("simple_input_data")
expected_df: DataFrame = self.get_data_as_sdf("expected")

input_tsdf = TSDF(
simple_input_tsdf.df.withColumnRenamed("event_ts", "other_ts_col"),
partition_cols=["partition_a", "partition_b"],
ts_col="other_ts_col",
)
input_tsdf = TSDF(simple_input_tsdf.df.withColumnRenamed("event_ts", "other_ts_col"), ts_col="other_ts_col",
series_ids=["partition_a", "partition_b"])

actual_df: DataFrame = input_tsdf.interpolate(
ts_col="other_ts_col",
show_interpolated=True,
partition_cols=["partition_a", "partition_b"],
series_ids=["partition_a", "partition_b"],
target_cols=["value_a"],
freq="30 seconds",
func="mean",
@@ -498,24 +411,23 @@ def test_interpolation_using_custom_params(self):
self.assertDataFrameEquality(expected_df, actual_df, ignore_nullable=True)

def test_tsdf_constructor_params_are_updated(self):
"""Verify that resulting TSDF class has the correct values for ts_col and partition_col based on the
interpolation."""
"""Verify that resulting TSDF class has the correct values for ts_col and partition_col based on the interpolation."""

# load test data
simple_input_tsdf: TSDF = self.get_data_as_tsdf("simple_input_data")

actual_tsdf: TSDF = simple_input_tsdf.interpolate(
ts_col="event_ts",
show_interpolated=True,
partition_cols=["partition_b"],
series_ids=["partition_b"],
target_cols=["value_a"],
freq="30 seconds",
func="mean",
method="linear",
)

self.assertEqual(actual_tsdf.ts_col, "event_ts")
self.assertEqual(actual_tsdf.partitionCols, ["partition_b"])
self.assertEqual(actual_tsdf.series_ids, ["partition_b"])

def test_interpolation_on_sampled_data(self):
"""Verify interpolation can be chained with resample within the TSDF class"""
18 changes: 8 additions & 10 deletions python/tests/intervals_tests.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from pyspark.sql.dataframe import DataFrame

from tempo.intervals import IntervalsDF
from tempo.intervals import *
from tests.tsdf_tests import SparkTest
from pyspark.sql.utils import AnalysisException
import pyspark.sql.functions as f
import pyspark.sql.functions as Fn


class IntervalsDFTests(SparkTest):
@@ -143,8 +141,8 @@ def test_fromStackedMetrics_series_list(self):
idf_expected = self.get_data_as_idf("expected")

df_input = df_input.withColumn(
"start_ts", f.to_timestamp("start_ts")
).withColumn("end_ts", f.to_timestamp("end_ts"))
"start_ts", Fn.to_timestamp("start_ts")
).withColumn("end_ts", Fn.to_timestamp("end_ts"))

idf = IntervalsDF.fromStackedMetrics(
df_input,
@@ -164,8 +162,8 @@ def test_fromStackedMetrics_metric_names(self):
idf_expected = self.get_data_as_idf("expected")

df_input = df_input.withColumn(
"start_ts", f.to_timestamp("start_ts")
).withColumn("end_ts", f.to_timestamp("end_ts"))
"start_ts", Fn.to_timestamp("start_ts")
).withColumn("end_ts", Fn.to_timestamp("end_ts"))

idf = IntervalsDF.fromStackedMetrics(
df_input,
@@ -338,8 +336,8 @@ def test_toDF_stack(self):
expected_df = self.get_data_as_sdf("expected")

expected_df = expected_df.withColumn(
"start_ts", f.to_timestamp("start_ts")
).withColumn("end_ts", f.to_timestamp("end_ts"))
"start_ts", Fn.to_timestamp("start_ts")
).withColumn("end_ts", Fn.to_timestamp("end_ts"))

actual_df = idf_input.toDF(stack=True)

10 changes: 0 additions & 10 deletions python/tests/resample_tests.py
Original file line number Diff line number Diff line change
@@ -27,16 +27,6 @@ def test_appendAggKey_freq_microsecond(self):
self.assertEqual(appendAggKey_tuple[1], "1")
self.assertEqual(appendAggKey_tuple[2], "microseconds")

def test_appendAggKey_freq_is_invalid(self):
input_tsdf = self.get_data_as_tsdf("input_data")

self.assertRaises(
ValueError,
_appendAggKey,
input_tsdf,
"1 invalid",
)

def test_aggregate_floor(self):
input_tsdf = self.get_data_as_tsdf("input_data")
expected_data = self.get_data_as_sdf("expected_data")
578 changes: 242 additions & 336 deletions python/tests/tsdf_tests.py

Large diffs are not rendered by default.

126 changes: 18 additions & 108 deletions python/tests/unit_test_data/as_of_join_tests.json
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
"shared_left": {
"schema": "symbol string, event_ts string, trade_pr float",
"ts_col": "event_ts",
"partition_cols": ["symbol"],
"series_ids": ["symbol"],
"data": [
["S1", "2020-08-01 00:00:10", 349.21],
["S1", "2020-08-01 00:01:12", 351.32],
@@ -26,7 +26,7 @@
"right": {
"schema": "symbol string, event_ts string, bid_pr float, ask_pr float",
"ts_col": "event_ts",
"partition_cols": ["symbol"],
"series_ids": ["symbol"],
"data": [
["S1", "2020-08-01 00:00:01", 345.11, 351.12],
["S1", "2020-08-01 00:01:05", 348.10, 353.13],
@@ -37,7 +37,7 @@
"expected": {
"schema": "symbol string, left_event_ts string, left_trade_pr float, right_event_ts string, right_bid_pr float, right_ask_pr float",
"ts_col": "left_event_ts",
"partition_cols": ["symbol"],
"series_ids": ["symbol"],
"other_ts_cols": ["right_event_ts"],
"data": {
"$ref": "#/__SharedData/test_asof_expected_data"
@@ -46,7 +46,7 @@
"expected_no_right_prefix": {
"schema": "symbol string, left_event_ts string, left_trade_pr float, event_ts string, bid_pr float, ask_pr float",
"ts_col": "left_event_ts",
"partition_cols": ["symbol"],
"series_ids": ["symbol"],
"other_ts_cols": ["event_ts"],
"data": {
"$ref": "#/__SharedData/test_asof_expected_data"
@@ -60,7 +60,7 @@
"right": {
"schema": "symbol string, event_ts string, bid_pr float, ask_pr float",
"ts_col": "event_ts",
"partition_cols": ["symbol"],
"series_ids": ["symbol"],
"data": [
["S1", "2020-08-01 00:00:01", 345.11, 351.12],
["S1", "2020-08-01 00:01:05", null, 353.13],
@@ -71,7 +71,7 @@
"expected_skip_nulls": {
"schema": "symbol string, left_event_ts string, left_trade_pr float, right_event_ts string, right_bid_pr float, right_ask_pr float",
"ts_col": "left_event_ts",
"partition_cols": ["symbol"],
"series_ids": ["symbol"],
"other_ts_cols": ["right_event_ts"],
"data": [
["S1", "2020-08-01 00:00:10", 349.21, "2020-08-01 00:00:01", 345.11, 351.12],
@@ -83,7 +83,7 @@
"expected_skip_nulls_disabled": {
"schema": "symbol string, left_event_ts string, left_trade_pr float, right_event_ts string, right_bid_pr float, right_ask_pr float",
"ts_col": "left_event_ts",
"partition_cols": ["symbol"],
"series_ids": ["symbol"],
"other_ts_cols": ["right_event_ts"],
"data": [
["S1", "2020-08-01 00:00:10", 349.21, "2020-08-01 00:00:01", 345.11, 351.12],
@@ -97,7 +97,7 @@
"left": {
"schema": "symbol string, event_ts string, trade_pr float, trade_id int",
"ts_col": "event_ts",
"partition_cols": ["symbol"],
"series_ids": ["symbol"],
"data": [
["S1", "2020-08-01 00:00:10", 349.21, 1],
["S1", "2020-08-01 00:00:10", 350.21, 5],
@@ -109,7 +109,7 @@
"right": {
"schema": "symbol string, event_ts string, bid_pr float, ask_pr float, seq_nb long",
"ts_col": "event_ts",
"partition_cols": ["symbol"],
"series_ids": ["symbol"],
"sequence_col": "seq_nb",
"data": [
["S1", "2020-08-01 00:00:01", 345.11, 351.12, 1],
@@ -123,7 +123,7 @@
"expected": {
"schema": "symbol string, event_ts string, trade_pr float, trade_id int, right_event_ts string, right_bid_pr float, right_ask_pr float, right_seq_nb long",
"ts_col": "event_ts",
"partition_cols": ["symbol"],
"series_ids": ["symbol"],
"other_ts_cols": ["right_event_ts"],
"data": [
["S1", "2020-08-01 00:00:10", 349.21, 1, "2020-08-01 00:00:10", 19.11, 20.12, 1],
@@ -138,7 +138,7 @@
"left": {
"schema": "symbol string, event_ts string, trade_pr float",
"ts_col": "event_ts",
"partition_cols": ["symbol"],
"series_ids": ["symbol"],
"data": [
["S1", "2020-08-01 00:00:02", 349.21],
["S1", "2020-08-01 00:00:08", 351.32],
@@ -152,7 +152,7 @@
"right": {
"schema": "symbol string, event_ts string, bid_pr float, ask_pr float",
"ts_col": "event_ts",
"partition_cols": ["symbol"],
"series_ids": ["symbol"],
"data": [
["S1", "2020-08-01 00:00:01", 345.11, 351.12],
["S1", "2020-08-01 00:00:09", 348.10, 353.13],
@@ -163,7 +163,7 @@
"expected": {
"schema": "symbol string, left_event_ts string, left_trade_pr float, right_event_ts string, right_bid_pr float, right_ask_pr float",
"ts_col": "left_event_ts",
"partition_cols": ["symbol"],
"series_ids": ["symbol"],
"other_ts_cols": ["right_event_ts"],
"data": [
["S1", "2020-08-01 00:00:02", 349.21, "2020-08-01 00:00:01", 345.11, 351.12],
@@ -180,7 +180,7 @@
"left": {
"schema": "symbol string, event_ts string, trade_pr float",
"ts_col": "event_ts",
"partition_cols": ["symbol"],
"series_ids": ["symbol"],
"data": [
["S1", "2022-01-01 09:59:59.123456789", 349.21],
["S1", "2022-01-01 10:00:00.123456788", 351.32],
@@ -191,7 +191,7 @@
"right": {
"schema": "symbol string, event_ts string, bid_pr float, ask_pr float",
"ts_col": "event_ts",
"partition_cols": ["symbol"],
"series_ids": ["symbol"],
"data": [
["S1", "2022-01-01 10:00:00.1234567", 345.11, 351.12],
["S1", "2022-01-01 10:00:00.12345671", 348.10, 353.13],
@@ -203,105 +203,15 @@
"expected": {
"schema": "symbol string, left_event_ts string, left_trade_pr float, right_event_ts string, right_ask_pr float, right_bid_pr float",
"ts_col": "left_event_ts",
"partition_cols": ["symbol"],
"other_ts_cols": ["right_event_ts"],
"series_ids": ["symbol"],
"data": [
["S1", "2022-01-01 09:59:59.123456789", 349.21, null, null, null],
["S1", "2022-01-01 10:00:00.123456788", 351.32, "2022-01-01 10:00:00.12345677", 365.33, 358.91],
["S1", "2022-01-01 10:00:00.123456789", 361.12, "2022-01-01 10:00:00.12345677", 365.33, 358.91],
["S1", "2022-01-01 10:00:01.123456789", 364.31, "2022-01-01 10:00:01.10000001", 365.31, 359.21]
]
}
},
"test_asof_join_tolerance": {
"left": {
"$ref": "#/__SharedData/shared_left"
},
"right": {
"schema": "symbol string, event_ts string, bid_pr float, ask_pr float",
"ts_col": "event_ts",
"partition_cols": ["symbol"],
"data": [
["S1", "2020-08-01 00:00:01", 345.11, 351.12],
["S1", "2020-08-01 00:00:10", 345.22, 351.33],
["S1", "2020-08-01 00:01:05", 348.10, 353.13],
["S1", "2020-09-01 00:02:01", 358.93, 365.12],
["S1", "2020-09-01 00:15:01", 359.21, 365.31]
]
},
"expected_tolerance_None": {
"schema": "symbol string, left_event_ts string, left_trade_pr float, right_event_ts string, right_bid_pr float, right_ask_pr float",
"ts_col": "left_event_ts",
"partition_cols": ["symbol"],
"other_ts_cols": ["right_event_ts"],
"data": [
["S1", "2020-08-01 00:00:10", 349.21, "2020-08-01 00:00:10", 345.22, 351.33],
["S1", "2020-08-01 00:01:12", 351.32, "2020-08-01 00:01:05", 348.10, 353.13],
["S1", "2020-09-01 00:02:10", 361.1, "2020-09-01 00:02:01", 358.93, 365.12],
["S1", "2020-09-01 00:19:12", 362.1, "2020-09-01 00:15:01", 359.21, 365.31]
]
},
"expected_tolerance_0": {
"schema": "symbol string, left_event_ts string, left_trade_pr float, right_event_ts string, right_bid_pr float, right_ask_pr float",
"ts_col": "left_event_ts",
"partition_cols": ["symbol"],
"other_ts_cols": ["right_event_ts"],
"data": [
["S1", "2020-08-01 00:00:10", 349.21, "2020-08-01 00:00:10", 345.22, 351.33],
["S1", "2020-08-01 00:01:12", 351.32, null, null, null],
["S1", "2020-09-01 00:02:10", 361.1, null, null, null],
["S1", "2020-09-01 00:19:12", 362.1, null, null, null]
]
},
"expected_tolerance_5.5": {
"schema": "symbol string, left_event_ts string, left_trade_pr float, right_event_ts string, right_bid_pr float, right_ask_pr float",
"ts_col": "left_event_ts",
"partition_cols": ["symbol"],
"other_ts_cols": ["right_event_ts"],
"data": [
["S1", "2020-08-01 00:00:10", 349.21, "2020-08-01 00:00:10", 345.22, 351.33],
["S1", "2020-08-01 00:01:12", 351.32, null, null, null],
["S1", "2020-09-01 00:02:10", 361.1, null, null, null],
["S1", "2020-09-01 00:19:12", 362.1, null, null, null]
]
},
"expected_tolerance_7": {
"schema": "symbol string, left_event_ts string, left_trade_pr float, right_event_ts string, right_bid_pr float, right_ask_pr float",
"ts_col": "left_event_ts",
"partition_cols": ["symbol"],
"other_ts_cols": ["right_event_ts"],
"data": [
["S1", "2020-08-01 00:00:10", 349.21, "2020-08-01 00:00:10", 345.22, 351.33],
["S1", "2020-08-01 00:01:12", 351.32, "2020-08-01 00:01:05", 348.10, 353.13],
["S1", "2020-09-01 00:02:10", 361.1, null, null, null],
["S1", "2020-09-01 00:19:12", 362.1, null, null, null]
]
},
"expected_tolerance_10": {
"schema": "symbol string, left_event_ts string, left_trade_pr float, right_event_ts string, right_bid_pr float, right_ask_pr float",
"ts_col": "left_event_ts",
"partition_cols": ["symbol"],
"other_ts_cols": ["right_event_ts"],
"data": [
["S1", "2020-08-01 00:00:10", 349.21, "2020-08-01 00:00:10", 345.22, 351.33],
["S1", "2020-08-01 00:01:12", 351.32, "2020-08-01 00:01:05", 348.10, 353.13],
["S1", "2020-09-01 00:02:10", 361.1, "2020-09-01 00:02:01", 358.93, 365.12],
["S1", "2020-09-01 00:19:12", 362.1, null, null, null]
]
}
},
"test_asof_join_sql_join_opt_and_bytes_threshold": {
"left": {
"$ref": "#/__SharedData/shared_left"
},
"right": {
"$ref": "#/AsOfJoinTest/test_asof_join/right"
},
"expected": {
"$ref": "#/AsOfJoinTest/test_asof_join/expected"
},
"expected_no_right_prefix": {
"$ref": "#/AsOfJoinTest/test_asof_join/expected_no_right_prefix"
}
}
}
}
}
25 changes: 0 additions & 25 deletions python/tests/unit_test_data/interpol_tests.json
Original file line number Diff line number Diff line change
@@ -1136,31 +1136,6 @@
]
]
}
},
"test_validate_ts_col_data_type_is_not_timestamp": {
"input_data": {
"$ref": "#/__SharedData/input_data"
}
},
"test_interpolation_freq_is_none": {
"input_data": {
"$ref": "#/__SharedData/input_data"
}
},
"test_interpolation_func_is_none": {
"input_data": {
"$ref": "#/__SharedData/input_data"
}
},
"test_interpolation_func_is_callable": {
"input_data": {
"$ref": "#/__SharedData/input_data"
}
},
"test_interpolation_freq_is_not_supported_type": {
"input_data": {
"$ref": "#/__SharedData/input_data"
}
}
},
"InterpolationIntegrationTest": {
2 changes: 1 addition & 1 deletion python/tests/unit_test_data/resample_tests.json
Original file line number Diff line number Diff line change
@@ -70,7 +70,7 @@
"$ref": "#/__SharedData/input_data"
}
},
"test_appendAggKey_freq_is_invalid": {
"test_appendAggKey_freq_day": {
"input_data": {
"$ref": "#/__SharedData/input_data"
}
643 changes: 75 additions & 568 deletions python/tests/unit_test_data/tsdf_tests.json

Large diffs are not rendered by default.

38 changes: 20 additions & 18 deletions python/tests/utils_tests.py
Original file line number Diff line number Diff line change
@@ -2,8 +2,7 @@
import sys
import unittest

from tempo.utils import * # noqa: F403

from tempo.utils import *
from tests.tsdf_tests import SparkTest
from unittest import mock

@@ -108,8 +107,10 @@ def test_display_html_pandas_dataframe(self):
)

def test_display_unavailable(self):
init_tsdf = self.get_data_as_tsdf("init")

with self.assertLogs(level="ERROR") as error_captured:
display_unavailable()
display_unavailable(init_tsdf)

self.assertEqual(len(error_captured.records), 1)
self.assertEqual(
@@ -120,21 +121,22 @@ def test_display_unavailable(self):
],
)

def test_get_display_df(self):
init_tsdf = self.get_data_as_tsdf("init")
expected_df = self.get_data_as_sdf("expected")

actual_df = get_display_df(init_tsdf, 2)

self.assertDataFrameEquality(actual_df, expected_df)

def test_get_display_df_sequence_col(self):
init_tsdf = self.get_data_as_tsdf("init")
expected_df = self.get_data_as_sdf("expected")

actual_df = get_display_df(init_tsdf, 2)

self.assertDataFrameEquality(actual_df, expected_df)
# TODO - replace with tests of natural ordering & show
# def test_get_display_df(self):
# init_tsdf = self.get_data_as_tsdf("init")
# expected_df = self.get_data_as_sdf("expected")
#
# actual_df = get_display_df(init_tsdf, 2)
#
# self.assertDataFrameEquality(actual_df, expected_df)
#
# def test_get_display_df_sequence_col(self):
# init_tsdf = self.get_data_as_tsdf("init")
# expected_df = self.get_data_as_sdf("expected")
#
# actual_df = get_display_df(init_tsdf, 2)
#
# self.assertDataFrameEquality(actual_df, expected_df)


# MAIN