Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metadata cleanup #272

Closed
wants to merge 19 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
created new TSIndex and TSSchema classes to represent TSDF metadata.
First round of TSDF code changes to use the new classes
tnixon committed Aug 11, 2022
commit 6a175697339cca91ad6979b9776e15b130354906
245 changes: 137 additions & 108 deletions python/tempo/tsdf.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
import logging
import operator
from functools import reduce
from typing import List, Union, Callable
from typing import List, Union, Callable, Collection, Set

import numpy as np
import pyspark.sql.functions as f
@@ -18,6 +18,7 @@
import tempo.io as tio
import tempo.resample as rs
from tempo.interpol import Interpolation
from tempo.tsschema import TSSchema
from tempo.utils import (
ENV_CAN_RENDER_HTML,
IS_DATABRICKS,
@@ -33,33 +34,84 @@ class TSDF:
This object is the main wrapper over a Spark data frame which allows a user to parallelize time series computations on a Spark data frame by various dimensions. The two dimensions required are partition_cols (list of columns by which to summarize) and ts_col (timestamp column, which can be epoch or TimestampType).
"""

def __init__(self, df, ts_col="event_ts", partition_cols=None, sequence_col=None):
"""
Constructor
:param df:
:param ts_col:
:param partitionCols:
:sequence_col every tsdf allows for a tie-breaker secondary sort key
"""
self.ts_col = self.__validated_column(df, ts_col)
self.partitionCols = (
[]
if partition_cols is None
else self.__validated_columns(df, partition_cols.copy())
)

def __init__(
self,
df: DataFrame,
ts_schema: TSSchema = None,
ts_col: str = None,
series_ids: Collection[str] = None,
validate_schema=True,
) -> None:
self.df = df
self.sequence_col = "" if sequence_col is None else sequence_col
# construct schema if we don't already have one
if ts_schema:
self.ts_schema = ts_schema
else:
self.ts_schema = TSSchema.fromDFSchema(self.df.schema, ts_col, series_ids)
# validate that this schema works for this DataFrame
if validate_schema:
self.ts_schema.validate(df.schema)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the scenario where we would not want to validate the schema?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see there are some protected methods where we don't validate schema, but seems like exposing this arg could cause issues if set to False when users initialize a TSDF.

I also don't think it hurts to validate the schema each time we manipulate the underlying DF in any way, even protected args.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most TSDF transformer methods make some changes to the underlying DF and then return it wrapped in a new TSDF object. I think of this validation as primarily for end-users who might need guidance on how they're building a TSDF. Internal transformations should already be safe, so shouldn't require validation.

However, I'm open to doing validation on every constructor. I dont' think it'll be a hugely heavy function.


@property
def ts_index(self) -> str:
return self.ts_schema.ts_index

@property
def ts_col(self) -> str:
if self.ts_schema.user_ts_col:
return self.ts_schema.user_ts_col
return self.ts_index

@property
def series_ids(self) -> List[str]:
return self.ts_schema.series_ids

@property
def structural_cols(self) -> Set[str]:
return self.ts_schema.structural_columns

@property
def observational_cols(self) -> List[str]:
return [
col.name
for col in self.ts_schema.find_observational_columns(self.df.schema)
]

# Add customized check for string type for the timestamp. If we see a string, we will proactively created a double version of the string timestamp for sorting purposes and rename to ts_col
if df.schema[ts_col].dataType == "StringType":
sample_ts = df.limit(1).collect()[0][0]
self.__validate_ts_string(sample_ts)
self.__add_double_ts().withColumnRenamed("double_ts", self.ts_col)
@property
def metric_cols(self) -> List[str]:
return [col.name for col in self.ts_schema.find_metric_columns(self.df.schema)]

"""
Make sure DF is ordered by its respective ts_col and partition columns.
"""
#
# Class Factory Methods
#

@classmethod
def __withTransformedDF(cls, new_df: DataFrame, ts_schema: TSSchema) -> "TSDF":
return cls(new_df, ts_schema=ts_schema, validate_schema=False)

# def __init__(self, df, ts_col="event_ts", partition_cols=None, sequence_col=None):
# """
# Constructor
# :param df:
# :param ts_col:
# :param partitionCols:
# :sequence_col every tsdf allows for a tie-breaker secondary sort key
# """
# self.ts_col = self.__validated_column(df, ts_col)
# self.partitionCols = (
# []
# if partition_cols is None
# else self.__validated_columns(df, partition_cols.copy())
# )
#
# self.df = df
# self.sequence_col = "" if sequence_col is None else sequence_col
#
# # Add customized check for string type for the timestamp. If we see a string, we will proactively created a double version of the string timestamp for sorting purposes and rename to ts_col
# if df.schema[ts_col].dataType == "StringType":
# sample_ts = df.limit(1).collect()[0][0]
# self.__validate_ts_string(sample_ts)
# self.__add_double_ts().withColumnRenamed("double_ts", self.ts_col)

#
# Helper functions
@@ -119,7 +171,7 @@ def __validated_columns(self, df, colnames):
return colnames

def __checkPartitionCols(self, tsdf_right):
for left_col, right_col in zip(self.partitionCols, tsdf_right.partitionCols):
for left_col, right_col in zip(self.series_ids, tsdf_right.series_ids):
if left_col != right_col:
raise ValueError(
"left and right dataframe partition columns should have same name in same order"
@@ -158,7 +210,7 @@ def __addPrefixToColumns(self, col_list, prefix):
if self.sequence_col
else self.sequence_col
)
return TSDF(df, ts_col, self.partitionCols, sequence_col=seq_col)
return TSDF(df, ts_col, self.series_ids, sequence_col=seq_col)

def __addColumnsFromOtherDF(self, other_cols):
"""
@@ -170,14 +222,14 @@ def __addColumnsFromOtherDF(self, other_cols):
self.df,
)

return TSDF(new_df, self.ts_col, self.partitionCols)
return TSDF(new_df, self.ts_col, self.series_ids)

def __combineTSDF(self, ts_df_right, combined_ts_col):
combined_df = self.df.unionByName(ts_df_right.df).withColumn(
combined_ts_col, f.coalesce(self.ts_col, ts_df_right.ts_col)
)

return TSDF(combined_df, combined_ts_col, self.partitionCols)
return TSDF(combined_df, combined_ts_col, self.series_ids)

def __getLastRightRow(
self,
@@ -197,7 +249,7 @@ def __getLastRightRow(
sort_keys = [f.col(col_name) for col_name in ptntl_sort_keys if col_name != ""]

window_spec = (
Window.partitionBy(self.partitionCols)
Window.partitionBy(self.series_ids)
.orderBy(sort_keys)
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
)
@@ -275,7 +327,7 @@ def __getLastRightRow(
)
df = df.drop(column)

return TSDF(df, left_ts_col, self.partitionCols)
return TSDF(df, left_ts_col, self.series_ids)

def __getTimePartitions(self, tsPartitionVal, fraction=0.1):
"""
@@ -316,7 +368,7 @@ def __getTimePartitions(self, tsPartitionVal, fraction=0.1):
df = partition_df.union(remainder_df).drop(
"partition_remainder", "ts_col_double"
)
return TSDF(df, self.ts_col, self.partitionCols + ["ts_partition"])
return TSDF(df, self.ts_col, self.series_ids + ["ts_partition"])

#
# Slicing & Selection
@@ -342,12 +394,12 @@ def select(self, *cols):
"""
# The columns which will be a mandatory requirement while selecting from TSDFs
seq_col_stub = [] if bool(self.sequence_col) is False else [self.sequence_col]
mandatory_cols = [self.ts_col] + self.partitionCols + seq_col_stub
mandatory_cols = [self.ts_col] + self.series_ids + seq_col_stub
if set(mandatory_cols).issubset(set(cols)):
return TSDF(
self.df.select(*cols),
self.ts_col,
self.partitionCols,
self.series_ids,
self.sequence_col,
)
else:
@@ -369,12 +421,7 @@ def __slice(self, op: str, target_ts):
target_expr = f"'{target_ts}'" if isinstance(target_ts, str) else target_ts
slice_expr = f.expr(f"{self.ts_col} {op} {target_expr}")
sliced_df = self.df.where(slice_expr)
return TSDF(
sliced_df,
ts_col=self.ts_col,
partition_cols=self.partitionCols,
sequence_col=self.sequence_col,
)
return TSDF.__withTransformedDF(sliced_df, self.ts_schema)

def at(self, ts):
"""
@@ -456,12 +503,7 @@ def __top_rows_per_series(self, win: WindowSpec, n: int):
.where(f.col(row_num_col) <= f.lit(n))
.drop(row_num_col)
)
return TSDF(
prev_records_df,
ts_col=self.ts_col,
partition_cols=self.partitionCols,
sequence_col=self.sequence_col,
)
return TSDF.__withTransformedDF(prev_records_df, self.ts_schema)

