diff --git a/.github/workflows/Lint.yml b/.github/workflows/Lint.yml index bc51d594..61b300eb 100644 --- a/.github/workflows/Lint.yml +++ b/.github/workflows/Lint.yml @@ -17,6 +17,8 @@ jobs: python-version: "3.8" - name: Install black run: pip install "black>=22.1.0,<23.0a0" + - name: Print code formatting with black + run: black --diff . - name: Check code formatting with black run: black --check . diff --git a/docs/changelog.rst b/docs/changelog.rst index badf58f9..b55c6acb 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -37,6 +37,10 @@ New Features Enhancements ++++++++++++ +- (536b) ``v1.AtomicResult.convert_v`` learned a ``external_input_data`` option to inject that field (if known) rather than using incomplete reconstruction from the v1 Result. may not be the final sol'n. +- (536b) ``v2.FailedOperation`` gained schema_name and schema_version=2. +- (536b) ``v2.AtomicResult`` no longer inherits from ``v2.AtomicInput``. It gained a ``input_data`` field for the corresponding ``AtomicInput`` and independent ``id`` and ``molecule`` fields (the latter being equivalvent to ``v1.AtomicResult.molecule`` with the frame of the results; ``v2.AtomicResult.input_data.molecule`` is new, preserving the input frame). Gained independent ``extras`` +- (536b) Both v1/v2 ``AtomicResult.convert_v()`` learned to handle the new ``input_data`` layout. - (:pr:`357`, :issue:`536`) ``v2.AtomicResult``, ``v2.OptimizationResult``, and ``v2.TorsionDriveResult`` have the ``success`` field enforced to ``True``. Previously it could be set T/F. Now validation errors if not T. Likewise ``v2.FailedOperation.success`` is enforced to ``False``. - (:pr:`357`, :issue:`536`) ``v2.AtomicResult``, ``v2.OptimizationResult``, and ``v2.TorsionDriveResult`` have the ``error`` field removed. This isn't used now that ``success=True`` and failure should be routed to ``FailedOperation``. - (:pr:`357`) ``v1.Molecule`` had its schema_version changed to a Literal[2] (remember Mol is one-ahead of general numbering scheme) so new instances will be 2 even if another value is passed in. Ditto ``v2.BasisSet.schema_version=2``. Ditto ``v1.BasisSet.schema_version=1`` Ditto ``v1.QCInputSpecification.schema_version=1`` and ``v1.OptimizationSpecification.schema_version=1``. diff --git a/qcelemental/models/v1/procedures.py b/qcelemental/models/v1/procedures.py index f8ae852d..7a7eaa76 100644 --- a/qcelemental/models/v1/procedures.py +++ b/qcelemental/models/v1/procedures.py @@ -180,8 +180,7 @@ def convert_v( dself = self.dict() if version == 2: # remove harmless empty error field that v2 won't accept. if populated, pydantic will catch it. - if dself.pop("error", None): - pass + dself.pop("error", None) dself["trajectory"] = [trajectory_class(**atres).convert_v(version) for atres in dself["trajectory"]] dself["input_specification"].pop("schema_version", None) @@ -356,11 +355,12 @@ def convert_v( dself = self.dict() if version == 2: # remove harmless empty error field that v2 won't accept. if populated, pydantic will catch it. - if dself.pop("error", None): - pass + dself.pop("error", None) + dself["input_specification"].pop("schema_version", None) + dself["optimization_spec"].pop("schema_version", None) dself["optimization_history"] = { - (k, [opthist_class(**res).convert_v(version) for res in lst]) + k: [opthist_class(**res).convert_v(version) for res in lst] for k, lst in dself["optimization_history"].items() } diff --git a/qcelemental/models/v1/results.py b/qcelemental/models/v1/results.py index c809fb1c..f43a0103 100644 --- a/qcelemental/models/v1/results.py +++ b/qcelemental/models/v1/results.py @@ -797,9 +797,29 @@ def _native_file_protocol(cls, value, values): return ret def convert_v( - self, version: int + self, + version: int, + *, + external_input_data: Optional[Any] = None, ) -> Union["qcelemental.models.v1.AtomicResult", "qcelemental.models.v2.AtomicResult"]: - """Convert to instance of particular QCSchema version.""" + """Convert to instance of particular QCSchema version. + + Parameters + ---------- + version + The version to convert to. + external_input_data + Since self contains data merged from input, this allows passing in the original input, particularly for `molecule` and `extras` fields. + Can be model or dictionary and should be *already* converted to the desired version. + Replaces ``input_data`` field entirely (not merges with extracts from self) and w/o consistency checking. + + Returns + ------- + AtomicResult + Returns self (not a copy) if ``version`` already satisfied. + Returns a new AtomicResult of ``version`` otherwise. + + """ import qcelemental as qcel if check_convertible_version(version, error="AtomicResult") == "self": @@ -808,8 +828,27 @@ def convert_v( dself = self.dict() if version == 2: # remove harmless empty error field that v2 won't accept. if populated, pydantic will catch it. - if dself.pop("error", None): - pass + dself.pop("error", None) + + input_data = { + k: dself.pop(k) for k in list(dself.keys()) if k in ["driver", "keywords", "model", "protocols"] + } + input_data["molecule"] = dself["molecule"] # duplicate since input mol has been overwritten + # any input provenance has been overwritten + input_data["extras"] = { + k: dself["extras"].pop(k) for k in list(dself["extras"].keys()) if k in [] + } # sep any merged extras + if external_input_data: + # Note: overwriting with external, not updating. reconsider? + dself["input_data"] = external_input_data + in_extras = ( + external_input_data.get("extras", {}) + if isinstance(external_input_data, dict) + else external_input_data.extras + ) + dself["extras"] = {k: v for k, v in dself["extras"].items() if (k, v) not in in_extras.items()} + else: + dself["input_data"] = input_data self_vN = qcel.models.v2.AtomicResult(**dself) diff --git a/qcelemental/models/v2/common_models.py b/qcelemental/models/v2/common_models.py index 0c660591..8bcac02f 100644 --- a/qcelemental/models/v2/common_models.py +++ b/qcelemental/models/v2/common_models.py @@ -101,6 +101,16 @@ class FailedOperation(ProtoModel): and containing the reason and input data which generated the failure. """ + schema_name: Literal["qcschema_failed_operation"] = Field( + "qcschema_failed_operation", + description=( + f"The QCSchema specification this model conforms to. Explicitly fixed as qcschema_failed_operation." + ), + ) + schema_version: Literal[2] = Field( + 2, + description="The version number of :attr:`~qcelemental.models.FailedOperation.schema_name` to which this model conforms.", + ) id: Optional[str] = Field( # type: ignore None, description="A unique identifier which links this FailedOperation, often of the same Id of the operation " @@ -132,6 +142,10 @@ class FailedOperation(ProtoModel): def __repr_args__(self) -> "ReprArgs": return [("error", self.error)] + @field_validator("schema_version", mode="before") + def _version_stamp(cls, v): + return 2 + def convert_v( self, version: int ) -> Union["qcelemental.models.v1.FailedOperation", "qcelemental.models.v2.FailedOperation"]: @@ -143,6 +157,9 @@ def convert_v( dself = self.model_dump() if version == 1: + dself.pop("schema_name") + dself.pop("schema_version") + self_vN = qcel.models.v1.FailedOperation(**dself) return self_vN diff --git a/qcelemental/models/v2/procedures.py b/qcelemental/models/v2/procedures.py index 0463ec7c..d3ddc8d4 100644 --- a/qcelemental/models/v2/procedures.py +++ b/qcelemental/models/v2/procedures.py @@ -177,7 +177,11 @@ def convert_v( dself = self.model_dump() if version == 1: + trajectory_class = self.trajectory[0].__class__ + + dself["trajectory"] = [trajectory_class(**atres).convert_v(version) for atres in dself["trajectory"]] dself["input_specification"].pop("schema_version", None) + self_vN = qcel.models.v1.OptimizationResult(**dself) return self_vN @@ -297,6 +301,9 @@ def convert_v( dself = self.model_dump() if version == 1: + if dself["optimization_spec"].pop("extras", None): + pass + self_vN = qcel.models.v1.TorsionDriveInput(**dself) return self_vN @@ -350,9 +357,16 @@ def convert_v( dself = self.model_dump() if version == 1: + opthist_class = next(iter(self.optimization_history.values()))[0].__class__ + if dself["optimization_spec"].pop("extras", None): pass + dself["optimization_history"] = { + k: [opthist_class(**res).convert_v(version) for res in lst] + for k, lst in dself["optimization_history"].items() + } + self_vN = qcel.models.v1.TorsionDriveResult(**dself) return self_vN diff --git a/qcelemental/models/v2/results.py b/qcelemental/models/v2/results.py index 782e6169..d70453a4 100644 --- a/qcelemental/models/v2/results.py +++ b/qcelemental/models/v2/results.py @@ -723,7 +723,7 @@ def convert_v( return self_vN -class AtomicResult(AtomicInput): +class AtomicResult(ProtoModel): r"""Results from a CMS program execution.""" schema_name: constr(strip_whitespace=True, pattern=r"^(qc\_?schema_output)$") = Field( # type: ignore @@ -736,6 +736,9 @@ class AtomicResult(AtomicInput): 2, description="The version number of :attr:`~qcelemental.models.AtomicResult.schema_name` to which this model conforms.", ) + id: Optional[str] = Field(None, description="The optional ID for the computation.") + input_data: AtomicInput = Field(..., description=str(AtomicInput.__doc__)) + molecule: Molecule = Field(..., description="The molecule with frame and orientation of the results.") properties: AtomicResultProperties = Field(..., description=str(AtomicResultProperties.__doc__)) wavefunction: Optional[WavefunctionProperties] = Field(None, description=str(WavefunctionProperties.__doc__)) @@ -755,6 +758,10 @@ class AtomicResult(AtomicInput): True, description="The success of program execution. If False, other fields may be blank." ) provenance: Provenance = Field(..., description=str(Provenance.__doc__)) + extras: Dict[str, Any] = Field( + {}, + description="Additional information to bundle with the computation. Use for schema development and scratch space.", + ) @field_validator("schema_name", mode="before") @classmethod @@ -774,12 +781,16 @@ def _version_stamp(cls, v): @field_validator("return_result") @classmethod def _validate_return_result(cls, v, info): - if info.data["driver"] == "energy": + # Do not propagate validation errors + if "input_data" not in info.data: + raise ValueError("Input_data was not properly formed.") + driver = info.data["input_data"].driver + if driver == "energy": if isinstance(v, np.ndarray) and v.size == 1: v = v.item(0) - elif info.data["driver"] == "gradient": + elif driver == "gradient": v = np.asarray(v).reshape(-1, 3) - elif info.data["driver"] == "hessian": + elif driver == "hessian": v = np.asarray(v) nsq = int(v.size**0.5) v.shape = (nsq, nsq) @@ -800,8 +811,8 @@ def _wavefunction_protocol(cls, value, info): raise ValueError("wavefunction must be None, a dict, or a WavefunctionProperties object.") # Do not propagate validation errors - if "protocols" not in info.data: - raise ValueError("Protocols was not properly formed.") + if "input_data" not in info.data: + raise ValueError("Input_data was not properly formed.") # Handle restricted restricted = wfn.get("restricted", None) @@ -814,7 +825,7 @@ def _wavefunction_protocol(cls, value, info): wfn.pop(k) # Handle protocols - wfnp = info.data["protocols"].wavefunction + wfnp = info.data["input_data"].protocols.wavefunction return_keep = None if wfnp == "all": pass @@ -861,10 +872,10 @@ def _wavefunction_protocol(cls, value, info): @classmethod def _stdout_protocol(cls, value, info): # Do not propagate validation errors - if "protocols" not in info.data: - raise ValueError("Protocols was not properly formed.") + if "input_data" not in info.data: + raise ValueError("Input_data was not properly formed.") - outp = info.data["protocols"].stdout + outp = info.data["input_data"].protocols.stdout if outp is True: return value elif outp is False: @@ -875,7 +886,11 @@ def _stdout_protocol(cls, value, info): @field_validator("native_files") @classmethod def _native_file_protocol(cls, value, info): - ancp = info.data["protocols"].native_files + # Do not propagate validation errors + if "input_data" not in info.data: + raise ValueError("Input_data was not properly formed.") + + ancp = info.data["input_data"].protocols.native_files if ancp == "all": return value elif ancp == "none": @@ -905,6 +920,12 @@ def convert_v( dself = self.model_dump() if version == 1: + # input_data = self.input_data.convert_v(1) # TODO probably later + input_data = dself.pop("input_data") + input_data.pop("molecule", None) # discard + input_data.pop("provenance", None) # discard + dself["extras"] = {**input_data.pop("extras", {}), **dself.pop("extras", {})} # merge + dself = {**input_data, **dself} self_vN = qcel.models.v1.AtomicResult(**dself) return self_vN diff --git a/qcelemental/tests/test_model_results.py b/qcelemental/tests/test_model_results.py index ff55262e..9ee325ff 100644 --- a/qcelemental/tests/test_model_results.py +++ b/qcelemental/tests/test_model_results.py @@ -93,7 +93,7 @@ @pytest.fixture(scope="function") -def result_data_fixture(schema_versions): +def result_data_fixture(schema_versions, request): Molecule = schema_versions.Molecule mol = Molecule.from_data( @@ -104,25 +104,39 @@ def result_data_fixture(schema_versions): """ ) - return { - "molecule": mol, - "driver": "energy", - "model": {"method": "UFF"}, - "return_result": 5, - "success": True, - "properties": {}, - "provenance": {"creator": "qcel"}, - "stdout": "I ran.", - } + if "v2" in request.node.name: + return { + "molecule": mol, + "input_data": {"molecule": mol, "model": {"method": "UFF"}, "driver": "energy"}, + "return_result": 5, + "success": True, + "properties": {}, + "provenance": {"creator": "qcel"}, + "stdout": "I ran.", + } + else: + return { + "molecule": mol, + "driver": "energy", + "model": {"method": "UFF"}, + "return_result": 5, + "success": True, + "properties": {}, + "provenance": {"creator": "qcel"}, + "stdout": "I ran.", + } @pytest.fixture(scope="function") -def wavefunction_data_fixture(result_data_fixture, schema_versions): +def wavefunction_data_fixture(result_data_fixture, schema_versions, request): BasisSet = schema_versions.basis.BasisSet bas = BasisSet(name="custom_basis", center_data=center_data, atom_map=["bs_sto3g_o", "bs_sto3g_h", "bs_sto3g_h"]) c_matrix = np.random.rand(bas.nbf, bas.nbf) - result_data_fixture["protocols"] = {"wavefunction": "all"} + if "v2" in request.node.name: + result_data_fixture["input_data"]["protocols"] = {"wavefunction": "all"} + else: + result_data_fixture["protocols"] = {"wavefunction": "all"} result_data_fixture["wavefunction"] = { "basis": bas, "restricted": True, @@ -134,8 +148,11 @@ def wavefunction_data_fixture(result_data_fixture, schema_versions): @pytest.fixture(scope="function") -def native_data_fixture(result_data_fixture): - result_data_fixture["protocols"] = {"native_files": "all"} +def native_data_fixture(result_data_fixture, request): + if "v2" in request.node.name: + result_data_fixture["input_data"]["protocols"] = {"native_files": "all"} + else: + result_data_fixture["protocols"] = {"native_files": "all"} result_data_fixture["native_files"] = { "input": """ echo @@ -368,7 +385,10 @@ def test_result_build(result_data_fixture, request, schema_versions): def test_result_build_wavefunction_delete(wavefunction_data_fixture, request, schema_versions): AtomicResult = schema_versions.AtomicResult - del wavefunction_data_fixture["protocols"] + if "v2" in request.node.name: + del wavefunction_data_fixture["input_data"]["protocols"] + else: + del wavefunction_data_fixture["protocols"] ret = AtomicResult(**wavefunction_data_fixture) drop_qcsk(ret, request.node.name) assert ret.wavefunction is None @@ -444,9 +464,15 @@ def test_wavefunction_protocols( wfn_data = wavefunction_data_fixture["wavefunction"] if protocol is None: - wavefunction_data_fixture.pop("protocols") + if "v2" in request.node.name: + wavefunction_data_fixture["input_data"].pop("protocols") + else: + wavefunction_data_fixture.pop("protocols") else: - wavefunction_data_fixture["protocols"]["wavefunction"] = protocol + if "v2" in request.node.name: + wavefunction_data_fixture["input_data"]["protocols"]["wavefunction"] = protocol + else: + wavefunction_data_fixture["protocols"]["wavefunction"] = protocol wfn_data["restricted"] = restricted bas = wfn_data["basis"] @@ -486,9 +512,15 @@ def test_native_protocols(protocol, provided, expected, native_data_fixture, req native_data = native_data_fixture["native_files"] if protocol is None: - native_data_fixture.pop("protocols") + if "v2" in request.node.name: + native_data_fixture["input_data"].pop("protocols") + else: + native_data_fixture.pop("protocols") else: - native_data_fixture["protocols"]["native_files"] = protocol + if "v2" in request.node.name: + native_data_fixture["input_data"]["protocols"]["native_files"] = protocol + else: + native_data_fixture["protocols"]["native_files"] = protocol for name in list(native_data.keys()): if name not in provided: @@ -534,12 +566,16 @@ def test_error_correction_protocol( policy["default_policy"] = default if defined is not None: policy["policies"] = defined - result_data_fixture["protocols"] = {"error_correction": policy} + if "v2" in request.node.name: + result_data_fixture["input_data"]["protocols"] = {"error_correction": policy} + else: + result_data_fixture["protocols"] = {"error_correction": policy} res = AtomicResult(**result_data_fixture) drop_qcsk(res, request.node.name) - assert res.protocols.error_correction.default_policy == default_result - assert res.protocols.error_correction.policies == defined_result + base = res.input_data if "v2" in request.node.name else res + assert base.protocols.error_correction.default_policy == default_result + assert base.protocols.error_correction.policies == defined_result def test_error_correction_logic(schema_versions): @@ -566,7 +602,10 @@ def test_error_correction_logic(schema_versions): def test_result_build_stdout_delete(result_data_fixture, request, schema_versions): AtomicResult = schema_versions.AtomicResult - result_data_fixture["protocols"] = {"stdout": False} + if "v2" in request.node.name: + result_data_fixture["input_data"]["protocols"] = {"stdout": False} + else: + result_data_fixture["protocols"] = {"stdout": False} ret = AtomicResult(**result_data_fixture) drop_qcsk(ret, request.node.name) assert ret.stdout is None @@ -675,7 +714,10 @@ def every_model_fixture(request): smodel = "AtomicInput" data = request.getfixturevalue("result_data_fixture") - data = {k: data[k] for k in ["molecule", "model", "driver"]} + if "v2" in request.node.name: + data = data["input_data"] + else: + data = {k: data[k] for k in ["molecule", "model", "driver"]} datas[smodel] = data smodel = "QCInputSpecification" # TODO "AtomicSpecification" @@ -889,7 +931,7 @@ def test_model_survey_schema_version(smodel1, smodel2, every_model_fixture, requ "v1-Mol-A" : 2, "v2-Mol-A" : 2, # TODO 3 "v1-Mol-B" : 2, "v2-Mol-B" : 2, # TODO 3 "v1-BasisSet" : 1, "v2-BasisSet" : 2, # TODO change for v2? - "v1-FailedOp" : None, "v2-FailedOp" : None, # TODO 2 + "v1-FailedOp" : None, "v2-FailedOp" : 2, "v1-AtIn" : 1, "v2-AtIn" : 2, "v1-AtSpec" : 1, "v2-AtSpec" : None, # WAS 1, # TODO 2 "v1-AtPtcl" : None, "v2-AtPtcl" : None, @@ -1062,6 +1104,64 @@ def test_model_survey_dictable(smodel1, smodel2, every_model_fixture, request, s assert instance +@pytest.mark.parametrize("smodel1,smodel2", _model_classes_struct) +def test_model_survey_convertable(smodel1, smodel2, every_model_fixture, request, schema_versions): + anskey = request.node.callspec.id.replace("None", "v1") + # fmt: off + ans = { + # "v1-Mol-A" , "v2-Mol-A" , + # "v1-Mol-B" , "v2-Mol-B" , + # "v1-BasisSet" , "v2-BasisSet", + "v1-FailedOp" , "v2-FailedOp", + "v1-AtIn" , "v2-AtIn" , + # "v1-AtSpec" , "v2-AtSpec" , + # "v1-AtPtcl" , "v2-AtPtcl" , + "v1-AtRes" , "v2-AtRes" , + # "v1-AtProp" , "v2-AtProp" , + # "v1-WfnProp" , "v2-WfnProp" , + "v1-OptIn" , "v2-OptIn" , + # "v1-OptSpec" , "v2-OptSpec" , + # "v1-OptPtcl" , "v2-OptPtcl" , + "v1-OptRes" , "v2-OptRes" , + # "v1-OptProp" , "v2-OptProp" , + "v1-TDIn" , "v2-TDIn" , + # "v1-TDSpec" , "v2-TDSpec" , + # "v1-TDKw" , "v2-TDKw" , + # "v1-TDPtcl" , "v2-TDPtcl" , + "v1-TDRes" , "v2-TDRes" , + # "v1-TDProp" , "v2-TDProp" , + # "v1-MBIn" , "v2-MBIn" , + # "v1-MBSpec" , "v2-MBSpec" , + # "v1-MBKw" , "v2-MBKw" , + # "v1-MBPtcl" , "v2-MBPtcl" , + # "v1-MBRes" . "v2-MBRes" , + # "v1-MBProp" , "v2-MBProp" , + } + # fmt: on + + smodel_fro = smodel2 if "v2" in anskey else smodel1 + smodel_to = smodel1 if "v2" in anskey else smodel2 + if smodel_fro is None or smodel_to is None: + pytest.skip("model not available for this schema version") + if anskey not in ans: + pytest.skip("model not yet convert_v()-able") + if "ManyBody" in smodel_fro: + import qcmanybody + + # TODO + model = getattr(qcmanybody.models, smodel_fro.split("-")[0]) + else: + model_fro = getattr(schema_versions, smodel_fro.split("-")[0]) + models_to = qcel.models.v1 if "v2" in anskey else qcel.models.v2 + model_to = getattr(models_to, smodel_to.split("-")[0]) + data = every_model_fixture[smodel_fro] + + # check converts and converts to expected class + instance_fro = model_fro(**data) + instance_to = instance_fro.convert_v(1 if "v2" in anskey else 2) + assert isinstance(instance_to, model_to), f"instance {model_fro} failed to convert to {model_to}" + + def test_result_model_deprecations(result_data_fixture, optimization_data_fixture, request): if "v1" not in request.node.name: # schema_versions coming from fixtures despite not being explicitly present diff --git a/qcelemental/tests/test_utils.py b/qcelemental/tests/test_utils.py index 289beceb..b8d5f2c4 100644 --- a/qcelemental/tests/test_utils.py +++ b/qcelemental/tests/test_utils.py @@ -313,7 +313,7 @@ def test_serialization(obj, encoding): @pytest.fixture -def atomic_result_data(): +def atomic_result_data(request): """Mock AtomicResult output which can be tested against for complex serialization methods""" data = { @@ -384,6 +384,15 @@ def atomic_result_data(): "native_files": {}, "success": True, } + if "v2" in request.node.name: + data["input_data"] = { + "molecule": data["molecule"], + "driver": data.pop("driver"), + "model": data.pop("model"), + "keywords": data.pop("keywords"), + "protocols": data.pop("protocols"), + } + return data diff --git a/qcelemental/tests/test_zqcschema.py b/qcelemental/tests/test_zqcschema.py index d84324d6..49ef3140 100644 --- a/qcelemental/tests/test_zqcschema.py +++ b/qcelemental/tests/test_zqcschema.py @@ -17,7 +17,9 @@ def qcschema_models(): @pytest.mark.parametrize("fl", files, ids=ids) -def test_qcschema(fl, qcschema_models): +def test_qcschema(fl, qcschema_models, request): + if "v2" in request.node.name: + pytest.skip() # TODO v2 schema above import jsonschema model = fl.parent.stem