From 92d3ccfb9f2549d816cb2546d7aa4c033dffadd7 Mon Sep 17 00:00:00 2001 From: Lily Wang Date: Thu, 17 Oct 2024 13:46:47 +1100 Subject: [PATCH 1/5] add failing test --- openff/nagl/tests/nn/test_model.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/openff/nagl/tests/nn/test_model.py b/openff/nagl/tests/nn/test_model.py index 1e04c72..323bc27 100644 --- a/openff/nagl/tests/nn/test_model.py +++ b/openff/nagl/tests/nn/test_model.py @@ -426,6 +426,12 @@ def test_compute_property( assert_allclose(charges, expected_charges, atol=1e-5) + def test_lookup_table_before_chemical_domain(self, model): + hcl = Molecule.from_mapped_smiles("[Cl:1][H:2]") + expected_charges = [-0.1680, 0.1680] + charges = model.compute_property(hcl, as_numpy=True, check_lookup_table=True) + assert_allclose(charges, expected_charges, atol=1e-5) + @pytest.mark.xfail(reason="Model does not include 0 bonds as feature") def test_assign_partial_charges_to_ion(self, model): mol = Molecule.from_smiles("[Cl-]") From 85aebf605ed0ebdc51c5af95f247f45ef10038d4 Mon Sep 17 00:00:00 2001 From: Lily Wang Date: Thu, 17 Oct 2024 15:03:23 +1100 Subject: [PATCH 2/5] fix failing test --- openff/nagl/tests/nn/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openff/nagl/tests/nn/test_model.py b/openff/nagl/tests/nn/test_model.py index 323bc27..05e72ba 100644 --- a/openff/nagl/tests/nn/test_model.py +++ b/openff/nagl/tests/nn/test_model.py @@ -429,7 +429,7 @@ def test_compute_property( def test_lookup_table_before_chemical_domain(self, model): hcl = Molecule.from_mapped_smiles("[Cl:1][H:2]") expected_charges = [-0.1680, 0.1680] - charges = model.compute_property(hcl, as_numpy=True, check_lookup_table=True) + charges = model.compute_property(hcl, as_numpy=True, check_domains=True) assert_allclose(charges, expected_charges, atol=1e-5) @pytest.mark.xfail(reason="Model does not include 0 bonds as feature") From 145352c3f6b1a18ea8631cd32774c415e9349eaf Mon Sep 17 00:00:00 2001 From: Lily Wang Date: Thu, 17 Oct 2024 13:54:47 +1100 Subject: [PATCH 3/5] swap order of checking look up tables and domains --- openff/nagl/nn/_models.py | 42 ++++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/openff/nagl/nn/_models.py b/openff/nagl/nn/_models.py index 6e6d160..969bfbc 100644 --- a/openff/nagl/nn/_models.py +++ b/openff/nagl/nn/_models.py @@ -278,24 +278,13 @@ def _compute_properties( ------- result: Dict[str, torch.Tensor] or Dict[str, numpy.ndarray] """ - if check_domains: - is_supported, error = self.chemical_domain.check_molecule( - molecule, return_error_message=True - ) - if not is_supported: - if error_if_unsupported: - raise ValueError(error) - else: - warnings.warn(error) - try: - values = self._compute_properties_dgl(molecule) - except (MissingOptionalDependencyError, TypeError): - values = self._compute_properties_nagl(molecule) - + values = {} + + expected_value_keys = list(self.readout_modules.keys()) + if check_lookup_table and self.lookup_tables: - property_names = list(values) - for property_name in property_names: + for property_name in expected_value_keys: try: value = self._check_property_lookup_table( molecule=molecule, @@ -312,6 +301,27 @@ def _compute_properties( ) values[property_name] = value + + computed_value_keys = set(values.keys()) + if computed_value_keys == set(expected_value_keys): + return values + + if check_domains: + is_supported, error = self.chemical_domain.check_molecule( + molecule, return_error_message=True + ) + if not is_supported: + if error_if_unsupported: + raise ValueError(error) + else: + warnings.warn(error) + + try: + values = self._compute_properties_dgl(molecule) + except (MissingOptionalDependencyError, TypeError): + values = self._compute_properties_nagl(molecule) + + if as_numpy: values = {k: v.detach().numpy().flatten() for k, v in values.items()} return values From 1962bba352342c7f78f42f8ea1c4dc6fa4006f66 Mon Sep 17 00:00:00 2001 From: Lily Wang Date: Thu, 17 Oct 2024 14:55:28 +1100 Subject: [PATCH 4/5] fix typing --- openff/nagl/nn/_models.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/openff/nagl/nn/_models.py b/openff/nagl/nn/_models.py index 969bfbc..a451af6 100644 --- a/openff/nagl/nn/_models.py +++ b/openff/nagl/nn/_models.py @@ -217,6 +217,7 @@ def compute_properties( else: tensor = torch.empty for property_name, value in results[0].items(): + print(as_numpy, value, value.dtype) combined_results[property_name] = tensor( molecule.n_atoms, dtype=value.dtype @@ -304,6 +305,8 @@ def _compute_properties( computed_value_keys = set(values.keys()) if computed_value_keys == set(expected_value_keys): + if as_numpy: + values = {k: v.detach().numpy().flatten() for k, v in values.items()} return values if check_domains: From 48cca1e1b64b0bead7853c611f5a4c662c93b51b Mon Sep 17 00:00:00 2001 From: Lily Wang Date: Thu, 17 Oct 2024 16:44:47 +1100 Subject: [PATCH 5/5] update changelog [skip ci] --- docs/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 20ee08a..e352236 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -26,6 +26,9 @@ The rules for this file: ### Changed - Removed unused, undocumented code paths, and updated docs (PR #132) +### Fixed +- Check lookup tables for allowed molecules before ChemicalDomain for forbidden ones (PR #145, Issue #144) + ## v0.4.0 -- 2024-07-18