From 5af8dc2fe7e7f84660e9e481897de7540a684c7f Mon Sep 17 00:00:00 2001 From: David Wallace Date: Sun, 10 Mar 2024 21:39:23 +0100 Subject: [PATCH] tests: add test for base peak --- src/raman_fitting/config/base_settings.py | 8 +++++ .../models/deconvolution/base_peak.py | 34 ------------------- tests/conftest.py | 5 +++ tests/models/test_base_peak.py | 34 +++++++++++++++++++ 4 files changed, 47 insertions(+), 34 deletions(-) create mode 100644 tests/models/test_base_peak.py diff --git a/src/raman_fitting/config/base_settings.py b/src/raman_fitting/config/base_settings.py index 424fb6e..f8b6d6b 100644 --- a/src/raman_fitting/config/base_settings.py +++ b/src/raman_fitting/config/base_settings.py @@ -16,6 +16,7 @@ ) from .default_models import load_config_from_toml_files from .path_settings import create_default_package_dir_or_ask, InternalPathSettings +from types import MappingProxyType def get_default_models_and_peaks_from_definitions(): @@ -36,5 +37,12 @@ class Settings(BaseSettings): init_var=False, validate_default=False, ) + default_definitions: MappingProxyType | None = Field( + default_factory=load_config_from_toml_files, + alias="my_default_definitions", + init_var=False, + validate_default=False, + ) + destination_dir: Path = Field(default_factory=create_default_package_dir_or_ask) internal_paths: InternalPathSettings = Field(default_factory=InternalPathSettings) diff --git a/src/raman_fitting/models/deconvolution/base_peak.py b/src/raman_fitting/models/deconvolution/base_peak.py index 4d05344..4649b34 100644 --- a/src/raman_fitting/models/deconvolution/base_peak.py +++ b/src/raman_fitting/models/deconvolution/base_peak.py @@ -217,37 +217,3 @@ def get_peaks_from_peak_definitions( for peak_name, peak_def in peak_type_defs.items(): peak_models[peak_name] = BasePeak(**peak_def) return peak_models - - -def _main(): - model_definitions = load_config_from_toml_files() - print(model_definitions["first_order"]["models"]) - peaks = {} - peak_items = { - **model_definitions["first_order"]["peaks"], - **model_definitions["second_order"]["peaks"], - }.items() - for k, v in peak_items: - peaks.update({k: BasePeak(**v)}) - - peak_d = BasePeak(**model_definitions["first_order"]["peaks"]["D"]) - print(peak_d) - model_items = { - **model_definitions["first_order"]["models"], - **model_definitions["second_order"]["models"], - }.items() - models = {} - for model_name, model_comp in model_items: - print(k, v) - comps = model_comp.split("+") - peak_comps = [peaks[i] for i in comps] - lmfit_comp_model = sum( - map(lambda x: x.lmfit_model, peak_comps), peak_comps.pop().lmfit_model - ) - models[model_name] = lmfit_comp_model - print(lmfit_comp_model) - # breakpoint() - - -if __name__ == "__main__": - _main() diff --git a/tests/conftest.py b/tests/conftest.py index cef9296..9f95487 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,6 +28,11 @@ def example_files(internal_paths): return example_files +@pytest.fixture(autouse=True) +def default_definitions(internal_paths): + return settings.default_definitions + + @pytest.fixture(autouse=True) def default_models(internal_paths): return settings.default_models diff --git a/tests/models/test_base_peak.py b/tests/models/test_base_peak.py new file mode 100644 index 0000000..4be455c --- /dev/null +++ b/tests/models/test_base_peak.py @@ -0,0 +1,34 @@ +from raman_fitting.models.deconvolution.base_peak import BasePeak + + +def test_initialize_base_peaks( + default_definitions, default_models_first_order, default_models_second_order +): + peaks = {} + + peak_items = { + **default_definitions["first_order"]["peaks"], + **default_definitions["second_order"]["peaks"], + }.items() + for k, v in peak_items: + peaks.update({k: BasePeak(**v)}) + + peak_d = BasePeak(**default_definitions["first_order"]["peaks"]["D"]) + assert ( + peak_d.peak_name + == default_definitions["first_order"]["peaks"]["D"]["peak_name"] + ) + assert ( + peak_d.peak_type + == default_definitions["first_order"]["peaks"]["D"]["peak_type"] + ) + assert ( + peak_d.lmfit_model.components[0].prefix + == default_definitions["first_order"]["peaks"]["D"]["peak_name"] + "_" + ) + assert ( + peak_d.param_hints["center"].value + == default_definitions["first_order"]["peaks"]["D"]["param_hints"]["center"][ + "value" + ] + )