Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #331 #340

Merged
merged 4 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ The rules for this file:
*/*/2023 hl2500

* 2.2.0

orbeckst marked this conversation as resolved.
Show resolved Hide resolved
Changes
orbeckst marked this conversation as resolved.
Show resolved Hide resolved
- For pandas>=2.1, metadata will be loaded from the parquet file (issue #331, PR #340).

Enhancements
- Add a TI estimator using gaussian quadrature to calculate the free energy.
Expand Down
2 changes: 1 addition & 1 deletion devtools/conda-envs/test_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
dependencies:
- python
- numpy
- pandas
- pandas>=2.1
- pymbar>=4
- scipy
- scikit-learn
Expand Down
39 changes: 36 additions & 3 deletions src/alchemlyb/parsing/parquet.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,42 @@
import pandas as pd
from loguru import logger

from . import _init_attrs


@_init_attrs
orbeckst marked this conversation as resolved.
Show resolved Hide resolved
def _check_metadata(path: str, T: float) -> pd.DataFrame:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name is a bit misleading because it does not only check but also return a df. I would choose a name such as _read_parquet_with_metadata or similar.

"""
Check if the metadata is included in the Dataframe and has the correct
temperature.

Parameters
----------
path : str
Path to parquet file to extract dataframe from.
T : float
Temperature in Kelvin of the simulations.

Returns
-------
DataFrame
"""
df = pd.read_parquet(path)
if "temperature" not in df.attrs:
logger.warning(
f"No temperature metadata found in {path}. "
f"Serialise the Dataframe with pandas>=2.1 to preserve the metadata."
)
df.attrs["temperature"] = T
df.attrs["energy_unit"] = "kT"
else:
if df.attrs["temperature"] != T:
raise ValueError(
f"Temperature in the input ({T}) doesn't match the temperature "
f"in the dataframe ({df.attrs['temperature']})."
)
return df


def extract_u_nk(path, T):
r"""Return reduced potentials `u_nk` (unit: kT) from a pandas parquet file.

Expand Down Expand Up @@ -36,7 +69,7 @@ def extract_u_nk(path, T):
.. versionadded:: 2.1.0

"""
u_nk = pd.read_parquet(path)
u_nk = _check_metadata(path, T)
columns = list(u_nk.columns)
if isinstance(columns[0], str) and columns[0][0] == "(":
new_columns = []
Expand Down Expand Up @@ -81,4 +114,4 @@ def extract_dHdl(path, T):
.. versionadded:: 2.1.0

"""
return pd.read_parquet(path)
return _check_metadata(path, T)
2 changes: 1 addition & 1 deletion src/alchemlyb/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def gmx_ABFE():


@pytest.fixture
def gmx_ABFE_complex_n_uk(gmx_ABFE):
def gmx_ABFE_complex_u_nk(gmx_ABFE):
return [gmx.extract_u_nk(file, T=300) for file in gmx_ABFE["complex"]]


Expand Down
35 changes: 34 additions & 1 deletion src/alchemlyb/tests/parsing/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,45 @@ def test_extract_dHdl(dHdl_list, request, tmp_path):
new_dHdl = extract_dHdl(str(tmp_path / "dhdl.parquet"), T=300)
assert (new_dHdl.columns == dHdl.columns).all()
assert (new_dHdl.index == dHdl.index).all()
assert new_dHdl.attrs["temperature"] == 300
assert new_dHdl.attrs["energy_unit"] == "kT"


@pytest.mark.parametrize("u_nk_list", ["gmx_benzene_VDW_u_nk", "gmx_ABFE_complex_n_uk"])
@pytest.mark.parametrize("u_nk_list", ["gmx_benzene_VDW_u_nk", "gmx_ABFE_complex_u_nk"])
def test_extract_dHdl(u_nk_list, request, tmp_path):
u_nk = request.getfixturevalue(u_nk_list)[0]
u_nk.to_parquet(path=str(tmp_path / "u_nk.parquet"), index=True)
new_u_nk = extract_u_nk(str(tmp_path / "u_nk.parquet"), T=300)
assert (new_u_nk.columns == u_nk.columns).all()
assert (new_u_nk.index == u_nk.index).all()
assert new_u_nk.attrs["temperature"] == 300
assert new_u_nk.attrs["energy_unit"] == "kT"


@pytest.fixture()
def u_nk(gmx_ABFE_complex_u_nk):
return gmx_ABFE_complex_u_nk[0]


def test_no_T(u_nk, tmp_path, caplog):
u_nk.attrs = {}
u_nk.to_parquet(path=str(tmp_path / "temp.parquet"), index=True)
extract_u_nk(str(tmp_path / "temp.parquet"), 300)
assert (
"Serialise the Dataframe with pandas>=2.1 to preserve the metadata."
in caplog.text
)


def test_wrong_T(u_nk, tmp_path, caplog):
u_nk.to_parquet(path=str(tmp_path / "temp.parquet"), index=True)
with pytest.raises(ValueError, match="doesn't match the temperature"):
extract_u_nk(str(tmp_path / "temp.parquet"), 400)


def test_metadata_unchanged(u_nk, tmp_path):
u_nk.attrs = {"temperature": 400, "energy_unit": "kcal/mol"}
u_nk.to_parquet(path=str(tmp_path / "temp.parquet"), index=True)
new_u_nk = extract_u_nk(str(tmp_path / "temp.parquet"), 400)
assert new_u_nk.attrs["temperature"] == 400
assert new_u_nk.attrs["energy_unit"] == "kcal/mol"
6 changes: 3 additions & 3 deletions src/alchemlyb/tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def u_nk(gmx_benzene_Coulomb_u_nk):


@pytest.fixture()
def multi_index_u_nk(gmx_ABFE_complex_n_uk):
return gmx_ABFE_complex_n_uk[0]
def multi_index_u_nk(gmx_ABFE_complex_u_nk):
return gmx_ABFE_complex_u_nk[0]


@pytest.fixture()
Expand Down Expand Up @@ -470,7 +470,7 @@ def test_decorrelate_dhdl_multiple_l(multi_index_dHdl):
)


def test_raise_non_uk(multi_index_dHdl):
def test_raise_nou_nk(multi_index_dHdl):
with pytest.raises(ValueError):
decorrelate_u_nk(
multi_index_dHdl,
Expand Down
2 changes: 1 addition & 1 deletion src/alchemlyb/tests/test_workflow_ABFE.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def test_single_estimator_ti(self, workflow, monkeypatch):
summary = workflow.generate_result()
assert np.isclose(summary["TI"]["Stages"]["TOTAL"], 21.51472826028906, 0.1)

def test_unprocessed_n_uk(self, workflow, monkeypatch):
def test_unprocessed_u_nk(self, workflow, monkeypatch):
monkeypatch.setattr(workflow, "u_nk_sample_list", None)
monkeypatch.setattr(workflow, "estimator", dict())
workflow.estimate()
Expand Down
Loading