Skip to content

Commit

Permalink
style: format with ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasberbuer committed Jul 22, 2024
1 parent 631bf90 commit 9e15c36
Show file tree
Hide file tree
Showing 18 changed files with 201 additions and 156 deletions.
10 changes: 5 additions & 5 deletions examples/ex1_read_pridb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
==========
"""

#%%
# %%
from pathlib import Path

import matplotlib.pyplot as plt
Expand All @@ -12,7 +12,7 @@
HERE = Path(__file__).parent if "__file__" in locals() else Path.cwd()
PRIDB = HERE / "steel_plate" / "sample.pridb"

#%%
# %%
# Open pridb
# ----------
pridb = vae.io.PriDatabase(PRIDB)
Expand All @@ -21,7 +21,7 @@
print("Number of rows in data table (ae_data): ", pridb.rows())
print("Set of all channels: ", pridb.channel())

#%%
# %%
# Read hits to Pandas DataFrame
# -----------------------------
df_hits = pridb.read_hits()
Expand All @@ -39,13 +39,13 @@
plt.tight_layout()
plt.show()

#%%
# %%
# Read markers
# ------------
df_markers = pridb.read_markers()
print(df_markers)

#%%
# %%
# Read parametric data
# --------------------
df_parametric = pridb.read_parametric()
Expand Down
19 changes: 11 additions & 8 deletions examples/ex3_timepicker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
SAMPLES = 2000


#%%
# %%
# Read waveform from tradb
# ------------------------
tradb = vae.io.TraDatabase(TRADB)
Expand All @@ -31,7 +31,8 @@
t *= 1e6 # convert to µs
y *= 1e3 # convert to mV

#%%

# %%
# Prepare plotting with time-picker results
# -----------------------------------------
def plot(t_wave, y_wave, y_picker, index_picker, name_picker):
Expand All @@ -49,38 +50,39 @@ def plot(t_wave, y_wave, y_picker, index_picker, name_picker):
plt.axvline(t_wave[index_picker], color="k", linestyle=":")
plt.show()

#%%

# %%
# Hinkley Criterion
# -----------------
hc_arr, hc_index = vae.timepicker.hinkley(y, alpha=5)
plot(t, y, hc_arr, hc_index, "Hinkley Criterion")

#%%
# %%
# The negative trend correlates to the chosen alpha value
# and can influence the results strongly.
# Results with **alpha = 50** (less negative trend):
hc_arr, hc_index = vae.timepicker.hinkley(y, alpha=50)
plot(t, y, hc_arr, hc_index, "Hinkley Criterion")

#%%
# %%
# Akaike Information Criterion (AIC)
# ----------------------------------
aic_arr, aic_index = vae.timepicker.aic(y)
plot(t, y, aic_arr, aic_index, "Akaike Information Criterion")

#%%
# %%
# Energy Ratio
# ------------
er_arr, er_index = vae.timepicker.energy_ratio(y)
plot(t, y, er_arr, er_index, "Energy Ratio")

#%%
# %%
# Modified Energy Ratio
# ---------------------
mer_arr, mer_index = vae.timepicker.modified_energy_ratio(y)
plot(t, y, mer_arr, mer_index, "Modified Energy Ratio")

#%%
# %%
# Performance comparison
# ----------------------
# All timepicker implementations are using Numba for just-in-time (JIT) compilations.
Expand All @@ -95,6 +97,7 @@ def timeit(func, loops=100):
func()
return 1e6 * (time.perf_counter() - time_start) / loops # elapsed time in µs


timer_results = {
"Hinkley": timeit(lambda: vae.timepicker.hinkley(y, 5)),
"AIC": timeit(lambda: vae.timepicker.aic(y)),
Expand Down
26 changes: 14 additions & 12 deletions examples/ex4_timepicker_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,21 @@
TRFDB = HERE / "steel_plate" / "sample.trfdb"
TRFDB_TMP = Path(gettempdir()) / "sample.trfdb"

#%%
# %%
# Open tradb (readonly) and trfdb (readwrite)
# -------------------------------------------
copyfile(TRFDB, TRFDB_TMP) # copy trfdb, so we don't overwrite it

tradb = vae.io.TraDatabase(TRADB)
trfdb = vae.io.TrfDatabase(TRFDB_TMP, mode="rw") # allow writing

#%%
# %%
# Read current trfdb
# ------------------
print(trfdb.read())

#%%

# %%
# Compute arrival time offsets with different timepickers
# -------------------------------------------------------
# To improve localisation, time of arrival estimates
Expand All @@ -50,7 +51,8 @@ def dt_from_timepicker(timepicker_func, tra: vae.io.TraRecord):
# Compute offset in µs
return (index_timepicker - index_ref) * 1e6 / tra.samplerate

#%%

# %%
# Transient data is streamed from the database row by row using `vallenae.io.TraDatabase.iread`.
# Only one transient data set is loaded into memory at a time.
# That makes the streaming interface ideal for batch processing.
Expand All @@ -65,25 +67,25 @@ def dt_from_timepicker(timepicker_func, tra: vae.io.TraRecord):
"ATO_AIC": dt_from_timepicker(vae.timepicker.aic, tra),
"ATO_ER": dt_from_timepicker(vae.timepicker.energy_ratio, tra),
"ATO_MER": dt_from_timepicker(vae.timepicker.modified_energy_ratio, tra),
}
},
)
# Save results to trfdb
trfdb.write(feature_set)

#%%
# %%
# Read results from trfdb
# -----------------------
print(trfdb.read().filter(regex="ATO"))

#%%
# %%
# Plot results
# ------------
ax = trfdb.read()[["ATO_Hinkley", "ATO_AIC", "ATO_ER", "ATO_MER"]].plot.barh()
ax.invert_yaxis()
ax.set_xlabel("Arrival time offset [µs]")
plt.show()

#%%
# %%
# Plot waveforms and arrival times
# --------------------------------
_, axes = plt.subplots(4, 1, tight_layout=True, figsize=(8, 8))
Expand All @@ -108,7 +110,7 @@ def dt_from_timepicker(timepicker_func, tra: vae.io.TraRecord):
axes[0].legend(["Waveform", "Hinkley", "AIC", "ER", "MER"])
plt.show()

#%%
# %%
# Use results in VisualAE
# -----------------------
# The computed arrival time offsets can be directly used in VisualAE.
Expand All @@ -117,11 +119,11 @@ def dt_from_timepicker(timepicker_func, tra: vae.io.TraRecord):
# Field infos can be retrieved with `vallenae.io.TrfDatabase.fieldinfo`:
print(trfdb.fieldinfo())

#%%
# %%
# Show results as table:
print(pd.DataFrame(trfdb.fieldinfo()))

#%%
# %%
# Write units to trfdb
# ~~~~~~~~~~~~~~~~~~~~
# Field infos can be written with `vallenae.io.TrfDatabase.write_fieldinfo`:
Expand All @@ -134,7 +136,7 @@ def dt_from_timepicker(timepicker_func, tra: vae.io.TraRecord):
print(pd.DataFrame(trfdb.fieldinfo()).filter(regex="ATO"))


#%%
# %%
# Load results in VisualAE
# ~~~~~~~~~~~~~~~~~~~~~~~~
# Time arrival offsets can be specified in the settings of `Location Processors` - `Channel Positions` - `Arrival Time Offset`.
Expand Down
3 changes: 2 additions & 1 deletion examples/ex5_location.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def get_channel_positions(setup_file: str) -> Dict[int, Tuple[float, float]]:
raise RuntimeError("Can not retrieve channel positions from %s", setup_file)
return {
int(elem.get("Chan")): (float(elem.get("X")), float(elem.get("Y"))) # type: ignore
for elem in nodes if elem is not None
for elem in nodes
if elem is not None
}


Expand Down
17 changes: 10 additions & 7 deletions examples/ex6_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
HERE = Path(__file__).parent if "__file__" in locals() else Path.cwd()
TRADB = HERE / "steel_plate" / "sample_plain.tradb"

#%%

# %%
# Prepare streaming reads
# -----------------------
# Our sample tradb only contains four data sets. That is not enough data for demonstrating batch processing.
Expand All @@ -33,15 +34,16 @@ def tra_generator(loops: int = 1000) -> Iterable[vae.io.TraRecord]:
break
yield tra

#%%

# %%
# Define feature extraction function
# ----------------------------------
# A simple function from the module `_feature_extraction` is applied to all data sets and returns computed features.
# The function is defined in another module to work with `multiprocessing.Pool`: https://bugs.python.org/issue25053
from __feature_extraction import feature_extraction # noqa


#%%
# %%
# Compute with single thread/core
# -------------------------------
# .. note::
Expand All @@ -53,6 +55,7 @@ def tra_generator(loops: int = 1000) -> Iterable[vae.io.TraRecord]:
def time_elapsed_ms(t0):
return 1000.0 * (time.perf_counter() - t0)


if __name__ == "__main__": # guard needed for multiprocessing on Windows
time_start = time.perf_counter()
for tra in tra_generator():
Expand All @@ -62,13 +65,13 @@ def time_elapsed_ms(t0):

print(f"Time single thread: {time_single_thread:.2f} ms")

#%%
# %%
# Compute with multiple processes/cores
# -------------------------------------
# First get number of available cores in your machine:
print(f"Available CPU cores: {os.cpu_count()}")

#%%
# %%
# But how can we utilize those cores? The common answer for most programming languages is multithreading.
# Threads run in the same process and heap, so data can be shared between them (with care).
# Sadly, Python uses a global interpreter lock (GIL) that locks heap memory, because Python objects are not thread-safe.
Expand All @@ -78,7 +81,7 @@ def time_elapsed_ms(t0):
# Multiprocessing will introduce overhead for interprocess communication and data serialization/deserialization.
# To reduce the overhead, data is sent in bigger chunks.

#%%
# %%
# Run computation on 4 cores with chunks of 128 data sets and get the time / speedup:
if __name__ == "__main__": # guard needed for multiprocessing on Windows
with multiprocessing.Pool(4) as pool:
Expand All @@ -90,7 +93,7 @@ def time_elapsed_ms(t0):
print(f"Time multiprocessing: {time_multiprocessing:.2f} ms")
print(f"Speedup: {(time_single_thread / time_multiprocessing):.2f}")

#%%
# %%
# Variation of the chunksize
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# Following results show how the chunksize impacts the overall performance.
Expand Down
27 changes: 16 additions & 11 deletions examples/ex7_custom_feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@
# TRFDB = HERE / "bearing" / "bearing.trfdb"
TRFDB_TMP = Path(gettempdir()) / "bearing_custom.trfdb" # use a temp file for demo

#%%

# %%
# Custom feature extraction algorithms
# ------------------------------------
def rms(data: np.ndarray) -> float:
"""Root mean square (RMS)."""
return np.sqrt(np.mean(data ** 2))
return np.sqrt(np.mean(data**2))


def crest_factor(data: np.ndarray) -> float:
Expand All @@ -47,27 +48,31 @@ def spectral_peak_frequency(spectrum_: np.ndarray, samplerate: int) -> float:
Returns:
Peak frequency in Hz
"""

