From ca58d4ef77f32036bd8fa7faa12bc3d32ccdba83 Mon Sep 17 00:00:00 2001 From: Jonathan Schilling Date: Thu, 6 Feb 2025 15:31:33 +0100 Subject: [PATCH] Add a method to load a VmecWout object from a NetCDF file. --- src/vmecpp/__init__.py | 26 +++++++++++++++++++++++++- tests/test_init.py | 6 +++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/src/vmecpp/__init__.py b/src/vmecpp/__init__.py index 595f889e..db52934a 100644 --- a/src/vmecpp/__init__.py +++ b/src/vmecpp/__init__.py @@ -830,7 +830,31 @@ def _to_cpp_wout(self) -> _vmecpp.WOutFileContents: return cpp_wout - # TODO(eguiraud): implement from_wout_file + @staticmethod + def from_wout_file(wout_filename: str | Path) -> VmecWOut: + """Load wout contents in NetCDF format. + + This is the format used by Fortran VMEC implementations and the one expected by + SIMSOPT. + """ + with netCDF4.Dataset(wout_filename, "r") as fnc: + fnc.set_auto_mask(False) + attrs = {} + for key in fnc.variables: + if key.endswith("__logical__"): + attrs[key[:-11]] = fnc[key][()] != 0 + elif key == "volume_p": + attrs["volume"] = fnc[key][()] + elif key in ["xm", "xn", "xm_nyq", "xn_nyq"]: + attrs[key] = np.array(fnc[key][()], dtype=int) + elif key in ["pmass_type", "piota_type", "pcurr_type", "mgrid_file"]: + attrs[key] = fnc[key][()].tobytes().decode("ascii") + else: + attrs[key] = fnc[key][()] + if "lmns_full" not in fnc.variables: + attrs["lmns_full"] = None + return VmecWOut(**attrs) + raise RuntimeError("Failed to load NetCDF wout file " + str(wout_filename)) class Threed1Volumetrics(pydantic.BaseModel): diff --git a/tests/test_init.py b/tests/test_init.py index b44fe1f9..a49109f6 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -88,12 +88,16 @@ def cma_output() -> vmecpp.VmecOutput: return vmec_output -def test_vmecwout_save(cma_output): +def test_vmecwout_io(cma_output): with tempfile.NamedTemporaryFile() as tmp_file: cma_output.wout.save(tmp_file.name) assert Path(tmp_file.name).exists() + # check that from_wout_file can load the file as well + loaded_wout = vmecpp.VmecWOut.from_wout_file(tmp_file.name) + assert loaded_wout is not None + test_dataset = netCDF4.Dataset(tmp_file.name, "r") expected_dataset = netCDF4.Dataset(TEST_DATA_DIR / "wout_cma.nc", "r")