Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Free energy fitting #54

Draft
wants to merge 49 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
79d125f
Create run on openff-1.2 training dataset.ipynb
maxentile Oct 22, 2020
e3454ac
use offmol_indices, handle case of no propers / impropers, and adhere…
maxentile Oct 22, 2020
d3f236d
allow to .sum(dim=1) even when there are no impropers
maxentile Oct 22, 2020
10268d5
allow impropers to have length 0 in valencemodel
maxentile Oct 24, 2020
9c350ee
add gbsa port
maxentile Oct 25, 2020
cf14a20
oof, avoid pytorch in-place modification of input argument tensors
maxentile Oct 25, 2020
0c20732
gahh, careless variable-name typo
maxentile Oct 25, 2020
32f7d2d
wip demo notebook for fitting to a hydration free energy calculation …
maxentile Oct 25, 2020
dfce74c
port @proteneer's GBSA implementation instead
maxentile Oct 25, 2020
6a02086
repeat fitting-to-free-energies notebook with less-likely-to-be-buggy…
maxentile Oct 25, 2020
b5fa1e9
add reference implementation from bayes-implicit-solvent
maxentile Oct 29, 2020
b77794c
refactor gbsa_obc2_energy into a function in openmm unit system, and …
maxentile Oct 29, 2020
eb73857
increase descriptiveness in gbsa implementation, add thorough shape a…
maxentile Oct 29, 2020
745c6ac
import FreeSolv database v0.52
maxentile Oct 29, 2020
a88fe91
remove tensor-shape-printing statements
maxentile Oct 29, 2020
3773dd5
create pandas dataframe with serialized openmm systems for freesolv s…
maxentile Oct 29, 2020
22d314f
ooooof. fix silly mistake in fitting-to-free-energies notebook
maxentile Oct 29, 2020
dfdeac0
also handle cases like methane where len(propers) == 0
maxentile Oct 29, 2020
5866029
save also xyz coordinates from brief md
maxentile Oct 29, 2020
47931bd
remove **kwargs to try to play nice with torchscript jit
maxentile Oct 29, 2020
d812bdd
remove temporary assert statements in _gbsa_obc2_energy_omm
maxentile Oct 29, 2020
fc2fbe7
must be a remaining sign-flip error -- seems like it's unable to make…
maxentile Oct 29, 2020
b564fe6
add missing conversion from nm/(proton_charge**2) to kJ/mol
maxentile Oct 30, 2020
08a2bcf
update gbsa docstring
maxentile Oct 30, 2020
d5ae749
update freesolv-fitting notebook
maxentile Oct 30, 2020
bacb705
re-run demo notebook with increased stepsize and decreased network si…
maxentile Oct 30, 2020
6d4c831
ipynb --> py
maxentile Oct 30, 2020
9ffa304
refine fit_freesolv.py script
maxentile Oct 30, 2020
7115efb
add pdf figures from fit_freesolv
maxentile Oct 30, 2020
c63bc05
notebook reporting on element coverage in freesolv
maxentile Oct 31, 2020
6db4016
oops, forgot nitrogen!
maxentile Oct 31, 2020
9857f52
notebook fitting to {C, H, O} mini-freesolv
maxentile Oct 31, 2020
8e50eec
oodles o' vacuum samples
maxentile Oct 31, 2020
1f1f520
set openmm_cpu_threads to 1
maxentile Oct 31, 2020
38bbf03
Create fit to {C, H, O, N, Cl} subset of freesolv (n=529).ipynb
maxentile Oct 31, 2020
7de7b6c
oops fix plot labels
maxentile Oct 31, 2020
2643694
merge freesolv vacuum sample records
maxentile Oct 31, 2020
d24b5bf
git lfs track freesolv_vacuum_samples.npz (279MB)
maxentile Oct 31, 2020
e6d31d9
add xyz column to freesolv_with_samples.h5
maxentile Oct 31, 2020
98f424f
update {C, H, O} subset experiment to use thorough equilibrium sampling
maxentile Oct 31, 2020
a38812d
add experiment script for k-fold cv
maxentile Nov 2, 2020
ac80468
oops, don't indent all the relevant stuff out of the training loop!
maxentile Nov 2, 2020
7c76087
git lfs track each of the K=10-fold CV trajectories
maxentile Nov 2, 2020
f56999c
add notebook to plot k-fold cv results
maxentile Nov 2, 2020
e01ab88
add a horizontal line depicting RMSE of FreeSolv's explicit-solvent c…
maxentile Nov 2, 2020
0a06d43
Add todos
maxentile Sep 3, 2021
3c7eee4
Add SAGEConv
maxentile Sep 3, 2021
fad31a5
Address GraphSAGE todo
maxentile Sep 3, 2021
ccea6ba
Update PDF figures
maxentile Sep 3, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 21 additions & 47 deletions espaloma/mm/implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,50 +36,22 @@ def _gbsa_obc2_energy_omm(

N = len(charges)
eye = torch.eye(N, dtype=distance_matrix.dtype)
assert (eye.shape == (N, N))

r = distance_matrix + eye
assert (r.shape == (N, N))
or1 = radii.reshape((N, 1)) - dielectric_offset
or2 = radii.reshape((1, N)) - dielectric_offset

assert (or1.shape == (N, 1))
assert (or2.shape == (1, N))

sr2 = scales.reshape((1, N)) * or2
assert (sr2.shape == (1, N))

r_sr2 = abs(r - sr2)
assert (r_sr2.shape == (N, N))

L = torch.max(or1, r_sr2)
# TODO: check if elementwise
#print('L.shape', L.shape)
#print('or1.shape', or1.shape)
#print('r_sr2.shape', r_sr2.shape)
# TODO: this is fishy
# assert(L.shape == or1.shape)

L = torch.max(or1, abs(r - sr2))
U = r + sr2
assert (U.shape == (N, N))

I = 1 / L - 1 / U + 0.25 * (r - sr2 ** 2 / r) * (
1 / (U ** 2) - 1 / (L ** 2)) + 0.5 * torch.log(
L / U) / r
assert (I.shape == (N, N))

# handle the interior case
condition = or1 < (sr2 - r)
if_true = I + 2 * (1 / or1 - 1 / L)
if_false = I
I = torch.where(condition, if_true, if_false)

I = torch.where(or1 < (sr2 - r), I + 2 * (1 / or1 - 1 / L), I)
I = step(r + sr2 - or1) * 0.5 * I # note the extra 0.5 here

# intention: zero out values on the diagonal
# TODO: possibly replace this diag(diag) with scatter update
I -= torch.diag(torch.diag(I))

I = torch.sum(I, dim=1)

# okay, next compute born radii
Expand All @@ -95,32 +67,22 @@ def _gbsa_obc2_energy_omm(
psi3_coefficient * psi ** 3)

B = 1 / (1 / offset_radius - torch.tanh(psi_term) / radii)
assert (B.shape == (N,))

E = 0.0
# single particle
# ACE
ACE_individual_terms = surface_tension * (radii + probe_radius) ** 2 * (
radii / B) ** 6
ACE = torch.sum(ACE_individual_terms)
assert (ACE_individual_terms.shape == (N,))
E += ACE
E += torch.sum(
surface_tension * (radii + probe_radius) ** 2 * (radii / B) ** 6)

# on-diagonal
on_diagonal = -0.5 * (
1 / solute_dielectric - 1 / solvent_dielectric) * charges ** 2 / B
assert (on_diagonal.shape == (N,))
E += torch.sum(on_diagonal)
E += torch.sum(-0.5 * (
1 / solute_dielectric - 1 / solvent_dielectric) * charges ** 2 / B)

# particle pair
# note: np.outer --> torch.ger
assert (B.ndim == 1)

f = torch.sqrt(r ** 2 + torch.ger(B, B) * torch.exp(
-r ** 2 / (4 * torch.ger(B, B))))
charge_products = torch.ger(charges, charges)
assert (f.shape == (N, N))
assert (charge_products.shape == (N, N))

ixns = - (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this missing a -138.935485 conversion from nm/(proton_charge**2) to kJ/mol? The docstring says "everything is in OpenMM native units".

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like you might pre-multiply the charges by sqrt(138.935485)? If so, you should probably document that in the docstring.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gahh -- you're right -- I had dropped this in the current conversion! Thank you for catching this. Charges are not assumed to be premultiplied by sqrt(138.935485) , will clarify docstring...

(This conversion was present but poorly labeled in the numpy/jax implementation in bayes-implicit-solvent.)

1 / solute_dielectric - 1 / solvent_dielectric) * charge_products / f
Expand All @@ -133,13 +95,25 @@ def gbsa_obc2_energy(
distance_matrix_in_bohr,
radii_in_bohr, scales, charges,
alpha=0.8, beta=0.0, gamma=2.909125,
**kwargs
dielectric_offset=0.009,
surface_tension=28.3919551,
solute_dielectric=1.0,
solvent_dielectric=78.5,
probe_radius=0.14
):
# convert distances and radii into units of nanometers before proceeding
distance_matrix = distance_matrix_in_bohr * distance_to_nm
radii = radii_in_bohr * distance_to_nm

E = _gbsa_obc2_energy_omm(distance_matrix, radii, scales, charges, alpha,
beta, gamma, **kwargs)
E = _gbsa_obc2_energy_omm(
distance_matrix,
radii, scales, charges,
alpha, beta, gamma,
dielectric_offset=dielectric_offset,
surface_tension=surface_tension,
solute_dielectric=solute_dielectric,
solvent_dielectric=solvent_dielectric,
probe_radius=probe_radius,
)

return E * energy_from_kjmol # return E in espaloma energy unit
9 changes: 6 additions & 3 deletions espaloma/redux/symmetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,12 @@ def symmetry_pool(f, interactions, permutations):

# proper torsions: sum over (abcd, dcba)
proper_perms = [(0, 1, 2, 3), (3, 2, 1, 0)]
propers = symmetry_pool(
self.readouts.propers, indices.propers, proper_perms
)
if len(indices.propers > 0):
propers = symmetry_pool(
self.readouts.propers, indices.propers, proper_perms
)
else:
propers = torch.zeros((0, 6))

# improper torsions: sum over 3 cyclic permutations of non-central
# atoms, following smirnoff trefoil convention:
Expand Down
1 change: 1 addition & 0 deletions scripts/free_energy_fitting/.gitattributes
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
freesolv.h5 filter=lfs diff=lfs merge=lfs -text
freesolv_with_samples.h5 filter=lfs diff=lfs merge=lfs -text
Loading