From 26f7cd3a77f71c1dd1f49d29eabf64258f4bea8c Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Mon, 8 Jan 2024 20:29:52 +0100 Subject: [PATCH] Add `material_id: str | int` to `PhononBSDOSDoc` schema (#669) * add material_id: str | int to PhononBSDOSDoc schema * set extra="allow" on PhononBSDOSDoc, drop material_id --- src/atomate2/common/schemas/phonons.py | 5 ++-- tests/common/schemas/test_elastic.py | 10 +++----- tests/common/schemas/test_phonons.py | 9 ++++++++ tests/vasp/flows/test_phonons.py | 32 ++++---------------------- 4 files changed, 18 insertions(+), 38 deletions(-) diff --git a/src/atomate2/common/schemas/phonons.py b/src/atomate2/common/schemas/phonons.py index 357fcc4289..e0ac0ddacd 100644 --- a/src/atomate2/common/schemas/phonons.py +++ b/src/atomate2/common/schemas/phonons.py @@ -100,12 +100,11 @@ class PhononJobDirs(BaseModel): ) -class PhononBSDOSDoc(StructureMetadata): +class PhononBSDOSDoc(StructureMetadata, extra="allow"): # type: ignore[call-arg] """Collection of all data produced by the phonon workflow.""" structure: Optional[Structure] = Field( - None, - description="Structure of Materials Project.", + None, description="Structure of Materials Project." ) phonon_bandstructure: Optional[PhononBandStructureSymmLine] = Field( diff --git a/tests/common/schemas/test_elastic.py b/tests/common/schemas/test_elastic.py index 4c0314a4ab..f218e83034 100644 --- a/tests/common/schemas/test_elastic.py +++ b/tests/common/schemas/test_elastic.py @@ -16,18 +16,14 @@ def test_elastic_document(test_dir): schema_ref = json.loads(schema_path.read_text()) doc = ElasticDocument(**schema_ref) - ElasticDocument.model_validate_json(json.dumps(doc, cls=MontyEncoder)) + validated = ElasticDocument.model_validate_json(json.dumps(doc, cls=MontyEncoder)) + assert isinstance(validated, ElasticDocument) # schemas where all fields have default values @pytest.mark.parametrize( "model_cls", - [ - ElasticDocument, - ElasticTensorDocument, - DerivedProperties, - FittingData, - ], + [ElasticDocument, ElasticTensorDocument, DerivedProperties, FittingData], ) def test_model_validate(model_cls): model_cls.model_validate_json(json.dumps(model_cls(), cls=MontyEncoder)) diff --git a/tests/common/schemas/test_phonons.py b/tests/common/schemas/test_phonons.py index 4cd02e69bd..649aa28c23 100644 --- a/tests/common/schemas/test_phonons.py +++ b/tests/common/schemas/test_phonons.py @@ -40,9 +40,18 @@ def test_phonon_bs_dos_doc(): validated = PhononBSDOSDoc.model_validate_json(json.dumps(doc, cls=MontyEncoder)) assert isinstance(validated, PhononBSDOSDoc) + # test invalid supercell_matrix type fails with pytest.raises(ValidationError): doc = PhononBSDOSDoc(**kwargs | {"supercell_matrix": (1, 1, 1)}) + # test optional material_id + doc = PhononBSDOSDoc(**kwargs | {"material_id": 1234}) + assert doc.material_id == 1234 + + # test extra="allow" option + doc = PhononBSDOSDoc(**kwargs | {"extra_field": "test"}) + assert doc.extra_field == "test" + # schemas where all fields have default values @pytest.mark.parametrize("model_cls", [PhononJobDirs, PhononUUIDs]) diff --git a/tests/vasp/flows/test_phonons.py b/tests/vasp/flows/test_phonons.py index b57c3e7cfb..fbb6da4758 100644 --- a/tests/vasp/flows/test_phonons.py +++ b/tests/vasp/flows/test_phonons.py @@ -180,44 +180,20 @@ def test_phonon_wf_only_displacements_no_structural_transformation( assert_allclose( responses[job.jobs[-1].uuid][1].output.free_energies, - [ - 5774.56699647, - 5616.29786373, - 4724.73684926, - 3044.19341280, - 696.34353154, - ], + [5774.56699647, 5616.29786373, 4724.73684926, 3044.19341280, 696.34353154], ) assert_allclose( responses[job.jobs[-1].uuid][1].output.entropies, - [ - 0.0, - 4.78666294, - 13.02533234, - 20.36075467, - 26.39807246, - ], + [0.0, 4.78666294, 13.02533234, 20.36075467, 26.39807246], ) assert_allclose( responses[job.jobs[-1].uuid][1].output.heat_capacities, - [ - 0.0, - 8.04749769, - 15.97101906, - 19.97032648, - 21.87475268, - ], + [0.0, 8.04749769, 15.97101906, 19.97032648, 21.87475268], ) assert_allclose( responses[job.jobs[-1].uuid][1].output.internal_energies, - [ - 5774.56699647, - 6094.96415750, - 7329.80331668, - 9152.41981241, - 11255.57251541, - ], + [5774.56699647, 6094.96415750, 7329.80331668, 9152.41981241, 11255.57251541], ) assert isinstance(