Skip to content

Commit

Permalink
Simplify path to saving to H/ZSPY from EBSD.save()
Browse files Browse the repository at this point in the history
Signed-off-by: Håkon Wiik Ånes <[email protected]>
  • Loading branch information
hakonanes committed Oct 28, 2024
1 parent b4fb343 commit 099349e
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 23 deletions.
13 changes: 3 additions & 10 deletions src/kikuchipy/io/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import kikuchipy.signals

PLUGINS: list = []
write_extensions = []
WRITE_EXTENSIONS = []
specification_paths = list(Path(__file__).parent.rglob("specification.yaml"))
for path in specification_paths:
with open(path) as file:
Expand All @@ -40,7 +40,7 @@
PLUGINS.append(spec)
if spec["writes"]:
for ext in spec["file_extensions"]:
write_extensions.append(ext)
WRITE_EXTENSIONS.append(ext)


if TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -399,17 +399,10 @@ def _save(
writer = plugin
break

if writer is None and ext.lower() in ["hspy", "zspy"]:
try:
super(type(signal), signal).save(filename, overwrite=overwrite, **kwargs)
return
except () as e:
raise ValueError("Attempt to write with HyperSpy failed:") from e

if writer is None:
raise ValueError(
f"{ext!r} does not correspond to any supported format. Supported file "
f"extensions are: {write_extensions!r}"
f"extensions are: {WRITE_EXTENSIONS!r}"
)
else:
sig_dim = signal.axes_manager.signal_dimension
Expand Down
23 changes: 14 additions & 9 deletions src/kikuchipy/signals/ebsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2604,8 +2604,7 @@ def save(
filename
If not given and 'tmp_parameters.filename' and
'tmp_parameters.folder' are defined in signal metadata, the
filename and path will be taken from there. Alternatively,
the extension can be passed to *extension*.
filename and path will be taken from there.
overwrite
If not given and the file exists, it will query the user. If
True (False) it (does not) overwrite the file if it exists.
Expand Down Expand Up @@ -2637,19 +2636,25 @@ def save(
This method is a modified version of HyperSpy's function
:meth:`hyperspy.signal.BaseSignal.save`.
"""
if filename is None:
if filename is not None:
fname = Path(filename)
else:
tmp_params = self.tmp_parameters
if tmp_params.has_item("filename") and tmp_params.has_item("folder"):
filename = os.path.join(tmp_params.folder, tmp_params.filename)
fname = Path(tmp_params.folder) / tmp_params.filename
extension = tmp_params.extension if not extension else extension
elif self.metadata.has_item("General.original_filename"):
filename = self.metadata.General.original_filename
fname = Path(self.metadata.General.original_filename)
else:
raise ValueError("Filename not defined")
raise ValueError("filename not given")
if extension is not None:
basename, _ = os.path.splitext(filename)
filename = basename + "." + extension
_save(filename, self, overwrite=overwrite, **kwargs)
ext = extension
else:
ext = fname.suffix[1:]
if ext.lower() in ["hspy", "zspy"]:
super().save(filename, overwrite=overwrite, **kwargs)
return
_save(fname, self, overwrite=overwrite, **kwargs)

def get_decomposition_model(
self,
Expand Down
10 changes: 10 additions & 0 deletions tests/test_io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import os

import hyperspy.api as hs
import numpy as np
import pytest

Expand Down Expand Up @@ -120,3 +121,12 @@ def test_save_to_existing_file(self, save_path_hdf5, kikuchipy_h5ebsd_path):
s.save(save_path_hdf5, scan_number=2, overwrite=False, add_scan=False)
with pytest.raises(OSError, match="Scan 'Scan 2' is not among the"):
_ = kp.load(save_path_hdf5, scan_group_names="Scan 2")


class TestHSPYWrite:
def test_hspy_write(self, tmpdir, dummy_signal):
file_path = str(tmpdir / "test.hspy")
s = dummy_signal.as_lazy()
s.save(file_path)
s2 = hs.load(file_path, signal_type="EBSD")
assert np.allclose(dummy_signal.data, s2.data)
7 changes: 3 additions & 4 deletions tests/test_io/test_kikuchipy_h5ebsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import dask.array as da
import h5py
from hyperspy.api import load as hs_load
import hyperspy.api as hs
import numpy as np
from orix.quaternion import Rotation
import pytest
Expand Down Expand Up @@ -174,8 +174,7 @@ def test_load_save_hyperspy_cycle(self, tmp_path, kikuchipy_h5ebsd_path):
s.save(file)

# Reload data and use HyperSpy's set_signal_type function
s_reload = hs_load(file)
s_reload.set_signal_type("EBSD")
s_reload = hs.load(file, signal_type="EBSD")

# Check signal type, patterns and learning results
assert isinstance(s_reload, kp.signals.EBSD)
Expand Down Expand Up @@ -255,7 +254,7 @@ def test_save_fresh(self, save_path_hdf5, tmp_path):

# Test writing of signal to file when no file name is passed to save()
del s.tmp_parameters.filename
with pytest.raises(ValueError, match="Filename not defined"):
with pytest.raises(ValueError, match="filename not given"):
s.save(overwrite=True)

s.metadata.General.original_filename = "an_original_filename"
Expand Down

0 comments on commit 099349e

Please sign in to comment.