From e45cd17c361448426f13b98b760c8dbd8c9e6359 Mon Sep 17 00:00:00 2001 From: surkovvv Date: Wed, 20 Sep 2023 22:30:12 +0300 Subject: [PATCH] optimizer fix --- homework01/homework_part1.ipynb | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/homework01/homework_part1.ipynb b/homework01/homework_part1.ipynb index c278e70..147d551 100644 --- a/homework01/homework_part1.ipynb +++ b/homework01/homework_part1.ipynb @@ -320,11 +320,16 @@ "outputs": [], "source": [ "#!L\n", + "import multiprocessing\n", + "\n", + "\n", "batch_size = 64\n", + "num_workers = multiprocessing.cpu_count()\n", + "\n", "train_batch_gen = torch.utils.data.DataLoader(train_dataset, \n", " batch_size=batch_size,\n", " shuffle=True,\n", - " num_workers=12)" + " num_workers=num_workers)" ] }, { @@ -340,7 +345,7 @@ "val_batch_gen = torch.utils.data.DataLoader(val_dataset, \n", " batch_size=batch_size,\n", " shuffle=False,\n", - " num_workers=12)" + " num_workers=num_workers)" ] }, { @@ -555,7 +560,7 @@ " train_loss = []\n", " model.train(True) # enable dropout / batch_norm training behavior\n", " for (X_batch, y_batch) in tqdm.tqdm(train_data_generator):\n", - " opt.zero_grad()\n", + " optimizer.zero_grad()\n", "\n", " # forward\n", " # YOUR CODE: move X_batch, y_batch to 'device', compute model outputs on X_batch, \n",