Skip to content

Commit

Permalink
feat!: migrate from named tuples to dataclasses (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasberbuer authored Sep 16, 2024
1 parent 04d5034 commit e7975e8
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 15 deletions.
10 changes: 6 additions & 4 deletions src/vallenae/io/_dataframe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Any

import pandas as pd
from tqdm import tqdm

Expand All @@ -23,7 +25,7 @@ def _convert_to_nullable_types(df: pd.DataFrame):


def iter_to_dataframe(
iterable: SizedIterable[tuple],
iterable: SizedIterable[Any],
show_progress: bool = True,
desc: str = "",
index_column: str | None = None,
Expand All @@ -32,10 +34,10 @@ def iter_to_dataframe(
Helper function to save iterator results in Pandas DataFrame.
Args:
iterable: Iterable generating `NamedTuple`s
show_progress: Show progress bar. Default: `True`
iterable: Iterable generating dataclasses
show_progress: Show progress bar
desc: Description shown left to the progress bar
index_column: Set column as index. Default: `None`
index_column: Set column as index
Returns:
Pandas DataFrame
"""
Expand Down
29 changes: 18 additions & 11 deletions src/vallenae/io/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from __future__ import annotations

from dataclasses import dataclass, field
from enum import IntEnum, IntFlag
from typing import Any, NamedTuple
from typing import Any

import numpy as np

Expand Down Expand Up @@ -66,7 +67,8 @@ class StatusFlags(IntFlag):
# fmt: on


class HitRecord(NamedTuple):
@dataclass
class HitRecord:
"""
Hit record in pridb (`SetType.HIT`).
"""
Expand All @@ -81,7 +83,7 @@ class HitRecord(NamedTuple):
rms: float #: RMS of the noise before the hit in volts
# optional for creating:
set_id: int | None = None #: Unique identifier for data set in pridb
status: HitFlags = HitFlags(0) #: Status flags
status: HitFlags = field(default=HitFlags(0)) #: Status flags
threshold: float | None = None #: Threshold amplitude in volts
rise_time: float | None = None #: Rise time in seconds
signal_strength: float | None = None #: Signal strength in nVs (1e-9 Vs)
Expand Down Expand Up @@ -123,7 +125,8 @@ def from_sql(cls, row: dict[str, Any]) -> "HitRecord":
)


class MarkerRecord(NamedTuple):
@dataclass
class MarkerRecord:
"""
Marker record in pridb (`SetType.LABEL`, `SetType.DATETIME`, `SetType.SECTION`).
"""
Expand Down Expand Up @@ -152,7 +155,8 @@ def from_sql(cls, row: dict[str, Any]) -> "MarkerRecord":
)


class StatusRecord(NamedTuple):
@dataclass
class StatusRecord:
"""
Status data record in pridb (`SetType.STATUS`).
"""
Expand All @@ -164,7 +168,7 @@ class StatusRecord(NamedTuple):
rms: float #: RMS in volts
# optional for creating:
set_id: int | None = None #: Unique identifier for data set in pridb
status: StatusFlags = StatusFlags(0) #: Status flags
status: StatusFlags = field(default=StatusFlags(0)) #: Status flags
threshold: float | None = None #: Threshold amplitude in volts
signal_strength: float | None = None #: Signal strength in nVs (1e-9 Vs)

Expand All @@ -189,7 +193,8 @@ def from_sql(cls, row: dict[str, Any]) -> "StatusRecord":
)


class ParametricRecord(NamedTuple):
@dataclass
class ParametricRecord:
"""
Parametric data record in pridb (`SetType.PARAMETRIC`).
"""
Expand All @@ -198,7 +203,7 @@ class ParametricRecord(NamedTuple):
param_id: int #: Parameter ID of table ae_params for ADC value conversion
# optional for creating:
set_id: int | None = None #: Unique identifier for data set in pridb
status: StatusFlags = StatusFlags(0) #: Status flags
status: StatusFlags = field(default=StatusFlags(0)) #: Status flags
pctd: int | None = None #: Digital counter value
pcta: int | None = None #: Analog hysteresis counter
pa0: int | None = None #: Amplitude of parametric input 0 in volts
Expand Down Expand Up @@ -236,7 +241,8 @@ def from_sql(cls, row: dict[str, Any]) -> "ParametricRecord":
)


class TraRecord(NamedTuple):
@dataclass
class TraRecord:
"""Transient data record in tradb."""

time: float #: Time in seconds
Expand All @@ -248,7 +254,7 @@ class TraRecord(NamedTuple):
samples: int #: Number of samples
data: np.ndarray #: Transient signal in volts or ADC values if `raw` = `True`
# optional for writing
status: HitFlags = HitFlags(0) #: Status flags
status: HitFlags = field(default=HitFlags(0)) #: Status flags
trai: int | None = None #: Transient recorder index (foreign key between pridb and tradb)
rms: float | None = None #: RMS of the noise before the hit
# optional
Expand Down Expand Up @@ -278,7 +284,8 @@ def from_sql(cls, row: dict[str, Any], *, raw: bool = False) -> "TraRecord":
)


class FeatureRecord(NamedTuple):
@dataclass
class FeatureRecord:
"""
Transient feature record in trfdb.
"""
Expand Down

0 comments on commit e7975e8

Please sign in to comment.