diff --git a/examples/harmonic-oscillators/harmonic-oscillators-prof.py b/examples/harmonic-oscillators/harmonic-oscillators-prof.py new file mode 100644 index 00000000..6ee41813 --- /dev/null +++ b/examples/harmonic-oscillators/harmonic-oscillators-prof.py @@ -0,0 +1,171 @@ +#!/usr/bin/python + +#============================================================================================= +# Test MBAR by performing statistical tests on a set of of 1D harmonic oscillators, for which +# the true free energy differences can be computed analytically. +# +# A number of replications of an experiment in which i.i.d. samples are drawn from a set of +# K harmonic oscillators are produced. For each replicate, we estimate the dimensionless free +# energy differences and mean-square displacements (an observable), as well as their uncertainties. +# +# For a 1D harmonic oscillator, the potential is given by +# V(x;K) = (K/2) * (x-x_0)**2 +# where K denotes the spring constant. +# +# The equilibrium distribution is given analytically by +# p(x;beta,K) = sqrt[(beta K) / (2 pi)] exp[-beta K (x-x_0)**2 / 2] +# The dimensionless free energy is therefore +# f(beta,K) = - (1/2) * ln[ (2 pi) / (beta K) ] +# +#============================================================================================= + +#============================================================================================= +# IMPORTS +#============================================================================================= +import sys +import numpy as np +from pymbar import testsystems, exp, exp_gauss, bar, MBAR +from pymbar.utils import ParameterError + +#============================================================================================= +# HELPER FUNCTIONS +#============================================================================================= + +def stddev_away(namex,errorx,dx): + + if dx > 0: + print("%s differs by %.3f standard deviations from analytical" % (namex,errorx/dx)) + else: + print("%s differs by an undefined number of standard deviations" % (namex)) + +def GetAnalytical(beta,K,O,observables): + + # For a harmonic oscillator with spring constant K, + # x ~ Normal(x_0, sigma^2), where sigma = 1/sqrt(beta K) + + # Compute the absolute dimensionless free energies of each oscillator analytically. + # f = - ln(sqrt((2 pi)/(beta K)) ) + print('Computing dimensionless free energies analytically...') + + sigma = (beta * K)**-0.5 + f_k_analytical = - np.log(np.sqrt(2 * np.pi) * sigma ) + + Delta_f_ij_analytical = np.matrix(f_k_analytical) - np.matrix(f_k_analytical).transpose() + + A_k_analytical = dict() + A_ij_analytical = dict() + + for observe in observables: + if observe == 'RMS displacement': + A_k_analytical[observe] = sigma # mean square displacement + if observe == 'potential energy': + A_k_analytical[observe] = 1/(2*beta)*np.ones(len(K),float) # By equipartition + if observe == 'position': + A_k_analytical[observe] = O # observable is the position + if observe == 'position^2': + A_k_analytical[observe] = (1+ beta*K*O**2)/(beta*K) # observable is the position^2 + + A_ij_analytical[observe] = A_k_analytical[observe] - np.transpose(np.matrix(A_k_analytical[observe])) + + return f_k_analytical, Delta_f_ij_analytical, A_k_analytical, A_ij_analytical + +#============================================================================================= +# PARAMETERS +#============================================================================================= + +copies = 1 +K_k = copies*[2.5,1.6,9,4,1,1] +K_k = np.array(K_k) # spring constants for each state +O_i = [0,1,2,3,4,5] +O_k = np.array(copies*O_i) # offsets for spring constants +O_k = np.array(O_k) +for c in range(copies): + O_k[len(O_i)*c:len(O_i)*(c+1)] += c*len(O_i)*np.ones(len(O_i),int) +N_k = copies*[1000, 1000, 1000, 1000, 0, 1000] +N_k = 5000*np.array(N_k) # number of samples from each state (can be zero for some states) +Nk_ne_zero = (N_k!=0) +beta = 1.0 # inverse temperature for all simulations +K_extra = np.array([20, 12, 6, 2, 1]) +O_extra = np.array([ 0.5, 1.5, 2.5, 3.5, 4.5]) +observables = ['position','position^2','potential energy','RMS displacement'] + +seed = None +# Uncomment the following line to seed the random number generated to produce reproducible output. +seed = 0 +np.random.seed(seed) + +#============================================================================================= +# MAIN +#============================================================================================= + +# Determine number of simulations. +K = np.size(N_k) +if np.shape(K_k) != np.shape(N_k): + raise ParameterError("K_k (%d) and N_k (%d) must have same dimensions." % (np.shape(K_k), np.shape(N_k))) +if np.shape(O_k) != np.shape(N_k): + raise ParameterError("O_k (%d) and N_k (%d) must have same dimensions." % (np.shape(K_k), np.shape(N_k))) + +# Determine maximum number of samples to be drawn for any state. +N_max = np.max(N_k) + +(f_k_analytical, Delta_f_ij_analytical, A_k_analytical, A_ij_analytical) = GetAnalytical(beta,K_k,O_k,observables) + +print("This script will draw samples from %d harmonic oscillators." % (K)) +print("The harmonic oscillators have equilibrium positions") +print(O_k) +print("and spring constants") +print(K_k) +print("and the following number of samples will be drawn from each (can be zero if no samples drawn):") +print(N_k) +print("") + +#============================================================================================= +# Generate independent data samples from K one-dimensional harmonic oscillators centered at q = 0. +#============================================================================================= + +print('generating samples...') +randomsample = testsystems.harmonic_oscillators.HarmonicOscillatorsTestCase(O_k=O_k, K_k=K_k, beta=beta) +[x_kn,u_kn,N_k,s_n] = randomsample.sample(N_k,mode='u_kn') + +# get the unreduced energies +U_kn = u_kn/beta + +#============================================================================================= +# Estimate free energies and expectations. +#============================================================================================= + +print("======================================") +print(" Initializing MBAR ") +print("======================================") + +# Estimate free energies from simulation using MBAR. +print("Estimating relative free energies from simulation (this may take a while)...") + +# Initialize the MBAR class, determining the free energies. +mbar = MBAR(u_kn, N_k, relative_tolerance=1.0e-10, verbose=True) +# Get matrix of dimensionless free energy differences and uncertainty estimate. + +print("=============================================") +print(" Testing compute_free_energy_differences ") +print("=============================================") +results = mbar.compute_free_energy_differences() +Delta_f_ij_estimated = results['Delta_f'] +dDelta_f_ij_estimated = results['dDelta_f'] + +# Compute error from analytical free energy differences. +Delta_f_ij_error = Delta_f_ij_estimated - Delta_f_ij_analytical + +print("Error in free energies is:") +print(Delta_f_ij_error) +print("Uncertainty in free energies is:") +print(dDelta_f_ij_estimated) + +print("Standard deviations away is:") +# mathematical manipulation to avoid dividing by zero errors; we don't care +# about the diagnonals, since they are identically zero. +df_ij_mod = dDelta_f_ij_estimated + np.identity(K) +stdevs = np.abs(Delta_f_ij_error/df_ij_mod) +for k in range(K): + stdevs[k,k] = 0 +print(stdevs) + diff --git a/examples/harmonic-oscillators/harmonic-oscillators-testprof.py b/examples/harmonic-oscillators/harmonic-oscillators-testprof.py new file mode 100644 index 00000000..20f87bac --- /dev/null +++ b/examples/harmonic-oscillators/harmonic-oscillators-testprof.py @@ -0,0 +1,862 @@ +#!/usr/bin/python + +#============================================================================================= +# Test MBAR by performing statistical tests on a set of of 1D harmonic oscillators, for which +# the true free energy differences can be computed analytically. +# +# A number of replications of an experiment in which i.i.d. samples are drawn from a set of +# K harmonic oscillators are produced. For each replicate, we estimate the dimensionless free +# energy differences and mean-square displacements (an observable), as well as their uncertainties. +# +# For a 1D harmonic oscillator, the potential is given by +# V(x;K) = (K/2) * (x-x_0)**2 +# where K denotes the spring constant. +# +# The equilibrium distribution is given analytically by +# p(x;beta,K) = sqrt[(beta K) / (2 pi)] exp[-beta K (x-x_0)**2 / 2] +# The dimensionless free energy is therefore +# f(beta,K) = - (1/2) * ln[ (2 pi) / (beta K) ] +# +#============================================================================================= + +#============================================================================================= +# IMPORTS +#============================================================================================= +from __future__ import print_function +import sys +import numpy +from pymbar import testsystems, EXP, EXPGauss, BAR, MBAR +from pymbar.utils import ParameterError + +#============================================================================================= +# HELPER FUNCTIONS +#============================================================================================= + +def stddev_away(namex,errorx,dx): + + if dx > 0: + print("%s differs by %.3f standard deviations from analytical" % (namex,errorx/dx)) + else: + print("%s differs by an undefined number of standard deviations" % (namex)) + +def GetAnalytical(beta,K,O,observables): + + # For a harmonic oscillator with spring constant K, + # x ~ Normal(x_0, sigma^2), where sigma = 1/sqrt(beta K) + + # Compute the absolute dimensionless free energies of each oscillator analytically. + # f = - ln(sqrt((2 pi)/(beta K)) ) + print('Computing dimensionless free energies analytically...') + + sigma = (beta * K)**-0.5 + f_k_analytical = - numpy.log(numpy.sqrt(2 * numpy.pi) * sigma ) + + Delta_f_ij_analytical = numpy.matrix(f_k_analytical) - numpy.matrix(f_k_analytical).transpose() + + A_k_analytical = dict() + A_ij_analytical = dict() + + for observe in observables: + if observe == 'RMS displacement': + A_k_analytical[observe] = sigma # mean square displacement + if observe == 'potential energy': + A_k_analytical[observe] = 1/(2*beta)*numpy.ones(len(K),float) # By equipartition + if observe == 'position': + A_k_analytical[observe] = O # observable is the position + if observe == 'position^2': + A_k_analytical[observe] = (1+ beta*K*O**2)/(beta*K) # observable is the position^2 + + A_ij_analytical[observe] = A_k_analytical[observe] - numpy.transpose(numpy.matrix(A_k_analytical[observe])) + + return f_k_analytical, Delta_f_ij_analytical, A_k_analytical, A_ij_analytical + +#============================================================================================= +# PARAMETERS +#============================================================================================= + +K_k = numpy.array([25, 16, 9, 4, 1, 1]) # spring constants for each state +O_k = numpy.array([0, 1, 2, 3, 4, 5]) # offsets for spring constants +N_k = 1000*numpy.array([1000, 1000, 1000, 1000, 0, 1000]) # number of samples from each state (can be zero for some states) +Nk_ne_zero = (N_k!=0) +beta = 1.0 # inverse temperature for all simulations +K_extra = numpy.array([20, 12, 6, 2, 1]) +O_extra = numpy.array([ 0.5, 1.5, 2.5, 3.5, 4.5]) +observables = ['position','position^2','potential energy','RMS displacement'] + +seed = None +# Uncomment the following line to seed the random number generated to produce reproducible output. +seed = 0 +numpy.random.seed(seed) + +#============================================================================================= +# MAIN +#============================================================================================= + +# Determine number of simulations. +K = numpy.size(N_k) +if numpy.shape(K_k) != numpy.shape(N_k): + raise ParameterError("K_k (%d) and N_k (%d) must have same dimensions." % (numpy.shape(K_k), numpy.shape(N_k))) +if numpy.shape(O_k) != numpy.shape(N_k): + raise ParameterError("O_k (%d) and N_k (%d) must have same dimensions." % (numpy.shape(K_k), numpy.shape(N_k))) + +# Determine maximum number of samples to be drawn for any state. +N_max = numpy.max(N_k) + +(f_k_analytical, Delta_f_ij_analytical, A_k_analytical, A_ij_analytical) = GetAnalytical(beta,K_k,O_k,observables) + +print("This script will draw samples from %d harmonic oscillators." % (K)) +print("The harmonic oscillators have equilibrium positions") +print(O_k) +print("and spring constants") +print(K_k) +print("and the following number of samples will be drawn from each (can be zero if no samples drawn):") +print(N_k) +print("") + +#============================================================================================= +# Generate independent data samples from K one-dimensional harmonic oscillators centered at q = 0. +#============================================================================================= + +print('generating samples...') +randomsample = testsystems.harmonic_oscillators.HarmonicOscillatorsTestCase(O_k=O_k, K_k=K_k, beta=beta) +[x_kn,u_kln,N_k] = randomsample.sample(N_k,mode='u_kln') + +# get the unreduced energies +U_kln = u_kln/beta + +#============================================================================================= +# Estimate free energies and expectations. +#============================================================================================= + +print("======================================") +print(" Initializing MBAR ") +print("======================================") + +# Estimate free energies from simulation using MBAR. +print("Estimating relative free energies from simulation (this may take a while)...") + +# Initialize the MBAR class, determining the free energies. +mbar = MBAR(u_kln, N_k, relative_tolerance=1.0e-10, verbose=True) +# Get matrix of dimensionless free energy differences and uncertainty estimate. + +print("=============================================") +print(" Testing getFreeEnergyDifferences ") +print("=============================================") + +results = mbar.getFreeEnergyDifferences(return_dict=True) +Delta_f_ij_estimated = results['Delta_f'] +dDelta_f_ij_estimated = results['dDelta_f'] + +# Compute error from analytical free energy differences. +Delta_f_ij_error = Delta_f_ij_estimated - Delta_f_ij_analytical + +print("Error in free energies is:") +print(Delta_f_ij_error) +print("Uncertainty in free energies is:") +print(dDelta_f_ij_estimated) + +print("Standard deviations away is:") +# mathematical manipulation to avoid dividing by zero errors; we don't care +# about the diagnonals, since they are identically zero. +df_ij_mod = dDelta_f_ij_estimated + numpy.identity(K) +stdevs = numpy.abs(Delta_f_ij_error/df_ij_mod) +for k in range(K): + stdevs[k,k] = 0 +print(stdevs) + +exit() +print("==============================================") +print(" Testing computeBAR ") +print("==============================================") + +nonzero_indices = numpy.array(list(range(K)))[Nk_ne_zero] +Knon = len(nonzero_indices) +for i in range(Knon-1): + k = nonzero_indices[i] + k1 = nonzero_indices[i+1] + w_F = u_kln[k, k1, 0:N_k[k]] - u_kln[k, k, 0:N_k[k]] # forward work + w_R = u_kln[k1, k, 0:N_k[k1]] - u_kln[k1, k1, 0:N_k[k1]] # reverse work + results = BAR(w_F, w_R, return_dict=True) + df_bar = results['Delta_f'] + ddf_bar = results['dDelta_f'] + bar_analytical = (f_k_analytical[k1]-f_k_analytical[k]) + bar_error = bar_analytical - df_bar + print("BAR estimator for reduced free energy from states %d to %d is %f +/- %f" % (k,k1,df_bar,ddf_bar)) + stddev_away("BAR estimator",bar_error,ddf_bar) + +print("==============================================") +print(" Testing computeEXP ") +print("==============================================") + +print("EXP forward free energy") +for k in range(K-1): + if N_k[k] != 0: + w_F = u_kln[k, k+1, 0:N_k[k]] - u_kln[k, k, 0:N_k[k]] # forward work + results = EXP(w_F, return_dict=True) + df_exp = results['Delta_f'] + ddf_exp = results['dDelta_f'] + exp_analytical = (f_k_analytical[k+1]-f_k_analytical[k]) + exp_error = exp_analytical - df_exp + print("df from states %d to %d is %f +/- %f" % (k,k+1,df_exp,ddf_exp)) + stddev_away("df",exp_error,ddf_exp) + +print("EXP reverse free energy") +for k in range(1,K): + if N_k[k] != 0: + w_R = u_kln[k, k-1, 0:N_k[k]] - u_kln[k, k, 0:N_k[k]] # reverse work + (df_exp,ddf_exp) = EXP(w_R, return_dict=True) + df_exp = -results['Delta_f'] + ddf_exp = results['dDelta_f'] + exp_analytical = (f_k_analytical[k]-f_k_analytical[k-1]) + exp_error = exp_analytical - df_exp + print("df from states %d to %d is %f +/- %f" % (k,k-1,df_exp,ddf_exp)) + stddev_away("df",exp_error,ddf_exp) + +print("==============================================") +print(" Testing computeGauss ") +print("==============================================") + +print("Gaussian forward estimate") +for k in range(K-1): + if N_k[k] != 0: + w_F = u_kln[k, k+1, 0:N_k[k]] - u_kln[k, k, 0:N_k[k]] # forward work + results = EXPGauss(w_F, return_dict=True) + df_gauss = results['Delta_f'] + ddf_gauss = results['dDelta_f'] + gauss_analytical = (f_k_analytical[k+1]-f_k_analytical[k]) + gauss_error = gauss_analytical - df_gauss + print("df for reduced free energy from states %d to %d is %f +/- %f" % (k,k+1,df_gauss,ddf_gauss)) + stddev_away("df",gauss_error,ddf_gauss) + +print("Gaussian reverse estimate") +for k in range(1,K): + if N_k[k] != 0: + w_R = u_kln[k, k-1, 0:N_k[k]] - u_kln[k, k, 0:N_k[k]] # reverse work + results = EXPGauss(w_R, return_dict=True) + df_gauss = results['Delta_f'] + ddf_gauss = results['dDelta_f'] + gauss_analytical = (f_k_analytical[k]-f_k_analytical[k-1]) + gauss_error = gauss_analytical - df_gauss + print("df for reduced free energy from states %d to %d is %f +/- %f" % (k,k-1,df_gauss,ddf_gauss)) + stddev_away("df",gauss_error,ddf_gauss) + +print("======================================") +print(" Testing computeExpectations") +print("======================================") + +A_kn_all = dict() +A_k_estimated_all = dict() +A_kl_estimated_all = dict() +N = numpy.sum(N_k) + +for observe in observables: + print("============================================") + print(" Testing observable %s" % (observe)) + print("============================================") + + if observe == 'RMS displacement': + state_dependent = True + A_kn = numpy.zeros([K,N], dtype = numpy.float64) + n = 0 + for k in range(0,K): + for nk in range(0,N_k[k]): + A_kn[:,n] = (x_kn[k,nk] - O_k[:])**2 # observable is the squared displacement + n += 1 + + # observable is the potential energy, a 3D array since the potential energy is a function of + # thermodynamic state + elif observe == 'potential energy': + state_dependent = True + A_kn = numpy.zeros([K,N], dtype = numpy.float64) + n = 0 + for k in range(0,K): + for nk in range(0,N_k[k]): + A_kn[:,n] = U_kln[k,:,nk] + n += 1 + + # observable for estimation is the position + elif observe == 'position': + state_dependent = False + A_kn = numpy.zeros([K,N_max], dtype = numpy.float64) + for k in range(0,K): + A_kn[k,0:N_k[k]] = x_kn[k,0:N_k[k]] + + # observable for estimation is the position^2 + elif observe == 'position^2': + state_dependent = False + A_kn = numpy.zeros([K,N_max], dtype = numpy.float64) + for k in range(0,K): + A_kn[k,0:N_k[k]] = x_kn[k,0:N_k[k]]**2 + + results = mbar.computeExpectations(A_kn, state_dependent = state_dependent, return_dict=True) + A_k_estimated = results['mu'] + dA_k_estimated = results['sigma'] + + # need to additionally transform to get the square root + if observe == 'RMS displacement': + A_k_estimated = numpy.sqrt(A_k_estimated) + # Compute error from analytical observable estimate. + dA_k_estimated = dA_k_estimated/(2*A_k_estimated) + + As_k_estimated = numpy.zeros([K],numpy.float64) + dAs_k_estimated = numpy.zeros([K],numpy.float64) + + # 'standard' expectation averages - not defined if no samples + nonzeros = numpy.arange(K)[Nk_ne_zero] + + totaln = 0 + for k in nonzeros: + if (observe == 'position') or (observe == 'position^2'): + As_k_estimated[k] = numpy.average(A_kn[k,0:N_k[k]]) + dAs_k_estimated[k] = numpy.sqrt(numpy.var(A_kn[k,0:N_k[k]])/(N_k[k]-1)) + elif (observe == 'RMS displacement' ) or (observe == 'potential energy'): + totalp = totaln + N_k[k] + As_k_estimated[k] = numpy.average(A_kn[k,totaln:totalp]) + dAs_k_estimated[k] = numpy.sqrt(numpy.var(A_kn[k,totaln:totalp])/(N_k[k]-1)) + totaln = totalp + if observe == 'RMS displacement': + As_k_estimated[k] = numpy.sqrt(As_k_estimated[k]) + dAs_k_estimated[k] = dAs_k_estimated[k]/(2*As_k_estimated[k]) + + A_k_error = A_k_estimated - A_k_analytical[observe] + As_k_error = As_k_estimated - A_k_analytical[observe] + + print("------------------------------") + print("Now testing 'averages' mode") + print("------------------------------") + + print("Analytical estimator of %s is" % (observe)) + print(A_k_analytical[observe]) + + print("MBAR estimator of the %s is" % (observe)) + print(A_k_estimated) + + print("MBAR estimators differ by X standard deviations") + stdevs = numpy.abs(A_k_error/dA_k_estimated) + print(stdevs) + + print("Standard estimator of %s is (states with samples):" % (observe)) + print(As_k_estimated[Nk_ne_zero]) + + print("Standard estimators differ by X standard deviations (states with samples)") + stdevs = numpy.abs(As_k_error[Nk_ne_zero]/dAs_k_estimated[Nk_ne_zero]) + print(stdevs) + + results = mbar.computeExpectations(A_kn, state_dependent = state_dependent, output = 'differences', return_dict=True) + A_kl_estimated = results['mu'] + dA_kl_estimated = results['sigma'] + + + print("------------------------------") + print("Now testing 'differences' mode") + print("------------------------------") + + if 'RMS displacement' != observe: # can't test this, because we're actually computing the expectation of + # the mean square displacement, and so the differences are - , + # not sqrt^2 - sqrt^2 + A_kl_analytical = numpy.matrix(A_k_analytical[observe]) - numpy.matrix(A_k_analytical[observe]).transpose() + A_kl_error = A_kl_estimated - A_kl_analytical + + print("Analytical estimator of differences of %s is" % (observe)) + print(A_kl_analytical) + + print("MBAR estimator of the differences of %s is" % (observe)) + print(A_kl_estimated) + + print("MBAR estimators differ by X standard deviations") + stdevs = numpy.abs(A_kl_error/(dA_kl_estimated+numpy.identity(K))) + for k in range(K): + stdevs[k,k] = 0 + print(stdevs) + + # save up the A_k for use in computeMultipleExpectations + A_kn_all[observe] = A_kn + A_k_estimated_all[observe] = A_k_estimated + A_kl_estimated_all[observe] = A_kl_estimated + +print("=============================================") +print(" Testing computeMultipleExpectations") +print("=============================================") + +# have to exclude the potential and RMS displacemet for now, not functions of a single state +observables_single = ['position','position^2'] + +A_ikn = numpy.zeros([len(observables_single), K, N_k.max()], numpy.float64) +for i,observe in enumerate(observables_single): + A_ikn[i,:,:] = A_kn_all[observe] +for i in range(K): + results = mbar.computeMultipleExpectations(A_ikn, u_kln[:,i,:], compute_covariance=True, return_dict=True) + A_i = results['mu'] + dA_ij = results['sigma'] + Ca_ij = results['covariances'] + print("Averages for state %d" % (i)) + print(A_i) + print("Uncertainties for state %d" % (i)) + print(dA_ij) + print("Correlation matrix between observables for state %d" % (i)) + print(Ca_ij) + +print("============================================") +print(" Testing computeEntropyAndEnthalpy") +print("============================================") + +results = mbar.computeEntropyAndEnthalpy(u_kn = u_kln, verbose = True, return_dict=True) +Delta_f_ij = results['Delta_f'] +dDelta_f_ij = results['dDelta_f'] +Delta_u_ij = results['Delta_u'] +dDelta_u_ij = results['dDelta_u'] +Delta_s_ij = results['Delta_s'] +dDelta_s_ij = results['dDelta_s'] + +print("Free energies") +print(Delta_f_ij) +print(dDelta_f_ij) +diffs1 = Delta_f_ij - Delta_f_ij_estimated +print("maximum difference between values computed here and in computeFreeEnergies is %g" % (numpy.max(diffs1))) +if (numpy.max(numpy.abs(diffs1)) > 1.0e-10): + print("Difference in values from computeFreeEnergies") + print(diffs1) +diffs2 = dDelta_f_ij - dDelta_f_ij_estimated +print("maximum difference between uncertainties computed here and in computeFreeEnergies is %g" % (numpy.max(diffs2))) +if (numpy.max(numpy.abs(diffs2)) > 1.0e-10): + print("Difference in expectations from computeFreeEnergies") + print(diffs2) + +print("Energies") +print(Delta_u_ij) +print(dDelta_u_ij) +U_k = numpy.matrix(A_k_estimated_all['potential energy']) +expectations = U_k - U_k.transpose() +diffs1 = Delta_u_ij - expectations +print("maximum difference between values computed here and in computeExpectations is %g" % (numpy.max(diffs1))) +if (numpy.max(numpy.abs(diffs1)) > 1.0e-10): + print("Difference in values from computeExpectations") + print(diffs1) + +print("Entropies") +print(Delta_s_ij) +print(dDelta_s_ij) + +#analytical entropy estimate +s_k_analytical = numpy.matrix(0.5 / beta - f_k_analytical) +Delta_s_ij_analytical = s_k_analytical - s_k_analytical.transpose() + +Delta_s_ij_error = Delta_s_ij_analytical - Delta_s_ij +print("Error in entropies is:") +print(Delta_f_ij_error) + +print("Standard deviations away is:") +# mathematical manipulation to avoid dividing by zero errors; we don't care +# about the diagnonals, since they are identically zero. +ds_ij_mod = dDelta_s_ij + numpy.identity(K) +stdevs = numpy.abs(Delta_s_ij_error/ds_ij_mod) +for k in range(K): + stdevs[k,k] = 0 +print(stdevs) + +print("============================================") +print(" Testing computePerturbedFreeEnergies") +print("============================================") + +L = numpy.size(K_extra) +(f_k_analytical, Delta_f_ij_analytical, A_k_analytical, A_ij_analytical) = GetAnalytical(beta,K_extra,O_extra,observables) + +if numpy.size(O_extra) != numpy.size(K_extra): + raise ParameterError("O_extra (%d) and K_extra (%d) must have the same dimensions." % (numpy.shape(K_k), numpy.shape(N_k))) + +unew_kln = numpy.zeros([K,L,numpy.max(N_k)],numpy.float64) +for k in range(K): + for l in range(L): + unew_kln[k,l,0:N_k[k]] = (K_extra[l]/2.0) * (x_kn[k,0:N_k[k]]-O_extra[l])**2 + +results = mbar.computePerturbedFreeEnergies(unew_kln, return_dict=True) +Delta_f_ij_estimated = results['Delta_f'] +dDelta_f_ij_estimated = results['dDelta_f'] + +Delta_f_ij_error = Delta_f_ij_estimated - Delta_f_ij_analytical + +print("Error in free energies is:") +print(Delta_f_ij_error) + +print("Standard deviations away is:") +# mathematical manipulation to avoid dividing by zero errors; we don't care +# about the diagnonals, since they are identically zero. +df_ij_mod = dDelta_f_ij_estimated + numpy.identity(L) +stdevs = numpy.abs(Delta_f_ij_error/df_ij_mod) +for l in range(L): + stdevs[l,l] = 0 +print(stdevs) + +print("============================================") +print(" Testing computeExpectation (new states) ") +print("============================================") + +nth = 3 +# test the nth "extra" states, O_extra[nth] & K_extra[nth] +for observe in observables: + print("============================================") + print(" Testing observable %s" % (observe)) + print("============================================") + + if observe == 'RMS displacement': + state_dependent = True + A_kn = numpy.zeros([K,1,N_max], dtype = numpy.float64) + for k in range(0,K): + A_kn[k,0,0:N_k[k]] = (x_kn[k,0:N_k[k]] - O_extra[nth])**2 # observable is the squared displacement + + # observable is the potential energy, a 3D array since the potential energy is a function of + # thermodynamic state + elif observe == 'potential energy': + state_dependent = True + A_kn = unew_kln[:,[nth],:]/beta + + # position and position^2 can use the same observables + # observable for estimation is the position + elif observe == 'position': + state_dependent = False + A_kn = A_kn_all['position'] + + elif observe == 'position^2': + state_dependent = False + A_kn = A_kn_all['position^2'] + + A_k_estimated, dA_k_estimated + results = mbar.computeExpectations(A_kn,unew_kln[:,[nth],:],state_dependent=state_dependent, return_dict=True) + A_k_estimated = results['mu'] + dA_k_estimated = results['sigma'] + # need to additionally transform to get the square root + if observe == 'RMS displacement': + A_k_estimated = numpy.sqrt(A_k_estimated) + dA_k_estimated = dA_k_estimated/(2*A_k_estimated) + + A_k_error = A_k_estimated - A_k_analytical[observe][nth] + + print("Analytical estimator of %s is" % (observe)) + print(A_k_analytical[observe][nth]) + + print("MBAR estimator of the %s is" % (observe)) + print(A_k_estimated) + + print("MBAR estimators differ by X standard deviations") + stdevs = numpy.abs(A_k_error/dA_k_estimated) + print(stdevs) + +print("============================================") +print(" Testing computeOverlap ") +print("============================================") + +results = mbar.computeOverlap(return_dict=True) +O = results['scalar'] +O_i = results['eigenvalues'] +O_ij = results['matrix'] + +print("Overlap matrix output") +print(O_ij) + +for k in range(K): + print("Sum of row %d is %f (should be 1)," % (k,numpy.sum(O_ij[k,:])), end=' ') + if (numpy.abs(numpy.sum(O_ij[k,:])-1)<1.0e-10): + print("looks like it is.") + else: + print("but it's not.") + +print("Overlap eigenvalue output") +print(O_i) + + +print("Overlap scalar output") +print(O) + +print("============================================") +print(" Testing computeEffectiveSampleNumber ") +print("============================================") + +N_eff = mbar.computeEffectiveSampleNumber(verbose = True) +print("Effective Sample number") +print(N_eff) +print("Compare stanadrd estimate of with the MBAR estimate of ") +print("We should have that with MBAR, err_MBAR = sqrt(N_k/N_eff)*err_standard,") +print("so standard (scaled) results should be very close to MBAR results.") +print("No standard estimate exists for states that are not sampled.") +A_kn = x_kn +results = mbar.computeExpectations(A_kn, return_dict=True) +val_mbar = results['mu'] +err_mbar = results['sigma'] +err_standard = numpy.zeros([K],dtype = numpy.float64) +err_scaled = numpy.zeros([K],dtype = numpy.float64) + +for k in range(K): + if N_k[k] != 0: + # use position + err_standard[k] = numpy.std(A_kn[k,0:N_k[k]])/numpy.sqrt(N_k[k]-1) + err_scaled[k] = numpy.std(A_kn[k,0:N_k[k]])/numpy.sqrt(N_eff[k]-1) + +print(" ", end=' ') +for k in range(K): + print(" %d " %(k), end=' ') +print("") +print("MBAR :", end=' ') +print(err_mbar) +print("standard :", end=' ') +print(err_standard) +print("sqrt N_k/N_eff :", end=' ') +print(numpy.sqrt(N_k/N_eff)) +print("Standard (scaled):", end=' ') +print(err_standard * numpy.sqrt(N_k/N_eff)) + +print("============================================") +print(" Testing computePMF ") +print("============================================") + +# For 2-D, The equilibrium distribution is given analytically by +# p(x;beta,K) = sqrt[(beta K) / (2 pi)] exp[-beta K [(x-mu)^2] / 2] +# +# The dimensionless free energy is therefore +# f(beta,K) = - (1/2) * ln[ (2 pi) / (beta K) ] +# +# In this problem, we are investigating the sum of two Gaussians, once +# centered at 0, and others centered at grid points. +# +# V(x;K) = (K0/2) * [(x-x_0)^2] +# +# For 1-D, The equilibrium distribution is given analytically by +# p(x;beta,K) = 1/N exp[-beta (K0 [x^2] / 2 + KU [(x-mu)^2] / 2)] +# Where N is the normalization constant. +# +# The dimensionless free energy is the integral of this, and can be computed as: +# f(beta,K) = - ln [ (2*numpy.pi/(Ko+Ku))^(d/2) exp[ -Ku*Ko mu' mu / 2(Ko +Ku)] +# f(beta,K) - fzero = -Ku*Ko / 2(Ko+Ku) = 1/(1/(Ku/2) + 1/(K0/2)) + +def generate_pmf_data(ndim=1, nbinsperdim=15, nsamples = 1000, K0=20.0, Ku = 100.0, gridscale=0.2, xrange = [[-3,3]]): + + x0 = numpy.zeros([ndim], numpy.float64) # center of base potential + numbrellas = 1 + nperdim = numpy.zeros([ndim],int) + for d in range(ndim): + nperdim[d] = xrange[d][1] - xrange[d][0] + 1 + numbrellas *= nperdim[d] + + print("There are a total of %d umbrellas." % numbrellas) + + # Enumerate umbrella centers, and compute the analytical free energy of that umbrella + print("Constructing umbrellas...") + ksum = (Ku+K0)/beta + kprod = (Ku*K0)/(beta*beta) + f_k_analytical = numpy.zeros(numbrellas, numpy.float64); + xu_i = numpy.zeros([numbrellas, ndim], numpy.float64) # xu_i[i,:] is the center of umbrella i + + dp = numpy.zeros(ndim,int) + dp[0] = 1 + for d in range(1,ndim): + dp[d] = nperdim[d]*dp[d-1] + + umbrella_zero = 0 + for i in range(numbrellas): + center = [] + for d in range(ndim): + val = gridscale*((int(i//dp[d])) % nperdim[d] + xrange[d][0]) + center.append(val) + center = numpy.array(center) + xu_i[i,:] = center + mu2 = numpy.dot(center,center) + f_k_analytical[i] = numpy.log((ndim*numpy.pi/ksum)**(3.0/2.0) *numpy.exp(-kprod*mu2/(2.0*ksum))) + if numpy.all(center==0.0): # assumes that we have one state that is at the zero. + umbrella_zero = i + i += 1 + f_k_analytical -= f_k_analytical[umbrella_zero] + + print("Generating %d samples for each of %d umbrellas..." % (nsamples, numbrellas)) + x_n = numpy.zeros([numbrellas * nsamples, ndim], numpy.float64) + + for i in range(numbrellas): + for dim in range(ndim): + # Compute mu and sigma for this dimension for sampling from V0(x) + Vu(x). + # Product of Gaussians: N(x ; a, A) N(x ; b, B) = N(a ; b , A+B) x N(x ; c, C) where + # C = 1/(1/A + 1/B) + # c = C(a/A+b/B) + # A = 1/K0, B = 1/Ku + sigma = 1.0 / (K0 + Ku) + mu = sigma * (x0[dim]*K0 + xu_i[i,dim]*Ku) + # Generate normal deviates for this dimension. + x_n[i*nsamples:(i+1)*nsamples,dim] = numpy.random.normal(mu, numpy.sqrt(sigma), [nsamples]) + + u_kn = numpy.zeros([numbrellas, nsamples*numbrellas], numpy.float64) + # Compute reduced potential due to V0. + u_n = beta*(K0/2)*numpy.sum((x_n[:,:] - x0)**2, axis=1) + for k in range(numbrellas): + uu = beta*(Ku/2)*numpy.sum((x_n[:,:] - xu_i[k,:])**2, axis=1) # reduced potential due to umbrella k + u_kn[k,:] = u_n + uu + + return u_kn, u_n, x_n, f_k_analytical + +nbinsperdim = 15 +gridscale = 0.2 +nsamples = 1000 +ndim = 1 +K0 = 20.0 +Ku = 100.0 +print("============================================") +print(" Test 1: 1D PMF ") +print("============================================") + +xrange = [[-3,3]] +ndim = 1 +u_kn, u_n, x_n, f_k_analytical = generate_pmf_data(K0 = K0, Ku = Ku, ndim=ndim, nbinsperdim = nbinsperdim, nsamples = nsamples, gridscale = gridscale, xrange=xrange) +numbrellas = (numpy.shape(u_kn))[0] +N_k = nsamples*numpy.ones([numbrellas], int) +print("Solving for free energies of state ...") +mbar = MBAR(u_kn, N_k) + +# Histogram bins are indexed using the scheme: +# index = 1 + numpy.floor((x[0] - xmin)/dx) + nbins*numpy.floor((x[1] - xmin)/dy) +# index = 0 is reserved for samples outside of the allowed domain +xmin = gridscale*(numpy.min(xrange[0][0])-1/2.0) +xmax = gridscale*(numpy.max(xrange[0][1])+1/2.0) +dx = (xmax-xmin)/nbinsperdim +nbins = 1 + nbinsperdim**ndim +bin_centers = numpy.zeros([nbins,ndim],numpy.float64) + +ibin = 1 +pmf_analytical = numpy.zeros([nbins],numpy.float64) +minmu2 = 1000000 +zeroindex = 0 +# construct the bins and the pmf +for i in range(nbinsperdim): + xbin = xmin + dx * (i + 0.5) + bin_centers[ibin,0] = xbin + mu2 = xbin*xbin + if (mu2 < minmu2): + minmu2 = mu2 + zeroindex = ibin + pmf_analytical[ibin] = K0*mu2/2.0 + ibin += 1 +fzero = pmf_analytical[zeroindex] +pmf_analytical -= fzero +pmf_analytical[0] = 0 + +bin_n = numpy.zeros([numbrellas*nsamples], int) +# Determine indices of those within bounds. +within_bounds = (x_n[:,0] >= xmin) & (x_n[:,0] < xmax) +# Determine states for these. +bin_n[within_bounds] = 1 + numpy.floor((x_n[within_bounds,0]-xmin)/dx) +# Determine indices of bins that are not empty. +bin_counts = numpy.zeros([nbins], int) +for i in range(nbins): + bin_counts[i] = (bin_n == i).sum() + +# Compute PMF. +print("Computing PMF ...") +results = mbar.computePMF(u_n, bin_n, nbins, uncertainties = 'from-specified', pmf_reference = zeroindex, return_dict=True) +f_i = results['f_i'] +df_i = results['df_i'] + +# Show free energy and uncertainty of each occupied bin relative to lowest free energy + +print("1D PMF:") +print("%d counts out of %d counts not in any bin" % (bin_counts[0],numbrellas*nsamples)) +print("%8s %6s %8s %10s %10s %10s %10s %8s" % ('bin', 'x', 'N', 'f', 'true','error','df','sigmas')) +for i in range(1,nbins): + if (i == zeroindex): + stdevs = 0 + df_i[0] = 0 + else: + error = pmf_analytical[i]-f_i[i] + stdevs = numpy.abs(error)/df_i[i] + print('%8d %6.2f %8d %10.3f %10.3f %10.3f %10.3f %8.2f' % (i, bin_centers[i,0], bin_counts[i], f_i[i], pmf_analytical[i], error, df_i[i], stdevs)) + +print("============================================") +print(" Test 2: 2D PMF ") +print("============================================") + +xrange = [[-3,3],[-3,3]] +ndim = 2 +nsamples = 300 +u_kn, u_n, x_n, f_k_analytical = generate_pmf_data(K0 = K0, Ku = Ku, ndim=ndim, nbinsperdim = nbinsperdim, nsamples = nsamples, gridscale = gridscale, xrange=xrange) +numbrellas = (numpy.shape(u_kn))[0] +N_k = nsamples*numpy.ones([numbrellas], int) +print("Solving for free energies of state ...") +mbar = MBAR(u_kn, N_k) + +# The dimensionless free energy is the integral of this, and can be computed as: +# f(beta,K) = - ln [ (2*numpy.pi/(Ko+Ku))^(d/2) exp[ -Ku*Ko mu' mu / 2(Ko +Ku)] +# f(beta,K) - fzero = -Ku*Ko / 2(Ko+Ku) = 1/(1/(Ku/2) + 1/(K0/2)) +# for computing harmonic samples + +#Can compare the free energies computed with MBAR if desired with f_k_analytical + +# Histogram bins are indexed using the scheme: +# index = 1 + numpy.floor((x[0] - xmin)/dx) + nbins*numpy.floor((x[1] - xmin)/dy) +# index = 0 is reserved for samples outside of the allowed domain + +xmin = gridscale*(numpy.min(xrange[0][0])-1/2.0) +xmax = gridscale*(numpy.max(xrange[0][1])+1/2.0) +ymin = gridscale*(numpy.min(xrange[1][0])-1/2.0) +ymax = gridscale*(numpy.max(xrange[1][1])+1/2.0) +dx = (xmax-xmin)/nbinsperdim +dy = (ymax-ymin)/nbinsperdim +nbins = 1 + nbinsperdim**ndim +bin_centers = numpy.zeros([nbins,ndim],numpy.float64) + +ibin = 1 # first reserved for something outside. +pmf_analytical = numpy.zeros([nbins],numpy.float64) +minmu2 = 1000000 +zeroindex = 0 +# construct the bins and the pmf +for i in range(nbinsperdim): + xbin = xmin + dx * (i + 0.5) + for j in range(nbinsperdim): + # Determine (x,y) of bin center. + ybin = ymin + dy * (j + 0.5) + bin_centers[ibin,0] = xbin + bin_centers[ibin,1] = ybin + mu2 = xbin*xbin+ybin*ybin + if (mu2 < minmu2): + minmu2 = mu2 + zeroindex = ibin + pmf_analytical[ibin] = K0*mu2/2.0 + ibin += 1 +fzero = pmf_analytical[zeroindex] +pmf_analytical -= fzero +pmf_analytical[0] = 0 + +bin_n = numpy.zeros([numbrellas * nsamples], int) +# Determine indices of those within bounds. +within_bounds = (x_n[:,0] >= xmin) & (x_n[:,0] < xmax) & (x_n[:,1] >= ymin) & (x_n[:,1] < ymax) +# Determine states for these. +xgrid = (x_n[within_bounds,0]-xmin)/dx +ygrid = (x_n[within_bounds,1]-ymin)/dy +bin_n[within_bounds] = 1 + xgrid.astype(int) + nbinsperdim*ygrid.astype(int) + +# Determine indices of bins that are not empty. +bin_counts = numpy.zeros([nbins], int) +for i in range(nbins): + bin_counts[i] = (bin_n == i).sum() + +# Compute PMF. +print("Computing PMF ...") +[f_i, df_i] +results = mbar.computePMF(u_n, bin_n, nbins, uncertainties = 'from-specified', pmf_reference = zeroindex, return_dict=True) +f_i = results['f_i'] +df_i = results['df_i'] + +# Show free energy and uncertainty of each occupied bin relative to lowest free energy +print("2D PMF:") + +print("%d counts out of %d counts not in any bin" % (bin_counts[0],numbrellas*nsamples)) +print("%8s %6s %6s %8s %10s %10s %10s %10s %8s" % ('bin', 'x', 'y', 'N', 'f', 'true','error','df','sigmas')) +for i in range(1,nbins): + if (i == zeroindex): + stdevs = 0 + df_i[0] = 0 + else: + error = pmf_analytical[i]-f_i[i] + stdevs = numpy.abs(error)/df_i[i] + print('%8d %6.2f %6.2f %8d %10.3f %10.3f %10.3f %10.3f %8.2f' % (i, bin_centers[i,0], bin_centers[i,1] , bin_counts[i], f_i[i], pmf_analytical[i], error, df_i[i], stdevs)) + +#============================================================================================= +# TERMINATE +#============================================================================================= + +# Signal successful execution. +sys.exit(0) + diff --git a/examples/harmonic-oscillators/harmonic-oscillators.py b/examples/harmonic-oscillators/harmonic-oscillators.py index 6caed243..0292c3dc 100644 --- a/examples/harmonic-oscillators/harmonic-oscillators.py +++ b/examples/harmonic-oscillators/harmonic-oscillators.py @@ -78,7 +78,6 @@ def get_analytical(beta, K, O, observables): # PARAMETERS # ============================================================================================= - K_k = np.array([25, 16, 9, 4, 1, 1]) # spring constants for each state O_k = np.array([0, 1, 2, 3, 4, 5]) # offsets for spring constants # number of samples from each state (can be zero for some states) diff --git a/pymbar/mbar.py b/pymbar/mbar.py index 5048d6ad..d8941f5d 100644 --- a/pymbar/mbar.py +++ b/pymbar/mbar.py @@ -342,9 +342,7 @@ def __init__( # which might involve passing in different combinations of options, and passing out other strings. solver["options"]["verbose"] = self.verbose - self.f_k = mbar_solvers.solve_mbar_for_all_states( - self.u_kn, self.N_k, self.f_k, solver_protocol - ) + self.f_k = mbar_solvers.solve_mbar_for_all_states(self.u_kn, self.N_k, self.f_k, self.states_with_samples,solver_protocol) self.Log_W_nk = mbar_solvers.mbar_log_W_nk(self.u_kn, self.N_k, self.f_k) # Print final dimensionless free energies. diff --git a/pymbar/mbar_solvers.py b/pymbar/mbar_solvers.py index 5003ca68..c752774f 100644 --- a/pymbar/mbar_solvers.py +++ b/pymbar/mbar_solvers.py @@ -1,9 +1,13 @@ -from __future__ import division # Ensure same division behavior in py2 and py3 import logging import numpy as np import math import scipy.optimize -from pymbar.utils import ensure_type, logsumexp, check_w_normalized +from pymbar.utils import ensure_type, check_w_normalized +import jax +from jax.scipy.special import logsumexp +from jax.ops import index_update, index +from jax.config import config; config.update("jax_enable_x64", True) +import jax.numpy as jnp import warnings logger = logging.getLogger(__name__) @@ -40,16 +44,27 @@ def validate_inputs(u_kn, N_k, f_k): """ n_states, n_samples = u_kn.shape - u_kn = ensure_type(u_kn, "float", 2, "u_kn or Q_kn", shape=(n_states, n_samples)) - N_k = ensure_type( - N_k, "float", 1, "N_k", shape=(n_states,), warn_on_cast=False - ) # Autocast to float because will be eventually used in float calculations. - f_k = ensure_type(f_k, "float", 1, "f_k", shape=(n_states,)) + u_kn = ensure_type(u_kn, 'float', 2, "u_kn or Q_kn", shape=(n_states, n_samples)) + N_k = ensure_type(N_k, 'float', 1, "N_k", shape=(n_states,), warn_on_cast=False) # Autocast to float because will be eventually used in float calculations. + f_k = ensure_type(f_k, 'float', 1, "f_k", shape=(n_states,)) return u_kn, N_k, f_k +def jax_self_consistent_update(u_kn, N_k, f_k, states_with_samples=None): -def self_consistent_update(u_kn, N_k, f_k): + jNk = 1.0*N_k + + # Only the states with samples can contribute to the denominator term. + if states_with_samples is not None: + log_denominator_n = logsumexp(f_k[states_with_samples] - u_kn[states_with_samples].T, b=jNk[states_with_samples], axis=1) + else: + log_denominator_n = logsumexp(f_k - u_kn.T, b=jNk, axis=1) + # All states can contribute to the numerator term. + return -1. * logsumexp(-log_denominator_n - u_kn, axis=1) # check transpose + +jit_self_consistent_update = jax.jit(jax_self_consistent_update) + +def self_consistent_update(u_kn, N_k, f_k, states_with_samples=None): """Return an improved guess for the dimensionless free energies Parameters @@ -70,19 +85,16 @@ def self_consistent_update(u_kn, N_k, f_k): ----- Equation C3 in MBAR JCP paper. """ + return jit_self_consistent_update(u_kn, N_k, f_k, states_with_samples=states_with_samples) - u_kn, N_k, f_k = validate_inputs(u_kn, N_k, f_k) +def jax_mbar_gradient(u_kn, N_k, f_k): - states_with_samples = N_k > 0 - - # Only the states with samples can contribute to the denominator term. - log_denominator_n = logsumexp( - f_k[states_with_samples] - u_kn[states_with_samples].T, b=N_k[states_with_samples], axis=1 - ) - - # All states can contribute to the numerator term. - return -1.0 * logsumexp(-log_denominator_n - u_kn, axis=1) + jNk = 1.0*N_k + log_denominator_n = logsumexp(f_k - u_kn.T, b=jNk, axis=1) + log_numerator_k = logsumexp(-log_denominator_n - u_kn, axis=1) + return -1 * jNk * (1.0 - jnp.exp(f_k + log_numerator_k)) +jit_mbar_gradient = jax.jit(jax_mbar_gradient) def mbar_gradient(u_kn, N_k, f_k): """Gradient of MBAR objective function. @@ -105,12 +117,19 @@ def mbar_gradient(u_kn, N_k, f_k): ----- This is equation C6 in the JCP MBAR paper. """ - u_kn, N_k, f_k = validate_inputs(u_kn, N_k, f_k) + return jit_mbar_gradient(u_kn, N_k, f_k) + +def jax_mbar_objective_and_gradient(u_kn, N_k, f_l): log_denominator_n = logsumexp(f_k - u_kn.T, b=N_k, axis=1) log_numerator_k = logsumexp(-log_denominator_n - u_kn, axis=1) - return -1 * N_k * (1.0 - np.exp(f_k + log_numerator_k)) + grad = -1 * N_k * (1.0 - jnp.exp(f_k + log_numerator_k)) + + obj = jnp.sum(log_denominator_n) - N_k.dot(f_k) + return obj, grad + +jit_mbar_objective_and_gradient = jax.jit(jax_mbar_objective_and_gradient) def mbar_objective_and_gradient(u_kn, N_k, f_k): """Calculates both objective function and gradient for MBAR. @@ -124,7 +143,6 @@ def mbar_objective_and_gradient(u_kn, N_k, f_k): f_k : np.ndarray, shape=(n_states), dtype='float' The reduced free energies of each state - Returns ------- obj : float @@ -139,22 +157,23 @@ def mbar_objective_and_gradient(u_kn, N_k, f_k): results, u_kn can be preconditioned by subtracting out a `n` dependent vector. - More optimal precision, the objective function uses math.fsum for the - outermost sum and logsumexp for the inner sum. - The gradient is equation C6 in the JCP MBAR paper; the objective function is its integral. """ - u_kn, N_k, f_k = validate_inputs(u_kn, N_k, f_k) + return jit_mbar_objective_and_gradient(u_kn, N_k, f_k) - log_denominator_n = logsumexp(f_k - u_kn.T, b=N_k, axis=1) - log_numerator_k = logsumexp(-log_denominator_n - u_kn, axis=1) - grad = -1 * N_k * (1.0 - np.exp(f_k + log_numerator_k)) +def jax_mbar_hessian(u_kn, N_k, f_k): - obj = math.fsum(log_denominator_n) - N_k.dot(f_k) - - return obj, grad + jNk = 1.0*N_k + log_denominator_n = logsumexp(f_k - u_kn.T, b=jNk, axis=1) + logW = f_k - u_kn.T - log_denominator_n[:, jnp.newaxis] + W = jnp.exp(logW) + H = W.T.dot(W) + H *= jNk + H *= jNk[:, jnp.newaxis] + H -= jnp.diag(W.sum(0) * jNk) + return -1.0 * H def mbar_hessian(u_kn, N_k, f_k): """Hessian of MBAR objective function. @@ -177,17 +196,14 @@ def mbar_hessian(u_kn, N_k, f_k): ----- Equation (C9) in JCP MBAR paper. """ - u_kn, N_k, f_k = validate_inputs(u_kn, N_k, f_k) - - W = mbar_W_nk(u_kn, N_k, f_k) + return jax_mbar_hessian(u_kn, N_k, f_k) - H = W.T.dot(W) - H *= N_k - H *= N_k[:, np.newaxis] - H -= np.diag(W.sum(0) * N_k) - - return -1.0 * H +def jax_mbar_log_W_nk(u_kn, N_k, f_k): + jNk = 1.0*N_k + log_denominator_n = logsumexp(f_k - u_kn.T, b=jNk, axis=1) + logW = f_k - u_kn.T - log_denominator_n[:, jnp.newaxis] + return logW def mbar_log_W_nk(u_kn, N_k, f_k): """Calculate the log weight matrix. @@ -210,14 +226,11 @@ def mbar_log_W_nk(u_kn, N_k, f_k): ----- Equation (9) in JCP MBAR paper. """ - u_kn, N_k, f_k = validate_inputs(u_kn, N_k, f_k) - - log_denominator_n = logsumexp(f_k - u_kn.T, b=N_k, axis=1) - logW = f_k - u_kn.T - log_denominator_n[:, np.newaxis] - return logW + return jit_mbar_log_W_nk(u_kn, N_k, f_k) +jit_mbar_log_W_nk = jax.jit(jax_mbar_log_W_nk) -def mbar_W_nk(u_kn, N_k, f_k): +def jax_mbar_W_nk(u_kn, N_k, f_k): """Calculate the weight matrix. Parameters @@ -238,8 +251,9 @@ def mbar_W_nk(u_kn, N_k, f_k): ----- Equation (9) in JCP MBAR paper. """ - return np.exp(mbar_log_W_nk(u_kn, N_k, f_k)) + return npj.exp(jax_mbar_log_W_nk(u_kn, N_k, f_k)) +jit_mbar_W_nk = jax.jit(jax_mbar_W_nk) def adaptive(u_kn, N_k, f_k, tol=1.0e-12, options=None): @@ -279,7 +293,6 @@ def adaptive(u_kn, N_k, f_k, tol=1.0e-12, options=None): options.setdefault("print_warning", False) options.setdefault("gamma", 1.0) - gamma = options["gamma"] doneIterating = False if options["verbose"] == True: logger.info( @@ -292,40 +305,10 @@ def adaptive(u_kn, N_k, f_k, tol=1.0e-12, options=None): nr_iter = 0 sci_iter = 0 - f_sci = np.zeros(len(f_k), dtype=np.float64) - f_nr = np.zeros(len(f_k), dtype=np.float64) - - # Perform Newton-Raphson iterations (with sci computed on the way) - - # usually calculated at the end of the loop and saved, but we need - # to calculate the first time. - g = mbar_gradient(u_kn, N_k, f_k) # Objective function gradient. - for iteration in range(0, options["maximum_iterations"]): - H = mbar_hessian(u_kn, N_k, f_k) # Objective function hessian - Hinvg = np.linalg.lstsq(H, g, rcond=-1)[0] - Hinvg -= Hinvg[0] - f_nr = f_k - gamma * Hinvg - - # self-consistent iteration gradient norm and saved log sums. - f_sci = self_consistent_update(u_kn, N_k, f_k) - f_sci = f_sci - f_sci[0] # zero out the minimum - g_sci = mbar_gradient(u_kn, N_k, f_sci) - gnorm_sci = np.dot(g_sci, g_sci) - - # newton raphson gradient norm and saved log sums. - g_nr = mbar_gradient(u_kn, N_k, f_nr) - gnorm_nr = np.dot(g_nr, g_nr) - - # we could save the gradient, for the next round, but it's not too expensive to - # compute since we are doing the Hessian anyway. - - if options["verbose"]: - logger.info( - "self consistent iteration gradient norm is %10.5g, Newton-Raphson gradient norm is %10.5g" - % (gnorm_sci, gnorm_nr) - ) + (f_sci, g_sci, gnorm_sci, f_nr, g_nr, gnorm_nr) = jit_core_adaptive(u_kn, N_k, f_k, options['gamma']) + # decide which directon to go depending on size of gradient norm f_old = f_k if gnorm_sci < gnorm_nr or sci_iter < 2: @@ -334,11 +317,10 @@ def adaptive(u_kn, N_k, f_k, tol=1.0e-12, options=None): sci_iter += 1 if options["verbose"]: if sci_iter < 2: - logger.info("Choosing self-consistent iteration on iteration %d" % iteration) + logger.info(f"Choosing self-consistent iteration on iteration {iteration}") else: logger.info( - "Choosing self-consistent iteration for lower gradient on iteration %d" - % iteration + f"Choosing self-consistent iteration for lower gradient on iteration {iteration}" ) else: f_k = f_nr @@ -346,48 +328,70 @@ def adaptive(u_kn, N_k, f_k, tol=1.0e-12, options=None): nr_iter += 1 if options["verbose"]: logger.info("Newton-Raphson used on iteration %d" % iteration) - - div = np.abs(f_k[1:]) # what we will divide by to get relative difference - zeroed = np.abs(f_k[1:]) < np.min( - [10 ** -8, tol] - ) # check which values are near enough to zero, hard coded max for now. - div[zeroed] = 1.0 # for these values, use absolute values. - max_delta = np.max(np.abs(f_k[1:] - f_old[1:]) / div) - if np.isnan(max_delta) or (max_delta < tol): + + div = jnp.abs(f_k[1:]) # what we will divide by to get relative difference + zeroed = jnp.abs(f_k[1:])< np.min([10**-8,tol]) # check which values are near enough to zero, hard coded max for now. + jax.ops.index_update(div,index[zeroed],1.0) + max_delta = jnp.max(jnp.abs(f_k[1:]-f_old[1:])/div) + if jnp.isnan(max_delta) or (max_delta < tol): doneIterating = True break if doneIterating: - if options["verbose"]: - logger.info( - "Converged to tolerance of {:e} in {:d} iterations.".format( - max_delta, iteration + 1 - ) - ) - logger.info( - "Of {:d} iterations, {:d} were Newton-Raphson iterations and {:d} were self-consistent iterations".format( - iteration + 1, nr_iter, sci_iter - ) - ) - if np.all(f_k == 0.0): - # all f_k appear to be zero - logger.info("WARNING: All f_k appear to be zero.") + if options["verbose"]: + logger.info( + f"Converged to tolerance of {max_delta:e} in {iteration+1} iterations." + ) + logger.info( + f"Of {iteration+1} iterations, {nr_iter} were Newton-Raphson iterations and {sci_iter} were self-consistent iterations" + ) + if np.all(f_k == 0.0): + # all f_k appear to be zero + logger.info("WARNING: All f_k appear to be zero.") else: logger.warning("WARNING: Did not converge to within specified tolerance.") if options["maximum_iterations"] <= 0: logger.warning( - "No iterations ran be cause maximum_iterations was <= 0 ({:s})!".format( - options["maximum_iterations"] - ) + f"No iterations ran because maximum_iterations was <= 0 ({options['maximum_iterations']})!" ) else: logger.warning( - "max_delta = {:e}, tol = {:e}, maximum_iterations = {:d}, iterations completed = {:d}".format( - max_delta, tol, options["maximum_iterations"], iteration - ) + f"max_delta = {max_delta:e}, tol = {tol:e}, maximum_iterations = {options['maximum_iterations']}, iterations completed = {iteration}" ) + return f_k +def jax_core_adaptive(u_kn, N_k, f_k, gamma): + + # Perform Newton-Raphson iterations (with sci computed on the way) + g = mbar_gradient(u_kn, N_k, f_k) # Objective function gradient + H = mbar_hessian(u_kn, N_k, f_k) # Objective function hessian + Hinvg = jnp.linalg.lstsq(H, g, rcond=-1)[0] + Hinvg -= Hinvg[0] + f_nr = f_k - gamma * Hinvg + + # self-consistent iteration gradient norm and saved log sums. + f_sci = self_consistent_update(u_kn, N_k, f_k) + f_sci = f_sci - f_sci[0] # zero out the minimum + g_sci = mbar_gradient(u_kn, N_k, f_sci) + gnorm_sci = jnp.dot(g_sci, g_sci) + + # newton raphson gradient norm and saved log sums. + g_nr = mbar_gradient(u_kn, N_k, f_nr) + gnorm_nr = jnp.dot(g_nr, g_nr) + + return (f_sci, g_sci, gnorm_sci, f_nr, g_nr, gnorm_nr) + +jit_core_adaptive = jax.jit(jax_core_adaptive) + +def jax_precondition_u_kn(u_kn,N_k,f_k): + + jNk = 1.0*N_k + u_kn = u_kn - u_kn.min(0) + u_kn += (logsumexp(f_k - u_kn.T, b=jNk, axis=1)) - jNk.dot(f_k) / jNk.sum() + return u_kn + +jit_precondition_u_kn = jax.jit(jax_precondition_u_kn) def precondition_u_kn(u_kn, N_k, f_k): """Subtract a sample-dependent constant from u_kn to improve precision @@ -414,11 +418,7 @@ def precondition_u_kn(u_kn, N_k, f_k): x_n such that the current objective function value is zero, which should give maximum precision in the objective function. """ - u_kn, N_k, f_k = validate_inputs(u_kn, N_k, f_k) - u_kn = u_kn - u_kn.min(0) - u_kn += (logsumexp(f_k - u_kn.T, b=N_k, axis=1)) - N_k.dot(f_k) / float(N_k.sum()) - return u_kn - + return jax_precondition_u_kn(u_kn, N_k, f_k) def solve_mbar_once( u_kn_nonzero, N_k_nonzero, f_k_nonzero, method="hybr", tol=1e-12, options=None @@ -468,7 +468,7 @@ def solve_mbar_once( u_kn_nonzero, N_k_nonzero, f_k_nonzero ) f_k_nonzero = f_k_nonzero - f_k_nonzero[0] # Work with reduced dimensions with f_k[0] := 0 - u_kn_nonzero = precondition_u_kn(u_kn_nonzero, N_k_nonzero, f_k_nonzero) + u_kn_nonzero = jit_precondition_u_kn(u_kn_nonzero, N_k_nonzero, f_k_nonzero) pad = lambda x: np.pad( x, (1, 0), mode="constant" @@ -607,7 +607,7 @@ def solve_mbar(u_kn_nonzero, N_k_nonzero, f_k_nonzero, solver_protocol=None): return f_k_nonzero, all_results -def solve_mbar_for_all_states(u_kn, N_k, f_k, solver_protocol): +def solve_mbar_for_all_states(u_kn, N_k, f_k, states_with_samples, solver_protocol): """Solve for free energies of states with samples, then calculate for empty states. @@ -628,7 +628,6 @@ def solve_mbar_for_all_states(u_kn, N_k, f_k, solver_protocol): f_k : np.ndarray, shape=(n_states), dtype='float' The free energies of states """ - states_with_samples = np.where(N_k > 0)[0] if len(states_with_samples) == 1: f_k_nonzero = np.array([0.0]) @@ -640,10 +639,10 @@ def solve_mbar_for_all_states(u_kn, N_k, f_k, solver_protocol): solver_protocol=solver_protocol, ) - f_k[states_with_samples] = f_k_nonzero + f_k[states_with_samples] = np.array(f_k_nonzero) # Update all free energies because those from states with zero samples are not correctly computed by solvers. - f_k = self_consistent_update(u_kn, N_k, f_k) + f_k = np.array(self_consistent_update(u_kn, N_k, f_k, states_with_samples)) # This is necessary because state 0 might have had zero samples, # but we still want that state to be the reference with free energy 0. f_k -= f_k[0]