def bin_to_hz(samplerate: int, samples: int, index: int):
return 0.5 * samplerate * index / (samples - 1)

peak_index = np.argmax(spectrum_)
return bin_to_hz(samplerate, len(spectrum_), peak_index)

#%%

# %%
# Open tradb and trfdb
# --------------------
tradb = vae.io.TraDatabase(TRADB)
trfdb = vae.io.TrfDatabase(TRFDB_TMP, mode="rwc")

#%%

# %%
# Helper function to notify VisualAE, that the transient feature database is active/closed
def set_file_status(trfdb_: vae.io.TrfDatabase, status: int):
"""Notify VisualAE that trfdb is active/closed."""
trfdb_.connection().execute(
f"UPDATE trf_globalinfo SET Value = {status} WHERE Key == 'FileStatus'"
)

#%%

# %%
# Read tra records, compute features and save to trfdb
# ----------------------------------------------------
# The `vallenae.io.TraDatabase.listen` method will read the tradb row by row and can be used during
Expand All @@ -83,27 +88,27 @@ def set_file_status(trfdb_: vae.io.TrfDatabase, status: int):
"RMS": rms(tra.data),
"CrestFactor": crest_factor(tra.data),
"SpectralPeakFreq": spectral_peak_frequency(spectrum, tra.samplerate),
}
},
)
trfdb.write(features)

set_file_status(trfdb, 0) # 0 = closed

#%%
# %%
# Write field infos to trfdb
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# Field infos can be written with `vallenae.io.TrfDatabase.write_fieldinfo`:
trfdb.write_fieldinfo("RMS", {"Unit": "[V]", "LongName": "Root mean square"})
trfdb.write_fieldinfo("CrestFactor", {"Unit": "[]", "LongName": "Crest factor"})
trfdb.write_fieldinfo("SpectralPeakFreq", {"Unit": "[Hz]", "LongName": "Spectral peak frequency"})

#%%
# %%
# Read results from trfdb
# -----------------------
df_trfdb = trfdb.read()
print(df_trfdb)

#%%
# %%
# Plot AE features and custom features
# ------------------------------------
# Read pridb and join it with trfdb:
Expand All @@ -113,7 +118,7 @@ def set_file_status(trfdb_: vae.io.TrfDatabase, status: int):
df_combined = df_pridb.join(df_trfdb, on="trai", how="left")
print(df_combined)

#%%
# %%
# Plot joined features from pridb and trfdb
features = [
# from pridb
Expand All @@ -138,7 +143,7 @@ def set_file_status(trfdb_: vae.io.TrfDatabase, status: int):
plt.tight_layout()
plt.show()

#%%
# %%
# Display custom features in VisualAE
# -----------------------------------
#
Expand Down
Loading

0 comments on commit 9e15c36

Please sign in to comment.