def earliest(self, n: int = 1):
"""
@@ -579,7 +621,7 @@ def describe(self):

# describe stats
desc_stats = this_df.describe().union(missing_vals)
unique_ts = this_df.select(*self.partitionCols).distinct().count()
unique_ts = this_df.select(*self.series_ids).distinct().count()

max_ts = this_df.select(f.max(f.col(self.ts_col)).alias("max_ts")).collect()[0][
0
@@ -707,10 +749,10 @@ def asofJoin(
(left_bytes < bytes_threshold) | (right_bytes < bytes_threshold)
):
spark.conf.set("spark.databricks.optimizer.rangeJoin.binSize", 60)
partition_cols = right_tsdf.partitionCols
left_cols = list(set(left_df.columns).difference(set(self.partitionCols)))
partition_cols = right_tsdf.series_ids
left_cols = list(set(left_df.columns).difference(set(self.series_ids)))
right_cols = list(
set(right_df.columns).difference(set(right_tsdf.partitionCols))
set(right_df.columns).difference(set(right_tsdf.series_ids))
)

left_prefix = (
@@ -753,7 +795,7 @@ def asofJoin(
)
.drop("lead_" + right_tsdf.ts_col)
)
return TSDF(res, partition_cols=self.partitionCols, ts_col=new_left_ts_col)
return TSDF(res, series_ids=self.series_ids, ts_col=new_left_ts_col)

# end of block checking to see if standard Spark SQL join will work

@@ -772,11 +814,9 @@ def asofJoin(
# validate timestamp datatypes match
self.__validateTsColMatch(right_tsdf)

orig_left_col_diff = list(
set(left_df.columns).difference(set(self.partitionCols))
)
orig_left_col_diff = list(set(left_df.columns).difference(set(self.series_ids)))
orig_right_col_diff = list(
set(right_df.columns).difference(set(self.partitionCols))
set(right_df.columns).difference(set(self.series_ids))
)

left_tsdf = (
@@ -789,10 +829,10 @@ def asofJoin(
)

left_nonpartition_cols = list(
set(left_tsdf.df.columns).difference(set(self.partitionCols))
set(left_tsdf.df.columns).difference(set(self.series_ids))
)
right_nonpartition_cols = list(
set(right_tsdf.df.columns).difference(set(self.partitionCols))
set(right_tsdf.df.columns).difference(set(self.series_ids))
)

# For both dataframes get all non-partition columns (including ts_col)
@@ -836,29 +876,24 @@ def asofJoin(
"ts_partition", "is_original"
)

asofDF = TSDF(df, asofDF.ts_col, combined_df.partitionCols)
asofDF = TSDF(df, asofDF.ts_col, combined_df.series_ids)

return asofDF

def __baseWindow(self, sort_col=None, reverse=False):
# figure out our sorting columns
primary_sort_col = self.ts_col if not sort_col else sort_col
sort_cols = (
[primary_sort_col, self.sequence_col]
if self.sequence_col
else [primary_sort_col]
)

# are we ordering forwards (default) or reveresed?
col_fn = f.col
if reverse:
col_fn = lambda colname: f.col(colname).desc() # noqa E731

# our window will be sorted on our sort_cols in the appropriate direction
w = Window().orderBy([col_fn(col) for col in sort_cols])
if reverse:
w = Window().orderBy(f.col(self.ts_index).desc())
else:
w = Window().orderBy(f.col(self.ts_index).asc())
# and partitioned by any series IDs
if self.partitionCols:
w = w.partitionBy([f.col(elem) for elem in self.partitionCols])
if self.series_ids:
w = w.partitionBy([f.col(sid) for sid in self.series_ids])
return w

def __rangeBetweenWindow(self, range_from, range_to, sort_col=None, reverse=False):
@@ -900,8 +935,8 @@ def vwap(self, frequency="m", volume_col="volume", price_col="price"):
)

group_cols = ["time_group"]
if self.partitionCols:
group_cols.extend(self.partitionCols)
if self.series_ids:
group_cols.extend(self.series_ids)
vwapped = (
pre_vwap.withColumn("dllr_value", f.col(price_col) * f.col(volume_col))
.groupby(group_cols)
@@ -913,7 +948,7 @@ def vwap(self, frequency="m", volume_col="volume", price_col="price"):
.withColumn("vwap", f.col("dllr_value") / f.col(volume_col))
)

return TSDF(vwapped, self.ts_col, self.partitionCols)
return TSDF(vwapped, self.ts_col, self.series_ids)

def EMA(self, colName, window=30, exp_factor=0.2):
"""
@@ -940,7 +975,7 @@ def EMA(self, colName, window=30, exp_factor=0.2):
).drop(lagColName)
# Nulls are currently removed

