Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Csse layout 536 input_data #358

Merged
merged 6 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ New Features

Enhancements
++++++++++++
- (536b) ``v1.AtomicResult.convert_v`` learned a ``external_input_data`` option to inject that field (if known) rather than using incomplete reconstruction from the v1 Result. may not be the final sol'n.
- (536b) ``v2.FailedOperation`` gained schema_name and schema_version=2.
- (536b) ``v2.AtomicResult`` no longer inherits from ``v2.AtomicInput``. It gained a ``input_data`` field for the corresponding ``AtomicInput`` and independent ``id`` and ``molecule`` fields (the latter being equivalvent to ``v1.AtomicResult.molecule`` with the frame of the results; ``v2.AtomicResult.input_data.molecule`` is new, preserving the input frame). Gained independent ``extras``
- (536b) Both v1/v2 ``AtomicResult.convert_v()`` learned to handle the new ``input_data`` layout.
- (:pr:`357`, :issue:`536`) ``v2.AtomicResult``, ``v2.OptimizationResult``, and ``v2.TorsionDriveResult`` have the ``success`` field enforced to ``True``. Previously it could be set T/F. Now validation errors if not T. Likewise ``v2.FailedOperation.success`` is enforced to ``False``.
- (:pr:`357`, :issue:`536`) ``v2.AtomicResult``, ``v2.OptimizationResult``, and ``v2.TorsionDriveResult`` have the ``error`` field removed. This isn't used now that ``success=True`` and failure should be routed to ``FailedOperation``.
- (:pr:`357`) ``v1.Molecule`` had its schema_version changed to a Literal[2] (remember Mol is one-ahead of general numbering scheme) so new instances will be 2 even if another value is passed in. Ditto ``v2.BasisSet.schema_version=2``. Ditto ``v1.BasisSet.schema_version=1`` Ditto ``v1.QCInputSpecification.schema_version=1`` and ``v1.OptimizationSpecification.schema_version=1``.
Expand Down
10 changes: 5 additions & 5 deletions qcelemental/models/v1/procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,7 @@ def convert_v(
dself = self.dict()
if version == 2:
# remove harmless empty error field that v2 won't accept. if populated, pydantic will catch it.
if dself.pop("error", None):
pass
dself.pop("error", None)

dself["trajectory"] = [trajectory_class(**atres).convert_v(version) for atres in dself["trajectory"]]
dself["input_specification"].pop("schema_version", None)
Expand Down Expand Up @@ -356,11 +355,12 @@ def convert_v(
dself = self.dict()
if version == 2:
# remove harmless empty error field that v2 won't accept. if populated, pydantic will catch it.
if dself.pop("error", None):
pass
dself.pop("error", None)

dself["input_specification"].pop("schema_version", None)
dself["optimization_spec"].pop("schema_version", None)
dself["optimization_history"] = {
(k, [opthist_class(**res).convert_v(version) for res in lst])
k: [opthist_class(**res).convert_v(version) for res in lst]
for k, lst in dself["optimization_history"].items()
}

Expand Down
47 changes: 43 additions & 4 deletions qcelemental/models/v1/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,9 +797,29 @@ def _native_file_protocol(cls, value, values):
return ret

def convert_v(
self, version: int
self,
version: int,
*,
external_input_data: Optional[Any] = None,
) -> Union["qcelemental.models.v1.AtomicResult", "qcelemental.models.v2.AtomicResult"]:
"""Convert to instance of particular QCSchema version."""
"""Convert to instance of particular QCSchema version.

Parameters
----------
version
The version to convert to.
external_input_data
Since self contains data merged from input, this allows passing in the original input, particularly for `molecule` and `extras` fields.
Can be model or dictionary and should be *already* converted to the desired version.
Replaces ``input_data`` field entirely (not merges with extracts from self) and w/o consistency checking.

Returns
-------
AtomicResult
Returns self (not a copy) if ``version`` already satisfied.
Returns a new AtomicResult of ``version`` otherwise.

"""
import qcelemental as qcel

if check_convertible_version(version, error="AtomicResult") == "self":
Expand All @@ -808,8 +828,27 @@ def convert_v(
dself = self.dict()
if version == 2:
# remove harmless empty error field that v2 won't accept. if populated, pydantic will catch it.
if dself.pop("error", None):
pass
dself.pop("error", None)

input_data = {
k: dself.pop(k) for k in list(dself.keys()) if k in ["driver", "keywords", "model", "protocols"]
}
input_data["molecule"] = dself["molecule"] # duplicate since input mol has been overwritten
# any input provenance has been overwritten
input_data["extras"] = {
k: dself["extras"].pop(k) for k in list(dself["extras"].keys()) if k in []
} # sep any merged extras
if external_input_data:
# Note: overwriting with external, not updating. reconsider?
dself["input_data"] = external_input_data
in_extras = (
external_input_data.get("extras", {})
if isinstance(external_input_data, dict)
else external_input_data.extras
)
dself["extras"] = {k: v for k, v in dself["extras"].items() if (k, v) not in in_extras.items()}
else:
dself["input_data"] = input_data

self_vN = qcel.models.v2.AtomicResult(**dself)

Expand Down
17 changes: 17 additions & 0 deletions qcelemental/models/v2/common_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,16 @@ class FailedOperation(ProtoModel):
and containing the reason and input data which generated the failure.
"""

schema_name: Literal["qcschema_failed_operation"] = Field(
"qcschema_failed_operation",
description=(
f"The QCSchema specification this model conforms to. Explicitly fixed as qcschema_failed_operation."
),
)
schema_version: Literal[2] = Field(
2,
description="The version number of :attr:`~qcelemental.models.FailedOperation.schema_name` to which this model conforms.",
)
id: Optional[str] = Field( # type: ignore
None,
description="A unique identifier which links this FailedOperation, often of the same Id of the operation "
Expand Down Expand Up @@ -132,6 +142,10 @@ class FailedOperation(ProtoModel):
def __repr_args__(self) -> "ReprArgs":
return [("error", self.error)]

@field_validator("schema_version", mode="before")
def _version_stamp(cls, v):
return 2

def convert_v(
self, version: int
) -> Union["qcelemental.models.v1.FailedOperation", "qcelemental.models.v2.FailedOperation"]:
Expand All @@ -143,6 +157,9 @@ def convert_v(

dself = self.model_dump()
if version == 1:
dself.pop("schema_name")
dself.pop("schema_version")

self_vN = qcel.models.v1.FailedOperation(**dself)

return self_vN
Expand Down
12 changes: 12 additions & 0 deletions qcelemental/models/v2/procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,12 @@ def convert_v(
if check_convertible_version(version, error="OptimizationResult") == "self":
return self

trajectory_class = self.trajectory[0].__class__
dself = self.model_dump()
if version == 1:
dself["trajectory"] = [trajectory_class(**atres).convert_v(version) for atres in dself["trajectory"]]
loriab marked this conversation as resolved.
Show resolved Hide resolved
dself["input_specification"].pop("schema_version", None)

self_vN = qcel.models.v1.OptimizationResult(**dself)

return self_vN
Expand Down Expand Up @@ -297,6 +300,9 @@ def convert_v(

dself = self.model_dump()
if version == 1:
if dself["optimization_spec"].pop("extras", None):
pass

self_vN = qcel.models.v1.TorsionDriveInput(**dself)

return self_vN
Expand Down Expand Up @@ -348,11 +354,17 @@ def convert_v(
if check_convertible_version(version, error="TorsionDriveResult") == "self":
return self

opthist_class = next(iter(self.optimization_history.values()))[0].__class__

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can the opthist_class variable be moved into the if-clause given it seems to only be used there?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I placed it outside since it was likely to be of use to all conversions (besides the no-op), but multiple versions are a long way off.

dself = self.model_dump()
if version == 1:
if dself["optimization_spec"].pop("extras", None):
pass

dself["optimization_history"] = {
k: [opthist_class(**res).convert_v(version) for res in lst]
for k, lst in dself["optimization_history"].items()
}

self_vN = qcel.models.v1.TorsionDriveResult(**dself)

return self_vN
43 changes: 32 additions & 11 deletions qcelemental/models/v2/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ def convert_v(
return self_vN


class AtomicResult(AtomicInput):
class AtomicResult(ProtoModel):
r"""Results from a CMS program execution."""

schema_name: constr(strip_whitespace=True, pattern=r"^(qc\_?schema_output)$") = Field( # type: ignore
Expand All @@ -736,6 +736,9 @@ class AtomicResult(AtomicInput):
2,
description="The version number of :attr:`~qcelemental.models.AtomicResult.schema_name` to which this model conforms.",
)
id: Optional[str] = Field(None, description="The optional ID for the computation.")
input_data: AtomicInput = Field(..., description=str(AtomicInput.__doc__))
molecule: Molecule = Field(..., description="The molecule with frame and orientation of the results.")
properties: AtomicResultProperties = Field(..., description=str(AtomicResultProperties.__doc__))
wavefunction: Optional[WavefunctionProperties] = Field(None, description=str(WavefunctionProperties.__doc__))

Expand All @@ -755,6 +758,10 @@ class AtomicResult(AtomicInput):
True, description="The success of program execution. If False, other fields may be blank."
)
provenance: Provenance = Field(..., description=str(Provenance.__doc__))
extras: Dict[str, Any] = Field(
{},
description="Additional information to bundle with the computation. Use for schema development and scratch space.",
)

@field_validator("schema_name", mode="before")
@classmethod
Expand All @@ -774,12 +781,16 @@ def _version_stamp(cls, v):
@field_validator("return_result")
@classmethod
def _validate_return_result(cls, v, info):
if info.data["driver"] == "energy":
# Do not propagate validation errors
if "input_data" not in info.data:
raise ValueError("Input_data was not properly formed.")
driver = info.data["input_data"].driver
if driver == "energy":
if isinstance(v, np.ndarray) and v.size == 1:
v = v.item(0)
elif info.data["driver"] == "gradient":
elif driver == "gradient":
v = np.asarray(v).reshape(-1, 3)
elif info.data["driver"] == "hessian":
elif driver == "hessian":
v = np.asarray(v)
nsq = int(v.size**0.5)
v.shape = (nsq, nsq)
Expand All @@ -800,8 +811,8 @@ def _wavefunction_protocol(cls, value, info):
raise ValueError("wavefunction must be None, a dict, or a WavefunctionProperties object.")

# Do not propagate validation errors
if "protocols" not in info.data:
raise ValueError("Protocols was not properly formed.")
if "input_data" not in info.data:
raise ValueError("Input_data was not properly formed.")

# Handle restricted
restricted = wfn.get("restricted", None)
Expand All @@ -814,7 +825,7 @@ def _wavefunction_protocol(cls, value, info):
wfn.pop(k)

# Handle protocols
wfnp = info.data["protocols"].wavefunction
wfnp = info.data["input_data"].protocols.wavefunction
return_keep = None
if wfnp == "all":
pass
Expand Down Expand Up @@ -861,10 +872,10 @@ def _wavefunction_protocol(cls, value, info):
@classmethod
def _stdout_protocol(cls, value, info):
# Do not propagate validation errors
if "protocols" not in info.data:
raise ValueError("Protocols was not properly formed.")
if "input_data" not in info.data:
raise ValueError("Input_data was not properly formed.")

outp = info.data["protocols"].stdout
outp = info.data["input_data"].protocols.stdout
if outp is True:
return value
elif outp is False:
Expand All @@ -875,7 +886,11 @@ def _stdout_protocol(cls, value, info):
@field_validator("native_files")
@classmethod
def _native_file_protocol(cls, value, info):
ancp = info.data["protocols"].native_files
# Do not propagate validation errors
if "input_data" not in info.data:
raise ValueError("Input_data was not properly formed.")

ancp = info.data["input_data"].protocols.native_files
if ancp == "all":
return value
elif ancp == "none":
Expand Down Expand Up @@ -905,6 +920,12 @@ def convert_v(

dself = self.model_dump()
if version == 1:
# input_data = self.input_data.convert_v(1) # TODO probably later
input_data = dself.pop("input_data")
input_data.pop("molecule", None) # discard
input_data.pop("provenance", None) # discard
dself["extras"] = {**input_data.pop("extras", {}), **dself.pop("extras", {})} # merge
dself = {**input_data, **dself}
self_vN = qcel.models.v1.AtomicResult(**dself)

return self_vN
Loading