From f2533ea0e13dc60ef4ed6571df6cd52b76aa3bad Mon Sep 17 00:00:00 2001 From: danielsf Date: Tue, 10 Sep 2024 12:00:49 -0700 Subject: [PATCH] update_uns can now handle nested keys --- src/cell_type_mapper/utils/anndata_utils.py | 88 ++++++++++--- tests/utils/test_anndata_utils.py | 131 +++++++++++++++++++- 2 files changed, 203 insertions(+), 16 deletions(-) diff --git a/src/cell_type_mapper/utils/anndata_utils.py b/src/cell_type_mapper/utils/anndata_utils.py index f555386d..d20439fb 100644 --- a/src/cell_type_mapper/utils/anndata_utils.py +++ b/src/cell_type_mapper/utils/anndata_utils.py @@ -74,21 +74,79 @@ def update_uns(h5ad_path, new_uns, clobber=False): Otherwise, raise an exception of there are duplicate keys. """ - uns = read_uns_from_h5ad(h5ad_path) - if not clobber: - new_keys = set(new_uns.keys()) - old_keys = set(uns.keys()) - duplicates = new_keys.intersection(old_keys) - if len(duplicates) > 0: - duplicates = list(duplicates) - duplicates.sort() - msg = ( - "Cannot update uns. The following keys already exist:\n" - f"{duplicates}" - ) - raise RuntimeError(msg) - uns.update(new_uns) - write_uns_to_h5ad(h5ad_path, uns_value=uns) + updated_uns = read_uns_from_h5ad(h5ad_path) + + for new_key in new_uns: + updated_uns = _update_uns_key( + old_uns=updated_uns, + key=new_key, + new_data=new_uns[new_key], + clobber=clobber + ) + + write_uns_to_h5ad(h5ad_path, uns_value=updated_uns) + + +def _update_uns_key( + old_uns, + key, + new_data, + clobber, + key_str=None): + """ + Update one key, value pair in the uns dict. + + Parameters + ---------- + old_uns: + a dict. The uns dict being updated + key: + any. The key in old_uns being updated + new_data: + any. The data that goes with key in the new uns dict + clobber: + a boolean. If True and key already exists in old_uns, overwrite. + If False and key already exists in old_uns, raise an exception. + key_str: + a str. Used to keep track of the chain of nested keys we + are updating for more helpful error messages) + + Returns + ------- + The updated old_uns dict (note that old_uns is also updated in place) + """ + if key_str is None: + key_str = f'{key}' + else: + key_str = f'{key_str}:{key}' + + if key not in old_uns: + old_uns[key] = new_data + else: + if isinstance(new_data, dict): + if not isinstance(old_uns[key], dict): + raise RuntimeError( + f"Cannot update uns. '{key_str}' points to a dict in the " + "new data, but not in the original data" + ) + for inner_key in new_data: + old_uns[key] = _update_uns_key( + old_uns=old_uns[key], + key=inner_key, + new_data=new_data[inner_key], + clobber=clobber, + key_str=key_str + ) + else: + if not clobber: + msg = ( + "Cannot update uns. The following key already exists:\n" + f"{key_str}" + ) + raise RuntimeError(msg) + old_uns[key] = new_data + + return old_uns def does_obsm_have_key(h5ad_path, obsm_key): diff --git a/tests/utils/test_anndata_utils.py b/tests/utils/test_anndata_utils.py index 58ada4f1..f41cd727 100644 --- a/tests/utils/test_anndata_utils.py +++ b/tests/utils/test_anndata_utils.py @@ -279,9 +279,18 @@ def test_update_uns(tmp_dir_fixture, which_test): assert actual.uns[k] == new_uns[k] elif which_test == 'error': - with pytest.raises(RuntimeError, match="keys already exist"): + with pytest.raises(RuntimeError, match="key already exists"): update_uns(h5ad_path, new_uns={'a':2, 'f': 6}, clobber=False) + # make sure uns was unchanged + actual = anndata.read_h5ad(h5ad_path) + assert set(actual.uns.keys()) == set(original_uns.keys()) + for k in original_uns: + if isinstance(actual.uns[k], np.ndarray): + np.testing.assert_array_equal(actual.uns[k], original_uns[k]) + else: + assert actual.uns[k] == original_uns[k] + elif which_test == 'clobber': update_uns(h5ad_path, new_uns={'a': 2, 'f': 6}, clobber=True) actual = anndata.read_h5ad(h5ad_path, backed='r') @@ -293,6 +302,126 @@ def test_update_uns(tmp_dir_fixture, which_test): else: raise RuntimeError(f"cannot parse which_test = {which_test}") + +@pytest.mark.parametrize( + 'clobber', [True, ] +) +def test_compound_update_uns(tmp_dir_fixture, clobber): + + original_uns = { + 'a': 1, + 'b': 9, + 'c': { + 'd': 4, + 'other_dict': { + 'z': 88, + 'u': 14 + } + } + } + + a_data = anndata.AnnData( + uns=original_uns) + h5ad_path = mkstemp_clean( + dir=tmp_dir_fixture, + prefix='update_uns_', + suffix='.h5ad') + a_data.write_h5ad(h5ad_path) + + new_uns = { + 'c': { + 'e': 2, + 'f': 3, + 'other_dict': { + 'x': 17, + }, + 'still_another_dict': { + 'y': 55 + } + } + } + + update_uns(h5ad_path, new_uns=new_uns, clobber=False) + roundtrip = anndata.read_h5ad(h5ad_path) + expected = { + 'a': 1, + 'b': 9, + 'c': { + 'd': 4, + 'e': 2, + 'f': 3, + 'other_dict': { + 'z': 88, + 'x': 17, + 'u': 14 + }, + 'still_another_dict': { + 'y': 55 + } + } + } + assert roundtrip.uns == expected + + a_data = anndata.AnnData( + uns=original_uns) + h5ad_path = mkstemp_clean( + dir=tmp_dir_fixture, + prefix='update_uns_', + suffix='.h5ad') + a_data.write_h5ad(h5ad_path) + + new_uns = { + 'c': { + 'd': 3, + 'e': 4, + 'still_another_dict': { + 'y': 45 + }, + 'other_dict': { + 'x': 66, + 'z': 13 + } + } + } + + if clobber: + update_uns(h5ad_path, new_uns=new_uns, clobber=clobber) + expected = { + 'a': 1, + 'b': 9, + 'c': { + 'd': 3, + 'e': 4, + 'other_dict': { + 'z': 13, + 'x': 66, + 'u': 14 + }, + 'still_another_dict': { + 'y': 45 + } + } + } + roundtrip = anndata.read_h5ad(h5ad_path) + if roundtrip.uns != expected: + import json + from cell_type_mapper.utils.utils import ( + clean_for_json + ) + msg = ( + f"{json.dumps(clean_for_json(roundtrip.uns), indent=2)}\n" + "=======n" + f"{json.dumps(clean_for_json(expected), indent=2)}\n" + ) + raise RuntimeError(msg) + else: + with pytest.raises(RuntimeError, match="key already exists"): + update_uns(h5ad_path, new_uns=new_uns, clobber=clobber) + # make sure uns was unchanged + actual = anndata.read_h5ad(h5ad_path) + assert actual.uns == original_uns + + def test_read_empty_uns(tmp_dir_fixture): """ Make sure that reading uns from an h5ad file that