return TSDF(df, self.ts_col, self.partitionCols)
return TSDF(df, self.ts_col, self.series_ids)

def withLookbackFeatures(
self, featureCols, lookbackWindowSize, exactSize=True, featureColName="features"
@@ -974,7 +1009,7 @@ def withLookbackFeatures(
if exactSize:
return lookback_tsdf.where(f.size(featureColName) == lookbackWindowSize)

return TSDF(lookback_tsdf, self.ts_col, self.partitionCols)
return TSDF(lookback_tsdf, self.ts_col, self.series_ids)

def withRangeStats(
self, type="range", colsToSummarize=[], rangeBackWindowSecs=1000
@@ -1000,8 +1035,8 @@ def withRangeStats(
if not colsToSummarize:
# columns we should never summarize
prohibited_cols = [self.ts_col.lower()]
if self.partitionCols:
prohibited_cols.extend([pc.lower() for pc in self.partitionCols])
if self.series_ids:
prohibited_cols.extend([pc.lower() for pc in self.series_ids])
# types that can be summarized
summarizable_types = ["int", "bigint", "float", "double"]
# filter columns to find summarizable columns
@@ -1045,7 +1080,7 @@ def withRangeStats(
"double_ts"
)

return TSDF(summary_df, self.ts_col, self.partitionCols)
return TSDF(summary_df, self.ts_col, self.series_ids)

def withGroupedStats(self, metricCols=[], freq=None):
"""
@@ -1062,8 +1097,8 @@ def withGroupedStats(self, metricCols=[], freq=None):
if not metricCols:
# columns we should never summarize
prohibited_cols = [self.ts_col.lower()]
if self.partitionCols:
prohibited_cols.extend([pc.lower() for pc in self.partitionCols])
if self.series_ids:
prohibited_cols.extend([pc.lower() for pc in self.series_ids])
# types that can be summarized
summarizable_types = ["int", "bigint", "float", "double"]
# filter columns to find summarizable columns
@@ -1097,16 +1132,14 @@ def withGroupedStats(self, metricCols=[], freq=None):
]
)

selected_df = self.df.groupBy(self.partitionCols + [agg_window]).agg(
*selectedCols
)
selected_df = self.df.groupBy(self.series_ids + [agg_window]).agg(*selectedCols)
summary_df = (
selected_df.select(*selected_df.columns)
.withColumn(self.ts_col, f.col("window").start)
.drop("window")
)

return TSDF(summary_df, self.ts_col, self.partitionCols)
return TSDF(summary_df, self.ts_col, self.series_ids)

def write(self, spark, tabName, optimizationCols=None):
tio.write(self, spark, tabName, optimizationCols)
@@ -1134,15 +1167,15 @@ def resample(

# Throw warning for user to validate that the expected number of output rows is valid.
if fill is True and perform_checks is True:
calculate_time_horizon(self.df, self.ts_col, freq, self.partitionCols)
calculate_time_horizon(self.df, self.ts_col, freq, self.series_ids)

enriched_df: DataFrame = rs.aggregate(
self, freq, func, metricCols, prefix, fill
)
return _ResampledTSDF(
enriched_df,
ts_col=self.ts_col,
partition_cols=self.partitionCols,
series_ids=self.series_ids,
freq=freq,
func=func,
)
@@ -1177,7 +1210,7 @@ def interpolate(
if ts_col is None:
ts_col = self.ts_col
if partition_cols is None:
partition_cols = self.partitionCols
partition_cols = self.series_ids
if target_cols is None:
prohibited_cols: List[str] = partition_cols + [ts_col]
summarizable_types = ["int", "bigint", "float", "double"]
@@ -1193,7 +1226,7 @@ def interpolate(
]

interpolate_service: Interpolation = Interpolation(is_resampled=False)
tsdf_input = TSDF(self.df, ts_col=ts_col, partition_cols=partition_cols)
tsdf_input = TSDF(self.df, ts_col=ts_col, series_ids=partition_cols)
interpolated_df: DataFrame = interpolate_service.interpolate(
tsdf_input,
ts_col,
@@ -1206,7 +1239,7 @@ def interpolate(
perform_checks,
)

return TSDF(interpolated_df, ts_col=ts_col, partition_cols=partition_cols)
return TSDF(interpolated_df, ts_col=ts_col, series_ids=partition_cols)

def calc_bars(tsdf, freq, func=None, metricCols=None, fill=None):

@@ -1223,21 +1256,21 @@ def calc_bars(tsdf, freq, func=None, metricCols=None, fill=None):
freq=freq, func="ceil", metricCols=metricCols, prefix="close", fill=fill
)

join_cols = resample_open.partitionCols + [resample_open.ts_col]
join_cols = resample_open.series_ids + [resample_open.ts_col]
bars = (
resample_open.df.join(resample_high.df, join_cols)
.join(resample_low.df, join_cols)
.join(resample_close.df, join_cols)
)
non_part_cols = set(set(bars.columns) - set(resample_open.partitionCols)) - set(
non_part_cols = set(set(bars.columns) - set(resample_open.series_ids)) - set(
[resample_open.ts_col]
)
sel_and_sort = (
resample_open.partitionCols + [resample_open.ts_col] + sorted(non_part_cols)
resample_open.series_ids + [resample_open.ts_col] + sorted(non_part_cols)
)
bars = bars.select(sel_and_sort)

return TSDF(bars, resample_open.ts_col, resample_open.partitionCols)
return TSDF(bars, resample_open.ts_col, resample_open.series_ids)

def fourier_transform(self, timestep, valueCol):
"""
@@ -1267,7 +1300,7 @@ def tempo_fourier_util(pdf):
valueCol = self.__validated_column(self.df, valueCol)
data = self.df
if self.sequence_col:
if self.partitionCols == []:
if self.series_ids == []:
data = data.withColumn("dummy_group", f.lit("dummy_val"))
data = (
data.select(
@@ -1288,7 +1321,7 @@ def tempo_fourier_util(pdf):
)
result = result.drop("dummy_group", "tdval", "tpoints")
else:
group_cols = self.partitionCols
group_cols = self.series_ids
data = (
data.select(
*group_cols, self.ts_col, self.sequence_col, f.col(valueCol)
@@ -1305,7 +1338,7 @@ def tempo_fourier_util(pdf):
)
result = result.drop("tdval", "tpoints")
else:
if self.partitionCols == []:
if self.series_ids == []:
data = data.withColumn("dummy_group", f.lit("dummy_val"))
data = (
data.select(f.col("dummy_group"), self.ts_col, f.col(valueCol))
@@ -1321,7 +1354,7 @@ def tempo_fourier_util(pdf):
)
result = result.drop("dummy_group", "tdval", "tpoints")
else:
group_cols = self.partitionCols
group_cols = self.series_ids
data = (
data.select(*group_cols, self.ts_col, f.col(valueCol))
.withColumn("tdval", f.col(valueCol))
@@ -1336,7 +1369,7 @@ def tempo_fourier_util(pdf):
)
result = result.drop("tdval", "tpoints")

return TSDF(result, self.ts_col, self.partitionCols, self.sequence_col)
return TSDF(result, self.ts_col, self.series_ids, self.sequence_col)

def extractStateIntervals(
self,
@@ -1447,7 +1480,7 @@ def state_comparison_fn(a, b):

# Find the start and end timestamp of the interval
result = (
data.groupBy(*self.partitionCols, "state_incrementer")
data.groupBy(*self.series_ids, "state_incrementer")
.agg(
f.min("previous_ts").alias("start_ts"),
f.max(self.ts_col).alias("end_ts"),
@@ -1463,12 +1496,12 @@ def __init__(
self,
df,
ts_col="event_ts",
partition_cols=None,
series_ids=None,
sequence_col=None,
freq=None,
func=None,
):
super(_ResampledTSDF, self).__init__(df, ts_col, partition_cols, sequence_col)
super(_ResampledTSDF, self).__init__(df, ts_col, series_ids, sequence_col)
self.__freq = freq
self.__func = func

@@ -1491,7 +1524,7 @@ def interpolate(

# Set defaults for target columns, timestamp column and partition columns when not provided
if target_cols is None:
prohibited_cols: List[str] = self.partitionCols + [self.ts_col]
prohibited_cols: List[str] = self.series_ids + [self.ts_col]
summarizable_types = ["int", "bigint", "float", "double"]

# get summarizable find summarizable columns
@@ -1505,13 +1538,11 @@ def interpolate(
]

interpolate_service: Interpolation = Interpolation(is_resampled=True)
tsdf_input = TSDF(
self.df, ts_col=self.ts_col, partition_cols=self.partitionCols
)
tsdf_input = TSDF(self.df, ts_col=self.ts_col, series_ids=self.series_ids)
interpolated_df = interpolate_service.interpolate(
tsdf=tsdf_input,
ts_col=self.ts_col,
partition_cols=self.partitionCols,
series_ids=self.series_ids,
target_cols=target_cols,
freq=self.__freq,
func=self.__func,
@@ -1520,6 +1551,4 @@ def interpolate(
perform_checks=perform_checks,
)

return TSDF(
interpolated_df, ts_col=self.ts_col, partition_cols=self.partitionCols
)
return TSDF(interpolated_df, ts_col=self.ts_col, series_ids=self.series_ids)
99 changes: 99 additions & 0 deletions python/tempo/tsschema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Collection

from pyspark.sql.types import *


class TSIndex:
# Valid types for time index columns
__valid_ts_types = (
DateType(),
TimestampType(),
ByteType(),
ShortType(),
IntegerType(),
LongType(),
DecimalType(),
FloatType(),
DoubleType(),
)

def __init__(self, name: str, dataType: DataType) -> None:
if dataType not in self.__valid_ts_types:
raise TypeError(f"DataType {dataType} is not valid for a Timeseries Index")
self.name = name
self.dataType = dataType

@classmethod
def fromField(cls, ts_field: StructField) -> "TSIndex":
return cls(ts_field.name, ts_field.dataType)


class TSSchema:
"""
Schema type for a :class:`TSDF` class.
"""

# Valid types for metric columns
__metric_types = (
BooleanType(),
ByteType(),
ShortType(),
IntegerType(),
LongType(),
DecimalType(),
FloatType(),
DoubleType(),
)

def __init__(
self,
ts_idx: TSIndex,
series_ids: Collection[str] = None,
user_ts_col: str = None,
subsequence_col: str = None,
) -> None:
self.ts_idx = ts_idx
self.series_ids = list(series_ids)
self.user_ts_col = user_ts_col
self.subsequence_col = subsequence_col

@classmethod
def fromDFSchema(
cls, df_schema: StructType, ts_col: str, series_ids: Collection[str] = None
) -> "TSSchema":
# construct a TSIndex for the given ts_col
ts_idx = TSIndex.fromField(df_schema[ts_col])
return cls(ts_idx, series_ids)

@property
def ts_index(self) -> str:
return self.ts_idx.name

@property
def structural_columns(self) -> set[str]:
"""
Structural columns are those that define the structure of the :class:`TSDF`. This includes the timeseries column,
a timeseries index (if different), any subsequence column (if present), and the series ID columns.
:return: a set of column names corresponding the structural columns of a :class:`TSDF`
"""
struct_cols = {self.ts_index, self.user_ts_col, self.subsequence_col}.union(
self.series_ids
)
struct_cols.discard(None)
return struct_cols

def validate(self, df_schema: StructType) -> None:
pass

def find_observational_columns(self, df_schema: StructType) -> list[StructField]:
return [
col for col in df_schema.fields if col.name not in self.structural_columns
]

def find_metric_columns(self, df_schema: StructType) -> list[StructField]:
return [
col
for col in self.find_observational_columns(df_schema)
if col.dataType in self.__metric_types
]