Skip to content

Commit

Permalink
Standardize imports (#340)
Browse files Browse the repository at this point in the history
* standardizing package imports

* black reformatting

* simplify the TSDF imports

* reorganize imports

* Revert "simplify the TSDF imports"

This reverts commit 0cefd1569f110c4e7f27db23bfa33db3a1bc730e.

* refactoring sql_fn to sfn based on popular demand

* Describe module import standards

* black formatting

* restoring dlt asofjoin fix from #334

* Update python/tests/intervals_tests.py

hmmm - guess I missed this one :D

Co-authored-by: Lorin Dawson <[email protected]>

* Update python/tests/intervals_tests.py

good catch

Co-authored-by: Lorin Dawson <[email protected]>

* Update python/tests/intervals_tests.py

Co-authored-by: Lorin Dawson <[email protected]>

* Update python/tests/intervals_tests.py

Co-authored-by: Lorin Dawson <[email protected]>

---------

Co-authored-by: Lorin Dawson <[email protected]>
  • Loading branch information
tnixon and R7L208 authored Jun 7, 2023
1 parent 7f93b9a commit 776218e
Show file tree
Hide file tree
Showing 14 changed files with 385 additions and 326 deletions.
30 changes: 30 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,33 @@ These environments are also defined in the `tox.ini`file and skip installing dep
* lint
* type-check
* coverage-report

# Code style & Standards

The tempo project abides by [`black`](https://black.readthedocs.io/en/stable/index.html) formatting standards,
as well as using [`flake8`](https://flake8.pycqa.org/en/latest/) and [`mypy`](https://mypy.readthedocs.io/en/stable/)
to check for effective code style, type-checking and common bad practices.
To test your code against these standards, run the following command:
```bash
tox -e lint, type-check
```
To have `black` automatically format your code, run the following command:
```bash
tox -e format
```

In addition, we apply some project-specific standards:

## Module imports

We organize import statements at the top of each module in the following order, each section being separated by a blank line:
1. Standard Python library imports
2. Third-party library imports
3. PySpark library imports
4. Tempo library imports

Within each section, imports are sorted alphabetically. While it is acceptable to directly import classes and some functions that are
going to be commonly used, for the sake of readability, it is generally preferred to import a package with an alias and then use the alias
to reference the package's classes and functions.

When importing `pyspark.sql.functions`, we use the convention to alias this package as `sfn`, which is both distinctive and short.
87 changes: 46 additions & 41 deletions python/tempo/interpol.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import annotations

from typing import List, Optional, Union, Callable
from typing import Callable, List, Optional, Union

from pyspark.sql.dataframe import DataFrame
from pyspark.sql.functions import col, expr, last, lead, lit, when
import pyspark.sql.functions as sfn
from pyspark.sql.window import Window

import tempo.utils as t_utils
import tempo.resample as t_resample
import tempo.tsdf as t_tsdf
import tempo.utils as t_utils

# Interpolation fill options
method_options = ["zero", "null", "bfill", "ffill", "linear"]
Expand Down Expand Up @@ -130,56 +130,56 @@ def __interpolate_column(
END AS is_interpolated_{target_col}
"""
output_df = output_df.withColumn(
f"is_interpolated_{target_col}", expr(flag_expr)
f"is_interpolated_{target_col}", sfn.expr(flag_expr)
)

# Handle zero fill
if method == "zero":
output_df = output_df.withColumn(
target_col,
when(
col(f"is_interpolated_{target_col}") == False, # noqa: E712
col(target_col),
).otherwise(lit(0)),
sfn.when(
sfn.col(f"is_interpolated_{target_col}") == False, # noqa: E712
sfn.col(target_col),
).otherwise(sfn.lit(0)),
)

# Handle null fill
if method == "null":
output_df = output_df.withColumn(
target_col,
when(
col(f"is_interpolated_{target_col}") == False, # noqa: E712
col(target_col),
sfn.when(
sfn.col(f"is_interpolated_{target_col}") == False, # noqa: E712
sfn.col(target_col),
).otherwise(None),
)

# Handle forward fill
if method == "ffill":
output_df = output_df.withColumn(
target_col,
when(
col(f"is_interpolated_{target_col}") == True, # noqa: E712
col(f"previous_{target_col}"),
).otherwise(col(target_col)),
sfn.when(
sfn.col(f"is_interpolated_{target_col}") == True, # noqa: E712
sfn.col(f"previous_{target_col}"),
).otherwise(sfn.col(target_col)),
)
# Handle backwards fill
if method == "bfill":
output_df = output_df.withColumn(
target_col,
# Handle case when subsequent value is null
when(
(col(f"is_interpolated_{target_col}") == True) # noqa: E712
sfn.when(
(sfn.col(f"is_interpolated_{target_col}") == True) # noqa: E712
& (
col(f"next_{target_col}").isNull()
& (col(f"{ts_col}_{target_col}").isNull())
sfn.col(f"next_{target_col}").isNull()
& (sfn.col(f"{ts_col}_{target_col}").isNull())
),
col(f"next_null_{target_col}"),
sfn.col(f"next_null_{target_col}"),
).otherwise(
# Handle standard backwards fill
when(
col(f"is_interpolated_{target_col}") == True, # noqa: E712
col(f"next_{target_col}"),
).otherwise(col(f"{target_col}"))
sfn.when(
sfn.col(f"is_interpolated_{target_col}") == True, # noqa: E712
sfn.col(f"next_{target_col}"),
).otherwise(sfn.col(f"{target_col}"))
),
)

Expand All @@ -205,10 +205,12 @@ def __generate_time_series_fill(
"""
return df.withColumn(
"previous_timestamp",
col(ts_col),
sfn.col(ts_col),
).withColumn(
"next_timestamp",
lead(df[ts_col]).over(Window.partitionBy(*partition_cols).orderBy(ts_col)),
sfn.lead(df[ts_col]).over(
Window.partitionBy(*partition_cols).orderBy(ts_col)
),
)

def __generate_column_time_fill(
Expand All @@ -232,13 +234,13 @@ def __generate_column_time_fill(

return df.withColumn(
f"previous_timestamp_{target_col}",
last(col(f"{ts_col}_{target_col}"), ignorenulls=True).over(
sfn.last(sfn.col(f"{ts_col}_{target_col}"), ignorenulls=True).over(
window.orderBy(ts_col).rowsBetween(Window.unboundedPreceding, 0)
),
).withColumn(
f"next_timestamp_{target_col}",
last(col(f"{ts_col}_{target_col}"), ignorenulls=True).over(
window.orderBy(col(ts_col).desc()).rowsBetween(
sfn.last(sfn.col(f"{ts_col}_{target_col}"), ignorenulls=True).over(
window.orderBy(sfn.col(ts_col).desc()).rowsBetween(
Window.unboundedPreceding, 0
)
),
Expand Down Expand Up @@ -266,21 +268,21 @@ def __generate_target_fill(
return (
df.withColumn(
f"previous_{target_col}",
last(df[target_col], ignorenulls=True).over(
sfn.last(df[target_col], ignorenulls=True).over(
window.orderBy(ts_col).rowsBetween(Window.unboundedPreceding, 0)
),
)
# Handle if subsequent value is null
.withColumn(
f"next_null_{target_col}",
last(df[target_col], ignorenulls=True).over(
window.orderBy(col(ts_col).desc()).rowsBetween(
sfn.last(df[target_col], ignorenulls=True).over(
window.orderBy(sfn.col(ts_col).desc()).rowsBetween(
Window.unboundedPreceding, 0
)
),
).withColumn(
f"next_{target_col}",
lead(df[target_col]).over(window.orderBy(ts_col)),
sfn.lead(df[target_col]).over(window.orderBy(ts_col)),
)
)

Expand Down Expand Up @@ -356,7 +358,7 @@ def interpolate(
for column in target_cols:
add_column_time = add_column_time.withColumn(
f"{ts_col}_{column}",
when(col(column).isNull(), None).otherwise(col(ts_col)),
sfn.when(sfn.col(column).isNull(), None).otherwise(sfn.col(ts_col)),
)
add_column_time = self.__generate_column_time_fill(
add_column_time, partition_cols, ts_col, column
Expand All @@ -365,9 +367,10 @@ def interpolate(
# Handle edge case if last value (latest) is null
edge_filled = add_column_time.withColumn(
"next_timestamp",
when(
col("next_timestamp").isNull(), expr(f"{ts_col}+ interval {freq}")
).otherwise(col("next_timestamp")),
sfn.when(
sfn.col("next_timestamp").isNull(),
sfn.expr(f"{ts_col}+ interval {freq}"),
).otherwise(sfn.col("next_timestamp")),
)

# Fill target column for nearest values
Expand All @@ -380,7 +383,7 @@ def interpolate(
# Generate missing timeseries values
exploded_series = target_column_filled.withColumn(
f"new_{ts_col}",
expr(
sfn.expr(
f"explode(sequence({ts_col}, next_timestamp - interval {freq}, interval {freq} )) as timestamp"
),
)
Expand All @@ -390,10 +393,12 @@ def interpolate(
flagged_series = (
exploded_series.withColumn(
"is_ts_interpolated",
when(col(f"new_{ts_col}") != col(ts_col), True).otherwise(False),
sfn.when(sfn.col(f"new_{ts_col}") != sfn.col(ts_col), True).otherwise(
False
),
)
.withColumn(ts_col, col(f"new_{ts_col}"))
.drop(col(f"new_{ts_col}"))
.withColumn(ts_col, sfn.col(f"new_{ts_col}"))
.drop(sfn.col(f"new_{ts_col}"))
)

# # Perform interpolation on each target column
Expand Down
52 changes: 26 additions & 26 deletions python/tempo/intervals.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from __future__ import annotations

from typing import Optional
from functools import cached_property
from typing import Optional

import pyspark.sql
import pyspark.sql.functions as sfn
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import NumericType, BooleanType, StructField
import pyspark.sql.functions as f
from pyspark.sql.window import Window
from pyspark.sql.types import BooleanType, NumericType, StructField
from pyspark.sql.window import Window, WindowSpec


def is_metric_col(col: StructField) -> bool:
Expand Down Expand Up @@ -105,7 +104,7 @@ def metric_columns(self) -> list[str]:
return [col.name for col in self.df.schema.fields if is_metric_col(col)]

@cached_property
def window(self) -> pyspark.sql.window:
def window(self) -> WindowSpec:
return Window.partitionBy(*self.series_ids).orderBy(*self.interval_boundaries)

@classmethod
Expand Down Expand Up @@ -210,10 +209,10 @@ def __get_adjacent_rows(self, df: DataFrame) -> DataFrame:
for c in self.interval_boundaries + self.metric_columns:
df = df.withColumn(
f"_lead_1_{c}",
f.lead(c, 1).over(self.window),
sfn.lead(c, 1).over(self.window),
).withColumn(
f"_lag_1_{c}",
f.lag(c, 1).over(self.window),
sfn.lag(c, 1).over(self.window),
)

return df
Expand All @@ -236,8 +235,8 @@ def __identify_subset_intervals(self, df: DataFrame) -> tuple[DataFrame, str]:

df = df.withColumn(
subset_indicator,
(f.col(f"_lag_1_{self.start_ts}") <= f.col(self.start_ts))
& (f.col(f"_lag_1_{self.end_ts}") >= f.col(self.end_ts)),
(sfn.col(f"_lag_1_{self.start_ts}") <= sfn.col(self.start_ts))
& (sfn.col(f"_lag_1_{self.end_ts}") >= sfn.col(self.end_ts)),
)

# NB: the first record cannot be a subset of the previous and
Expand Down Expand Up @@ -271,12 +270,12 @@ def __identify_overlaps(self, df: DataFrame) -> tuple[DataFrame, list[str]]:
for ts in self.interval_boundaries:
df = df.withColumn(
f"_lead_1_{ts}_overlaps",
(f.col(f"_lead_1_{ts}") > f.col(self.start_ts))
& (f.col(f"_lead_1_{ts}") < f.col(self.end_ts)),
(sfn.col(f"_lead_1_{ts}") > sfn.col(self.start_ts))
& (sfn.col(f"_lead_1_{ts}") < sfn.col(self.end_ts)),
).withColumn(
f"_lag_1_{ts}_overlaps",
(f.col(f"_lag_1_{ts}") > f.col(self.start_ts))
& (f.col(f"_lag_1_{ts}") < f.col(self.end_ts)),
(sfn.col(f"_lag_1_{ts}") > sfn.col(self.start_ts))
& (sfn.col(f"_lag_1_{ts}") < sfn.col(self.end_ts)),
)

overlap_indicators.extend(
Expand Down Expand Up @@ -321,9 +320,10 @@ def __merge_adjacent_subset_and_superset(
for c in self.metric_columns:
df = df.withColumn(
c,
f.when(
f.col(subset_indicator), f.coalesce(f.col(c), f"_lag_1_{c}")
).otherwise(f.col(c)),
sfn.when(
sfn.col(subset_indicator),
sfn.coalesce(sfn.col(c), f"_lag_1_{c}"),
).otherwise(sfn.col(c)),
)

return df
Expand Down Expand Up @@ -385,7 +385,7 @@ def __merge_adjacent_overlaps(

df = df.withColumn(
new_boundary_col,
f.expr(new_interval_boundaries),
sfn.expr(new_interval_boundaries),
)

if how == "left":
Expand All @@ -394,13 +394,13 @@ def __merge_adjacent_overlaps(
c,
# needed when intervals have same start but different ends
# in this case, merge metrics since they overlap
f.when(
f.col(f"_lag_1_{self.end_ts}_overlaps"),
f.coalesce(f.col(c), f.col(f"_lag_1_{c}")),
sfn.when(
sfn.col(f"_lag_1_{self.end_ts}_overlaps"),
sfn.coalesce(sfn.col(c), sfn.col(f"_lag_1_{c}")),
)
# general case when constructing left disjoint interval
# just want new boundary without merging metrics
.otherwise(f.col(c)),
.otherwise(sfn.col(c)),
)

return df
Expand All @@ -423,7 +423,7 @@ def __merge_equal_intervals(self, df: DataFrame) -> DataFrame:
"""

merge_expr = tuple(f.max(c).alias(c) for c in self.metric_columns)
merge_expr = tuple(sfn.max(c).alias(c) for c in self.metric_columns)

return df.groupBy(*self.interval_boundaries, *self.series_ids).agg(*merge_expr)

Expand Down Expand Up @@ -469,7 +469,7 @@ def disjoint(self) -> "IntervalsDF":

(df, subset_indicator) = self.__identify_subset_intervals(df)

subset_df = df.filter(f.col(subset_indicator))
subset_df = df.filter(sfn.col(subset_indicator))

subset_df = self.__merge_adjacent_subset_and_superset(
subset_df, subset_indicator
Expand All @@ -479,7 +479,7 @@ def disjoint(self) -> "IntervalsDF":
*self.interval_boundaries, *self.series_ids, *self.metric_columns
)

non_subset_df = df.filter(~f.col(subset_indicator))
non_subset_df = df.filter(~sfn.col(subset_indicator))

(non_subset_df, overlap_indicators) = self.__identify_overlaps(non_subset_df)

Expand Down Expand Up @@ -611,7 +611,7 @@ def toDF(self, stack: bool = False) -> DataFrame:
)

return self.df.select(
*self.interval_boundaries, *self.series_ids, f.expr(stack_expr)
*self.interval_boundaries, *self.series_ids, sfn.expr(stack_expr)
).dropna(subset="metric_value")

else:
Expand Down
Loading

0 comments on commit 776218e

Please sign in to comment.