From 831516cccf30736aa3f3367b6b2bb6e810fc65c5 Mon Sep 17 00:00:00 2001 From: Alistair White Date: Tue, 26 Nov 2024 17:17:48 -0800 Subject: [PATCH] Fix default save dest --- baseline_models/HSR/training/hsr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/baseline_models/HSR/training/hsr.py b/baseline_models/HSR/training/hsr.py index 9908bfb..fa63e18 100644 --- a/baseline_models/HSR/training/hsr.py +++ b/baseline_models/HSR/training/hsr.py @@ -80,7 +80,7 @@ def sample(self, x, random=True): else: return mu, torch.exp(logprec)**(-0.5) - def trainer(self, data, epochs=20, save="models/vae.cp", plot=True, loss_type='mle', + def trainer(self, data, epochs=20, save="models/hsr.cp", plot=True, loss_type='mle', optimizer='adam', lr=0.0001, gamma=0.01, rho=None): """ Train the Heteroskedastic Regression model