You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
First of all, thank you for developing such a versatile and efficient library! I suspect that I came across a bug when working with BoTorch, but I believe it is originating from GPyTorch.
When working with a custom mean function that depends on some parameters $\theta$ that we wish to optimize (e.g. NN feature extractor), the derivative of the predictive mean is wrong.
The predictive mean $\mu(\cdot)$ (given a custom mean function $m_\theta$) at $x^*$ is given by
Easy test for the derivative is to predict at the observed data points $X$, which gives us (when the observational noise is small, $\sigma \approx 0$)
$$
\mu(X | X, y) \approx m_\theta(X) + y - m_\theta(X) = y
$$
whose derivative w.r.t the mean module's parameters should be zero.
It appears that at least in the following specific case this does not happen, and seems to be related to incorrect detaching at one place (see below for a hypothetical location where this happens)
To reproduce
** Code snippet to reproduce **
frombotorch.modelsimportSingleTaskGPimporttorchtorch.manual_seed(123)
#define a model for the meanclassLinearModel(torch.nn.Module):
def__init__(self, D=1, *args, **kwargs) ->None:
super().__init__(*args, **kwargs)
self.beta=torch.nn.Parameter(torch.randn(1,D))
defforward(self, X):
return (X*self.beta).sum(-1)
lm=LinearModel(1)
opt=torch.optim.Adam(lm.parameters(), lr=0.424242)#for zeroing grads#generate some dataN=50x_data=torch.linspace(0, 1, N).view(-1, 1)
y_data=2*torch.sin(10*x_data) +0.01*torch.randn_like(x_data)
#gp prediction manuallydefgp_pred(xstar, obs_X, obs_Y, prior_mean, gp, detach_bug=False):
samples_prior_mean=prior_mean(xstar)
obs_prior_mean=prior_mean(obs_X)
gp_y=obs_Y-obs_prior_mean.unsqueeze(-1)
#gp predK=gp.covar_module(obs_X, obs_X).to_dense().detach()
likelihood_additive_noise=gp.likelihood.noise_covar.raw_noise.detach()
KplusNoise=K+likelihood_additive_noise*torch.eye(K.shape[0])
Kstar=gp.covar_module(xstar, obs_X).to_dense().detach()
KpNinv=torch.linalg.inv(KplusNoise)
#pred_mean = Kstar @ KpNinv @ gp_ypred_mean=Kstar @ torch.linalg.solve(KplusNoise, gp_y)
ifdetach_bug: #the bug is herepred_mean=pred_mean.detach()
pred_mean_og_scale=pred_mean.squeeze() +samples_prior_meanpred_var=gp.covar_module.outputscale.detach() - (Kstar @ KpNinv @ Kstar.T).diag()
returnpred_mean_og_scale, pred_var#define the Gpytorch modelimportgpytorchclassExactGPModel(gpytorch.models.ExactGP):
def__init__(self, train_x, train_y, likelihood, mean_module, covar_module):
super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
self.mean_module=mean_moduleself.covar_module=covar_moduledefforward(self, x):
mean_x=self.mean_module(x)
covar_x=self.covar_module(x)
returngpytorch.distributions.MultivariateNormal(mean_x, covar_x)
#botorch modelgp_botorch=SingleTaskGP(x_data, y_data, mean_module=lm)
gp_botorch.likelihood.noise_covar.noise=0.0001#to simulate near noiseless predictiongp_botorch.eval()
gp_botorch.likelihood.eval()
#gpytorch modelgp_gpytorch=ExactGPModel(x_data, y_data.squeeze(), gp_botorch.likelihood, lm, gp_botorch.covar_module)
gp_gpytorch.eval()
#test that manual, botorch and gpytorch actually output same predictionsx_star=torch.linspace(0, 1, 100).view(-1, 1)
pred_mean, pred_var=gp_pred(x_star, x_data, y_data, lm, gp_botorch)
pred_botorch=gp_botorch(x_star)
pred_gpytorch=gp_gpytorch(x_star)
#check botorch and gpytorch are equalprint("botorch vs. gpytorch")
print(torch.abs(pred_botorch.mean-pred_gpytorch.mean).max())
print(torch.abs(pred_botorch.variance-pred_gpytorch.variance).max())
#check predictions botorch and manual are approximately equalprint("botorch vs manual")
print(torch.abs(pred_mean-gp_botorch(x_star).mean).max())
print(torch.abs(pred_var-gp_botorch(x_star).variance).max())
#plotting# import matplotlib.pyplot as plt# plt.scatter(x_data, y_data, label="data")# plt.plot(x_star, pred_mean.detach(), label="manual")# plt.plot(x_star, pred_botorch.mean.detach(), label="botorch")# plt.plot(x_star, pred_gpytorch.mean.detach(), label="gpytorch")# plt.legend()# plt.show()#prediciton at the training pointsopt.zero_grad()
pred_mean, pred_var=gp_pred(x_data, x_data, y_data, lm, gp_botorch)
pred_mean.mean().backward()
print(lm.beta.grad) #correct: ≈ 0opt.zero_grad()
pred_mean, pred_var=gp_pred(x_data, x_data, y_data, lm, gp_botorch, detach_bug=True)
pred_mean.mean().backward()
print(lm.beta.grad) #incorrect: ≠ 0opt.zero_grad()
pred_botorch=gp_botorch(x_data)
pred_botorch.mean.mean().backward()
print(lm.beta.grad) #incorrect: ≠ 0opt.zero_grad()
pred_gpytorch=gp_gpytorch(x_data)
pred_gpytorch.mean.mean().backward()
print(lm.beta.grad) #incorrect: ≠ 0
Now the last three predictions are equal in outputs and in gradients (imo incorrect), but the first one matches in outputs and produces correct gradient.
Expected Behavior
The gradient of the predictive mean at the observation locations $X$ to be $0$. See #correct in the above snippet.
System information
Please complete the following information:
Gpytorch 1.11
PyTorch 2.1.0
Sonoma 14.4.1
I could not locate a bug in GPyTorch code, but hopefully you will be able to locate it with this report.
The text was updated successfully, but these errors were encountered:
🐛 Bug
Hi,
First of all, thank you for developing such a versatile and efficient library! I suspect that I came across a bug when working with BoTorch, but I believe it is originating from GPyTorch.
When working with a custom mean function that depends on some parameters$\theta$ that we wish to optimize (e.g. NN feature extractor), the derivative of the predictive mean is wrong.
The predictive mean$\mu(\cdot)$ (given a custom mean function $m_\theta$ ) at $x^*$ is given by
Easy test for the derivative is to predict at the observed data points$X$ , which gives us (when the observational noise is small, $\sigma \approx 0$ )
whose derivative w.r.t the mean module's parameters should be zero.
It appears that at least in the following specific case this does not happen, and seems to be related to incorrect detaching at one place (see below for a hypothetical location where this happens)
To reproduce
** Code snippet to reproduce **
Now the last three predictions are equal in outputs and in gradients (imo incorrect), but the first one matches in outputs and produces correct gradient.
Expected Behavior
The gradient of the predictive mean at the observation locations$X$ to be $0$ . See #correct in the above snippet.
System information
Please complete the following information:
I could not locate a bug in GPyTorch code, but hopefully you will be able to locate it with this report.
The text was updated successfully, but these errors were encountered: