From e7975e869900507dc6d226e3b9609092b9087b64 Mon Sep 17 00:00:00 2001 From: Lukas Berbuer <36054362+lukasberbuer@users.noreply.github.com> Date: Mon, 16 Sep 2024 19:28:47 +0200 Subject: [PATCH] feat!: migrate from named tuples to dataclasses (#31) --- src/vallenae/io/_dataframe.py | 10 ++++++---- src/vallenae/io/datatypes.py | 29 ++++++++++++++++++----------- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/src/vallenae/io/_dataframe.py b/src/vallenae/io/_dataframe.py index 07115df..a149267 100644 --- a/src/vallenae/io/_dataframe.py +++ b/src/vallenae/io/_dataframe.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Any + import pandas as pd from tqdm import tqdm @@ -23,7 +25,7 @@ def _convert_to_nullable_types(df: pd.DataFrame): def iter_to_dataframe( - iterable: SizedIterable[tuple], + iterable: SizedIterable[Any], show_progress: bool = True, desc: str = "", index_column: str | None = None, @@ -32,10 +34,10 @@ def iter_to_dataframe( Helper function to save iterator results in Pandas DataFrame. Args: - iterable: Iterable generating `NamedTuple`s - show_progress: Show progress bar. Default: `True` + iterable: Iterable generating dataclasses + show_progress: Show progress bar desc: Description shown left to the progress bar - index_column: Set column as index. Default: `None` + index_column: Set column as index Returns: Pandas DataFrame """ diff --git a/src/vallenae/io/datatypes.py b/src/vallenae/io/datatypes.py index 47944e0..78b756d 100644 --- a/src/vallenae/io/datatypes.py +++ b/src/vallenae/io/datatypes.py @@ -2,8 +2,9 @@ from __future__ import annotations +from dataclasses import dataclass, field from enum import IntEnum, IntFlag -from typing import Any, NamedTuple +from typing import Any import numpy as np @@ -66,7 +67,8 @@ class StatusFlags(IntFlag): # fmt: on -class HitRecord(NamedTuple): +@dataclass +class HitRecord: """ Hit record in pridb (`SetType.HIT`). """ @@ -81,7 +83,7 @@ class HitRecord(NamedTuple): rms: float #: RMS of the noise before the hit in volts # optional for creating: set_id: int | None = None #: Unique identifier for data set in pridb - status: HitFlags = HitFlags(0) #: Status flags + status: HitFlags = field(default=HitFlags(0)) #: Status flags threshold: float | None = None #: Threshold amplitude in volts rise_time: float | None = None #: Rise time in seconds signal_strength: float | None = None #: Signal strength in nVs (1e-9 Vs) @@ -123,7 +125,8 @@ def from_sql(cls, row: dict[str, Any]) -> "HitRecord": ) -class MarkerRecord(NamedTuple): +@dataclass +class MarkerRecord: """ Marker record in pridb (`SetType.LABEL`, `SetType.DATETIME`, `SetType.SECTION`). """ @@ -152,7 +155,8 @@ def from_sql(cls, row: dict[str, Any]) -> "MarkerRecord": ) -class StatusRecord(NamedTuple): +@dataclass +class StatusRecord: """ Status data record in pridb (`SetType.STATUS`). """ @@ -164,7 +168,7 @@ class StatusRecord(NamedTuple): rms: float #: RMS in volts # optional for creating: set_id: int | None = None #: Unique identifier for data set in pridb - status: StatusFlags = StatusFlags(0) #: Status flags + status: StatusFlags = field(default=StatusFlags(0)) #: Status flags threshold: float | None = None #: Threshold amplitude in volts signal_strength: float | None = None #: Signal strength in nVs (1e-9 Vs) @@ -189,7 +193,8 @@ def from_sql(cls, row: dict[str, Any]) -> "StatusRecord": ) -class ParametricRecord(NamedTuple): +@dataclass +class ParametricRecord: """ Parametric data record in pridb (`SetType.PARAMETRIC`). """ @@ -198,7 +203,7 @@ class ParametricRecord(NamedTuple): param_id: int #: Parameter ID of table ae_params for ADC value conversion # optional for creating: set_id: int | None = None #: Unique identifier for data set in pridb - status: StatusFlags = StatusFlags(0) #: Status flags + status: StatusFlags = field(default=StatusFlags(0)) #: Status flags pctd: int | None = None #: Digital counter value pcta: int | None = None #: Analog hysteresis counter pa0: int | None = None #: Amplitude of parametric input 0 in volts @@ -236,7 +241,8 @@ def from_sql(cls, row: dict[str, Any]) -> "ParametricRecord": ) -class TraRecord(NamedTuple): +@dataclass +class TraRecord: """Transient data record in tradb.""" time: float #: Time in seconds @@ -248,7 +254,7 @@ class TraRecord(NamedTuple): samples: int #: Number of samples data: np.ndarray #: Transient signal in volts or ADC values if `raw` = `True` # optional for writing - status: HitFlags = HitFlags(0) #: Status flags + status: HitFlags = field(default=HitFlags(0)) #: Status flags trai: int | None = None #: Transient recorder index (foreign key between pridb and tradb) rms: float | None = None #: RMS of the noise before the hit # optional @@ -278,7 +284,8 @@ def from_sql(cls, row: dict[str, Any], *, raw: bool = False) -> "TraRecord": ) -class FeatureRecord(NamedTuple): +@dataclass +class FeatureRecord: """ Transient feature record in trfdb. """